├── .gitattributes ├── .github └── workflows │ └── upload_to_pypi.yml ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── Makefile ├── boone2021.rst ├── conf.py ├── index.rst ├── installation.rst ├── make.bat ├── photoz.rst ├── reference.rst ├── sncosmo.rst └── usage.rst ├── notebooks ├── classification.ipynb ├── compare_latent_dimension.ipynb ├── detecting_novel_transients.ipynb ├── parsnip_model.ipynb ├── photoz.ipynb ├── sncosmo_mcmc_posterior.ipynb └── spectra_comparison.ipynb ├── parsnip ├── __init__.py ├── classifier.py ├── instruments.py ├── light_curve.py ├── models │ ├── plasticc.pt │ ├── plasticc_photoz.pt │ └── ps1.pt ├── parsnip.py ├── plotting.py ├── settings.py ├── sncosmo.py └── utils.py ├── pyproject.toml ├── scripts ├── parsnip_build_plasticc_combined ├── parsnip_predict └── parsnip_train ├── setup.cfg └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | notebooks/* linguist-documentation 2 | docs/* linguist-documentation 3 | -------------------------------------------------------------------------------- /.github/workflows/upload_to_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | 13 | - uses: actions/setup-python@v2 14 | name: Install Python 15 | with: 16 | python-version: '3.x' 17 | 18 | - name: Install build 19 | run: python -m pip install build 20 | 21 | - name: Build sdist and wheel 22 | run: python -m build 23 | 24 | - name: Upload to PyPI 25 | uses: pypa/gh-action-pypi-publish@release/v1 26 | with: 27 | user: __token__ 28 | password: ${{ secrets.PYPI_API_TOKEN }} 29 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | 12 | # Optionally build your docs in additional formats such as PDF 13 | formats: 14 | - pdf 15 | 16 | # Optionally set the version of Python and requirements required to build your docs 17 | python: 18 | version: "3.8" 19 | install: 20 | - method: pip 21 | path: . 22 | extra_requirements: 23 | - docs 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kyle Boone 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include parsnip/models/plasticc.pt 2 | include parsnip/models/ps1.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ParSNIP 2 | 3 | Deep generative modeling of astronomical transient light curves 4 | 5 | [![Documentation Status](https://readthedocs.org/projects/parsnip/badge/?version=latest)](https://parsnip.readthedocs.io/en/latest/?badge=latest) 6 | 7 | ## About 8 | 9 | ParSNIP learns a generative model of transients from a large dataset 10 | of transient light curves. This code has many applications including 11 | classification of transients, cosmological distance estimation, and 12 | identifying novel transients. A full description of the algorithms 13 | in this code can be found in Boone 2021 (submitted to ApJ). 14 | 15 | ## Installation and Usage 16 | 17 | Instructions on how to install and use ParSNIP can be found on the [ParSNIP 18 | readthedocs page](https://parsnip.readthedocs.io/en/latest/). 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/boone2021.rst: -------------------------------------------------------------------------------- 1 | ********************** 2 | Reproducing Boone 2021 3 | ********************** 4 | 5 | Overview 6 | ======== 7 | 8 | The details of the ParSNIP model are documented in Boone 2021. To reproduce all of the 9 | results in that paper, follow the following steps. 10 | 11 | Installing ParSNIP 12 | ================== 13 | 14 | Install the ParSNIP software package following the instructions on the 15 | :doc:`installation` page. 16 | 17 | Downloading the data 18 | ==================== 19 | 20 | From the desired working directory, run the following scripts on the command line to 21 | download the PLAsTiCC and PS1 datasets to `./data/` directory. 22 | 23 | Download PS1:: 24 | 25 | $ lcdata_download_ps1 26 | 27 | Download PLAsTiCC (warning, this can take a long time):: 28 | 29 | $ lcdata_download_plasticc 30 | 31 | Build a combined PLAsTiCC training set for ParSNIP:: 32 | 33 | $ parsnip_build_plasticc_combined 34 | 35 | 36 | Training the ParSNIP model 37 | ========================== 38 | 39 | Note: Model training is much faster if a GPU is available. By default, ParSNIP will 40 | attempt to use the GPU if there is one and fallback to CPU if not. This can be overriden 41 | by passing e.g. `--device cpu` to the `parsnip_train` script where `cpu` is the desired 42 | PyTorch device. 43 | 44 | Train a PS1 model using the full dataset (1 hour):: 45 | 46 | $ parsnip_train \ 47 | ./models/parsnip_ps1.pt \ 48 | ./data/ps1.h5 49 | 50 | Train a PS1 model with a held-out validation set (1 hour):: 51 | 52 | $ parsnip_train \ 53 | ./models/parsnip_ps1_validation.pt \ 54 | ./data/ps1.h5 \ 55 | --split_train_test 56 | 57 | Train a PLAsTiCC model using the full dataset (1 day):: 58 | 59 | $ parsnip_train \ 60 | ./models/parsnip_plasticc.pt \ 61 | ./data/plasticc_combined.h5 62 | 63 | Train a PLAsTiCC model with a held-out validation set (1 day):: 64 | 65 | $ parsnip_train \ 66 | ./models/parsnip_plasticc_validation.pt \ 67 | ./data/plasticc_combined.h5 \ 68 | --split_train_test 69 | 70 | 71 | Generate predictions 72 | ==================== 73 | 74 | Generate predictions for the PS1 dataset (< 1 min):: 75 | 76 | parsnip_predict ./predictions/parsnip_predictions_ps1.h5 \ 77 | ./models/parsnip_ps1.pt \ 78 | ./data/ps1.h5 79 | 80 | Generate predictions for the PS1 dataset with 100-fold augmentation (3 min):: 81 | 82 | parsnip_predict ./predictions/parsnip_predictions_ps1_aug_100.h5 \ 83 | ./models/parsnip_ps1.pt \ 84 | ./data/ps1.h5 \ 85 | --augments 100 86 | 87 | Generate predictions for the PLAsTiCC combined training dataset (7 min):: 88 | 89 | parsnip_predict ./predictions/parsnip_predictions_plasticc_combined.h5 \ 90 | ./models/parsnip_plasticc.pt \ 91 | ./data/plasticc_combined.h5 92 | 93 | Generate predictions for the PLAsTiCC training set with 100-fold augmentation (4 min):: 94 | 95 | parsnip_predict ./predictions/parsnip_predictions_plasticc_train_aug_100.h5 \ 96 | ./models/parsnip_plasticc.pt \ 97 | ./data/plasticc_train.h5 \ 98 | --augments 100 99 | 100 | Generate predictions for the full PLAsTiCC dataset (1 hour):: 101 | 102 | parsnip_predict ./predictions/parsnip_predictions_plasticc_test.h5 \ 103 | ./models/parsnip_plasticc.pt \ 104 | ./data/plasticc_test.h5 105 | 106 | Figures and analysis 107 | ==================== 108 | 109 | All of the figures and analysis in Boone 2021 were done with `Jupyter notebooks that are 110 | available on GitHub `_. To rerun 111 | these notebooks, copy the notebooks folder to the working directory and run the 112 | notebooks from within that folder. 113 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | import importlib.metadata 8 | 9 | # -- Project information ----------------------------------------------------- 10 | 11 | project = 'parsnip' 12 | copyright = '2021, Kyle Boone' 13 | author = 'Kyle Boone' 14 | 15 | # The full version, including alpha/beta/rc tags 16 | release = importlib.metadata.version('astro-parsnip') 17 | 18 | 19 | # -- General configuration --------------------------------------------------- 20 | 21 | intersphinx_mapping = { 22 | 'python': ('https://docs.python.org/3/', None), 23 | 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 24 | 'astropy': ('http://docs.astropy.org/en/stable/', None), 25 | } 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'sphinx.ext.autodoc', 32 | 'sphinx.ext.autosummary', 33 | 'sphinx.ext.intersphinx', 34 | 'sphinx.ext.mathjax', 35 | 'sphinx.ext.napoleon', 36 | ] 37 | 38 | autosummary_generate = ["reference.rst"] 39 | 40 | # Napoleon settings 41 | napoleon_google_docstring = False 42 | napoleon_numpy_docstring = True 43 | napoleon_include_init_with_doc = False 44 | napoleon_include_private_with_doc = False 45 | napoleon_include_special_with_doc = True 46 | napoleon_use_admonition_for_examples = False 47 | napoleon_use_admonition_for_notes = False 48 | napoleon_use_admonition_for_references = False 49 | napoleon_use_ivar = False 50 | napoleon_use_param = True 51 | napoleon_use_rtype = True 52 | 53 | # Add any paths that contain templates here, relative to this directory. 54 | templates_path = ['_templates'] 55 | 56 | # List of patterns, relative to source directory, that match files and 57 | # directories to ignore when looking for source files. 58 | # This pattern also affects html_static_path and html_extra_path. 59 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 60 | 61 | # The reST default role (used for this markup: `text`) to use for all 62 | # documents. 63 | default_role = 'obj' 64 | 65 | # The name of the Pygments (syntax highlighting) style to use. 66 | pygments_style = 'sphinx' 67 | 68 | # -- Options for HTML output ------------------------------------------------- 69 | 70 | # The theme to use for HTML and HTML Help pages. See the documentation for 71 | # a list of builtin themes. 72 | # 73 | html_theme = 'sphinx_rtd_theme' 74 | 75 | master_doc = 'index' 76 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ParSNIP 2 | ======= 3 | 4 | About 5 | ----- 6 | 7 | ParSNIP is a package for learning generative models of transient light curves. This code 8 | has many applications including classification of transients, cosmological distance 9 | estimation, and identifying novel transients. 10 | 11 | .. toctree:: 12 | :maxdepth: 1 13 | :titlesonly: 14 | 15 | installation 16 | usage 17 | sncosmo 18 | boone2021 19 | photoz 20 | reference 21 | 22 | Source code: https://github.com/kboone/parsnip 23 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | ************ 2 | Installation 3 | ************ 4 | 5 | ParSNIP requires Python 3.6+ and depends on the following Python packages: 6 | 7 | - `astropy `_ 8 | - `extinction `_ 9 | - `lcdata `_ 10 | - `lightgbm `_ 11 | - `matplotlib `_ 12 | - `numpy `_ 13 | - `scipy `_ 14 | - `PyTorch `_ 15 | - `scikit-learn `_ 16 | - `tqdm `_ 17 | 18 | Install using pip (recommended) 19 | =============================== 20 | 21 | ParSNIP is available on PyPI. To install the latest release:: 22 | 23 | pip install astro-parsnip 24 | 25 | 26 | Install development version 27 | =========================== 28 | 29 | The ParSNIP source code can be found on `github `_. 30 | 31 | To install it:: 32 | 33 | git clone git://github.com/kboone/parsnip 34 | cd parsnip 35 | pip install -e . 36 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/photoz.rst: -------------------------------------------------------------------------------- 1 | ******************************* 2 | Including Photometric Redshifts 3 | ******************************* 4 | 5 | Overview 6 | ======== 7 | 8 | The base ParSNIP model described in Boone 2021 assumes that the redshift of each 9 | transient is known. In Boone et al. 2022 (in prep.), ParSNIP was extended to handle 10 | datasets that only have photometric redshifts available. ParSNIP uses the photometric 11 | redshift as a prior and predicts the redshift of each transients. Currently ParSNIP only 12 | supports Gaussian photometric redshifts like the ones in the PLAsTiCC dataset, but it is 13 | straightforward to include more complex photometric redshift priors. 14 | 15 | The `plasticc_photoz` built-in model was trained on the PLAsTiCC dataset and uses 16 | photometric redshifts instead of true redshifts. It can be loaded with the following 17 | command: 18 | 19 | >>> model = parsnip.load_model('plasticc_photoz') 20 | 21 | This model assumes that each transient has metadata with a `hostgal_photoz` key 22 | containing the mean photometric redshift prediction and a `hostgal_photoz_err` key 23 | containing the photometric redshift uncertainty. 24 | 25 | Training ParSNIP with photometric redshifts 26 | =========================================== 27 | 28 | The following steps can be used to train a model that uses photometric redshifts on the 29 | PLAsTiCC dataset and generate predictions for both the training and test datasets. You 30 | should first follow the steps in :doc:`boone2021` to download the PLAsTiCC dataset. 31 | 32 | Photometric redshifts are enabled by passing the `--predict_redshift` flag to 33 | `parsnip_train`. Model training can be unstable at early epochs when the redshift is 34 | being predicted, so we recommend using larger batch sizes and starting the training with 35 | a lower learning rate. A batch size of 256 and a learning rate of 5e-4 is stable for 36 | the PLAsTiCC dataset. 37 | 38 | Note: Model training is much faster if a GPU is available. By default, ParSNIP will 39 | attempt to use the GPU if there is one and fallback to CPU if not. This can be overriden 40 | by passing e.g. `--device cpu` to the `parsnip_train` script where `cpu` is the desired 41 | PyTorch device. 42 | 43 | Train the PLAsTiCC model using the full dataset (1 day):: 44 | 45 | $ parsnip_train \ 46 | ./models/parsnip_plasticc_photoz.pt \ 47 | ./data/plasticc_combined.h5 \ 48 | --batch_size 256 \ 49 | --learning_rate 5e-4 \ 50 | --predict_redshift 51 | 52 | Generate predictions for the PLAsTiCC training set with 100-fold augmentation (4 min):: 53 | 54 | parsnip_predict ./predictions/parsnip_predictions_plasticc_photoz_train_aug_100.h5 \ 55 | ./models/parsnip_plasticc_photoz.pt \ 56 | ./data/plasticc_train.h5 \ 57 | --augments 100 58 | 59 | Generate predictions for the full PLAsTiCC dataset (1 hour):: 60 | 61 | parsnip_predict ./predictions/parsnip_predictions_plasticc_photoz_test.h5 \ 62 | ./models/parsnip_plasticc_photoz.pt \ 63 | ./data/plasticc_test.h5 64 | 65 | By default, ParSNIP uses a spectroscopic redshift prior with a width of 0.01 during 66 | training. This can be adjusted using the `specz_error` flag to `parsnip_train`. For 67 | example, running `parsnip_train ... --specz_error 0.05` will use a prior with a width of 68 | 0.05 instead. 69 | 70 | Figures and analysis 71 | ==================== 72 | 73 | All of the figures and analysis in Boone et al. 2022 (in prep.) can be reproduced with 74 | `a Jupyter notebook that is available on GitHub 75 | `_. To rerun this 76 | notebook on a newly trained model, copy the notebooks folder to the working directory 77 | and run the notebook from within that folder. 78 | -------------------------------------------------------------------------------- /docs/reference.rst: -------------------------------------------------------------------------------- 1 | *************** 2 | Reference / API 3 | *************** 4 | 5 | .. currentmodule:: parsnip 6 | 7 | 8 | Models 9 | ====== 10 | 11 | *Loading/saving a model* 12 | 13 | .. autosummary:: 14 | :toctree: api 15 | 16 | ParsnipModel 17 | load_model 18 | ParsnipModel.save 19 | ParsnipModel.to 20 | 21 | *Interacting with a dataset* 22 | 23 | .. autosummary:: 24 | :toctree: api 25 | 26 | ParsnipModel.preprocess 27 | ParsnipModel.augment_light_curves 28 | ParsnipModel.get_data_loader 29 | ParsnipModel.fit 30 | ParsnipModel.score 31 | 32 | *Generating model predictions* 33 | 34 | .. autosummary:: 35 | :toctree: api 36 | 37 | ParsnipModel.predict 38 | ParsnipModel.predict_dataset 39 | ParsnipModel.predict_dataset_augmented 40 | ParsnipModel.predict_light_curve 41 | ParsnipModel.predict_spectrum 42 | ParsnipModel.predict_sncosmo 43 | 44 | *Individual parts of the model* 45 | 46 | .. autosummary:: 47 | :toctree: api 48 | 49 | ParsnipModel.forward 50 | ParsnipModel.encode 51 | ParsnipModel.decode 52 | ParsnipModel.decode_spectra 53 | ParsnipModel.loss_function 54 | 55 | 56 | Datasets 57 | ======== 58 | 59 | *Loading datasets* 60 | 61 | .. autosummary:: 62 | :toctree: api 63 | 64 | load_dataset 65 | load_datasets 66 | parse_dataset 67 | 68 | *Parsers for specific instruments* 69 | 70 | .. autosummary:: 71 | :toctree: api 72 | 73 | parse_plasticc 74 | parse_ps1 75 | parse_ztf 76 | 77 | *Tools for manipulating datasets* 78 | 79 | .. autosummary:: 80 | :toctree: api 81 | 82 | split_train_test 83 | get_bands 84 | 85 | 86 | Plotting 87 | ======== 88 | 89 | .. autosummary:: 90 | :toctree: api 91 | 92 | plot_light_curve 93 | plot_representation 94 | plot_spectrum 95 | plot_spectra 96 | plot_sne_space 97 | plot_confusion_matrix 98 | get_band_plot_color 99 | get_band_plot_marker 100 | 101 | 102 | Classification 103 | ============== 104 | 105 | .. autosummary:: 106 | :toctree: api 107 | 108 | Classifier 109 | extract_top_classifications 110 | weighted_multi_logloss 111 | 112 | 113 | SNCosmo Interface 114 | ================= 115 | 116 | .. autosummary:: 117 | :toctree: api 118 | 119 | ParsnipSncosmoSource 120 | ParsnipModel.predict_sncosmo 121 | 122 | 123 | Custom Neural Network Layers 124 | ============================ 125 | 126 | .. autosummary:: 127 | :toctree: api 128 | 129 | ResidualBlock 130 | Conv1dBlock 131 | GlobalMaxPoolingTime 132 | 133 | 134 | Settings 135 | ======== 136 | 137 | .. autosummary:: 138 | :toctree: api 139 | 140 | parse_settings 141 | parse_int_list 142 | build_default_argparse 143 | update_derived_settings 144 | update_settings_version 145 | 146 | 147 | Light curve utilities 148 | ===================== 149 | 150 | .. autosummary:: 151 | :toctree: api 152 | 153 | preprocess_light_curve 154 | time_to_grid 155 | grid_to_time 156 | get_band_effective_wavelength 157 | calculate_band_mw_extinctions 158 | should_correct_background 159 | 160 | 161 | General utilities 162 | ================= 163 | 164 | .. autosummary:: 165 | :toctree: api 166 | 167 | nmad 168 | frac_to_mag 169 | parse_device 170 | -------------------------------------------------------------------------------- /docs/sncosmo.rst: -------------------------------------------------------------------------------- 1 | ***************** 2 | SNCosmo Interface 3 | ***************** 4 | 5 | Overview 6 | ======== 7 | 8 | ParSNIP provides an SNCosmo interface with an implementation of the `sncosmo.Source` 9 | class. To load the built-in ParSNIP model trained on the PLAsTiCC dataset:: 10 | 11 | >>> import parsnip 12 | >>> source = parsnip.ParsnipSncosmoSource('plasticc') 13 | 14 | This source can be used in any SNCosmo models or methods. For example:: 15 | 16 | >>> import sncosmo 17 | >>> model = sncosmo.Model(source=source) 18 | 19 | >>> model.param_names 20 | ['z', 't0', 'amplitude', 'color', 's1', 's2', 's3'] 21 | 22 | >>> data = sncosmo.load_example_data() 23 | >>> result, fitted_model = sncosmo.fit_lc( 24 | ... data, model, 25 | ... ['z', 't0', 'amplitude', 's1', 's2', 's3', 'color'], 26 | ... bounds={'z': (0.3, 0.7)}, 27 | ... ) 28 | 29 | Note that ParSNIP is a generative model in that it predicts the full spectral time 30 | series of each transient. When used with the SNCosmo interface, it can operate on light 31 | curves observed in any bands, not just the ones that it was trained on. 32 | 33 | Predicting the model parameters with variational inference 34 | ========================================================== 35 | 36 | The ParSNIP model uses variational inference to predict the posterior distribution over 37 | all of the parameters of the model. An SNCosmo model can be initialized with the result 38 | of this prediction:: 39 | 40 | >>> parsnip_model = parsnip.load_model( ... ) 41 | >>> sncosmo_model = parsnip_model.predict_sncosmo(light_curve) 42 | -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | ***** 2 | Usage 3 | ***** 4 | 5 | Overview 6 | ======== 7 | 8 | ParSNIP is a generative model of astronomical transient light curves. It is designed to 9 | work with light curves in `sncosmo` format using the `lcdata` package to handle large 10 | datasets. See the `lcdata` documentation for details on how to download or ingest 11 | different datasets. 12 | 13 | Training a model 14 | ================ 15 | 16 | ParSNIP provides a built-in script called `parsnip_train` that can be used to train a 17 | model on an `lcdata` dataset. It takes as input the path that the model will be saved to 18 | along with a list of paths to datasets. For example:: 19 | 20 | 21 | $ parsnip_train ./model.pt ./dataset_1.h5 ./dataset_2.h5 22 | 23 | will train a model named `model.pt` using the datasets `dataset_1.h5` and 24 | `dataset_2.h5`. 25 | 26 | Generating predictions 27 | ====================== 28 | 29 | The `parsnip_predict` script can be used to generate predictions given an `lcdata` 30 | dataset and a pretrained ParSNIP model. To run it:: 31 | 32 | $ parsnip_predict ./predictions.h5 ./model.h5 ./dataset.h5 33 | 34 | will generate predictions to the file named `predictions.h5` using the dataset 35 | `dataset.h5` and the model `model.h5`. 36 | 37 | Loading a dataset in Python 38 | =========================== 39 | 40 | ParSNIP is designed to work with `lcdata` datasets. `lcdata` datasets are guaranteed to 41 | be in a specific format, but they may include instrument-specific quirks, light curves 42 | that are not compatible with ParSNIP, or metadata in unusual formats (e.g. PLAsTiCC 43 | types are random integers). ParSNIP includes tools to clean up datasets from a range of 44 | different surveys and reject invalid light curves. Given an `lcdata` dataset, this can 45 | be done with:: 46 | 47 | >>> dataset = parsnip.parse_dataset(raw_dataset, kind='ps1') 48 | 49 | Here `kind` specifies the type of dataset, in this case one from PanSTARRS-1. Currently 50 | supported options include: 51 | 52 | * ps1 53 | * ztf 54 | * plasticc 55 | 56 | A convenience function is also included to read `lcdata` datasets in HDF5 format and 57 | parse them automatically:: 58 | 59 | >>> dataset = parsnip.load_dataset('/path/to/data.h5') 60 | 61 | This function will attempt to determine the dataset kind from the filename. This can be 62 | overridden with the `kind` keyword as in the previous example. 63 | 64 | Loading a model in Python 65 | ========================= 66 | 67 | Once a model has been trained, ParSNIP has a vast Python API for manipulating it and 68 | using it to generate predictions and plots. To load a model in Python:: 69 | 70 | >>> import parsnip 71 | >>> model = parsnip.load_model('/path/to/model.h5') 72 | 73 | There are several built-in models included that can be loaded by specifying their name. 74 | Currently, these are: 75 | 76 | * `plasticc` trained on the PLAsTiCC dataset. 77 | * `ps1` trained on the PS1 dataset from Villar et al. 2020. 78 | * `plasticc_photoz` trained on the PLAsTiCC dataset. Uses the photometric redshifts 79 | instead of the true redshifts. 80 | 81 | To load one of these built-in models:: 82 | 83 | >>> model = parsnip.load_model('plasticc') 84 | 85 | Assuming that you have a light curve in `sncosmo` format, some examples of what can be 86 | done with a model include: 87 | 88 | Predict the latent representation of a light curve:: 89 | 90 | >>> model.predict(light_curve) 91 | { 92 | 'object_id': 'PS0909006', 93 | ... 94 | 's1': 0.19424194, 95 | 's1_error': 0.44743112, 96 | 's2': -0.051611423, 97 | 's2_error': 1.0143535, 98 | ... 99 | } 100 | 101 | Plot the predicted light curve:: 102 | 103 | >>> parsnip.plot_light_curve(light_curve, model) 104 | 105 | Plot the predicted spectrum at a given time:: 106 | 107 | >>> parsnip.plot_spectrum(light_curve, model, time=53000.) 108 | 109 | See the :doc:`reference` page for a list of all of the built-in methods, or the 110 | `notebooks that were used to make figures for Boone et al. 111 | 2021 `_ for examples. 112 | 113 | Classifying light curves 114 | ======================== 115 | 116 | To classify light curves, we first need to predict their representations using a ParSNIP 117 | model. This can be done either with the `parsnip_predict` script described previously or 118 | by operating in memory on an `lcdata` Dataset object:: 119 | 120 | >>> predictions = model.predict_dataset(dataset) 121 | >>> print(predictions) 122 | object_id ra dec ... s3 s3_error 123 | --------- -------- -------- ... ------------- ----------- 124 | PS0909006 333.9503 1.1848 ... 0.19424233 0.4474311 125 | PS0909010 37.1182 -4.0789 ... -0.40881702 0.59658796 126 | PS0910012 52.4718 -28.0867 ... -2.142636 0.08176677 127 | PS0910016 35.3073 -3.91 ... -0.31671444 0.5740286 128 | ... ... ... ... ... ... 129 | 130 | A classifier can be trained on a set of predictions with:: 131 | 132 | >>> classifier = parsnip.Classifier() 133 | >>> classifier.train(predictions) 134 | 135 | The classifier can the be used to generate predictions for a new dataset with:: 136 | 137 | >>> classifier.classify(new_predictions) 138 | object_id SLSN SNII SNIIn SNIa SNIbc 139 | --------- ----- ----- ----- ----- ----- 140 | PS0909006 0.009 0.025 0.031 0.858 0.077 141 | PS0909010 0.001 0.002 0.017 0.954 0.024 142 | PS0910016 0.002 0.002 0.017 0.948 0.032 143 | PSc000001 0.003 0.936 0.038 0.003 0.021 144 | PSc090022 0.960 0.001 0.037 0.001 0.000 145 | ... ... ... ... ... ... 146 | 147 | For more details and examples, see the `classification demo notebook 148 | `_. 149 | -------------------------------------------------------------------------------- /notebooks/compare_latent_dimension.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# ParSNIP latent dimension size\n", 8 | "\n", 9 | "In this notebook, we compare the results of the ParSNIP model when trained with a range of different latent dimension sizes. For each latent dimension size we trained three separate models on the PS1 dataset. We calculate the loss function for each of these models on both the PS1 training and test datasets and use this to determine how many dimensions are necessary/useful." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# Load the dataset" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import numpy as np\n", 26 | "from matplotlib import pyplot as plt\n", 27 | "\n", 28 | "import parsnip" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "Parsing 'ps1.h5' as PanSTARRS dataset ...\n", 41 | "Preprocessing dataset: 100%|██████████| 2885/2885 [00:02<00:00, 1094.12it/s]\n", 42 | "CPU times: user 4.36 s, sys: 376 ms, total: 4.74 s\n", 43 | "Wall time: 3.82 s\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "%%time\n", 49 | "# Load the dataset\n", 50 | "dataset = parsnip.load_dataset('../data/ps1.h5')\n", 51 | "\n", 52 | "# Preprocess it\n", 53 | "base_model = parsnip.load_model('../models/parsnip_ps1.pt')\n", 54 | "dataset = base_model.preprocess(dataset)\n", 55 | "\n", 56 | "# Split into train/test\n", 57 | "train_dataset, test_dataset = parsnip.split_train_test(dataset)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "# Score the dataset with a series of models trained with different latent sizes." 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": {}, 71 | "outputs": [ 72 | { 73 | "name": "stdout", 74 | "output_type": "stream", 75 | "text": [ 76 | "Dimension 1:\n", 77 | " Model #1: train=63.61, test=77.52\n", 78 | " Model #2: train=62.58, test=80.70\n", 79 | " Model #3: train=63.08, test=79.65\n", 80 | "\n", 81 | "Dimension 2:\n", 82 | " Model #1: train=59.51, test=74.15\n", 83 | " Model #2: train=59.19, test=73.48\n", 84 | " Model #3: train=59.23, test=72.21\n", 85 | "\n", 86 | "Dimension 3:\n", 87 | " Model #1: train=58.20, test=71.60\n", 88 | " Model #2: train=58.46, test=71.78\n", 89 | " Model #3: train=58.27, test=72.92\n", 90 | "\n", 91 | "Dimension 4:\n", 92 | " Model #1: train=59.24, test=71.67\n", 93 | " Model #2: train=58.15, test=71.87\n", 94 | " Model #3: train=58.47, test=72.86\n", 95 | "\n", 96 | "Dimension 5:\n", 97 | " Model #1: train=58.76, test=72.37\n", 98 | " Model #2: train=58.28, test=72.79\n", 99 | " Model #3: train=58.50, test=71.44\n", 100 | "\n", 101 | "Dimension 6:\n", 102 | " Model #1: train=58.29, test=72.49\n", 103 | " Model #2: train=58.97, test=72.00\n", 104 | " Model #3: train=58.72, test=72.65\n", 105 | "\n", 106 | "Dimension 7:\n", 107 | " Model #1: train=58.05, test=71.86\n", 108 | " Model #2: train=59.18, test=72.54\n", 109 | " Model #3: train=58.78, test=72.49\n", 110 | "\n", 111 | "Dimension 8:\n", 112 | " Model #1: train=58.68, test=71.27\n", 113 | " Model #2: train=58.54, test=72.11\n", 114 | " Model #3: train=58.18, test=71.78\n", 115 | "\n", 116 | "Dimension 9:\n", 117 | " Model #1: train=58.34, test=73.77\n", 118 | " Model #2: train=58.90, test=73.51\n", 119 | " Model #3: train=58.93, test=71.81\n", 120 | "\n", 121 | "Dimension 10:\n", 122 | " Model #1: train=58.16, test=72.39\n", 123 | " Model #2: train=58.26, test=72.30\n", 124 | " Model #3: train=59.03, test=72.38\n", 125 | "\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "train_scores = []\n", 131 | "test_scores = []\n", 132 | "latent_sizes = []\n", 133 | "\n", 134 | "base_rounds = 10\n", 135 | "\n", 136 | "for latent_size in range(1, 11):\n", 137 | " size_train_scores = []\n", 138 | " size_test_scores = []\n", 139 | " print(f'Dimension {latent_size}:')\n", 140 | " for model_idx in range(1, 4):\n", 141 | " model = parsnip.load_model(f'../models/latent_{latent_size}_{model_idx}.pt', device='cuda')\n", 142 | "\n", 143 | " train_score = model.score(train_dataset, rounds=base_rounds, return_components=True)\n", 144 | " test_score = model.score(test_dataset, rounds=10 * base_rounds, return_components=True)\n", 145 | "\n", 146 | " print(f' Model #{model_idx}: train={np.sum(train_score):.2f}, test={np.sum(test_score):.2f}')\n", 147 | "\n", 148 | " size_train_scores.append(train_score)\n", 149 | " size_test_scores.append(test_score)\n", 150 | " train_scores.append(size_train_scores)\n", 151 | " test_scores.append(size_test_scores)\n", 152 | " print(\"\")\n", 153 | "\n", 154 | "train_scores = np.array(train_scores)\n", 155 | "test_scores = np.array(test_scores)\n", 156 | "latent_sizes = np.array(latent_sizes)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 4, 162 | "metadata": {}, 163 | "outputs": [ 164 | { 165 | "data": { 166 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAEoCAYAAABBxKqlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8gElEQVR4nO3deXwU9f348dd7s7lIQgiQcIgWsByKYEDEEjyCqFVR8KziSf3W+wBsq7ZfRfCqfmsVj6pV6/GzKFo8APHGoK2xWkDkEBDl0CiEcIUEcm32/ftjJjEJOTbH7uR4P2Efu5mZnc97r/d85jOf+YyoKsYYY9oen9cBGGOMaRpL4MYY00ZZAjfGmDbKErgxxrRRlsCNMaaNsgRujDFtlN/rAELRvXt37du3r9dhGGNMxC1dunS7qqbWNq9NJPC+ffuyZMkSr8MwxpiIE5HNdc2zJhRjjGmjLIEbY0wbZQncGGPaqDbRBm6MaRllZWXk5ORQXFzsdSimhri4OPr06UN0dHTIz7EEbkwHkpOTQ1JSEn379kVEvA7HuFSVHTt2kJOTQ79+/UJ+njWhGNOBFBcX061bN0verYyI0K1bt0bvGVkCN6aDseTdOjXlc2nXCTwYDJKbm4uNeW5M65CZmcm7775bbdqsWbO45ppr6n1OxXkgp556Krt3795vmRkzZnD//ffXW/Ybb7zBV199Vfn39OnT+eCDDxoRfcu45557Wmxd7TaBB4NBxo4dS58+fcjMzCQYDHodkjEd3qRJk5gzZ061aXPmzGHSpEkhPf+tt96iS5cuTSq7ZgK/4447OOGEE5q0ruawBB6CvLw8srOzCQQCZGdnk5eX53VIxnR455xzDm+++SYlJSUAbNq0iR9//JGjjz6aq6++mpEjRzJkyBBuv/32Wp/ft29ftm/fDsDdd9/NoEGDOOGEE1i3bl3lMk899RRHHnkkhx9+OGeffTb79u0jOzub+fPn8/vf/5709HS+/fZbJk+ezNy5cwFYtGgRw4cPZ+jQoVx22WWV8fXt25fbb7+dESNGMHToUNauXbtfTKtXr2bUqFGkp6czbNgw1q9fD8A//vGPyulXXnkl5eXl3HLLLRQVFZGens6FF17Y7Pez3fZCSUtLIyMjg+zsbDIyMkhLS/M6JGNala333EPJmv0TUnPEHjKYnn/8Y53zu3XrxqhRo3jnnXeYOHEic+bM4bzzzkNEuPvuu+natSvl5eWMGzeOFStWMGzYsFrXs3TpUubMmcMXX3xBIBBgxIgRHHHEEQCcddZZXH755QDceuut/P3vf+f6669nwoQJnHbaaZxzzjnV1lVcXMzkyZNZtGgRAwcO5JJLLuHxxx9n6tSpAHTv3p1ly5bx2GOPcf/99/P0009Xe/4TTzzBlClTuPDCCyktLaW8vJw1a9bw8ssv88knnxAdHc0111zD7Nmzuffee3n00UdZvnx5E9/h6tptDVxEyMrKIicnh8WLF9uBG2NaiarNKFWbT1555RVGjBjB8OHDWb16dbXmjpr+9a9/ceaZZ9KpUyc6d+7MhAkTKuetWrWKY445hqFDhzJ79mxWr15dbzzr1q2jX79+DBw4EIBLL72Ujz/+uHL+WWedBcARRxzBpk2b9nv+6NGjueeee7jvvvvYvHkz8fHxLFq0iKVLl3LkkUeSnp7OokWL2LBhQ2hvUCO02xo4gM/no0ePHl6HYUyrVF9NOZzOOOMMbrzxRpYtW0ZRUREjRoxg48aN3H///fz3v/8lJSWFyZMnN9ilrq5K2eTJk3njjTc4/PDDee6551i8eHG962mok0NsbCwAUVFRBAKB/eZfcMEFHHXUUSxcuJBf/vKXPP3006gql156KX/605/qXXdzhbUGLiLTRGS1iKwSkZdEJE5EuorI+yKy3r1PCWcMxpjWJTExkczMTC677LLK2veePXtISEggOTmZ3Nxc3n777XrXceyxx/L6669TVFREQUEBCxYsqJxXUFBAr169KCsrY/bs2ZXTk5KSKCgo2G9dgwcPZtOmTXzzzTcAvPDCCxx33HEhv54NGzbQv39/brjhBiZMmMCKFSsYN24cc+fOZdu2bQDs3LmTzZudQQWjo6MpKysLef31CVsCF5EDgBuAkap6GBAFnA/cAixS1QHAIvdvY0wHMmnSJL788kvOP/98AA4//HCGDx/OkCFDuOyyyxgzZky9zx8xYgTnnXce6enpnH322RxzzDGV8+68806OOuooTjzxRAYPHlw5/fzzz+fPf/4zw4cP59tvv62cHhcXx7PPPsu5557L0KFD8fl8XHXVVSG/lpdffpnDDjuM9PR01q5dyyWXXMKhhx7KXXfdxUknncSwYcM48cQT2bJlCwBXXHEFw4YNa5GDmBKuPtJuAv8PcDiwB3gDeBh4BMhU1S0i0gtYrKqD6lvXyJEj1cYDN6b51qxZwyGHHOJ1GKYOtX0+IrJUVUfWtnzYauCq+gNwP/AdsAXIV9X3gB6qusVdZgtQa/cQEblCRJaIyBLrAmiMMfsLZxNKCjAR6Af0BhJE5KJQn6+qT6rqSFUdmZpa69WEjDGmQwvnQcwTgI2qmqeqZcBrQAaQ6zad4N5vC1cAdiq9MaY9C2cC/w74hYh0Eqe/zzhgDTAfuNRd5lJgXjgKt1PpjTHtXdj6gavqZyIyF1gGBIAvgCeBROAVEfkfnCR/bjjKz3nxRabn/EDP/gezNecHcl58kYMuCrkFxxhjWr2wnsijqrcDNQc1KMGpjYdN/oIF7Lv/L/R2r2zROzqafff/hfzkZJJPPz2cRRtjTMS0y1Pptz04C61xFpcWF7PtwVneBGSMAWDHjh2kp6eTnp5Oz549OeCAAyr/Li0trfe5S5Ys4YYbbmiwjIyMjJYKt1FacpTBUIWtH3hLamw/8DWHHAq1vS4RDllT9/gKxrR3rakf+IwZM0hMTOR3v/td5bRAIIDf3zZH+EhMTKSwsLBZ62g1/cC95O/Vq1HTjTHemTx5MjfeeCNjx47l5ptv5vPPPycjI4Phw4eTkZFROVTs4sWLOe200wAn+V922WVkZmbSv39/Hn744cr1JSYmVi6fmZnJOeecw+DBg7nwwgsre6S99dZbDB48mKOPPpobbrihcr1VeTlMbKja5qauAWnTprLltunVmlEkLo60aVO9C8qYVua+z+9j7c6WHU52cNfB3Dzq5kY/7+uvv+aDDz4gKiqKPXv28PHHH+P3+/nggw/44x//yKuvvrrfc9auXUtWVhYFBQUMGjSIq6++er8run/xxResXr2a3r17M2bMGD755BNGjhzJlVdeyccff0y/fv3qvJiEl8PEhqpdJvCKA5XbHpxF4McfQYS0W262A5jGtFLnnnsuUVFRAOTn53PppZeyfv16RKTOgZ/Gjx9PbGwssbGxpKWlkZubS58+faotM2rUqMpp6enpbNq0icTERPr371959fdJkybx5JNP7rf+0aNHc/fdd5OTk8NZZ53FgAEDqg0TC1BUVOTptQbaZQIHJ4knn346Jd98w4YJEynbuMnrkIxpVZpSUw6XhISEyse33XYbY8eO5fXXX2fTpk1kZmbW+pyKYV6h7qFea1sm1ON+Xg4TG6p22QZeVezPf07yGWew68UXKfvhB6/DMcY0ID8/nwMOOACA5557rsXXP3jwYDZs2FB5cYaXX3651uW8HCY2VO0+gQOkXn8diJD36F+9DsUY04CbbrqJP/zhD4wZM4by8vIWX398fDyPPfYYJ598MkcffTQ9evQgOTl5v+W8HCY2VO2yG2Ftcu/7P3Y+/zz9571B7IABLRSZMW1La+pG6KXCwkISExNRVa699loGDBjAtGnTvA7LuhHWpdsVl+Pr1MlO5jHG8NRTT5Gens6QIUPIz8/nyiuv9DqkJmm3BzFr8qek0O03/0PerIfYt2wZnUaM8DokY4xHpk2b1ipq3M3VYWrgAF0vuYSo1O5s+8sDNsSsMabN61AJ3NepE6nXXEPR0qUUfvSR1+EY4wmrvLROTflcOlQCB+hyzjlEH3QQeQ88iIbhCLcxrVlcXBw7duywJN7KqCo7duwgLi6uUc/rMG3gFSQ6mtQpN/Djb3/HnjffJHniRK9DMiZi+vTpQ05ODnad2dYnLi5uvzNJG9JhuhFWpcEgG885h2D+Hvq//Ra+mJgWW7cxxrQk60ZYg/h8pN34W8p++IHdc2o/C8sYY1q7dp3A67uoccKYDDr94hdsf+IJygv3ehCdMcY0T7tN4A1d1FhESLtxGuU7d7Lz2Wc9itIYY5qu3SbwvLw8srOzCQQCZGdn13rQJn7YMJJOOomdzz5LYMcOD6I0xpima7cJPC0tjYyMDPx+PxkZGXWO2Zs6dSrBkhK2P/G3CEdojDHN024TuIiQlZVFTk4OixcvRkRqXS62fz+6nHUWu+bMoTQnJ8JRGmNM07XbBA7g8/no0aNHncm7QvfrrkV8PvKqXFfPGGNau3adwEMV3aMHXS++iD0L3qTYvYCqMca0dpbAXd0uvxxfUhJ5DzzodSjGGBMSS+CuqORkul3+Gwo/+oh9LXjWpzHGhIsl8Cq6XnQR/rQ0tt3/FxvsxxjT6lkCr8IXH0/3a6+laPlyCrOyvA7HGGPqZQm8hi5nn0VM377kPWjDzRpjWjdL4DWI30/q1KmUrP+G/HnzvQ7HGGPqZAm8Fkm/PIm4ww4j75FHCJaUeB2OMcbUyhJ4LUSEtN/eSGDLFna9+JLX4RhjTK0sgdchYfRoEjIy2PG3v1FeUOB1OMYYsx9L4PVI/e2NlO/ezY5nnvE6FGOM2Y8l8HrEDxlC51NPYedzzxOwawgaY1oZS+ANSJ0yBS0rY/vjj3sdijHGVGMJvAExP/sZXc45m12v/JPS777zOhxjjKlkCTwE3a+5BomOJu8hG27WGNN6WAIPQXRaGl0vuYQ9CxdS/NVXXodjjDGAJfCQdfvN/xCVnMw2G27WGNNKWAIPUVRSEt2uuIK9//43e//zmdfhGGNM+BK4iAwSkeVVbntEZKqIzBCRH6pMPzVcMbS0lIsuxN+zJ9seeMCGmzXGeC5sCVxV16lquqqmA0cA+4DX3dkPVsxT1bfCFUNL88XGknr9dRSvWEHB++97HY4xpoOLVBPKOOBbVd0cofLCJnniRGIOPpi8B2ehgYDX4RhjOrBIJfDzgaqjQl0nIitE5BkRSYlQDC1C/H7Spk2ldONG8t94w+twjDEdWNgTuIjEABOAf7qTHgcOBtKBLcBf6njeFSKyRESW5LWy09gTx40j/vDDyXvkUYLFxV6HY4zpoCJRAz8FWKaquQCqmquq5aoaBJ4CRtX2JFV9UlVHqurI1NTUCIQZOhEh9bc3EsjNZdfs2V6HY4zpoCKRwCdRpflERHpVmXcmsCoCMbS4hFGjSDj2GLY/+RTle/Z4HY4xpgMKawIXkU7AicBrVSb/n4isFJEVwFhgWjhjCKe0G28kmJ/Pjqf/7nUoxpgOyB/OlavqPqBbjWkXh7PMSIobPJjOp53Gzv/3/0i58EKie6R5HZIxpgOxMzGbKXXKDWh5Odsfe8zrUIwxHYwl8GaKOfBAUn71K3bPnUvJxo3V5gWDQXJzc+2sTWNMWFgCbwHdr74KiY0l7+GfhpsNBoOMHTuWPn36kJmZSTAY9DBCY0x7ZAm8Bfi7d6fb5EspePsdilatBiAvL4/s7GwCgQDZ2dm0tr7sxpi2zxJ4C+l62WVEpaSQ94BzXlJaWhoZGRn4/X4yMjJIS7MDnMaYlmUJvIVEJSbS/aor2Zv9KXuzsxERsrKyyMnJYfHixYiI1yEa0yG152NRlsBbUJdJk4ju3Zttf3kADQbx+Xz06NHDkrcxHmnvx6IsgbcgX0wM3W+4nuLVqyl47z2vwzGmw2vvx6Isgbew5NNPJ3bAAPIenEV5SUm73XUzpi1IS0tj9OjRREVFMXr06HZ3LMoSeAuTqChSp02jdPNmlqUPZ/uxx/HJkMPYNX++16EZ0+GoKiJSeWtvlamQEriIRIlIbxE5qOIW7sDasvLCAhAhURWfCN2CQbbeNp38BQu8Ds2YDqW9N6E0OBaKiFwP3A7kAhVHABQYFsa42rS8WQ9BzS19SQnbHpxF8umnexOUMR1Q7GefkzVgICnl5eyKiiL2s89hQvv5DYYymNUUYJCq7gh3MO1FYMuW2qf/+CMlGzYQ279/hCMypuPJX7CArdOn0y0YhIo94enTEaHdVKRCaUL5HsgPdyDtib9XrzrnbTh1PBt/dR47Z88msGtXBKMypmPZ9uAstMYVs7S4mG0PzvImoDAIJYFvABaLyB9E5MaKW7gDa8vSpk1F4uKqTZO4OHrcditpN9+MlpaSe+ddrD/2OL6/7jr2vP8+Wlra4nG0hhMYWkMMpmMK/PhjndN3zp5N0YoVBMPwu4ukUJpQvnNvMe7NNKBi92zbg7MIbNmCv1cv0qZNrZze7deTKV67lvx588lfsIDCDxYRlZxM5/HjST5jInFDhzb75J+KExiys7PJyMggKysLny+ynY5aSwx5eXmkpaXZCVUdQHDfPvIXLmT3nJfrXsjnI/fOu5zH0dHEDRpE/LChxB02lPhhQ4np1w+JiopMwM0kodaMRCQJUFUtDG9I+xs5cqQuWbIk0sVGhAYC7M3OJv+NeRQsWoSWlBDTrx/JEyeSPOF0onv3btJ6c3Nz6dOnD4FAAL/fT05ODj169Gjh6Ft3DLYB6TiKv/6a3XNeJn/+fIKFhcQOHEjckCHkv/UWlJRULidxcfS8YyYJRx5J0cqVFK9cSdHKVRSvXElw714AfAkJxA0ZUi2p+3v18uzzE5Glqjqy1nkNJXAROQx4AejqTtoOXKKqq1s0ynq05wReVXlBAQXvvkv+G/PY577eTkcdRfLEiSSddBJRiQkhr0tVyczMrExeXozH4nUMtgH5KY72uBEJlpZS8O677JrzMkVLlyIxMXQ+5WS6nHc+8cPTERHyFyyoc0+4Kg0GKd240UnqK1ZStHIlJWvXomVlAER17078YYcRN2wo8UOHEnfYYfhTUiLyOpubwLOB/1XVLPfvTOAeVc1o4Tjr1FESeFWlOTnkz59P/rx5lG3+DomLI+nEE0meOJGE0b8IaRevNfxwvYzB6w3Id//4B+tvn0FPv5+tgQADZs7goIsuilj54Lz/Nd8DLzYiLal082Z2vfIK+a+9TvmuXUT/7CBSzjuf5DPPaNGkGiwtpWTdOopWrKB45SqKVq6kdMOGyi7C0QcdVD2pH3oovvj4/dYT6kakLs1N4F+q6uENTQunjpjAK6gqRcuXkz9vHnveepvgnj3409JInnA6yRMnEjtggNchtmqBQIC1a9dy6KGHRjRx5S9YwJbbplfrBSFxcfS6846IdmHb9MILfDNjJr38frYEAvx8xu30vbjtXZZWAwEKsrLY/dIc9mZnQ1QUSePGkXL+eXT6xS+QCH225YWFFK9aTdHKn5J6ZbfhqChiBwwgfuhhxA11knrxunVsnTGzWd+D5ibw14FlOM0oABcBI1X1jJBKbwEdOYFXFSwpoTBrMfnz5lH4r39BIEDcoYeSfMZEOo8fj79bt4ZXEkHNrXk0167581l9yx8qT+IYcu+fSJkwocnr02CQ4L4igoUFBAsLKS8oIFi4l2BhAeWFhQQLCp3phQXs/udctKhov3VIXByJYzPxxcQgMTFIdAwSHe08rnaLrpzuqzmv6vLR1ZevWBa/nz1vvsmPt95WrQ2Y2Fh633Vnm+kHXbZ1K7tf+Se7584lsG0b/p496fKrc+ly9jmt5iLigbw8ilauqpbUg/luz2uR/U/qA/y9ezPgw0Uhrb+5CTwFmAkcDQjwMTBDVSPWibmtJvBwNh8Eduxgz8K3yJ83j+LVqyEqisRjjiH5jIkkjh1LwXvveZo8va6B5i9YsH/yiokhbeoUOh05iuDeKgm4oMD920nA+yfkiumFtf4Ya/IlJjrL1iGmXz+0tPSnW1kZwbIycNtbW0QdiQOc9tyfv/sOvoTQj6lEkgaD7P3kE3bNeZnCrCxQJeGYo0k5/3wSjz0W8YfSec47qkrZd99RtHIVP/7ud7UvJMIha74KaX3NSuCtQVtM4JE8gFWyfr3bXj6fwLZtSFycc/ClvLxymdqSp6o6P/LycjQYhGAQLQ9CsNx5XNe08nI0qBAsd+ZpEC135lcsn3P99ZTv2P/kXV9KCj1vuRkNBNCyAFpW5j4uQwNlTtx1zKPicdV5VedXmVe2ZYsTTyNIXBy+xESiEhPxJSbiS6p4nORMT0rEl1BlelISvgR3esXfnTohPh/rjx9Xaz/k+mpeGgw6r6VKYq+W6EtLCZaWoqVlaJl7X7lsjQ1CaSk7Hn+i3tcbfdBBxA0aSOyAgcQOGkTcoIFEH3RQizdHhFqRCezYwe7XXmP3y69QlpNDVNeudDn7bLqc9yti+vRp0ZgipSnfg5qalMBFZJaqThWRBThjn1Sjqk3fF22ktpjAvegBoeXl7P3Pf8i57vpad98RQaKjncRdkXBbIYl2mgOIjkb8fufvKvdEVzyO/mlexXx3Xv68ukd/7PPYY/gSE4hKchKzLzGRqIQEp+mhhXi9BwJ1Jw9f1650u/giitd9Tcm6dZRu3lz5XZD4eGIHDHAS+8BBxA4aSNzAgUR16dKkGBqqyKgqRUuWsOulOex5/30oK6PTkUeSMul8kk44oUU/Ey+0xPegvgRe375IRZv3/aEGa35ScU3Mii9uJMYhlqgoEseM2e/04UqqdL3kYhAfRPkQXxT4fEiUD3xR4JO6p0X5nJqZL8qZJzWWiXKf5y7zwy23EKylBh6Vmkrff7zgJuLoymRdkYSJimqR5qa9/11SZ80n6fixzV5/Q5JPP52gavU2+DtmRrQZK23a1FqTR88/3FItjmBRESXffEvJ1+soXreOknVfU/De++z+59zKZfw9e7rJfFBlbT2mb19nQ1uP2kYD7NGjB+V79pD/xjx2vfwypd9+iy8piZRJ55Ny3nnEHnxwy78ZHqk8qe+BBwls3Yq/V0/Spk1rse9BKG3gU1T1oYamhVNbrIGDd13oWmK3rbm8roF6XT543w8dmn4gV1UJbMuj5Ot1lKxb59TWv/6akg0bKtvqJTqamIMPrlJbdxJ7VPfuld/33fMXsOqWWyrLH3jttQR+/IE9C99Ci4uJGzaMlPPOo/Opp9TaBa8ltIbutM3R3IOYy1R1RI1pX6jq8BaMsV5tNYF7pTUkL/C+H7TXvWC87ocOLb8R0dJSSjZuqp7Y160jsG1b5TJRXbsSO2gg4o9m76efQiBQbR0SHU3yGWfQ5fzziB8ypMmxhKK1nEzVHE1tA58EXIDT++RfVWYlAeWqekJLB1oXS+CN53XyBCgvL6dbt27k5+eTnJzMjh07iGojY0y0FK9rf5HaiAR27aJk3dc/NcN8vZ7iVatq7QkT1bMnAxdntXgMtWkNe0HN1dQE/jOgH/An4JYqswqAFaoaqPWJYWAJvPFaQ/LMzc3lgAMOoLy8nKioKH744YeI/3i8TqCtgVfvwZrBh9Q+oxFd6JqrNewFNVd9CbzOfQlV3ayqi4ELgc9U9SNV/QhYA7TNPj0dyPbt2yl0+yIXFhayffv2iMeQlpbGmDFj8Pv9jBkzJuIXlK3Yfe7Tpw+ZmZkEW2mvm3Dz+Xz06NEj4onLX8dAbPWNl9/SRISsrCxycnLaZPJuSCiNQa/w06XUAMqBf4YnHNNSvE6e4P2Pp71fD7G1q2tc/LRpUyMah1cbsEgIJYH7VbVy1HP3cdvunNkBeJ08K3j546noyun3+yPWldP8JPn00+lxx0x2+HwEVdnh89Ejwl0p27tQzknNE5EJqjofQEQm4gwpa1q5iuTZUVVsxDp6G7iXSo86irHrv/7pIOJRR3kdUrsSSg38KuCPIvKdiHwP3AxcGd6wjGkZ7Xn3uS2wvaDwarAGrqrfAr8QkUScXisF4Q/LGNMe2F5QeDWYwEUkFjgb6Av4Kz4AVb0jrJEZY9qFjt6UF06htIHPA/KBpUBJA8saY4yJkFASeB9VPTnskRhjjGmUUA5iZovI0LBHYowxplFCqYEfDUwWkY04TSgCqKoOC2tkxhhj6hVKAj8l7FEYY4xptFCaULSOW71EZJCILK9y2yMiU0Wkq4i8LyLr3fuU5r0EY4zpmEKpgS/ESdgCxOGMULgOqHcgX1VdB6QDiEgU8APwOs7IhotU9V4RucX9++Ymxm+MMR1WKCfyVDuAKSIjaPyZmOOAb1V1s3sqfqY7/XlgMZbAjTGm0Rp9aQpVXQYc2cinnQ+85D7uoapb3HVtAezcWmOMaYJQzsS8scqfPmAEEPK4nCISA0wA/tCYwETkCuAKgIMOOqgxTzXGmA4hlBp4UpVbLE6b+MRGlHEKsExVc92/c0WkF4B7v622J6nqk6o6UlVHpqamNqI4Y4zpGOqsgYvIC6p6MbC7mVegn8RPzScA84FLgXvd+3nNWLcxxnRY9dXAj3Cvi3mZiKS43f8qb6GsXEQ6AScCr1WZfC9wooisd+fd29TgjTGmI6uvDfwJ4B2gP85AVlXHgVR3er1UdR/Qrca0HTi9UowxxjRDfRc1flhVDwGeUdX+qtqvyq3B5G2MMSa8GjyIqapXRyIQY4wxjdPofuDGGGNaB0vgxhjTRjWYwEUkQUR87uOBIjJBRKLDH5oxxpj6hFID/xiIE5EDgEXAr4HnwhmUMcaYhoWSwMXtDngW8IiqngkcGt6wjDHGNCSkBC4io4ELcU6jh9CGoTXGGBNGoSTwqTgDUb2uqqtFpD+QFdaojDHGNCiU8cA/Aj4CcA9mblfVG8IdmDHGmPqF0gvlRRHpLCIJwFfAOhH5ffhDM8YYU59QmlAOVdU9wBnAW8BBwMXhDMoYY0zDQkng0W6/7zOAeapaRggXNTbGGBNeoSTwvwGbgATgY3eI2T3hDMoYY0zDQjmI+TDwcJVJm0VkbPhCMsYYE4pQDmImi8gDIrLEvf0FpzZujDHGQ6E0oTwDFAC/cm97gGfDGZQxxpiGhXJG5cGqenaVv2eKyPIwxWOMMSZEodTAi0Tk6Io/RGQMUBS+kIwxxoQilBr41cDzIpKMc13MncDkcAZljDGmYaH0QlkOHC4ind2/rQuhMca0AnUmcBG5sY7pAKjqA2GKyRhjTAjqq4EnRSwKY4wxjVZnAlfVmZEMxBhjTOPYRY2NMaaNsgRujDFtlCVwY4xpo+pM4CIyq8rjKTXmPRe+kIwxxoSivhr4sVUeX1pj3rAwxGKMMaYR6kvgUsdjY4wxrUB9/cB9IpKCk+QrHlck8qiwR2aMMaZe9SXwZGApPyXtZeEPxxhjTKjqO5GnbwTjMMYY00iN6kYoIgeLyP+KyKpwBWSMMSY0oVxSrZeITBWRz4HVOLX2SWGPzBhjTL3q6wd+uYh8CHwEdAd+A2xR1ZmqujJSARpjjKldfQcx/wp8ClygqksAREQjEpUxxpgG1ZfAewPnAg+ISA/gFSA6IlEZY4xpUJ1NKKq6XVUfV9VjgROAfGCbiKwRkXsiFqExxpha1dcG/qiIZACo6veqer+qHgGcAZREKD5jjDF1qK8XynrgLyKySUTuE5F0AFVdZxd7MMYY79XXhPKQqo4GjsO5Ev2zbvPJdBEZELEIjTHG1KrBfuCqullV71PV4cAFwJnA2lBWLiJdRGSuiKx1k/9oEZkhIj+IyHL3dmozX4MxxnRIoZzIEy0ip4vIbOBt4Gvg7BDX/xDwjqoOBg4H1rjTH1TVdPf2VlMCN8aYjq7OboQiciLOGZfjgc+BOcAVqro3lBWLSGecMcUnA6hqKVAqYiPTGmNMS6ivBv5HnBN5DlHV01V1dqjJ29UfyMNpO/9CRJ4WkQR33nUiskJEnnGHqd2PiFwhIktEZEleXl4jijXGmI6hvoOYY1X1KVXd2cR1+4ERwONu+/le4BbgceBgIB3YAvyljvKfVNWRqjoyNTW1iSEYY0z7Fc6LGucAOar6mfv3XGCEquaqarmqBoGngFFhjMEYY9qtsCVwVd0KfC8ig9xJ44CvRKRXlcXOBGxoWmOMaYL6xkJpCdcDs0UkBtgA/Bp42D0pSIFNwJVhjsEYY9qlsCZwVV0OjKwx+eJwlmmMMR1FONvAjTHGhJElcGOMaaMsgRtjTBtlCdwYY9ooS+DGGNNGWQI3xpg2ql0n8GAwSG5uLqp2LWZjTPvTbhN4MBhk7Nix9OnTh8zMTILBoNchGWNMi2q3CTwvL4/s7GwCgQDZ2dnYiIbGmPam3SbwtLQ0MjIy8Pv9ZGRkkJaW5nVIxhjTosI9FopnRISsrCzy8vJIS0vDLiRhjGlv2m0NfOGGhZz82smc+M6J/PLVX7Jww0KvQzLGmBbVLmvgCzcs5PZPbqckWALAlr1bmJE9A4Dx/cd7GJkxxrScdlkDf2jZQ5XJu0JxeTEPLXvIo4iMMabltcsEvnXv1kZNN8aYtqhdJvCeCT0bNd0YY9qidpnAp4yYQlxU3H7TJw+ZHPlgjDEmTNplAh/ffzwzMmbQK6EXgtA9vjuxUbG88c0b7Cvb53V4xhjTItplAgc4pe8plDxSwprfrCH41yAPHPcA63at4+Z/3Ux5sNzr8IwxptnabQKveSr9oJhB3HTkTSz+fjEPLn3Q6/CMMabZ2m0Cr+1U+gsPuZBJgyfx/FfP88+v/+l1iMYY0yzt8kQeqPtU+puOvInvC77n7v/cTZ/EPozuPdrjSI0xpmnabQ0cwOfz0aNHj2rjoPh9fv587J/pl9yP3y7+LRt2b/AwQmOMabp2ncDrkhiTyF/H/ZWYqBiuWXQNO4t3eh2SMcY0WodM4AC9E3vz8PEPs71oO1OzplJSXtLwk4wxphXpsAkcYFjqMO4++m6+2PYF0z+ZbpdeM8a0Ke32IGaoftn3l3y35zse/uJh+ib35erDr/Y6JGOMCUmHT+AAvxn6Gzbt2cRjyx/jZ0k/49T+p3odkjHGNKhDN6FUEBFuH307R/Q4gts+uY3l25Z7HZIxxjTIErgrJiqGWZmz6JnQkylZU8gpyPE6JGOMqZcl8Cq6xHXh0XGPUhYs47pF11FQWuB1SMYYUydL4DX0S+7HrMxZbN6zmd8u/i1lwTKvQzLGmFpZAq/FqF6jmD56Op9u+ZR7P7vXuhcaY1ol64VShzMHnMmmPZt4ZtUz9E3uy8WHXux1SMYYU40l8HpMGTGF7/Z8x5//+2cOTDqQzAMzvQ7JGGMqWRNKPXzi455j7uHQbody08c3sXbnWq9DMsaYSpbAGxDvj+eR4x+hc0xnrlt0Hdv2bfM6JGOMASyBhyS1Uyp/HfdXCkoLuP7D6+26msaYVsESeIgGdR3E/x37f6zduZY//vuPBDXodUjGmA7OEngjHHfgcfx+5O9Z9N0iZi2b5XU4xpgOznqhNNKFh1zIpj2beHbVs/Tt3JezBpzldUjGmA4qrDVwEekiInNFZK2IrBGR0SLSVUTeF5H17n1KOGNoaSLCLaNuIaN3Bnd+eiefbfnM65CMMR1UuJtQHgLeUdXBwOHAGuAWYJGqDgAWuX+3KX6fn/uPu5+fdf4Z0xZPY2P+Rq9DMsZ0QGFL4CLSGTgW+DuAqpaq6m5gIvC8u9jzwBnhiiGckmKSeHTco0T7orl20bXsKt7ldUjGmA4mnDXw/kAe8KyIfCEiT4tIAtBDVbcAuPdpYYwhrPok9eGhsQ+RuzeXqVlTKS0v9TokY0wHEs4E7gdGAI+r6nBgL41oLhGRK0RkiYgsycvLC1eMzZaels5dR9/Fsm3LmPnpTBv4yhgTMeFM4DlAjqpWHOWbi5PQc0WkF4B7X+upjar6pKqOVNWRqampYQyz+U7pdwrXpl/L/G/n89TKp7wOxxjTQYQtgavqVuB7ERnkThoHfAXMBy51p10KzAtXDJF05bArOa3/aTzyxSO8s+kdr8MxxnQA4e4Hfj0wW0RigA3Ar3E2Gq+IyP8A3wHnhjmGiBARZmbM5MfCH7n137fSO6E3h3U7jLy8PNLS0hARr0M0xrQz0hbabEeOHKlLlizxOoyQ7CrexQULL2BXyS727d5HsFOQqL1R3PXLuzj956d7HZ4xpo0RkaWqOrK2eXYqfQtLiUvhV4N+xd6yvWiCIiIEE4PM/HQmCzcs9Dq8iAsGg+Tm5trBXWPCwBJ4GLy09qX9ppUES3hw6YMeROOdYDDI2LFj6dOnD5mZmQSDNgCYF2wj2n5ZAg+DrXu31jo9d18u13xwDXO/nsv2ou0Rjiry8vLyyM7OJhAIkJ2dTWvuDtpe2Ua0dQjXRtQSeBj0TOhZ6/QEfwIb8jcw89OZHP/K8Vz81sU8u+pZNu/ZHOEIIyMtLY2MjAz8fj8ZGRmkpbXZc7aazOvar21Evf8MwrkRtQQeBlNGTCEuKq7atLioOG4bfRtvn/U2r054lWvSr6GkvIQHlj7Aaa+fxhlvnMFDyx5iZd7KdjPWuIiQlZVFTk4Oixcv7nA9cVpD7bejb0Rbw2cQzo2oJfAwGN9/PDMyZtAroReC0CuhFzMyZjC+/3hEhIEpA7nq8Kt45fRXeO/s97hl1C10j+/Os6ue5YK3LuDEf57IXf+5i+wfsikrL/P65TSLz+ejR48eniVvL2tfc76cQ+6EXAY9NYjcCbnM+XJOxGN4a+NbxF4Xy+CnBxN7fSxvbXwr4jF46cUvXmTrhK0MemoQWyds5cUvXox4DOHciFo3wlYkvySfj3M+Juv7LP79w78pChSRGJ3IMX2O4fiDjufo3keTGJPodZghW7hhIQ8te4ite7fSM6EnU0ZMYXz/8RErv6L2lZ2dTUZGBllZWfh8kamzLNywkBnZMyguL66cFhcVV7kh7ygxgPM5eHE+xJvfvsntn9xOqf40RlGMxDBzzExOO/i0iMUBzXsP6utGaAm8lSoOFPPZls/48PsPWfz9YnYW7yTaF82oXqM4/sDjGXvgWFI7td4hBt789k1mZM+gJFhSOS3WF8v00dM5/eDTI/JDfvGLF7njwzvwd/MT2BFg+vHTuWD4Bc1eb3mwnOLyYvaV7aMoUMS+wL5qj4sCRdz3+X3sKd2z33MT/AmcNdC5CEjFb09RVLXafYXK6e68yul1PKfqOj/Y/EG15F2hc0xnbv3FrSTFJFXeOsd0Jikmidio2Ga/P1Ut+GYBt75zK+WJ5UQVRnHXyU0/H6KkvIRdxbucW4lzv7tkd7VpVf/eUbyjznWlxafRJa4LXWK7kBybTJfY/R9XzO8S24WkmCR80rSNf3MrMpbA27jyYDlf5n3Jh999yIfff8j3Bd8DMCx1GMcfeDzHH3Q8/ZL77fe8cNR8ghpkV/EuthdtJ68oj7x9eewo3kHevjzyivLYXrSd7UXbySnIqZaIqvKJj9io2Oo3fyxxUXHERMXsf++ve3pt64mNiuU/P/6HR794tNoGJMYXw68P+zXpaen1Jt99ZfsqHxeV/TS9Yl5tSbExOvk7VX4e4v5z/gsi8tM0qL5cjXm1Pafq5/xD4Q+Nji3GF0NiTGJlQq956xzTmaRo53Fty8VFxVXGsHDDQm7/5Pb9NuIzx8zk5L4nk1+az+7i3ZXJeFfJLnYX72Zn8U4nEbt/V8wrChTVGrMglQk3JTaFlLgUusR24dX1r9b5Os/8+ZnsLtldecsvySe/JJ9yLa91eZ/46BzTef8k75Zb10bg/c3vN3svyBJ4O6KqfLP7G7K+z+LD7z5k9Y7VAPRL7leZzA/rfhgLv13Ire/eSnlCeUhngpaWl1Ym5e37tv/0uMr99n3b2VG8o9YveVJ0Et07dad7vHN7e+PbdZZ1+dDLKSkvqX4LlFASdO9rm+c+rusH1hLi/fF08neiU3SnWh/H++PpFN2p2uPKedHVl5n8zmS27dt/nLZeCb1475z3wvYaqjpp7kls2btlv+lp8Wk8ddJT7CndQ0FpwU+3soLqf7u3qsuVBusfMtnv81cm9R8Lf6QsuP8xnIoNUF0b+E7+TqTEpZASm1KZlLvEdaFrXFe6xFZJ0u68zjGdifJFhfz66/oMghqksKyQ/OL8yg1Ifkl+tSRf9fGuYmd+fRt0QWp9nY35HlgCb8e27t1amcyXbF1CQAMkxSSxt3QvQX464h4t0Zw98Gz6JPXZLynnFeXVursvCF3jupLaKZXu8d1JjU+tTNCpnVJJjU+lW3w3usd3J94fX+25jf3xhKosWEZpeSnFgWLnvrzGfaCYkvISpi2eVuc6Xjjlhf2ScZw/rsm7yLVpDe3P4YihpLykwSRfcXt7U90b8asOv6paMq6oNafEpbRYM06kPoPiQHGdCf7R5Y/W+hxBWHHpipDWX18Ct4sat3E9E3oyafAkJg2eRH5JPv/64V/MzJ5ZLXkDlGkZc9Y5vSBifDGVSblfcj9G9hxJanxq5bSKZJ0Sl4Lf17SvyJQRU2r98UwZMaXpLxaI9kUT7YsmITqh3uV6JfSqcwOSnpberBhCMb7/eILBYLW9oOm/nB7Rg4cVZbXkgeTYqFhi42PpHt+9wWWX5y2v8zO4Nv3aJscQqvH9x5Ofn1/tOMgfjv9Di38Gcf44evp71nr+x6vrX631PajrXJHGshp4OzTs+WF17p5+MukTkqKTInIQsSUPYDVWfe2vkUqiubm59OnTh0AggN/vJycnhx49ekSk7Kq86gXSGvZCVJXMzMzKnkiRPh+hJb6HNphVB1PX1r1XQi86x3SO2Bd4VNIovpr2Fat/vZqvpn3FqKRRESkXfuqL7yv0oar4Cn0R7z7XGk6i8fJElvH9xzNtyDTKtpehqpRtL2PakGkR/Qy8Ppms4j0o3V6KqlK6vbRF3wOrgbdDraHmA97XfsC72mdrKd/rvYDW8B3wmqpy3HHHVb4HH330UaPeAzuI2QF5fRJNBa8TWEfXGhKofQfsRB5L4MY0kSXQts16oRjTgVWMR2PaHzuIaYwxbZQlcGOMaaMsgRtjTBtlCdwYY9ooS+DGGNNGWQI3xpg2yhK4Mca0UZbAjTGmjWoTZ2KKSB6wuYlP7w5sb8FwrPy2F0NHL781xNDRy29ODD9T1Vqvn9gmEnhziMiSuk5DtfI7RgwdvfzWEENHLz9cMVgTijHGtFGWwI0xpo3qCAn8SSvfc17H0NHLB+9j6OjlQxhiaPdt4MYY0151hBq4Mca0S+0ygYvIMyKyTURWeVT+gSKSJSJrRGS1iDTvUuxNiyFORD4XkS/dGGZGOgY3jigR+UJE3vSo/E0islJElotIxK8KIiJdRGSuiKx1vw+jI1j2IPd1V9z2iMjUSJXvxjDN/f6tEpGXRCQukuW7MUxxy18diddfW/4Rka4i8r6IrHfvU1qirHaZwIHngJM9LD8A/FZVDwF+AVwrIodGOIYS4HhVPRxIB04WkV9EOAaAKcAaD8qtaqyqpnvUjewh4B1VHQwcTgTfC1Vd577udOAIYB/weqTKF5EDgBuAkap6GBAFnB+p8t0YDgMuB0bhvP+niciAMBf7HPvnn1uARao6AFjk/t1s7TKBq+rHwE4Py9+iqsvcxwU4P9oDIhyDqmqh+2e0e4voAQ8R6QOMB56OZLmthYh0Bo4F/g6gqqWqutujcMYB36pqU0+Iayo/EC8ifqAT8GOEyz8E+I+q7lPVAPARcGY4C6wj/0wEnncfPw+c0RJltcsE3pqISF9gOPCZB2VHichyYBvwvqpGOoZZwE1AMMLlVqXAeyKyVESuiHDZ/YE84Fm3GelpEUmIcAwVzgdeimSBqvoDcD/wHbAFyFfV9yIZA7AKOFZEuolIJ+BU4MAIxwDQQ1W3gFPBA9JaYqWWwMNIRBKBV4Gpqron0uWrarm7+9wHGOXuTkaEiJwGbFPVpZEqsw5jVHUEcApOU9axESzbD4wAHlfV4cBeWmjXuTFEJAaYAPwzwuWm4NQ8+wG9gQQRuSiSMajqGuA+4H3gHeBLnCbOdsESeJiISDRO8p6tqq95GYu7276YyB4XGANMEJFNwBzgeBH5RwTLB0BVf3Tvt+G0/46KYPE5QE6VPZ+5OAk90k4BlqlqboTLPQHYqKp5qloGvAZkRDgGVPXvqjpCVY/FadpYH+kYgFwR6QXg3m9riZVaAg8DERGcds81qvqARzGkikgX93E8zo9pbaTKV9U/qGofVe2Ls/v+oapGtPYlIgkiklTxGDgJZ5c6IlR1K/C9iAxyJ40DvopU+VVMIsLNJ67vgF+ISCf3NzEODw5oi0iae38QcBbevBfzgUvdx5cC81pipf6WWElrIyIvAZlAdxHJAW5X1b9HMIQxwMXASrcNGuCPqvpWBGPoBTwvIlE4G+pXVNWTrnwe6gG87uQO/MCLqvpOhGO4HpjtNmNsAH4dycLddt8TgSsjWS6Aqn4mInOBZTjNFl/gzRmRr4pIN6AMuFZVd4WzsNryD3Av8IqI/A/Ohu3cFinLzsQ0xpi2yZpQjDGmjbIEbowxbZQlcGOMaaMsgRtjTBtlCdwYY9ooS+CtnIiUuyPJrXZHFrxRRHzuvJEi8rBHcWWHYZ3Picg57uOnKwYAE5E/tmAZN7ijAs6uMb3B99IdWfCaBpZp0vsiIleJyCWNWL6wgfkNxhpCGZNFpHdz1mHCy7oRtnIiUqiqie7jNOBF4BNVvd3byFqeiDwHvKmqc2tMr3wPWqCMtcApqrqxCc/t68a335AEIhKlquUtEGKosdT7ntQXayPKWAz8TlUjPgyvCY3VwNsQ93TwK4DrxJEp7jjbIjJDRJ4XkffEGQP7LBH5P3HGwn7HPbUfETlCRD5yB3d6t8rpvYtF5D5xxhD/WkSOcacPcactF5EVFUNxVtQA3Tj+LM54yytF5Dx3eqa7zoqxsGe7Z+MhItNF5L/uc56smF6V+9yRInIvzmh2y9113ClVxlcXkbtF5IZann+ju/5V4o4BLSJP4AwwNV9EptVYvuZ7+Ywbw4Yq678XONiN5c/uc7JE5EVgZY33pb7Xf6+IfOW+n/dXKfN37uOfi8gH4uxxLRORg+v6TohIoogscpdbKSITa4vVXfb37vu+Qtzx4UWkrzh7JE+Js5f3nojEi7MnNBLnJKTl4pzNW7XcG6q8hjlVXsMLIvKhOONeX95AjIjIJe46vhSRF9xpqSLyqhvrf0VkTF2vv8NTVbu14htQWMu0XThnGWbi1LIAZgD/xhk29nCcsZ9Pcee9jjN8ZTSQDaS6088DnnEfLwb+4j4+FfjAffwIcKH7OAaIrxoXcDbOQEFRbkzf4ZwFmgnk4wyk5QM+BY52n9O1ymt5ATjdffwccE6VeEbWfA+AvjjjeuCu91ugW4335wichJoAJAKrgeHuvE1A91re05rvZTYQC3QHdrjvXV9gVY3n7AX61fy86nr9QFdgHT/t/XapUubv3MefAWe6j+OATnV9L3DOMO3sPu4OfANILbGehHMWpLjxvIkz1G1fnLMk093lXgEuqvkZ1FL+j0BsLa/hSyDejeV7nEGs6opxiPtedK/6vcDZy6z4rhyEMySF57/F1nhrl6fSdwD71Vhdb6tqmYisxEmoFaeNr8T5oQ4CDgPedyuDUTjDfFaoGHRrqbs8OInnf8UZ2/s1Va05ENDRwEvqNB/kishHwJHAHuBzVc0BEGdIgb44G5mxInITzvjQXXES7IJQXriqbhKRHSIyHGeD8YWq7qglptdVda9b9mvAMTincodqoaqWACUiss0tqzafa93NMbW9/v8AxcDTIrIQJ5FWEmfslgNU9XX39RY3EKcA94gzymIQZ9z52mI9yb1VvAeJwACcDe5GVV3uTq/62ddnBU7t/A3gjSrT56lqEVAkIlk4g4ctrCPG44G5qrrdfa0VY2ifABxaZcess4gkqTO2vqnCEngbIyL9gXKc0cwOqTG7BEBVgyJSpm4VBudH48f5sa9W1bou61Xi3pe7y6OqL4rIZzgXZnhXRH6jqh9WDamecEuqPC4H/OJcUusxnJrd9yIyA6eW2RhPA5OBnsAztcyvL6ZQ7Rd7Hcvtbcw6VDUgIqNwBnY6H7gOJ5FVaGzsFwKpwBHuxnsTtb+fAvxJVf9WbaLTVl4zzmrNJXUYj1ODnwDcJiJD3Ok1D6ppPTFKLcuDs4cw2t0QmHpYG3gbIiKpwBPAo1WSc2OsA1LFvS6jiERX+eHVVWZ/YIOqPowzotqwGot8DJwnzsUjUnF+1J/Xs8qK5LJdnPHSzwkh7jJx2/Bdr+MMjXsk8G4ty38MnCHOKHgJOFdg+VcI5TSkAEhqzgrc15yszsBmU3Eud1dJnXHjc0TkDHf5WHEGpKpLMs6462UiMhb4WR2xvgtc5paPiBwg7ih99aj19YrTC+pAVc3CuWBHF5waPcBEca7H2g2nGem/9cS4CPiVuywi0tWd/h7Ohq2ivPQG4uywrAbe+sW7u9/ROG2VLwBNGqJWVUvdg1MPi0gyzuc/C6cJoy7nAReJSBmwFbijxvzXgdE4bZ8K3KSqW0VkcB0x7BaRp3CadTbh/MAb8iSwQkSWqeqF7uvIAnZrLT0/VHWZOD1aKjYkT6tqY5pPaqWqO0TkE3EuVvs2TtNAYyUB89w9EQGm1bLMxcDfROQOnBH0zsUZybA2s4EF4lyweTnukME1Y1XV34vIIcCnbtNEIXARTo27Ls8BT4hIEdVrxFHAP9zvkAAPup8rOO/5Qpy26ztV9UdxumzWFuNqEbkb+EhEynGadybjXEfzryKyAuc7+jFwVT1xdljWjdC0OW4NcBlwbi1t8sYjbnNYoare73UsHYU1oZg2RZyTe77BucK3JW/ToVkN3Bhj2iirgRtjTBtlCdwYY9ooS+DGGNNGWQI3xpg2yhK4Mca0UZbAjTGmjfr/JeIrlQoY7s4AAAAASUVORK5CYII=\n", 167 | "text/plain": [ 168 | "
" 169 | ] 170 | }, 171 | "metadata": { 172 | "needs_background": "light" 173 | }, 174 | "output_type": "display_data" 175 | } 176 | ], 177 | "source": [ 178 | "plt.figure(figsize=(5, 4), constrained_layout=True)\n", 179 | "\n", 180 | "train_loss = np.sum(train_scores, axis=2)\n", 181 | "test_loss = np.sum(test_scores, axis=2)\n", 182 | "\n", 183 | "plt.scatter(np.tile(np.arange(1, 11), (3, 1)).T, train_loss, s=5, c='k')\n", 184 | "plt.scatter(np.tile(np.arange(1, 11), (3, 1)).T, test_loss, s=5, c='k')\n", 185 | "plt.scatter(np.arange(1, 11), np.mean(test_loss, axis=1), c='C3')\n", 186 | "plt.scatter(np.arange(1, 11), np.mean(train_loss, axis=1), c='C2')\n", 187 | "plt.plot(np.arange(1, 11), np.mean(test_loss, axis=1), c='C3', label='Validation set')\n", 188 | "plt.plot(np.arange(1, 11), np.mean(train_loss, axis=1), c='C2', label='Training set')\n", 189 | "plt.legend()\n", 190 | "plt.ylabel('VAE loss function')\n", 191 | "plt.xlabel('Dimensionality of intrinsic latent space')\n", 192 | "plt.xticks(np.arange(10) + 1)\n", 193 | "\n", 194 | "plt.savefig('./figures/vae_dimensionality.pdf')" 195 | ] 196 | } 197 | ], 198 | "metadata": { 199 | "kernelspec": { 200 | "display_name": "Kyle's conda", 201 | "language": "python", 202 | "name": "kyle_conda" 203 | }, 204 | "language_info": { 205 | "codemirror_mode": { 206 | "name": "ipython", 207 | "version": 3 208 | }, 209 | "file_extension": ".py", 210 | "mimetype": "text/x-python", 211 | "name": "python", 212 | "nbconvert_exporter": "python", 213 | "pygments_lexer": "ipython3", 214 | "version": "3.8.6" 215 | } 216 | }, 217 | "nbformat": 4, 218 | "nbformat_minor": 4 219 | } 220 | -------------------------------------------------------------------------------- /parsnip/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import * 2 | from .instruments import * 3 | from .light_curve import * 4 | from .parsnip import * 5 | from .plotting import * 6 | from .settings import * 7 | from .sncosmo import * 8 | from .utils import * 9 | -------------------------------------------------------------------------------- /parsnip/classifier.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import StratifiedKFold 2 | import astropy.table 3 | import lightgbm 4 | import numpy as np 5 | import os 6 | import pickle 7 | 8 | 9 | def extract_top_classifications(classifications): 10 | """Extract the top classification for each row a classifications Table. 11 | 12 | This is a bit complicated when working with astropy Tables. 13 | 14 | Parameters 15 | ---------- 16 | classifications : `~astropy.table.Table` 17 | Classifications table output from a `Classifier` 18 | 19 | Returns 20 | ------- 21 | `numpy.array` 22 | numpy array with the top type for each light curve 23 | """ 24 | types = classifications.colnames[1:] 25 | dtype = classifications[types[0]].dtype 26 | probabilities = classifications[types].as_array().view((dtype, len(types))) 27 | top_types = np.array(types)[probabilities.argmax(axis=1)] 28 | 29 | return top_types 30 | 31 | 32 | def weighted_multi_logloss(true_types, classifications): 33 | """Calculate a weighted log loss metric. 34 | 35 | This is the metric used for the PLAsTiCC challenge (with class weights set to 1) 36 | as described in Malz et al. 2019 37 | 38 | Parameters 39 | ---------- 40 | true_types : `~numpy.ndarray` 41 | True types for each object 42 | classifications : `~astropy.table.Table` 43 | Classifications table output from a `~Classifier` 44 | 45 | Returns 46 | ------- 47 | [type] 48 | [description] 49 | """ 50 | total_logloss = 0. 51 | unique_types = np.unique(true_types) 52 | for type_name in unique_types: 53 | type_mask = true_types == type_name 54 | type_predictions = classifications[type_name][type_mask] 55 | type_loglosses = ( 56 | -np.log(type_predictions) 57 | / len(unique_types) 58 | / len(type_predictions) 59 | ) 60 | total_logloss += np.sum(type_loglosses) 61 | return total_logloss 62 | 63 | 64 | class Classifier(): 65 | """LightGBM classifier that operates on ParSNIP predictions""" 66 | def __init__(self): 67 | # Keys to use 68 | self.keys = [ 69 | 'color', 70 | 'color_error', 71 | 's1', 72 | 's1_error', 73 | 's2', 74 | 's2_error', 75 | 's3', 76 | 's3_error', 77 | 'luminosity', 78 | 'luminosity_error', 79 | 'reference_time_error', 80 | ] 81 | 82 | def extract_features(self, predictions): 83 | """Extract features used for classification 84 | 85 | The features to use are specified by the `keys` attribute. 86 | 87 | Parameters 88 | ---------- 89 | predictions : `~astropy.table.Table` 90 | Predictions output from `ParsnipModel.predict_dataset` 91 | 92 | Returns 93 | ------- 94 | `~numpy.ndarray` 95 | Extracted features that will be used for classification 96 | """ 97 | return np.array([predictions[i].data for i in self.keys]).T 98 | 99 | def train(self, predictions, num_folds=10, labels=None, target_label=None, 100 | reweight=True, min_child_weight=1000.): 101 | """Train a classifier on the predictions from a ParSNIP model 102 | 103 | Parameters 104 | ---------- 105 | predictions : `~astropy.table.Table` 106 | Predictions output from `ParsnipModel.predict_dataset` 107 | num_folds : int, optional 108 | Number of K-folds to use, by default 10 109 | labels : List[str], optional 110 | True labels for each light curve, by default None 111 | target_label : str, optional 112 | If specified, do one-vs-all classification for the given label, by default 113 | None 114 | reweight : bool, optional 115 | If true, weight all light curves so that each type has the same total 116 | weight, by default True 117 | min_child_weight : float, optional 118 | `min_child_weight` parameter for LightGBM, by default 1000 119 | 120 | Returns 121 | ------- 122 | `~astropy.table.Table` 123 | K-folding out-of-sample predictions for each light curve 124 | """ 125 | print("Training classifier with keys:") 126 | for key in self.keys: 127 | print(f" {key}") 128 | 129 | if labels is None: 130 | # Use default labels 131 | labels = predictions['type'] 132 | 133 | if target_label is not None: 134 | # Single class classification 135 | labels = labels == target_label 136 | class_names = np.array([target_label, 'Other']) 137 | 138 | numeric_labels = (~labels).astype(int) 139 | else: 140 | # Multi-class classification 141 | class_names = np.unique(labels) 142 | 143 | # Assign numbers to the labels so that we can guarantee a consistent 144 | # ordering. 145 | label_map = {j: i for i, j in enumerate(class_names)} 146 | numeric_labels = np.array([label_map[i] for i in labels]) 147 | 148 | # Assign folds while making sure that we keep all augmentations of the same 149 | # object in the same fold. 150 | if num_folds > 1: 151 | if 'augmented' in predictions.colnames: 152 | original_mask = ~predictions['augmented'] 153 | else: 154 | original_mask = np.ones(len(predictions), dtype=bool) 155 | 156 | object_ids = predictions['original_object_id'][original_mask] 157 | original_labels = numeric_labels[original_mask] 158 | 159 | kf = StratifiedKFold(num_folds, random_state=1, shuffle=True) 160 | 161 | fold_map = {} 162 | 163 | for fold_idx, (train_index, test_index) in enumerate( 164 | kf.split(object_ids, original_labels)): 165 | test_ids = object_ids[test_index] 166 | for tid in test_ids: 167 | fold_map[tid] = fold_idx 168 | 169 | predictions['fold'] = [fold_map[i] for i in 170 | predictions['original_object_id']] 171 | else: 172 | predictions['fold'] = -1 173 | 174 | classifiers = [] 175 | 176 | features = self.extract_features(predictions) 177 | 178 | # Normalize by the class counts. We normalize so that the average weight is 1 179 | # across all objects, and so that the sum of weights for each class is the same. 180 | if reweight: 181 | count_names, class_counts = np.unique(numeric_labels, return_counts=True) 182 | norm = np.mean(class_counts) 183 | class_weights = {name: norm / count for name, count in zip(count_names, 184 | class_counts)} 185 | weights = np.array([class_weights[i] for i in numeric_labels]) 186 | else: 187 | weights = np.ones_like(numeric_labels) 188 | 189 | # Calculate out-of-sample classifications with K-fold cross-validation if we 190 | # are doing that. 191 | if num_folds > 1: 192 | classifications = np.zeros((len(numeric_labels), len(class_names))) 193 | 194 | for fold in range(num_folds): 195 | if target_label is not None: 196 | # Single class classification 197 | lightgbm_params = { 198 | "objective": "binary", 199 | "metric": "binary_logloss", 200 | "min_child_weight": min_child_weight, 201 | } 202 | else: 203 | lightgbm_params = { 204 | "objective": "multiclass", 205 | "num_class": len(class_names), 206 | "metric": "multi_logloss", 207 | "min_child_weight": min_child_weight, 208 | } 209 | 210 | train_index = predictions['fold'] != fold 211 | test_index = ~train_index 212 | 213 | fit_params = {"verbose": 100, "sample_weight": weights[train_index]} 214 | 215 | if num_folds > 1: 216 | fit_params["eval_set"] = [(features[test_index], 217 | numeric_labels[test_index])] 218 | fit_params["eval_sample_weight"] = [weights[test_index]] 219 | 220 | classifier = lightgbm.LGBMClassifier(**lightgbm_params) 221 | classifier.fit(features[train_index], 222 | numeric_labels[train_index], **fit_params) 223 | 224 | classifiers.append(classifier) 225 | 226 | if num_folds > 1: 227 | # Out of sample predictions 228 | classifications[test_index] = classifier.predict_proba( 229 | features[test_index] 230 | ) 231 | 232 | # Keep the trained classifiers 233 | self.classifiers = classifiers 234 | self.class_names = class_names 235 | 236 | if num_folds == 1: 237 | # Only had a single fold, so do in sample predictions 238 | classifications = classifiers[0].predict_proba(features) 239 | 240 | classifications = astropy.table.hstack([ 241 | predictions['object_id'], 242 | astropy.table.Table(classifications, names=class_names) 243 | ]) 244 | 245 | return classifications 246 | 247 | def classify(self, predictions): 248 | """Classify light curves using predictions from a `~ParsnipModel` 249 | 250 | If the classifier was trained with K-folding, we average the classification 251 | probabilities over all folds. 252 | 253 | Parameters 254 | ---------- 255 | predictions : `~astropy.table.Table` 256 | Predictions output from `ParsnipModel.predict_dataset` 257 | 258 | Returns 259 | ------- 260 | Returns 261 | ------- 262 | `~astropy.table.Table` 263 | Predictions for each light curve 264 | """ 265 | features = self.extract_features(predictions) 266 | 267 | classifications = 0. 268 | 269 | for classifier in self.classifiers: 270 | classifications += classifier.predict_proba(features) 271 | 272 | classifications /= len(self.classifiers) 273 | 274 | classifications = astropy.table.hstack([ 275 | predictions['object_id'], 276 | astropy.table.Table(classifications, names=self.class_names) 277 | ]) 278 | 279 | return classifications 280 | 281 | def write(self, path): 282 | """Write the classifier out to disk 283 | 284 | Parameters 285 | ---------- 286 | path : str 287 | Path to write to 288 | """ 289 | os.makedirs(os.path.dirname(path), exist_ok=True) 290 | 291 | with open(path, 'wb') as f: 292 | pickle.dump(self, f) 293 | 294 | @classmethod 295 | def load(cls, path): 296 | """Load a classifier that was saved to disk 297 | 298 | Parameters 299 | ---------- 300 | path : str 301 | Path where the classifier was saved 302 | 303 | Returns 304 | ------- 305 | `~Classifier` 306 | Loaded classifier 307 | """ 308 | with open(path, 'rb') as f: 309 | classifier = pickle.load(f) 310 | 311 | return classifier 312 | -------------------------------------------------------------------------------- /parsnip/instruments.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from functools import lru_cache, reduce 3 | 4 | import extinction 5 | import lcdata 6 | import numpy as np 7 | import sncosmo 8 | 9 | """This file contains instrument specific definitions.""" 10 | 11 | 12 | band_info = { 13 | # Information about all of the different bands and how to handle them. We assume 14 | # that all data from the same telescope should be processed the same way. 15 | 16 | # Band name Correct Correct Plot color Plot marker 17 | # Background MWEBV 18 | # PanSTARRS 19 | 'ps1::g': (True, True, 'C0', 'o'), 20 | 'ps1::r': (True, True, 'C2', '^'), 21 | 'ps1::i': (True, True, 'C1', 'v',), 22 | 'ps1::z': (True, True, 'C3', '<'), 23 | 24 | # PLAsTICC 25 | 'lsstu': (True, False, 'C6', 'o'), 26 | 'lsstg': (True, False, 'C4', 'v'), 27 | 'lsstr': (True, False, 'C0', '^'), 28 | 'lssti': (True, False, 'C2', '<'), 29 | 'lsstz': (True, False, 'C3', '>'), 30 | 'lssty': (True, False, 'goldenrod', 's'), 31 | 32 | # ZTF 33 | 'ztfg': (False, True, 'C0', 'o'), 34 | 'ztfr': (False, True, 'C2', '^'), 35 | 'ztfi': (False, True, 'C1', 'v'), 36 | 37 | # SWIFT 38 | 'uvot::u': (False, True, 'C6', '<'), 39 | 'uvot::b': (False, True, 'C3', '>'), 40 | 'uvot::v': (False, True, 'goldenrod', 's'), 41 | 'uvot::uvm2': (False, True, 'C5', 'p'), 42 | 'uvot::uvw1': (False, True, 'C7', 'P'), 43 | 'uvot::uvw2': (False, True, 'C8', '*'), 44 | } 45 | 46 | 47 | def calculate_band_mw_extinctions(bands): 48 | """Calculate the Milky Way extinction corrections for a set of bands 49 | 50 | Multiply mwebv by these values to get the extinction that should be applied to 51 | each band for a specific light curve. For bands that have already been corrected, we 52 | set this value to 0. 53 | 54 | Parameters 55 | ---------- 56 | bands : List[str] 57 | Bands to calculate the extinction for 58 | 59 | Returns 60 | ------- 61 | `~numpy.ndarray` 62 | Milky Way extinction in each band 63 | 64 | Raises 65 | ------ 66 | KeyError 67 | If any bands are not available in band_info in instruments.py 68 | """ 69 | band_mw_extinctions = [] 70 | 71 | for band_name in bands: 72 | # Check if we should be correcting the extinction for this band. 73 | try: 74 | should_correct = band_info[band_name][1] 75 | except KeyError: 76 | raise KeyError(f"Can't handle band {band_name}. Add it to band_info in " 77 | "instruments.py") 78 | 79 | if should_correct: 80 | band = sncosmo.get_bandpass(band_name) 81 | band_mw_extinctions.append(extinction.fm07(np.array([band.wave_eff]), 82 | 3.1)[0]) 83 | else: 84 | band_mw_extinctions.append(0.) 85 | 86 | band_mw_extinctions = np.array(band_mw_extinctions) 87 | 88 | return band_mw_extinctions 89 | 90 | 91 | def should_correct_background(bands): 92 | """Determine if we should correct the background levels for a set of bands 93 | 94 | Parameters 95 | ---------- 96 | bands : List[str] 97 | Bands to lookup 98 | 99 | Returns 100 | ------- 101 | `~numpy.ndarray` 102 | Boolean for each band indicating if it needs background correction 103 | 104 | Raises 105 | ------ 106 | KeyError 107 | If any bands are not available in band_info in instruments.py 108 | """ 109 | band_correct_background = [] 110 | 111 | for band_name in bands: 112 | # Check if we should be correcting the extinction for this band. 113 | try: 114 | should_correct = band_info[band_name][0] 115 | except KeyError: 116 | raise KeyError(f"Can't handle band {band_name}. Add it to band_info in " 117 | "instruments.py") 118 | 119 | band_correct_background.append(should_correct) 120 | 121 | band_correct_background = np.array(band_correct_background) 122 | 123 | return band_correct_background 124 | 125 | 126 | def get_band_plot_color(band): 127 | """Return the plot color for a given band. 128 | 129 | If the band does not yet have a color assigned to it, then a random color 130 | will be assigned (in a systematic way). 131 | 132 | Parameters 133 | ---------- 134 | band : str 135 | Name of the band to use. 136 | 137 | Returns 138 | ------- 139 | str 140 | Matplotlib color to use when plotting the band 141 | """ 142 | if band in band_info: 143 | return band_info[band][2] 144 | 145 | # Systematic random colors. We use the hash of the band name. 146 | # Note: hash() uses a random offset in python 3 so it isn't consistent 147 | # between runs! 148 | hasher = hashlib.md5() 149 | hasher.update(band.encode("utf8")) 150 | hex_color = "#%s" % hasher.hexdigest()[-6:] 151 | 152 | return hex_color 153 | 154 | 155 | def get_band_plot_marker(band): 156 | """Return the plot marker for a given band. 157 | 158 | If the band does not yet have a marker assigned to it, then we use the 159 | default circle. 160 | 161 | Parameters 162 | ---------- 163 | band : str 164 | Name of the band to use. 165 | 166 | Returns 167 | ------- 168 | str 169 | Matplotlib marker to use when plotting the band 170 | """ 171 | if band in band_info: 172 | return band_info[band][3] 173 | else: 174 | return 'o' 175 | 176 | 177 | def parse_ps1(dataset, reject_invalid=True, label_map=None, verbose=True): 178 | """Parse a PanSTARRS-1 dataset 179 | 180 | Parameters 181 | ---------- 182 | dataset : `~lcdata.Dataset` 183 | PanSTARRS-1 dataset to parse 184 | 185 | Returns 186 | ------- 187 | `~lcdata.Dataset` 188 | Parsed dataset 189 | """ 190 | # Light curves in the unsupervised set from Villar et al. 2020 don't have valid 191 | # redshifts. 192 | dataset.meta['original_redshift'] = dataset.meta['redshift'] 193 | dataset.meta['redshift'][~dataset.meta['unsupervised']] = np.nan 194 | 195 | # Labels to use for classification. Note that all of the non-supernova-like light 196 | # curves get rejected by the previous cut from Villar et al. 2020. 197 | if label_map is None: 198 | label_map = { 199 | 'AGN': 'AGN', 200 | 'Bad': 'Bad', 201 | 'Bad (rise)': 'Bad', 202 | 'FELT': 'FELT', 203 | 'FELT (Bronze)': 'FELT', 204 | 'Lensed SNIa': 'Lensed SNIa', 205 | 'QSO': 'QSO', 206 | 'SLSN': 'SLSN', 207 | 'SNII': 'SNII', 208 | 'SNIIb?': 'SNII', 209 | 'SNIIn': 'SNIIn', 210 | 'SNIa': 'SNIa', 211 | 'SNIax': 'SNIax', 212 | 'SNIbc (Ib)': 'SNIbc', 213 | 'SNIbc (Ic)': 'SNIbc', 214 | 'SNIbc (Ic-BL)': 'SNIbc', 215 | 'SNIbn': 'SNIbc', 216 | 'TDE': 'TDE', 217 | 'Unknown': 'Unknown', 218 | 'VAR': 'VAR' 219 | } 220 | dataset.meta['type'] = [label_map[i] for i in dataset.meta['type']] 221 | 222 | return dataset 223 | 224 | 225 | def parse_ztf(dataset, reject_invalid=True, label_map=None, valid_classes=None, verbose=True): 226 | """Parse a ZTF dataset 227 | 228 | Parameters 229 | ---------- 230 | dataset : `~lcdata.Dataset` 231 | ZTF dataset to parse 232 | 233 | Returns 234 | ------- 235 | `~lcdata.Dataset` 236 | Parsed dataset 237 | """ 238 | lcs = [] 239 | invalid_count = 0 240 | for lc in dataset.light_curves: 241 | # Some ZTF datasets replace lower limits with a flux of zero. This is bad. Throw 242 | # out all of those observations because we can't handle them. 243 | lc = lc[(lc['flux'] != 0.) & (lc['fluxerr'] != 0.)] 244 | if len(lc) == 0: 245 | invalid_count += 1 246 | continue 247 | 248 | lcs.append(lc) 249 | if verbose: 250 | print(f"Rejecting {invalid_count} light curves with no good observations.") 251 | dataset = lcdata.from_light_curves(lcs) 252 | 253 | # Clean up labels 254 | types = [str(i).replace(' ', '').replace('?', '') for i in dataset.meta['type']] 255 | 256 | if label_map is None: 257 | label_map = { 258 | 'AGN': 'Galaxy', 259 | 'Bogus': 'Bad', 260 | 'CLAGN': 'Galaxy', 261 | 'CV': 'Star', 262 | 'CVCandidate': 'Star', 263 | 'Duplicate': 'Bad', 264 | 'Galaxy': 'Galaxy', 265 | 'Gap': 'Peculiar', 266 | 'GapI': 'Peculiar', 267 | 'GapI-Ca-rich': 'Peculiar', 268 | 'IIP': 'SNII', 269 | 'ILRT': 'Peculiar', 270 | 'LBV': 'Star', 271 | 'LINER': 'Galaxy', 272 | 'LRN': 'Star', 273 | 'NLS1': 'Galaxy', 274 | 'None': 'Unknown', 275 | 'Nova': 'Star', 276 | 'Q': 'Galaxy', 277 | 'QSO': 'Galaxy', 278 | 'SLSN-I': 'SLSN', 279 | 'SLSN-I.5': 'SLSN', 280 | 'SLSN-II': 'SLSN', 281 | 'SLSN-R': 'SLSN', 282 | 'SN': 'Unknown', 283 | 'SNII': 'SNII', 284 | 'SNII-pec': 'SNII', 285 | 'SNIIL': 'SNII', 286 | 'SNIIP': 'SNII', 287 | 'SNIIb': 'SNII', 288 | 'SNIIn': 'SNII', 289 | 'SNIa': 'SNIa', 290 | 'SNIa-91T': 'SNIa', 291 | 'SNIa-91T-like': 'SNIa', 292 | 'SNIa-91bg': 'SNIa', 293 | 'SNIa-99aa': 'SNIa', 294 | 'SNIa-CSM': 'SNIa', 295 | 'SNIa-norm': 'SNIa', 296 | 'SNIa-pec': 'SNIa', 297 | 'SNIa00cx-like': 'SNIa', 298 | 'SNIa02cx-like': 'SNIa', 299 | 'SNIa02ic-like': 'SNIa', 300 | 'SNIa91T': 'SNIa', 301 | 'SNIa91T-like': 'SNIa', 302 | 'SNIa91bg-like': 'SNIa', 303 | 'SNIapec': 'SNIa', 304 | 'SNIax': 'SNIa', 305 | 'SNIb': 'SNIbc', 306 | 'SNIb/c': 'SNIbc', 307 | 'SNIbn': 'SNIbc', 308 | 'SNIbpec': 'SNIbc', 309 | 'SNIc': 'SNIbc', 310 | 'SNIc-BL': 'SNIbc', 311 | 'SNIc-broad': 'SNIbc', 312 | 'Star': 'Star', 313 | 'TDE': 'TDE', 314 | 'Var': 'Star', 315 | 'asteroid': 'Asteroid', 316 | 'blazar': 'Galaxy', 317 | 'bogus': 'Bad', 318 | 'duplicate': 'Bad', 319 | 'galaxy': 'Galaxy', 320 | 'nan': 'Unknown', 321 | 'nova': 'Star', 322 | 'old': 'Bad', 323 | 'rock': 'Asteroid', 324 | 'star': 'Star', 325 | 'stellar': 'Star', 326 | 'unclassified': 'Unknown', 327 | 'unk': 'Unknown', 328 | 'unknown': 'Unknown', 329 | 'Unknown': 'Unknown', 330 | 'varstar': 'Star', 331 | } 332 | 333 | dataset.meta['original_type'] = dataset.meta['type'] 334 | dataset.meta['type'] = [label_map[i] for i in types] 335 | 336 | # Drop light curves that aren't supernova-like 337 | if valid_classes is None: 338 | valid_classes = [ 339 | 'SNIa', 340 | 'SNII', 341 | 'Unknown', 342 | # 'Galaxy', 343 | 'SNIbc', 344 | 'SLSN', 345 | # 'Star', 346 | 'TDE', 347 | # 'Bad', 348 | 'Peculiar', 349 | ] 350 | if reject_invalid: 351 | mask = np.isin(dataset.meta['type'], valid_classes) 352 | if verbose: 353 | print(f"Rejecting {np.sum(~mask)} non-supernova-like light curves.") 354 | dataset = dataset[mask] 355 | 356 | return dataset 357 | 358 | 359 | def parse_plasticc(dataset, reject_invalid=True, verbose=True): 360 | """Parse a PLAsTiCC dataset 361 | 362 | Parameters 363 | ---------- 364 | dataset : `~lcdata.Dataset` 365 | PLAsTiCC dataset to parse 366 | 367 | Returns 368 | ------- 369 | `lcdata.Dataset` 370 | Parsed dataset 371 | """ 372 | # Set invalid speczs to nan. 373 | dataset.meta['hostgal_specz'][dataset.meta['hostgal_specz'] < 0] = np.nan 374 | 375 | # Throw out light curves that don't look like supernovae 376 | valid_classes = [ 377 | 'SNIa', 378 | 'SNIa-91bg', 379 | 'SNIax', 380 | 'SNII', 381 | 'SNIbc', 382 | 'SLSN-I', 383 | 'TDE', 384 | 'KN', 385 | 'ILOT', 386 | 'CaRT', 387 | 'PISN', 388 | # 'AGN', 389 | # 'RRL', 390 | # 'M-dwarf', 391 | # 'EB', 392 | # 'Mira', 393 | # 'muLens-Single', 394 | # 'muLens-Binary', 395 | # 'muLens-String', 396 | ] 397 | 398 | if reject_invalid: 399 | mask = np.isin(dataset.meta['type'], valid_classes) 400 | if verbose: 401 | print(f"Rejecting {np.sum(~mask)} non-supernova-like light curves.") 402 | dataset = dataset[mask] 403 | 404 | return dataset 405 | 406 | 407 | def parse_dataset(dataset, path_or_name=None, kind=None, reject_invalid=True, require_redshift=True, label_map=None, valid_classes=None,verbose=True): 408 | """Parse a dataset from the lcdata package. 409 | 410 | We cut out observations that are not relevant for the ParSNIP model (e.g. galactic 411 | ones), and update the class labels. 412 | 413 | We try to guess the kind of dataset from the filename. If this doesn't work, specify 414 | the kind explicitly instead. 415 | 416 | Parameters 417 | ---------- 418 | dataset : `~lcdata.Dataset` 419 | Dataset to parse 420 | path_or_name : str, optional 421 | Name of the dataset, or path to it, by default None 422 | kind : str, optional 423 | Kind of dataset, by default None 424 | reject_invalid : bool, optional 425 | Whether to reject invalid light curves, by default True 426 | label_map : dict, optional 427 | Overwriting the default classification label mapping with 428 | a custom dict 429 | verbose : bool, optional 430 | If true, print parsing information, by default True 431 | 432 | Returns 433 | ------- 434 | `~lcdata.Dataset` 435 | Parsed dataset 436 | """ 437 | if kind is None and path_or_name is not None: 438 | # Parse the dataset to figure out what we need to do with it. 439 | parse_name = path_or_name.lower().split('/')[-1] 440 | if 'ps1' in parse_name or 'panstarrs' in parse_name: 441 | if verbose: 442 | print(f"Parsing '{parse_name}' as PanSTARRS dataset ...") 443 | kind = 'ps1' 444 | elif 'plasticc' in parse_name: 445 | if verbose: 446 | print(f"Parsing '{parse_name}' as PLAsTiCC dataset...") 447 | kind = 'plasticc' 448 | elif 'ztf' in parse_name: 449 | if verbose: 450 | print(f"Parsing '{parse_name}' as ZTF dataset...") 451 | kind = 'ztf' 452 | else: 453 | if verbose: 454 | print(f"Unknown dataset type '{parse_name}'. Using default parsing. " 455 | "Specify how to parse it in instruments.py if necessary.") 456 | kind = 'default' 457 | 458 | if kind == 'ps1': 459 | dataset = parse_ps1(dataset=dataset, reject_invalid=reject_invalid, label_map=label_map, verbose=verbose) 460 | elif kind == 'plasticc': 461 | dataset = parse_plasticc(dataset=dataset, reject_invalid=reject_invalid, verbose=verbose) 462 | elif kind == 'ztf': 463 | dataset = parse_ztf(dataset=dataset, reject_invalid=reject_invalid, label_map=label_map, valid_classes=valid_classes,verbose=verbose) 464 | elif kind == 'default': 465 | # Don't do anything by default 466 | pass 467 | else: 468 | if verbose: 469 | print(f"Unknown dataset type '{kind}'. Using default parsing. " 470 | "Specify how to parse it in instruments.py if necessary.") 471 | 472 | # Throw out light curves that don't have valid redshifts. 473 | if require_redshift and reject_invalid: 474 | redshift_mask = np.isnan(dataset.meta['redshift']) 475 | if np.any(redshift_mask): 476 | if verbose: 477 | print(f"Rejecting {np.sum(redshift_mask)} light curves with missing " 478 | "redshifts.") 479 | dataset = dataset[~redshift_mask] 480 | 481 | if verbose: 482 | print(f"Dataset contains {len(dataset)} light curves.") 483 | 484 | return dataset 485 | 486 | 487 | def load_dataset(path, kind=None, in_memory=True, reject_invalid=True, 488 | require_redshift=True, label_map=None, valid_classes=None, verbose=True): 489 | """Load a dataset using the lcdata package. 490 | 491 | This can be any lcdata HDF5 dataset. We use `~parse_dataset` to clean things up for 492 | ParSNIP by rejecting irrelevant light curves (e.g. galactic ones) and updating class 493 | labels. 494 | 495 | We try to guess the dataset type from the filename. If this doesn't work, specify 496 | the filename explicitly instead. 497 | 498 | Parameters 499 | ---------- 500 | path : str 501 | Path to the dataset on disk 502 | kind : str, optional 503 | Kind of dataset, by default we will attempt to determine it from the filename 504 | in_memory : bool, optional 505 | If False, don't load the light curves into memory, and only load the metadata. 506 | See `lcdata.Dataset` for details. 507 | reject_invalid : bool, optional 508 | Whether to reject invalid light curves, by default True 509 | label_map : dict, optional 510 | Overwriting the default classification label mapping with 511 | a custom dict 512 | verbose : bool, optional 513 | If True, print parsing information, by default True 514 | 515 | Returns 516 | ------- 517 | `~lcdata.Dataset` 518 | Loaded dataset 519 | """ 520 | dataset = lcdata.read_hdf5(path, in_memory=in_memory) 521 | dataset = parse_dataset( 522 | dataset=dataset, 523 | path_or_name=path, 524 | kind=kind, 525 | reject_invalid=reject_invalid, 526 | require_redshift=require_redshift, 527 | label_map=label_map, 528 | valid_classes=valid_classes, 529 | verbose=verbose 530 | ) 531 | 532 | return dataset 533 | 534 | 535 | def load_datasets(dataset_paths, kind=None, reject_invalid=True, require_redshift=True, label_map=None,valid_classes=None,verbose=True): 536 | """Load a list of datasets and merge them 537 | 538 | Parameters 539 | ---------- 540 | dataset_paths : List[str] 541 | Paths to each dataset to load 542 | verbose : bool, optional 543 | If True, print parsing information, by default True 544 | 545 | Returns 546 | ------- 547 | `~lcdata.Dataset` 548 | Loaded dataset 549 | """ 550 | # Load the dataset(s). 551 | datasets = [] 552 | for dataset_name in dataset_paths: 553 | datasets.append(load_dataset( 554 | path=dataset_name, 555 | kind=kind, 556 | reject_invalid=reject_invalid, 557 | require_redshift=require_redshift, 558 | label_map=label_map, 559 | valid_classes=valid_classes, 560 | verbose=verbose 561 | )) 562 | 563 | # Add all of the datasets together 564 | dataset = reduce(lambda i, j: i+j, datasets) 565 | 566 | return dataset 567 | 568 | 569 | def split_train_test(dataset): 570 | """Split a dataset into training and testing parts. 571 | 572 | We train on 90%, and test on 10%. We use a fixed algorithm to split the train and 573 | test so that we don't have to keep track of what we did. 574 | 575 | Parameters 576 | ---------- 577 | dataset : `~lcdata.Dataset` 578 | Dataset to split 579 | 580 | Returns 581 | ------- 582 | `~lcdata.Dataset` 583 | Training dataset 584 | `~lcdata.Dataset` 585 | Test dataset 586 | """ 587 | # Keep part of the dataset for validation 588 | train_mask = np.ones(len(dataset), dtype=bool) 589 | train_mask[::10] = False 590 | test_mask = ~train_mask 591 | 592 | train_dataset = dataset[train_mask] 593 | test_dataset = dataset[test_mask] 594 | 595 | return train_dataset, test_dataset 596 | 597 | 598 | @lru_cache(maxsize=None) 599 | def get_band_effective_wavelength(band): 600 | """Calculate the effective wavelength of a band 601 | 602 | The results of this calculation are cached, and the effective wavelength will only 603 | be calculated once for each band. 604 | 605 | Parameters 606 | ---------- 607 | band : str 608 | Name of a band in the `sncosmo` band registry 609 | 610 | Returns 611 | ------- 612 | float 613 | Effective wavelength of the band. 614 | """ 615 | return sncosmo.get_bandpass(band).wave_eff 616 | 617 | 618 | def get_bands(dataset): 619 | """Retrieve a list of bands in a dataset 620 | 621 | Parameters 622 | ---------- 623 | dataset : `~lcdata.Dataset` 624 | Dataset to retrieve the bands from 625 | 626 | Returns 627 | ------- 628 | List[str] 629 | List of bands in the dataset sorted by effective wavelength 630 | """ 631 | bands = set() 632 | for lc in dataset.light_curves: 633 | bands = bands.union(lc['band']) 634 | 635 | sorted_bands = np.array(sorted(bands, key=get_band_effective_wavelength)) 636 | 637 | return sorted_bands 638 | -------------------------------------------------------------------------------- /parsnip/light_curve.py: -------------------------------------------------------------------------------- 1 | import lcdata 2 | import numpy as np 3 | import scipy.stats 4 | 5 | from astropy.stats import biweight_location 6 | 7 | SIDEREAL_SCALE = 86400. / 86164.0905 8 | 9 | 10 | def _determine_time_grid(light_curve): 11 | """Determine the time grid that will be used for a light curve 12 | 13 | ParSNIP evaluates all light curves on a grid internally for the encoder. This 14 | function determines where to line up that grid. 15 | 16 | Parameters 17 | ---------- 18 | light_curve : `~astropy.table.Table` 19 | Light curve 20 | 21 | Returns 22 | ------- 23 | float 24 | Reference time for the time grid 25 | """ 26 | time = light_curve['time'] 27 | sidereal_time = time * SIDEREAL_SCALE 28 | 29 | # Initial guess of the phase. Round everything to 0.1 days, and find the decimal 30 | # that has the largest count. 31 | mode, count = scipy.stats.mode(np.round(sidereal_time % 1 + 0.05, 1), keepdims=True) 32 | guess_offset = mode[0] - 0.05 33 | 34 | # Shift everything by the guessed offset 35 | guess_shift_time = sidereal_time - guess_offset 36 | 37 | # Do a proper estimate of the offset. 38 | sidereal_offset = guess_offset + np.median((guess_shift_time + 0.5) % 1) - 0.5 39 | 40 | # Shift everything by the final offset estimate. 41 | shift_time = sidereal_time - sidereal_offset 42 | 43 | # Determine the reference time for the light curve. 44 | # This is tricky to do right. We want to roughly estimate where the "peak" of 45 | # the light curve is. Oftentimes we see low signal-to-noise observations that 46 | # are much larger than the peak flux though. This algorithm tries to find a 47 | # nice balance to handle that. 48 | 49 | # Find the five highest signal-to-noise observations 50 | s2n = light_curve['flux'] / light_curve['fluxerr'] 51 | s2n_mask = np.argsort(s2n)[-5:] 52 | 53 | # If we have very few observations, only keep the ones above signal-to-noise of 54 | # 5 if possible. Sometimes we only have a single point on the rise so far, so 55 | # we don't want to include a bunch of bad observations in our determination of 56 | # the time. 57 | s2n_mask_2 = s2n[s2n_mask] > 5. 58 | if np.any(s2n_mask_2): 59 | cut_times = shift_time[s2n_mask][s2n_mask_2] 60 | else: 61 | # No observations with signal-to-noise above 5. Just use whatever we 62 | # have... 63 | cut_times = shift_time[s2n_mask] 64 | 65 | max_time = np.round(np.median(cut_times)) 66 | 67 | # Convert back to a reference time in the original units. This reference time 68 | # corresponds to the reference of the grid in sidereal time. 69 | reference_time = ((max_time + sidereal_offset) / SIDEREAL_SCALE) 70 | return reference_time 71 | 72 | 73 | def time_to_grid(time, reference_time): 74 | """Convert a time in the original units to one on the internal ParSNIP grid 75 | 76 | Parameters 77 | ---------- 78 | time : float 79 | Real time to convert 80 | reference_time : float 81 | Reference time for the grid 82 | 83 | Returns 84 | ------- 85 | float 86 | Time on the internal grid 87 | """ 88 | return (time - reference_time) * SIDEREAL_SCALE 89 | 90 | 91 | def grid_to_time(grid_time, reference_time): 92 | """Convert a time on the internal grid to a time in the original units 93 | 94 | Parameters 95 | ---------- 96 | grid_time : float 97 | Time on the internal grid 98 | reference_time : float 99 | Reference time for the grid 100 | 101 | Returns 102 | ------- 103 | float 104 | Time in original units 105 | """ 106 | return grid_time / SIDEREAL_SCALE + reference_time 107 | 108 | 109 | def preprocess_light_curve(light_curve, settings, raise_on_invalid=True, 110 | ignore_missing_redshift=False): 111 | """Preprocess a light curve for the ParSNIP model 112 | 113 | Parameters 114 | ---------- 115 | light_curve : `~astropy.Table` 116 | Raw light curve 117 | settings : dict 118 | ParSNIP model settings 119 | raise_on_invalid : bool 120 | Whether to raise a ValueError for invalid light curves. If False, None is 121 | returned instead. By default, True. 122 | ignore_missing_redshift : bool 123 | Whether to ignore missing redshifts, by default False. If False, a missing 124 | redshift value will cause a light curve to be invalid. 125 | 126 | Returns 127 | ------- 128 | `~astropy.Table` 129 | Preprocessed light curve 130 | 131 | Raises 132 | ------ 133 | ValueError 134 | For any invalid light curves that cannot be handled by ParSNIP if 135 | raise_on_invalid is True. The error message will describe why the light curve is 136 | invalid. 137 | """ 138 | if light_curve.meta.get('parsnip_preprocessed', False): 139 | # Already preprocessed 140 | return light_curve 141 | 142 | # Parse the light curve with lcdata to ensure that all of the columns/metadata have 143 | # standard names. 144 | try: 145 | light_curve = lcdata.parse_light_curve(light_curve) 146 | except ValueError as e: 147 | if raise_on_invalid: 148 | raise 149 | else: 150 | lcdata.utils.warn_first_time("invalid_lc_format", 151 | f"Failed to parse light curve: {e}") 152 | return None 153 | 154 | if (not settings['predict_redshift'] 155 | and not ignore_missing_redshift 156 | and not np.isfinite(light_curve.meta['redshift'])): 157 | # For models that don't predict the redshift, we require that each light curve 158 | # has a valid redshift. 159 | message = "No redshift available for light curve and model requires redshift." 160 | if raise_on_invalid: 161 | raise ValueError(message) 162 | else: 163 | lcdata.utils.warn_first_time("missing_redshift", message) 164 | return None 165 | 166 | # Align the observations to a grid in sidereal time. 167 | reference_time = _determine_time_grid(light_curve) 168 | 169 | # Build a preprocessed light curve object. 170 | new_lc = light_curve.copy() 171 | 172 | # Map each band to its corresponding index. 173 | band_map = {j: i for i, j in enumerate(settings['bands'])} 174 | new_lc['band_index'] = [band_map.get(i, -1) for i in new_lc['band']] 175 | 176 | # Cut out any observations that are outside of the window that we are 177 | # considering. 178 | grid_times = time_to_grid(new_lc['time'], reference_time) 179 | time_indices = np.round(grid_times).astype(int) + settings['time_window'] // 2 180 | time_mask = ( 181 | (time_indices >= -settings['time_pad']) 182 | & (time_indices < settings['time_window'] + settings['time_pad']) 183 | ) 184 | new_lc['grid_time'] = grid_times 185 | new_lc['time_index'] = time_indices 186 | 187 | # Correct background levels for bands that need it. 188 | for band_idx, do_correction in enumerate(settings['band_correct_background']): 189 | if not do_correction: 190 | continue 191 | 192 | band_mask = new_lc['band_index'] == band_idx 193 | # Find observations outside of our window. 194 | outside_obs = new_lc[~time_mask & band_mask] 195 | if len(outside_obs) == 0: 196 | # No outside observations, don't update the background level. 197 | continue 198 | 199 | # Estimate the background level and subtract it. 200 | background = biweight_location(outside_obs['flux']) 201 | new_lc['flux'][band_mask] -= background 202 | 203 | # Cut out observations that are in unused bands or outside of the time window. 204 | band_mask = new_lc['band_index'] != -1 205 | new_lc = new_lc[band_mask & time_mask] 206 | 207 | if len(new_lc) == 0: 208 | # No valid observations for this light curve. 209 | message = (f"Light curve has no usable observations! Valid bands are " 210 | f"{settings['bands']}.") 211 | if raise_on_invalid: 212 | raise ValueError(message) 213 | else: 214 | lcdata.utils.warn_first_time("unusable_observations", message) 215 | return None 216 | 217 | # Correct for Milky Way extinction if desired. 218 | band_extinctions = ( 219 | settings['band_mw_extinctions'] * new_lc.meta.get('mwebv', 0.) 220 | ) 221 | extinction_scales = 10**(0.4 * band_extinctions[new_lc['band_index']]) 222 | new_lc['flux'] *= extinction_scales 223 | new_lc['fluxerr'] *= extinction_scales 224 | 225 | # Scale the light curve so that its peak has an amplitude of roughly 1. We use 226 | # the brightest observation with signal-to-noise above 5 if there is one, or 227 | # simply the brightest observation otherwise. 228 | s2n = new_lc['flux'] / new_lc['fluxerr'] 229 | s2n_mask = s2n > 5. 230 | if np.any(s2n_mask): 231 | scale = np.max(new_lc['flux'][s2n_mask]) 232 | else: 233 | scale = np.max(new_lc['flux']) 234 | 235 | new_lc.meta['parsnip_reference_time'] = reference_time 236 | new_lc.meta['parsnip_scale'] = scale 237 | new_lc.meta['parsnip_preprocessed'] = True 238 | 239 | return new_lc 240 | -------------------------------------------------------------------------------- /parsnip/models/plasticc.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LSSTDESC/parsnip/f1cfd56ce49621cea202f0a95f23573ea8310528/parsnip/models/plasticc.pt -------------------------------------------------------------------------------- /parsnip/models/plasticc_photoz.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LSSTDESC/parsnip/f1cfd56ce49621cea202f0a95f23573ea8310528/parsnip/models/plasticc_photoz.pt -------------------------------------------------------------------------------- /parsnip/models/ps1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LSSTDESC/parsnip/f1cfd56ce49621cea202f0a95f23573ea8310528/parsnip/models/ps1.pt -------------------------------------------------------------------------------- /parsnip/parsnip.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import functools 3 | import multiprocessing 4 | import numpy as np 5 | import os 6 | import sys 7 | 8 | from astropy.cosmology import Planck18 9 | import astropy.table 10 | import extinction 11 | import sncosmo 12 | import pkg_resources 13 | import lcdata 14 | 15 | import torch 16 | import torch.utils.data 17 | from torch import nn, optim 18 | from torch.nn import functional as F 19 | from torch.utils.data import DataLoader 20 | 21 | from .light_curve import preprocess_light_curve, grid_to_time, time_to_grid, \ 22 | SIDEREAL_SCALE 23 | from .utils import frac_to_mag, parse_device, replace_nan_grads 24 | from .settings import parse_settings, default_model 25 | from .sncosmo import ParsnipSncosmoSource 26 | 27 | 28 | class ResidualBlock(nn.Module): 29 | """1D residual convolutional neural network block 30 | 31 | This module operates on 1D sequences. The input will be padded so that length of the 32 | sequences is be left unchanged. 33 | 34 | Parameters 35 | ---------- 36 | in_channels : int 37 | Number of channels for the input 38 | out_channels : int 39 | Number of channels for the output 40 | dilation : int 41 | Dilation to use in the convolution 42 | """ 43 | def __init__(self, in_channels, out_channels, dilation): 44 | super().__init__() 45 | 46 | self.in_channels = in_channels 47 | self.out_channels = out_channels 48 | 49 | if self.out_channels < self.in_channels: 50 | raise Exception("out_channels must be >= in_channels.") 51 | 52 | self.conv1 = nn.Conv1d(in_channels, out_channels, 3, dilation=dilation, 53 | padding=dilation) 54 | self.conv2 = nn.Conv1d(out_channels, out_channels, 3, 55 | dilation=dilation, padding=dilation) 56 | 57 | def forward(self, x): 58 | out = self.conv1(x) 59 | out = F.relu(out) 60 | out = self.conv2(out) 61 | 62 | # Add back in the input. If it is smaller than the output, pad it first. 63 | if self.in_channels < self.out_channels: 64 | pad_size = self.out_channels - self.in_channels 65 | pad_x = F.pad(x, (0, 0, 0, pad_size)) 66 | else: 67 | pad_x = x 68 | 69 | # Residual connection 70 | out = out + pad_x 71 | 72 | out = F.relu(out) 73 | 74 | return out 75 | 76 | 77 | class Conv1dBlock(nn.Module): 78 | """1D convolutional neural network block 79 | 80 | This module operates on 1D sequences. The input will be padded so that length of the 81 | sequences is be left unchanged. 82 | 83 | Parameters 84 | ---------- 85 | in_channels : int 86 | Number of channels for the input 87 | out_channels : int 88 | Number of channels for the output 89 | dilation : int 90 | Dilation to use in the convolution 91 | """ 92 | def __init__(self, in_channels, out_channels, dilation): 93 | super().__init__() 94 | 95 | self.in_channels = in_channels 96 | self.out_channels = out_channels 97 | 98 | self.conv = nn.Conv1d(in_channels, out_channels, 5, dilation=dilation, 99 | padding=2*dilation) 100 | 101 | self.relu = nn.ReLU(inplace=True) 102 | 103 | def forward(self, x): 104 | out = self.conv(x) 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | 110 | class GlobalMaxPoolingTime(nn.Module): 111 | """Time max pooling layer for 1D sequences 112 | 113 | This layer applies global max pooling over all channels to elminate the channel 114 | dimension while preserving the time dimension. 115 | """ 116 | def forward(self, x): 117 | out, inds = torch.max(x, 2) 118 | return out 119 | 120 | 121 | class ParsnipModel(nn.Module): 122 | """Generative model of transient light curves 123 | 124 | This class represents a generative model of transient light curves. Given a set of 125 | latent variables representing a transient, it can predict the full spectral time 126 | series of that transient. It can also use variational inference to predict the 127 | posterior distribution over the latent variables for a given light curve. 128 | 129 | Parameters 130 | ---------- 131 | path : str 132 | Path to where the model should be stored on disk. 133 | bands : List[str] 134 | Bands that the model uses as input for variational inference 135 | device : str 136 | PyTorch device to use for the model 137 | threads : int 138 | Number of threads to use 139 | settings : dict 140 | Settings for the model. Any settings specified here will override the defaults 141 | set in settings.py 142 | ignore_unknown_settings : bool 143 | If True, ignore any settings that are specified that are unknown. Otherwise, 144 | raise a KeyError if an unknown setting is specified. By default False. 145 | """ 146 | def __init__(self, path, bands, device='cpu', threads=8, settings={}, 147 | ignore_unknown_settings=False): 148 | super().__init__() 149 | 150 | # Parse settings 151 | self.settings = parse_settings(bands, settings, 152 | ignore_unknown_settings=ignore_unknown_settings) 153 | 154 | self.path = path 155 | self.threads = threads 156 | 157 | # Setup the device 158 | self.device = parse_device(device) 159 | torch.set_num_threads(self.threads) 160 | 161 | # Setup the bands 162 | self._setup_band_weights() 163 | 164 | # Setup the color law. We scale this so that the color law has a B-V color of 1, 165 | # meaning that a coefficient multiplying the color law is the b-v color. 166 | color_law = extinction.fm07(self.model_wave, 3.1) 167 | self.color_law = torch.FloatTensor(color_law).to(self.device) 168 | 169 | # Setup the timing 170 | self.input_times = (torch.arange(self.settings['time_window'], 171 | device=self.device) 172 | - self.settings['time_window'] // 2) 173 | 174 | # Build the model 175 | self._build_model() 176 | 177 | # Set up the training 178 | self.epoch = 0 179 | optim_kwargs = { 180 | 'params': self.parameters(), 181 | 'lr': self.settings['learning_rate'], 182 | } 183 | if self.settings['optimizer'].lower() == 'adam': 184 | self.optimizer = optim.Adam(**optim_kwargs) 185 | elif self.settings['optimizer'].lower() == 'sgd': 186 | self.optimizer = optim.SGD(momentum=self.settings['sgd_momentum'], **optim_kwargs) 187 | else: 188 | raise ValueError('Unknown optimizer "{}"'.format(self.settings['optimizer'])) 189 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( 190 | self.optimizer, factor=self.settings['scheduler_factor'], verbose=True 191 | ) 192 | 193 | # Send the model weights to the desired device 194 | self.to(self.device, force=True) 195 | 196 | def to(self, device, force=False): 197 | """Send the model to the specified device 198 | 199 | Parameters 200 | ---------- 201 | device : str 202 | PyTorch device 203 | force : bool, optional 204 | If True, force the model to be sent to the device even if it is there 205 | already (useful if only parts of the model are there), by default False 206 | """ 207 | new_device = parse_device(device) 208 | 209 | if self.device == new_device and not force: 210 | # Already on that device 211 | return 212 | 213 | self.device = new_device 214 | 215 | # Send all of the weights 216 | super().to(self.device) 217 | 218 | # Send all of the variables that we create manually 219 | self.color_law = self.color_law.to(self.device) 220 | self.input_times = self.input_times.to(self.device) 221 | 222 | self.band_interpolate_locations = \ 223 | self.band_interpolate_locations.to(self.device) 224 | self.band_interpolate_weights = self.band_interpolate_weights.to(self.device) 225 | 226 | def save(self): 227 | """Save the model""" 228 | os.makedirs(os.path.dirname(self.path), exist_ok=True) 229 | torch.save([self.settings, self.state_dict()], self.path) 230 | 231 | def _setup_band_weights(self): 232 | """Setup the interpolation for the band weights used for photometry""" 233 | # Build the model in log wavelength 234 | model_log_wave = np.linspace(np.log10(self.settings['min_wave']), 235 | np.log10(self.settings['max_wave']), 236 | self.settings['spectrum_bins']) 237 | model_spacing = model_log_wave[1] - model_log_wave[0] 238 | 239 | band_spacing = model_spacing / self.settings['band_oversampling'] 240 | band_max_log_wave = ( 241 | np.log10(self.settings['max_wave'] * (1 + self.settings['max_redshift'])) 242 | + band_spacing 243 | ) 244 | 245 | # Oversampling must be odd. 246 | assert self.settings['band_oversampling'] % 2 == 1 247 | pad = (self.settings['band_oversampling'] - 1) // 2 248 | band_log_wave = np.arange(np.log10(self.settings['min_wave']), 249 | band_max_log_wave, band_spacing) 250 | band_wave = 10**(band_log_wave) 251 | band_pad_log_wave = np.arange( 252 | np.log10(self.settings['min_wave']) - band_spacing * pad, 253 | band_max_log_wave + band_spacing * pad, 254 | band_spacing 255 | ) 256 | band_pad_dwave = ( 257 | 10**(band_pad_log_wave + band_spacing / 2.) 258 | - 10**(band_pad_log_wave - band_spacing / 2.) 259 | ) 260 | 261 | ref = sncosmo.get_magsystem(self.settings['magsys']) 262 | 263 | band_weights = [] 264 | 265 | for band_name in self.settings['bands']: 266 | band = sncosmo.get_bandpass(band_name) 267 | band_transmission = band(10**(band_pad_log_wave)) 268 | 269 | # Convolve the bands to match the sampling of the spectrum. 270 | band_conv_transmission = np.convolve( 271 | band_transmission * band_pad_dwave, 272 | np.ones(self.settings['band_oversampling']), 273 | mode='valid' 274 | ) 275 | 276 | band_weight = ( 277 | band_wave 278 | * band_conv_transmission 279 | / sncosmo.constants.HC_ERG_AA 280 | / ref.zpbandflux(band) 281 | * 10**(0.4 * -20.) 282 | ) 283 | 284 | band_weights.append(band_weight) 285 | 286 | # Get the locations that should be sampled at redshift 0. We can scale these to 287 | # get the locations at any redshift. 288 | band_interpolate_locations = torch.arange( 289 | 0, 290 | self.settings['spectrum_bins'] * self.settings['band_oversampling'], 291 | self.settings['band_oversampling'] 292 | ) 293 | 294 | # Save the variables that we need to do interpolation. 295 | self.band_interpolate_locations = band_interpolate_locations.to(self.device) 296 | self.band_interpolate_spacing = band_spacing 297 | self.band_interpolate_weights = torch.FloatTensor(band_weights).to(self.device) 298 | self.model_wave = 10**(model_log_wave) 299 | 300 | def _calculate_band_weights(self, redshifts): 301 | """Calculate the band weights for a given set of redshifts 302 | 303 | We have precomputed the weights for each bandpass, so we simply interpolate 304 | those weights at the desired redshifts. We are working in log-wavelength, so a 305 | change in redshift just gives us a shift in indices. 306 | 307 | Parameters 308 | ---------- 309 | redshifts : List[float] 310 | Redshifts to calculate the band weights at 311 | 312 | Returns 313 | ------- 314 | `~numpy.ndarray` 315 | Band weights for each redshift/band combination 316 | """ 317 | # Figure out the locations to sample at for each redshift. 318 | locs = ( 319 | self.band_interpolate_locations 320 | + torch.log10(1 + redshifts)[:, None] / self.band_interpolate_spacing 321 | ) 322 | flat_locs = locs.flatten() 323 | 324 | # Linear interpolation 325 | int_locs = flat_locs.long() 326 | remainders = flat_locs - int_locs 327 | 328 | start = self.band_interpolate_weights[..., int_locs] 329 | end = self.band_interpolate_weights[..., int_locs + 1] 330 | 331 | flat_result = remainders * end + (1 - remainders) * start 332 | result = flat_result.reshape((-1,) + locs.shape).permute(1, 2, 0) 333 | 334 | # We need an extra term of 1 + z from the filter contraction. 335 | result /= (1 + redshifts)[:, None, None] 336 | 337 | return result 338 | 339 | def _test_band_weights(self, redshift, source='salt2-extended'): 340 | """Test the accuracy of the band weights 341 | 342 | We compare sncosmo photometry to the photometry calculated by this class. 343 | 344 | Parameters 345 | ---------- 346 | redshift : float 347 | Redshift to evaluate the model at 348 | source : str, optional 349 | SNCosmo source to use, by default 'salt2-extended' 350 | """ 351 | model = sncosmo.Model(source=source) 352 | 353 | # sncosmo photometry 354 | model.set(z=redshift) 355 | sncosmo_photometry = model.bandflux(self.settings['bands'], 0., zp=-20., 356 | zpsys=self.settings['magsys']) 357 | 358 | # parsnip photometry 359 | model.set(z=0.) 360 | model_flux = model._flux(0., self.model_wave)[0] 361 | band_weights = self._calculate_band_weights( 362 | torch.FloatTensor([redshift]))[0].numpy() 363 | parsnip_photometry = np.sum(model_flux[:, None] * band_weights, axis=0) 364 | 365 | print(f"z = {redshift}") 366 | print(f"sncosmo photometry: {sncosmo_photometry}") 367 | print(f"parsnip photometry: {parsnip_photometry}") 368 | print(f"ratio: {parsnip_photometry / sncosmo_photometry}") 369 | 370 | def preprocess(self, dataset, chunksize=64, verbose=True): 371 | """Preprocess an lcdata dataset 372 | 373 | The preprocessing will be done over multiple threads. Set `ParsnipModel.threads` 374 | to change how many are used. If the dataset is already preprocessed, then 375 | nothing will be done and it will be returned as is. 376 | 377 | Parameters 378 | ---------- 379 | dataset : `~lcdata.Dataset` 380 | Dataset to preprocess 381 | chunksize : int, optional 382 | Number of light curves to process at a time, by default 64 383 | verbose : bool, optional 384 | Whether to show a progress bar, by default True 385 | 386 | Returns 387 | ------- 388 | `~lcdata.Dataset` 389 | Preprocessed dataset 390 | """ 391 | import lcdata 392 | 393 | # Check if we were given a preprocessed dataset. We store our preprocessed data 394 | # as the parsnip_data variable. 395 | if ('parsnip_preprocessed' in dataset.meta.keys() 396 | and np.all(dataset.meta['parsnip_preprocessed'])): 397 | return dataset 398 | 399 | if self.threads == 1: 400 | iterator = dataset.light_curves 401 | if verbose: 402 | iterator = tqdm(dataset.light_curves, file=sys.stdout, 403 | desc="Preprocessing dataset") 404 | 405 | # Run on a single core without multiprocessing 406 | preprocessed_light_curves = [] 407 | for lc in iterator: 408 | preprocessed_light_curves.append( 409 | preprocess_light_curve(lc, self.settings, raise_on_invalid=False) 410 | ) 411 | else: 412 | # Run with multiprocessing in multiple threads. 413 | func = functools.partial(preprocess_light_curve, settings=self.settings, 414 | raise_on_invalid=False) 415 | 416 | with multiprocessing.Pool(self.threads) as p: 417 | iterator = p.imap(func, dataset.light_curves, chunksize=chunksize) 418 | if verbose: 419 | iterator = tqdm(iterator, total=len(dataset.light_curves), 420 | file=sys.stdout, desc="Preprocessing dataset") 421 | preprocessed_light_curves = list(iterator) 422 | 423 | # Check if any light curves failed to process 424 | none_count = 0 425 | for lc in preprocessed_light_curves: 426 | if lc is None: 427 | none_count += 1 428 | if none_count > 0: 429 | print(f"WARNING: Rejecting {none_count}/{len(preprocessed_light_curves)} " 430 | "light curves. Consider using 'parsnip.load_dataset()' or " 431 | "'parsnip.parse_dataset()' to load/parse the dataset and hopefully " 432 | "avoid this.") 433 | preprocessed_light_curves = [i for i in preprocessed_light_curves if i is 434 | not None] 435 | 436 | dataset = lcdata.from_light_curves(preprocessed_light_curves) 437 | return dataset 438 | 439 | def augment_light_curves(self, light_curves, as_table=True): 440 | """Augment a set of light curves 441 | 442 | Parameters 443 | ---------- 444 | light_curves : List[`~astropy.table.Table`] 445 | List of light curves to augment 446 | as_table : bool, optional 447 | Whether to return the light curves as astropy Tables, by default True. 448 | Constructing new tables is relatively slow, so internally we skip this step 449 | when training the ParSNIP model. 450 | 451 | Returns 452 | ------- 453 | List 454 | Augmented light curves 455 | """ 456 | # Check if we have a list of light curves or a single one and handle it 457 | # appropriately. 458 | if isinstance(light_curves, astropy.table.Table): 459 | # Single object. Wrap it so that we can process it as an array. We'll unwrap 460 | # it at the end. 461 | single = True 462 | light_curves = [light_curves] 463 | else: 464 | single = False 465 | 466 | new_light_curves = np.empty(shape=len(light_curves), dtype=object) 467 | 468 | for idx, lc in enumerate(light_curves): 469 | # Convert the table to a numpy recarray. This is much faster to work with. 470 | data = lc.as_array() 471 | 472 | # Make a copy of the metadata to work off of. 473 | meta = lc.meta.copy(use_cache=not as_table) 474 | 475 | # Randomly drop observations and make a copy of the light curve. 476 | drop_frac = np.random.uniform(0, 0.5) 477 | mask = np.random.rand(len(data)) > drop_frac 478 | data = data[mask] 479 | 480 | # Shift the time randomly. 481 | time_shift = np.round( 482 | np.random.normal(0., self.settings['time_sigma']) 483 | ).astype(int) 484 | meta['parsnip_reference_time'] += time_shift / SIDEREAL_SCALE 485 | data['grid_time'] -= time_shift 486 | data['time_index'] -= time_shift 487 | 488 | # Add noise to the observations 489 | if np.random.rand() < 0.5 and len(data) > 0: 490 | # Choose an overall scale for the noise from a lognormal 491 | # distribution. 492 | noise_scale = np.random.lognormal(-4., 1.) * meta['parsnip_scale'] 493 | 494 | # Choose the noise levels for each observation from a lognormal 495 | # distribution. 496 | noise_sigmas = np.random.lognormal(np.log(noise_scale), 1., len(data)) 497 | 498 | # Add the noise to the observations. 499 | noise = np.random.normal(0., noise_sigmas) 500 | data['flux'] += noise 501 | data['fluxerr'] = np.sqrt(data['fluxerr']**2 + noise_sigmas**2) 502 | 503 | # Scale the amplitude that we input to the model randomly. 504 | amp_scale = np.exp(np.random.normal(0, 0.5)) 505 | meta['parsnip_scale'] *= amp_scale 506 | 507 | # Convert back to an astropy Table if desired. This is somewhat slow, so we 508 | # skip it internally when training the model. 509 | if as_table: 510 | new_lc = astropy.table.Table(data, meta=meta) 511 | else: 512 | new_lc = (data, meta) 513 | 514 | new_light_curves[idx] = new_lc 515 | 516 | if single: 517 | return new_light_curves[0] 518 | else: 519 | return new_light_curves 520 | 521 | def _get_data(self, light_curves): 522 | """Extract data needed by ParSNIP from a set of light curves. 523 | 524 | Parameters 525 | ---------- 526 | light_curves : List[`~astropy.table.Table`] 527 | Light curves to extract data from 528 | 529 | Returns 530 | ------- 531 | data : dict 532 | A dictionary with the following keys: 533 | - 'input_data' : A `~torch.FloatTensor` that is used as input to the ParSNIP 534 | encoder. 535 | - 'compare_data' : A `~torch.FloatTensor` containing data that is used for 536 | comparisons with the output of the ParSNIP decoder. 537 | - 'redshift' : A `~torch.FloatTensor` containing the redshifts of each 538 | light curve. 539 | - 'band_indices' : A `~torch.LongTensor` containing the band indices for 540 | each observation that will be compared 541 | - 'photoz' : A `~torch.FloatTensor` containing the photozs of each 542 | observation. Only available if the 'predict_redshift' model setting is 543 | True. 544 | - 'photoz_error' : A `~torch.FloatTensor` containing the photoz errors of 545 | each observation. Only available if the 'predict_redshift' model setting 546 | is True. 547 | """ 548 | redshifts = [] 549 | if self.settings['predict_redshift']: 550 | photozs = [] 551 | photoz_errors = [] 552 | 553 | compare_data = [] 554 | compare_band_indices = [] 555 | 556 | # Build a grid for the input 557 | grid_flux = np.zeros((len(light_curves), len(self.settings['bands']), 558 | self.settings['time_window'])) 559 | grid_weights = np.zeros_like(grid_flux) 560 | 561 | for idx, lc in enumerate(light_curves): 562 | # Convert the table to a numpy recarray. This is much faster to work with. 563 | # For augmentation, we skip creating a Table because that is slow and just 564 | # keep the recarray. Handle that too. 565 | if isinstance(lc, astropy.table.Table): 566 | lc_data = lc.as_array() 567 | lc_meta = lc.meta 568 | else: 569 | lc_data, lc_meta = lc 570 | 571 | # Extract the redshift. 572 | if self.settings['predict_redshift']: 573 | # Note: this uses the keys for PLAsTiCC and should be adapted to handle 574 | # more general surveys. 575 | redshifts.append(lc_meta['hostgal_specz']) 576 | photozs.append(lc_meta['hostgal_photoz']) 577 | photoz_errors.append(lc_meta['hostgal_photoz_err']) 578 | else: 579 | redshifts.append(lc_meta['redshift']) 580 | 581 | # Mask out observations that are outside of our window. 582 | mask = (lc_data['time_index'] >= 0) & (lc_data['time_index'] < 583 | self.settings['time_window']) 584 | lc_data = lc_data[mask] 585 | 586 | # Scale the flux and fluxerr appropriately. Note that applying the mask 587 | # makes a copy of the array, so this won't affect the original data. 588 | lc_data['flux'] /= lc_meta['parsnip_scale'] 589 | lc_data['fluxerr'] /= lc_meta['parsnip_scale'] 590 | 591 | # Calculate weights with an error floor included. Note that this typically a 592 | # very large number. For the comparison this doesn't matter, but for the 593 | # input we scale it by the error floor so that it becomes a number between 0 594 | # and 1. 595 | weights = 1 / (lc_data['fluxerr']**2 + self.settings['error_floor']**2) 596 | 597 | # Fill in the input array. 598 | grid_flux[idx, lc_data['band_index'], lc_data['time_index']] = \ 599 | lc_data['flux'] 600 | grid_weights[idx, lc_data['band_index'], lc_data['time_index']] = \ 601 | self.settings['error_floor']**2 * weights 602 | 603 | # Stack all of the data that will be used for comparisons and convert it to 604 | # a torch tensor. 605 | obj_compare_data = torch.FloatTensor(np.vstack([ 606 | lc_data['grid_time'], 607 | lc_data['flux'], 608 | lc_data['fluxerr'], 609 | weights, 610 | ])) 611 | compare_data.append(obj_compare_data.T) 612 | compare_band_indices.append(torch.LongTensor(lc_data['band_index'].copy())) 613 | 614 | # Gather the input data. 615 | redshifts = np.array(redshifts) 616 | if self.settings['predict_redshift']: 617 | photozs = np.array(photozs) 618 | photoz_errors = np.array(photoz_errors) 619 | 620 | # Add extra features to the input. 621 | if self.settings['input_redshift']: 622 | if self.settings['predict_redshift']: 623 | extra_input_data = [photozs, photoz_errors] 624 | else: 625 | extra_input_data = [redshifts] 626 | 627 | # Stack everything together. 628 | input_data = np.concatenate( 629 | [i[:, None, None].repeat(self.settings['time_window'], axis=2) for i in 630 | extra_input_data] 631 | + [grid_flux, grid_weights], 632 | axis=1 633 | ) 634 | 635 | # Convert to torch tensors 636 | input_data = torch.FloatTensor(input_data).to(self.device) 637 | redshifts = torch.FloatTensor(redshifts).to(self.device) 638 | 639 | # Pad all of the compare data to have the same shape. 640 | compare_data = nn.utils.rnn.pad_sequence(compare_data, batch_first=True) 641 | compare_data = compare_data.permute(0, 2, 1) 642 | compare_band_indices = nn.utils.rnn.pad_sequence(compare_band_indices, 643 | batch_first=True) 644 | compare_data = compare_data.to(self.device) 645 | compare_band_indices = compare_band_indices.to(self.device) 646 | 647 | data = { 648 | 'input_data': input_data, 649 | 'compare_data': compare_data, 650 | 'redshift': redshifts, 651 | 'band_indices': compare_band_indices, 652 | } 653 | 654 | if self.settings['predict_redshift']: 655 | data['photoz'] = torch.FloatTensor(photozs).to(self.device) 656 | data['photoz_error'] = torch.FloatTensor(photoz_errors).to(self.device) 657 | 658 | return data 659 | 660 | def _build_model(self): 661 | """Build the model""" 662 | input_size = len(self.settings['bands']) * 2 663 | if self.settings['input_redshift']: 664 | if self.settings['predict_redshift']: 665 | input_size += 2 666 | else: 667 | input_size += 1 668 | 669 | if self.settings['encode_block'] == 'conv1d': 670 | encode_block = Conv1dBlock 671 | elif self.settings['encode_block'] == 'residual': 672 | encode_block = ResidualBlock 673 | else: 674 | raise Exception(f"Unknown block {self.settings['encode_block']}.") 675 | 676 | # Encoder architecture. We start with an input of size input_size x 677 | # time_window We apply a series of convolutional blocks to this that produce 678 | # outputs that are the same size. The type of block is specified by 679 | # settings['encode_block']. Each convolutional block has a dilation that is 680 | # given by settings['encode_conv_dilations']. 681 | if (len(self.settings['encode_conv_architecture']) != 682 | len(self.settings['encode_conv_dilations'])): 683 | raise Exception("Layer sizes and dilations must have the same length!") 684 | 685 | encode_layers = [] 686 | 687 | # Convolutional layers. 688 | last_size = input_size 689 | for layer_size, dilation in zip(self.settings['encode_conv_architecture'], 690 | self.settings['encode_conv_dilations']): 691 | encode_layers.append( 692 | encode_block(last_size, layer_size, dilation) 693 | ) 694 | last_size = layer_size 695 | 696 | # Fully connected layers for the encoder following the convolution blocks. 697 | # These are Conv1D layers with a kernel size of 1 that mix within the time 698 | # indexes. 699 | for layer_size in self.settings['encode_fc_architecture']: 700 | encode_layers.append(nn.Conv1d(last_size, layer_size, 1)) 701 | encode_layers.append(nn.ReLU()) 702 | last_size = layer_size 703 | 704 | self.encode_layers = nn.Sequential(*encode_layers) 705 | 706 | # Fully connected layers for the time-indexing layer. These are Conv1D layers 707 | # with a kernel size of 1 that mix within time indexes. 708 | time_last_size = last_size 709 | encode_time_layers = [] 710 | for layer_size in self.settings['encode_time_architecture']: 711 | encode_time_layers.append(nn.Conv1d(time_last_size, layer_size, 1)) 712 | encode_time_layers.append(nn.ReLU()) 713 | time_last_size = layer_size 714 | 715 | # Final layer, go down to a single channel with no activation function. 716 | encode_time_layers.append(nn.Conv1d(time_last_size, 1, 1)) 717 | self.encode_time_layers = nn.Sequential(*encode_time_layers) 718 | 719 | # Fully connected layers to calculate the latent space parameters for the VAE. 720 | encode_latent_layers = [] 721 | latent_last_size = last_size 722 | for layer_size in self.settings['encode_latent_prepool_architecture']: 723 | encode_latent_layers.append(nn.Conv1d(latent_last_size, layer_size, 1)) 724 | encode_latent_layers.append(nn.ReLU()) 725 | latent_last_size = layer_size 726 | 727 | # Apply a global max pooling over the time channels. 728 | encode_latent_layers.append(GlobalMaxPoolingTime()) 729 | 730 | # Apply fully connected layers to get the embedding. 731 | for layer_size in self.settings['encode_latent_postpool_architecture']: 732 | encode_latent_layers.append(nn.Linear(latent_last_size, layer_size)) 733 | encode_latent_layers.append(nn.ReLU()) 734 | latent_last_size = layer_size 735 | 736 | self.encode_latent_layers = nn.Sequential(*encode_latent_layers) 737 | 738 | # Finally, use a last FC layer to get mu and logvar 739 | mu_size = self.settings['latent_size'] + 1 740 | logvar_size = self.settings['latent_size'] + 2 741 | 742 | if self.settings['predict_redshift']: 743 | # Predict the redshift 744 | mu_size += 1 745 | logvar_size += 1 746 | 747 | self.encode_mu_layer = nn.Linear(latent_last_size, mu_size) 748 | self.encode_logvar_layer = nn.Linear(latent_last_size, logvar_size) 749 | 750 | # MLP decoder. We start with an input that is the intrinsic latent space + one 751 | # dimension for time, and output a spectrum of size 752 | # self.settings['spectrum_bins']. We also have hidden layers with sizes given 753 | # by self.settings['decode_layers']. We implement this using a Conv1D layer 754 | # with a kernel size of 1 for computational reasons so that it decodes multiple 755 | # spectra for each transient all at the same time, but the decodes are all done 756 | # independently so this is really an MLP. 757 | decode_last_size = self.settings['latent_size'] + 1 758 | decode_layers = [] 759 | for layer_size in self.settings['decode_architecture']: 760 | decode_layers.append(nn.Conv1d(decode_last_size, layer_size, 1)) 761 | decode_layers.append(nn.Tanh()) 762 | decode_last_size = layer_size 763 | 764 | # Final layer. Use a FC layer to get us to the correct number of bins, and use 765 | # a softplus activation function to get positive flux. 766 | decode_layers.append(nn.Conv1d(decode_last_size, 767 | self.settings['spectrum_bins'], 1)) 768 | decode_layers.append(nn.Softplus()) 769 | 770 | self.decode_layers = nn.Sequential(*decode_layers) 771 | 772 | def get_data_loader(self, dataset, augment=False, **kwargs): 773 | """Get a PyTorch DataLoader for an lcdata Dataset 774 | 775 | Parameters 776 | ---------- 777 | dataset : `~lcdata.Dataset` 778 | Dataset to load 779 | augment : bool, optional 780 | Whether to augment the dataset, by default False 781 | 782 | Returns 783 | ------- 784 | `~torch.utils.data.DataLoader` 785 | PyTorch DataLoader for the dataset 786 | """ 787 | # Preprocess the dataset if it isn't already. 788 | dataset = self.preprocess(dataset) 789 | 790 | if augment: 791 | # Reset the metadata caches that we use to speed up augmenting. 792 | for lc in dataset.light_curves: 793 | lc.meta.copy(update_cache=True) 794 | 795 | # To speed things up, don't create new astropy.Table objects for the 796 | # augmented light curves. The `forward` method can handle the result that 797 | # is returned by `augment_light_curves`. 798 | collate_fn = functools.partial(self.augment_light_curves, as_table=False) 799 | else: 800 | collate_fn = list 801 | 802 | return DataLoader(dataset.light_curves, batch_size=self.settings['batch_size'], 803 | collate_fn=collate_fn, **kwargs) 804 | 805 | def encode(self, input_data): 806 | """Predict the latent variables for a set of light curves 807 | 808 | We use variational inference, and predict the parameters of a posterior 809 | distribution over the latent space. 810 | 811 | Parameters 812 | ---------- 813 | input_data : `~torch.FloatTensor` 814 | Input data representing a set of gridded light curves 815 | 816 | Returns 817 | ------- 818 | `~torch.FloatTensor` 819 | Mean predictions for each latent variable 820 | `~torch.FloatTensor` 821 | Log-variance predictions for each latent variable 822 | """ 823 | # Apply common encoder blocks 824 | e = self.encode_layers(input_data) 825 | 826 | # Reference time branch. First, apply additional FC layers to get to an output 827 | # that has a single channel. 828 | e_time = self.encode_time_layers(e) 829 | 830 | # Apply the time-indexing layer to calculate the reference time. This is a 831 | # special layer that is invariant to translations of the input. 832 | t_vec = torch.nn.functional.softmax(torch.squeeze(e_time, 1), dim=1) 833 | ref_time_mu = ( 834 | torch.sum(t_vec * self.input_times, 1) 835 | / self.settings['time_sigma'] 836 | ) 837 | 838 | # Latent space branch. 839 | e_latent = self.encode_latent_layers(e) 840 | 841 | # Predict mu and logvar 842 | encoding_mu = self.encode_mu_layer(e_latent) 843 | encoding_logvar = self.encode_logvar_layer(e_latent) 844 | 845 | # Prepend the time mu value to get the full encoding. 846 | encoding_mu = torch.cat([ref_time_mu[:, None], encoding_mu], 1) 847 | 848 | # Constrain the logvar so that it doesn't go to crazy values and throw 849 | # everything off with floating point precision errors. This will not be a 850 | # concern for a properly trained model, but things can go wrong early in the 851 | # training at high learning rates. 852 | encoding_logvar = torch.clamp(encoding_logvar, None, 5.) 853 | 854 | return encoding_mu, encoding_logvar 855 | 856 | def decode_spectra(self, encoding, phases, color, amplitude=None): 857 | """Predict the spectra at a given set of latent variables 858 | 859 | Parameters 860 | ---------- 861 | encoding : `~torch.FloatTensor` 862 | Coordinates in the ParSNIP intrinsic latent space for each light curve 863 | phases : `~torch.FloatTensor` 864 | Phases to decode each light curve at 865 | color : `~torch.FloatTensor` 866 | Color of each light curve 867 | amplitude : `~torch.FloatTensor`, optional 868 | Amplitude to scale each light curve by, by default no scaling will be 869 | applied. 870 | 871 | Returns 872 | ------- 873 | `~torch.FloatTensor` 874 | Predicted spectra 875 | """ 876 | scale_phases = phases / (self.settings['time_window'] // 2) 877 | 878 | repeat_encoding = encoding[:, :, None].expand((-1, -1, scale_phases.shape[1])) 879 | stack_encoding = torch.cat([repeat_encoding, scale_phases[:, None, :]], 1) 880 | 881 | # Apply intrinsic decoder 882 | model_spectra = self.decode_layers(stack_encoding) 883 | 884 | if color is not None: 885 | # Apply colors 886 | apply_colors = 10**(-0.4 * color[:, None] * self.color_law[None, :]) 887 | model_spectra = model_spectra * apply_colors[..., None] 888 | 889 | if amplitude is not None: 890 | # Apply amplitude 891 | model_spectra = model_spectra * amplitude[:, None, None] 892 | 893 | return model_spectra 894 | 895 | def decode(self, encoding, ref_times, color, times, redshifts, band_indices, 896 | amplitude=None): 897 | """Predict the light curves for a given set of latent variables 898 | 899 | Parameters 900 | ---------- 901 | encoding : `~torch.FloatTensor` 902 | Coordinates in the ParSNIP intrinsic latent space for each light curve 903 | ref_times : `~torch.FloatTensor` 904 | Reference time for each light curve 905 | color : `~torch.FloatTensor` 906 | Color of each light curve 907 | times : `~torch.FloatTensor` 908 | Times to predict each light curve at 909 | redshifts : `~torch.FloatTensor` 910 | Redshift of each light curve 911 | band_indices : `~torch.LongTensor` 912 | Band indices for each observation 913 | amplitude : `~torch.FloatTensor`, optional 914 | Amplitude to scale each light curve by, by default no scaling will be 915 | applied 916 | 917 | Returns 918 | ------- 919 | `~torch.FloatTensor` 920 | Model spectra 921 | `~torch.FloatTensor` 922 | Model photometry 923 | """ 924 | phases = ( 925 | (times - ref_times[:, None]) 926 | / (1 + redshifts[:, None]) 927 | ) 928 | 929 | # Generate the restframe spectra 930 | model_spectra = self.decode_spectra(encoding, phases, color, amplitude) 931 | 932 | # Figure out the weights for each band 933 | band_weights = self._calculate_band_weights(redshifts) 934 | num_batches = band_indices.shape[0] 935 | num_observations = band_indices.shape[1] 936 | batch_indices = ( 937 | torch.arange(num_batches, device=encoding.device) 938 | .repeat_interleave(num_observations) 939 | ) 940 | obs_band_weights = ( 941 | band_weights[batch_indices, :, band_indices.flatten()] 942 | .reshape((num_batches, num_observations, -1)) 943 | .permute(0, 2, 1) 944 | ) 945 | 946 | # Sum over each filter. 947 | model_flux = torch.sum(model_spectra * obs_band_weights, axis=1) 948 | 949 | return model_spectra, model_flux 950 | 951 | def _reparameterize(self, mu, logvar, sample=True): 952 | if sample: 953 | std = torch.exp(0.5*logvar) 954 | eps = torch.randn_like(std) 955 | return mu + eps*std 956 | else: 957 | return mu 958 | 959 | def _sample(self, encoding_mu, encoding_logvar, sample=True): 960 | sample_encoding = self._reparameterize(encoding_mu, encoding_logvar, 961 | sample=sample) 962 | 963 | time_sigma = self.settings['time_sigma'] 964 | color_sigma = self.settings['color_sigma'] 965 | 966 | if self.settings['predict_redshift']: 967 | redshift = torch.exp(sample_encoding[:, -1] - 1) 968 | sample_encoding = sample_encoding[:, :-1] 969 | else: 970 | redshift = torch.zeros_like(sample_encoding[:, 0]) 971 | 972 | # Rescale variables 973 | ref_times = sample_encoding[:, 0] * time_sigma 974 | color = sample_encoding[:, 1] * color_sigma 975 | encoding = sample_encoding[:, 2:] 976 | 977 | # Constrain the color and reference time so that things don't go to crazy values 978 | # and throw everything off with floating point precision errors. This will not 979 | # be a concern for a properly trained model, but things can go wrong early in 980 | # the training at high learning rates. 981 | ref_times = torch.clamp(ref_times, -10. * time_sigma, 10. * time_sigma) 982 | color = torch.clamp(color, -10. * color_sigma, 10. * color_sigma) 983 | redshift = torch.clamp(redshift, 0., self.settings['max_redshift']) 984 | 985 | return redshift, ref_times, color, encoding 986 | 987 | def forward(self, light_curves, sample=True, to_numpy=False): 988 | """Run a set of light curves through the full ParSNIP model 989 | 990 | We use variational inference to predict the latent representation of each light 991 | curve, and we then use the generative model to predict the light curves for 992 | those representations. 993 | 994 | Parameters 995 | ---------- 996 | light_curves : List[`~astropy.table.Table`] 997 | List of light curves 998 | sample : bool, optional 999 | If True (default), sample from the posterior distribution. If False, use the 1000 | MAP. 1001 | to_numpy : bool, optional 1002 | Whether to convert the outputs to numpy arrays, by default False 1003 | 1004 | Returns 1005 | ------- 1006 | dict 1007 | Result dictionary. If to_numpy is True, all of the elements will be numpy 1008 | arrays. Otherwise, they will be PyTorch tensors on the model's device. 1009 | """ 1010 | # Extract the data that we need and move it to the right device. 1011 | data = self._get_data(light_curves) 1012 | 1013 | # Encode the light curves. 1014 | encoding_mu, encoding_logvar = self.encode(data['input_data']) 1015 | 1016 | # Sample from the latent space. 1017 | predicted_redshifts, ref_times, color, encoding = self._sample( 1018 | encoding_mu, encoding_logvar, sample=sample 1019 | ) 1020 | 1021 | if self.settings['predict_redshift']: 1022 | use_redshifts = predicted_redshifts 1023 | else: 1024 | use_redshifts = data['redshift'] 1025 | 1026 | time = data['compare_data'][:, 0] 1027 | obs_flux = data['compare_data'][:, 1] 1028 | obs_fluxerr = data['compare_data'][:, 2] 1029 | obs_weight = data['compare_data'][:, 3] 1030 | 1031 | # Decode the light curves 1032 | model_spectra, model_flux = self.decode( 1033 | encoding, ref_times, color, time, use_redshifts, data['band_indices'] 1034 | ) 1035 | 1036 | # Analytically evaluate the conditional distribution for the amplitude and 1037 | # sample from it. 1038 | amplitude_mu, amplitude_logvar = self._compute_amplitude(obs_weight, model_flux, 1039 | obs_flux) 1040 | amplitude = self._reparameterize(amplitude_mu, amplitude_logvar, sample=sample) 1041 | model_flux = model_flux * amplitude[:, None] 1042 | model_spectra = model_spectra * amplitude[:, None, None] 1043 | 1044 | result = { 1045 | 'ref_times': ref_times, 1046 | 'color': color, 1047 | 'encoding': encoding, 1048 | 'amplitude': amplitude, 1049 | 'redshift': data['redshift'], 1050 | 'predicted_redshift': predicted_redshifts, 1051 | 'time': time, 1052 | 'obs_flux': obs_flux, 1053 | 'obs_fluxerr': obs_fluxerr, 1054 | 'obs_weight': obs_weight, 1055 | 'band_indices': data['band_indices'], 1056 | 'model_flux': model_flux, 1057 | 'model_spectra': model_spectra, 1058 | 'encoding_mu': encoding_mu, 1059 | 'encoding_logvar': encoding_logvar, 1060 | 'amplitude_mu': amplitude_mu, 1061 | 'amplitude_logvar': amplitude_logvar, 1062 | } 1063 | 1064 | if self.settings['predict_redshift']: 1065 | result['photoz'] = data['photoz'] 1066 | result['photoz_error'] = data['photoz_error'] 1067 | 1068 | if to_numpy: 1069 | result = {k: v.detach().cpu().numpy() for k, v in result.items()} 1070 | 1071 | return result 1072 | 1073 | def _compute_amplitude(self, weight, model_flux, flux): 1074 | num = torch.sum(weight * model_flux * flux, axis=1) 1075 | denom = torch.sum(weight * model_flux * model_flux, axis=1) 1076 | 1077 | # With augmentation, can very rarely end up with no light curve points. Handle 1078 | # that gracefully by setting the amplitude to 0 with a very large uncertainty. 1079 | denom[denom == 0.] = 1e-5 1080 | 1081 | amplitude_mu = num / denom 1082 | amplitude_logvar = torch.log(1. / denom) 1083 | 1084 | return amplitude_mu, amplitude_logvar 1085 | 1086 | def loss_function(self, result, return_components=False, return_individual=False): 1087 | """Compute the loss function for a set of light curves 1088 | 1089 | Parameters 1090 | ---------- 1091 | result : dict 1092 | Output of `~ParsnipModel.forward` 1093 | return_components : bool, optional 1094 | Whether to return the individual parts of the loss function, by default 1095 | False. 1096 | return_individual : bool, optional 1097 | Whether to return the loss function for each light curve individually, by 1098 | default False. 1099 | 1100 | Returns 1101 | ------- 1102 | float or `~torch.FloatTensor` 1103 | If return_components and return_individual are False, return a single value 1104 | representing the loss function for a set of light curves. 1105 | If return_components is True, then we return a set of four values 1106 | representing the negative log likelihood, the KL divergence, the 1107 | regularization penalty, and the amplitude probability. 1108 | If return_individual is True, then we return the loss function for each 1109 | light curve individually. 1110 | """ 1111 | # Reconstruction likelihood 1112 | nll = (0.5 * result['obs_weight'] 1113 | * (result['obs_flux'] - result['model_flux'])**2) 1114 | 1115 | # KL divergence 1116 | kld = -0.5 * (1 + result['encoding_logvar'] 1117 | - result['encoding_mu'].pow(2) 1118 | - result['encoding_logvar'].exp()) 1119 | 1120 | # Regularization of spectra 1121 | diff = ( 1122 | (result['model_spectra'][:, 1:, :] - result['model_spectra'][:, :-1, :]) 1123 | / (result['model_spectra'][:, 1:, :] + result['model_spectra'][:, :-1, :]) 1124 | ) 1125 | penalty = self.settings['penalty'] * diff**2 1126 | 1127 | # Amplitude probability for the importance sampling integral 1128 | amp_prob = -0.5 * ((result['amplitude'] - result['amplitude_mu'])**2 1129 | / result['amplitude_logvar'].exp()) 1130 | 1131 | # Redshift error 1132 | if self.settings['predict_redshift']: 1133 | # Prior from photoz estimate 1134 | photoz_diff = result['predicted_redshift'] - result['photoz'] 1135 | redshift_nll = 0.5 * photoz_diff**2 / result['photoz_error']**2 1136 | 1137 | # Prior from true redshift 1138 | mask = ~torch.isnan(result['redshift']) 1139 | diff_redshifts = (result['predicted_redshift'][mask] 1140 | - result['redshift'][mask]) 1141 | redshift_nll[mask] += ( 1142 | 0.5 * diff_redshifts**2 / self.settings['specz_error']**2 1143 | ) 1144 | else: 1145 | redshift_nll = torch.zeros_like(amp_prob) 1146 | 1147 | if return_individual: 1148 | nll = torch.sum(nll, axis=1) 1149 | kld = torch.sum(kld, axis=1) 1150 | penalty = torch.sum(torch.sum(penalty, axis=2), axis=1) 1151 | else: 1152 | nll = torch.sum(nll) 1153 | kld = torch.sum(kld) 1154 | penalty = torch.sum(penalty) 1155 | amp_prob = torch.sum(amp_prob) 1156 | redshift_nll = torch.sum(redshift_nll) 1157 | 1158 | if return_components: 1159 | return torch.stack([nll, kld, penalty, amp_prob, redshift_nll]) 1160 | else: 1161 | return nll + kld + penalty + amp_prob + redshift_nll 1162 | 1163 | def score(self, dataset, rounds=1, return_components=False, sample=True): 1164 | """Evaluate the loss function on a given dataset. 1165 | 1166 | Parameters 1167 | ---------- 1168 | dataset : `~lcdata.Dataset` 1169 | Dataset to run on 1170 | rounds : int, optional 1171 | Number of rounds to use for evaluation. VAEs are stochastic, so the loss 1172 | function is not deterministic. By running for multiple rounds, the 1173 | uncertainty on the loss function can be decreased. Default 1. 1174 | return_components : bool, optional 1175 | Whether to return the individual parts of the loss function, by default 1176 | False. See `~ParsnipModel.loss_function` for details. 1177 | 1178 | Returns 1179 | ------- 1180 | loss 1181 | Computed loss function 1182 | """ 1183 | self.eval() 1184 | 1185 | total_loss = 0 1186 | total_count = 0 1187 | 1188 | loader = self.get_data_loader(dataset) 1189 | 1190 | # Compute the loss 1191 | for round in range(rounds): 1192 | for batch_lcs in loader: 1193 | result = self.forward(batch_lcs, sample=sample) 1194 | loss = self.loss_function(result, return_components) 1195 | 1196 | if return_components: 1197 | total_loss += loss.detach().cpu().numpy() 1198 | else: 1199 | total_loss += loss.item() 1200 | total_count += len(batch_lcs) 1201 | 1202 | loss = total_loss / total_count 1203 | 1204 | return loss 1205 | 1206 | def fit(self, dataset, max_epochs=1000, augment=True, test_dataset=None): 1207 | """Fit the model to a dataset 1208 | 1209 | Parameters 1210 | ---------- 1211 | dataset : `~lcdata.Dataset` 1212 | Dataset to fit to 1213 | max_epochs : int, optional 1214 | Maximum number of epochs, by default 1000 1215 | augment : bool, optional 1216 | Whether to use augmentation, by default True 1217 | test_dataset : `~lcdata.Dataset`, optional 1218 | Test dataset that will be scored at the end of each epoch, by default None 1219 | """ 1220 | # The model is stochastic, so the loss function will have a fair bit of noise. 1221 | # If the dataset is small, we run through several augmentations of it every 1222 | # epoch to get the noise down. 1223 | repeats = int(np.ceil(25000 / len(dataset))) 1224 | 1225 | loader = self.get_data_loader(dataset, augment=augment, shuffle=True) 1226 | 1227 | if test_dataset is not None: 1228 | test_dataset = self.preprocess(test_dataset) 1229 | 1230 | while self.epoch < max_epochs: 1231 | self.train() 1232 | train_loss = 0 1233 | train_count = 0 1234 | 1235 | with tqdm(range(len(loader) * repeats), file=sys.stdout) as pbar: 1236 | for repeat in range(repeats): 1237 | # Training step 1238 | for batch_lcs in loader: 1239 | self.optimizer.zero_grad() 1240 | result = self.forward(batch_lcs) 1241 | 1242 | loss = self.loss_function(result) 1243 | 1244 | loss.backward() 1245 | replace_nan_grads(self.parameters()) 1246 | train_loss += loss.item() 1247 | self.optimizer.step() 1248 | 1249 | train_count += len(batch_lcs) 1250 | 1251 | total_loss = train_loss / train_count 1252 | batch_loss = loss.item() / len(batch_lcs) 1253 | 1254 | pbar.set_description( 1255 | f'Epoch {self.epoch:4d}: Loss: {total_loss:8.4f} ' 1256 | f'({batch_loss:8.4f})', 1257 | refresh=False 1258 | ) 1259 | pbar.update() 1260 | 1261 | if test_dataset is not None: 1262 | # Calculate the test loss 1263 | test_loss = self.score(test_dataset) 1264 | pbar.set_description( 1265 | f'Epoch {self.epoch:4d}: Loss: {total_loss:8.4f}, ' 1266 | f'Test loss: {test_loss:8.4f}', 1267 | ) 1268 | else: 1269 | pbar.set_description( 1270 | f'Epoch {self.epoch:4d}: Loss: {total_loss:8.4f}' 1271 | ) 1272 | 1273 | self.scheduler.step(train_loss) 1274 | 1275 | # Checkpoint and save the model 1276 | self.save() 1277 | 1278 | # Check if the learning rate is below our threshold, and exit if it is. 1279 | lr = self.optimizer.param_groups[0]['lr'] 1280 | if lr < self.settings['min_learning_rate']: 1281 | break 1282 | 1283 | self.epoch += 1 1284 | 1285 | def predict(self, light_curves, augment=False): 1286 | """Generate predictions for a light curve or set of light curves. 1287 | 1288 | Parameters 1289 | ---------- 1290 | light_curves : `~astropy.table.Table` or List[`~astropy.table.Table`] 1291 | Light curve(s) to generate predictions for. 1292 | augment : bool, optional 1293 | Whether to augment the light curve(s), by default False 1294 | 1295 | Returns 1296 | ------- 1297 | `~astropy.table.Table` or dict 1298 | Table (for multiple light curves) or dict (for a single light curve) 1299 | containing the predictions. 1300 | """ 1301 | # Check if we have a list of light curves or a single one and handle it 1302 | # appropriately. 1303 | if isinstance(light_curves, astropy.table.Table): 1304 | # Single object. Wrap it so that we can process it as an array. We'll unwrap 1305 | # it at the end. 1306 | single = True 1307 | light_curves = [light_curves] 1308 | else: 1309 | single = False 1310 | 1311 | # Wrap the light curves in an lcdata dataset and use that to process them. 1312 | dataset = lcdata.from_light_curves(light_curves) 1313 | predictions = self.predict_dataset(dataset, augment=augment) 1314 | 1315 | if single: 1316 | return dict(zip(predictions[0].keys(), predictions[0].values())) 1317 | else: 1318 | return predictions 1319 | 1320 | def predict_dataset(self, dataset, augment=False): 1321 | """Generate predictions for a dataset 1322 | 1323 | Parameters 1324 | ---------- 1325 | dataset : `~lcdata.Dataset` 1326 | Dataset to generate predictions for. 1327 | augment : bool, optional 1328 | Whether to perform augmentation, False by default. 1329 | 1330 | Returns 1331 | ------- 1332 | predictions : `~astropy.table.Table` 1333 | astropy Table with one row for each light curve and columns with each of the 1334 | predicted values. 1335 | """ 1336 | predictions = [] 1337 | 1338 | dataset = self.preprocess(dataset, verbose=len(dataset) > 100) 1339 | loader = self.get_data_loader(dataset, augment=augment) 1340 | 1341 | for batch_lcs in loader: 1342 | # Run the data through the model. 1343 | result = self.forward(batch_lcs, to_numpy=True, sample=False) 1344 | 1345 | # Pull out the reference time and reference scale. Note that if we are 1346 | # working with an augmented dataset, get_data_loader doesn't construct a 1347 | # full astropy Table to save time. Handle either case. 1348 | parsnip_reference_time = [] 1349 | parsnip_scale = [] 1350 | for lc in batch_lcs: 1351 | if isinstance(lc, astropy.table.Table): 1352 | lc_meta = lc.meta 1353 | else: 1354 | lc_data, lc_meta = lc 1355 | parsnip_reference_time.append(lc_meta['parsnip_reference_time']) 1356 | parsnip_scale.append(lc_meta['parsnip_scale']) 1357 | parsnip_reference_time = np.array(parsnip_reference_time) 1358 | parsnip_scale = np.array(parsnip_scale) 1359 | 1360 | encoding_mu = result['encoding_mu'] 1361 | encoding_err = np.sqrt(np.exp(result['encoding_logvar'])) 1362 | 1363 | # Update the reference time. 1364 | reference_time_offset = ( 1365 | encoding_mu[:, 0] * self.settings['time_sigma'] / SIDEREAL_SCALE 1366 | ) 1367 | reference_time = parsnip_reference_time + reference_time_offset 1368 | reference_time_error = ( 1369 | encoding_err[:, 0] * self.settings['time_sigma'] / SIDEREAL_SCALE 1370 | ) 1371 | 1372 | amplitude_mu = result['amplitude_mu'] * parsnip_scale 1373 | amplitude_error = ( 1374 | np.sqrt(np.exp(result['amplitude_logvar'])) * parsnip_scale 1375 | ) 1376 | 1377 | # Pull out the keys that we care about saving. 1378 | batch_predictions = { 1379 | 'reference_time': reference_time, 1380 | 'reference_time_error': reference_time_error, 1381 | 'color': encoding_mu[:, 1] * self.settings['color_sigma'], 1382 | 'color_error': encoding_err[:, 1] * self.settings['color_sigma'], 1383 | 'amplitude': amplitude_mu, 1384 | 'amplitude_error': amplitude_error, 1385 | } 1386 | 1387 | for idx in range(self.settings['latent_size']): 1388 | batch_predictions[f's{idx+1}'] = encoding_mu[:, 2 + idx] 1389 | batch_predictions[f's{idx+1}_error'] = encoding_err[:, 2 + idx] 1390 | 1391 | if self.settings['predict_redshift']: 1392 | pred_redshift = np.clip( 1393 | np.exp(encoding_mu[:, -1] - 1), 1394 | 0, self.settings['max_redshift'] 1395 | ) 1396 | pred_redshift_pos = np.exp(encoding_mu[:, -1] + encoding_err[:, -1] - 1) 1397 | pred_redshift_neg = np.exp(encoding_mu[:, -1] - encoding_err[:, -1] - 1) 1398 | pred_redshift_error = (pred_redshift_pos - pred_redshift_neg) / 2. 1399 | batch_predictions['predicted_redshift'] = pred_redshift 1400 | batch_predictions['predicted_redshift_error'] = pred_redshift_error 1401 | 1402 | # Calculate other useful features. 1403 | time = result['time'] 1404 | obs_flux = result['obs_flux'] 1405 | obs_fluxerr = result['obs_fluxerr'] 1406 | model_flux = result['model_flux'] 1407 | fluxerr_mask = obs_fluxerr == 0 1408 | obs_fluxerr[fluxerr_mask] = -1. 1409 | 1410 | # Signal-to-noise 1411 | s2n = obs_flux / obs_fluxerr 1412 | s2n[fluxerr_mask] = 0. 1413 | batch_predictions['total_s2n'] = np.sqrt(np.sum(s2n**2, axis=1)) 1414 | 1415 | # Number of observations 1416 | batch_predictions['count'] = np.sum(~fluxerr_mask, axis=1) 1417 | 1418 | # Number of observations with signal-to-noise above some threshold. 1419 | batch_predictions['count_s2n_3'] = np.sum(s2n > 3, axis=1) 1420 | batch_predictions['count_s2n_5'] = np.sum(s2n > 5, axis=1) 1421 | 1422 | # Number of observations with signal-to-noise above some threshold in 1423 | # different time windows. 1424 | compare_time = reference_time_offset[:, None] 1425 | mask_pre = time < compare_time - 50. 1426 | mask_rise = (time >= compare_time - 50.) & (time < compare_time) 1427 | mask_fall = (time >= compare_time) & (time < compare_time + 50.) 1428 | mask_post = (time >= compare_time + 50.) 1429 | mask_s2n = s2n > 3 1430 | batch_predictions['count_s2n_3_pre'] = np.sum(mask_pre & mask_s2n, axis=1) 1431 | batch_predictions['count_s2n_3_rise'] = np.sum(mask_rise & mask_s2n, axis=1) 1432 | batch_predictions['count_s2n_3_fall'] = np.sum(mask_fall & mask_s2n, axis=1) 1433 | batch_predictions['count_s2n_3_post'] = np.sum(mask_post & mask_s2n, axis=1) 1434 | 1435 | # Chi-square 1436 | all_chisq = (obs_flux - model_flux)**2 / obs_fluxerr**2 1437 | all_chisq[fluxerr_mask] = 0. 1438 | batch_predictions['model_chisq'] = np.sum(all_chisq, axis=1) 1439 | batch_predictions['model_dof'] = ( 1440 | batch_predictions['count'] 1441 | - 3 # amplitude, color, reference time 1442 | - self.settings['latent_size'] 1443 | ) 1444 | 1445 | predictions.append(astropy.table.Table(batch_predictions)) 1446 | 1447 | predictions = astropy.table.vstack(predictions, 'exact') 1448 | 1449 | # Drop any old predictions from the metadata, and merge it in. 1450 | meta = dataset.meta.copy() 1451 | common_columns = set(predictions.colnames) & set(meta.colnames) 1452 | meta.remove_columns(common_columns) 1453 | predictions = astropy.table.hstack([meta, predictions], 'exact') 1454 | 1455 | # Estimate the absolute luminosity. 1456 | # Figure out which light curves we can calculate the luminosity for. 1457 | amplitudes = predictions['amplitude'].copy() 1458 | amplitude_mask = amplitudes > 0. 1459 | if self.settings['predict_redshift']: 1460 | redshifts = predictions['predicted_redshift'].copy() 1461 | else: 1462 | redshifts = predictions['redshift'].copy() 1463 | redshift_mask = redshifts > 0. 1464 | amplitude_error_mask = predictions['amplitude_error'] < 0.5 * amplitudes 1465 | luminosity_mask = amplitude_mask & redshift_mask & amplitude_error_mask 1466 | 1467 | # Mask out invalid data for luminosities 1468 | redshifts[~luminosity_mask] = 1. 1469 | amplitudes[~luminosity_mask] = 1. 1470 | frac_diff = predictions['amplitude_error'] / amplitudes 1471 | frac_diff[~luminosity_mask] = 0.5 1472 | 1473 | luminosity = ( 1474 | -2.5*np.log10(amplitudes) 1475 | + self.settings['zeropoint'] 1476 | - Planck18.distmod(redshifts).value 1477 | ) 1478 | luminosity[~luminosity_mask] = np.nan 1479 | predictions['luminosity'] = luminosity 1480 | 1481 | # Luminosity uncertainty 1482 | int_mag_err = frac_to_mag(frac_diff) 1483 | int_mag_err[~luminosity_mask] = np.nan 1484 | predictions['luminosity_error'] = int_mag_err 1485 | 1486 | # Remove the processing flag. 1487 | del predictions['parsnip_preprocessed'] 1488 | 1489 | return predictions 1490 | 1491 | def predict_dataset_augmented(self, dataset, augments=10): 1492 | """Generate predictions for a dataset with augmentation 1493 | 1494 | This will first generate predictions for the dataset without augmentation, 1495 | and will then generate predictions for the dataset with augmentation the 1496 | given number of times. This returns a dataframe in the same format as 1497 | `~predict_dataset`, but with the following additional columns: 1498 | - original_object_id: the original object_id for each augmentation. 1499 | - augmented: True for augmented light curves, False for original ones. 1500 | 1501 | Parameters 1502 | ---------- 1503 | dataset : `~lcdata.Dataset` 1504 | Dataset to generate predictions for. 1505 | augments : int, optional 1506 | Number of times to augment the dataset, by default 10 1507 | 1508 | Returns 1509 | ------- 1510 | predictions : `~astropy.table.Table` 1511 | astropy Table with one row for each light curve and columns with each of the 1512 | predicted values. 1513 | """ 1514 | # First pass without augmentation. 1515 | pred = self.predict_dataset(dataset) 1516 | pred['original_object_id'] = pred['object_id'] 1517 | pred['augmented'] = False 1518 | 1519 | predictions = [pred] 1520 | 1521 | # Next passes with augmentation. 1522 | for idx in tqdm(range(augments), file=sys.stdout): 1523 | pred = self.predict_dataset(dataset, augment=True) 1524 | pred['original_object_id'] = pred['object_id'] 1525 | pred['augmented'] = True 1526 | pred['object_id'] = [i + f'_aug_{idx+1}' for i in pred['object_id']] 1527 | predictions.append(pred) 1528 | 1529 | predictions = astropy.table.vstack(predictions, 'exact') 1530 | return predictions 1531 | 1532 | def _predict_time_series(self, light_curve, pred_times, pred_bands, sample, count): 1533 | # Preprocess the light curve if it wasn't already. 1534 | light_curve = preprocess_light_curve(light_curve, self.settings) 1535 | 1536 | # Convert given times to our internal times. 1537 | grid_times = time_to_grid(pred_times, 1538 | light_curve.meta['parsnip_reference_time']) 1539 | 1540 | grid_times = torch.FloatTensor(grid_times)[None, :].to(self.device) 1541 | pred_bands = torch.LongTensor(pred_bands)[None, :].to(self.device) 1542 | 1543 | if count is not None: 1544 | # Predict multiple light curves 1545 | light_curves = [light_curve] * count 1546 | grid_times = grid_times.repeat(count, 1) 1547 | pred_bands = pred_bands.repeat(count, 1) 1548 | else: 1549 | light_curves = [light_curve] 1550 | 1551 | # Sample VAE parameters 1552 | result = self.forward(light_curves, sample) 1553 | 1554 | # Do the predictions 1555 | if self.settings['predict_redshift']: 1556 | redshifts = result['predicted_redshift'] 1557 | else: 1558 | redshifts = result['redshift'] 1559 | 1560 | model_spectra, model_flux = self.decode( 1561 | result['encoding'], 1562 | result['ref_times'], 1563 | result['color'], 1564 | grid_times, 1565 | redshifts, 1566 | pred_bands, 1567 | result['amplitude'], 1568 | ) 1569 | 1570 | model_flux = model_flux.cpu().detach().numpy() 1571 | model_spectra = model_spectra.cpu().detach().numpy() 1572 | 1573 | if count is None: 1574 | # Get rid of the batch index 1575 | model_flux = model_flux[0] 1576 | model_spectra = model_spectra[0] 1577 | 1578 | cpu_result = {k: v.detach().cpu().numpy() for k, v in result.items()} 1579 | 1580 | # Scale everything to the original light curve scale. 1581 | model_flux *= light_curve.meta['parsnip_scale'] 1582 | model_spectra *= light_curve.meta['parsnip_scale'] 1583 | 1584 | return model_flux, model_spectra, cpu_result 1585 | 1586 | def predict_light_curve(self, light_curve, sample=False, count=None, sampling=1., 1587 | pad=50.): 1588 | """Predict the flux of a light curve on a grid 1589 | 1590 | Parameters 1591 | ---------- 1592 | light_curve : `~astropy.table.Table` 1593 | Light curve to predict 1594 | sample : bool, optional 1595 | If True, sample from the latent variable posteriors. Otherwise, 1596 | use the MAP. By default False. 1597 | count : int, optional 1598 | Number of light curves to predict, by default None (single prediction) 1599 | sampling : int, optional 1600 | Grid sampling in days, by default 1. 1601 | pad : int, optional 1602 | Number of days before and after the light curve observations to predict the 1603 | light curve at, by default 50. 1604 | 1605 | Returns 1606 | ------- 1607 | `~numpy.ndarray` 1608 | Times that the model was sampled at 1609 | `~numpy.ndarray` 1610 | Flux of the model in each band 1611 | `~numpy.ndarray` 1612 | Model result from ParsnipModel.forward 1613 | """ 1614 | # Figure out where to sample the light curve 1615 | min_time = np.min(light_curve['time']) - pad 1616 | max_time = np.max(light_curve['time']) + pad 1617 | model_times = np.arange(min_time, max_time + sampling, sampling) 1618 | 1619 | band_indices = np.arange(len(self.settings['bands'])) 1620 | 1621 | pred_times = np.tile(model_times, len(band_indices)) 1622 | pred_bands = np.repeat(band_indices, len(model_times)) 1623 | 1624 | model_flux, model_spectra, model_result = self._predict_time_series( 1625 | light_curve, pred_times, pred_bands, sample, count 1626 | ) 1627 | 1628 | # Reshape model_flux so that it has the shape (batch, band, time) 1629 | model_flux = model_flux.reshape((-1, len(band_indices), len(model_times))) 1630 | 1631 | if count == 0: 1632 | # Get rid of the batch index 1633 | model_flux = model_flux[0] 1634 | 1635 | return model_times, model_flux, model_result 1636 | 1637 | def predict_spectrum(self, light_curve, time, sample=False, count=None): 1638 | """Predict the spectrum of a light curve at a given time 1639 | 1640 | Parameters 1641 | ---------- 1642 | light_curve : `~astropy.table.Table` 1643 | Light curve 1644 | time : float 1645 | Time to predict the spectrum at 1646 | sample : bool, optional 1647 | If True, sample from the latent variable posteriors. Otherwise, 1648 | use the MAP. By default False. 1649 | count : int, optional 1650 | Number of spectra to predict, by default None (single prediction) 1651 | 1652 | Returns 1653 | ------- 1654 | `~numpy.ndarray` 1655 | Predicted spectrum at the wavelengths specified by 1656 | `~ParsnipModel.model_wave` 1657 | """ 1658 | pred_times = [time] 1659 | pred_bands = [0] 1660 | 1661 | model_flux, model_spectra, model_result = self._predict_time_series( 1662 | light_curve, pred_times, pred_bands, sample, count 1663 | ) 1664 | 1665 | return model_spectra[..., 0] 1666 | 1667 | def predict_sncosmo(self, light_curve, sample=False): 1668 | """Package the predictions for a light curve as an sncosmo model 1669 | 1670 | This method performs variational inference on a light curve to predict its 1671 | latent representation. It then initializes an SNCosmo model with that 1672 | representation. 1673 | 1674 | Parameters 1675 | ---------- 1676 | light_curve : `~astropy.table.Table` 1677 | Light curve 1678 | sample : bool, optional 1679 | If True, sample from the latent variable posteriors. Otherwise, 1680 | use the MAP. By default False. 1681 | 1682 | Returns 1683 | ------- 1684 | `~ParsnipSncosmoModel` 1685 | SNCosmo model initialized with the light curve's predicted latent 1686 | representation 1687 | """ 1688 | light_curve = preprocess_light_curve(light_curve, self.settings) 1689 | 1690 | # Run through the model to predict parameters. 1691 | result = self.forward([light_curve], sample=sample, to_numpy=True) 1692 | 1693 | # Build the sncosmo model. 1694 | model = sncosmo.Model(source=ParsnipSncosmoSource(self)) 1695 | 1696 | meta = light_curve.meta 1697 | if self.settings['predict_redshift']: 1698 | model['z'] = result['predicted_redshift'][0] 1699 | else: 1700 | model['z'] = meta['redshift'] 1701 | 1702 | model['t0'] = grid_to_time(result['ref_times'][0], 1703 | meta['parsnip_reference_time']) 1704 | model['color'] = result['color'][0] 1705 | 1706 | # Note: ZP of amplitude is 25, and we use an internal offset of 20 for building 1707 | # the model so that things are close to 1. Combined, that means that we need to 1708 | # apply an offset of 45 mag when calculating the amplitude for sncosmo. 1709 | model['amplitude'] = ( 1710 | light_curve.meta['parsnip_scale'] * result['amplitude'][0] 1711 | * 10**(-0.4 * (20 + self.settings['zeropoint'])) 1712 | ) 1713 | 1714 | for i in range(self.settings['latent_size']): 1715 | model[f's{i+1}'] = result['encoding'][0, i] 1716 | 1717 | return model 1718 | 1719 | def predict_redshift_distribution(self, light_curve, min_redshift=0., 1720 | max_redshift=None, sampling=0.01): 1721 | """Predict the redshift distribution for a light curve. 1722 | 1723 | Given observations y, and latent variables s, we want to compute the redshift 1724 | distribution p(z|y) marginalized over the latent variables. Working this out 1725 | with Bayes' theorem:: 1726 | 1727 | p(z|y) = Integral[p(y|s,z) p(s,z) ds] / p(y) 1728 | 1729 | p(y|s,z) is the term that we compute as the negative log-likelihood in our loss 1730 | function. We assume that p(s,z) is constant. p(y) just contributes an overall 1731 | normalization term and can be ignored. 1732 | 1733 | The correct way to evaluate this function would be to perform a Monte Carlo 1734 | integration p(y|s,z) like we currently do to marginalize over the amplitude. 1735 | However, that procedure is stochastic and requires many computations to average 1736 | out. Here we instead approximate p(z|y) by simply evaluating p(y|s,z) at the MAP 1737 | value of the model parameters. This is not correct, but should provide a 1738 | reasonable approximation to the integral in most cases and only requires a 1739 | single evaluation per redshift. 1740 | 1741 | We evaluate this approximate redshift distribution on a prespecified grid of 1742 | redshifts and normalize so that the distribution sums to 1. 1743 | 1744 | Parameters 1745 | ---------- 1746 | light_curve : `~astropy.table.Table` 1747 | Light curve 1748 | min_redshift : float, optional 1749 | Minimum redshift to consider, by default 0. 1750 | max_redshift : float, optional 1751 | Maximum redshift to consider, by default specified by 1752 | settings['max_redshift']. 1753 | sampling : float, optional 1754 | Sampling to use, by default 0.01. 1755 | 1756 | Returns 1757 | ------- 1758 | `~numpy.ndarray` 1759 | Redshifts that the probability distribution was evaluated at 1760 | `~numpy.ndarray` 1761 | Redshift probability distribution 1762 | """ 1763 | if max_redshift is None: 1764 | max_redshift = self.settings['max_redshift'] 1765 | sample_redshifts = np.arange(min_redshift, max_redshift + sampling / 100., 1766 | sampling) 1767 | if min_redshift == 0: 1768 | # Having the redshift be equal to zero can cause problems. Handle this 1769 | # gracefully. 1770 | sample_redshifts[0] = 0.0001 1771 | 1772 | light_curve = preprocess_light_curve(light_curve, self.settings, 1773 | ignore_missing_redshift=True) 1774 | 1775 | lcs = [] 1776 | for redshift in sample_redshifts: 1777 | redshift_lc = light_curve.copy() 1778 | redshift_lc.meta['redshift'] = redshift 1779 | lcs.append(redshift_lc) 1780 | 1781 | result = self.forward(lcs, sample=False) 1782 | nll = self.loss_function(result, return_individual=True, 1783 | return_components=True)[0] 1784 | nll = nll.detach().cpu().numpy() 1785 | 1786 | # Normalize the probability distribution 1787 | prob = np.exp(-(nll - np.min(nll))) 1788 | prob = prob / np.sum(prob) / sampling 1789 | 1790 | return sample_redshifts, prob 1791 | 1792 | def predict_redshift(self, light_curve, min_redshift=0., max_redshift=None, 1793 | sampling=0.01): 1794 | """Predict the redshift of a light curve. 1795 | 1796 | This evaluates the MAP estimate of the redshift. 1797 | 1798 | Parameters 1799 | ---------- 1800 | light_curve : `~astropy.table.Table` 1801 | Light curve 1802 | 1803 | Returns 1804 | ------- 1805 | float 1806 | MAP estimate of the redshift 1807 | """ 1808 | redshifts, redshift_distribution = self.predict_redshift_distribution( 1809 | light_curve, 1810 | min_redshift=min_redshift, 1811 | max_redshift=max_redshift, 1812 | sampling=sampling 1813 | ) 1814 | return redshifts[np.argmax(redshift_distribution)] 1815 | 1816 | 1817 | def load_model(path=None, device='cpu', threads=8): 1818 | """Load a ParSNIP model. 1819 | 1820 | Parameters 1821 | ---------- 1822 | path : str, optional 1823 | Path to the model on disk, or name of a model. If not specified, the 1824 | default_model specified in settings.py is loaded. 1825 | device : str, optional 1826 | Torch device to load the model to, by default 'cpu' 1827 | threads : int, optional 1828 | Number of threads to use, by default 8 1829 | 1830 | Returns 1831 | ------- 1832 | `~ParsnipModel` 1833 | Loaded model 1834 | """ 1835 | if path is None: 1836 | path = default_model 1837 | print(f"Loading default ParSNIP model '{path}'") 1838 | 1839 | # Figure out if we were given the path to a model or a built-in model. 1840 | if '.' not in path: 1841 | # We were given the name of a built-in model. 1842 | resource_path = f'models/{path}.pt' 1843 | full_path = pkg_resources.resource_filename('parsnip', resource_path) 1844 | if not os.path.exists(full_path): 1845 | raise ValueError(f"No built-in model named '{path}'") 1846 | path = full_path 1847 | 1848 | # Load the model data 1849 | use_device = parse_device(device) 1850 | settings, state_dict = torch.load(path, use_device) 1851 | 1852 | # Instantiate the model 1853 | model = ParsnipModel(path, settings['bands'], use_device, threads, settings) 1854 | model.load_state_dict(state_dict) 1855 | 1856 | return model 1857 | -------------------------------------------------------------------------------- /parsnip/plotting.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from matplotlib.gridspec import GridSpec 6 | from sklearn.metrics import confusion_matrix 7 | 8 | from .classifier import extract_top_classifications 9 | from .instruments import (get_band_effective_wavelength, get_band_plot_color, 10 | get_band_plot_marker) 11 | from .light_curve import preprocess_light_curve 12 | 13 | 14 | def _get_reference_time(light_curve): 15 | if 'reference_time' in light_curve.meta: 16 | # Reference time calculated from running through the full ParSNIP model. 17 | return light_curve.meta['reference_time'] 18 | elif 'parsnip_reference_time' in light_curve.meta: 19 | # Initial estimate of the reference time. 20 | return light_curve.meta['parsnip_reference_time'] 21 | else: 22 | # No estimate of the reference time. Just show the light curve as is. 23 | return 0. 24 | 25 | 26 | def plot_light_curve(light_curve, model=None, count=100, show_uncertainty_bands=True, 27 | show_missing_bandpasses=False, percentile=68, normalize_flux=False, 28 | sncosmo_model=None, sncosmo_label='SNCosmo Model', ax=None): 29 | """Plot a light curve 30 | 31 | Parameters 32 | ---------- 33 | light_curve : `~astropy.table.Table` 34 | Light curve to plot 35 | model : `~ParsnipModel`, optional 36 | ParSNIP model to show, by default None 37 | count : int, optional 38 | Number of samples from the ParSNIP model, by default 100 39 | show_uncertainty_bands : bool, optional 40 | If True (default), show uncertainty bands. Otherwise, show individual draws. 41 | show_missing_bandpasses : bool, optional 42 | Whether to show model predictions for bandpasses where there is no data, by 43 | default False 44 | percentile : int, optional 45 | Percentile for the uncertainty bands, by default 68 46 | normalize_flux : bool, optional 47 | Whether to normalize the flux, by default False 48 | sncosmo_model : `~sncosmo.Model`, optional 49 | SNCosmo model to show, by default None 50 | sncosmo_label : str, optional 51 | Legend label for the SNCosmo model, by default 'SNCosmo Model' 52 | ax : axis, optional 53 | Matplotlib axis to use for the plot, by default one will be created 54 | """ 55 | if model is not None: 56 | light_curve = preprocess_light_curve(light_curve, model.settings) 57 | 58 | if ax is None: 59 | fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True) 60 | 61 | used_bandpasses = [] 62 | 63 | if normalize_flux: 64 | flux_scale = 1. / light_curve.meta['parsnip_scale'] 65 | else: 66 | flux_scale = 1. 67 | 68 | reference_time = _get_reference_time(light_curve) 69 | 70 | # Group the observations by band, and order the bands by central wavelength. 71 | band_groups = light_curve.group_by('band').groups 72 | band_dict = dict(zip(band_groups.keys['band'], band_groups)) 73 | band_order = sorted(band_dict.keys(), key=get_band_effective_wavelength) 74 | 75 | # Plot the observations 76 | for band_name in band_order: 77 | band_data = band_dict[band_name] 78 | 79 | c = get_band_plot_color(band_name) 80 | marker = get_band_plot_marker(band_name) 81 | 82 | band_time = band_data['time'] 83 | band_flux = band_data['flux'] * flux_scale 84 | band_fluxerr = band_data['fluxerr'] * flux_scale 85 | band_time = band_time - reference_time 86 | 87 | ax.errorbar(band_time, band_flux, band_fluxerr, ls='', c=c, label=band_name, 88 | elinewidth=1, marker=marker) 89 | 90 | used_bandpasses.append(band_name) 91 | 92 | # Plot the model if we have one. 93 | if model is not None: 94 | max_model = 0. 95 | label_model = True 96 | 97 | model_times, model_flux, model_result = model.predict_light_curve( 98 | light_curve, sample=True, count=count 99 | ) 100 | 101 | model_times = model_times - reference_time 102 | model_flux = model_flux * flux_scale 103 | 104 | for band_idx, band_name in enumerate(model.settings['bands']): 105 | if band_name not in used_bandpasses and not show_missing_bandpasses: 106 | continue 107 | 108 | c = get_band_plot_color(band_name) 109 | marker = get_band_plot_marker(band_name) 110 | 111 | if label_model: 112 | label = 'ParSNIP Model' 113 | label_model = False 114 | else: 115 | label = None 116 | 117 | if count == 0: 118 | # Single prediction 119 | ax.plot(model_times, model_flux[band_idx], c=c, label=label) 120 | band_max_model = np.max(model_flux[band_idx]) 121 | elif show_uncertainty_bands: 122 | # Multiple predictions, show error bands. 123 | percentile_offset = (100 - percentile) / 2. 124 | flux_median = np.median(model_flux[:, band_idx], axis=0) 125 | flux_min = np.percentile(model_flux[:, band_idx], percentile_offset, 126 | axis=0) 127 | flux_max = np.percentile(model_flux[:, band_idx], 128 | 100 - percentile_offset, axis=0) 129 | ax.plot(model_times, flux_median, c=c, label=label) 130 | ax.fill_between(model_times, flux_min, 131 | flux_max, color=c, alpha=0.3) 132 | band_max_model = np.max(flux_median) 133 | else: 134 | # Multiple predictions, show raw light curves 135 | ax.plot(model_times, model_flux[:, band_idx].T, c=c, alpha=0.1) 136 | band_max_model = np.max(model_flux) 137 | 138 | max_model = max(max_model, band_max_model) 139 | 140 | ax.set_ylim(-0.2 * max_model, 1.2 * max_model) 141 | 142 | # Plot an SNCosmo model if we have one. 143 | if sncosmo_model is not None: 144 | model_times = np.arange(sncosmo_model.mintime(), sncosmo_model.maxtime(), 0.5) 145 | 146 | label_model = True 147 | 148 | for band_idx, band_name in enumerate(model.settings['bands']): 149 | if band_name not in used_bandpasses and not show_missing_bandpasses: 150 | continue 151 | 152 | try: 153 | flux = flux_scale * sncosmo_model.bandflux( 154 | band_name, model_times, zp=model.settings['zeropoint'], zpsys='ab' 155 | ) 156 | except ValueError: 157 | # Outside of wavelength range 158 | continue 159 | 160 | c = get_band_plot_color(band_name) 161 | if label_model: 162 | label = sncosmo_label 163 | label_model = False 164 | else: 165 | label = None 166 | 167 | ax.plot(model_times - reference_time, flux, c=c, ls='--', label=label) 168 | 169 | ax.legend() 170 | 171 | if reference_time != 0.: 172 | ax.set_xlabel(f'Relative Time (days + {reference_time:.2f})') 173 | else: 174 | ax.set_xlabel('Time (days)') 175 | 176 | if normalize_flux: 177 | ax.set_ylabel('Normalized Flux') 178 | else: 179 | ax.set_ylabel(f'Flux ($ZP_{{AB}}$={model.settings["zeropoint"]})') 180 | 181 | 182 | def normalize_spectrum_flux(wave, flux, min_wave=5500., max_wave=6500.): 183 | """Normalize the flux of a spectrum 184 | 185 | The flux will be normalized so that it averages to 1 in a given window. 186 | 187 | Parameters 188 | ---------- 189 | wave : `~numpy.ndarray` 190 | Wavelengths 191 | flux : `~numpy.ndarray` 192 | Flux values 193 | min_wave : float, optional 194 | Minimum wavelength to consider for normalization, by default 5500. 195 | max_wave : float, optional 196 | Maximum wavelength to consider for normalization, by default 6500. 197 | 198 | Returns 199 | ------- 200 | `~numpy.ndarray` 201 | Normalized flux 202 | """ 203 | cut = (wave > min_wave) & (wave < max_wave) 204 | scale = np.mean(flux[..., cut], axis=-1) 205 | return (flux.T / scale).T 206 | 207 | 208 | def plot_spectrum(light_curve, model, time, count=100, show_uncertainty_bands=True, 209 | percentile=68, ax=None, c=None, label=None, offset=None, 210 | normalize_flux=False, normalize_min_wave=5500., 211 | normalize_max_wave=6500., spectrum_label=None, 212 | spectrum_label_wave=7500., spectrum_label_offset=0.2, flux_scale=1.): 213 | """Plot the spectrum of a light curve predicted by a ParSNIP model 214 | 215 | Parameters 216 | ---------- 217 | light_curve : `~astropy.table.Table` 218 | Light curve 219 | model : `~ParsnipModel` 220 | Model to use for the prediction 221 | time : float 222 | Time to predict the spectrum at 223 | count : int, optional 224 | Number of spectra to sample, by default 100 225 | show_uncertainty_bands : bool, optional 226 | Whether to show uncertainty bands, by default True 227 | percentile : int, optional 228 | Percentile for the uncertainty bands, by default 68 229 | ax : axis, optional 230 | Matplotlib axis to use, by default None 231 | c : str, optional 232 | Color for the plot, by default None 233 | label : str, optional 234 | Label for the plot, by default None 235 | offset : float, optional 236 | Constant offset to add to the flux for plotting, by default None 237 | normalize_flux : bool, optional 238 | Whether to normalize the flux, by default False 239 | normalize_min_wave : float, optional 240 | Minimum wavelength of the normalization window, by default 5500. 241 | normalize_max_wave : float, optional 242 | Maximum wavelength of the normalization window, by default 6500. 243 | spectrum_label : str, optional 244 | Label to plot near the spectrum, by default None 245 | spectrum_label_wave : float, optional 246 | Wavelength to plot the spectrum label at, by default 7500. 247 | spectrum_label_offset : float, optional 248 | Y offset for the spectrum label, by default 0.2 249 | flux_scale : float, optional 250 | Scale to multiply the flux by, by default 1. 251 | """ 252 | light_curve = preprocess_light_curve(light_curve, model.settings) 253 | 254 | if ax is None: 255 | fig, ax = plt.subplots(figsize=(8, 6), dpi=100) 256 | 257 | model_wave = model.model_wave 258 | model_spectra = model.predict_spectrum(light_curve, time, sample=True, count=count) 259 | 260 | if normalize_flux: 261 | model_spectra = normalize_spectrum_flux( 262 | model_wave, model_spectra, normalize_min_wave, normalize_max_wave 263 | ) 264 | 265 | model_spectra *= flux_scale 266 | 267 | if offset is not None: 268 | model_spectra += offset 269 | 270 | if count == 0: 271 | # Single prediction 272 | ax.plot(model_wave, model_spectra, c=c, label=label) 273 | elif show_uncertainty_bands: 274 | # Multiple predictions, show error bands. 275 | percentile_offset = (100 - percentile) / 2. 276 | flux_median = np.median(model_spectra, axis=0) 277 | flux_min = np.percentile(model_spectra, percentile_offset, axis=0) 278 | flux_max = np.percentile( 279 | model_spectra, 100 - percentile_offset, axis=0) 280 | ax.plot(model_wave, flux_median, c=c, label=label) 281 | ax.fill_between(model_wave, flux_min, flux_max, color=c, alpha=0.3) 282 | else: 283 | # Multiple predictions, show raw light curves 284 | ax.plot(model_wave, model_spectra.T, c=c, alpha=0.1) 285 | 286 | if spectrum_label is not None: 287 | # Show a label above the spectrum. 288 | wave_idx = np.searchsorted(model.model_wave, spectrum_label_wave) 289 | label_height = spectrum_label_offset + np.mean(model_spectra[..., wave_idx]) 290 | ax.text(spectrum_label_wave, label_height, spectrum_label) 291 | 292 | ax.set_xlabel('Wavelength ($\\AA$)') 293 | if normalize_flux: 294 | ax.set_ylabel('Normalized Flux') 295 | else: 296 | ax.set_ylabel('Flux') 297 | 298 | 299 | def plot_spectra(light_curve, model, times=[0., 10., 20., 30.], flux_scale=1., 300 | ax=None, sncosmo_model=None, sncosmo_label='SNCosmo Model', 301 | spectrum_label_offset=0.2): 302 | """Plot the spectral time series of a light curve predicted by a ParSNIP model 303 | 304 | Parameters 305 | ---------- 306 | light_curve : `~astropy.table.Table` 307 | Light curve 308 | model : `~ParsnipModel` 309 | Model to use for the predictions 310 | times : list, optional 311 | Times to predict the spectra at, by default [0., 10., 20., 30.] 312 | flux_scale : float, optional 313 | Scale to multiple the flux by, by default 1. 314 | ax : axis, optional 315 | Matplotlib axis, by default None 316 | sncosmo_model : `~sncosmo.Model`, optional 317 | SNCosmo model to overplot, by default None 318 | sncosmo_label : str, optional 319 | Label for the SNCosmo model, by default 'SNCosmo Model' 320 | spectrum_label_offset : float, optional 321 | Offset of the time labels for each spectrum, by default 0.2 322 | """ 323 | light_curve = preprocess_light_curve(light_curve, model.settings) 324 | 325 | wave = model.model_wave 326 | redshift = light_curve.meta['redshift'] 327 | scale = flux_scale / light_curve.meta['parsnip_scale'] 328 | 329 | if ax is None: 330 | fig, ax = plt.subplots(figsize=(5, 4), constrained_layout=True) 331 | 332 | reference_time = _get_reference_time(light_curve) 333 | 334 | for plot_idx, time in enumerate(times): 335 | plot_offset = len(times) - plot_idx - 1 336 | 337 | plot_time = time + reference_time 338 | 339 | if plot_idx == 0: 340 | use_label = 'ParSNIP Model' 341 | use_sncosmo_label = sncosmo_label 342 | else: 343 | use_label = None 344 | use_sncosmo_label = None 345 | 346 | plot_spectrum(light_curve, model, plot_time, ax=ax, c='C2', 347 | flux_scale=scale, offset=plot_offset, label=use_label, 348 | spectrum_label=f'{time:+.1f} days', 349 | spectrum_label_offset=spectrum_label_offset) 350 | 351 | if sncosmo_model is not None: 352 | sncosmo_flux = ( 353 | sncosmo_model._flux(plot_time, wave * (1 + redshift))[0] 354 | * 10**(0.4 * 45) 355 | * (1 + redshift) 356 | ) 357 | ax.plot(wave, scale * sncosmo_flux + plot_offset, c='k', alpha=0.3, 358 | label=use_sncosmo_label) 359 | 360 | ax.set_ylabel('Normalized Flux + Offset') 361 | ax.legend() 362 | 363 | 364 | def plot_sne_space(light_curve, model, name, min_wave=10000., max_wave=0., time_diff=0., 365 | min_time=-10000., max_time=100000., source=None, kernel=5, 366 | flux_scale=0.5, label_wave=9000., label_offset=0.2, figsize=(5, 6)): 367 | """Compare a ParSNIP spectrum prediction to a real spectrum from sne.space 368 | 369 | Parameters 370 | ---------- 371 | light_curve : `~astropy.table.Table` 372 | Light curve 373 | model : `~ParsnipModel` 374 | ParSNIP Model to use for the prediction 375 | name : str 376 | Name of the light curve on sne.space 377 | min_wave : float, optional 378 | Ignore any spectra that don't have data bluer than this wavelength, by default 379 | 10000. 380 | max_wave : float, optional 381 | Ignore any spectra that don't have data redder than this wavelength, by default 382 | 0. 383 | time_diff : float, optional 384 | Minimum time between spectra, by default 0. 385 | min_time : float, optional 386 | Ignore any spectra before this time, by default -10000. 387 | max_time : float, optional 388 | Ignore any spectra after this time, by default 100000. 389 | source : str, optional 390 | Ignore any spectra not from this source, by default None 391 | kernel : int, optional 392 | Smooth the spectra by a median filter kernel of this size, by default 5 393 | flux_scale : float, optional 394 | Scale the flux by this amount, by default 0.5 395 | label_wave : float, optional 396 | Show labels with the times of each spectrum at this wavelength, by default 9000. 397 | label_offset : float, optional 398 | Y offset to use for the labels, by default 0.2 399 | figsize : tuple, optional 400 | Figure size, by default (5, 6) 401 | """ 402 | import json 403 | import urllib 404 | from scipy.signal import medfilt 405 | 406 | light_curve = preprocess_light_curve(light_curve, model.settings) 407 | 408 | redshift = light_curve.meta['redshift'] 409 | reference_time = _get_reference_time(light_curve) 410 | 411 | fig, ax = plt.subplots(figsize=figsize, constrained_layout=True) 412 | 413 | url = f'https://api.sne.space/{name}/spectra/time+data+instrument+telescope+source' 414 | with urllib.request.urlopen(url) as request: 415 | data = json.loads(request.read().decode()) 416 | 417 | spectra = data[name]['spectra'] 418 | 419 | plot_idx = 0 420 | last_time = 1e9 421 | 422 | for spec in spectra[::-1]: 423 | # Go in reverse order so that we can plot the spectra with the 424 | # first one on top. 425 | spec_time, spec_data, telescope, _, spec_source = spec 426 | 427 | spec_time = float(spec_time) 428 | spec_data = np.array(spec_data, dtype=float) 429 | 430 | spec_wave = spec_data[:, 0] / (1 + redshift) 431 | spec_flux = spec_data[:, 1] 432 | spec_flux = medfilt(spec_flux, kernel) 433 | 434 | if spec_wave[0] > min_wave or spec_wave[-1] < max_wave: 435 | # print("skipping wave") 436 | continue 437 | 438 | if last_time - spec_time < time_diff: 439 | continue 440 | 441 | if (spec_time - reference_time < min_time 442 | or spec_time - reference_time > max_time): 443 | continue 444 | 445 | if source is not None and spec_source != source: 446 | continue 447 | 448 | last_time = spec_time 449 | 450 | normalize_min_wave = max([5500., spec_wave[0]]) 451 | normalize_max_wave = min([6500., spec_wave[-1]]) 452 | 453 | plot_offset_scale = 1. 454 | plot_offset = (plot_idx * plot_offset_scale) 455 | 456 | plot_spectrum( 457 | light_curve, 458 | model, 459 | spec_time, 460 | normalize_flux=True, 461 | normalize_min_wave=normalize_min_wave, 462 | normalize_max_wave=normalize_max_wave, 463 | flux_scale=flux_scale, 464 | ax=ax, 465 | offset=plot_offset, 466 | c='C2' 467 | ) 468 | 469 | spec_flux = flux_scale * normalize_spectrum_flux( 470 | spec_wave, spec_flux, normalize_min_wave, normalize_max_wave 471 | ) 472 | ax.plot(spec_wave, spec_flux + plot_offset, c='k') 473 | 474 | plt.text(label_wave, plot_offset + label_offset, 475 | f'${spec_time - reference_time:.1f}$ days', ha='right') 476 | 477 | plot_idx += 1 478 | 479 | plt.legend(['ParSNIP Model', 'Observed Spectra']) 480 | plt.title("") 481 | plt.xlabel('Rest-Frame Wavelength ($\\AA$)') 482 | plt.ylabel('Normalized Flux + Offset') 483 | plt.xlim(1500., 10500.) 484 | 485 | 486 | def plot_confusion_matrix(predictions, classifications, figsize=(5, 4), title=None, 487 | verbose=True): 488 | """Plot a confusion matrix 489 | 490 | Adapted from example that used to be at 491 | http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html 492 | 493 | Parameters 494 | ---------- 495 | predictions : `~astropy.table.Table` 496 | Predictions from `~ParsnipModel.predict_dataset` 497 | classifications : `~astropy.table.Table` 498 | Classifications from a `~Classifier` 499 | figsize : tuple, optional 500 | Figure size, by default (5, 4) 501 | title : str, optional 502 | Figure title, by default None 503 | verbose : bool, optional 504 | Whether to print additional statistics, by default True 505 | """ 506 | # true_types = np.char.decode(predictions['type']) 507 | true_types = predictions["type"] 508 | predicted_types = extract_top_classifications(classifications) 509 | 510 | if len(classifications.columns) == 3 and classifications.colnames[2] == 'Other': 511 | # Single class classification. All labels other than the target one are grouped 512 | # as "Other". 513 | single_type = classifications.colnames[1] 514 | true_types[true_types != single_type] = 'Other' 515 | 516 | type_names = classifications.colnames[1:] 517 | 518 | fig, ax = plt.subplots(figsize=figsize, constrained_layout=True) 519 | cm = confusion_matrix(true_types, predicted_types, labels=type_names, 520 | normalize='true') 521 | 522 | im = ax.imshow(cm, interpolation='nearest', 523 | cmap=plt.cm.Blues, vmin=0, vmax=1) 524 | tick_marks = np.arange(len(type_names)) 525 | ax.set_xticks(tick_marks, type_names, rotation=60, ha='right') 526 | ax.set_yticks(tick_marks, type_names) 527 | 528 | fmt = '.2f' 529 | thresh = cm.max() / 2. 530 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 531 | plt.text(j, i, format(cm[i, j], fmt), 532 | horizontalalignment="center", 533 | color="white" if cm[i, j] > thresh else "black") 534 | 535 | ax.set_ylabel('True Type') 536 | ax.set_xlabel('Predicted Type') 537 | if title is not None: 538 | ax.set_title(title) 539 | 540 | # Make a colorbar that is lined up with the plot 541 | from mpl_toolkits.axes_grid1 import make_axes_locatable 542 | ax1 = plt.gca() 543 | divider = make_axes_locatable(ax1) 544 | cax = divider.append_axes("right", size="4%", pad=0.25) 545 | plt.colorbar(im, cax=cax, label='Fraction of objects') 546 | 547 | if verbose: 548 | # Print out stats. 549 | print("Macro averaged completeness (Villar et al. 2020): " 550 | f"{np.diag(cm).mean():.4f}") 551 | print(f"Fraction correct: {np.mean(true_types == predicted_types):.4f}") 552 | 553 | return ax 554 | 555 | 556 | def plot_representation(predictions, plot_labels, mask=None, idx1=1, idx2=2, idx3=None, 557 | max_count=1000, show_legend=True, legend_ncol=1, marker='o', 558 | markersize=5, ax=None): 559 | """Plot the representation of a ParSNIP model 560 | 561 | Parameters 562 | ---------- 563 | predictions : `~astropy.table.Table` 564 | Predictions for a dataset from `~ParsnipModel.predict_dataset` 565 | plot_labels : List[str] 566 | Labels for each of the classes 567 | mask : `~np.array`, optional 568 | Mask to apply to the predictions, by default None 569 | idx1 : int, optional 570 | Intrinsic latent variable to plot on the x axis, by default 1 571 | idx2 : int, optional 572 | Intrinsic latent variable to plot on the y axis, by default 2 573 | idx3 : int, optional 574 | If specified, show a three paneled plot with this latent variable in the extra 575 | two panels plotted against the other ones 576 | max_count : int, optional 577 | Maximum number of light curves to show of each type, by default 1000 578 | show_legend : bool, optional 579 | Whether to show the legend, by default True 580 | legend_ncol : int, optional 581 | Number of columns to use in the legend, by default 1 582 | marker : str, optional 583 | Matplotlib marker to use, by default None 584 | markersize : int, optional 585 | Matplotlib marker size, by default 5 586 | ax : axis, optional 587 | Matplotlib axis, by default None 588 | """ 589 | color_map = { 590 | 'SNIa': 'C0', 591 | 'SNIax': 'C9', 592 | 'SNIa-91bg': 'lightgreen', 593 | 594 | 'SLSN': 'C2', 595 | 'SLSN-I': 'C2', 596 | 'SNII': 'C1', 597 | 'SNIIn': 'C3', 598 | 'SNIbc': 'C4', 599 | 600 | 'KN': 'C5', 601 | 602 | 'CaRT': 'C3', 603 | 'ILOT': 'C6', 604 | 'PISN': 'C8', 605 | 'TDE': 'C7', 606 | 607 | 'FELT': 'C5', 608 | 'Peculiar': 'C5', 609 | } 610 | 611 | if idx3 is not None: 612 | if ax is not None: 613 | raise Exception("Can't make 3D plot with prespecified axis.") 614 | 615 | fig = plt.figure(figsize=(8, 8), constrained_layout=True) 616 | 617 | gs = GridSpec(2, 2, figure=fig) 618 | 619 | ax12 = fig.add_subplot(gs[1, 0]) 620 | ax13 = fig.add_subplot(gs[0, 0], sharex=ax12) 621 | ax32 = fig.add_subplot(gs[1, 1], sharey=ax12) 622 | legend_ax = fig.add_subplot(gs[0, 1]) 623 | legend_ax.axis('off') 624 | 625 | plot_vals = [ 626 | (idx1, idx2, ax12), 627 | (idx1, idx3, ax13), 628 | (idx3, idx2, ax32), 629 | ] 630 | else: 631 | if ax is None: 632 | fig, ax = plt.subplots(figsize=(6, 6), constrained_layout=True) 633 | 634 | plot_vals = [ 635 | (idx1, idx2, ax) 636 | ] 637 | 638 | for xidx, yidx, ax in plot_vals: 639 | if mask is not None: 640 | cut_predictions = predictions[~mask] 641 | ax.scatter(cut_predictions[f's{xidx}'], cut_predictions[f's{yidx}'], 642 | c='k', s=3, alpha=0.1, label='Unknown') 643 | valid_predictions = predictions[mask] 644 | else: 645 | valid_predictions = predictions 646 | 647 | for type_name in plot_labels: 648 | type_predictions = valid_predictions[valid_predictions['type'] == 649 | type_name] 650 | 651 | color = color_map[type_name] 652 | 653 | markers, caps, bars = ax.errorbar( 654 | type_predictions[f's{xidx}'][:max_count], 655 | type_predictions[f's{yidx}'][:max_count], 656 | xerr=type_predictions[f's{xidx}_error'][:max_count], 657 | yerr=type_predictions[f's{yidx}_error'][:max_count], 658 | label=type_name, 659 | ls='', 660 | marker=marker, 661 | markersize=markersize, 662 | c=color, 663 | ) 664 | 665 | [bar.set_alpha(0.3) for bar in bars] 666 | 667 | if idx3 is not None: 668 | ax12.set_xlabel(f'$s_{idx1}$') 669 | ax12.set_ylabel(f'$s_{idx2}$') 670 | ax13.set_ylabel(f'$s_{idx3}$') 671 | ax13.tick_params(labelbottom=False) 672 | ax32.set_xlabel(f'$s_{idx3}$') 673 | ax32.tick_params(labelleft=False) 674 | 675 | if show_legend: 676 | handles, labels = ax12.get_legend_handles_labels() 677 | legend_ax.legend(handles=handles, labels=labels, loc='center', 678 | ncol=legend_ncol) 679 | else: 680 | ax.set_xlabel(f'$s_{idx1}$') 681 | ax.set_ylabel(f'$s_{idx2}$') 682 | 683 | if show_legend: 684 | ax.legend(ncol=legend_ncol) 685 | -------------------------------------------------------------------------------- /parsnip/settings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from .instruments import calculate_band_mw_extinctions, should_correct_background 4 | 5 | default_model = 'plasticc' 6 | 7 | default_settings = { 8 | 'model_version': 2, 9 | 10 | 'input_redshift': True, 11 | 12 | 'predict_redshift': False, 13 | 'specz_error': 0.01, 14 | 15 | 'min_wave': 1000., 16 | 'max_wave': 11000., 17 | 'spectrum_bins': 300, 18 | 'max_redshift': 4., 19 | 'band_oversampling': 51, 20 | 'time_window': 300, 21 | 'time_pad': 100, 22 | 'time_sigma': 20., 23 | 'color_sigma': 0.3, 24 | 'magsys': 'ab', 25 | 'error_floor': 0.01, 26 | 'zeropoint': 25.0, 27 | 28 | 'batch_size': 128, 29 | 'learning_rate': 1e-3, 30 | 'scheduler_factor': 0.5, 31 | 'min_learning_rate': 1e-5, 32 | 'penalty': 1e-3, 33 | 'optimizer': 'Adam', # 'Adam' or 'SGD' 34 | 'sgd_momentum': 0.9, 35 | 36 | 'latent_size': 3, 37 | 'encode_block': 'residual', 38 | 'encode_conv_architecture': [40, 80, 120, 160, 200, 200, 200], 39 | 'encode_conv_dilations': [1, 2, 4, 8, 16, 32, 64], 40 | 'encode_fc_architecture': [200], 41 | 'encode_time_architecture': [200], 42 | 'encode_latent_prepool_architecture': [200], 43 | 'encode_latent_postpool_architecture': [200], 44 | 'decode_architecture': [40, 80, 160], 45 | 46 | # Settings that will be filled later. 47 | 'derived_settings_calculated': None, 48 | 'bands': None, 49 | 'band_mw_extinctions': None, 50 | 'band_correct_background': None, 51 | } 52 | 53 | 54 | def update_derived_settings(settings): 55 | """Update the derived settings for a model 56 | 57 | This calculate the Milky Way extinctions in each band, and determines whether 58 | background correction should be applied. 59 | 60 | Parameters 61 | ---------- 62 | settings : dict 63 | Input settings 64 | 65 | Returns 66 | ------- 67 | dict 68 | Updated settings with derived settings calculated 69 | """ 70 | 71 | # Figure out what Milky Way extinction correction to apply for each band. 72 | settings['band_mw_extinctions'] = calculate_band_mw_extinctions(settings['bands']) 73 | 74 | # Figure out if we want to do background correction for each band. 75 | settings['band_correct_background'] = should_correct_background(settings['bands']) 76 | 77 | # Flag that the derived settings have been calculated so that we don't redo it when 78 | # loading a model from disk. 79 | settings['derived_settings_calculated'] = True 80 | 81 | return settings 82 | 83 | 84 | def update_settings_version(settings): 85 | """Update settings to a new version 86 | 87 | Parameters 88 | ---------- 89 | settings : dict 90 | Old settings 91 | 92 | Returns 93 | ------- 94 | dict 95 | Updates settings 96 | """ 97 | # Version 2, added redshift prediction. 98 | if settings['model_version'] < 2: 99 | settings['predict_redshift'] = False 100 | settings['specz_error'] = 0.05 101 | 102 | settings['model_version'] = default_settings['model_version'] 103 | 104 | return settings 105 | 106 | 107 | def parse_settings(bands, settings={}, ignore_unknown_settings=False): 108 | """Parse the settings for a ParSNIP model 109 | 110 | Parameters 111 | ---------- 112 | bands : List[str] 113 | Bands to use in the encoder model 114 | settings : dict, optional 115 | Settings to override, by default {} 116 | ignore_unknown_settings : bool, optional 117 | If False (default), raise an KeyError if there are any unknown settings. 118 | Otherwise, do nothing. 119 | 120 | Returns 121 | ------- 122 | dict 123 | Parsed settings dictionary 124 | 125 | Raises 126 | ------ 127 | KeyError 128 | If there are unknown keys in the input settings 129 | """ 130 | if 'derived_settings_calculated' in settings: 131 | # We are loading a prebuilt-model, don't recalculate everything. 132 | prebuilt_model = True 133 | else: 134 | prebuilt_model = False 135 | 136 | use_settings = default_settings.copy() 137 | use_settings['bands'] = bands 138 | 139 | for key, value in settings.items(): 140 | if key not in default_settings: 141 | if ignore_unknown_settings: 142 | continue 143 | else: 144 | raise KeyError(f"Unknown setting '{key}' with value '{value}'.") 145 | else: 146 | use_settings[key] = value 147 | 148 | if not prebuilt_model: 149 | use_settings = update_derived_settings(use_settings) 150 | 151 | if use_settings['model_version'] != default_settings['model_version']: 152 | # Update the settings to the latest version 153 | use_settings = update_settings_version(use_settings) 154 | 155 | return use_settings 156 | 157 | 158 | def parse_int_list(text): 159 | """Parse a string into a list of integers 160 | 161 | For example, the string "1,2,3,4" will be parsed to [1, 2, 3, 4]. 162 | 163 | Parameters 164 | ---------- 165 | text : str 166 | String to parse 167 | 168 | Returns 169 | ------- 170 | List[int] 171 | Parsed integer list 172 | """ 173 | result = [int(i) for i in text.split(',')] 174 | return result 175 | 176 | 177 | def build_default_argparse(description): 178 | """Build an argparse object that can handle all of the ParSNIP model settings. 179 | 180 | The resulting parsed namespace can be passed to parse_settings to get a ParSNIP 181 | settings object. 182 | 183 | Parameters 184 | ---------- 185 | description : str 186 | Description for the argument parser 187 | 188 | Returns 189 | ------- 190 | `~argparse.ArgumentParser` 191 | Argument parser with the ParSNIP model settings added as arguments 192 | """ 193 | parser = argparse.ArgumentParser(description=description) 194 | 195 | for key, value in default_settings.items(): 196 | if value is None: 197 | # Derived setting, not something that should be specified. 198 | continue 199 | 200 | if isinstance(value, bool): 201 | # Handle booleans. 202 | if value: 203 | parser.add_argument(f'--no_{key}', action='store_false', dest=key) 204 | else: 205 | parser.add_argument(f'--{key}', action='store_true', dest=key) 206 | elif isinstance(value, list): 207 | # Handle lists of integers 208 | parser.add_argument(f'--{key}', type=parse_int_list, default=value) 209 | else: 210 | # Handle other object types 211 | parser.add_argument(f'--{key}', type=type(value), default=value) 212 | 213 | return parser 214 | -------------------------------------------------------------------------------- /parsnip/sncosmo.py: -------------------------------------------------------------------------------- 1 | from scipy.interpolate import interp1d 2 | import numpy as np 3 | import os 4 | import sncosmo 5 | import torch 6 | 7 | import parsnip 8 | 9 | from .light_curve import SIDEREAL_SCALE 10 | 11 | 12 | class ParsnipSncosmoSource(sncosmo.Source): 13 | """SNCosmo interface for a ParSNIP model 14 | 15 | Parameters 16 | ---------- 17 | model : `~ParsnipModel` or str, optional 18 | ParSNIP model to use, or path to a model on disk. 19 | """ 20 | def __init__(self, model=None): 21 | if not isinstance(model, parsnip.ParsnipModel): 22 | model = parsnip.load_model(model) 23 | 24 | self._model = model 25 | 26 | model_name = os.path.splitext(os.path.basename(model.path))[0] 27 | self.name = f'parsnip_{model_name}' 28 | self._param_names = ( 29 | ['amplitude', 'color'] 30 | + [f's{i+1}' for i in range(self._model.settings['latent_size'])] 31 | ) 32 | self.param_names_latex = ( 33 | ['A', 'c'] + [f's_{i+1}' for i in 34 | range(self._model.settings['latent_size'])] 35 | ) 36 | self.version = 1 37 | 38 | self._parameters = np.zeros(len(self._param_names)) 39 | self._parameters[0] = 1. 40 | 41 | def _flux(self, phase, wave): 42 | # Generate predictions at the given phase. 43 | encoding = (torch.FloatTensor(self._parameters[2:])[None, :] 44 | .to(self._model.device)) 45 | phase = phase * SIDEREAL_SCALE 46 | phase = torch.FloatTensor(phase)[None, :].to(self._model.device) 47 | color = torch.FloatTensor([self._parameters[1]]).to(self._model.device) 48 | amplitude = (torch.FloatTensor([self._parameters[0]]).to(self._model.device)) 49 | 50 | model_spectra = self._model.decode_spectra(encoding, phase, color, amplitude) 51 | model_spectra = model_spectra.detach().cpu().numpy()[0] 52 | 53 | flux = interp1d(self._model.model_wave, model_spectra.T)(wave) 54 | 55 | return flux 56 | 57 | def minphase(self): 58 | return (-self._model.settings['time_window'] // 2 59 | - self._model.settings['time_pad']) 60 | 61 | def maxphase(self): 62 | return (self._model.settings['time_window'] // 2 63 | + self._model.settings['time_pad']) 64 | 65 | def minwave(self): 66 | return self._model.settings['min_wave'] 67 | 68 | def maxwave(self): 69 | return self._model.settings['max_wave'] 70 | -------------------------------------------------------------------------------- /parsnip/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def nmad(x): 6 | """Calculate the normalize median absolute deviation (NMAD) 7 | 8 | Parameters 9 | ---------- 10 | x : `~numpy.ndarray` 11 | Data to calculate the NMAD of 12 | 13 | Returns 14 | ------- 15 | float 16 | NMAD of the input 17 | """ 18 | return 1.4826 * np.median(np.abs(x - np.median(x))) 19 | 20 | 21 | def frac_to_mag(fractional_difference): 22 | """Convert a fractional difference to a difference in magnitude. 23 | 24 | Because this transformation is asymmetric for larger fractional changes, we 25 | take the average of positive and negative differences. 26 | 27 | This supports numpy broadcasting. 28 | 29 | Parameters 30 | ---------- 31 | fractional_difference : float 32 | Fractional flux difference 33 | 34 | Returns 35 | ------- 36 | float 37 | Difference in magnitudes 38 | """ 39 | pos_mag = 2.5 * np.log10(1 + fractional_difference) 40 | neg_mag = 2.5 * np.log10(1 - fractional_difference) 41 | mag_diff = (pos_mag - neg_mag) / 2.0 42 | 43 | return mag_diff 44 | 45 | 46 | def parse_device(device): 47 | """Figure out which PyTorch device to use 48 | 49 | Parameters 50 | ---------- 51 | device : str 52 | Requested device 53 | 54 | Returns 55 | ------- 56 | str 57 | Device to use 58 | """ 59 | # Figure out which device to run on. 60 | try: 61 | backend = getattr(torch.backends, device) 62 | is_available = getattr(backend, 'is_available') 63 | except AttributeError: 64 | device_available = False 65 | else: 66 | device_available = is_available() 67 | 68 | if device == 'cpu': 69 | # Requested CPU. 70 | use_device = 'cpu' 71 | elif device == 'cuda' and torch.cuda.is_available(): 72 | use_device = 'cuda' 73 | elif device_available: 74 | use_device = device 75 | else: 76 | print(f"WARNING: Device '{device}' not available, using 'cpu' instead.") 77 | use_device = 'cpu' 78 | 79 | return use_device 80 | 81 | 82 | def replace_nan_grads(parameters, value=0.0): 83 | """Replace NaN gradients 84 | 85 | Parameters 86 | ---------- 87 | parameters : Iterator[torch.Tensor] 88 | Model parameters, usually you can get them by `model.parameters()` 89 | value : float, optional 90 | Value to replace NaNs with 91 | """ 92 | for p in parameters: 93 | if p.grad is None: 94 | continue 95 | grads = p.grad.data 96 | grads[torch.isnan(grads)] = value 97 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | build-backend = "setuptools.build_meta" 7 | -------------------------------------------------------------------------------- /scripts/parsnip_build_plasticc_combined: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import lcdata 5 | import argparse 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser( 10 | description='Build a PLAsTiCC dataset to train on. The full dataset is too big ' 11 | 'to fit in memory, so we work with only part of it. This script creates a ' 12 | 'dataset that contains all of the training and DDF light curves, and 10% of ' 13 | 'the WFD light curves', 14 | argument_default=argparse.SUPPRESS 15 | ) 16 | 17 | parser.add_argument('data_directory', default='./data/', nargs='?', 18 | help='Default: ./data/') 19 | 20 | args = parser.parse_args() 21 | 22 | basedir = args.data_directory 23 | 24 | train_path = os.path.join(basedir, 'plasticc_train.h5') 25 | test_path = os.path.join(basedir, 'plasticc_test.h5') 26 | out_path = os.path.join(basedir, 'plasticc_combined.h5') 27 | 28 | if not os.path.exists(train_path) or not os.path.exists(test_path): 29 | print(f"PLAsTiCC dataset not found in directory '{basedir}'. Download it by " 30 | "running lcdata_download_plasticc") 31 | sys.exit() 32 | 33 | if os.path.exists(out_path): 34 | print(f"PLAsTiCC combined dataset already exists at '{out_path}'.") 35 | sys.exit() 36 | 37 | print("Loading PLAsTiCC training data...") 38 | plasticc_train = lcdata.read_hdf5(train_path) 39 | 40 | print("Loading PLAsTiCC test metadata...") 41 | plasticc_test = lcdata.read_hdf5(test_path, in_memory=False) 42 | 43 | print("Loading DDF light curves...") 44 | # The DDF light curves are all in the first chunk of the test dataset. Loading the 45 | # first 1% of the dataset includes all of them. 46 | chunk_size = len(plasticc_test) // 100 47 | plasticc_ddf = plasticc_test[:chunk_size].load() 48 | 49 | print("Loading 10% of the WFD light curves...") 50 | chunk_size = len(plasticc_test) // 10 51 | plasticc_wfd = plasticc_test[5*chunk_size:6*chunk_size].load() 52 | 53 | plasticc_combined = plasticc_train + plasticc_ddf + plasticc_wfd 54 | 55 | print("Writing out combined dataset.") 56 | plasticc_combined.write_hdf5(out_path) 57 | 58 | print(f"Done! Combined dataset now available at '{out_path}'.") 59 | -------------------------------------------------------------------------------- /scripts/parsnip_predict: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from tqdm import tqdm 3 | import argparse 4 | import astropy.table 5 | import lcdata 6 | import os 7 | import parsnip 8 | import sys 9 | import time 10 | 11 | 12 | if __name__ == '__main__': 13 | start_time = time.time() 14 | 15 | parser = argparse.ArgumentParser( 16 | description='Generate predictions from a ParSNIP model for a dataset.' 17 | ) 18 | parser.add_argument('predictions_path') 19 | parser.add_argument('model_path') 20 | parser.add_argument('dataset_path') 21 | 22 | parser.add_argument('--overwrite', action='store_true') 23 | parser.add_argument('--chunk_size', default=10000, type=int) 24 | parser.add_argument('--augments', default=0, type=int) 25 | 26 | parser.add_argument('--device', default='cuda') 27 | parser.add_argument('--threads', default=8, type=int) 28 | 29 | args = vars(parser.parse_args()) 30 | 31 | predictions_path = args['predictions_path'] 32 | if os.path.exists(predictions_path): 33 | if args['overwrite']: 34 | print(f"Predictions '{predictions_path}' already exist, overwriting!") 35 | else: 36 | print(f"Predictions '{predictions_path}' already exist, skipping!") 37 | sys.exit() 38 | 39 | # Load the model 40 | model = parsnip.load_model( 41 | args['model_path'], 42 | device=args['device'], 43 | threads=args['threads'], 44 | ) 45 | 46 | # Load the metadata for the dataset. We parse the dataset in chunks since we can't 47 | # necessarily fit large datasets all in memory. 48 | dataset = parsnip.load_dataset( 49 | args['dataset_path'], 50 | require_redshift=not model.settings['predict_redshift'], 51 | in_memory=False 52 | ) 53 | 54 | # Parse the dataset in chunks. For large datasets, we can't fit them all in memory 55 | # at the same time. 56 | if isinstance(dataset, lcdata.HDF5Dataset): 57 | chunk_size = args['chunk_size'] 58 | num_chunks = dataset.count_chunks(chunk_size) 59 | chunks = tqdm(dataset.iterate_chunks(chunk_size), total=num_chunks, 60 | file=sys.stdout) 61 | else: 62 | chunks = [dataset] 63 | 64 | # Optionally, the dataset can be augmented a given number of times. 65 | augments = args['augments'] 66 | 67 | predictions = [] 68 | 69 | for chunk in chunks: 70 | # Preprocess the light curves 71 | chunk = model.preprocess(chunk, verbose=False) 72 | 73 | # Generate the prediction 74 | if augments == 0: 75 | chunk_predictions = model.predict_dataset(chunk) 76 | else: 77 | chunk_predictions = model.predict_dataset_augmented(chunk, 78 | augments=augments) 79 | predictions.append(chunk_predictions) 80 | 81 | predictions = astropy.table.vstack(predictions, 'exact') 82 | 83 | # Save the predictions 84 | os.makedirs(os.path.dirname(predictions_path), exist_ok=True) 85 | 86 | # By default, assume that we are writing to HDF5 format. In this case, we serialize 87 | # the table to preserve masked columns and data types. Note that the output will 88 | # only be able to be read by astropy.table.Table. 89 | try: 90 | predictions.write(predictions_path, overwrite=True, serialize_meta=True, 91 | path='/predictions') 92 | except TypeError: 93 | # Writing to some other format that doesn't support serialize_meta. 94 | print(f"WARNING: filetype given by '{predictions_path}' may not handle masked " 95 | "columns correctly. HDF5 format (extension .h5) is recommended.") 96 | predictions.write(predictions_path, overwrite=True) 97 | 98 | # Calculate time taken in minutes 99 | end_time = time.time() 100 | elapsed_time = (end_time - start_time) / 60. 101 | print(f"Total time: {elapsed_time:.2f} min") 102 | -------------------------------------------------------------------------------- /scripts/parsnip_train: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import numpy as np 3 | import os 4 | import sys 5 | 6 | import parsnip 7 | import time 8 | 9 | 10 | if __name__ == '__main__': 11 | start_time = time.time() 12 | 13 | parser = parsnip.build_default_argparse('Train a ParSNIP model.') 14 | 15 | parser.add_argument('model_path') 16 | parser.add_argument('dataset_paths', nargs='+') 17 | 18 | parser.add_argument('--overwrite', action='store_true') 19 | parser.add_argument('--max_epochs', type=int, default=1000) 20 | parser.add_argument('--split_train_test', action='store_true') 21 | parser.add_argument('--bands', default=None) 22 | 23 | parser.add_argument('--device', default='cuda') 24 | parser.add_argument('--threads', default=8, type=int) 25 | 26 | # Parse the arguments 27 | args = vars(parser.parse_args()) 28 | 29 | # Figure out if we have already trained a model at this path. 30 | model_path = args['model_path'] 31 | if os.path.exists(model_path): 32 | if args['overwrite']: 33 | print(f"Model '{model_path}' already exists, overwriting!") 34 | else: 35 | print(f"Model '{model_path}' already exists, skipping!") 36 | sys.exit() 37 | 38 | dataset = parsnip.load_datasets( 39 | args['dataset_paths'], 40 | require_redshift=not args['predict_redshift'], 41 | ) 42 | 43 | # Figure out which bands we want to use for the model. If specific ones were 44 | # specified on the command line, use those. Otherwise, use all available bands. 45 | bands = args.pop('bands') 46 | if bands is None: 47 | bands = parsnip.get_bands(dataset) 48 | else: 49 | bands = bands.split(',') 50 | 51 | model = parsnip.ParsnipModel( 52 | model_path, 53 | bands, 54 | device=args['device'], 55 | threads=args['threads'], 56 | settings=args, 57 | ignore_unknown_settings=True 58 | ) 59 | 60 | dataset = model.preprocess(dataset) 61 | 62 | if args['split_train_test']: 63 | train_dataset, test_dataset = parsnip.split_train_test(dataset) 64 | model.fit(train_dataset, test_dataset=test_dataset, 65 | max_epochs=args['max_epochs']) 66 | else: 67 | train_dataset = dataset 68 | model.fit(train_dataset, max_epochs=args['max_epochs']) 69 | 70 | # Save the score to a file for quick comparisons. If we have a small dataset, 71 | # repeat the dataset several times when calculating the score. 72 | rounds = int(np.ceil(25000 / len(train_dataset))) 73 | 74 | train_score = model.score(train_dataset, rounds=rounds) 75 | if args['split_train_test']: 76 | test_score = model.score(test_dataset, rounds=10 * rounds) 77 | else: 78 | test_score = -1. 79 | 80 | end_time = time.time() 81 | 82 | # Time taken in minutes 83 | elapsed_time = (end_time - start_time) / 60. 84 | 85 | with open('./parsnip_results.log', 'a') as f: 86 | print(f'{model_path} {model.epoch} {elapsed_time:.2f} {train_score:.4f} ' 87 | f'{test_score:.4f}', file=f) 88 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = astro-parsnip 3 | version = 1.4.2 4 | author = Kyle Boone 5 | author_email = kyboone@uw.edu 6 | description = Deep generative modeling of astronomical transient light curves 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | url = https://github.com/kboone/parsnip 10 | classifiers = 11 | Programming Language :: Python :: 3 12 | License :: OSI Approved :: MIT License 13 | Operating System :: OS Independent 14 | 15 | [options] 16 | packages = find: 17 | python_requires = >=3.6 18 | install_requires = 19 | astropy 20 | extinction 21 | lcdata>=1.1.1 22 | lightgbm>=2.3.1,<3 23 | matplotlib 24 | numpy 25 | scikit-learn 26 | scipy 27 | sncosmo>=2.6 28 | torch 29 | tqdm 30 | scripts = 31 | scripts/parsnip_build_plasticc_combined 32 | scripts/parsnip_predict 33 | scripts/parsnip_train 34 | include_package_data = True 35 | 36 | [options.package_data] 37 | parsnip = models/*.pt 38 | 39 | [options.extras_require] 40 | docs = # Required to build the docs. 41 | numpy 42 | sphinx 43 | sphinx_rtd_theme 44 | pillow 45 | numpydoc 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | setup() 3 | --------------------------------------------------------------------------------