├── benchmarks
├── __init__.py
├── README.md
├── bbox_benchmarks.py
└── asv.conf.json
├── src
└── scarlet2
│ ├── questionnaire
│ ├── views
│ │ ├── __init__.py
│ │ ├── output_box.html.jinja
│ │ ├── output_box.css
│ │ └── question_box.css
│ ├── __init__.py
│ └── models.py
│ ├── psf.py
│ ├── __init__.py
│ ├── io.py
│ ├── validation_utils.py
│ ├── validation.py
│ ├── nn.py
│ ├── spectrum.py
│ ├── wavelets.py
│ ├── morphology.py
│ ├── bbox.py
│ ├── lsst_utils.py
│ └── source.py
├── docs
├── _static
│ ├── icon.png
│ ├── example_obs_validation.png
│ ├── questionnaire_screenshot.png
│ ├── logo_light.svg
│ └── logo_dark.svg
├── requirements.txt
├── _templates
│ ├── custom-base-template.rst
│ ├── custom-class-template.rst
│ └── custom-module-template.rst
├── 1-howto.md
├── api.rst
├── Makefile
├── make.bat
├── conf.py
├── index.md
├── howto
│ ├── sampling.ipynb
│ └── priors.ipynb
└── 2-questionnaire.md
├── .git_archival.txt
├── .github
├── ISSUE_TEMPLATE
│ ├── 0-general_issue.md
│ ├── 1-bug_report.md
│ └── 2-feature_request.md
├── dependabot.yml
└── workflows
│ ├── publish-to-pypi.yml
│ ├── pre-commit-ci.yml
│ ├── testing-and-coverage.yml
│ ├── smoke-test.yml
│ ├── publish-benchmarks-pr.yml
│ ├── asv-main.yml
│ ├── asv-nightly.yml
│ └── asv-pr.yml
├── .readthedocs.yaml
├── tests
└── scarlet2
│ ├── test_validation_utils.py
│ ├── questionnaire
│ ├── test_models.py
│ ├── data
│ │ ├── example_questionnaire.yaml
│ │ ├── example_questionnaire_switch.yaml
│ │ └── example_questionnaire_followup_switch.yaml
│ └── conftest.py
│ ├── test_scene_fit_checks.py
│ ├── test_quickstart.py
│ ├── test_save_output.py
│ ├── test_moments.py
│ ├── conftest.py
│ ├── test_observation_checks.py
│ └── test_renderer.py
├── .copier-answers.yml
├── .gitattributes
├── LICENSE
├── README.md
├── .setup_dev.sh
├── .gitignore
├── .initialize_new_project.sh
├── .pre-commit-config.yaml
└── pyproject.toml
/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/scarlet2/questionnaire/views/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/_static/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pmelchior/scarlet2/HEAD/docs/_static/icon.png
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx-issues
2 | myst-nb
3 | sphinx-book-theme
4 | ipywidgets
5 | corner
6 | arviz
7 | galaxygrad >= 0.3.0
--------------------------------------------------------------------------------
/docs/_static/example_obs_validation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pmelchior/scarlet2/HEAD/docs/_static/example_obs_validation.png
--------------------------------------------------------------------------------
/docs/_static/questionnaire_screenshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pmelchior/scarlet2/HEAD/docs/_static/questionnaire_screenshot.png
--------------------------------------------------------------------------------
/docs/_templates/custom-base-template.rst:
--------------------------------------------------------------------------------
1 | {{ objname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. auto{{ objtype }}:: {{ objname }}
6 |
--------------------------------------------------------------------------------
/.git_archival.txt:
--------------------------------------------------------------------------------
1 | node: f4b070e05b951dfc0a7615f40388b2f95ff151ba
2 | node-date: 2025-12-15T22:12:03+01:00
3 | describe-name: v0.3.0-58-gf4b070e0
4 | ref-names: HEAD -> main
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/0-general_issue.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: General issue
3 | about: Quickly create a general issue
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
--------------------------------------------------------------------------------
/src/scarlet2/questionnaire/__init__.py:
--------------------------------------------------------------------------------
1 | from .questionnaire import QuestionnaireWidget, run_questionnaire
2 |
3 | __all__ = ["QuestionnaireWidget", "run_questionnaire"]
4 |
--------------------------------------------------------------------------------
/docs/1-howto.md:
--------------------------------------------------------------------------------
1 | # How To ...
2 |
3 | ```{toctree}
4 | :maxdepth: 2
5 |
6 | howto/sampling
7 | howto/priors
8 | howto/correlated
9 | howto/multiresolution
10 | howto/timedomain
11 | ```
--------------------------------------------------------------------------------
/docs/_templates/custom-class-template.rst:
--------------------------------------------------------------------------------
1 | {{ objname | escape | underline}}
2 |
3 | .. currentmodule:: {{ module }}
4 |
5 | .. autoclass:: {{ objname }}
6 | :members:
7 | :show-inheritance:
8 | :inherited-members:
9 | :special-members: __call__
10 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "github-actions"
4 | directory: "/"
5 | schedule:
6 | interval: "monthly"
7 | - package-ecosystem: "pip"
8 | directory: "/"
9 | schedule:
10 | interval: "monthly"
11 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: "2"
2 |
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.10"
7 |
8 | python:
9 | install:
10 | - method: pip
11 | path: .
12 | extra_requirements:
13 | - docs
14 | - requirements: docs/requirements.txt
15 |
16 | sphinx:
17 | configuration: docs/conf.py
--------------------------------------------------------------------------------
/src/scarlet2/questionnaire/views/output_box.html.jinja:
--------------------------------------------------------------------------------
1 |
2 | {{highlighted_code | safe}}
3 |
4 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | API Documentation
2 | =================
3 |
4 | .. autosummary::
5 | :toctree: _autosummary
6 | :template: custom-module-template.rst
7 |
8 | scarlet2
9 |
10 | .. autosummary::
11 | :caption: Modules
12 | :toctree: _autosummary
13 | :template: custom-module-template.rst
14 | :recursive:
15 |
16 | scarlet2.fft
17 | scarlet2.init
18 | scarlet2.interpolation
19 | scarlet2.io
20 | scarlet2.nn
21 | scarlet2.measure
22 | scarlet2.plot
23 | scarlet2.renderer
24 | scarlet2.wavelets
25 |
26 |
--------------------------------------------------------------------------------
/tests/scarlet2/test_validation_utils.py:
--------------------------------------------------------------------------------
1 | from scarlet2.validation_utils import set_validation
2 |
3 |
4 | def test_set_validation():
5 | """Test setting the validation switch. Note that we have to re-import VALIDATION_SWITCH
6 | to ensure we are using the current value."""
7 |
8 | set_validation(True)
9 | from scarlet2.validation_utils import VALIDATION_SWITCH
10 |
11 | assert VALIDATION_SWITCH is True
12 |
13 | set_validation(False)
14 | from scarlet2.validation_utils import VALIDATION_SWITCH
15 |
16 | assert VALIDATION_SWITCH is False
17 |
--------------------------------------------------------------------------------
/benchmarks/README.md:
--------------------------------------------------------------------------------
1 | # Benchmarks
2 |
3 | This directory contains files that will be run via continuous testing either
4 | nightly or after committing code to a pull request.
5 |
6 | The runtime and/or memory usage of the functions defined in these files will be
7 | tracked and reported to give you a sense of the overall performance of your code.
8 |
9 | You are encouraged to add, update, or remove benchmark functions to suit the needs
10 | of your project.
11 |
12 | For more information, see the documentation here: https://lincc-ppt.readthedocs.io/en/latest/practices/ci_benchmarking.html
--------------------------------------------------------------------------------
/.copier-answers.yml:
--------------------------------------------------------------------------------
1 | # Changes here will be overwritten by Copier
2 | _commit: v2.0.7
3 | _src_path: gh:lincc-frameworks/python-project-template
4 | author_email: peter.m.melchior@gmail.com
5 | author_name: Peter Melchior
6 | create_example_module: false
7 | custom_install: true
8 | enforce_style:
9 | - ruff_lint
10 | - ruff_format
11 | failure_notification: []
12 | include_benchmarks: true
13 | include_docs: false
14 | mypy_type_checking: none
15 | package_name: scarlet2
16 | project_license: MIT
17 | project_name: scarlet2
18 | project_organization: pmelchior
19 | python_versions:
20 | - '3.10'
21 | - '3.11'
22 | - '3.12'
23 | - '3.13'
24 | test_lowest_version: none
25 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/src/scarlet2/questionnaire/views/output_box.css:
--------------------------------------------------------------------------------
1 | .code-output {
2 | background-color: #272822;
3 | color: #f1f1f1;
4 | padding: 10px;
5 | border-radius: 10px;
6 | border: 1px solid #444;
7 | font-family: monospace;
8 | overflow-x: auto;
9 | }
10 |
11 | .copy_button {
12 | font-size: 12px;
13 | padding: 4px 8px;
14 | background-color: #272822;
15 | color: white;
16 | border: none;
17 | border-radius: 4px;
18 | cursor: pointer;
19 | }
20 |
21 | .commentary-box {
22 | background-color: #f6f8fa;
23 | color: #333;
24 | padding: 8px 12px;
25 | border-left: 4px solid #0366d6;
26 | border-radius: 6px;
27 | font-size: 14px;
28 | font-family: sans-serif;
29 | }
30 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/1-bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Tell us about a problem to fix
4 | title: 'Short description'
5 | labels: 'bug'
6 | assignees: ''
7 |
8 | ---
9 | **Bug report**
10 |
11 |
12 | **Before submitting**
13 | Please check the following:
14 |
15 | - [ ] I have described the situation in which the bug arose, including what code was executed, information about my environment, and any applicable data others will need to reproduce the problem.
16 | - [ ] I have included available evidence of the unexpected behavior (including error messages, screenshots, and/or plots) as well as a description of what I expected instead.
17 | - [ ] If I have a solution in mind, I have provided an explanation and/or pseudocode and/or task list.
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/2-feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: 'Short description'
5 | labels: 'enhancement'
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Feature request**
11 |
12 |
13 | **Before submitting**
14 | Please check the following:
15 |
16 | - [ ] I have described the purpose of the suggested change, specifying what I need the enhancement to accomplish, i.e. what problem it solves.
17 | - [ ] I have included any relevant links, screenshots, environment information, and data relevant to implementing the requested feature, as well as pseudocode for how I want to access the new functionality.
18 | - [ ] If I have ideas for how the new feature could be implemented, I have provided explanations and/or pseudocode and/or task lists for the steps.
19 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # For explanation of this file and uses see
2 | # https://git-scm.com/docs/gitattributes
3 | # https://developer.lsst.io/git/git-lfs.html#using-git-lfs-enabled-repositories
4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/git-lfs.html
5 | #
6 | # Used by https://github.com/lsst/afwdata.git
7 | # *.boost filter=lfs diff=lfs merge=lfs -text
8 | # *.dat filter=lfs diff=lfs merge=lfs -text
9 | # *.fits filter=lfs diff=lfs merge=lfs -text
10 | # *.gz filter=lfs diff=lfs merge=lfs -text
11 | #
12 | # apache parquet files
13 | # *.parq filter=lfs diff=lfs merge=lfs -text
14 | #
15 | # sqlite files
16 | # *.sqlite3 filter=lfs diff=lfs merge=lfs -text
17 | #
18 | # gzip files
19 | # *.gz filter=lfs diff=lfs merge=lfs -text
20 | #
21 | # png image files
22 | # *.png filter=lfs diff=lfs merge=lfs -text
23 |
24 | .git_archival.txt export-subst
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/benchmarks/bbox_benchmarks.py:
--------------------------------------------------------------------------------
1 | from scarlet2.bbox import Box
2 |
3 |
4 | def time_bbox_creation():
5 | """Basic timing benchmark for Box creation"""
6 | _ = Box((15, 15))
7 |
8 |
9 | def mem_bbox_creation():
10 | """Basic memory benchmark for Box creation"""
11 | _ = Box((15, 15))
12 |
13 |
14 | class BoxSuite:
15 | """Suite of benchmarks for methods of the Box class"""
16 |
17 | params = [2, 16, 256, 2048]
18 |
19 | def setup(self, edge_length):
20 | """Create a Box, for each of the different edge_lengths defined in the
21 | `params` list."""
22 | self.bb = Box((edge_length, edge_length))
23 |
24 | def time_bbox_contains(self, edge_length):
25 | """Timing benchmark for `contains` method. Note that `edge_length` is
26 | unused by this method."""
27 | this_point = (6, 7)
28 | self.bb.contains(this_point)
29 |
30 | def mem_bbox_contains(self, edge_length):
31 | """Memory benchmark for `contains` method. Note that `edge_length` is
32 | unused by this method."""
33 | this_point = (6, 7)
34 | self.bb.contains(this_point)
35 |
--------------------------------------------------------------------------------
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 |
2 | # This workflow will upload a Python Package using Twine when a release is created
3 | # For more information see: https://github.com/pypa/gh-action-pypi-publish#trusted-publishing
4 |
5 | # This workflow uses actions that are not certified by GitHub.
6 | # They are provided by a third-party and are governed by
7 | # separate terms of service, privacy policy, and support
8 | # documentation.
9 |
10 | name: Upload Python Package
11 |
12 | on:
13 | release:
14 | types: [published]
15 |
16 | permissions:
17 | contents: read
18 |
19 | jobs:
20 | deploy:
21 |
22 | runs-on: ubuntu-latest
23 | permissions:
24 | id-token: write
25 | steps:
26 | - uses: actions/checkout@v6
27 | - name: Set up Python
28 | uses: actions/setup-python@v6
29 | with:
30 | python-version: '3.11'
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pip install build
35 | - name: Build package
36 | run: python -m build
37 | - name: Publish package
38 | uses: pypa/gh-action-pypi-publish@release/v1
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023-2025 Peter Melchior
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # _scarlet2_
2 |
3 | _scarlet2_ is an open-source python library for modeling astronomical sources from multi-band, multi-epoch, and
4 | multi-instrument data. It provides non-parametric and parametric models, can handle source overlap (aka blending), and
5 | can integrate neural network priors. It's designed to be modular, flexible, and powerful.
6 |
7 | _scarlet2_ is implemented in [jax](http://jax.readthedocs.io/), layered on top of
8 | the [equinox](https://docs.kidger.site/equinox/)
9 | library. It can be deployed to GPUs and TPUs and supports optimization and sampling approaches.
10 |
11 | ## Installation
12 |
13 | For performance reasons, you should first install `jax` with the suitable `jaxlib` for your platform. After that
14 |
15 | ```
16 | pip install scarlet2
17 | ```
18 |
19 | should do. If you want the latest development version, use
20 |
21 | ```
22 | pip install git+https://github.com/pmelchior/scarlet2.git
23 | ```
24 |
25 | This will allow you to evaluate source models and compute likelihoods of observed data, so you can run your own
26 | optimizer/sampler. If you want a fully fledged library out of the box, you need to install `optax`, `numpyro`, and
27 | `h5py` as well.
--------------------------------------------------------------------------------
/.github/workflows/pre-commit-ci.yml:
--------------------------------------------------------------------------------
1 |
2 | # This workflow runs pre-commit hooks on pushes and pull requests to main
3 | # to enforce coding style. To ensure correct configuration, please refer to:
4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/ci_precommit.html
5 | name: Run pre-commit hooks
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | pull_request:
11 | branches: [ main ]
12 |
13 | jobs:
14 | pre-commit-ci:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v6
18 | with:
19 | fetch-depth: 0
20 | - name: Set up Python
21 | uses: actions/setup-python@v6
22 | with:
23 | python-version: '3.11'
24 | - name: Install dependencies
25 | run: |
26 | sudo apt-get update
27 | python -m pip install --upgrade pip
28 | pip install .[dev]
29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
30 | - uses: pre-commit/action@v3.0.1
31 | with:
32 | extra_args: --all-files --verbose
33 | env:
34 | SKIP: "check-lincc-frameworks-template-version,no-commit-to-branch,check-added-large-files,validate-pyproject,sphinx-build,pytest-check"
35 | - uses: pre-commit-ci/lite-action@v1.1.0
36 | if: failure() && github.event_name == 'pull_request' && github.event.pull_request.draft == false
--------------------------------------------------------------------------------
/.github/workflows/testing-and-coverage.yml:
--------------------------------------------------------------------------------
1 |
2 | # This workflow will install Python dependencies, run tests and report code coverage with a variety of Python versions
3 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
4 |
5 | name: Unit test and code coverage
6 |
7 | on:
8 | push:
9 | branches: [ main ]
10 | pull_request:
11 | branches: [ main ]
12 |
13 | jobs:
14 | build:
15 |
16 | runs-on: ubuntu-latest
17 | strategy:
18 | matrix:
19 | python-version: ['3.10', '3.11', '3.12', '3.13']
20 |
21 | steps:
22 | - uses: actions/checkout@v6
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v6
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 | - name: Install dependencies
28 | run: |
29 | sudo apt-get update
30 | python -m pip install --upgrade pip
31 | pip install -e .[dev]
32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
33 | - name: Run unit tests with pytest
34 | run: |
35 | python -m pytest --cov=scarlet2 --cov-report=xml
36 | - name: Upload coverage report to codecov
37 | uses: codecov/codecov-action@v5
38 | with:
39 | token: ${{ secrets.CODECOV_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.github/workflows/smoke-test.yml:
--------------------------------------------------------------------------------
1 | # This workflow will run daily at 06:45.
2 | # It will install Python dependencies and run tests with a variety of Python versions.
3 | # See documentation for help debugging smoke test issues:
4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/ci_testing.html#version-culprit
5 |
6 | name: Unit test smoke test
7 |
8 | on:
9 |
10 | # Runs this workflow automatically
11 | schedule:
12 | - cron: 45 6 * * *
13 |
14 | # Allows you to run this workflow manually from the Actions tab
15 | workflow_dispatch:
16 |
17 | jobs:
18 | build:
19 |
20 | runs-on: ubuntu-latest
21 | strategy:
22 | matrix:
23 | python-version: ['3.10', '3.11', '3.12', '3.13']
24 |
25 | steps:
26 | - uses: actions/checkout@v6
27 | - name: Set up Python ${{ matrix.python-version }}
28 | uses: actions/setup-python@v6
29 | with:
30 | python-version: ${{ matrix.python-version }}
31 | - name: Install dependencies
32 | run: |
33 | sudo apt-get update
34 | python -m pip install --upgrade pip
35 | pip install -e .[dev]
36 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
37 | - name: List dependencies
38 | run: |
39 | pip list
40 | - name: Run unit tests with pytest
41 | run: |
42 | python -m pytest
--------------------------------------------------------------------------------
/tests/scarlet2/questionnaire/test_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from scarlet2.questionnaire.models import Questionnaire
3 |
4 |
5 | def test_validate_model(example_questionnaire_dict):
6 | """Test that the Questionnaire model validates correctly."""
7 | questionnaire = Questionnaire.model_validate(example_questionnaire_dict)
8 | assert questionnaire.initial_template == "{{code}}"
9 | assert questionnaire.initial_commentary == "This is an example commentary."
10 | assert len(questionnaire.questions) == 2
11 | question = questionnaire.questions[0]
12 | assert question.question == "Example question?"
13 | assert len(question.answers) == 2
14 | answer = question.answers[0]
15 | assert answer.answer == "Example answer"
16 | assert len(answer.templates) == 1
17 | assert answer.templates[0].replacement == "code"
18 | assert len(answer.followups) == 2
19 | followup = answer.followups[0]
20 | assert followup.question == "Follow-up question?"
21 |
22 |
23 | def test_model_fails(example_questionnaire_dict):
24 | """Test that the Questionnaire model raises errors for invalid data."""
25 | invalid_dict = example_questionnaire_dict.copy()
26 | invalid_dict["initial_template"] = 123
27 |
28 | with pytest.raises(ValueError):
29 | Questionnaire.model_validate(invalid_dict)
30 |
31 | invalid_dict = example_questionnaire_dict.copy()
32 | invalid_dict["questions"][0]["question"] = 123
33 |
34 | with pytest.raises(ValueError):
35 | Questionnaire.model_validate(invalid_dict)
36 |
--------------------------------------------------------------------------------
/tests/scarlet2/questionnaire/data/example_questionnaire.yaml:
--------------------------------------------------------------------------------
1 | initial_commentary: This is an example commentary.
2 | initial_template: '{{code}}'
3 | questions:
4 | - question: Example question?
5 | answers:
6 | - answer: Example answer
7 | commentary: This is some commentary.
8 | templates:
9 | - code: example_code {{follow}}
10 | replacement: code
11 | tooltip: This is an example tooltip.
12 | followups:
13 | - answers:
14 | - answer: Follow-up answer
15 | commentary: ''
16 | followups: []
17 | templates:
18 | - code: 'followup_code {{code}}'
19 | replacement: follow
20 | tooltip: This is a follow-up tooltip.
21 | - answer: Second follow-up answer
22 | templates:
23 | - code: 'second_followup_code {{code}}'
24 | replacement: follow
25 | tooltip: This is a second follow-up tooltip.
26 | question: Follow-up question?
27 | - answers:
28 | - answer: Another follow-up answer
29 | templates: []
30 | question: Another follow-up question?
31 | - answer: Another answer
32 | commentary: Some other commentary.
33 | templates:
34 | - code: 'another_code {{code}}'
35 | replacement: code
36 | tooltip: This is another tooltip.
37 | followups: []
38 | - question: Second question?
39 | answers:
40 | - answer: Second answer
41 | commentary: Next commentary.
42 | templates:
43 | - code: second_code
44 | replacement: code
45 | followups: []
46 | tooltip: This is a second tooltip.
47 |
--------------------------------------------------------------------------------
/tests/scarlet2/questionnaire/data/example_questionnaire_switch.yaml:
--------------------------------------------------------------------------------
1 | initial_commentary: This is an example commentary.
2 | initial_template: '{{code}}'
3 | questions:
4 | - question: Example question?
5 | variable: example_var
6 | answers:
7 | - answer: First answer
8 | commentary: This is some commentary.
9 | templates:
10 | - code: 'example_code {{code}}'
11 | replacement: code
12 | tooltip: This is an example tooltip.
13 | - answer: Second answer
14 | commentary: Some other commentary.
15 | followups: []
16 | templates:
17 | - code: 'another_code {{code}}'
18 | replacement: code
19 | tooltip: This is another tooltip.
20 | - answer: Third answer
21 | commentary: Some other commentary.
22 | followups: []
23 | templates:
24 | - code: 'third_code {{code}}'
25 | replacement: code
26 | tooltip: This is another tooltip.
27 | - switch: example_var
28 | cases:
29 | - value: 0
30 | questions:
31 | - question: Second Question 1?
32 | answers:
33 | - answer: Second answer 1
34 | templates:
35 | - code: 'followup_code_1 {{code}}'
36 | replacement: code
37 | - value: 1
38 | questions:
39 | - question: Second Question 2?
40 | answers:
41 | - answer: Second answer 2
42 | templates:
43 | - code: 'followup_code_2 {{code}}'
44 | replacement: code
45 | - value: null
46 | questions:
47 | - question: Default Question?
48 | answers:
49 | - answer: Default answer
50 | templates:
51 | - code: 'default_code {{code}}'
52 | replacement: code
53 |
54 |
55 |
--------------------------------------------------------------------------------
/docs/_static/logo_light.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
24 |
--------------------------------------------------------------------------------
/docs/_templates/custom-module-template.rst:
--------------------------------------------------------------------------------
1 | {{ fullname | escape | underline}}
2 |
3 | .. automodule:: {{ fullname }}
4 |
5 | {% block classes %}
6 | {% if classes %}
7 | .. rubric:: {{ _('Classes') }}
8 |
9 | .. autosummary::
10 | :toctree:
11 | :template: custom-class-template.rst
12 | {% for item in classes %}
13 | {{ item }}
14 | {%- endfor %}
15 | {% endif %}
16 | {% endblock %}
17 |
18 | {% block functions %}
19 | {% if functions %}
20 | .. rubric:: {{ _('Functions') }}
21 |
22 | .. autosummary::
23 | :toctree:
24 | :template: custom-base-template.rst
25 | {% for item in functions %}
26 | {{ item }}
27 | {%- endfor %}
28 | {% endif %}
29 | {% endblock %}
30 |
31 | {% block attributes %}
32 | {% if attributes %}
33 | .. rubric:: Module Attributes
34 |
35 | .. autosummary::
36 | :toctree:
37 | :template: custom-base-template.rst
38 | {% for item in attributes %}
39 | {{ item }}
40 | {%- endfor %}
41 | {% endif %}
42 | {% endblock %}
43 |
44 | {% block exceptions %}
45 | {% if exceptions %}
46 | .. rubric:: {{ _('Exceptions') }}
47 |
48 | .. autosummary::
49 | :toctree:
50 | :template: custom-base-template.rst
51 | {% for item in exceptions %}
52 | {{ item }}
53 | {%- endfor %}
54 | {% endif %}
55 | {% endblock %}
56 |
57 | {% block modules %}
58 | {% if modules %}
59 | .. rubric:: Modules
60 |
61 | .. autosummary::
62 | :toctree:
63 | :template: custom-module-template.rst
64 | :recursive:
65 | {% for item in modules %}
66 | {{ item }}
67 | {%- endfor %}
68 | {% endif %}
69 | {% endblock %}
70 |
--------------------------------------------------------------------------------
/docs/_static/logo_dark.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
24 |
--------------------------------------------------------------------------------
/src/scarlet2/psf.py:
--------------------------------------------------------------------------------
1 | """PSF-related classes"""
2 |
3 | import jax.numpy as jnp
4 |
5 | from .module import Module
6 | from .morphology import GaussianMorphology
7 |
8 |
9 | class PSF(Module):
10 | """PSF base class"""
11 |
12 | @property
13 | def shape(self):
14 | """Shape of the PSF model"""
15 | return self.morphology.shape
16 |
17 |
18 | class ArrayPSF(PSF):
19 | """PSF defined by an image array
20 |
21 | Warnings
22 | --------
23 | The number of pixels in `morphology` should be odd, and the centroid of the
24 | PSF image should be in the central pixel. If that is not the case, one creates
25 | an effective shift by the PSF, which is not captured by the coordinate
26 | convention of the frame, e.g. its :py:attr:`~scarlet2.Frame.wcs`.
27 |
28 | See :issue:`96` from more details.
29 | """
30 |
31 | morphology: jnp.ndarray
32 | """The PSF morphology image. Can be 2D (height, width) or 3D (channel, height, width)"""
33 |
34 | def __call__(self):
35 | """Evaluate PSF
36 |
37 | Returns
38 | -------
39 | array
40 | 2D image, normalized to total flux=1
41 | """
42 | return self.morphology / self.morphology.sum((-2, -1), keepdims=True)
43 |
44 |
45 | class GaussianPSF(PSF):
46 | """Gaussian-shaped PSF"""
47 |
48 | morphology: GaussianMorphology
49 | """Morphology model"""
50 |
51 | def __init__(self, sigma):
52 | """Initialize Gaussian PSF
53 |
54 | Parameters
55 | ----------
56 | sigma: float
57 | Standard deviation of Gaussian
58 | """
59 | self.morphology = GaussianMorphology(sigma)
60 |
61 | def __call__(self):
62 | """What to run when the Gaussian PSF is called"""
63 | morph = self.morphology()
64 | morph /= morph.sum((-2, -1), keepdims=True)
65 | return morph
66 |
--------------------------------------------------------------------------------
/tests/scarlet2/questionnaire/data/example_questionnaire_followup_switch.yaml:
--------------------------------------------------------------------------------
1 | initial_commentary: This is an example commentary.
2 | initial_template: '{{code}}'
3 | questions:
4 | - question: Example question?
5 | answers:
6 | - answer: First answer
7 | commentary: This is some commentary.
8 | templates:
9 | - code: 'example_code {{follow}}'
10 | replacement: code
11 | tooltip: This is an example tooltip.
12 | followups:
13 | - question: Follow-up question?
14 | variable: example_var
15 | answers:
16 | - answer: Follow-up answer
17 | commentary: ''
18 | followups: []
19 | templates:
20 | - code: 'followup_code {{code}}'
21 | replacement: follow
22 | tooltip: This is a follow-up tooltip.
23 | - answer: Second follow-up answer
24 | templates:
25 | - code: 'second_followup_code {{code}}'
26 | replacement: follow
27 | tooltip: This is a second follow-up tooltip.
28 | - answer: Second answer
29 | commentary: Some other commentary.
30 | followups: []
31 | templates:
32 | - code: 'another_code {{code}}'
33 | replacement: code
34 | tooltip: This is another tooltip.
35 | - switch: example_var
36 | cases:
37 | - value: 0
38 | questions:
39 | - question: Second Question 1?
40 | answers:
41 | - answer: Second answer 1
42 | templates:
43 | - code: 'followup_code_1 {{code}}'
44 | replacement: code
45 | - value: 1
46 | questions:
47 | - answers:
48 | - answer: Second answer 2
49 | templates:
50 | - code: 'followup_code_2 {{code}}'
51 | replacement: code
52 | question: Second Question 2?
53 | - value: null
54 | questions:
55 | - question: Default Question?
56 | answers:
57 | - answer: Default answer
58 | templates:
59 | - code: 'default_code {{code}}'
60 | replacement: code
61 |
--------------------------------------------------------------------------------
/src/scarlet2/questionnaire/models.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | from typing import Union
3 |
4 | from pydantic import BaseModel, Field
5 |
6 |
7 | class Template(BaseModel):
8 | """Represents a template for code replacement in the questionnaire."""
9 |
10 | replacement: str
11 | code: str
12 |
13 |
14 | class Answer(BaseModel):
15 | """Represents an answer to a question in the questionnaire."""
16 |
17 | answer: str
18 | tooltip: str = ""
19 | templates: list[Template] = Field(default_factory=list)
20 | followups: list[Union["Question", "Switch"]] = Field(default_factory=list)
21 | commentary: str = ""
22 |
23 |
24 | class Question(BaseModel):
25 | """Represents a question in the questionnaire."""
26 |
27 | question: str
28 | variable: str | None = None
29 | answers: list[Answer]
30 |
31 |
32 | class Case(BaseModel):
33 | """Represents a case in a switch statement within the questionnaire."""
34 |
35 | value: int | None = None
36 | questions: list[Union[Question, "Switch"]]
37 |
38 |
39 | class Switch(BaseModel):
40 | """Represents a switch statement in the questionnaire."""
41 |
42 | switch: str
43 | cases: list[Case]
44 |
45 |
46 | # Rebuild models to support self-referencing types and forward references
47 | Question.model_rebuild()
48 | Answer.model_rebuild()
49 | Case.model_rebuild()
50 | Switch.model_rebuild()
51 |
52 |
53 | class QuestionAnswer(BaseModel):
54 | """Represents a user's answer to a question."""
55 |
56 | question: str
57 | answer: str
58 | value: int
59 |
60 |
61 | class QuestionAnswers(BaseModel):
62 | """Represents a collection of user answers to questions."""
63 |
64 | answers: list[QuestionAnswer] = Field(default_factory=list)
65 | timestamp: datetime = Field(default_factory=datetime.now)
66 |
67 |
68 | class Questionnaire(BaseModel):
69 | """Represents a questionnaire with an initial template and a list of questions."""
70 |
71 | initial_template: str
72 | initial_commentary: str = ""
73 | feedback_url: str | None = None
74 | questions: list[Question | Switch]
75 |
--------------------------------------------------------------------------------
/tests/scarlet2/test_scene_fit_checks.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import patch
2 |
3 | import pytest
4 | from scarlet2.scene import FitValidator
5 | from scarlet2.validation_utils import (
6 | ValidationError,
7 | ValidationInfo,
8 | ValidationWarning,
9 | set_validation,
10 | )
11 |
12 |
13 | @pytest.fixture(autouse=True)
14 | def setup_validation():
15 | """Automatically disable validation for all tests. This permits the creation
16 | of intentionally invalid Observation objects."""
17 | set_validation(False)
18 |
19 |
20 | @pytest.mark.parametrize(
21 | "mocked_chi_value,expected",
22 | [
23 | (0.1, ValidationInfo),
24 | (2.0, ValidationWarning),
25 | (10.0, ValidationError),
26 | ],
27 | )
28 | def test_check_goodness_of_fit(scene, parameters, good_obs, mocked_chi_value, expected):
29 | """Mocked goodness_of_fit return to produces the expected ValidationResult type"""
30 |
31 | scene_ = scene.fit(good_obs, parameters, max_iter=1, e_rel=1e-4, progress_bar=True)
32 | checker = FitValidator(scene_, good_obs)
33 |
34 | with patch.object(type(good_obs), "goodness_of_fit", return_value=mocked_chi_value) as _:
35 | results = checker.check_goodness_of_fit()
36 |
37 | assert isinstance(results, expected)
38 |
39 |
40 | @pytest.mark.parametrize(
41 | "mocked_chi_value,expected",
42 | [
43 | (0.1, ValidationInfo),
44 | (2.0, ValidationWarning),
45 | (10.0, ValidationError),
46 | ],
47 | )
48 | def test_check_chi_squared_in_box_and_border(scene, parameters, good_obs, mocked_chi_value, expected):
49 | """Mocked chi-squared evaluation in box and border."""
50 |
51 | scene_ = scene.fit(good_obs, parameters, max_iter=1, e_rel=1e-4, progress_bar=True)
52 | checker = FitValidator(scene_, good_obs)
53 | mock_return = {0: {"in": mocked_chi_value, "out": mocked_chi_value}}
54 |
55 | with patch.object(type(good_obs), "eval_chi_square_in_box_and_border", return_value=mock_return) as _:
56 | results = checker.check_chi_square_in_box_and_border()
57 |
58 | assert all(isinstance(res, expected) for res in results)
59 |
--------------------------------------------------------------------------------
/.github/workflows/publish-benchmarks-pr.yml:
--------------------------------------------------------------------------------
1 | # This workflow publishes a benchmarks comment on a pull request. It is triggered after the
2 | # benchmarks are computed in the asv-pr workflow. This separation of concerns allows us limit
3 | # access to the target repository private tokens and secrets, increasing the level of security.
4 | # Based on https://securitylab.github.com/research/github-actions-preventing-pwn-requests/.
5 | name: Publish benchmarks comment to PR
6 |
7 | on:
8 | workflow_run:
9 | workflows: ["Run benchmarks for PR"]
10 | types: [completed]
11 |
12 | jobs:
13 | upload-pr-comment:
14 | runs-on: ubuntu-latest
15 | if: >
16 | github.event.workflow_run.event == 'pull_request' &&
17 | github.event.workflow_run.conclusion == 'success'
18 | permissions:
19 | issues: write
20 | pull-requests: write
21 | steps:
22 | - name: Display Workflow Run Information
23 | run: |
24 | echo "Workflow Run ID: ${{ github.event.workflow_run.id }}"
25 | echo "Head SHA: ${{ github.event.workflow_run.head_sha }}"
26 | echo "Head Branch: ${{ github.event.workflow_run.head_branch }}"
27 | echo "Conclusion: ${{ github.event.workflow_run.conclusion }}"
28 | echo "Event: ${{ github.event.workflow_run.event }}"
29 | - name: Download artifact
30 | uses: dawidd6/action-download-artifact@v11
31 | with:
32 | name: benchmark-artifacts
33 | run_id: ${{ github.event.workflow_run.id }}
34 | - name: Extract artifacts information
35 | id: pr-info
36 | run: |
37 | printf "PR number: $(cat pr)\n"
38 | printf "Output:\n$(cat output)"
39 | printf "pr=$(cat pr)" >> $GITHUB_OUTPUT
40 | - name: Find benchmarks comment
41 | uses: peter-evans/find-comment@v4
42 | id: find-comment
43 | with:
44 | issue-number: ${{ steps.pr-info.outputs.pr }}
45 | comment-author: 'github-actions[bot]'
46 | body-includes: view all benchmarks
47 | - name: Create or update benchmarks comment
48 | uses: peter-evans/create-or-update-comment@v5
49 | with:
50 | comment-id: ${{ steps.find-comment.outputs.comment-id }}
51 | issue-number: ${{ steps.pr-info.outputs.pr }}
52 | body-path: output
53 | edit-mode: replace
--------------------------------------------------------------------------------
/src/scarlet2/__init__.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: E402
2 | """Main namespace for scarlet2"""
3 |
4 |
5 | class Scenery:
6 | """Class to hold the context for the current scene
7 |
8 | See Also
9 | --------
10 | :py:class:`~scarlet2.Scene`
11 | """
12 |
13 | scene = None
14 | """Scene of the currently opened context"""
15 |
16 |
17 | class Parameterization:
18 | """Class to hold the context for the current parameter set
19 |
20 | See Also
21 | --------
22 | :py:class:`~scarlet2.Parameters`
23 | """
24 |
25 | parameters = None
26 | """Parameters of the currently opened context"""
27 |
28 |
29 | from . import init, measure, plot
30 | from .bbox import Box
31 | from .frame import Frame
32 | from .module import Module, Parameter, Parameters, relative_step
33 | from .morphology import (
34 | GaussianMorphology,
35 | Morphology,
36 | ProfileMorphology,
37 | SersicMorphology,
38 | StarletMorphology,
39 | )
40 | from .observation import CorrelatedObservation, Observation
41 | from .psf import PSF, ArrayPSF, GaussianPSF
42 | from .scene import Scene
43 | from .source import Component, DustComponent, PointSource, Source
44 | from .spectrum import Spectrum, StaticArraySpectrum, TransientArraySpectrum
45 | from .validation import check_fit, check_observation, check_scene, check_source
46 | from .validation_utils import VALIDATION_SWITCH, set_validation
47 | from .wavelets import Starlet
48 |
49 | # for * imports and docs
50 | __all__ = [
51 | "init",
52 | "measure",
53 | "plot",
54 | "Scenery",
55 | "Parameterization",
56 | "Box",
57 | "Frame",
58 | "Parameter",
59 | "Parameters",
60 | "Module",
61 | "relative_step",
62 | "Morphology",
63 | "ProfileMorphology",
64 | "GaussianMorphology",
65 | "SersicMorphology",
66 | "StarletMorphology",
67 | "Observation",
68 | "CorrelatedObservation",
69 | "PSF",
70 | "ArrayPSF",
71 | "GaussianPSF",
72 | "Scene",
73 | "Component",
74 | "DustComponent",
75 | "Source",
76 | "PointSource",
77 | "Spectrum",
78 | "StaticArraySpectrum",
79 | "TransientArraySpectrum",
80 | "Starlet",
81 | "check_fit",
82 | "check_observation",
83 | "check_scene",
84 | "check_source",
85 | "VALIDATION_SWITCH",
86 | "set_validation",
87 | ]
88 |
--------------------------------------------------------------------------------
/.setup_dev.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Bash Unofficial strict mode (http://redsymbol.net/articles/unofficial-bash-strict-mode/)
4 | # and (https://disconnected.systems/blog/another-bash-strict-mode/)
5 | set -o nounset # Any uninitialized variable is an error
6 | set -o errexit # Exit the script on the failure of any command to execute without error
7 | set -o pipefail # Fail command pipelines on the failure of any individual step
8 | IFS=$'\n\t' #set internal field separator to avoid iteration errors
9 | # Trap all exits and output something helpful
10 | trap 's=$?; echo "$0: Error on line "$LINENO": $BASH_COMMAND"; exit $s' ERR
11 |
12 | # This script should be run by new developers to install this package in
13 | # editable mode and configure their local environment
14 |
15 | echo "Checking virtual environment"
16 | if [ "${VIRTUAL_ENV:-missing}" = "missing" ] && [ "${CONDA_PREFIX:-missing}" = "missing" ]; then
17 | echo 'No virtual environment detected: none of $VIRTUAL_ENV or $CONDA_PREFIX is set.'
18 | echo
19 | echo "=== This script is going to install the project in the system python environment ==="
20 | echo "Proceed? [y/N]"
21 | read -r RESPONCE
22 | if [ "${RESPONCE}" != "y" ]; then
23 | echo "See https://lincc-ppt.readthedocs.io/ for details."
24 | echo "Exiting."
25 | exit 1
26 | fi
27 |
28 | fi
29 |
30 | echo "Checking pip version"
31 | MINIMUM_PIP_VERSION=22
32 | pipversion=( $(python -m pip --version | awk '{print $2}' | sed 's/\./\n\t/g') )
33 | if let "${pipversion[0]}<${MINIMUM_PIP_VERSION}"; then
34 | echo "Insufficient version of pip found. Requires at least version ${MINIMUM_PIP_VERSION}."
35 | echo "See https://lincc-ppt.readthedocs.io/ for details."
36 | exit 1
37 | fi
38 |
39 | echo "Installing package and runtime dependencies in local environment"
40 | python -m pip install -e . > /dev/null
41 |
42 | echo "Installing developer dependencies in local environment"
43 | python -m pip install -e .'[dev]' > /dev/null
44 | if [ -f docs/requirements.txt ]; then python -m pip install -r docs/requirements.txt > /dev/null; fi
45 |
46 | echo "Installing pre-commit"
47 | pre-commit install > /dev/null
48 |
49 | #######################################################
50 | # Include any additional configurations below this line
51 | #######################################################
52 |
--------------------------------------------------------------------------------
/.github/workflows/asv-main.yml:
--------------------------------------------------------------------------------
1 | # This workflow will run benchmarks with airspeed velocity (asv),
2 | # store the new results in the "benchmarks" branch and publish them
3 | # to a dashboard on GH Pages.
4 | name: Run ASV benchmarks for main
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 |
10 | env:
11 | PYTHON_VERSION: "3.11"
12 | ASV_VERSION: "0.6.4"
13 | WORKING_DIR: ${{github.workspace}}/benchmarks
14 |
15 | concurrency:
16 | group: ${{github.workflow}}-${{github.ref}}
17 | cancel-in-progress: true
18 |
19 | jobs:
20 | asv-main:
21 | runs-on: ubuntu-latest
22 | permissions:
23 | contents: write
24 | defaults:
25 | run:
26 | working-directory: ${{env.WORKING_DIR}}
27 | steps:
28 | - name: Set up Python ${{env.PYTHON_VERSION}}
29 | uses: actions/setup-python@v6
30 | with:
31 | python-version: ${{env.PYTHON_VERSION}}
32 | - name: Checkout main branch of the repository
33 | uses: actions/checkout@v6
34 | with:
35 | fetch-depth: 0
36 | - name: Install dependencies
37 | run: pip install "asv[virtualenv]==${{env.ASV_VERSION}}"
38 | - name: Configure git
39 | run: |
40 | git config user.name "github-actions[bot]"
41 | git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
42 | - name: Create ASV machine config file
43 | run: asv machine --machine gh-runner --yes
44 | - name: Fetch previous results from the "benchmarks" branch
45 | run: |
46 | if git ls-remote --exit-code origin benchmarks > /dev/null 2>&1; then
47 | git merge origin/benchmarks \
48 | --allow-unrelated-histories \
49 | --no-commit
50 | mv ../_results .
51 | fi
52 | - name: Run ASV for the main branch
53 | run: asv run ALL --skip-existing --verbose || true
54 | - name: Submit new results to the "benchmarks" branch
55 | uses: JamesIves/github-pages-deploy-action@v4
56 | with:
57 | branch: benchmarks
58 | folder: ${{env.WORKING_DIR}}/_results
59 | target-folder: _results
60 | - name: Generate dashboard HTML
61 | run: |
62 | asv show
63 | asv publish
64 | - name: Deploy to Github pages
65 | uses: JamesIves/github-pages-deploy-action@v4
66 | with:
67 | branch: gh-pages
68 | folder: ${{env.WORKING_DIR}}/_html
--------------------------------------------------------------------------------
/src/scarlet2/questionnaire/views/question_box.css:
--------------------------------------------------------------------------------
1 | /* Add overflow-x: hidden to the question box to prevent horizontal scrolling */
2 | .question-box-container {
3 | overflow-x: hidden;
4 | display: flex;
5 | flex-direction: column;
6 | min-height: 300px; /* Ensure there's enough height for content */
7 | }
8 |
9 | /* Style for the save button container */
10 | .save-button-container {
11 | margin-top: auto;
12 | text-align: center;
13 | padding: 5px;
14 | }
15 |
16 | /* Style for the save button */
17 | .save-button {
18 | font-size: 0.9em;
19 | background-color: #444444;
20 | color: white;
21 | }
22 |
23 | .save-message {
24 | margin-top: 10px;
25 | text-align: left;
26 | }
27 |
28 | .prev-item {
29 | position: relative;
30 | overflow: visible;
31 | }
32 | .prev-btn {
33 | width: auto;
34 | text-align: left;
35 | border: none;
36 | background: none;
37 | box-shadow: none;
38 | color: #555;
39 | text-decoration: none;
40 | cursor: pointer;
41 | font-size: 0.9em;
42 | margin: 2px;
43 | padding: 2px 4px;
44 | overflow: hidden;
45 | max-width: 100%; /* Ensure button doesn't exceed container width */
46 | text-overflow: ellipsis; /* Add ellipsis for text that overflows */
47 | }
48 |
49 | /* Container for the tooltip */
50 | .prev-item .tooltip {
51 | position: absolute;
52 | background: #333;
53 | color: #fff;
54 | padding: 6px 8px;
55 | border-radius: 8px;
56 | font-size: 12px;
57 | line-height: 1.3;
58 | box-shadow: 0 4px 12px rgba(0,0,0,.2);
59 | white-space: normal;
60 | min-width: 160px;
61 | max-width: 320px;
62 | opacity: 0;
63 | visibility: hidden;
64 | transition: opacity .12s ease;
65 | pointer-events: none;
66 | z-index: 9999;
67 |
68 | /* Default position to the right of the button */
69 | left: 100%;
70 | top: 50%;
71 | transform: translateY(-50%) translateX(8px);
72 | }
73 |
74 | /* When hovering over the container, show the tooltip */
75 | .prev-item:hover .tooltip {
76 | opacity: 1;
77 | visibility: visible;
78 | }
79 |
80 | /* Arrow for the tooltip */
81 | .prev-item .tooltip::before {
82 | content: "";
83 | position: absolute;
84 | border-width: 6px;
85 | border-style: solid;
86 | left: -6px;
87 | top: 50%;
88 | transform: translateY(-50%);
89 | border-color: transparent #333 transparent transparent;
90 | }
91 |
--------------------------------------------------------------------------------
/src/scarlet2/io.py:
--------------------------------------------------------------------------------
1 | """Methods to save and load scenes"""
2 |
3 | import os
4 | import pickle
5 |
6 | import h5py
7 | import numpy as np
8 |
9 |
10 | def model_to_h5(model, filename, id=0, path=".", overwrite=False):
11 | """Save the scene model to a HDF5 file
12 |
13 | Parameters
14 | ----------
15 | filename : str
16 | Name of the HDF5 file to create
17 | model : :py:class:`~scarlet2.Module`
18 | Scene to be stored
19 | id : int
20 | HDF5 group to store this `model` under
21 | path: str, optional
22 | Explicit path for `filename`. If not set, uses local directory
23 | overwrite : bool, optional
24 | Whether to overwrite an existing file with the same path and filename
25 |
26 | Returns
27 | -------
28 | None
29 |
30 | Notes
31 | -----
32 | This is not a pure function hence cannot be utilized within a JAX JIT compilation.
33 | """
34 | # create directory if it does not exist
35 | if not os.path.exists(path):
36 | os.makedirs(path)
37 |
38 | # first serialize the model into a pytree
39 | model_group = str(id)
40 | save_h5_path = os.path.join(path, filename)
41 |
42 | f = h5py.File(save_h5_path, "a")
43 | # create a group for the scene
44 | if model_group in f:
45 | if overwrite:
46 | del f[model_group]
47 | else:
48 | raise ValueError("ID already exists. Set overwrite=True to overwrite the ID.")
49 |
50 | # save the binary to HDF5
51 | group = f.create_group(model_group)
52 | model = pickle.dumps(model)
53 | group.attrs["model"] = np.void(model)
54 | f.close()
55 |
56 |
57 | def model_from_h5(filename, id=0, path="."):
58 | """
59 | Load scene model from a HDF5 file
60 |
61 | Parameters
62 | ----------
63 | filename : str
64 | Name of the HDF5 file to load from
65 | id : int
66 | HDF5 group to identify the scene by
67 | path: str, optional
68 | Explicit path for `filename`. If not set, uses local directory
69 |
70 | Returns
71 | -------
72 | :py:class:`~scarlet2.Scene`
73 | """
74 |
75 | filename = os.path.join(path, filename)
76 | f = h5py.File(filename, "r")
77 | model_group = str(id)
78 | if model_group not in f:
79 | raise ValueError(f"ID {id} not found in the file.")
80 |
81 | group = f.get(model_group)
82 | out = group.attrs["model"]
83 | binary_blob = out.tobytes()
84 | scene = pickle.loads(binary_blob)
85 | f.close()
86 |
87 | return scene
88 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 | _version.py
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_autosummary/
74 | docs/_build/
75 | docs/jupyter_execute/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # VSCode settings
135 | .vscode
136 |
137 | # Default location for saved models
138 | stored_models/
139 |
--------------------------------------------------------------------------------
/.initialize_new_project.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Bash Unofficial strict mode (http://redsymbol.net/articles/unofficial-bash-strict-mode/)
4 | # and (https://disconnected.systems/blog/another-bash-strict-mode/)
5 | set -o nounset # Any uninitialized variable is an error
6 | set -o errexit # Exit the script on the failure of any command to execute without error
7 | set -o pipefail # Fail command pipelines on the failure of any individual step
8 | IFS=$'\n\t' #set internal field separator to avoid iteration errors
9 | # Trap all exits and output something helpful
10 | trap 's=$?; echo "$0: Error on line "$LINENO": $BASH_COMMAND"; exit $s' ERR
11 |
12 | echo "Checking virtual environment"
13 | if [ "${VIRTUAL_ENV:-missing}" = "missing" ] && [ "${CONDA_PREFIX:-missing}" = "missing" ]; then
14 | echo 'No virtual environment detected: none of $VIRTUAL_ENV or $CONDA_PREFIX is set.'
15 | echo
16 | echo "=== This script is going to install the project in the system python environment ==="
17 | echo "Proceed? [y/N]"
18 | read -r RESPONCE
19 | if [ "${RESPONCE}" != "y" ]; then
20 | echo "See https://lincc-ppt.readthedocs.io/ for details."
21 | echo "Exiting."
22 | exit 1
23 | fi
24 |
25 | fi
26 |
27 | echo "Checking pip version"
28 | MINIMUM_PIP_VERSION=22
29 | pipversion=( $(python -m pip --version | awk '{print $2}' | sed 's/\./\n\t/g') )
30 | if let "${pipversion[0]}<${MINIMUM_PIP_VERSION}"; then
31 | echo "Insufficient version of pip found. Requires at least version ${MINIMUM_PIP_VERSION}."
32 | echo "See https://lincc-ppt.readthedocs.io/ for details."
33 | exit 1
34 | fi
35 |
36 | echo "Initializing local git repository"
37 | {
38 | gitversion=( $(git version | git version | awk '{print $3}' | sed 's/\./\n\t/g') )
39 | if let "${gitversion[0]}<2"; then
40 | # manipulate directly
41 | git init . && echo 'ref: refs/heads/main' >.git/HEAD
42 | elif let "${gitversion[0]}==2 & ${gitversion[1]}<34"; then
43 | # rename master to main
44 | git init . && { git branch -m master main 2>/dev/null || true; };
45 | else
46 | # set the initial branch name to main
47 | git init --initial-branch=main >/dev/null
48 | fi
49 | } > /dev/null
50 |
51 | echo "Installing package and runtime dependencies in local environment"
52 | python -m pip install -e . > /dev/null
53 |
54 | echo "Installing developer dependencies in local environment"
55 | python -m pip install -e .'[dev]' > /dev/null
56 | if [ -f docs/requirements.txt ]; then python -m pip install -r docs/requirements.txt; fi
57 |
58 | echo "Installing pre-commit"
59 | pre-commit install > /dev/null
60 |
61 | echo "Committing initial files"
62 | git add . && SKIP="no-commit-to-branch" git commit -m "Initial commit"
63 |
--------------------------------------------------------------------------------
/.github/workflows/asv-nightly.yml:
--------------------------------------------------------------------------------
1 | # This workflow will run daily at 06:45.
2 | # It will run benchmarks with airspeed velocity (asv)
3 | # and compare performance with the previous nightly build.
4 | name: Run benchmarks nightly job
5 |
6 | on:
7 | schedule:
8 | - cron: 45 6 * * *
9 | workflow_dispatch:
10 |
11 | env:
12 | PYTHON_VERSION: "3.11"
13 | ASV_VERSION: "0.6.4"
14 | WORKING_DIR: ${{github.workspace}}/benchmarks
15 | NIGHTLY_HASH_FILE: nightly-hash
16 |
17 | jobs:
18 | asv-nightly:
19 | runs-on: ubuntu-latest
20 | defaults:
21 | run:
22 | working-directory: ${{env.WORKING_DIR}}
23 | steps:
24 | - name: Set up Python ${{env.PYTHON_VERSION}}
25 | uses: actions/setup-python@v6
26 | with:
27 | python-version: ${{env.PYTHON_VERSION}}
28 | - name: Checkout main branch of the repository
29 | uses: actions/checkout@v6
30 | with:
31 | fetch-depth: 0
32 | - name: Install dependencies
33 | run: pip install "asv[virtualenv]==${{env.ASV_VERSION}}"
34 | - name: Configure git
35 | run: |
36 | git config user.name "github-actions[bot]"
37 | git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
38 | - name: Create ASV machine config file
39 | run: asv machine --machine gh-runner --yes
40 | - name: Fetch previous results from the "benchmarks" branch
41 | run: |
42 | if git ls-remote --exit-code origin benchmarks > /dev/null 2>&1; then
43 | git merge origin/benchmarks \
44 | --allow-unrelated-histories \
45 | --no-commit
46 | mv ../_results .
47 | fi
48 | - name: Get nightly dates under comparison
49 | id: nightly-dates
50 | run: |
51 | echo "yesterday=$(date -d yesterday +'%Y-%m-%d')" >> $GITHUB_OUTPUT
52 | echo "today=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
53 | - name: Use last nightly commit hash from cache
54 | uses: actions/cache@v4
55 | with:
56 | path: ${{env.WORKING_DIR}}
57 | key: nightly-results-${{steps.nightly-dates.outputs.yesterday}}
58 | - name: Run comparison of main against last nightly build
59 | run: |
60 | HASH_FILE=${{env.NIGHTLY_HASH_FILE}}
61 | CURRENT_HASH=${{github.sha}}
62 | if [ -f $HASH_FILE ]; then
63 | PREV_HASH=$(cat $HASH_FILE)
64 | asv continuous $PREV_HASH $CURRENT_HASH --verbose || true
65 | asv compare $PREV_HASH $CURRENT_HASH --sort ratio --verbose
66 | fi
67 | echo $CURRENT_HASH > $HASH_FILE
68 | - name: Update last nightly hash in cache
69 | uses: actions/cache@v4
70 | with:
71 | path: ${{env.WORKING_DIR}}
72 | key: nightly-results-${{steps.nightly-dates.outputs.today}}
--------------------------------------------------------------------------------
/.github/workflows/asv-pr.yml:
--------------------------------------------------------------------------------
1 | # This workflow will run benchmarks with airspeed velocity (asv) for pull requests.
2 | # It will compare the performance of the main branch with the performance of the merge
3 | # with the new changes. It then publishes a comment with this assessment by triggering
4 | # the publish-benchmarks-pr workflow.
5 | # Based on https://securitylab.github.com/research/github-actions-preventing-pwn-requests/.
6 | name: Run benchmarks for PR
7 |
8 | on:
9 | pull_request:
10 | branches: [ main ]
11 | workflow_dispatch:
12 |
13 | concurrency:
14 | group: ${{github.workflow}}-${{github.ref}}
15 | cancel-in-progress: true
16 |
17 | env:
18 | PYTHON_VERSION: "3.11"
19 | ASV_VERSION: "0.6.4"
20 | WORKING_DIR: ${{github.workspace}}/benchmarks
21 | ARTIFACTS_DIR: ${{github.workspace}}/artifacts
22 |
23 | jobs:
24 | asv-pr:
25 | runs-on: ubuntu-latest
26 | defaults:
27 | run:
28 | working-directory: ${{env.WORKING_DIR}}
29 | steps:
30 | - name: Set up Python ${{env.PYTHON_VERSION}}
31 | uses: actions/setup-python@v6
32 | with:
33 | python-version: ${{env.PYTHON_VERSION}}
34 | - name: Checkout PR branch of the repository
35 | uses: actions/checkout@v6
36 | with:
37 | fetch-depth: 0
38 | - name: Display Workflow Run Information
39 | run: |
40 | echo "Workflow Run ID: ${{github.run_id}}"
41 | - name: Install dependencies
42 | run: pip install "asv[virtualenv]==${{env.ASV_VERSION}}" lf-asv-formatter
43 | - name: Make artifacts directory
44 | run: mkdir -p ${{env.ARTIFACTS_DIR}}
45 | - name: Save pull request number
46 | run: echo ${{github.event.pull_request.number}} > ${{env.ARTIFACTS_DIR}}/pr
47 | - name: Get current job logs URL
48 | uses: Tiryoh/gha-jobid-action@v1
49 | id: jobs
50 | with:
51 | github_token: ${{secrets.GITHUB_TOKEN}}
52 | job_name: ${{github.job}}
53 | - name: Create ASV machine config file
54 | run: asv machine --machine gh-runner --yes
55 | - name: Save comparison of PR against main branch
56 | run: |
57 | git remote add upstream https://github.com/${{github.repository}}.git
58 | git fetch upstream
59 | asv continuous upstream/main HEAD --verbose || true
60 | asv compare upstream/main HEAD --sort ratio --verbose | tee output
61 | python -m lf_asv_formatter --asv_version "$(asv --version | awk '{print $2}')"
62 | printf "\n\nClick [here]($STEP_URL) to view all benchmarks." >> output
63 | mv output ${{env.ARTIFACTS_DIR}}
64 | env:
65 | STEP_URL: ${{steps.jobs.outputs.html_url}}#step:10:1
66 | - name: Upload artifacts (PR number and benchmarks output)
67 | uses: actions/upload-artifact@v5
68 | with:
69 | name: benchmark-artifacts
70 | path: ${{env.ARTIFACTS_DIR}}
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 |
2 | repos:
3 | # Compare the local template version to the latest remote template version
4 | # This hook should always pass. It will print a message if the local version
5 | # is out of date.
6 | - repo: https://github.com/lincc-frameworks/pre-commit-hooks
7 | rev: v0.1.2
8 | hooks:
9 | - id: check-lincc-frameworks-template-version
10 | name: Check template version
11 | description: Compare current template version against latest
12 | verbose: true
13 | # Clear output from jupyter notebooks so that only the input cells are committed.
14 | - repo: local
15 | hooks:
16 | - id: jupyter-nb-clear-output
17 | name: Clear output from Jupyter notebooks
18 | description: Clear output from Jupyter notebooks.
19 | files: \.ipynb$
20 | exclude: ^docs/pre_executed
21 | stages: [pre-commit]
22 | language: system
23 | entry: jupyter nbconvert --clear-output
24 | # Prevents committing directly branches named 'main' and 'master'.
25 | - repo: https://github.com/pre-commit/pre-commit-hooks
26 | rev: v4.4.0
27 | hooks:
28 | - id: no-commit-to-branch
29 | name: Prevent main branch commits
30 | description: Prevent the user from committing directly to the primary branch.
31 | - id: check-added-large-files
32 | name: Check for large files
33 | description: Prevent the user from committing very large files.
34 | args: ['--maxkb=500']
35 | # Verify that pyproject.toml is well formed
36 | - repo: https://github.com/abravalheri/validate-pyproject
37 | rev: v0.12.1
38 | hooks:
39 | - id: validate-pyproject
40 | name: Validate pyproject.toml
41 | description: Verify that pyproject.toml adheres to the established schema.
42 | # Verify that GitHub workflows are well formed
43 | - repo: https://github.com/python-jsonschema/check-jsonschema
44 | rev: 0.28.0
45 | hooks:
46 | - id: check-github-workflows
47 | args: ["--verbose"]
48 | - repo: https://github.com/astral-sh/ruff-pre-commit
49 | # Ruff version.
50 | rev: v0.13.3
51 | hooks:
52 | # Run the linter.
53 | - id: ruff-check
54 | types_or: [ python, pyi ]
55 | args: [ --fix ]
56 | # Run the formatter.
57 | - id: ruff-format
58 | types_or: [ python, pyi ]
59 | # Run unit tests, verify that they pass. Note that coverage is run against
60 | # the ./src directory here because that is what will be committed. In the
61 | # github workflow script, the coverage is run against the installed package
62 | # and uploaded to Codecov by calling pytest like so:
63 | # `python -m pytest --cov= --cov-report=xml`
64 | - repo: local
65 | hooks:
66 | - id: pytest-check
67 | name: Run unit tests
68 | description: Run unit tests with pytest.
69 | entry: bash -c "if python -m pytest --co -qq; then python -m pytest --cov=./src --cov-report=html; fi"
70 | language: system
71 | pass_filenames: false
72 | always_run: true
73 |
--------------------------------------------------------------------------------
/tests/scarlet2/test_quickstart.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: D101
2 | # ruff: noqa: D102
3 | # ruff: noqa: D103
4 | # ruff: noqa: D106
5 |
6 | from functools import partial
7 |
8 | import jax.numpy as jnp
9 | import numpyro.distributions as dist
10 | from huggingface_hub import hf_hub_download
11 | from numpyro.distributions import constraints
12 | from numpyro.infer.initialization import init_to_sample
13 |
14 | from scarlet2 import init
15 | from scarlet2.frame import Frame
16 | from scarlet2.module import Parameter, Parameters, relative_step
17 | from scarlet2.observation import Observation
18 | from scarlet2.psf import ArrayPSF, GaussianPSF
19 | from scarlet2.scene import Scene
20 | from scarlet2.source import Source
21 | from scarlet2.validation_utils import set_validation
22 |
23 | # turn off automatic validation checks
24 | set_validation(False)
25 |
26 | filename = hf_hub_download(
27 | repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset"
28 | )
29 | file = jnp.load(filename)
30 | data = jnp.asarray(file["images"])
31 | channels = [str(f) for f in file["filters"]]
32 | centers = [(src["y"], src["x"]) for src in file["catalog"]] # Note: y/x convention!
33 | weights = jnp.asarray(1 / file["variance"])
34 | psf = jnp.asarray(file["psfs"])
35 |
36 | _ = GaussianPSF(0.7)
37 | obs = Observation(data, weights, psf=ArrayPSF(jnp.asarray(psf)), channels=channels)
38 | model_frame = Frame.from_observations(obs)
39 |
40 | with Scene(model_frame) as scene:
41 | for center in centers:
42 | center = jnp.array(center)
43 | try:
44 | spectrum, morph = init.from_gaussian_moments(obs, center, min_corr=0.99)
45 | except ValueError:
46 | spectrum = init.pixel_spectrum(obs, center)
47 | morph = init.compact_morphology()
48 |
49 | Source(center, spectrum, morph)
50 |
51 | scene_ = None
52 |
53 |
54 | def test_fit():
55 | global scene_
56 | spec_step = partial(relative_step, factor=0.05)
57 |
58 | # fitting
59 | with Parameters(scene) as parameters:
60 | for i in range(len(scene.sources)):
61 | Parameter(
62 | scene.sources[i].spectrum,
63 | name=f"spectrum:{i}",
64 | constraint=constraints.positive,
65 | stepsize=spec_step,
66 | )
67 | Parameter(
68 | scene.sources[i].morphology, name=f"morph:{i}", constraint=constraints.positive, stepsize=0.1
69 | )
70 |
71 | maxiter = 10
72 | scene_ = scene.fit(obs, parameters, max_iter=maxiter, progress_bar=False)
73 |
74 |
75 | def test_sample():
76 | # using pre-optimized scene for better warm-up
77 | global scene_
78 | # old style of parameter declaration, check backward compatibility
79 | parameters = scene_.make_parameters()
80 | p = scene_.sources[0].spectrum
81 | prior = dist.Normal(p, scale=1)
82 | parameters += Parameter(p, name="spectrum:0", prior=prior)
83 |
84 | _ = scene_.sample(
85 | obs, parameters, num_samples=10, dense_mass=True, init_strategy=init_to_sample, progress_bar=False
86 | )
87 |
88 |
89 | if __name__ == "__main__":
90 | test_fit()
91 | test_sample()
92 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 |
2 | [project]
3 | name = "scarlet2"
4 | license = {file = "LICENSE"}
5 | readme = "README.md"
6 | authors = [
7 | { name = "Peter Melchior", email = "peter.m.melchior@gmail.com" }
8 | ]
9 | classifiers = [
10 | "Development Status :: 4 - Beta",
11 | "License :: OSI Approved :: MIT License",
12 | "Intended Audience :: Developers",
13 | "Intended Audience :: Science/Research",
14 | "Operating System :: OS Independent",
15 | "Programming Language :: Python",
16 | ]
17 | dynamic = ["version"]
18 | requires-python = ">=3.10"
19 | dependencies = [
20 | "equinox",
21 | "jax==0.6.2",
22 | "astropy",
23 | "numpy",
24 | "matplotlib",
25 | "varname",
26 | "optax",
27 | "numpyro",
28 | "h5py",
29 | "ipywidgets",
30 | "pydantic",
31 | "pyyaml",
32 | "colorama",
33 | "markdown",
34 | "pygments",
35 | "jinja2"
36 | ]
37 |
38 | [project.urls]
39 | "Source Code" = "https://github.com/pmelchior/scarlet2"
40 |
41 | # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes)
42 | [project.optional-dependencies]
43 | dev = [
44 | "asv==0.6.5", # Used to compute performance benchmarks
45 | "huggingface_hub", # Pulls down example data for testing
46 | "jupyter", # Clears output from Jupyter notebooks
47 | "pre-commit", # Used to run checks before finalizing a git commit
48 | "pytest",
49 | "pytest-cov", # Used to report total code coverage
50 | "pytest-mock", # Used to mock objects in tests
51 | "ruff", # Used for static linting of files
52 | ]
53 |
54 | [build-system]
55 | requires = [
56 | "setuptools>=62", # Used to build and package the Python project
57 | "setuptools_scm>=6.2", # Gets release version from git. Makes it available programmatically
58 | ]
59 | build-backend = "setuptools.build_meta"
60 |
61 | [tool.setuptools_scm]
62 | write_to = "src/scarlet2/_version.py"
63 |
64 | [tool.pytest.ini_options]
65 | testpaths = [
66 | "tests",
67 | ]
68 | addopts = "--doctest-modules --doctest-glob=*.rst"
69 |
70 | [tool.ruff]
71 | line-length = 110
72 | target-version = "py310"
73 | [tool.ruff.lint]
74 | select = [
75 | # pycodestyle
76 | "E",
77 | "W",
78 | # Pyflakes
79 | "F",
80 | # pep8-naming
81 | "N",
82 | # pyupgrade
83 | "UP",
84 | # flake8-bugbear
85 | "B",
86 | # flake8-simplify
87 | "SIM",
88 | # isort
89 | "I",
90 | # docstrings
91 | "D101",
92 | "D102",
93 | "D103",
94 | "D106",
95 | "D206",
96 | "D207",
97 | "D208",
98 | "D300",
99 | "D417",
100 | "D419",
101 | # Numpy v2.0 compatibility
102 | "NPY201",
103 | ]
104 | ignore = [
105 | "UP006", # Allow non standard library generics in type hints
106 | "UP007", # Allow Union in type hints
107 | "SIM114", # Allow if with same arms
108 | "B028", # Allow default warning level
109 | "SIM117", # Allow nested with
110 | "UP015", # Allow redundant open parameters
111 | "UP028", # Allow yield in for loop
112 | "UP038", # Allow tuple in `isinstance`
113 | "E731", # Allow assignment of lambda expressions
114 | ]
115 |
116 |
117 | [tool.coverage.run]
118 | omit=["src/scarlet2/_version.py"]
119 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Path setup --------------------------------------------------------------
7 |
8 | # -- Project information -----------------------------------------------------
9 |
10 | project = "scarlet2"
11 | copyright = "2025, Peter Melchior"
12 | author = "Peter Melchior"
13 |
14 | # -- General configuration ---------------------------------------------------
15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
16 |
17 | extensions = [
18 | "sphinx.ext.autodoc",
19 | "sphinx.ext.autosummary",
20 | "sphinx.ext.mathjax",
21 | "sphinx.ext.napoleon",
22 | "sphinx.ext.viewcode",
23 | "sphinx.ext.intersphinx",
24 | "sphinx.ext.doctest",
25 | "sphinx_issues",
26 | "myst_nb",
27 | ]
28 | master_doc = "index"
29 | source_suffix = {
30 | ".rst": "restructuredtext",
31 | ".ipynb": "myst-nb",
32 | }
33 |
34 | templates_path = ["_templates"]
35 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "jupyter_execute"]
36 |
37 | # -- Options for HTML output -------------------------------------------------
38 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
39 |
40 | html_theme = "sphinx_book_theme"
41 | html_static_path = ["_static"]
42 | html_title = "scarlet2"
43 | html_favicon = "_static/icon.png"
44 | html_show_sourcelink = False
45 | html_theme_options = {
46 | "path_to_docs": "docs",
47 | "repository_url": "https://github.com/pmelchior/scarlet2",
48 | "repository_branch": "main",
49 | "logo": {
50 | "image_light": "_static/logo_light.svg",
51 | "image_dark": "_static/logo_dark.svg",
52 | },
53 | "launch_buttons": {"colab_url": "https://colab.research.google.com"},
54 | "use_edit_page_button": True,
55 | "use_issues_button": True,
56 | "use_repository_button": True,
57 | "use_download_button": True,
58 | "show_toc_level": 3,
59 | }
60 | html_baseurl = "https://scarlet2.readthedocs.io/en/latest/"
61 |
62 | autoclass_content = "both"
63 | autosummary_generate = True
64 | autosummary_imported_members = False
65 | autosummary_ignore_module_all = False
66 | autodoc_type_aliases = {
67 | "eqx.Module": "equinox.Module",
68 | "jnp.ndarray": "jax.numpy.array",
69 | }
70 |
71 | intersphinx_mapping = {
72 | "astropy": ("https://docs.astropy.org/en/stable/", None),
73 | "optax": ("https://optax.readthedocs.io/en/latest/", None),
74 | "numpyro": ("https://num.pyro.ai/en/stable/", None),
75 | }
76 |
77 | issues_github_path = "pmelchior/scarlet2"
78 |
79 | nb_execution_timeout = 60
80 | nb_execution_excludepatterns = ["_build", "jupyter_execute"]
81 |
82 | # Napoleon settings
83 | napoleon_google_docstring = False
84 | napoleon_numpy_docstring = True
85 | napoleon_include_init_with_doc = False
86 | napoleon_include_private_with_doc = False
87 | napoleon_include_special_with_doc = True
88 | napoleon_use_admonition_for_examples = True
89 | napoleon_use_admonition_for_notes = True
90 | napoleon_use_admonition_for_references = False
91 | napoleon_use_ivar = False
92 | napoleon_use_param = True
93 | napoleon_use_rtype = True
94 | napoleon_preprocess_types = True
95 | napoleon_type_aliases = None
96 | napoleon_attr_annotations = True
97 |
--------------------------------------------------------------------------------
/tests/scarlet2/test_save_output.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: D101
2 | # ruff: noqa: D102
3 | # ruff: noqa: D103
4 | # ruff: noqa: D106
5 |
6 | import os
7 |
8 | import astropy.wcs as wcs
9 | import h5py
10 | import jax
11 | import jax.numpy as jnp
12 | from huggingface_hub import hf_hub_download
13 |
14 | from scarlet2 import * # noqa: F403
15 | from scarlet2 import init
16 | from scarlet2.bbox import Box
17 | from scarlet2.frame import Frame
18 | from scarlet2.io import model_from_h5, model_to_h5
19 | from scarlet2.observation import Observation
20 | from scarlet2.psf import GaussianPSF
21 | from scarlet2.scene import Scene
22 | from scarlet2.source import Source
23 | from scarlet2.validation_utils import set_validation
24 |
25 | # turn off automatic validation checks
26 | set_validation(False)
27 |
28 |
29 | def test_save_output():
30 | filename = hf_hub_download(
31 | repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset"
32 | )
33 | file = jnp.load(filename)
34 | data = jnp.asarray(file["images"])
35 | centers = [(src["y"], src["x"]) for src in file["catalog"]] # Note: y/x convention!
36 | weights = jnp.asarray(1 / file["variance"])
37 | psf = jnp.asarray(file["psfs"])
38 |
39 | frame_psf = GaussianPSF(0.7)
40 | model_frame = Frame(Box(data.shape), psf=frame_psf)
41 | obs = Observation(data, weights, psf=psf)
42 |
43 | with Scene(model_frame) as scene:
44 | for center in centers:
45 | center = jnp.array(center)
46 | try:
47 | spectrum, morph = init.from_gaussian_moments(obs, center, min_corr=0.99)
48 | except ValueError:
49 | spectrum = init.pixel_spectrum(obs, center)
50 | morph = init.compact_morphology()
51 |
52 | Source(center, spectrum, morph)
53 |
54 | # save the output
55 | id = 1
56 | filename = "demo_io.h5"
57 | path = "stored_models"
58 | model_to_h5(scene, filename, id, path=path, overwrite=True)
59 |
60 | # demo that it works to add models to a single file
61 | id = 2
62 | model_to_h5(scene, filename, id, path=path, overwrite=True)
63 |
64 | # load files and show keys
65 | full_path = os.path.join(path, filename)
66 | with h5py.File(full_path, "r") as f:
67 | print(f.keys())
68 |
69 | # print the output
70 | print(f"Output saved to {full_path}")
71 | # print the storage size
72 | print(f"Storage size: {os.path.getsize(full_path) / 1e6:.4f} MB")
73 | # load the output and plot the sources
74 | scene_loaded = model_from_h5(filename, id, path=path)
75 | print("Output loaded from h5 file")
76 |
77 | # compare scenes
78 | saved = jax.tree_util.tree_leaves(scene)
79 | loaded = jax.tree_util.tree_leaves(scene_loaded)
80 | status = True
81 | for leaf_saved, leaf_loaded in zip(saved, loaded, strict=False):
82 | if isinstance(leaf_saved, wcs.WCS): # wcs doesn't allow direct == comparison...
83 | if not leaf_saved.wcs.compare(leaf_loaded.wcs):
84 | status = False
85 | elif hasattr(leaf_saved, "__iter__"):
86 | if (leaf_saved != leaf_loaded).all():
87 | status = False
88 | else:
89 | if leaf_saved != leaf_loaded:
90 | status = False
91 |
92 | print(f"saved == loaded: {status}")
93 | assert status, "Loaded leaves not identical to original"
94 |
95 |
96 | if __name__ == "__main__":
97 | test_save_output()
98 |
--------------------------------------------------------------------------------
/src/scarlet2/validation_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from dataclasses import dataclass
3 | from typing import Any
4 |
5 | from colorama import Back, Fore, Style
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | # A global switch that toggles automated validation checks.
10 | VALIDATION_SWITCH = True
11 |
12 |
13 | def set_validation(state: bool = True):
14 | """Set the global validation switch.
15 |
16 | Parameters
17 | ----------
18 | state : bool, optional
19 | If True, validation checks will be automatically performed. If False,
20 | they will be skipped. Defaults to True.
21 | """
22 |
23 | global VALIDATION_SWITCH
24 | VALIDATION_SWITCH = state
25 | logger.info(f"Automated validation checks are now {'enabled' if VALIDATION_SWITCH else 'disabled'}.")
26 |
27 |
28 | @dataclass
29 | class ValidationResult:
30 | """Represents a validation result. This is the base dataclass that all the
31 | more specific Validation dataclasses inherit from. Generally, it should
32 | not be instantiated directly, but rather through the more specific
33 | ValidationInfo, ValidationWarning, or ValidationError classes.
34 | """
35 |
36 | message: str
37 | check: str
38 | context: Any | None = None
39 |
40 | def __str__(self):
41 | base = f"{self.message}"
42 | if self.context is not None:
43 | base += f" | Context={self.context})"
44 | return base
45 |
46 |
47 | @dataclass
48 | class ValidationInfo(ValidationResult):
49 | """Represents a validation info message that is informative but not critical."""
50 |
51 | def __str__(self):
52 | return f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN} INFO {Style.RESET_ALL} {super().__str__()}"
53 |
54 |
55 | @dataclass
56 | class ValidationWarning(ValidationResult):
57 | """Represents a validation warning that is not critical but should be noted."""
58 |
59 | def __str__(self):
60 | return f"{Style.BRIGHT}{Fore.BLACK}{Back.YELLOW} WARN {Style.RESET_ALL} {super().__str__()}"
61 |
62 |
63 | @dataclass
64 | class ValidationError(ValidationResult):
65 | """Represents a validation error that is critical and should be addressed."""
66 |
67 | def __str__(self):
68 | return f"{Style.BRIGHT}{Fore.WHITE}{Back.RED} ERROR {Style.RESET_ALL} {super().__str__()}"
69 |
70 |
71 | class ValidationMethodCollector(type):
72 | """Metaclass that collects all validation methods in a class into a single list.
73 | For any class that uses this metaclass, all methods that start with "check_"
74 | will be automatically collected into a class attribute named `validation_checks`.
75 | """
76 |
77 | def __new__(cls, name, bases, namespace):
78 | """Creates a list of callable methods when a new instances of a class is
79 | created."""
80 | cls = super().__new__(cls, name, bases, namespace)
81 | cls.validation_checks = [
82 | attr for attr, value in namespace.items() if callable(value) and attr.startswith("check_")
83 | ]
84 | return cls
85 |
86 |
87 | def print_validation_results(preamble: str, results: list[ValidationResult]):
88 | """Print the validation results in a formatted manner.
89 |
90 | Parameters
91 | ----------
92 | preamble : str
93 | A string to print before the validation results.
94 | results : list[_ValidationResult]
95 | A list of validation results to print.
96 | """
97 | if len(results) == 0:
98 | return
99 | print(
100 | f"{preamble}:\n" + "\n".join(f"[{str(i).zfill(3)}] {str(result)}" for i, result in enumerate(results))
101 | )
102 |
--------------------------------------------------------------------------------
/benchmarks/asv.conf.json:
--------------------------------------------------------------------------------
1 |
2 | {
3 | // The version of the config file format. Do not change, unless
4 | // you know what you are doing.
5 | "version": 1,
6 | // The name of the project being benchmarked.
7 | "project": "scarlet2",
8 | // The project's homepage.
9 | "project_url": "https://github.com/pmelchior/scarlet2",
10 | // The URL or local path of the source code repository for the
11 | // project being benchmarked.
12 | "repo": "..",
13 | // List of branches to benchmark. If not provided, defaults to "master"
14 | // (for git) or "tip" (for mercurial).
15 | "branches": [
16 | "HEAD"
17 | ],
18 | "install_command": [
19 | "python -m pip install {wheel_file}"
20 | ],
21 | "build_command": [
22 | "python -m build --wheel -o {build_cache_dir} {build_dir}"
23 | ],
24 | // The DVCS being used. If not set, it will be automatically
25 | // determined from "repo" by looking at the protocol in the URL
26 | // (if remote), or by looking for special directories, such as
27 | // ".git" (if local).
28 | "dvcs": "git",
29 | // The tool to use to create environments. May be "conda",
30 | // "virtualenv" or other value depending on the plugins in use.
31 | // If missing or the empty string, the tool will be automatically
32 | // determined by looking for tools on the PATH environment
33 | // variable.
34 | "environment_type": "virtualenv",
35 | // the base URL to show a commit for the project.
36 | "show_commit_url": "https://github.com/pmelchior/scarlet2/commit/",
37 | // The Pythons you'd like to test against. If not provided, defaults
38 | // to the current version of Python used to run `asv`.
39 | "pythons": [
40 | "3.11"
41 | ],
42 | // The matrix of dependencies to test. Each key is the name of a
43 | // package (in PyPI) and the values are version numbers. An empty
44 | // list indicates to just test against the default (latest)
45 | // version.
46 | "matrix": {
47 | "Cython": [],
48 | "build": [],
49 | "packaging": []
50 | },
51 | // The directory (relative to the current directory) that benchmarks are
52 | // stored in. If not provided, defaults to "benchmarks".
53 | "benchmark_dir": ".",
54 | // The directory (relative to the current directory) to cache the Python
55 | // environments in. If not provided, defaults to "env".
56 | "env_dir": "env",
57 | // The directory (relative to the current directory) that raw benchmark
58 | // results are stored in. If not provided, defaults to "results".
59 | "results_dir": "_results",
60 | // The directory (relative to the current directory) that the html tree
61 | // should be written to. If not provided, defaults to "html".
62 | "html_dir": "_html",
63 | // The number of characters to retain in the commit hashes.
64 | // "hash_length": 8,
65 | // `asv` will cache wheels of the recent builds in each
66 | // environment, making them faster to install next time. This is
67 | // number of builds to keep, per environment.
68 | "build_cache_size": 8
69 | // The commits after which the regression search in `asv publish`
70 | // should start looking for regressions. Dictionary whose keys are
71 | // regexps matching to benchmark names, and values corresponding to
72 | // the commit (exclusive) after which to start looking for
73 | // regressions. The default is to start from the first commit
74 | // with results. If the commit is `null`, regression detection is
75 | // skipped for the matching benchmark.
76 | //
77 | // "regressions_first_commits": {
78 | // "some_benchmark": "352cdf", // Consider regressions only after this commit
79 | // "another_benchmark": null, // Skip regression detection altogether
80 | // }
81 | }
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # _scarlet2_ Documentation
2 |
3 | _scarlet2_ is an open-source python library for modeling astronomical sources from multi-band, multi-epoch, and
4 | multi-instrument data. It provides non-parametric and parametric models, can handle source overlap (aka blending), and
5 | can integrate neural network priors. It's designed to be modular, flexible, and powerful.
6 |
7 | _scarlet2_ is implemented in [jax](http://jax.readthedocs.io/), layered on top of
8 | the [equinox](https://docs.kidger.site/equinox/)
9 | library. It can be deployed to GPUs and TPUs and supports optimization and sampling approaches.
10 |
11 | ## Installation
12 |
13 | For performance reasons, you should first install `jax` with the suitable `jaxlib` for your platform. After that
14 |
15 | ```
16 | pip install scarlet2
17 | ```
18 |
19 | should do. If you want the latest development version, use
20 |
21 | ```
22 | pip install git+https://github.com/pmelchior/scarlet2.git
23 | ```
24 |
25 | This will allow you to evaluate source models and compute likelihoods of observed data, so you can run your own
26 | optimizer/sampler. If you want a fully fledged library out of the box, you need to install `optax`, `numpyro`, and
27 | `h5py` as well.
28 |
29 | ## Usage
30 |
31 | ```{toctree}
32 | :maxdepth: 2
33 |
34 | 0-quickstart
35 | 1-howto
36 | 2-questionnaire
37 | 3-validation
38 | api
39 | ```
40 |
41 | ## Differences between _scarlet_ and _scarlet2_
42 |
43 | [_scarlet_](https://pmelchior.github.io/scarlet/) was introduced by
44 | [Melchior et al. (2018)](https://doi.org/10.1016/j.ascom.2018.07.001) to solve the deblending problem for the Rubin
45 | Observatory. A stripped down version of it (developed by Fred Moolekamp) runs as part of the Rubin Observatory software
46 | stack and is used for their data releases. We now call this version _scarlet1_.
47 |
48 | _scarlet2_ follows very similar concepts. So, what's different?
49 |
50 | ### Model specification
51 |
52 | _scarlet1_ is designed for a specific purpose: deblending for the Rubin Observatory. That has implications for the
53 | quality and type of data it needs to work with. **_scarlet2_ is much more flexible to handle complex sources
54 | configurations**, e.g. strong-lensing systems, supernova host galaxies, transient sources, etc.
55 |
56 | This flexibility led us to carefully design a "language" to construct sources. It allows new source combinations and
57 | is more explicit about the parameters and their initialization.
58 |
59 | ### Compute
60 |
61 | Because some of the constraints in _scarlet1_ are expensive to evaluate, they are implemented in C++, which requires
62 | the installation of a lot of additional code just to get it to run. **_scarlet2_ is implemented entirely in jax.**
63 | A combination of `conda` and `pip` will get it installed. Unlike _scarlet1_, it will also run on GPUs and TPUs, and
64 | performs just-in-time compilation of the model evaluation.
65 |
66 | In addition, we can now interface with deep learning methods. In particular, we can employ neural networks as
67 | data-driven priors, which helps break the degeneracies that arise when multiple components need to be
68 | fit at the same time.
69 |
70 | ### Constraints
71 |
72 | _scarlet1_ uses constrained optimization to help with fitting degeneracies, but that requires non-standard
73 | (namely proximal) optimization because these constraints are not differentiable.
74 | That can lead to problems with calibration, but, more importantly, it prevents the use of gradient-based
75 | optimization or sampling. As a result, we could never calculate errors for _scarlet1_ models.
76 | **_scarlet2_ uses only constraints that can be differentiated.** It supports any continuous optimization or sampling
77 | method, including error estimates.
78 |
79 | ## Ideas, Questions or Problems?
80 |
81 | If you have any of those, head over to our [github repo](https://github.com/pmelchior/scarlet2/) and create an issue.
82 |
--------------------------------------------------------------------------------
/tests/scarlet2/test_moments.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: D101
2 | # ruff: noqa: D102
3 | # ruff: noqa: D103
4 | # ruff: noqa: D106
5 |
6 | import copy
7 |
8 | import astropy.units as u
9 | import jax.numpy as jnp
10 | from astropy.wcs import WCS
11 | from numpy.testing import assert_allclose
12 |
13 | from scarlet2.frame import _flip_matrix, _rot_matrix
14 | from scarlet2.measure import Moments
15 | from scarlet2.morphology import GaussianMorphology
16 |
17 | # create a test image and measure moments
18 | t0 = 10
19 | ellipticity = jnp.array((0.3, 0.5))
20 | morph = GaussianMorphology(size=t0, ellipticity=ellipticity)
21 | img = morph()
22 | g = Moments(component=img, N=2)
23 |
24 | # create trivial WCS for that image
25 | wcs = WCS(naxis=2)
26 | wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
27 | wcs.wcs.pc = jnp.diag(jnp.ones(2))
28 |
29 |
30 | def test_measure_size():
31 | assert_allclose(t0, g.size, rtol=1e-3)
32 |
33 |
34 | def test_measure_ellipticity():
35 | assert_allclose(ellipticity, g.ellipticity, rtol=2e-3)
36 |
37 |
38 | def test_gaussian_from_moments():
39 | # generate Gaussian from moments
40 | t = g.size
41 | ellipticity = g.ellipticity
42 | morph2 = GaussianMorphology(t, ellipticity)
43 |
44 | assert_allclose(img, morph2(), rtol=1e-2)
45 |
46 |
47 | def test_convolve_moments():
48 | g_ = copy.deepcopy(g)
49 | tp = t0 / 2
50 | psf = GaussianMorphology(size=tp, ellipticity=ellipticity)()
51 | psf /= psf.sum()
52 | p = Moments(psf, N=2)
53 | g_.convolve(p)
54 | assert_allclose(g_.size, jnp.sqrt(t0**2 + tp**2), rtol=3e-4)
55 | assert_allclose(g_.ellipticity, ellipticity, rtol=2e-3)
56 | g_.deconvolve(p)
57 | assert_allclose(g_.size, t0, rtol=3e-4)
58 | assert_allclose(g_.ellipticity, ellipticity, rtol=2e-3)
59 |
60 |
61 | def test_rotate_moments():
62 | # rotate moments counter-clockwise 30 deg
63 | g_ = copy.deepcopy(g)
64 | a = 30 * u.deg
65 | g_.rotate(a)
66 |
67 | # apply theoretical rotation to spin-2 vector
68 | ellipticity_ = ellipticity[0] + 1j * ellipticity[1]
69 | ellipticity_ *= jnp.exp(2j * a.to(u.rad).value)
70 | ellipticity_ = jnp.array((ellipticity_.real, ellipticity_.imag))
71 |
72 | assert_allclose(g_.ellipticity, ellipticity_, rtol=2e-3)
73 |
74 |
75 | def test_resize_moments():
76 | # resize the image
77 | c = 0.5
78 | morph2 = GaussianMorphology(size=t0 * c, ellipticity=ellipticity, shape=img.shape)
79 | g2 = Moments(morph2(), 2)
80 | g2.resize(1 / c)
81 |
82 | assert_allclose(g.size, g2.size, rtol=1e-3)
83 |
84 |
85 | def test_flip_moments():
86 | img2 = jnp.fliplr(img)
87 | g2 = Moments(img2, 2)
88 | g2.fliplr()
89 | assert_allclose(g.ellipticity, g2.ellipticity, rtol=1e-3)
90 |
91 | img3 = jnp.flipud(img)
92 | g3 = Moments(img3, 2)
93 | g3.flipud()
94 | assert_allclose(g.ellipticity, g3.ellipticity, rtol=1e-3)
95 |
96 |
97 | def test_wcs_transfer_moments_rot90():
98 | # create a 90 deg counter-clockwise version of the morph image
99 | im_ = jnp.rot90(img, axes=(1, 0))
100 | g_ = Moments(im_)
101 |
102 | # create mock WCS for that image
103 | wcs_ = WCS(naxis=2)
104 | wcs_.wcs.ctype = ["RA---TAN", "DEC--TAN"]
105 | phi = (-90 * u.deg).to(u.rad).value # clockwise to counteract the rotation above
106 | wcs_.wcs.pc = _rot_matrix(phi)
107 |
108 | # match WCS
109 | g_.transfer(wcs_, wcs)
110 |
111 | # Check that size and ellipticity are conserved
112 | assert_allclose(g_.size, g.size, rtol=1e-3)
113 | assert_allclose(g_.ellipticity, g.ellipticity, rtol=1e-2)
114 |
115 |
116 | def test_wcs_transfer_moments():
117 | # create a rotated, resized, flipped version of the morph image
118 | # apply theoretical rotation to spin-2 vector
119 | a = (30 * u.deg).to(u.rad).value
120 | ellipticity_ = ellipticity[0] + 1j * ellipticity[1]
121 | ellipticity_ *= jnp.exp(2j * a)
122 | ellipticity_ = jnp.array((ellipticity_.real, ellipticity_.imag))
123 | c = 0.5
124 | morph2 = GaussianMorphology(size=t0 * c, ellipticity=ellipticity_, shape=img.shape)
125 | # note order: rescale+rotate, then flip.
126 | im_ = jnp.flipud(morph2())
127 | g_ = Moments(im_)
128 |
129 | # create mock WCS for that image
130 | wcs_ = WCS(naxis=2)
131 | wcs_.wcs.ctype = ["RA---TAN", "DEC--TAN"]
132 | # because of the image creation above: altered order (rotation, then flip)
133 | # this is unusual because for standard WCS, we correct handedness first, then rotate
134 | wcs_.wcs.pc = 1 / c * (_flip_matrix(-1) @ _rot_matrix(-a))
135 |
136 | # match WCS
137 | g_.transfer(wcs_, wcs)
138 |
139 | # Check that size and ellipticity are conserved
140 | assert_allclose(g_.size, g.size, rtol=1e-3)
141 | assert_allclose(g_.ellipticity, g.ellipticity, rtol=1e-2)
142 |
--------------------------------------------------------------------------------
/src/scarlet2/validation.py:
--------------------------------------------------------------------------------
1 | from .module import ParameterValidator
2 | from .observation import ObservationValidator
3 | from .scene import FitValidator
4 | from .source import SourceValidator
5 | from .validation_utils import ValidationResult
6 |
7 |
8 | def _check(validation_class, **kwargs) -> list[ValidationResult]:
9 | """Check the object against the validation rules defined in the validation_class.
10 |
11 | Parameters
12 | ----------
13 | validation_class : type
14 | The class containing the validation checks.
15 | **kwargs : dict
16 | Keyword arguments to pass to the validation class constructor. These should be
17 | the inputs required by the validation classes, such as `scene`, `observation`,
18 | or `source`.
19 |
20 | Returns
21 | -------
22 | list[ValidationResult]
23 | A list of validation results returned from the validation checks for the
24 | given object.
25 | """
26 | validator = validation_class(**kwargs)
27 | validation_results = []
28 | for check in validator.validation_checks:
29 | if error := getattr(validator, check)():
30 | if isinstance(error, list):
31 | validation_results.extend(error)
32 | else:
33 | validation_results.append(error)
34 |
35 | return validation_results
36 |
37 |
38 | def check_fit(scene, observation) -> list[ValidationResult]:
39 | """Check the scene after fitting against the various validation rules.
40 |
41 | Parameters
42 | ----------
43 | scene : Scene
44 | The scene object to check.
45 | observation : Observation
46 | The observation object to use for checks.
47 |
48 | Returns
49 | -------
50 | list[ValidationResult]
51 | A list of validation results returned from the validation checks for the
52 | scene fit results.
53 | """
54 |
55 | return _check(validation_class=FitValidator, **{"scene": scene, "observation": observation})
56 |
57 |
58 | def check_observation(observation) -> list[ValidationResult]:
59 | """Check the observation object for consistency
60 |
61 | Parameters
62 | ----------
63 | observation: Observation
64 | The observation object to check.
65 |
66 | Returns
67 | -------
68 | list[ValidationResult]
69 | A list of validation results from the validation check of the observation
70 | object.
71 | """
72 |
73 | return _check(validation_class=ObservationValidator, **{"observation": observation})
74 |
75 |
76 | def check_scene(scene) -> list[ValidationResult]:
77 | """Check the scene against the various validation rules.
78 |
79 | Parameters
80 | ----------
81 | scene : Scene
82 | The scene object to check.
83 |
84 | Returns
85 | -------
86 | list[ValidationResult]
87 | A list of validation results from the validation checks of the scene.
88 | """
89 |
90 | validation_results = []
91 | for source in scene.sources:
92 | validation_results.extend(check_source(source))
93 |
94 | return validation_results
95 |
96 |
97 | def check_source(source) -> list[ValidationResult]:
98 | """Check the source against the various validation rules.
99 |
100 | Parameters
101 | ----------
102 | source : Source
103 | The source object to check.
104 | scene : Scene
105 | The scene that the source is part of.
106 |
107 | Returns
108 | -------
109 | list[ValidationResult]
110 | A list of validation results from the source object checks.
111 | """
112 |
113 | return _check(validation_class=SourceValidator, **{"source": source})
114 |
115 |
116 | def check_parameters(parameters) -> list[ValidationResult]:
117 | """Check the parameter list against the various validation rules.
118 |
119 | Parameters
120 | ----------
121 | parameters : Parameters
122 | The parameters list to check
123 |
124 | Returns
125 | -------
126 | list[ValidationResult]
127 | A list of validation results from the validation checks of the parameters.
128 | """
129 |
130 | validation_results = []
131 | for p in parameters:
132 | validation_results.extend(check_parameter(p))
133 |
134 | return validation_results
135 |
136 |
137 | def check_parameter(parameter) -> list[ValidationResult]:
138 | """Check the parameter against the various validation rules.
139 |
140 | Parameters
141 | ----------
142 | parameter : Parameter
143 | The parameter to check
144 |
145 | Returns
146 | -------
147 | list[ValidationResult]
148 | A list of validation results from the validation checks of the parameters.
149 | """
150 |
151 | return _check(validation_class=ParameterValidator, **{"parameter": parameter})
152 |
--------------------------------------------------------------------------------
/tests/scarlet2/conftest.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax.numpy as jnp
4 | import numpy as np
5 | import pytest
6 | from huggingface_hub import hf_hub_download
7 | from numpyro.distributions import constraints
8 |
9 | from scarlet2 import init
10 | from scarlet2.frame import Frame
11 | from scarlet2.module import Parameter, Parameters, relative_step
12 | from scarlet2.observation import Observation
13 | from scarlet2.psf import ArrayPSF
14 | from scarlet2.scene import Scene
15 | from scarlet2.source import PointSource, Source
16 |
17 |
18 | @pytest.fixture()
19 | def data_file():
20 | """Download and load a realistic test file. This is the same data used in the
21 | quickstart notebook. The data will be manipulated to create invalid inputs for
22 | the `bad_obs` fixture."""
23 | filename = hf_hub_download(
24 | repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset"
25 | )
26 | return jnp.load(filename)
27 |
28 |
29 | @pytest.fixture()
30 | def good_obs(data_file):
31 | """Create an observation that should pass all validation checks."""
32 | data = jnp.asarray(data_file["images"])
33 | channels = [str(f) for f in data_file["filters"]]
34 | weights = jnp.asarray(1 / data_file["variance"])
35 | psf = jnp.asarray(data_file["psfs"])
36 |
37 | return Observation(
38 | data=data,
39 | weights=weights,
40 | channels=channels,
41 | psf=ArrayPSF(psf),
42 | )
43 |
44 |
45 | @pytest.fixture()
46 | def bad_obs(data_file):
47 | """Create an observation that should fail multiple validation checks."""
48 |
49 | data = np.asarray(data_file["images"])
50 | channels = [str(f) for f in data_file["filters"]]
51 | weights = np.asarray(1 / data_file["variance"])
52 | psf = np.asarray(data_file["psfs"])
53 |
54 | weights = weights[:-1] # Remove the last weight to create a mismatch in dimensions
55 | weights[0][0] = np.inf # Set one weight to infinity
56 | weights[1][0] = -1.0 # Set one weight to a negative value
57 | psf = psf[:-1] # Remove the last PSF to create a mismatch in dimensions
58 | psf = psf[0] + 0.001
59 |
60 | return Observation(
61 | data=data,
62 | weights=weights,
63 | channels=channels,
64 | psf=ArrayPSF(psf),
65 | )
66 |
67 |
68 | @pytest.fixture()
69 | def good_source(good_obs, data_file):
70 | """Assemble a source from the good observation and the data file."""
71 | model_frame = Frame.from_observations(good_obs)
72 | with Scene(model_frame) as _:
73 | centers = jnp.array([(src["y"], src["x"]) for src in data_file["catalog"]]) # Note: y/x convention!
74 | spectrum, morph = init.from_gaussian_moments(good_obs, centers[0], min_corr=0.99)
75 | return Source(centers[0], spectrum, morph)
76 |
77 |
78 | @pytest.fixture()
79 | def bad_source(good_obs, data_file):
80 | """Assemble a source from the bad observation and the data file."""
81 | model_frame = Frame.from_observations(good_obs)
82 | with Scene(model_frame) as _:
83 | centers = jnp.array([(src["y"], src["x"]) for src in data_file["catalog"]]) # Note: y/x convention!
84 | spectrum, morph = init.from_gaussian_moments(good_obs, centers[0], min_corr=0.99)
85 | return Source(centers[0], spectrum * -1, morph)
86 |
87 |
88 | @pytest.fixture()
89 | def scene(good_obs, data_file):
90 | """Assemble a scene from the good observation and the data file."""
91 | model_frame = Frame.from_observations(good_obs)
92 | centers = jnp.array([(src["y"], src["x"]) for src in data_file["catalog"]]) # Note: y/x convention!
93 |
94 | with Scene(model_frame) as scene:
95 | for i, center in enumerate(centers):
96 | if i == 0: # we know source 0 is a star
97 | spectrum = init.pixel_spectrum(good_obs, center, correct_psf=True)
98 | PointSource(center, spectrum)
99 | else:
100 | try:
101 | spectrum, morph = init.from_gaussian_moments(good_obs, center, min_corr=0.99)
102 | except ValueError:
103 | spectrum = init.pixel_spectrum(good_obs, center)
104 | morph = init.compact_morphology()
105 |
106 | Source(center, spectrum, morph)
107 |
108 | return scene
109 |
110 |
111 | @pytest.fixture()
112 | def parameters(scene, good_obs):
113 | """Create parameters for the scene."""
114 | spec_step = partial(relative_step, factor=0.05)
115 | morph_step = partial(relative_step, factor=1e-3)
116 |
117 | with Parameters(scene) as parameters:
118 | for i in range(len(scene.sources)):
119 | Parameter(
120 | scene.sources[i].spectrum,
121 | name=f"spectrum:{i}",
122 | constraint=constraints.positive,
123 | stepsize=spec_step,
124 | )
125 | if i == 0:
126 | Parameter(scene.sources[i].center, name=f"center:{i}", stepsize=0.1)
127 | else:
128 | Parameter(
129 | scene.sources[i].morphology,
130 | name=f"morph:{i}",
131 | constraint=constraints.unit_interval,
132 | stepsize=morph_step,
133 | )
134 |
135 | return parameters
136 |
--------------------------------------------------------------------------------
/src/scarlet2/nn.py:
--------------------------------------------------------------------------------
1 | # To be removed as part of issue #168
2 | # ruff: noqa: D101
3 | # ruff: noqa: D102
4 | # ruff: noqa: D103
5 | # ruff: noqa: D106
6 |
7 | """Neural network priors"""
8 |
9 | from functools import partial
10 |
11 | try:
12 | import numpyro.distributions as dist
13 | import numpyro.distributions.constraints as constraints
14 |
15 | except ImportError as err:
16 | raise ImportError("scarlet2.nn requires numpyro.") from err
17 |
18 | import equinox as eqx
19 | import jax.numpy as jnp
20 | from jax import custom_vjp, vjp
21 |
22 |
23 | def pad_fwd(x, model_shape):
24 | """Zero-pads the input image to the model size
25 |
26 | Parameters
27 | ----------
28 | x : jnp.array
29 | data to be padded
30 | model_shape : tuple
31 | shape of the prior model to be used
32 |
33 | Returns
34 | -------
35 | x : jnp.array
36 | data padded to same size as model_shape
37 | pad: tuple
38 | padding amount in every dimension
39 | """
40 | assert all(model_shape[d] >= x.shape[d] for d in range(x.ndim)), (
41 | "Model size must be larger than data size"
42 | )
43 | if model_shape == x.shape:
44 | pad = 0
45 | return x, pad
46 |
47 | pad = tuple(
48 | # even padding
49 | (int(gap / 2), int(gap / 2))
50 | if (gap := model_shape[d] - x.shape[d]) % 2 == 0
51 | # uneven padding
52 | else (int(gap // 2), int(gap // 2) + 1)
53 | # over all dimensions
54 | for d in range(x.ndim)
55 | )
56 | # perform the zero-padding
57 | x = jnp.pad(x, pad, "constant", constant_values=0)
58 | return x, pad
59 |
60 |
61 | # reverse pad back to original size
62 | def pad_back(x, pad):
63 | """Removes the zero-padding from the input image
64 |
65 | Parameters
66 | ---------
67 | x : jnp.array
68 | padded data to same size as model_shape
69 | pad: tuple
70 | padding amount in every dimension
71 |
72 | Returns
73 | -------
74 | x : jnp.array
75 | data returned to it pre-pad shape
76 | """
77 | slices = tuple(slice(low, -hi) if hi > 0 else slice(low, None) for (low, hi) in pad)
78 | return x[slices]
79 |
80 |
81 | # calculate score function (jacobian of log-probability)
82 | def calc_grad(x, model):
83 | """Calculates the gradient of the log-prior
84 | using the ScoreNet model chosen
85 |
86 | Parameters
87 | ----------
88 | x:
89 | array of the data
90 | model:
91 | the model to calculate the score function
92 |
93 | Returns
94 | -------
95 | score_func : array of the score function
96 | """
97 | # cast to float32, expand to (batch, shape), and pad to match the shape of the score model
98 | x_, pad = pad_fwd(jnp.float32(x), model.shape)
99 |
100 | # run score model, expects (batch, shape)
101 | if jnp.ndim(x) == len(model.shape):
102 | x_ = jnp.expand_dims(x_, axis=0)
103 | score_func = model.func(x_)
104 | if jnp.ndim(x) == len(model.shape):
105 | score_func = jnp.squeeze(score_func, axis=0)
106 |
107 | # remove padding
108 | if pad != 0:
109 | score_func = pad_back(score_func, pad)
110 | return score_func
111 |
112 |
113 | # jax gradient function to calculate jacobian
114 | def vgrad(f, x):
115 | y, vjp_fn = vjp(f, x)
116 | return vjp_fn(jnp.ones(y.shape))[0]
117 |
118 |
119 | # Here we define a custom vjp for the log_prob function
120 | # such that for gradient calls in jax, the score prior
121 | # is returned
122 |
123 |
124 | @partial(custom_vjp, nondiff_argnums=(0,))
125 | def _log_prob(model, x):
126 | return 0.0
127 |
128 |
129 | def _log_prob_fwd(model, x):
130 | score_func = calc_grad(x, model)
131 | return 0.0, score_func # cannot directly call log_prob in Class object
132 |
133 |
134 | def _log_prob_bwd(model, res, g):
135 | score_func = res # Get residuals computed in f_fwd
136 | return (g * score_func,) # create the vector (g) jacobian (score_func) product
137 |
138 |
139 | # register the custom vjp
140 | _log_prob.defvjp(_log_prob_fwd, _log_prob_bwd)
141 |
142 |
143 | class ScorePrior(dist.Distribution):
144 | class ScoreWrapper(eqx.Module):
145 | func: callable
146 | shape: tuple
147 |
148 | support = constraints.real_vector
149 | _model = ScoreWrapper
150 |
151 | def __init__(self, model, shape, *args, **kwargs):
152 | """Score-matching neural network to represent the prior distribution
153 |
154 | This class is used to calculate the gradient of the log-probability of the prior distribution.
155 | A custom vjp is created to return the score when calling `jax.grad()`.
156 |
157 | Parameters
158 | ----------
159 | model: callable
160 | Returns the score value given parameter: `model(x) -> score`
161 | shape: tuple
162 | Shape of the parameter the model can accept
163 | *args: tuple
164 | List of unnamed parameter for model, e.g. `model(x, *args) -> score`
165 | **kwargs: dict
166 | List of named parameter for model, e.g. `model(x, **kwargs) -> score`
167 | """
168 | # helper class that ensures the model function binds the args/kwargs and has a shape
169 | wrapper = ScorePrior.ScoreWrapper(partial(model.__call__, *args, **kwargs), shape)
170 | self._model = wrapper
171 |
172 | super().__init__(
173 | validate_args=None,
174 | )
175 |
176 | def __call__(self, x):
177 | return self._model.func(x)
178 |
179 | def sample(self, key, sample_shape=()):
180 | # TODO: add ability to draw samples from the prior, if desired
181 | raise NotImplementedError
182 |
183 | def mean(self):
184 | raise NotImplementedError
185 |
186 | def log_prob(self, x):
187 | return _log_prob(self._model, x)
188 |
--------------------------------------------------------------------------------
/src/scarlet2/spectrum.py:
--------------------------------------------------------------------------------
1 | import equinox as eqx
2 | import jax.numpy as jnp
3 |
4 | from . import Scenery
5 | from .module import Module
6 |
7 |
8 | class Spectrum(Module):
9 | """Spectrum base class"""
10 |
11 | @property
12 | def shape(self):
13 | """Shape (1D) of the spectrum model"""
14 | raise NotImplementedError
15 |
16 |
17 | class StaticArraySpectrum(Spectrum):
18 | """Static (non-variable) source in a transient scene
19 |
20 | In the frames of transient scenes, the attribute :py:attr:`~scarlet2.Frame.channels`
21 | are overloaded and defined with a spectral and a temporal component, e.g.
22 | `channel = (band, epoch)`.
23 | This class is for models that do not vary in time, i.e. only have a spectral dependency.
24 | The length of :py:attr:`data` is thus given by the number of distinct spectral bands.
25 | """
26 |
27 | data: jnp.array
28 | """Data to describe the static spectrum
29 |
30 | The order in this array should be given by :py:attr:`bands`.
31 | """
32 | bands: list
33 | """Identifier for the list of unique bands in the model frame channels"""
34 | _channelindex: jnp.array = eqx.field(repr=False)
35 |
36 | def __init__(self, data, bands, band_selector=lambda channel: channel[0]):
37 | """
38 | Parameters
39 | ----------
40 | data: array
41 | Spectrum without temporal variation. Contains as many elements as there
42 | are spectral channels in the model.
43 | bands: list, array
44 | Identifier for the list of unique bands in the model frame channels
45 | band_selector: callable, optional
46 | Identify the spectral "band" component from the name/ID used in the
47 | channels of the model frame
48 |
49 | Examples
50 | --------
51 | >>> # model channels: [('G',0),('G',1),('R',0),('R',1),('R',2)]
52 | >>> spectrum = jnp.ones(2)
53 | >>> bands = ['G','R']
54 | >>> band_selector = lambda channel: channel[0]
55 | >>> StaticArraySpectrum(spectrum, bands, band_selector=band_selector)
56 |
57 | This constructs a 2-element spectrum to describe the spectral properties
58 | in all epochs 0,1,2.
59 |
60 | See Also
61 | --------
62 | TransientArraySpectrum
63 | """
64 | try:
65 | frame = Scenery.scene.frame
66 | except AttributeError:
67 | print("Source can only be created within the context of a Scene")
68 | print("Use 'with Scene(frame) as scene: Source(...)'")
69 | raise
70 |
71 | self.data = data
72 | self.bands = bands
73 | self._channelindex = jnp.array([self.bands.index(band_selector(c)) for c in frame.channels])
74 |
75 | def __call__(self):
76 | """What to run when the StaticArraySpectrum is called"""
77 | return self.data[self._channelindex]
78 |
79 | @property
80 | def shape(self):
81 | """The shape of the spectrum data"""
82 | return (len(self._channelindex),)
83 |
84 |
85 | class TransientArraySpectrum(Spectrum):
86 | """Variable source in a transient scene with possible quiescent periods
87 |
88 | In the frames of transient scenes, the attribute :py:attr:`~scarlet2.Frame.channels`
89 | are overloaded and defined with a spectral and a temporal component, e.g.
90 | `channel = (band, epoch)`. This class is for models that vary in time, especially
91 | if they have periods of inactivity. The length of :py:attr:`data` is given by
92 | the number channels in the model frame, but during inactive epochs, the emission
93 | is set to zero.
94 | """
95 |
96 | data: jnp.array
97 | """Data to describe the variable spectrum.
98 |
99 | The length of this vector is identical to the number of channels in the model frame.
100 | """
101 | epochs: list
102 | """Identifier for the list of active epochs. If set to `None`, all epochs are
103 | considered active"""
104 | _epochmultiplier: jnp.array = eqx.field(repr=False)
105 |
106 | def __init__(self, data, epochs=None, epoch_selector=lambda channel: channel[1]):
107 | """
108 | Parameters
109 | ----------
110 | data: array
111 | Spectrum array. Contains as many elements as there are spectro-temporal
112 | channels in the model.
113 | epochs: list, array, optional
114 | List of temporal "epoch" identifiers for the active phases of the source.
115 | epoch_selector: callable, optional
116 | Identify the temporal "epoch" component from the name/ID used in the
117 | channels of the model frame
118 |
119 | Examples
120 | --------
121 | >>> # model channels: [('G',0),('G',1),('R',0),('R',1),('R',2)]
122 | >>> spectrum = jnp.ones(5)
123 | >>> epochs = [0, 1]
124 | >>> epoch_selector = lambda channel: channel[1]
125 | >>> TransientArraySpectrum(spectrum, epochs, epoch_selector=epoch_selector)
126 |
127 | This sets the spectrum to active during epochs 0 and 1, and mask the
128 | spectrum element for `('R',2)` with zero.
129 |
130 | See Also
131 | --------
132 | StaticArraySpectrum
133 | """
134 | try:
135 | frame = Scenery.scene.frame
136 | except AttributeError:
137 | print("Source can only be created within the context of a Scene")
138 | print("Use 'with Scene(frame) as scene: Source(...)'")
139 | raise
140 | self.data = data
141 | self.epochs = epochs
142 | self._epochmultiplier = jnp.array(
143 | [1.0 if epoch_selector(c) in epochs else 0.0 for c in frame.channels]
144 | )
145 |
146 | def __call__(self):
147 | """What to run when the TransientArraySpectrum is called"""
148 | return jnp.multiply(self.data, self._epochmultiplier)
149 |
150 | @property
151 | def shape(self):
152 | """The shape of the spectrum data"""
153 | return self.data.shape
154 |
--------------------------------------------------------------------------------
/tests/scarlet2/test_observation_checks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from scarlet2.observation import Observation, ObservationValidator
4 | from scarlet2.psf import ArrayPSF
5 | from scarlet2.validation_utils import ValidationError, ValidationInfo, set_validation
6 |
7 |
8 | @pytest.fixture(autouse=True)
9 | def setup_validation():
10 | """Automatically disable validation for all tests. This permits the creation
11 | of intentionally invalid Observation objects."""
12 | set_validation(False)
13 |
14 |
15 | def test_weights_non_negative_returns_error(bad_obs):
16 | """Test that the weights in the observation are non-negative."""
17 | checker = ObservationValidator(bad_obs)
18 |
19 | results = checker.check_weights_non_negative()
20 |
21 | assert isinstance(results, ValidationError)
22 |
23 |
24 | def test_weights_finite_returns_error(bad_obs):
25 | """Test that the weights in the observation are finite."""
26 | checker = ObservationValidator(bad_obs)
27 |
28 | results = checker.check_weights_finite()
29 |
30 | assert isinstance(results, ValidationError)
31 |
32 |
33 | def test_weights_non_negative_returns_none(good_obs):
34 | """Test that the weights in the observation are non-negative."""
35 | checker = ObservationValidator(good_obs)
36 |
37 | results = checker.check_weights_non_negative()
38 |
39 | assert isinstance(results, ValidationInfo)
40 |
41 |
42 | def test_weights_finite_returns_none(good_obs):
43 | """Test that the weights in the observation are finite."""
44 | checker = ObservationValidator(good_obs)
45 |
46 | results = checker.check_weights_finite()
47 |
48 | assert isinstance(results, ValidationInfo)
49 |
50 |
51 | def test_data_and_weights_shape_returns_error(bad_obs):
52 | """Test that the data and weights have the same shape."""
53 | checker = ObservationValidator(bad_obs)
54 |
55 | results = checker.check_data_and_weights_shape()
56 |
57 | assert isinstance(results, ValidationError)
58 | assert results.message == "Data and weights must have the same shape."
59 |
60 |
61 | def test_data_and_weights_shape_returns_none(good_obs):
62 | """Test that the data and weights have the same shape."""
63 | checker = ObservationValidator(good_obs)
64 |
65 | results = checker.check_data_and_weights_shape()
66 |
67 | assert isinstance(results, ValidationInfo)
68 |
69 |
70 | def test_num_channels_matches_data_returns_none(good_obs):
71 | """Test that the number of channels in the observation matches the data."""
72 | checker = ObservationValidator(good_obs)
73 |
74 | results = checker.check_num_channels_matches_data()
75 |
76 | assert isinstance(results, ValidationInfo)
77 |
78 |
79 | def test_data_finite_for_non_zero_weights_returns_none(good_obs):
80 | """Test that the data in the observation is finite where weights are greater than zero."""
81 | checker = ObservationValidator(good_obs)
82 |
83 | results = checker.check_data_finite_for_non_zero_weights()
84 |
85 | assert isinstance(results, ValidationInfo)
86 |
87 |
88 | def test_data_finite_for_non_zero_weights_returns_none_with_infinity():
89 | """Test that non-finite data does not raise an error when weights are zero."""
90 | obs = Observation(
91 | data=np.array([[np.inf, 2.0, 3.0], [4.0, 5.0, 6.0]]),
92 | weights=np.array([[0.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
93 | channels=[0, 1],
94 | )
95 |
96 | checker = ObservationValidator(obs)
97 | results = checker.check_data_finite_for_non_zero_weights()
98 |
99 | assert isinstance(results, ValidationInfo)
100 |
101 |
102 | def test_data_finite_for_non_zero_weights_returns_error_with_infinity():
103 | """Test that non-finite data raises an error when weights are non-zero."""
104 | obs = Observation(
105 | data=np.array([[np.inf, np.inf, 3.0], [4.0, 5.0, 6.0]]),
106 | weights=np.array([[0.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
107 | channels=[0, 1],
108 | )
109 |
110 | checker = ObservationValidator(obs)
111 | results = checker.check_data_finite_for_non_zero_weights()
112 |
113 | assert isinstance(results, ValidationError)
114 | assert results.message == "Data in the observation must be finite."
115 |
116 |
117 | def test_number_of_psf_channels(good_obs):
118 | """Test that the number of PSF channels matches the observation channels."""
119 | checker = ObservationValidator(good_obs)
120 |
121 | results = checker.check_number_of_psf_channels()
122 | assert isinstance(results, ValidationInfo)
123 |
124 |
125 | def test_number_of_psf_channels_returns_error(bad_obs):
126 | """Test that the number of PSF channels does not match the observation channels."""
127 | checker = ObservationValidator(bad_obs)
128 |
129 | results = checker.check_psf_has_3_dimensions()
130 |
131 | assert isinstance(results, ValidationError)
132 | assert results.message == "PSF must be 3-dimensional."
133 |
134 |
135 | def test_validation_on_runs_observation_checks(capsys):
136 | """Set auto-validation to True, expect that the non-finite data raises an error
137 | when weights are non-zero."""
138 | set_validation(True)
139 | _ = Observation(
140 | data=np.array([[np.inf, np.inf, 3.0], [4.0, 5.0, 6.0]]),
141 | weights=np.array([[0.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
142 | channels=[0, 1],
143 | )
144 | captured = capsys.readouterr()
145 | assert "Observation validation results:" in captured.out
146 |
147 |
148 | def test_psf_centroid_inconsistent_returns_error(data_file):
149 | """Test that when the PSF centroids are not consistent across bands, an error
150 | is raised."""
151 |
152 | data = np.asarray(data_file["images"])
153 | channels = [str(f) for f in data_file["filters"]]
154 | weights = np.asarray(1 / data_file["variance"])
155 | psf = np.asarray(data_file["psfs"])
156 | psf[0][1:, :] = psf[0][:-1, :] # Shift the PSF to create an inconsistency
157 |
158 | bad_obs = Observation(
159 | data=data,
160 | weights=weights,
161 | channels=channels,
162 | psf=ArrayPSF(psf),
163 | )
164 |
165 | checker = ObservationValidator(bad_obs)
166 | results = checker.check_psf_centroid_consistent()
167 |
168 | assert isinstance(results, ValidationError)
169 | assert results.message == "PSF centroid is not the same in all channels."
170 |
171 |
172 | def test_psf_centroid_consistent_no_error(data_file):
173 | """Test that when the PSF centroids are consistent across bands, no errors
174 | are raised."""
175 |
176 | data = np.asarray(data_file["images"])
177 | channels = [str(f) for f in data_file["filters"]]
178 | weights = np.asarray(1 / data_file["variance"])
179 | psf = np.zeros_like(data)
180 | psf[:, 1, 1] = 1
181 |
182 | bad_obs = Observation(
183 | data=data,
184 | weights=weights,
185 | channels=channels,
186 | psf=ArrayPSF(psf),
187 | )
188 |
189 | checker = ObservationValidator(bad_obs)
190 |
191 | results = checker.check_psf_centroid_consistent()
192 | assert isinstance(results, ValidationInfo)
193 |
--------------------------------------------------------------------------------
/tests/scarlet2/test_renderer.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: D101
2 | # ruff: noqa: D102
3 | # ruff: noqa: D103
4 | # ruff: noqa: D106
5 |
6 | import astropy.units as u
7 | import jax.numpy as jnp
8 | from numpy.testing import assert_allclose
9 |
10 | import scarlet2
11 | from scarlet2 import ArrayPSF
12 | from scarlet2.frame import _flip_matrix, _rot_matrix, _wcs_default, get_affine
13 | from scarlet2.measure import Moments
14 | from scarlet2.morphology import GaussianMorphology
15 |
16 | # resampling renderer to test
17 | cls = scarlet2.renderer.ResamplingRenderer
18 |
19 | # create a Gaussian as model
20 | T = 10
21 | eps = jnp.array((0.5, 0.3))
22 | model = GaussianMorphology(size=T, ellipticity=eps, shape=(150, 151))()[None, ...] # test even & odd
23 | g = scarlet2.measure.Moments(component=model[0], N=2)
24 |
25 | # make a Frame
26 | model_frame = scarlet2.Frame(scarlet2.Box(model.shape))
27 |
28 |
29 | def test_rescale():
30 | # how does this model look if we change the WCS scale
31 | # if scale becomes larger, the image gets smaller
32 | scale = 3.1
33 | shape = (int(model.shape[-2] / scale) + 1, int(model.shape[-1] / scale))
34 | wcs_obs = _wcs_default(shape)
35 | m = scarlet2.frame.get_affine(wcs_obs)
36 | m = scale * m
37 | wcs_obs.wcs.pc = m
38 |
39 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), wcs=wcs_obs)
40 | renderer = cls(model_frame, obs_frame)
41 | model_ = renderer(model)
42 |
43 | # for testing outputs...
44 | # print(renderer)
45 | # fig, ax = plt.subplots(1, 2, figsize=(8, 4))
46 | # ax[0].imshow(model[0])
47 | # ax[1].imshow(model_[0])
48 |
49 | # undo resizing
50 | g_obs = Moments(model_[0])
51 | g_obs.resize(scale)
52 | assert_allclose(g_obs.flux, g.flux, atol=3e-3)
53 | shift_ = jnp.asarray(g_obs.centroid) - jnp.asarray(obs_frame.bbox.spatial.center)
54 | assert_allclose(shift_, (0, 0), atol=1e-4)
55 | assert_allclose(g_obs.size, g.size, rtol=3e-5)
56 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5)
57 |
58 |
59 | def test_rotate():
60 | # how does this model look if we rotated the WCS scale
61 | # because we change the frame, this will appear as a rotation in the opposite direction
62 | shape = model.shape
63 | phi = (30 * u.deg).to(u.rad).value
64 | wcs_obs = _wcs_default(shape)
65 | m = scarlet2.frame.get_affine(wcs_obs)
66 | r = _rot_matrix(phi)
67 | m = r @ m
68 | wcs_obs.wcs.pc = m
69 |
70 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), wcs=wcs_obs)
71 | renderer = cls(model_frame, obs_frame)
72 | model_ = renderer(model)
73 |
74 | # rotate to correct for the counter-rotation of the frame
75 | g_obs = Moments(model_[0])
76 | g_obs.rotate(phi)
77 | assert_allclose(g_obs.flux, g.flux, atol=3e-3)
78 | assert_allclose(g_obs.centroid, g.centroid, atol=1e-4)
79 | assert_allclose(g_obs.size, g.size, rtol=3e-5)
80 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5)
81 |
82 |
83 | def test_flip():
84 | shape = model.shape
85 | wcs_obs = _wcs_default(shape)
86 | m = scarlet2.frame.get_affine(wcs_obs)
87 | f = _flip_matrix(-1)
88 | m = f @ m
89 | wcs_obs.wcs.pc = m
90 |
91 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), wcs=wcs_obs)
92 | renderer = cls(model_frame, obs_frame)
93 | model_ = renderer(model)
94 |
95 | # undo the flip
96 | g_obs = Moments(model_[0])
97 | g_obs.flipud()
98 | assert_allclose(g_obs.flux, g.flux, atol=3e-3)
99 | assert_allclose(g_obs.centroid, g.centroid, atol=1e-4)
100 | assert_allclose(g_obs.size, g.size, rtol=3e-5)
101 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5)
102 |
103 |
104 | def test_translation():
105 | shift = jnp.array((12, -1.9))
106 | shape = model.shape
107 | wcs_obs = _wcs_default(shape)
108 | wcs_obs.wcs.crpix += shift[::-1] # x/y
109 |
110 | shape = model.shape
111 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), wcs=wcs_obs)
112 | renderer = cls(model_frame, obs_frame)
113 | model_ = renderer(model)
114 |
115 | g_obs = Moments(model_[0])
116 | shift_ = jnp.asarray(g_obs.centroid) - jnp.asarray(g.centroid)
117 |
118 | assert_allclose(g_obs.flux, g.flux, rtol=2e-3)
119 | assert_allclose(shift_, -shift, atol=1e-4)
120 | assert_allclose(g_obs.size, g.size, rtol=3e-5)
121 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5)
122 |
123 |
124 | def test_convolution():
125 | # create model PSF and convolve g
126 | t_p = 0.7
127 | eps_p = None
128 | psf = GaussianMorphology(size=t_p, ellipticity=eps_p)()
129 | psf /= psf.sum()
130 | p = Moments(psf, N=2)
131 | psf = ArrayPSF(psf[None, ...])
132 | model_frame_ = scarlet2.Frame(scarlet2.Box(model.shape), psf=psf)
133 |
134 | # create obs PSF
135 | t_obs, eps_obs = 5, jnp.array((-0.1, 0.1))
136 | psf_obs = GaussianMorphology(size=t_obs, ellipticity=eps_obs)()
137 | psf_obs /= psf_obs.sum()
138 | p_obs = Moments(psf_obs, N=2)
139 | psf_obs = ArrayPSF(psf_obs[None, ...])
140 |
141 | # render: deconvolve, reconcolve
142 | shape = model.shape
143 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), psf=psf_obs)
144 | renderer = cls(model_frame_, obs_frame)
145 | model_ = renderer(model)
146 | g_obs = Moments(model_[0])
147 |
148 | g_obs.deconvolve(p_obs).convolve(p)
149 | assert_allclose(g_obs.flux, g.flux, atol=3e-3)
150 | assert_allclose(g_obs.centroid, g.centroid, atol=1e-4)
151 | assert_allclose(g_obs.size, g.size, rtol=3e-5)
152 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5)
153 |
154 |
155 | def test_all():
156 | scale = 2.1
157 | phi = (70 * u.deg).to(u.rad).value
158 | shift = jnp.array((1.4, -0.456))
159 | shape = (int(model.shape[1] // scale), int(model.shape[2] // scale))
160 | wcs_obs = _wcs_default(shape)
161 | m = get_affine(wcs_obs)
162 | wcs_obs.wcs.pc = scale * _rot_matrix(phi) @ _flip_matrix(-1) @ m
163 | wcs_obs.wcs.crpix += shift[::-1] # x/y
164 |
165 | # create model PSF and convolve g
166 | t_p = 1
167 | eps_p = None
168 | psf = GaussianMorphology(size=t_p, ellipticity=eps_p)()
169 | psf /= psf.sum()
170 | p = Moments(psf, N=2)
171 | psf = ArrayPSF(psf[None, ...])
172 | model_frame_ = scarlet2.Frame(scarlet2.Box(model.shape), psf=psf)
173 |
174 | # obs PSF
175 | t_obs, eps_obs = 3, jnp.array((0.1, -0.1))
176 | psf_obs = GaussianMorphology(size=t_obs, ellipticity=eps_obs)()
177 | psf_obs /= psf_obs.sum()
178 | p_obs = Moments(psf_obs, N=2)
179 | psf_obs = ArrayPSF(psf_obs[None, ...])
180 |
181 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), psf=psf_obs, wcs=wcs_obs)
182 | renderer = cls(model_frame_, obs_frame)
183 | model_ = renderer(model)
184 |
185 | g_obs = Moments(model_[0])
186 | g_obs.deconvolve(p_obs).flipud().rotate(phi).resize(scale).convolve(p)
187 | shift_ = jnp.asarray(g_obs.centroid) - jnp.asarray(obs_frame.bbox.spatial.center)
188 |
189 | assert_allclose(g_obs.flux, g.flux, rtol=2e-3)
190 | assert_allclose(shift_, -shift, atol=1e-4)
191 | assert_allclose(g_obs.size, g.size, rtol=3e-5)
192 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5)
193 |
--------------------------------------------------------------------------------
/src/scarlet2/wavelets.py:
--------------------------------------------------------------------------------
1 | """Wavelet functions"""
2 | # from https://github.com/pmelchior/scarlet/blob/master/scarlet/wavelet.py
3 |
4 | import jax.numpy as jnp
5 |
6 |
7 | class Starlet:
8 | """Wavelet transform of a images (2D or 3D) with the 'a trou' algorithm.
9 |
10 | The transform is performed by convolving the image by a seed starlet: the transform of an all-zero
11 | image with its central pixel set to one. This requires 2-fold padding of the image and an odd pad
12 | shape. The fft of the seed starlet is cached so that it can be reused in the transform of other
13 | images that have the same shape.
14 | """
15 |
16 | def __init__(self, image, coefficients, generation, convolve2d):
17 | """
18 | Parameters
19 | ----------
20 | image: array
21 | Image in real space.
22 | coefficients: array
23 | Starlet transform of the image.
24 | generation: int
25 | The generation of the starlet transform (either `1` or `2`).
26 | convolve2d: array
27 | The filter used to convolve the image and create the wavelets.
28 | When `convolve2d` is `None` this uses a cubic bspline.
29 | """
30 | self._image = image
31 | self._coeffs = coefficients
32 | self._generation = generation
33 | self._convolve2d = convolve2d
34 | self._norm = None
35 |
36 | @property
37 | def image(self):
38 | """The real space image"""
39 | return self._image
40 |
41 | @property
42 | def coefficients(self):
43 | """Starlet coefficients"""
44 | return self._coeffs
45 |
46 | @staticmethod
47 | def from_image(image, scales=None, generation=2, convolve2d=None):
48 | """Generate a set of starlet coefficients for an image
49 |
50 | Parameters
51 | ----------
52 | image: array-like
53 | The image that is converted into starlet coefficients
54 | scales: int
55 | The number of starlet scales to use.
56 | If `scales` is `None` then the maximum number of scales is used.
57 | Note: this is the length of the coefficients-1, as in the notation
58 | of `Starck et al. 2011`.
59 | generation: int
60 | The generation of the starlet transform (either `1` or `2`).
61 | convolve2d: array-like
62 | The filter used to convolve the image and create the wavelets.
63 | When `convolve2D` is `None` this uses a cubic bspline.
64 |
65 | Returns
66 | -------
67 | result: Starlet
68 | The resulting `Starlet` that contains the image, starlet coefficients,
69 | as well as the parameters used to generate the coefficients.
70 | """
71 | if scales is None:
72 | scales = get_scales(image.shape)
73 | coefficients = starlet_transform(image, scales, generation, convolve2d)
74 | return Starlet(image, coefficients, generation, convolve2d)
75 |
76 |
77 | def bspline_convolve(image, scale):
78 | """Convolve an image with a bpsline at a given scale.
79 |
80 | This uses the spline
81 | `h1d = jnp.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])`
82 | from Starck et al. 2011.
83 |
84 | Parameters
85 | ----------
86 | image: 2D array
87 | The image or wavelet coefficients to convolve.
88 | scale: int
89 | The wavelet scale for the convolution. This sets the
90 | spacing between adjacent pixels with the spline.
91 |
92 | """
93 | # Filter for the scarlet transform. Here bspline
94 | h1d = jnp.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])
95 | j = scale
96 |
97 | slice0 = slice(None, -(2 ** (j + 1)))
98 | slice1 = slice(None, -(2**j))
99 | slice3 = slice(2**j, None)
100 | slice4 = slice(2 ** (j + 1), None)
101 |
102 | # row
103 | col = image * h1d[2]
104 | col = col.at[slice4].add(image[slice0] * h1d[0])
105 | col = col.at[slice3].add(image[slice1] * h1d[1])
106 | col = col.at[slice1].add(image[slice3] * h1d[3])
107 | col = col.at[slice0].add(image[slice4] * h1d[4])
108 |
109 | # column
110 | result = col * h1d[2]
111 | result = result.at[:, slice4].add(col[:, slice0] * h1d[0])
112 | result = result.at[:, slice3].add(col[:, slice1] * h1d[1])
113 | result = result.at[:, slice1].add(col[:, slice3] * h1d[3])
114 | result = result.at[:, slice0].add(col[:, slice4] * h1d[4])
115 | return result
116 |
117 |
118 | def starlet_transform(image, scales=None, generation=2, convolve2d=None):
119 | """Perform a scarlet transform, or 2nd gen starlet transform.
120 |
121 | Parameters
122 | ----------
123 | image: 2D array
124 | The image to transform into starlet coefficients.
125 | scales: int
126 | The number of scale to transform with starlets.
127 | The total dimension of the starlet will have
128 | `scales+1` dimensions, since it will also hold
129 | the image at all scales higher than `scales`.
130 | generation: int
131 | The generation of the transform.
132 | This must be `1` or `2`. Default is `2`.
133 | convolve2d: function
134 | The filter function to use to convolve the image
135 | with starlets in 2D.
136 |
137 | Returns
138 | -------
139 | starlet: array with dimension (scales+1, Ny, Nx)
140 | The starlet dictionary for the input `image`.
141 | """
142 | assert len(image.shape) == 2, f"Image should be 2D, got {len(image.shape)}"
143 | assert generation in (1, 2), f"generation should be 1 or 2, got {generation}"
144 |
145 | scales = get_scales(image.shape, scales)
146 | c = image
147 | if convolve2d is None:
148 | convolve2d = bspline_convolve
149 |
150 | ## wavelet set of coefficients.
151 | starlet = jnp.zeros((scales + 1,) + image.shape)
152 | for j in range(scales):
153 | gen1 = convolve2d(c, j)
154 |
155 | if generation == 2:
156 | gen2 = convolve2d(gen1, j)
157 | starlet = starlet.at[j].set(c - gen2)
158 | else:
159 | starlet = starlet.at[j].set(c - gen1)
160 |
161 | c = gen1
162 |
163 | starlet = starlet.at[-1].set(c)
164 | return starlet
165 |
166 |
167 | def starlet_reconstruction(starlets, generation=2, convolve2d=None):
168 | """Reconstruct an image from a dictionary of starlets
169 |
170 | Parameters
171 | ----------
172 | starlets: array with dimension (scales+1, Ny, Nx)
173 | The starlet dictionary used to reconstruct the image.
174 | generation: int
175 | The generation of the transform.
176 | This must be `1` or `2`. Default is `2`.
177 | convolve2d: function
178 | The filter function to use to convolve the image
179 | with starlets in 2D.
180 |
181 | Returns
182 | -------
183 | image: 2D array
184 | The image reconstructed from the input `starlet`.
185 | """
186 | if generation == 1:
187 | return jnp.sum(starlets, axis=0)
188 | if convolve2d is None:
189 | convolve2d = bspline_convolve
190 | scales = len(starlets) - 1
191 |
192 | c = starlets[-1]
193 | for i in range(1, scales + 1):
194 | j = scales - i
195 | cj = convolve2d(c, j)
196 | c = cj + starlets[j]
197 | return c
198 |
199 |
200 | def get_scales(image_shape, scales=None):
201 | """Get the number of scales to use in the starlet transform.
202 |
203 | Parameters
204 | ----------
205 | image_shape: tuple
206 | The 2D shape of the image that is being transformed
207 | scales: int
208 | The number of scale to transform with starlets.
209 | The total dimension of the starlet will have
210 | `scales+1` dimensions, since it will also hold
211 | the image at all scales higher than `scales`.
212 | """
213 | # Number of levels for the Starlet decomposition
214 | max_scale = int(jnp.log2(min(image_shape[-2:]))) - 1
215 | if (scales is None) or scales > max_scale:
216 | scales = max_scale
217 | return int(scales)
218 |
--------------------------------------------------------------------------------
/src/scarlet2/morphology.py:
--------------------------------------------------------------------------------
1 | import astropy.units as u
2 | import equinox as eqx
3 | import jax.numpy as jnp
4 | import jax.scipy
5 |
6 | from . import Scenery, measure
7 | from .module import Module
8 | from .wavelets import starlet_reconstruction, starlet_transform
9 |
10 |
11 | class Morphology(Module):
12 | """Morphology base class"""
13 |
14 | @property
15 | def shape(self):
16 | """Shape (2D) of the morphology model"""
17 | raise NotImplementedError
18 |
19 |
20 | class ProfileMorphology(Morphology):
21 | """Base class for morphologies based on a radial profile"""
22 |
23 | size: float
24 | """Size of the profile
25 |
26 | Can be given as an astropy angle, which will be transformed with the WCS of the current
27 | :py:class:`~scarlet2.Scene`.
28 | """
29 | ellipticity: (None, jnp.array)
30 | """Ellipticity of the profile"""
31 | _shape: tuple = eqx.field(repr=False)
32 |
33 | def __init__(self, size, ellipticity=None, shape=None):
34 | if isinstance(size, u.Quantity):
35 | try:
36 | size = Scenery.scene.frame.u_to_pixel(size)
37 | except AttributeError:
38 | print("`size` defined in astropy units can only be used within the context of a Scene")
39 | print("Use 'with Scene(frame) as scene: (...)'")
40 | raise
41 |
42 | self.size = size
43 | self.ellipticity = ellipticity
44 |
45 | # default shape: square 15x size
46 | if shape is None:
47 | # explicit call to int() to avoid bbox sizes being jax-traced
48 | size = int(jnp.ceil(15 * self.size))
49 | # odd shapes for unique center pixel
50 | if size % 2 == 0:
51 | size += 1
52 | shape = (size, size)
53 | self._shape = shape
54 |
55 | @property
56 | def shape(self):
57 | """Shape of the bounding box for the profile.
58 |
59 | If not set during `__init__`, uses a square box with an odd number of pixels
60 | not smaller than `10*size`.
61 | """
62 | return self._shape
63 |
64 | def f(self, r2):
65 | """Radial profile function
66 |
67 | Parameters
68 | ----------
69 | r2: float or array
70 | Radius (distance from the center) squared
71 | """
72 | raise NotImplementedError
73 |
74 | def __call__(self, delta_center=jnp.zeros(2)): # noqa: B008
75 | """Evaluate the model"""
76 | _y = jnp.arange(-(self.shape[-2] // 2), self.shape[-2] // 2 + 1, dtype=float) - delta_center[-2]
77 | _x = jnp.arange(-(self.shape[-1] // 2), self.shape[-1] // 2 + 1, dtype=float) - delta_center[-1]
78 |
79 | if self.ellipticity is None:
80 | r2 = _y[:, None] ** 2 + _x[None, :] ** 2
81 | else:
82 | e1, e2 = self.ellipticity
83 | g_factor = 1 / (1.0 + jnp.sqrt(1.0 - (e1**2 + e2**2)))
84 | g1, g2 = self.ellipticity * g_factor
85 | __x = ((1 - g1) * _x[None, :] - g2 * _y[:, None]) / jnp.sqrt(1 - (g1**2 + g2**2))
86 | __y = (-g2 * _x[None, :] + (1 + g1) * _y[:, None]) / jnp.sqrt(1 - (g1**2 + g2**2))
87 | r2 = __y**2 + __x**2
88 |
89 | r2 /= self.size**2
90 | r2 = jnp.maximum(r2, 1e-3) # prevents infs at R2 = 0
91 | morph = self.f(r2)
92 | return morph
93 |
94 |
95 | class GaussianMorphology(ProfileMorphology):
96 | """Gaussian radial profile"""
97 |
98 | def f(self, r2):
99 | """Radial profile function
100 |
101 | Parameters
102 | ----------
103 | r2: float or array
104 | Radius (distance from the center) squared
105 | """
106 | return jnp.exp(-r2 / 2)
107 |
108 | def __call__(self, delta_center=jnp.zeros(2)): # noqa: B008
109 | """Evaluate the model"""
110 | # faster circular 2D Gaussian: instead of N^2 evaluations, use outer product of 2 1D Gaussian evals
111 | if self.ellipticity is None:
112 | _y = jnp.arange(-(self.shape[-2] // 2), self.shape[-2] // 2 + 1, dtype=float) - delta_center[-2]
113 | _x = jnp.arange(-(self.shape[-1] // 2), self.shape[-1] // 2 + 1, dtype=float) - delta_center[-1]
114 |
115 | # with pixel integration
116 | f = lambda x, s: 0.5 * ( # noqa: E731
117 | 1
118 | - jax.scipy.special.erfc((0.5 - x) / jnp.sqrt(2) / s)
119 | + 1
120 | - jax.scipy.special.erfc((0.5 + x) / jnp.sqrt(2) / s)
121 | )
122 | # # without pixel integration
123 | # f = lambda x, s: jnp.exp(-(x ** 2) / (2 * s ** 2)) / (jnp.sqrt(2 * jnp.pi) * s)
124 |
125 | return jnp.outer(f(_y, self.size), f(_x, self.size))
126 |
127 | else:
128 | return super().__call__(delta_center)
129 |
130 | @staticmethod
131 | def from_image(image):
132 | """Create Gaussian radial profile from the 2nd moments of `image`
133 |
134 | Parameters
135 | ----------
136 | image: array
137 | 2D array to measure :py:class:`~scarlet2.measure.Moments` from.
138 |
139 | Returns
140 | -------
141 | GaussianMorphology
142 | """
143 | assert image.ndim == 2
144 | center = measure.centroid(image)
145 | # compute moments and create Gaussian from it
146 | g = measure.Moments(image, center=center, N=2)
147 | return GaussianMorphology.from_moments(g, shape=image.shape)
148 |
149 | @staticmethod
150 | def from_moments(g, shape=None):
151 | """Create Gaussian radial profile from the moments `g`
152 |
153 | Parameters
154 | ----------
155 | g: :py:class:`~scarlet2.measure.Moments`
156 | Moments, order >= 2
157 | shape: tuple
158 | Shape of the bounding box
159 |
160 | Returns
161 | -------
162 | GaussianMorphology
163 | """
164 | t = g.size
165 | ellipticity = g.ellipticity
166 |
167 | # create image of Gaussian with these 2nd moments
168 | if jnp.isfinite(t) and jnp.isfinite(ellipticity).all():
169 | morph = GaussianMorphology(t, ellipticity, shape=shape)
170 | else:
171 | raise ValueError(
172 | f"Gaussian morphology not possible with size={t}, and ellipticity={ellipticity}!"
173 | )
174 | return morph
175 |
176 |
177 | class SersicMorphology(ProfileMorphology):
178 | """Sersic radial profile"""
179 |
180 | n: float
181 | """Sersic index"""
182 |
183 | def __init__(self, n, size, ellipticity=None, shape=None):
184 | self.n = n
185 | super().__init__(size, ellipticity=ellipticity, shape=shape)
186 |
187 | def f(self, r2):
188 | """Radial profile function
189 |
190 | Parameters
191 | ----------
192 | r2: float or array
193 | Radius (distance from the center) squared
194 | """
195 | n = self.n
196 | n2 = n * n
197 | # simplest form of bn: Capaccioli (1989)
198 | # bn = 1.9992 * n - 0.3271
199 | #
200 | # better treatment in Ciotti & Bertin (1999), eq. 18
201 | # stable to n > 0.36, with errors < 10^5
202 | bn = 2 * n - 0.333333 + 0.009877 / n + 0.001803 / n2 + 0.000114 / (n2 * n) - 0.000072 / (n2 * n2)
203 |
204 | # MacArthur, Courteau, & Holtzman (2003), eq. A2
205 | # much more stable for n < 0.36
206 | # not using it here to avoid if clause in jitted code
207 | # bn = 0.01945 - 0.8902 * n + 10.95 * n2 - 19.67 * n2 * n + 13.43 * n2 * n2
208 |
209 | # Graham & Driver 2005, eq. 1
210 | # we're given R^2, so we use R2^(0.5/n) instead of 1/n
211 | return jnp.exp(-bn * (r2 ** (0.5 / n) - 1))
212 |
213 |
214 | prox_plus = lambda x: jnp.maximum(x, 0) # noqa: E731
215 | prox_soft = lambda x, thresh: jnp.sign(x) * prox_plus(jnp.abs(x) - thresh) # noqa: E731
216 | prox_soft_plus = lambda x, thresh: prox_plus(prox_soft(x, thresh)) # noqa: E731
217 |
218 |
219 | class StarletMorphology(Morphology):
220 | """Morphology in the starlet basis
221 |
222 | See Also
223 | --------
224 | scarlet2.wavelets.Starlet
225 | """
226 |
227 | coeffs: jnp.ndarray
228 | """Starlet coefficients"""
229 | l1_thresh: float = eqx.field(default=0)
230 | """L1 threshold for coefficient to create sparse representation"""
231 | positive: bool = eqx.field(default=True)
232 | """Whether the coefficients are restricted to non-negative values"""
233 |
234 | def __call__(self, **kwargs):
235 | """Evaluate the model"""
236 | f = prox_soft_plus if self.positive else prox_soft
237 | return starlet_reconstruction(f(self.coeffs, self.l1_thresh))
238 |
239 | @property
240 | def shape(self):
241 | """Shape (2D) of the morphology model"""
242 | return self.coeffs.shape[-2:] # wavelet coeffs: scales x n1 x n2
243 |
244 | @staticmethod
245 | def from_image(image, **kwargs):
246 | """Create starlet morphology from `image`
247 |
248 | Parameters
249 | ----------
250 | image: array
251 | 2D image array to determine coefficients from.
252 | kwargs: dict
253 | Additional arguments for `__init__`
254 |
255 | Returns
256 | -------
257 | StarletMorphology
258 | """
259 | # Starlet transform of image (n1,n2) into coefficient with 3 dimensions: (scales+1,n1,n2)
260 | coeffs = starlet_transform(image)
261 | return StarletMorphology(coeffs, **kwargs)
262 |
--------------------------------------------------------------------------------
/docs/howto/sampling.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Sample from Posterior\n",
8 | "\n",
9 | "_scarlet2_ can provide samples from the posterior distribution to pass to downstream operations and as the most precise option for uncertainty quantification. In principle, we can get posterior samples for every parameter, and this can be done with any sampler by evaluating the log-posterior distribution. For this guide we will use the Hamiltonian Monte Carlo sampler from numpyro, for which we created a convenient front-end in _scarlet2_.\n",
10 | "\n",
11 | "We start from the [quickstart tutorial](../0-quickstart), loading the same data and the best-fitting model."
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "# Import Packages and setup\n",
21 | "import jax.numpy as jnp\n",
22 | "import matplotlib.pyplot as plt\n",
23 | "\n",
24 | "from scarlet2 import *\n",
25 | "\n",
26 | "set_validation(False)"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {},
32 | "source": [
33 | "## Create Observation\n",
34 | "\n",
35 | "We need to create the {py:class}`~scarlet2.Observation` because it contains the {py:func}`~scarlet2.Observation.log_likelihood` method we need for the posterior:"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "# load the data\n",
45 | "from huggingface_hub import hf_hub_download\n",
46 | "\n",
47 | "filename = hf_hub_download(\n",
48 | " repo_id=\"astro-data-lab/scarlet-test-data\", filename=\"hsc_cosmos_35.npz\", repo_type=\"dataset\"\n",
49 | ")\n",
50 | "file = jnp.load(filename)\n",
51 | "data = jnp.asarray(file[\"images\"])\n",
52 | "channels = [str(f) for f in file[\"filters\"]]\n",
53 | "centers = jnp.array([(src[\"y\"], src[\"x\"]) for src in file[\"catalog\"]])\n",
54 | "weights = jnp.asarray(1 / file[\"variance\"])\n",
55 | "psf = jnp.asarray(file[\"psfs\"])\n",
56 | "\n",
57 | "# create the observation\n",
58 | "obs = Observation(\n",
59 | " data,\n",
60 | " weights,\n",
61 | " psf=ArrayPSF(psf),\n",
62 | " channels=channels,\n",
63 | ")"
64 | ]
65 | },
66 | {
67 | "cell_type": "markdown",
68 | "metadata": {
69 | "collapsed": false
70 | },
71 | "source": [
72 | "## Load Model\n",
73 | "\n",
74 | "We can make use of the best-fit model from the Quickstart guide as the starting point of the sampler."
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "import scarlet2.io\n",
84 | "\n",
85 | "id = 35\n",
86 | "filename = \"hsc_cosmos.h5\"\n",
87 | "scene = scarlet2.io.model_from_h5(filename, path=\"..\", id=id)"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "metadata": {
93 | "collapsed": false
94 | },
95 | "source": [
96 | "Let's have a look:"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "execution_count": null,
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "norm = plot.AsinhAutomaticNorm(obs)\n",
106 | "plot.scene(scene, observation=obs, norm=norm, add_boxes=True)\n",
107 | "plt.show()"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "## Define Parameters with Prior\n",
115 | "\n",
116 | "In principle, we can get posterior samples for every parameter. We will demonstrate by sampling from the spectrum and the center position of the point source #0. We therefore need to set the `prior` attribute for each of these parameters; the attribute `stepsize` is ignored, but `constraint` cannot be used when `prior` is set."
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "import numpyro.distributions as dist\n",
126 | "\n",
127 | "C = len(channels)\n",
128 | "with scarlet2.Parameters(scene) as parameters:\n",
129 | " # rough guess of source brightness across bands\n",
130 | " p1 = scene.sources[0].spectrum\n",
131 | " prior1 = dist.Uniform(low=jnp.zeros(C), high=500 * jnp.ones(C))\n",
132 | " Parameter(p1, name=\"spectrum\", prior=prior1)\n",
133 | "\n",
134 | " # initial position was integer pixel coordinate\n",
135 | " # assume 0.5 pixel uncertainty\n",
136 | " p2 = scene.sources[0].center\n",
137 | " prior2 = dist.Normal(centers[0], scale=0.5)\n",
138 | " Parameter(p2, name=\"center\", prior=prior2)"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "metadata": {
144 | "collapsed": false
145 | },
146 | "source": [
147 | "```{warning}\n",
148 | "You are responsible to set reasonable priors, which describe what you know about the parameter before having looked at the data. In the example above, the spectrum gets a wide flat prior, and the center prior uses the position `centers[0]`, which is given by the original detection catalog. Neither use information from the optimized `scene`.\n",
149 | "\n",
150 | "Also: If in doubt how much prior choices matter, vary them within reason.\n",
151 | "```\n",
152 | "\n",
153 | "## Run Sampler\n",
154 | "\n",
155 | "Then we can run numpyro's {py:class}`~numpyro.infer.hmc.NUTS` sampler with a call to {py:func}`~scarlet2.Scene.sample`, which is analogous to {py:func}`~scarlet2.Scene.fit`:"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "mcmc = scene.sample(\n",
165 | " obs,\n",
166 | " parameters,\n",
167 | " num_warmup=100,\n",
168 | " num_samples=1000,\n",
169 | " progress_bar=False,\n",
170 | ")\n",
171 | "mcmc.print_summary()"
172 | ]
173 | },
174 | {
175 | "cell_type": "markdown",
176 | "metadata": {
177 | "collapsed": false
178 | },
179 | "source": [
180 | "## Access Samples\n",
181 | "\n",
182 | "The samples can be accessed from the MCMC chain and are listed as arrays under the names chosen above for the respective `Parameter`."
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "import pprint\n",
192 | "\n",
193 | "samples = mcmc.get_samples()\n",
194 | "pprint.pprint(samples)"
195 | ]
196 | },
197 | {
198 | "cell_type": "markdown",
199 | "metadata": {},
200 | "source": [
201 | "To create versions of the scene for any of the samples, we first select a few at random and then use the method {py:func}`scarlet2.Module.replace` to set their values at the locations identified by `parameters`:"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "# get values for three random samples\n",
211 | "S = 3\n",
212 | "import jax.random\n",
213 | "\n",
214 | "seed = 42\n",
215 | "key = jax.random.key(seed)\n",
216 | "idxs = jax.random.randint(key, shape=(S,), minval=0, maxval=mcmc.num_samples)\n",
217 | "\n",
218 | "values = [[spectrum, center] for spectrum, center in zip(samples[\"spectrum\"][idxs], samples[\"center\"][idxs])]\n",
219 | "\n",
220 | "# create versions of the scene with these posterior samples\n",
221 | "scenes = [scene.replace(parameters, v) for v in values]\n",
222 | "\n",
223 | "# display the source model\n",
224 | "fig, axes = plt.subplots(1, S, figsize=(10, 4))\n",
225 | "for s in range(S):\n",
226 | " source_array = scenes[s].sources[0]()\n",
227 | " axes[s].imshow(plot.img_to_rgb(source_array, norm=norm))"
228 | ]
229 | },
230 | {
231 | "cell_type": "markdown",
232 | "metadata": {
233 | "collapsed": false
234 | },
235 | "source": [
236 | "The difference are imperceptible for this source which tells us that the data were highly informative. But we can measure e.g. the total fluxes for each sample"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "metadata": {},
243 | "outputs": [],
244 | "source": [
245 | "print(f\"-------------- {channels}\")\n",
246 | "for i, scene in enumerate(scenes):\n",
247 | " print(f\"Flux Sample {i}: {measure.flux(scene.sources[0])}\")"
248 | ]
249 | },
250 | {
251 | "cell_type": "markdown",
252 | "metadata": {},
253 | "source": [
254 | "## Visualize Posterior\n",
255 | "\n",
256 | "We can also visualize the posterior distributions, e.g. with the [`corner`](https://corner.readthedocs.io/en/latest/) package:"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": null,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "import corner\n",
266 | "\n",
267 | "corner.corner(mcmc);"
268 | ]
269 | }
270 | ],
271 | "metadata": {
272 | "celltoolbar": "Raw Cell Format",
273 | "kernelspec": {
274 | "display_name": "scarlet2",
275 | "language": "python",
276 | "name": "python3"
277 | },
278 | "language_info": {
279 | "codemirror_mode": {
280 | "name": "ipython",
281 | "version": 3
282 | },
283 | "file_extension": ".py",
284 | "mimetype": "text/x-python",
285 | "name": "python",
286 | "nbconvert_exporter": "python",
287 | "pygments_lexer": "ipython3",
288 | "version": "3.12.11"
289 | }
290 | },
291 | "nbformat": 4,
292 | "nbformat_minor": 4
293 | }
294 |
--------------------------------------------------------------------------------
/src/scarlet2/bbox.py:
--------------------------------------------------------------------------------
1 | import equinox as eqx
2 | import jax.numpy as jnp
3 |
4 |
5 | class Box(eqx.Module):
6 | """Bounding Box for data array
7 |
8 | A Bounding box describes the location of a data array in the model coordinate system.
9 | It is used to identify spatial and channel overlap and to map from model
10 | to observed frames and back.
11 |
12 | The `BBox` code is agnostic about the meaning of the dimensions.
13 | We generally use this convention:
14 |
15 | - 2D shapes denote (Height, Width)
16 | - 3D shapes denote (Channels, Height, Width)
17 | """
18 |
19 | shape: tuple
20 | """Size of the array"""
21 | origin: tuple
22 | """Start coordinate (in 2D: lower-left corner) of the array in model frame"""
23 |
24 | def __init__(self, shape, origin=None):
25 | """
26 | Parameters
27 | ----------
28 | shape: tuple
29 | Size of the array
30 | origin: tuple, optional
31 | Start coordinate (in 2D: lower-left corner) of the array in model frame
32 | """
33 | self.shape = tuple(shape)
34 | if origin is None:
35 | origin = (0,) * len(shape)
36 | self.origin = tuple(origin)
37 |
38 | @staticmethod
39 | def from_bounds(*bounds):
40 | """Initialize a box from its bounds
41 |
42 | Parameters
43 | ----------
44 | bounds: tuple of (min,max) pairs
45 | Min/Max coordinate for every dimension
46 |
47 | Returns
48 | -------
49 | bbox: Box
50 | A new box bounded by the input bounds.
51 | """
52 | shape = tuple(max(0, cmax - cmin) for cmin, cmax in bounds)
53 | origin = (cmin for cmin, cmax in bounds)
54 | return Box(shape, origin=origin)
55 |
56 | @staticmethod
57 | def from_data(X, min_value=0): # noqa: N803
58 | """Define box where `X` is above `min_value`
59 |
60 | Parameters
61 | ----------
62 | X : jnp.ndarray
63 | Data to threshold
64 | min_value : float
65 | Minimum value of the result.
66 |
67 | Returns
68 | -------
69 | bbox : :class:`scarlet2.bbox.Box`
70 | Bounding box for the thresholded `X`
71 | """
72 | sel = min_value < X
73 | if sel.any():
74 | nonzero = jnp.where(sel)
75 | bounds = []
76 | for dim in range(len(X.shape)):
77 | bounds.append((nonzero[dim].min(), nonzero[dim].max() + 1))
78 | else:
79 | bounds = [[0, 0]] * len(X.shape)
80 | return Box.from_bounds(*bounds)
81 |
82 | def contains(self, p):
83 | """Whether the box contains a given coordinate `p`"""
84 | if len(p) != self.D:
85 | raise ValueError(f"Dimension mismatch in {p} and {self.D}")
86 |
87 | for d in range(self.D):
88 | if p[d] < self.origin[d] or p[d] >= self.origin[d] + self.shape[d]:
89 | return False
90 | return True
91 |
92 | def get_extent(self):
93 | """Return the start and end coordinates."""
94 | return [self.start[-1], self.stop[-1], self.start[-2], self.stop[-2]]
95 |
96 | @property
97 | def D(self): # noqa: N802
98 | """Dimensionality of this BBox"""
99 | return len(self.shape)
100 |
101 | @property
102 | def start(self):
103 | """Tuple of start coordinates"""
104 | return self.origin
105 |
106 | @property
107 | def stop(self):
108 | """Tuple of stop coordinates"""
109 | return tuple(o + s for o, s in zip(self.origin, self.shape, strict=False))
110 |
111 | @property
112 | def center(self):
113 | """Tuple of center coordinates"""
114 | return tuple(o + s // 2 for o, s in zip(self.origin, self.shape, strict=False))
115 |
116 | @property
117 | def bounds(self):
118 | """Bounds of the box"""
119 | return tuple((o, o + s) for o, s in zip(self.origin, self.shape, strict=False))
120 |
121 | @property
122 | def slices(self):
123 | """Bounds of the box as slices"""
124 | return tuple([slice(o, o + s) for o, s in zip(self.origin, self.shape, strict=False)])
125 |
126 | @property
127 | def spatial(self):
128 | """Spatial component of higher-dimensional box"""
129 | assert self.D >= 2
130 | return self[-2:]
131 |
132 | def set_center(self, pos):
133 | """Center box at given position"""
134 | pos_ = tuple(_.item() for _ in pos)
135 | origin = tuple(o + p - c for o, p, c in zip(self.origin, pos_, self.center, strict=False))
136 | object.__setattr__(self, "origin", origin)
137 |
138 | def grow(self, delta):
139 | """Grow the Box by the given delta in each direction"""
140 | if not hasattr(delta, "__iter__"):
141 | delta = [delta] * self.D
142 | origin = tuple([self.origin[d] - delta[d] for d in range(self.D)])
143 | shape = tuple([self.shape[d] + 2 * delta[d] for d in range(self.D)])
144 | return Box(shape, origin=origin)
145 |
146 | def shrink(self, delta):
147 | """Shrink the Box by the given delta in each direction"""
148 | if not hasattr(delta, "__iter__"):
149 | delta = [delta] * self.D
150 | origin = tuple([self.origin[d] + delta[d] for d in range(self.D)])
151 | shape = tuple([self.shape[d] - 2 * delta[d] for d in range(self.D)])
152 | return Box(shape, origin=origin)
153 |
154 | def __or__(self, other):
155 | """Union of two bounding boxes
156 |
157 | Parameters
158 | ----------
159 | other: `Box`
160 | The other bounding box in the union
161 |
162 | Returns
163 | -------
164 | result: `Box`
165 | The smallest rectangular box that contains *both* boxes.
166 | """
167 | if other.D != self.D:
168 | raise ValueError(f"Dimension mismatch in the boxes {other} and {self}")
169 | bounds = []
170 | for d in range(self.D):
171 | bounds.append((min(self.start[d], other.start[d]), max(self.stop[d], other.stop[d])))
172 | return Box.from_bounds(*bounds)
173 |
174 | def __and__(self, other):
175 | """Intersection of two bounding boxes
176 |
177 | If there is no intersection between the two bounding
178 | boxes then an empty bounding box is returned.
179 |
180 | Parameters
181 | ----------
182 | other: `Box`
183 | The other bounding box in the intersection
184 |
185 | Returns
186 | -------
187 | result: `Box`
188 | The rectangular box that is in the overlap region
189 | of both boxes.
190 | """
191 | if other.D != self.D:
192 | raise ValueError(f"Dimension mismatch in the boxes {other} and {self}")
193 | assert other.D == self.D
194 | bounds = []
195 | for d in range(self.D):
196 | bounds.append((max(self.start[d], other.start[d]), min(self.stop[d], other.stop[d])))
197 | return Box.from_bounds(*bounds)
198 |
199 | def __getitem__(self, i):
200 | s_ = self.shape[i]
201 | o_ = self.origin[i]
202 | if not hasattr(s_, "__iter__"):
203 | s_ = (s_,)
204 | o_ = (o_,)
205 | return Box(s_, origin=o_)
206 |
207 | def __add__(self, offset):
208 | if not hasattr(offset, "__iter__"):
209 | offset = (offset,) * self.D
210 | origin = tuple([a + o for a, o in zip(self.origin, offset, strict=False)])
211 | return Box(self.shape, origin=origin)
212 |
213 | def __sub__(self, offset):
214 | if not hasattr(offset, "__iter__"):
215 | offset = (offset,) * self.D
216 | origin = tuple([a - o for a, o in zip(self.origin, offset, strict=False)])
217 | return Box(self.shape, origin=origin)
218 |
219 | def __matmul__(self, bbox):
220 | bounds = self.bounds + bbox.bounds
221 | return Box.from_bounds(*bounds)
222 |
223 | def __copy__(self):
224 | return Box(self.shape, origin=self.origin)
225 |
226 | def __eq__(self, other):
227 | return self.shape == other.shape and self.origin == other.origin
228 |
229 | def __hash__(self):
230 | return hash((self.shape, self.origin))
231 |
232 |
233 | def overlap_slices(bbox1, bbox2, return_boxes=False):
234 | """Slices of bbox1 and bbox2 that overlap
235 |
236 | Parameters
237 | ----------
238 | bbox1: `~scarlet.bbox.Box`
239 | The first box to use for comparing overlap.
240 | bbox2: `~scarlet.bbox.Box`
241 | The second box to use for comparing overlap.
242 | return_boxes: bool
243 | If True return new boxes corresponding to the overlapping portion of
244 | each of the input boxes. If False, return the overlapping portion of the
245 | original boxes. Default False.
246 |
247 | Returns
248 | -------
249 | slices: tuple of slices
250 | The slice of an array bounded by `bbox1` and
251 | the slice of an array bounded by `bbox2` in the
252 | overlapping region.
253 | """
254 | overlap = bbox1 & bbox2
255 | _bbox1 = overlap - bbox1.origin
256 | _bbox2 = overlap - bbox2.origin
257 | if return_boxes:
258 | return _bbox1, _bbox2
259 | slices = (
260 | _bbox1.slices,
261 | _bbox2.slices,
262 | )
263 | return slices
264 |
265 |
266 | def insert_into(image, sub, bbox):
267 | """Insert `sub` into `image` according to this bbox
268 |
269 | Inverse operation to :func:`~scarlet.bbox.Box.extract_from`.
270 |
271 | Parameters
272 | ----------
273 | image: array
274 | Full image
275 | sub: array
276 | Smaller sub-image
277 | bbox: Box
278 | Bounding box that describes the shape and position of `sub` in the pixel coordinates of `image`.
279 | Returns
280 | -------
281 | image: array
282 | Image with `sub` inserted at `bbox`.
283 | """
284 | imbox = Box(image.shape)
285 |
286 | im_slices, sub_slices = overlap_slices(imbox, bbox)
287 | try:
288 | image[im_slices] = sub[sub_slices] # numpy arrays
289 | except TypeError:
290 | image = image.at[im_slices].set(sub[sub_slices]) # jax arrays
291 | return image
292 |
--------------------------------------------------------------------------------
/src/scarlet2/lsst_utils.py:
--------------------------------------------------------------------------------
1 | import jax.numpy as jnp
2 | import lsst.afw.geom as afw_geom
3 | import lsst.geom as geom
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | from astropy import units as u
7 | from astropy.wcs import WCS
8 | from lsst.afw.fits import MemFileManager
9 | from lsst.afw.image import ExposureF
10 | from lsst.geom import Point2D
11 | from lsst.meas.algorithms import WarpedPsf
12 | from lsst.pipe.tasks.registerImage import RegisterConfig, RegisterTask
13 | from pyvo.dal.adhoc import DatalinkResults, SodaQuery
14 |
15 | import scarlet2
16 |
17 |
18 | def warp_img(ref_img, img_to_warp, ref_wcs, wcs_to_warp):
19 | """Warp and rotate an image onto the coordinate system of another image
20 |
21 | Parameters
22 | ----------
23 | ref_img: 'ExposureF'
24 | is the reference image for the re-projection
25 | img_to_warp: 'ExposureF'
26 | the image to rotate and warp onto the reference image's wcs
27 | ref_wcs: 'wcs object'
28 | the wcs of the reference image (i.e. ref_img.getWcs() )
29 | wcs_to_warp: 'wcs object'
30 | the wcs of the image to warp (i.e. img_to_warp.getWcs() )
31 | Returns
32 | -------
33 | warped_exp: 'ExposureF'
34 | a reprojected, rotated image that is aligned and matched to ref_image
35 | """
36 | config = RegisterConfig()
37 | task = RegisterTask(name="register", config=config)
38 | warped_exp = task.warpExposure(img_to_warp, wcs_to_warp, ref_wcs, ref_img.getBBox())
39 |
40 | return warped_exp
41 |
42 |
43 | def read_cutout_mem(sq):
44 | """Read the cutout into memory
45 |
46 | Parameters
47 | ----------
48 | sq : 'dict'
49 | returned from SodaQuery.from_resource()
50 |
51 | Returns
52 | -------
53 | exposure : 'ExposureF'
54 | the cutout in exposureF format
55 | """
56 |
57 | cutout_bytes = sq.execute_stream().read()
58 | sq.raise_if_error()
59 | mem = MemFileManager(len(cutout_bytes))
60 | mem.setData(cutout_bytes, len(cutout_bytes))
61 | exposure = ExposureF(mem)
62 |
63 | return exposure
64 |
65 |
66 | def make_image_cutout(tap_service, ra, dec, data_id, cutout_size=0.01, imtype=None):
67 | """Wrapper function to generate a cutout using the cutout tool
68 |
69 | Parameters
70 | ----------
71 | tap_service : `pyvo.dal.tap.TAPService`
72 | the TAP service to use for querying the cutouts
73 | ra, dec : 'float'
74 | the ra and dec of the cutout center
75 | data_id : 'dict'
76 | the dataId of the image to make a cutout from. The format
77 | must correspond to that provided for parameter 'imtype'
78 | cutout_size : 'float', optional
79 | edge length in degrees of the cutout
80 | imtype : 'string', optional
81 | string containing the type of LSST image to generate
82 | a cutout of (e.g. deepCoadd, calexp). If imtype=None,
83 | the function will assume a deepCoadd.
84 |
85 | Returns
86 | -------
87 | exposure : 'ExposureF'
88 | the cutout in exposureF format
89 | """
90 |
91 | sphere_point = geom.SpherePoint(ra * geom.degrees, dec * geom.degrees)
92 |
93 | if imtype == "calexp":
94 | query = (
95 | "SELECT access_format, access_url, dataproduct_subtype, "
96 | + "lsst_visit, lsst_detector, lsst_band "
97 | + "FROM dp02_dc2_catalogs.ObsCore WHERE dataproduct_type = 'image' "
98 | + "AND obs_collection = 'LSST.DP02' "
99 | + "AND dataproduct_subtype = 'lsst.calexp' "
100 | + "AND lsst_visit = "
101 | + str(data_id["visit"])
102 | + " "
103 | + "AND lsst_detector = "
104 | + str(data_id["detector"])
105 | )
106 | results = tap_service.search(query)
107 |
108 | else:
109 | # Find the tract and patch that contain this point
110 | tract = data_id["tract"]
111 | patch = data_id["patch"]
112 |
113 | # add optional default band if it is not contained in the data_id
114 | band = data_id["band"]
115 |
116 | query = (
117 | "SELECT access_format, access_url, dataproduct_subtype, "
118 | + "lsst_patch, lsst_tract, lsst_band "
119 | + "FROM dp02_dc2_catalogs.ObsCore WHERE dataproduct_type = 'image' "
120 | + "AND obs_collection = 'LSST.DP02' "
121 | + "AND dataproduct_subtype = 'lsst.deepCoadd_calexp' "
122 | + "AND lsst_tract = "
123 | + str(tract)
124 | + " "
125 | + "AND lsst_patch = "
126 | + str(patch)
127 | + " "
128 | + "AND lsst_band = "
129 | + "'"
130 | + str(band)
131 | + "' "
132 | )
133 | results = tap_service.search(query)
134 |
135 | # Get datalink
136 | data_link_url = results[0].getdataurl()
137 | auth_session = tap_service._session
138 | dl = DatalinkResults.from_result_url(data_link_url, session=auth_session)
139 |
140 | # from_resource: creates a instance from
141 | # a number of records and a Datalink Resource.
142 | sq = SodaQuery.from_resource(dl, dl.get_adhocservice_by_id("cutout-sync"), session=auth_session)
143 |
144 | sq.circle = (
145 | sphere_point.getRa().asDegrees() * u.deg,
146 | sphere_point.getDec().asDegrees() * u.deg,
147 | cutout_size * u.deg,
148 | )
149 |
150 | exposure = read_cutout_mem(sq)
151 |
152 | return exposure
153 |
154 |
155 | def dia_source_to_observations(cutout_size_pix, dia_src, service, plot_images=False):
156 | """Convert a DIA source to a list of scarlet2 Observations
157 |
158 | Parameters
159 | ----------
160 | cutout_size_pix : 'int'
161 | the size of the cutout in pixels
162 | dia_src : `astropy.table.Table`
163 | the DIA source table containing the sources to make cutouts for
164 | service : `pyvo.dal.tap.TAPService`
165 | the TAP service to use for querying the cutouts
166 | i.e. the result from lsst.rsp.get_tap_service()
167 | plot_images : 'bool', optional
168 | whether to plot the images as they are processed
169 | (default is False)
170 |
171 | Returns
172 | -------
173 | observations : 'list of scarlet2.Observation'
174 | a list of scarlet2 Observations, one for each DIA source
175 | channels_sc2 : 'list of tuples'
176 | a list of tuples containing the band and channel number for each observation
177 | """
178 | # TODO: - add back in the plotting functionality?
179 | # - figure out how to get the WCS from the cutout without writing to disk
180 | cutout_size = cutout_size_pix * 0.2 / 3600.0
181 | ra = dia_src["ra"][0]
182 | dec = dia_src["decl"][0]
183 |
184 | observations = []
185 | channels_sc2 = []
186 | img_ref = None
187 | wcs_ref = None
188 |
189 | first_time = dia_src["midPointTai"][0]
190 | vmin = -200
191 | vmax = 300
192 |
193 | for i, src in enumerate(dia_src):
194 | ccd_visit_id = src["ccdVisitId"]
195 | band = str(src["filterName"])
196 | visit = str(ccd_visit_id)[:-3]
197 | detector = str(ccd_visit_id)[-3:]
198 | visit = int(visit)
199 | detector = int(detector)
200 | data_id_calexp = {"visit": visit, "detector": detector}
201 |
202 | if i == 0:
203 | img = make_image_cutout(
204 | service, ra, dec, cutout_size=cutout_size, imtype="calexp", dataId=data_id_calexp
205 | )
206 | img_ref = img
207 | # no warping is needed for the reference
208 | img_warped = img_ref
209 | offset = geom.Extent2D(geom.Point2I(0, 0) - img_ref.getXY0())
210 | shifted_wcs = img_ref.getWcs().copyAtShiftedPixelOrigin(offset)
211 | wcs_ref = WCS(shifted_wcs.getFitsMetadata())
212 | else:
213 | img = make_image_cutout(
214 | service, ra, dec, cutout_size=cutout_size * 50.0, imtype="calexp", dataId=data_id_calexp
215 | )
216 | img_warped = warp_img(img_ref, img, img_ref.getWcs(), img.getWcs())
217 | im_arr = img_warped.image.array
218 | var_arr = img_warped.variance.array
219 |
220 | # reshape image array
221 | n1, n2 = im_arr.shape
222 | image_sc2 = im_arr.reshape(1, n1, n2)
223 |
224 | # reshape variance array
225 | n1, n2 = var_arr.shape
226 | weight_sc2 = 1 / var_arr.reshape(1, n1, n2)
227 |
228 | # other transformations
229 | point_tuple = (int(img_warped.image.array.shape[0] / 2), int(img_warped.image.array.shape[1] / 2))
230 | point_image = Point2D(point_tuple)
231 | xy_transform = afw_geom.makeWcsPairTransform(img.wcs, img_warped.wcs)
232 | psf_w = WarpedPsf(img.getPsf(), xy_transform)
233 | point_tuple = (int(img_warped.image.array.shape[0] / 2), int(img_warped.image.array.shape[1] / 2))
234 | point_image = Point2D(point_tuple)
235 | psf_warped = psf_w.computeImage(point_image).convertF()
236 |
237 | # filter out images with no overlapping data at point
238 | if np.sum(psf_warped.array) == 0:
239 | print("PSF model unavailable, skipping")
240 | continue
241 |
242 | # reshape psf array
243 | n1, n2 = psf_warped.array.shape
244 | psf_sc2 = psf_warped.array.reshape(1, n1, n2)
245 |
246 | obs = scarlet2.Observation(
247 | jnp.array(image_sc2).astype(float),
248 | weights=jnp.array(weight_sc2).astype(float),
249 | psf=scarlet2.ArrayPSF(jnp.array(psf_sc2).astype(float)),
250 | wcs=wcs_ref,
251 | channels=[(band, str(i))],
252 | )
253 | channels_sc2.append((band, str(i)))
254 | observations.append(obs)
255 |
256 | if plot_images:
257 | _, ax = plt.subplots(1, 1, figsize=(2, 2))
258 | plt.imshow(im_arr, origin="lower", cmap="gray", vmin=vmin, vmax=vmax)
259 | ax.set_xticklabels([])
260 | ax.set_yticklabels([])
261 | ax.set_xticks([])
262 | ax.set_yticks([])
263 | ax.text(
264 | 0.1,
265 | 0.9,
266 | r"$\Delta$t=" + str(round(src["midPointTai"] - first_time, 2)),
267 | color="white",
268 | fontsize=12,
269 | )
270 | plt.show()
271 | plt.close()
272 | return observations, channels_sc2
273 |
--------------------------------------------------------------------------------
/docs/howto/priors.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Use Neural Priors\n",
8 | "\n",
9 | "In [the sampling tutorial](sampling), we have demonstrated how to define parameters with priors.\n",
10 | "This guide shows you how to set up and use neural network priors.\n",
11 | "We make use of the related package [`galaxygrad`](https://github.com/SampsonML/galaxygrad), which can be pip-installed.\n",
12 | "\n",
13 | "This guide will follow the [Quick Start Guide](../0-quickstart), with changes in the initialization and parameter specification. We assume that you have a full installation of _scarlet2_ including `optax`, `numpyro`, `h5py` and `galaxygrad`.\n",
14 | "\n",
15 | "More details about the use of a score-based prior model for diffusion can be found in the paper \"Score-matching neural networks for improved multi-band source separation\", [Sampson et al., 2024, A&C, 49, 100875](http://ui.adsabs.harvard.edu/abs/2024A&C....4900875S)."
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "# Import Packages and setup\n",
25 | "import jax.numpy as jnp\n",
26 | "import matplotlib.pyplot as plt\n",
27 | "\n",
28 | "import scarlet2 as sc2\n",
29 | "\n",
30 | "sc2.set_validation(False)"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "## Create Observation\n",
38 | "\n",
39 | "Again we import the test data and create the observation:"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "from huggingface_hub import hf_hub_download\n",
49 | "\n",
50 | "filename = hf_hub_download(\n",
51 | " repo_id=\"astro-data-lab/scarlet-test-data\", filename=\"hsc_cosmos_35.npz\", repo_type=\"dataset\"\n",
52 | ")\n",
53 | "file = jnp.load(filename)\n",
54 | "data = jnp.asarray(file[\"images\"])\n",
55 | "channels = [str(f) for f in file[\"filters\"]]\n",
56 | "centers = jnp.array([(src[\"y\"], src[\"x\"]) for src in file[\"catalog\"]])\n",
57 | "weights = jnp.asarray(1 / file[\"variance\"])\n",
58 | "psf = jnp.asarray(file[\"psfs\"])\n",
59 | "\n",
60 | "# create the observation\n",
61 | "obs = sc2.Observation(\n",
62 | " data,\n",
63 | " weights,\n",
64 | " psf=sc2.ArrayPSF(psf),\n",
65 | " channels=channels,\n",
66 | ")\n",
67 | "model_frame = sc2.Frame.from_observations(obs)"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "## Initialize Sources"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "with sc2.Scene(model_frame) as scene:\n",
84 | " for i, center in enumerate(centers):\n",
85 | " if i == 0: # we know source 0 is a star\n",
86 | " spectrum = sc2.init.pixel_spectrum(obs, center, correct_psf=True)\n",
87 | " sc2.PointSource(center, spectrum)\n",
88 | " else:\n",
89 | " try:\n",
90 | " spectrum, morph = sc2.init.from_gaussian_moments(obs, center, min_corr=0.99)\n",
91 | " except ValueError:\n",
92 | " spectrum = sc2.init.pixel_spectrum(obs, center)\n",
93 | " morph = sc2.init.compact_morphology()\n",
94 | "\n",
95 | " sc2.Source(center, spectrum, morph)"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {},
101 | "source": [
102 | "## Load Neural Prior"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "# load in the model you wish to use\n",
112 | "from galaxygrad import get_prior\n",
113 | "from scarlet2.nn import ScorePrior\n",
114 | "\n",
115 | "# instantiate the prior class\n",
116 | "temp = 2e-2 # values in the range of [1e-3, 1e-1] produce good results\n",
117 | "prior32 = get_prior(\"hsc32\")\n",
118 | "prior64 = get_prior(\"hsc64\")\n",
119 | "prior32 = ScorePrior(prior32, prior32.shape(), t=temp)\n",
120 | "prior64 = ScorePrior(prior64, prior64.shape(), t=temp)"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "The prior model itself is in the form of a score-based diffusion model, which matches the score function, i.e. the gradient of the log-likelihood of the training data with respect to the parameters. For an image-based parameterization, the free parameters are the pixels, which means the gradient has the same shape as the image. `galaxygrad` provides several pre-trained models, here we use a prior that was trained on deblended isolate source in HSC data, with the shapes of 32x32 or 64x64, respectively. These sizes denote the maximum image size for which the prior is trained.\n",
128 | "\n",
129 | "We import {py:class}`~scarlet2.nn.ScorePrior` to use with our prior. It automatically zero-pads any smaller image array up to the specified size and provides a custom gradient path that calls the underlying score model during optimization or HMC sampling. The `temp` argument refers to a fixed temperature for the diffusion process. For speed, we run a single diffusion step with the given temperature."
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {
135 | "collapsed": false
136 | },
137 | "source": [
138 | "## Define Parameters with Prior\n",
139 | "\n",
140 | "We use the same fitting routine as in the Quickstart guide, but replace `contraints.positive` with `prior=prior` in the Parameter containing the source morphologies. It is also useful to reduce the step size for the morphology updates because large jumps can lead to unstable prior gradients."
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": null,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "from functools import partial\n",
150 | "from numpyro.distributions import constraints\n",
151 | "\n",
152 | "spec_step = partial(sc2.relative_step, factor=0.05)\n",
153 | "morph_step = partial(sc2.relative_step, factor=1e-3)\n",
154 | "\n",
155 | "with sc2.Parameters(scene) as parameters:\n",
156 | " for i in range(len(scene.sources)):\n",
157 | " sc2.Parameter(\n",
158 | " scene.sources[i].spectrum,\n",
159 | " name=f\"spectrum:{i}\",\n",
160 | " constraint=constraints.positive,\n",
161 | " stepsize=spec_step,\n",
162 | " )\n",
163 | " if i == 0:\n",
164 | " sc2.Parameter(scene.sources[i].center, name=f\"center:{i}\", stepsize=0.1)\n",
165 | " else:\n",
166 | " # chose a prior of suitable size\n",
167 | " prior = prior32 if max(scene.sources[i].morphology.shape) <= 32 else prior64\n",
168 | " sc2.Parameter(\n",
169 | " scene.sources[i].morphology,\n",
170 | " name=f\"morph:{i}\",\n",
171 | " prior=prior, # attach the prior here\n",
172 | " stepsize=morph_step,\n",
173 | " )"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {},
179 | "source": [
180 | "Note that the use of a `prior` is incompatible with the use of a `constraint`."
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {
186 | "collapsed": false
187 | },
188 | "source": [
189 | "We again perform the fitting:"
190 | ]
191 | },
192 | {
193 | "cell_type": "code",
194 | "execution_count": null,
195 | "metadata": {},
196 | "outputs": [],
197 | "source": [
198 | "maxiter = 1000\n",
199 | "print(\"Initial likelihood:\", obs.log_likelihood(scene()))\n",
200 | "scene_ = scene.fit(obs, parameters, max_iter=maxiter, e_rel=1e-4, progress_bar=True)\n",
201 | "print(\"Optimized likelihood:\", obs.log_likelihood(scene_()))"
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {
207 | "collapsed": false
208 | },
209 | "source": [
210 | "The fit reaches values quite comparable to the run with the positivity constraints in the quickstart guide."
211 | ]
212 | },
213 | {
214 | "cell_type": "markdown",
215 | "metadata": {},
216 | "source": [
217 | "## Check Results"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": null,
223 | "metadata": {},
224 | "outputs": [],
225 | "source": [
226 | "norm = sc2.plot.AsinhAutomaticNorm(obs)\n",
227 | "sc2.plot.scene(\n",
228 | " scene_,\n",
229 | " obs,\n",
230 | " norm=norm,\n",
231 | " show_model=True,\n",
232 | " show_rendered=True,\n",
233 | " show_observed=True,\n",
234 | " show_residual=True,\n",
235 | " add_boxes=True,\n",
236 | ")\n",
237 | "plt.show()"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "sc2.plot.sources(\n",
247 | " scene_,\n",
248 | " norm=norm,\n",
249 | " observation=obs,\n",
250 | " show_model=True,\n",
251 | " show_rendered=True,\n",
252 | " show_observed=True,\n",
253 | " show_spectrum=False,\n",
254 | " add_labels=False,\n",
255 | " add_boxes=True,\n",
256 | ")\n",
257 | "plt.show()"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {},
263 | "source": [
264 | "The results for most of the galaxies look very reasonable now, in particular for the fainter ones. They remain compact and not overly affected by noise. Source #1 has minor artifacts and picks up neighboring objects, indicating that this prior has not been trained (yet) on as many larger galaxies and is therefore still somewhat weak. An update will fix this soon."
265 | ]
266 | }
267 | ],
268 | "metadata": {
269 | "celltoolbar": "Raw Cell Format",
270 | "kernelspec": {
271 | "display_name": "scarlet2",
272 | "language": "python",
273 | "name": "python3"
274 | },
275 | "language_info": {
276 | "codemirror_mode": {
277 | "name": "ipython",
278 | "version": 3
279 | },
280 | "file_extension": ".py",
281 | "mimetype": "text/x-python",
282 | "name": "python",
283 | "nbconvert_exporter": "python",
284 | "pygments_lexer": "ipython3",
285 | "version": "3.12.11"
286 | }
287 | },
288 | "nbformat": 4,
289 | "nbformat_minor": 4
290 | }
291 |
--------------------------------------------------------------------------------
/tests/scarlet2/questionnaire/conftest.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from importlib.resources import files
4 | from pathlib import Path
5 |
6 | import yaml
7 | from ipywidgets import HTML, Button, HBox, VBox
8 | from pygments import highlight
9 | from pygments.formatters.html import HtmlFormatter
10 | from pygments.lexers.python import PythonLexer
11 | from pytest import fixture
12 | from scarlet2.questionnaire.models import QuestionAnswer, QuestionAnswers, Questionnaire
13 | from scarlet2.questionnaire.questionnaire import (
14 | OUTPUT_BOX_LAYOUT,
15 | OUTPUT_BOX_STYLE_FILE,
16 | QUESTION_BOX_LAYOUT,
17 | VIEWS_PACKAGE_PATH,
18 | )
19 |
20 |
21 | @fixture
22 | def data_dir():
23 | """Path to the data directory containing the example questionnaire YAML file."""
24 | return Path(__file__).parent / "data"
25 |
26 |
27 | @fixture
28 | def example_questionnaire_dict(data_dir):
29 | """An example questionnaire dictionary"""
30 | yaml_path = data_dir / "example_questionnaire.yaml"
31 | with yaml_path.open("r") as f:
32 | return yaml.safe_load(f)
33 |
34 |
35 | @fixture
36 | def example_questionnaire(example_questionnaire_dict):
37 | """An example Questionnaire model instance"""
38 | return Questionnaire.model_validate(example_questionnaire_dict)
39 |
40 |
41 | @fixture
42 | def example_questionnaire_with_switch_dict(data_dir):
43 | """An example questionnaire dictionary with a switch question"""
44 | yaml_path = data_dir / "example_questionnaire_switch.yaml"
45 | with yaml_path.open("r") as f:
46 | return yaml.safe_load(f)
47 |
48 |
49 | @fixture
50 | def example_questionnaire_with_switch(example_questionnaire_with_switch_dict):
51 | """An example Questionnaire model instance with a switch question"""
52 | return Questionnaire.model_validate(example_questionnaire_with_switch_dict)
53 |
54 |
55 | @fixture
56 | def questionnaire_with_followup_switch_example_dict(data_dir):
57 | """An example questionnaire dictionary with a switch question"""
58 | yaml_path = data_dir / "example_questionnaire_followup_switch.yaml"
59 | with yaml_path.open("r") as f:
60 | return yaml.safe_load(f)
61 |
62 |
63 | @fixture
64 | def example_questionnaire_with_followup_switch(questionnaire_with_followup_switch_example_dict):
65 | """An example Questionnaire model instance with a switch question"""
66 | return Questionnaire.model_validate(questionnaire_with_followup_switch_example_dict)
67 |
68 |
69 | @fixture
70 | def example_questionnaire_with_feedback(example_questionnaire):
71 | """An example Questionnaire model instance with a feedback URL"""
72 | questionnaire = example_questionnaire.model_copy(deep=True)
73 | questionnaire.feedback_url = "https://example.com/feedback"
74 | return questionnaire
75 |
76 |
77 | @fixture
78 | def example_question_answers(example_questionnaire):
79 | """An example list of question answers for the exmaple questionnaire"""
80 | answer_inds = [0, 1, 0] # indices of answers to select for each question
81 | questions = [
82 | example_questionnaire.questions[0],
83 | example_questionnaire.questions[0].answers[0].followups[0],
84 | example_questionnaire.questions[0].answers[0].followups[1],
85 | ]
86 | qas = [
87 | QuestionAnswer(question=q.question, answer=q.answers[i].answer, value=i)
88 | for q, i in zip(questions, answer_inds, strict=False)
89 | ]
90 | return QuestionAnswers(answers=qas)
91 |
92 |
93 | class Helpers:
94 | """Helper functions for testing the QuestionnaireWidget"""
95 |
96 | @staticmethod
97 | def get_answer_button(widget, answer_index):
98 | """Get an answer button from the question box children.
99 |
100 | Args:
101 | widget: The QuestionnaireWidget instance
102 | answer_index: The index of the answer button to get
103 |
104 | Returns:
105 | The answer button widget
106 | """
107 | css_offset = 1
108 | prev_questions_offset = len(widget.question_answers)
109 | question_label_offset = 1
110 |
111 | # Get the buttons container which is at index after the question label
112 | buttons_container_index = css_offset + prev_questions_offset + question_label_offset
113 | buttons_container = widget.question_box.children[buttons_container_index]
114 |
115 | # Return the specific button from the buttons container
116 | return buttons_container.children[answer_index]
117 |
118 | @staticmethod
119 | def get_prev_question_button(widget, question_index):
120 | """Get a previous question button from the question box children.
121 |
122 | Args:
123 | widget: The QuestionnaireWidget instance
124 | question_index: The index of the previous question to get
125 |
126 | Returns:
127 | The previous question button widget
128 | """
129 | css_offset = 1
130 |
131 | container = widget.question_box.children[css_offset + question_index]
132 | # The button is the first child of the container
133 | return container.children[0]
134 |
135 | @staticmethod
136 | def assert_widget_ui_matches_state(widget):
137 | """Assert that the widget's UI matches its internal state."""
138 | assert isinstance(widget.ui, HBox)
139 | assert widget.ui.children == (widget.question_box, widget.output_box)
140 |
141 | assert isinstance(widget.output_box, VBox)
142 | assert widget.output_box.children == (widget.output_container,)
143 | assert widget.output_box.layout == OUTPUT_BOX_LAYOUT
144 |
145 | assert isinstance(widget.output_container, HTML)
146 |
147 | # check output container contains css from output_box css file
148 | css_file = files(VIEWS_PACKAGE_PATH).joinpath(OUTPUT_BOX_STYLE_FILE)
149 | with css_file.open("r") as f:
150 | css_content = f.read()
151 |
152 | assert css_content in widget.output_container.value
153 |
154 | output_code = re.sub(r"\{\{\s*\w+\s*\}\}", "", widget.code_output)
155 |
156 | html = widget.output_container.value
157 |
158 | # regex to capture the JS argument
159 | match = re.search(r"navigator\.clipboard\.writeText\((.*?)\)", html)
160 | assert match, "No copy button found"
161 |
162 | actual_arg = match.group(1)
163 | expected_arg = json.dumps(output_code) # exactly how Jinja|tojson would encode it
164 |
165 | assert actual_arg.strip() == expected_arg
166 |
167 | formatter = HtmlFormatter(style="monokai", noclasses=True)
168 | highlighted_code = highlight(output_code, PythonLexer(), formatter)
169 |
170 | assert highlighted_code in widget.output_container.value
171 |
172 | assert isinstance(widget.question_box, VBox)
173 | assert widget.question_box.layout == QUESTION_BOX_LAYOUT
174 |
175 | css_snippet_count = 1
176 | save_button_container_count = 1
177 |
178 | # If there's a current question, we have:
179 | # - CSS snippet
180 | # - Previous question containers
181 | # - Current question label
182 | # - Buttons container
183 | # - Save button container
184 | if widget.current_question:
185 | expected_children_count = (
186 | css_snippet_count + len(widget.question_answers) + 1 + 1 + save_button_container_count
187 | )
188 | # If there's no current question, we have:
189 | # - CSS snippet
190 | # - Previous question containers
191 | # - Final message container
192 | # - Save button container
193 | else:
194 | expected_children_count = (
195 | css_snippet_count + len(widget.question_answers) + 1 + save_button_container_count
196 | )
197 |
198 | assert len(widget.question_box.children) == expected_children_count
199 |
200 | # Skip the CSS snippet
201 | css_offset = 1
202 |
203 | for i in range(len(widget.question_answers)):
204 | # Add the CSS offset to the index
205 | child_index = i + css_offset
206 | assert isinstance(widget.question_box.children[child_index], HBox)
207 | # The button is the first child of the container
208 | btn = widget.question_box.children[child_index].children[0]
209 | question = widget.question_answers[i][0]
210 | assert question.question in btn.description
211 | ans_index = widget.question_answers[i][1]
212 | assert question.answers[ans_index].answer in btn.description
213 |
214 | if widget.current_question is not None:
215 | # Add the CSS offset to the index
216 | qs_ind = len(widget.question_answers) + css_offset
217 |
218 | assert isinstance(widget.question_box.children[qs_ind], HTML)
219 | assert widget.current_question.question in widget.question_box.children[qs_ind].value
220 |
221 | # Get the buttons container which is at index after the question label
222 | buttons_container_index = qs_ind + 1
223 | buttons_container = widget.question_box.children[buttons_container_index]
224 | assert isinstance(buttons_container, VBox)
225 |
226 | # Check each button in the buttons container
227 | for btn, ans in zip(buttons_container.children, widget.current_question.answers, strict=False):
228 | assert isinstance(btn, Button)
229 | assert btn.description == ans.answer
230 | assert btn.tooltip == ans.tooltip
231 |
232 | else:
233 | # When there's no current question, we have:
234 | # - CSS snippet
235 | # - Previous question containers
236 | # - Final message container
237 | # - Save button container
238 |
239 | # Get the final message container which is at index before the save button container
240 | final_message_index = len(widget.question_box.children) - 2
241 | final_message_container = widget.question_box.children[final_message_index]
242 | assert isinstance(final_message_container, VBox)
243 |
244 | # The final message is in the HTML child of the container
245 | final_message_html = final_message_container.children[0]
246 | assert isinstance(final_message_html, HTML)
247 | final_message = final_message_html.value
248 | assert "You're done" in final_message
249 |
250 | # Check for feedback URL if present in the questionnaire
251 | if widget.feedback_url:
252 | assert widget.feedback_url in final_message
253 | assert "feedback form" in final_message
254 |
255 | # Check that the last child is the save button container
256 | save_button_container = widget.question_box.children[-1]
257 | assert isinstance(save_button_container, VBox)
258 | assert len(save_button_container.children) > 0
259 | save_button = save_button_container.children[-1]
260 | assert isinstance(save_button, Button)
261 | assert save_button.description == "Save Answers"
262 |
263 |
264 | @fixture
265 | def helpers():
266 | """Provide helper functions for testing."""
267 | return Helpers()
268 |
--------------------------------------------------------------------------------
/docs/2-questionnaire.md:
--------------------------------------------------------------------------------
1 | # Questionnaire
2 |
3 | ## User Guide - Interactive Project Setup
4 |
5 | The scarlet2 questionnaire is an interactive tool designed to help you quickly set up a new scarlet2 project
6 | by generating a customized code template that matches your use case.
7 |
8 | ### Running the Questionnaire
9 |
10 | To run the questionnaire, you need to import and call the `run_questionnaire` function from the `scarlet2.questionnaire` module:
11 |
12 | ```python
13 | from scarlet2.questionnaire import run_questionnaire
14 | run_questionnaire()
15 | ```
16 |
17 | 
18 |
19 | This will launch an interactive widget in your Jupyter notebook that guides you through a series of questions about your project and your data.
20 |
21 | The questionnaire presents one question at a time with multiple-choice answers. For each answer you select,
22 | the questionnaire will update the code template to match your choices and display some explanatory text to
23 | help you understand the code being generated.
24 |
25 | You can navigate through the questions using the answer buttons. If you want to change a previous answer,
26 | you can click the previous question in the list to go back to that question.
27 |
28 | ### Using the Generated Template
29 |
30 | As you progress through the questionnaire, a code template is dynamically generated based on your answers.
31 | The code will outline the steps needed to set up your scarlet2 project, but it will require some manual editing
32 | to fill in specific details about your data. For example, the code will include definitions of example values
33 | for variables like `channels` that you will need to replace with values that match your data, and will
34 | reference variables like `data` and `psf` that you will need to change to use your actual data.
35 |
36 | In the bottom right of the code output, you can click the "📋 Copy" button at any time to copy the generated
37 | code to your clipboard.
38 |
39 | > Note: The template is a starting point that you can modify to fit your specific data. It provides the structure for your project based on your use case.
40 |
41 | ### Saving Your Progress
42 |
43 | If you want to save your questionnaire answers and return to them later, you can use the `Save Answers`
44 | button in the bottom left of the questionnaire widget. This will save your answers to a yaml file that you can
45 | then load later and continue from.
46 |
47 | The yaml file will be saved to your current working directory with a filename like
48 | `scarlet2_questionnaire_timestamp.yaml`. To change where the file is saved, run the questionnaire with the
49 | `save_path` argument:
50 |
51 | ```python
52 | from scarlet2.questionnaire import run_questionnaire
53 | run_questionnaire(save_path="path/to/save/directory")
54 | ```
55 |
56 | To load a previously saved questionnaire, use the `answer_path` argument to specify the path to your saved
57 | yaml file:
58 |
59 | ```python
60 | from scarlet2.questionnaire import run_questionnaire
61 | run_questionnaire(answer_path="path/to/saved/answers.yaml")
62 | ```
63 |
64 | ### Feedback and Issues
65 |
66 | If you encounter any issues or have suggestions for improving the questionnaire, [please fill out our feedback form!](https://docs.google.com/forms/d/e/1FAIpQLScKHbiqxhizacgzRx3xHEdqqjgtZBsxjtQZFJlYBdLcbOnfBg/viewform)
67 |
68 | ## Developer Guide - Questionnaire Architecture
69 |
70 | The questionnaire module is designed to be extensible and maintainable. This section explains how the questionnaire works internally and how to modify or extend it.
71 |
72 | ### Module Structure
73 |
74 | The questionnaire module consists of several key components:
75 |
76 | 1. **questions.yaml**: Stores the actual questions, answers, and templates
77 | 2. **models.py**: Defines the data structures used by the questionnaire that map to the YAML file
78 | 3. **questionnaire.py**: Contains the main `QuestionnaireWidget` class that uses ipywidgets to render the UI and handle user interactions
79 |
80 | ### YAML Structure
81 |
82 | Setting up the questions in the questionnaire is done by modifying the `questions.yaml` file.
83 | The questions are defined in this YAML file with the following structure:
84 |
85 | ```yaml
86 | initial_template: "{{code}}" # The starting template with placeholders
87 | initial_commentary: "Welcome message" # (Optional) The initial commentary text before any questions are answered
88 | questions: # List of top-level questions
89 | - question: "Question text" # Each question object has a question text, answers, and optionally a variable
90 | variable: "variable_name" # Optional variable to store the answer to be referenced later
91 | answers:
92 | - answer: "Answer option 1" # Each answer has an answer text
93 | tooltip: "Helpful tooltip" # Optional tooltip for the answer that appears on hover of the button
94 | templates: # A list of code snippets to apply if selected
95 | - replacement: code # The placeholder to replace
96 | code: "# Your code here\n{{next_placeholder}}" # The replacement code
97 | commentary: "Explanation of this choice" # The commentary text to display when this answer is selected, can include markdown formatting
98 | followups: # Additional questions to ask immediately if this answer is selected. This list of followups matches the structure of top-level questions, and can include question objects or switch/case objects
99 | - question: "Follow-up question"
100 | answers: [...]
101 | - question: "..." # More questions
102 | answers: [...]
103 | # ...
104 | - switch: "variable_name" # Conditional branching based on a previous answer
105 | cases:
106 | - value: 0 # If the question with variable "variable_name" was answered with the first answer (index 0)
107 | questions: [...] # The questions to ask in this case. This list matches the structure of top-level questions, and can include question objects or switch/case objects
108 | - value: null # The default case if no other case matches or if the variable was not set (e.g. the question was skipped) If there is no default case, the switch is skipped
109 | questions: [...]
110 | ```
111 |
112 | The questionnaire starts with an `initial_template` that contains placeholders (e.g. `{{code}}`) that will be replaced
113 | as the user answers questions. Each question has a list of possible answers, and each answer can specify one or more
114 | code snippets to replace the specified placeholders in the template, as well as commentary text to display.
115 | The commentary can include markdown formatting which will be rendered in the commentary box.
116 | Answers can also specify follow-up questions that are asked immediately after the current question if that answer is selected.
117 |
118 | The flow of questions can also include conditional branching using `switch` objects that check the value of a
119 | previously answered question (by its variable name) and present different sets of questions based on the answer.
120 |
121 | > Note: In YAML, single and double quotes behave differently. Single quotes will treat backslashes as literal
122 | > characters, while double quotes will interpret backslashes as escape characters. For example, to include a
123 | > newline character in a string, you would use double quotes: `"Line 1\nLine 2"`. If you used single quotes,
124 | > it would be treated as the literal text `Line 1\nLine 2` without a newline.
125 |
126 | The yaml file is packaged with the module as it is built, and is loaded using the `importlib.resources` module
127 | that allows access to package data files even in a zipped package. With `pyproject.toml` and `setuptools-scm`,
128 | any package data files that are tracked by git are automatically included in the package, and so the
129 | `questions.yaml` file is included automatically when the package is built and deployed.
130 | [(See the setuptools documentation for more details).](https://setuptools.pypa.io/en/latest/userguide/datafiles.html).
131 |
132 | ### Data Models
133 |
134 | The questionnaire uses Pydantic models to define its data structures which are found in `models.py`.
135 | These models match the structure of the YAML file and are used to parse and validate the questionnaire data
136 | loaded from the YAML file.
137 |
138 | Pydantic allows defining dataclasses with type hints and validation, making it easier to work with structured
139 | data like the questionnaire configuration. When the YAML file is loaded, it is parsed into these models for
140 | use in the questionnaire logic. If the YAML structure does not match the expected models, Pydantic will raise
141 | validation errors when the questionnaire is initialized.
142 |
143 | ### QuestionnaireWidget Class
144 |
145 | The `QuestionnaireWidget` class is responsible for rendering the questionnaire UI and handling user interactions.
146 |
147 | It uses `ipywidgets` to render the questionnaire in the output of a jupyter cell. The class maintains the
148 | state of the questionnaire, including the current question, the user's answers, and the generated code template.
149 |
150 | The state consists of:
151 |
152 | - `self.code_output`: The current code template with placeholders. This is updated as the user answers questions. Regex is used to find and replace placeholders in the template when updating with answer templates.
153 | - `self.commentary`: The current commentary text that explains the generated code. This is set with each answer.
154 | - `self.questions_stack`: A stack of questions to ask. This allows for handling follow-up questions and branching logic. The next question is popped from the stack when needed.
155 | - `self.question_answers`: A list of tuples of (question, answer_index) representing the user's answers to each question. This is used to display the history of questions and answers in the UI.
156 | - `self.variables`: A dictionary mapping variable names to the index of the selected answer. This is used for switch/case branching.
157 |
158 | The UI consists of a question area with a list of the previous questions and answers and a set of buttons
159 | for the current question's answers, and an output area that displays the generated code template and
160 | commentary text.
161 |
162 | The question area uses ipywidgets components to display the questions and answer buttons, and handle the user
163 | input. The output area uses a HTML widget to display the generated code with syntax highlighting, the copy
164 | button, and the commentary text. Since the output area is written in HTML, the code defining the output area
165 | is stored in separate HTML template files (`output_box.html.jinja` and `output_box.css`) for easier editing.
166 | These template files are also included in the package using `setuptools-scm` the same way as `questions.yaml`.
167 | The `jinja2` templating engine is used to render the HTML with the generated code and commentary. For syntax
168 | highlighting, the `pygments` library is used to convert the code into HTML with appropriate styling, and the
169 | commentary text is converted from markdown to HTML using the `markdown` library.
170 |
171 | Key methods include:
172 | - `_get_next_question()`: Determines the next question to display. A stack is used to manage the flow of questions, including handling follow-up questions and switch/case branching.
173 | - `_handle_answer()`: Processes user answers and updates the state.
174 | - `_update_template()`: Applies template changes based on answers.
175 | - `_render_output_box()`: Creates the UI for displaying the generated code. Uses Jinja2, pygments, and markdown for rendering.
176 |
--------------------------------------------------------------------------------
/src/scarlet2/source.py:
--------------------------------------------------------------------------------
1 | import operator
2 |
3 | import jax
4 | import jax.numpy as jnp
5 | from astropy.coordinates import SkyCoord
6 |
7 | from . import Scenery
8 | from .bbox import Box, overlap_slices
9 | from .module import Module
10 | from .morphology import Morphology
11 | from .spectrum import Spectrum
12 | from .validation_utils import (
13 | ValidationMethodCollector,
14 | )
15 |
16 |
17 | class Component(Module):
18 | """Single component of a hyperspectral model
19 |
20 | The parameterization of the 3D model (channel, height, width) is defined by
21 | the outer product of `spectrum` and `morphology`. That means that there is no
22 | variation of the spectrum in spatial direction. The `center` coordinate is only
23 | needed to define the bounding box and place the component in the model frame.
24 | """
25 |
26 | center: jnp.ndarray
27 | """Center position, in pixel coordinates of the model frame"""
28 | spectrum: (jnp.array, Spectrum)
29 | """Spectrum model"""
30 | morphology: (jnp.array, Morphology)
31 | """Morphology model"""
32 | bbox: Box
33 | """Bounding box of the model, in pixel coordinates of the model frame"""
34 |
35 | def __init__(self, center, spectrum, morphology):
36 | """
37 | Parameters
38 | ----------
39 | center: array, :py:class:`astropy.coordinates.SkyCoord`
40 | Center position. If given as astropy sky coordinate, it will be
41 | transformed with the WCS of the model frame.
42 | spectrum: :py:class:`~scarlet2.Spectrum`
43 | The spectrum of the component.
44 | morphology: :py:class:`~scarlet2.Morphology`
45 | The morphology of the component.
46 |
47 | Examples
48 | --------
49 | To uniquely determine coordinates, the creation of components is restricted
50 | to a context defined by a :py:class:`~scarlet2.Scene`, which define the
51 | :py:class:`~scarlet2.Frame` of the model.
52 |
53 | >>> with Scene(model_frame) as scene:
54 | >>> component = Component(center, spectrum, morphology)
55 | """
56 | self.spectrum = spectrum
57 | self.morphology = morphology
58 |
59 | if isinstance(center, SkyCoord):
60 | try:
61 | center = Scenery.scene.frame.get_pixel(center)
62 | except AttributeError:
63 | print("`center` defined in sky coordinates can only be created within the context of a Scene")
64 | print("Use 'with Scene(frame) as scene: (...)'")
65 | raise
66 | self.center = center
67 |
68 | box = Box(spectrum.shape)
69 | box2d = Box(morphology.shape)
70 | box2d.set_center(center.astype(int))
71 | self.bbox = box @ box2d
72 |
73 | def __call__(self):
74 | """What to run when Component is called"""
75 | # Boxed and centered model
76 | delta_center = (self.center[-2] - self.bbox.center[-2], self.center[-1] - self.bbox.center[-1])
77 | spectrum = self.spectrum() if isinstance(self.spectrum, Module) else self.spectrum
78 | morph = (
79 | self.morphology(delta_center=delta_center)
80 | if isinstance(self.morphology, Module)
81 | else self.morphology
82 | )
83 | return spectrum[:, None, None] * morph[None, :, :]
84 |
85 |
86 | class DustComponent(Component):
87 | """Component with negative exponential model
88 |
89 | This component is meant to describe the dust attenuation, :math:`\\exp(-\\tau)`,
90 | where :math:`\\tau` is the hyperspectral model defined by the base :py:class:`~scarlet2.Component`.
91 | """
92 |
93 | def __call__(self):
94 | """What to run when DustComponent is called"""
95 | return jnp.exp(-super().__call__())
96 |
97 |
98 | class Source(Component):
99 | """Source model
100 |
101 | The class is the basic parameterization for sources in :py:class:`~scarlet2.Scene`.
102 | """
103 |
104 | components: list
105 | """List of components in this source"""
106 | component_ops: list
107 | """List of operators to combine `components` for the final model"""
108 |
109 | def __init__(self, center, spectrum, morphology):
110 | """
111 | Parameters
112 | ----------
113 | center: array, :py:class:`astropy.coordinates.SkyCoord`
114 | Center position. If given as astropy sky coordinate, it will be
115 | transformed with the WCS of the model frame.
116 | spectrum: array, :py:class:`~scarlet2.Spectrum`
117 | The spectrum of the source.
118 | morphology: array, :py:class:`~scarlet2.Morphology`
119 | The morphology of the source.
120 |
121 | Examples
122 | --------
123 | A source declaration is restricted to a context of a :py:class:`~scarlet2.Scene`,
124 | which defines the :py:class:`~scarlet2.Frame` of the entire model.
125 |
126 | >>> with Scene(model_frame) as scene:
127 | >>> source = Source(center, spectrum, morphology)
128 |
129 | A source can comprise one or multiple :py:class:`~scarlet2.Component`,
130 | which can be added by :py:func:`add_component` or operators `+=`
131 | (for an additive component) or `*=` (for a multiplicative component).
132 |
133 | >>> with Scene(model_frame) as scene:
134 | >>> source = Source(center, spectrum, morphology)
135 | >>> source *= DustComponent(center, dust_spectrum, dust_morphology)
136 | """
137 | # set the base component
138 | super().__init__(center, spectrum, morphology)
139 | # create the empty component list
140 | self.components = list()
141 | self.component_ops = list()
142 |
143 | # add this source to the active scene
144 | try:
145 | Scenery.scene.sources.append(self)
146 | except AttributeError:
147 | print("Source can only be created within the context of a Scene")
148 | print("Use 'with Scene(frame) as scene: Source(...)'")
149 | raise
150 |
151 | def add_component(self, component, op):
152 | """Add `component` to this source
153 |
154 | Parameters
155 | ----------
156 | component: :py:class:`~scarlet2.Component`
157 | The component to include in this source. It will be combined with the
158 | previous component according to the operator `op`.
159 | op: callable
160 | Operator to combine this `component` with those before it in the list :py:attr:`components`.
161 | Conventional operators from the :py:mod:`operator` package can be used.
162 | Signature: op(x,y) -> z, where all terms have the same shapes
163 | """
164 | assert isinstance(component, (Source, Component))
165 |
166 | # if component is a source, it's already registered in scene
167 | # remove it from scene to not call it twice
168 | if isinstance(component, Source):
169 | try:
170 | Scenery.scene.sources.remove(component)
171 | except AttributeError:
172 | print("Source can only be modified within the context of a Scene")
173 | print("Use 'with Scene(frame) as scene: Source(...)'")
174 | raise
175 | except ValueError:
176 | pass
177 |
178 | # adding a full source will maintain its ownership of components:
179 | # hierarchical definition of sources withing sources
180 | self.components.append(component)
181 | self.component_ops.append(op)
182 | return self
183 |
184 | def __iadd__(self, component):
185 | return self.add_component(component, operator.add)
186 |
187 | def __imul__(self, component):
188 | return self.add_component(component, operator.mul)
189 |
190 | def __call__(self):
191 | """What to run when Source is called"""
192 | base = super()
193 | model = base.__call__()
194 | for component, op in zip(self.components, self.component_ops, strict=False):
195 | model_ = component()
196 | # cut out regions from model and model_
197 | bbox, bbox_ = overlap_slices(base.bbox, component.bbox, return_boxes=True)
198 | sub_model = jax.lax.dynamic_slice(model, bbox.start, bbox.shape)
199 | sub_model_ = jax.lax.dynamic_slice(model_, bbox_.start, bbox_.shape)
200 | # combine with operator
201 | sub_model = op(sub_model, sub_model_)
202 | # add model_ back in full model
203 | model = jax.lax.dynamic_update_slice(model, sub_model, bbox.start)
204 | return model
205 |
206 |
207 | class PointSource(Source):
208 | """Point source model"""
209 |
210 | def __init__(self, center, spectrum):
211 | """Model for point sources
212 |
213 | Because the morphology is determined by the model PSF, it does not need to be provided.
214 |
215 | Parameters
216 | ----------
217 | center: array, :py:class:`astropy.coordinates.SkyCoord`
218 | Center position. If given as astropy sky coordinate, it will be
219 | transformed with the WCS of the model frame.
220 | spectrum: array, :py:class:`~scarlet2.Spectrum`
221 | The spectrum of the point source.
222 |
223 | Examples
224 | --------
225 | A source declaration is restricted to a context of a :py:class:`~scarlet2.Scene`,
226 | which defines the :py:class:`~scarlet2.Frame` of the entire model.
227 |
228 | >>> with Scene(model_frame) as scene:
229 | >>> point_source = PointSource(center, spectrum)
230 | """
231 | try:
232 | frame = Scenery.scene.frame
233 | except AttributeError:
234 | print("Source can only be created within the context of a Scene")
235 | print("Use 'with Scene(frame) as scene: Source(...)'")
236 | raise
237 | if frame.psf is None:
238 | raise AttributeError("PointSource can only be create with a PSF in the model frame")
239 | morphology = frame.psf.morphology
240 |
241 | super().__init__(center, spectrum, morphology)
242 |
243 |
244 | class SourceValidator(metaclass=ValidationMethodCollector):
245 | """A class containing all of the validation checks for Source objects.
246 |
247 | Note that the metaclass is defined as `MethodCollector`, which collects all
248 | validation methods in this class into a single class attribute list called
249 | `validation_checks`. This allows for easy iteration over all checks."""
250 |
251 | def __init__(self, source: Source):
252 | self.source = source
253 |
254 | # def check_source_has_positive_contribution(self) -> ValidationResult:
255 | # """Check that the source has a positive contribution i.e. that the result
256 | # of evaluating self.source() does not contain negative values.
257 | #
258 | # Returns
259 | # -------
260 | # ValidationResult
261 | # A subclass of ValidationResult indicating the result of the check.
262 | # """
263 | # model = self.source()
264 | # if jnp.any(model < 0):
265 | # return ValidationError(
266 | # "Source model has negative contributions.",
267 | # check=self.__class__.__name__,
268 | # context={"source": self.source},
269 | # )
270 | # else:
271 | # return ValidationInfo(
272 | # "Source model has positive contributions.",
273 | # check=self.__class__.__name__,
274 | # )
275 |
--------------------------------------------------------------------------------