├── COPYRIGHT
├── python
└── lsst
│ ├── __init__.py
│ └── multiprofit
│ ├── __init__.py
│ ├── limits.py
│ ├── errors.py
│ ├── config.py
│ ├── model_utils.py
│ ├── utils.py
│ ├── observationconfig.py
│ ├── priors.py
│ ├── modelconfig.py
│ ├── asinhstretchsigned.py
│ ├── fit_catalog.py
│ ├── psfmodel_utils.py
│ ├── transforms.py
│ ├── fit_bootstrap_model.py
│ ├── sourceconfig.py
│ └── componentconfig.py
├── tests
├── SConscript
├── test_psfmodel_utils.py
├── test_fit_psf.py
├── test_observationconfig.py
├── test_plots.py
├── test_componentconfig.py
├── test_modelconfig.py
├── test_sourceconfig.py
├── test_fit_bootstrap_model.py
└── test_modeller.py
├── requirements.txt
├── examples
├── 222.51551376,0.09749601_g_psf.fits
├── 222.51551376,0.09749601_i_psf.fits
├── 222.51551376,0.09749601_r_psf.fits
├── 222.51551376,0.09749601_300x300_g.fits
├── 222.51551376,0.09749601_300x300_i.fits
├── 222.51551376,0.09749601_300x300_r.fits
├── 222.51551376,0.09749601_300x300_mask_inv_highsn.npz
├── test_gaussians.py
├── plot_sersic_mix.py
├── test_gaussian_gradients.py
├── test_mgsersic.py
├── fithsc.py
└── fithsc.ipynb
├── doc
├── .gitignore
├── index.rst
├── manifest.yaml
├── conf.py
└── lsst.multiprofit
│ └── index.rst
├── SConstruct
├── .github
└── workflows
│ └── rebase_checker.yaml
├── ups
└── multiprofit.table
├── .gitignore
├── .pre-commit-config.yaml
├── README.rst
└── pyproject.toml
/COPYRIGHT:
--------------------------------------------------------------------------------
1 | Copyright 2018-2021 The Trustees of Princeton University
2 |
--------------------------------------------------------------------------------
/python/lsst/__init__.py:
--------------------------------------------------------------------------------
1 | import pkgutil
2 |
3 | __path__ = pkgutil.extend_path(__path__, __name__)
4 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/__init__.py:
--------------------------------------------------------------------------------
1 | import pkgutil
2 |
3 | __path__ = pkgutil.extend_path(__path__, __name__)
4 |
--------------------------------------------------------------------------------
/tests/SConscript:
--------------------------------------------------------------------------------
1 | # -*- python -*-
2 | from lsst.sconsUtils import scripts
3 | scripts.BasicSConscript.tests(pyList=[])
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | astropy
2 | galsim
3 | gauss2d
4 | matplotlib
5 | numpy
6 | pydantic
7 | pytest
8 | scipy
9 | seaborn
10 |
--------------------------------------------------------------------------------
/examples/222.51551376,0.09749601_g_psf.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_g_psf.fits
--------------------------------------------------------------------------------
/examples/222.51551376,0.09749601_i_psf.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_i_psf.fits
--------------------------------------------------------------------------------
/examples/222.51551376,0.09749601_r_psf.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_r_psf.fits
--------------------------------------------------------------------------------
/doc/.gitignore:
--------------------------------------------------------------------------------
1 | # Doxygen products
2 | html
3 | xml
4 | *.tag
5 | *.inc
6 | doxygen.conf
7 |
8 | # Sphinx products
9 | _build
10 | py-api
11 |
--------------------------------------------------------------------------------
/examples/222.51551376,0.09749601_300x300_g.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_300x300_g.fits
--------------------------------------------------------------------------------
/examples/222.51551376,0.09749601_300x300_i.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_300x300_i.fits
--------------------------------------------------------------------------------
/examples/222.51551376,0.09749601_300x300_r.fits:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_300x300_r.fits
--------------------------------------------------------------------------------
/SConstruct:
--------------------------------------------------------------------------------
1 | # -*- python -*-
2 | from lsst.sconsUtils import scripts
3 | # Python-only package
4 | scripts.BasicSConstruct("multiprofit", disableCc=True, noCfgFile=True)
5 |
--------------------------------------------------------------------------------
/examples/222.51551376,0.09749601_300x300_mask_inv_highsn.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsst-dm/legacy-multiprofit/HEAD/examples/222.51551376,0.09749601_300x300_mask_inv_highsn.npz
--------------------------------------------------------------------------------
/.github/workflows/rebase_checker.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | name: Check that 'main' is not merged into the development branch
3 |
4 | on: pull_request
5 |
6 | jobs:
7 | call-workflow:
8 | uses: lsst/rubin_workflows/.github/workflows/rebase_checker.yaml@main
9 |
--------------------------------------------------------------------------------
/ups/multiprofit.table:
--------------------------------------------------------------------------------
1 | # List EUPS dependencies of this package here.
2 | # - Any package whose API is used directly should be listed explicitly.
3 | # - Common third-party packages can be assumed to be recursively included by
4 | # the "sconsUtils" package.
5 | setupRequired(sconsUtils)
6 | setupRequired(gauss2dfit)
7 |
8 | envPrepend(PYTHONPATH, ${PRODUCT_DIR}/python)
9 |
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | ################################
2 | multiprofit documentation preview
3 | ################################
4 |
5 | .. This page is for local development only. It isn't published to pipelines.lsst.io.
6 |
7 | .. Link the index pages of package and module documentation directions (listed in manifest.yaml).
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 |
12 | lsst.multiprofit/index
13 |
--------------------------------------------------------------------------------
/doc/manifest.yaml:
--------------------------------------------------------------------------------
1 | # Documentation manifest.
2 |
3 | # List of names of Python modules in this package.
4 | # For each module there is a corresponding module doc subdirectory.
5 | modules:
6 | - "lsst.multiprofit"
7 |
8 | # Name of the static content directories (subdirectories of `_static`).
9 | # Static content directories are usually named after the package.
10 | # Most packages do not need a static content directory (leave commented out).
11 | # statics:
12 | # - "_static/multiprofit"
13 |
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | """Sphinx configuration file for an LSST stack package.
2 |
3 | This configuration only affects single-package Sphinx documentation builds.
4 | For more information, see:
5 | https://developer.lsst.io/stack/building-single-package-docs.html
6 | """
7 |
8 | from documenteer.conf.pipelinespkg import * # noqa: F403, import *
9 |
10 | project = "multiprofit"
11 | html_theme_options["logotext"] = project # noqa: F405, unknown name
12 | html_title = project
13 | html_short_title = project
14 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Build
2 | build/
3 | dist/
4 | multiprofit.egg-info/
5 | __pycache__/
6 | .coverage
7 | .sconsign.dblite
8 | config.log
9 | version.py
10 |
11 | # Generated by pip install -e .
12 | python/lsst_multiprofit.egg-info/
13 |
14 | # IDE folders and files
15 | .cache/
16 | .idea/
17 | .theia/
18 | .vscode/
19 | compile_commands.json
20 |
21 | # test outputs
22 | tests/.tests
23 |
24 | # pytest plugins
25 | .hypothesis/
26 | prof/
27 |
28 | # in-source notebooks
29 | examples/.ipynb_checkpoints
30 |
--------------------------------------------------------------------------------
/examples/test_gaussians.py:
--------------------------------------------------------------------------------
1 | from timeit import default_timer as timer
2 |
3 | from test_utils import gaussian_test
4 |
5 | start = timer()
6 | test = gaussian_test(
7 | nbenchmark=100,
8 | do_like=True,
9 | do_residual=True,
10 | do_grad=True,
11 | do_jac=True,
12 | do_meas_modelfit=False,
13 | nsub=4,
14 | )
15 | for x in test:
16 | print(f"re={x['reff']:.3f} q={x['axrat']:.2f}" f" ang={x['ang']:2.1f} { x['string']}")
17 | print(f"Test complete in {timer() - start:.2f}s")
18 |
--------------------------------------------------------------------------------
/examples/plot_sersic_mix.py:
--------------------------------------------------------------------------------
1 | import lsst.gauss2d.fit as g2f
2 | from lsst.multiprofit.plots import plot_sersicmix_interp
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | from scipy.interpolate import CubicSpline
6 |
7 | interps = {
8 | "lin": (g2f.LinearSersicMixInterpolator(), "-"),
9 | "gsl-csp": (g2f.GSLSersicMixInterpolator(interp_type=g2f.InterpType.cspline), (0, (8, 8))),
10 | "scipy-csp": ((CubicSpline, {}), (0, (4, 4))),
11 | }
12 |
13 | for n_low, n_hi in ((0.5, 0.7), (0.8, 1.2), (2.2, 4.4)):
14 | n_ser = 10 ** np.linspace(np.log10(n_low), np.log10(n_hi), 100)
15 | plot_sersicmix_interp(interps=interps, n_ser=n_ser, figsize=(10, 8))
16 | plt.tight_layout()
17 | plt.show()
18 |
--------------------------------------------------------------------------------
/examples/test_gaussian_gradients.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from test_utils import gradient_test
3 |
4 | np.random.seed(1)
5 | for reff, axrat, ang in [(4.457911011776755, 0.6437167462922668, 44.55485360075)]:
6 | for reff_psf, axrat_psf, ang_psf in [
7 | (0, 0, 0),
8 | (1.5121054822774742, 0.9135936343054303, 50.30562156585181),
9 | (3.7442185156735914, 0.8695066738347554, -39.40729158864958),
10 | ]:
11 | grads, dlls, diffabs = gradient_test(
12 | xdim=23,
13 | ydim=19,
14 | reff=5,
15 | angle=20,
16 | reff_psf=reff_psf,
17 | axrat_psf=axrat_psf,
18 | angle_psf=ang_psf,
19 | printout=False,
20 | plot=False,
21 | )
22 | print("Gradient ", grads)
23 | print("Finite Diff. ", dlls)
24 | print("FinD. - Grad.", dlls - grads)
25 | print("FinD. / Grad.", dlls / grads)
26 | print("Jacobian sum abs. diff.", diffabs)
27 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | repos:
5 | - repo: https://github.com/pre-commit/pre-commit-hooks
6 | rev: v4.5.0
7 | hooks:
8 | - id: end-of-file-fixer
9 | - id: trailing-whitespace
10 | - repo: https://github.com/psf/black-pre-commit-mirror
11 | rev: 23.3.0
12 | hooks:
13 | - id: black
14 | # It is recommended to specify the latest version of Python
15 | # supported by your project here, or alternatively use
16 | # pre-commit's default_language_version, see
17 | # https://pre-commit.com/#top_level-default_language_version
18 | - repo: https://github.com/pycqa/isort
19 | rev: 5.13.2
20 | hooks:
21 | - id: isort
22 | name: isort (python)
23 | - repo: https://github.com/astral-sh/ruff-pre-commit
24 | # Ruff version.
25 | rev: v0.1.8
26 | hooks:
27 | - id: ruff
28 | - repo: https://github.com/numpy/numpydoc
29 | rev: "v1.6.0"
30 | hooks:
31 | - id: numpydoc-validation
32 |
--------------------------------------------------------------------------------
/examples/test_mgsersic.py:
--------------------------------------------------------------------------------
1 | from timeit import default_timer as timer
2 |
3 | import numpy as np
4 | from test_utils import mgsersic_test
5 |
6 | start = timer()
7 | reffs = [2.0, 5.0]
8 | angs = np.linspace(0, 90, 7)
9 | axrats = [1, 0.5, 0.2, 0.1, 0.01]
10 | nsers = [0.5, 1.0, 2.0, 4.0, 6.0]
11 |
12 | mgsersic_test(reff=reffs[1], nser=2, axrat=axrats[1], angle=angs[2], do_galsim=True, plot=True)
13 |
14 | for nser in nsers:
15 | for reff in reffs:
16 | for axrat in axrats:
17 | for ang in angs:
18 | diffs = mgsersic_test(
19 | reff=reff,
20 | nser=nser,
21 | axrat=axrat,
22 | angle=ang,
23 | do_galsim=True,
24 | plot=False,
25 | flux=1.0,
26 | )
27 | diff_abs = np.sum(np.abs(diffs["gs_ser_nopix"]))
28 | print(
29 | f"Test: nser={nser} reff={reff} axrat={axrat} ang={ang}"
30 | f" sum(abs(diff))={diff_abs:.3f}%"
31 | )
32 |
--------------------------------------------------------------------------------
/tests/test_psfmodel_utils.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import lsst.gauss2d.fit as g2f
23 | from lsst.multiprofit.psfmodel_utils import make_psf_source
24 |
25 |
26 | def test_make_psf_source():
27 | source = make_psf_source([2, 3, 4])
28 | assert source.gaussians(g2f.Channel.NONE).size == 3
29 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/limits.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from lsst.gauss2d.fit import LimitsD as Limits
23 |
24 | __all__ = ["limits_ref"]
25 |
26 |
27 | # TODO: Replace with a parameter factory and/or profile factory
28 | limits_ref = {
29 | "none": Limits(),
30 | "axrat": Limits(min=1e-2, max=1),
31 | "con": Limits(min=1, max=10),
32 | "fluxfrac": Limits(min=0.001, max=0.999),
33 | "n_ser": Limits(min=0.3, max=6.0),
34 | "n_ser_multigauss": Limits(min=0.5, max=6.0),
35 | "rho": Limits(min=-0.999, max=0.999),
36 | }
37 |
--------------------------------------------------------------------------------
/tests/test_fit_psf.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from lsst.multiprofit.fit_psf import CatalogPsfFitterConfig, CatalogPsfFitterConfigData
23 | import pytest
24 |
25 |
26 | @pytest.fixture(scope="module")
27 | def fitter_config() -> CatalogPsfFitterConfig:
28 | config = CatalogPsfFitterConfig()
29 | return config
30 |
31 | @pytest.fixture(scope="module")
32 | def fitter_config_data(fitter_config) -> CatalogPsfFitterConfigData:
33 | config_data = CatalogPsfFitterConfigData(config=fitter_config)
34 | return config_data
35 |
36 |
37 | def test_fitter_config_data(fitter_config_data):
38 | parameters = fitter_config_data.parameters
39 | psf_model = fitter_config_data.psf_model
40 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/errors.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from abc import abstractmethod
23 |
24 | __all__ = ["CatalogError", "PsfRebuildFitFlagError"]
25 |
26 |
27 | class CatalogError(RuntimeError):
28 | """RuntimeError that can be caught and flagged in a column."""
29 |
30 | @classmethod
31 | @abstractmethod
32 | def column_name(cls) -> str:
33 | """Return the standard column name for this error."""
34 |
35 |
36 | class NoDataError(CatalogError):
37 | """RuntimeError for when there is no data to fit."""
38 |
39 | @classmethod
40 | @abstractmethod
41 | def column_name(cls) -> str:
42 | return "no_data_flag"
43 |
44 |
45 | class PsfRebuildFitFlagError(RuntimeError):
46 | """RuntimeError for when a PSF can't be rebuilt because the fit failed."""
47 |
48 | @classmethod
49 | def column_name(cls):
50 | return "psf_fit_flag"
51 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | This is the original repository for multiprofit and is now superseded by
2 | `multiprofit `_ in the lsst organization.
3 |
4 | MultiProFit
5 | ###########
6 |
7 | .. todo image:: https://travis-ci.org/ICRAR/multiprofit.svg?branch=master
8 | .. todo :target: https://travis-ci.org/lsst-dm/multiprofit
9 |
10 | .. todo image:: https://img.shields.io/pypi/v/multiprofit.svg
11 | .. todo :target: https://pypi.python.org/pypi/multiprofit
12 |
13 | .. todo image:: https://img.shields.io/pypi/pyversions/multiprofit.svg
14 | .. todo :target: https://pypi.python.org/pypi/multiprofit
15 |
16 | *multiprofit* is a Python astronomical source modelling code, inspired by `ProFit `_, but made for LSST Data Management. MultiProFit means Multiple Profile Fitting. The
18 | multi- aspect can be multi-object, multi-component, multi-band, multi-instrument, and someday multi-epoch.
19 |
20 | *multiprofit* can fit any kind of imaging data while modelling sources as Gaussian mixtures - including
21 | approximations to Sersic profiles - using a Gaussian pixel-convolved point spread function. It can also use
22 | `GalSim `_ or `libprofit `_
23 | via `pyprofit `_ to generate true Sersic and/or other supported
24 | models convolved with arbitrary PSFs images or models.
25 |
26 | *multiprofit* has support for multi-object fitting and experimental support for multi-band fitting, albeit
27 | currently limited to pixel-matched images of identical dimensions. Unlike ProFit, Bayesian MCMC is not
28 | available (yet).
29 |
30 | *multiprofit* requires Python 3, along with `pybind11 `_ for C++ bindings,
31 | and `gauss2d `_ for evaluating Gaussian mixtures. It can be installed
32 | using setup.py like so:
33 |
34 | python3 setup.py install --user
35 |
36 | .. todo *multiprofit* is available in `PyPI `_
37 | .. and thus can be easily installed via::
38 |
39 | .. pip install multiprofit
40 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/config.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from typing import Any
23 |
24 | import lsst.pex.config as pexConfig
25 |
26 | __all__ = ["set_config_from_dict"]
27 |
28 |
29 | def set_config_from_dict(
30 | config: pexConfig.Config | pexConfig.dictField.Dict | pexConfig.configDictField.ConfigDict | dict,
31 | overrides: dict[str, Any],
32 | ):
33 | """Set `lsst.pex.config` params from a dict.
34 |
35 | Parameters
36 | ----------
37 | config
38 | A config, dictField or configDictField object.
39 | overrides
40 | A dict of key-value pairs to override in the config.
41 | """
42 | is_config_dict = hasattr(config, "__getitem__")
43 | if is_config_dict:
44 | keys = tuple(config.keys())
45 | for key in keys:
46 | if key not in overrides:
47 | del config[key]
48 | for key, value in overrides.items():
49 | if isinstance(value, dict):
50 | attr = config[key] if is_config_dict else getattr(config, key)
51 | set_config_from_dict(attr, value)
52 | else:
53 | try:
54 | if is_config_dict:
55 | config[key] = value
56 | else:
57 | setattr(config, key, value)
58 | except Exception as e:
59 | print(e)
60 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/model_utils.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from typing import Any
23 |
24 | import lsst.gauss2d as g2
25 | import lsst.gauss2d.fit as g2f
26 |
27 | __all__ = ["make_image_gaussians", "make_psf_model_null"]
28 |
29 |
30 | def make_image_gaussians(
31 | gaussians_source: g2.Gaussians,
32 | gaussians_kernel: g2.Gaussians | None = None,
33 | **kwargs: Any,
34 | ) -> g2.ImageD:
35 | """Make an image array from a set of Gaussians.
36 |
37 | Parameters
38 | ----------
39 | gaussians_source
40 | Gaussians representing components of sources.
41 | gaussians_kernel
42 | Gaussians representing the smoothing kernel.
43 | **kwargs
44 | Additional keyword arguments to pass to gauss2d.make_gaussians_pixel_D
45 | (i.e. image size, etc.).
46 |
47 | Returns
48 | -------
49 | image
50 | The rendered image of the given Gaussians.
51 | """
52 | if gaussians_kernel is None:
53 | gaussians_kernel = g2.Gaussians([g2.Gaussian()])
54 | gaussians = g2.ConvolvedGaussians(
55 | [
56 | g2.ConvolvedGaussian(source=source, kernel=kernel)
57 | for source in gaussians_source for kernel in gaussians_kernel
58 | ]
59 | )
60 | return g2.make_gaussians_pixel_D(gaussians=gaussians, **kwargs)
61 |
62 |
63 | def make_psf_model_null() -> g2f.PsfModel:
64 | """Make a default (null) PSF model.
65 |
66 | Returns
67 | -------
68 | model
69 | A null PSF model consisting of a single, normalized, zero-size
70 | Gaussian.
71 | """
72 | return g2f.PsfModel(g2f.GaussianComponent.make_uniq_default_gaussians([0], True))
73 |
--------------------------------------------------------------------------------
/tests/test_observationconfig.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 | import lsst.gauss2d as g2
22 | import lsst.gauss2d.fit as g2f
23 | from lsst.multiprofit.observationconfig import CoordinateSystemConfig, ObservationConfig
24 | import pytest
25 |
26 |
27 | @pytest.fixture(scope="module")
28 | def kwargs_coordsys():
29 | return {
30 | 'dx1': 0.4,
31 | 'dy2': 1.6,
32 | 'x_min': -51.3,
33 | 'y_min': 1684.5,
34 | }
35 |
36 |
37 | @pytest.fixture(scope="module")
38 | def config_coordsys(kwargs_coordsys) -> CoordinateSystemConfig:
39 | return CoordinateSystemConfig(**kwargs_coordsys)
40 |
41 |
42 | @pytest.fixture(scope="module")
43 | def coordsys(config_coordsys) -> g2.CoordinateSystem:
44 | return config_coordsys.make_coordinate_system()
45 |
46 |
47 | def test_CoordinateSystemConfig(kwargs_coordsys, coordsys):
48 | for kwarg, value in kwargs_coordsys.items():
49 | assert getattr(coordsys, kwarg) == value
50 |
51 |
52 | def test_ObservationConfig():
53 | n_cols, n_rows = 15, 17
54 | shape = [n_rows, n_cols]
55 | config = ObservationConfig(n_cols=n_cols, n_rows=n_rows)
56 | observation = config.make_observation()
57 | assert observation.channel == g2f.Channel.NONE
58 | planes = ("image", "mask_inv", "sigma_inv")
59 | for plane in planes:
60 | attr = getattr(observation, plane)
61 | assert attr.shape == shape
62 | config.band = "red"
63 | observation2 = config.make_observation()
64 | assert observation2.channel == g2f.Channel.get("red")
65 | for plane in planes:
66 | attr1, attr2 = (getattr(obs, plane) for obs in (observation, observation2))
67 | assert attr1 is not attr2
68 | # Initialize both images; comparison checks equality
69 | attr1.fill(0)
70 | attr2.fill(0)
71 | assert attr1 == attr2
72 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/utils.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from typing import Any
23 |
24 | import lsst.gauss2d.fit as g2f
25 | import numpy
26 | import numpy as np
27 |
28 | __all__ = ["ArbitraryAllowedConfig", "FrozenArbitraryAllowedConfig", "get_params_uniq", "normalize"]
29 |
30 |
31 | class ArbitraryAllowedConfig:
32 | """Pydantic config to allow arbitrary typed Fields."""
33 |
34 | arbitrary_types_allowed = True
35 | # Disallow any extra kwargs
36 | extra = "forbid"
37 |
38 |
39 | class FrozenArbitraryAllowedConfig(ArbitraryAllowedConfig):
40 | """Pydantic config to allow arbitrary typed Fields for frozen classes."""
41 |
42 |
43 | def get_params_uniq(parametric: g2f.Parametric, **kwargs: Any):
44 | """Get a sorted set of parameters matching a filter.
45 |
46 | Parameters
47 | ----------
48 | parametric
49 | The parametric object to get parameters from.
50 | **kwargs
51 | Keyword arguments to pass to g2f.ParamFilter.
52 |
53 | Returns
54 | -------
55 | params
56 | The unique parameters from the parametric object matching the filter.
57 | """
58 | params = parametric.parameters(paramfilter=g2f.ParamFilter(**kwargs))
59 | # This should always return the same list as:
60 | # list({p: None for p in }.keys())
61 | return g2f.params_unique(params)
62 |
63 |
64 | def normalize(ndarray: numpy.ndarray, return_sum: bool = False):
65 | """Normalize a numpy array.
66 |
67 | Parameters
68 | ----------
69 | ndarray
70 | The array to normalize.
71 | return_sum
72 | Whether to return the sum.
73 |
74 | Returns
75 | -------
76 | ndarray
77 | The input ndarray.
78 | total
79 | The sum of the array.
80 | """
81 | total = np.sum(ndarray)
82 | ndarray /= total
83 | if return_sum:
84 | return ndarray, total
85 | return ndarray
86 |
--------------------------------------------------------------------------------
/doc/lsst.multiprofit/index.rst:
--------------------------------------------------------------------------------
1 | .. py:currentmodule:: lsst.multiprofit
2 |
3 | .. _lsst.multiprofit:
4 |
5 | #######################
6 | lsst.multiprofit
7 | #######################
8 |
9 | MultiProFit is a Python astronomical source modelling code. See the README for more information about the package.
10 |
11 | .. .. _lsst.multiprofit-using:
12 |
13 | .. Using lsst.multiprofit
14 | .. =============================
15 |
16 | .. toctree linking to topics related to using the module's APIs.
17 |
18 | .. .. toctree::
19 | .. :maxdepth: 1
20 |
21 | .. _lsst.multiprofit-contributing:
22 |
23 | Contributing
24 | ============
25 |
26 | ``lsst.multiprofit`` is developed at https://github.com/lsst-dm/multiprofit.
27 | You can find Jira issues for this module through `search `_.
28 |
29 | .. If there are topics related to developing this module (rather than using it), link to this from a toctree placed here.
30 |
31 | .. .. toctree::
32 | .. :maxdepth: 1
33 |
34 | .. _lsst.multiprofit-command-line-taskref:
35 |
36 | Task reference
37 | ==============
38 |
39 | ``lsst.multiprofit`` tasks are implemented in `meas_extensions_multiprofit `_.
40 |
41 | Configurations
42 | --------------
43 |
44 | .. lsst-configs::
45 | :root: lsst.multiprofit
46 | :toctree: configs
47 |
48 | .. .. _lsst.multiprofit-scripts:
49 |
50 | .. Script reference
51 | .. ================
52 |
53 | .. .. TODO: Add an item to this toctree for each script reference topic in the scripts subdirectory.
54 |
55 | .. toctree::
56 | :maxdepth: 1
57 |
58 | .. _lsst.multiprofit-pyapi:
59 |
60 | Python API reference
61 | ====================
62 |
63 | .. automodapi:: lsst.multiprofit.asinhstretchsigned
64 | :no-inheritance-diagram:
65 |
66 | .. automodapi:: lsst.multiprofit.componentconfig
67 | :no-inheritance-diagram:
68 |
69 | .. automodapi:: lsst.multiprofit.config
70 | :no-inheritance-diagram:
71 |
72 | .. automodapi:: lsst.multiprofit.fit_bootstrap_model
73 | :no-inheritance-diagram:
74 |
75 | .. automodapi:: lsst.multiprofit.fit_catalog
76 | :no-inheritance-diagram:
77 |
78 | .. automodapi:: lsst.multiprofit.fit_source
79 | :no-inheritance-diagram:
80 |
81 | .. automodapi:: lsst.multiprofit.limits
82 | :no-inheritance-diagram:
83 |
84 | .. automodapi:: lsst.multiprofit.modeller
85 | :no-inheritance-diagram:
86 |
87 | .. automodapi:: lsst.multiprofit.plots
88 | :no-inheritance-diagram:
89 |
90 | .. automodapi:: lsst.multiprofit.priors
91 | :no-inheritance-diagram:
92 |
93 | .. automodapi:: lsst.multiprofit.psfmodel_utils
94 | :no-inheritance-diagram:
95 |
96 | .. automodapi:: lsst.multiprofit.transforms
97 | :no-inheritance-diagram:
98 |
99 | .. automodapi:: lsst.multiprofit.utils
100 | :no-inheritance-diagram:
101 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "lsst-versions >= 1.3.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "lsst-multiprofit"
7 | description = "Astronomical image and source model fitting code."
8 | license = {file = "LICENSE"}
9 | readme = "README.rst"
10 | authors = [
11 | {name="Rubin Observatory Data Management", email="dm-admin@lists.lsst.org"},
12 | ]
13 | classifiers = [
14 | "Intended Audience :: Science/Research",
15 | "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
16 | "Operating System :: OS Independent",
17 | "Programming Language :: Python :: 3",
18 | "Programming Language :: Python :: 3.10",
19 | "Programming Language :: Python :: 3.11",
20 | "Topic :: Scientific/Engineering :: Astronomy",
21 | ]
22 | keywords = [
23 | "astronomy",
24 | "astrophysics",
25 | "fitting",
26 | "lsst",
27 | "models",
28 | "modeling",
29 | ]
30 | requires-python = ">=3.10.0"
31 | dependencies = [
32 | "astropy",
33 | "lsst-gauss2d",
34 | "lsst-gauss2d-fit",
35 | "lsst-pex-config",
36 | "lsst-utils",
37 | "importlib_resources",
38 | "matplotlib",
39 | "numpy",
40 | "pydantic",
41 | "scipy",
42 | ]
43 | dynamic = ["version"]
44 |
45 | [project.urls]
46 | "Homepage" = "https://github.com/lsst-dm/multiprofit"
47 |
48 | [project.optional-dependencies]
49 | galsim = ["galsim"]
50 | test = [
51 | "pytest",
52 | ]
53 |
54 | [tool.setuptools.packages.find]
55 | where = ["python"]
56 |
57 | [tool.setuptools.dynamic]
58 | version = { attr = "lsst_versions.get_lsst_version" }
59 |
60 | [tool.black]
61 | line-length = 110
62 | target-version = ["py311"]
63 | force-exclude = [
64 | "examples/fithsc.py",
65 | ]
66 |
67 | [tool.isort]
68 | profile = "black"
69 | line_length = 110
70 | force_sort_within_sections = true
71 |
72 | [tool.ruff]
73 | exclude = [
74 | "__init__.py",
75 | "examples/fithsc.py",
76 | "examples/test_utils.py",
77 | "tests/*.py",
78 | ]
79 | ignore = [
80 | "N802",
81 | "N803",
82 | "N806",
83 | "N812",
84 | "N815",
85 | "N816",
86 | "N999",
87 | "D107",
88 | "D105",
89 | "D102",
90 | "D104",
91 | "D100",
92 | "D200",
93 | "D205",
94 | "D400",
95 | ]
96 | line-length = 110
97 | select = [
98 | "E", # pycodestyle
99 | "F", # pycodestyle
100 | "N", # pep8-naming
101 | "W", # pycodestyle
102 | "D", # pydocstyle
103 | ]
104 | target-version = "py311"
105 |
106 | [tool.ruff.pycodestyle]
107 | max-doc-length = 79
108 |
109 | [tool.ruff.pydocstyle]
110 | convention = "numpy"
111 |
112 | [tool.numpydoc_validation]
113 | checks = [
114 | "all", # All except the rules listed below.
115 | "ES01", # No extended summary required.
116 | "EX01", # Example section.
117 | "GL01", # Summary text can start on same line as """
118 | "GL08", # Do not require docstring.
119 | "PR04", # numpydoc parameter types are redundant with type hints
120 | "RT01", # Unfortunately our @property trigger this.
121 | "RT02", # Does not want named return value. DM style says we do.
122 | "SA01", # See Also section.
123 | "SS05", # pydocstyle is better at finding infinitive verb.
124 | "SS06", # Summary can go into second line.
125 | ]
126 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/observationconfig.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import lsst.gauss2d as g2
23 | import lsst.gauss2d.fit as g2f
24 | import lsst.pex.config as pexConfig
25 |
26 |
27 | class CoordinateSystemConfig(pexConfig.Config):
28 | """Configuration for a gauss2d CoordinateSystem."""
29 |
30 | dx1 = pexConfig.Field[float](doc="The x-axis pixel scale", optional=False, default=1.0)
31 | dy2 = pexConfig.Field[float](doc="The y-axis pixel scale", optional=False, default=1.0)
32 | x_min = pexConfig.Field[float](
33 | doc="The x-axis coordinate of the bottom left corner",
34 | optional=False,
35 | default=0.0,
36 | )
37 | y_min = pexConfig.Field[float](
38 | doc="The y-axis coordinate of the bottom left corner",
39 | optional=False,
40 | default=0.0,
41 | )
42 |
43 | def make_coordinate_system(self) -> g2.CoordinateSystem:
44 | return g2.CoordinateSystem(dx1=self.dx1, dy2=self.dy2, x_min=self.x_min, y_min=self.y_min)
45 |
46 |
47 | class ObservationConfig(pexConfig.Config):
48 | """Configuration for a gauss2d.fit Observation."""
49 |
50 | band = pexConfig.Field[str](doc="The name of the band", optional=False, default="None")
51 | coordsys = pexConfig.ConfigField[CoordinateSystemConfig](doc="The coordinate system config")
52 | n_rows = pexConfig.Field[int](doc="The number of rows in the image")
53 | n_cols = pexConfig.Field[int](doc="The number of columns in the image")
54 |
55 | def make_observation(self) -> g2f.ObservationD:
56 | coordsys = self.coordsys.make_coordinate_system() if self.coordsys else None
57 | image = g2.ImageD(n_rows=self.n_rows, n_cols=self.n_cols, coordsys=coordsys)
58 | sigma_inv = g2.ImageD(n_rows=self.n_rows, n_cols=self.n_cols, coordsys=coordsys)
59 | mask = g2.ImageB(n_rows=self.n_rows, n_cols=self.n_cols, coordsys=coordsys)
60 | observation = g2f.ObservationD(
61 | image=image,
62 | sigma_inv=sigma_inv,
63 | mask_inv=mask,
64 | channel=g2f.Channel.get(self.band),
65 | )
66 | return observation
67 |
68 |
69 | class PsfObservationConfig(ObservationConfig):
70 | """Configuration for a gauss2d.fit Observation used for PSF fitting."""
71 |
72 | def validate(self):
73 | super().validate()
74 | if self.band != "None":
75 | raise ValueError("band must be None for PSF fitting")
76 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/priors.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import lsst.gauss2d.fit as g2f
23 | import lsst.pex.config as pexConfig
24 | import numpy as np
25 |
26 | from .transforms import transforms_ref
27 |
28 | __all__ = ["ShapePriorConfig", "get_hst_size_prior"]
29 |
30 |
31 | class ShapePriorConfig(pexConfig.Config):
32 | """Configuration for a shape prior."""
33 |
34 | prior_axrat_mean = pexConfig.Field[float](
35 | default=0.7,
36 | doc="Prior mean on axis ratio (prior ignored if not >0)",
37 | )
38 | prior_axrat_stddev = pexConfig.Field[float](
39 | default=0,
40 | doc="Prior std. dev. on axis ratio",
41 | )
42 | prior_size_mean = pexConfig.Field[float](
43 | default=1,
44 | doc="Prior std. dev. on size_major",
45 | )
46 | prior_size_stddev = pexConfig.Field[float](
47 | default=0,
48 | doc="Prior std. dev. on size_major (prior ignored if not >0)",
49 | )
50 |
51 | def get_shape_prior(self, ellipse: g2f.ParametricEllipse) -> g2f.ShapePrior | None:
52 | use_prior_axrat = (self.prior_axrat_stddev > 0) and np.isfinite(self.prior_axrat_stddev)
53 | use_prior_size = (self.prior_size_stddev > 0) and np.isfinite(self.prior_size_stddev)
54 |
55 | if use_prior_axrat or use_prior_size:
56 | prior_size = (
57 | g2f.ParametricGaussian1D(
58 | g2f.MeanParameterD(self.prior_size_mean, transform=transforms_ref["log10"]),
59 | g2f.StdDevParameterD(self.prior_size_stddev),
60 | )
61 | if use_prior_size
62 | else None
63 | )
64 | prior_axrat = (
65 | g2f.ParametricGaussian1D(
66 | g2f.MeanParameterD(self.prior_axrat_mean, transform=transforms_ref["logit_axrat_prior"]),
67 | g2f.StdDevParameterD(self.prior_axrat_stddev),
68 | )
69 | if use_prior_axrat
70 | else None
71 | )
72 | return g2f.ShapePrior(ellipse, prior_size, prior_axrat)
73 | return None
74 |
75 |
76 | def get_hst_size_prior(mag_psf_i):
77 | """Return the mean and stddev for an HST-based size prior.
78 |
79 | The size is major axis half-light radius.
80 |
81 | Parameters
82 | ----------
83 | mag_psf_i
84 | The i-band PSF magnitudes of the source(s).
85 |
86 | Notes
87 | -----
88 | Return values are log10 scaled in units of arcseconds.
89 | The input should be a PSF mag because other magnitudes - even Gaussian -
90 | can be unreliable for low S/N (non-)detections.
91 | """
92 | return 0.75 * (19 - np.clip(mag_psf_i, 10, 30)) / 6.5, 0.2
93 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/modelconfig.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import string
23 | from typing import Iterable
24 |
25 | import lsst.gauss2d.fit as g2f
26 | import lsst.pex.config as pexConfig
27 |
28 | from .componentconfig import Fluxes
29 | from .sourceconfig import SourceConfig
30 |
31 |
32 | class ModelConfig(pexConfig.Config):
33 | """Configuration for a gauss2d.fit Model."""
34 |
35 | sources = pexConfig.ConfigDictField[str, SourceConfig](doc="The configuration for sources")
36 |
37 | @staticmethod
38 | def format_label(label: str, name_source: str) -> str:
39 | return string.Template(label).safe_substitute(name_source=name_source)
40 |
41 | def get_integral_label_default(self, sourceconfig: SourceConfig) -> str:
42 | prefix = "src: {name_source} " if self.has_prefix_source() else ""
43 | return f"{prefix}{sourceconfig.get_integral_label_default()}"
44 |
45 | def has_prefix_source(self) -> bool:
46 | return (len(self.sources) > 1) or next(iter(self.sources.keys()))
47 |
48 | def make_sources(
49 | self,
50 | component_group_fluxes_srcs: Iterable[list[list[Fluxes]]],
51 | label_integral: str | None = None,
52 | ) -> tuple[list[g2f.Source], list[g2f.Prior]]:
53 | n_src = len(self.sources)
54 | if component_group_fluxes_srcs is None or len(component_group_fluxes_srcs) != n_src:
55 | raise ValueError(f"{len(component_group_fluxes_srcs)=} != {n_src=}")
56 |
57 | sources = []
58 | priors = []
59 | for component_group_fluxes, (name_src, config_src) in zip(
60 | component_group_fluxes_srcs, self.sources.items()
61 | ):
62 | label_integral_src = label_integral if label_integral is not None else (
63 | self.get_integral_label_default(config_src))
64 |
65 | source, priors_src = config_src.make_source(
66 | component_group_fluxes=component_group_fluxes,
67 | label_integral=self.format_label(label=label_integral_src, name_source=name_src)
68 | )
69 | sources.append(source)
70 | priors.extend(priors_src)
71 |
72 | return sources, priors
73 |
74 | def make_model(
75 | self,
76 | component_group_fluxes_srcs: Iterable[list[list[Fluxes]]],
77 | data: g2f.DataD,
78 | psf_models: list[g2f.PsfModel],
79 | label_integral: str | None = None,
80 | ) -> g2f.ModelD:
81 | sources, priors = self.make_sources(
82 | component_group_fluxes_srcs=component_group_fluxes_srcs,
83 | label_integral=label_integral,
84 | )
85 |
86 | model = g2f.ModelD(data=data, psfmodels=psf_models, sources=sources, priors=priors)
87 |
88 | return model
89 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/asinhstretchsigned.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 |
23 | import astropy.visualization as apvis
24 | import numpy as np
25 |
26 | # This is all hacked from astropy's AsinhStretch
27 |
28 | __all__ = ["AsinhStretchSigned", "SinhStretchSigned"]
29 |
30 |
31 | def _prepare(values: np.ndarray, clip: bool = True, out: np.ndarray | None = None):
32 | """Return clipped and/or copied values from input.
33 |
34 | Parameters
35 | ----------
36 | values
37 | The values to copy/clip from.
38 | clip
39 | Whether to clip values to between 0 and 1 (inclusive).
40 | out
41 | An existing array to assign to.
42 |
43 | Returns
44 | -------
45 | prepared
46 | The prepared values.
47 | """
48 | if clip:
49 | return np.clip(values, 0.0, 1.0, out=out)
50 | else:
51 | if out is None:
52 | return np.array(values, copy=True)
53 | else:
54 | out[:] = np.asarray(values)
55 | return out
56 |
57 |
58 | class AsinhStretchSigned(apvis.BaseStretch):
59 | r"""
60 | A signed asinh stretch.
61 |
62 | The stretch is given by:
63 |
64 | .. math::
65 | y = 0.5(1 + sign(x - 0.5)\frac{{\rm asinh}(2(x - 0.5) / a)}{{\rm asinh}(1 / a)}).
66 |
67 | Parameters
68 | ----------
69 | a : float, optional
70 | The ``a`` parameter used in the above formula. The value of
71 | this parameter is where the asinh curve transitions from linear
72 | to logarithmic behavior, expressed as a fraction of the
73 | normalized image. Must be in the range between 0 and 1.
74 | Default is 0.1.
75 | """ # noqa: W505
76 |
77 | def __init__(self, a=0.1):
78 | super().__init__()
79 | self.a = a
80 |
81 | # [docs]
82 | def __call__(self, values, clip=True, out=None):
83 | values = _prepare(values, clip=clip, out=out)
84 | values *= 2
85 | values -= 1
86 | signs = np.sign(values)
87 | np.abs(values, out=values)
88 | np.true_divide(values, self.a, out=values)
89 | np.arcsinh(values, out=values)
90 | np.true_divide(values, np.arcsinh(1.0 / self.a), out=values)
91 | np.true_divide(1.0 + signs * values, 2.0, out=values)
92 | return values
93 |
94 | @property
95 | def inverse(self):
96 | """A stretch object that performs the inverse operation.
97 |
98 | Returns
99 | -------
100 | inverse
101 | The inverse stretch.
102 | """
103 | return SinhStretchSigned(a=1.0 / np.arcsinh(1.0 / self.a))
104 |
105 |
106 | class SinhStretchSigned(apvis.BaseStretch):
107 | r"""
108 | A sinh stretch.
109 |
110 | The stretch is given by:
111 |
112 | .. math::
113 | y = \frac{{\rm sinh}(x / a)}{{\rm sinh}(1 / a)}
114 |
115 | Parameters
116 | ----------
117 | a : float, optional
118 | The ``a`` parameter used in the above formula. Default is 1/3.
119 | """
120 |
121 | def __init__(self, a=1.0 / 3.0):
122 | super().__init__()
123 | self.a = a
124 |
125 | # [docs]
126 |
127 | def __call__(self, values, clip=True, out=None):
128 | values = _prepare(values, clip=clip, out=out)
129 | values *= 2.0
130 | values -= 1.0
131 | np.true_divide(values, self.a, out=values)
132 | np.sinh(values, out=values)
133 | np.true_divide(values, np.sinh(1.0 / self.a), out=values)
134 | values += 1.0
135 | values /= 2.0
136 | return values
137 |
138 | @property
139 | def inverse(self):
140 | """A stretch object that performs the inverse operation.
141 |
142 | Returns
143 | -------
144 | inverse
145 | The inverse stretch.
146 | """
147 | return AsinhStretchSigned(a=1.0 / np.sinh(1.0 / self.a))
148 |
--------------------------------------------------------------------------------
/tests/test_plots.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warrantyfluxes = u.ABmag.to(u.nanojansky, mags) of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import lsst.gauss2d.fit as g2f
23 | from lsst.multiprofit.componentconfig import CentroidConfig, GaussianComponentConfig, ParameterConfig
24 | from lsst.multiprofit.model_utils import make_psf_model_null
25 | from lsst.multiprofit.modelconfig import ModelConfig
26 | from lsst.multiprofit.observationconfig import CoordinateSystemConfig, ObservationConfig
27 | from lsst.multiprofit.plots import abs_mag_sol_lsst, bands_weights_lsst, plot_model_rgb
28 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
29 | import numpy as np
30 | import pytest
31 |
32 | sigma_inv = 1e4
33 |
34 |
35 | @pytest.fixture(scope="module")
36 | def channels() -> dict[str, g2f.Channel]:
37 | return {band: g2f.Channel.get(band) for band in bands_weights_lsst}
38 |
39 |
40 | @pytest.fixture(scope="module")
41 | def data(channels) -> g2f.DataD:
42 | n_rows, n_cols = 16, 21
43 | x_min, y_min = 0, 0
44 |
45 | dn_rows, dn_cols = 1, -2
46 | dx_min, dy_min = -2, 1
47 |
48 | observations = []
49 | for idx, band in enumerate(channels):
50 | config = ObservationConfig(
51 | band=band,
52 | coordsys=CoordinateSystemConfig(
53 | x_min=x_min + idx*dx_min,
54 | y_min=y_min + idx*dy_min,
55 | ),
56 | n_rows=n_rows + idx*dn_rows,
57 | n_cols=n_cols + idx*dn_cols,
58 | )
59 | observation = config.make_observation()
60 | observation.sigma_inv.fill(sigma_inv)
61 | observation.mask_inv.fill(1)
62 | observations.append(observation)
63 | return g2f.DataD(observations)
64 |
65 |
66 | @pytest.fixture(scope="module")
67 | def psf_model():
68 | return make_psf_model_null()
69 |
70 |
71 | @pytest.fixture(scope="module")
72 | def psf_models(psf_model, channels) -> list[g2f.PsfModel]:
73 | return [psf_model]*len(channels)
74 |
75 |
76 | @pytest.fixture(scope="module")
77 | def model(channels, data, psf_models):
78 | fluxes_group = [{channels[band]: 10**(-0.4*(mag - 8.9)) for band, mag in abs_mag_sol_lsst.items()}]
79 |
80 | modelconfig = ModelConfig(
81 | sources={
82 | 'src': SourceConfig(
83 | component_groups={
84 | '': ComponentGroupConfig(
85 | centroids={"default": CentroidConfig(
86 | x=ParameterConfig(value_initial=6., fixed=True),
87 | y=ParameterConfig(value_initial=11., fixed=True),
88 | )},
89 | components_gauss={
90 | "": GaussianComponentConfig(
91 | rho=ParameterConfig(value_initial=0.1),
92 | size_x=ParameterConfig(value_initial=3.8),
93 | size_y=ParameterConfig(value_initial=5.1),
94 | )
95 | },
96 | )
97 | }
98 | ),
99 | },
100 | )
101 | model = modelconfig.make_model([[fluxes_group]], data=data, psf_models=psf_models)
102 | model.setup_evaluators(g2f.EvaluatorMode.image)
103 | model.evaluate()
104 | rng = np.random.default_rng(1)
105 | for output, obs in zip(model.outputs, model.data):
106 | img = obs.image.data
107 | img.flat = output.data.flat + rng.standard_normal(img.size) / sigma_inv
108 | return model
109 |
110 |
111 | def test_plot_model_rgb(model):
112 | fig, ax, fig_gs, ax_gs, *_ = plot_model_rgb(
113 | model, minimum=0, stretch=0.15, Q=4, weights=bands_weights_lsst, plot_chi_hist=True,
114 | )
115 | assert fig is not None
116 | assert ax is not None
117 | assert fig_gs is not None
118 | assert ax_gs is not None
119 |
120 |
121 | def test_plot_model_rgb_auto(model):
122 | fig, ax, *_ = plot_model_rgb(
123 | model, Q=6, weights=bands_weights_lsst, rgb_min_auto=True, rgb_stretch_auto=True,
124 | plot_singleband=False, plot_chi_hist=False,
125 | )
126 | assert fig is not None
127 | assert ax is not None
128 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/fit_catalog.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from abc import ABC, abstractmethod
23 | from collections.abc import Iterable
24 |
25 | import astropy.units as u
26 | import lsst.pex.config as pexConfig
27 | import pydantic
28 | from pydantic.dataclasses import dataclass
29 |
30 | from .modeller import ModelFitConfig
31 | from .utils import ArbitraryAllowedConfig
32 |
33 | __all__ = ["CatalogExposureABC", "ColumnInfo", "CatalogFitterConfig"]
34 |
35 |
36 | class CatalogExposureABC(ABC):
37 | """Interface for catalog-exposure pairs."""
38 |
39 | # TODO: add get_exposure (with Any return type?)
40 |
41 | @abstractmethod
42 | def get_catalog(self) -> Iterable:
43 | """Return a row-iterable catalog covering an exposure."""
44 |
45 |
46 | @dataclass(frozen=True, kw_only=True, config=ArbitraryAllowedConfig)
47 | class ColumnInfo:
48 | """Metadata for a column in a catalog."""
49 |
50 | dtype: str = pydantic.Field(title="Column data type name (numpy or otherwise)")
51 | key: str = pydantic.Field(title="Column key (name)")
52 | description: str = pydantic.Field("", title="Column description")
53 | unit: u.UnitBase | None = pydantic.Field(None, title="Column unit (astropy)")
54 |
55 |
56 | class CatalogFitterConfig(pexConfig.Config):
57 | """Configuration for generic MultiProFit fitting tasks."""
58 |
59 | column_id = pexConfig.Field[str](default="id", doc="Catalog index column key")
60 | compute_errors = pexConfig.ChoiceField[str](
61 | default="INV_HESSIAN_BESTFIT",
62 | doc="Whether/how to compute sqrt(variances) of each free parameter",
63 | allowed={
64 | "NONE": "no errors computed",
65 | "INV_HESSIAN": "inverse hessian using noisy image as data",
66 | "INV_HESSIAN_BESTFIT": "inverse hessian using best-fit model as data",
67 | },
68 | )
69 | compute_errors_from_jacobian = pexConfig.Field[bool](
70 | default=True,
71 | doc="Whether to estimate the Hessian from the Jacobian first, with finite differencing as a backup",
72 | )
73 | compute_errors_no_covar = pexConfig.Field[bool](
74 | default=True,
75 | doc="Whether to compute parameter errors independently, ignoring covariances",
76 | )
77 | config_fit = pexConfig.ConfigField[ModelFitConfig](default=ModelFitConfig(), doc="Fitter configuration")
78 | fit_centroid = pexConfig.Field[bool](default=True, doc="Fit centroid parameters")
79 | fit_linear_init = pexConfig.Field[bool](default=True, doc="Fit linear parameters after initialization")
80 | fit_linear_final = pexConfig.Field[bool](default=True, doc="Fit linear parameters after optimization")
81 | flag_errors = pexConfig.DictField(
82 | default={},
83 | keytype=str,
84 | itemtype=str,
85 | doc="Flag column names to set, keyed by name of exception to catch",
86 | )
87 | prefix_column = pexConfig.Field[str](default="mpf_", doc="Column name prefix")
88 |
89 | def schema(
90 | self,
91 | bands: list[str] = None,
92 | ) -> list[ColumnInfo]:
93 | """Return the schema as an ordered list of columns.
94 |
95 | Parameters
96 | ----------
97 | bands
98 | A list of band names to prefix band-dependent columns with.
99 | Band prefixes should not be used if None.
100 |
101 | Returns
102 | -------
103 | schema
104 | An ordered list of ColumnInfo instances.
105 | """
106 | schema = [
107 | ColumnInfo(key=self.column_id, dtype="i8"),
108 | ColumnInfo(key="n_iter", dtype="i4"),
109 | ColumnInfo(key="time_eval", dtype="f8", unit=u.s),
110 | ColumnInfo(key="time_fit", dtype="f8", unit=u.s),
111 | ColumnInfo(key="time_full", dtype="f8", unit=u.s),
112 | ColumnInfo(key="chisq_red", dtype="f8"),
113 | ColumnInfo(key="unknown_flag", dtype="bool"),
114 | ]
115 | schema.extend([ColumnInfo(key=key, dtype="bool") for key in self.flag_errors.keys()])
116 | # Subclasses should always write out centroids even if not fitting
117 | # They are helpful for reconstructing models
118 | return schema
119 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/psfmodel_utils.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import lsst.gauss2d.fit as g2f
23 | import numpy as np
24 |
25 | from .limits import limits_ref
26 |
27 | __all__ = ["make_psf_source"]
28 |
29 |
30 | # TODO: This function should be replaced with SourceConfig.make_source
31 | def make_psf_source(
32 | sigma_xs: list[float] | None = None,
33 | sigma_ys: list[float] | None = None,
34 | rhos: list[float] | None = None,
35 | fracs: list[float] | None = None,
36 | transforms: dict[str, g2f.TransformD] = None,
37 | limits_rho: g2f.LimitsD = None,
38 | ) -> g2f.Source:
39 | """Make a Gaussian mixture PSF source from parameter values.
40 |
41 | Parameters
42 | ----------
43 | sigma_xs
44 | Gaussian sigma_x values.
45 | sigma_ys
46 | Gaussian sigma_y values.
47 | rhos
48 | Gaussian rho values.
49 | fracs
50 | Gaussian sigma_x values.
51 | transforms
52 | Dict of transforms by variable name (frac/rho/sigma). If not set,
53 | will default to Logit/LogitLimited/Log10, respectively.
54 | limits_rho
55 | Limits for rho parameters. Defaults to limits_ref['rho'].
56 |
57 | Returns
58 | -------
59 | source
60 | A source model with Gaussians initialized as specified.
61 |
62 | Notes
63 | -----
64 | Parameter lists must all be the same length.
65 | """
66 | if limits_rho is None:
67 | limits_rho = limits_ref["rho"]
68 | if sigma_xs is None:
69 | sigma_xs = [1.5, 3.0] if sigma_ys is not None else sigma_ys
70 | if sigma_ys is None:
71 | sigma_ys = sigma_xs
72 | n_gaussians = len(sigma_xs)
73 | if n_gaussians == 0:
74 | raise ValueError(f"{n_gaussians=}!>0")
75 | if rhos is None:
76 | rhos = [0.0] * n_gaussians
77 | if fracs is None:
78 | fracs = np.arange(1, n_gaussians + 1) / n_gaussians
79 | if transforms is None:
80 | transforms = {}
81 | transforms_default = {
82 | "frac": transforms.get("frac", g2f.LogitTransformD()),
83 | "rho": transforms.get("rho", g2f.LogitLimitedTransformD(limits=limits_rho)),
84 | "sigma": transforms.get("sigma", g2f.Log10TransformD()),
85 | }
86 | for key, value in transforms_default.items():
87 | if key not in transforms:
88 | transforms[key] = value
89 |
90 | if (len(sigma_ys) != n_gaussians) or (len(rhos) != n_gaussians) or (len(fracs) != n_gaussians):
91 | raise ValueError(f"{len(sigma_ys)=} and/or {len(rhos)=} and/or {len(fracs)=} != {n_gaussians=}")
92 |
93 | errors = []
94 | for idx, (sigma_x, sigma_y, rho, frac) in enumerate(zip(sigma_xs, sigma_ys, rhos, fracs)):
95 | if not ((sigma_x >= 0) and (sigma_y >= 0)):
96 | errors.append(f"sigma_xs[{idx}]={sigma_x} and/or sigma_ys[{idx}]={sigma_y} !>=0")
97 | if not (limits_rho.check(rho)):
98 | errors.append(f"rhos[{idx}]={rho} !within({limits_rho=})")
99 | if not (frac >= 0):
100 | errors.append(f"fluxes[{idx}]={frac} !>0")
101 | if errors:
102 | raise ValueError("; ".join(errors))
103 | fracs[-1] = 1.0
104 |
105 | components = [None] * n_gaussians
106 | cenx = g2f.CentroidXParameterD(0, limits=g2f.LimitsD(min=0, max=100))
107 | ceny = g2f.CentroidYParameterD(0, limits=g2f.LimitsD(min=0, max=100))
108 | centroid = g2f.CentroidParameters(cenx, ceny)
109 |
110 | n_last = n_gaussians - 1
111 | last = None
112 |
113 | for c in range(n_gaussians):
114 | is_last = c == n_last
115 | last = g2f.FractionalIntegralModel(
116 | [
117 | (
118 | g2f.Channel.NONE,
119 | g2f.ProperFractionParameterD(fracs[c], fixed=is_last, transform=transforms["frac"]),
120 | )
121 | ],
122 | g2f.LinearIntegralModel([(g2f.Channel.NONE, g2f.IntegralParameterD(1.0, fixed=True))])
123 | if (c == 0)
124 | else last,
125 | is_last,
126 | )
127 | components[c] = g2f.GaussianComponent(
128 | g2f.GaussianParametricEllipse(
129 | g2f.SigmaXParameterD(sigma_xs[c], transform=transforms["sigma"]),
130 | g2f.SigmaYParameterD(sigma_ys[c], transform=transforms["sigma"]),
131 | g2f.RhoParameterD(rhos[c], transform=transforms["rho"]),
132 | ),
133 | centroid,
134 | last,
135 | )
136 | return g2f.Source(components)
137 |
--------------------------------------------------------------------------------
/tests/test_componentconfig.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import lsst.gauss2d.fit as g2f
23 | from lsst.multiprofit.componentconfig import (
24 | EllipticalComponentConfig,
25 | GaussianComponentConfig,
26 | ParameterConfig,
27 | SersicComponentConfig,
28 | SersicIndexParameterConfig,
29 | )
30 | from lsst.multiprofit.config import set_config_from_dict
31 | from lsst.multiprofit.utils import get_params_uniq
32 | import numpy as np
33 | import pytest
34 |
35 |
36 | @pytest.fixture(scope="module")
37 | def centroid_limits():
38 | limits = g2f.LimitsD(min=-np.Inf, max=np.Inf)
39 | return limits
40 |
41 |
42 | @pytest.fixture(scope="module")
43 | def centroid(centroid_limits):
44 | cenx = g2f.CentroidXParameterD(0, limits=centroid_limits, fixed=True)
45 | ceny = g2f.CentroidYParameterD(0, limits=centroid_limits, fixed=True)
46 | centroid = g2f.CentroidParameters(cenx, ceny)
47 | return centroid
48 |
49 |
50 | @pytest.fixture(scope="module")
51 | def channels():
52 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")}
53 |
54 |
55 | def test_EllipticalComponentConfig():
56 | config = EllipticalComponentConfig()
57 | config2 = EllipticalComponentConfig()
58 | set_config_from_dict(config2, config.toDict())
59 | assert config == config2
60 |
61 |
62 | def test_GaussianComponentConfig(centroid):
63 | config = GaussianComponentConfig(
64 | rho=ParameterConfig(value_initial=0),
65 | size_x=ParameterConfig(value_initial=1.4),
66 | size_y=ParameterConfig(value_initial=1.6),
67 | )
68 | channel = g2f.Channel.NONE
69 | component_data1 = config.make_component(
70 | centroid=centroid,
71 | integral_model=g2f.FractionalIntegralModel(
72 | [(channel, g2f.ProperFractionParameterD(0.5, fixed=False))],
73 | model=config.make_linear_integral_model({channel: 1.0}),
74 | ),
75 | )
76 | component_data2 = config.make_component(
77 | centroid=centroid,
78 | integral_model=g2f.FractionalIntegralModel(
79 | [(channel, g2f.ProperFractionParameterD(1.0, fixed=True))],
80 | model=component_data1.integral_model,
81 | is_final=True,
82 | ),
83 | )
84 | components = (component_data1, component_data2)
85 | n_components = len(components)
86 | for idx, component_data in enumerate(components):
87 | component = component_data.component
88 | assert component.centroid is centroid
89 | assert len(component_data.priors) == 0
90 | fluxes = list(get_params_uniq(component, nonlinear=False))
91 | assert len(fluxes) == 1
92 | assert isinstance(fluxes[0], g2f.IntegralParameterD)
93 | fracs = [param for param in get_params_uniq(component, linear=False)
94 | if isinstance(param, g2f.ProperFractionParameterD)]
95 | assert len(fracs) == (idx + (idx == 0) - (idx == n_components))
96 |
97 |
98 | def test_SersicConfig(centroid, channels):
99 | rho, size_x, size_y, sersic_index = -0.3, 1.4, 1.6, 3.2
100 | config = SersicComponentConfig(
101 | rho=ParameterConfig(value_initial=rho),
102 | size_x=ParameterConfig(value_initial=size_x),
103 | size_y=ParameterConfig(value_initial=size_y),
104 | sersic_index=SersicIndexParameterConfig(value_initial=sersic_index),
105 | )
106 | fluxes = {
107 | channel: 1.0 + idx
108 | for idx, channel in enumerate(channels.values())
109 | }
110 | integral_model = config.make_linear_integral_model(fluxes)
111 | component_data = config.make_component(
112 | centroid=centroid,
113 | integral_model=integral_model,
114 | )
115 | assert component_data.component is not None
116 | # As long as there's a default Sersic index prior
117 | assert len(component_data.priors) == 1
118 | params = get_params_uniq(component_data.component)
119 | values_init = {
120 | g2f.RhoParameterD: rho,
121 | g2f.ReffXParameterD: size_x,
122 | g2f.ReffYParameterD: size_y,
123 | g2f.SersicIndexParameterD: sersic_index,
124 | }
125 | fluxes_label = {
126 | config.format_label(config.get_integral_label_default(), name_channel=channel.name):
127 | fluxes[channel] for channel in fluxes.keys()
128 | }
129 | for param in params:
130 | if isinstance(param, g2f.IntegralParameterD):
131 | assert fluxes_label[param.label] == param.value
132 | elif value_init := values_init.get(param.__class__):
133 | assert param.value == value_init
134 |
--------------------------------------------------------------------------------
/tests/test_modelconfig.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import lsst.gauss2d as g2
23 | import lsst.gauss2d.fit as g2f
24 | from lsst.multiprofit.componentconfig import (
25 | CentroidConfig,
26 | GaussianComponentConfig,
27 | ParameterConfig,
28 | SersicComponentConfig,
29 | SersicIndexParameterConfig,
30 | )
31 | from lsst.multiprofit.modelconfig import ModelConfig
32 | from lsst.multiprofit.observationconfig import ObservationConfig
33 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
34 | import numpy as np
35 | import pytest
36 |
37 |
38 | @pytest.fixture(scope="module")
39 | def channels() -> dict[str, g2f.Channel]:
40 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")}
41 |
42 |
43 | @pytest.fixture(scope="module")
44 | def data(channels) -> g2f.DataD:
45 | config = ObservationConfig(n_rows=13, n_cols=19)
46 | observations = []
47 | for band in channels:
48 | config.band = band
49 | observations.append(config.make_observation())
50 | return g2f.DataD(observations)
51 |
52 |
53 | @pytest.fixture(scope="module")
54 | def psf_model():
55 | rho, size_x, size_y = 0.25, 1.6, 1.2
56 | drho, dsize_x, dsize_y = -0.4, 1.1, 1.9
57 |
58 | n_components = 3
59 | flux_total = 2.*(n_components + 1)
60 | fluxes = [x/flux_total for x in range(1, 1 + n_components)]
61 |
62 | config = SourceConfig(
63 | component_groups={
64 | 'src': ComponentGroupConfig(
65 | components_gauss={
66 | str(idx): GaussianComponentConfig(
67 | rho=ParameterConfig(value_initial=rho + idx*drho),
68 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x),
69 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y),
70 | )
71 | for idx in range(n_components)
72 | },
73 | )
74 | },
75 | )
76 | config.validate()
77 | channel = g2f.Channel.NONE
78 | psf_model, priors = config.make_psf_model(
79 | [
80 | [
81 | {channel: flux} for flux in fluxes
82 | ],
83 | ],
84 | )
85 | return psf_model
86 |
87 |
88 | @pytest.fixture(scope="module")
89 | def psf_models(psf_model, channels) -> list[g2f.PsfModel]:
90 | return [psf_model]*len(channels)
91 |
92 |
93 | @pytest.fixture(scope="module")
94 | def modelconfig_fluxes(channels):
95 | rho, size_x, size_y, sersicn, flux = 0.4, 1.5, 1.9, 0.5, 4.7
96 | drho, dsize_x, dsize_y, dsersicn, dflux = -0.9, 2.5, 5.4, 2.8, 13.9
97 |
98 | components_sersic = {}
99 | fluxes_mix = []
100 | for idx, name in enumerate(("PS", "Sersic")):
101 | components_sersic[name] = SersicComponentConfig(
102 | rho=ParameterConfig(value_initial=rho + idx*drho),
103 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x),
104 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y),
105 | sersic_index=SersicIndexParameterConfig(
106 | value_initial=sersicn + idx * dsersicn,
107 | fixed=idx == 0,
108 | prior_mean=None,
109 | ),
110 | )
111 | fluxes_comp = {
112 | channel: flux + idx_channel*dflux*idx
113 | for idx_channel, channel in enumerate(channels.values())
114 | }
115 | fluxes_mix.append(fluxes_comp)
116 |
117 | modelconfig = ModelConfig(
118 | sources={
119 | 'src': SourceConfig(
120 | component_groups={
121 | 'mix': ComponentGroupConfig(
122 | centroids={
123 | "default": CentroidConfig(
124 | x=ParameterConfig(value_initial=15.8, fixed=True),
125 | y=ParameterConfig(value_initial=14.3, fixed=False),
126 | ),
127 | },
128 | components_sersic=components_sersic,
129 | ),
130 | }
131 | ),
132 | },
133 | )
134 | return modelconfig, fluxes_mix
135 |
136 |
137 | def test_ModelConfig(modelconfig_fluxes, data, psf_models):
138 | modelconfig, fluxes = modelconfig_fluxes
139 | model = modelconfig.make_model([[fluxes]], data=data, psf_models=psf_models)
140 | assert model is not None
141 | assert model.data is data
142 | for observation in model.data:
143 | observation.sigma_inv.fill(1.)
144 | observation.mask_inv.fill(1)
145 |
146 | # Set the outputs to new images that refer to the existing data
147 | # because obs.image will not return a holding pointer
148 | outputs = [[g2.ImageD(obs.image.data)] for obs in model.data]
149 | model.setup_evaluators(g2f.EvaluatorMode.image, outputs=outputs)
150 | model.evaluate()
151 | model.setup_evaluators(g2f.EvaluatorMode.loglike)
152 | assert np.sum(model.evaluate()) == 0
153 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/transforms.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from typing import Any
23 |
24 | import lsst.gauss2d.fit as g2f
25 | import numpy as np
26 |
27 | from .limits import limits_ref
28 |
29 | __all__ = ["get_logit_limited", "verify_transform_derivative", "transforms_ref"]
30 |
31 |
32 | def get_logit_limited(lower: float, upper: float, factor: float = 1.0, name: str | None = None):
33 | """Get a logit transform stretched to span a different range than [0,1].
34 |
35 | Parameters
36 | ----------
37 | lower
38 | The lower limit of the range to span.
39 | upper
40 | The upper limit of the range to span.
41 | factor
42 | A multiplicative factor to apply to the transformed result.
43 | name
44 | A descriptive name for the transform.
45 |
46 | Returns
47 | -------
48 | transform
49 | A modified logit transform as specified.
50 | """
51 | return g2f.LogitLimitedTransformD(
52 | limits=g2f.LimitsD(
53 | min=lower,
54 | max=upper,
55 | name=name
56 | if name is not None
57 | else f"LogitLimitedTransformD(min={lower}, max={upper}, factor={factor})",
58 | ),
59 | factor=factor,
60 | )
61 |
62 |
63 | def verify_transform_derivative(
64 | transform: g2f.TransformD,
65 | value_transformed: float,
66 | derivative: float | None = None,
67 | abs_max: float = 1e6,
68 | dx_ratios=None,
69 | **kwargs: Any,
70 | ):
71 | """Verify that the derivative of a transform class is correct.
72 |
73 | Parameters
74 | ----------
75 | transform
76 | The transform to verify.
77 | value_transformed
78 | The un-transformed value at which to verify the transform.
79 | derivative
80 | The nominal derivative at value_transformed.
81 | Must equal transform.derivative(value_transformed).
82 | abs_max
83 | The x value to skip verification if np.abs(derivative) > x.
84 | dx_ratios
85 | Iterable of signed ratios to set dx for finite differencing.
86 | dx = value*ratio (untransformed). Only used if dx is None.
87 | **kwargs
88 | Keyword arguments to pass to np.isclose when comparing derivatives to
89 | finite differences.
90 |
91 | Raises
92 | ------
93 | RuntimeError
94 | Raised if the transform derivative doesn't match finite differences
95 | within the specified tolerances.
96 |
97 | Notes
98 | -----
99 | derivative should only be specified if it has previously been computed for
100 | the exact value_transformed, to avoid re-computing it unnecessarily.
101 |
102 | Default dx_ratios are [1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14].
103 | Verification will test all ratios until at least one passes.
104 | """
105 | # Skip testing finite differencing if the derivative is very large
106 | # This might happen e.g. near the limits of the transformation
107 | # TODO: Check if better finite differencing is possible for large values
108 | if abs_max is None:
109 | abs_max = 1e8
110 | value = transform.reverse(value_transformed)
111 | if derivative is None:
112 | derivative = transform.derivative(value)
113 | is_close = np.abs(derivative) > abs_max
114 | if not is_close:
115 | if dx_ratios is None:
116 | dx_ratios = [1e-4, 1e-6, 1e-8, 1e-10, 1e-12, 1e-14]
117 | for ratio in dx_ratios:
118 | dx = value * ratio
119 | fin_diff = (transform.forward(value + dx) - value_transformed) / dx
120 | if not np.isfinite(fin_diff):
121 | fin_diff = -(transform.forward(value - dx) - value_transformed) / dx
122 | is_close = np.isclose(derivative, fin_diff, **kwargs)
123 | if is_close:
124 | break
125 | if not is_close:
126 | raise RuntimeError(
127 | f"{transform} derivative={derivative:.8e} != last "
128 | f"finite diff.={fin_diff:8e} with dx={dx} dx_abs_max={abs_max}"
129 | )
130 |
131 |
132 | transforms_ref = {
133 | "none": g2f.UnitTransformD(),
134 | "log": g2f.LogTransformD(),
135 | "log10": g2f.Log10TransformD(),
136 | "inverse": g2f.InverseTransformD(),
137 | "logit": g2f.LogitTransformD(),
138 | "logit_fluxfrac": get_logit_limited(
139 | limits_ref["fluxfrac"].min,
140 | limits_ref["fluxfrac"].max,
141 | name=f"ref_logit_fluxfrac[{limits_ref['fluxfrac'].min}, {limits_ref['fluxfrac'].max}]",
142 | ),
143 | "logit_rho": get_logit_limited(
144 | limits_ref["rho"].min,
145 | limits_ref["rho"].max,
146 | name=f"ref_logit_rho[{limits_ref['rho'].min}, {limits_ref['rho'].max}]",
147 | ),
148 | "logit_axrat": get_logit_limited(1e-4, 1, name="ref_logit_axrat[1e-4, 1]"),
149 | "logit_axrat_prior": get_logit_limited(-0.001, 1.001, name="ref_logit_axrat_prior[-0.001, 1.001]"),
150 | "logit_sersic": get_logit_limited(0.49, 6.01, name="ref_logit_sersic[0.49, 6.01]"),
151 | }
152 |
--------------------------------------------------------------------------------
/tests/test_sourceconfig.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import lsst.gauss2d.fit as g2f
23 | from lsst.multiprofit.componentconfig import (
24 | GaussianComponentConfig,
25 | ParameterConfig,
26 | SersicComponentConfig,
27 | SersicIndexParameterConfig,
28 | )
29 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
30 | from lsst.multiprofit.utils import get_params_uniq
31 | import numpy as np
32 | import pytest
33 |
34 |
35 | @pytest.fixture(scope="module")
36 | def centroid_limits():
37 | limits = g2f.LimitsD(min=-np.Inf, max=np.Inf)
38 | return limits
39 |
40 |
41 | @pytest.fixture(scope="module")
42 | def centroid(centroid_limits):
43 | cenx = g2f.CentroidXParameterD(0, limits=centroid_limits, fixed=True)
44 | ceny = g2f.CentroidYParameterD(0, limits=centroid_limits, fixed=True)
45 | centroid = g2f.CentroidParameters(cenx, ceny)
46 | return centroid
47 |
48 |
49 | @pytest.fixture(scope="module")
50 | def channels():
51 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")}
52 |
53 |
54 | def test_ComponentGroupConfig(centroid):
55 | with pytest.raises(ValueError) as exc:
56 | config = ComponentGroupConfig(
57 | components_gauss={"x": GaussianComponentConfig()},
58 | components_sersic={"x": SersicComponentConfig()},
59 | )
60 | config.validate()
61 |
62 |
63 | def test_SourceConfig_base():
64 | with pytest.raises(ValueError) as exc:
65 | config = SourceConfig()
66 | config.validate()
67 |
68 | with pytest.raises(ValueError) as exc:
69 | config = SourceConfig(component_groups={})
70 | config.validate()
71 |
72 |
73 | def test_SourceConfig_fractional(centroid):
74 | rho, size_x, size_y = -0.3, 1.4, 1.6
75 | drho, dsize_x, dsize_y = 0.5, 1.6, 1.3
76 |
77 | n_components = 2
78 | config = SourceConfig(
79 | component_groups={
80 | 'src': ComponentGroupConfig(
81 | components_gauss={
82 | str(idx): GaussianComponentConfig(
83 | rho=ParameterConfig(value_initial=rho + idx*drho),
84 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x),
85 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y),
86 | )
87 | for idx in range(n_components)
88 | },
89 | is_fractional=True,
90 | )
91 | },
92 | )
93 | config.validate()
94 | channel = g2f.Channel.NONE
95 | psf_model, priors = config.make_psf_model(
96 | [
97 | [
98 | {channel: 1.0},
99 | {channel: 0.5},
100 | ]
101 | ],
102 | )
103 | assert len(priors) == 0
104 | assert len(psf_model.components) == n_components
105 |
106 |
107 | def test_SourceConfig_linear(centroid, channels):
108 | rho, size_x, size_y, sersicn, flux = 0.4, 1.5, 1.9, 0.5, 4.7
109 | drho, dsize_x, dsize_y, dsersicn, dflux = -0.9, 2.5, 5.4, 2.8, 13.9
110 |
111 | names = ("PS", "Sersic")
112 | config = SourceConfig(
113 | component_groups={
114 | 'src': ComponentGroupConfig(
115 | components_sersic={
116 | name: SersicComponentConfig(
117 | rho=ParameterConfig(value_initial=rho + idx*drho),
118 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x),
119 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y),
120 | sersic_index=SersicIndexParameterConfig(
121 | value_initial=sersicn + idx * dsersicn,
122 | fixed=idx == 0,
123 | prior_mean=None,
124 | ),
125 | )
126 | for idx, name in enumerate(names)
127 | }
128 | ),
129 | }
130 | )
131 | fluxes = [
132 | {
133 | channel: flux + idx_channel*dflux*idx_comp
134 | for idx_channel, channel in enumerate(channels.values())
135 | }
136 | for idx_comp in range(len(config.component_groups["src"].components_sersic))
137 | ]
138 | source, priors = config.make_source([fluxes])
139 | assert len(priors) == 0
140 | for idx, component in enumerate(source.components):
141 | params = get_params_uniq(component)
142 | values_init = {
143 | g2f.RhoParameterD: rho + idx*drho,
144 | g2f.ReffXParameterD: size_x + idx*dsize_x,
145 | g2f.ReffYParameterD: size_y + idx*dsize_y,
146 | g2f.SersicIndexParameterD: sersicn + idx*dsersicn,
147 | }
148 | for name_group, component_group in config.component_groups.items():
149 | fluxes_comp = fluxes[idx]
150 | name_comp = names[idx]
151 | config_comp = component_group.components_sersic[name_comp]
152 | fluxes_label = {
153 | config.format_label(
154 | component_group.format_label(
155 | label=config_comp.format_label(label=config.get_integral_label_default(),
156 | name_channel=channel.name),
157 | name_component=name_comp,
158 | ),
159 | name_group=name_group,
160 | ): fluxes_comp[channel]
161 | for channel in channels.values()
162 | }
163 | for param in params:
164 | if isinstance(param, g2f.IntegralParameterD):
165 | assert fluxes_label[param.label] == param.value
166 | elif value_init := values_init.get(param.__class__):
167 | assert param.value == value_init
168 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/fit_bootstrap_model.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | from functools import cached_property
23 | import logging
24 | from typing import Any, Mapping, Sequence
25 |
26 | import astropy
27 | import lsst.gauss2d.fit as g2f
28 | import lsst.pex.config as pexConfig
29 | import numpy as np
30 | import pydantic
31 | from pydantic.dataclasses import dataclass
32 |
33 | from .config import set_config_from_dict
34 | from .fit_psf import CatalogExposurePsfABC, CatalogPsfFitterConfig, CatalogPsfFitterConfigData
35 | from .fit_source import CatalogExposureSourcesABC, CatalogSourceFitterABC, CatalogSourceFitterConfigData
36 | from .model_utils import make_image_gaussians
37 | from .observationconfig import ObservationConfig
38 | from .utils import FrozenArbitraryAllowedConfig, get_params_uniq
39 |
40 | __all__ = [
41 | "CatalogBootstrapConfig",
42 | "CatalogExposurePsfBootstrap",
43 | "CatalogExposureSourcesBootstrap",
44 | "CatalogPsfBootstrapConfig",
45 | "CatalogSourceBootstrapConfig",
46 | "CatalogSourceFitterBootstrap",
47 | "NoisyObservationConfig",
48 | ]
49 |
50 |
51 | class CatalogBootstrapConfig(pexConfig.Config):
52 | """Configuration for a bootstrap source catalog fitter."""
53 |
54 | n_sources = pexConfig.Field[int](doc="Number of sources", default=1)
55 |
56 | @cached_property
57 | def catalog(self):
58 | catalog = astropy.table.Table({"id": np.arange(self.n_sources)})
59 | return catalog
60 |
61 |
62 | class ObservationNoiseConfig(pexConfig.Config):
63 | """Configuration for noise to be added to an Observation.
64 |
65 | The background level is in user-defined flux units, should be multiplied
66 | by the gain to obtain counts.
67 | """
68 |
69 | background = pexConfig.Field[float](doc="Background flux per pixel", default=1e-4)
70 | gain = pexConfig.Field[float](doc="Multiplicative factor to convert flux to counts", default=1.0)
71 |
72 |
73 | class NoisyObservationConfig(ObservationConfig, ObservationNoiseConfig):
74 | """Configuration for an observation with noise."""
75 |
76 |
77 | class NoisyPsfObservationConfig(ObservationConfig, ObservationNoiseConfig):
78 | """Configuration for a PSF observation with noise."""
79 |
80 |
81 | class CatalogPsfBootstrapConfig(CatalogBootstrapConfig):
82 | """Configuration for a catalog of noisy PSF observations for bootstrapping.
83 |
84 | Each row is a stacked and normalized image of any number of point sources.
85 | """
86 |
87 | observation = pexConfig.ConfigField[NoisyPsfObservationConfig](
88 | doc="The PSF image configuration",
89 | default=NoisyPsfObservationConfig,
90 | )
91 |
92 |
93 | class CatalogSourceBootstrapConfig(CatalogBootstrapConfig):
94 | """Configuration for a catalog of noisy source observations
95 | for bootstrapping.
96 |
97 | Each row is a PSF-convolved observation of the sources in one band.
98 | """
99 |
100 | observation = pexConfig.ConfigField[NoisyObservationConfig](
101 | doc="The source image configuration",
102 | default=NoisyObservationConfig,
103 | )
104 |
105 |
106 | @dataclass(kw_only=True, frozen=True, config=FrozenArbitraryAllowedConfig)
107 | class CatalogExposurePsfBootstrap(CatalogExposurePsfABC, CatalogPsfFitterConfigData):
108 | """Dataclass for a PSF-convolved bootstrap fitter."""
109 |
110 | config_boot: CatalogPsfBootstrapConfig = pydantic.Field(title="The configuration for bootstrapping")
111 |
112 | @cached_property
113 | def image(self) -> np.ndarray:
114 | psf_model_init = self.config.make_psf_model()
115 | # A hacky way to initialize the psf_model property to the same values
116 | # TODO: Include this functionality in fit_psf.py
117 | for param_init, param in zip(get_params_uniq(psf_model_init), get_params_uniq(self.psf_model)):
118 | param.value = param_init.value
119 | image = make_image_gaussians(
120 | psf_model_init.gaussians(g2f.Channel.NONE),
121 | n_rows=self.config_boot.observation.n_rows,
122 | n_cols=self.config_boot.observation.n_cols,
123 | )
124 | return image.data
125 |
126 | def get_catalog(self) -> astropy.table.Table:
127 | return self.config_boot.catalog
128 |
129 | def get_psf_image(
130 | self, source: astropy.table.Row | Mapping[str, Any], config: CatalogPsfFitterConfig | None = None
131 | ) -> np.ndarray:
132 | rng = np.random.default_rng(source["id"])
133 | image = self.image
134 | config_obs = self.config_boot.observation
135 | return image + rng.standard_normal(image.shape)*np.sqrt(
136 | (image + config_obs.background)/config_obs.gain)
137 |
138 | def __post_init__(self):
139 | self.config_boot.freeze()
140 |
141 |
142 | @dataclass(kw_only=True, frozen=True, config=FrozenArbitraryAllowedConfig)
143 | class CatalogExposureSourcesBootstrap(CatalogExposureSourcesABC):
144 | """A CatalogExposure for bootstrap fitting of source catalogs."""
145 |
146 | config_boot: CatalogSourceBootstrapConfig = pydantic.Field(
147 | title="A CatalogSourceBootstrapConfig to be frozen")
148 | table_psf_fits: astropy.table.Table = pydantic.Field(title="PSF fit parameters for the catalog")
149 |
150 | @cached_property
151 | def channel(self):
152 | channel = g2f.Channel.get(self.config_boot.observation.band)
153 | return channel
154 |
155 | def get_catalog(self) -> astropy.table.Table:
156 | return self.config_boot.catalog
157 |
158 | def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel:
159 | psf_model = self.psf_model_data.psf_model
160 | self.psf_model_data.init_psf_model(self.table_psf_fits[params["id"]])
161 | return psf_model
162 |
163 | def get_source_observation(self, source: Mapping[str, Any]) -> g2f.ObservationD:
164 | obs = self.config_boot.observation.make_observation()
165 | return obs
166 |
167 | def __post_init__(self):
168 | config_dict = self.table_psf_fits.meta["config"]
169 | config = CatalogPsfFitterConfig()
170 | set_config_from_dict(config, config_dict)
171 | config_data = CatalogPsfFitterConfigData(config=config)
172 | object.__setattr__(self, "psf_model_data", config_data)
173 |
174 |
175 | @dataclass(kw_only=True, frozen=True, config=FrozenArbitraryAllowedConfig)
176 | class CatalogSourceFitterBootstrap(CatalogSourceFitterABC):
177 | """A catalog fitter that bootstraps a single model.
178 |
179 | This fitter generates a different noisy image of the specified model for
180 | each row. The resulting catalog can be used to examine performance and
181 | statistics of the best-fit parameters.
182 | """
183 |
184 | def get_model_radec(self, source: Mapping[str, Any], cen_x: float, cen_y: float) -> tuple[float, float]:
185 | return float(cen_x), float(cen_y)
186 |
187 | def initialize_model(
188 | self,
189 | model: g2f.ModelD,
190 | source: Mapping[str, Any],
191 | catexps: list[CatalogExposureSourcesABC],
192 | values_init: Mapping[g2f.ParameterD, float] | None = None,
193 | centroid_pixel_offset: float = 0,
194 | ):
195 | if values_init is None:
196 | values_init = {}
197 | min_x, max_x = np.Inf, -np.Inf
198 | min_y, max_y = np.Inf, -np.Inf
199 | for idx_obs, observation in enumerate(model.data):
200 | x_min = observation.image.coordsys.x_min
201 | min_x = min(min_x, x_min)
202 | max_x = max(max_x, x_min + observation.image.n_cols*observation.image.coordsys.dx1)
203 | y_min = observation.image.coordsys.y_min
204 | min_y = min(min_y, y_min)
205 | max_y = max(max_y, y_min + observation.image.n_rows*observation.image.coordsys.dy2)
206 |
207 | cen_x = (min_x + max_x) / 2.0
208 | cen_y = (min_y + max_y) / 2.0
209 |
210 | params_limits_init = {
211 | g2f.CentroidXParameterD: (cen_x, (min_x, max_x)),
212 | g2f.CentroidYParameterD: (cen_y, (min_y, max_y)),
213 | }
214 |
215 | params_free = get_params_uniq(model, fixed=False)
216 | for param in params_free:
217 | value_init, limits_new = params_limits_init.get(
218 | type(param),
219 | (values_init.get(param), None)
220 | )
221 | if value_init is not None:
222 | param.value = value_init
223 | if limits_new:
224 | param.limits.min = -np.Inf
225 | param.limits.max = limits_new[1]
226 | param.limits.min = limits_new[0]
227 |
228 | # Should be done in get_source_observation, but it gets called first
229 | # ... and therefore does not have the initialization above
230 | # Also, this must be done per-iteration because PSF parameters vary
231 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.image)
232 | model.evaluate()
233 |
234 | # The offset is to keep the rng seed different from the PSF image seed
235 | # It doesn't really need to be so large but it's reasonably safe
236 | rng = np.random.default_rng(source["id"] + 10000000)
237 |
238 | for idx_obs, observation in enumerate(model.data):
239 | config_obs = catexps[idx_obs].config_boot.observation
240 | image, sigma_inv = observation.image, observation.sigma_inv
241 | # TODO: This doesn't raise or warn or anything if setting to
242 | # the wrong (differently sized output). That seems dangerous.
243 | image.data.flat = model.outputs[idx_obs].data.flat
244 | sigma_inv.data.flat = np.sqrt((image.data + config_obs.background)/config_obs.gain)
245 | image.data.flat += sigma_inv.data.flat * rng.standard_normal(image.data.size)
246 | sigma_inv.data.flat = (1.0 / sigma_inv.data).flat
247 | # This is mandatory because C++ construction does no initialization
248 | # (could instead initialize in get_source_observation)
249 | # TODO: Do some timings to see which is more efficient
250 | observation.mask_inv.data.flat = 1
251 |
252 | def validate_fit_inputs(
253 | self,
254 | catalog_multi: Sequence,
255 | catexps: list[CatalogExposureSourcesABC],
256 | config_data: CatalogSourceFitterConfigData = None,
257 | logger: logging.Logger = None,
258 | **kwargs: Any,
259 | ) -> None:
260 | errors = []
261 | for idx, catexp in enumerate(catexps):
262 | if not (
263 | (config_boot := getattr(catexp, "config_boot", None))
264 | and isinstance(config_boot, CatalogSourceBootstrapConfig)
265 | ):
266 | errors.append(f"catexps[{idx=}] = {catexp} does not have a config_boot attr of type"
267 | f"{CatalogSourceBootstrapConfig}")
268 |
--------------------------------------------------------------------------------
/tests/test_fit_bootstrap_model.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import astropy.table
23 | import lsst.gauss2d.fit as g2f
24 | from lsst.multiprofit.componentconfig import (
25 | CentroidConfig,
26 | FluxFractionParameterConfig,
27 | FluxParameterConfig,
28 | GaussianComponentConfig,
29 | ParameterConfig,
30 | SersicComponentConfig,
31 | SersicIndexParameterConfig,
32 | )
33 | from lsst.multiprofit.fit_bootstrap_model import (
34 | CatalogExposurePsfBootstrap,
35 | CatalogExposureSourcesBootstrap,
36 | CatalogPsfBootstrapConfig,
37 | CatalogSourceBootstrapConfig,
38 | CatalogSourceFitterBootstrap,
39 | NoisyObservationConfig,
40 | NoisyPsfObservationConfig,
41 | )
42 | from lsst.multiprofit.fit_psf import CatalogPsfFitter, CatalogPsfFitterConfig, CatalogPsfFitterConfigData
43 | from lsst.multiprofit.fit_source import CatalogSourceFitterConfig, CatalogSourceFitterConfigData
44 | from lsst.multiprofit.modelconfig import ModelConfig
45 | from lsst.multiprofit.modeller import ModelFitConfig
46 | from lsst.multiprofit.observationconfig import CoordinateSystemConfig
47 | from lsst.multiprofit.plots import ErrorValues, plot_catalog_bootstrap, plot_loglike
48 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
49 | from lsst.multiprofit.utils import get_params_uniq
50 | import numpy as np
51 | import pytest
52 |
53 | shape_img = (23, 27)
54 | reff_x_src, reff_y_src, rho_src, nser_src = 2.5, 3.6, -0.25, 2.0
55 |
56 | # TODO: These can be parameterized; should they be?
57 | compute_errors_no_covar = True
58 | compute_errors_from_jacobian = True
59 | include_point_source = False
60 | n_sources = 3
61 | # Set to True for interactive debugging (but don't commit)
62 | plot = False
63 |
64 |
65 | @pytest.fixture(scope="module")
66 | def channels():
67 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")}
68 |
69 |
70 | @pytest.fixture(scope="module")
71 | def config_fitter_psfs(channels) -> dict[g2f.Channel, CatalogExposurePsfBootstrap]:
72 | config_datas = {}
73 | for idx, (band, channel) in enumerate(channels.items()):
74 | n_rows = 17 + idx*2
75 | n_cols = 15 + idx*2
76 | config = CatalogPsfFitterConfig(
77 | model=SourceConfig(
78 | component_groups={"": ComponentGroupConfig(
79 | centroids={
80 | "default": CentroidConfig(
81 | x=ParameterConfig(value_initial=n_cols/2.),
82 | y=ParameterConfig(value_initial=n_rows/2.),
83 | ),
84 | },
85 | components_gauss={
86 | "comp1": GaussianComponentConfig(
87 | flux=FluxParameterConfig(value_initial=1.0, fixed=True),
88 | fluxfrac=FluxFractionParameterConfig(value_initial=0.5, fixed=False),
89 | size_x=ParameterConfig(value_initial=1.5 + 0.1*idx),
90 | size_y=ParameterConfig(value_initial=1.7 + 0.13*idx),
91 | rho=ParameterConfig(value_initial=-0.035 - 0.007*idx),
92 | ),
93 | "comp2": GaussianComponentConfig(
94 | size_x=ParameterConfig(value_initial=3.1 + 0.24*idx),
95 | size_y=ParameterConfig(value_initial=2.7 + 0.16*idx),
96 | rho=ParameterConfig(value_initial=0.06 + 0.012*idx),
97 | fluxfrac=FluxFractionParameterConfig(value_initial=1.0, fixed=True),
98 | ),
99 | },
100 | is_fractional=True,
101 | )}
102 | ),
103 | )
104 | config_boot = CatalogPsfBootstrapConfig(
105 | observation=NoisyPsfObservationConfig(n_rows=n_rows, n_cols=n_cols, gain=1e5),
106 | n_sources=n_sources,
107 | )
108 | config_data = CatalogExposurePsfBootstrap(config=config, config_boot=config_boot)
109 | config_datas[channel] = config_data
110 |
111 | return config_datas
112 |
113 |
114 | @pytest.fixture(scope="module")
115 | def config_fitter_source(channels) -> CatalogSourceFitterConfigData:
116 | config = CatalogSourceFitterConfig(
117 | config_fit=ModelFitConfig(fit_linear_iter=3),
118 | config_model=ModelConfig(
119 | sources={
120 | "": SourceConfig(
121 | component_groups={
122 | "": ComponentGroupConfig(
123 | components_gauss={
124 | "ps": GaussianComponentConfig(
125 | flux=FluxParameterConfig(value_initial=1000),
126 | rho=ParameterConfig(value_initial=0, fixed=True),
127 | size_x=ParameterConfig(value_initial=0, fixed=True),
128 | size_y=ParameterConfig(value_initial=0, fixed=True),
129 | )
130 | } if include_point_source else {},
131 | components_sersic={
132 | "ser": SersicComponentConfig(
133 | prior_size_mean=reff_y_src,
134 | prior_size_stddev=1.0,
135 | prior_axrat_mean=reff_x_src / reff_y_src,
136 | prior_axrat_stddev=0.2,
137 | flux=FluxParameterConfig(value_initial=5000),
138 | rho=ParameterConfig(value_initial=rho_src),
139 | size_x=ParameterConfig(value_initial=reff_x_src),
140 | size_y=ParameterConfig(value_initial=reff_y_src),
141 | sersic_index=SersicIndexParameterConfig(fixed=False, value_initial=1.0),
142 | ),
143 | }
144 | )
145 | }
146 | ),
147 | },
148 | ),
149 | convert_cen_xy_to_radec=False,
150 | compute_errors_no_covar=compute_errors_no_covar,
151 | compute_errors_from_jacobian=compute_errors_from_jacobian,
152 | )
153 | config_data = CatalogSourceFitterConfigData(
154 | channels=tuple(channels.values()),
155 | config=config,
156 | )
157 | return config_data
158 |
159 |
160 | @pytest.fixture(scope="module")
161 | def tables_psf_fits(config_fitter_psfs) -> dict[g2f.Channel, astropy.table.Table]:
162 | fitter = CatalogPsfFitter()
163 | fits = {
164 | channel: fitter.fit(
165 | catexp=config_fitter_psf,
166 | config_data=config_fitter_psf,
167 | )
168 | for channel, config_fitter_psf in config_fitter_psfs.items()
169 | }
170 | return fits
171 |
172 |
173 | @pytest.fixture(scope="module")
174 | def config_data_sources(
175 | config_fitter_psfs, tables_psf_fits,
176 | ) -> dict[g2f.Channel, CatalogExposureSourcesBootstrap]:
177 | config_datas = {}
178 | for idx, (channel, config_fitter_psf) in enumerate(config_fitter_psfs.items()):
179 | table_psf_fits = tables_psf_fits[channel]
180 | n_rows = shape_img[0] + idx*2
181 | n_cols = shape_img[1] + idx*2
182 | config_boot = CatalogSourceBootstrapConfig(
183 | observation=NoisyObservationConfig(
184 | n_rows=n_rows, n_cols=n_cols, band=channel.name, background=100,
185 | coordsys=CoordinateSystemConfig(x_min=-2 + 3*idx, y_min=5 - 4*idx),
186 | ),
187 | n_sources=n_sources,
188 | )
189 | config_data = CatalogExposureSourcesBootstrap(
190 | config_boot=config_boot,
191 | table_psf_fits=table_psf_fits,
192 | )
193 | config_datas[channel] = config_data
194 |
195 | return config_datas
196 |
197 |
198 | def test_fit_psf(config_fitter_psfs, tables_psf_fits):
199 | for band, results in tables_psf_fits.items():
200 | assert len(results) == n_sources
201 | assert np.sum(results["mpf_psf_unknown_flag"]) == 0
202 | assert all(np.isfinite(list(results[0].values())))
203 | config_data_psf = config_fitter_psfs[band]
204 | psf_model_init = config_data_psf.config.make_psf_model()
205 | psfdata = CatalogPsfFitterConfigData(config=config_data_psf.config)
206 | psf_model_fit = psfdata.psf_model
207 | psfdata.init_psf_model(results[0])
208 | assert len(psf_model_init.components) == len(psf_model_fit.components)
209 | params_init = psf_model_init.parameters()
210 | params_fit = psf_model_fit.parameters()
211 | assert len(params_init) == len(params_fit)
212 | for p_init, p_meas in zip(params_init, params_fit):
213 | assert p_meas.fixed == p_init.fixed
214 | if p_meas.fixed:
215 | assert p_init.value == p_meas.value
216 | else:
217 | # TODO: come up with better (noise-dependent) thresholds here
218 | if isinstance(p_init, g2f.IntegralParameterD):
219 | atol, rtol = 0, 0.02
220 | elif isinstance(p_init, g2f.ProperFractionParameterD):
221 | atol, rtol = 0.1, 0.01
222 | elif isinstance(p_init, g2f.RhoParameterD):
223 | atol, rtol = 0.05, 0.1
224 | else:
225 | atol, rtol = 0.01, 0.1
226 | assert np.isclose(p_init.value, p_meas.value, atol=atol, rtol=rtol)
227 |
228 |
229 | def test_fit_source(config_fitter_source, config_data_sources):
230 | fitter = CatalogSourceFitterBootstrap()
231 | # We don't have or need multiband input catalog, so just pretend the first one is
232 | catalog_multi = next(iter(config_data_sources.values())).get_catalog()
233 | catexps = list(config_data_sources.values())
234 | results = fitter.fit(catalog_multi=catalog_multi, catexps=catexps, config_data=config_fitter_source)
235 | assert len(results) == n_sources
236 |
237 | model = fitter.get_model(
238 | 0, catalog_multi=catalog_multi, catexps=catexps, config_data=config_fitter_source, results=results
239 | )
240 |
241 | model_sources, priors = config_fitter_source.config.make_sources(
242 | channels=list(config_data_sources.keys())
243 | )
244 | model_true = g2f.ModelD(data=model.data, psfmodels=model.psfmodels, sources=model_sources)
245 | fitter.initialize_model(model_true, catalog_multi[0], catexps=catexps)
246 | params_true = tuple(param.value for param in get_params_uniq(model_true, fixed=False))
247 | plot_catalog_bootstrap(
248 | results, histtype="step", paramvals_ref=params_true, plot_total_fluxes=True, plot_colors=True
249 | )
250 | if plot:
251 | import matplotlib.pyplot as plt
252 |
253 | plt.show()
254 |
255 | assert np.sum(results["mpf_unknown_flag"]) == 0
256 | assert all(np.isfinite(list(results[0].values())))
257 |
258 | variances = []
259 | for return_negative in (False, True):
260 | variances.append(
261 | fitter.modeller.compute_variances(
262 | model, transformed=False, options=g2f.HessianOptions(return_negative=return_negative),
263 | use_diag_only=True,
264 | )
265 | )
266 | assert np.all(variances[-1] > 0)
267 | if return_negative:
268 | variances = np.array(variances)
269 | variances[variances <= 0] = 0
270 | variances = list(variances)
271 |
272 | # Bootstrap errors
273 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.image)
274 | model.evaluate()
275 | img_data_old = []
276 | for obs, output in zip(model.data, model.outputs):
277 | img_data_old.append(obs.image.data.copy())
278 | img = obs.image.data
279 | img.flat = output.data.flat
280 | options_hessian = g2f.HessianOptions(return_negative=return_negative)
281 | variances_bootstrap = fitter.modeller.compute_variances(model, transformed=False, options=options_hessian)
282 | variances_bootstrap_diag = fitter.modeller.compute_variances(
283 | model, transformed=False, options=options_hessian, use_diag_only=True
284 | )
285 | for obs, img_datum_old in zip(model.data, img_data_old):
286 | obs.image.data.flat = img_datum_old.flat
287 | variances_jac = fitter.modeller.compute_variances(model, transformed=False)
288 | variances_jac_diag = fitter.modeller.compute_variances(model, transformed=False, use_diag_only=True)
289 |
290 | errors_plot = {
291 | "inv_hess": ErrorValues(values=np.sqrt(variances[0]), kwargs_plot={"linestyle": "-", "color": "r"}),
292 | "-inv_hess": ErrorValues(values=np.sqrt(variances[1]), kwargs_plot={"linestyle": "--", "color": "r"}),
293 | "inv_jac": ErrorValues(values=np.sqrt(variances_jac), kwargs_plot={"linestyle": "-.", "color": "r"}),
294 | "boot_hess": ErrorValues(
295 | values=np.sqrt(variances_bootstrap), kwargs_plot={"linestyle": "-", "color": "b"}
296 | ),
297 | "boot_diag": ErrorValues(
298 | values=np.sqrt(variances_bootstrap_diag), kwargs_plot={"linestyle": "--", "color": "b"}
299 | ),
300 | "boot_jac_diag": ErrorValues(
301 | values=np.sqrt(variances_jac_diag), kwargs_plot={"linestyle": "-.", "color": "m"}
302 | ),
303 | }
304 | fig, ax = plot_loglike(model, errors=errors_plot, values_reference=params_true)
305 | if plot:
306 | plt.tight_layout()
307 | plt.show()
308 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/sourceconfig.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import string
23 |
24 | import lsst.gauss2d.fit as g2f
25 | import lsst.pex.config as pexConfig
26 |
27 | from .componentconfig import (
28 | CentroidConfig,
29 | EllipticalComponentConfig,
30 | Fluxes,
31 | GaussianComponentConfig,
32 | SersicComponentConfig,
33 | )
34 |
35 | __all__ = [
36 | "ComponentConfigs", "CentroidConfig", "ComponentGroupConfig", "SourceConfig",
37 | ]
38 |
39 | ComponentConfigs = dict[str, EllipticalComponentConfig]
40 |
41 |
42 | class ComponentGroupConfig(pexConfig.Config):
43 | """Configuration for a group of gauss2d.fit Components.
44 |
45 | ComponentGroups may have linked CentroidParameters
46 | and IntegralModels, e.g. if is_fractional is True.
47 |
48 | Notes
49 | -----
50 | Gaussian components are generated first, then Sersic.
51 |
52 | This config class has no equivalent in gauss2dfit, because gauss2dfit
53 | model parameter dependencies implicitly. This class implements only a
54 | subset of typical use cases, i.e. PSFs sharing a fractional integral
55 | model with fixed unit flux, and galaxies/PSF components sharing a single
56 | common centroid.
57 | If greater flexibility in linking parameter values is needed,
58 | users must assemble their own gauss2dfit models directly.
59 | """
60 |
61 | centroids = pexConfig.ConfigDictField[str, CentroidConfig](
62 | doc="Centroids by key, which can be a component name or 'default'."
63 | "The 'default' key-value pair must be specified if it is needed.",
64 | default={"default": CentroidConfig()},
65 | )
66 | # TODO: Change this to just one EllipticalComponentConfig field
67 | # when pex_config supports derived types in ConfigDictField
68 | # (possibly DM-41049)
69 | components_gauss = pexConfig.ConfigDictField[str, GaussianComponentConfig](
70 | doc="Gaussian Components in the source",
71 | optional=False,
72 | default={},
73 | )
74 | components_sersic = pexConfig.ConfigDictField[str, SersicComponentConfig](
75 | doc="Sersic Components in the component mixture",
76 | optional=False,
77 | default={},
78 | )
79 | is_fractional = pexConfig.Field[bool](doc="Whether the integral_model is fractional", default=False)
80 | transform_fluxfrac_name = pexConfig.Field[str](
81 | doc="The name of the reference transform for flux parameters",
82 | default="logit_fluxfrac",
83 | optional=True,
84 | )
85 | transform_flux_name = pexConfig.Field[str](
86 | doc="The name of the reference transform for flux parameters",
87 | default="log10",
88 | optional=True,
89 | )
90 |
91 | @staticmethod
92 | def format_label(label: str, name_component: str) -> str:
93 | return string.Template(label).safe_substitute(name_component=name_component)
94 |
95 | @staticmethod
96 | def get_integral_label_default() -> str:
97 | return "comp: ${name_component} " + EllipticalComponentConfig.get_integral_label_default()
98 |
99 | def get_component_configs(self) -> ComponentConfigs:
100 | component_configs: ComponentConfigs = dict(self.components_gauss)
101 | for name, component in self.components_sersic.items():
102 | component_configs[name] = component
103 | return component_configs
104 |
105 | @staticmethod
106 | def get_fluxes_default(
107 | channels: tuple[g2f.Channel], component_configs: ComponentConfigs, is_fractional: bool = False,
108 | ) -> list[Fluxes]:
109 | if len(component_configs) == 0:
110 | raise ValueError("Must provide at least one ComponentConfig")
111 | fluxes = []
112 | component_configs_iter = tuple(component_configs.values())[:len(component_configs) - is_fractional]
113 | for idx, component_config in enumerate(component_configs_iter):
114 | if is_fractional:
115 | if idx == 0:
116 | value = component_config.flux.value_initial
117 | fluxes.append({channel: value for channel in channels})
118 | value = component_config.fluxfrac.value_initial
119 | fluxes.append({channel: value for channel in channels})
120 | else:
121 | value = component_config.flux.value_initial
122 | fluxes.append({channel: value for channel in channels})
123 | return fluxes
124 |
125 | def make_components(
126 | self,
127 | component_fluxes: list[Fluxes],
128 | label_integral: str | None = None,
129 | ) -> tuple[list[g2f.Component], list[g2f.Prior]]:
130 | """Make a list of gauss2d.fit.Component from this configuration.
131 |
132 | Parameters
133 | ----------
134 | component_fluxes
135 | A list of Fluxes to populate an appropriate
136 | `gauss2d.fit.IntegralModel` with.
137 | If self.is_fractional, the first item in the list must be
138 | total fluxes while the remainder are fractions (the final
139 | fraction is always fixed at 1.0 and must not be provided).
140 | label_integral
141 | A label to apply to integral parameters. Can reference the
142 | relevant component name with ${name_component}}.
143 |
144 | Returns
145 | -------
146 | componentdata
147 | An appropriate ComponentData including the initialized component.
148 | """
149 | component_configs = self.get_component_configs()
150 | fluxes_first = component_fluxes[0]
151 | channels = fluxes_first.keys()
152 | fluxes_all = (component_fluxes[1:] + [None]) if self.is_fractional else component_fluxes
153 | if len(fluxes_all) != len(component_configs):
154 | raise ValueError(f"{len(fluxes_all)=} != {len(component_configs)=}")
155 | priors = []
156 | idx_final = len(component_configs) - 1
157 | components = []
158 | last = None
159 |
160 | centroid_default = None
161 | for idx, (fluxes_component, (name_component, config_comp)) in enumerate(
162 | zip(fluxes_all, component_configs.items())
163 | ):
164 | label_integral_comp = self.format_label(
165 | label_integral if label_integral is not None else (
166 | config_comp.get_integral_label_default()
167 | ),
168 | name_component=name_component,
169 | )
170 |
171 | if self.is_fractional:
172 | if idx == 0:
173 | last = config_comp.make_linear_integral_model(
174 | fluxes=fluxes_first,
175 | label_integral=label_integral_comp,
176 | )
177 |
178 | is_final = idx == idx_final
179 | if is_final:
180 | params_frac = [
181 | (channel, g2f.ProperFractionParameterD(1.0, fixed=True))
182 | for channel in channels
183 | ]
184 | else:
185 | if fluxes_component.keys() != channels:
186 | raise ValueError(f"{name_component=} {fluxes_component=}")
187 | params_frac = [
188 | (
189 | channel,
190 | config_comp.make_fluxfrac_parameter(value=fluxfrac),
191 | ) for channel, fluxfrac in fluxes_component.items()
192 | ]
193 |
194 | integral_model = g2f.FractionalIntegralModel(
195 | params_frac,
196 | model=last,
197 | is_final=is_final,
198 | )
199 | # TODO: Omitting this crucial step should raise but doesn't
200 | # There shouldn't be two IntegralModels with the same last
201 | # especially not one is_final and one not
202 | last = integral_model
203 | else:
204 | integral_model = config_comp.make_linear_integral_model(
205 | fluxes_component,
206 | label_integral=label_integral_comp,
207 | )
208 |
209 | centroid = self.centroids.get(name_component)
210 | if not centroid:
211 | if centroid_default is None:
212 | centroid_default = self.centroids["default"].make_centroid()
213 | centroid = centroid_default
214 | componentdata = config_comp.make_component(
215 | centroid=centroid,
216 | integral_model=integral_model,
217 | )
218 | components.append(componentdata.component)
219 | priors.extend(componentdata.priors)
220 | return components, priors
221 |
222 | def validate(self):
223 | super().validate()
224 | errors = []
225 | components: ComponentConfigs = dict(self.components_gauss)
226 |
227 | for name, component in self.components_sersic.items():
228 | if name in components:
229 | errors.append(
230 | f"key={name} cannot be used in both self.components_gauss and self.components_sersic"
231 | )
232 | components[name] = component
233 |
234 | keys = set(self.centroids.keys())
235 | has_default = "default" in keys
236 | for name in components.keys():
237 | if name in keys:
238 | keys.remove(name)
239 | elif not has_default:
240 | errors.append(f"component {name=} has no entry in self.centroids and default not specified")
241 | if errors:
242 | newline = "\n"
243 | raise ValueError(f"ComponentMixtureConfig has validation errors:\n{newline.join(errors)}")
244 |
245 |
246 | class SourceConfig(pexConfig.Config):
247 | """Configuration for a gauss2d.fit Source.
248 |
249 | Sources may contain components with distinct centroids that may be linked
250 | by a prior (e.g. a galaxy + AGN + star clusters),
251 | although such priors are not yet implemented.
252 | """
253 |
254 | component_groups = pexConfig.ConfigDictField[str, ComponentGroupConfig](
255 | doc="Components in the source",
256 | optional=False,
257 | )
258 |
259 | def _make_components_priors(
260 | self,
261 | component_group_fluxes: list[list[Fluxes]],
262 | label_integral: str,
263 | validate_psf: bool = False,
264 | ) -> [list[g2f.Component], list[g2f.Prior]]:
265 | if len(component_group_fluxes) != len(self.component_groups):
266 | raise ValueError(f"{len(component_group_fluxes)=} != {len(self.component_groups)=}")
267 | components = []
268 | priors = []
269 | if validate_psf:
270 | keys_expected = tuple((g2f.Channel.NONE,))
271 | for component_fluxes, (name_group, component_group) in zip(
272 | component_group_fluxes, self.component_groups.items()
273 | ):
274 | if validate_psf:
275 | for idx, fluxes_comp in enumerate(component_fluxes):
276 | keys = tuple(fluxes_comp.keys())
277 | if keys != keys_expected:
278 | raise ValueError(
279 | f"{name_group=} comp[{idx}] {keys=} != {keys_expected=} with {validate_psf=}"
280 | )
281 |
282 | components_i, priors_i = component_group.make_components(
283 | component_fluxes=component_fluxes,
284 | label_integral=self.format_label(label=label_integral, name_group=name_group),
285 | )
286 | components.extend(components_i)
287 | priors.extend(priors_i)
288 |
289 | return components, priors
290 |
291 | @staticmethod
292 | def format_label(label: str, name_group: str) -> str:
293 | return string.Template(label).safe_substitute(name_group=name_group)
294 |
295 | def get_component_configs(self) -> ComponentConfigs:
296 | has_prefix_group = self.has_prefix_group()
297 | component_configs = {}
298 | for name_group, config_group in self.component_groups.items():
299 | prefix_group = f"{name_group}_" if has_prefix_group else ""
300 | for name_comp, component_config in config_group.get_component_configs().items():
301 | component_configs[f"{prefix_group}{name_comp}"] = component_config
302 | return component_configs
303 |
304 | def get_integral_label_default(self) -> str:
305 | prefix = "mix: ${name_group} " if self.has_prefix_group() else ""
306 | return f"{prefix}{ComponentGroupConfig.get_integral_label_default()}"
307 |
308 | def has_prefix_group(self) -> bool:
309 | return (len(self.component_groups) > 1) or next(iter(self.component_groups.keys()))
310 |
311 | def make_source(
312 | self,
313 | component_group_fluxes: list[list[Fluxes]],
314 | label_integral: str | None = None,
315 | ) -> [g2f.Source, list[g2f.Prior]]:
316 | """Make a gauss2d.fit.Source from this configuration.
317 |
318 | Parameters
319 | ----------
320 | component_group_fluxes
321 | A list of Fluxes for each of the self.component_groups to use
322 | when calling make_components.
323 | label_integral
324 | A label to apply to integral parameters. Can reference the
325 | relevant component mixture name with ${name_group}.
326 |
327 | Returns
328 | -------
329 | source
330 | An appropriate gauss2d.fit.Source.
331 | priors
332 | A list of priors from all constituent components.
333 | """
334 | if label_integral is None:
335 | label_integral = self.get_integral_label_default()
336 | components, priors = self._make_components_priors(
337 | component_group_fluxes=component_group_fluxes,
338 | label_integral=label_integral,
339 | )
340 | source = g2f.Source(components)
341 | return source, priors
342 |
343 | def make_psf_model(
344 | self,
345 | component_group_fluxes: list[list[Fluxes]],
346 | label_integral: str | None = None,
347 | ) -> [g2f.PsfModel, list[g2f.Prior]]:
348 | """Make a gauss2d.fit.PsfModel from this configuration.
349 |
350 | This method will validate that the arguments make a valid PSF model,
351 | i.e. with a unity total flux, and only one config for the none band.
352 |
353 | Parameters
354 | ----------
355 | component_group_fluxes
356 | A list of CentroidFluxes for each of the self.component_groups
357 | when calling make_components.
358 | label_integral
359 | A label to apply to integral parameters. Can reference the
360 | relevant component mixture name with ${name_group}.
361 |
362 | Returns
363 | -------
364 | psf_model
365 | An appropriate gauss2d.fit.PSfModel.
366 | priors
367 | A list of priors from all constituent components.
368 | """
369 | if label_integral is None:
370 | label_integral = f"PSF {self.get_integral_label_default()}"
371 | components, priors = self._make_components_priors(
372 | component_group_fluxes=component_group_fluxes,
373 | label_integral=label_integral,
374 | validate_psf=True,
375 | )
376 | model = g2f.PsfModel(components=components)
377 |
378 | return model, priors
379 |
380 | def validate(self):
381 | super().validate()
382 | if not self.component_groups:
383 | raise ValueError("Must have at least one componentgroup")
384 |
--------------------------------------------------------------------------------
/python/lsst/multiprofit/componentconfig.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import string
23 |
24 | import lsst.gauss2d.fit as g2f
25 | import lsst.pex.config as pexConfig
26 | import pydantic
27 | from pydantic.dataclasses import dataclass
28 |
29 | from .limits import limits_ref
30 | from .priors import ShapePriorConfig
31 | from .transforms import transforms_ref
32 | from .utils import FrozenArbitraryAllowedConfig
33 |
34 | __all__ = [
35 | "ParameterConfig",
36 | "FluxFractionParameterConfig",
37 | "FluxParameterConfig",
38 | "CentroidConfig",
39 | "ComponentData",
40 | "Fluxes",
41 | "EllipticalComponentConfig",
42 | "GaussianComponentConfig",
43 | "SersicIndexParameterConfig",
44 | "SersicComponentConfig",
45 | ]
46 |
47 |
48 | class ParameterConfig(pexConfig.Config):
49 | """Configuration for a parameter."""
50 |
51 | fixed = pexConfig.Field[bool](default=False, doc="Whether parameter is fixed or not (free)")
52 | value_initial = pexConfig.Field[float](default=0, doc="Initial value")
53 |
54 |
55 | class FluxParameterConfig(ParameterConfig):
56 | """Configuration for flux parameters (IntegralParameterD).
57 |
58 | The safest initial value for a flux is 1.0, because if it's set to zero,
59 | linear fitting will not work correctly initially.
60 | """
61 |
62 | def setDefaults(self):
63 | super().setDefaults()
64 | self.value_initial = 1.0
65 |
66 |
67 | class FluxFractionParameterConfig(ParameterConfig):
68 | """Configuration for flux fraction parameters (ProperFractionParameterD).
69 |
70 | The safest initial value for a flux fraction is 0.5, because if it's set
71 | to one, downstream fractions will be zero, while if it's set to zero,
72 | linear fitting will not work correctly initially.
73 | """
74 |
75 | def setDefaults(self):
76 | super().setDefaults()
77 | self.value_initial = 1.0
78 |
79 |
80 | class CentroidConfig(pexConfig.Config):
81 | """Configuration for a component centroid."""
82 |
83 | x = pexConfig.ConfigField[ParameterConfig](doc="The x-axis centroid configuration")
84 | y = pexConfig.ConfigField[ParameterConfig](doc="The y-axis centroid configuration")
85 |
86 | def make_centroid(self) -> g2f.CentroidParameters:
87 | cen_x, cen_y = (
88 | type_param(config.value_initial, fixed=config.fixed, limits=g2f.LimitsD())
89 | for (config, type_param) in (
90 | (self.x, g2f.CentroidXParameterD), (self.y, g2f.CentroidYParameterD)
91 | )
92 | )
93 | centroid = g2f.CentroidParameters(x=cen_x, y=cen_y)
94 | return centroid
95 |
96 |
97 | @dataclass(kw_only=True, frozen=True, config=FrozenArbitraryAllowedConfig)
98 | class ComponentData:
99 | """Dataclass for a Component config."""
100 |
101 | component: g2f.Component = pydantic.Field(title="The component instance")
102 | integral_model: g2f.IntegralModel = pydantic.Field(title="The component's integral_model")
103 | priors: list[g2f.Prior] = pydantic.Field(title="The priors associated with the component")
104 |
105 |
106 | Fluxes = dict[g2f.Channel, float]
107 |
108 |
109 | class EllipticalComponentConfig(ShapePriorConfig):
110 | """Configuration for an elliptically-symmetric component.
111 |
112 | This class can be initialized but cannot implement make_component.
113 | """
114 |
115 | fluxfrac = pexConfig.ConfigField[FluxFractionParameterConfig](
116 | doc="Fractional flux parameter(s) config",
117 | default=None,
118 | )
119 | flux = pexConfig.ConfigField[FluxParameterConfig](
120 | doc="Flux parameter(s) config",
121 | default=FluxParameterConfig,
122 | )
123 |
124 | rho = pexConfig.ConfigField[ParameterConfig](doc="Rho parameter config")
125 | size_x = pexConfig.ConfigField[ParameterConfig](doc="x-axis size parameter config")
126 | size_y = pexConfig.ConfigField[ParameterConfig](doc="y-axis size parameter config")
127 | transform_flux_name = pexConfig.Field[str](
128 | doc="The name of the reference transform for flux parameters",
129 | default="log10",
130 | optional=True,
131 | )
132 | transform_fluxfrac_name = pexConfig.Field[str](
133 | doc="The name of the reference transform for flux fraction parameters",
134 | default="logit_fluxfrac",
135 | optional=True,
136 | )
137 | transform_rho_name = pexConfig.Field[str](
138 | doc="The name of the reference transform for rho parameters",
139 | default="logit_rho",
140 | optional=True,
141 | )
142 | transform_size_name = pexConfig.Field[str](
143 | doc="The name of the reference transform for size parameters",
144 | default="log10",
145 | optional=True,
146 | )
147 |
148 | def format_label(self, label: str, name_channel: str) -> str:
149 | """Format a label for a band-dependent parameter.
150 |
151 | Parameters
152 | ----------
153 | label
154 | The label to format.
155 | name_channel
156 | The name of the channel to format with.
157 |
158 | Returns
159 | -------
160 | label_formmated
161 | The formatted label.
162 | """
163 | label_formatted = string.Template(label).safe_substitute(
164 | type_component=self.get_type_name(), name_channel=name_channel,
165 | )
166 | return label_formatted
167 |
168 | @staticmethod
169 | def get_integral_label_default() -> str:
170 | """Return the default integral label."""
171 | return "${type_component} ${name_channel}-band"
172 |
173 | def get_size_label(self) -> str:
174 | """Return the label for the component's size parameters."""
175 | raise NotImplementedError("EllipticalComponent does not implement get_size_label")
176 |
177 | def get_type_name(self) -> str:
178 | """Return a descriptive component name."""
179 | raise NotImplementedError("EllipticalComponent does not implement get_type_name")
180 |
181 | def get_transform_fluxfrac(self) -> g2f.TransformD | None:
182 | return transforms_ref[self.transform_fluxfrac_name] if self.transform_fluxfrac_name else None
183 |
184 | def get_transform_flux(self) -> g2f.TransformD | None:
185 | return transforms_ref[self.transform_flux_name] if self.transform_flux_name else None
186 |
187 | def get_transform_rho(self) -> g2f.TransformD | None:
188 | return transforms_ref[self.transform_rho_name] if self.transform_rho_name else None
189 |
190 | def get_transform_size(self) -> g2f.TransformD | None:
191 | return transforms_ref[self.transform_size_name] if self.transform_size_name else None
192 |
193 | def make_component(
194 | self,
195 | centroid: g2f.CentroidParameters,
196 | integral_model: g2f.IntegralModel,
197 | ) -> ComponentData:
198 | """Make a Component reflecting the current configuration.
199 |
200 | Parameters
201 | ----------
202 | centroid
203 | Centroid parameters for the component.
204 | integral_model
205 | The integral_model for this component.
206 |
207 | Returns
208 | -------
209 | component_data
210 | An appropriate ComponentData including the initialized component.
211 |
212 | Notes
213 | -----
214 | The default `gauss2d.fit.LinearIntegralModel` can be populated with
215 | unit fluxes (`gauss2d.fit.IntegralParameterD` instances) to prepare
216 | for linear least squares fitting.
217 | """
218 | raise NotImplementedError("EllipticalComponent cannot not implement make_component")
219 |
220 | def make_gaussianparametricellipse(self) -> g2f.GaussianParametricEllipse:
221 | transform_size = self.get_transform_size()
222 | transform_rho = self.get_transform_rho()
223 | ellipse = g2f.GaussianParametricEllipse(
224 | sigma_x=g2f.SigmaXParameterD(
225 | self.size_x.value_initial, transform=transform_size, fixed=self.size_x.fixed
226 | ),
227 | sigma_y=g2f.SigmaYParameterD(
228 | self.size_y.value_initial, transform=transform_size, fixed=self.size_y.fixed
229 | ),
230 | rho=g2f.RhoParameterD(self.rho.value_initial, transform=transform_rho, fixed=self.rho.fixed),
231 | )
232 | return ellipse
233 |
234 | def make_fluxfrac_parameter(
235 | self,
236 | value: float | None,
237 | label: str | None = None,
238 | **kwargs
239 | ) -> g2f.ProperFractionParameterD:
240 | parameter = g2f.ProperFractionParameterD(
241 | value if value is None else self.fluxfrac.value_initial,
242 | fixed=self.fluxfrac.fixed,
243 | transform=self.get_transform_fluxfrac(),
244 | label=label if label is not None else "",
245 | **kwargs
246 | )
247 | return parameter
248 |
249 | def make_flux_parameter(
250 | self,
251 | value: float | None,
252 | label: str | None = None,
253 | **kwargs
254 | ) -> g2f.IntegralParameterD:
255 | parameter = g2f.IntegralParameterD(
256 | value if value is not None else self.flux.value_initial,
257 | fixed=self.flux.fixed,
258 | transform=self.get_transform_flux(),
259 | label=label if label is not None else "",
260 | **kwargs
261 | )
262 | return parameter
263 |
264 | def make_linear_integral_model(
265 | self,
266 | fluxes: Fluxes,
267 | label_integral: str | None = None,
268 | **kwargs
269 | ) -> g2f.IntegralModel:
270 | """Make a gauss2d.fit.LinearIntegralModel for this component.
271 |
272 | Parameters
273 | ----------
274 | fluxes
275 | Configurations, including initial values, for the flux
276 | parameters by channel.
277 | label_integral
278 | A label to apply to integral parameters. Can reference the
279 | relevant channel with e.g. {channel.name}.
280 | **kwargs
281 | Additional keyword arguments to pass to make_flux_parameter.
282 | Some parameters cannot be overriden from their configs.
283 |
284 | Returns
285 | -------
286 | integral_model
287 | The requested gauss2d.fit.IntegralModel.
288 | """
289 | if label_integral is None:
290 | label_integral = self.get_integral_label_default()
291 | integral_model = g2f.LinearIntegralModel(
292 | [
293 | (
294 | channel,
295 | self.make_flux_parameter(
296 | flux,
297 | label=self.format_label(label_integral, name_channel=channel.name),
298 | **kwargs,
299 | ),
300 | )
301 | for channel, flux in fluxes.items()
302 | ]
303 | )
304 | return integral_model
305 |
306 | @staticmethod
307 | def set_size_x(component: g2f.EllipticalComponent, size_x: float) -> None:
308 | component.ellipse.sigma_x = size_x
309 |
310 | @staticmethod
311 | def set_size_y(component: g2f.EllipticalComponent, size_y: float) -> None:
312 | component.ellipse.sigma_y = size_y
313 |
314 | @staticmethod
315 | def set_rho(component: g2f.EllipticalComponent, rho: float) -> None:
316 | component.ellipse.rho = rho
317 |
318 |
319 | class GaussianComponentConfig(EllipticalComponentConfig):
320 | """Configuration for a gauss2d.fit Gaussian component."""
321 |
322 | transform_frac_name = pexConfig.Field[str](
323 | doc="The name of the reference transform for flux fraction parameters",
324 | default="log10",
325 | optional=True,
326 | )
327 |
328 | def get_size_label(self) -> str:
329 | return "sigma"
330 |
331 | def get_type_name(self) -> str:
332 | return "Gaussian"
333 |
334 | def make_component(
335 | self,
336 | centroid: g2f.CentroidParameters,
337 | integral_model: g2f.IntegralModel,
338 | ) -> ComponentData:
339 | ellipse = self.make_gaussianparametricellipse()
340 | prior = self.get_shape_prior(ellipse)
341 | component_data = ComponentData(
342 | component=g2f.GaussianComponent(
343 | centroid=centroid,
344 | ellipse=ellipse,
345 | integral=integral_model,
346 | ),
347 | integral_model=integral_model,
348 | priors=[] if prior is None else [prior],
349 | )
350 | return component_data
351 |
352 |
353 | class SersicIndexParameterConfig(ParameterConfig):
354 | """Configuration for a gauss2d.fit Sersic index parameter."""
355 |
356 | prior_mean = pexConfig.Field[float](doc="Mean for the prior (untransformed)", default=1.0, optional=True)
357 | prior_stddev = pexConfig.Field[float](doc="Std. dev. for the prior", default=0.5, optional=True)
358 | prior_transformed = pexConfig.Field[float](
359 | doc="Whether the prior should be in transformed values", default=True,
360 | )
361 |
362 | def get_prior(self, param: g2f.SersicIndexParameterD) -> g2f.Prior | None:
363 | if self.prior_mean is not None:
364 | mean = param.transform.forward(self.prior_mean) if self.prior_transformed else self.prior_mean
365 | stddev = (
366 | param.transform.forward(self.prior_mean + self.prior_stddev/2.) -
367 | param.transform.forward(self.prior_mean - self.prior_stddev/2.)
368 | ) if self.prior_transformed else self.prior_stddev
369 | return g2f.GaussianPrior(
370 | param=param, mean=mean, stddev=stddev, transformed=self.prior_transformed,
371 | )
372 | return None
373 |
374 | def setDefaults(self):
375 | self.value_initial = 0.5
376 |
377 | def validate(self):
378 | super().validate()
379 | if self.prior_mean is not None:
380 | if not self.prior_mean > 0.:
381 | raise ValueError("Sersic index prior mean must be > 0")
382 | if not self.prior_stddev > 0.:
383 | raise ValueError("Sersic index prior std. dev. must be > 0")
384 |
385 |
386 | class SersicComponentConfig(EllipticalComponentConfig):
387 | """Configuration for a gauss2d.fit Sersic component.
388 |
389 | Notes
390 | -----
391 | make_component will return a `gauss2d.fit.GaussianComponent` if the Sersic
392 | index is fixed at 0.5, or a `gauss2d.fit.SersicMixComponent` otherwise.
393 | """
394 |
395 | _interpolators: dict[int, g2f.SersicMixInterpolator] = {}
396 |
397 | order = pexConfig.ChoiceField[int](doc="Sersic mix order", allowed={4: "Four", 8: "Eight"}, default=4)
398 | sersic_index = pexConfig.ConfigField[SersicIndexParameterConfig](doc="Sersic index config")
399 |
400 | def get_interpolator(self, order: int):
401 | return self._interpolators.get(
402 | order,
403 | (
404 | g2f.GSLSersicMixInterpolator
405 | if hasattr(g2f, "GSLSersicMixInterpolator")
406 | else g2f.LinearSersicMixInterpolator
407 | )(order=order),
408 | )
409 |
410 | def get_size_label(self) -> str:
411 | return "reff"
412 |
413 | def get_type_name(self) -> str:
414 | is_gaussian_fixed = self.is_gaussian_fixed()
415 | return f"{'Gaussian (fixed Sersic)' if is_gaussian_fixed else 'Sersic'}"
416 |
417 | def is_gaussian_fixed(self):
418 | return self.sersic_index.value_initial == 0.5 and self.sersic_index.fixed
419 |
420 | def make_component(
421 | self,
422 | centroid: g2f.CentroidParameters,
423 | integral_model: g2f.IntegralModel,
424 | ) -> ComponentData:
425 | is_gaussian_fixed = self.is_gaussian_fixed()
426 | transform_size = self.get_transform_size()
427 | transform_rho = self.get_transform_rho()
428 | if is_gaussian_fixed:
429 | ellipse = self.make_gaussianparametricellipse()
430 | component = g2f.GaussianComponent(
431 | centroid=centroid,
432 | ellipse=ellipse,
433 | integral=integral_model,
434 | )
435 | priors = []
436 | else:
437 | ellipse = g2f.SersicParametricEllipse(
438 | size_x=g2f.ReffXParameterD(
439 | self.size_x.value_initial, transform=transform_size, fixed=self.size_x.fixed
440 | ),
441 | size_y=g2f.ReffYParameterD(
442 | self.size_y.value_initial, transform=transform_size, fixed=self.size_y.fixed
443 | ),
444 | rho=g2f.RhoParameterD(self.rho.value_initial, transform=transform_rho, fixed=self.rho.fixed),
445 | )
446 | sersic_index = g2f.SersicMixComponentIndexParameterD(
447 | value=self.sersic_index.value_initial,
448 | fixed=self.sersic_index.fixed,
449 | transform=transforms_ref["logit_sersic"] if not self.sersic_index.fixed else None,
450 | interpolator=self.get_interpolator(order=self.order),
451 | limits=limits_ref["n_ser_multigauss"],
452 | )
453 | component = g2f.SersicMixComponent(
454 | centroid=centroid,
455 | ellipse=ellipse,
456 | integral=integral_model,
457 | sersicindex=sersic_index,
458 | )
459 | prior = self.sersic_index.get_prior(sersic_index) if not sersic_index.fixed else None
460 | priors = [prior] if prior else []
461 | prior = self.get_shape_prior(ellipse)
462 | if prior:
463 | priors.append(prior)
464 | return ComponentData(
465 | component=component,
466 | integral_model=integral_model,
467 | priors=priors,
468 | )
469 |
470 | def validate(self):
471 | super().validate()
472 |
--------------------------------------------------------------------------------
/examples/fithsc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | # # Fitting HSC data in multiband mode using MultiProFit
5 |
6 | # In[1]:
7 |
8 |
9 | # Import required packages
10 | import time
11 | from typing import Any, Iterable, Mapping
12 |
13 | from astropy.coordinates import SkyCoord
14 | from astropy.io.ascii import Csv
15 | import astropy.io.fits as fits
16 | import astropy.table as apTab
17 | import astropy.visualization as apVis
18 | from astropy.wcs import WCS
19 | import lsst.gauss2d as g2
20 | import lsst.gauss2d.fit as g2f
21 | from lsst.multiprofit.componentconfig import SersicComponentConfig, SersicIndexParameterConfig
22 | from lsst.multiprofit.fit_psf import (
23 | CatalogExposurePsfABC,
24 | CatalogPsfFitter,
25 | CatalogPsfFitterConfig,
26 | CatalogPsfFitterConfigData,
27 | )
28 | from lsst.multiprofit.fit_source import (
29 | CatalogExposureSourcesABC,
30 | CatalogSourceFitterABC,
31 | CatalogSourceFitterConfig,
32 | CatalogSourceFitterConfigData,
33 | )
34 | from lsst.multiprofit.modelconfig import ModelConfig
35 | from lsst.multiprofit.plots import plot_model_rgb
36 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
37 | from lsst.multiprofit.utils import ArbitraryAllowedConfig, get_params_uniq
38 | import matplotlib as mpl
39 | import matplotlib.pyplot as plt
40 | import numpy as np
41 | import pydantic
42 | from pydantic.dataclasses import dataclass
43 |
44 | # In[2]:
45 |
46 |
47 | # Define settings
48 | band_ref = 'i'
49 | bands = {'i': 0.87108833, 'r': 0.97288654, 'g': 1.44564678}
50 | band_multi = ''.join(bands)
51 | channels = {band: g2f.Channel.get(band) for band in bands}
52 |
53 | # This is in the WCS, but may as well keep full precision
54 | scale_pixel_hsc = 0.168
55 |
56 | # Common to all FITS
57 | hdu_img, hdu_mask, hdu_var = 1, 2, 3
58 |
59 | # Masks
60 | bad_masks = (
61 | 'BAD', 'SAT', 'INTRP', 'CR', 'EDGE', 'CLIPPED', 'NO_DATA', 'CROSSTALK',
62 | 'NO_DATA', 'UNMASKEDNAN', 'SUSPECT', 'REJECTED', 'SENSOR_EDGE',
63 | )
64 | maskbits = tuple(f'MP_{b}' for b in bad_masks)
65 |
66 | # A pre-defined bitmask to exclude regions with low SN
67 | read_mask_highsn = True
68 | write_mask_highsn = False
69 |
70 | # matplotlib settings
71 | mpl.rcParams['image.origin'] = 'lower'
72 | mpl.rcParams['figure.dpi'] = 120
73 |
74 |
75 | # In[3]:
76 |
77 |
78 | # Define source to fit
79 | id_gama, z = 79635, 0.0403
80 | """
81 | Acquired from https://hsc-release.mtk.nao.ac.jp/datasearch/catalog_jobs with query:
82 |
83 | SELECT object_id, ra, dec,
84 | g_cmodel_mag, g_cmodel_magerr, r_cmodel_mag, r_cmodel_magerr, i_cmodel_mag, i_cmodel_magerr,
85 | g_psfflux_mag, g_psfflux_magerr, r_psfflux_mag, r_psfflux_magerr, i_psfflux_mag, i_psfflux_magerr,
86 | g_kronflux_mag, g_kronflux_magerr, r_kronflux_mag, r_kronflux_magerr, i_kronflux_mag, i_kronflux_magerr,
87 | g_sdssshape_shape11, g_sdssshape_shape11err, g_sdssshape_shape22, g_sdssshape_shape22err,
88 | g_sdssshape_shape12, g_sdssshape_shape12err,
89 | r_sdssshape_shape11, r_sdssshape_shape11err, r_sdssshape_shape22, r_sdssshape_shape22err,
90 | r_sdssshape_shape12, r_sdssshape_shape12err,
91 | i_sdssshape_shape11, i_sdssshape_shape11err, i_sdssshape_shape22, i_sdssshape_shape22err,
92 | i_sdssshape_shape12, i_sdssshape_shape12err
93 | FROM pdr3_wide.forced
94 | LEFT JOIN pdr3_wide.forced2 USING (object_id)
95 | WHERE isprimary AND conesearch(coord, 222.51551376, 0.09749601, 35.64)
96 | AND (r_kronflux_mag < 26 OR i_kronflux_mag < 26) AND NOT i_kronflux_flag AND NOT r_kronflux_flag;
97 | """
98 | cat = Csv()
99 | cat.header.splitter.escapechar = '#'
100 | cat = cat.read('fithsc_src.csv')
101 |
102 | prefix = '222.51551376,0.09749601_'
103 | prefix_img = f'{prefix}300x300_'
104 |
105 | # Read data, acquired with:
106 | # https://github.com/taranu/astro_imaging/blob/4d5a8e095e6a3944f1fbc19318b1dbc22b22d9ca/examples/HSC.ipynb
107 | # (with get_mask=True, get_variance=True,)
108 | images, psfs = {}, {}
109 | for band in bands:
110 | images[band] = fits.open(f'{prefix_img}{band}.fits')
111 | psfs[band] = fits.open(f'{prefix}{band}_psf.fits')
112 |
113 | wcs = WCS(images[band_ref][hdu_img])
114 | cat['x'], cat['y'] = wcs.world_to_pixel(SkyCoord(cat['ra'], cat['dec'], unit='deg'))
115 |
116 |
117 | # In[4]:
118 |
119 |
120 | # Plot image
121 | img_rgb = apVis.make_lupton_rgb(*[img[1].data*bands[band] for band, img in images.items()])
122 | plt.scatter(cat['x'], cat['y'], s=10, c='g', marker='x')
123 | plt.imshow(img_rgb)
124 | plt.title("gri image with detected objects")
125 | plt.show()
126 |
127 |
128 | # In[5]:
129 |
130 |
131 | # Generate a rough mask around other sources
132 | bright = (cat['i_cmodel_mag'] < 23) | (cat['i_psfflux_mag'] < 23)
133 |
134 | img_ref = images[band_ref][hdu_img].data
135 |
136 | mask_inverse = np.ones(img_ref.shape, dtype=bool)
137 | y_cen, x_cen = (x/2. for x in img_ref.shape)
138 | y, x = np.indices(img_ref.shape)
139 |
140 | idx_src_main, row_main = None, None
141 |
142 | sizes_override = {
143 | 42305088563206480: 8.,
144 | }
145 |
146 | for src in cat[bright]:
147 | id_src, x_src, y_src = (src[col] for col in ['object_id', 'x', 'y'])
148 | dist = np.hypot(x_src - x_cen, y_src - y_cen)
149 | if dist > 20:
150 | dists = np.hypot(y - y_src, x - x_src)
151 | mag = np.nanmin([src['i_cmodel_mag'], src['r_cmodel_mag'], src['i_psfflux_mag'], src['r_psfflux_mag']])
152 | if (radius_mask := sizes_override.get(id_src)) is None:
153 | radius_mask = 2*np.sqrt(
154 | src[f'{band_ref}_sdssshape_shape11'] + src[f'{band_ref}_sdssshape_shape22']
155 | )/scale_pixel_hsc
156 | if (radius_mask > 10) and (mag > 21):
157 | radius_mask = 5
158 | mask_inverse[dists < radius_mask] = 0
159 | print(f'Masking src=({id_src} at {x_src:.3f}, {y_src:.3f}) dist={dist:.3f}'
160 | f', mag={mag:.3f}, radius_mask={radius_mask:.3f}')
161 | elif dist < 2:
162 | idx_src_main = id_src
163 | row_main = src
164 | print(f"{idx_src_main=} {dict(src)=}")
165 |
166 | tab_row_main = apTab.Table(row_main)
167 |
168 | if read_mask_highsn:
169 | mask_highsn = np.load(f'{prefix_img}mask_inv_highsn.npz')['mask_inv']
170 | mask_inverse *= mask_highsn
171 |
172 | plt.imshow(mask_inverse)
173 | plt.title("Fitting mask")
174 | plt.show()
175 |
176 |
177 | # In[6]:
178 |
179 |
180 | # Fit PSF
181 | @dataclass(frozen=True, config=ArbitraryAllowedConfig)
182 | class CatalogExposurePsf(CatalogExposurePsfABC):
183 | catalog: apTab.Table = pydantic.Field(title="The detected object catalog")
184 | img: np.ndarray = pydantic.Field(title="The PSF image")
185 |
186 | def get_catalog(self) -> Iterable:
187 | return self.catalog
188 |
189 | def get_psf_image(self, source: apTab.Row | Mapping[str, Any]) -> np.array:
190 | return self.img
191 |
192 | config_psf = CatalogPsfFitterConfig(column_id='object_id')
193 | fitter_psf = CatalogPsfFitter()
194 | catalog_psf = apTab.Table({'object_id': [tab_row_main['object_id']]})
195 | results_psf = {}
196 |
197 | # Keep a separate configdata_psf per band, because it has a cached PSF model
198 | # those should not be shared!
199 | config_data_psfs = {}
200 | for band, psf_file in psfs.items():
201 | config_data_psf = CatalogPsfFitterConfigData(config=config_psf)
202 | catexp = CatalogExposurePsf(catalog=catalog_psf, img=psf_file[0].data)
203 | t_start = time.time()
204 | result = fitter_psf.fit(config_data=config_data_psf, catexp=catexp)
205 | t_end = time.time()
206 | results_psf[band] = result
207 | config_data_psfs[band] = config_data_psf
208 | print(f"Fit {band}-band PSF in {t_end - t_start:.2e}s; result:")
209 | print(dict(result[0]))
210 |
211 |
212 | # In[7]:
213 |
214 |
215 | # Set fit configs
216 | config_source = CatalogSourceFitterConfig(
217 | column_id='object_id',
218 | config_model=ModelConfig(
219 | sources={
220 | "src": SourceConfig(
221 | component_groups={
222 | "": ComponentGroupConfig(
223 | components_sersic={
224 | 'disk': SersicComponentConfig(
225 | sersic_index=SersicIndexParameterConfig(value_initial=1., fixed=True),
226 | prior_size_stddev=0.5,
227 | prior_axrat_stddev=0.2,
228 | ),
229 | 'bulge': SersicComponentConfig(
230 | sersic_index=SersicIndexParameterConfig(value_initial=4., fixed=True),
231 | prior_size_stddev=0.1,
232 | prior_axrat_stddev=0.2,
233 | ),
234 | },
235 | ),
236 | }
237 | ),
238 | },
239 | ),
240 | )
241 | config_data_source = CatalogSourceFitterConfigData(
242 | channels=list(channels.values()),
243 | config=config_source,
244 | )
245 |
246 |
247 | # In[8]:
248 |
249 |
250 | # Setup exposure with band-specific image, mask and variance
251 | @dataclass(frozen=True, config=ArbitraryAllowedConfig)
252 | class CatalogExposureSources(CatalogExposureSourcesABC):
253 | config_data_psf: CatalogPsfFitterConfigData = pydantic.Field(title="The PSF fit config")
254 | observation: g2f.ObservationD = pydantic.Field(title="The observation to fit")
255 | table_psf_fits: apTab.Table = pydantic.Field(title="The table of PSF fit parameters")
256 |
257 | @property
258 | def channel(self) -> g2f.Channel:
259 | return self.observation.channel
260 |
261 | def get_catalog(self) -> Iterable:
262 | return self.table_psf_fits
263 |
264 | def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel:
265 | self.config_data_psf.init_psf_model(params)
266 | return self.config_data_psf.psf_model
267 |
268 | def get_source_observation(self, source: Mapping[str, Any]) -> g2f.ObservationD:
269 | return self.observation
270 |
271 |
272 | @dataclass(frozen=True, config=ArbitraryAllowedConfig)
273 | class CatalogSourceFitter(CatalogSourceFitterABC):
274 | band: str = pydantic.Field(title="The reference band for initialization and priors")
275 | scale_pixel: float = pydantic.Field(title="The pixel scale in arcsec")
276 | wcs_ref: WCS = pydantic.Field(title="The WCS for the coadded image")
277 |
278 | def initialize_model(
279 | self,
280 | model: g2f.ModelD,
281 | source: Mapping[str, Any],
282 | catexps: list[CatalogExposureSourcesABC],
283 | values_init: Mapping[g2f.ParameterD, float] | None = None,
284 | centroid_pixel_offset: float = 0,
285 | **kwargs
286 | ):
287 | if values_init is None:
288 | values_init = {}
289 | x, y = source['x'], source['y']
290 | scale_sq = self.scale_pixel**(-2)
291 | ellipse = g2.Ellipse(g2.Covariance(
292 | sigma_x_sq=source[f'{band}_sdssshape_shape11']*scale_sq,
293 | sigma_y_sq=source[f'{band}_sdssshape_shape22']*scale_sq,
294 | cov_xy=source[f'{band}_sdssshape_shape12']*scale_sq,
295 | ))
296 | size_major = g2.EllipseMajor(ellipse).r_major
297 | limits_size = g2f.LimitsD(1e-5, np.sqrt(x*x + y*y))
298 | # An R_eff larger than the box size is problematic
299 | # Also should stop unreasonable size proposals; log10 transform isn't enough
300 | # TODO: Try logit for r_eff?
301 | params_limits_init = {
302 | # Should set limits based on image size, but this shortcut is fine
303 | # for this particular object
304 | g2f.CentroidXParameterD: (x, g2f.LimitsD(0, 2*x)),
305 | g2f.CentroidYParameterD: (x, g2f.LimitsD(0, 2*y)),
306 | g2f.ReffXParameterD: (ellipse.sigma_x, limits_size),
307 | g2f.ReffYParameterD: (ellipse.sigma_y, limits_size),
308 | # There is a sign convention difference
309 | g2f.RhoParameterD: (-ellipse.rho, None),
310 | g2f.IntegralParameterD: (1.0, g2f.LimitsD(1e-10, 1e10)),
311 | }
312 | params_free = get_params_uniq(model, fixed=False)
313 | for param in params_free:
314 | type_param = type(param)
315 | value_init, limits_new = params_limits_init.get(
316 | type_param,
317 | (values_init.get(param), None)
318 | )
319 | if value_init is not None:
320 | param.value = value_init
321 | if limits_new:
322 | # For slightly arcane reasons, we must set a new limits object
323 | # Changing limits values is unreliable
324 | param.limits = limits_new
325 | for prior in model.priors:
326 | if isinstance(prior, g2f.ShapePrior):
327 | prior.prior_size.mean_parameter.value = size_major
328 |
329 |
330 | def validate_fit_inputs(
331 | self,
332 | catalog_multi,
333 | catexps: list[CatalogExposureSourcesABC],
334 | config_data: CatalogSourceFitterConfigData = None,
335 | logger = None,
336 | **kwargs: Any,
337 | ) -> None:
338 | super().validate_fit_inputs(
339 | catalog_multi=catalog_multi, catexps=catexps, config_data=config_data,
340 | logger=logger, **kwargs
341 | )
342 |
343 |
344 | # In[9]:
345 |
346 |
347 | # Set up Fitter, Observations and CatalogExposureSources
348 | fitter = CatalogSourceFitter(band=band, scale_pixel=scale_pixel_hsc, wcs_ref=wcs)
349 |
350 | observations = {}
351 | catexps = {}
352 |
353 | for band in bands:
354 | data = images[band]
355 | # There are better ways to use bitmasks, but this will do
356 | header = data[hdu_mask].header
357 | bitmask = data[hdu_mask].data
358 | mask = np.zeros_like(bitmask, dtype='bool')
359 | for bit in maskbits:
360 | mask |= ((bitmask & 2**header[bit]) != 0)
361 |
362 | mask = (mask == 0) & mask_inverse
363 | sigma_inv = 1.0/np.sqrt(data[hdu_var].data)
364 | sigma_inv[mask != 1] = 0
365 |
366 | observation = g2f.ObservationD(
367 | image=g2.ImageD(data[hdu_img].data),
368 | sigma_inv=g2.ImageD(sigma_inv),
369 | mask_inv=g2.ImageB(mask),
370 | channel=g2f.Channel.get(band),
371 | )
372 | observations[band] = observation
373 | catexps[band] = CatalogExposureSources(
374 | config_data_psf=config_data_psfs[band],
375 | observation=observation,
376 | table_psf_fits=results_psf[band],
377 | )
378 |
379 |
380 | # In[10]:
381 |
382 |
383 | # Now do the multi-band fit
384 | t_start = time.time()
385 | result_multi = fitter.fit(
386 | catalog_multi=tab_row_main,
387 | catexps=list(catexps.values()),
388 | config_data=config_data_source,
389 | )
390 | t_end = time.time()
391 | print(f"Fit {','.join(bands.keys())}-band bulge-disk model in {t_end - t_start:.2e}s; result:")
392 | print(dict(result_multi[0]))
393 |
394 |
395 | # In[11]:
396 |
397 |
398 | # Fit in each band separately
399 | results = {}
400 | for band, observation in bands.items():
401 | config_data_source_band = CatalogSourceFitterConfigData(
402 | channels=[channels[band]],
403 | config=config_source,
404 | )
405 | t_start = time.time()
406 | result = fitter.fit(
407 | catalog_multi=tab_row_main,
408 | catexps=[catexps[band]],
409 | config_data=config_data_source_band,
410 | )
411 | t_end = time.time()
412 | results[band] = result
413 | print(f"Fit {band}-band bulge-disk model in {t_end - t_start:.2f}s; result:")
414 | print(dict(result[0]))
415 |
416 |
417 | # In[12]:
418 |
419 |
420 | # Make a model for the best-fit params
421 | data, psf_models = config_source.make_model_data(idx_row=0, catexps=list(catexps.values()))
422 | model = g2f.ModelD(data=data, psfmodels=psf_models, sources=config_data_source.sources_priors[0], priors=config_data_source.sources_priors[1])
423 | params = get_params_uniq(model, fixed=False)
424 | result_multi_row = dict(result_multi[0])
425 | # This is the last column before fit params
426 | idx_last = next(idx for idx, column in enumerate(result_multi_row.keys()) if column == 'mpf_unknown_flag')
427 | # Set params to best fit values
428 | for param, (column, value) in zip(params, list(result_multi_row.items())[idx_last+1:]):
429 | param.value = value
430 | model.setup_evaluators(g2f.EvaluatorMode.loglike_image)
431 | # Print the loglikelihoods, which are from the data and end with the (sum of all) priors
432 | loglikes = model.evaluate()
433 | print(f"{loglikes=}")
434 |
435 |
436 | # ### Multiband Residuals
437 | #
438 | # What's with the structure in the residuals? Most broadly, a point source + exponential disk + deVauc bulge model is totally inadequate for this galaxy for several possible reasons:
439 | #
440 | # 1. The disk isn't exactly exponential (n=1)
441 | # 2. The disk has colour gradients not accounted for in this model*
442 | # 3. If the galaxy even has a bulge, it's very weak and def. not a deVaucouleurs (n=4) profile; it may be an exponential "pseudobulge"
443 | #
444 | # \*MultiProFit can do more general Gaussian mixture models (linear or non-linear), which may be explored in a future iteration of this notebook, but these are generally do not improve the accuracy of photometry for smaller/fainter galaxies.
445 | #
446 | # Note that the two scalings of the residual plots (98%ile and +/- 20 sigma) end up looking very similar.
447 | #
448 |
449 | # In[13]:
450 |
451 |
452 | # Make some basic plots
453 | _, _, _, _, mask_inv_highsn = plot_model_rgb(
454 | model, weights=bands, high_sn_threshold=0.2 if write_mask_highsn else None,
455 | )
456 | plt.show()
457 |
458 | # Write the high SN bitmask to a compressed, bitpacked file
459 | if write_mask_highsn:
460 | plt.figure()
461 | plt.imshow(mask_highsn, cmap='gray')
462 | plt.show()
463 | packed = np.packbits(mask_inv_highsn, bitorder='little')
464 | np.savez_compressed(f'{prefix_img}mask_inv_highsn.npz', mask_inv=mask_highsn)
465 |
466 | # TODO: Some features still missing from plot_model_rgb
467 | # residual histograms, param values, better labels, etc
468 |
469 |
470 | # ### More exercises for the reader
471 | #
472 | # These are of the sort that the author hasn't gotten around to yet because they're far from trivial. Try:
473 | #
474 | # 0. Use the WCS to compute ra, dec and errors thereof.
475 | # Hint: override CatalogSourceFitter.get_model_radec
476 | #
477 | # 1. Replace the real data with simulated data.
478 | # Make new observations using model.evaluate and add noise based on the variance maps.
479 | # Try fitting again and see how well results converge depending on the initialization scheme.
480 | #
481 | # 2. Fit every other source individually.
482 | # Try subtracting the best-fit galaxy model from above first.
483 | # Hint: get_source_observation should be redefined to return a smaller postage stamp around the nominal centroid.
484 | # Pass the full catalog (excluding the central galaxy) to catalog_multi.
485 | #
486 | # 3. Fit all sources simultaneously.
487 | # Redefine CatalogFitterConfig.make_model_data to make a model with multiple sources, using the catexp catalogs
488 | # initialize_model will no longer need to do anything
489 | # catalog_multi should still be a single row
490 |
--------------------------------------------------------------------------------
/tests/test_modeller.py:
--------------------------------------------------------------------------------
1 | # This file is part of multiprofit.
2 | #
3 | # Developed for the LSST Data Management System.
4 | # This product includes software developed by the LSST Project
5 | # (https://www.lsst.org).
6 | # See the COPYRIGHT file at the top-level directory of this distribution
7 | # for details of code ownership.
8 | #
9 | # This program is free software: you can redistribute it and/or modify
10 | # it under the terms of the GNU General Public License as published by
11 | # the Free Software Foundation, either version 3 of the License, or
12 | # (at your option) any later version.
13 | #
14 | # This program is distributed in the hope that it will be useful,
15 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 | # GNU General Public License for more details.
18 | #
19 | # You should have received a copy of the GNU General Public License
20 | # along with this program. If not, see .
21 |
22 | import math
23 | import time
24 |
25 | import lsst.gauss2d as g2
26 | import lsst.gauss2d.fit as g2f
27 | from lsst.multiprofit.componentconfig import (
28 | CentroidConfig,
29 | FluxFractionParameterConfig,
30 | FluxParameterConfig,
31 | GaussianComponentConfig,
32 | ParameterConfig,
33 | SersicComponentConfig,
34 | SersicIndexParameterConfig,
35 | )
36 | from lsst.multiprofit.model_utils import make_image_gaussians, make_psf_model_null
37 | from lsst.multiprofit.modelconfig import ModelConfig
38 | from lsst.multiprofit.modeller import FitInputs, LinearGaussians, Modeller, fitmethods_linear
39 | from lsst.multiprofit.observationconfig import CoordinateSystemConfig, ObservationConfig
40 | from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig
41 | from lsst.multiprofit.utils import get_params_uniq
42 | import numpy as np
43 | import pytest
44 |
45 | sigma_inv = 1e4
46 |
47 |
48 | @pytest.fixture(scope="module")
49 | def channels() -> dict[str, g2f.Channel]:
50 | return {band: g2f.Channel.get(band) for band in ("R", "G", "B")}
51 |
52 |
53 | @pytest.fixture(scope="module")
54 | def data(channels) -> g2f.DataD:
55 | n_rows, n_cols = 25, 27
56 | x_min, y_min = 0, 0
57 |
58 | dn_rows, dn_cols = 2, -3
59 | dx_min, dy_min = -1, 1
60 |
61 | observations = []
62 | for idx, band in enumerate(channels):
63 | config = ObservationConfig(
64 | band=band,
65 | coordsys=CoordinateSystemConfig(
66 | x_min=x_min + idx*dx_min,
67 | y_min=y_min + idx*dy_min,
68 | ),
69 | n_rows=n_rows + idx*dn_rows,
70 | n_cols=n_cols + idx*dn_cols,
71 | )
72 | observation = config.make_observation()
73 | observation.image.fill(0)
74 | observation.sigma_inv.fill(sigma_inv)
75 | observation.mask_inv.fill(1)
76 | observations.append(observation)
77 | return g2f.DataD(observations)
78 |
79 |
80 | @pytest.fixture(scope="module")
81 | def psf_models(channels) -> list[g2f.PsfModel]:
82 | rho, size_x, size_y = 0.12, 1.6, 1.2
83 | drho, dsize_x, dsize_y = -0.3, 1.1, 1.9
84 | drho_chan, dsize_x_chan, dsize_y_chan = 0.03, 0.12, 0.14
85 | frac, dfrac = 0.62, -0.08
86 |
87 | n_components = 2
88 | psf_models = []
89 |
90 | for idx_chan, channel in enumerate(channels.values()):
91 | frac_chan = frac + idx_chan*dfrac
92 | config = SourceConfig(
93 | component_groups={
94 | 'psf': ComponentGroupConfig(
95 | components_gauss={
96 | str(idx): GaussianComponentConfig(
97 | rho=ParameterConfig(value_initial=rho + idx*drho + idx_chan*drho_chan),
98 | size_x=ParameterConfig(
99 | value_initial=size_x + idx*dsize_x + idx_chan*dsize_x_chan),
100 | size_y=ParameterConfig(
101 | value_initial=size_y + idx*dsize_y + idx_chan*dsize_y_chan),
102 | **({
103 | "flux": FluxParameterConfig(value_initial=1.0, fixed=True),
104 | "fluxfrac": FluxFractionParameterConfig(value_initial=frac_chan, fixed=False),
105 | } if (idx == 0) else {})
106 | )
107 | for idx in range(n_components)
108 | },
109 | is_fractional=True,
110 | )
111 | },
112 | )
113 | config.validate()
114 | psf_model, priors = config.make_psf_model([
115 | component_group.get_fluxes_default(
116 | channels=(g2f.Channel.NONE,),
117 | component_configs=component_group.get_component_configs(),
118 | is_fractional=component_group.is_fractional,
119 | )
120 | for component_group in config.component_groups.values()
121 | ])
122 | psf_models.append(psf_model)
123 | return psf_models
124 |
125 |
126 | @pytest.fixture(scope="module")
127 | def model(channels, data, psf_models) -> g2f.ModelD:
128 | rho, size_x, size_y, sersicn, flux = 0.4, 1.5, 1.9, 1.0, 4.7
129 | drho, dsize_x, dsize_y, dsersicn, dflux = -0.9, 2.5, 5.4, 3.0, 13.9
130 |
131 | components_sersic = {}
132 | fluxes_group = []
133 |
134 | # Linear interpolators fail to compute accurate likelihoods at knot values
135 | is_linear_interp = g2f.SersicMixComponentIndexParameterD(
136 | interpolator=SersicComponentConfig().get_interpolator(4)
137 | ).interptype == g2f.InterpType.linear
138 |
139 | for idx, name in enumerate(("exp", "dev")):
140 | components_sersic[name] = SersicComponentConfig(
141 | rho=ParameterConfig(value_initial=rho + idx*drho),
142 | size_x=ParameterConfig(value_initial=size_x + idx*dsize_x),
143 | size_y=ParameterConfig(value_initial=size_y + idx*dsize_y),
144 | sersic_index=SersicIndexParameterConfig(
145 | # Add a small offset since 1.0 and 4.0 are bound to be knots
146 | value_initial=sersicn + idx * dsersicn + 1e-4*is_linear_interp,
147 | fixed=idx == 0,
148 | prior_mean=None,
149 | ),
150 | )
151 | fluxes_comp = {
152 | channel: flux + idx_channel*dflux*idx
153 | for idx_channel, channel in enumerate(channels.values())
154 | }
155 | fluxes_group.append(fluxes_comp)
156 |
157 | modelconfig = ModelConfig(
158 | sources={
159 | "src": SourceConfig(
160 | component_groups={
161 | "": ComponentGroupConfig(
162 | components_sersic=components_sersic,
163 | centroids={"default": CentroidConfig(
164 | x=ParameterConfig(value_initial=12.14, fixed=True),
165 | y=ParameterConfig(value_initial=13.78, fixed=True),
166 | )},
167 | ),
168 | }
169 | ),
170 | },
171 | )
172 | model = modelconfig.make_model([[fluxes_group]], data=data, psf_models=psf_models)
173 | return model
174 |
175 |
176 | @pytest.fixture(scope="module")
177 | def model_jac(model) -> g2f.ModelD:
178 | model_jac = g2f.ModelD(data=model.data, psfmodels=model.psfmodels, sources=model.sources)
179 | return model_jac
180 |
181 |
182 | @pytest.fixture(scope="module")
183 | def psf_observations(psf_models) -> list[g2f.ObservationD]:
184 | config = ObservationConfig(n_rows=17, n_cols=19)
185 | rng = np.random.default_rng(1)
186 |
187 | observations = []
188 | for psf_model in psf_models:
189 | observation = config.make_observation()
190 | # Have to make a duplicate image here because one can only call
191 | # make_image_gaussians with an owning pointer, whereas
192 | # observation.image is a reference
193 | image = g2.ImageD(observation.image.data)
194 | # Make the kernel centered
195 | gaussians_source = psf_model.gaussians(g2f.Channel.NONE)
196 | for idx in range(len(gaussians_source)):
197 | gaussian_idx = gaussians_source.at(idx)
198 | gaussian_idx.centroid.x = image.n_cols/2.
199 | gaussian_idx.centroid.y = image.n_rows/2.
200 | gaussians_kernel = g2.Gaussians([g2.Gaussian()])
201 | make_image_gaussians(
202 | gaussians_source=gaussians_source,
203 | gaussians_kernel=gaussians_kernel,
204 | output=image,
205 | )
206 | image.data.flat += 1e-4 * rng.standard_normal(image.data.size)
207 | observation.mask_inv.fill(1)
208 | observation.sigma_inv.fill(1e3)
209 | observations.append(observation)
210 | return observations
211 |
212 |
213 | @pytest.fixture(scope="module")
214 | def psf_fit_models(psf_models, psf_observations):
215 | psf_null = [make_psf_model_null()]
216 | return [
217 | g2f.ModelD(g2f.DataD([observation]), psf_null, [g2f.Source(psf_model.components)])
218 | for psf_model, observation in zip(psf_models, psf_observations)
219 | ]
220 |
221 |
222 | def test_model_evaluation(channels, model, model_jac):
223 | with pytest.raises(RuntimeError):
224 | model.evaluate()
225 |
226 | printout = False
227 | # Freeze the PSF params - they can't be fit anyway
228 | for m in (model, model_jac):
229 | for psf_model in m.psfmodels:
230 | params = psf_model.parameters()
231 | for param in params:
232 | param.fixed = True
233 |
234 | model.setup_evaluators(print=printout)
235 | model.evaluate()
236 |
237 | n_priors = 0
238 | n_obs = len(model.data)
239 | n_rows = np.zeros(n_obs, dtype=int)
240 | n_cols = np.zeros(n_obs, dtype=int)
241 | datasizes = np.zeros(n_obs, dtype=int)
242 | ranges_params = [None] * n_obs
243 | params_free = tuple(get_params_uniq(model_jac, fixed=False))
244 |
245 | # There's one extra validation array
246 | n_params_jac = len(params_free) + 1
247 | assert n_params_jac > 1
248 |
249 | rng = np.random.default_rng(2)
250 |
251 | for idx_obs in range(n_obs):
252 | observation = model.data[idx_obs]
253 | output = model.outputs[idx_obs]
254 | observation.image.data.flat = (
255 | output.data.flat + rng.standard_normal(output.data.size) / observation.sigma_inv.data.flat
256 | )
257 | n_rows[idx_obs] = observation.image.n_rows
258 | n_cols[idx_obs] = observation.image.n_cols
259 | datasizes[idx_obs] = n_rows[idx_obs] * n_cols[idx_obs]
260 | params = tuple(get_params_uniq(model, fixed=False, channel=observation.channel))
261 | n_params_obs = len(params)
262 | ranges_params_obs = [0] * (n_params_obs + 1)
263 | for idx_param in range(n_params_obs):
264 | ranges_params_obs[idx_param + 1] = params_free.index(params[idx_param]) + 1
265 | ranges_params[idx_obs] = ranges_params_obs
266 |
267 | n_free_first = len(ranges_params[0])
268 | assert all([len(rp) == n_free_first for rp in ranges_params[1:]])
269 |
270 | jacobians = [None] * n_obs
271 | residuals = [None] * n_obs
272 | datasize = np.sum(datasizes) + n_priors
273 | jacobian = np.zeros((datasize, n_params_jac))
274 | residual = np.zeros(datasize)
275 | # jacobian_prior = self.jacobian[datasize:, ].view()
276 |
277 | offset = 0
278 | for idx_obs in range(n_obs):
279 | size_obs = datasizes[idx_obs]
280 | end = offset + size_obs
281 | shape = (n_rows[idx_obs], n_cols[idx_obs])
282 | jacobians_obs = [None] * n_params_jac
283 | for idx_jac in range(n_params_jac):
284 | jacobians_obs[idx_jac] = g2.ImageD(jacobian[offset:end, idx_jac].view().reshape(shape))
285 | jacobians[idx_obs] = jacobians_obs
286 | residuals[idx_obs] = g2.ImageD(residual[offset:end].view().reshape(shape))
287 | offset = end
288 |
289 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike)
290 | loglike_init = model.evaluate()
291 |
292 | model_jac.setup_evaluators(
293 | evaluatormode=g2f.EvaluatorMode.jacobian,
294 | outputs=jacobians,
295 | residuals=residuals,
296 | print=printout,
297 | )
298 | model_jac.verify_jacobian()
299 | loglike_jac = model_jac.evaluate()
300 |
301 | assert all(np.isclose(loglike_init, loglike_jac))
302 |
303 |
304 | @pytest.fixture(scope="module")
305 | def psf_models_linear_gaussians(channels, psf_models):
306 | gaussians = [None] * len(psf_models)
307 | for idx, psf_model in enumerate(psf_models):
308 | params = psf_model.parameters(paramfilter=g2f.ParamFilter(nonlinear=False, channel=g2f.Channel.NONE))
309 | params[0].fixed = False
310 | gaussians[idx] = LinearGaussians.make(psf_model, is_psf=True)
311 | # If this is not done, test_psf_model_fit will fail
312 | params[0].fixed = True
313 | return gaussians
314 |
315 |
316 | def test_make_psf_source_linear(psf_models, psf_models_linear_gaussians):
317 | for psf_model, linear_gaussians in zip(psf_models, psf_models_linear_gaussians):
318 | gaussians = psf_model.gaussians(g2f.Channel.NONE)
319 | assert len(gaussians) == (
320 | len(linear_gaussians.gaussians_free) + len(linear_gaussians.gaussians_fixed)
321 | )
322 |
323 |
324 | def test_modeller(model):
325 | # For debugging purposes
326 | printout = False
327 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike_image)
328 | # Get the model images
329 | model.evaluate()
330 | rng = np.random.default_rng(3)
331 |
332 | for idx_obs, observation in enumerate(model.data):
333 | output = model.outputs[idx_obs]
334 | observation.image.data.flat = (
335 | output.data.flat + rng.standard_normal(output.data.size) / observation.sigma_inv.data.flat
336 | )
337 |
338 | # Freeze the PSF params - they can't be fit anyway
339 | for psf_model in model.psfmodels:
340 | for param in psf_model.parameters():
341 | param.fixed = True
342 |
343 | params_free = tuple(get_params_uniq(model, fixed=False))
344 | values_true = tuple(param.value for param in params_free)
345 |
346 | modeller = Modeller()
347 |
348 | dloglike = model.compute_loglike_grad(verify=True, findiff_frac=1e-8, findiff_add=1e-8)
349 | assert all(np.isfinite(dloglike))
350 |
351 | time_init = time.process_time()
352 | kwargs_fit = dict(ftol=1e-6, xtol=1e-6)
353 |
354 | for delta_param in (0, 0.2):
355 | model = g2f.ModelD(data=model.data, psfmodels=model.psfmodels, sources=model.sources)
356 | values_init = values_true
357 | if delta_param != 0:
358 | for param, value_init in zip(params_free, values_init):
359 | param.value = value_init
360 | try:
361 | param.value_transformed += delta_param
362 | except RuntimeError:
363 | param.value_transformed -= delta_param
364 |
365 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike)
366 | loglike_init = np.array(model.evaluate())
367 | results = modeller.fit_model(model, **kwargs_fit)
368 | params_best = results.params_best
369 |
370 | for param, value in zip(params_free, params_best):
371 | param.value_transformed = value
372 |
373 | loglike_noprior = model.evaluate()
374 | assert np.sum(loglike_noprior) > np.sum(loglike_init)
375 |
376 | errors = modeller.compute_variances(model)
377 | # TODO: This should check >0, and < (some reasonable value), but the scipy least squares
378 | # does not do a great job optimizing and the loglike_grad isn't even negative...
379 | assert np.all(np.isfinite(errors))
380 |
381 | if printout:
382 | print(
383 | f"got loglike={loglike_noprior} (init={sum(loglike_noprior)})"
384 | f" from modeller.fit_model in t={time.process_time() - time_init:.3e}, x={params_best},"
385 | f" results: \n{results}"
386 | )
387 |
388 | loglike_noprior_sum = sum(loglike_noprior)
389 | for offset in (0, 1e-6):
390 | for param, value in zip(params_free, params_best):
391 | param.value_transformed = value
392 | priors = tuple(
393 | g2f.GaussianPrior(param, param.value_transformed + offset, 1.0, transformed=True)
394 | for param in params_free
395 | )
396 | if offset == 0:
397 | for p in priors:
398 | assert p.evaluate().loglike == 0
399 | assert p.loglike_const_terms[0] == -math.log(math.sqrt(2 * math.pi))
400 | model = g2f.ModelD(
401 | data=model.data, psfmodels=model.psfmodels, sources=model.sources, priors=priors
402 | )
403 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike)
404 | loglike_init = sum(loglike_eval for loglike_eval in model.evaluate())
405 | if offset == 0:
406 | assert np.isclose(loglike_init, loglike_noprior_sum, rtol=1e-10, atol=1e-10)
407 | else:
408 | assert loglike_init < loglike_noprior_sum
409 |
410 | time_init = time.process_time()
411 | results = modeller.fit_model(model, **kwargs_fit)
412 | time_init = time.process_time() - time_init
413 | loglike_new = -results.result.cost
414 | for param, value in zip(params_free, results.params_best):
415 | param.value_transformed = value
416 |
417 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike)
418 | loglike_model = sum(loglike_eval for loglike_eval in model.evaluate())
419 | assert np.isclose(loglike_new, loglike_model, rtol=1e-10, atol=1e-10)
420 | # This should be > 0. TODO: Determine why it isn't always
421 | assert (loglike_new - loglike_init) > -1e-3
422 |
423 | if printout:
424 | print(
425 | f"got loglike={loglike_new} (first={loglike_noprior})"
426 | f" from modeller.fit_model in t={time_init:.3e}, x={results.params_best},"
427 | f" results: \n{results}"
428 | )
429 | # Adding a suitably-scaled prior far from the truth should always
430 | # worsen loglikel, but doesn't - why? noise bias? bad convergence?
431 | # assert (loglike_new >= loglike_noprior) == (offset == 0)
432 |
433 |
434 | def test_psf_model_fit(psf_fit_models):
435 | for model in psf_fit_models:
436 | params = get_params_uniq(model.sources[0])
437 | for param in params:
438 | # Fitting the total flux won't work in a fractional model (yet)
439 | if isinstance(param, g2f.IntegralParameterD):
440 | assert param.fixed
441 | else:
442 | param.fixed = False
443 | # Necessary whenever parameters are freed/fixed
444 | model.setup_evaluators(g2f.EvaluatorMode.jacobian, force=True)
445 | errors = model.verify_jacobian(rtol=5e-4, atol=5e-4, findiff_add=1e-6, findiff_frac=1e-6)
446 | if errors:
447 | import matplotlib.pyplot as plt
448 | print(model.parameters())
449 |
450 | fitinputs = FitInputs.from_model(model)
451 | model.setup_evaluators(
452 | evaluatormode=g2f.EvaluatorMode.jacobian,
453 | outputs=fitinputs.jacobians,
454 | residuals=fitinputs.residuals,
455 | print=True,
456 | force=True,
457 | )
458 | model.evaluate(print=True)
459 | assert (fitinputs.jacobians[0][0].data == 0).all()
460 | assert np.sum(np.abs(fitinputs.jacobians[0][1].data)) > 0
461 | model.setup_evaluators(evaluatormode=g2f.EvaluatorMode.loglike_image)
462 | model.evaluate()
463 | outputs = model.outputs
464 | diffs = [g2.ImageD(img.data.copy()) for img in outputs]
465 | delta = 1e-5
466 | param.value -= delta
467 | model.evaluate()
468 | for diff, output in zip(diffs, outputs):
469 | diff = (output.data - diff.data) / delta
470 | jacobian = fitinputs.jacobians[0][1].data
471 | fig, ax = plt.subplots(1, 2)
472 | ax[0].imshow(diff)
473 | ax[1].imshow(jacobian)
474 | plt.show()
475 | assert len(errors) == 0
476 |
477 |
478 | def test_psf_models_linear_gaussians(data, psf_models_linear_gaussians, psf_observations):
479 | results = [None] * len(psf_observations)
480 | for idx, (gaussians_linear, observation_psf) in enumerate(
481 | zip(psf_models_linear_gaussians, psf_observations)
482 | ):
483 | results[idx] = Modeller.fit_gaussians_linear(
484 | gaussians_linear=gaussians_linear,
485 | observation=observation_psf,
486 | fitmethods=fitmethods_linear,
487 | plot=False,
488 | )
489 | assert len(results[idx]) > 0
490 |
491 |
492 | def test_modeller_fit_linear(model):
493 | modeller = Modeller()
494 | results = modeller.fit_model_linear(model, validate=True)
495 | # TODO: add more here
496 | assert results is not None
497 |
--------------------------------------------------------------------------------
/examples/fithsc.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "collapsed": true
7 | },
8 | "source": [
9 | "# Fitting HSC data in multiband mode using MultiProFit"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "outputs": [],
16 | "source": [
17 | "# Import required packages\n",
18 | "import time\n",
19 | "from typing import Any, Iterable, Mapping\n",
20 | "\n",
21 | "from astropy.coordinates import SkyCoord\n",
22 | "from astropy.io.ascii import Csv\n",
23 | "import astropy.io.fits as fits\n",
24 | "import astropy.table as apTab\n",
25 | "import astropy.visualization as apVis\n",
26 | "from astropy.wcs import WCS\n",
27 | "import lsst.gauss2d as g2\n",
28 | "import lsst.gauss2d.fit as g2f\n",
29 | "from lsst.multiprofit.componentconfig import SersicComponentConfig, SersicIndexParameterConfig\n",
30 | "from lsst.multiprofit.fit_psf import (\n",
31 | " CatalogExposurePsfABC,\n",
32 | " CatalogPsfFitter,\n",
33 | " CatalogPsfFitterConfig,\n",
34 | " CatalogPsfFitterConfigData,\n",
35 | ")\n",
36 | "from lsst.multiprofit.fit_source import (\n",
37 | " CatalogExposureSourcesABC,\n",
38 | " CatalogSourceFitterABC,\n",
39 | " CatalogSourceFitterConfig,\n",
40 | " CatalogSourceFitterConfigData,\n",
41 | ")\n",
42 | "from lsst.multiprofit.modelconfig import ModelConfig\n",
43 | "from lsst.multiprofit.plots import plot_model_rgb\n",
44 | "from lsst.multiprofit.sourceconfig import ComponentGroupConfig, SourceConfig\n",
45 | "from lsst.multiprofit.utils import ArbitraryAllowedConfig, get_params_uniq\n",
46 | "import matplotlib as mpl\n",
47 | "import matplotlib.pyplot as plt\n",
48 | "import numpy as np\n",
49 | "import pydantic\n",
50 | "from pydantic.dataclasses import dataclass"
51 | ],
52 | "metadata": {
53 | "collapsed": false,
54 | "ExecuteTime": {
55 | "end_time": "2024-06-14T06:17:42.558190725Z",
56 | "start_time": "2024-06-14T06:17:41.498267418Z"
57 | }
58 | }
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 2,
63 | "metadata": {
64 | "ExecuteTime": {
65 | "end_time": "2024-06-14T06:17:42.608230886Z",
66 | "start_time": "2024-06-14T06:17:42.560250073Z"
67 | }
68 | },
69 | "outputs": [],
70 | "source": [
71 | "# Define settings\n",
72 | "band_ref = 'i'\n",
73 | "bands = {'i': 0.87108833, 'r': 0.97288654, 'g': 1.44564678}\n",
74 | "band_multi = ''.join(bands)\n",
75 | "channels = {band: g2f.Channel.get(band) for band in bands}\n",
76 | "\n",
77 | "# This is in the WCS, but may as well keep full precision\n",
78 | "scale_pixel_hsc = 0.168\n",
79 | "\n",
80 | "# Common to all FITS\n",
81 | "hdu_img, hdu_mask, hdu_var = 1, 2, 3\n",
82 | "\n",
83 | "# Masks\n",
84 | "bad_masks = (\n",
85 | " 'BAD', 'SAT', 'INTRP', 'CR', 'EDGE', 'CLIPPED', 'NO_DATA', 'CROSSTALK',\n",
86 | " 'NO_DATA', 'UNMASKEDNAN', 'SUSPECT', 'REJECTED', 'SENSOR_EDGE',\n",
87 | ")\n",
88 | "maskbits = tuple(f'MP_{b}' for b in bad_masks)\n",
89 | "\n",
90 | "# A pre-defined bitmask to exclude regions with low SN\n",
91 | "read_mask_highsn = True\n",
92 | "write_mask_highsn = False\n",
93 | "\n",
94 | "# matplotlib settings\n",
95 | "mpl.rcParams['image.origin'] = 'lower'\n",
96 | "mpl.rcParams['figure.dpi'] = 120"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": 3,
102 | "metadata": {
103 | "ExecuteTime": {
104 | "end_time": "2024-06-14T06:17:42.613129061Z",
105 | "start_time": "2024-06-14T06:17:42.569183931Z"
106 | }
107 | },
108 | "outputs": [],
109 | "source": [
110 | "# Define source to fit\n",
111 | "id_gama, z = 79635, 0.0403\n",
112 | "\"\"\"\n",
113 | "Acquired from https://hsc-release.mtk.nao.ac.jp/datasearch/catalog_jobs with query:\n",
114 | "\n",
115 | "SELECT object_id, ra, dec,\n",
116 | " g_cmodel_mag, g_cmodel_magerr, r_cmodel_mag, r_cmodel_magerr, i_cmodel_mag, i_cmodel_magerr,\n",
117 | " g_psfflux_mag, g_psfflux_magerr, r_psfflux_mag, r_psfflux_magerr, i_psfflux_mag, i_psfflux_magerr,\n",
118 | " g_kronflux_mag, g_kronflux_magerr, r_kronflux_mag, r_kronflux_magerr, i_kronflux_mag, i_kronflux_magerr,\n",
119 | " g_sdssshape_shape11, g_sdssshape_shape11err, g_sdssshape_shape22, g_sdssshape_shape22err,\n",
120 | " g_sdssshape_shape12, g_sdssshape_shape12err,\n",
121 | " r_sdssshape_shape11, r_sdssshape_shape11err, r_sdssshape_shape22, r_sdssshape_shape22err,\n",
122 | " r_sdssshape_shape12, r_sdssshape_shape12err,\n",
123 | " i_sdssshape_shape11, i_sdssshape_shape11err, i_sdssshape_shape22, i_sdssshape_shape22err,\n",
124 | " i_sdssshape_shape12, i_sdssshape_shape12err\n",
125 | "FROM pdr3_wide.forced\n",
126 | "LEFT JOIN pdr3_wide.forced2 USING (object_id)\n",
127 | "WHERE isprimary AND conesearch(coord, 222.51551376, 0.09749601, 35.64)\n",
128 | "AND (r_kronflux_mag < 26 OR i_kronflux_mag < 26) AND NOT i_kronflux_flag AND NOT r_kronflux_flag;\n",
129 | "\"\"\"\n",
130 | "cat = Csv()\n",
131 | "cat.header.splitter.escapechar = '#'\n",
132 | "cat = cat.read('fithsc_src.csv')\n",
133 | "\n",
134 | "prefix = '222.51551376,0.09749601_'\n",
135 | "prefix_img = f'{prefix}300x300_'\n",
136 | "\n",
137 | "# Read data, acquired with:\n",
138 | "# https://github.com/taranu/astro_imaging/blob/4d5a8e095e6a3944f1fbc19318b1dbc22b22d9ca/examples/HSC.ipynb\n",
139 | "# (with get_mask=True, get_variance=True,)\n",
140 | "images, psfs = {}, {}\n",
141 | "for band in bands:\n",
142 | " images[band] = fits.open(f'{prefix_img}{band}.fits')\n",
143 | " psfs[band] = fits.open(f'{prefix}{band}_psf.fits')\n",
144 | "\n",
145 | "wcs = WCS(images[band_ref][hdu_img])\n",
146 | "cat['x'], cat['y'] = wcs.world_to_pixel(SkyCoord(cat['ra'], cat['dec'], unit='deg'))"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {
153 | "is_executing": true,
154 | "ExecuteTime": {
155 | "start_time": "2024-06-14T06:17:42.599466290Z"
156 | }
157 | },
158 | "outputs": [],
159 | "source": [
160 | "# Plot image\n",
161 | "img_rgb = apVis.make_lupton_rgb(*[img[1].data*bands[band] for band, img in images.items()])\n",
162 | "plt.scatter(cat['x'], cat['y'], s=10, c='g', marker='x')\n",
163 | "plt.imshow(img_rgb)\n",
164 | "plt.title(\"gri image with detected objects\")\n",
165 | "plt.show()"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": null,
171 | "metadata": {
172 | "is_executing": true
173 | },
174 | "outputs": [],
175 | "source": [
176 | "# Generate a rough mask around other sources\n",
177 | "bright = (cat['i_cmodel_mag'] < 23) | (cat['i_psfflux_mag'] < 23)\n",
178 | "\n",
179 | "img_ref = images[band_ref][hdu_img].data\n",
180 | "\n",
181 | "mask_inverse = np.ones(img_ref.shape, dtype=bool)\n",
182 | "y_cen, x_cen = (x/2. for x in img_ref.shape)\n",
183 | "y, x = np.indices(img_ref.shape)\n",
184 | "\n",
185 | "idx_src_main, row_main = None, None\n",
186 | "\n",
187 | "sizes_override = {\n",
188 | " 42305088563206480: 8.,\n",
189 | "}\n",
190 | "\n",
191 | "for src in cat[bright]:\n",
192 | " id_src, x_src, y_src = (src[col] for col in ['object_id', 'x', 'y'])\n",
193 | " dist = np.hypot(x_src - x_cen, y_src - y_cen)\n",
194 | " if dist > 20:\n",
195 | " dists = np.hypot(y - y_src, x - x_src)\n",
196 | " mag = np.nanmin([src['i_cmodel_mag'], src['r_cmodel_mag'], src['i_psfflux_mag'], src['r_psfflux_mag']])\n",
197 | " if (radius_mask := sizes_override.get(id_src)) is None:\n",
198 | " radius_mask = 2*np.sqrt(\n",
199 | " src[f'{band_ref}_sdssshape_shape11'] + src[f'{band_ref}_sdssshape_shape22']\n",
200 | " )/scale_pixel_hsc\n",
201 | " if (radius_mask > 10) and (mag > 21):\n",
202 | " radius_mask = 5\n",
203 | " mask_inverse[dists < radius_mask] = 0\n",
204 | " print(f'Masking src=({id_src} at {x_src:.3f}, {y_src:.3f}) dist={dist:.3f}'\n",
205 | " f', mag={mag:.3f}, radius_mask={radius_mask:.3f}')\n",
206 | " elif dist < 2:\n",
207 | " idx_src_main = id_src\n",
208 | " row_main = src\n",
209 | " print(f\"{idx_src_main=} {dict(src)=}\")\n",
210 | "\n",
211 | "tab_row_main = apTab.Table(row_main)\n",
212 | "\n",
213 | "if read_mask_highsn:\n",
214 | " mask_highsn = np.load(f'{prefix_img}mask_inv_highsn.npz')['mask_inv']\n",
215 | " mask_inverse *= mask_highsn\n",
216 | "\n",
217 | "plt.imshow(mask_inverse)\n",
218 | "plt.title(\"Fitting mask\")\n",
219 | "plt.show()"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": null,
225 | "metadata": {
226 | "is_executing": true
227 | },
228 | "outputs": [],
229 | "source": [
230 | "# Fit PSF\n",
231 | "@dataclass(frozen=True, config=ArbitraryAllowedConfig)\n",
232 | "class CatalogExposurePsf(CatalogExposurePsfABC):\n",
233 | " catalog: apTab.Table = pydantic.Field(title=\"The detected object catalog\")\n",
234 | " img: np.ndarray = pydantic.Field(title=\"The PSF image\")\n",
235 | "\n",
236 | " def get_catalog(self) -> Iterable:\n",
237 | " return self.catalog\n",
238 | "\n",
239 | " def get_psf_image(self, source: apTab.Row | Mapping[str, Any]) -> np.array:\n",
240 | " return self.img\n",
241 | "\n",
242 | "config_psf = CatalogPsfFitterConfig(column_id='object_id')\n",
243 | "fitter_psf = CatalogPsfFitter()\n",
244 | "catalog_psf = apTab.Table({'object_id': [tab_row_main['object_id']]})\n",
245 | "results_psf = {}\n",
246 | "\n",
247 | "# Keep a separate configdata_psf per band, because it has a cached PSF model\n",
248 | "# those should not be shared!\n",
249 | "config_data_psfs = {}\n",
250 | "for band, psf_file in psfs.items():\n",
251 | " config_data_psf = CatalogPsfFitterConfigData(config=config_psf)\n",
252 | " catexp = CatalogExposurePsf(catalog=catalog_psf, img=psf_file[0].data)\n",
253 | " t_start = time.time()\n",
254 | " result = fitter_psf.fit(config_data=config_data_psf, catexp=catexp)\n",
255 | " t_end = time.time()\n",
256 | " results_psf[band] = result\n",
257 | " config_data_psfs[band] = config_data_psf\n",
258 | " print(f\"Fit {band}-band PSF in {t_end - t_start:.2e}s; result:\")\n",
259 | " print(dict(result[0]))"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": null,
265 | "metadata": {
266 | "is_executing": true
267 | },
268 | "outputs": [],
269 | "source": [
270 | "# Set fit configs\n",
271 | "config_source = CatalogSourceFitterConfig(\n",
272 | " column_id='object_id',\n",
273 | " config_model=ModelConfig(\n",
274 | " sources={\n",
275 | " \"src\": SourceConfig(\n",
276 | " component_groups={\n",
277 | " \"\": ComponentGroupConfig(\n",
278 | " components_sersic={\n",
279 | " 'disk': SersicComponentConfig(\n",
280 | " sersic_index=SersicIndexParameterConfig(value_initial=1., fixed=True),\n",
281 | " prior_size_stddev=0.5,\n",
282 | " prior_axrat_stddev=0.2,\n",
283 | " ),\n",
284 | " 'bulge': SersicComponentConfig(\n",
285 | " sersic_index=SersicIndexParameterConfig(value_initial=4., fixed=True),\n",
286 | " prior_size_stddev=0.1,\n",
287 | " prior_axrat_stddev=0.2,\n",
288 | " ),\n",
289 | " },\n",
290 | " ),\n",
291 | " }\n",
292 | " ),\n",
293 | " },\n",
294 | " ),\n",
295 | ")\n",
296 | "config_data_source = CatalogSourceFitterConfigData(\n",
297 | " channels=list(channels.values()),\n",
298 | " config=config_source,\n",
299 | ")"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": null,
305 | "metadata": {
306 | "is_executing": true
307 | },
308 | "outputs": [],
309 | "source": [
310 | "# Setup exposure with band-specific image, mask and variance\n",
311 | "@dataclass(frozen=True, config=ArbitraryAllowedConfig)\n",
312 | "class CatalogExposureSources(CatalogExposureSourcesABC):\n",
313 | " config_data_psf: CatalogPsfFitterConfigData = pydantic.Field(title=\"The PSF fit config\")\n",
314 | " observation: g2f.ObservationD = pydantic.Field(title=\"The observation to fit\")\n",
315 | " table_psf_fits: apTab.Table = pydantic.Field(title=\"The table of PSF fit parameters\")\n",
316 | "\n",
317 | " @property\n",
318 | " def channel(self) -> g2f.Channel:\n",
319 | " return self.observation.channel\n",
320 | "\n",
321 | " def get_catalog(self) -> Iterable:\n",
322 | " return self.table_psf_fits\n",
323 | "\n",
324 | " def get_psf_model(self, params: Mapping[str, Any]) -> g2f.PsfModel:\n",
325 | " self.config_data_psf.init_psf_model(params)\n",
326 | " return self.config_data_psf.psf_model\n",
327 | "\n",
328 | " def get_source_observation(self, source: Mapping[str, Any]) -> g2f.ObservationD:\n",
329 | " return self.observation\n",
330 | "\n",
331 | "\n",
332 | "@dataclass(frozen=True, config=ArbitraryAllowedConfig)\n",
333 | "class CatalogSourceFitter(CatalogSourceFitterABC):\n",
334 | " band: str = pydantic.Field(title=\"The reference band for initialization and priors\")\n",
335 | " scale_pixel: float = pydantic.Field(title=\"The pixel scale in arcsec\")\n",
336 | " wcs_ref: WCS = pydantic.Field(title=\"The WCS for the coadded image\")\n",
337 | "\n",
338 | " def initialize_model(\n",
339 | " self,\n",
340 | " model: g2f.ModelD,\n",
341 | " source: Mapping[str, Any],\n",
342 | " catexps: list[CatalogExposureSourcesABC],\n",
343 | " values_init: Mapping[g2f.ParameterD, float] | None = None,\n",
344 | " centroid_pixel_offset: float = 0,\n",
345 | " **kwargs\n",
346 | " ):\n",
347 | " if values_init is None:\n",
348 | " values_init = {}\n",
349 | " x, y = source['x'], source['y']\n",
350 | " scale_sq = self.scale_pixel**(-2)\n",
351 | " ellipse = g2.Ellipse(g2.Covariance(\n",
352 | " sigma_x_sq=source[f'{band}_sdssshape_shape11']*scale_sq,\n",
353 | " sigma_y_sq=source[f'{band}_sdssshape_shape22']*scale_sq,\n",
354 | " cov_xy=source[f'{band}_sdssshape_shape12']*scale_sq,\n",
355 | " ))\n",
356 | " size_major = g2.EllipseMajor(ellipse).r_major\n",
357 | " limits_size = g2f.LimitsD(1e-5, np.sqrt(x*x + y*y))\n",
358 | " # An R_eff larger than the box size is problematic\n",
359 | " # Also should stop unreasonable size proposals; log10 transform isn't enough\n",
360 | " # TODO: Try logit for r_eff?\n",
361 | " params_limits_init = {\n",
362 | " # Should set limits based on image size, but this shortcut is fine\n",
363 | " # for this particular object\n",
364 | " g2f.CentroidXParameterD: (x, g2f.LimitsD(0, 2*x)),\n",
365 | " g2f.CentroidYParameterD: (x, g2f.LimitsD(0, 2*y)),\n",
366 | " g2f.ReffXParameterD: (ellipse.sigma_x, limits_size),\n",
367 | " g2f.ReffYParameterD: (ellipse.sigma_y, limits_size),\n",
368 | " # There is a sign convention difference\n",
369 | " g2f.RhoParameterD: (-ellipse.rho, None),\n",
370 | " g2f.IntegralParameterD: (1.0, g2f.LimitsD(1e-10, 1e10)),\n",
371 | " }\n",
372 | " params_free = get_params_uniq(model, fixed=False)\n",
373 | " for param in params_free:\n",
374 | " type_param = type(param)\n",
375 | " value_init, limits_new = params_limits_init.get(\n",
376 | " type_param,\n",
377 | " (values_init.get(param), None)\n",
378 | " )\n",
379 | " if value_init is not None:\n",
380 | " param.value = value_init\n",
381 | " if limits_new:\n",
382 | " # For slightly arcane reasons, we must set a new limits object\n",
383 | " # Changing limits values is unreliable\n",
384 | " param.limits = limits_new\n",
385 | " for prior in model.priors:\n",
386 | " if isinstance(prior, g2f.ShapePrior):\n",
387 | " prior.prior_size.mean_parameter.value = size_major\n",
388 | "\n",
389 | "\n",
390 | " def validate_fit_inputs(\n",
391 | " self,\n",
392 | " catalog_multi,\n",
393 | " catexps: list[CatalogExposureSourcesABC],\n",
394 | " config_data: CatalogSourceFitterConfigData = None,\n",
395 | " logger = None,\n",
396 | " **kwargs: Any,\n",
397 | " ) -> None:\n",
398 | " super().validate_fit_inputs(\n",
399 | " catalog_multi=catalog_multi, catexps=catexps, config_data=config_data,\n",
400 | " logger=logger, **kwargs\n",
401 | " )"
402 | ]
403 | },
404 | {
405 | "cell_type": "code",
406 | "execution_count": null,
407 | "metadata": {
408 | "is_executing": true
409 | },
410 | "outputs": [],
411 | "source": [
412 | "# Set up Fitter, Observations and CatalogExposureSources\n",
413 | "fitter = CatalogSourceFitter(band=band, scale_pixel=scale_pixel_hsc, wcs_ref=wcs)\n",
414 | "\n",
415 | "observations = {}\n",
416 | "catexps = {}\n",
417 | "\n",
418 | "for band in bands:\n",
419 | " data = images[band]\n",
420 | " # There are better ways to use bitmasks, but this will do\n",
421 | " header = data[hdu_mask].header\n",
422 | " bitmask = data[hdu_mask].data\n",
423 | " mask = np.zeros_like(bitmask, dtype='bool')\n",
424 | " for bit in maskbits:\n",
425 | " mask |= ((bitmask & 2**header[bit]) != 0)\n",
426 | "\n",
427 | " mask = (mask == 0) & mask_inverse\n",
428 | " sigma_inv = 1.0/np.sqrt(data[hdu_var].data)\n",
429 | " sigma_inv[mask != 1] = 0\n",
430 | "\n",
431 | " observation = g2f.ObservationD(\n",
432 | " image=g2.ImageD(data[hdu_img].data),\n",
433 | " sigma_inv=g2.ImageD(sigma_inv),\n",
434 | " mask_inv=g2.ImageB(mask),\n",
435 | " channel=g2f.Channel.get(band),\n",
436 | " )\n",
437 | " observations[band] = observation\n",
438 | " catexps[band] = CatalogExposureSources(\n",
439 | " config_data_psf=config_data_psfs[band],\n",
440 | " observation=observation,\n",
441 | " table_psf_fits=results_psf[band],\n",
442 | " )"
443 | ]
444 | },
445 | {
446 | "cell_type": "code",
447 | "execution_count": null,
448 | "metadata": {
449 | "scrolled": true,
450 | "is_executing": true
451 | },
452 | "outputs": [],
453 | "source": [
454 | "# Now do the multi-band fit\n",
455 | "t_start = time.time()\n",
456 | "result_multi = fitter.fit(\n",
457 | " catalog_multi=tab_row_main,\n",
458 | " catexps=list(catexps.values()),\n",
459 | " config_data=config_data_source,\n",
460 | ")\n",
461 | "t_end = time.time()\n",
462 | "print(f\"Fit {','.join(bands.keys())}-band bulge-disk model in {t_end - t_start:.2e}s; result:\")\n",
463 | "print(dict(result_multi[0]))"
464 | ]
465 | },
466 | {
467 | "cell_type": "code",
468 | "execution_count": null,
469 | "metadata": {
470 | "scrolled": true,
471 | "is_executing": true
472 | },
473 | "outputs": [],
474 | "source": [
475 | "# Fit in each band separately\n",
476 | "results = {}\n",
477 | "for band, observation in bands.items():\n",
478 | " config_data_source_band = CatalogSourceFitterConfigData(\n",
479 | " channels=[channels[band]],\n",
480 | " config=config_source,\n",
481 | " )\n",
482 | " t_start = time.time()\n",
483 | " result = fitter.fit(\n",
484 | " catalog_multi=tab_row_main,\n",
485 | " catexps=[catexps[band]],\n",
486 | " config_data=config_data_source_band,\n",
487 | " )\n",
488 | " t_end = time.time()\n",
489 | " results[band] = result\n",
490 | " print(f\"Fit {band}-band bulge-disk model in {t_end - t_start:.2f}s; result:\")\n",
491 | " print(dict(result[0]))"
492 | ]
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": null,
497 | "metadata": {
498 | "is_executing": true
499 | },
500 | "outputs": [],
501 | "source": [
502 | "# Make a model for the best-fit params\n",
503 | "data, psf_models = config_source.make_model_data(idx_row=0, catexps=list(catexps.values()))\n",
504 | "model = g2f.ModelD(data=data, psfmodels=psf_models, sources=config_data_source.sources_priors[0], priors=config_data_source.sources_priors[1])\n",
505 | "params = get_params_uniq(model, fixed=False)\n",
506 | "result_multi_row = dict(result_multi[0])\n",
507 | "# This is the last column before fit params\n",
508 | "idx_last = next(idx for idx, column in enumerate(result_multi_row.keys()) if column == 'mpf_unknown_flag')\n",
509 | "# Set params to best fit values\n",
510 | "for param, (column, value) in zip(params, list(result_multi_row.items())[idx_last+1:]):\n",
511 | " param.value = value\n",
512 | "model.setup_evaluators(g2f.EvaluatorMode.loglike_image)\n",
513 | "# Print the loglikelihoods, which are from the data and end with the (sum of all) priors\n",
514 | "loglikes = model.evaluate()\n",
515 | "print(f\"{loglikes=}\")"
516 | ]
517 | },
518 | {
519 | "cell_type": "markdown",
520 | "metadata": {},
521 | "source": [
522 | "### Multiband Residuals\n",
523 | "\n",
524 | "What's with the structure in the residuals? Most broadly, a point source + exponential disk + deVauc bulge model is totally inadequate for this galaxy for several possible reasons:\n",
525 | "\n",
526 | "1. The disk isn't exactly exponential (n=1)\n",
527 | "2. The disk has colour gradients not accounted for in this model*\n",
528 | "3. If the galaxy even has a bulge, it's very weak and def. not a deVaucouleurs (n=4) profile; it may be an exponential \"pseudobulge\"\n",
529 | "\n",
530 | "\\*MultiProFit can do more general Gaussian mixture models (linear or non-linear), which may be explored in a future iteration of this notebook, but these are generally do not improve the accuracy of photometry for smaller/fainter galaxies.\n",
531 | "\n",
532 | "Note that the two scalings of the residual plots (98%ile and +/- 20 sigma) end up looking very similar.\n"
533 | ]
534 | },
535 | {
536 | "cell_type": "code",
537 | "execution_count": null,
538 | "metadata": {
539 | "is_executing": true
540 | },
541 | "outputs": [],
542 | "source": [
543 | "# Make some basic plots\n",
544 | "_, _, _, _, mask_inv_highsn = plot_model_rgb(\n",
545 | " model, weights=bands, high_sn_threshold=0.2 if write_mask_highsn else None,\n",
546 | ")\n",
547 | "plt.show()\n",
548 | "\n",
549 | "# Write the high SN bitmask to a compressed, bitpacked file\n",
550 | "if write_mask_highsn:\n",
551 | " plt.figure()\n",
552 | " plt.imshow(mask_highsn, cmap='gray')\n",
553 | " plt.show()\n",
554 | " packed = np.packbits(mask_inv_highsn, bitorder='little')\n",
555 | " np.savez_compressed(f'{prefix_img}mask_inv_highsn.npz', mask_inv=mask_highsn)\n",
556 | "\n",
557 | "# TODO: Some features still missing from plot_model_rgb\n",
558 | "# residual histograms, param values, better labels, etc"
559 | ]
560 | },
561 | {
562 | "cell_type": "markdown",
563 | "metadata": {},
564 | "source": [
565 | "### More exercises for the reader\n",
566 | "\n",
567 | "These are of the sort that the author hasn't gotten around to yet because they're far from trivial. Try:\n",
568 | "\n",
569 | "0. Use the WCS to compute ra, dec and errors thereof.\n",
570 | "Hint: override CatalogSourceFitter.get_model_radec\n",
571 | "\n",
572 | "1. Replace the real data with simulated data.\n",
573 | "Make new observations using model.evaluate and add noise based on the variance maps.\n",
574 | "Try fitting again and see how well results converge depending on the initialization scheme.\n",
575 | "\n",
576 | "2. Fit every other source individually.\n",
577 | "Try subtracting the best-fit galaxy model from above first.\n",
578 | "Hint: get_source_observation should be redefined to return a smaller postage stamp around the nominal centroid.\n",
579 | "Pass the full catalog (excluding the central galaxy) to catalog_multi.\n",
580 | "\n",
581 | "3. Fit all sources simultaneously.\n",
582 | "Redefine CatalogFitterConfig.make_model_data to make a model with multiple sources, using the catexp catalogs\n",
583 | "initialize_model will no longer need to do anything\n",
584 | "catalog_multi should still be a single row"
585 | ]
586 | }
587 | ],
588 | "metadata": {
589 | "kernelspec": {
590 | "display_name": "Python 3 (ipykernel)",
591 | "language": "python",
592 | "name": "python3"
593 | },
594 | "language_info": {
595 | "codemirror_mode": {
596 | "name": "ipython",
597 | "version": 3
598 | },
599 | "file_extension": ".py",
600 | "mimetype": "text/x-python",
601 | "name": "python",
602 | "nbconvert_exporter": "python",
603 | "pygments_lexer": "ipython3",
604 | "version": "3.11.4"
605 | }
606 | },
607 | "nbformat": 4,
608 | "nbformat_minor": 1
609 | }
610 |
--------------------------------------------------------------------------------