├── LICENSE ├── README.md ├── instrument.py ├── synthetic_data.py ├── util.py ├── spender_model.py └── train_aestra_synthetic.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yan Liang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AESTRA Architecture 2 | AESTRA (Auto-Encoding STellar Radial-velocity and Activity) is a deep learning method for precise radial velocity measurements in the presence of stellar activity noise. 3 | The architecture combines a convolutional radial-velocity estimator and a spectrum auto-encoder called [spender](https://github.com/pmelchior/spender). For an in-depth understanding of the spectrum auto-encoder, see [Melchior et al. 2023](https://iopscience.iop.org/article/10.3847/1538-3881/ace0ff) and [Liang et al. 2023](https://iopscience.iop.org/article/10.3847/1538-3881/ace100). 4 | 5 | **Liang, Y., Winn, J. N., & Melchior, P.** (2023). *AESTRA: Deep Learning for Precise Radial Velocity Estimation in the Presence of Stellar Activity*. Manuscript submitted. 6 | 7 | The input consists of a collection of hundreds or more of spectra of a single star, which span a variety of activity states and orbital motion phases of any potential planets. 8 | 9 | ![AESTRA_Diagram_R1](https://github.com/yanliang-astro/aestra/assets/71669502/f6d6f40f-98a7-4d9a-bc90-8db084545d2a) 10 | 11 | Training of the AESTRA architecture does not require a spectral template or line list, or indeed any prior knowledge about the star. 12 | The spectrum auto-encoder is trained with a fidelity loss that ensures accurate reconstruction of the activity. 13 | The RV estimator network is trained with an RV-consistency loss that seeks to recover the injected velocity offset from an artificially Doppler-shifted "augment" spectrum. 14 | 15 | -------------------------------------------------------------------------------- /instrument.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import numpy as np 3 | from torch import nn 4 | from torchinterp1d import Interp1d 5 | 6 | class BaseInstrument(nn.Module): 7 | def __init__(self, 8 | wave_obs, 9 | lsf=None, 10 | calibration=None, 11 | ): 12 | 13 | super(BaseInstrument, self).__init__() 14 | 15 | self.calibration = calibration 16 | if lsf is not None: 17 | # construct conv1d layer 18 | self.lsf = nn.Conv1d(1, 1, len(lsf), bias=False, padding='same') 19 | # if LSF should be fit, set `requires_grad=True` 20 | self.lsf.weight = nn.Parameter(lsf.flip(0).reshape(1,1,-1), requires_grad=False) 21 | else: 22 | self.lsf = None 23 | 24 | # register wavelength tensors on the same device as the entire model 25 | self.register_buffer('wave_obs', wave_obs) 26 | self.register_buffer('skyline_mask', skylines_mask(wave_obs)) 27 | 28 | @property 29 | def name(self): 30 | return self.__class__.__name__ 31 | 32 | def skylines_mask(waves, intensity_limit=2, radii=5): 33 | this_dir, this_filename = os.path.split(__file__) 34 | filename = os.path.join(this_dir, "sky-lines.txt") 35 | f=open(filename,"r") 36 | content = f.readlines() 37 | f.close() 38 | 39 | skylines = [[10*float(line.split()[0]),float(line.split()[1])] for line in content if not line[0]=="#" ] 40 | skylines = np.array(skylines) 41 | 42 | n_lines = 0 43 | mask = ~(waves>0) 44 | 45 | for line in skylines: 46 | line_pos, intensity = line 47 | if line_pos>waves.max():continue 48 | if intensity(line_pos-radii))] = True 51 | 52 | non_zero = torch.count_nonzero(mask) 53 | return mask 54 | 55 | # allow registry of new instruments 56 | # see https://effectivepython.com/2015/02/02/register-class-existence-with-metaclasses 57 | instrument_register = {} 58 | 59 | def register_class(target_class): 60 | instrument_register[target_class.__name__] = target_class 61 | 62 | class Meta(type): 63 | def __new__(meta, name, bases, class_dict): 64 | cls = type.__new__(meta, name, bases, class_dict) 65 | # remove those that are directly derived from the base class 66 | if BaseInstrument not in bases: 67 | register_class(cls) 68 | return cls 69 | 70 | class Instrument(BaseInstrument, metaclass=Meta): 71 | pass 72 | -------------------------------------------------------------------------------- /synthetic_data.py: -------------------------------------------------------------------------------- 1 | import glob, os, urllib.request 2 | import numpy as np 3 | import pickle 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torchinterp1d import Interp1d 8 | import astropy.io.fits as fits 9 | import astropy.table as aTable 10 | from functools import partial 11 | 12 | from instrument import Instrument 13 | from util import BatchedFilesDataset, load_batch, cubic_transform 14 | 15 | class Synthetic(Instrument): 16 | _wave_obs = torch.arange(5000,5050,0.025, dtype=torch.double) 17 | wave_rest = torch.linspace(4999.0, 5051.,2100, dtype=torch.double) 18 | c = 299792458. # m/s 19 | 20 | def __init__(self, lsf=None, calibration=None): 21 | super().__init__(Synthetic._wave_obs, lsf=lsf, calibration=calibration) 22 | 23 | @classmethod 24 | def get_data_loader(cls, dir, select=None, which=None, tag=None, batch_size=30, shuffle=False): 25 | files = cls.list_batches(dir, select=select, 26 | which=which, tag=tag) 27 | if which in ["train", "valid"]: 28 | subset = slice(0,4) 29 | else: 30 | subset = None 31 | load_fct = partial(load_batch, subset=subset) 32 | data = BatchedFilesDataset(files, load_fct, shuffle=shuffle) 33 | return DataLoader(data, batch_size=batch_size) 34 | 35 | @classmethod 36 | def list_batches(cls, dir, select=None, which=None, tag=None): 37 | if tag is None:tag = "chunk50" 38 | if select is None:select = cls.__mro__[0].__name__ 39 | filename = f"{select}{tag}_*.pkl" 40 | batch_files = glob.glob(dir + "/" + filename) 41 | batches = [item for item in batch_files if not "copy" in item] 42 | 43 | NBATCH = len(batches) 44 | train_batches = batches#[:int(0.9*NBATCH)] 45 | #valid_batches = batches[int(0.9*NBATCH):int(0.95*NBATCH)] 46 | valid_batches = test_batches = batches[int(0.9*NBATCH):] 47 | 48 | if which == "test": return test_batches 49 | elif which == "valid": return valid_batches 50 | elif which == "train": return train_batches 51 | else: return batches 52 | 53 | @classmethod 54 | def save_batch(cls, dir, batch, select=None, tag=None, counter=None): 55 | if tag is None: 56 | tag = f"chunk{len(batch[-1])}" 57 | if select is None:select = cls.__mro__[0].__name__ 58 | if counter is None: 59 | counter = "" 60 | filename = os.path.join(dir, f"{select}{tag}_{counter}.pkl") 61 | with open(filename, 'wb') as f: 62 | pickle.dump(batch, f) 63 | 64 | @classmethod 65 | def save_in_batches(cls, dir, files, select=None, tag=None, batch_size=30): 66 | N = len(files) 67 | idx = np.arange(0, N, batch_size) 68 | batches = np.array_split(files, idx[1:]) 69 | for counter, ids_ in zip(idx, batches): 70 | print (f"saving batch {counter} / {N}") 71 | print("batch size:",len(ids_)) 72 | batch = cls.make_batch(ids_) 73 | cls.save_batch(dir, batch, select, tag=tag, counter=counter) 74 | 75 | @classmethod 76 | def augment_spectra(cls,batch,z,noise=True,ratio=0.20): 77 | spec, w, _, ID = batch[:4] 78 | batch_size, spec_size = spec.shape 79 | device = spec.device 80 | wave_obs = cls._wave_obs.to(device) 81 | 82 | # uniform distribution of redshift offsets 83 | z_lim = 2e-8 # 6 m/s 84 | z_offset = z_lim*(torch.rand(batch_size,1, device=device)-0.5) 85 | 86 | wave_redshifted = wave_obs - wave_obs*z_offset 87 | 88 | # redshift interpolation 89 | spec_new = cubic_transform(wave_obs, spec, wave_redshifted) 90 | if noise: 91 | spec_noise = torch.normal(mean=0,std=w[0]**(-0.5), 92 | size=spec.shape,device=device) 93 | noise_mask = torch.rand(spec.shape).to(device)>ratio 94 | spec_noise[noise_mask]=0 95 | spec_new += spec_noise 96 | if spec.dtype==torch.float32:spec_new=spec_new.float() 97 | # ensure extrapolated values have zero weights 98 | wmin = wave_obs.min() 99 | wmax = wave_obs.max() 100 | out = (wave_redshiftedwmax) 101 | spec_new[out] = 1 102 | w_new = torch.clone(w) 103 | #acf = cls.acf_ccf(wave_obs,spec) 104 | return spec_new, w_new, _, z_offset 105 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import io, os, sys, time, random 4 | import numpy as np 5 | import pickle 6 | from scipy.special import gamma 7 | import matplotlib.pyplot as plt 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import IterableDataset 11 | from itertools import chain 12 | import pickle, humanize, psutil, GPUtil, io, random 13 | from torchinterp1d import Interp1d 14 | from torchcubicspline import natural_cubic_spline_coeffs 15 | from astropy.timeseries import LombScargle 16 | 17 | def cubic_evaluate(coeffs, tnew): 18 | t = coeffs[0] 19 | a,b,c,d = [item.squeeze(-1) for item in coeffs[1:]] 20 | maxlen = b.size(-1) - 1 21 | index = torch.bucketize(tnew, t) - 1 22 | index = index.clamp(0, maxlen) # clamp because t may go outside of [t[0], t[-1]]; this is fine 23 | # will never access the last element of self._t; this is correct behaviour 24 | fractional_part = tnew - t[index] 25 | 26 | batch_size, spec_size = tnew.shape 27 | batch_ind = torch.arange(batch_size,device=tnew.device) 28 | batch_ind = batch_ind.repeat((spec_size,1)).T 29 | 30 | inner = c[batch_ind, index] + d[batch_ind, index] * fractional_part 31 | inner = b[batch_ind, index] + inner * fractional_part 32 | return a[batch_ind, index] + inner * fractional_part 33 | 34 | def cubic_transform(xrest, yrest, wave_shifted): 35 | #wave_shifted = - xobs * z + xobs 36 | #print("xrest:",xrest.shape,"yrest:",yrest.shape) 37 | coeffs = natural_cubic_spline_coeffs(xrest, yrest.unsqueeze(-1)) 38 | out = cubic_evaluate(coeffs, wave_shifted) 39 | #print("out:",out.shape) 40 | return out 41 | 42 | def moving_mean(x,y,w=None,n=20,skip_weight=True): 43 | dx = (x.max()-x.min())/n 44 | xgrid = np.linspace(x.min(),x.max(),n+2) 45 | xgrid = xgrid[1:-1] 46 | ygrid = np.zeros_like(xgrid) 47 | delta_y = np.zeros_like(xgrid) 48 | for i,xmid in enumerate(xgrid): 49 | mask = x>(xmid-dx) 50 | mask *= x<(xmid+dx) 51 | if skip_weight: 52 | ygrid[i] = np.mean(y[mask]) 53 | delta_y[i] = y[mask].std()/np.sqrt(mask.sum()) 54 | else: 55 | ygrid[i] = np.average(y[mask],weights=w[mask]) 56 | delta_y[i] = np.sqrt(np.cov(y[mask], aweights=w[mask]))/np.sqrt(mask.sum()) 57 | return xgrid,ygrid,delta_y 58 | 59 | ''' 60 | def calculate_fft(time,signal): 61 | time_interval = time[1]-time[0] 62 | # Perform the FFT 63 | fft = np.fft.fft(signal) 64 | # Calculate the frequency axis 65 | freq_axis = np.fft.fftfreq(len(signal), time_interval) 66 | real = freq_axis>0 67 | p_axis = 1.0/freq_axis[real] 68 | # Only show the real part of the power spectrum 69 | power_spectrum = np.real(fft * np.conj(fft)) 70 | power_spectrum /= max(power_spectrum[real]) 71 | return p_axis,power_spectrum[real] 72 | ''' 73 | def plot_fft(timestamp,signals,fname,labels,period=100,fs=14): 74 | cs = ["grey","k","b","r"] 75 | alphas = [1,1,1,0.7] 76 | lw = [2,2,2,2] 77 | fig,ax = plt.subplots(figsize=(4,2.5),constrained_layout=True) 78 | pmax=0 79 | for i,ts in enumerate(signals[:len(cs)]): 80 | if "encode" in labels[i]:continue 81 | if "doppler" in labels[i]:continue 82 | frequency, power = LombScargle(timestamp, ts).autopower() 83 | p_axis = 1.0/frequency 84 | # Plot the result 85 | ax.plot(p_axis,power, c=cs[i],lw=lw[i],label="%s"%(labels[i]), alpha=alphas[i]) 86 | if power.max()>pmax: pmax = power.max() 87 | ax.set_xlim(1,299) 88 | ax.set_ylim(0,1.1*pmax) 89 | ax.set_xlabel('Period [days]');ax.set_ylabel('Power') 90 | ax.axvline(period,ls="--",c="grey",zorder=-10,label="$P_{true}$") 91 | if "uniform" in fname: 92 | ax.set_yticks([0.01,0.02,0.03]) 93 | title = r"$\mathbf{Case\ I \ (N=1000)}$" 94 | elif "dynamic" in fname: 95 | ax.set_yticks([0.05,0.10,0.15,0.20]) 96 | title = r"$\mathbf{Case\ II \ (N=200)}$" 97 | else:title="test" 98 | ax.legend(fontsize=fs,title=title) 99 | plt.savefig("[%s]periodogram.png"%fname,dpi=300) 100 | #with open("results-%s.pkl"%fname,"wb") as f: 101 | # pickle.dump(signals,f) 102 | # pickle.dump(labels,f) 103 | return 104 | 105 | def plot_sphere(pos,radius,ax,c="grey",alpha=0.5,zorder=0): 106 | u = np.linspace(0, 2 * np.pi, 100) 107 | v = np.linspace(0, np.pi, 100) 108 | x = radius * np.outer(np.cos(u), np.sin(v)) + pos[0] 109 | y = radius * np.outer(np.sin(u), np.sin(v)) + pos[1] 110 | z = radius* np.outer(np.ones(np.size(u)), np.cos(v)) + pos[2] 111 | # Plot the surface 112 | ax.plot_surface(x, y, z, alpha=alpha, zorder=zorder,color=c) 113 | return 114 | 115 | def density_plot(points,bins=30): 116 | x,y,z = points 117 | fig, ax = plt.subplots() 118 | density,X,Y,_ = ax.hist2d(x, y, bins=bins) 119 | #print("X,Y",X,Y) 120 | X, Y = np.meshgrid(X[1:],Y[1:]) 121 | mesh_dict = {"XY":[X,Y,density]} 122 | return mesh_dict 123 | 124 | def visualize_encoding(points,points_aug,RV_encode,radius=0,tag=None): 125 | 126 | axis_mean = points.mean(axis=1,keepdims=True) 127 | axis_std = points.std(axis=1,keepdims=True) 128 | points -= axis_mean 129 | points /= axis_std 130 | 131 | points_aug -= axis_mean 132 | points_aug /= axis_std 133 | 134 | rand = np.random.randint(points.shape[1],size=(points.shape[1])) 135 | print("rand:",rand.shape) 136 | N = len(rand) 137 | dist = ((points-points[:,rand])**2).sum(axis=0) 138 | dist_aug = ((points-points_aug)**2).sum(axis=0) 139 | 140 | print("random pairs: %.5f"%dist.mean(),dist.shape) 141 | print("augment pairs: %.5f"%dist_aug.mean(),dist_aug.shape) 142 | 143 | bins = np.logspace(-4,1,20) 144 | fig,ax = plt.subplots(figsize=(4,2.5),constrained_layout=True) 145 | _=ax.hist(dist,label=r"$\langle \Delta s_{rand} \rangle $: %.3f"%dist.mean(), 146 | color="b",bins=bins,log=False,histtype="stepfilled",alpha=0.7) 147 | _=ax.hist(dist_aug,label=r"$\langle \Delta s_{aug} \rangle$: %.3f"%dist_aug.mean(), 148 | color="r",bins=bins,log=False,histtype="stepfilled",alpha=0.7) 149 | ax.legend(loc=2);ax.set_xlabel("latent distance $\Delta s$");ax.set_ylabel("N") 150 | ax.set_xscale('log') 151 | plt.savefig("[%s]histogram.png"%tag,dpi=300) 152 | 153 | import matplotlib.colors 154 | 155 | elev=20;azim=150; dtr = np.pi/180.0 156 | viewpoint = np.array([np.cos(elev*dtr)*np.cos(azim*dtr), 157 | np.cos(elev*dtr)*np.sin(azim*dtr), 158 | np.sin(elev*dtr)]) 159 | dist = 8 160 | viewpoint *= dist 161 | print("viewpoint:",viewpoint.shape,"points:",points.shape) 162 | depth = ((points-viewpoint[:,None])**2).sum(axis=0)**0.5 163 | depth /= depth.min() 164 | size = 40/depth**2+5 165 | #colors = points[0] 166 | colors = RV_encode 167 | print("colors:",colors.min(),colors.max()) 168 | print("RV_encode:",RV_encode.min(),RV_encode.max()) 169 | #print("depth:",depth.shape) 170 | #print(size.min(),size.mean(),size.max()) 171 | # 3D rendering 172 | 173 | b_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["b","skyblue"]) 174 | r_cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["r","salmon"]) 175 | 176 | fig = plt.figure(figsize = (10, 8)) 177 | ax = plt.axes(projection ="3d") 178 | # Add x, y gridlines 179 | #ax.grid(b = True, color ='grey',linestyle ='-.', linewidth = 0.3,alpha = 0.2) 180 | pic = ax.scatter(points[0], points[1], points[2], s=size, marker="o", 181 | alpha=1,c=colors,cmap="viridis") 182 | #ax.scatter(points_aug[0], points_aug[1], points_aug[2], s=size, marker="o",alpha=1,c=colors,cmap=r_cmap) 183 | #for i in range(N):plt.plot([points[0,i],points_aug[0,i]],[points[1,i],points_aug[1,i]],[points[2,i],points_aug[2,i]],c="grey",lw=0.5) 184 | 185 | xlim=(-4, 5) 186 | ylim=(-3, 5) 187 | zlim=(-4, 7) 188 | 189 | ms=5;c="darkgrey" 190 | ax.scatter(points[0], points[1],[zlim[0]]*N,s=ms,c=c,alpha=1) 191 | ax.scatter(points_aug[0], points_aug[1],[zlim[0]]*N, 192 | s=ms,c=c,alpha=1) 193 | ms=5;c="grey" 194 | ax.scatter(points[0],[ylim[0]]*N, points[2],s=ms,c=c,alpha=1) 195 | ax.scatter(points_aug[0],[ylim[0]]*N, points_aug[2], 196 | s=ms,c=c,alpha=1) 197 | 198 | pos = [0,4,0] 199 | fs = 20 200 | # plot a sphere 201 | #if radius > 0:plot_sphere(pos,radius,ax,alpha=0.5,zorder=0) 202 | ax.set_proj_type('persp', focal_length=0.5) 203 | ax.set_xlabel("$s_1$",fontsize=fs) 204 | ax.set_ylabel("$s_2$",fontsize=fs) 205 | ax.set_zlabel("$s_3$",fontsize=fs) 206 | ax.xaxis.labelpad=-10 207 | ax.yaxis.labelpad=-10 208 | ax.zaxis.labelpad=-10 209 | 210 | ax.set_xticklabels([]);ax.set_yticklabels([]);ax.set_zticklabels([]) 211 | ax.view_init(elev=elev,azim=azim,roll=0) 212 | ax.dist=dist 213 | ax.set(xlim=xlim, ylim=ylim, zlim=zlim) 214 | cbar = fig.colorbar(pic, ax=ax,location = 'top', pad=0.0, shrink=0.4) 215 | #cbar.ax.set_xticks([]) 216 | #cbar.ax.set_xticklabels([-2,-1,0,1],fontsize=12) 217 | cbar.set_label("$v_{encode}$[m/s]",fontsize=16,labelpad=10) 218 | #ax.set_aspect('equal') 219 | #plt.subplots_adjust(left=0.08, bottom=0.08, right=0.95, top=0.98) 220 | plt.savefig("[%s]R1-3D.png"%tag,dpi=300) 221 | exit() 222 | return 223 | 224 | ############ Functions for creating batched files ############### 225 | class CPU_Unpickler(pickle.Unpickler): 226 | def find_class(self, module, name): 227 | if module == 'torch.storage' and name == '_load_from_bytes': 228 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 229 | else: return super().find_class(module, name) 230 | 231 | def load_batch(batch_name, subset=None): 232 | with open(batch_name, 'rb') as f: 233 | if torch.cuda.is_available(): 234 | batch = pickle.load(f) 235 | else: 236 | batch = CPU_Unpickler(f).load() 237 | 238 | if subset is not None: 239 | return batch[subset] 240 | return batch 241 | 242 | # based on https://medium.com/speechmatics/how-to-build-a-streaming-dataloader-with-pytorch-a66dd891d9dd 243 | class BatchedFilesDataset(IterableDataset): 244 | 245 | def __init__(self, file_list, load_fct, shuffle=False, shuffle_instance=False): 246 | assert len(file_list), "File list cannot be empty" 247 | self.file_list = file_list 248 | self.shuffle = shuffle 249 | self.shuffle_instance = shuffle_instance 250 | self.load_fct = load_fct 251 | 252 | def process_data(self, idx): 253 | if self.shuffle: 254 | idx = random.randint(0, len(self.file_list) -1) 255 | batch_name = self.file_list[idx] 256 | data = self.load_fct(batch_name) 257 | data = list(zip(*data)) 258 | if self.shuffle_instance: 259 | random.shuffle(data) 260 | for x in data: 261 | yield x 262 | 263 | def get_stream(self): 264 | return chain.from_iterable(map(self.process_data, range(len(self.file_list)))) 265 | 266 | def __iter__(self): 267 | return self.get_stream() 268 | 269 | def __len__(self): 270 | return len(self.file_list) 271 | 272 | 273 | def mem_report(): 274 | print("CPU RAM Free: " + humanize.naturalsize( psutil.virtual_memory().available )) 275 | 276 | if torch.cuda.device_count() ==0: return 277 | 278 | GPUs = GPUtil.getGPUs() 279 | for i, gpu in enumerate(GPUs): 280 | print('GPU {:d} ... Mem Free: {:.0f}MB / {:.0f}MB | Utilization {:3.0f}%'.format(i, gpu.memoryFree, gpu.memoryTotal, gpu.memoryUtil*100)) 281 | return 282 | 283 | 284 | def resample_to_restframe(wave_obs,wave_rest,y,w,z): 285 | wave_z = (wave_rest.unsqueeze(1)*(1 + z)).T 286 | wave_obs = wave_obs.repeat(y.shape[0],1) 287 | # resample observed spectra to restframe 288 | yrest = Interp1d()(wave_obs, y, wave_z) 289 | wrest = Interp1d()(wave_obs, w, wave_z) 290 | 291 | # interpolation = extrapolation outside of observed region, need to mask 292 | msk = (wave_z<=wave_obs.min())|(wave_z>=wave_obs.max()) 293 | # yrest[msk]=0 # not needed because all spectral elements are weighted 294 | wrest[msk]=0 295 | return yrest,wrest 296 | 297 | def generate_lines(xrange,max_amp=0.7,width=0.3,n_lines=100): 298 | amps = np.random.uniform(low=0.01,high=max_amp,size=n_lines) 299 | sigmas = np.random.normal(loc=width,scale=0.1*width,size=n_lines) 300 | line_loc = np.random.uniform(low=(xrange[0]+width),high=(xrange[1]-width),size=n_lines) 301 | sigmas = np.maximum(sigmas,0.01) 302 | lines = {"loc":line_loc,"amp":amps,"sigma":sigmas} 303 | return lines 304 | 305 | def evaluate_lines(wave,lines,z=0,depth=1,skew=0,broaden=1,window=5): 306 | abs_lines = np.ones_like(wave) 307 | line_location = lines["loc"]+lines["loc"]*z 308 | for i,loc in enumerate(line_location): 309 | amp,sigma = lines["amp"][i],broaden*lines["sigma"][i] 310 | mask = (wave>(loc-window*sigma))*(wave<(loc+window*sigma)) 311 | if skew>0:signal = gamma_profile(wave[mask],amp,loc,sigma, skew) 312 | else:signal = amp*np.exp(-0.5*((wave[mask]-loc)/sigma)**2) 313 | abs_lines[mask] *= (1-depth*signal) 314 | return abs_lines 315 | 316 | def gauss(x, *p): 317 | amp, mu, sigma, b = p 318 | return amp*np.exp(-(x-mu)**2/(2.*sigma**2))+b 319 | 320 | 321 | def gamma_profile(x, amp, mu, sigma, skew): 322 | a = 4/skew**2; b=2*a; sigma_0 = a**0.5/b 323 | mu0 = (a-1)/b 324 | y = np.zeros_like(x) 325 | xloc = ((x-mu)/sigma)*sigma_0 + mu0 326 | mask = xloc>0 327 | y[mask] = ((xloc[mask])**(a-1))*np.exp(-b*(xloc[mask])) 328 | y/=y.max() 329 | return amp*y 330 | -------------------------------------------------------------------------------- /spender_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torchinterp1d import Interp1d 5 | from util import cubic_transform 6 | 7 | #### Simple MLP #### 8 | class MLP(nn.Module): 9 | def __init__(self, 10 | n_in, 11 | n_out, 12 | n_hidden=(16, 16, 16), 13 | act=(nn.LeakyReLU(), nn.LeakyReLU(), nn.LeakyReLU(), nn.LeakyReLU()), 14 | dropout=0): 15 | super(MLP, self).__init__() 16 | 17 | layer = [] 18 | n_ = [n_in, *n_hidden, n_out] 19 | for i in range(len(n_)-1): 20 | layer.append(nn.Linear(n_[i], n_[i+1])) 21 | layer.append(act[i]) 22 | layer.append(nn.Dropout(p=dropout)) 23 | self.mlp = nn.Sequential(*layer) 24 | 25 | def forward(self, x): 26 | return self.mlp(x) 27 | 28 | 29 | class SpeculatorActivation(nn.Module): 30 | """Activation function from the Speculator paper 31 | .. math: 32 | a(\mathbf{x}) = \left[\boldsymbol{\gamma} + (1+e^{-\boldsymbol\beta\odot\mathbf{x}})^{-1}(1-\boldsymbol{\gamma})\right]\odot\mathbf{x} 33 | Paper: Alsing et al., 2020, ApJS, 249, 5 34 | Parameters 35 | ---------- 36 | n_parameter: int 37 | Number of parameters for the activation function to act on 38 | plus_one: bool 39 | Whether to add 1 to the output 40 | """ 41 | 42 | def __init__(self, n_parameter, plus_one=False): 43 | super().__init__() 44 | self.plus_one = plus_one 45 | self.beta = nn.Parameter(torch.randn(n_parameter), requires_grad=True) 46 | self.gamma = nn.Parameter(torch.randn(n_parameter), requires_grad=True) 47 | 48 | def forward(self, x): 49 | """Forward method 50 | Parameters 51 | ---------- 52 | x: `torch.tensor` 53 | Returns 54 | ------- 55 | x': `torch.tensor`, same shape as `x` 56 | """ 57 | # eq 8 in Alsing+2020 58 | x = (self.gamma + (1 - self.gamma) * torch.sigmoid(self.beta * x)) * x 59 | if self.plus_one: 60 | return x + 1 61 | return x 62 | 63 | 64 | class RVEstimator(nn.Module): 65 | def __init__(self, 66 | n_in, 67 | sizes = [5,10], 68 | n_hidden=(128, 64, 32), 69 | act=(nn.PReLU(128),nn.PReLU(64),nn.PReLU(32), nn.Identity()), 70 | dropout=0): 71 | super(RVEstimator, self).__init__() 72 | 73 | filters = [128,64] 74 | self.conv1,self.conv2 = self._conv_blocks(filters, sizes, dropout=dropout) 75 | self.n_feature = filters[-1] * ((n_in //sizes[0])//sizes[1]) 76 | 77 | self.pool1, self.pool2 = tuple(nn.MaxPool1d(s) for s in sizes[:2]) 78 | print("self.n_feature:",self.n_feature) 79 | self.mlp = MLP(self.n_feature, 1, n_hidden=n_hidden, act=act, dropout=dropout) 80 | self.flatten = nn.Flatten() 81 | self.softmax = nn.Softmax(dim=-1) 82 | 83 | def _conv_blocks(self, filters, sizes, dropout=0): 84 | convs = [] 85 | for i in range(len(filters)): 86 | f_in = 1 if i == 0 else filters[i-1] 87 | f = filters[i] 88 | s = sizes[i] 89 | p = s // 2 90 | conv = nn.Conv1d(in_channels=f_in, 91 | out_channels=f, 92 | kernel_size=s, 93 | padding=p, 94 | ) 95 | norm = nn.InstanceNorm1d(f) 96 | act = nn.PReLU(num_parameters=f) 97 | drop = nn.Dropout(p=dropout) 98 | convs.append(nn.Sequential(conv, norm, act, drop)) 99 | return tuple(convs) 100 | 101 | def forward(self, x): 102 | # compression 103 | x = x.unsqueeze(1) 104 | x = self.pool1(self.conv1(x)) 105 | x = self.pool2(self.conv2(x)) 106 | x = self.softmax(x) 107 | x = self.flatten(x) 108 | x = self.mlp(x) 109 | return x 110 | 111 | 112 | class NullRVEstimator(nn.Module): 113 | def __init__(self): 114 | super(NullRVEstimator, self).__init__() 115 | 116 | def forward(self, x): 117 | return torch.zeros((x.shape[0],1),device=x.device) 118 | 119 | #### Spectrum encoder #### 120 | #### based on Serra 2018 #### 121 | #### with robust feature combination from Geisler 2020 #### 122 | class SpectrumEncoder(nn.Module): 123 | def __init__(self, 124 | instrument, 125 | n_latent, 126 | n_hidden=(128, 64, 32), 127 | act=(nn.PReLU(128), nn.PReLU(64), nn.PReLU(32), nn.Identity()), 128 | n_aux=0, 129 | dropout=0): 130 | 131 | super(SpectrumEncoder, self).__init__() 132 | self.instrument = instrument 133 | self.n_latent = n_latent 134 | self.n_aux = n_aux 135 | 136 | filters = [128, 256, 512] 137 | sizes = [5, 11, 21] 138 | self.conv1, self.conv2, self.conv3 = self._conv_blocks(filters, sizes, dropout=dropout) 139 | self.n_feature = filters[-1] // 2 140 | 141 | # pools and softmax work for spectra and weights 142 | self.pool1, self.pool2 = tuple(nn.MaxPool1d(s, padding=s//2) for s in sizes[:2]) 143 | self.softmax = nn.Softmax(dim=-1) 144 | 145 | # small MLP to go from CNN features to latents 146 | self.mlp = MLP(self.n_feature + n_aux, self.n_latent, n_hidden=n_hidden, act=act, dropout=dropout) 147 | 148 | def _conv_blocks(self, filters, sizes, dropout=0): 149 | convs = [] 150 | for i in range(len(filters)): 151 | f_in = 1 if i == 0 else filters[i-1] 152 | f = filters[i] 153 | s = sizes[i] 154 | p = s // 2 155 | conv = nn.Conv1d(in_channels=f_in, 156 | out_channels=f, 157 | kernel_size=s, 158 | padding=p, 159 | ) 160 | norm = nn.InstanceNorm1d(f) 161 | act = nn.PReLU(num_parameters=f) 162 | drop = nn.Dropout(p=dropout) 163 | convs.append(nn.Sequential(conv, norm, act, drop)) 164 | return tuple(convs) 165 | 166 | def _downsample(self, x): 167 | # compression 168 | x = x.unsqueeze(1) 169 | x = self.pool1(self.conv1(x)) 170 | x = self.pool2(self.conv2(x)) 171 | x = self.conv3(x) 172 | C = x.shape[1] // 2 173 | # split half channels into attention value and key 174 | h, a = torch.split(x, [C, C], dim=1) 175 | 176 | return h, a 177 | 178 | def forward(self, x, aux=None): 179 | # run through CNNs 180 | h, a = self._downsample(x) 181 | # softmax attention 182 | a = self.softmax(a) 183 | # apply attention 184 | x = torch.sum(h * a, dim=2) 185 | # redshift depending feature combination to final latents 186 | if aux is not None and aux is not False: 187 | x = torch.cat((x, aux), dim=-1) 188 | x = self.mlp(x) 189 | return x 190 | 191 | @property 192 | def n_parameters(self): 193 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 194 | 195 | 196 | #### Spectrum decoder #### 197 | #### Simple MLP but with explicit redshift and instrument path #### 198 | class SpectrumDecoder(MLP): 199 | def __init__(self, 200 | wave_rest, 201 | spec_rest, 202 | n_latent=5, 203 | n_hidden=(64, 256, 1024), 204 | act=None, 205 | dropout=0, 206 | datatag="mockdata", 207 | ): 208 | 209 | if act==None: 210 | act = [nn.LeakyReLU() for i in range(len(n_hidden)+1)] 211 | #act = [SpeculatorActivation(n) for n in n_hidden] 212 | #act.append(SpeculatorActivation(len(wave_rest))) 213 | 214 | super(SpectrumDecoder, self).__init__( 215 | n_latent, 216 | len(wave_rest), 217 | n_hidden=n_hidden, 218 | act=act, 219 | dropout=dropout, 220 | ) 221 | 222 | self.n_latent = n_latent 223 | #self.decode_act = nn.Identity() 224 | self.decode_act = nn.LeakyReLU() 225 | # register wavelength tensors on the same device as the entire model 226 | if spec_rest is None: 227 | self.spec_rest= torch.nn.Parameter(torch.randn(len(wave_rest))) 228 | else: self.spec_rest= torch.nn.Parameter(spec_rest) 229 | self.register_buffer('wave_rest', wave_rest) 230 | 231 | def decode(self, s): 232 | x = super().forward(s) 233 | x = -self.decode_act(-x) 234 | return x 235 | 236 | def forward(self, s): 237 | return self.decode(s) 238 | 239 | def transform(self, spectrum_restframe, z, instrument=None): 240 | xx = self.wave_rest 241 | 242 | if instrument in [False, None]: 243 | wave_obs = self.wave_rest 244 | else: 245 | wave_obs = instrument.wave_obs 246 | 247 | wave_redshifted = - wave_obs * z + wave_obs 248 | spectrum = cubic_transform(xx, spectrum_restframe, wave_redshifted) 249 | 250 | # convolve with LSF 251 | if instrument.lsf is not None: 252 | spectrum = instrument.lsf(spectrum.unsqueeze(1)).squeeze(1) 253 | 254 | # apply calibration function to observed spectrum 255 | if instrument is not None and instrument.calibration is not None: 256 | spectrum = instrument.calibration(wave_obs, spectrum) 257 | 258 | return spectrum 259 | 260 | @property 261 | def n_parameters(self): 262 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 263 | 264 | 265 | # Combine spectrum encoder and decoder 266 | class BaseAutoencoder(nn.Module): 267 | def __init__(self, 268 | encoder, 269 | decoder, 270 | rv_estimator, 271 | normalize=False, 272 | ): 273 | 274 | super(BaseAutoencoder, self).__init__() 275 | assert encoder.n_latent == decoder.n_latent 276 | self.encoder = encoder 277 | self.decoder = decoder 278 | self.rv_estimator = rv_estimator 279 | self.normalize = normalize 280 | 281 | def encode(self, x, aux=None): 282 | return self.encoder(x, aux=aux) 283 | 284 | def decode(self, x): 285 | return self.decoder(x) 286 | 287 | def estimate_rv(self,x): 288 | # estimate z 289 | return self.rv_estimator(x) 290 | 291 | def _forward(self, x, w, s, z, instrument=None, aux=None): 292 | if w.dim()==1:w=w.unsqueeze(1) 293 | 294 | if instrument is None: 295 | instrument = self.encoder.instrument 296 | 297 | if self.decoder.spec_rest == None: baseline = 1.0 298 | else: baseline = self.decoder.spec_rest 299 | spectrum_activity = self.decode(s) 300 | spectrum_restframe = baseline+spectrum_activity 301 | spectrum_observed = self.decoder.transform(spectrum_restframe, z, instrument=instrument) 302 | 303 | if self.normalize: 304 | c = self._normalization(x, spectrum_observed, w=w) 305 | spectrum_observed = spectrum_observed * c 306 | spectrum_restframe = spectrum_restframe * c 307 | 308 | return spectrum_activity, spectrum_restframe, spectrum_observed 309 | 310 | def forward(self, x, w, s, z, instrument=None, aux=None): 311 | spectrum_activity, spectrum_restframe, spectrum_observed = self._forward(x, w, s, z, instrument=instrument, aux=aux) 312 | return spectrum_observed 313 | 314 | def loss(self, x, w, s, z, instrument=None, aux=None, individual=False): 315 | spectrum_observed = self.forward(x, w, s, z, instrument=instrument, aux=aux) 316 | return self._loss(x, w, spectrum_observed, individual=individual) 317 | 318 | def _loss(self, x, w, spectrum_observed, individual=False): 319 | # loss = total squared deviation in units of variance 320 | # if the model is identical to observed spectrum (up to the noise), 321 | # then loss per object = D (number of non-zero bins) 322 | 323 | # to make it to order unity for comparing losses, divide out L (number of bins) 324 | # instead of D, so that spectra with more valid bins have larger impact 325 | if w.dim()==1:w=w.unsqueeze(1) 326 | loss_ind = torch.sum(w * (x - spectrum_observed).pow(2), dim=1) / x.shape[1] 327 | 328 | if individual: 329 | return loss_ind 330 | 331 | return torch.sum(loss_ind) 332 | 333 | def _normalization(self, x, m, w=None): 334 | # apply constant factor c that minimizes (c*m - x)^2 335 | if w is None: 336 | w = 1 337 | mw = m*w 338 | c = (mw * x).sum(dim=-1) / (mw * m).sum(dim=-1) 339 | return c.unsqueeze(-1) 340 | 341 | @property 342 | def n_parameter(self): 343 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 344 | 345 | @property 346 | def wave_obs(self): 347 | return self.encoder.instrument.wave_obs 348 | 349 | @property 350 | def wave_rest(self): 351 | return self.decoder.wave_rest 352 | 353 | class SpectrumAutoencoder(BaseAutoencoder): 354 | def __init__(self, 355 | instrument, 356 | wave_rest, 357 | spec_rest=None, 358 | rv_estimator=None, 359 | n_latent=10, 360 | n_aux=0, 361 | n_hidden=(64, 256, 1024), 362 | act=None, 363 | normalize=False, 364 | ): 365 | 366 | encoder = SpectrumEncoder(instrument, n_latent, n_aux=n_aux) 367 | 368 | decoder = SpectrumDecoder( 369 | wave_rest, 370 | spec_rest, 371 | n_latent, 372 | n_hidden=n_hidden, 373 | act=act, 374 | ) 375 | 376 | if rv_estimator==None: 377 | rv_estimator = RVEstimator(instrument.wave_obs.shape[0],sizes = [10,10]) 378 | #rv_estimator = NullRVEstimator() 379 | 380 | super(SpectrumAutoencoder, self).__init__( 381 | encoder, 382 | decoder, 383 | rv_estimator, 384 | normalize=normalize, 385 | ) 386 | -------------------------------------------------------------------------------- /train_aestra_synthetic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import time, argparse, os 4 | import numpy as np 5 | import functools 6 | import torch 7 | from torch import nn 8 | from torch import optim 9 | from accelerate import Accelerator 10 | # allows one to run fp16_train.py from home directory 11 | import sys;sys.path.insert(1, './') 12 | from spender_model import SpectrumAutoencoder,NullRVEstimator 13 | from synthetic_data import Synthetic 14 | from util import mem_report 15 | from functools import partial 16 | from util import BatchedFilesDataset, load_batch 17 | from torch.utils.data import DataLoader,Dataset 18 | from torchinterp1d import Interp1d 19 | from line_profiler import LineProfiler 20 | from scipy.special import digamma 21 | 22 | def corrcoef(tensor, rowvar=True, bias=False): 23 | """Estimate a corrcoef matrix (np.corrcoef) 24 | https://gist.github.com/ModarTensai/5ab449acba9df1a26c12060240773110 25 | """ 26 | tensor = tensor if rowvar else tensor.transpose(-1, -2) 27 | tensor = tensor - tensor.mean(dim=-1, keepdim=True) 28 | factor = 1 / (tensor.shape[-1] - int(not bool(bias))) 29 | covmat = factor * tensor @ tensor.transpose(-1, -2).conj() 30 | std = torch.diag(covmat)**0.5 31 | covmat /= std[:, None] 32 | covmat /= std[None, :] 33 | return covmat 34 | 35 | def avgdigamma(dist,radius): 36 | num_points = torch.count_nonzero(dist<(radius-1e-15),dim=0).double() 37 | return (torch.digamma(num_points)).mean() 38 | 39 | def mutual_information(x0, y0, k=50, base=2): 40 | x = (x0-x0.mean(dim=0))/x0.std(dim=0) 41 | y = (y0-y0.mean())/y0.std() 42 | """Mutual information of x and y 43 | """ 44 | assert x.shape[0] == y.shape[0], "Arrays should have same length" 45 | assert y.shape[1] == 1, "Single value function" 46 | assert k <= x.shape[0] - 1, "Set k smaller than num. samples - 1" 47 | # Find nearest neighbors in joint space, 48 | points = torch.cat((x, y),dim=1) 49 | 50 | # Find nearest neighbors in joint space, p=inf means max-norm 51 | distmat = torch.abs(points[:,None]-points[None,:]) 52 | dvec = torch.kthvalue(torch.amax(distmat,2),k+1,dim=1)[0] 53 | 54 | a, b, c, d = ( 55 | avgdigamma(torch.amax(distmat[:,:,:-1],2),dvec), 56 | avgdigamma(distmat[:,:,-1],dvec), 57 | digamma(k), 58 | digamma(x.shape[0]), 59 | ) 60 | return (-a - b + c + d) / np.log(base) 61 | 62 | def prepare_train(seq,niter=100000): 63 | for d in seq: 64 | if not "iteration" in d:d["iteration"]=niter 65 | if not "encoder" in d:d.update({"encoder":d["data"]}) 66 | return seq 67 | 68 | def build_ladder(train_sequence): 69 | n_iter = sum([item['iteration'] for item in train_sequence]) 70 | 71 | ladder = np.zeros(n_iter,dtype='int') 72 | n_start = 0 73 | for i,mode in enumerate(train_sequence): 74 | n_end = n_start+mode['iteration'] 75 | ladder[n_start:n_end]= i 76 | n_start = n_end 77 | return ladder 78 | 79 | def get_all_parameters(models,instruments): 80 | model_params = [] 81 | # multiple encoders 82 | for model in models: 83 | model_params += model.encoder.parameters() 84 | model_params += model.rv_estimator.parameters() 85 | # 1 decoder 86 | model_params += model.decoder.parameters() 87 | dicts = [{'params':model_params}] 88 | 89 | n_parameters = sum([p.numel() for p in model_params if p.requires_grad]) 90 | 91 | instr_params = [] 92 | # instruments 93 | for inst in instruments: 94 | if inst==None:continue 95 | instr_params += inst.parameters() 96 | s = [p.numel() for p in inst.parameters()] 97 | #print("Adding %d parameters..."%sum(s)) 98 | if instr_params != []: 99 | dicts.append({'params':instr_params,'lr': 1e-4}) 100 | n_parameters += sum([p.numel() for p in instr_params if p.requires_grad]) 101 | print("parameter dict:",dicts[1]) 102 | return dicts,n_parameters 103 | 104 | def consistency_loss(s, s_aug, individual=False, sigma_s=0.02): 105 | batch_size, s_size = s.shape 106 | ds = torch.sum((s_aug - s)**2/(sigma_s)**2,dim=1)/(s_size) 107 | cons_loss = torch.sigmoid(ds)-0.5 # zero = perfect alignment 108 | if individual: 109 | return cons_loss 110 | return cons_loss.sum() 111 | 112 | def z_offset_loss(z_off, z_off_true, sigma_z=1e-9,individual=False): 113 | z_loss = ((z_off - z_off_true)/sigma_z)**2 114 | if individual:return z_loss 115 | return z_loss.sum() 116 | 117 | def restframe_weight(model,instrument,xrange=[5000.,5050.],sn=1000): 118 | x = model.decoder.wave_rest 119 | w = torch.zeros_like(x).float() 120 | w[(x>xrange[0])*(x0 314 | losses = losses[:,:,non_zero,:] 315 | 316 | epoch = len(losses[0][0]) 317 | 318 | n_epoch += epoch 319 | detailed_loss = np.zeros((2, n_encoder, n_epoch, n_loss)) 320 | detailed_loss[:, :, :epoch, :] = losses 321 | 322 | if verbose: 323 | losses = tuple(detailed_loss[0, :, epoch-1, :]) 324 | vlosses = tuple(detailed_loss[1, :, epoch-1, :]) 325 | print(f'====> Epoch: {epoch-1}') 326 | print('TRAINING Losses:', losses) 327 | print('VALIDATION Losses:', vlosses) 328 | except: # OK if losses are empty 329 | print("loss empty...") 330 | pass 331 | 332 | if outfile is None: 333 | outfile = "checkpoint.pt" 334 | 335 | for epoch_ in range(epoch, n_epoch): 336 | 337 | mode = train_sequence[ladder[epoch_ - epoch]] 338 | 339 | # turn on/off model decoder 340 | for p in models[0].decoder.parameters(): 341 | p.requires_grad = mode['decoder'] 342 | models[0].decoder.spec_rest.requires_grad = mode['spec_rest'] 343 | 344 | slope = ANNEAL_SCHEDULE[(epoch_ - epoch)%len(ANNEAL_SCHEDULE)] 345 | if n_epoch-epoch_<=10: slope=0 # turn off similarity 346 | 347 | if verbose and similarity: 348 | print("similarity info:",slope) 349 | 350 | for which in range(n_encoder): 351 | 352 | # turn on/off encoder 353 | print("Encoder:",mode['encoder'][which]) 354 | for p in models[which].encoder.parameters(): 355 | p.requires_grad = mode['encoder'][which] 356 | # turn on/off rv_estimator 357 | print("RV estimator:",mode['rv'][which]) 358 | for p in models[which].rv_estimator.parameters(): 359 | p.requires_grad = mode['rv'][which] 360 | 361 | # optional: training on single dataset 362 | if not mode['data'][which]: 363 | continue 364 | 365 | models[which].train() 366 | instruments[which].train() 367 | 368 | n_sample = 0 369 | for k, batch in enumerate(trainloaders[which]): 370 | batch_size = len(batch[0]) 371 | losses = get_losses( 372 | models[which], 373 | instruments[which], 374 | batch, 375 | template_data, 376 | aug_fct=aug_fcts[which], 377 | similarity=similarity, 378 | consistency=consistency, 379 | flexibility=flexibility, 380 | slope=slope, 381 | skipfid=skipfid 382 | ) 383 | # sum up all losses 384 | loss = functools.reduce(lambda a, b: a+b , losses) 385 | accelerator.backward(loss) 386 | # clip gradients: stabilizes training with similarity 387 | accelerator.clip_grad_norm_(model_parameters[0]['params'], 1.0) 388 | # once per batch 389 | optimizer.step() 390 | optimizer.zero_grad() 391 | 392 | # logging: training 393 | detailed_loss[0][which][epoch_] += tuple( l.item() if hasattr(l, 'item') else 0 for l in losses ) 394 | n_sample += batch_size 395 | 396 | # stop after n_batch 397 | if n_batch is not None and k == n_batch - 1: 398 | break 399 | detailed_loss[0][which][epoch_] /= n_sample 400 | 401 | scheduler.step() 402 | 403 | with torch.no_grad(): 404 | for which in range(n_encoder): 405 | models[which].eval() 406 | instruments[which].eval() 407 | 408 | n_sample = 0 409 | for k, batch in enumerate(validloaders[which]): 410 | batch_size = len(batch[0]) 411 | losses = get_losses( 412 | models[which], 413 | instruments[which], 414 | batch, 415 | template_data, 416 | aug_fct=aug_fcts[which], 417 | similarity=similarity, 418 | consistency=consistency, 419 | flexibility=flexibility, 420 | slope=slope, 421 | skipfid=skipfid 422 | ) 423 | # logging: validation 424 | detailed_loss[1][which][epoch_] += tuple( l.item() if hasattr(l, 'item') else 0 for l in losses ) 425 | n_sample += batch_size 426 | 427 | # stop after n_batch 428 | if n_batch is not None and k == n_batch - 1: 429 | break 430 | 431 | detailed_loss[1][which][epoch_] /= n_sample 432 | 433 | if verbose: 434 | #mem_report() 435 | losses = tuple(detailed_loss[0, :, epoch_, :]) 436 | vlosses = tuple(detailed_loss[1, :, epoch_, :]) 437 | print('====> Epoch: %i'%(epoch_)) 438 | print('TRAINING Losses:', losses) 439 | print('VALIDATION Losses:', vlosses) 440 | 441 | if epoch_ % 60 == 0 or epoch_ == n_epoch - 1: 442 | args = models 443 | checkpoint(accelerator, args, optimizer, scheduler, n_encoder, outfile, detailed_loss) 444 | 445 | 446 | if __name__ == "__main__": 447 | 448 | parser = argparse.ArgumentParser() 449 | parser.add_argument("data", help="dataset name") 450 | parser.add_argument("dir", help="data file directory") 451 | parser.add_argument("outfile", help="output file name") 452 | parser.add_argument("-n", "--latents", help="latent dimensionality", type=int, default=2) 453 | parser.add_argument("-b", "--batch_size", help="batch size", type=int, default=512) 454 | parser.add_argument("-l", "--batch_number", help="number of batches per epoch", type=int, default=None) 455 | parser.add_argument("-r", "--rate", help="learning rate", type=float, default=1e-3) 456 | parser.add_argument("-z", "--rv_file", help="rv estimator", type=str, default="None") 457 | parser.add_argument("-it", "--iteration", help="number of interation", type=int, default=100000) 458 | parser.add_argument("-s", "--similarity", help="add similarity loss", action="store_true") 459 | parser.add_argument("-skipfid", "--skipfid", help="skip fidelity loss", action="store_true",default=False) 460 | parser.add_argument("-c", "--consistency", help="add consistency loss", action="store_true") 461 | parser.add_argument("-d", "--double", help="double precision", action="store_true",default=False) 462 | parser.add_argument("-init", "--init", help="initialize restframe", action="store_true",default=False) 463 | parser.add_argument("-f", "--flexibility", help="constrian model flexibility", action="store_true",default=False) 464 | parser.add_argument("-C", "--clobber", help="continue training of existing model", action="store_true") 465 | parser.add_argument("-v", "--verbose", help="verbose printing", action="store_true") 466 | args = parser.parse_args() 467 | 468 | # define instruments 469 | instruments = [ Synthetic() ] 470 | n_encoder = len(instruments) 471 | 472 | # restframe wavelength for reconstructed spectra 473 | #lmbda_min = 4999.0;lmbda_max = 5011.0;bins = 1200 474 | wave_rest = Synthetic.wave_rest 475 | 476 | # data loaders 477 | trainloaders = [ get_data_loader(args.dir, select=args.data, which="train", 478 | batch_size=args.batch_size,double=args.double) for inst in instruments ] 479 | validloaders = [ get_data_loader(args.dir, select=args.data, which="train", 480 | batch_size=args.batch_size,double=args.double) for inst in instruments ] 481 | 482 | template_data = load_batch("%s%s-template.pkl"%(args.dir,args.data)) 483 | 484 | if args.init: init_restframe = load_batch("%s%s-rest.pkl"%(args.dir,args.data))[0] 485 | else: 486 | init_restframe = Interp1d()(instruments[0].wave_obs, template_data[0], wave_rest) 487 | 488 | if args.double: 489 | template_data = [item.double() for item in template_data] 490 | if args.init: init_restframe = init_restframe.double() 491 | 492 | # get augmentation function 493 | aug_fcts = [ Synthetic.augment_spectra ] 494 | 495 | # define training sequence 496 | FULL = {"data":[True],"encoder":[True],"rv":[True], 497 | "decoder":True,"spec_rest":True} 498 | train_sequence = prepare_train([FULL],niter=args.iteration) 499 | 500 | annealing_step = 1000 501 | ANNEAL_SCHEDULE = np.linspace(0.0,1.0,annealing_step) 502 | if args.verbose and args.similarity: 503 | print("similarity_slope:",len(ANNEAL_SCHEDULE),ANNEAL_SCHEDULE) 504 | 505 | # define and train the model 506 | n_hidden = (64, 256, 1024) 507 | models = [ SpectrumAutoencoder(instrument, 508 | wave_rest, 509 | spec_rest=init_restframe, 510 | n_latent=args.latents, 511 | n_hidden=n_hidden, 512 | n_aux=0, 513 | normalize=False) 514 | for instrument in instruments ] 515 | print("RVEstimator:",models[0].rv_estimator) 516 | 517 | # use same decoder 518 | if n_encoder==2:models[1].decoder = models[0].decoder 519 | if args.double:[model.double() for model in models] 520 | n_epoch = sum([item['iteration'] for item in train_sequence]) 521 | init_t = time.time() 522 | if args.verbose: 523 | print("torch.cuda.device_count():",torch.cuda.device_count()) 524 | print (f"--- Model {args.outfile} ---") 525 | 526 | # check if outfile already exists, continue only of -c is set 527 | if os.path.isfile(args.outfile) and not args.clobber: 528 | raise SystemExit("\nOutfile exists! Set option -C to continue training.") 529 | losses = None 530 | if os.path.isfile(args.outfile): 531 | if args.verbose: 532 | print("\nLoading file %s"%args.outfile) 533 | models, losses = load_model(args.outfile, models, instruments) 534 | 535 | if os.path.isfile(args.rv_file): 536 | if args.verbose: 537 | print("\nUpdating RV estimator based on file %s"%args.rv_file) 538 | models = update_rv_estimator(args.rv_file, models, instruments) 539 | 540 | profiler = LineProfiler() 541 | profiler.add_function(partial) 542 | profiler.add_function(load_batch) 543 | lpWrapper = profiler(train) 544 | lpWrapper(models, instruments, trainloaders, validloaders, template_data, n_epoch=n_epoch, 545 | n_batch=args.batch_number, lr=args.rate, aug_fcts=aug_fcts, similarity=args.similarity, consistency=args.consistency, flexibility=args.flexibility, skipfid=args.skipfid,outfile=args.outfile, losses=losses, verbose=args.verbose) 546 | 547 | profiler.print_stats() 548 | #train(models, instruments, trainloaders, validloaders, template_data, n_epoch=n_epoch,n_batch=args.batch_number, lr=args.rate, aug_fcts=aug_fcts, similarity=args.similarity, consistency=args.consistency, outfile=args.outfile, losses=losses, verbose=args.verbose) 549 | 550 | if args.verbose: 551 | print("--- %s seconds ---" % (time.time()-init_t)) 552 | --------------------------------------------------------------------------------