├── .gitignore ├── LICENSE ├── Models ├── cosmo_IllustrisTNG_LH_onlypos_0_lr_1.62e-07_weightdecay_1.00e-07_layers_2_rlink_1.50e-02_channels_64_epochs_300 └── cosmo_SIMBA_LH_onlypos_0_lr_1.09e-06_weightdecay_1.00e-07_layers_4_rlink_1.48e-02_channels_64_epochs_300 ├── PS_files ├── Pk_galaxies_IllustrisTNG_LH_10_kmax=20.0.npy ├── Pk_galaxies_IllustrisTNG_LH_14_kmax=20.0.npy ├── Pk_galaxies_IllustrisTNG_LH_18_kmax=20.0.npy ├── Pk_galaxies_IllustrisTNG_LH_24_kmax=20.0.npy ├── Pk_galaxies_IllustrisTNG_LH_33_kmax=20.0.npy ├── Pk_galaxies_SIMBA_LH_10_kmax=20.0.npy ├── Pk_galaxies_SIMBA_LH_14_kmax=20.0.npy ├── Pk_galaxies_SIMBA_LH_18_kmax=20.0.npy ├── Pk_galaxies_SIMBA_LH_24_kmax=20.0.npy ├── Pk_galaxies_SIMBA_LH_33_kmax=20.0.npy └── k_values.txt ├── README.md ├── Source ├── constants.py ├── load_data.py ├── metalayer.py ├── plotting.py └── training.py ├── crosstest.py ├── hyperparameters.py ├── hyperparams_optimization.py ├── main.py ├── ps_test.py ├── visualize_graph_10.png └── visualize_graphs.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | batchfile.slurm 3 | crosstest_batchfile.slurm 4 | tunning_batchfile.slurm 5 | __pycache__/ 6 | plotresults.py 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Pablo Villanueva Domingo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Models/cosmo_IllustrisTNG_LH_onlypos_0_lr_1.62e-07_weightdecay_1.00e-07_layers_2_rlink_1.50e-02_channels_64_epochs_300: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/Models/cosmo_IllustrisTNG_LH_onlypos_0_lr_1.62e-07_weightdecay_1.00e-07_layers_2_rlink_1.50e-02_channels_64_epochs_300 -------------------------------------------------------------------------------- /Models/cosmo_SIMBA_LH_onlypos_0_lr_1.09e-06_weightdecay_1.00e-07_layers_4_rlink_1.48e-02_channels_64_epochs_300: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/Models/cosmo_SIMBA_LH_onlypos_0_lr_1.09e-06_weightdecay_1.00e-07_layers_4_rlink_1.48e-02_channels_64_epochs_300 -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_IllustrisTNG_LH_10_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_IllustrisTNG_LH_10_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_IllustrisTNG_LH_14_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_IllustrisTNG_LH_14_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_IllustrisTNG_LH_18_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_IllustrisTNG_LH_18_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_IllustrisTNG_LH_24_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_IllustrisTNG_LH_24_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_IllustrisTNG_LH_33_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_IllustrisTNG_LH_33_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_SIMBA_LH_10_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_SIMBA_LH_10_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_SIMBA_LH_14_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_SIMBA_LH_14_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_SIMBA_LH_18_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_SIMBA_LH_18_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_SIMBA_LH_24_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_SIMBA_LH_24_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/Pk_galaxies_SIMBA_LH_33_kmax=20.0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PabloVD/CosmoGraphNet/36041d7e3b23ae9e288cd7f178e6c2888b54042b/PS_files/Pk_galaxies_SIMBA_LH_33_kmax=20.0.npy -------------------------------------------------------------------------------- /PS_files/k_values.txt: -------------------------------------------------------------------------------- 1 | 3.559856489878782115e-01 2 | 6.031648866026231293e-01 3 | 8.579160745827347778e-01 4 | 1.113464310595710716e+00 5 | 1.374039914321364408e+00 6 | 1.619154704570298309e+00 7 | 1.870686673430271840e+00 8 | 2.123290040706380832e+00 9 | 2.378144976376236119e+00 10 | 2.630489925835516996e+00 11 | 2.878653162679494670e+00 12 | 3.131144366275339319e+00 13 | 3.388159578333816313e+00 14 | 3.638741848210948149e+00 15 | 3.888315910732085268e+00 16 | 4.139052155837465996e+00 17 | 4.390305777018236988e+00 18 | 4.642172339120373081e+00 19 | 4.894678666149941293e+00 20 | 5.147269052361176378e+00 21 | 5.398003206656323272e+00 22 | 5.647798235368590269e+00 23 | 5.900534302989873581e+00 24 | 6.151460297009074374e+00 25 | 6.402630176204441526e+00 26 | 6.655964674748322096e+00 27 | 6.908782583105771202e+00 28 | 7.159060853508338873e+00 29 | 7.409414762591408632e+00 30 | 7.659227149364629561e+00 31 | 7.911559474606142395e+00 32 | 8.163750941790897997e+00 33 | 8.415757210439384295e+00 34 | 8.667638839289207198e+00 35 | 8.918801124864609520e+00 36 | 9.170480108345369530e+00 37 | 9.420566033467249412e+00 38 | 9.671642335271435797e+00 39 | 9.923426157758477473e+00 40 | 1.017396902749181820e+01 41 | 1.042533700753922155e+01 42 | 1.067702948358784276e+01 43 | 1.092969337256462303e+01 44 | 1.118211872033766419e+01 45 | 1.143346516925797296e+01 46 | 1.168350508557724687e+01 47 | 1.193379937775929100e+01 48 | 1.218553055762879112e+01 49 | 1.243711650801278878e+01 50 | 1.268867268647359303e+01 51 | 1.294075027177769144e+01 52 | 1.319267287020158363e+01 53 | 1.344470571782862756e+01 54 | 1.369480150283709641e+01 55 | 1.394591277066507828e+01 56 | 1.419764645764527167e+01 57 | 1.444766774168416212e+01 58 | 1.469906835825429070e+01 59 | 1.495059118111036511e+01 60 | 1.520261612590052991e+01 61 | 1.545415944503929495e+01 62 | 1.570456072545780124e+01 63 | 1.595705800128773966e+01 64 | 1.620829297830731619e+01 65 | 1.645966591028190962e+01 66 | 1.671125708846596325e+01 67 | 1.696160259490052624e+01 68 | 1.721370276108077135e+01 69 | 1.746543424102250341e+01 70 | 1.771598473288305442e+01 71 | 1.796653767411944003e+01 72 | 1.821765658501893270e+01 73 | 1.847053788120423334e+01 74 | 1.872257696809484173e+01 75 | 1.897394097105746269e+01 76 | 1.922542454726159988e+01 77 | 1.947607703355479103e+01 78 | 1.972678839839982601e+01 79 | 1.997802698857814363e+01 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CosmoGraphNet 2 | 3 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.6485804.svg)](https://doi.org/10.5281/zenodo.6485804) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2204.13713-B31B1B.svg)](http://arxiv.org/abs/2204.13713) 5 | 6 | Graph Neural Networks to predict the cosmological parameters and the galaxy power spectrum from galaxy catalogues. 7 | 8 | A graph is created from a galaxy catalogue with information about the 3D position and intrinsic galactic properties. A Graph Neural Network is then applied to infer the cosmological parameters or the galaxy power spectrum. Galaxy catalogues extracted from the [CAMELS](https://camels.readthedocs.io/en/latest/index.html) hydrodynamic simulations, specially suited for Machine Learning purposes. Neural nets architectures are defined making use of the package [PyTorch-geometric](https://pytorch-geometric.readthedocs.io/en/latest/). 9 | 10 | See the paper [arXiv:2204.13713](https://arxiv.org/abs/2204.13713) for more details. 11 | 12 | 13 | 14 | 15 | ## Description of the codes 16 | 17 | Here is a brief oveview of the codes included: 18 | 19 | - `main.py`: main driver to train and test the network. 20 | 21 | - `hyperparameters.py`: script with the definition of the hyperparameters employed by the networks. 22 | 23 | - `crosstest.py`: tests a pre-trained model. 24 | 25 | - `hyperparams_optimization.py`: optimize the hyperparameters using `optuna`. 26 | 27 | - `ps_test.py`: tests the power spectrum neural networks in point distributions with different clustering properties. 28 | 29 | - `visualize_graphs.py`: display graphs from galaxy catalogues in 2D or 3D. 30 | 31 | 32 | The folder `Source` contains scripts with auxiliary routines: 33 | 34 | * `constants.py`: basic constants and initialization. 35 | 36 | * `load_data.py`: contains routines to load data from simulation files. 37 | 38 | * `plotting.py`: includes functions for displaying the results from the neural nets. 39 | 40 | * `metalayer.py`: includes the definition of the Graph Neural Networks architecture. 41 | 42 | * `training.py`: includes routines for training and testing the net. 43 | 44 | 45 | ## Requisites 46 | 47 | The libraries required for training the models and compute some statistics are: 48 | * `numpy` 49 | * `pytorch` 50 | * `pytorch-geometric` 51 | * `matplotlib` 52 | * `scipy` 53 | * `sklearn` 54 | * [`optuna`](https://optuna.readthedocs.io/en/stable/index.html) (only for optimization in `hyperparams_optimization.py`) 55 | * [`Pylians`](https://pylians3.readthedocs.io/en/master/) (only for computing power spectra in `ps_test.py`) 56 | 57 | 58 | ## Usage 59 | 60 | The codes implemented here are designed to train Graph Neural Network for two tasks. The desired task is chosen in `hyperparameters.py` with the `outmode` flag: 61 | 1. Infer cosmological parameters from galaxy catalogues. Set `outmode = "cosmo"`. 62 | 2. Predict the power spectrum from galaxy catalogues. Set `outmode = "ps"`. 63 | 64 | These are some advices to employ the scripts described above: 65 | 1. To perform a search of the optimal hyperparameters, run `hyperparams_optimization.py`. 66 | 2. To train a model with a given set of parameters defined in `hyperparameters.py`, run `main.py`. The hyperparameters currently present in `hyperparameters.py` correspond to the best optimal values for each suite when all galactic features are employed (see the paper). Modify it accordingly to the task. 67 | 3. Once a model is trained to perform cosmological parameter inference, run `crosstest.py` to test in the training simulation suite and cross test it in the other one included in CAMELS (IllustrisTNG and SIMBA). It needs a pretrained model. 68 | 4. If a model has been trained to predict the power spectrum from CAMELS galaxy catalogues, evaluate its extrapolation performance on different point distributions running `ps_test.py`. It needs a pretrained model. 69 | 70 | 71 | ## Citation 72 | 73 | If you use the code, please link this repository, and cite [arXiv:2204.13713](https://arxiv.org/abs/2204.13713) and the DOI [10.5281/zenodo.6485804](https://doi.org/10.5281/zenodo.6485804). 74 | 75 | 76 | ## Contact 77 | 78 | Feel free to contact me at for comments, questions and suggestions. 79 | -------------------------------------------------------------------------------- /Source/constants.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------------------------- 2 | # List of constants and some common functions 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 4/22 5 | #---------------------------------------------------------------------- 6 | 7 | import numpy as np 8 | import torch 9 | import os 10 | import random 11 | 12 | # Random seeds 13 | torch.manual_seed(12345) 14 | np.random.seed(12345) 15 | random.seed(12345) 16 | 17 | # use GPUs if available 18 | if torch.cuda.is_available(): 19 | print("CUDA Available") 20 | device = torch.device('cuda') 21 | else: 22 | print('CUDA Not Available') 23 | device = torch.device('cpu') 24 | 25 | #--- PARAMETERS AND CONSTANTS ---# 26 | 27 | # Reduced Hubble constant 28 | hred = 0.7 29 | 30 | # Root path for simulations 31 | simpathroot = "/projects/QUIJOTE/CAMELS/Sims/" 32 | 33 | # Box size in comoving kpc/h 34 | boxsize = 25.e3 35 | 36 | # Validation and test size 37 | valid_size, test_size = 0.15, 0.15 38 | 39 | # Batch size 40 | batch_size = 25 41 | 42 | # Number of k bins in the power spectrum 43 | ps_size = 79 44 | 45 | #--- FUNCTIONS ---# 46 | 47 | # Choose color depending on the CAMELS simulation suite 48 | def colorsuite(suite): 49 | if suite=="IllustrisTNG": return "purple" 50 | elif suite=="SIMBA": return "dodgerblue" 51 | -------------------------------------------------------------------------------- /Source/load_data.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------- 2 | # Routine for loading the CAMELS galaxy catalogues 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 4/22 5 | #---------------------------------------------------- 6 | 7 | import h5py 8 | from torch_geometric.data import Data, DataLoader 9 | from Source.constants import * 10 | from Source.plotting import * 11 | import scipy.spatial as SS 12 | 13 | Nstar_th = 20 # Minimum number of stellar particles required to consider a galaxy 14 | 15 | # Normalize CAMELS parameters 16 | def normalize_params(params): 17 | 18 | minimum = np.array([0.1, 0.6, 0.25, 0.25, 0.5, 0.5]) 19 | maximum = np.array([0.5, 1.0, 4.00, 4.00, 2.0, 2.0]) 20 | params = (params - minimum)/(maximum - minimum) 21 | return params 22 | 23 | # Normalize power spectrum 24 | def normalize_ps(ps): 25 | mean, std = ps.mean(axis=0), ps.std(axis=0) 26 | normps = (ps - mean)/std 27 | return normps 28 | 29 | # Compute KDTree and get edges and edge features 30 | def get_edges(pos, r_link, use_loops): 31 | 32 | # 1. Get edges 33 | 34 | # Create the KDTree and look for pairs within a distance r_link 35 | # Boxsize normalize to 1 36 | kd_tree = SS.KDTree(pos, leafsize=16, boxsize=1.0001) 37 | edge_index = kd_tree.query_pairs(r=r_link, output_type="ndarray") 38 | 39 | # Add reverse pairs 40 | reversepairs = np.zeros((edge_index.shape[0],2)) 41 | for i, pair in enumerate(edge_index): 42 | reversepairs[i] = np.array([pair[1], pair[0]]) 43 | edge_index = np.append(edge_index, reversepairs, 0) 44 | 45 | edge_index = edge_index.astype(int) 46 | 47 | # Write in pytorch-geometric format 48 | edge_index = edge_index.T 49 | num_pairs = edge_index.shape[1] 50 | 51 | # 2. Get edge attributes 52 | 53 | row, col = edge_index 54 | diff = pos[row]-pos[col] 55 | 56 | # Take into account periodic boundary conditions, correcting the distances 57 | for i, pos_i in enumerate(diff): 58 | for j, coord in enumerate(pos_i): 59 | if coord > r_link: 60 | diff[i,j] -= 1. # Boxsize normalize to 1 61 | elif -coord > r_link: 62 | diff[i,j] += 1. # Boxsize normalize to 1 63 | 64 | # Get translational and rotational invariant features 65 | 66 | # Distance 67 | dist = np.linalg.norm(diff, axis=1) 68 | 69 | # Centroid of galaxy catalogue 70 | centroid = np.mean(pos,axis=0) 71 | 72 | #Vectors of node and neighbor 73 | row = (pos[row] - centroid) 74 | col = (pos[col] - centroid) 75 | 76 | # Take into account periodic boundary conditions: row and col 77 | for i, pos_i in enumerate(row): 78 | for j, coord in enumerate(pos_i): 79 | if coord > 0.5: 80 | row[i,j] -= 1. # Boxsize normalize to 1 81 | 82 | elif -coord > 0.5: 83 | row[i,j] += 1. # Boxsize normalize to 1 84 | 85 | for i, pos_i in enumerate(col): 86 | for j, coord in enumerate(pos_i): 87 | if coord > 0.5: 88 | col[i,j] -= 1. # Boxsize normalize to 1 89 | 90 | elif -coord > 0.5: 91 | col[i,j] += 1. # Boxsize normalize to 1 92 | 93 | # Normalizing 94 | unitrow = row/np.linalg.norm(row, axis = 1).reshape(-1, 1) 95 | unitcol = col/np.linalg.norm(col, axis = 1).reshape(-1, 1) 96 | unitdiff = diff/dist.reshape(-1,1) 97 | 98 | # Dot products between unit vectors 99 | cos1 = np.array([np.dot(unitrow[i,:].T,unitcol[i,:]) for i in range(num_pairs)]) 100 | cos2 = np.array([np.dot(unitrow[i,:].T,unitdiff[i,:]) for i in range(num_pairs)]) 101 | 102 | # Normalize distance by linking radius 103 | dist /= r_link 104 | 105 | # Concatenate to get all edge attributes 106 | edge_attr = np.concatenate([dist.reshape(-1,1), cos1.reshape(-1,1), cos2.reshape(-1,1)], axis=1) 107 | 108 | # Add loops 109 | if use_loops: 110 | loops = np.zeros((2,pos.shape[0]),dtype=int) 111 | atrloops = np.zeros((pos.shape[0],3)) 112 | for i, posit in enumerate(pos): 113 | loops[0,i], loops[1,i] = i, i 114 | atrloops[i,0], atrloops[i,1], atrloops[i,2] = 0., 1., 0. 115 | edge_index = np.append(edge_index, loops, 1) 116 | edge_attr = np.append(edge_attr, atrloops, 0) 117 | edge_index = edge_index.astype(int) 118 | 119 | return edge_index, edge_attr 120 | 121 | # Routine to create a cosmic graph from a galaxy catalogue 122 | # simnumber: number of simulation 123 | # param_file: file with the value of the cosmological + astrophysical parameters 124 | # hparams: hyperparameters class 125 | def sim_graph(simnumber, param_file, hparams): 126 | 127 | # Get some hyperparameters 128 | simsuite,simset,r_link,only_positions,outmode,pred_params = hparams.simsuite,hparams.simset,hparams.r_link,hparams.only_positions,hparams.outmode,hparams.pred_params 129 | 130 | # Name of the galaxy catalogue 131 | simpath = simpathroot + simsuite + "/"+simset+"_" 132 | catalogue = simpath + str(simnumber)+"/fof_subhalo_tab_0"+hparams.snap+".hdf5" 133 | 134 | # Read the catalogue 135 | f = h5py.File(catalogue, 'r') 136 | pos = f['/Subhalo/SubhaloPos'][:]/boxsize 137 | Mstar = f['/Subhalo/SubhaloMassType'][:,4] #Msun/h 138 | Rstar = f["Subhalo/SubhaloHalfmassRadType"][:,4] 139 | Metal = f["Subhalo/SubhaloStarMetallicity"][:] 140 | Vmax = f["Subhalo/SubhaloVmax"][:] 141 | Nstar = f['/Subhalo/SubhaloLenType'][:,4] #number of stars 142 | f.close() 143 | 144 | # Some simulations are slightly outside the box, correct it 145 | pos[np.where(pos<0.0)]+=1.0 146 | pos[np.where(pos>1.0)]-=1.0 147 | 148 | # Select only galaxies with more than Nstar_th star particles 149 | indexes = np.where(Nstar>Nstar_th)[0] 150 | pos = pos[indexes] 151 | Mstar = Mstar[indexes] 152 | Rstar = Rstar[indexes] 153 | Metal = Metal[indexes] 154 | Vmax = Vmax[indexes] 155 | 156 | # Get the output to be predicted by the GNN, either the cosmo parameters or the power spectrum 157 | if outmode=="cosmo": 158 | # Read the value of the cosmological & astrophysical parameters 159 | paramsfile = np.loadtxt(param_file, dtype=str) 160 | params = np.array(paramsfile[simnumber,1:-1],dtype=np.float32) 161 | params = normalize_params(params) 162 | params = params[:pred_params] # Consider only the first parameters, up to pred_params 163 | y = np.reshape(params, (1,params.shape[0])) 164 | 165 | # Read the power spectra 166 | elif outmode=="ps": 167 | 168 | ps = np.load(param_file) 169 | ps = ps[simnumber] 170 | ps = np.log10(ps) 171 | #ps = normalize_ps(ps) 172 | y = np.reshape(ps, (1,ps_size)) 173 | 174 | # Number of galaxies as global feature 175 | u = np.log10(pos.shape[0]).reshape(1,1) 176 | 177 | Mstar = np.log10(1.+ Mstar) 178 | Rstar = np.log10(1.+ Rstar) 179 | Metal = np.log10(1.+ Metal) 180 | Vmax = np.log10(1. + Vmax) 181 | 182 | # Node features 183 | tab = np.column_stack((Mstar, Rstar, Metal, Vmax)) 184 | #tab = Vmax.reshape(-1,1) # For using only Vmax 185 | x = torch.tensor(tab, dtype=torch.float32) 186 | 187 | # Use loops if node features are considered only 188 | if only_positions: 189 | tab = np.zeros_like(pos[:,:1]) # Node features not really used 190 | use_loops = False 191 | else: 192 | use_loops = True 193 | 194 | # Get edges and edge features 195 | edge_index, edge_attr = get_edges(pos, r_link, use_loops) 196 | 197 | # Construct the graph 198 | graph = Data(x=x, 199 | y=torch.tensor(y, dtype=torch.float32), 200 | u=torch.tensor(u, dtype=torch.float32), 201 | edge_index=torch.tensor(edge_index, dtype=torch.long), 202 | edge_attr=torch.tensor(edge_attr, dtype=torch.float32)) 203 | 204 | return graph 205 | 206 | 207 | # Split training and validation sets 208 | def split_datasets(dataset): 209 | 210 | random.shuffle(dataset) 211 | 212 | num_train = len(dataset) 213 | split_valid = int(np.floor(valid_size * num_train)) 214 | split_test = split_valid + int(np.floor(test_size * num_train)) 215 | 216 | train_dataset = dataset[split_test:] 217 | valid_dataset = dataset[:split_valid] 218 | test_dataset = dataset[split_valid:split_test] 219 | 220 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 221 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True) 222 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) 223 | 224 | return train_loader, valid_loader, test_loader 225 | 226 | ###################################################################################### 227 | 228 | # Main routine to load data and create the dataset 229 | def create_dataset(hparams): 230 | 231 | # Target file depending on the task: inferring cosmo parameters or predicting power spectrum 232 | if hparams.outmode == "cosmo": 233 | param_file = "/projects/QUIJOTE/CAMELS/Sims/CosmoAstroSeed_params_"+hparams.simsuite+".txt" 234 | elif hparams.outmode == "ps": 235 | param_file = "PS_files/Pk_galaxies_"+hparams.simsuite+"_LH_"+hparams.snap+"_kmax=20.0.npy" 236 | 237 | dataset = [] 238 | 239 | for simnumber in range(hparams.n_sims): 240 | dataset.append(sim_graph(simnumber,param_file,hparams)) 241 | 242 | # Add the other suite for predicting the power spectrum 243 | if hparams.outmode == "ps": 244 | hparams.simsuite = hparams.flip_suite() 245 | param_file = "PS_files/Pk_galaxies_"+hparams.simsuite+"_LH_"+hparams.snap+"_kmax=20.0.npy" 246 | 247 | for simnumber in range(hparams.n_sims): 248 | dataset.append(sim_graph(simnumber,param_file,hparams)) 249 | 250 | # Add other snapshots from other redshifts 251 | # Snapshot redshift 252 | # 004: z=3, 010: z=2, 014: z=1.5, 018: z=1, 024: z=0.5, 033: z=0 253 | #for snap in [24,18,14,10]: 254 | for snap in [18,10]: 255 | 256 | hparams.snap = str(snap) 257 | 258 | param_file = "PS_files/Pk_galaxies_"+hparams.simsuite+"_LH_"+hparams.snap+"_kmax=20.0.npy" 259 | 260 | for simnumber in range(hparams.n_sims): 261 | dataset.append(sim_graph(simnumber,param_file,hparams)) 262 | 263 | hparams.simsuite = hparams.flip_suite() 264 | param_file = "PS_files/Pk_galaxies_"+hparams.simsuite+"_LH_"+hparams.snap+"_kmax=20.0.npy" 265 | 266 | for simnumber in range(hparams.n_sims): 267 | dataset.append(sim_graph(simnumber,param_file,hparams)) 268 | 269 | gals = np.array([graph.x.shape[0] for graph in dataset]) 270 | print("Total of galaxies", gals.sum(0), "Mean of", gals.mean(0),"per simulation, Std of", gals.std(0)) 271 | 272 | return dataset 273 | -------------------------------------------------------------------------------- /Source/metalayer.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------- 2 | # Graph Neural Network architecture 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 4/22 5 | #---------------------------------------------------- 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch_cluster import knn_graph, radius_graph 10 | from torch.nn import Sequential, Linear, ReLU, ModuleList 11 | from torch_geometric.nn import MessagePassing, MetaLayer, LayerNorm 12 | from torch_scatter import scatter_mean, scatter_sum, scatter_max, scatter_min, scatter_add 13 | from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool 14 | from Source.constants import * 15 | 16 | # Model for updating edge attritbutes 17 | class EdgeModel(torch.nn.Module): 18 | def __init__(self, node_in, node_out, edge_in, edge_out, hid_channels, residuals=True, norm=False): 19 | super().__init__() 20 | 21 | self.residuals = residuals 22 | self.norm = norm 23 | 24 | layers = [Linear(node_in*2 + edge_in, hid_channels), 25 | ReLU(), 26 | Linear(hid_channels, edge_out)] 27 | if self.norm: layers.append(LayerNorm(edge_out)) 28 | 29 | self.edge_mlp = Sequential(*layers) 30 | 31 | 32 | def forward(self, src, dest, edge_attr, u, batch): 33 | # src, dest: [E, F_x], where E is the number of edges. 34 | # edge_attr: [E, F_e] 35 | # u: [B, F_u], where B is the number of graphs. 36 | # batch: [E] with max entry B - 1. 37 | 38 | out = torch.cat([src, dest, edge_attr], dim=1) 39 | #out = torch.cat([src, dest, edge_attr, u[batch]], 1) 40 | out = self.edge_mlp(out) 41 | if self.residuals: 42 | out = out + edge_attr 43 | return out 44 | 45 | # Model for updating node attritbutes 46 | class NodeModel(torch.nn.Module): 47 | def __init__(self, node_in, node_out, edge_in, edge_out, hid_channels, residuals=True, norm=False): 48 | super().__init__() 49 | 50 | self.residuals = residuals 51 | self.norm = norm 52 | 53 | layers = [Linear(node_in + 3*edge_out + 1, hid_channels), 54 | ReLU(), 55 | Linear(hid_channels, node_out)] 56 | if self.norm: layers.append(LayerNorm(node_out)) 57 | 58 | self.node_mlp = Sequential(*layers) 59 | 60 | def forward(self, x, edge_index, edge_attr, u, batch): 61 | # x: [N, F_x], where N is the number of nodes. 62 | # edge_index: [2, E] with max entry N - 1. 63 | # edge_attr: [E, F_e] 64 | # u: [B, F_u] 65 | # batch: [N] with max entry B - 1. 66 | 67 | row, col = edge_index 68 | out = edge_attr 69 | 70 | # Multipooling layer 71 | out1 = scatter_add(out, col, dim=0, dim_size=x.size(0)) 72 | out2 = scatter_max(out, col, dim=0, dim_size=x.size(0))[0] 73 | out3 = scatter_mean(out, col, dim=0, dim_size=x.size(0)) 74 | out = torch.cat([x, out1, out2, out3, u[batch]], dim=1) 75 | 76 | out = self.node_mlp(out) 77 | if self.residuals: 78 | out = out + x 79 | return out 80 | 81 | # First edge model for updating edge attritbutes when no initial node features are provided 82 | class EdgeModelIn(torch.nn.Module): 83 | def __init__(self, node_in, node_out, edge_in, edge_out, hid_channels, norm=False): 84 | super().__init__() 85 | 86 | self.norm = norm 87 | 88 | layers = [Linear(edge_in, hid_channels), 89 | ReLU(), 90 | Linear(hid_channels, edge_out)] 91 | if self.norm: layers.append(LayerNorm(edge_out)) 92 | 93 | self.edge_mlp = Sequential(*layers) 94 | 95 | 96 | def forward(self, src, dest, edge_attr, u, batch): 97 | 98 | out = self.edge_mlp(edge_attr) 99 | 100 | return out 101 | 102 | # First node model for updating node attritbutes when no initial node features are provided 103 | class NodeModelIn(torch.nn.Module): 104 | def __init__(self, node_in, node_out, edge_in, edge_out, hid_channels, norm=False): 105 | super().__init__() 106 | 107 | self.norm = norm 108 | 109 | layers = [Linear(3*edge_out + 1, hid_channels), 110 | ReLU(), 111 | Linear(hid_channels, node_out)] 112 | if self.norm: layers.append(LayerNorm(node_out)) 113 | 114 | self.node_mlp = Sequential(*layers) 115 | 116 | def forward(self, x, edge_index, edge_attr, u, batch): 117 | 118 | row, col = edge_index 119 | out = edge_attr 120 | 121 | # Multipooling layer 122 | out1 = scatter_add(out, col, dim=0, dim_size=x.size(0)) 123 | out2 = scatter_max(out, col, dim=0, dim_size=x.size(0))[0] 124 | out3 = scatter_mean(out, col, dim=0, dim_size=x.size(0)) 125 | out = torch.cat([out1, out2, out3, u[batch]], dim=1) 126 | 127 | out = self.node_mlp(out) 128 | 129 | return out 130 | 131 | # Graph Neural Network architecture, based on the Graph Network (arXiv:1806.01261) 132 | # Employing the MetaLayer implementation in Pytorch-Geometric 133 | class GNN(torch.nn.Module): 134 | def __init__(self, node_features, n_layers, hidden_channels, linkradius, dim_out, only_positions, residuals=True): 135 | super().__init__() 136 | 137 | self.n_layers = n_layers 138 | self.link_r = linkradius 139 | self.dim_out = dim_out 140 | self.only_positions = only_positions 141 | 142 | # Number of input node features (0 if only_positions is used) 143 | node_in = node_features 144 | # Input edge features: |p_i-p_j|, p_i*p_j, p_i*(p_i-p_j) 145 | edge_in = 3 146 | node_out = hidden_channels 147 | edge_out = hidden_channels 148 | hid_channels = hidden_channels 149 | 150 | layers = [] 151 | 152 | # Encoder graph block 153 | # If use only positions, node features are created from the aggregation of edge attritbutes of neighbors 154 | if self.only_positions: 155 | inlayer = MetaLayer(node_model=NodeModelIn(node_in, node_out, edge_in, edge_out, hid_channels), 156 | edge_model=EdgeModelIn(node_in, node_out, edge_in, edge_out, hid_channels)) 157 | 158 | else: 159 | inlayer = MetaLayer(node_model=NodeModel(node_in, node_out, edge_in, edge_out, hid_channels, residuals=False), 160 | edge_model=EdgeModel(node_in, node_out, edge_in, edge_out, hid_channels, residuals=False)) 161 | 162 | layers.append(inlayer) 163 | 164 | # Change input node and edge feature sizes 165 | node_in = node_out 166 | edge_in = edge_out 167 | 168 | # Hidden graph blocks 169 | for i in range(n_layers-1): 170 | 171 | lay = MetaLayer(node_model=NodeModel(node_in, node_out, edge_in, edge_out, hid_channels, residuals=residuals), 172 | edge_model=EdgeModel(node_in, node_out, edge_in, edge_out, hid_channels, residuals=residuals)) 173 | layers.append(lay) 174 | 175 | self.layers = ModuleList(layers) 176 | 177 | # Final aggregation layer 178 | self.outlayer = Sequential(Linear(3*node_out+1, hid_channels), 179 | ReLU(), 180 | Linear(hid_channels, hid_channels), 181 | ReLU(), 182 | Linear(hid_channels, hid_channels), 183 | ReLU(), 184 | Linear(hid_channels, self.dim_out)) 185 | 186 | def forward(self, data): 187 | 188 | h, edge_index, edge_attr, u = data.x, data.edge_index, data.edge_attr, data.u 189 | 190 | # Message passing layers 191 | for layer in self.layers: 192 | h, edge_attr, _ = layer(h, edge_index, edge_attr, u, data.batch) 193 | 194 | # Multipooling layer 195 | addpool = global_add_pool(h, data.batch) 196 | meanpool = global_mean_pool(h, data.batch) 197 | maxpool = global_max_pool(h, data.batch) 198 | 199 | out = torch.cat([addpool,meanpool,maxpool,u], dim=1) 200 | 201 | # Final linear layer 202 | out = self.outlayer(out) 203 | 204 | return out 205 | -------------------------------------------------------------------------------- /Source/plotting.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------------------------- 2 | # Script for plotting some statistics 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 4/22 5 | #---------------------------------------------------------------------- 6 | 7 | import matplotlib.pyplot as plt 8 | from Source.constants import * 9 | from sklearn.metrics import r2_score 10 | from matplotlib.offsetbox import AnchoredText 11 | from matplotlib.ticker import MultipleLocator 12 | from matplotlib.lines import Line2D 13 | import matplotlib as mpl 14 | mpl.rcParams.update({'font.size': 12}) 15 | 16 | # Plot loss trends 17 | def plot_losses(train_losses, valid_losses, test_loss, err_min, hparams): 18 | 19 | epochs = hparams.n_epochs 20 | plt.plot(range(epochs), np.exp(train_losses), "r-",label="Training") 21 | plt.plot(range(epochs), np.exp(valid_losses), "b:",label="Validation") 22 | plt.legend() 23 | plt.yscale("log") 24 | plt.title(f"Test loss: {test_loss:.2e}, Minimum relative error: {err_min:.2e}") 25 | plt.savefig("Plots/loss_"+hparams.name_model()+".png", bbox_inches='tight', dpi=300) 26 | plt.close() 27 | 28 | # Remove normalization of cosmo parameters 29 | def denormalize(trues, outputs, errors, minpar, maxpar): 30 | 31 | trues = minpar + trues*(maxpar - minpar) 32 | outputs = minpar + outputs*(maxpar - minpar) 33 | errors = errors*(maxpar - minpar) 34 | return trues, outputs, errors 35 | 36 | # Scatter plot of true vs predicted cosmological parameter 37 | def plot_out_true_scatter(hparams, cosmoparam, testsuite = False): 38 | 39 | figscat, axscat = plt.subplots(figsize=(6,5)) 40 | suite, simset = hparams.simsuite, hparams.simset 41 | col = colorsuite(suite) 42 | 43 | # Load true values and predicted means and standard deviations 44 | outputs = np.load("Outputs/outputs_"+hparams.name_model()+".npy") 45 | trues = np.load("Outputs/trues_"+hparams.name_model()+".npy") 46 | errors = np.load("Outputs/errors_"+hparams.name_model()+".npy") 47 | 48 | # There is a (0,0) initial point, fix it 49 | outputs = outputs[1:] 50 | trues = trues[1:] 51 | errors = errors[1:] 52 | 53 | # Choose cosmo param and denormalize 54 | if cosmoparam=="Om": 55 | minpar, maxpar = 0.1, 0.5 56 | outputs, trues, errors = outputs[:,0], trues[:,0], errors[:,0] 57 | elif cosmoparam=="Sig": 58 | minpar, maxpar = 0.6, 1.0 59 | outputs, trues, errors = outputs[:,1], trues[:,1], errors[:,1] 60 | trues, outputs, errors = denormalize(trues, outputs, errors, minpar, maxpar) 61 | 62 | # Compute the number of points lying within 1 or 2 sigma regions from their uncertainties 63 | cond_success_1sig, cond_success_2sig = np.abs(outputs-trues)<=np.abs(errors), np.abs(outputs-trues)<=2.*np.abs(errors) 64 | tot_points = outputs.shape[0] 65 | successes1sig, successes2sig = outputs[cond_success_1sig].shape[0], outputs[cond_success_2sig].shape[0] 66 | 67 | # Compute the linear correlation coefficient 68 | r2 = r2_score(trues,outputs) 69 | err_rel = np.mean(np.abs((trues - outputs)/(trues)), axis=0) 70 | chi2s = (outputs-trues)**2./errors**2. 71 | chi2 = chi2s[chi2s<1.e4].mean() # Remove some outliers which make explode the chi2 72 | print("R^2={:.2f}, Relative error={:.2e}, Chi2={:.2f}".format(r2, err_rel, chi2)) 73 | print("A fraction of succeses of", successes1sig/tot_points, "at 1 sigma,", successes2sig/tot_points, "at 2 sigmas") 74 | 75 | # Sort by true value 76 | indsort = trues.argsort() 77 | outputs, trues, errors = outputs[indsort], trues[indsort], errors[indsort] 78 | 79 | # Compute mean and std region within several bins 80 | truebins, binsize = np.linspace(trues[0], trues[-1], num=10, retstep=True) 81 | means, stds = [], [] 82 | for i, bin in enumerate(truebins[:-1]): 83 | cond = (trues>=bin) & (trues {:.2e}). Saving model ...".format(valid_loss_min,test_loss)) 121 | torch.save(model.state_dict(), "Models/"+hparams.name_model()) 122 | valid_loss_min = test_loss 123 | err_min = err 124 | 125 | if verbose: print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.2e}, Validation Loss: {test_loss:.2e}, Error: {err:.2e}') 126 | 127 | return train_losses, valid_losses 128 | -------------------------------------------------------------------------------- /crosstest.py: -------------------------------------------------------------------------------- 1 | #------------------------------------------------ 2 | # Test a model already trained 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 3/21 5 | #------------------------------------------------ 6 | 7 | from main import * 8 | from hyperparameters import hparams 9 | 10 | #--- MAIN ---# 11 | 12 | time_ini = time.time() 13 | 14 | for path in ["Plots", "Models", "Outputs"]: 15 | if not os.path.exists(path): 16 | os.mkdir(path) 17 | 18 | # Test a pretrained model 19 | hparams.training = False 20 | 21 | main(hparams) 22 | 23 | # Test the pretrained model in the other CAMELS suite 24 | hparams.simsuite = hparams.flip_suite() 25 | 26 | main(hparams, testsuite=True) 27 | 28 | print("Finished. Time elapsed:",datetime.timedelta(seconds=time.time()-time_ini)) 29 | -------------------------------------------------------------------------------- /hyperparameters.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------- 2 | # Hyperparameters definition 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 4/22 5 | #---------------------------------------------------- 6 | 7 | # Hyperparameters class 8 | class hyperparameters(): 9 | def __init__(self, outmode, only_positions, learning_rate, weight_decay, n_layers, hidden_channels, r_link, n_epochs, simsuite, simset="LH", n_sims=1000, training=True, pred_params=2): 10 | 11 | # Choose the output to be predicted, either the cosmological parameters ("cosmo") or the power spectrum ("ps") 12 | self.outmode = outmode 13 | # 1 for using only positions as features, 0 for using additional galactic features 14 | # 1 only for outmode = "cosmo" 15 | self.only_positions = only_positions 16 | if self.outmode == "ps": 17 | self.only_positions = 1 18 | # Learning rate 19 | self.learning_rate = learning_rate 20 | # Weight decay 21 | self.weight_decay = weight_decay 22 | # Number of graph layers 23 | self.n_layers = n_layers 24 | # Hidden channels 25 | self.hidden_channels = hidden_channels 26 | # Linking radius 27 | self.r_link = r_link 28 | # Number of epochs 29 | self.n_epochs = n_epochs 30 | # Simulation suite, choose between "IllustrisTNG" and "SIMBA" 31 | self.simsuite = simsuite 32 | # Simulation set, choose between "CV" and "LH" 33 | self.simset = simset 34 | # Number of simulations considered, maximum 27 for CV and 1000 for LH 35 | self.n_sims = n_sims 36 | # If training, set to True, otherwise loads a pretrained model and tests it 37 | self.training = training 38 | # Number of cosmo/astro params to be predicted, starting from Omega_m, sigma_8, etc. 39 | # Only for outmode = "cosmo" 40 | self.pred_params = pred_params 41 | # Snapshot of the simulation, indicating redshift 4: z=3, 10: z=2, 14: z=1.5, 18: z=1, 24: z=0.5, 33: z=0 42 | self.snap = "33" 43 | 44 | # Name of the model and hyperparameters 45 | def name_model(self): 46 | return self.outmode+"_"+self.simsuite+"_"+self.simset+"_onlypos_"+str(self.only_positions)+"_lr_{:.2e}_weightdecay_{:.2e}_layers_{:d}_rlink_{:.2e}_channels_{:d}_epochs_{:d}".format(self.learning_rate, self.weight_decay, self.n_layers, self.r_link, self.hidden_channels, self.n_epochs) 47 | 48 | # Return the other CAMELS simulation suite 49 | def flip_suite(self): 50 | if self.simsuite=="IllustrisTNG": 51 | new_simsuite = "SIMBA" 52 | elif self.simsuite=="SIMBA": 53 | new_simsuite = "IllustrisTNG" 54 | return new_simsuite 55 | 56 | 57 | #--- HYPERPARAMETER CHOICES ---# 58 | 59 | #""" 60 | # IllustrisTNG best model 61 | hparams = hyperparameters(outmode = "cosmo", # Choose the output to be predicted, either the cosmological parameters ("cosmo") or the power spectrum ("ps") 62 | only_positions = 0, # 1 for using only positions as features, 0 for using additional galactic features 63 | learning_rate = 1.619e-07, # Learning rate 64 | weight_decay = 1.e-07, # Weight decay 65 | n_layers = 2, # Number of hidden graph layers 66 | r_link = 0.015, # Linking radius 67 | hidden_channels = 64, # Hidden channels 68 | n_epochs = 300, # Number of epochs 69 | simsuite = "IllustrisTNG", # Simulation suite, choose between "IllustrisTNG" and "SIMBA" 70 | pred_params = 1 # Number of cosmo/astro params to be predicted, starting from Omega_m, sigma_8, etc. (Only for outmode = "cosmo") 71 | ) 72 | """ 73 | # SIMBA best model 74 | hparams = hyperparameters(outmode = "cosmo", # Choose the output to be predicted, either the cosmological parameters ("cosmo") or the power spectrum ("ps") 75 | only_positions = 0, # 1 for using only positions as features, 0 for using additional galactic features 76 | learning_rate = 1.087e-06, # Learning rate 77 | weight_decay = 1.e-07, # Weight decay 78 | n_layers = 4, # Number of hidden graph layers 79 | r_link = 0.0148, # Linking radius 80 | hidden_channels = 64, # Hidden channels 81 | n_epochs = 300, # Number of epochs 82 | simsuite = "SIMBA", # Simulation suite, choose between "IllustrisTNG" and "SIMBA" 83 | pred_params = 1 # Number of cosmo/astro params to be predicted, starting from Omega_m, sigma_8, etc. (Only for outmode = "cosmo") 84 | ) 85 | #""" 86 | -------------------------------------------------------------------------------- /hyperparams_optimization.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------------------------- 2 | # Script for optimizing the hyperparameters of the network using optuna 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 4/22 5 | #---------------------------------------------------------------------- 6 | 7 | import optuna 8 | from main import * 9 | from optuna.visualization import plot_optimization_history, plot_contour, plot_param_importances # it needs plotly and kaleido 10 | from hyperparameters import hparams 11 | 12 | # Simulation type 13 | simsuite = "IllustrisTNG" 14 | simset = "LH" 15 | n_sims = 1000 16 | # Number of epochs 17 | n_epochs = 300 18 | 19 | # Objective function to minimize 20 | def objective(trial): 21 | 22 | # Hyperparameters to optimize 23 | learning_rate = trial.suggest_float("learning_rate", 1e-8, 1e-4, log=True) 24 | #weight_decay = trial.suggest_float("weight_decay", 1e-8, 1e-6, log=True) 25 | weight_decay = 1.e-7 26 | n_layers = trial.suggest_int("n_layers", 1, 5) 27 | hidden_channels = trial.suggest_categorical("hidden_channels", [64, 128, 256]) 28 | r_link = trial.suggest_float("r_link", 5.e-3, 5.e-2, log=True) 29 | 30 | # Some verbose 31 | print('\nTrial number: {}'.format(trial.number)) 32 | print('learning_rate: {}'.format(learning_rate)) 33 | #print('weight_decay: {}'.format(weight_decay)) 34 | print('n_layers: {}'.format(n_layers)) 35 | print('hidden_channels: {}'.format(hidden_channels)) 36 | print('r_link: {}'.format(r_link)) 37 | 38 | # Hyperparameters to be optimized 39 | hparams.learning_rate = learning_rate 40 | hparams.weight_decay = weight_decay 41 | hparams.n_layers = n_layers 42 | hparams.hidden_channels = hidden_channels 43 | hparams.r_link = r_link 44 | 45 | # Default params 46 | hparams.n_epochs = n_epochs 47 | hparams.simsuite = simsuite 48 | hparams.simset = simset 49 | hparams.n_sims = n_sims 50 | 51 | # Run main routine 52 | min_test_loss = main(hparams, verbose = False) 53 | 54 | if torch.cuda.is_available(): 55 | torch.cuda.empty_cache() 56 | 57 | return min_test_loss 58 | 59 | 60 | #--- MAIN ---# 61 | 62 | if __name__ == "__main__": 63 | 64 | time_ini = time.time() 65 | 66 | for path in ["Plots", "Models", "Outputs"]: 67 | if not os.path.exists(path): 68 | os.mkdir(path) 69 | 70 | # Optuna parameters 71 | storage = "sqlite:///"+os.getcwd()+"/optuna_"+simsuite+"_"+simset 72 | study_name = "gnn" 73 | n_trials = 30 74 | 75 | # Define sampler and start optimization 76 | sampler = optuna.samplers.TPESampler(n_startup_trials=10) 77 | study = optuna.create_study(study_name=study_name, sampler=sampler, storage=storage, load_if_exists=True) 78 | study.optimize(objective, n_trials, gc_after_trial=True) 79 | 80 | # Print info for best trial 81 | print("Best trial:") 82 | trial = study.best_trial 83 | print(" Value: ", trial.value) 84 | print(" Params: ") 85 | for key, value in trial.params.items(): 86 | print(" {}: {}".format(key, value)) 87 | 88 | hparams.learning_rate = trial.params["learning_rate"] 89 | hparams.n_layers = trial.params["n_layers"] 90 | hparams.hidden_channels = trial.params["hidden_channels"] 91 | hparams.r_link = trial.params["r_link"] 92 | 93 | # Save best model and plots 94 | if not os.path.exists("Best"): 95 | os.mkdir("Best") 96 | # Change nominal suite to read correct files (actually in ps mode both suites are employed) 97 | if hparams.outmode=="ps": 98 | hparams.simsuite = hparams.flip_suite() 99 | files = [] 100 | files.append( "Plots/out_true_Om_"+hparams.name_model()+".png" ) 101 | files.append( "Plots/out_true_Sig_"+hparams.name_model()+".png" ) 102 | files.append( "Plots/loss_"+hparams.name_model()+".png" ) 103 | files.append( "Plots/ps_"+hparams.name_model()+".png" ) 104 | files.append( "Plots/rel_err_"+hparams.name_model()+".png" ) 105 | files.append( "Models/"+hparams.name_model() ) 106 | for file in files: 107 | if os.path.exists(file): 108 | os.system("cp "+file+" Best/.") 109 | 110 | # Visualization of optimization results 111 | fig = plot_optimization_history(study) 112 | fig.write_image("Plots/optuna_optimization_history_"+simsuite+".png") 113 | 114 | fig = plot_contour(study)#, params=["learning_rate", "weight_decay", "r_link"])#, "use_model"]) 115 | fig.write_image("Plots/optuna_contour_"+simsuite+".png") 116 | 117 | fig = plot_param_importances(study) 118 | fig.write_image("Plots/plot_param_importances_"+simsuite+".png") 119 | 120 | print("Finished. Time elapsed:",datetime.timedelta(seconds=time.time()-time_ini)) 121 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------- 2 | # Main routine for training and testing GNN models 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 4/22 5 | #---------------------------------------------------- 6 | 7 | import time, datetime, psutil 8 | from Source.metalayer import * 9 | from Source.training import * 10 | from Source.plotting import * 11 | from Source.load_data import * 12 | 13 | 14 | # Main routine to train the neural net 15 | # If testsuite==True, it takes a model already pretrained in the other suite and tests it in the selected one 16 | def main(hparams, verbose = True, testsuite = False): 17 | 18 | # Load data and create dataset 19 | dataset = create_dataset(hparams) 20 | node_features = dataset[0].x.shape[1] 21 | 22 | # Split dataset among training, validation and testing datasets 23 | train_loader, valid_loader, test_loader = split_datasets(dataset) 24 | 25 | # Size of the output of the GNN 26 | if hparams.outmode=="cosmo": 27 | dim_out=2*hparams.pred_params 28 | elif hparams.outmode=="ps": 29 | dim_out=ps_size 30 | 31 | # Initialize model 32 | model = GNN(node_features=node_features, 33 | n_layers=hparams.n_layers, 34 | hidden_channels=hparams.hidden_channels, 35 | linkradius=hparams.r_link, 36 | dim_out=dim_out, 37 | only_positions=hparams.only_positions) 38 | model.to(device) 39 | if verbose: print("Model: " + hparams.name_model()+"\n") 40 | 41 | # Print the memory (in GB) being used now: 42 | process = psutil.Process() 43 | print("Memory being used (GB):",process.memory_info().rss/1.e9) 44 | 45 | # Train the net 46 | if hparams.training: 47 | if verbose: print("Training!\n") 48 | train_losses, valid_losses = training_routine(model, train_loader, valid_loader, hparams, verbose) 49 | 50 | # Test the net 51 | if verbose: print("\nTesting!\n") 52 | 53 | # If test in other suite, change the suite for loading the model 54 | if testsuite==True: 55 | hparams.simsuite = hparams.flip_suite() # change for loading the model 56 | 57 | # Load the trained model 58 | state_dict = torch.load("Models/"+hparams.name_model(), map_location=device) 59 | model.load_state_dict(state_dict) 60 | 61 | if testsuite==True: hparams.simsuite = hparams.flip_suite() # change after loading the model 62 | 63 | # Test the model 64 | test_loss, rel_err = test(test_loader, model, hparams) 65 | if verbose: print("Test Loss: {:.2e}, Relative error: {:.2e}".format(test_loss, rel_err)) 66 | 67 | # Plot loss trends 68 | if hparams.training: 69 | plot_losses(train_losses, valid_losses, test_loss, rel_err, hparams) 70 | 71 | # Plot true vs predicted cosmo parameters 72 | if hparams.outmode=="cosmo": 73 | plot_out_true_scatter(hparams, "Om", testsuite) 74 | if hparams.pred_params==2: 75 | plot_out_true_scatter(hparams, "Sig", testsuite) 76 | 77 | # Plot power spectrum and relative error 78 | elif hparams.outmode=="ps": 79 | plot_ps(hparams) 80 | 81 | return test_loss 82 | 83 | 84 | #--- MAIN ---# 85 | 86 | if __name__ == "__main__": 87 | 88 | time_ini = time.time() 89 | 90 | for path in ["Plots", "Models", "Outputs"]: 91 | if not os.path.exists(path): 92 | os.mkdir(path) 93 | 94 | # Load default parameters 95 | from hyperparameters import hparams 96 | 97 | main(hparams) 98 | 99 | print("Finished. Time elapsed:",datetime.timedelta(seconds=time.time()-time_ini)) 100 | -------------------------------------------------------------------------------- /ps_test.py: -------------------------------------------------------------------------------- 1 | #---------------------------------------------------- 2 | # Compute the power spectrum of different point distirbutions with the GNN trained in CAMELS 3 | # Author: Pablo Villanueva Domingo 4 | # Last update: 4/22 5 | #---------------------------------------------------- 6 | 7 | import time, datetime 8 | from Source.metalayer import * 9 | from Source.training import * 10 | from Source.plotting import * 11 | from Source.load_data import * 12 | from visualize_graphs import visualize_graph 13 | #import powerbox as pbox 14 | import MAS_library as MASL 15 | import Pk_library as PKL 16 | 17 | # Power spectrum parameters 18 | BoxLen = 25.0 19 | grid = 512 20 | MAS = 'CIC' 21 | kmax = 20.0 #h/Mpc 22 | axis = 0 23 | threads = 28 24 | vol = (boxsize/1.e3)**3. # (Mpc/h)^3 25 | 26 | 27 | #--- POINT DISTRIBUTIONS ---# 28 | 29 | # Poisson point process 30 | def poisson_process(num_points): 31 | 32 | #pos = torch.rand((num_points,3)) 33 | pos = np.random.uniform(0., 1., (num_points,3)) 34 | 35 | return pos 36 | 37 | # Neyman-Scott process with a gaussian kernel (Thomas point process) 38 | # Based on https://hpaulkeeler.com/simulating-a-thomas-cluster-point-process/ 39 | def neynmanscott_process(num_parents, num_daughters, sigma): 40 | 41 | # Generate parents 42 | x_par = poisson_process(num_parents) 43 | 44 | # Simulate Poisson point process for the daughters 45 | numbPointsDaughter = np.random.poisson(num_daughters, x_par.shape[0]) 46 | numbPoints = sum(numbPointsDaughter) 47 | 48 | # Generate the relative locations as independent normal variables 49 | x_daug = np.random.normal(0, sigma, size=(numbPoints,3)) # (relative) x coordinaets 50 | 51 | # Replicate parent points (ie centres of disks/clusters) and center daughters around them 52 | xx = np.repeat(x_par, numbPointsDaughter, axis=0) 53 | xx = xx + x_daug 54 | 55 | # Retain only those points inside simulation window 56 | booleInside=((xx[:,0]>=0)&(xx[:,0]<=1)&(xx[:,1]>=0)&(xx[:,1]<=1)&(xx[:,2]>=0)&(xx[:,2]<=1)) 57 | xx = xx[booleInside] 58 | 59 | return xx 60 | 61 | # Soneira-Peebles point process (Soneira & Peebles 1977, 1978) 62 | def soneira_peebles_model(lamb, eta, n_levels, R0): 63 | 64 | # Radius for first level 65 | Rparent = R0 66 | 67 | # Generate parents 68 | #num_parents = max(1,np.random.poisson(eta)) 69 | num_parents = eta 70 | xparents = poisson_process(num_parents) 71 | 72 | xtot = [] 73 | xtot.extend(xparents) 74 | 75 | # Iterate over each level 76 | for n in range(2,n_levels+1): 77 | Rparent = Rparent/lamb 78 | pointsx = [] 79 | 80 | for ipar in range(len(xparents)): 81 | 82 | num_points = np.random.poisson(eta) 83 | #num_points = eta 84 | x_daug = xparents[ipar] + np.random.normal(0, Rparent, size=(num_points,3)) 85 | pointsx.extend(x_daug) 86 | 87 | xparents = pointsx 88 | xtot.extend(pointsx) 89 | 90 | xx = np.array(xtot) 91 | 92 | # Retain only those points inside simulation window 93 | booleInside=((xx[:,0]>=0)&(xx[:,0]<=1)&(xx[:,1]>=0)&(xx[:,1]<=1)&(xx[:,2]>=0)&(xx[:,2]<=1)) 94 | xx = xx[booleInside] 95 | 96 | return xx 97 | 98 | #--- OTHER ROUTINES ---# 99 | 100 | # Routine to compute the power spectrum using Pylians 101 | def compute_ps(pos): 102 | 103 | pos = pos.cpu().detach().numpy() 104 | 105 | pos = pos*BoxLen 106 | 107 | # Construct galaxy 3D density field 108 | delta = np.zeros((grid,grid,grid), dtype=np.float32) 109 | MASL.MA(pos, delta, BoxLen, MAS, verbose=False) 110 | delta /= np.mean(delta, dtype=np.float64) 111 | delta -= 1.0 112 | 113 | # Compute the power spectrum 114 | Pk = PKL.Pk(delta, BoxLen, axis, MAS, threads, verbose=False) 115 | k = Pk.k3D 116 | Pk0 = Pk.Pk[:,0] # Monopole 117 | 118 | indexes = np.where(kNstar_th)[0] 96 | pos = pos[indexes] 97 | Mstar = Mstar[indexes] 98 | 99 | tab = np.column_stack((pos, Mstar)) 100 | 101 | #edge_index, edge_attr = get_edges(pos, r_link, use_loops=False) 102 | edge_index = radius_graph(torch.tensor(pos,dtype=torch.float32), r=r_link, loop=False) 103 | 104 | data = Data(x=tab, edge_index=torch.tensor(edge_index, dtype=torch.long)) 105 | 106 | if showgraph: 107 | #visualize_graph(data, simnumber, "2d", edge_index) 108 | visualize_graph(data, simnumber, projection="3d", edge_index=data.edge_index) 109 | 110 | if get_degree: 111 | degrees.append( degree(edge_index[0], data.num_nodes).numpy() ) 112 | 113 | if get_degree: 114 | plot_degree_distribution(degrees) 115 | 116 | 117 | 118 | 119 | #--- MAIN ---# 120 | 121 | if __name__=="__main__": 122 | 123 | time_ini = time.time() 124 | 125 | for path in ["Plots"]: 126 | if not os.path.exists(path): 127 | os.mkdir(path) 128 | 129 | # Linking radius 130 | r_link = 0.05 131 | # Simulation suite, choose between "IllustrisTNG" and "SIMBA" 132 | simsuite = "IllustrisTNG" 133 | # Number of simulations considered, maximum 27 for CV and 1000 for LH 134 | n_sims = 20 135 | 136 | display_graphs(simsuite, n_sims, r_link) 137 | 138 | print("Finished. Time elapsed:",datetime.timedelta(seconds=time.time()-time_ini)) 139 | --------------------------------------------------------------------------------