├── README.md ├── .gitignore └── DINGO_Tutorial.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # Tutorial: Introduction to DINGO 2 | 3 | This tutorial provides a first introduction on how to use the [DINGO package](https://github.com/dingo-gw/dingo) for analyzing gravitational wave data using neural posterior estimation. 4 | It illustrates at a 2D toy example how to train a DINGO model from scratch in a simplified setting. Furthermore, the tutorial shows how to download and use an already trained model to obtain posterior samples for GW150914. 5 | With this tutorial, We hope to help people get started with DINGO easily. 6 | 7 | ## Getting started 8 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/annalena-k/tutorial-dingo-introduction/blob/main/DINGO_Tutorial.ipynb) 9 | 10 | To get started quickly, run the tutorial in Google Colab by clicking the button above. 11 | 12 | To run it locally (which may be faster), ensure that you create and activate a Python environment with the `dingo-gw` package. 13 | 14 | With `pip`, this can be done with 15 | ``` 16 | python3 -m venv dingo-venv 17 | source dingo-venv/bin/activate 18 | pip install dingo-gw jupyterlab 19 | ``` 20 | 21 | If using `conda`, this can be done with 22 | ``` 23 | conda create -c conda-forge -n venv-dingo dingo-gw jupyterlab 24 | conda activate venv-dingo 25 | ``` 26 | 27 | You can start the jupyter server by executing the command `jupyter lab` in the folder containing the notebook which should open a new browser window. Finally, you can click on the jupyter notebook `DINGO_Tutorial.ipynb` and start with the tutorial! 28 | (Since you have already installed `dingo-gw`, you do not have to execute the notebook cells containing `!pip install ...` commands.) 29 | 30 | If you are looking for a more general introduction to posterior estimation of gravitational wave data without DINGO, please check out the tutorial ["GW Parameter Inference with Machine Learning"](https://github.com/stephengreen/gw-school-corfu-2023). 31 | 32 | ### Updates 33 | * (03.06.2025) Updated to `dingo=0.8.3`; additional minor changes. 34 | * (30.01.2025) Tutorial was updated to be compatible with `dingo` version 0.7.0 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | .DS_Store 165 | -------------------------------------------------------------------------------- /DINGO_Tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "PrqX-YAHh_KB" 7 | }, 8 | "source": [ 9 | "# DINGO Tutorial\n", 10 | "\n", 11 | "**Welcome to the DINGO tutorial!**\n", 12 | "\n", 13 | "This notebook provides an introduction to the DINGO machine-learning package for gravitational wave (GW) parameter estimation. It shows how to train a DINGO model from scratch, or use an already-trained model to obtain posterior samples for GW data. We hope that this will serve as a starting point for those interested in using DINGO.\n", 14 | "\n", 15 | "For a more general introduction to neural posterior estimation of gravitational waves (without DINGO), see the tutorial [\"GW Parameter Inference with Machine Learning\"](https://github.com/stephengreen/gw-school-corfu-2023).\n", 16 | "\n", 17 | "Since it takes several days to train a production-level DINGO model (14 GW parameters, ~100 million network parameters, ~10 million training examples), we make several simplifications:\n", 18 | "- Restrict the parameter space to two parameters (chirp mass and the mass ratio).\n", 19 | "- Use a much smaller network and training dataset.\n", 20 | "This will give reduced performance, but can be run in real time.\n", 21 | "\n", 22 | "## Tutorial Structure\n", 23 | "1. Generating a waveform dataset\n", 24 | "2. Generating a noise dataset\n", 25 | "3. Training\n", 26 | "4. Inference on injections\n", 27 | "5. Inference on a real event with `dingo_pipe`\n", 28 | "6. Inference with a pre-trained DINGO model (from Zenodo)\n", 29 | "\n", 30 | "Update 21.10.2025: If you run the code in a Google Colab notebook, you have to downgrade the python version since [`bilby-pipe` currently has to use `python<3.12.7`](https://git.ligo.org/lscsoft/bilby_pipe/-/merge_requests/660). To do this, click `Runtime` -> `Change runtime type` -> Select `2025.07` as the `Runtime version`.\n", 31 | "\n", 32 | "To run on GPU with Colab, make sure that `Runtime` -> `Change runtime type` -> GPU.\n", 33 | "To run on CPU, change `device = cuda` to `device = cpu` in training and inference settings." 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": { 39 | "id": "I0Wnnk1rlRFe" 40 | }, 41 | "source": [ 42 | "## 0. Installation\n", 43 | "\n", 44 | "We use `pip`, but dingo can also be installed using `conda`." 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "id": "I88mekKClTda" 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "%%capture\n", 56 | "!pip3 install dingo-gw" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": { 62 | "id": "LP3pxcCwK42O" 63 | }, 64 | "source": [ 65 | "To check whether `dingo` is correctly installed and ready to use, run:" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": { 72 | "colab": { 73 | "base_uri": "https://localhost:8080/" 74 | }, 75 | "id": "imgDaLuRK_pP", 76 | "outputId": "85e62a21-fa94-4406-8f12-74685ff7f980" 77 | }, 78 | "outputs": [], 79 | "source": [ 80 | "import warnings\n", 81 | "warnings.filterwarnings(\"ignore\", \"Wswiglal-redir-stdio\")\n", 82 | "import lal\n", 83 | "\n", 84 | "import dingo\n", 85 | "print(dingo.__version__) # Should be >= 0.8.3" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "id": "7gALmOI8V5Nt" 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "import matplotlib.pyplot as plt\n", 97 | "import numpy as np\n", 98 | "import os\n", 99 | "import pickle\n", 100 | "import yaml" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "metadata": { 106 | "id": "3_Ft58QApNXR" 107 | }, 108 | "source": [ 109 | "Prepare folder structure for tutorial:" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "id": "D0C_WqEfpN2A" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "# Remove automatically created folder 'sample_data'\n", 121 | "if os.path.isdir(\"sample_data\"):\n", 122 | " os.system(\"rm -rf sample_data\")\n", 123 | "# Create folders required for the tutorial\n", 124 | "os.makedirs('01_training_data/asd_dataset', exist_ok=True)\n", 125 | "os.makedirs('01_training_data/waveform_dataset', exist_ok=True)\n", 126 | "os.makedirs('02_training', exist_ok=True)\n", 127 | "os.makedirs('03_inference/injection', exist_ok=True)\n", 128 | "os.makedirs('04_exercise/lum_dist_marginalization', exist_ok=True)\n", 129 | "os.makedirs('04_exercise/with_lum_dist', exist_ok=True)\n", 130 | "os.makedirs('05_pretrained_model/init_train_dir', exist_ok=True)\n", 131 | "os.makedirs('05_pretrained_model/main_train_dir', exist_ok=True)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "id": "XzYmfqu1lNH2" 138 | }, 139 | "source": [ 140 | "## 1. Generating a wavefrom dataset\n", 141 | "**Step 1: Prepare settings file**\n", 142 | "\n", 143 | "DINGO uses YAML files to configure all settings for generating training data and training models.\n", 144 | "\n", 145 | "A waveform dataset settings file specifies the following:\n", 146 | "- **domain**: Supported domain types are `UniformFrequencyDomain` and `MultibandedFrequencyDomain`. Specifies, e.g., frequency domain within the range $[f_\\mathrm{min}, f_\\mathrm{max}]$ and $\\delta f$. DINGO currently does not support training on time domain directly. Data from time-domain waveform generators is transformed and saved in the frequency domain.\n", 147 | "- **waveform_generator**: Specify waveform approximant (e.g., IMRPhenomPv2, IMRPhenomXPHM, SEOBNRv4PHM, SEOBNRv5PHM) and reference frequency. Note that SEOBNRv5 waveforms require `pyseobnr` to be installed separately (optional).\n", 148 | "- **intrinsic_prior**: The parameter space is split into intrinsic and extrinsic components: *Intrinsic* parameters refer to those that are needed to generate waveform polarizations. *Extrinsic* parameters refer to those parameters that can be sampled and applied rapidly during training. Luminosity distance and time of coalescence are considered as both intrinsic and extrinsic. They are needed to generate polarizations, but they can also be easily transformed during training to augment the dataset. We therefore fix them to fiducial values for generating polarizations.\n", 149 | "- **num_samples**: Number of waveforms to generate. In a production setting, it is recommended to train a model with (at least) $5 \\cdot 10^6$ waveforms, with larger parameter space requiring more samples.\n", 150 | "- **compression**: For large datasets, it is recommended to save compressed waveforms using a [singular-value decomposition (SVD)](https://en.wikipedia.org/wiki/Singular_value_decomposition). Since we only generate a dataset with $10^4$ waveforms, we do not apply the compression. Details about the SVD settings can be found [here](https://dingo-gw.readthedocs.io/en/latest/waveform_dataset.html#generating-a-simple-dataset).\n", 151 | "\n", 152 | "*For this tutorial, we analyze only chirp mass and the mass ratio, fixing all other parameters*\n", 153 | "\n", 154 | "More information about generating the waveform dataset can be found in the sections [Generating Waveforms](https://dingo-gw.readthedocs.io/en/latest/generating_waveforms.html) and [Building a waveform dataset](https://dingo-gw.readthedocs.io/en/latest/waveform_dataset.html) of the DINGO documentation.\n", 155 | "\n", 156 | "For more details on how to parallelize this step over multiple cores on a `htcondor` cluster, you can find more information in one of the [production tutorials](https://dingo-gw.readthedocs.io/en/latest/example_npe_model.html)." 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": { 163 | "id": "Skxp2SKkqDP6" 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "waveform_dataset_settings = \"\"\"\n", 168 | "domain:\n", 169 | " type: UniformFrequencyDomain\n", 170 | " f_min: 20.0\n", 171 | " f_max: 256.0\n", 172 | " delta_f: 0.5 # Expressions like 1.0/8.0 would require eval and are not supported\n", 173 | "\n", 174 | "waveform_generator:\n", 175 | " approximant: IMRPhenomPv2\n", 176 | " f_ref: 20.0\n", 177 | " spin_conversion_phase: 0.0 # Reference phase when converting from spin angles to Cartesian spins. If None, use phase parameter.\n", 178 | " # f_start: 15.0 # Optional setting useful for EOB waveforms. Overrides f_min when generating waveforms.\n", 179 | "\n", 180 | "# Dataset only samples over intrinsic parameters. Extrinsic parameters are chosen at train time.\n", 181 | "intrinsic_prior:\n", 182 | " mass_1: bilby.core.prior.Constraint(minimum=20.0, maximum=40.0)\n", 183 | " mass_2: bilby.core.prior.Constraint(minimum=20.0, maximum=40.0)\n", 184 | " chirp_mass: bilby.gw.prior.UniformInComponentsChirpMass(minimum=15.0, maximum=50.0)\n", 185 | " mass_ratio: bilby.gw.prior.UniformInComponentsMassRatio(minimum=0.125, maximum=1.0)\n", 186 | " theta_jn: 2.624497413635254\n", 187 | " tilt_1: 2.0111560821533203\n", 188 | " tilt_2: 1.0743615627288818\n", 189 | " a_1: 0.925635814666748\n", 190 | " a_2: 0.5538952350616455\n", 191 | " phi_jl: 5.561878204345703\n", 192 | " phase: 0.9604566579018894\n", 193 | " # Reference values for fixed (extrinsic) parameters. These are needed to generate a waveform.\n", 194 | " luminosity_distance: 100.0 # Mpc\n", 195 | " geocent_time: 0.0 # s\n", 196 | "\n", 197 | "# Dataset size\n", 198 | "num_samples: 10_000\n", 199 | "\n", 200 | "compression: None\n", 201 | "\"\"\"\n", 202 | "waveform_dataset_settings = yaml.safe_load(waveform_dataset_settings)\n", 203 | "with open('01_training_data/waveform_dataset/waveform_dataset_settings.yaml', 'w') as outfile:\n", 204 | " yaml.dump(waveform_dataset_settings, outfile, default_flow_style=False)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": { 210 | "id": "Fuq9gUhWqDdd" 211 | }, 212 | "source": [ 213 | "\n", 214 | "**Step 2: Generate waveform dataset**\n", 215 | "\n", 216 | "To generate the waveform dataset, execute the following command:" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": { 223 | "colab": { 224 | "base_uri": "https://localhost:8080/" 225 | }, 226 | "id": "TIS56_UYlQUE", 227 | "outputId": "fba1875b-40a2-4c00-ce49-7dd93d115543" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "!dingo_generate_dataset --settings 01_training_data/waveform_dataset/waveform_dataset_settings.yaml --out_file 01_training_data/waveform_dataset/waveform_dataset.hdf5" 232 | ] 233 | }, 234 | { 235 | "cell_type": "markdown", 236 | "metadata": { 237 | "id": "3eCI_XVAbyy4" 238 | }, 239 | "source": [ 240 | "Load the dataset and visualize an exemplary waveform" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": { 247 | "colab": { 248 | "base_uri": "https://localhost:8080/" 249 | }, 250 | "id": "M0cnwWp4RA_4", 251 | "outputId": "4d1991c5-66a1-4fe5-a236-b1dc666a619b" 252 | }, 253 | "outputs": [], 254 | "source": [ 255 | "from dingo.gw.dataset.waveform_dataset import WaveformDataset\n", 256 | "\n", 257 | "# Load dataset\n", 258 | "waveform_dataset_path = '01_training_data/waveform_dataset/waveform_dataset.hdf5'\n", 259 | "wfd = WaveformDataset(file_name=waveform_dataset_path)\n", 260 | "print(\"Each element wfd[i] contains an intrinsic waveform. It is a nested dict with structure\")\n", 261 | "for k, v in wfd[0].items():\n", 262 | " print(\" - \" + k)\n", 263 | " if isinstance(v, dict):\n", 264 | " for k1 in v:\n", 265 | " print(\" - \" + k1)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "metadata": { 272 | "colab": { 273 | "base_uri": "https://localhost:8080/", 274 | "height": 469 275 | }, 276 | "id": "N5vRbXV6Ud76", 277 | "outputId": "65e4b343-fc32-40ad-9733-82b294aef70a" 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "# Plot a sample waveform\n", 282 | "f_domain = wfd.domain.sample_frequencies\n", 283 | "data_sample = wfd[0]['waveform']\n", 284 | "\n", 285 | "plt.plot(f_domain, data_sample['h_cross'].real, c=\"blue\", label=r'$h_{\\times}$ real')\n", 286 | "plt.plot(f_domain, data_sample['h_cross'].imag, c=\"cornflowerblue\", label=r'$h_{\\times}$ imag')\n", 287 | "plt.plot(f_domain, data_sample['h_plus'].real, c=\"tab:red\", label=r'$h_+$ real')\n", 288 | "plt.plot(f_domain, data_sample['h_plus'].imag, c=\"coral\", label=r'$h_+$ imag')\n", 289 | "plt.xlabel(r\"Frequency $f$ [Hz]\")\n", 290 | "plt.ylabel(r\"Polarization $h$\")\n", 291 | "plt.legend();" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": { 297 | "id": "_tqiDrAHlUZ3" 298 | }, 299 | "source": [ 300 | "## 2. Generating a noise dataset\n", 301 | "In gravitational wave data analysis, the noise is assumed to be stationary and Gaussian in each frequency bin with the variance given the amplitude spectral density (ASD). Similar to extrinsic parameters, detector noise is repeatedly sampled during training and added to the simulated signal. This augments the training set with new noise realizations for each epoch, reducing overfitting.\n", 302 | "To take into account that the ASD can drift during an observing run, we create an ASD dataset based on stretches of measured noise in the LIGO and Virgo detectors.\n", 303 | "For a detailed explanation about the detector noise, please see the [DINGO documentation](https://dingo-gw.readthedocs.io/en/latest/noise_dataset.html) and paper [Dax+2021](https://arxiv.org/abs/2106.12594).\n", 304 | "\n", 305 | "**Step 1: Prepare settings file**\n", 306 | "To generate the ASD dataset, we need to prepare a settings file which contains the following information within `dataset_settings`:\n", 307 | "- **f_min, f_max** (Optional)\n", 308 | "- **f_s**: Sampling rate. This should be at least twice the value of `f_max` expected to be used.\n", 309 | "- **time_psd**: The entire length of data from which to estimate a power spectral density (PSD, $\\mathrm{ASD} = \\sqrt{\\mathrm{PSD}}$) using Welch's method. Periodigrams are calculated on segments of this, and then averaged using the median method.\n", 310 | "- **T**: The length of each segment on which to take the DFT and calculate a periodigram.\n", 311 | "- **window**: Parameters of the window function used before taking DFT of data segments.\n", 312 | "- **channels** (Optional): Channels where to download the data from. By default will use open data from [GWOSC](https://gwosc.org).\n", 313 | "- **time_gap**: This sets the time that is skipped between consecutive PSD estimates. If set < 0., the time segments overlap.\n", 314 | "- **num_psds_max**: If this is set to 0, all available PSDs will be downloaded. For values > 0, a subset of all available PSDs is used.\n", 315 | "- **detectors**: This list specifies for which detectors we want to download ASDs. Options: H1, L1, V1.\n", 316 | "- **observing_run**: This sets the observing run of the ASDs. Options: O1, O2, O3.\n", 317 | "\n", 318 | "In a DINGO training run where the initial layers of the embedding network are seeded with SVD components, we typically use two ASD datasets: One with a single, fiducial ASD that is fixed during the first training stage, and a second one with all ASDs of the selected observing run. In this tutorial, we will train the network with a single ASD.\n", 319 | "\n", 320 | "More information about the noise modeling with DINGO can be found in the section [Detector Noise](https://dingo-gw.readthedocs.io/en/latest/noise_dataset.html) of the documentation. It also includes examples for `htcondor` settings that allow to run this step on the cluster." 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": { 327 | "id": "_3bcFGmiq1Sv" 328 | }, 329 | "outputs": [], 330 | "source": [ 331 | "asd_dataset_settings = \"\"\"\n", 332 | "dataset_settings:\n", 333 | "# f_min: 0 # defaults to 0\n", 334 | "# f_max: 2048 # defaults to f_s/2\n", 335 | " f_s: 4096\n", 336 | " time_psd: 1024\n", 337 | " T: 2.0\n", 338 | " window:\n", 339 | " roll_off: 0.4\n", 340 | " type: tukey\n", 341 | " time_gap: 0 # specifies the time skipped between to consecutive PSD estimates. If set < 0, the time segments overlap\n", 342 | " num_psds_max: 1 # if set > 0, only a subset of all available PSDs will be used\n", 343 | " detectors:\n", 344 | " - H1\n", 345 | " - L1\n", 346 | " observing_run: O1\n", 347 | "\"\"\"\n", 348 | "asd_dataset_settings = yaml.safe_load(asd_dataset_settings)\n", 349 | "with open('01_training_data/asd_dataset/asd_dataset_settings.yaml', 'w') as outfile:\n", 350 | " yaml.dump(asd_dataset_settings, outfile, default_flow_style=False)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "id": "IGHv3EGTq2Ag" 357 | }, 358 | "source": [ 359 | "**Step 2: Generate ASD dataset**\n", 360 | "\n", 361 | "The following downloads the first valid ASD time segment of the selected observing run (recall `num_psds_max: 1`). For the first observing run O1, this is 1126073529 to 1126074553. It then calculates the ASD using Welch median and saves it to an HDF5 file." 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": { 368 | "colab": { 369 | "base_uri": "https://localhost:8080/" 370 | }, 371 | "id": "MO5gB3Iz7z8H", 372 | "outputId": "0876dbf0-c4e9-41c9-a624-fd196c4a3fc5" 373 | }, 374 | "outputs": [], 375 | "source": [ 376 | "!dingo_generate_asd_dataset --settings_file 01_training_data/asd_dataset/asd_dataset_settings.yaml --data_dir 01_training_data/asd_dataset --out_name 01_training_data/asd_dataset/asd_fiducial_O1.hdf5" 377 | ] 378 | }, 379 | { 380 | "cell_type": "markdown", 381 | "metadata": { 382 | "id": "FFppPrhn74xd" 383 | }, 384 | "source": [ 385 | "However, we are interested in analyzing the first GW event called GW150914. Therefore, we specify the start and end time of this event in `time_segment.pkl` and provide it as an additional argument to `dingo_generate_asd_dataset`:" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": null, 391 | "metadata": { 392 | "id": "Av8XjoV6uzV_" 393 | }, 394 | "outputs": [], 395 | "source": [ 396 | "time_GW150914 = 1126259462.391 - 0.0114 # maxlog geocent time of GW150914\n", 397 | "asd_start_time = time_GW150914 - asd_dataset_settings[\"dataset_settings\"][\"T\"]\n", 398 | "asd_end_time = time_GW150914\n", 399 | "time_segments = {\n", 400 | " 'H1': [[asd_start_time, asd_end_time]],\n", 401 | " 'L1': [[asd_start_time, asd_end_time]]\n", 402 | "}\n", 403 | "\n", 404 | "with open('01_training_data/asd_dataset/time_segment_GW150914.pkl', 'wb') as f:\n", 405 | " pickle.dump(time_segments, f)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "metadata": { 412 | "colab": { 413 | "base_uri": "https://localhost:8080/" 414 | }, 415 | "id": "hGK6VNrazQnp", 416 | "outputId": "b094ec31-037b-4d79-eb2f-a32a8dc7847b" 417 | }, 418 | "outputs": [], 419 | "source": [ 420 | "!dingo_generate_asd_dataset --settings_file 01_training_data/asd_dataset/asd_dataset_settings.yaml --data_dir 01_training_data/asd_dataset --out_name 01_training_data/asd_dataset/asd_GW150914.hdf5 --time_segments_file 01_training_data/asd_dataset/time_segment_GW150914.pkl" 421 | ] 422 | }, 423 | { 424 | "cell_type": "markdown", 425 | "metadata": { 426 | "id": "mE8qD8RNcwVA" 427 | }, 428 | "source": [ 429 | "Load and visualize exemplary ASD:" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": { 436 | "colab": { 437 | "base_uri": "https://localhost:8080/", 438 | "height": 496 439 | }, 440 | "id": "SSJizLYSlhtO", 441 | "outputId": "39de6c73-2276-4410-8bf8-afbebc286951" 442 | }, 443 | "outputs": [], 444 | "source": [ 445 | "from dingo.gw.noise.asd_dataset import ASDDataset\n", 446 | "\n", 447 | "# Load ASD dataset\n", 448 | "asd_dataset_path = '01_training_data/asd_dataset/asd_GW150914.hdf5'\n", 449 | "asds = ASDDataset(file_name=asd_dataset_path, domain_update=wfd.domain.domain_dict)\n", 450 | "# Get ASD sample\n", 451 | "asd_sample = asds.sample_random_asds()\n", 452 | "print(\"One Datapoint contains:\", asd_sample.keys())\n", 453 | "\n", 454 | "plt.loglog(f_domain, asd_sample['H1'], label=r'H1')\n", 455 | "plt.loglog(f_domain, asd_sample['L1'], label=r'L1')\n", 456 | "plt.xlabel(r\"Frequency [Hz]\")\n", 457 | "plt.ylabel(r\"Noise ASD $[1/\\sqrt{\\mathrm{Hz}}]$\")\n", 458 | "plt.xlim([wfd.domain.f_min, wfd.domain.f_max])\n", 459 | "plt.ylim([1.e-24, 1.e-20])\n", 460 | "plt.legend();" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": { 466 | "id": "R4xKhg7TlZ1B" 467 | }, 468 | "source": [ 469 | "## 3. Training a simple model\n", 470 | "**Step 1: Prepare config file**\n", 471 | "\n", 472 | "The train settings YAML file has several sections:\n", 473 | "\n", 474 | "**(a) data**: Specifies training data.\n", 475 | "- **waveform_dataset_path**: Path to waveform dataset containing intrinsic waveforms\n", 476 | "- **train_fraction**: Fraction of waveform dataset to be used for training. The remainder are used to compute the test loss.\n", 477 | "- **window**: Defines the window function to use when fourier-transforming from time-domain data. It is used here to calculate a window factor for simulating data. See discussion [here](https://dingo-gw.readthedocs.io/en/latest/noise_dataset.html#ref-window-factor).\n", 478 | "- **detectors**: List of detectors that should be used during training. Options: H1, L1, V1. Note: Dingo models are trained for a *fixed set of detectors*. This must be selected prior to training, and a new model must be trained if one wishes to analyze data in a different set of detectors. Thus, e.g., separate models must be trained for HL and HLV configurations.\n", 479 | "- **extrinsic_prior**: Specify the extrinsic prior. Default options are available.\n", 480 | "- **ref_time**: Reference time for the interferometer locations and orientations. See the important note [here](https://dingo-gw.readthedocs.io/en/latest/training_transforms.html#ref-ref-time).\n", 481 | "- **inference_parameters**: Parameters to infer with the model. Must be a subset of `sample[\"parameters\"]`. By specifying a strict subset, this can be used to marginalize over parameters. The default setting points to `dingo.gw.prior.default_inference_parameters`." 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "metadata": { 488 | "colab": { 489 | "base_uri": "https://localhost:8080/" 490 | }, 491 | "id": "6bMHJfdNDH5f", 492 | "outputId": "63f9b0f4-b86a-4e4d-ab50-1eee1e66e278" 493 | }, 494 | "outputs": [], 495 | "source": [ 496 | "from dingo.gw.prior import default_inference_parameters\n", 497 | "default_inference_parameters" 498 | ] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": { 503 | "id": "nHGVQBonECLz" 504 | }, 505 | "source": [ 506 | "\n", 507 | "**(b) model**: Specifies neural model.\n", 508 | "- **posterior_model_type**. Usually `normalizing_flow` (alternative `flow_matching`).\n", 509 | "- **posterior_kwargs**: Specifies the density estimator. For `normalizing_flow` this includes the number of flow steps and kwargs of flow transform `base_transform_kwargs`.\n", 510 | "- **embedding_kwargs**: Specifies embedding network. This is used to compress the data to a latent representation, which is a passed as context to the posterior model. Arguments include, e.g., dimension of output vector `out_dim`, hidden dimensions `hidden_dims`, and how to obtain the SVD components for the initialization of the first layer `svd`.\n", 511 | "\n", 512 | "**(c) training**: Specifies training procedure, optionally divided into stages. Each stage contains:\n", 513 | "- **epochs**: Number of epochs\n", 514 | "- **asd_dataset_path**: Path to the ASD dataset. If the DINGO model is trained in two stages, this should just contain a single fiducial ASD per detector for the first stage and a full ASD dataset per detector for the second stage.\n", 515 | "- **freeze_rb_layer**: Whether to freeze the reduced basis layer in the embedding network in this stage.\n", 516 | "- **optimizer**: Optimizer setings.\n", 517 | "- **scheduler**: Scheduler settings.\n", 518 | "- **batch_size**: Number of data points to use per batch.\n", 519 | "\n", 520 | "\n", 521 | "**(d) local**: Technical settings (no influence on final model).\n", 522 | "- **device**: Which device to train on. Options: `cpu`, `cuda` for training on GPU.\n", 523 | "- **num_workers**: Number of workers that are used to preprocess the data before training. (`num_workers >0` does not work on Mac, see [post](https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206))\n", 524 | "- **runtime_limits**: Specifies the maximum time per run and the maximum number of epochs per run.\n", 525 | "- **checkpoint_epochs**: Sets after how many epochs a checkpoint is saved.\n", 526 | "\n", 527 | "For more information on typical values to be used for production, see the tutorials for [training an NPE model](https://dingo-gw.readthedocs.io/en/latest/example_npe_model.html), [training a GNPE model](https://dingo-gw.readthedocs.io/en/latest/example_gnpe_model.html). The documentation contains more details about the [network architecture](https://dingo-gw.readthedocs.io/en/latest/network_architecture.html) as well as the [training procedure](https://dingo-gw.readthedocs.io/en/latest/training.html).\n", 528 | "\n", 529 | "Since we want to infer the posterior masses of GW150914, we fix the extrinsic prior values to the maximum likelihood values of the infered parameters." 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": null, 535 | "metadata": { 536 | "id": "lqbIyHwavoUJ" 537 | }, 538 | "outputs": [], 539 | "source": [ 540 | "train_settings = \"\"\"\n", 541 | "data:\n", 542 | " waveform_dataset_path: 01_training_data/waveform_dataset/waveform_dataset.hdf5 # Contains intrinsic waveforms\n", 543 | " train_fraction: 0.95\n", 544 | " window: # Needed to calculate window factor for simulated data\n", 545 | " type: tukey\n", 546 | " f_s: 4096\n", 547 | " T: 2.0\n", 548 | " roll_off: 0.4\n", 549 | " detectors:\n", 550 | " - H1\n", 551 | " - L1\n", 552 | " extrinsic_prior: # Sampled at train time\n", 553 | " dec: bilby.core.prior.analytical.DeltaFunction(-1.2616009712219238)\n", 554 | " ra: bilby.core.prior.analytical.DeltaFunction(1.4557750225067139)\n", 555 | " geocent_time: bilby.core.prior.analytical.DeltaFunction(0.011423417367041111)\n", 556 | " psi: bilby.core.prior.analytical.DeltaFunction(1.2124483585357666)\n", 557 | " luminosity_distance: bilby.core.prior.analytical.DeltaFunction(488.2327880859375)\n", 558 | " ref_time: 1126259462.391\n", 559 | " inference_parameters:\n", 560 | " - chirp_mass\n", 561 | " - mass_ratio\n", 562 | "\n", 563 | "# Model architecture\n", 564 | "model:\n", 565 | " # kwargs for neural spline flow\n", 566 | " posterior_model_type: normalizing_flow\n", 567 | " posterior_kwargs:\n", 568 | " num_flow_steps: 5 # 30\n", 569 | " base_transform_kwargs:\n", 570 | " hidden_dim: 64 # 1024\n", 571 | " num_transform_blocks: 5\n", 572 | " activation: elu\n", 573 | " dropout_probability: 0.0\n", 574 | " batch_norm: True\n", 575 | " num_bins: 8\n", 576 | " base_transform_type: rq-coupling\n", 577 | " # kwargs for embedding net\n", 578 | " embedding_kwargs:\n", 579 | " output_dim: 64 # 128\n", 580 | " hidden_dims: [1024, 512, 256, 64]\n", 581 | " activation: elu\n", 582 | " dropout: 0.0\n", 583 | " batch_norm: True\n", 584 | " svd:\n", 585 | " num_training_samples: 1000\n", 586 | " num_validation_samples: 100\n", 587 | " size: 50\n", 588 | "\n", 589 | "# The first stage (and only) stage of training.\n", 590 | "training:\n", 591 | " stage_0:\n", 592 | " epochs: 15\n", 593 | " asd_dataset_path: 01_training_data/asd_dataset/asd_GW150914.hdf5\n", 594 | " freeze_rb_layer: True\n", 595 | " optimizer:\n", 596 | " type: adam\n", 597 | " lr: 0.0001\n", 598 | " scheduler:\n", 599 | " type: cosine\n", 600 | " T_max: 15\n", 601 | " batch_size: 512\n", 602 | " # stage_1:\n", 603 | " # epochs: 5\n", 604 | " # asd_dataset_path: 01_training_data/asd_dataset/asd_GW150914.hdf5\n", 605 | " # freeze_rb_layer: False\n", 606 | " # optimizer:\n", 607 | " # type: adam\n", 608 | " # lr: 1.e-5\n", 609 | " # scheduler:\n", 610 | " # type: cosine\n", 611 | " # T_max: 5\n", 612 | " # batch_size: 64\n", 613 | "\n", 614 | "# Local settings for training that have no impact on the final trained network.\n", 615 | "local:\n", 616 | " device: cuda # [cpu, cuda] Set this to 'cuda' for training on a GPU.\n", 617 | " num_workers: 2 # num_workers >0 does not work on Mac, see https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206\n", 618 | " runtime_limits:\n", 619 | " max_time_per_run: 36000\n", 620 | " max_epochs_per_run: 30\n", 621 | " checkpoint_epochs: 15\n", 622 | "\"\"\"\n", 623 | "train_settings = yaml.safe_load(train_settings)\n", 624 | "with open('02_training/train_settings.yaml', 'w') as outfile:\n", 625 | " yaml.dump(train_settings, outfile, default_flow_style=False)" 626 | ] 627 | }, 628 | { 629 | "cell_type": "markdown", 630 | "metadata": { 631 | "id": "Mk5_2-Dcvotl" 632 | }, 633 | "source": [ 634 | "**Step 2: Start training run.**\n", 635 | "\n", 636 | "The following command initiates training. Various pieces of information are output during training, including SVD performance (e.g., the truncation of the SVD basis and the resulting mismatch), how the new model is initialized, the number of fixed and learnable parameters of the model, the epoch number with training and validation loss, etc.\n", 637 | "\n", 638 | "Training for 15 epochs takes approximately 10 - 15 minutes." 639 | ] 640 | }, 641 | { 642 | "cell_type": "code", 643 | "execution_count": null, 644 | "metadata": { 645 | "colab": { 646 | "base_uri": "https://localhost:8080/" 647 | }, 648 | "id": "81J8e7dnlaGv", 649 | "outputId": "d0bbbbfa-f8b8-4eab-ac72-ea3efd212274" 650 | }, 651 | "outputs": [], 652 | "source": [ 653 | "!dingo_train --settings_file 02_training/train_settings.yaml --train_dir 02_training" 654 | ] 655 | }, 656 | { 657 | "cell_type": "markdown", 658 | "metadata": { 659 | "id": "-VF96La4kCoS" 660 | }, 661 | "source": [ 662 | "The loss values are stored in `history.txt`, which contains as columns: epoch number, training loss, validation loss, and learning rate.\n", 663 | "\n", 664 | "One can load this file and plot the training and validation loss:" 665 | ] 666 | }, 667 | { 668 | "cell_type": "code", 669 | "execution_count": null, 670 | "metadata": { 671 | "colab": { 672 | "base_uri": "https://localhost:8080/", 673 | "height": 449 674 | }, 675 | "id": "1sJr-YOLjcer", 676 | "outputId": "4e3a21b8-ce1d-40f7-c602-e71e12b94c40" 677 | }, 678 | "outputs": [], 679 | "source": [ 680 | "# Load history.txt\n", 681 | "filename = '02_training/history.txt'\n", 682 | "data = np.loadtxt(filename, delimiter=\"\\t\")\n", 683 | "\n", 684 | "# Plot loss values\n", 685 | "plt.plot(data[:,0], data[:,1], label=f\"training loss\")\n", 686 | "plt.plot(data[:,0], data[:,2], label=f\"validation loss\")\n", 687 | "plt.xlabel(\"Epoch\")\n", 688 | "plt.ylabel(\"Loss\")\n", 689 | "plt.legend();" 690 | ] 691 | }, 692 | { 693 | "cell_type": "markdown", 694 | "metadata": { 695 | "id": "1c1aHvTqlm7N" 696 | }, 697 | "source": [ 698 | "## 4. Injections\n", 699 | "\n", 700 | "Here, we use the **Sampler API** to [prepare and run an injection](https://dingo-gw.readthedocs.io/en/latest/example_injection.html).\n", 701 | "1. Load the model `dingo.core.models.posterior_model.PosteriorModel` into the `dingo.gw.inference.gw_samplers.GWSampler` class.\n", 702 | "2. Instantiate the `dingo.gw.injection.Injection` class based on the settings the posterior model was trained on (accessible via `PosteriorModel.metadata`).\n", 703 | "3. Specify an ASD dataset, e.g. the fiducial ASD dataset the network was trained on.\n", 704 | "4. Sample from the prior and generate an injection.\n", 705 | "5. Insert the generated injection data into the sampler through `sampler.context`.\n", 706 | "6. Generate samples and convert the sampler instance into a `dingo.gw.result.Result`.\n", 707 | "7. Run importance sampling on the result object.\n", 708 | "8. Visualize the corner plot and save the result.\n", 709 | "\n", 710 | "This approach is designed to produce injections entirely consistent with training data." 711 | ] 712 | }, 713 | { 714 | "cell_type": "code", 715 | "execution_count": null, 716 | "metadata": { 717 | "id": "b-BTzAwWEiAY" 718 | }, 719 | "outputs": [], 720 | "source": [ 721 | "from dingo.core.posterior_models.normalizing_flow import NormalizingFlowPosteriorModel\n", 722 | "from dingo.gw.inference.gw_samplers import GWSampler\n", 723 | "from dingo.gw.injection import Injection" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": null, 729 | "metadata": { 730 | "colab": { 731 | "base_uri": "https://localhost:8080/", 732 | "height": 964 733 | }, 734 | "id": "5gTcp3G6E63q", 735 | "outputId": "3340ef72-3bbf-481a-c6ed-fe71b3ab1668" 736 | }, 737 | "outputs": [], 738 | "source": [ 739 | "model_path = \"02_training/model_latest.pt\"\n", 740 | "asd_path = \"01_training_data/asd_dataset/asd_GW150914.hdf5\"\n", 741 | "\n", 742 | "# Load the network into the GWSampler class\n", 743 | "pm = NormalizingFlowPosteriorModel(model_filename=model_path, device=\"cpu\")\n", 744 | "sampler = GWSampler(model=pm)\n", 745 | "\n", 746 | "# Generate an injection consistent with the data the model was trained on.\n", 747 | "injection = Injection.from_posterior_model_metadata(pm.metadata)\n", 748 | "injection.asd = ASDDataset(asd_path, ifos=[\"H1\", \"L1\"])\n", 749 | "theta = injection.prior.sample() # Random injection from prior.\n", 750 | "inj = injection.injection(theta)\n", 751 | "\n", 752 | "# Generate 10,000 samples from the DINGO model based on the generated injection data.\n", 753 | "sampler.context = inj\n", 754 | "sampler.run_sampler(10_000)\n", 755 | "result = sampler.to_result()\n", 756 | "\n", 757 | "# The following are only needed for importance-sampling the result.\n", 758 | "result.importance_sample(num_processes=8)\n", 759 | "\n", 760 | "# Make a corner plot and save the result.\n", 761 | "result.print_summary()\n", 762 | "kwargs = {\"legend_font_size\": 15, \"truth_color\": \"black\"}\n", 763 | "parameters = ['chirp_mass', 'mass_ratio']\n", 764 | "truths = [v for k,v in theta.items() if k in parameters]\n", 765 | "result.plot_corner(parameters=parameters,\n", 766 | " filename=\"03_inference/injection/corner.pdf\",\n", 767 | " truths=truths,\n", 768 | " **kwargs)\n", 769 | "result.to_file(\"03_inference/injection/result.hdf5\")" 770 | ] 771 | }, 772 | { 773 | "cell_type": "markdown", 774 | "metadata": { 775 | "id": "fl7Un8Sn-jvZ" 776 | }, 777 | "source": [ 778 | "Since we trained the model in a simplified setting (small waveform dataset, small model, short training), we don't expect high sample efficiencies (> 1%)." 779 | ] 780 | }, 781 | { 782 | "cell_type": "markdown", 783 | "metadata": { 784 | "id": "Gb3Yz2gRlfno" 785 | }, 786 | "source": [ 787 | "## 5. Real events with `dingo_pipe`\n", 788 | "\n", 789 | "Dingo includes a command-line tool [`dingo_pipe`](https://dingo-gw.readthedocs.io/en/latest/dingo_pipe.html) for automating inference tasks. This is based very closely on the [`bilby_pipe`](https://lscsoft.docs.ligo.org/bilby_pipe/master/user-interface.html) package, with suitable modifications.\n", 790 | "\n", 791 | "**Step 1: Prepare `.ini` file for `dingo_pipe`.**\n", 792 | "Similar to `bilby_pipe`, `dingo_pipe` requires an `.ini` file as an input that specifies the arguments of the subsequently performed steps.\n", 793 | "1. **Job submission arguments**: Specify cluster and system specific details here.\n", 794 | "2. **Sampler arguments**: Determine which model should be used to generate the samples.\n", 795 | "3. **Data generation arguments**: Define which event data should be downloaded.\n", 796 | "4. **Plotting arguments**: Specify which plots should be generated automatically." 797 | ] 798 | }, 799 | { 800 | "cell_type": "code", 801 | "execution_count": null, 802 | "metadata": { 803 | "id": "kmI3PWYtEyEK" 804 | }, 805 | "outputs": [], 806 | "source": [ 807 | "dingo_pipe_GW150914 = \"\"\"\n", 808 | "################################################################################\n", 809 | "## Job submission arguments\n", 810 | "################################################################################\n", 811 | "\n", 812 | "local = True\n", 813 | "submit = False\n", 814 | "accounting = dingo\n", 815 | "request-cpus-importance-sampling = 2\n", 816 | "simple-submission = False\n", 817 | "\n", 818 | "################################################################################\n", 819 | "## Sampler arguments\n", 820 | "################################################################################\n", 821 | "\n", 822 | "model = 02_training/model_latest.pt\n", 823 | "device = cuda\n", 824 | "num-samples = 5000\n", 825 | "batch-size = 5000\n", 826 | "importance-sampling-settings = {}\n", 827 | "\n", 828 | "################################################################################\n", 829 | "## Data generation arguments\n", 830 | "################################################################################\n", 831 | "\n", 832 | "trigger-time = 1126259462.3885767 # GW150914 # condition network on maxlog geocent time 1126259462.4 - 0.011423417367041111\n", 833 | "label = GW150914\n", 834 | "outdir = 03_inference/outdir_GW150914\n", 835 | "channel-dict = {H1:GWOSC, L1:GWOSC}\n", 836 | "psd-length = 128\n", 837 | "# sampling-frequency = 2048.0\n", 838 | "# importance-sampling-updates = {'duration': 4.0}\n", 839 | "\n", 840 | "################################################################################\n", 841 | "## Plotting arguments\n", 842 | "################################################################################\n", 843 | "\n", 844 | "plot-corner = true\n", 845 | "plot-weights = true\n", 846 | "plot-log-probs = true\n", 847 | "\"\"\"\n", 848 | "with open('03_inference/GW150914.ini', 'w') as outfile:\n", 849 | " outfile.write(dingo_pipe_GW150914)" 850 | ] 851 | }, 852 | { 853 | "cell_type": "markdown", 854 | "metadata": { 855 | "id": "ESQyNZ9UEydD" 856 | }, 857 | "source": [ 858 | "**Step 2: Run `dingo_pipe` for GW150914**\n", 859 | "\n", 860 | "If `local=True`, you can run all steps of `dingo_pipe` automatically by executing the following command:" 861 | ] 862 | }, 863 | { 864 | "cell_type": "code", 865 | "execution_count": null, 866 | "metadata": { 867 | "colab": { 868 | "base_uri": "https://localhost:8080/" 869 | }, 870 | "id": "ky-pl9tili4L", 871 | "outputId": "123a284b-253d-4658-b263-2010b35d1be6" 872 | }, 873 | "outputs": [], 874 | "source": [ 875 | "!dingo_pipe 03_inference/GW150914.ini" 876 | ] 877 | }, 878 | { 879 | "cell_type": "markdown", 880 | "metadata": { 881 | "id": "ZZy58i40C8Ki" 882 | }, 883 | "source": [ 884 | "This script produces several files in the folder `03_inference/outdir_GW150914`, where `results` contains the generated (importance) samples, and different plots.\n", 885 | "\n", 886 | "The officially reported values are: chirp mass $\\mathcal{M}_c = 27.6^{+2.0}_{-2.0} M_{⊙}$, and mass ratio $q = 0.82^{+0.17}_{-0.20}$ (from [GWOSC](https://gwosc.org/events/GW150914/#P150914), [arXiV:1602.03837](https://arxiv.org/abs/1602.03837)). Similar to the injections, we don't expect good results from this simplified toy setting." 887 | ] 888 | }, 889 | { 890 | "cell_type": "code", 891 | "execution_count": null, 892 | "metadata": {}, 893 | "outputs": [], 894 | "source": [ 895 | "from dingo.gw.result import Result\n", 896 | "\n", 897 | "# Adjust path as needed\n", 898 | "result = Result(file_name=\"03_inference/outdir_GW150914/result/GW150914_data0_1126259462-3885767_importance_sampling.hdf5\")\n", 899 | "result.samples.describe()" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": null, 905 | "metadata": {}, 906 | "outputs": [], 907 | "source": [ 908 | "result.print_summary()" 909 | ] 910 | }, 911 | { 912 | "cell_type": "markdown", 913 | "metadata": { 914 | "id": "4mk6pLoMNWLV" 915 | }, 916 | "source": [ 917 | "**Exercises**:\n", 918 | "1. Replace the delta function over the luminosity distance in the prior of the DINGO model by a Uniform distribution over $d_L \\in [100, 1000]$ Mpc.\n", 919 | "2. Train a marginalized model & perform inference.\n", 920 | "3. Train a model that infers the luminosity distance as well.\n", 921 | "\n", 922 | "You can use the folder `04_exercise` and its subfolders to save the settings files and models." 923 | ] 924 | }, 925 | { 926 | "cell_type": "markdown", 927 | "metadata": { 928 | "id": "_9nA80C7m_VI" 929 | }, 930 | "source": [ 931 | "# 6. Inference with a pre-trained DINGO model (from Zenodo)\n", 932 | "\n", 933 | "Larger networks are more expensive to train, often ~ week on an A100. We therefore supply some pre-trained networks on Zenodo. These production models rely on the algorithm [\"Group Equivariant Neural Posterior Estimation\" (GNPE)](https://dingo-gw.readthedocs.io/en/latest/gnpe.html) which includes (approximate) joint symmetries of data and parameters in the `dingo` model and simplifies the data subject to time shifts.\n", 934 | "\n", 935 | "Practically, this this means that instead of one `dingo` model, we must specify two separate models:\n", 936 | "1. **Initialization model**: Learns a proxy for the arrival times of the signal in each detector and acts as a starting point for the Gibbs sampler. Based on the output of this model, the strain data is transformed and simplified before serving as an input to the main model.\n", 937 | "2. **Main model**: Conditioned on the proxy variable and learns the full posterior distribution.\n", 938 | "\n", 939 | "\n", 940 | "**Step 1: Download trained DINGO model from Zenodo**\n", 941 | "\n", 942 | "Use the [O1 GNPE networks](https://zenodo.org/records/12156303), which are based on `IMRPhenomXPHM` and a luminosity distance prior of $d_L \\in [100 - 1000]$ Mpc.\n", 943 | "The following download might take some time depending on the internet connection." 944 | ] 945 | }, 946 | { 947 | "cell_type": "code", 948 | "execution_count": null, 949 | "metadata": { 950 | "id": "wnbcPG7A_B6U" 951 | }, 952 | "outputs": [], 953 | "source": [ 954 | "!wget -q --directory-prefix=05_pretrained_model/init_train_dir https://zenodo.org/records/12156303/files/model_init.pt # init model\n", 955 | "!wget -q --directory-prefix=05_pretrained_model/main_train_dir https://zenodo.org/records/12156303/files/model.pt # main model" 956 | ] 957 | }, 958 | { 959 | "cell_type": "markdown", 960 | "metadata": { 961 | "id": "r39wvOn6rNBu" 962 | }, 963 | "source": [ 964 | "**Step 2: Prepare INI file**\n", 965 | "\n", 966 | "For GNPE, the `GW150914.ini` file has to be adapted to specify the path to the initial and the main models. We also must include `num_gnpe_iterations` to specify the number of Gibbs sampling iterations.\n", 967 | "\n", 968 | "We set `recover-log-prob = False` and `importance-sample = False`, to speed up execution for the tutorial. If you change these values to `True`, the log probability values have to be recovered for importance sampling with GNPE which takes more time." 969 | ] 970 | }, 971 | { 972 | "cell_type": "code", 973 | "execution_count": null, 974 | "metadata": { 975 | "id": "puWRrqSLq7v2" 976 | }, 977 | "outputs": [], 978 | "source": [ 979 | "dingo_pipe_GW150914 = \"\"\"\n", 980 | "################################################################################\n", 981 | "## Job submission arguments\n", 982 | "################################################################################\n", 983 | "\n", 984 | "local = True\n", 985 | "submit = False\n", 986 | "accounting = dingo\n", 987 | "request-cpus-importance-sampling = 2\n", 988 | "simple-submission = False\n", 989 | "\n", 990 | "################################################################################\n", 991 | "## Sampler arguments\n", 992 | "################################################################################\n", 993 | "\n", 994 | "model-init = 05_pretrained_model/init_train_dir/model_init.pt\n", 995 | "model = 05_pretrained_model/main_train_dir/model.pt\n", 996 | "device = cuda\n", 997 | "num-gnpe-iterations = 10 #30\n", 998 | "num-samples = 10_000\n", 999 | "batch-size = 5000\n", 1000 | "recover-log-prob = False #True\n", 1001 | "prior-dict-updates = {\n", 1002 | "luminosity_distance = bilby.gw.prior.UniformComovingVolume(minimum=100, maximum=2000, name='luminosity_distance'),\n", 1003 | "}\n", 1004 | "importance-sample = False #True\n", 1005 | "\n", 1006 | "################################################################################\n", 1007 | "## Data generation arguments\n", 1008 | "################################################################################\n", 1009 | "\n", 1010 | "trigger-time = GW150914\n", 1011 | "label = GW150914\n", 1012 | "outdir = 05_pretrained_model/outdir_GW150914\n", 1013 | "channel-dict = {H1:GWOSC, L1:GWOSC}\n", 1014 | "psd-length = 128\n", 1015 | "# sampling-frequency = 2048.0\n", 1016 | "# importance-sampling-updates = {'duration': 4.0}\n", 1017 | "\n", 1018 | "################################################################################\n", 1019 | "## Plotting arguments\n", 1020 | "################################################################################\n", 1021 | "\n", 1022 | "plot-corner = true\n", 1023 | "plot-weights = true\n", 1024 | "plot-log-probs = true\n", 1025 | "\"\"\"\n", 1026 | "with open('05_pretrained_model/GW150914.ini', 'w') as outfile:\n", 1027 | " outfile.write(dingo_pipe_GW150914)" 1028 | ] 1029 | }, 1030 | { 1031 | "cell_type": "markdown", 1032 | "metadata": { 1033 | "id": "E0SIkaFEYdDq" 1034 | }, 1035 | "source": [ 1036 | "**Step 3: Run `dingo_pipe` on pretrained DINGO model**\n", 1037 | "\n", 1038 | "Finally, we can run `dingo_pipe` with the specified `.ini` file to obtain the results for the complete posterior distribution." 1039 | ] 1040 | }, 1041 | { 1042 | "cell_type": "markdown", 1043 | "metadata": {}, 1044 | "source": [ 1045 | "Due to a recent change in the window factor usage, the current models on Zenodo are not compatible with `dingo-gw>=0.9.0` and we expect low sample efficiency. If you want to run inference with the GNPE models, downgrade the package to `dingo-gw==0.8.7`.\n", 1046 | "The following code illustrates the usage of `dingo_pipe` and showcases how inference works in principle." 1047 | ] 1048 | }, 1049 | { 1050 | "cell_type": "code", 1051 | "execution_count": null, 1052 | "metadata": { 1053 | "colab": { 1054 | "base_uri": "https://localhost:8080/" 1055 | }, 1056 | "id": "cPFEGcr_Xlpi", 1057 | "outputId": "0f9a18c3-5ec3-4ebb-8c30-b52874f9f382" 1058 | }, 1059 | "outputs": [], 1060 | "source": [ 1061 | "!dingo_pipe 05_pretrained_model/GW150914.ini" 1062 | ] 1063 | }, 1064 | { 1065 | "cell_type": "markdown", 1066 | "metadata": { 1067 | "id": "11QlQ1XeBHC7" 1068 | }, 1069 | "source": [ 1070 | "The full 14-dimensional corner plot based on DINGO samples can be found in `05_pretrained_model/outdir_GW150914/results`.\n", 1071 | "\n", 1072 | "**Exercises:**\n", 1073 | "1. Run an injection with the GNPE model. You can find some hints in the [GNPE](https://dingo-gw.readthedocs.io/en/latest/gnpe.html) and [injection](https://dingo-gw.readthedocs.io/en/latest/example_injection.html) docs.\n", 1074 | "\n", 1075 | "You will need to use the `GNPESampler` class that can be imported via\n", 1076 | "`from dingo.core.samplers import GNPESampler`." 1077 | ] 1078 | } 1079 | ], 1080 | "metadata": { 1081 | "accelerator": "GPU", 1082 | "colab": { 1083 | "gpuType": "T4", 1084 | "provenance": [] 1085 | }, 1086 | "kernelspec": { 1087 | "display_name": "Python 3 (ipykernel)", 1088 | "language": "python", 1089 | "name": "python3" 1090 | }, 1091 | "language_info": { 1092 | "codemirror_mode": { 1093 | "name": "ipython", 1094 | "version": 3 1095 | }, 1096 | "file_extension": ".py", 1097 | "mimetype": "text/x-python", 1098 | "name": "python", 1099 | "nbconvert_exporter": "python", 1100 | "pygments_lexer": "ipython3", 1101 | "version": "3.12.6" 1102 | } 1103 | }, 1104 | "nbformat": 4, 1105 | "nbformat_minor": 4 1106 | } 1107 | --------------------------------------------------------------------------------