├── COPYRIGHT ├── python └── lsst │ ├── __init__.py │ └── multiprofit │ ├── __init__.py │ ├── limits.py │ ├── errors.py │ ├── config.py │ ├── model_utils.py │ ├── utils.py │ ├── observationconfig.py │ ├── priors.py │ ├── modelconfig.py │ ├── asinhstretchsigned.py │ ├── fit_catalog.py │ ├── psfmodel_utils.py │ ├── transforms.py │ ├── fit_bootstrap_model.py │ ├── sourceconfig.py │ └── componentconfig.py ├── tests ├── SConscript ├── test_psfmodel_utils.py ├── test_fit_psf.py ├── test_observationconfig.py ├── test_plots.py ├── test_componentconfig.py ├── test_modelconfig.py ├── test_sourceconfig.py ├── test_fit_bootstrap_model.py └── test_modeller.py ├── requirements.txt ├── examples ├── 222.51551376,0.09749601_g_psf.fits ├── 222.51551376,0.09749601_i_psf.fits ├── 222.51551376,0.09749601_r_psf.fits ├── 222.51551376,0.09749601_300x300_g.fits ├── 222.51551376,0.09749601_300x300_i.fits ├── 222.51551376,0.09749601_300x300_r.fits ├── 222.51551376,0.09749601_300x300_mask_inv_highsn.npz ├── test_gaussians.py ├── plot_sersic_mix.py ├── test_gaussian_gradients.py ├── test_mgsersic.py ├── fithsc.py └── fithsc.ipynb ├── doc ├── .gitignore ├── index.rst ├── manifest.yaml ├── conf.py └── lsst.multiprofit │ └── index.rst ├── SConstruct ├── .github └── workflows │ └── rebase_checker.yaml ├── ups └── multiprofit.table ├── .gitignore ├── .pre-commit-config.yaml ├── README.rst └── pyproject.toml /COPYRIGHT: -------------------------------------------------------------------------------- 1 | Copyright 2018-2021 The Trustees of Princeton University 2 | -------------------------------------------------------------------------------- /python/lsst/__init__.py: -------------------------------------------------------------------------------- 1 | import pkgutil 2 | 3 | __path__ = pkgutil.extend_path(__path__, __name__) 4 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/__init__.py: -------------------------------------------------------------------------------- 1 | import pkgutil 2 | 3 | __path__ = pkgutil.extend_path(__path__, __name__) 4 | -------------------------------------------------------------------------------- /tests/SConscript: -------------------------------------------------------------------------------- 1 | # -*- python -*- 2 | from lsst.sconsUtils import scripts 3 | scripts.BasicSConscript.tests(pyList=[]) 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astropy 2 | galsim 3 | gauss2d 4 | matplotlib 5 | numpy 6 | pydantic 7 | pytest 8 | scipy 9 | seaborn 10 | -------------------------------------------------------------------------------- /examples/222.51551376,0.09749601_g_psf.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_g_psf.fits -------------------------------------------------------------------------------- /examples/222.51551376,0.09749601_i_psf.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_i_psf.fits -------------------------------------------------------------------------------- /examples/222.51551376,0.09749601_r_psf.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_r_psf.fits -------------------------------------------------------------------------------- /doc/.gitignore: -------------------------------------------------------------------------------- 1 | # Doxygen products 2 | html 3 | xml 4 | *.tag 5 | *.inc 6 | doxygen.conf 7 | 8 | # Sphinx products 9 | _build 10 | py-api 11 | -------------------------------------------------------------------------------- /examples/222.51551376,0.09749601_300x300_g.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_300x300_g.fits -------------------------------------------------------------------------------- /examples/222.51551376,0.09749601_300x300_i.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_300x300_i.fits -------------------------------------------------------------------------------- /examples/222.51551376,0.09749601_300x300_r.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_300x300_r.fits -------------------------------------------------------------------------------- /SConstruct: -------------------------------------------------------------------------------- 1 | # -*- python -*- 2 | from lsst.sconsUtils import scripts 3 | # Python-only package 4 | scripts.BasicSConstruct("multiprofit", disableCc=True, noCfgFile=True) 5 | -------------------------------------------------------------------------------- /examples/222.51551376,0.09749601_300x300_mask_inv_highsn.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_300x300_mask_inv_highsn.npz -------------------------------------------------------------------------------- /.github/workflows/rebase_checker.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: Check that 'main' is not merged into the development branch 3 | 4 | on: pull_request 5 | 6 | jobs: 7 | call-workflow: 8 | uses: lsst/rubin_workflows/.github/workflows/rebase_checker.yaml@main 9 | -------------------------------------------------------------------------------- /ups/multiprofit.table: -------------------------------------------------------------------------------- 1 | # List EUPS dependencies of this package here. 2 | # - Any package whose API is used directly should be listed explicitly. 3 | # - Common third-party packages can be assumed to be recursively included by 4 | # the "sconsUtils" package. 5 | setupRequired(sconsUtils) 6 | setupRequired(gauss2dfit) 7 | 8 | envPrepend(PYTHONPATH, ${PRODUCT_DIR}/python) 9 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | ################################ 2 | multiprofit documentation preview 3 | ################################ 4 | 5 | .. This page is for local development only. It isn't published to pipelines.lsst.io. 6 | 7 | .. Link the index pages of package and module documentation directions (listed in manifest.yaml). 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | 12 | lsst.multiprofit/index 13 | -------------------------------------------------------------------------------- /doc/manifest.yaml: -------------------------------------------------------------------------------- 1 | # Documentation manifest. 2 | 3 | # List of names of Python modules in this package. 4 | # For each module there is a corresponding module doc subdirectory. 5 | modules: 6 | - "lsst.multiprofit" 7 | 8 | # Name of the static content directories (subdirectories of `_static`). 9 | # Static content directories are usually named after the package. 10 | # Most packages do not need a static content directory (leave commented out). 11 | # statics: 12 | # - "_static/multiprofit" 13 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | """Sphinx configuration file for an LSST stack package. 2 | 3 | This configuration only affects single-package Sphinx documentation builds. 4 | For more information, see: 5 | https://developer.lsst.io/stack/building-single-package-docs.html 6 | """ 7 | 8 | from documenteer.conf.pipelinespkg import * # noqa: F403, import * 9 | 10 | project = "multiprofit" 11 | html_theme_options["logotext"] = project # noqa: F405, unknown name 12 | html_title = project 13 | html_short_title = project 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Build 2 | build/ 3 | dist/ 4 | multiprofit.egg-info/ 5 | __pycache__/ 6 | .coverage 7 | .sconsign.dblite 8 | config.log 9 | version.py 10 | 11 | # Generated by pip install -e . 12 | python/lsst_multiprofit.egg-info/ 13 | 14 | # IDE folders and files 15 | .cache/ 16 | .idea/ 17 | .theia/ 18 | .vscode/ 19 | compile_commands.json 20 | 21 | # test outputs 22 | tests/.tests 23 | 24 | # pytest plugins 25 | .hypothesis/ 26 | prof/ 27 | 28 | # in-source notebooks 29 | examples/.ipynb_checkpoints 30 | -------------------------------------------------------------------------------- /examples/test_gaussians.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as timer 2 | 3 | from test_utils import gaussian_test 4 | 5 | start = timer() 6 | test = gaussian_test( 7 | nbenchmark=100, 8 | do_like=True, 9 | do_residual=True, 10 | do_grad=True, 11 | do_jac=True, 12 | do_meas_modelfit=False, 13 | nsub=4, 14 | ) 15 | for x in test: 16 | print(f"re={x['reff']:.3f} q={x['axrat']:.2f}" f" ang={x['ang']:2.1f} { x['string']}") 17 | print(f"Test complete in {timer() - start:.2f}s") 18 | -------------------------------------------------------------------------------- /examples/plot_sersic_mix.py: -------------------------------------------------------------------------------- 1 | import lsst.gauss2d.fit as g2f 2 | from lsst.multiprofit.plots import plot_sersicmix_interp 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from scipy.interpolate import CubicSpline 6 | 7 | interps = { 8 | "lin": (g2f.LinearSersicMixInterpolator(), "-"), 9 | "gsl-csp": (g2f.GSLSersicMixInterpolator(interp_type=g2f.InterpType.cspline), (0, (8, 8))), 10 | "scipy-csp": ((CubicSpline, {}), (0, (4, 4))), 11 | } 12 | 13 | for n_low, n_hi in ((0.5, 0.7), (0.8, 1.2), (2.2, 4.4)): 14 | n_ser = 10 ** np.linspace(np.log10(n_low), np.log10(n_hi), 100) 15 | plot_sersicmix_interp(interps=interps, n_ser=n_ser, figsize=(10, 8)) 16 | plt.tight_layout() 17 | plt.show() 18 | -------------------------------------------------------------------------------- /examples/test_gaussian_gradients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from test_utils import gradient_test 3 | 4 | np.random.seed(1) 5 | for reff, axrat, ang in [(4.457911011776755, 0.6437167462922668, 44.55485360075)]: 6 | for reff_psf, axrat_psf, ang_psf in [ 7 | (0, 0, 0), 8 | (1.5121054822774742, 0.9135936343054303, 50.30562156585181), 9 | (3.7442185156735914, 0.8695066738347554, -39.40729158864958), 10 | ]: 11 | grads, dlls, diffabs = gradient_test( 12 | xdim=23, 13 | ydim=19, 14 | reff=5, 15 | angle=20, 16 | reff_psf=reff_psf, 17 | axrat_psf=axrat_psf, 18 | angle_psf=ang_psf, 19 | printout=False, 20 | plot=False, 21 | ) 22 | print("Gradient ", grads) 23 | print("Finite Diff. ", dlls) 24 | print("FinD. - Grad.", dlls - grads) 25 | print("FinD. / Grad.", dlls / grads) 26 | print("Jacobian sum abs. diff.", diffabs) 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.5.0 7 | hooks: 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - repo: https://github.com/psf/black-pre-commit-mirror 11 | rev: 23.3.0 12 | hooks: 13 | - id: black 14 | # It is recommended to specify the latest version of Python 15 | # supported by your project here, or alternatively use 16 | # pre-commit's default_language_version, see 17 | # https://pre-commit.com/#top_level-default_language_version 18 | - repo: https://github.com/pycqa/isort 19 | rev: 5.13.2 20 | hooks: 21 | - id: isort 22 | name: isort (python) 23 | - repo: https://github.com/astral-sh/ruff-pre-commit 24 | # Ruff version. 25 | rev: v0.1.8 26 | hooks: 27 | - id: ruff 28 | - repo: https://github.com/numpy/numpydoc 29 | rev: "v1.6.0" 30 | hooks: 31 | - id: numpydoc-validation 32 | -------------------------------------------------------------------------------- /examples/test_mgsersic.py: -------------------------------------------------------------------------------- 1 | from timeit import default_timer as timer 2 | 3 | import numpy as np 4 | from test_utils import mgsersic_test 5 | 6 | start = timer() 7 | reffs = [2.0, 5.0] 8 | angs = np.linspace(0, 90, 7) 9 | axrats = [1, 0.5, 0.2, 0.1, 0.01] 10 | nsers = [0.5, 1.0, 2.0, 4.0, 6.0] 11 | 12 | mgsersic_test(reff=reffs[1], nser=2, axrat=axrats[1], angle=angs[2], do_galsim=True, plot=True) 13 | 14 | for nser in nsers: 15 | for reff in reffs: 16 | for axrat in axrats: 17 | for ang in angs: 18 | diffs = mgsersic_test( 19 | reff=reff, 20 | nser=nser, 21 | axrat=axrat, 22 | angle=ang, 23 | do_galsim=True, 24 | plot=False, 25 | flux=1.0, 26 | ) 27 | diff_abs = np.sum(np.abs(diffs["gs_ser_nopix"])) 28 | print( 29 | f"Test: nser={nser} reff={reff} axrat={axrat} ang={ang}" 30 | f" sum(abs(diff))={diff_abs:.3f}%" 31 | ) 32 | -------------------------------------------------------------------------------- /tests/test_psfmodel_utils.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import lsst.gauss2d.fit as g2f 23 | from lsst.multiprofit.psfmodel_utils import make_psf_source 24 | 25 | 26 | def test_make_psf_source(): 27 | source = make_psf_source([2, 3, 4]) 28 | assert source.gaussians(g2f.Channel.NONE).size == 3 29 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/limits.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from lsst.gauss2d.fit import LimitsD as Limits 23 | 24 | __all__ = ["limits_ref"] 25 | 26 | 27 | # TODO: Replace with a parameter factory and/or profile factory 28 | limits_ref = { 29 | "none": Limits(), 30 | "axrat": Limits(min=1e-2, max=1), 31 | "con": Limits(min=1, max=10), 32 | "fluxfrac": Limits(min=0.001, max=0.999), 33 | "n_ser": Limits(min=0.3, max=6.0), 34 | "n_ser_multigauss": Limits(min=0.5, max=6.0), 35 | "rho": Limits(min=-0.999, max=0.999), 36 | } 37 | -------------------------------------------------------------------------------- /tests/test_fit_psf.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from lsst.multiprofit.fit_psf import CatalogPsfFitterConfig, CatalogPsfFitterConfigData 23 | import pytest 24 | 25 | 26 | @pytest.fixture(scope="module") 27 | def fitter_config() -> CatalogPsfFitterConfig: 28 | config = CatalogPsfFitterConfig() 29 | return config 30 | 31 | @pytest.fixture(scope="module") 32 | def fitter_config_data(fitter_config) -> CatalogPsfFitterConfigData: 33 | config_data = CatalogPsfFitterConfigData(config=fitter_config) 34 | return config_data 35 | 36 | 37 | def test_fitter_config_data(fitter_config_data): 38 | parameters = fitter_config_data.parameters 39 | psf_model = fitter_config_data.psf_model 40 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/errors.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from abc import abstractmethod 23 | 24 | __all__ = ["CatalogError", "PsfRebuildFitFlagError"] 25 | 26 | 27 | class CatalogError(RuntimeError): 28 | """RuntimeError that can be caught and flagged in a column.""" 29 | 30 | @classmethod 31 | @abstractmethod 32 | def column_name(cls) -> str: 33 | """Return the standard column name for this error.""" 34 | 35 | 36 | class NoDataError(CatalogError): 37 | """RuntimeError for when there is no data to fit.""" 38 | 39 | @classmethod 40 | @abstractmethod 41 | def column_name(cls) -> str: 42 | return "no_data_flag" 43 | 44 | 45 | class PsfRebuildFitFlagError(RuntimeError): 46 | """RuntimeError for when a PSF can't be rebuilt because the fit failed.""" 47 | 48 | @classmethod 49 | def column_name(cls): 50 | return "psf_fit_flag" 51 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | This is the original repository for multiprofit and is now superseded by 2 | `multiprofit `_ in the lsst organization. 3 | 4 | MultiProFit 5 | ########### 6 | 7 | .. todo image:: https://travis-ci.org/ICRAR/multiprofit.svg?branch=master 8 | .. todo :target: https://travis-ci.org/lsst-dm/multiprofit 9 | 10 | .. todo image:: https://img.shields.io/pypi/v/multiprofit.svg 11 | .. todo :target: https://pypi.python.org/pypi/multiprofit 12 | 13 | .. todo image:: https://img.shields.io/pypi/pyversions/multiprofit.svg 14 | .. todo :target: https://pypi.python.org/pypi/multiprofit 15 | 16 | *multiprofit* is a Python astronomical source modelling code, inspired by `ProFit `_, but made for LSST Data Management. MultiProFit means Multiple Profile Fitting. The 18 | multi- aspect can be multi-object, multi-component, multi-band, multi-instrument, and someday multi-epoch. 19 | 20 | *multiprofit* can fit any kind of imaging data while modelling sources as Gaussian mixtures - including 21 | approximations to Sersic profiles - using a Gaussian pixel-convolved point spread function. It can also use 22 | `GalSim `_ or `libprofit `_ 23 | via `pyprofit `_ to generate true Sersic and/or other supported 24 | models convolved with arbitrary PSFs images or models. 25 | 26 | *multiprofit* has support for multi-object fitting and experimental support for multi-band fitting, albeit 27 | currently limited to pixel-matched images of identical dimensions. Unlike ProFit, Bayesian MCMC is not 28 | available (yet). 29 | 30 | *multiprofit* requires Python 3, along with `pybind11 `_ for C++ bindings, 31 | and `gauss2d `_ for evaluating Gaussian mixtures. It can be installed 32 | using setup.py like so: 33 | 34 | python3 setup.py install --user 35 | 36 | .. todo *multiprofit* is available in `PyPI `_ 37 | .. and thus can be easily installed via:: 38 | 39 | .. pip install multiprofit 40 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/config.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from typing import Any 23 | 24 | import lsst.pex.config as pexConfig 25 | 26 | __all__ = ["set_config_from_dict"] 27 | 28 | 29 | def set_config_from_dict( 30 | config: pexConfig.Config | pexConfig.dictField.Dict | pexConfig.configDictField.ConfigDict | dict, 31 | overrides: dict[str, Any], 32 | ): 33 | """Set `lsst.pex.config` params from a dict. 34 | 35 | Parameters 36 | ---------- 37 | config 38 | A config, dictField or configDictField object. 39 | overrides 40 | A dict of key-value pairs to override in the config. 41 | """ 42 | is_config_dict = hasattr(config, "__getitem__") 43 | if is_config_dict: 44 | keys = tuple(config.keys()) 45 | for key in keys: 46 | if key not in overrides: 47 | del config[key] 48 | for key, value in overrides.items(): 49 | if isinstance(value, dict): 50 | attr = config[key] if is_config_dict else getattr(config, key) 51 | set_config_from_dict(attr, value) 52 | else: 53 | try: 54 | if is_config_dict: 55 | config[key] = value 56 | else: 57 | setattr(config, key, value) 58 | except Exception as e: 59 | print(e) 60 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/model_utils.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from typing import Any 23 | 24 | import lsst.gauss2d as g2 25 | import lsst.gauss2d.fit as g2f 26 | 27 | __all__ = ["make_image_gaussians", "make_psf_model_null"] 28 | 29 | 30 | def make_image_gaussians( 31 | gaussians_source: g2.Gaussians, 32 | gaussians_kernel: g2.Gaussians | None = None, 33 | **kwargs: Any, 34 | ) -> g2.ImageD: 35 | """Make an image array from a set of Gaussians. 36 | 37 | Parameters 38 | ---------- 39 | gaussians_source 40 | Gaussians representing components of sources. 41 | gaussians_kernel 42 | Gaussians representing the smoothing kernel. 43 | **kwargs 44 | Additional keyword arguments to pass to gauss2d.make_gaussians_pixel_D 45 | (i.e. image size, etc.). 46 | 47 | Returns 48 | ------- 49 | image 50 | The rendered image of the given Gaussians. 51 | """ 52 | if gaussians_kernel is None: 53 | gaussians_kernel = g2.Gaussians([g2.Gaussian()]) 54 | gaussians = g2.ConvolvedGaussians( 55 | [ 56 | g2.ConvolvedGaussian(source=source, kernel=kernel) 57 | for source in gaussians_source for kernel in gaussians_kernel 58 | ] 59 | ) 60 | return g2.make_gaussians_pixel_D(gaussians=gaussians, **kwargs) 61 | 62 | 63 | def make_psf_model_null() -> g2f.PsfModel: 64 | """Make a default (null) PSF model. 65 | 66 | Returns 67 | ------- 68 | model 69 | A null PSF model consisting of a single, normalized, zero-size 70 | Gaussian. 71 | """ 72 | return g2f.PsfModel(g2f.GaussianComponent.make_uniq_default_gaussians([0], True)) 73 | -------------------------------------------------------------------------------- /tests/test_observationconfig.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | import lsst.gauss2d as g2 22 | import lsst.gauss2d.fit as g2f 23 | from lsst.multiprofit.observationconfig import CoordinateSystemConfig, ObservationConfig 24 | import pytest 25 | 26 | 27 | @pytest.fixture(scope="module") 28 | def kwargs_coordsys(): 29 | return { 30 | 'dx1': 0.4, 31 | 'dy2': 1.6, 32 | 'x_min': -51.3, 33 | 'y_min': 1684.5, 34 | } 35 | 36 | 37 | @pytest.fixture(scope="module") 38 | def config_coordsys(kwargs_coordsys) -> CoordinateSystemConfig: 39 | return CoordinateSystemConfig(**kwargs_coordsys) 40 | 41 | 42 | @pytest.fixture(scope="module") 43 | def coordsys(config_coordsys) -> g2.CoordinateSystem: 44 | return config_coordsys.make_coordinate_system() 45 | 46 | 47 | def test_CoordinateSystemConfig(kwargs_coordsys, coordsys): 48 | for kwarg, value in kwargs_coordsys.items(): 49 | assert getattr(coordsys, kwarg) == value 50 | 51 | 52 | def test_ObservationConfig(): 53 | n_cols, n_rows = 15, 17 54 | shape = [n_rows, n_cols] 55 | config = ObservationConfig(n_cols=n_cols, n_rows=n_rows) 56 | observation = config.make_observation() 57 | assert observation.channel == g2f.Channel.NONE 58 | planes = ("image", "mask_inv", "sigma_inv") 59 | for plane in planes: 60 | attr = getattr(observation, plane) 61 | assert attr.shape == shape 62 | config.band = "red" 63 | observation2 = config.make_observation() 64 | assert observation2.channel == g2f.Channel.get("red") 65 | for plane in planes: 66 | attr1, attr2 = (getattr(obs, plane) for obs in (observation, observation2)) 67 | assert attr1 is not attr2 68 | # Initialize both images; comparison checks equality 69 | attr1.fill(0) 70 | attr2.fill(0) 71 | assert attr1 == attr2 72 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/utils.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from typing import Any 23 | 24 | import lsst.gauss2d.fit as g2f 25 | import numpy 26 | import numpy as np 27 | 28 | __all__ = ["ArbitraryAllowedConfig", "FrozenArbitraryAllowedConfig", "get_params_uniq", "normalize"] 29 | 30 | 31 | class ArbitraryAllowedConfig: 32 | """Pydantic config to allow arbitrary typed Fields.""" 33 | 34 | arbitrary_types_allowed = True 35 | # Disallow any extra kwargs 36 | extra = "forbid" 37 | 38 | 39 | class FrozenArbitraryAllowedConfig(ArbitraryAllowedConfig): 40 | """Pydantic config to allow arbitrary typed Fields for frozen classes.""" 41 | 42 | 43 | def get_params_uniq(parametric: g2f.Parametric, **kwargs: Any): 44 | """Get a sorted set of parameters matching a filter. 45 | 46 | Parameters 47 | ---------- 48 | parametric 49 | The parametric object to get parameters from. 50 | **kwargs 51 | Keyword arguments to pass to g2f.ParamFilter. 52 | 53 | Returns 54 | ------- 55 | params 56 | The unique parameters from the parametric object matching the filter. 57 | """ 58 | params = parametric.parameters(paramfilter=g2f.ParamFilter(**kwargs)) 59 | # This should always return the same list as: 60 | # list({p: None for p in }.keys()) 61 | return g2f.params_unique(params) 62 | 63 | 64 | def normalize(ndarray: numpy.ndarray, return_sum: bool = False): 65 | """Normalize a numpy array. 66 | 67 | Parameters 68 | ---------- 69 | ndarray 70 | The array to normalize. 71 | return_sum 72 | Whether to return the sum. 73 | 74 | Returns 75 | ------- 76 | ndarray 77 | The input ndarray. 78 | total 79 | The sum of the array. 80 | """ 81 | total = np.sum(ndarray) 82 | ndarray /= total 83 | if return_sum: 84 | return ndarray, total 85 | return ndarray 86 | -------------------------------------------------------------------------------- /doc/lsst.multiprofit/index.rst: -------------------------------------------------------------------------------- 1 | .. py:currentmodule:: lsst.multiprofit 2 | 3 | .. _lsst.multiprofit: 4 | 5 | ####################### 6 | lsst.multiprofit 7 | ####################### 8 | 9 | MultiProFit is a Python astronomical source modelling code. See the README for more information about the package. 10 | 11 | .. .. _lsst.multiprofit-using: 12 | 13 | .. Using lsst.multiprofit 14 | .. ============================= 15 | 16 | .. toctree linking to topics related to using the module's APIs. 17 | 18 | .. .. toctree:: 19 | .. :maxdepth: 1 20 | 21 | .. _lsst.multiprofit-contributing: 22 | 23 | Contributing 24 | ============ 25 | 26 | ``lsst.multiprofit`` is developed at https://github.com/lsst-dm/multiprofit. 27 | You can find Jira issues for this module through `search `_. 28 | 29 | .. If there are topics related to developing this module (rather than using it), link to this from a toctree placed here. 30 | 31 | .. .. toctree:: 32 | .. :maxdepth: 1 33 | 34 | .. _lsst.multiprofit-command-line-taskref: 35 | 36 | Task reference 37 | ============== 38 | 39 | ``lsst.multiprofit`` tasks are implemented in `meas_extensions_multiprofit `_. 40 | 41 | Configurations 42 | -------------- 43 | 44 | .. lsst-configs:: 45 | :root: lsst.multiprofit 46 | :toctree: configs 47 | 48 | .. .. _lsst.multiprofit-scripts: 49 | 50 | .. Script reference 51 | .. ================ 52 | 53 | .. .. TODO: Add an item to this toctree for each script reference topic in the scripts subdirectory. 54 | 55 | .. toctree:: 56 | :maxdepth: 1 57 | 58 | .. _lsst.multiprofit-pyapi: 59 | 60 | Python API reference 61 | ==================== 62 | 63 | .. automodapi:: lsst.multiprofit.asinhstretchsigned 64 | :no-inheritance-diagram: 65 | 66 | .. automodapi:: lsst.multiprofit.componentconfig 67 | :no-inheritance-diagram: 68 | 69 | .. automodapi:: lsst.multiprofit.config 70 | :no-inheritance-diagram: 71 | 72 | .. automodapi:: lsst.multiprofit.fit_bootstrap_model 73 | :no-inheritance-diagram: 74 | 75 | .. automodapi:: lsst.multiprofit.fit_catalog 76 | :no-inheritance-diagram: 77 | 78 | .. automodapi:: lsst.multiprofit.fit_source 79 | :no-inheritance-diagram: 80 | 81 | .. automodapi:: lsst.multiprofit.limits 82 | :no-inheritance-diagram: 83 | 84 | .. automodapi:: lsst.multiprofit.modeller 85 | :no-inheritance-diagram: 86 | 87 | .. automodapi:: lsst.multiprofit.plots 88 | :no-inheritance-diagram: 89 | 90 | .. automodapi:: lsst.multiprofit.priors 91 | :no-inheritance-diagram: 92 | 93 | .. automodapi:: lsst.multiprofit.psfmodel_utils 94 | :no-inheritance-diagram: 95 | 96 | .. automodapi:: lsst.multiprofit.transforms 97 | :no-inheritance-diagram: 98 | 99 | .. automodapi:: lsst.multiprofit.utils 100 | :no-inheritance-diagram: 101 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "lsst-versions >= 1.3.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "lsst-multiprofit" 7 | description = "Astronomical image and source model fitting code." 8 | license = {file = "LICENSE"} 9 | readme = "README.rst" 10 | authors = [ 11 | {name="Rubin Observatory Data Management", email="dm-admin@lists.lsst.org"}, 12 | ] 13 | classifiers = [ 14 | "Intended Audience :: Science/Research", 15 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", 16 | "Operating System :: OS Independent", 17 | "Programming Language :: Python :: 3", 18 | "Programming Language :: Python :: 3.10", 19 | "Programming Language :: Python :: 3.11", 20 | "Topic :: Scientific/Engineering :: Astronomy", 21 | ] 22 | keywords = [ 23 | "astronomy", 24 | "astrophysics", 25 | "fitting", 26 | "lsst", 27 | "models", 28 | "modeling", 29 | ] 30 | requires-python = ">=3.10.0" 31 | dependencies = [ 32 | "astropy", 33 | "lsst-gauss2d", 34 | "lsst-gauss2d-fit", 35 | "lsst-pex-config", 36 | "lsst-utils", 37 | "importlib_resources", 38 | "matplotlib", 39 | "numpy", 40 | "pydantic", 41 | "scipy", 42 | ] 43 | dynamic = ["version"] 44 | 45 | [project.urls] 46 | "Homepage" = "https://github.com/lsst-dm/multiprofit" 47 | 48 | [project.optional-dependencies] 49 | galsim = ["galsim"] 50 | test = [ 51 | "pytest", 52 | ] 53 | 54 | [tool.setuptools.packages.find] 55 | where = ["python"] 56 | 57 | [tool.setuptools.dynamic] 58 | version = { attr = "lsst_versions.get_lsst_version" } 59 | 60 | [tool.black] 61 | line-length = 110 62 | target-version = ["py311"] 63 | force-exclude = [ 64 | "examples/fithsc.py", 65 | ] 66 | 67 | [tool.isort] 68 | profile = "black" 69 | line_length = 110 70 | force_sort_within_sections = true 71 | 72 | [tool.ruff] 73 | exclude = [ 74 | "__init__.py", 75 | "examples/fithsc.py", 76 | "examples/test_utils.py", 77 | "tests/*.py", 78 | ] 79 | ignore = [ 80 | "N802", 81 | "N803", 82 | "N806", 83 | "N812", 84 | "N815", 85 | "N816", 86 | "N999", 87 | "D107", 88 | "D105", 89 | "D102", 90 | "D104", 91 | "D100", 92 | "D200", 93 | "D205", 94 | "D400", 95 | ] 96 | line-length = 110 97 | select = [ 98 | "E", # pycodestyle 99 | "F", # pycodestyle 100 | "N", # pep8-naming 101 | "W", # pycodestyle 102 | "D", # pydocstyle 103 | ] 104 | target-version = "py311" 105 | 106 | [tool.ruff.pycodestyle] 107 | max-doc-length = 79 108 | 109 | [tool.ruff.pydocstyle] 110 | convention = "numpy" 111 | 112 | [tool.numpydoc_validation] 113 | checks = [ 114 | "all", # All except the rules listed below. 115 | "ES01", # No extended summary required. 116 | "EX01", # Example section. 117 | "GL01", # Summary text can start on same line as """ 118 | "GL08", # Do not require docstring. 119 | "PR04", # numpydoc parameter types are redundant with type hints 120 | "RT01", # Unfortunately our @property trigger this. 121 | "RT02", # Does not want named return value. DM style says we do. 122 | "SA01", # See Also section. 123 | "SS05", # pydocstyle is better at finding infinitive verb. 124 | "SS06", # Summary can go into second line. 125 | ] 126 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/observationconfig.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import lsst.gauss2d as g2 23 | import lsst.gauss2d.fit as g2f 24 | import lsst.pex.config as pexConfig 25 | 26 | 27 | class CoordinateSystemConfig(pexConfig.Config): 28 | """Configuration for a gauss2d CoordinateSystem.""" 29 | 30 | dx1 = pexConfig.Field[float](doc="The x-axis pixel scale", optional=False, default=1.0) 31 | dy2 = pexConfig.Field[float](doc="The y-axis pixel scale", optional=False, default=1.0) 32 | x_min = pexConfig.Field[float]( 33 | doc="The x-axis coordinate of the bottom left corner", 34 | optional=False, 35 | default=0.0, 36 | ) 37 | y_min = pexConfig.Field[float]( 38 | doc="The y-axis coordinate of the bottom left corner", 39 | optional=False, 40 | default=0.0, 41 | ) 42 | 43 | def make_coordinate_system(self) -> g2.CoordinateSystem: 44 | return g2.CoordinateSystem(dx1=self.dx1, dy2=self.dy2, x_min=self.x_min, y_min=self.y_min) 45 | 46 | 47 | class ObservationConfig(pexConfig.Config): 48 | """Configuration for a gauss2d.fit Observation.""" 49 | 50 | band = pexConfig.Field[str](doc="The name of the band", optional=False, default="None") 51 | coordsys = pexConfig.ConfigField[CoordinateSystemConfig](doc="The coordinate system config") 52 | n_rows = pexConfig.Field[int](doc="The number of rows in the image") 53 | n_cols = pexConfig.Field[int](doc="The number of columns in the image") 54 | 55 | def make_observation(self) -> g2f.ObservationD: 56 | coordsys = self.coordsys.make_coordinate_system() if self.coordsys else None 57 | image = g2.ImageD(n_rows=self.n_rows, n_cols=self.n_cols, coordsys=coordsys) 58 | sigma_inv = g2.ImageD(n_rows=self.n_rows, n_cols=self.n_cols, coordsys=coordsys) 59 | mask = g2.ImageB(n_rows=self.n_rows, n_cols=self.n_cols, coordsys=coordsys) 60 | observation = g2f.ObservationD( 61 | image=image, 62 | sigma_inv=sigma_inv, 63 | mask_inv=mask, 64 | channel=g2f.Channel.get(self.band), 65 | ) 66 | return observation 67 | 68 | 69 | class PsfObservationConfig(ObservationConfig): 70 | """Configuration for a gauss2d.fit Observation used for PSF fitting.""" 71 | 72 | def validate(self): 73 | super().validate() 74 | if self.band != "None": 75 | raise ValueError("band must be None for PSF fitting") 76 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/priors.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import lsst.gauss2d.fit as g2f 23 | import lsst.pex.config as pexConfig 24 | import numpy as np 25 | 26 | from .transforms import transforms_ref 27 | 28 | __all__ = ["ShapePriorConfig", "get_hst_size_prior"] 29 | 30 | 31 | class ShapePriorConfig(pexConfig.Config): 32 | """Configuration for a shape prior.""" 33 | 34 | prior_axrat_mean = pexConfig.Field[float]( 35 | default=0.7, 36 | doc="Prior mean on axis ratio (prior ignored if not >0)", 37 | ) 38 | prior_axrat_stddev = pexConfig.Field[float]( 39 | default=0, 40 | doc="Prior std. dev. on axis ratio", 41 | ) 42 | prior_size_mean = pexConfig.Field[float]( 43 | default=1, 44 | doc="Prior std. dev. on size_major", 45 | ) 46 | prior_size_stddev = pexConfig.Field[float]( 47 | default=0, 48 | doc="Prior std. dev. on size_major (prior ignored if not >0)", 49 | ) 50 | 51 | def get_shape_prior(self, ellipse: g2f.ParametricEllipse) -> g2f.ShapePrior | None: 52 | use_prior_axrat = (self.prior_axrat_stddev > 0) and np.isfinite(self.prior_axrat_stddev) 53 | use_prior_size = (self.prior_size_stddev > 0) and np.isfinite(self.prior_size_stddev) 54 | 55 | if use_prior_axrat or use_prior_size: 56 | prior_size = ( 57 | g2f.ParametricGaussian1D( 58 | g2f.MeanParameterD(self.prior_size_mean, transform=transforms_ref["log10"]), 59 | g2f.StdDevParameterD(self.prior_size_stddev), 60 | ) 61 | if use_prior_size 62 | else None 63 | ) 64 | prior_axrat = ( 65 | g2f.ParametricGaussian1D( 66 | g2f.MeanParameterD(self.prior_axrat_mean, transform=transforms_ref["logit_axrat_prior"]), 67 | g2f.StdDevParameterD(self.prior_axrat_stddev), 68 | ) 69 | if use_prior_axrat 70 | else None 71 | ) 72 | return g2f.ShapePrior(ellipse, prior_size, prior_axrat) 73 | return None 74 | 75 | 76 | def get_hst_size_prior(mag_psf_i): 77 | """Return the mean and stddev for an HST-based size prior. 78 | 79 | The size is major axis half-light radius. 80 | 81 | Parameters 82 | ---------- 83 | mag_psf_i 84 | The i-band PSF magnitudes of the source(s). 85 | 86 | Notes 87 | ----- 88 | Return values are log10 scaled in units of arcseconds. 89 | The input should be a PSF mag because other magnitudes - even Gaussian - 90 | can be unreliable for low S/N (non-)detections. 91 | """ 92 | return 0.75 * (19 - np.clip(mag_psf_i, 10, 30)) / 6.5, 0.2 93 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/modelconfig.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import string 23 | from typing import Iterable 24 | 25 | import lsst.gauss2d.fit as g2f 26 | import lsst.pex.config as pexConfig 27 | 28 | from .componentconfig import Fluxes 29 | from .sourceconfig import SourceConfig 30 | 31 | 32 | class ModelConfig(pexConfig.Config): 33 | """Configuration for a gauss2d.fit Model.""" 34 | 35 | sources = pexConfig.ConfigDictField[str, SourceConfig](doc="The configuration for sources") 36 | 37 | @staticmethod 38 | def format_label(label: str, name_source: str) -> str: 39 | return string.Template(label).safe_substitute(name_source=name_source) 40 | 41 | def get_integral_label_default(self, sourceconfig: SourceConfig) -> str: 42 | prefix = "src: {name_source} " if self.has_prefix_source() else "" 43 | return f"{prefix}{sourceconfig.get_integral_label_default()}" 44 | 45 | def has_prefix_source(self) -> bool: 46 | return (len(self.sources) > 1) or next(iter(self.sources.keys())) 47 | 48 | def make_sources( 49 | self, 50 | component_group_fluxes_srcs: Iterable[list[list[Fluxes]]], 51 | label_integral: str | None = None, 52 | ) -> tuple[list[g2f.Source], list[g2f.Prior]]: 53 | n_src = len(self.sources) 54 | if component_group_fluxes_srcs is None or len(component_group_fluxes_srcs) != n_src: 55 | raise ValueError(f"{len(component_group_fluxes_srcs)=} != {n_src=}") 56 | 57 | sources = [] 58 | priors = [] 59 | for component_group_fluxes, (name_src, config_src) in zip( 60 | component_group_fluxes_srcs, self.sources.items() 61 | ): 62 | label_integral_src = label_integral if label_integral is not None else ( 63 | self.get_integral_label_default(config_src)) 64 | 65 | source, priors_src = config_src.make_source( 66 | component_group_fluxes=component_group_fluxes, 67 | label_integral=self.format_label(label=label_integral_src, name_source=name_src) 68 | ) 69 | sources.append(source) 70 | priors.extend(priors_src) 71 | 72 | return sources, priors 73 | 74 | def make_model( 75 | self, 76 | component_group_fluxes_srcs: Iterable[list[list[Fluxes]]], 77 | data: g2f.DataD, 78 | psf_models: list[g2f.PsfModel], 79 | label_integral: str | None = None, 80 | ) -> g2f.ModelD: 81 | sources, priors = self.make_sources( 82 | component_group_fluxes_srcs=component_group_fluxes_srcs, 83 | label_integral=label_integral, 84 | ) 85 | 86 | model = g2f.ModelD(data=data, psfmodels=psf_models, sources=sources, priors=priors) 87 | 88 | return model 89 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/asinhstretchsigned.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | 23 | import astropy.visualization as apvis 24 | import numpy as np 25 | 26 | # This is all hacked from astropy's AsinhStretch 27 | 28 | __all__ = ["AsinhStretchSigned", "SinhStretchSigned"] 29 | 30 | 31 | def _prepare(values: np.ndarray, clip: bool = True, out: np.ndarray | None = None): 32 | """Return clipped and/or copied values from input. 33 | 34 | Parameters 35 | ---------- 36 | values 37 | The values to copy/clip from. 38 | clip 39 | Whether to clip values to between 0 and 1 (inclusive). 40 | out 41 | An existing array to assign to. 42 | 43 | Returns 44 | ------- 45 | prepared 46 | The prepared values. 47 | """ 48 | if clip: 49 | return np.clip(values, 0.0, 1.0, out=out) 50 | else: 51 | if out is None: 52 | return np.array(values, copy=True) 53 | else: 54 | out[:] = np.asarray(values) 55 | return out 56 | 57 | 58 | class AsinhStretchSigned(apvis.BaseStretch): 59 | r""" 60 | A signed asinh stretch. 61 | 62 | The stretch is given by: 63 | 64 | .. math:: 65 | y = 0.5(1 + sign(x - 0.5)\frac{{\rm asinh}(2(x - 0.5) / a)}{{\rm asinh}(1 / a)}). 66 | 67 | Parameters 68 | ---------- 69 | a : float, optional 70 | The ``a`` parameter used in the above formula. The value of 71 | this parameter is where the asinh curve transitions from linear 72 | to logarithmic behavior, expressed as a fraction of the 73 | normalized image. Must be in the range between 0 and 1. 74 | Default is 0.1. 75 | """ # noqa: W505 76 | 77 | def __init__(self, a=0.1): 78 | super().__init__() 79 | self.a = a 80 | 81 | # [docs] 82 | def __call__(self, values, clip=True, out=None): 83 | values = _prepare(values, clip=clip, out=out) 84 | values *= 2 85 | values -= 1 86 | signs = np.sign(values) 87 | np.abs(values, out=values) 88 | np.true_divide(values, self.a, out=values) 89 | np.arcsinh(values, out=values) 90 | np.true_divide(values, np.arcsinh(1.0 / self.a), out=values) 91 | np.true_divide(1.0 + signs * values, 2.0, out=values) 92 | return values 93 | 94 | @property 95 | def inverse(self): 96 | """A stretch object that performs the inverse operation. 97 | 98 | Returns 99 | ------- 100 | inverse 101 | The inverse stretch. 102 | """ 103 | return SinhStretchSigned(a=1.0 / np.arcsinh(1.0 / self.a)) 104 | 105 | 106 | class SinhStretchSigned(apvis.BaseStretch): 107 | r""" 108 | A sinh stretch. 109 | 110 | The stretch is given by: 111 | 112 | .. math:: 113 | y = \frac{{\rm sinh}(x / a)}{{\rm sinh}(1 / a)} 114 | 115 | Parameters 116 | ---------- 117 | a : float, optional 118 | The ``a`` parameter used in the above formula. Default is 1/3. 119 | """ 120 | 121 | def __init__(self, a=1.0 / 3.0): 122 | super().__init__() 123 | self.a = a 124 | 125 | # [docs] 126 | 127 | def __call__(self, values, clip=True, out=None): 128 | values = _prepare(values, clip=clip, out=out) 129 | values *= 2.0 130 | values -= 1.0 131 | np.true_divide(values, self.a, out=values) 132 | np.sinh(values, out=values) 133 | np.true_divide(values, np.sinh(1.0 / self.a), out=values) 134 | values += 1.0 135 | values /= 2.0 136 | return values 137 | 138 | @property 139 | def inverse(self): 140 | """A stretch object that performs the inverse operation. 141 | 142 | Returns 143 | ------- 144 | inverse 145 | The inverse stretch. 146 | """ 147 | return AsinhStretchSigned(a=1.0 / np.sinh(1.0 / self.a)) 148 | -------------------------------------------------------------------------------- /tests/test_plots.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warrantyfluxes = u.ABmag.to(u.nanojansky, mags) of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import lsst.gauss2d.fit as g2f 23 | from lsst.multiprofit.componentconfig import CentroidConfig, GaussianComponentConfig, ParameterConfig 24 | from lsst.multiprofit.model_utils import make_psf_model_null 25 | from lsst.multiprofit.modelconfig import ModelConfig 26 | from lsst.multiprofit.observationconfig import CoordinateSystemConfig, ObservationConfig 27 | from lsst.multiprofit.plots import abs_mag_sol_lsst, bands_weights_lsst, plot_model_rgb 28 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 29 | import numpy as np 30 | import pytest 31 | 32 | sigma_inv = 1e4 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def channels() -> dict[str, g2f.Channel]: 37 | return {band: g2f.Channel.get(band) for band in bands_weights_lsst} 38 | 39 | 40 | @pytest.fixture(scope="module") 41 | def data(channels) -> g2f.DataD: 42 | n_rows, n_cols = 16, 21 43 | x_min, y_min = 0, 0 44 | 45 | dn_rows, dn_cols = 1, -2 46 | dx_min, dy_min = -2, 1 47 | 48 | observations = [] 49 | for idx, band in enumerate(channels): 50 | config = ObservationConfig( 51 | band=band, 52 | coordsys=CoordinateSystemConfig( 53 | x_min=x_min + idx*dx_min, 54 | y_min=y_min + idx*dy_min, 55 | ), 56 | n_rows=n_rows + idx*dn_rows, 57 | n_cols=n_cols + idx*dn_cols, 58 | ) 59 | observation = config.make_observation() 60 | observation.sigma_inv.fill(sigma_inv) 61 | observation.mask_inv.fill(1) 62 | observations.append(observation) 63 | return g2f.DataD(observations) 64 | 65 | 66 | @pytest.fixture(scope="module") 67 | def psf_model(): 68 | return make_psf_model_null() 69 | 70 | 71 | @pytest.fixture(scope="module") 72 | def psf_models(psf_model, channels) -> list[g2f.PsfModel]: 73 | return [psf_model]*len(channels) 74 | 75 | 76 | @pytest.fixture(scope="module") 77 | def model(channels, data, psf_models): 78 | fluxes_group = [{channels[band]: 10**(-0.4*(mag - 8.9)) for band, mag in abs_mag_sol_lsst.items()}] 79 | 80 | modelconfig = ModelConfig( 81 | sources={ 82 | 'src': SourceConfig( 83 | component_groups={ 84 | '': ComponentGroupConfig( 85 | centroids={"default": CentroidConfig( 86 | x=ParameterConfig(value_initial=6., fixed=True), 87 | y=ParameterConfig(value_initial=11., fixed=True), 88 | )}, 89 | components_gauss={ 90 | "": GaussianComponentConfig( 91 | rho=ParameterConfig(value_initial=0.1), 92 | size_x=ParameterConfig(value_initial=3.8), 93 | size_y=ParameterConfig(value_initial=5.1), 94 | ) 95 | }, 96 | ) 97 | } 98 | ), 99 | }, 100 | ) 101 | model = modelconfig.make_model([[fluxes_group]], data=data, psf_models=psf_models) 102 | model.setup_evaluators(g2f.EvaluatorMode.image) 103 | model.evaluate() 104 | rng = np.random.default_rng(1) 105 | for output, obs in zip(model.outputs, model.data): 106 | img = obs.image.data 107 | img.flat = output.data.flat + rng.standard_normal(img.size) / sigma_inv 108 | return model 109 | 110 | 111 | def test_plot_model_rgb(model): 112 | fig, ax, fig_gs, ax_gs, *_ = plot_model_rgb( 113 | model, minimum=0, stretch=0.15, Q=4, weights=bands_weights_lsst, plot_chi_hist=True, 114 | ) 115 | assert fig is not None 116 | assert ax is not None 117 | assert fig_gs is not None 118 | assert ax_gs is not None 119 | 120 | 121 | def test_plot_model_rgb_auto(model): 122 | fig, ax, *_ = plot_model_rgb( 123 | model, Q=6, weights=bands_weights_lsst, rgb_min_auto=True, rgb_stretch_auto=True, 124 | plot_singleband=False, plot_chi_hist=False, 125 | ) 126 | assert fig is not None 127 | assert ax is not None 128 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/fit_catalog.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from abc import ABC, abstractmethod 23 | from collections.abc import Iterable 24 | 25 | import astropy.units as u 26 | import lsst.pex.config as pexConfig 27 | import pydantic 28 | from pydantic.dataclasses import dataclass 29 | 30 | from .modeller import ModelFitConfig 31 | from .utils import ArbitraryAllowedConfig 32 | 33 | __all__ = ["CatalogExposureABC", "ColumnInfo", "CatalogFitterConfig"] 34 | 35 | 36 | class CatalogExposureABC(ABC): 37 | """Interface for catalog-exposure pairs.""" 38 | 39 | # TODO: add get_exposure (with Any return type?) 40 | 41 | @abstractmethod 42 | def get_catalog(self) -> Iterable: 43 | """Return a row-iterable catalog covering an exposure.""" 44 | 45 | 46 | @dataclass(frozen=True, kw_only=True, config=ArbitraryAllowedConfig) 47 | class ColumnInfo: 48 | """Metadata for a column in a catalog.""" 49 | 50 | dtype: str = pydantic.Field(title="Column data type name (numpy or otherwise)") 51 | key: str = pydantic.Field(title="Column key (name)") 52 | description: str = pydantic.Field("", title="Column description") 53 | unit: u.UnitBase | None = pydantic.Field(None, title="Column unit (astropy)") 54 | 55 | 56 | class CatalogFitterConfig(pexConfig.Config): 57 | """Configuration for generic MultiProFit fitting tasks.""" 58 | 59 | column_id = pexConfig.Field[str](default="id", doc="Catalog index column key") 60 | compute_errors = pexConfig.ChoiceField[str]( 61 | default="INV_HESSIAN_BESTFIT", 62 | doc="Whether/how to compute sqrt(variances) of each free parameter", 63 | allowed={ 64 | "NONE": "no errors computed", 65 | "INV_HESSIAN": "inverse hessian using noisy image as data", 66 | "INV_HESSIAN_BESTFIT": "inverse hessian using best-fit model as data", 67 | }, 68 | ) 69 | compute_errors_from_jacobian = pexConfig.Field[bool]( 70 | default=True, 71 | doc="Whether to estimate the Hessian from the Jacobian first, with finite differencing as a backup", 72 | ) 73 | compute_errors_no_covar = pexConfig.Field[bool]( 74 | default=True, 75 | doc="Whether to compute parameter errors independently, ignoring covariances", 76 | ) 77 | config_fit = pexConfig.ConfigField[ModelFitConfig](default=ModelFitConfig(), doc="Fitter configuration") 78 | fit_centroid = pexConfig.Field[bool](default=True, doc="Fit centroid parameters") 79 | fit_linear_init = pexConfig.Field[bool](default=True, doc="Fit linear parameters after initialization") 80 | fit_linear_final = pexConfig.Field[bool](default=True, doc="Fit linear parameters after optimization") 81 | flag_errors = pexConfig.DictField( 82 | default={}, 83 | keytype=str, 84 | itemtype=str, 85 | doc="Flag column names to set, keyed by name of exception to catch", 86 | ) 87 | prefix_column = pexConfig.Field[str](default="mpf_", doc="Column name prefix") 88 | 89 | def schema( 90 | self, 91 | bands: list[str] = None, 92 | ) -> list[ColumnInfo]: 93 | """Return the schema as an ordered list of columns. 94 | 95 | Parameters 96 | ---------- 97 | bands 98 | A list of band names to prefix band-dependent columns with. 99 | Band prefixes should not be used if None. 100 | 101 | Returns 102 | ------- 103 | schema 104 | An ordered list of ColumnInfo instances. 105 | """ 106 | schema = [ 107 | ColumnInfo(key=self.column_id, dtype="i8"), 108 | ColumnInfo(key="n_iter", dtype="i4"), 109 | ColumnInfo(key="time_eval", dtype="f8", unit=u.s), 110 | ColumnInfo(key="time_fit", dtype="f8", unit=u.s), 111 | ColumnInfo(key="time_full", dtype="f8", unit=u.s), 112 | ColumnInfo(key="chisq_red", dtype="f8"), 113 | ColumnInfo(key="unknown_flag", dtype="bool"), 114 | ] 115 | schema.extend([ColumnInfo(key=key, dtype="bool") for key in self.flag_errors.keys()]) 116 | # Subclasses should always write out centroids even if not fitting 117 | # They are helpful for reconstructing models 118 | return schema 119 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/psfmodel_utils.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import lsst.gauss2d.fit as g2f 23 | import numpy as np 24 | 25 | from .limits import limits_ref 26 | 27 | __all__ = ["make_psf_source"] 28 | 29 | 30 | # TODO: This function should be replaced with SourceConfig.make_source 31 | def make_psf_source( 32 | sigma_xs: list[float] | None = None, 33 | sigma_ys: list[float] | None = None, 34 | rhos: list[float] | None = None, 35 | fracs: list[float] | None = None, 36 | transforms: dict[str, g2f.TransformD] = None, 37 | limits_rho: g2f.LimitsD = None, 38 | ) -> g2f.Source: 39 | """Make a Gaussian mixture PSF source from parameter values. 40 | 41 | Parameters 42 | ---------- 43 | sigma_xs 44 | Gaussian sigma_x values. 45 | sigma_ys 46 | Gaussian sigma_y values. 47 | rhos 48 | Gaussian rho values. 49 | fracs 50 | Gaussian sigma_x values. 51 | transforms 52 | Dict of transforms by variable name (frac/rho/sigma). If not set, 53 | will default to Logit/LogitLimited/Log10, respectively. 54 | limits_rho 55 | Limits for rho parameters. Defaults to limits_ref['rho']. 56 | 57 | Returns 58 | ------- 59 | source 60 | A source model with Gaussians initialized as specified. 61 | 62 | Notes 63 | ----- 64 | Parameter lists must all be the same length. 65 | """ 66 | if limits_rho is None: 67 | limits_rho = limits_ref["rho"] 68 | if sigma_xs is None: 69 | sigma_xs = [1.5, 3.0] if sigma_ys is not None else sigma_ys 70 | if sigma_ys is None: 71 | sigma_ys = sigma_xs 72 | n_gaussians = len(sigma_xs) 73 | if n_gaussians == 0: 74 | raise ValueError(f"{n_gaussians=}!>0") 75 | if rhos is None: 76 | rhos = [0.0] * n_gaussians 77 | if fracs is None: 78 | fracs = np.arange(1, n_gaussians + 1) / n_gaussians 79 | if transforms is None: 80 | transforms = {} 81 | transforms_default = { 82 | "frac": transforms.get("frac", g2f.LogitTransformD()), 83 | "rho": transforms.get("rho", g2f.LogitLimitedTransformD(limits=limits_rho)), 84 | "sigma": transforms.get("sigma", g2f.Log10TransformD()), 85 | } 86 | for key, value in transforms_default.items(): 87 | if key not in transforms: 88 | transforms[key] = value 89 | 90 | if (len(sigma_ys) != n_gaussians) or (len(rhos) != n_gaussians) or (len(fracs) != n_gaussians): 91 | raise ValueError(f"{len(sigma_ys)=} and/or {len(rhos)=} and/or {len(fracs)=} != {n_gaussians=}") 92 | 93 | errors = [] 94 | for idx, (sigma_x, sigma_y, rho, frac) in enumerate(zip(sigma_xs, sigma_ys, rhos, fracs)): 95 | if not ((sigma_x >= 0) and (sigma_y >= 0)): 96 | errors.append(f"sigma_xs[{idx}]={sigma_x} and/or sigma_ys[{idx}]={sigma_y} !>=0") 97 | if not (limits_rho.check(rho)): 98 | errors.append(f"rhos[{idx}]={rho} !within({limits_rho=})") 99 | if not (frac >= 0): 100 | errors.append(f"fluxes[{idx}]={frac} !>0") 101 | if errors: 102 | raise ValueError("; ".join(errors)) 103 | fracs[-1] = 1.0 104 | 105 | components = [None] * n_gaussians 106 | cenx = g2f.CentroidXParameterD(0, limits=g2f.LimitsD(min=0, max=100)) 107 | ceny = g2f.CentroidYParameterD(0, limits=g2f.LimitsD(min=0, max=100)) 108 | centroid = g2f.CentroidParameters(cenx, ceny) 109 | 110 | n_last = n_gaussians - 1 111 | last = None 112 | 113 | for c in range(n_gaussians): 114 | is_last = c == n_last 115 | last = g2f.FractionalIntegralModel( 116 | [ 117 | ( 118 | g2f.Channel.NONE, 119 | g2f.ProperFractionParameterD(fracs[c], fixed=is_last, transform=transforms["frac"]), 120 | ) 121 | ], 122 | g2f.LinearIntegralModel([(g2f.Channel.NONE, g2f.IntegralParameterD(1.0, fixed=True))]) 123 | if (c == 0) 124 | else last, 125 | is_last, 126 | ) 127 | components[c] = g2f.GaussianComponent( 128 | g2f.GaussianParametricEllipse( 129 | g2f.SigmaXParameterD(sigma_xs[c], transform=transforms["sigma"]), 130 | g2f.SigmaYParameterD(sigma_ys[c], transform=transforms["sigma"]), 131 | g2f.RhoParameterD(rhos[c], transform=transforms["rho"]), 132 | ), 133 | centroid, 134 | last, 135 | ) 136 | return g2f.Source(components) 137 | -------------------------------------------------------------------------------- /tests/test_componentconfig.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import lsst.gauss2d.fit as g2f 23 | from lsst.multiprofit.componentconfig import ( 24 | EllipticalComponentConfig, 25 | GaussianComponentConfig, 26 | ParameterConfig, 27 | SersicComponentConfig, 28 | SersicIndexParameterConfig, 29 | ) 30 | from lsst.multiprofit.config import set_config_from_dict 31 | from lsst.multiprofit.utils import get_params_uniq 32 | import numpy as np 33 | import pytest 34 | 35 | 36 | @pytest.fixture(scope="module") 37 | def centroid_limits(): 38 | limits = g2f.LimitsD(min=-np.Inf, max=np.Inf) 39 | return limits 40 | 41 | 42 | @pytest.fixture(scope="module") 43 | def centroid(centroid_limits): 44 | cenx = g2f.CentroidXParameterD(0, limits=centroid_limits, fixed=True) 45 | ceny = g2f.CentroidYParameterD(0, limits=centroid_limits, fixed=True) 46 | centroid = g2f.CentroidParameters(cenx, ceny) 47 | return centroid 48 | 49 | 50 | @pytest.fixture(scope="module") 51 | def channels(): 52 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")} 53 | 54 | 55 | def test_EllipticalComponentConfig(): 56 | config = EllipticalComponentConfig() 57 | config2 = EllipticalComponentConfig() 58 | set_config_from_dict(config2, config.toDict()) 59 | assert config == config2 60 | 61 | 62 | def test_GaussianComponentConfig(centroid): 63 | config = GaussianComponentConfig( 64 | rho=ParameterConfig(value_initial=0), 65 | size_x=ParameterConfig(value_initial=1.4), 66 | size_y=ParameterConfig(value_initial=1.6), 67 | ) 68 | channel = g2f.Channel.NONE 69 | component_data1 = config.make_component( 70 | centroid=centroid, 71 | integral_model=g2f.FractionalIntegralModel( 72 | [(channel, g2f.ProperFractionParameterD(0.5, fixed=False))], 73 | model=config.make_linear_integral_model({channel: 1.0}), 74 | ), 75 | ) 76 | component_data2 = config.make_component( 77 | centroid=centroid, 78 | integral_model=g2f.FractionalIntegralModel( 79 | [(channel, g2f.ProperFractionParameterD(1.0, fixed=True))], 80 | model=component_data1.integral_model, 81 | is_final=True, 82 | ), 83 | ) 84 | components = (component_data1, component_data2) 85 | n_components = len(components) 86 | for idx, component_data in enumerate(components): 87 | component = component_data.component 88 | assert component.centroid is centroid 89 | assert len(component_data.priors) == 0 90 | fluxes = list(get_params_uniq(component, nonlinear=False)) 91 | assert len(fluxes) == 1 92 | assert isinstance(fluxes[0], g2f.IntegralParameterD) 93 | fracs = [param for param in get_params_uniq(component, linear=False) 94 | if isinstance(param, g2f.ProperFractionParameterD)] 95 | assert len(fracs) == (idx + (idx == 0) - (idx == n_components)) 96 | 97 | 98 | def test_SersicConfig(centroid, channels): 99 | rho, size_x, size_y, sersic_index = -0.3, 1.4, 1.6, 3.2 100 | config = SersicComponentConfig( 101 | rho=ParameterConfig(value_initial=rho), 102 | size_x=ParameterConfig(value_initial=size_x), 103 | size_y=ParameterConfig(value_initial=size_y), 104 | sersic_index=SersicIndexParameterConfig(value_initial=sersic_index), 105 | ) 106 | fluxes = { 107 | channel: 1.0 + idx 108 | for idx, channel in enumerate(channels.values()) 109 | } 110 | integral_model = config.make_linear_integral_model(fluxes) 111 | component_data = config.make_component( 112 | centroid=centroid, 113 | integral_model=integral_model, 114 | ) 115 | assert component_data.component is not None 116 | # As long as there's a default Sersic index prior 117 | assert len(component_data.priors) == 1 118 | params = get_params_uniq(component_data.component) 119 | values_init = { 120 | g2f.RhoParameterD: rho, 121 | g2f.ReffXParameterD: size_x, 122 | g2f.ReffYParameterD: size_y, 123 | g2f.SersicIndexParameterD: sersic_index, 124 | } 125 | fluxes_label = { 126 | config.format_label(config.get_integral_label_default(), name_channel=channel.name): 127 | fluxes[channel] for channel in fluxes.keys() 128 | } 129 | for param in params: 130 | if isinstance(param, g2f.IntegralParameterD): 131 | assert fluxes_label[param.label] == param.value 132 | elif value_init := values_init.get(param.__class__): 133 | assert param.value == value_init 134 | -------------------------------------------------------------------------------- /tests/test_modelconfig.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import lsst.gauss2d as g2 23 | import lsst.gauss2d.fit as g2f 24 | from lsst.multiprofit.componentconfig import ( 25 | CentroidConfig, 26 | GaussianComponentConfig, 27 | ParameterConfig, 28 | SersicComponentConfig, 29 | SersicIndexParameterConfig, 30 | ) 31 | from lsst.multiprofit.modelconfig import ModelConfig 32 | from lsst.multiprofit.observationconfig import ObservationConfig 33 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 34 | import numpy as np 35 | import pytest 36 | 37 | 38 | @pytest.fixture(scope="module") 39 | def channels() -> dict[str, g2f.Channel]: 40 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")} 41 | 42 | 43 | @pytest.fixture(scope="module") 44 | def data(channels) -> g2f.DataD: 45 | config = ObservationConfig(n_rows=13, n_cols=19) 46 | observations = [] 47 | for band in channels: 48 | config.band = band 49 | observations.append(config.make_observation()) 50 | return g2f.DataD(observations) 51 | 52 | 53 | @pytest.fixture(scope="module") 54 | def psf_model(): 55 | rho, size_x, size_y = 0.25, 1.6, 1.2 56 | drho, dsize_x, dsize_y = -0.4, 1.1, 1.9 57 | 58 | n_components = 3 59 | flux_total = 2.*(n_components + 1) 60 | fluxes = [x/flux_total for x in range(1, 1 + n_components)] 61 | 62 | config = SourceConfig( 63 | component_groups={ 64 | 'src': ComponentGroupConfig( 65 | components_gauss={ 66 | str(idx): GaussianComponentConfig( 67 | rho=ParameterConfig(value_initial=rho + idx*drho), 68 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x), 69 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y), 70 | ) 71 | for idx in range(n_components) 72 | }, 73 | ) 74 | }, 75 | ) 76 | config.validate() 77 | channel = g2f.Channel.NONE 78 | psf_model, priors = config.make_psf_model( 79 | [ 80 | [ 81 | {channel: flux} for flux in fluxes 82 | ], 83 | ], 84 | ) 85 | return psf_model 86 | 87 | 88 | @pytest.fixture(scope="module") 89 | def psf_models(psf_model, channels) -> list[g2f.PsfModel]: 90 | return [psf_model]*len(channels) 91 | 92 | 93 | @pytest.fixture(scope="module") 94 | def modelconfig_fluxes(channels): 95 | rho, size_x, size_y, sersicn, flux = 0.4, 1.5, 1.9, 0.5, 4.7 96 | drho, dsize_x, dsize_y, dsersicn, dflux = -0.9, 2.5, 5.4, 2.8, 13.9 97 | 98 | components_sersic = {} 99 | fluxes_mix = [] 100 | for idx, name in enumerate(("PS", "Sersic")): 101 | components_sersic[name] = SersicComponentConfig( 102 | rho=ParameterConfig(value_initial=rho + idx*drho), 103 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x), 104 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y), 105 | sersic_index=SersicIndexParameterConfig( 106 | value_initial=sersicn + idx * dsersicn, 107 | fixed=idx == 0, 108 | prior_mean=None, 109 | ), 110 | ) 111 | fluxes_comp = { 112 | channel: flux + idx_channel*dflux*idx 113 | for idx_channel, channel in enumerate(channels.values()) 114 | } 115 | fluxes_mix.append(fluxes_comp) 116 | 117 | modelconfig = ModelConfig( 118 | sources={ 119 | 'src': SourceConfig( 120 | component_groups={ 121 | 'mix': ComponentGroupConfig( 122 | centroids={ 123 | "default": CentroidConfig( 124 | x=ParameterConfig(value_initial=15.8, fixed=True), 125 | y=ParameterConfig(value_initial=14.3, fixed=False), 126 | ), 127 | }, 128 | components_sersic=components_sersic, 129 | ), 130 | } 131 | ), 132 | }, 133 | ) 134 | return modelconfig, fluxes_mix 135 | 136 | 137 | def test_ModelConfig(modelconfig_fluxes, data, psf_models): 138 | modelconfig, fluxes = modelconfig_fluxes 139 | model = modelconfig.make_model([[fluxes]], data=data, psf_models=psf_models) 140 | assert model is not None 141 | assert model.data is data 142 | for observation in model.data: 143 | observation.sigma_inv.fill(1.) 144 | observation.mask_inv.fill(1) 145 | 146 | # Set the outputs to new images that refer to the existing data 147 | # because obs.image will not return a holding pointer 148 | outputs = [[g2.ImageD(obs.image.data)] for obs in model.data] 149 | model.setup_evaluators(g2f.EvaluatorMode.image, outputs=outputs) 150 | model.evaluate() 151 | model.setup_evaluators(g2f.EvaluatorMode.loglike) 152 | assert np.sum(model.evaluate()) == 0 153 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/transforms.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from typing import Any 23 | 24 | import lsst.gauss2d.fit as g2f 25 | import numpy as np 26 | 27 | from .limits import limits_ref 28 | 29 | __all__ = ["get_logit_limited", "verify_transform_derivative", "transforms_ref"] 30 | 31 | 32 | def get_logit_limited(lower: float, upper: float, factor: float = 1.0, name: str | None = None): 33 | """Get a logit transform stretched to span a different range than [0,1]. 34 | 35 | Parameters 36 | ---------- 37 | lower 38 | The lower limit of the range to span. 39 | upper 40 | The upper limit of the range to span. 41 | factor 42 | A multiplicative factor to apply to the transformed result. 43 | name 44 | A descriptive name for the transform. 45 | 46 | Returns 47 | ------- 48 | transform 49 | A modified logit transform as specified. 50 | """ 51 | return g2f.LogitLimitedTransformD( 52 | limits=g2f.LimitsD( 53 | min=lower, 54 | max=upper, 55 | name=name 56 | if name is not None 57 | else f"LogitLimitedTransformD(min={lower}, max={upper}, factor={factor})", 58 | ), 59 | factor=factor, 60 | ) 61 | 62 | 63 | def verify_transform_derivative( 64 | transform: g2f.TransformD, 65 | value_transformed: float, 66 | derivative: float | None = None, 67 | abs_max: float = 1e6, 68 | dx_ratios=None, 69 | **kwargs: Any, 70 | ): 71 | """Verify that the derivative of a transform class is correct. 72 | 73 | Parameters 74 | ---------- 75 | transform 76 | The transform to verify. 77 | value_transformed 78 | The un-transformed value at which to verify the transform. 79 | derivative 80 | The nominal derivative at value_transformed. 81 | Must equal transform.derivative(value_transformed). 82 | abs_max 83 | The x value to skip verification if np.abs(derivative) > x. 84 | dx_ratios 85 | Iterable of signed ratios to set dx for finite differencing. 86 | dx = value*ratio (untransformed). Only used if dx is None. 87 | **kwargs 88 | Keyword arguments to pass to np.isclose when comparing derivatives to 89 | finite differences. 90 | 91 | Raises 92 | ------ 93 | RuntimeError 94 | Raised if the transform derivative doesn't match finite differences 95 | within the specified tolerances. 96 | 97 | Notes 98 | ----- 99 | derivative should only be specified if it has previously been computed for 100 | the exact value_transformed, to avoid re-computing it unnecessarily. 101 | 102 | Default dx_ratios are [1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14]. 103 | Verification will test all ratios until at least one passes. 104 | """ 105 | # Skip testing finite differencing if the derivative is very large 106 | # This might happen e.g. near the limits of the transformation 107 | # TODO: Check if better finite differencing is possible for large values 108 | if abs_max is None: 109 | abs_max = 1e8 110 | value = transform.reverse(value_transformed) 111 | if derivative is None: 112 | derivative = transform.derivative(value) 113 | is_close = np.abs(derivative) > abs_max 114 | if not is_close: 115 | if dx_ratios is None: 116 | dx_ratios = [1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14] 117 | for ratio in dx_ratios: 118 | dx = value * ratio 119 | fin_diff = (transform.forward(value + dx) - value_transformed) / dx 120 | if not np.isfinite(fin_diff): 121 | fin_diff = -(transform.forward(value - dx) - value_transformed) / dx 122 | is_close = np.isclose(derivative, fin_diff, **kwargs) 123 | if is_close: 124 | break 125 | if not is_close: 126 | raise RuntimeError( 127 | f"{transform} derivative={derivative:.8e} != last " 128 | f"finite diff.={fin_diff:8e} with dx={dx} dx_abs_max={abs_max}" 129 | ) 130 | 131 | 132 | transforms_ref = { 133 | "none": g2f.UnitTransformD(), 134 | "log": g2f.LogTransformD(), 135 | "log10": g2f.Log10TransformD(), 136 | "inverse": g2f.InverseTransformD(), 137 | "logit": g2f.LogitTransformD(), 138 | "logit_fluxfrac": get_logit_limited( 139 | limits_ref["fluxfrac"].min, 140 | limits_ref["fluxfrac"].max, 141 | name=f"ref_logit_fluxfrac[{limits_ref['fluxfrac'].min}, {limits_ref['fluxfrac'].max}]", 142 | ), 143 | "logit_rho": get_logit_limited( 144 | limits_ref["rho"].min, 145 | limits_ref["rho"].max, 146 | name=f"ref_logit_rho[{limits_ref['rho'].min}, {limits_ref['rho'].max}]", 147 | ), 148 | "logit_axrat": get_logit_limited(1e-4, 1, name="ref_logit_axrat[1e-4, 1]"), 149 | "logit_axrat_prior": get_logit_limited(-0.001, 1.001, name="ref_logit_axrat_prior[-0.001, 1.001]"), 150 | "logit_sersic": get_logit_limited(0.49, 6.01, name="ref_logit_sersic[0.49, 6.01]"), 151 | } 152 | -------------------------------------------------------------------------------- /tests/test_sourceconfig.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import lsst.gauss2d.fit as g2f 23 | from lsst.multiprofit.componentconfig import ( 24 | GaussianComponentConfig, 25 | ParameterConfig, 26 | SersicComponentConfig, 27 | SersicIndexParameterConfig, 28 | ) 29 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 30 | from lsst.multiprofit.utils import get_params_uniq 31 | import numpy as np 32 | import pytest 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def centroid_limits(): 37 | limits = g2f.LimitsD(min=-np.Inf, max=np.Inf) 38 | return limits 39 | 40 | 41 | @pytest.fixture(scope="module") 42 | def centroid(centroid_limits): 43 | cenx = g2f.CentroidXParameterD(0, limits=centroid_limits, fixed=True) 44 | ceny = g2f.CentroidYParameterD(0, limits=centroid_limits, fixed=True) 45 | centroid = g2f.CentroidParameters(cenx, ceny) 46 | return centroid 47 | 48 | 49 | @pytest.fixture(scope="module") 50 | def channels(): 51 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")} 52 | 53 | 54 | def test_ComponentGroupConfig(centroid): 55 | with pytest.raises(ValueError) as exc: 56 | config = ComponentGroupConfig( 57 | components_gauss={"x": GaussianComponentConfig()}, 58 | components_sersic={"x": SersicComponentConfig()}, 59 | ) 60 | config.validate() 61 | 62 | 63 | def test_SourceConfig_base(): 64 | with pytest.raises(ValueError) as exc: 65 | config = SourceConfig() 66 | config.validate() 67 | 68 | with pytest.raises(ValueError) as exc: 69 | config = SourceConfig(component_groups={}) 70 | config.validate() 71 | 72 | 73 | def test_SourceConfig_fractional(centroid): 74 | rho, size_x, size_y = -0.3, 1.4, 1.6 75 | drho, dsize_x, dsize_y = 0.5, 1.6, 1.3 76 | 77 | n_components = 2 78 | config = SourceConfig( 79 | component_groups={ 80 | 'src': ComponentGroupConfig( 81 | components_gauss={ 82 | str(idx): GaussianComponentConfig( 83 | rho=ParameterConfig(value_initial=rho + idx*drho), 84 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x), 85 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y), 86 | ) 87 | for idx in range(n_components) 88 | }, 89 | is_fractional=True, 90 | ) 91 | }, 92 | ) 93 | config.validate() 94 | channel = g2f.Channel.NONE 95 | psf_model, priors = config.make_psf_model( 96 | [ 97 | [ 98 | {channel: 1.0}, 99 | {channel: 0.5}, 100 | ] 101 | ], 102 | ) 103 | assert len(priors) == 0 104 | assert len(psf_model.components) == n_components 105 | 106 | 107 | def test_SourceConfig_linear(centroid, channels): 108 | rho, size_x, size_y, sersicn, flux = 0.4, 1.5, 1.9, 0.5, 4.7 109 | drho, dsize_x, dsize_y, dsersicn, dflux = -0.9, 2.5, 5.4, 2.8, 13.9 110 | 111 | names = ("PS", "Sersic") 112 | config = SourceConfig( 113 | component_groups={ 114 | 'src': ComponentGroupConfig( 115 | components_sersic={ 116 | name: SersicComponentConfig( 117 | rho=ParameterConfig(value_initial=rho + idx*drho), 118 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x), 119 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y), 120 | sersic_index=SersicIndexParameterConfig( 121 | value_initial=sersicn + idx * dsersicn, 122 | fixed=idx == 0, 123 | prior_mean=None, 124 | ), 125 | ) 126 | for idx, name in enumerate(names) 127 | } 128 | ), 129 | } 130 | ) 131 | fluxes = [ 132 | { 133 | channel: flux + idx_channel*dflux*idx_comp 134 | for idx_channel, channel in enumerate(channels.values()) 135 | } 136 | for idx_comp in range(len(config.component_groups["src"].components_sersic)) 137 | ] 138 | source, priors = config.make_source([fluxes]) 139 | assert len(priors) == 0 140 | for idx, component in enumerate(source.components): 141 | params = get_params_uniq(component) 142 | values_init = { 143 | g2f.RhoParameterD: rho + idx*drho, 144 | g2f.ReffXParameterD: size_x + idx*dsize_x, 145 | g2f.ReffYParameterD: size_y + idx*dsize_y, 146 | g2f.SersicIndexParameterD: sersicn + idx*dsersicn, 147 | } 148 | for name_group, component_group in config.component_groups.items(): 149 | fluxes_comp = fluxes[idx] 150 | name_comp = names[idx] 151 | config_comp = component_group.components_sersic[name_comp] 152 | fluxes_label = { 153 | config.format_label( 154 | component_group.format_label( 155 | label=config_comp.format_label(label=config.get_integral_label_default(), 156 | name_channel=channel.name), 157 | name_component=name_comp, 158 | ), 159 | name_group=name_group, 160 | ): fluxes_comp[channel] 161 | for channel in channels.values() 162 | } 163 | for param in params: 164 | if isinstance(param, g2f.IntegralParameterD): 165 | assert fluxes_label[param.label] == param.value 166 | elif value_init := values_init.get(param.__class__): 167 | assert param.value == value_init 168 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/fit_bootstrap_model.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | from functools import cached_property 23 | import logging 24 | from typing import Any, Mapping, Sequence 25 | 26 | import astropy 27 | import lsst.gauss2d.fit as g2f 28 | import lsst.pex.config as pexConfig 29 | import numpy as np 30 | import pydantic 31 | from pydantic.dataclasses import dataclass 32 | 33 | from .config import set_config_from_dict 34 | from .fit_psf import CatalogExposurePsfABC, CatalogPsfFitterConfig, CatalogPsfFitterConfigData 35 | from .fit_source import CatalogExposureSourcesABC, CatalogSourceFitterABC, CatalogSourceFitterConfigData 36 | from .model_utils import make_image_gaussians 37 | from .observationconfig import ObservationConfig 38 | from .utils import FrozenArbitraryAllowedConfig, get_params_uniq 39 | 40 | __all__ = [ 41 | "CatalogBootstrapConfig", 42 | "CatalogExposurePsfBootstrap", 43 | "CatalogExposureSourcesBootstrap", 44 | "CatalogPsfBootstrapConfig", 45 | "CatalogSourceBootstrapConfig", 46 | "CatalogSourceFitterBootstrap", 47 | "NoisyObservationConfig", 48 | ] 49 | 50 | 51 | class CatalogBootstrapConfig(pexConfig.Config): 52 | """Configuration for a bootstrap source catalog fitter.""" 53 | 54 | n_sources = pexConfig.Field[int](doc="Number of sources", default=1) 55 | 56 | @cached_property 57 | def catalog(self): 58 | catalog = astropy.table.Table({"id": np.arange(self.n_sources)}) 59 | return catalog 60 | 61 | 62 | class ObservationNoiseConfig(pexConfig.Config): 63 | """Configuration for noise to be added to an Observation. 64 | 65 | The background level is in user-defined flux units, should be multiplied 66 | by the gain to obtain counts. 67 | """ 68 | 69 | background = pexConfig.Field[float](doc="Background flux per pixel", default=1e-4) 70 | gain = pexConfig.Field[float](doc="Multiplicative factor to convert flux to counts", default=1.0) 71 | 72 | 73 | class NoisyObservationConfig(ObservationConfig, ObservationNoiseConfig): 74 | """Configuration for an observation with noise.""" 75 | 76 | 77 | class NoisyPsfObservationConfig(ObservationConfig, ObservationNoiseConfig): 78 | """Configuration for a PSF observation with noise.""" 79 | 80 | 81 | class CatalogPsfBootstrapConfig(CatalogBootstrapConfig): 82 | """Configuration for a catalog of noisy PSF observations for bootstrapping. 83 | 84 | Each row is a stacked and normalized image of any number of point sources. 85 | """ 86 | 87 | observation = pexConfig.ConfigField[NoisyPsfObservationConfig]( 88 | doc="The PSF image configuration", 89 | default=NoisyPsfObservationConfig, 90 | ) 91 | 92 | 93 | class CatalogSourceBootstrapConfig(CatalogBootstrapConfig): 94 | """Configuration for a catalog of noisy source observations 95 | for bootstrapping. 96 | 97 | Each row is a PSF-convolved observation of the sources in one band. 98 | """ 99 | 100 | observation = pexConfig.ConfigField[NoisyObservationConfig]( 101 | doc="The source image configuration", 102 | default=NoisyObservationConfig, 103 | ) 104 | 105 | 106 | @dataclass(kw_only=True, frozen=True, config=FrozenArbitraryAllowedConfig) 107 | class CatalogExposurePsfBootstrap(CatalogExposurePsfABC, CatalogPsfFitterConfigData): 108 | """Dataclass for a PSF-convolved bootstrap fitter.""" 109 | 110 | config_boot: CatalogPsfBootstrapConfig = pydantic.Field(title="The configuration for bootstrapping") 111 | 112 | @cached_property 113 | def image(self) -> np.ndarray: 114 | psf_model_init = self.config.make_psf_model() 115 | # A hacky way to initialize the psf_model property to the same values 116 | # TODO: Include this functionality in fit_psf.py 117 | for param_init, param in zip(get_params_uniq(psf_model_init), get_params_uniq(self.psf_model)): 118 | param.value = param_init.value 119 | image = make_image_gaussians( 120 | psf_model_init.gaussians(g2f.Channel.NONE), 121 | n_rows=self.config_boot.observation.n_rows, 122 | n_cols=self.config_boot.observation.n_cols, 123 | ) 124 | return image.data 125 | 126 | def get_catalog(self) -> astropy.table.Table: 127 | return self.config_boot.catalog 128 | 129 | def get_psf_image( 130 | self, source: astropy.table.Row | Mapping[str, Any], config: CatalogPsfFitterConfig | None = None 131 | ) -> np.ndarray: 132 | rng = np.random.default_rng(source["id"]) 133 | image = self.image 134 | config_obs = self.config_boot.observation 135 | return image + rng.standard_normal(image.shape)*np.sqrt( 136 | (image + config_obs.background)/config_obs.gain) 137 | 138 | def __post_init__(self): 139 | self.config_boot.freeze() 140 | 141 | 142 | @dataclass(kw_only=True, frozen=True, config=FrozenArbitraryAllowedConfig) 143 | class CatalogExposureSourcesBootstrap(CatalogExposureSourcesABC): 144 | """A CatalogExposure for bootstrap fitting of source catalogs.""" 145 | 146 | config_boot: CatalogSourceBootstrapConfig = pydantic.Field( 147 | title="A CatalogSourceBootstrapConfig to be frozen") 148 | table_psf_fits: astropy.table.Table = pydantic.Field(title="PSF fit parameters for the catalog") 149 | 150 | @cached_property 151 | def channel(self): 152 | channel = g2f.Channel.get(self.config_boot.observation.band) 153 | return channel 154 | 155 | def get_catalog(self) -> astropy.table.Table: 156 | return self.config_boot.catalog 157 | 158 | def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel: 159 | psf_model = self.psf_model_data.psf_model 160 | self.psf_model_data.init_psf_model(self.table_psf_fits[params["id"]]) 161 | return psf_model 162 | 163 | def get_source_observation(self, source: Mapping[str, Any]) -> g2f.ObservationD: 164 | obs = self.config_boot.observation.make_observation() 165 | return obs 166 | 167 | def __post_init__(self): 168 | config_dict = self.table_psf_fits.meta["config"] 169 | config = CatalogPsfFitterConfig() 170 | set_config_from_dict(config, config_dict) 171 | config_data = CatalogPsfFitterConfigData(config=config) 172 | object.__setattr__(self, "psf_model_data", config_data) 173 | 174 | 175 | @dataclass(kw_only=True, frozen=True, config=FrozenArbitraryAllowedConfig) 176 | class CatalogSourceFitterBootstrap(CatalogSourceFitterABC): 177 | """A catalog fitter that bootstraps a single model. 178 | 179 | This fitter generates a different noisy image of the specified model for 180 | each row. The resulting catalog can be used to examine performance and 181 | statistics of the best-fit parameters. 182 | """ 183 | 184 | def get_model_radec(self, source: Mapping[str, Any], cen_x: float, cen_y: float) -> tuple[float, float]: 185 | return float(cen_x), float(cen_y) 186 | 187 | def initialize_model( 188 | self, 189 | model: g2f.ModelD, 190 | source: Mapping[str, Any], 191 | catexps: list[CatalogExposureSourcesABC], 192 | values_init: Mapping[g2f.ParameterD, float] | None = None, 193 | centroid_pixel_offset: float = 0, 194 | ): 195 | if values_init is None: 196 | values_init = {} 197 | min_x, max_x = np.Inf, -np.Inf 198 | min_y, max_y = np.Inf, -np.Inf 199 | for idx_obs, observation in enumerate(model.data): 200 | x_min = observation.image.coordsys.x_min 201 | min_x = min(min_x, x_min) 202 | max_x = max(max_x, x_min + observation.image.n_cols*observation.image.coordsys.dx1) 203 | y_min = observation.image.coordsys.y_min 204 | min_y = min(min_y, y_min) 205 | max_y = max(max_y, y_min + observation.image.n_rows*observation.image.coordsys.dy2) 206 | 207 | cen_x = (min_x + max_x) / 2.0 208 | cen_y = (min_y + max_y) / 2.0 209 | 210 | params_limits_init = { 211 | g2f.CentroidXParameterD: (cen_x, (min_x, max_x)), 212 | g2f.CentroidYParameterD: (cen_y, (min_y, max_y)), 213 | } 214 | 215 | params_free = get_params_uniq(model, fixed=False) 216 | for param in params_free: 217 | value_init, limits_new = params_limits_init.get( 218 | type(param), 219 | (values_init.get(param), None) 220 | ) 221 | if value_init is not None: 222 | param.value = value_init 223 | if limits_new: 224 | param.limits.min = -np.Inf 225 | param.limits.max = limits_new[1] 226 | param.limits.min = limits_new[0] 227 | 228 | # Should be done in get_source_observation, but it gets called first 229 | # ... and therefore does not have the initialization above 230 | # Also, this must be done per-iteration because PSF parameters vary 231 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.image) 232 | model.evaluate() 233 | 234 | # The offset is to keep the rng seed different from the PSF image seed 235 | # It doesn't really need to be so large but it's reasonably safe 236 | rng = np.random.default_rng(source["id"] + 10000000) 237 | 238 | for idx_obs, observation in enumerate(model.data): 239 | config_obs = catexps[idx_obs].config_boot.observation 240 | image, sigma_inv = observation.image, observation.sigma_inv 241 | # TODO: This doesn't raise or warn or anything if setting to 242 | # the wrong (differently sized output). That seems dangerous. 243 | image.data.flat = model.outputs[idx_obs].data.flat 244 | sigma_inv.data.flat = np.sqrt((image.data + config_obs.background)/config_obs.gain) 245 | image.data.flat += sigma_inv.data.flat * rng.standard_normal(image.data.size) 246 | sigma_inv.data.flat = (1.0 / sigma_inv.data).flat 247 | # This is mandatory because C++ construction does no initialization 248 | # (could instead initialize in get_source_observation) 249 | # TODO: Do some timings to see which is more efficient 250 | observation.mask_inv.data.flat = 1 251 | 252 | def validate_fit_inputs( 253 | self, 254 | catalog_multi: Sequence, 255 | catexps: list[CatalogExposureSourcesABC], 256 | config_data: CatalogSourceFitterConfigData = None, 257 | logger: logging.Logger = None, 258 | **kwargs: Any, 259 | ) -> None: 260 | errors = [] 261 | for idx, catexp in enumerate(catexps): 262 | if not ( 263 | (config_boot := getattr(catexp, "config_boot", None)) 264 | and isinstance(config_boot, CatalogSourceBootstrapConfig) 265 | ): 266 | errors.append(f"catexps[{idx=}] = {catexp} does not have a config_boot attr of type" 267 | f"{CatalogSourceBootstrapConfig}") 268 | -------------------------------------------------------------------------------- /tests/test_fit_bootstrap_model.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import astropy.table 23 | import lsst.gauss2d.fit as g2f 24 | from lsst.multiprofit.componentconfig import ( 25 | CentroidConfig, 26 | FluxFractionParameterConfig, 27 | FluxParameterConfig, 28 | GaussianComponentConfig, 29 | ParameterConfig, 30 | SersicComponentConfig, 31 | SersicIndexParameterConfig, 32 | ) 33 | from lsst.multiprofit.fit_bootstrap_model import ( 34 | CatalogExposurePsfBootstrap, 35 | CatalogExposureSourcesBootstrap, 36 | CatalogPsfBootstrapConfig, 37 | CatalogSourceBootstrapConfig, 38 | CatalogSourceFitterBootstrap, 39 | NoisyObservationConfig, 40 | NoisyPsfObservationConfig, 41 | ) 42 | from lsst.multiprofit.fit_psf import CatalogPsfFitter, CatalogPsfFitterConfig, CatalogPsfFitterConfigData 43 | from lsst.multiprofit.fit_source import CatalogSourceFitterConfig, CatalogSourceFitterConfigData 44 | from lsst.multiprofit.modelconfig import ModelConfig 45 | from lsst.multiprofit.modeller import ModelFitConfig 46 | from lsst.multiprofit.observationconfig import CoordinateSystemConfig 47 | from lsst.multiprofit.plots import ErrorValues, plot_catalog_bootstrap, plot_loglike 48 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 49 | from lsst.multiprofit.utils import get_params_uniq 50 | import numpy as np 51 | import pytest 52 | 53 | shape_img = (23, 27) 54 | reff_x_src, reff_y_src, rho_src, nser_src = 2.5, 3.6, -0.25, 2.0 55 | 56 | # TODO: These can be parameterized; should they be? 57 | compute_errors_no_covar = True 58 | compute_errors_from_jacobian = True 59 | include_point_source = False 60 | n_sources = 3 61 | # Set to True for interactive debugging (but don't commit) 62 | plot = False 63 | 64 | 65 | @pytest.fixture(scope="module") 66 | def channels(): 67 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")} 68 | 69 | 70 | @pytest.fixture(scope="module") 71 | def config_fitter_psfs(channels) -> dict[g2f.Channel, CatalogExposurePsfBootstrap]: 72 | config_datas = {} 73 | for idx, (band, channel) in enumerate(channels.items()): 74 | n_rows = 17 + idx*2 75 | n_cols = 15 + idx*2 76 | config = CatalogPsfFitterConfig( 77 | model=SourceConfig( 78 | component_groups={"": ComponentGroupConfig( 79 | centroids={ 80 | "default": CentroidConfig( 81 | x=ParameterConfig(value_initial=n_cols/2.), 82 | y=ParameterConfig(value_initial=n_rows/2.), 83 | ), 84 | }, 85 | components_gauss={ 86 | "comp1": GaussianComponentConfig( 87 | flux=FluxParameterConfig(value_initial=1.0, fixed=True), 88 | fluxfrac=FluxFractionParameterConfig(value_initial=0.5, fixed=False), 89 | size_x=ParameterConfig(value_initial=1.5 + 0.1*idx), 90 | size_y=ParameterConfig(value_initial=1.7 + 0.13*idx), 91 | rho=ParameterConfig(value_initial=-0.035 - 0.007*idx), 92 | ), 93 | "comp2": GaussianComponentConfig( 94 | size_x=ParameterConfig(value_initial=3.1 + 0.24*idx), 95 | size_y=ParameterConfig(value_initial=2.7 + 0.16*idx), 96 | rho=ParameterConfig(value_initial=0.06 + 0.012*idx), 97 | fluxfrac=FluxFractionParameterConfig(value_initial=1.0, fixed=True), 98 | ), 99 | }, 100 | is_fractional=True, 101 | )} 102 | ), 103 | ) 104 | config_boot = CatalogPsfBootstrapConfig( 105 | observation=NoisyPsfObservationConfig(n_rows=n_rows, n_cols=n_cols, gain=1e5), 106 | n_sources=n_sources, 107 | ) 108 | config_data = CatalogExposurePsfBootstrap(config=config, config_boot=config_boot) 109 | config_datas[channel] = config_data 110 | 111 | return config_datas 112 | 113 | 114 | @pytest.fixture(scope="module") 115 | def config_fitter_source(channels) -> CatalogSourceFitterConfigData: 116 | config = CatalogSourceFitterConfig( 117 | config_fit=ModelFitConfig(fit_linear_iter=3), 118 | config_model=ModelConfig( 119 | sources={ 120 | "": SourceConfig( 121 | component_groups={ 122 | "": ComponentGroupConfig( 123 | components_gauss={ 124 | "ps": GaussianComponentConfig( 125 | flux=FluxParameterConfig(value_initial=1000), 126 | rho=ParameterConfig(value_initial=0, fixed=True), 127 | size_x=ParameterConfig(value_initial=0, fixed=True), 128 | size_y=ParameterConfig(value_initial=0, fixed=True), 129 | ) 130 | } if include_point_source else {}, 131 | components_sersic={ 132 | "ser": SersicComponentConfig( 133 | prior_size_mean=reff_y_src, 134 | prior_size_stddev=1.0, 135 | prior_axrat_mean=reff_x_src / reff_y_src, 136 | prior_axrat_stddev=0.2, 137 | flux=FluxParameterConfig(value_initial=5000), 138 | rho=ParameterConfig(value_initial=rho_src), 139 | size_x=ParameterConfig(value_initial=reff_x_src), 140 | size_y=ParameterConfig(value_initial=reff_y_src), 141 | sersic_index=SersicIndexParameterConfig(fixed=False, value_initial=1.0), 142 | ), 143 | } 144 | ) 145 | } 146 | ), 147 | }, 148 | ), 149 | convert_cen_xy_to_radec=False, 150 | compute_errors_no_covar=compute_errors_no_covar, 151 | compute_errors_from_jacobian=compute_errors_from_jacobian, 152 | ) 153 | config_data = CatalogSourceFitterConfigData( 154 | channels=tuple(channels.values()), 155 | config=config, 156 | ) 157 | return config_data 158 | 159 | 160 | @pytest.fixture(scope="module") 161 | def tables_psf_fits(config_fitter_psfs) -> dict[g2f.Channel, astropy.table.Table]: 162 | fitter = CatalogPsfFitter() 163 | fits = { 164 | channel: fitter.fit( 165 | catexp=config_fitter_psf, 166 | config_data=config_fitter_psf, 167 | ) 168 | for channel, config_fitter_psf in config_fitter_psfs.items() 169 | } 170 | return fits 171 | 172 | 173 | @pytest.fixture(scope="module") 174 | def config_data_sources( 175 | config_fitter_psfs, tables_psf_fits, 176 | ) -> dict[g2f.Channel, CatalogExposureSourcesBootstrap]: 177 | config_datas = {} 178 | for idx, (channel, config_fitter_psf) in enumerate(config_fitter_psfs.items()): 179 | table_psf_fits = tables_psf_fits[channel] 180 | n_rows = shape_img[0] + idx*2 181 | n_cols = shape_img[1] + idx*2 182 | config_boot = CatalogSourceBootstrapConfig( 183 | observation=NoisyObservationConfig( 184 | n_rows=n_rows, n_cols=n_cols, band=channel.name, background=100, 185 | coordsys=CoordinateSystemConfig(x_min=-2 + 3*idx, y_min=5 - 4*idx), 186 | ), 187 | n_sources=n_sources, 188 | ) 189 | config_data = CatalogExposureSourcesBootstrap( 190 | config_boot=config_boot, 191 | table_psf_fits=table_psf_fits, 192 | ) 193 | config_datas[channel] = config_data 194 | 195 | return config_datas 196 | 197 | 198 | def test_fit_psf(config_fitter_psfs, tables_psf_fits): 199 | for band, results in tables_psf_fits.items(): 200 | assert len(results) == n_sources 201 | assert np.sum(results["mpf_psf_unknown_flag"]) == 0 202 | assert all(np.isfinite(list(results[0].values()))) 203 | config_data_psf = config_fitter_psfs[band] 204 | psf_model_init = config_data_psf.config.make_psf_model() 205 | psfdata = CatalogPsfFitterConfigData(config=config_data_psf.config) 206 | psf_model_fit = psfdata.psf_model 207 | psfdata.init_psf_model(results[0]) 208 | assert len(psf_model_init.components) == len(psf_model_fit.components) 209 | params_init = psf_model_init.parameters() 210 | params_fit = psf_model_fit.parameters() 211 | assert len(params_init) == len(params_fit) 212 | for p_init, p_meas in zip(params_init, params_fit): 213 | assert p_meas.fixed == p_init.fixed 214 | if p_meas.fixed: 215 | assert p_init.value == p_meas.value 216 | else: 217 | # TODO: come up with better (noise-dependent) thresholds here 218 | if isinstance(p_init, g2f.IntegralParameterD): 219 | atol, rtol = 0, 0.02 220 | elif isinstance(p_init, g2f.ProperFractionParameterD): 221 | atol, rtol = 0.1, 0.01 222 | elif isinstance(p_init, g2f.RhoParameterD): 223 | atol, rtol = 0.05, 0.1 224 | else: 225 | atol, rtol = 0.01, 0.1 226 | assert np.isclose(p_init.value, p_meas.value, atol=atol, rtol=rtol) 227 | 228 | 229 | def test_fit_source(config_fitter_source, config_data_sources): 230 | fitter = CatalogSourceFitterBootstrap() 231 | # We don't have or need multiband input catalog, so just pretend the first one is 232 | catalog_multi = next(iter(config_data_sources.values())).get_catalog() 233 | catexps = list(config_data_sources.values()) 234 | results = fitter.fit(catalog_multi=catalog_multi, catexps=catexps, config_data=config_fitter_source) 235 | assert len(results) == n_sources 236 | 237 | model = fitter.get_model( 238 | 0, catalog_multi=catalog_multi, catexps=catexps, config_data=config_fitter_source, results=results 239 | ) 240 | 241 | model_sources, priors = config_fitter_source.config.make_sources( 242 | channels=list(config_data_sources.keys()) 243 | ) 244 | model_true = g2f.ModelD(data=model.data, psfmodels=model.psfmodels, sources=model_sources) 245 | fitter.initialize_model(model_true, catalog_multi[0], catexps=catexps) 246 | params_true = tuple(param.value for param in get_params_uniq(model_true, fixed=False)) 247 | plot_catalog_bootstrap( 248 | results, histtype="step", paramvals_ref=params_true, plot_total_fluxes=True, plot_colors=True 249 | ) 250 | if plot: 251 | import matplotlib.pyplot as plt 252 | 253 | plt.show() 254 | 255 | assert np.sum(results["mpf_unknown_flag"]) == 0 256 | assert all(np.isfinite(list(results[0].values()))) 257 | 258 | variances = [] 259 | for return_negative in (False, True): 260 | variances.append( 261 | fitter.modeller.compute_variances( 262 | model, transformed=False, options=g2f.HessianOptions(return_negative=return_negative), 263 | use_diag_only=True, 264 | ) 265 | ) 266 | assert np.all(variances[-1] > 0) 267 | if return_negative: 268 | variances = np.array(variances) 269 | variances[variances <= 0] = 0 270 | variances = list(variances) 271 | 272 | # Bootstrap errors 273 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.image) 274 | model.evaluate() 275 | img_data_old = [] 276 | for obs, output in zip(model.data, model.outputs): 277 | img_data_old.append(obs.image.data.copy()) 278 | img = obs.image.data 279 | img.flat = output.data.flat 280 | options_hessian = g2f.HessianOptions(return_negative=return_negative) 281 | variances_bootstrap = fitter.modeller.compute_variances(model, transformed=False, options=options_hessian) 282 | variances_bootstrap_diag = fitter.modeller.compute_variances( 283 | model, transformed=False, options=options_hessian, use_diag_only=True 284 | ) 285 | for obs, img_datum_old in zip(model.data, img_data_old): 286 | obs.image.data.flat = img_datum_old.flat 287 | variances_jac = fitter.modeller.compute_variances(model, transformed=False) 288 | variances_jac_diag = fitter.modeller.compute_variances(model, transformed=False, use_diag_only=True) 289 | 290 | errors_plot = { 291 | "inv_hess": ErrorValues(values=np.sqrt(variances[0]), kwargs_plot={"linestyle": "-", "color": "r"}), 292 | "-inv_hess": ErrorValues(values=np.sqrt(variances[1]), kwargs_plot={"linestyle": "--", "color": "r"}), 293 | "inv_jac": ErrorValues(values=np.sqrt(variances_jac), kwargs_plot={"linestyle": "-.", "color": "r"}), 294 | "boot_hess": ErrorValues( 295 | values=np.sqrt(variances_bootstrap), kwargs_plot={"linestyle": "-", "color": "b"} 296 | ), 297 | "boot_diag": ErrorValues( 298 | values=np.sqrt(variances_bootstrap_diag), kwargs_plot={"linestyle": "--", "color": "b"} 299 | ), 300 | "boot_jac_diag": ErrorValues( 301 | values=np.sqrt(variances_jac_diag), kwargs_plot={"linestyle": "-.", "color": "m"} 302 | ), 303 | } 304 | fig, ax = plot_loglike(model, errors=errors_plot, values_reference=params_true) 305 | if plot: 306 | plt.tight_layout() 307 | plt.show() 308 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/sourceconfig.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import string 23 | 24 | import lsst.gauss2d.fit as g2f 25 | import lsst.pex.config as pexConfig 26 | 27 | from .componentconfig import ( 28 | CentroidConfig, 29 | EllipticalComponentConfig, 30 | Fluxes, 31 | GaussianComponentConfig, 32 | SersicComponentConfig, 33 | ) 34 | 35 | __all__ = [ 36 | "ComponentConfigs", "CentroidConfig", "ComponentGroupConfig", "SourceConfig", 37 | ] 38 | 39 | ComponentConfigs = dict[str, EllipticalComponentConfig] 40 | 41 | 42 | class ComponentGroupConfig(pexConfig.Config): 43 | """Configuration for a group of gauss2d.fit Components. 44 | 45 | ComponentGroups may have linked CentroidParameters 46 | and IntegralModels, e.g. if is_fractional is True. 47 | 48 | Notes 49 | ----- 50 | Gaussian components are generated first, then Sersic. 51 | 52 | This config class has no equivalent in gauss2dfit, because gauss2dfit 53 | model parameter dependencies implicitly. This class implements only a 54 | subset of typical use cases, i.e. PSFs sharing a fractional integral 55 | model with fixed unit flux, and galaxies/PSF components sharing a single 56 | common centroid. 57 | If greater flexibility in linking parameter values is needed, 58 | users must assemble their own gauss2dfit models directly. 59 | """ 60 | 61 | centroids = pexConfig.ConfigDictField[str, CentroidConfig]( 62 | doc="Centroids by key, which can be a component name or 'default'." 63 | "The 'default' key-value pair must be specified if it is needed.", 64 | default={"default": CentroidConfig()}, 65 | ) 66 | # TODO: Change this to just one EllipticalComponentConfig field 67 | # when pex_config supports derived types in ConfigDictField 68 | # (possibly DM-41049) 69 | components_gauss = pexConfig.ConfigDictField[str, GaussianComponentConfig]( 70 | doc="Gaussian Components in the source", 71 | optional=False, 72 | default={}, 73 | ) 74 | components_sersic = pexConfig.ConfigDictField[str, SersicComponentConfig]( 75 | doc="Sersic Components in the component mixture", 76 | optional=False, 77 | default={}, 78 | ) 79 | is_fractional = pexConfig.Field[bool](doc="Whether the integral_model is fractional", default=False) 80 | transform_fluxfrac_name = pexConfig.Field[str]( 81 | doc="The name of the reference transform for flux parameters", 82 | default="logit_fluxfrac", 83 | optional=True, 84 | ) 85 | transform_flux_name = pexConfig.Field[str]( 86 | doc="The name of the reference transform for flux parameters", 87 | default="log10", 88 | optional=True, 89 | ) 90 | 91 | @staticmethod 92 | def format_label(label: str, name_component: str) -> str: 93 | return string.Template(label).safe_substitute(name_component=name_component) 94 | 95 | @staticmethod 96 | def get_integral_label_default() -> str: 97 | return "comp: ${name_component} " + EllipticalComponentConfig.get_integral_label_default() 98 | 99 | def get_component_configs(self) -> ComponentConfigs: 100 | component_configs: ComponentConfigs = dict(self.components_gauss) 101 | for name, component in self.components_sersic.items(): 102 | component_configs[name] = component 103 | return component_configs 104 | 105 | @staticmethod 106 | def get_fluxes_default( 107 | channels: tuple[g2f.Channel], component_configs: ComponentConfigs, is_fractional: bool = False, 108 | ) -> list[Fluxes]: 109 | if len(component_configs) == 0: 110 | raise ValueError("Must provide at least one ComponentConfig") 111 | fluxes = [] 112 | component_configs_iter = tuple(component_configs.values())[:len(component_configs) - is_fractional] 113 | for idx, component_config in enumerate(component_configs_iter): 114 | if is_fractional: 115 | if idx == 0: 116 | value = component_config.flux.value_initial 117 | fluxes.append({channel: value for channel in channels}) 118 | value = component_config.fluxfrac.value_initial 119 | fluxes.append({channel: value for channel in channels}) 120 | else: 121 | value = component_config.flux.value_initial 122 | fluxes.append({channel: value for channel in channels}) 123 | return fluxes 124 | 125 | def make_components( 126 | self, 127 | component_fluxes: list[Fluxes], 128 | label_integral: str | None = None, 129 | ) -> tuple[list[g2f.Component], list[g2f.Prior]]: 130 | """Make a list of gauss2d.fit.Component from this configuration. 131 | 132 | Parameters 133 | ---------- 134 | component_fluxes 135 | A list of Fluxes to populate an appropriate 136 | `gauss2d.fit.IntegralModel` with. 137 | If self.is_fractional, the first item in the list must be 138 | total fluxes while the remainder are fractions (the final 139 | fraction is always fixed at 1.0 and must not be provided). 140 | label_integral 141 | A label to apply to integral parameters. Can reference the 142 | relevant component name with ${name_component}}. 143 | 144 | Returns 145 | ------- 146 | componentdata 147 | An appropriate ComponentData including the initialized component. 148 | """ 149 | component_configs = self.get_component_configs() 150 | fluxes_first = component_fluxes[0] 151 | channels = fluxes_first.keys() 152 | fluxes_all = (component_fluxes[1:] + [None]) if self.is_fractional else component_fluxes 153 | if len(fluxes_all) != len(component_configs): 154 | raise ValueError(f"{len(fluxes_all)=} != {len(component_configs)=}") 155 | priors = [] 156 | idx_final = len(component_configs) - 1 157 | components = [] 158 | last = None 159 | 160 | centroid_default = None 161 | for idx, (fluxes_component, (name_component, config_comp)) in enumerate( 162 | zip(fluxes_all, component_configs.items()) 163 | ): 164 | label_integral_comp = self.format_label( 165 | label_integral if label_integral is not None else ( 166 | config_comp.get_integral_label_default() 167 | ), 168 | name_component=name_component, 169 | ) 170 | 171 | if self.is_fractional: 172 | if idx == 0: 173 | last = config_comp.make_linear_integral_model( 174 | fluxes=fluxes_first, 175 | label_integral=label_integral_comp, 176 | ) 177 | 178 | is_final = idx == idx_final 179 | if is_final: 180 | params_frac = [ 181 | (channel, g2f.ProperFractionParameterD(1.0, fixed=True)) 182 | for channel in channels 183 | ] 184 | else: 185 | if fluxes_component.keys() != channels: 186 | raise ValueError(f"{name_component=} {fluxes_component=}") 187 | params_frac = [ 188 | ( 189 | channel, 190 | config_comp.make_fluxfrac_parameter(value=fluxfrac), 191 | ) for channel, fluxfrac in fluxes_component.items() 192 | ] 193 | 194 | integral_model = g2f.FractionalIntegralModel( 195 | params_frac, 196 | model=last, 197 | is_final=is_final, 198 | ) 199 | # TODO: Omitting this crucial step should raise but doesn't 200 | # There shouldn't be two IntegralModels with the same last 201 | # especially not one is_final and one not 202 | last = integral_model 203 | else: 204 | integral_model = config_comp.make_linear_integral_model( 205 | fluxes_component, 206 | label_integral=label_integral_comp, 207 | ) 208 | 209 | centroid = self.centroids.get(name_component) 210 | if not centroid: 211 | if centroid_default is None: 212 | centroid_default = self.centroids["default"].make_centroid() 213 | centroid = centroid_default 214 | componentdata = config_comp.make_component( 215 | centroid=centroid, 216 | integral_model=integral_model, 217 | ) 218 | components.append(componentdata.component) 219 | priors.extend(componentdata.priors) 220 | return components, priors 221 | 222 | def validate(self): 223 | super().validate() 224 | errors = [] 225 | components: ComponentConfigs = dict(self.components_gauss) 226 | 227 | for name, component in self.components_sersic.items(): 228 | if name in components: 229 | errors.append( 230 | f"key={name} cannot be used in both self.components_gauss and self.components_sersic" 231 | ) 232 | components[name] = component 233 | 234 | keys = set(self.centroids.keys()) 235 | has_default = "default" in keys 236 | for name in components.keys(): 237 | if name in keys: 238 | keys.remove(name) 239 | elif not has_default: 240 | errors.append(f"component {name=} has no entry in self.centroids and default not specified") 241 | if errors: 242 | newline = "\n" 243 | raise ValueError(f"ComponentMixtureConfig has validation errors:\n{newline.join(errors)}") 244 | 245 | 246 | class SourceConfig(pexConfig.Config): 247 | """Configuration for a gauss2d.fit Source. 248 | 249 | Sources may contain components with distinct centroids that may be linked 250 | by a prior (e.g. a galaxy + AGN + star clusters), 251 | although such priors are not yet implemented. 252 | """ 253 | 254 | component_groups = pexConfig.ConfigDictField[str, ComponentGroupConfig]( 255 | doc="Components in the source", 256 | optional=False, 257 | ) 258 | 259 | def _make_components_priors( 260 | self, 261 | component_group_fluxes: list[list[Fluxes]], 262 | label_integral: str, 263 | validate_psf: bool = False, 264 | ) -> [list[g2f.Component], list[g2f.Prior]]: 265 | if len(component_group_fluxes) != len(self.component_groups): 266 | raise ValueError(f"{len(component_group_fluxes)=} != {len(self.component_groups)=}") 267 | components = [] 268 | priors = [] 269 | if validate_psf: 270 | keys_expected = tuple((g2f.Channel.NONE,)) 271 | for component_fluxes, (name_group, component_group) in zip( 272 | component_group_fluxes, self.component_groups.items() 273 | ): 274 | if validate_psf: 275 | for idx, fluxes_comp in enumerate(component_fluxes): 276 | keys = tuple(fluxes_comp.keys()) 277 | if keys != keys_expected: 278 | raise ValueError( 279 | f"{name_group=} comp[{idx}] {keys=} != {keys_expected=} with {validate_psf=}" 280 | ) 281 | 282 | components_i, priors_i = component_group.make_components( 283 | component_fluxes=component_fluxes, 284 | label_integral=self.format_label(label=label_integral, name_group=name_group), 285 | ) 286 | components.extend(components_i) 287 | priors.extend(priors_i) 288 | 289 | return components, priors 290 | 291 | @staticmethod 292 | def format_label(label: str, name_group: str) -> str: 293 | return string.Template(label).safe_substitute(name_group=name_group) 294 | 295 | def get_component_configs(self) -> ComponentConfigs: 296 | has_prefix_group = self.has_prefix_group() 297 | component_configs = {} 298 | for name_group, config_group in self.component_groups.items(): 299 | prefix_group = f"{name_group}_" if has_prefix_group else "" 300 | for name_comp, component_config in config_group.get_component_configs().items(): 301 | component_configs[f"{prefix_group}{name_comp}"] = component_config 302 | return component_configs 303 | 304 | def get_integral_label_default(self) -> str: 305 | prefix = "mix: ${name_group} " if self.has_prefix_group() else "" 306 | return f"{prefix}{ComponentGroupConfig.get_integral_label_default()}" 307 | 308 | def has_prefix_group(self) -> bool: 309 | return (len(self.component_groups) > 1) or next(iter(self.component_groups.keys())) 310 | 311 | def make_source( 312 | self, 313 | component_group_fluxes: list[list[Fluxes]], 314 | label_integral: str | None = None, 315 | ) -> [g2f.Source, list[g2f.Prior]]: 316 | """Make a gauss2d.fit.Source from this configuration. 317 | 318 | Parameters 319 | ---------- 320 | component_group_fluxes 321 | A list of Fluxes for each of the self.component_groups to use 322 | when calling make_components. 323 | label_integral 324 | A label to apply to integral parameters. Can reference the 325 | relevant component mixture name with ${name_group}. 326 | 327 | Returns 328 | ------- 329 | source 330 | An appropriate gauss2d.fit.Source. 331 | priors 332 | A list of priors from all constituent components. 333 | """ 334 | if label_integral is None: 335 | label_integral = self.get_integral_label_default() 336 | components, priors = self._make_components_priors( 337 | component_group_fluxes=component_group_fluxes, 338 | label_integral=label_integral, 339 | ) 340 | source = g2f.Source(components) 341 | return source, priors 342 | 343 | def make_psf_model( 344 | self, 345 | component_group_fluxes: list[list[Fluxes]], 346 | label_integral: str | None = None, 347 | ) -> [g2f.PsfModel, list[g2f.Prior]]: 348 | """Make a gauss2d.fit.PsfModel from this configuration. 349 | 350 | This method will validate that the arguments make a valid PSF model, 351 | i.e. with a unity total flux, and only one config for the none band. 352 | 353 | Parameters 354 | ---------- 355 | component_group_fluxes 356 | A list of CentroidFluxes for each of the self.component_groups 357 | when calling make_components. 358 | label_integral 359 | A label to apply to integral parameters. Can reference the 360 | relevant component mixture name with ${name_group}. 361 | 362 | Returns 363 | ------- 364 | psf_model 365 | An appropriate gauss2d.fit.PSfModel. 366 | priors 367 | A list of priors from all constituent components. 368 | """ 369 | if label_integral is None: 370 | label_integral = f"PSF {self.get_integral_label_default()}" 371 | components, priors = self._make_components_priors( 372 | component_group_fluxes=component_group_fluxes, 373 | label_integral=label_integral, 374 | validate_psf=True, 375 | ) 376 | model = g2f.PsfModel(components=components) 377 | 378 | return model, priors 379 | 380 | def validate(self): 381 | super().validate() 382 | if not self.component_groups: 383 | raise ValueError("Must have at least one componentgroup") 384 | -------------------------------------------------------------------------------- /python/lsst/multiprofit/componentconfig.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import string 23 | 24 | import lsst.gauss2d.fit as g2f 25 | import lsst.pex.config as pexConfig 26 | import pydantic 27 | from pydantic.dataclasses import dataclass 28 | 29 | from .limits import limits_ref 30 | from .priors import ShapePriorConfig 31 | from .transforms import transforms_ref 32 | from .utils import FrozenArbitraryAllowedConfig 33 | 34 | __all__ = [ 35 | "ParameterConfig", 36 | "FluxFractionParameterConfig", 37 | "FluxParameterConfig", 38 | "CentroidConfig", 39 | "ComponentData", 40 | "Fluxes", 41 | "EllipticalComponentConfig", 42 | "GaussianComponentConfig", 43 | "SersicIndexParameterConfig", 44 | "SersicComponentConfig", 45 | ] 46 | 47 | 48 | class ParameterConfig(pexConfig.Config): 49 | """Configuration for a parameter.""" 50 | 51 | fixed = pexConfig.Field[bool](default=False, doc="Whether parameter is fixed or not (free)") 52 | value_initial = pexConfig.Field[float](default=0, doc="Initial value") 53 | 54 | 55 | class FluxParameterConfig(ParameterConfig): 56 | """Configuration for flux parameters (IntegralParameterD). 57 | 58 | The safest initial value for a flux is 1.0, because if it's set to zero, 59 | linear fitting will not work correctly initially. 60 | """ 61 | 62 | def setDefaults(self): 63 | super().setDefaults() 64 | self.value_initial = 1.0 65 | 66 | 67 | class FluxFractionParameterConfig(ParameterConfig): 68 | """Configuration for flux fraction parameters (ProperFractionParameterD). 69 | 70 | The safest initial value for a flux fraction is 0.5, because if it's set 71 | to one, downstream fractions will be zero, while if it's set to zero, 72 | linear fitting will not work correctly initially. 73 | """ 74 | 75 | def setDefaults(self): 76 | super().setDefaults() 77 | self.value_initial = 1.0 78 | 79 | 80 | class CentroidConfig(pexConfig.Config): 81 | """Configuration for a component centroid.""" 82 | 83 | x = pexConfig.ConfigField[ParameterConfig](doc="The x-axis centroid configuration") 84 | y = pexConfig.ConfigField[ParameterConfig](doc="The y-axis centroid configuration") 85 | 86 | def make_centroid(self) -> g2f.CentroidParameters: 87 | cen_x, cen_y = ( 88 | type_param(config.value_initial, fixed=config.fixed, limits=g2f.LimitsD()) 89 | for (config, type_param) in ( 90 | (self.x, g2f.CentroidXParameterD), (self.y, g2f.CentroidYParameterD) 91 | ) 92 | ) 93 | centroid = g2f.CentroidParameters(x=cen_x, y=cen_y) 94 | return centroid 95 | 96 | 97 | @dataclass(kw_only=True, frozen=True, config=FrozenArbitraryAllowedConfig) 98 | class ComponentData: 99 | """Dataclass for a Component config.""" 100 | 101 | component: g2f.Component = pydantic.Field(title="The component instance") 102 | integral_model: g2f.IntegralModel = pydantic.Field(title="The component's integral_model") 103 | priors: list[g2f.Prior] = pydantic.Field(title="The priors associated with the component") 104 | 105 | 106 | Fluxes = dict[g2f.Channel, float] 107 | 108 | 109 | class EllipticalComponentConfig(ShapePriorConfig): 110 | """Configuration for an elliptically-symmetric component. 111 | 112 | This class can be initialized but cannot implement make_component. 113 | """ 114 | 115 | fluxfrac = pexConfig.ConfigField[FluxFractionParameterConfig]( 116 | doc="Fractional flux parameter(s) config", 117 | default=None, 118 | ) 119 | flux = pexConfig.ConfigField[FluxParameterConfig]( 120 | doc="Flux parameter(s) config", 121 | default=FluxParameterConfig, 122 | ) 123 | 124 | rho = pexConfig.ConfigField[ParameterConfig](doc="Rho parameter config") 125 | size_x = pexConfig.ConfigField[ParameterConfig](doc="x-axis size parameter config") 126 | size_y = pexConfig.ConfigField[ParameterConfig](doc="y-axis size parameter config") 127 | transform_flux_name = pexConfig.Field[str]( 128 | doc="The name of the reference transform for flux parameters", 129 | default="log10", 130 | optional=True, 131 | ) 132 | transform_fluxfrac_name = pexConfig.Field[str]( 133 | doc="The name of the reference transform for flux fraction parameters", 134 | default="logit_fluxfrac", 135 | optional=True, 136 | ) 137 | transform_rho_name = pexConfig.Field[str]( 138 | doc="The name of the reference transform for rho parameters", 139 | default="logit_rho", 140 | optional=True, 141 | ) 142 | transform_size_name = pexConfig.Field[str]( 143 | doc="The name of the reference transform for size parameters", 144 | default="log10", 145 | optional=True, 146 | ) 147 | 148 | def format_label(self, label: str, name_channel: str) -> str: 149 | """Format a label for a band-dependent parameter. 150 | 151 | Parameters 152 | ---------- 153 | label 154 | The label to format. 155 | name_channel 156 | The name of the channel to format with. 157 | 158 | Returns 159 | ------- 160 | label_formmated 161 | The formatted label. 162 | """ 163 | label_formatted = string.Template(label).safe_substitute( 164 | type_component=self.get_type_name(), name_channel=name_channel, 165 | ) 166 | return label_formatted 167 | 168 | @staticmethod 169 | def get_integral_label_default() -> str: 170 | """Return the default integral label.""" 171 | return "${type_component} ${name_channel}-band" 172 | 173 | def get_size_label(self) -> str: 174 | """Return the label for the component's size parameters.""" 175 | raise NotImplementedError("EllipticalComponent does not implement get_size_label") 176 | 177 | def get_type_name(self) -> str: 178 | """Return a descriptive component name.""" 179 | raise NotImplementedError("EllipticalComponent does not implement get_type_name") 180 | 181 | def get_transform_fluxfrac(self) -> g2f.TransformD | None: 182 | return transforms_ref[self.transform_fluxfrac_name] if self.transform_fluxfrac_name else None 183 | 184 | def get_transform_flux(self) -> g2f.TransformD | None: 185 | return transforms_ref[self.transform_flux_name] if self.transform_flux_name else None 186 | 187 | def get_transform_rho(self) -> g2f.TransformD | None: 188 | return transforms_ref[self.transform_rho_name] if self.transform_rho_name else None 189 | 190 | def get_transform_size(self) -> g2f.TransformD | None: 191 | return transforms_ref[self.transform_size_name] if self.transform_size_name else None 192 | 193 | def make_component( 194 | self, 195 | centroid: g2f.CentroidParameters, 196 | integral_model: g2f.IntegralModel, 197 | ) -> ComponentData: 198 | """Make a Component reflecting the current configuration. 199 | 200 | Parameters 201 | ---------- 202 | centroid 203 | Centroid parameters for the component. 204 | integral_model 205 | The integral_model for this component. 206 | 207 | Returns 208 | ------- 209 | component_data 210 | An appropriate ComponentData including the initialized component. 211 | 212 | Notes 213 | ----- 214 | The default `gauss2d.fit.LinearIntegralModel` can be populated with 215 | unit fluxes (`gauss2d.fit.IntegralParameterD` instances) to prepare 216 | for linear least squares fitting. 217 | """ 218 | raise NotImplementedError("EllipticalComponent cannot not implement make_component") 219 | 220 | def make_gaussianparametricellipse(self) -> g2f.GaussianParametricEllipse: 221 | transform_size = self.get_transform_size() 222 | transform_rho = self.get_transform_rho() 223 | ellipse = g2f.GaussianParametricEllipse( 224 | sigma_x=g2f.SigmaXParameterD( 225 | self.size_x.value_initial, transform=transform_size, fixed=self.size_x.fixed 226 | ), 227 | sigma_y=g2f.SigmaYParameterD( 228 | self.size_y.value_initial, transform=transform_size, fixed=self.size_y.fixed 229 | ), 230 | rho=g2f.RhoParameterD(self.rho.value_initial, transform=transform_rho, fixed=self.rho.fixed), 231 | ) 232 | return ellipse 233 | 234 | def make_fluxfrac_parameter( 235 | self, 236 | value: float | None, 237 | label: str | None = None, 238 | **kwargs 239 | ) -> g2f.ProperFractionParameterD: 240 | parameter = g2f.ProperFractionParameterD( 241 | value if value is None else self.fluxfrac.value_initial, 242 | fixed=self.fluxfrac.fixed, 243 | transform=self.get_transform_fluxfrac(), 244 | label=label if label is not None else "", 245 | **kwargs 246 | ) 247 | return parameter 248 | 249 | def make_flux_parameter( 250 | self, 251 | value: float | None, 252 | label: str | None = None, 253 | **kwargs 254 | ) -> g2f.IntegralParameterD: 255 | parameter = g2f.IntegralParameterD( 256 | value if value is not None else self.flux.value_initial, 257 | fixed=self.flux.fixed, 258 | transform=self.get_transform_flux(), 259 | label=label if label is not None else "", 260 | **kwargs 261 | ) 262 | return parameter 263 | 264 | def make_linear_integral_model( 265 | self, 266 | fluxes: Fluxes, 267 | label_integral: str | None = None, 268 | **kwargs 269 | ) -> g2f.IntegralModel: 270 | """Make a gauss2d.fit.LinearIntegralModel for this component. 271 | 272 | Parameters 273 | ---------- 274 | fluxes 275 | Configurations, including initial values, for the flux 276 | parameters by channel. 277 | label_integral 278 | A label to apply to integral parameters. Can reference the 279 | relevant channel with e.g. {channel.name}. 280 | **kwargs 281 | Additional keyword arguments to pass to make_flux_parameter. 282 | Some parameters cannot be overriden from their configs. 283 | 284 | Returns 285 | ------- 286 | integral_model 287 | The requested gauss2d.fit.IntegralModel. 288 | """ 289 | if label_integral is None: 290 | label_integral = self.get_integral_label_default() 291 | integral_model = g2f.LinearIntegralModel( 292 | [ 293 | ( 294 | channel, 295 | self.make_flux_parameter( 296 | flux, 297 | label=self.format_label(label_integral, name_channel=channel.name), 298 | **kwargs, 299 | ), 300 | ) 301 | for channel, flux in fluxes.items() 302 | ] 303 | ) 304 | return integral_model 305 | 306 | @staticmethod 307 | def set_size_x(component: g2f.EllipticalComponent, size_x: float) -> None: 308 | component.ellipse.sigma_x = size_x 309 | 310 | @staticmethod 311 | def set_size_y(component: g2f.EllipticalComponent, size_y: float) -> None: 312 | component.ellipse.sigma_y = size_y 313 | 314 | @staticmethod 315 | def set_rho(component: g2f.EllipticalComponent, rho: float) -> None: 316 | component.ellipse.rho = rho 317 | 318 | 319 | class GaussianComponentConfig(EllipticalComponentConfig): 320 | """Configuration for a gauss2d.fit Gaussian component.""" 321 | 322 | transform_frac_name = pexConfig.Field[str]( 323 | doc="The name of the reference transform for flux fraction parameters", 324 | default="log10", 325 | optional=True, 326 | ) 327 | 328 | def get_size_label(self) -> str: 329 | return "sigma" 330 | 331 | def get_type_name(self) -> str: 332 | return "Gaussian" 333 | 334 | def make_component( 335 | self, 336 | centroid: g2f.CentroidParameters, 337 | integral_model: g2f.IntegralModel, 338 | ) -> ComponentData: 339 | ellipse = self.make_gaussianparametricellipse() 340 | prior = self.get_shape_prior(ellipse) 341 | component_data = ComponentData( 342 | component=g2f.GaussianComponent( 343 | centroid=centroid, 344 | ellipse=ellipse, 345 | integral=integral_model, 346 | ), 347 | integral_model=integral_model, 348 | priors=[] if prior is None else [prior], 349 | ) 350 | return component_data 351 | 352 | 353 | class SersicIndexParameterConfig(ParameterConfig): 354 | """Configuration for a gauss2d.fit Sersic index parameter.""" 355 | 356 | prior_mean = pexConfig.Field[float](doc="Mean for the prior (untransformed)", default=1.0, optional=True) 357 | prior_stddev = pexConfig.Field[float](doc="Std. dev. for the prior", default=0.5, optional=True) 358 | prior_transformed = pexConfig.Field[float]( 359 | doc="Whether the prior should be in transformed values", default=True, 360 | ) 361 | 362 | def get_prior(self, param: g2f.SersicIndexParameterD) -> g2f.Prior | None: 363 | if self.prior_mean is not None: 364 | mean = param.transform.forward(self.prior_mean) if self.prior_transformed else self.prior_mean 365 | stddev = ( 366 | param.transform.forward(self.prior_mean + self.prior_stddev/2.) - 367 | param.transform.forward(self.prior_mean - self.prior_stddev/2.) 368 | ) if self.prior_transformed else self.prior_stddev 369 | return g2f.GaussianPrior( 370 | param=param, mean=mean, stddev=stddev, transformed=self.prior_transformed, 371 | ) 372 | return None 373 | 374 | def setDefaults(self): 375 | self.value_initial = 0.5 376 | 377 | def validate(self): 378 | super().validate() 379 | if self.prior_mean is not None: 380 | if not self.prior_mean > 0.: 381 | raise ValueError("Sersic index prior mean must be > 0") 382 | if not self.prior_stddev > 0.: 383 | raise ValueError("Sersic index prior std. dev. must be > 0") 384 | 385 | 386 | class SersicComponentConfig(EllipticalComponentConfig): 387 | """Configuration for a gauss2d.fit Sersic component. 388 | 389 | Notes 390 | ----- 391 | make_component will return a `gauss2d.fit.GaussianComponent` if the Sersic 392 | index is fixed at 0.5, or a `gauss2d.fit.SersicMixComponent` otherwise. 393 | """ 394 | 395 | _interpolators: dict[int, g2f.SersicMixInterpolator] = {} 396 | 397 | order = pexConfig.ChoiceField[int](doc="Sersic mix order", allowed={4: "Four", 8: "Eight"}, default=4) 398 | sersic_index = pexConfig.ConfigField[SersicIndexParameterConfig](doc="Sersic index config") 399 | 400 | def get_interpolator(self, order: int): 401 | return self._interpolators.get( 402 | order, 403 | ( 404 | g2f.GSLSersicMixInterpolator 405 | if hasattr(g2f, "GSLSersicMixInterpolator") 406 | else g2f.LinearSersicMixInterpolator 407 | )(order=order), 408 | ) 409 | 410 | def get_size_label(self) -> str: 411 | return "reff" 412 | 413 | def get_type_name(self) -> str: 414 | is_gaussian_fixed = self.is_gaussian_fixed() 415 | return f"{'Gaussian (fixed Sersic)' if is_gaussian_fixed else 'Sersic'}" 416 | 417 | def is_gaussian_fixed(self): 418 | return self.sersic_index.value_initial == 0.5 and self.sersic_index.fixed 419 | 420 | def make_component( 421 | self, 422 | centroid: g2f.CentroidParameters, 423 | integral_model: g2f.IntegralModel, 424 | ) -> ComponentData: 425 | is_gaussian_fixed = self.is_gaussian_fixed() 426 | transform_size = self.get_transform_size() 427 | transform_rho = self.get_transform_rho() 428 | if is_gaussian_fixed: 429 | ellipse = self.make_gaussianparametricellipse() 430 | component = g2f.GaussianComponent( 431 | centroid=centroid, 432 | ellipse=ellipse, 433 | integral=integral_model, 434 | ) 435 | priors = [] 436 | else: 437 | ellipse = g2f.SersicParametricEllipse( 438 | size_x=g2f.ReffXParameterD( 439 | self.size_x.value_initial, transform=transform_size, fixed=self.size_x.fixed 440 | ), 441 | size_y=g2f.ReffYParameterD( 442 | self.size_y.value_initial, transform=transform_size, fixed=self.size_y.fixed 443 | ), 444 | rho=g2f.RhoParameterD(self.rho.value_initial, transform=transform_rho, fixed=self.rho.fixed), 445 | ) 446 | sersic_index = g2f.SersicMixComponentIndexParameterD( 447 | value=self.sersic_index.value_initial, 448 | fixed=self.sersic_index.fixed, 449 | transform=transforms_ref["logit_sersic"] if not self.sersic_index.fixed else None, 450 | interpolator=self.get_interpolator(order=self.order), 451 | limits=limits_ref["n_ser_multigauss"], 452 | ) 453 | component = g2f.SersicMixComponent( 454 | centroid=centroid, 455 | ellipse=ellipse, 456 | integral=integral_model, 457 | sersicindex=sersic_index, 458 | ) 459 | prior = self.sersic_index.get_prior(sersic_index) if not sersic_index.fixed else None 460 | priors = [prior] if prior else [] 461 | prior = self.get_shape_prior(ellipse) 462 | if prior: 463 | priors.append(prior) 464 | return ComponentData( 465 | component=component, 466 | integral_model=integral_model, 467 | priors=priors, 468 | ) 469 | 470 | def validate(self): 471 | super().validate() 472 | -------------------------------------------------------------------------------- /examples/fithsc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Fitting HSC data in multiband mode using MultiProFit 5 | 6 | # In[1]: 7 | 8 | 9 | # Import required packages 10 | import time 11 | from typing import Any, Iterable, Mapping 12 | 13 | from astropy.coordinates import SkyCoord 14 | from astropy.io.ascii import Csv 15 | import astropy.io.fits as fits 16 | import astropy.table as apTab 17 | import astropy.visualization as apVis 18 | from astropy.wcs import WCS 19 | import lsst.gauss2d as g2 20 | import lsst.gauss2d.fit as g2f 21 | from lsst.multiprofit.componentconfig import SersicComponentConfig, SersicIndexParameterConfig 22 | from lsst.multiprofit.fit_psf import ( 23 | CatalogExposurePsfABC, 24 | CatalogPsfFitter, 25 | CatalogPsfFitterConfig, 26 | CatalogPsfFitterConfigData, 27 | ) 28 | from lsst.multiprofit.fit_source import ( 29 | CatalogExposureSourcesABC, 30 | CatalogSourceFitterABC, 31 | CatalogSourceFitterConfig, 32 | CatalogSourceFitterConfigData, 33 | ) 34 | from lsst.multiprofit.modelconfig import ModelConfig 35 | from lsst.multiprofit.plots import plot_model_rgb 36 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 37 | from lsst.multiprofit.utils import ArbitraryAllowedConfig, get_params_uniq 38 | import matplotlib as mpl 39 | import matplotlib.pyplot as plt 40 | import numpy as np 41 | import pydantic 42 | from pydantic.dataclasses import dataclass 43 | 44 | # In[2]: 45 | 46 | 47 | # Define settings 48 | band_ref = 'i' 49 | bands = {'i': 0.87108833, 'r': 0.97288654, 'g': 1.44564678} 50 | band_multi = ''.join(bands) 51 | channels = {band: g2f.Channel.get(band) for band in bands} 52 | 53 | # This is in the WCS, but may as well keep full precision 54 | scale_pixel_hsc = 0.168 55 | 56 | # Common to all FITS 57 | hdu_img, hdu_mask, hdu_var = 1, 2, 3 58 | 59 | # Masks 60 | bad_masks = ( 61 | 'BAD', 'SAT', 'INTRP', 'CR', 'EDGE', 'CLIPPED', 'NO_DATA', 'CROSSTALK', 62 | 'NO_DATA', 'UNMASKEDNAN', 'SUSPECT', 'REJECTED', 'SENSOR_EDGE', 63 | ) 64 | maskbits = tuple(f'MP_{b}' for b in bad_masks) 65 | 66 | # A pre-defined bitmask to exclude regions with low SN 67 | read_mask_highsn = True 68 | write_mask_highsn = False 69 | 70 | # matplotlib settings 71 | mpl.rcParams['image.origin'] = 'lower' 72 | mpl.rcParams['figure.dpi'] = 120 73 | 74 | 75 | # In[3]: 76 | 77 | 78 | # Define source to fit 79 | id_gama, z = 79635, 0.0403 80 | """ 81 | Acquired from https://hsc-release.mtk.nao.ac.jp/datasearch/catalog_jobs with query: 82 | 83 | SELECT object_id, ra, dec, 84 | g_cmodel_mag, g_cmodel_magerr, r_cmodel_mag, r_cmodel_magerr, i_cmodel_mag, i_cmodel_magerr, 85 | g_psfflux_mag, g_psfflux_magerr, r_psfflux_mag, r_psfflux_magerr, i_psfflux_mag, i_psfflux_magerr, 86 | g_kronflux_mag, g_kronflux_magerr, r_kronflux_mag, r_kronflux_magerr, i_kronflux_mag, i_kronflux_magerr, 87 | g_sdssshape_shape11, g_sdssshape_shape11err, g_sdssshape_shape22, g_sdssshape_shape22err, 88 | g_sdssshape_shape12, g_sdssshape_shape12err, 89 | r_sdssshape_shape11, r_sdssshape_shape11err, r_sdssshape_shape22, r_sdssshape_shape22err, 90 | r_sdssshape_shape12, r_sdssshape_shape12err, 91 | i_sdssshape_shape11, i_sdssshape_shape11err, i_sdssshape_shape22, i_sdssshape_shape22err, 92 | i_sdssshape_shape12, i_sdssshape_shape12err 93 | FROM pdr3_wide.forced 94 | LEFT JOIN pdr3_wide.forced2 USING (object_id) 95 | WHERE isprimary AND conesearch(coord, 222.51551376, 0.09749601, 35.64) 96 | AND (r_kronflux_mag < 26 OR i_kronflux_mag < 26) AND NOT i_kronflux_flag AND NOT r_kronflux_flag; 97 | """ 98 | cat = Csv() 99 | cat.header.splitter.escapechar = '#' 100 | cat = cat.read('fithsc_src.csv') 101 | 102 | prefix = '222.51551376,0.09749601_' 103 | prefix_img = f'{prefix}300x300_' 104 | 105 | # Read data, acquired with: 106 | # https://github.com/taranu/astro_imaging/blob/4d5a8e095e6a3944f1fbc19318b1dbc22b22d9ca/examples/HSC.ipynb 107 | # (with get_mask=True, get_variance=True,) 108 | images, psfs = {}, {} 109 | for band in bands: 110 | images[band] = fits.open(f'{prefix_img}{band}.fits') 111 | psfs[band] = fits.open(f'{prefix}{band}_psf.fits') 112 | 113 | wcs = WCS(images[band_ref][hdu_img]) 114 | cat['x'], cat['y'] = wcs.world_to_pixel(SkyCoord(cat['ra'], cat['dec'], unit='deg')) 115 | 116 | 117 | # In[4]: 118 | 119 | 120 | # Plot image 121 | img_rgb = apVis.make_lupton_rgb(*[img[1].data*bands[band] for band, img in images.items()]) 122 | plt.scatter(cat['x'], cat['y'], s=10, c='g', marker='x') 123 | plt.imshow(img_rgb) 124 | plt.title("gri image with detected objects") 125 | plt.show() 126 | 127 | 128 | # In[5]: 129 | 130 | 131 | # Generate a rough mask around other sources 132 | bright = (cat['i_cmodel_mag'] < 23) | (cat['i_psfflux_mag'] < 23) 133 | 134 | img_ref = images[band_ref][hdu_img].data 135 | 136 | mask_inverse = np.ones(img_ref.shape, dtype=bool) 137 | y_cen, x_cen = (x/2. for x in img_ref.shape) 138 | y, x = np.indices(img_ref.shape) 139 | 140 | idx_src_main, row_main = None, None 141 | 142 | sizes_override = { 143 | 42305088563206480: 8., 144 | } 145 | 146 | for src in cat[bright]: 147 | id_src, x_src, y_src = (src[col] for col in ['object_id', 'x', 'y']) 148 | dist = np.hypot(x_src - x_cen, y_src - y_cen) 149 | if dist > 20: 150 | dists = np.hypot(y - y_src, x - x_src) 151 | mag = np.nanmin([src['i_cmodel_mag'], src['r_cmodel_mag'], src['i_psfflux_mag'], src['r_psfflux_mag']]) 152 | if (radius_mask := sizes_override.get(id_src)) is None: 153 | radius_mask = 2*np.sqrt( 154 | src[f'{band_ref}_sdssshape_shape11'] + src[f'{band_ref}_sdssshape_shape22'] 155 | )/scale_pixel_hsc 156 | if (radius_mask > 10) and (mag > 21): 157 | radius_mask = 5 158 | mask_inverse[dists < radius_mask] = 0 159 | print(f'Masking src=({id_src} at {x_src:.3f}, {y_src:.3f}) dist={dist:.3f}' 160 | f', mag={mag:.3f}, radius_mask={radius_mask:.3f}') 161 | elif dist < 2: 162 | idx_src_main = id_src 163 | row_main = src 164 | print(f"{idx_src_main=} {dict(src)=}") 165 | 166 | tab_row_main = apTab.Table(row_main) 167 | 168 | if read_mask_highsn: 169 | mask_highsn = np.load(f'{prefix_img}mask_inv_highsn.npz')['mask_inv'] 170 | mask_inverse *= mask_highsn 171 | 172 | plt.imshow(mask_inverse) 173 | plt.title("Fitting mask") 174 | plt.show() 175 | 176 | 177 | # In[6]: 178 | 179 | 180 | # Fit PSF 181 | @dataclass(frozen=True, config=ArbitraryAllowedConfig) 182 | class CatalogExposurePsf(CatalogExposurePsfABC): 183 | catalog: apTab.Table = pydantic.Field(title="The detected object catalog") 184 | img: np.ndarray = pydantic.Field(title="The PSF image") 185 | 186 | def get_catalog(self) -> Iterable: 187 | return self.catalog 188 | 189 | def get_psf_image(self, source: apTab.Row | Mapping[str, Any]) -> np.array: 190 | return self.img 191 | 192 | config_psf = CatalogPsfFitterConfig(column_id='object_id') 193 | fitter_psf = CatalogPsfFitter() 194 | catalog_psf = apTab.Table({'object_id': [tab_row_main['object_id']]}) 195 | results_psf = {} 196 | 197 | # Keep a separate configdata_psf per band, because it has a cached PSF model 198 | # those should not be shared! 199 | config_data_psfs = {} 200 | for band, psf_file in psfs.items(): 201 | config_data_psf = CatalogPsfFitterConfigData(config=config_psf) 202 | catexp = CatalogExposurePsf(catalog=catalog_psf, img=psf_file[0].data) 203 | t_start = time.time() 204 | result = fitter_psf.fit(config_data=config_data_psf, catexp=catexp) 205 | t_end = time.time() 206 | results_psf[band] = result 207 | config_data_psfs[band] = config_data_psf 208 | print(f"Fit {band}-band PSF in {t_end - t_start:.2e}s; result:") 209 | print(dict(result[0])) 210 | 211 | 212 | # In[7]: 213 | 214 | 215 | # Set fit configs 216 | config_source = CatalogSourceFitterConfig( 217 | column_id='object_id', 218 | config_model=ModelConfig( 219 | sources={ 220 | "src": SourceConfig( 221 | component_groups={ 222 | "": ComponentGroupConfig( 223 | components_sersic={ 224 | 'disk': SersicComponentConfig( 225 | sersic_index=SersicIndexParameterConfig(value_initial=1., fixed=True), 226 | prior_size_stddev=0.5, 227 | prior_axrat_stddev=0.2, 228 | ), 229 | 'bulge': SersicComponentConfig( 230 | sersic_index=SersicIndexParameterConfig(value_initial=4., fixed=True), 231 | prior_size_stddev=0.1, 232 | prior_axrat_stddev=0.2, 233 | ), 234 | }, 235 | ), 236 | } 237 | ), 238 | }, 239 | ), 240 | ) 241 | config_data_source = CatalogSourceFitterConfigData( 242 | channels=list(channels.values()), 243 | config=config_source, 244 | ) 245 | 246 | 247 | # In[8]: 248 | 249 | 250 | # Setup exposure with band-specific image, mask and variance 251 | @dataclass(frozen=True, config=ArbitraryAllowedConfig) 252 | class CatalogExposureSources(CatalogExposureSourcesABC): 253 | config_data_psf: CatalogPsfFitterConfigData = pydantic.Field(title="The PSF fit config") 254 | observation: g2f.ObservationD = pydantic.Field(title="The observation to fit") 255 | table_psf_fits: apTab.Table = pydantic.Field(title="The table of PSF fit parameters") 256 | 257 | @property 258 | def channel(self) -> g2f.Channel: 259 | return self.observation.channel 260 | 261 | def get_catalog(self) -> Iterable: 262 | return self.table_psf_fits 263 | 264 | def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel: 265 | self.config_data_psf.init_psf_model(params) 266 | return self.config_data_psf.psf_model 267 | 268 | def get_source_observation(self, source: Mapping[str, Any]) -> g2f.ObservationD: 269 | return self.observation 270 | 271 | 272 | @dataclass(frozen=True, config=ArbitraryAllowedConfig) 273 | class CatalogSourceFitter(CatalogSourceFitterABC): 274 | band: str = pydantic.Field(title="The reference band for initialization and priors") 275 | scale_pixel: float = pydantic.Field(title="The pixel scale in arcsec") 276 | wcs_ref: WCS = pydantic.Field(title="The WCS for the coadded image") 277 | 278 | def initialize_model( 279 | self, 280 | model: g2f.ModelD, 281 | source: Mapping[str, Any], 282 | catexps: list[CatalogExposureSourcesABC], 283 | values_init: Mapping[g2f.ParameterD, float] | None = None, 284 | centroid_pixel_offset: float = 0, 285 | **kwargs 286 | ): 287 | if values_init is None: 288 | values_init = {} 289 | x, y = source['x'], source['y'] 290 | scale_sq = self.scale_pixel**(-2) 291 | ellipse = g2.Ellipse(g2.Covariance( 292 | sigma_x_sq=source[f'{band}_sdssshape_shape11']*scale_sq, 293 | sigma_y_sq=source[f'{band}_sdssshape_shape22']*scale_sq, 294 | cov_xy=source[f'{band}_sdssshape_shape12']*scale_sq, 295 | )) 296 | size_major = g2.EllipseMajor(ellipse).r_major 297 | limits_size = g2f.LimitsD(1e-5, np.sqrt(x*x + y*y)) 298 | # An R_eff larger than the box size is problematic 299 | # Also should stop unreasonable size proposals; log10 transform isn't enough 300 | # TODO: Try logit for r_eff? 301 | params_limits_init = { 302 | # Should set limits based on image size, but this shortcut is fine 303 | # for this particular object 304 | g2f.CentroidXParameterD: (x, g2f.LimitsD(0, 2*x)), 305 | g2f.CentroidYParameterD: (x, g2f.LimitsD(0, 2*y)), 306 | g2f.ReffXParameterD: (ellipse.sigma_x, limits_size), 307 | g2f.ReffYParameterD: (ellipse.sigma_y, limits_size), 308 | # There is a sign convention difference 309 | g2f.RhoParameterD: (-ellipse.rho, None), 310 | g2f.IntegralParameterD: (1.0, g2f.LimitsD(1e-10, 1e10)), 311 | } 312 | params_free = get_params_uniq(model, fixed=False) 313 | for param in params_free: 314 | type_param = type(param) 315 | value_init, limits_new = params_limits_init.get( 316 | type_param, 317 | (values_init.get(param), None) 318 | ) 319 | if value_init is not None: 320 | param.value = value_init 321 | if limits_new: 322 | # For slightly arcane reasons, we must set a new limits object 323 | # Changing limits values is unreliable 324 | param.limits = limits_new 325 | for prior in model.priors: 326 | if isinstance(prior, g2f.ShapePrior): 327 | prior.prior_size.mean_parameter.value = size_major 328 | 329 | 330 | def validate_fit_inputs( 331 | self, 332 | catalog_multi, 333 | catexps: list[CatalogExposureSourcesABC], 334 | config_data: CatalogSourceFitterConfigData = None, 335 | logger = None, 336 | **kwargs: Any, 337 | ) -> None: 338 | super().validate_fit_inputs( 339 | catalog_multi=catalog_multi, catexps=catexps, config_data=config_data, 340 | logger=logger, **kwargs 341 | ) 342 | 343 | 344 | # In[9]: 345 | 346 | 347 | # Set up Fitter, Observations and CatalogExposureSources 348 | fitter = CatalogSourceFitter(band=band, scale_pixel=scale_pixel_hsc, wcs_ref=wcs) 349 | 350 | observations = {} 351 | catexps = {} 352 | 353 | for band in bands: 354 | data = images[band] 355 | # There are better ways to use bitmasks, but this will do 356 | header = data[hdu_mask].header 357 | bitmask = data[hdu_mask].data 358 | mask = np.zeros_like(bitmask, dtype='bool') 359 | for bit in maskbits: 360 | mask |= ((bitmask & 2**header[bit]) != 0) 361 | 362 | mask = (mask == 0) & mask_inverse 363 | sigma_inv = 1.0/np.sqrt(data[hdu_var].data) 364 | sigma_inv[mask != 1] = 0 365 | 366 | observation = g2f.ObservationD( 367 | image=g2.ImageD(data[hdu_img].data), 368 | sigma_inv=g2.ImageD(sigma_inv), 369 | mask_inv=g2.ImageB(mask), 370 | channel=g2f.Channel.get(band), 371 | ) 372 | observations[band] = observation 373 | catexps[band] = CatalogExposureSources( 374 | config_data_psf=config_data_psfs[band], 375 | observation=observation, 376 | table_psf_fits=results_psf[band], 377 | ) 378 | 379 | 380 | # In[10]: 381 | 382 | 383 | # Now do the multi-band fit 384 | t_start = time.time() 385 | result_multi = fitter.fit( 386 | catalog_multi=tab_row_main, 387 | catexps=list(catexps.values()), 388 | config_data=config_data_source, 389 | ) 390 | t_end = time.time() 391 | print(f"Fit {','.join(bands.keys())}-band bulge-disk model in {t_end - t_start:.2e}s; result:") 392 | print(dict(result_multi[0])) 393 | 394 | 395 | # In[11]: 396 | 397 | 398 | # Fit in each band separately 399 | results = {} 400 | for band, observation in bands.items(): 401 | config_data_source_band = CatalogSourceFitterConfigData( 402 | channels=[channels[band]], 403 | config=config_source, 404 | ) 405 | t_start = time.time() 406 | result = fitter.fit( 407 | catalog_multi=tab_row_main, 408 | catexps=[catexps[band]], 409 | config_data=config_data_source_band, 410 | ) 411 | t_end = time.time() 412 | results[band] = result 413 | print(f"Fit {band}-band bulge-disk model in {t_end - t_start:.2f}s; result:") 414 | print(dict(result[0])) 415 | 416 | 417 | # In[12]: 418 | 419 | 420 | # Make a model for the best-fit params 421 | data, psf_models = config_source.make_model_data(idx_row=0, catexps=list(catexps.values())) 422 | model = g2f.ModelD(data=data, psfmodels=psf_models, sources=config_data_source.sources_priors[0], priors=config_data_source.sources_priors[1]) 423 | params = get_params_uniq(model, fixed=False) 424 | result_multi_row = dict(result_multi[0]) 425 | # This is the last column before fit params 426 | idx_last = next(idx for idx, column in enumerate(result_multi_row.keys()) if column == 'mpf_unknown_flag') 427 | # Set params to best fit values 428 | for param, (column, value) in zip(params, list(result_multi_row.items())[idx_last+1:]): 429 | param.value = value 430 | model.setup_evaluators(g2f.EvaluatorMode.loglike_image) 431 | # Print the loglikelihoods, which are from the data and end with the (sum of all) priors 432 | loglikes = model.evaluate() 433 | print(f"{loglikes=}") 434 | 435 | 436 | # ### Multiband Residuals 437 | # 438 | # What's with the structure in the residuals? Most broadly, a point source + exponential disk + deVauc bulge model is totally inadequate for this galaxy for several possible reasons: 439 | # 440 | # 1. The disk isn't exactly exponential (n=1) 441 | # 2. The disk has colour gradients not accounted for in this model* 442 | # 3. If the galaxy even has a bulge, it's very weak and def. not a deVaucouleurs (n=4) profile; it may be an exponential "pseudobulge" 443 | # 444 | # \*MultiProFit can do more general Gaussian mixture models (linear or non-linear), which may be explored in a future iteration of this notebook, but these are generally do not improve the accuracy of photometry for smaller/fainter galaxies. 445 | # 446 | # Note that the two scalings of the residual plots (98%ile and +/- 20 sigma) end up looking very similar. 447 | # 448 | 449 | # In[13]: 450 | 451 | 452 | # Make some basic plots 453 | _, _, _, _, mask_inv_highsn = plot_model_rgb( 454 | model, weights=bands, high_sn_threshold=0.2 if write_mask_highsn else None, 455 | ) 456 | plt.show() 457 | 458 | # Write the high SN bitmask to a compressed, bitpacked file 459 | if write_mask_highsn: 460 | plt.figure() 461 | plt.imshow(mask_highsn, cmap='gray') 462 | plt.show() 463 | packed = np.packbits(mask_inv_highsn, bitorder='little') 464 | np.savez_compressed(f'{prefix_img}mask_inv_highsn.npz', mask_inv=mask_highsn) 465 | 466 | # TODO: Some features still missing from plot_model_rgb 467 | # residual histograms, param values, better labels, etc 468 | 469 | 470 | # ### More exercises for the reader 471 | # 472 | # These are of the sort that the author hasn't gotten around to yet because they're far from trivial. Try: 473 | # 474 | # 0. Use the WCS to compute ra, dec and errors thereof. 475 | # Hint: override CatalogSourceFitter.get_model_radec 476 | # 477 | # 1. Replace the real data with simulated data. 478 | # Make new observations using model.evaluate and add noise based on the variance maps. 479 | # Try fitting again and see how well results converge depending on the initialization scheme. 480 | # 481 | # 2. Fit every other source individually. 482 | # Try subtracting the best-fit galaxy model from above first. 483 | # Hint: get_source_observation should be redefined to return a smaller postage stamp around the nominal centroid. 484 | # Pass the full catalog (excluding the central galaxy) to catalog_multi. 485 | # 486 | # 3. Fit all sources simultaneously. 487 | # Redefine CatalogFitterConfig.make_model_data to make a model with multiple sources, using the catexp catalogs 488 | # initialize_model will no longer need to do anything 489 | # catalog_multi should still be a single row 490 | -------------------------------------------------------------------------------- /tests/test_modeller.py: -------------------------------------------------------------------------------- 1 | # This file is part of multiprofit. 2 | # 3 | # Developed for the LSST Data Management System. 4 | # This product includes software developed by the LSST Project 5 | # (https://www.lsst.org). 6 | # See the COPYRIGHT file at the top-level directory of this distribution 7 | # for details of code ownership. 8 | # 9 | # This program is free software: you can redistribute it and/or modify 10 | # it under the terms of the GNU General Public License as published by 11 | # the Free Software Foundation, either version 3 of the License, or 12 | # (at your option) any later version. 13 | # 14 | # This program is distributed in the hope that it will be useful, 15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 | # GNU General Public License for more details. 18 | # 19 | # You should have received a copy of the GNU General Public License 20 | # along with this program. If not, see . 21 | 22 | import math 23 | import time 24 | 25 | import lsst.gauss2d as g2 26 | import lsst.gauss2d.fit as g2f 27 | from lsst.multiprofit.componentconfig import ( 28 | CentroidConfig, 29 | FluxFractionParameterConfig, 30 | FluxParameterConfig, 31 | GaussianComponentConfig, 32 | ParameterConfig, 33 | SersicComponentConfig, 34 | SersicIndexParameterConfig, 35 | ) 36 | from lsst.multiprofit.model_utils import make_image_gaussians, make_psf_model_null 37 | from lsst.multiprofit.modelconfig import ModelConfig 38 | from lsst.multiprofit.modeller import FitInputs, LinearGaussians, Modeller, fitmethods_linear 39 | from lsst.multiprofit.observationconfig import CoordinateSystemConfig, ObservationConfig 40 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig 41 | from lsst.multiprofit.utils import get_params_uniq 42 | import numpy as np 43 | import pytest 44 | 45 | sigma_inv = 1e4 46 | 47 | 48 | @pytest.fixture(scope="module") 49 | def channels() -> dict[str, g2f.Channel]: 50 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")} 51 | 52 | 53 | @pytest.fixture(scope="module") 54 | def data(channels) -> g2f.DataD: 55 | n_rows, n_cols = 25, 27 56 | x_min, y_min = 0, 0 57 | 58 | dn_rows, dn_cols = 2, -3 59 | dx_min, dy_min = -1, 1 60 | 61 | observations = [] 62 | for idx, band in enumerate(channels): 63 | config = ObservationConfig( 64 | band=band, 65 | coordsys=CoordinateSystemConfig( 66 | x_min=x_min + idx*dx_min, 67 | y_min=y_min + idx*dy_min, 68 | ), 69 | n_rows=n_rows + idx*dn_rows, 70 | n_cols=n_cols + idx*dn_cols, 71 | ) 72 | observation = config.make_observation() 73 | observation.image.fill(0) 74 | observation.sigma_inv.fill(sigma_inv) 75 | observation.mask_inv.fill(1) 76 | observations.append(observation) 77 | return g2f.DataD(observations) 78 | 79 | 80 | @pytest.fixture(scope="module") 81 | def psf_models(channels) -> list[g2f.PsfModel]: 82 | rho, size_x, size_y = 0.12, 1.6, 1.2 83 | drho, dsize_x, dsize_y = -0.3, 1.1, 1.9 84 | drho_chan, dsize_x_chan, dsize_y_chan = 0.03, 0.12, 0.14 85 | frac, dfrac = 0.62, -0.08 86 | 87 | n_components = 2 88 | psf_models = [] 89 | 90 | for idx_chan, channel in enumerate(channels.values()): 91 | frac_chan = frac + idx_chan*dfrac 92 | config = SourceConfig( 93 | component_groups={ 94 | 'psf': ComponentGroupConfig( 95 | components_gauss={ 96 | str(idx): GaussianComponentConfig( 97 | rho=ParameterConfig(value_initial=rho + idx*drho + idx_chan*drho_chan), 98 | size_x=ParameterConfig( 99 | value_initial=size_x + idx*dsize_x + idx_chan*dsize_x_chan), 100 | size_y=ParameterConfig( 101 | value_initial=size_y + idx*dsize_y + idx_chan*dsize_y_chan), 102 | **({ 103 | "flux": FluxParameterConfig(value_initial=1.0, fixed=True), 104 | "fluxfrac": FluxFractionParameterConfig(value_initial=frac_chan, fixed=False), 105 | } if (idx == 0) else {}) 106 | ) 107 | for idx in range(n_components) 108 | }, 109 | is_fractional=True, 110 | ) 111 | }, 112 | ) 113 | config.validate() 114 | psf_model, priors = config.make_psf_model([ 115 | component_group.get_fluxes_default( 116 | channels=(g2f.Channel.NONE,), 117 | component_configs=component_group.get_component_configs(), 118 | is_fractional=component_group.is_fractional, 119 | ) 120 | for component_group in config.component_groups.values() 121 | ]) 122 | psf_models.append(psf_model) 123 | return psf_models 124 | 125 | 126 | @pytest.fixture(scope="module") 127 | def model(channels, data, psf_models) -> g2f.ModelD: 128 | rho, size_x, size_y, sersicn, flux = 0.4, 1.5, 1.9, 1.0, 4.7 129 | drho, dsize_x, dsize_y, dsersicn, dflux = -0.9, 2.5, 5.4, 3.0, 13.9 130 | 131 | components_sersic = {} 132 | fluxes_group = [] 133 | 134 | # Linear interpolators fail to compute accurate likelihoods at knot values 135 | is_linear_interp = g2f.SersicMixComponentIndexParameterD( 136 | interpolator=SersicComponentConfig().get_interpolator(4) 137 | ).interptype == g2f.InterpType.linear 138 | 139 | for idx, name in enumerate(("exp", "dev")): 140 | components_sersic[name] = SersicComponentConfig( 141 | rho=ParameterConfig(value_initial=rho + idx*drho), 142 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x), 143 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y), 144 | sersic_index=SersicIndexParameterConfig( 145 | # Add a small offset since 1.0 and 4.0 are bound to be knots 146 | value_initial=sersicn + idx * dsersicn + 1e-4*is_linear_interp, 147 | fixed=idx == 0, 148 | prior_mean=None, 149 | ), 150 | ) 151 | fluxes_comp = { 152 | channel: flux + idx_channel*dflux*idx 153 | for idx_channel, channel in enumerate(channels.values()) 154 | } 155 | fluxes_group.append(fluxes_comp) 156 | 157 | modelconfig = ModelConfig( 158 | sources={ 159 | "src": SourceConfig( 160 | component_groups={ 161 | "": ComponentGroupConfig( 162 | components_sersic=components_sersic, 163 | centroids={"default": CentroidConfig( 164 | x=ParameterConfig(value_initial=12.14, fixed=True), 165 | y=ParameterConfig(value_initial=13.78, fixed=True), 166 | )}, 167 | ), 168 | } 169 | ), 170 | }, 171 | ) 172 | model = modelconfig.make_model([[fluxes_group]], data=data, psf_models=psf_models) 173 | return model 174 | 175 | 176 | @pytest.fixture(scope="module") 177 | def model_jac(model) -> g2f.ModelD: 178 | model_jac = g2f.ModelD(data=model.data, psfmodels=model.psfmodels, sources=model.sources) 179 | return model_jac 180 | 181 | 182 | @pytest.fixture(scope="module") 183 | def psf_observations(psf_models) -> list[g2f.ObservationD]: 184 | config = ObservationConfig(n_rows=17, n_cols=19) 185 | rng = np.random.default_rng(1) 186 | 187 | observations = [] 188 | for psf_model in psf_models: 189 | observation = config.make_observation() 190 | # Have to make a duplicate image here because one can only call 191 | # make_image_gaussians with an owning pointer, whereas 192 | # observation.image is a reference 193 | image = g2.ImageD(observation.image.data) 194 | # Make the kernel centered 195 | gaussians_source = psf_model.gaussians(g2f.Channel.NONE) 196 | for idx in range(len(gaussians_source)): 197 | gaussian_idx = gaussians_source.at(idx) 198 | gaussian_idx.centroid.x = image.n_cols/2. 199 | gaussian_idx.centroid.y = image.n_rows/2. 200 | gaussians_kernel = g2.Gaussians([g2.Gaussian()]) 201 | make_image_gaussians( 202 | gaussians_source=gaussians_source, 203 | gaussians_kernel=gaussians_kernel, 204 | output=image, 205 | ) 206 | image.data.flat += 1e-4 * rng.standard_normal(image.data.size) 207 | observation.mask_inv.fill(1) 208 | observation.sigma_inv.fill(1e3) 209 | observations.append(observation) 210 | return observations 211 | 212 | 213 | @pytest.fixture(scope="module") 214 | def psf_fit_models(psf_models, psf_observations): 215 | psf_null = [make_psf_model_null()] 216 | return [ 217 | g2f.ModelD(g2f.DataD([observation]), psf_null, [g2f.Source(psf_model.components)]) 218 | for psf_model, observation in zip(psf_models, psf_observations) 219 | ] 220 | 221 | 222 | def test_model_evaluation(channels, model, model_jac): 223 | with pytest.raises(RuntimeError): 224 | model.evaluate() 225 | 226 | printout = False 227 | # Freeze the PSF params - they can't be fit anyway 228 | for m in (model, model_jac): 229 | for psf_model in m.psfmodels: 230 | params = psf_model.parameters() 231 | for param in params: 232 | param.fixed = True 233 | 234 | model.setup_evaluators(print=printout) 235 | model.evaluate() 236 | 237 | n_priors = 0 238 | n_obs = len(model.data) 239 | n_rows = np.zeros(n_obs, dtype=int) 240 | n_cols = np.zeros(n_obs, dtype=int) 241 | datasizes = np.zeros(n_obs, dtype=int) 242 | ranges_params = [None] * n_obs 243 | params_free = tuple(get_params_uniq(model_jac, fixed=False)) 244 | 245 | # There's one extra validation array 246 | n_params_jac = len(params_free) + 1 247 | assert n_params_jac > 1 248 | 249 | rng = np.random.default_rng(2) 250 | 251 | for idx_obs in range(n_obs): 252 | observation = model.data[idx_obs] 253 | output = model.outputs[idx_obs] 254 | observation.image.data.flat = ( 255 | output.data.flat + rng.standard_normal(output.data.size) / observation.sigma_inv.data.flat 256 | ) 257 | n_rows[idx_obs] = observation.image.n_rows 258 | n_cols[idx_obs] = observation.image.n_cols 259 | datasizes[idx_obs] = n_rows[idx_obs] * n_cols[idx_obs] 260 | params = tuple(get_params_uniq(model, fixed=False, channel=observation.channel)) 261 | n_params_obs = len(params) 262 | ranges_params_obs = [0] * (n_params_obs + 1) 263 | for idx_param in range(n_params_obs): 264 | ranges_params_obs[idx_param + 1] = params_free.index(params[idx_param]) + 1 265 | ranges_params[idx_obs] = ranges_params_obs 266 | 267 | n_free_first = len(ranges_params[0]) 268 | assert all([len(rp) == n_free_first for rp in ranges_params[1:]]) 269 | 270 | jacobians = [None] * n_obs 271 | residuals = [None] * n_obs 272 | datasize = np.sum(datasizes) + n_priors 273 | jacobian = np.zeros((datasize, n_params_jac)) 274 | residual = np.zeros(datasize) 275 | # jacobian_prior = self.jacobian[datasize:, ].view() 276 | 277 | offset = 0 278 | for idx_obs in range(n_obs): 279 | size_obs = datasizes[idx_obs] 280 | end = offset + size_obs 281 | shape = (n_rows[idx_obs], n_cols[idx_obs]) 282 | jacobians_obs = [None] * n_params_jac 283 | for idx_jac in range(n_params_jac): 284 | jacobians_obs[idx_jac] = g2.ImageD(jacobian[offset:end, idx_jac].view().reshape(shape)) 285 | jacobians[idx_obs] = jacobians_obs 286 | residuals[idx_obs] = g2.ImageD(residual[offset:end].view().reshape(shape)) 287 | offset = end 288 | 289 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike) 290 | loglike_init = model.evaluate() 291 | 292 | model_jac.setup_evaluators( 293 | evaluatormode=g2f.EvaluatorMode.jacobian, 294 | outputs=jacobians, 295 | residuals=residuals, 296 | print=printout, 297 | ) 298 | model_jac.verify_jacobian() 299 | loglike_jac = model_jac.evaluate() 300 | 301 | assert all(np.isclose(loglike_init, loglike_jac)) 302 | 303 | 304 | @pytest.fixture(scope="module") 305 | def psf_models_linear_gaussians(channels, psf_models): 306 | gaussians = [None] * len(psf_models) 307 | for idx, psf_model in enumerate(psf_models): 308 | params = psf_model.parameters(paramfilter=g2f.ParamFilter(nonlinear=False, channel=g2f.Channel.NONE)) 309 | params[0].fixed = False 310 | gaussians[idx] = LinearGaussians.make(psf_model, is_psf=True) 311 | # If this is not done, test_psf_model_fit will fail 312 | params[0].fixed = True 313 | return gaussians 314 | 315 | 316 | def test_make_psf_source_linear(psf_models, psf_models_linear_gaussians): 317 | for psf_model, linear_gaussians in zip(psf_models, psf_models_linear_gaussians): 318 | gaussians = psf_model.gaussians(g2f.Channel.NONE) 319 | assert len(gaussians) == ( 320 | len(linear_gaussians.gaussians_free) + len(linear_gaussians.gaussians_fixed) 321 | ) 322 | 323 | 324 | def test_modeller(model): 325 | # For debugging purposes 326 | printout = False 327 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike_image) 328 | # Get the model images 329 | model.evaluate() 330 | rng = np.random.default_rng(3) 331 | 332 | for idx_obs, observation in enumerate(model.data): 333 | output = model.outputs[idx_obs] 334 | observation.image.data.flat = ( 335 | output.data.flat + rng.standard_normal(output.data.size) / observation.sigma_inv.data.flat 336 | ) 337 | 338 | # Freeze the PSF params - they can't be fit anyway 339 | for psf_model in model.psfmodels: 340 | for param in psf_model.parameters(): 341 | param.fixed = True 342 | 343 | params_free = tuple(get_params_uniq(model, fixed=False)) 344 | values_true = tuple(param.value for param in params_free) 345 | 346 | modeller = Modeller() 347 | 348 | dloglike = model.compute_loglike_grad(verify=True, findiff_frac=1e-8, findiff_add=1e-8) 349 | assert all(np.isfinite(dloglike)) 350 | 351 | time_init = time.process_time() 352 | kwargs_fit = dict(ftol=1e-6, xtol=1e-6) 353 | 354 | for delta_param in (0, 0.2): 355 | model = g2f.ModelD(data=model.data, psfmodels=model.psfmodels, sources=model.sources) 356 | values_init = values_true 357 | if delta_param != 0: 358 | for param, value_init in zip(params_free, values_init): 359 | param.value = value_init 360 | try: 361 | param.value_transformed += delta_param 362 | except RuntimeError: 363 | param.value_transformed -= delta_param 364 | 365 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike) 366 | loglike_init = np.array(model.evaluate()) 367 | results = modeller.fit_model(model, **kwargs_fit) 368 | params_best = results.params_best 369 | 370 | for param, value in zip(params_free, params_best): 371 | param.value_transformed = value 372 | 373 | loglike_noprior = model.evaluate() 374 | assert np.sum(loglike_noprior) > np.sum(loglike_init) 375 | 376 | errors = modeller.compute_variances(model) 377 | # TODO: This should check >0, and < (some reasonable value), but the scipy least squares 378 | # does not do a great job optimizing and the loglike_grad isn't even negative... 379 | assert np.all(np.isfinite(errors)) 380 | 381 | if printout: 382 | print( 383 | f"got loglike={loglike_noprior} (init={sum(loglike_noprior)})" 384 | f" from modeller.fit_model in t={time.process_time() - time_init:.3e}, x={params_best}," 385 | f" results: \n{results}" 386 | ) 387 | 388 | loglike_noprior_sum = sum(loglike_noprior) 389 | for offset in (0, 1e-6): 390 | for param, value in zip(params_free, params_best): 391 | param.value_transformed = value 392 | priors = tuple( 393 | g2f.GaussianPrior(param, param.value_transformed + offset, 1.0, transformed=True) 394 | for param in params_free 395 | ) 396 | if offset == 0: 397 | for p in priors: 398 | assert p.evaluate().loglike == 0 399 | assert p.loglike_const_terms[0] == -math.log(math.sqrt(2 * math.pi)) 400 | model = g2f.ModelD( 401 | data=model.data, psfmodels=model.psfmodels, sources=model.sources, priors=priors 402 | ) 403 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike) 404 | loglike_init = sum(loglike_eval for loglike_eval in model.evaluate()) 405 | if offset == 0: 406 | assert np.isclose(loglike_init, loglike_noprior_sum, rtol=1e-10, atol=1e-10) 407 | else: 408 | assert loglike_init < loglike_noprior_sum 409 | 410 | time_init = time.process_time() 411 | results = modeller.fit_model(model, **kwargs_fit) 412 | time_init = time.process_time() - time_init 413 | loglike_new = -results.result.cost 414 | for param, value in zip(params_free, results.params_best): 415 | param.value_transformed = value 416 | 417 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike) 418 | loglike_model = sum(loglike_eval for loglike_eval in model.evaluate()) 419 | assert np.isclose(loglike_new, loglike_model, rtol=1e-10, atol=1e-10) 420 | # This should be > 0. TODO: Determine why it isn't always 421 | assert (loglike_new - loglike_init) > -1e-3 422 | 423 | if printout: 424 | print( 425 | f"got loglike={loglike_new} (first={loglike_noprior})" 426 | f" from modeller.fit_model in t={time_init:.3e}, x={results.params_best}," 427 | f" results: \n{results}" 428 | ) 429 | # Adding a suitably-scaled prior far from the truth should always 430 | # worsen loglikel, but doesn't - why? noise bias? bad convergence? 431 | # assert (loglike_new >= loglike_noprior) == (offset == 0) 432 | 433 | 434 | def test_psf_model_fit(psf_fit_models): 435 | for model in psf_fit_models: 436 | params = get_params_uniq(model.sources[0]) 437 | for param in params: 438 | # Fitting the total flux won't work in a fractional model (yet) 439 | if isinstance(param, g2f.IntegralParameterD): 440 | assert param.fixed 441 | else: 442 | param.fixed = False 443 | # Necessary whenever parameters are freed/fixed 444 | model.setup_evaluators(g2f.EvaluatorMode.jacobian, force=True) 445 | errors = model.verify_jacobian(rtol=5e-4, atol=5e-4, findiff_add=1e-6, findiff_frac=1e-6) 446 | if errors: 447 | import matplotlib.pyplot as plt 448 | print(model.parameters()) 449 | 450 | fitinputs = FitInputs.from_model(model) 451 | model.setup_evaluators( 452 | evaluatormode=g2f.EvaluatorMode.jacobian, 453 | outputs=fitinputs.jacobians, 454 | residuals=fitinputs.residuals, 455 | print=True, 456 | force=True, 457 | ) 458 | model.evaluate(print=True) 459 | assert (fitinputs.jacobians[0][0].data == 0).all() 460 | assert np.sum(np.abs(fitinputs.jacobians[0][1].data)) > 0 461 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike_image) 462 | model.evaluate() 463 | outputs = model.outputs 464 | diffs = [g2.ImageD(img.data.copy()) for img in outputs] 465 | delta = 1e-5 466 | param.value -= delta 467 | model.evaluate() 468 | for diff, output in zip(diffs, outputs): 469 | diff = (output.data - diff.data) / delta 470 | jacobian = fitinputs.jacobians[0][1].data 471 | fig, ax = plt.subplots(1, 2) 472 | ax[0].imshow(diff) 473 | ax[1].imshow(jacobian) 474 | plt.show() 475 | assert len(errors) == 0 476 | 477 | 478 | def test_psf_models_linear_gaussians(data, psf_models_linear_gaussians, psf_observations): 479 | results = [None] * len(psf_observations) 480 | for idx, (gaussians_linear, observation_psf) in enumerate( 481 | zip(psf_models_linear_gaussians, psf_observations) 482 | ): 483 | results[idx] = Modeller.fit_gaussians_linear( 484 | gaussians_linear=gaussians_linear, 485 | observation=observation_psf, 486 | fitmethods=fitmethods_linear, 487 | plot=False, 488 | ) 489 | assert len(results[idx]) > 0 490 | 491 | 492 | def test_modeller_fit_linear(model): 493 | modeller = Modeller() 494 | results = modeller.fit_model_linear(model, validate=True) 495 | # TODO: add more here 496 | assert results is not None 497 | -------------------------------------------------------------------------------- /examples/fithsc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "collapsed": true 7 | }, 8 | "source": [ 9 | "# Fitting HSC data in multiband mode using MultiProFit" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "outputs": [], 16 | "source": [ 17 | "# Import required packages\n", 18 | "import time\n", 19 | "from typing import Any, Iterable, Mapping\n", 20 | "\n", 21 | "from astropy.coordinates import SkyCoord\n", 22 | "from astropy.io.ascii import Csv\n", 23 | "import astropy.io.fits as fits\n", 24 | "import astropy.table as apTab\n", 25 | "import astropy.visualization as apVis\n", 26 | "from astropy.wcs import WCS\n", 27 | "import lsst.gauss2d as g2\n", 28 | "import lsst.gauss2d.fit as g2f\n", 29 | "from lsst.multiprofit.componentconfig import SersicComponentConfig, SersicIndexParameterConfig\n", 30 | "from lsst.multiprofit.fit_psf import (\n", 31 | " CatalogExposurePsfABC,\n", 32 | " CatalogPsfFitter,\n", 33 | " CatalogPsfFitterConfig,\n", 34 | " CatalogPsfFitterConfigData,\n", 35 | ")\n", 36 | "from lsst.multiprofit.fit_source import (\n", 37 | " CatalogExposureSourcesABC,\n", 38 | " CatalogSourceFitterABC,\n", 39 | " CatalogSourceFitterConfig,\n", 40 | " CatalogSourceFitterConfigData,\n", 41 | ")\n", 42 | "from lsst.multiprofit.modelconfig import ModelConfig\n", 43 | "from lsst.multiprofit.plots import plot_model_rgb\n", 44 | "from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig\n", 45 | "from lsst.multiprofit.utils import ArbitraryAllowedConfig, get_params_uniq\n", 46 | "import matplotlib as mpl\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "import numpy as np\n", 49 | "import pydantic\n", 50 | "from pydantic.dataclasses import dataclass" 51 | ], 52 | "metadata": { 53 | "collapsed": false, 54 | "ExecuteTime": { 55 | "end_time": "2024-06-14T06:17:42.558190725Z", 56 | "start_time": "2024-06-14T06:17:41.498267418Z" 57 | } 58 | } 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "metadata": { 64 | "ExecuteTime": { 65 | "end_time": "2024-06-14T06:17:42.608230886Z", 66 | "start_time": "2024-06-14T06:17:42.560250073Z" 67 | } 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# Define settings\n", 72 | "band_ref = 'i'\n", 73 | "bands = {'i': 0.87108833, 'r': 0.97288654, 'g': 1.44564678}\n", 74 | "band_multi = ''.join(bands)\n", 75 | "channels = {band: g2f.Channel.get(band) for band in bands}\n", 76 | "\n", 77 | "# This is in the WCS, but may as well keep full precision\n", 78 | "scale_pixel_hsc = 0.168\n", 79 | "\n", 80 | "# Common to all FITS\n", 81 | "hdu_img, hdu_mask, hdu_var = 1, 2, 3\n", 82 | "\n", 83 | "# Masks\n", 84 | "bad_masks = (\n", 85 | " 'BAD', 'SAT', 'INTRP', 'CR', 'EDGE', 'CLIPPED', 'NO_DATA', 'CROSSTALK',\n", 86 | " 'NO_DATA', 'UNMASKEDNAN', 'SUSPECT', 'REJECTED', 'SENSOR_EDGE',\n", 87 | ")\n", 88 | "maskbits = tuple(f'MP_{b}' for b in bad_masks)\n", 89 | "\n", 90 | "# A pre-defined bitmask to exclude regions with low SN\n", 91 | "read_mask_highsn = True\n", 92 | "write_mask_highsn = False\n", 93 | "\n", 94 | "# matplotlib settings\n", 95 | "mpl.rcParams['image.origin'] = 'lower'\n", 96 | "mpl.rcParams['figure.dpi'] = 120" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 3, 102 | "metadata": { 103 | "ExecuteTime": { 104 | "end_time": "2024-06-14T06:17:42.613129061Z", 105 | "start_time": "2024-06-14T06:17:42.569183931Z" 106 | } 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "# Define source to fit\n", 111 | "id_gama, z = 79635, 0.0403\n", 112 | "\"\"\"\n", 113 | "Acquired from https://hsc-release.mtk.nao.ac.jp/datasearch/catalog_jobs with query:\n", 114 | "\n", 115 | "SELECT object_id, ra, dec,\n", 116 | " g_cmodel_mag, g_cmodel_magerr, r_cmodel_mag, r_cmodel_magerr, i_cmodel_mag, i_cmodel_magerr,\n", 117 | " g_psfflux_mag, g_psfflux_magerr, r_psfflux_mag, r_psfflux_magerr, i_psfflux_mag, i_psfflux_magerr,\n", 118 | " g_kronflux_mag, g_kronflux_magerr, r_kronflux_mag, r_kronflux_magerr, i_kronflux_mag, i_kronflux_magerr,\n", 119 | " g_sdssshape_shape11, g_sdssshape_shape11err, g_sdssshape_shape22, g_sdssshape_shape22err,\n", 120 | " g_sdssshape_shape12, g_sdssshape_shape12err,\n", 121 | " r_sdssshape_shape11, r_sdssshape_shape11err, r_sdssshape_shape22, r_sdssshape_shape22err,\n", 122 | " r_sdssshape_shape12, r_sdssshape_shape12err,\n", 123 | " i_sdssshape_shape11, i_sdssshape_shape11err, i_sdssshape_shape22, i_sdssshape_shape22err,\n", 124 | " i_sdssshape_shape12, i_sdssshape_shape12err\n", 125 | "FROM pdr3_wide.forced\n", 126 | "LEFT JOIN pdr3_wide.forced2 USING (object_id)\n", 127 | "WHERE isprimary AND conesearch(coord, 222.51551376, 0.09749601, 35.64)\n", 128 | "AND (r_kronflux_mag < 26 OR i_kronflux_mag < 26) AND NOT i_kronflux_flag AND NOT r_kronflux_flag;\n", 129 | "\"\"\"\n", 130 | "cat = Csv()\n", 131 | "cat.header.splitter.escapechar = '#'\n", 132 | "cat = cat.read('fithsc_src.csv')\n", 133 | "\n", 134 | "prefix = '222.51551376,0.09749601_'\n", 135 | "prefix_img = f'{prefix}300x300_'\n", 136 | "\n", 137 | "# Read data, acquired with:\n", 138 | "# https://github.com/taranu/astro_imaging/blob/4d5a8e095e6a3944f1fbc19318b1dbc22b22d9ca/examples/HSC.ipynb\n", 139 | "# (with get_mask=True, get_variance=True,)\n", 140 | "images, psfs = {}, {}\n", 141 | "for band in bands:\n", 142 | " images[band] = fits.open(f'{prefix_img}{band}.fits')\n", 143 | " psfs[band] = fits.open(f'{prefix}{band}_psf.fits')\n", 144 | "\n", 145 | "wcs = WCS(images[band_ref][hdu_img])\n", 146 | "cat['x'], cat['y'] = wcs.world_to_pixel(SkyCoord(cat['ra'], cat['dec'], unit='deg'))" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": { 153 | "is_executing": true, 154 | "ExecuteTime": { 155 | "start_time": "2024-06-14T06:17:42.599466290Z" 156 | } 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "# Plot image\n", 161 | "img_rgb = apVis.make_lupton_rgb(*[img[1].data*bands[band] for band, img in images.items()])\n", 162 | "plt.scatter(cat['x'], cat['y'], s=10, c='g', marker='x')\n", 163 | "plt.imshow(img_rgb)\n", 164 | "plt.title(\"gri image with detected objects\")\n", 165 | "plt.show()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": { 172 | "is_executing": true 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "# Generate a rough mask around other sources\n", 177 | "bright = (cat['i_cmodel_mag'] < 23) | (cat['i_psfflux_mag'] < 23)\n", 178 | "\n", 179 | "img_ref = images[band_ref][hdu_img].data\n", 180 | "\n", 181 | "mask_inverse = np.ones(img_ref.shape, dtype=bool)\n", 182 | "y_cen, x_cen = (x/2. for x in img_ref.shape)\n", 183 | "y, x = np.indices(img_ref.shape)\n", 184 | "\n", 185 | "idx_src_main, row_main = None, None\n", 186 | "\n", 187 | "sizes_override = {\n", 188 | " 42305088563206480: 8.,\n", 189 | "}\n", 190 | "\n", 191 | "for src in cat[bright]:\n", 192 | " id_src, x_src, y_src = (src[col] for col in ['object_id', 'x', 'y'])\n", 193 | " dist = np.hypot(x_src - x_cen, y_src - y_cen)\n", 194 | " if dist > 20:\n", 195 | " dists = np.hypot(y - y_src, x - x_src)\n", 196 | " mag = np.nanmin([src['i_cmodel_mag'], src['r_cmodel_mag'], src['i_psfflux_mag'], src['r_psfflux_mag']])\n", 197 | " if (radius_mask := sizes_override.get(id_src)) is None:\n", 198 | " radius_mask = 2*np.sqrt(\n", 199 | " src[f'{band_ref}_sdssshape_shape11'] + src[f'{band_ref}_sdssshape_shape22']\n", 200 | " )/scale_pixel_hsc\n", 201 | " if (radius_mask > 10) and (mag > 21):\n", 202 | " radius_mask = 5\n", 203 | " mask_inverse[dists < radius_mask] = 0\n", 204 | " print(f'Masking src=({id_src} at {x_src:.3f}, {y_src:.3f}) dist={dist:.3f}'\n", 205 | " f', mag={mag:.3f}, radius_mask={radius_mask:.3f}')\n", 206 | " elif dist < 2:\n", 207 | " idx_src_main = id_src\n", 208 | " row_main = src\n", 209 | " print(f\"{idx_src_main=} {dict(src)=}\")\n", 210 | "\n", 211 | "tab_row_main = apTab.Table(row_main)\n", 212 | "\n", 213 | "if read_mask_highsn:\n", 214 | " mask_highsn = np.load(f'{prefix_img}mask_inv_highsn.npz')['mask_inv']\n", 215 | " mask_inverse *= mask_highsn\n", 216 | "\n", 217 | "plt.imshow(mask_inverse)\n", 218 | "plt.title(\"Fitting mask\")\n", 219 | "plt.show()" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "metadata": { 226 | "is_executing": true 227 | }, 228 | "outputs": [], 229 | "source": [ 230 | "# Fit PSF\n", 231 | "@dataclass(frozen=True, config=ArbitraryAllowedConfig)\n", 232 | "class CatalogExposurePsf(CatalogExposurePsfABC):\n", 233 | " catalog: apTab.Table = pydantic.Field(title=\"The detected object catalog\")\n", 234 | " img: np.ndarray = pydantic.Field(title=\"The PSF image\")\n", 235 | "\n", 236 | " def get_catalog(self) -> Iterable:\n", 237 | " return self.catalog\n", 238 | "\n", 239 | " def get_psf_image(self, source: apTab.Row | Mapping[str, Any]) -> np.array:\n", 240 | " return self.img\n", 241 | "\n", 242 | "config_psf = CatalogPsfFitterConfig(column_id='object_id')\n", 243 | "fitter_psf = CatalogPsfFitter()\n", 244 | "catalog_psf = apTab.Table({'object_id': [tab_row_main['object_id']]})\n", 245 | "results_psf = {}\n", 246 | "\n", 247 | "# Keep a separate configdata_psf per band, because it has a cached PSF model\n", 248 | "# those should not be shared!\n", 249 | "config_data_psfs = {}\n", 250 | "for band, psf_file in psfs.items():\n", 251 | " config_data_psf = CatalogPsfFitterConfigData(config=config_psf)\n", 252 | " catexp = CatalogExposurePsf(catalog=catalog_psf, img=psf_file[0].data)\n", 253 | " t_start = time.time()\n", 254 | " result = fitter_psf.fit(config_data=config_data_psf, catexp=catexp)\n", 255 | " t_end = time.time()\n", 256 | " results_psf[band] = result\n", 257 | " config_data_psfs[band] = config_data_psf\n", 258 | " print(f\"Fit {band}-band PSF in {t_end - t_start:.2e}s; result:\")\n", 259 | " print(dict(result[0]))" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "metadata": { 266 | "is_executing": true 267 | }, 268 | "outputs": [], 269 | "source": [ 270 | "# Set fit configs\n", 271 | "config_source = CatalogSourceFitterConfig(\n", 272 | " column_id='object_id',\n", 273 | " config_model=ModelConfig(\n", 274 | " sources={\n", 275 | " \"src\": SourceConfig(\n", 276 | " component_groups={\n", 277 | " \"\": ComponentGroupConfig(\n", 278 | " components_sersic={\n", 279 | " 'disk': SersicComponentConfig(\n", 280 | " sersic_index=SersicIndexParameterConfig(value_initial=1., fixed=True),\n", 281 | " prior_size_stddev=0.5,\n", 282 | " prior_axrat_stddev=0.2,\n", 283 | " ),\n", 284 | " 'bulge': SersicComponentConfig(\n", 285 | " sersic_index=SersicIndexParameterConfig(value_initial=4., fixed=True),\n", 286 | " prior_size_stddev=0.1,\n", 287 | " prior_axrat_stddev=0.2,\n", 288 | " ),\n", 289 | " },\n", 290 | " ),\n", 291 | " }\n", 292 | " ),\n", 293 | " },\n", 294 | " ),\n", 295 | ")\n", 296 | "config_data_source = CatalogSourceFitterConfigData(\n", 297 | " channels=list(channels.values()),\n", 298 | " config=config_source,\n", 299 | ")" 300 | ] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": { 306 | "is_executing": true 307 | }, 308 | "outputs": [], 309 | "source": [ 310 | "# Setup exposure with band-specific image, mask and variance\n", 311 | "@dataclass(frozen=True, config=ArbitraryAllowedConfig)\n", 312 | "class CatalogExposureSources(CatalogExposureSourcesABC):\n", 313 | " config_data_psf: CatalogPsfFitterConfigData = pydantic.Field(title=\"The PSF fit config\")\n", 314 | " observation: g2f.ObservationD = pydantic.Field(title=\"The observation to fit\")\n", 315 | " table_psf_fits: apTab.Table = pydantic.Field(title=\"The table of PSF fit parameters\")\n", 316 | "\n", 317 | " @property\n", 318 | " def channel(self) -> g2f.Channel:\n", 319 | " return self.observation.channel\n", 320 | "\n", 321 | " def get_catalog(self) -> Iterable:\n", 322 | " return self.table_psf_fits\n", 323 | "\n", 324 | " def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel:\n", 325 | " self.config_data_psf.init_psf_model(params)\n", 326 | " return self.config_data_psf.psf_model\n", 327 | "\n", 328 | " def get_source_observation(self, source: Mapping[str, Any]) -> g2f.ObservationD:\n", 329 | " return self.observation\n", 330 | "\n", 331 | "\n", 332 | "@dataclass(frozen=True, config=ArbitraryAllowedConfig)\n", 333 | "class CatalogSourceFitter(CatalogSourceFitterABC):\n", 334 | " band: str = pydantic.Field(title=\"The reference band for initialization and priors\")\n", 335 | " scale_pixel: float = pydantic.Field(title=\"The pixel scale in arcsec\")\n", 336 | " wcs_ref: WCS = pydantic.Field(title=\"The WCS for the coadded image\")\n", 337 | "\n", 338 | " def initialize_model(\n", 339 | " self,\n", 340 | " model: g2f.ModelD,\n", 341 | " source: Mapping[str, Any],\n", 342 | " catexps: list[CatalogExposureSourcesABC],\n", 343 | " values_init: Mapping[g2f.ParameterD, float] | None = None,\n", 344 | " centroid_pixel_offset: float = 0,\n", 345 | " **kwargs\n", 346 | " ):\n", 347 | " if values_init is None:\n", 348 | " values_init = {}\n", 349 | " x, y = source['x'], source['y']\n", 350 | " scale_sq = self.scale_pixel**(-2)\n", 351 | " ellipse = g2.Ellipse(g2.Covariance(\n", 352 | " sigma_x_sq=source[f'{band}_sdssshape_shape11']*scale_sq,\n", 353 | " sigma_y_sq=source[f'{band}_sdssshape_shape22']*scale_sq,\n", 354 | " cov_xy=source[f'{band}_sdssshape_shape12']*scale_sq,\n", 355 | " ))\n", 356 | " size_major = g2.EllipseMajor(ellipse).r_major\n", 357 | " limits_size = g2f.LimitsD(1e-5, np.sqrt(x*x + y*y))\n", 358 | " # An R_eff larger than the box size is problematic\n", 359 | " # Also should stop unreasonable size proposals; log10 transform isn't enough\n", 360 | " # TODO: Try logit for r_eff?\n", 361 | " params_limits_init = {\n", 362 | " # Should set limits based on image size, but this shortcut is fine\n", 363 | " # for this particular object\n", 364 | " g2f.CentroidXParameterD: (x, g2f.LimitsD(0, 2*x)),\n", 365 | " g2f.CentroidYParameterD: (x, g2f.LimitsD(0, 2*y)),\n", 366 | " g2f.ReffXParameterD: (ellipse.sigma_x, limits_size),\n", 367 | " g2f.ReffYParameterD: (ellipse.sigma_y, limits_size),\n", 368 | " # There is a sign convention difference\n", 369 | " g2f.RhoParameterD: (-ellipse.rho, None),\n", 370 | " g2f.IntegralParameterD: (1.0, g2f.LimitsD(1e-10, 1e10)),\n", 371 | " }\n", 372 | " params_free = get_params_uniq(model, fixed=False)\n", 373 | " for param in params_free:\n", 374 | " type_param = type(param)\n", 375 | " value_init, limits_new = params_limits_init.get(\n", 376 | " type_param,\n", 377 | " (values_init.get(param), None)\n", 378 | " )\n", 379 | " if value_init is not None:\n", 380 | " param.value = value_init\n", 381 | " if limits_new:\n", 382 | " # For slightly arcane reasons, we must set a new limits object\n", 383 | " # Changing limits values is unreliable\n", 384 | " param.limits = limits_new\n", 385 | " for prior in model.priors:\n", 386 | " if isinstance(prior, g2f.ShapePrior):\n", 387 | " prior.prior_size.mean_parameter.value = size_major\n", 388 | "\n", 389 | "\n", 390 | " def validate_fit_inputs(\n", 391 | " self,\n", 392 | " catalog_multi,\n", 393 | " catexps: list[CatalogExposureSourcesABC],\n", 394 | " config_data: CatalogSourceFitterConfigData = None,\n", 395 | " logger = None,\n", 396 | " **kwargs: Any,\n", 397 | " ) -> None:\n", 398 | " super().validate_fit_inputs(\n", 399 | " catalog_multi=catalog_multi, catexps=catexps, config_data=config_data,\n", 400 | " logger=logger, **kwargs\n", 401 | " )" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": null, 407 | "metadata": { 408 | "is_executing": true 409 | }, 410 | "outputs": [], 411 | "source": [ 412 | "# Set up Fitter, Observations and CatalogExposureSources\n", 413 | "fitter = CatalogSourceFitter(band=band, scale_pixel=scale_pixel_hsc, wcs_ref=wcs)\n", 414 | "\n", 415 | "observations = {}\n", 416 | "catexps = {}\n", 417 | "\n", 418 | "for band in bands:\n", 419 | " data = images[band]\n", 420 | " # There are better ways to use bitmasks, but this will do\n", 421 | " header = data[hdu_mask].header\n", 422 | " bitmask = data[hdu_mask].data\n", 423 | " mask = np.zeros_like(bitmask, dtype='bool')\n", 424 | " for bit in maskbits:\n", 425 | " mask |= ((bitmask & 2**header[bit]) != 0)\n", 426 | "\n", 427 | " mask = (mask == 0) & mask_inverse\n", 428 | " sigma_inv = 1.0/np.sqrt(data[hdu_var].data)\n", 429 | " sigma_inv[mask != 1] = 0\n", 430 | "\n", 431 | " observation = g2f.ObservationD(\n", 432 | " image=g2.ImageD(data[hdu_img].data),\n", 433 | " sigma_inv=g2.ImageD(sigma_inv),\n", 434 | " mask_inv=g2.ImageB(mask),\n", 435 | " channel=g2f.Channel.get(band),\n", 436 | " )\n", 437 | " observations[band] = observation\n", 438 | " catexps[band] = CatalogExposureSources(\n", 439 | " config_data_psf=config_data_psfs[band],\n", 440 | " observation=observation,\n", 441 | " table_psf_fits=results_psf[band],\n", 442 | " )" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "metadata": { 449 | "scrolled": true, 450 | "is_executing": true 451 | }, 452 | "outputs": [], 453 | "source": [ 454 | "# Now do the multi-band fit\n", 455 | "t_start = time.time()\n", 456 | "result_multi = fitter.fit(\n", 457 | " catalog_multi=tab_row_main,\n", 458 | " catexps=list(catexps.values()),\n", 459 | " config_data=config_data_source,\n", 460 | ")\n", 461 | "t_end = time.time()\n", 462 | "print(f\"Fit {','.join(bands.keys())}-band bulge-disk model in {t_end - t_start:.2e}s; result:\")\n", 463 | "print(dict(result_multi[0]))" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": { 470 | "scrolled": true, 471 | "is_executing": true 472 | }, 473 | "outputs": [], 474 | "source": [ 475 | "# Fit in each band separately\n", 476 | "results = {}\n", 477 | "for band, observation in bands.items():\n", 478 | " config_data_source_band = CatalogSourceFitterConfigData(\n", 479 | " channels=[channels[band]],\n", 480 | " config=config_source,\n", 481 | " )\n", 482 | " t_start = time.time()\n", 483 | " result = fitter.fit(\n", 484 | " catalog_multi=tab_row_main,\n", 485 | " catexps=[catexps[band]],\n", 486 | " config_data=config_data_source_band,\n", 487 | " )\n", 488 | " t_end = time.time()\n", 489 | " results[band] = result\n", 490 | " print(f\"Fit {band}-band bulge-disk model in {t_end - t_start:.2f}s; result:\")\n", 491 | " print(dict(result[0]))" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "metadata": { 498 | "is_executing": true 499 | }, 500 | "outputs": [], 501 | "source": [ 502 | "# Make a model for the best-fit params\n", 503 | "data, psf_models = config_source.make_model_data(idx_row=0, catexps=list(catexps.values()))\n", 504 | "model = g2f.ModelD(data=data, psfmodels=psf_models, sources=config_data_source.sources_priors[0], priors=config_data_source.sources_priors[1])\n", 505 | "params = get_params_uniq(model, fixed=False)\n", 506 | "result_multi_row = dict(result_multi[0])\n", 507 | "# This is the last column before fit params\n", 508 | "idx_last = next(idx for idx, column in enumerate(result_multi_row.keys()) if column == 'mpf_unknown_flag')\n", 509 | "# Set params to best fit values\n", 510 | "for param, (column, value) in zip(params, list(result_multi_row.items())[idx_last+1:]):\n", 511 | " param.value = value\n", 512 | "model.setup_evaluators(g2f.EvaluatorMode.loglike_image)\n", 513 | "# Print the loglikelihoods, which are from the data and end with the (sum of all) priors\n", 514 | "loglikes = model.evaluate()\n", 515 | "print(f\"{loglikes=}\")" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "### Multiband Residuals\n", 523 | "\n", 524 | "What's with the structure in the residuals? Most broadly, a point source + exponential disk + deVauc bulge model is totally inadequate for this galaxy for several possible reasons:\n", 525 | "\n", 526 | "1. The disk isn't exactly exponential (n=1)\n", 527 | "2. The disk has colour gradients not accounted for in this model*\n", 528 | "3. If the galaxy even has a bulge, it's very weak and def. not a deVaucouleurs (n=4) profile; it may be an exponential \"pseudobulge\"\n", 529 | "\n", 530 | "\\*MultiProFit can do more general Gaussian mixture models (linear or non-linear), which may be explored in a future iteration of this notebook, but these are generally do not improve the accuracy of photometry for smaller/fainter galaxies.\n", 531 | "\n", 532 | "Note that the two scalings of the residual plots (98%ile and +/- 20 sigma) end up looking very similar.\n" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": null, 538 | "metadata": { 539 | "is_executing": true 540 | }, 541 | "outputs": [], 542 | "source": [ 543 | "# Make some basic plots\n", 544 | "_, _, _, _, mask_inv_highsn = plot_model_rgb(\n", 545 | " model, weights=bands, high_sn_threshold=0.2 if write_mask_highsn else None,\n", 546 | ")\n", 547 | "plt.show()\n", 548 | "\n", 549 | "# Write the high SN bitmask to a compressed, bitpacked file\n", 550 | "if write_mask_highsn:\n", 551 | " plt.figure()\n", 552 | " plt.imshow(mask_highsn, cmap='gray')\n", 553 | " plt.show()\n", 554 | " packed = np.packbits(mask_inv_highsn, bitorder='little')\n", 555 | " np.savez_compressed(f'{prefix_img}mask_inv_highsn.npz', mask_inv=mask_highsn)\n", 556 | "\n", 557 | "# TODO: Some features still missing from plot_model_rgb\n", 558 | "# residual histograms, param values, better labels, etc" 559 | ] 560 | }, 561 | { 562 | "cell_type": "markdown", 563 | "metadata": {}, 564 | "source": [ 565 | "### More exercises for the reader\n", 566 | "\n", 567 | "These are of the sort that the author hasn't gotten around to yet because they're far from trivial. Try:\n", 568 | "\n", 569 | "0. Use the WCS to compute ra, dec and errors thereof.\n", 570 | "Hint: override CatalogSourceFitter.get_model_radec\n", 571 | "\n", 572 | "1. Replace the real data with simulated data.\n", 573 | "Make new observations using model.evaluate and add noise based on the variance maps.\n", 574 | "Try fitting again and see how well results converge depending on the initialization scheme.\n", 575 | "\n", 576 | "2. Fit every other source individually.\n", 577 | "Try subtracting the best-fit galaxy model from above first.\n", 578 | "Hint: get_source_observation should be redefined to return a smaller postage stamp around the nominal centroid.\n", 579 | "Pass the full catalog (excluding the central galaxy) to catalog_multi.\n", 580 | "\n", 581 | "3. Fit all sources simultaneously.\n", 582 | "Redefine CatalogFitterConfig.make_model_data to make a model with multiple sources, using the catexp catalogs\n", 583 | "initialize_model will no longer need to do anything\n", 584 | "catalog_multi should still be a single row" 585 | ] 586 | } 587 | ], 588 | "metadata": { 589 | "kernelspec": { 590 | "display_name": "Python 3 (ipykernel)", 591 | "language": "python", 592 | "name": "python3" 593 | }, 594 | "language_info": { 595 | "codemirror_mode": { 596 | "name": "ipython", 597 | "version": 3 598 | }, 599 | "file_extension": ".py", 600 | "mimetype": "text/x-python", 601 | "name": "python", 602 | "nbconvert_exporter": "python", 603 | "pygments_lexer": "ipython3", 604 | "version": "3.11.4" 605 | } 606 | }, 607 | "nbformat": 4, 608 | "nbformat_minor": 1 609 | } 610 | --------------------------------------------------------------------------------