├── benchmarks ├── __init__.py ├── README.md ├── bbox_benchmarks.py └── asv.conf.json ├── src └── scarlet2 │ ├── questionnaire │ ├── views │ │ ├── __init__.py │ │ ├── output_box.html.jinja │ │ ├── output_box.css │ │ └── question_box.css │ ├── __init__.py │ └── models.py │ ├── psf.py │ ├── __init__.py │ ├── io.py │ ├── validation_utils.py │ ├── validation.py │ ├── nn.py │ ├── spectrum.py │ ├── wavelets.py │ ├── morphology.py │ ├── bbox.py │ ├── lsst_utils.py │ └── source.py ├── docs ├── _static │ ├── icon.png │ ├── example_obs_validation.png │ ├── questionnaire_screenshot.png │ ├── logo_light.svg │ └── logo_dark.svg ├── requirements.txt ├── _templates │ ├── custom-base-template.rst │ ├── custom-class-template.rst │ └── custom-module-template.rst ├── 1-howto.md ├── api.rst ├── Makefile ├── make.bat ├── conf.py ├── index.md ├── howto │ ├── sampling.ipynb │ └── priors.ipynb └── 2-questionnaire.md ├── .git_archival.txt ├── .github ├── ISSUE_TEMPLATE │ ├── 0-general_issue.md │ ├── 1-bug_report.md │ └── 2-feature_request.md ├── dependabot.yml └── workflows │ ├── publish-to-pypi.yml │ ├── pre-commit-ci.yml │ ├── testing-and-coverage.yml │ ├── smoke-test.yml │ ├── publish-benchmarks-pr.yml │ ├── asv-main.yml │ ├── asv-nightly.yml │ └── asv-pr.yml ├── .readthedocs.yaml ├── tests └── scarlet2 │ ├── test_validation_utils.py │ ├── questionnaire │ ├── test_models.py │ ├── data │ │ ├── example_questionnaire.yaml │ │ ├── example_questionnaire_switch.yaml │ │ └── example_questionnaire_followup_switch.yaml │ └── conftest.py │ ├── test_scene_fit_checks.py │ ├── test_quickstart.py │ ├── test_save_output.py │ ├── test_moments.py │ ├── conftest.py │ ├── test_observation_checks.py │ └── test_renderer.py ├── .copier-answers.yml ├── .gitattributes ├── LICENSE ├── README.md ├── .setup_dev.sh ├── .gitignore ├── .initialize_new_project.sh ├── .pre-commit-config.yaml └── pyproject.toml /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/scarlet2/questionnaire/views/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/_static/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pmelchior/scarlet2/HEAD/docs/_static/icon.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-issues 2 | myst-nb 3 | sphinx-book-theme 4 | ipywidgets 5 | corner 6 | arviz 7 | galaxygrad >= 0.3.0 -------------------------------------------------------------------------------- /docs/_static/example_obs_validation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pmelchior/scarlet2/HEAD/docs/_static/example_obs_validation.png -------------------------------------------------------------------------------- /docs/_static/questionnaire_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pmelchior/scarlet2/HEAD/docs/_static/questionnaire_screenshot.png -------------------------------------------------------------------------------- /docs/_templates/custom-base-template.rst: -------------------------------------------------------------------------------- 1 | {{ objname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. auto{{ objtype }}:: {{ objname }} 6 | -------------------------------------------------------------------------------- /.git_archival.txt: -------------------------------------------------------------------------------- 1 | node: f4b070e05b951dfc0a7615f40388b2f95ff151ba 2 | node-date: 2025-12-15T22:12:03+01:00 3 | describe-name: v0.3.0-58-gf4b070e0 4 | ref-names: HEAD -> main -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/0-general_issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: General issue 3 | about: Quickly create a general issue 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- -------------------------------------------------------------------------------- /src/scarlet2/questionnaire/__init__.py: -------------------------------------------------------------------------------- 1 | from .questionnaire import QuestionnaireWidget, run_questionnaire 2 | 3 | __all__ = ["QuestionnaireWidget", "run_questionnaire"] 4 | -------------------------------------------------------------------------------- /docs/1-howto.md: -------------------------------------------------------------------------------- 1 | # How To ... 2 | 3 | ```{toctree} 4 | :maxdepth: 2 5 | 6 | howto/sampling 7 | howto/priors 8 | howto/correlated 9 | howto/multiresolution 10 | howto/timedomain 11 | ``` -------------------------------------------------------------------------------- /docs/_templates/custom-class-template.rst: -------------------------------------------------------------------------------- 1 | {{ objname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | :members: 7 | :show-inheritance: 8 | :inherited-members: 9 | :special-members: __call__ 10 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | - package-ecosystem: "pip" 8 | directory: "/" 9 | schedule: 10 | interval: "monthly" 11 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.10" 7 | 8 | python: 9 | install: 10 | - method: pip 11 | path: . 12 | extra_requirements: 13 | - docs 14 | - requirements: docs/requirements.txt 15 | 16 | sphinx: 17 | configuration: docs/conf.py -------------------------------------------------------------------------------- /src/scarlet2/questionnaire/views/output_box.html.jinja: -------------------------------------------------------------------------------- 1 |
2 | {{highlighted_code | safe}} 3 |
4 | 8 |
9 |
10 |
11 | {{commentary_html | safe}} 12 |
-------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ================= 3 | 4 | .. autosummary:: 5 | :toctree: _autosummary 6 | :template: custom-module-template.rst 7 | 8 | scarlet2 9 | 10 | .. autosummary:: 11 | :caption: Modules 12 | :toctree: _autosummary 13 | :template: custom-module-template.rst 14 | :recursive: 15 | 16 | scarlet2.fft 17 | scarlet2.init 18 | scarlet2.interpolation 19 | scarlet2.io 20 | scarlet2.nn 21 | scarlet2.measure 22 | scarlet2.plot 23 | scarlet2.renderer 24 | scarlet2.wavelets 25 | 26 | -------------------------------------------------------------------------------- /tests/scarlet2/test_validation_utils.py: -------------------------------------------------------------------------------- 1 | from scarlet2.validation_utils import set_validation 2 | 3 | 4 | def test_set_validation(): 5 | """Test setting the validation switch. Note that we have to re-import VALIDATION_SWITCH 6 | to ensure we are using the current value.""" 7 | 8 | set_validation(True) 9 | from scarlet2.validation_utils import VALIDATION_SWITCH 10 | 11 | assert VALIDATION_SWITCH is True 12 | 13 | set_validation(False) 14 | from scarlet2.validation_utils import VALIDATION_SWITCH 15 | 16 | assert VALIDATION_SWITCH is False 17 | -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | This directory contains files that will be run via continuous testing either 4 | nightly or after committing code to a pull request. 5 | 6 | The runtime and/or memory usage of the functions defined in these files will be 7 | tracked and reported to give you a sense of the overall performance of your code. 8 | 9 | You are encouraged to add, update, or remove benchmark functions to suit the needs 10 | of your project. 11 | 12 | For more information, see the documentation here: https://lincc-ppt.readthedocs.io/en/latest/practices/ci_benchmarking.html -------------------------------------------------------------------------------- /.copier-answers.yml: -------------------------------------------------------------------------------- 1 | # Changes here will be overwritten by Copier 2 | _commit: v2.0.7 3 | _src_path: gh:lincc-frameworks/python-project-template 4 | author_email: peter.m.melchior@gmail.com 5 | author_name: Peter Melchior 6 | create_example_module: false 7 | custom_install: true 8 | enforce_style: 9 | - ruff_lint 10 | - ruff_format 11 | failure_notification: [] 12 | include_benchmarks: true 13 | include_docs: false 14 | mypy_type_checking: none 15 | package_name: scarlet2 16 | project_license: MIT 17 | project_name: scarlet2 18 | project_organization: pmelchior 19 | python_versions: 20 | - '3.10' 21 | - '3.11' 22 | - '3.12' 23 | - '3.13' 24 | test_lowest_version: none 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /src/scarlet2/questionnaire/views/output_box.css: -------------------------------------------------------------------------------- 1 | .code-output { 2 | background-color: #272822; 3 | color: #f1f1f1; 4 | padding: 10px; 5 | border-radius: 10px; 6 | border: 1px solid #444; 7 | font-family: monospace; 8 | overflow-x: auto; 9 | } 10 | 11 | .copy_button { 12 | font-size: 12px; 13 | padding: 4px 8px; 14 | background-color: #272822; 15 | color: white; 16 | border: none; 17 | border-radius: 4px; 18 | cursor: pointer; 19 | } 20 | 21 | .commentary-box { 22 | background-color: #f6f8fa; 23 | color: #333; 24 | padding: 8px 12px; 25 | border-left: 4px solid #0366d6; 26 | border-radius: 6px; 27 | font-size: 14px; 28 | font-family: sans-serif; 29 | } 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/1-bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Tell us about a problem to fix 4 | title: 'Short description' 5 | labels: 'bug' 6 | assignees: '' 7 | 8 | --- 9 | **Bug report** 10 | 11 | 12 | **Before submitting** 13 | Please check the following: 14 | 15 | - [ ] I have described the situation in which the bug arose, including what code was executed, information about my environment, and any applicable data others will need to reproduce the problem. 16 | - [ ] I have included available evidence of the unexpected behavior (including error messages, screenshots, and/or plots) as well as a description of what I expected instead. 17 | - [ ] If I have a solution in mind, I have provided an explanation and/or pseudocode and/or task list. 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: 'Short description' 5 | labels: 'enhancement' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Feature request** 11 | 12 | 13 | **Before submitting** 14 | Please check the following: 15 | 16 | - [ ] I have described the purpose of the suggested change, specifying what I need the enhancement to accomplish, i.e. what problem it solves. 17 | - [ ] I have included any relevant links, screenshots, environment information, and data relevant to implementing the requested feature, as well as pseudocode for how I want to access the new functionality. 18 | - [ ] If I have ideas for how the new feature could be implemented, I have provided explanations and/or pseudocode and/or task lists for the steps. 19 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # For explanation of this file and uses see 2 | # https://git-scm.com/docs/gitattributes 3 | # https://developer.lsst.io/git/git-lfs.html#using-git-lfs-enabled-repositories 4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/git-lfs.html 5 | # 6 | # Used by https://github.com/lsst/afwdata.git 7 | # *.boost filter=lfs diff=lfs merge=lfs -text 8 | # *.dat filter=lfs diff=lfs merge=lfs -text 9 | # *.fits filter=lfs diff=lfs merge=lfs -text 10 | # *.gz filter=lfs diff=lfs merge=lfs -text 11 | # 12 | # apache parquet files 13 | # *.parq filter=lfs diff=lfs merge=lfs -text 14 | # 15 | # sqlite files 16 | # *.sqlite3 filter=lfs diff=lfs merge=lfs -text 17 | # 18 | # gzip files 19 | # *.gz filter=lfs diff=lfs merge=lfs -text 20 | # 21 | # png image files 22 | # *.png filter=lfs diff=lfs merge=lfs -text 23 | 24 | .git_archival.txt export-subst -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /benchmarks/bbox_benchmarks.py: -------------------------------------------------------------------------------- 1 | from scarlet2.bbox import Box 2 | 3 | 4 | def time_bbox_creation(): 5 | """Basic timing benchmark for Box creation""" 6 | _ = Box((15, 15)) 7 | 8 | 9 | def mem_bbox_creation(): 10 | """Basic memory benchmark for Box creation""" 11 | _ = Box((15, 15)) 12 | 13 | 14 | class BoxSuite: 15 | """Suite of benchmarks for methods of the Box class""" 16 | 17 | params = [2, 16, 256, 2048] 18 | 19 | def setup(self, edge_length): 20 | """Create a Box, for each of the different edge_lengths defined in the 21 | `params` list.""" 22 | self.bb = Box((edge_length, edge_length)) 23 | 24 | def time_bbox_contains(self, edge_length): 25 | """Timing benchmark for `contains` method. Note that `edge_length` is 26 | unused by this method.""" 27 | this_point = (6, 7) 28 | self.bb.contains(this_point) 29 | 30 | def mem_bbox_contains(self, edge_length): 31 | """Memory benchmark for `contains` method. Note that `edge_length` is 32 | unused by this method.""" 33 | this_point = (6, 7) 34 | self.bb.contains(this_point) 35 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflow will upload a Python Package using Twine when a release is created 3 | # For more information see: https://github.com/pypa/gh-action-pypi-publish#trusted-publishing 4 | 5 | # This workflow uses actions that are not certified by GitHub. 6 | # They are provided by a third-party and are governed by 7 | # separate terms of service, privacy policy, and support 8 | # documentation. 9 | 10 | name: Upload Python Package 11 | 12 | on: 13 | release: 14 | types: [published] 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | deploy: 21 | 22 | runs-on: ubuntu-latest 23 | permissions: 24 | id-token: write 25 | steps: 26 | - uses: actions/checkout@v6 27 | - name: Set up Python 28 | uses: actions/setup-python@v6 29 | with: 30 | python-version: '3.11' 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install build 35 | - name: Build package 36 | run: python -m build 37 | - name: Publish package 38 | uses: pypa/gh-action-pypi-publish@release/v1 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-2025 Peter Melchior 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # _scarlet2_ 2 | 3 | _scarlet2_ is an open-source python library for modeling astronomical sources from multi-band, multi-epoch, and 4 | multi-instrument data. It provides non-parametric and parametric models, can handle source overlap (aka blending), and 5 | can integrate neural network priors. It's designed to be modular, flexible, and powerful. 6 | 7 | _scarlet2_ is implemented in [jax](http://jax.readthedocs.io/), layered on top of 8 | the [equinox](https://docs.kidger.site/equinox/) 9 | library. It can be deployed to GPUs and TPUs and supports optimization and sampling approaches. 10 | 11 | ## Installation 12 | 13 | For performance reasons, you should first install `jax` with the suitable `jaxlib` for your platform. After that 14 | 15 | ``` 16 | pip install scarlet2 17 | ``` 18 | 19 | should do. If you want the latest development version, use 20 | 21 | ``` 22 | pip install git+https://github.com/pmelchior/scarlet2.git 23 | ``` 24 | 25 | This will allow you to evaluate source models and compute likelihoods of observed data, so you can run your own 26 | optimizer/sampler. If you want a fully fledged library out of the box, you need to install `optax`, `numpyro`, and 27 | `h5py` as well. -------------------------------------------------------------------------------- /.github/workflows/pre-commit-ci.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflow runs pre-commit hooks on pushes and pull requests to main 3 | # to enforce coding style. To ensure correct configuration, please refer to: 4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/ci_precommit.html 5 | name: Run pre-commit hooks 6 | 7 | on: 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | jobs: 14 | pre-commit-ci: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v6 18 | with: 19 | fetch-depth: 0 20 | - name: Set up Python 21 | uses: actions/setup-python@v6 22 | with: 23 | python-version: '3.11' 24 | - name: Install dependencies 25 | run: | 26 | sudo apt-get update 27 | python -m pip install --upgrade pip 28 | pip install .[dev] 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | - uses: pre-commit/action@v3.0.1 31 | with: 32 | extra_args: --all-files --verbose 33 | env: 34 | SKIP: "check-lincc-frameworks-template-version,no-commit-to-branch,check-added-large-files,validate-pyproject,sphinx-build,pytest-check" 35 | - uses: pre-commit-ci/lite-action@v1.1.0 36 | if: failure() && github.event_name == 'pull_request' && github.event.pull_request.draft == false -------------------------------------------------------------------------------- /.github/workflows/testing-and-coverage.yml: -------------------------------------------------------------------------------- 1 | 2 | # This workflow will install Python dependencies, run tests and report code coverage with a variety of Python versions 3 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 4 | 5 | name: Unit test and code coverage 6 | 7 | on: 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | jobs: 14 | build: 15 | 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: ['3.10', '3.11', '3.12', '3.13'] 20 | 21 | steps: 22 | - uses: actions/checkout@v6 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v6 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | sudo apt-get update 30 | python -m pip install --upgrade pip 31 | pip install -e .[dev] 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | - name: Run unit tests with pytest 34 | run: | 35 | python -m pytest --cov=scarlet2 --cov-report=xml 36 | - name: Upload coverage report to codecov 37 | uses: codecov/codecov-action@v5 38 | with: 39 | token: ${{ secrets.CODECOV_TOKEN }} 40 | -------------------------------------------------------------------------------- /.github/workflows/smoke-test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run daily at 06:45. 2 | # It will install Python dependencies and run tests with a variety of Python versions. 3 | # See documentation for help debugging smoke test issues: 4 | # https://lincc-ppt.readthedocs.io/en/latest/practices/ci_testing.html#version-culprit 5 | 6 | name: Unit test smoke test 7 | 8 | on: 9 | 10 | # Runs this workflow automatically 11 | schedule: 12 | - cron: 45 6 * * * 13 | 14 | # Allows you to run this workflow manually from the Actions tab 15 | workflow_dispatch: 16 | 17 | jobs: 18 | build: 19 | 20 | runs-on: ubuntu-latest 21 | strategy: 22 | matrix: 23 | python-version: ['3.10', '3.11', '3.12', '3.13'] 24 | 25 | steps: 26 | - uses: actions/checkout@v6 27 | - name: Set up Python ${{ matrix.python-version }} 28 | uses: actions/setup-python@v6 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | - name: Install dependencies 32 | run: | 33 | sudo apt-get update 34 | python -m pip install --upgrade pip 35 | pip install -e .[dev] 36 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 37 | - name: List dependencies 38 | run: | 39 | pip list 40 | - name: Run unit tests with pytest 41 | run: | 42 | python -m pytest -------------------------------------------------------------------------------- /tests/scarlet2/questionnaire/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from scarlet2.questionnaire.models import Questionnaire 3 | 4 | 5 | def test_validate_model(example_questionnaire_dict): 6 | """Test that the Questionnaire model validates correctly.""" 7 | questionnaire = Questionnaire.model_validate(example_questionnaire_dict) 8 | assert questionnaire.initial_template == "{{code}}" 9 | assert questionnaire.initial_commentary == "This is an example commentary." 10 | assert len(questionnaire.questions) == 2 11 | question = questionnaire.questions[0] 12 | assert question.question == "Example question?" 13 | assert len(question.answers) == 2 14 | answer = question.answers[0] 15 | assert answer.answer == "Example answer" 16 | assert len(answer.templates) == 1 17 | assert answer.templates[0].replacement == "code" 18 | assert len(answer.followups) == 2 19 | followup = answer.followups[0] 20 | assert followup.question == "Follow-up question?" 21 | 22 | 23 | def test_model_fails(example_questionnaire_dict): 24 | """Test that the Questionnaire model raises errors for invalid data.""" 25 | invalid_dict = example_questionnaire_dict.copy() 26 | invalid_dict["initial_template"] = 123 27 | 28 | with pytest.raises(ValueError): 29 | Questionnaire.model_validate(invalid_dict) 30 | 31 | invalid_dict = example_questionnaire_dict.copy() 32 | invalid_dict["questions"][0]["question"] = 123 33 | 34 | with pytest.raises(ValueError): 35 | Questionnaire.model_validate(invalid_dict) 36 | -------------------------------------------------------------------------------- /tests/scarlet2/questionnaire/data/example_questionnaire.yaml: -------------------------------------------------------------------------------- 1 | initial_commentary: This is an example commentary. 2 | initial_template: '{{code}}' 3 | questions: 4 | - question: Example question? 5 | answers: 6 | - answer: Example answer 7 | commentary: This is some commentary. 8 | templates: 9 | - code: example_code {{follow}} 10 | replacement: code 11 | tooltip: This is an example tooltip. 12 | followups: 13 | - answers: 14 | - answer: Follow-up answer 15 | commentary: '' 16 | followups: [] 17 | templates: 18 | - code: 'followup_code {{code}}' 19 | replacement: follow 20 | tooltip: This is a follow-up tooltip. 21 | - answer: Second follow-up answer 22 | templates: 23 | - code: 'second_followup_code {{code}}' 24 | replacement: follow 25 | tooltip: This is a second follow-up tooltip. 26 | question: Follow-up question? 27 | - answers: 28 | - answer: Another follow-up answer 29 | templates: [] 30 | question: Another follow-up question? 31 | - answer: Another answer 32 | commentary: Some other commentary. 33 | templates: 34 | - code: 'another_code {{code}}' 35 | replacement: code 36 | tooltip: This is another tooltip. 37 | followups: [] 38 | - question: Second question? 39 | answers: 40 | - answer: Second answer 41 | commentary: Next commentary. 42 | templates: 43 | - code: second_code 44 | replacement: code 45 | followups: [] 46 | tooltip: This is a second tooltip. 47 | -------------------------------------------------------------------------------- /tests/scarlet2/questionnaire/data/example_questionnaire_switch.yaml: -------------------------------------------------------------------------------- 1 | initial_commentary: This is an example commentary. 2 | initial_template: '{{code}}' 3 | questions: 4 | - question: Example question? 5 | variable: example_var 6 | answers: 7 | - answer: First answer 8 | commentary: This is some commentary. 9 | templates: 10 | - code: 'example_code {{code}}' 11 | replacement: code 12 | tooltip: This is an example tooltip. 13 | - answer: Second answer 14 | commentary: Some other commentary. 15 | followups: [] 16 | templates: 17 | - code: 'another_code {{code}}' 18 | replacement: code 19 | tooltip: This is another tooltip. 20 | - answer: Third answer 21 | commentary: Some other commentary. 22 | followups: [] 23 | templates: 24 | - code: 'third_code {{code}}' 25 | replacement: code 26 | tooltip: This is another tooltip. 27 | - switch: example_var 28 | cases: 29 | - value: 0 30 | questions: 31 | - question: Second Question 1? 32 | answers: 33 | - answer: Second answer 1 34 | templates: 35 | - code: 'followup_code_1 {{code}}' 36 | replacement: code 37 | - value: 1 38 | questions: 39 | - question: Second Question 2? 40 | answers: 41 | - answer: Second answer 2 42 | templates: 43 | - code: 'followup_code_2 {{code}}' 44 | replacement: code 45 | - value: null 46 | questions: 47 | - question: Default Question? 48 | answers: 49 | - answer: Default answer 50 | templates: 51 | - code: 'default_code {{code}}' 52 | replacement: code 53 | 54 | 55 | -------------------------------------------------------------------------------- /docs/_static/logo_light.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 12 | 14 | 17 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /docs/_templates/custom-module-template.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block classes %} 6 | {% if classes %} 7 | .. rubric:: {{ _('Classes') }} 8 | 9 | .. autosummary:: 10 | :toctree: 11 | :template: custom-class-template.rst 12 | {% for item in classes %} 13 | {{ item }} 14 | {%- endfor %} 15 | {% endif %} 16 | {% endblock %} 17 | 18 | {% block functions %} 19 | {% if functions %} 20 | .. rubric:: {{ _('Functions') }} 21 | 22 | .. autosummary:: 23 | :toctree: 24 | :template: custom-base-template.rst 25 | {% for item in functions %} 26 | {{ item }} 27 | {%- endfor %} 28 | {% endif %} 29 | {% endblock %} 30 | 31 | {% block attributes %} 32 | {% if attributes %} 33 | .. rubric:: Module Attributes 34 | 35 | .. autosummary:: 36 | :toctree: 37 | :template: custom-base-template.rst 38 | {% for item in attributes %} 39 | {{ item }} 40 | {%- endfor %} 41 | {% endif %} 42 | {% endblock %} 43 | 44 | {% block exceptions %} 45 | {% if exceptions %} 46 | .. rubric:: {{ _('Exceptions') }} 47 | 48 | .. autosummary:: 49 | :toctree: 50 | :template: custom-base-template.rst 51 | {% for item in exceptions %} 52 | {{ item }} 53 | {%- endfor %} 54 | {% endif %} 55 | {% endblock %} 56 | 57 | {% block modules %} 58 | {% if modules %} 59 | .. rubric:: Modules 60 | 61 | .. autosummary:: 62 | :toctree: 63 | :template: custom-module-template.rst 64 | :recursive: 65 | {% for item in modules %} 66 | {{ item }} 67 | {%- endfor %} 68 | {% endif %} 69 | {% endblock %} 70 | -------------------------------------------------------------------------------- /docs/_static/logo_dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 12 | 14 | 17 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /src/scarlet2/psf.py: -------------------------------------------------------------------------------- 1 | """PSF-related classes""" 2 | 3 | import jax.numpy as jnp 4 | 5 | from .module import Module 6 | from .morphology import GaussianMorphology 7 | 8 | 9 | class PSF(Module): 10 | """PSF base class""" 11 | 12 | @property 13 | def shape(self): 14 | """Shape of the PSF model""" 15 | return self.morphology.shape 16 | 17 | 18 | class ArrayPSF(PSF): 19 | """PSF defined by an image array 20 | 21 | Warnings 22 | -------- 23 | The number of pixels in `morphology` should be odd, and the centroid of the 24 | PSF image should be in the central pixel. If that is not the case, one creates 25 | an effective shift by the PSF, which is not captured by the coordinate 26 | convention of the frame, e.g. its :py:attr:`~scarlet2.Frame.wcs`. 27 | 28 | See :issue:`96` from more details. 29 | """ 30 | 31 | morphology: jnp.ndarray 32 | """The PSF morphology image. Can be 2D (height, width) or 3D (channel, height, width)""" 33 | 34 | def __call__(self): 35 | """Evaluate PSF 36 | 37 | Returns 38 | ------- 39 | array 40 | 2D image, normalized to total flux=1 41 | """ 42 | return self.morphology / self.morphology.sum((-2, -1), keepdims=True) 43 | 44 | 45 | class GaussianPSF(PSF): 46 | """Gaussian-shaped PSF""" 47 | 48 | morphology: GaussianMorphology 49 | """Morphology model""" 50 | 51 | def __init__(self, sigma): 52 | """Initialize Gaussian PSF 53 | 54 | Parameters 55 | ---------- 56 | sigma: float 57 | Standard deviation of Gaussian 58 | """ 59 | self.morphology = GaussianMorphology(sigma) 60 | 61 | def __call__(self): 62 | """What to run when the Gaussian PSF is called""" 63 | morph = self.morphology() 64 | morph /= morph.sum((-2, -1), keepdims=True) 65 | return morph 66 | -------------------------------------------------------------------------------- /tests/scarlet2/questionnaire/data/example_questionnaire_followup_switch.yaml: -------------------------------------------------------------------------------- 1 | initial_commentary: This is an example commentary. 2 | initial_template: '{{code}}' 3 | questions: 4 | - question: Example question? 5 | answers: 6 | - answer: First answer 7 | commentary: This is some commentary. 8 | templates: 9 | - code: 'example_code {{follow}}' 10 | replacement: code 11 | tooltip: This is an example tooltip. 12 | followups: 13 | - question: Follow-up question? 14 | variable: example_var 15 | answers: 16 | - answer: Follow-up answer 17 | commentary: '' 18 | followups: [] 19 | templates: 20 | - code: 'followup_code {{code}}' 21 | replacement: follow 22 | tooltip: This is a follow-up tooltip. 23 | - answer: Second follow-up answer 24 | templates: 25 | - code: 'second_followup_code {{code}}' 26 | replacement: follow 27 | tooltip: This is a second follow-up tooltip. 28 | - answer: Second answer 29 | commentary: Some other commentary. 30 | followups: [] 31 | templates: 32 | - code: 'another_code {{code}}' 33 | replacement: code 34 | tooltip: This is another tooltip. 35 | - switch: example_var 36 | cases: 37 | - value: 0 38 | questions: 39 | - question: Second Question 1? 40 | answers: 41 | - answer: Second answer 1 42 | templates: 43 | - code: 'followup_code_1 {{code}}' 44 | replacement: code 45 | - value: 1 46 | questions: 47 | - answers: 48 | - answer: Second answer 2 49 | templates: 50 | - code: 'followup_code_2 {{code}}' 51 | replacement: code 52 | question: Second Question 2? 53 | - value: null 54 | questions: 55 | - question: Default Question? 56 | answers: 57 | - answer: Default answer 58 | templates: 59 | - code: 'default_code {{code}}' 60 | replacement: code 61 | -------------------------------------------------------------------------------- /src/scarlet2/questionnaire/models.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Union 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class Template(BaseModel): 8 | """Represents a template for code replacement in the questionnaire.""" 9 | 10 | replacement: str 11 | code: str 12 | 13 | 14 | class Answer(BaseModel): 15 | """Represents an answer to a question in the questionnaire.""" 16 | 17 | answer: str 18 | tooltip: str = "" 19 | templates: list[Template] = Field(default_factory=list) 20 | followups: list[Union["Question", "Switch"]] = Field(default_factory=list) 21 | commentary: str = "" 22 | 23 | 24 | class Question(BaseModel): 25 | """Represents a question in the questionnaire.""" 26 | 27 | question: str 28 | variable: str | None = None 29 | answers: list[Answer] 30 | 31 | 32 | class Case(BaseModel): 33 | """Represents a case in a switch statement within the questionnaire.""" 34 | 35 | value: int | None = None 36 | questions: list[Union[Question, "Switch"]] 37 | 38 | 39 | class Switch(BaseModel): 40 | """Represents a switch statement in the questionnaire.""" 41 | 42 | switch: str 43 | cases: list[Case] 44 | 45 | 46 | # Rebuild models to support self-referencing types and forward references 47 | Question.model_rebuild() 48 | Answer.model_rebuild() 49 | Case.model_rebuild() 50 | Switch.model_rebuild() 51 | 52 | 53 | class QuestionAnswer(BaseModel): 54 | """Represents a user's answer to a question.""" 55 | 56 | question: str 57 | answer: str 58 | value: int 59 | 60 | 61 | class QuestionAnswers(BaseModel): 62 | """Represents a collection of user answers to questions.""" 63 | 64 | answers: list[QuestionAnswer] = Field(default_factory=list) 65 | timestamp: datetime = Field(default_factory=datetime.now) 66 | 67 | 68 | class Questionnaire(BaseModel): 69 | """Represents a questionnaire with an initial template and a list of questions.""" 70 | 71 | initial_template: str 72 | initial_commentary: str = "" 73 | feedback_url: str | None = None 74 | questions: list[Question | Switch] 75 | -------------------------------------------------------------------------------- /tests/scarlet2/test_scene_fit_checks.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | import pytest 4 | from scarlet2.scene import FitValidator 5 | from scarlet2.validation_utils import ( 6 | ValidationError, 7 | ValidationInfo, 8 | ValidationWarning, 9 | set_validation, 10 | ) 11 | 12 | 13 | @pytest.fixture(autouse=True) 14 | def setup_validation(): 15 | """Automatically disable validation for all tests. This permits the creation 16 | of intentionally invalid Observation objects.""" 17 | set_validation(False) 18 | 19 | 20 | @pytest.mark.parametrize( 21 | "mocked_chi_value,expected", 22 | [ 23 | (0.1, ValidationInfo), 24 | (2.0, ValidationWarning), 25 | (10.0, ValidationError), 26 | ], 27 | ) 28 | def test_check_goodness_of_fit(scene, parameters, good_obs, mocked_chi_value, expected): 29 | """Mocked goodness_of_fit return to produces the expected ValidationResult type""" 30 | 31 | scene_ = scene.fit(good_obs, parameters, max_iter=1, e_rel=1e-4, progress_bar=True) 32 | checker = FitValidator(scene_, good_obs) 33 | 34 | with patch.object(type(good_obs), "goodness_of_fit", return_value=mocked_chi_value) as _: 35 | results = checker.check_goodness_of_fit() 36 | 37 | assert isinstance(results, expected) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "mocked_chi_value,expected", 42 | [ 43 | (0.1, ValidationInfo), 44 | (2.0, ValidationWarning), 45 | (10.0, ValidationError), 46 | ], 47 | ) 48 | def test_check_chi_squared_in_box_and_border(scene, parameters, good_obs, mocked_chi_value, expected): 49 | """Mocked chi-squared evaluation in box and border.""" 50 | 51 | scene_ = scene.fit(good_obs, parameters, max_iter=1, e_rel=1e-4, progress_bar=True) 52 | checker = FitValidator(scene_, good_obs) 53 | mock_return = {0: {"in": mocked_chi_value, "out": mocked_chi_value}} 54 | 55 | with patch.object(type(good_obs), "eval_chi_square_in_box_and_border", return_value=mock_return) as _: 56 | results = checker.check_chi_square_in_box_and_border() 57 | 58 | assert all(isinstance(res, expected) for res in results) 59 | -------------------------------------------------------------------------------- /.github/workflows/publish-benchmarks-pr.yml: -------------------------------------------------------------------------------- 1 | # This workflow publishes a benchmarks comment on a pull request. It is triggered after the 2 | # benchmarks are computed in the asv-pr workflow. This separation of concerns allows us limit 3 | # access to the target repository private tokens and secrets, increasing the level of security. 4 | # Based on https://securitylab.github.com/research/github-actions-preventing-pwn-requests/. 5 | name: Publish benchmarks comment to PR 6 | 7 | on: 8 | workflow_run: 9 | workflows: ["Run benchmarks for PR"] 10 | types: [completed] 11 | 12 | jobs: 13 | upload-pr-comment: 14 | runs-on: ubuntu-latest 15 | if: > 16 | github.event.workflow_run.event == 'pull_request' && 17 | github.event.workflow_run.conclusion == 'success' 18 | permissions: 19 | issues: write 20 | pull-requests: write 21 | steps: 22 | - name: Display Workflow Run Information 23 | run: | 24 | echo "Workflow Run ID: ${{ github.event.workflow_run.id }}" 25 | echo "Head SHA: ${{ github.event.workflow_run.head_sha }}" 26 | echo "Head Branch: ${{ github.event.workflow_run.head_branch }}" 27 | echo "Conclusion: ${{ github.event.workflow_run.conclusion }}" 28 | echo "Event: ${{ github.event.workflow_run.event }}" 29 | - name: Download artifact 30 | uses: dawidd6/action-download-artifact@v11 31 | with: 32 | name: benchmark-artifacts 33 | run_id: ${{ github.event.workflow_run.id }} 34 | - name: Extract artifacts information 35 | id: pr-info 36 | run: | 37 | printf "PR number: $(cat pr)\n" 38 | printf "Output:\n$(cat output)" 39 | printf "pr=$(cat pr)" >> $GITHUB_OUTPUT 40 | - name: Find benchmarks comment 41 | uses: peter-evans/find-comment@v4 42 | id: find-comment 43 | with: 44 | issue-number: ${{ steps.pr-info.outputs.pr }} 45 | comment-author: 'github-actions[bot]' 46 | body-includes: view all benchmarks 47 | - name: Create or update benchmarks comment 48 | uses: peter-evans/create-or-update-comment@v5 49 | with: 50 | comment-id: ${{ steps.find-comment.outputs.comment-id }} 51 | issue-number: ${{ steps.pr-info.outputs.pr }} 52 | body-path: output 53 | edit-mode: replace -------------------------------------------------------------------------------- /src/scarlet2/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: E402 2 | """Main namespace for scarlet2""" 3 | 4 | 5 | class Scenery: 6 | """Class to hold the context for the current scene 7 | 8 | See Also 9 | -------- 10 | :py:class:`~scarlet2.Scene` 11 | """ 12 | 13 | scene = None 14 | """Scene of the currently opened context""" 15 | 16 | 17 | class Parameterization: 18 | """Class to hold the context for the current parameter set 19 | 20 | See Also 21 | -------- 22 | :py:class:`~scarlet2.Parameters` 23 | """ 24 | 25 | parameters = None 26 | """Parameters of the currently opened context""" 27 | 28 | 29 | from . import init, measure, plot 30 | from .bbox import Box 31 | from .frame import Frame 32 | from .module import Module, Parameter, Parameters, relative_step 33 | from .morphology import ( 34 | GaussianMorphology, 35 | Morphology, 36 | ProfileMorphology, 37 | SersicMorphology, 38 | StarletMorphology, 39 | ) 40 | from .observation import CorrelatedObservation, Observation 41 | from .psf import PSF, ArrayPSF, GaussianPSF 42 | from .scene import Scene 43 | from .source import Component, DustComponent, PointSource, Source 44 | from .spectrum import Spectrum, StaticArraySpectrum, TransientArraySpectrum 45 | from .validation import check_fit, check_observation, check_scene, check_source 46 | from .validation_utils import VALIDATION_SWITCH, set_validation 47 | from .wavelets import Starlet 48 | 49 | # for * imports and docs 50 | __all__ = [ 51 | "init", 52 | "measure", 53 | "plot", 54 | "Scenery", 55 | "Parameterization", 56 | "Box", 57 | "Frame", 58 | "Parameter", 59 | "Parameters", 60 | "Module", 61 | "relative_step", 62 | "Morphology", 63 | "ProfileMorphology", 64 | "GaussianMorphology", 65 | "SersicMorphology", 66 | "StarletMorphology", 67 | "Observation", 68 | "CorrelatedObservation", 69 | "PSF", 70 | "ArrayPSF", 71 | "GaussianPSF", 72 | "Scene", 73 | "Component", 74 | "DustComponent", 75 | "Source", 76 | "PointSource", 77 | "Spectrum", 78 | "StaticArraySpectrum", 79 | "TransientArraySpectrum", 80 | "Starlet", 81 | "check_fit", 82 | "check_observation", 83 | "check_scene", 84 | "check_source", 85 | "VALIDATION_SWITCH", 86 | "set_validation", 87 | ] 88 | -------------------------------------------------------------------------------- /.setup_dev.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Bash Unofficial strict mode (http://redsymbol.net/articles/unofficial-bash-strict-mode/) 4 | # and (https://disconnected.systems/blog/another-bash-strict-mode/) 5 | set -o nounset # Any uninitialized variable is an error 6 | set -o errexit # Exit the script on the failure of any command to execute without error 7 | set -o pipefail # Fail command pipelines on the failure of any individual step 8 | IFS=$'\n\t' #set internal field separator to avoid iteration errors 9 | # Trap all exits and output something helpful 10 | trap 's=$?; echo "$0: Error on line "$LINENO": $BASH_COMMAND"; exit $s' ERR 11 | 12 | # This script should be run by new developers to install this package in 13 | # editable mode and configure their local environment 14 | 15 | echo "Checking virtual environment" 16 | if [ "${VIRTUAL_ENV:-missing}" = "missing" ] && [ "${CONDA_PREFIX:-missing}" = "missing" ]; then 17 | echo 'No virtual environment detected: none of $VIRTUAL_ENV or $CONDA_PREFIX is set.' 18 | echo 19 | echo "=== This script is going to install the project in the system python environment ===" 20 | echo "Proceed? [y/N]" 21 | read -r RESPONCE 22 | if [ "${RESPONCE}" != "y" ]; then 23 | echo "See https://lincc-ppt.readthedocs.io/ for details." 24 | echo "Exiting." 25 | exit 1 26 | fi 27 | 28 | fi 29 | 30 | echo "Checking pip version" 31 | MINIMUM_PIP_VERSION=22 32 | pipversion=( $(python -m pip --version | awk '{print $2}' | sed 's/\./\n\t/g') ) 33 | if let "${pipversion[0]}<${MINIMUM_PIP_VERSION}"; then 34 | echo "Insufficient version of pip found. Requires at least version ${MINIMUM_PIP_VERSION}." 35 | echo "See https://lincc-ppt.readthedocs.io/ for details." 36 | exit 1 37 | fi 38 | 39 | echo "Installing package and runtime dependencies in local environment" 40 | python -m pip install -e . > /dev/null 41 | 42 | echo "Installing developer dependencies in local environment" 43 | python -m pip install -e .'[dev]' > /dev/null 44 | if [ -f docs/requirements.txt ]; then python -m pip install -r docs/requirements.txt > /dev/null; fi 45 | 46 | echo "Installing pre-commit" 47 | pre-commit install > /dev/null 48 | 49 | ####################################################### 50 | # Include any additional configurations below this line 51 | ####################################################### 52 | -------------------------------------------------------------------------------- /.github/workflows/asv-main.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run benchmarks with airspeed velocity (asv), 2 | # store the new results in the "benchmarks" branch and publish them 3 | # to a dashboard on GH Pages. 4 | name: Run ASV benchmarks for main 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | 10 | env: 11 | PYTHON_VERSION: "3.11" 12 | ASV_VERSION: "0.6.4" 13 | WORKING_DIR: ${{github.workspace}}/benchmarks 14 | 15 | concurrency: 16 | group: ${{github.workflow}}-${{github.ref}} 17 | cancel-in-progress: true 18 | 19 | jobs: 20 | asv-main: 21 | runs-on: ubuntu-latest 22 | permissions: 23 | contents: write 24 | defaults: 25 | run: 26 | working-directory: ${{env.WORKING_DIR}} 27 | steps: 28 | - name: Set up Python ${{env.PYTHON_VERSION}} 29 | uses: actions/setup-python@v6 30 | with: 31 | python-version: ${{env.PYTHON_VERSION}} 32 | - name: Checkout main branch of the repository 33 | uses: actions/checkout@v6 34 | with: 35 | fetch-depth: 0 36 | - name: Install dependencies 37 | run: pip install "asv[virtualenv]==${{env.ASV_VERSION}}" 38 | - name: Configure git 39 | run: | 40 | git config user.name "github-actions[bot]" 41 | git config user.email "41898282+github-actions[bot]@users.noreply.github.com" 42 | - name: Create ASV machine config file 43 | run: asv machine --machine gh-runner --yes 44 | - name: Fetch previous results from the "benchmarks" branch 45 | run: | 46 | if git ls-remote --exit-code origin benchmarks > /dev/null 2>&1; then 47 | git merge origin/benchmarks \ 48 | --allow-unrelated-histories \ 49 | --no-commit 50 | mv ../_results . 51 | fi 52 | - name: Run ASV for the main branch 53 | run: asv run ALL --skip-existing --verbose || true 54 | - name: Submit new results to the "benchmarks" branch 55 | uses: JamesIves/github-pages-deploy-action@v4 56 | with: 57 | branch: benchmarks 58 | folder: ${{env.WORKING_DIR}}/_results 59 | target-folder: _results 60 | - name: Generate dashboard HTML 61 | run: | 62 | asv show 63 | asv publish 64 | - name: Deploy to Github pages 65 | uses: JamesIves/github-pages-deploy-action@v4 66 | with: 67 | branch: gh-pages 68 | folder: ${{env.WORKING_DIR}}/_html -------------------------------------------------------------------------------- /src/scarlet2/questionnaire/views/question_box.css: -------------------------------------------------------------------------------- 1 | /* Add overflow-x: hidden to the question box to prevent horizontal scrolling */ 2 | .question-box-container { 3 | overflow-x: hidden; 4 | display: flex; 5 | flex-direction: column; 6 | min-height: 300px; /* Ensure there's enough height for content */ 7 | } 8 | 9 | /* Style for the save button container */ 10 | .save-button-container { 11 | margin-top: auto; 12 | text-align: center; 13 | padding: 5px; 14 | } 15 | 16 | /* Style for the save button */ 17 | .save-button { 18 | font-size: 0.9em; 19 | background-color: #444444; 20 | color: white; 21 | } 22 | 23 | .save-message { 24 | margin-top: 10px; 25 | text-align: left; 26 | } 27 | 28 | .prev-item { 29 | position: relative; 30 | overflow: visible; 31 | } 32 | .prev-btn { 33 | width: auto; 34 | text-align: left; 35 | border: none; 36 | background: none; 37 | box-shadow: none; 38 | color: #555; 39 | text-decoration: none; 40 | cursor: pointer; 41 | font-size: 0.9em; 42 | margin: 2px; 43 | padding: 2px 4px; 44 | overflow: hidden; 45 | max-width: 100%; /* Ensure button doesn't exceed container width */ 46 | text-overflow: ellipsis; /* Add ellipsis for text that overflows */ 47 | } 48 | 49 | /* Container for the tooltip */ 50 | .prev-item .tooltip { 51 | position: absolute; 52 | background: #333; 53 | color: #fff; 54 | padding: 6px 8px; 55 | border-radius: 8px; 56 | font-size: 12px; 57 | line-height: 1.3; 58 | box-shadow: 0 4px 12px rgba(0,0,0,.2); 59 | white-space: normal; 60 | min-width: 160px; 61 | max-width: 320px; 62 | opacity: 0; 63 | visibility: hidden; 64 | transition: opacity .12s ease; 65 | pointer-events: none; 66 | z-index: 9999; 67 | 68 | /* Default position to the right of the button */ 69 | left: 100%; 70 | top: 50%; 71 | transform: translateY(-50%) translateX(8px); 72 | } 73 | 74 | /* When hovering over the container, show the tooltip */ 75 | .prev-item:hover .tooltip { 76 | opacity: 1; 77 | visibility: visible; 78 | } 79 | 80 | /* Arrow for the tooltip */ 81 | .prev-item .tooltip::before { 82 | content: ""; 83 | position: absolute; 84 | border-width: 6px; 85 | border-style: solid; 86 | left: -6px; 87 | top: 50%; 88 | transform: translateY(-50%); 89 | border-color: transparent #333 transparent transparent; 90 | } 91 | -------------------------------------------------------------------------------- /src/scarlet2/io.py: -------------------------------------------------------------------------------- 1 | """Methods to save and load scenes""" 2 | 3 | import os 4 | import pickle 5 | 6 | import h5py 7 | import numpy as np 8 | 9 | 10 | def model_to_h5(model, filename, id=0, path=".", overwrite=False): 11 | """Save the scene model to a HDF5 file 12 | 13 | Parameters 14 | ---------- 15 | filename : str 16 | Name of the HDF5 file to create 17 | model : :py:class:`~scarlet2.Module` 18 | Scene to be stored 19 | id : int 20 | HDF5 group to store this `model` under 21 | path: str, optional 22 | Explicit path for `filename`. If not set, uses local directory 23 | overwrite : bool, optional 24 | Whether to overwrite an existing file with the same path and filename 25 | 26 | Returns 27 | ------- 28 | None 29 | 30 | Notes 31 | ----- 32 | This is not a pure function hence cannot be utilized within a JAX JIT compilation. 33 | """ 34 | # create directory if it does not exist 35 | if not os.path.exists(path): 36 | os.makedirs(path) 37 | 38 | # first serialize the model into a pytree 39 | model_group = str(id) 40 | save_h5_path = os.path.join(path, filename) 41 | 42 | f = h5py.File(save_h5_path, "a") 43 | # create a group for the scene 44 | if model_group in f: 45 | if overwrite: 46 | del f[model_group] 47 | else: 48 | raise ValueError("ID already exists. Set overwrite=True to overwrite the ID.") 49 | 50 | # save the binary to HDF5 51 | group = f.create_group(model_group) 52 | model = pickle.dumps(model) 53 | group.attrs["model"] = np.void(model) 54 | f.close() 55 | 56 | 57 | def model_from_h5(filename, id=0, path="."): 58 | """ 59 | Load scene model from a HDF5 file 60 | 61 | Parameters 62 | ---------- 63 | filename : str 64 | Name of the HDF5 file to load from 65 | id : int 66 | HDF5 group to identify the scene by 67 | path: str, optional 68 | Explicit path for `filename`. If not set, uses local directory 69 | 70 | Returns 71 | ------- 72 | :py:class:`~scarlet2.Scene` 73 | """ 74 | 75 | filename = os.path.join(path, filename) 76 | f = h5py.File(filename, "r") 77 | model_group = str(id) 78 | if model_group not in f: 79 | raise ValueError(f"ID {id} not found in the file.") 80 | 81 | group = f.get(model_group) 82 | out = group.attrs["model"] 83 | binary_blob = out.tobytes() 84 | scene = pickle.loads(binary_blob) 85 | f.close() 86 | 87 | return scene 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | _version.py 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_autosummary/ 74 | docs/_build/ 75 | docs/jupyter_execute/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # VSCode settings 135 | .vscode 136 | 137 | # Default location for saved models 138 | stored_models/ 139 | -------------------------------------------------------------------------------- /.initialize_new_project.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Bash Unofficial strict mode (http://redsymbol.net/articles/unofficial-bash-strict-mode/) 4 | # and (https://disconnected.systems/blog/another-bash-strict-mode/) 5 | set -o nounset # Any uninitialized variable is an error 6 | set -o errexit # Exit the script on the failure of any command to execute without error 7 | set -o pipefail # Fail command pipelines on the failure of any individual step 8 | IFS=$'\n\t' #set internal field separator to avoid iteration errors 9 | # Trap all exits and output something helpful 10 | trap 's=$?; echo "$0: Error on line "$LINENO": $BASH_COMMAND"; exit $s' ERR 11 | 12 | echo "Checking virtual environment" 13 | if [ "${VIRTUAL_ENV:-missing}" = "missing" ] && [ "${CONDA_PREFIX:-missing}" = "missing" ]; then 14 | echo 'No virtual environment detected: none of $VIRTUAL_ENV or $CONDA_PREFIX is set.' 15 | echo 16 | echo "=== This script is going to install the project in the system python environment ===" 17 | echo "Proceed? [y/N]" 18 | read -r RESPONCE 19 | if [ "${RESPONCE}" != "y" ]; then 20 | echo "See https://lincc-ppt.readthedocs.io/ for details." 21 | echo "Exiting." 22 | exit 1 23 | fi 24 | 25 | fi 26 | 27 | echo "Checking pip version" 28 | MINIMUM_PIP_VERSION=22 29 | pipversion=( $(python -m pip --version | awk '{print $2}' | sed 's/\./\n\t/g') ) 30 | if let "${pipversion[0]}<${MINIMUM_PIP_VERSION}"; then 31 | echo "Insufficient version of pip found. Requires at least version ${MINIMUM_PIP_VERSION}." 32 | echo "See https://lincc-ppt.readthedocs.io/ for details." 33 | exit 1 34 | fi 35 | 36 | echo "Initializing local git repository" 37 | { 38 | gitversion=( $(git version | git version | awk '{print $3}' | sed 's/\./\n\t/g') ) 39 | if let "${gitversion[0]}<2"; then 40 | # manipulate directly 41 | git init . && echo 'ref: refs/heads/main' >.git/HEAD 42 | elif let "${gitversion[0]}==2 & ${gitversion[1]}<34"; then 43 | # rename master to main 44 | git init . && { git branch -m master main 2>/dev/null || true; }; 45 | else 46 | # set the initial branch name to main 47 | git init --initial-branch=main >/dev/null 48 | fi 49 | } > /dev/null 50 | 51 | echo "Installing package and runtime dependencies in local environment" 52 | python -m pip install -e . > /dev/null 53 | 54 | echo "Installing developer dependencies in local environment" 55 | python -m pip install -e .'[dev]' > /dev/null 56 | if [ -f docs/requirements.txt ]; then python -m pip install -r docs/requirements.txt; fi 57 | 58 | echo "Installing pre-commit" 59 | pre-commit install > /dev/null 60 | 61 | echo "Committing initial files" 62 | git add . && SKIP="no-commit-to-branch" git commit -m "Initial commit" 63 | -------------------------------------------------------------------------------- /.github/workflows/asv-nightly.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run daily at 06:45. 2 | # It will run benchmarks with airspeed velocity (asv) 3 | # and compare performance with the previous nightly build. 4 | name: Run benchmarks nightly job 5 | 6 | on: 7 | schedule: 8 | - cron: 45 6 * * * 9 | workflow_dispatch: 10 | 11 | env: 12 | PYTHON_VERSION: "3.11" 13 | ASV_VERSION: "0.6.4" 14 | WORKING_DIR: ${{github.workspace}}/benchmarks 15 | NIGHTLY_HASH_FILE: nightly-hash 16 | 17 | jobs: 18 | asv-nightly: 19 | runs-on: ubuntu-latest 20 | defaults: 21 | run: 22 | working-directory: ${{env.WORKING_DIR}} 23 | steps: 24 | - name: Set up Python ${{env.PYTHON_VERSION}} 25 | uses: actions/setup-python@v6 26 | with: 27 | python-version: ${{env.PYTHON_VERSION}} 28 | - name: Checkout main branch of the repository 29 | uses: actions/checkout@v6 30 | with: 31 | fetch-depth: 0 32 | - name: Install dependencies 33 | run: pip install "asv[virtualenv]==${{env.ASV_VERSION}}" 34 | - name: Configure git 35 | run: | 36 | git config user.name "github-actions[bot]" 37 | git config user.email "41898282+github-actions[bot]@users.noreply.github.com" 38 | - name: Create ASV machine config file 39 | run: asv machine --machine gh-runner --yes 40 | - name: Fetch previous results from the "benchmarks" branch 41 | run: | 42 | if git ls-remote --exit-code origin benchmarks > /dev/null 2>&1; then 43 | git merge origin/benchmarks \ 44 | --allow-unrelated-histories \ 45 | --no-commit 46 | mv ../_results . 47 | fi 48 | - name: Get nightly dates under comparison 49 | id: nightly-dates 50 | run: | 51 | echo "yesterday=$(date -d yesterday +'%Y-%m-%d')" >> $GITHUB_OUTPUT 52 | echo "today=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT 53 | - name: Use last nightly commit hash from cache 54 | uses: actions/cache@v4 55 | with: 56 | path: ${{env.WORKING_DIR}} 57 | key: nightly-results-${{steps.nightly-dates.outputs.yesterday}} 58 | - name: Run comparison of main against last nightly build 59 | run: | 60 | HASH_FILE=${{env.NIGHTLY_HASH_FILE}} 61 | CURRENT_HASH=${{github.sha}} 62 | if [ -f $HASH_FILE ]; then 63 | PREV_HASH=$(cat $HASH_FILE) 64 | asv continuous $PREV_HASH $CURRENT_HASH --verbose || true 65 | asv compare $PREV_HASH $CURRENT_HASH --sort ratio --verbose 66 | fi 67 | echo $CURRENT_HASH > $HASH_FILE 68 | - name: Update last nightly hash in cache 69 | uses: actions/cache@v4 70 | with: 71 | path: ${{env.WORKING_DIR}} 72 | key: nightly-results-${{steps.nightly-dates.outputs.today}} -------------------------------------------------------------------------------- /.github/workflows/asv-pr.yml: -------------------------------------------------------------------------------- 1 | # This workflow will run benchmarks with airspeed velocity (asv) for pull requests. 2 | # It will compare the performance of the main branch with the performance of the merge 3 | # with the new changes. It then publishes a comment with this assessment by triggering 4 | # the publish-benchmarks-pr workflow. 5 | # Based on https://securitylab.github.com/research/github-actions-preventing-pwn-requests/. 6 | name: Run benchmarks for PR 7 | 8 | on: 9 | pull_request: 10 | branches: [ main ] 11 | workflow_dispatch: 12 | 13 | concurrency: 14 | group: ${{github.workflow}}-${{github.ref}} 15 | cancel-in-progress: true 16 | 17 | env: 18 | PYTHON_VERSION: "3.11" 19 | ASV_VERSION: "0.6.4" 20 | WORKING_DIR: ${{github.workspace}}/benchmarks 21 | ARTIFACTS_DIR: ${{github.workspace}}/artifacts 22 | 23 | jobs: 24 | asv-pr: 25 | runs-on: ubuntu-latest 26 | defaults: 27 | run: 28 | working-directory: ${{env.WORKING_DIR}} 29 | steps: 30 | - name: Set up Python ${{env.PYTHON_VERSION}} 31 | uses: actions/setup-python@v6 32 | with: 33 | python-version: ${{env.PYTHON_VERSION}} 34 | - name: Checkout PR branch of the repository 35 | uses: actions/checkout@v6 36 | with: 37 | fetch-depth: 0 38 | - name: Display Workflow Run Information 39 | run: | 40 | echo "Workflow Run ID: ${{github.run_id}}" 41 | - name: Install dependencies 42 | run: pip install "asv[virtualenv]==${{env.ASV_VERSION}}" lf-asv-formatter 43 | - name: Make artifacts directory 44 | run: mkdir -p ${{env.ARTIFACTS_DIR}} 45 | - name: Save pull request number 46 | run: echo ${{github.event.pull_request.number}} > ${{env.ARTIFACTS_DIR}}/pr 47 | - name: Get current job logs URL 48 | uses: Tiryoh/gha-jobid-action@v1 49 | id: jobs 50 | with: 51 | github_token: ${{secrets.GITHUB_TOKEN}} 52 | job_name: ${{github.job}} 53 | - name: Create ASV machine config file 54 | run: asv machine --machine gh-runner --yes 55 | - name: Save comparison of PR against main branch 56 | run: | 57 | git remote add upstream https://github.com/${{github.repository}}.git 58 | git fetch upstream 59 | asv continuous upstream/main HEAD --verbose || true 60 | asv compare upstream/main HEAD --sort ratio --verbose | tee output 61 | python -m lf_asv_formatter --asv_version "$(asv --version | awk '{print $2}')" 62 | printf "\n\nClick [here]($STEP_URL) to view all benchmarks." >> output 63 | mv output ${{env.ARTIFACTS_DIR}} 64 | env: 65 | STEP_URL: ${{steps.jobs.outputs.html_url}}#step:10:1 66 | - name: Upload artifacts (PR number and benchmarks output) 67 | uses: actions/upload-artifact@v5 68 | with: 69 | name: benchmark-artifacts 70 | path: ${{env.ARTIFACTS_DIR}} -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | 2 | repos: 3 | # Compare the local template version to the latest remote template version 4 | # This hook should always pass. It will print a message if the local version 5 | # is out of date. 6 | - repo: https://github.com/lincc-frameworks/pre-commit-hooks 7 | rev: v0.1.2 8 | hooks: 9 | - id: check-lincc-frameworks-template-version 10 | name: Check template version 11 | description: Compare current template version against latest 12 | verbose: true 13 | # Clear output from jupyter notebooks so that only the input cells are committed. 14 | - repo: local 15 | hooks: 16 | - id: jupyter-nb-clear-output 17 | name: Clear output from Jupyter notebooks 18 | description: Clear output from Jupyter notebooks. 19 | files: \.ipynb$ 20 | exclude: ^docs/pre_executed 21 | stages: [pre-commit] 22 | language: system 23 | entry: jupyter nbconvert --clear-output 24 | # Prevents committing directly branches named 'main' and 'master'. 25 | - repo: https://github.com/pre-commit/pre-commit-hooks 26 | rev: v4.4.0 27 | hooks: 28 | - id: no-commit-to-branch 29 | name: Prevent main branch commits 30 | description: Prevent the user from committing directly to the primary branch. 31 | - id: check-added-large-files 32 | name: Check for large files 33 | description: Prevent the user from committing very large files. 34 | args: ['--maxkb=500'] 35 | # Verify that pyproject.toml is well formed 36 | - repo: https://github.com/abravalheri/validate-pyproject 37 | rev: v0.12.1 38 | hooks: 39 | - id: validate-pyproject 40 | name: Validate pyproject.toml 41 | description: Verify that pyproject.toml adheres to the established schema. 42 | # Verify that GitHub workflows are well formed 43 | - repo: https://github.com/python-jsonschema/check-jsonschema 44 | rev: 0.28.0 45 | hooks: 46 | - id: check-github-workflows 47 | args: ["--verbose"] 48 | - repo: https://github.com/astral-sh/ruff-pre-commit 49 | # Ruff version. 50 | rev: v0.13.3 51 | hooks: 52 | # Run the linter. 53 | - id: ruff-check 54 | types_or: [ python, pyi ] 55 | args: [ --fix ] 56 | # Run the formatter. 57 | - id: ruff-format 58 | types_or: [ python, pyi ] 59 | # Run unit tests, verify that they pass. Note that coverage is run against 60 | # the ./src directory here because that is what will be committed. In the 61 | # github workflow script, the coverage is run against the installed package 62 | # and uploaded to Codecov by calling pytest like so: 63 | # `python -m pytest --cov= --cov-report=xml` 64 | - repo: local 65 | hooks: 66 | - id: pytest-check 67 | name: Run unit tests 68 | description: Run unit tests with pytest. 69 | entry: bash -c "if python -m pytest --co -qq; then python -m pytest --cov=./src --cov-report=html; fi" 70 | language: system 71 | pass_filenames: false 72 | always_run: true 73 | -------------------------------------------------------------------------------- /tests/scarlet2/test_quickstart.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101 2 | # ruff: noqa: D102 3 | # ruff: noqa: D103 4 | # ruff: noqa: D106 5 | 6 | from functools import partial 7 | 8 | import jax.numpy as jnp 9 | import numpyro.distributions as dist 10 | from huggingface_hub import hf_hub_download 11 | from numpyro.distributions import constraints 12 | from numpyro.infer.initialization import init_to_sample 13 | 14 | from scarlet2 import init 15 | from scarlet2.frame import Frame 16 | from scarlet2.module import Parameter, Parameters, relative_step 17 | from scarlet2.observation import Observation 18 | from scarlet2.psf import ArrayPSF, GaussianPSF 19 | from scarlet2.scene import Scene 20 | from scarlet2.source import Source 21 | from scarlet2.validation_utils import set_validation 22 | 23 | # turn off automatic validation checks 24 | set_validation(False) 25 | 26 | filename = hf_hub_download( 27 | repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset" 28 | ) 29 | file = jnp.load(filename) 30 | data = jnp.asarray(file["images"]) 31 | channels = [str(f) for f in file["filters"]] 32 | centers = [(src["y"], src["x"]) for src in file["catalog"]] # Note: y/x convention! 33 | weights = jnp.asarray(1 / file["variance"]) 34 | psf = jnp.asarray(file["psfs"]) 35 | 36 | _ = GaussianPSF(0.7) 37 | obs = Observation(data, weights, psf=ArrayPSF(jnp.asarray(psf)), channels=channels) 38 | model_frame = Frame.from_observations(obs) 39 | 40 | with Scene(model_frame) as scene: 41 | for center in centers: 42 | center = jnp.array(center) 43 | try: 44 | spectrum, morph = init.from_gaussian_moments(obs, center, min_corr=0.99) 45 | except ValueError: 46 | spectrum = init.pixel_spectrum(obs, center) 47 | morph = init.compact_morphology() 48 | 49 | Source(center, spectrum, morph) 50 | 51 | scene_ = None 52 | 53 | 54 | def test_fit(): 55 | global scene_ 56 | spec_step = partial(relative_step, factor=0.05) 57 | 58 | # fitting 59 | with Parameters(scene) as parameters: 60 | for i in range(len(scene.sources)): 61 | Parameter( 62 | scene.sources[i].spectrum, 63 | name=f"spectrum:{i}", 64 | constraint=constraints.positive, 65 | stepsize=spec_step, 66 | ) 67 | Parameter( 68 | scene.sources[i].morphology, name=f"morph:{i}", constraint=constraints.positive, stepsize=0.1 69 | ) 70 | 71 | maxiter = 10 72 | scene_ = scene.fit(obs, parameters, max_iter=maxiter, progress_bar=False) 73 | 74 | 75 | def test_sample(): 76 | # using pre-optimized scene for better warm-up 77 | global scene_ 78 | # old style of parameter declaration, check backward compatibility 79 | parameters = scene_.make_parameters() 80 | p = scene_.sources[0].spectrum 81 | prior = dist.Normal(p, scale=1) 82 | parameters += Parameter(p, name="spectrum:0", prior=prior) 83 | 84 | _ = scene_.sample( 85 | obs, parameters, num_samples=10, dense_mass=True, init_strategy=init_to_sample, progress_bar=False 86 | ) 87 | 88 | 89 | if __name__ == "__main__": 90 | test_fit() 91 | test_sample() 92 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [project] 3 | name = "scarlet2" 4 | license = {file = "LICENSE"} 5 | readme = "README.md" 6 | authors = [ 7 | { name = "Peter Melchior", email = "peter.m.melchior@gmail.com" } 8 | ] 9 | classifiers = [ 10 | "Development Status :: 4 - Beta", 11 | "License :: OSI Approved :: MIT License", 12 | "Intended Audience :: Developers", 13 | "Intended Audience :: Science/Research", 14 | "Operating System :: OS Independent", 15 | "Programming Language :: Python", 16 | ] 17 | dynamic = ["version"] 18 | requires-python = ">=3.10" 19 | dependencies = [ 20 | "equinox", 21 | "jax==0.6.2", 22 | "astropy", 23 | "numpy", 24 | "matplotlib", 25 | "varname", 26 | "optax", 27 | "numpyro", 28 | "h5py", 29 | "ipywidgets", 30 | "pydantic", 31 | "pyyaml", 32 | "colorama", 33 | "markdown", 34 | "pygments", 35 | "jinja2" 36 | ] 37 | 38 | [project.urls] 39 | "Source Code" = "https://github.com/pmelchior/scarlet2" 40 | 41 | # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes) 42 | [project.optional-dependencies] 43 | dev = [ 44 | "asv==0.6.5", # Used to compute performance benchmarks 45 | "huggingface_hub", # Pulls down example data for testing 46 | "jupyter", # Clears output from Jupyter notebooks 47 | "pre-commit", # Used to run checks before finalizing a git commit 48 | "pytest", 49 | "pytest-cov", # Used to report total code coverage 50 | "pytest-mock", # Used to mock objects in tests 51 | "ruff", # Used for static linting of files 52 | ] 53 | 54 | [build-system] 55 | requires = [ 56 | "setuptools>=62", # Used to build and package the Python project 57 | "setuptools_scm>=6.2", # Gets release version from git. Makes it available programmatically 58 | ] 59 | build-backend = "setuptools.build_meta" 60 | 61 | [tool.setuptools_scm] 62 | write_to = "src/scarlet2/_version.py" 63 | 64 | [tool.pytest.ini_options] 65 | testpaths = [ 66 | "tests", 67 | ] 68 | addopts = "--doctest-modules --doctest-glob=*.rst" 69 | 70 | [tool.ruff] 71 | line-length = 110 72 | target-version = "py310" 73 | [tool.ruff.lint] 74 | select = [ 75 | # pycodestyle 76 | "E", 77 | "W", 78 | # Pyflakes 79 | "F", 80 | # pep8-naming 81 | "N", 82 | # pyupgrade 83 | "UP", 84 | # flake8-bugbear 85 | "B", 86 | # flake8-simplify 87 | "SIM", 88 | # isort 89 | "I", 90 | # docstrings 91 | "D101", 92 | "D102", 93 | "D103", 94 | "D106", 95 | "D206", 96 | "D207", 97 | "D208", 98 | "D300", 99 | "D417", 100 | "D419", 101 | # Numpy v2.0 compatibility 102 | "NPY201", 103 | ] 104 | ignore = [ 105 | "UP006", # Allow non standard library generics in type hints 106 | "UP007", # Allow Union in type hints 107 | "SIM114", # Allow if with same arms 108 | "B028", # Allow default warning level 109 | "SIM117", # Allow nested with 110 | "UP015", # Allow redundant open parameters 111 | "UP028", # Allow yield in for loop 112 | "UP038", # Allow tuple in `isinstance` 113 | "E731", # Allow assignment of lambda expressions 114 | ] 115 | 116 | 117 | [tool.coverage.run] 118 | omit=["src/scarlet2/_version.py"] 119 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Path setup -------------------------------------------------------------- 7 | 8 | # -- Project information ----------------------------------------------------- 9 | 10 | project = "scarlet2" 11 | copyright = "2025, Peter Melchior" 12 | author = "Peter Melchior" 13 | 14 | # -- General configuration --------------------------------------------------- 15 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 16 | 17 | extensions = [ 18 | "sphinx.ext.autodoc", 19 | "sphinx.ext.autosummary", 20 | "sphinx.ext.mathjax", 21 | "sphinx.ext.napoleon", 22 | "sphinx.ext.viewcode", 23 | "sphinx.ext.intersphinx", 24 | "sphinx.ext.doctest", 25 | "sphinx_issues", 26 | "myst_nb", 27 | ] 28 | master_doc = "index" 29 | source_suffix = { 30 | ".rst": "restructuredtext", 31 | ".ipynb": "myst-nb", 32 | } 33 | 34 | templates_path = ["_templates"] 35 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "jupyter_execute"] 36 | 37 | # -- Options for HTML output ------------------------------------------------- 38 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 39 | 40 | html_theme = "sphinx_book_theme" 41 | html_static_path = ["_static"] 42 | html_title = "scarlet2" 43 | html_favicon = "_static/icon.png" 44 | html_show_sourcelink = False 45 | html_theme_options = { 46 | "path_to_docs": "docs", 47 | "repository_url": "https://github.com/pmelchior/scarlet2", 48 | "repository_branch": "main", 49 | "logo": { 50 | "image_light": "_static/logo_light.svg", 51 | "image_dark": "_static/logo_dark.svg", 52 | }, 53 | "launch_buttons": {"colab_url": "https://colab.research.google.com"}, 54 | "use_edit_page_button": True, 55 | "use_issues_button": True, 56 | "use_repository_button": True, 57 | "use_download_button": True, 58 | "show_toc_level": 3, 59 | } 60 | html_baseurl = "https://scarlet2.readthedocs.io/en/latest/" 61 | 62 | autoclass_content = "both" 63 | autosummary_generate = True 64 | autosummary_imported_members = False 65 | autosummary_ignore_module_all = False 66 | autodoc_type_aliases = { 67 | "eqx.Module": "equinox.Module", 68 | "jnp.ndarray": "jax.numpy.array", 69 | } 70 | 71 | intersphinx_mapping = { 72 | "astropy": ("https://docs.astropy.org/en/stable/", None), 73 | "optax": ("https://optax.readthedocs.io/en/latest/", None), 74 | "numpyro": ("https://num.pyro.ai/en/stable/", None), 75 | } 76 | 77 | issues_github_path = "pmelchior/scarlet2" 78 | 79 | nb_execution_timeout = 60 80 | nb_execution_excludepatterns = ["_build", "jupyter_execute"] 81 | 82 | # Napoleon settings 83 | napoleon_google_docstring = False 84 | napoleon_numpy_docstring = True 85 | napoleon_include_init_with_doc = False 86 | napoleon_include_private_with_doc = False 87 | napoleon_include_special_with_doc = True 88 | napoleon_use_admonition_for_examples = True 89 | napoleon_use_admonition_for_notes = True 90 | napoleon_use_admonition_for_references = False 91 | napoleon_use_ivar = False 92 | napoleon_use_param = True 93 | napoleon_use_rtype = True 94 | napoleon_preprocess_types = True 95 | napoleon_type_aliases = None 96 | napoleon_attr_annotations = True 97 | -------------------------------------------------------------------------------- /tests/scarlet2/test_save_output.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101 2 | # ruff: noqa: D102 3 | # ruff: noqa: D103 4 | # ruff: noqa: D106 5 | 6 | import os 7 | 8 | import astropy.wcs as wcs 9 | import h5py 10 | import jax 11 | import jax.numpy as jnp 12 | from huggingface_hub import hf_hub_download 13 | 14 | from scarlet2 import * # noqa: F403 15 | from scarlet2 import init 16 | from scarlet2.bbox import Box 17 | from scarlet2.frame import Frame 18 | from scarlet2.io import model_from_h5, model_to_h5 19 | from scarlet2.observation import Observation 20 | from scarlet2.psf import GaussianPSF 21 | from scarlet2.scene import Scene 22 | from scarlet2.source import Source 23 | from scarlet2.validation_utils import set_validation 24 | 25 | # turn off automatic validation checks 26 | set_validation(False) 27 | 28 | 29 | def test_save_output(): 30 | filename = hf_hub_download( 31 | repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset" 32 | ) 33 | file = jnp.load(filename) 34 | data = jnp.asarray(file["images"]) 35 | centers = [(src["y"], src["x"]) for src in file["catalog"]] # Note: y/x convention! 36 | weights = jnp.asarray(1 / file["variance"]) 37 | psf = jnp.asarray(file["psfs"]) 38 | 39 | frame_psf = GaussianPSF(0.7) 40 | model_frame = Frame(Box(data.shape), psf=frame_psf) 41 | obs = Observation(data, weights, psf=psf) 42 | 43 | with Scene(model_frame) as scene: 44 | for center in centers: 45 | center = jnp.array(center) 46 | try: 47 | spectrum, morph = init.from_gaussian_moments(obs, center, min_corr=0.99) 48 | except ValueError: 49 | spectrum = init.pixel_spectrum(obs, center) 50 | morph = init.compact_morphology() 51 | 52 | Source(center, spectrum, morph) 53 | 54 | # save the output 55 | id = 1 56 | filename = "demo_io.h5" 57 | path = "stored_models" 58 | model_to_h5(scene, filename, id, path=path, overwrite=True) 59 | 60 | # demo that it works to add models to a single file 61 | id = 2 62 | model_to_h5(scene, filename, id, path=path, overwrite=True) 63 | 64 | # load files and show keys 65 | full_path = os.path.join(path, filename) 66 | with h5py.File(full_path, "r") as f: 67 | print(f.keys()) 68 | 69 | # print the output 70 | print(f"Output saved to {full_path}") 71 | # print the storage size 72 | print(f"Storage size: {os.path.getsize(full_path) / 1e6:.4f} MB") 73 | # load the output and plot the sources 74 | scene_loaded = model_from_h5(filename, id, path=path) 75 | print("Output loaded from h5 file") 76 | 77 | # compare scenes 78 | saved = jax.tree_util.tree_leaves(scene) 79 | loaded = jax.tree_util.tree_leaves(scene_loaded) 80 | status = True 81 | for leaf_saved, leaf_loaded in zip(saved, loaded, strict=False): 82 | if isinstance(leaf_saved, wcs.WCS): # wcs doesn't allow direct == comparison... 83 | if not leaf_saved.wcs.compare(leaf_loaded.wcs): 84 | status = False 85 | elif hasattr(leaf_saved, "__iter__"): 86 | if (leaf_saved != leaf_loaded).all(): 87 | status = False 88 | else: 89 | if leaf_saved != leaf_loaded: 90 | status = False 91 | 92 | print(f"saved == loaded: {status}") 93 | assert status, "Loaded leaves not identical to original" 94 | 95 | 96 | if __name__ == "__main__": 97 | test_save_output() 98 | -------------------------------------------------------------------------------- /src/scarlet2/validation_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Any 4 | 5 | from colorama import Back, Fore, Style 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | # A global switch that toggles automated validation checks. 10 | VALIDATION_SWITCH = True 11 | 12 | 13 | def set_validation(state: bool = True): 14 | """Set the global validation switch. 15 | 16 | Parameters 17 | ---------- 18 | state : bool, optional 19 | If True, validation checks will be automatically performed. If False, 20 | they will be skipped. Defaults to True. 21 | """ 22 | 23 | global VALIDATION_SWITCH 24 | VALIDATION_SWITCH = state 25 | logger.info(f"Automated validation checks are now {'enabled' if VALIDATION_SWITCH else 'disabled'}.") 26 | 27 | 28 | @dataclass 29 | class ValidationResult: 30 | """Represents a validation result. This is the base dataclass that all the 31 | more specific Validation dataclasses inherit from. Generally, it should 32 | not be instantiated directly, but rather through the more specific 33 | ValidationInfo, ValidationWarning, or ValidationError classes. 34 | """ 35 | 36 | message: str 37 | check: str 38 | context: Any | None = None 39 | 40 | def __str__(self): 41 | base = f"{self.message}" 42 | if self.context is not None: 43 | base += f" | Context={self.context})" 44 | return base 45 | 46 | 47 | @dataclass 48 | class ValidationInfo(ValidationResult): 49 | """Represents a validation info message that is informative but not critical.""" 50 | 51 | def __str__(self): 52 | return f"{Style.BRIGHT}{Fore.BLACK}{Back.GREEN} INFO {Style.RESET_ALL} {super().__str__()}" 53 | 54 | 55 | @dataclass 56 | class ValidationWarning(ValidationResult): 57 | """Represents a validation warning that is not critical but should be noted.""" 58 | 59 | def __str__(self): 60 | return f"{Style.BRIGHT}{Fore.BLACK}{Back.YELLOW} WARN {Style.RESET_ALL} {super().__str__()}" 61 | 62 | 63 | @dataclass 64 | class ValidationError(ValidationResult): 65 | """Represents a validation error that is critical and should be addressed.""" 66 | 67 | def __str__(self): 68 | return f"{Style.BRIGHT}{Fore.WHITE}{Back.RED} ERROR {Style.RESET_ALL} {super().__str__()}" 69 | 70 | 71 | class ValidationMethodCollector(type): 72 | """Metaclass that collects all validation methods in a class into a single list. 73 | For any class that uses this metaclass, all methods that start with "check_" 74 | will be automatically collected into a class attribute named `validation_checks`. 75 | """ 76 | 77 | def __new__(cls, name, bases, namespace): 78 | """Creates a list of callable methods when a new instances of a class is 79 | created.""" 80 | cls = super().__new__(cls, name, bases, namespace) 81 | cls.validation_checks = [ 82 | attr for attr, value in namespace.items() if callable(value) and attr.startswith("check_") 83 | ] 84 | return cls 85 | 86 | 87 | def print_validation_results(preamble: str, results: list[ValidationResult]): 88 | """Print the validation results in a formatted manner. 89 | 90 | Parameters 91 | ---------- 92 | preamble : str 93 | A string to print before the validation results. 94 | results : list[_ValidationResult] 95 | A list of validation results to print. 96 | """ 97 | if len(results) == 0: 98 | return 99 | print( 100 | f"{preamble}:\n" + "\n".join(f"[{str(i).zfill(3)}] {str(result)}" for i, result in enumerate(results)) 101 | ) 102 | -------------------------------------------------------------------------------- /benchmarks/asv.conf.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | // The version of the config file format. Do not change, unless 4 | // you know what you are doing. 5 | "version": 1, 6 | // The name of the project being benchmarked. 7 | "project": "scarlet2", 8 | // The project's homepage. 9 | "project_url": "https://github.com/pmelchior/scarlet2", 10 | // The URL or local path of the source code repository for the 11 | // project being benchmarked. 12 | "repo": "..", 13 | // List of branches to benchmark. If not provided, defaults to "master" 14 | // (for git) or "tip" (for mercurial). 15 | "branches": [ 16 | "HEAD" 17 | ], 18 | "install_command": [ 19 | "python -m pip install {wheel_file}" 20 | ], 21 | "build_command": [ 22 | "python -m build --wheel -o {build_cache_dir} {build_dir}" 23 | ], 24 | // The DVCS being used. If not set, it will be automatically 25 | // determined from "repo" by looking at the protocol in the URL 26 | // (if remote), or by looking for special directories, such as 27 | // ".git" (if local). 28 | "dvcs": "git", 29 | // The tool to use to create environments. May be "conda", 30 | // "virtualenv" or other value depending on the plugins in use. 31 | // If missing or the empty string, the tool will be automatically 32 | // determined by looking for tools on the PATH environment 33 | // variable. 34 | "environment_type": "virtualenv", 35 | // the base URL to show a commit for the project. 36 | "show_commit_url": "https://github.com/pmelchior/scarlet2/commit/", 37 | // The Pythons you'd like to test against. If not provided, defaults 38 | // to the current version of Python used to run `asv`. 39 | "pythons": [ 40 | "3.11" 41 | ], 42 | // The matrix of dependencies to test. Each key is the name of a 43 | // package (in PyPI) and the values are version numbers. An empty 44 | // list indicates to just test against the default (latest) 45 | // version. 46 | "matrix": { 47 | "Cython": [], 48 | "build": [], 49 | "packaging": [] 50 | }, 51 | // The directory (relative to the current directory) that benchmarks are 52 | // stored in. If not provided, defaults to "benchmarks". 53 | "benchmark_dir": ".", 54 | // The directory (relative to the current directory) to cache the Python 55 | // environments in. If not provided, defaults to "env". 56 | "env_dir": "env", 57 | // The directory (relative to the current directory) that raw benchmark 58 | // results are stored in. If not provided, defaults to "results". 59 | "results_dir": "_results", 60 | // The directory (relative to the current directory) that the html tree 61 | // should be written to. If not provided, defaults to "html". 62 | "html_dir": "_html", 63 | // The number of characters to retain in the commit hashes. 64 | // "hash_length": 8, 65 | // `asv` will cache wheels of the recent builds in each 66 | // environment, making them faster to install next time. This is 67 | // number of builds to keep, per environment. 68 | "build_cache_size": 8 69 | // The commits after which the regression search in `asv publish` 70 | // should start looking for regressions. Dictionary whose keys are 71 | // regexps matching to benchmark names, and values corresponding to 72 | // the commit (exclusive) after which to start looking for 73 | // regressions. The default is to start from the first commit 74 | // with results. If the commit is `null`, regression detection is 75 | // skipped for the matching benchmark. 76 | // 77 | // "regressions_first_commits": { 78 | // "some_benchmark": "352cdf", // Consider regressions only after this commit 79 | // "another_benchmark": null, // Skip regression detection altogether 80 | // } 81 | } -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # _scarlet2_ Documentation 2 | 3 | _scarlet2_ is an open-source python library for modeling astronomical sources from multi-band, multi-epoch, and 4 | multi-instrument data. It provides non-parametric and parametric models, can handle source overlap (aka blending), and 5 | can integrate neural network priors. It's designed to be modular, flexible, and powerful. 6 | 7 | _scarlet2_ is implemented in [jax](http://jax.readthedocs.io/), layered on top of 8 | the [equinox](https://docs.kidger.site/equinox/) 9 | library. It can be deployed to GPUs and TPUs and supports optimization and sampling approaches. 10 | 11 | ## Installation 12 | 13 | For performance reasons, you should first install `jax` with the suitable `jaxlib` for your platform. After that 14 | 15 | ``` 16 | pip install scarlet2 17 | ``` 18 | 19 | should do. If you want the latest development version, use 20 | 21 | ``` 22 | pip install git+https://github.com/pmelchior/scarlet2.git 23 | ``` 24 | 25 | This will allow you to evaluate source models and compute likelihoods of observed data, so you can run your own 26 | optimizer/sampler. If you want a fully fledged library out of the box, you need to install `optax`, `numpyro`, and 27 | `h5py` as well. 28 | 29 | ## Usage 30 | 31 | ```{toctree} 32 | :maxdepth: 2 33 | 34 | 0-quickstart 35 | 1-howto 36 | 2-questionnaire 37 | 3-validation 38 | api 39 | ``` 40 | 41 | ## Differences between _scarlet_ and _scarlet2_ 42 | 43 | [_scarlet_](https://pmelchior.github.io/scarlet/) was introduced by 44 | [Melchior et al. (2018)](https://doi.org/10.1016/j.ascom.2018.07.001) to solve the deblending problem for the Rubin 45 | Observatory. A stripped down version of it (developed by Fred Moolekamp) runs as part of the Rubin Observatory software 46 | stack and is used for their data releases. We now call this version _scarlet1_. 47 | 48 | _scarlet2_ follows very similar concepts. So, what's different? 49 | 50 | ### Model specification 51 | 52 | _scarlet1_ is designed for a specific purpose: deblending for the Rubin Observatory. That has implications for the 53 | quality and type of data it needs to work with. **_scarlet2_ is much more flexible to handle complex sources 54 | configurations**, e.g. strong-lensing systems, supernova host galaxies, transient sources, etc. 55 | 56 | This flexibility led us to carefully design a "language" to construct sources. It allows new source combinations and 57 | is more explicit about the parameters and their initialization. 58 | 59 | ### Compute 60 | 61 | Because some of the constraints in _scarlet1_ are expensive to evaluate, they are implemented in C++, which requires 62 | the installation of a lot of additional code just to get it to run. **_scarlet2_ is implemented entirely in jax.** 63 | A combination of `conda` and `pip` will get it installed. Unlike _scarlet1_, it will also run on GPUs and TPUs, and 64 | performs just-in-time compilation of the model evaluation. 65 | 66 | In addition, we can now interface with deep learning methods. In particular, we can employ neural networks as 67 | data-driven priors, which helps break the degeneracies that arise when multiple components need to be 68 | fit at the same time. 69 | 70 | ### Constraints 71 | 72 | _scarlet1_ uses constrained optimization to help with fitting degeneracies, but that requires non-standard 73 | (namely proximal) optimization because these constraints are not differentiable. 74 | That can lead to problems with calibration, but, more importantly, it prevents the use of gradient-based 75 | optimization or sampling. As a result, we could never calculate errors for _scarlet1_ models. 76 | **_scarlet2_ uses only constraints that can be differentiated.** It supports any continuous optimization or sampling 77 | method, including error estimates. 78 | 79 | ## Ideas, Questions or Problems? 80 | 81 | If you have any of those, head over to our [github repo](https://github.com/pmelchior/scarlet2/) and create an issue. 82 | -------------------------------------------------------------------------------- /tests/scarlet2/test_moments.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101 2 | # ruff: noqa: D102 3 | # ruff: noqa: D103 4 | # ruff: noqa: D106 5 | 6 | import copy 7 | 8 | import astropy.units as u 9 | import jax.numpy as jnp 10 | from astropy.wcs import WCS 11 | from numpy.testing import assert_allclose 12 | 13 | from scarlet2.frame import _flip_matrix, _rot_matrix 14 | from scarlet2.measure import Moments 15 | from scarlet2.morphology import GaussianMorphology 16 | 17 | # create a test image and measure moments 18 | t0 = 10 19 | ellipticity = jnp.array((0.3, 0.5)) 20 | morph = GaussianMorphology(size=t0, ellipticity=ellipticity) 21 | img = morph() 22 | g = Moments(component=img, N=2) 23 | 24 | # create trivial WCS for that image 25 | wcs = WCS(naxis=2) 26 | wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"] 27 | wcs.wcs.pc = jnp.diag(jnp.ones(2)) 28 | 29 | 30 | def test_measure_size(): 31 | assert_allclose(t0, g.size, rtol=1e-3) 32 | 33 | 34 | def test_measure_ellipticity(): 35 | assert_allclose(ellipticity, g.ellipticity, rtol=2e-3) 36 | 37 | 38 | def test_gaussian_from_moments(): 39 | # generate Gaussian from moments 40 | t = g.size 41 | ellipticity = g.ellipticity 42 | morph2 = GaussianMorphology(t, ellipticity) 43 | 44 | assert_allclose(img, morph2(), rtol=1e-2) 45 | 46 | 47 | def test_convolve_moments(): 48 | g_ = copy.deepcopy(g) 49 | tp = t0 / 2 50 | psf = GaussianMorphology(size=tp, ellipticity=ellipticity)() 51 | psf /= psf.sum() 52 | p = Moments(psf, N=2) 53 | g_.convolve(p) 54 | assert_allclose(g_.size, jnp.sqrt(t0**2 + tp**2), rtol=3e-4) 55 | assert_allclose(g_.ellipticity, ellipticity, rtol=2e-3) 56 | g_.deconvolve(p) 57 | assert_allclose(g_.size, t0, rtol=3e-4) 58 | assert_allclose(g_.ellipticity, ellipticity, rtol=2e-3) 59 | 60 | 61 | def test_rotate_moments(): 62 | # rotate moments counter-clockwise 30 deg 63 | g_ = copy.deepcopy(g) 64 | a = 30 * u.deg 65 | g_.rotate(a) 66 | 67 | # apply theoretical rotation to spin-2 vector 68 | ellipticity_ = ellipticity[0] + 1j * ellipticity[1] 69 | ellipticity_ *= jnp.exp(2j * a.to(u.rad).value) 70 | ellipticity_ = jnp.array((ellipticity_.real, ellipticity_.imag)) 71 | 72 | assert_allclose(g_.ellipticity, ellipticity_, rtol=2e-3) 73 | 74 | 75 | def test_resize_moments(): 76 | # resize the image 77 | c = 0.5 78 | morph2 = GaussianMorphology(size=t0 * c, ellipticity=ellipticity, shape=img.shape) 79 | g2 = Moments(morph2(), 2) 80 | g2.resize(1 / c) 81 | 82 | assert_allclose(g.size, g2.size, rtol=1e-3) 83 | 84 | 85 | def test_flip_moments(): 86 | img2 = jnp.fliplr(img) 87 | g2 = Moments(img2, 2) 88 | g2.fliplr() 89 | assert_allclose(g.ellipticity, g2.ellipticity, rtol=1e-3) 90 | 91 | img3 = jnp.flipud(img) 92 | g3 = Moments(img3, 2) 93 | g3.flipud() 94 | assert_allclose(g.ellipticity, g3.ellipticity, rtol=1e-3) 95 | 96 | 97 | def test_wcs_transfer_moments_rot90(): 98 | # create a 90 deg counter-clockwise version of the morph image 99 | im_ = jnp.rot90(img, axes=(1, 0)) 100 | g_ = Moments(im_) 101 | 102 | # create mock WCS for that image 103 | wcs_ = WCS(naxis=2) 104 | wcs_.wcs.ctype = ["RA---TAN", "DEC--TAN"] 105 | phi = (-90 * u.deg).to(u.rad).value # clockwise to counteract the rotation above 106 | wcs_.wcs.pc = _rot_matrix(phi) 107 | 108 | # match WCS 109 | g_.transfer(wcs_, wcs) 110 | 111 | # Check that size and ellipticity are conserved 112 | assert_allclose(g_.size, g.size, rtol=1e-3) 113 | assert_allclose(g_.ellipticity, g.ellipticity, rtol=1e-2) 114 | 115 | 116 | def test_wcs_transfer_moments(): 117 | # create a rotated, resized, flipped version of the morph image 118 | # apply theoretical rotation to spin-2 vector 119 | a = (30 * u.deg).to(u.rad).value 120 | ellipticity_ = ellipticity[0] + 1j * ellipticity[1] 121 | ellipticity_ *= jnp.exp(2j * a) 122 | ellipticity_ = jnp.array((ellipticity_.real, ellipticity_.imag)) 123 | c = 0.5 124 | morph2 = GaussianMorphology(size=t0 * c, ellipticity=ellipticity_, shape=img.shape) 125 | # note order: rescale+rotate, then flip. 126 | im_ = jnp.flipud(morph2()) 127 | g_ = Moments(im_) 128 | 129 | # create mock WCS for that image 130 | wcs_ = WCS(naxis=2) 131 | wcs_.wcs.ctype = ["RA---TAN", "DEC--TAN"] 132 | # because of the image creation above: altered order (rotation, then flip) 133 | # this is unusual because for standard WCS, we correct handedness first, then rotate 134 | wcs_.wcs.pc = 1 / c * (_flip_matrix(-1) @ _rot_matrix(-a)) 135 | 136 | # match WCS 137 | g_.transfer(wcs_, wcs) 138 | 139 | # Check that size and ellipticity are conserved 140 | assert_allclose(g_.size, g.size, rtol=1e-3) 141 | assert_allclose(g_.ellipticity, g.ellipticity, rtol=1e-2) 142 | -------------------------------------------------------------------------------- /src/scarlet2/validation.py: -------------------------------------------------------------------------------- 1 | from .module import ParameterValidator 2 | from .observation import ObservationValidator 3 | from .scene import FitValidator 4 | from .source import SourceValidator 5 | from .validation_utils import ValidationResult 6 | 7 | 8 | def _check(validation_class, **kwargs) -> list[ValidationResult]: 9 | """Check the object against the validation rules defined in the validation_class. 10 | 11 | Parameters 12 | ---------- 13 | validation_class : type 14 | The class containing the validation checks. 15 | **kwargs : dict 16 | Keyword arguments to pass to the validation class constructor. These should be 17 | the inputs required by the validation classes, such as `scene`, `observation`, 18 | or `source`. 19 | 20 | Returns 21 | ------- 22 | list[ValidationResult] 23 | A list of validation results returned from the validation checks for the 24 | given object. 25 | """ 26 | validator = validation_class(**kwargs) 27 | validation_results = [] 28 | for check in validator.validation_checks: 29 | if error := getattr(validator, check)(): 30 | if isinstance(error, list): 31 | validation_results.extend(error) 32 | else: 33 | validation_results.append(error) 34 | 35 | return validation_results 36 | 37 | 38 | def check_fit(scene, observation) -> list[ValidationResult]: 39 | """Check the scene after fitting against the various validation rules. 40 | 41 | Parameters 42 | ---------- 43 | scene : Scene 44 | The scene object to check. 45 | observation : Observation 46 | The observation object to use for checks. 47 | 48 | Returns 49 | ------- 50 | list[ValidationResult] 51 | A list of validation results returned from the validation checks for the 52 | scene fit results. 53 | """ 54 | 55 | return _check(validation_class=FitValidator, **{"scene": scene, "observation": observation}) 56 | 57 | 58 | def check_observation(observation) -> list[ValidationResult]: 59 | """Check the observation object for consistency 60 | 61 | Parameters 62 | ---------- 63 | observation: Observation 64 | The observation object to check. 65 | 66 | Returns 67 | ------- 68 | list[ValidationResult] 69 | A list of validation results from the validation check of the observation 70 | object. 71 | """ 72 | 73 | return _check(validation_class=ObservationValidator, **{"observation": observation}) 74 | 75 | 76 | def check_scene(scene) -> list[ValidationResult]: 77 | """Check the scene against the various validation rules. 78 | 79 | Parameters 80 | ---------- 81 | scene : Scene 82 | The scene object to check. 83 | 84 | Returns 85 | ------- 86 | list[ValidationResult] 87 | A list of validation results from the validation checks of the scene. 88 | """ 89 | 90 | validation_results = [] 91 | for source in scene.sources: 92 | validation_results.extend(check_source(source)) 93 | 94 | return validation_results 95 | 96 | 97 | def check_source(source) -> list[ValidationResult]: 98 | """Check the source against the various validation rules. 99 | 100 | Parameters 101 | ---------- 102 | source : Source 103 | The source object to check. 104 | scene : Scene 105 | The scene that the source is part of. 106 | 107 | Returns 108 | ------- 109 | list[ValidationResult] 110 | A list of validation results from the source object checks. 111 | """ 112 | 113 | return _check(validation_class=SourceValidator, **{"source": source}) 114 | 115 | 116 | def check_parameters(parameters) -> list[ValidationResult]: 117 | """Check the parameter list against the various validation rules. 118 | 119 | Parameters 120 | ---------- 121 | parameters : Parameters 122 | The parameters list to check 123 | 124 | Returns 125 | ------- 126 | list[ValidationResult] 127 | A list of validation results from the validation checks of the parameters. 128 | """ 129 | 130 | validation_results = [] 131 | for p in parameters: 132 | validation_results.extend(check_parameter(p)) 133 | 134 | return validation_results 135 | 136 | 137 | def check_parameter(parameter) -> list[ValidationResult]: 138 | """Check the parameter against the various validation rules. 139 | 140 | Parameters 141 | ---------- 142 | parameter : Parameter 143 | The parameter to check 144 | 145 | Returns 146 | ------- 147 | list[ValidationResult] 148 | A list of validation results from the validation checks of the parameters. 149 | """ 150 | 151 | return _check(validation_class=ParameterValidator, **{"parameter": parameter}) 152 | -------------------------------------------------------------------------------- /tests/scarlet2/conftest.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import pytest 6 | from huggingface_hub import hf_hub_download 7 | from numpyro.distributions import constraints 8 | 9 | from scarlet2 import init 10 | from scarlet2.frame import Frame 11 | from scarlet2.module import Parameter, Parameters, relative_step 12 | from scarlet2.observation import Observation 13 | from scarlet2.psf import ArrayPSF 14 | from scarlet2.scene import Scene 15 | from scarlet2.source import PointSource, Source 16 | 17 | 18 | @pytest.fixture() 19 | def data_file(): 20 | """Download and load a realistic test file. This is the same data used in the 21 | quickstart notebook. The data will be manipulated to create invalid inputs for 22 | the `bad_obs` fixture.""" 23 | filename = hf_hub_download( 24 | repo_id="astro-data-lab/scarlet-test-data", filename="hsc_cosmos_35.npz", repo_type="dataset" 25 | ) 26 | return jnp.load(filename) 27 | 28 | 29 | @pytest.fixture() 30 | def good_obs(data_file): 31 | """Create an observation that should pass all validation checks.""" 32 | data = jnp.asarray(data_file["images"]) 33 | channels = [str(f) for f in data_file["filters"]] 34 | weights = jnp.asarray(1 / data_file["variance"]) 35 | psf = jnp.asarray(data_file["psfs"]) 36 | 37 | return Observation( 38 | data=data, 39 | weights=weights, 40 | channels=channels, 41 | psf=ArrayPSF(psf), 42 | ) 43 | 44 | 45 | @pytest.fixture() 46 | def bad_obs(data_file): 47 | """Create an observation that should fail multiple validation checks.""" 48 | 49 | data = np.asarray(data_file["images"]) 50 | channels = [str(f) for f in data_file["filters"]] 51 | weights = np.asarray(1 / data_file["variance"]) 52 | psf = np.asarray(data_file["psfs"]) 53 | 54 | weights = weights[:-1] # Remove the last weight to create a mismatch in dimensions 55 | weights[0][0] = np.inf # Set one weight to infinity 56 | weights[1][0] = -1.0 # Set one weight to a negative value 57 | psf = psf[:-1] # Remove the last PSF to create a mismatch in dimensions 58 | psf = psf[0] + 0.001 59 | 60 | return Observation( 61 | data=data, 62 | weights=weights, 63 | channels=channels, 64 | psf=ArrayPSF(psf), 65 | ) 66 | 67 | 68 | @pytest.fixture() 69 | def good_source(good_obs, data_file): 70 | """Assemble a source from the good observation and the data file.""" 71 | model_frame = Frame.from_observations(good_obs) 72 | with Scene(model_frame) as _: 73 | centers = jnp.array([(src["y"], src["x"]) for src in data_file["catalog"]]) # Note: y/x convention! 74 | spectrum, morph = init.from_gaussian_moments(good_obs, centers[0], min_corr=0.99) 75 | return Source(centers[0], spectrum, morph) 76 | 77 | 78 | @pytest.fixture() 79 | def bad_source(good_obs, data_file): 80 | """Assemble a source from the bad observation and the data file.""" 81 | model_frame = Frame.from_observations(good_obs) 82 | with Scene(model_frame) as _: 83 | centers = jnp.array([(src["y"], src["x"]) for src in data_file["catalog"]]) # Note: y/x convention! 84 | spectrum, morph = init.from_gaussian_moments(good_obs, centers[0], min_corr=0.99) 85 | return Source(centers[0], spectrum * -1, morph) 86 | 87 | 88 | @pytest.fixture() 89 | def scene(good_obs, data_file): 90 | """Assemble a scene from the good observation and the data file.""" 91 | model_frame = Frame.from_observations(good_obs) 92 | centers = jnp.array([(src["y"], src["x"]) for src in data_file["catalog"]]) # Note: y/x convention! 93 | 94 | with Scene(model_frame) as scene: 95 | for i, center in enumerate(centers): 96 | if i == 0: # we know source 0 is a star 97 | spectrum = init.pixel_spectrum(good_obs, center, correct_psf=True) 98 | PointSource(center, spectrum) 99 | else: 100 | try: 101 | spectrum, morph = init.from_gaussian_moments(good_obs, center, min_corr=0.99) 102 | except ValueError: 103 | spectrum = init.pixel_spectrum(good_obs, center) 104 | morph = init.compact_morphology() 105 | 106 | Source(center, spectrum, morph) 107 | 108 | return scene 109 | 110 | 111 | @pytest.fixture() 112 | def parameters(scene, good_obs): 113 | """Create parameters for the scene.""" 114 | spec_step = partial(relative_step, factor=0.05) 115 | morph_step = partial(relative_step, factor=1e-3) 116 | 117 | with Parameters(scene) as parameters: 118 | for i in range(len(scene.sources)): 119 | Parameter( 120 | scene.sources[i].spectrum, 121 | name=f"spectrum:{i}", 122 | constraint=constraints.positive, 123 | stepsize=spec_step, 124 | ) 125 | if i == 0: 126 | Parameter(scene.sources[i].center, name=f"center:{i}", stepsize=0.1) 127 | else: 128 | Parameter( 129 | scene.sources[i].morphology, 130 | name=f"morph:{i}", 131 | constraint=constraints.unit_interval, 132 | stepsize=morph_step, 133 | ) 134 | 135 | return parameters 136 | -------------------------------------------------------------------------------- /src/scarlet2/nn.py: -------------------------------------------------------------------------------- 1 | # To be removed as part of issue #168 2 | # ruff: noqa: D101 3 | # ruff: noqa: D102 4 | # ruff: noqa: D103 5 | # ruff: noqa: D106 6 | 7 | """Neural network priors""" 8 | 9 | from functools import partial 10 | 11 | try: 12 | import numpyro.distributions as dist 13 | import numpyro.distributions.constraints as constraints 14 | 15 | except ImportError as err: 16 | raise ImportError("scarlet2.nn requires numpyro.") from err 17 | 18 | import equinox as eqx 19 | import jax.numpy as jnp 20 | from jax import custom_vjp, vjp 21 | 22 | 23 | def pad_fwd(x, model_shape): 24 | """Zero-pads the input image to the model size 25 | 26 | Parameters 27 | ---------- 28 | x : jnp.array 29 | data to be padded 30 | model_shape : tuple 31 | shape of the prior model to be used 32 | 33 | Returns 34 | ------- 35 | x : jnp.array 36 | data padded to same size as model_shape 37 | pad: tuple 38 | padding amount in every dimension 39 | """ 40 | assert all(model_shape[d] >= x.shape[d] for d in range(x.ndim)), ( 41 | "Model size must be larger than data size" 42 | ) 43 | if model_shape == x.shape: 44 | pad = 0 45 | return x, pad 46 | 47 | pad = tuple( 48 | # even padding 49 | (int(gap / 2), int(gap / 2)) 50 | if (gap := model_shape[d] - x.shape[d]) % 2 == 0 51 | # uneven padding 52 | else (int(gap // 2), int(gap // 2) + 1) 53 | # over all dimensions 54 | for d in range(x.ndim) 55 | ) 56 | # perform the zero-padding 57 | x = jnp.pad(x, pad, "constant", constant_values=0) 58 | return x, pad 59 | 60 | 61 | # reverse pad back to original size 62 | def pad_back(x, pad): 63 | """Removes the zero-padding from the input image 64 | 65 | Parameters 66 | --------- 67 | x : jnp.array 68 | padded data to same size as model_shape 69 | pad: tuple 70 | padding amount in every dimension 71 | 72 | Returns 73 | ------- 74 | x : jnp.array 75 | data returned to it pre-pad shape 76 | """ 77 | slices = tuple(slice(low, -hi) if hi > 0 else slice(low, None) for (low, hi) in pad) 78 | return x[slices] 79 | 80 | 81 | # calculate score function (jacobian of log-probability) 82 | def calc_grad(x, model): 83 | """Calculates the gradient of the log-prior 84 | using the ScoreNet model chosen 85 | 86 | Parameters 87 | ---------- 88 | x: 89 | array of the data 90 | model: 91 | the model to calculate the score function 92 | 93 | Returns 94 | ------- 95 | score_func : array of the score function 96 | """ 97 | # cast to float32, expand to (batch, shape), and pad to match the shape of the score model 98 | x_, pad = pad_fwd(jnp.float32(x), model.shape) 99 | 100 | # run score model, expects (batch, shape) 101 | if jnp.ndim(x) == len(model.shape): 102 | x_ = jnp.expand_dims(x_, axis=0) 103 | score_func = model.func(x_) 104 | if jnp.ndim(x) == len(model.shape): 105 | score_func = jnp.squeeze(score_func, axis=0) 106 | 107 | # remove padding 108 | if pad != 0: 109 | score_func = pad_back(score_func, pad) 110 | return score_func 111 | 112 | 113 | # jax gradient function to calculate jacobian 114 | def vgrad(f, x): 115 | y, vjp_fn = vjp(f, x) 116 | return vjp_fn(jnp.ones(y.shape))[0] 117 | 118 | 119 | # Here we define a custom vjp for the log_prob function 120 | # such that for gradient calls in jax, the score prior 121 | # is returned 122 | 123 | 124 | @partial(custom_vjp, nondiff_argnums=(0,)) 125 | def _log_prob(model, x): 126 | return 0.0 127 | 128 | 129 | def _log_prob_fwd(model, x): 130 | score_func = calc_grad(x, model) 131 | return 0.0, score_func # cannot directly call log_prob in Class object 132 | 133 | 134 | def _log_prob_bwd(model, res, g): 135 | score_func = res # Get residuals computed in f_fwd 136 | return (g * score_func,) # create the vector (g) jacobian (score_func) product 137 | 138 | 139 | # register the custom vjp 140 | _log_prob.defvjp(_log_prob_fwd, _log_prob_bwd) 141 | 142 | 143 | class ScorePrior(dist.Distribution): 144 | class ScoreWrapper(eqx.Module): 145 | func: callable 146 | shape: tuple 147 | 148 | support = constraints.real_vector 149 | _model = ScoreWrapper 150 | 151 | def __init__(self, model, shape, *args, **kwargs): 152 | """Score-matching neural network to represent the prior distribution 153 | 154 | This class is used to calculate the gradient of the log-probability of the prior distribution. 155 | A custom vjp is created to return the score when calling `jax.grad()`. 156 | 157 | Parameters 158 | ---------- 159 | model: callable 160 | Returns the score value given parameter: `model(x) -> score` 161 | shape: tuple 162 | Shape of the parameter the model can accept 163 | *args: tuple 164 | List of unnamed parameter for model, e.g. `model(x, *args) -> score` 165 | **kwargs: dict 166 | List of named parameter for model, e.g. `model(x, **kwargs) -> score` 167 | """ 168 | # helper class that ensures the model function binds the args/kwargs and has a shape 169 | wrapper = ScorePrior.ScoreWrapper(partial(model.__call__, *args, **kwargs), shape) 170 | self._model = wrapper 171 | 172 | super().__init__( 173 | validate_args=None, 174 | ) 175 | 176 | def __call__(self, x): 177 | return self._model.func(x) 178 | 179 | def sample(self, key, sample_shape=()): 180 | # TODO: add ability to draw samples from the prior, if desired 181 | raise NotImplementedError 182 | 183 | def mean(self): 184 | raise NotImplementedError 185 | 186 | def log_prob(self, x): 187 | return _log_prob(self._model, x) 188 | -------------------------------------------------------------------------------- /src/scarlet2/spectrum.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | import jax.numpy as jnp 3 | 4 | from . import Scenery 5 | from .module import Module 6 | 7 | 8 | class Spectrum(Module): 9 | """Spectrum base class""" 10 | 11 | @property 12 | def shape(self): 13 | """Shape (1D) of the spectrum model""" 14 | raise NotImplementedError 15 | 16 | 17 | class StaticArraySpectrum(Spectrum): 18 | """Static (non-variable) source in a transient scene 19 | 20 | In the frames of transient scenes, the attribute :py:attr:`~scarlet2.Frame.channels` 21 | are overloaded and defined with a spectral and a temporal component, e.g. 22 | `channel = (band, epoch)`. 23 | This class is for models that do not vary in time, i.e. only have a spectral dependency. 24 | The length of :py:attr:`data` is thus given by the number of distinct spectral bands. 25 | """ 26 | 27 | data: jnp.array 28 | """Data to describe the static spectrum 29 | 30 | The order in this array should be given by :py:attr:`bands`. 31 | """ 32 | bands: list 33 | """Identifier for the list of unique bands in the model frame channels""" 34 | _channelindex: jnp.array = eqx.field(repr=False) 35 | 36 | def __init__(self, data, bands, band_selector=lambda channel: channel[0]): 37 | """ 38 | Parameters 39 | ---------- 40 | data: array 41 | Spectrum without temporal variation. Contains as many elements as there 42 | are spectral channels in the model. 43 | bands: list, array 44 | Identifier for the list of unique bands in the model frame channels 45 | band_selector: callable, optional 46 | Identify the spectral "band" component from the name/ID used in the 47 | channels of the model frame 48 | 49 | Examples 50 | -------- 51 | >>> # model channels: [('G',0),('G',1),('R',0),('R',1),('R',2)] 52 | >>> spectrum = jnp.ones(2) 53 | >>> bands = ['G','R'] 54 | >>> band_selector = lambda channel: channel[0] 55 | >>> StaticArraySpectrum(spectrum, bands, band_selector=band_selector) 56 | 57 | This constructs a 2-element spectrum to describe the spectral properties 58 | in all epochs 0,1,2. 59 | 60 | See Also 61 | -------- 62 | TransientArraySpectrum 63 | """ 64 | try: 65 | frame = Scenery.scene.frame 66 | except AttributeError: 67 | print("Source can only be created within the context of a Scene") 68 | print("Use 'with Scene(frame) as scene: Source(...)'") 69 | raise 70 | 71 | self.data = data 72 | self.bands = bands 73 | self._channelindex = jnp.array([self.bands.index(band_selector(c)) for c in frame.channels]) 74 | 75 | def __call__(self): 76 | """What to run when the StaticArraySpectrum is called""" 77 | return self.data[self._channelindex] 78 | 79 | @property 80 | def shape(self): 81 | """The shape of the spectrum data""" 82 | return (len(self._channelindex),) 83 | 84 | 85 | class TransientArraySpectrum(Spectrum): 86 | """Variable source in a transient scene with possible quiescent periods 87 | 88 | In the frames of transient scenes, the attribute :py:attr:`~scarlet2.Frame.channels` 89 | are overloaded and defined with a spectral and a temporal component, e.g. 90 | `channel = (band, epoch)`. This class is for models that vary in time, especially 91 | if they have periods of inactivity. The length of :py:attr:`data` is given by 92 | the number channels in the model frame, but during inactive epochs, the emission 93 | is set to zero. 94 | """ 95 | 96 | data: jnp.array 97 | """Data to describe the variable spectrum. 98 | 99 | The length of this vector is identical to the number of channels in the model frame. 100 | """ 101 | epochs: list 102 | """Identifier for the list of active epochs. If set to `None`, all epochs are 103 | considered active""" 104 | _epochmultiplier: jnp.array = eqx.field(repr=False) 105 | 106 | def __init__(self, data, epochs=None, epoch_selector=lambda channel: channel[1]): 107 | """ 108 | Parameters 109 | ---------- 110 | data: array 111 | Spectrum array. Contains as many elements as there are spectro-temporal 112 | channels in the model. 113 | epochs: list, array, optional 114 | List of temporal "epoch" identifiers for the active phases of the source. 115 | epoch_selector: callable, optional 116 | Identify the temporal "epoch" component from the name/ID used in the 117 | channels of the model frame 118 | 119 | Examples 120 | -------- 121 | >>> # model channels: [('G',0),('G',1),('R',0),('R',1),('R',2)] 122 | >>> spectrum = jnp.ones(5) 123 | >>> epochs = [0, 1] 124 | >>> epoch_selector = lambda channel: channel[1] 125 | >>> TransientArraySpectrum(spectrum, epochs, epoch_selector=epoch_selector) 126 | 127 | This sets the spectrum to active during epochs 0 and 1, and mask the 128 | spectrum element for `('R',2)` with zero. 129 | 130 | See Also 131 | -------- 132 | StaticArraySpectrum 133 | """ 134 | try: 135 | frame = Scenery.scene.frame 136 | except AttributeError: 137 | print("Source can only be created within the context of a Scene") 138 | print("Use 'with Scene(frame) as scene: Source(...)'") 139 | raise 140 | self.data = data 141 | self.epochs = epochs 142 | self._epochmultiplier = jnp.array( 143 | [1.0 if epoch_selector(c) in epochs else 0.0 for c in frame.channels] 144 | ) 145 | 146 | def __call__(self): 147 | """What to run when the TransientArraySpectrum is called""" 148 | return jnp.multiply(self.data, self._epochmultiplier) 149 | 150 | @property 151 | def shape(self): 152 | """The shape of the spectrum data""" 153 | return self.data.shape 154 | -------------------------------------------------------------------------------- /tests/scarlet2/test_observation_checks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from scarlet2.observation import Observation, ObservationValidator 4 | from scarlet2.psf import ArrayPSF 5 | from scarlet2.validation_utils import ValidationError, ValidationInfo, set_validation 6 | 7 | 8 | @pytest.fixture(autouse=True) 9 | def setup_validation(): 10 | """Automatically disable validation for all tests. This permits the creation 11 | of intentionally invalid Observation objects.""" 12 | set_validation(False) 13 | 14 | 15 | def test_weights_non_negative_returns_error(bad_obs): 16 | """Test that the weights in the observation are non-negative.""" 17 | checker = ObservationValidator(bad_obs) 18 | 19 | results = checker.check_weights_non_negative() 20 | 21 | assert isinstance(results, ValidationError) 22 | 23 | 24 | def test_weights_finite_returns_error(bad_obs): 25 | """Test that the weights in the observation are finite.""" 26 | checker = ObservationValidator(bad_obs) 27 | 28 | results = checker.check_weights_finite() 29 | 30 | assert isinstance(results, ValidationError) 31 | 32 | 33 | def test_weights_non_negative_returns_none(good_obs): 34 | """Test that the weights in the observation are non-negative.""" 35 | checker = ObservationValidator(good_obs) 36 | 37 | results = checker.check_weights_non_negative() 38 | 39 | assert isinstance(results, ValidationInfo) 40 | 41 | 42 | def test_weights_finite_returns_none(good_obs): 43 | """Test that the weights in the observation are finite.""" 44 | checker = ObservationValidator(good_obs) 45 | 46 | results = checker.check_weights_finite() 47 | 48 | assert isinstance(results, ValidationInfo) 49 | 50 | 51 | def test_data_and_weights_shape_returns_error(bad_obs): 52 | """Test that the data and weights have the same shape.""" 53 | checker = ObservationValidator(bad_obs) 54 | 55 | results = checker.check_data_and_weights_shape() 56 | 57 | assert isinstance(results, ValidationError) 58 | assert results.message == "Data and weights must have the same shape." 59 | 60 | 61 | def test_data_and_weights_shape_returns_none(good_obs): 62 | """Test that the data and weights have the same shape.""" 63 | checker = ObservationValidator(good_obs) 64 | 65 | results = checker.check_data_and_weights_shape() 66 | 67 | assert isinstance(results, ValidationInfo) 68 | 69 | 70 | def test_num_channels_matches_data_returns_none(good_obs): 71 | """Test that the number of channels in the observation matches the data.""" 72 | checker = ObservationValidator(good_obs) 73 | 74 | results = checker.check_num_channels_matches_data() 75 | 76 | assert isinstance(results, ValidationInfo) 77 | 78 | 79 | def test_data_finite_for_non_zero_weights_returns_none(good_obs): 80 | """Test that the data in the observation is finite where weights are greater than zero.""" 81 | checker = ObservationValidator(good_obs) 82 | 83 | results = checker.check_data_finite_for_non_zero_weights() 84 | 85 | assert isinstance(results, ValidationInfo) 86 | 87 | 88 | def test_data_finite_for_non_zero_weights_returns_none_with_infinity(): 89 | """Test that non-finite data does not raise an error when weights are zero.""" 90 | obs = Observation( 91 | data=np.array([[np.inf, 2.0, 3.0], [4.0, 5.0, 6.0]]), 92 | weights=np.array([[0.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 93 | channels=[0, 1], 94 | ) 95 | 96 | checker = ObservationValidator(obs) 97 | results = checker.check_data_finite_for_non_zero_weights() 98 | 99 | assert isinstance(results, ValidationInfo) 100 | 101 | 102 | def test_data_finite_for_non_zero_weights_returns_error_with_infinity(): 103 | """Test that non-finite data raises an error when weights are non-zero.""" 104 | obs = Observation( 105 | data=np.array([[np.inf, np.inf, 3.0], [4.0, 5.0, 6.0]]), 106 | weights=np.array([[0.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 107 | channels=[0, 1], 108 | ) 109 | 110 | checker = ObservationValidator(obs) 111 | results = checker.check_data_finite_for_non_zero_weights() 112 | 113 | assert isinstance(results, ValidationError) 114 | assert results.message == "Data in the observation must be finite." 115 | 116 | 117 | def test_number_of_psf_channels(good_obs): 118 | """Test that the number of PSF channels matches the observation channels.""" 119 | checker = ObservationValidator(good_obs) 120 | 121 | results = checker.check_number_of_psf_channels() 122 | assert isinstance(results, ValidationInfo) 123 | 124 | 125 | def test_number_of_psf_channels_returns_error(bad_obs): 126 | """Test that the number of PSF channels does not match the observation channels.""" 127 | checker = ObservationValidator(bad_obs) 128 | 129 | results = checker.check_psf_has_3_dimensions() 130 | 131 | assert isinstance(results, ValidationError) 132 | assert results.message == "PSF must be 3-dimensional." 133 | 134 | 135 | def test_validation_on_runs_observation_checks(capsys): 136 | """Set auto-validation to True, expect that the non-finite data raises an error 137 | when weights are non-zero.""" 138 | set_validation(True) 139 | _ = Observation( 140 | data=np.array([[np.inf, np.inf, 3.0], [4.0, 5.0, 6.0]]), 141 | weights=np.array([[0.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 142 | channels=[0, 1], 143 | ) 144 | captured = capsys.readouterr() 145 | assert "Observation validation results:" in captured.out 146 | 147 | 148 | def test_psf_centroid_inconsistent_returns_error(data_file): 149 | """Test that when the PSF centroids are not consistent across bands, an error 150 | is raised.""" 151 | 152 | data = np.asarray(data_file["images"]) 153 | channels = [str(f) for f in data_file["filters"]] 154 | weights = np.asarray(1 / data_file["variance"]) 155 | psf = np.asarray(data_file["psfs"]) 156 | psf[0][1:, :] = psf[0][:-1, :] # Shift the PSF to create an inconsistency 157 | 158 | bad_obs = Observation( 159 | data=data, 160 | weights=weights, 161 | channels=channels, 162 | psf=ArrayPSF(psf), 163 | ) 164 | 165 | checker = ObservationValidator(bad_obs) 166 | results = checker.check_psf_centroid_consistent() 167 | 168 | assert isinstance(results, ValidationError) 169 | assert results.message == "PSF centroid is not the same in all channels." 170 | 171 | 172 | def test_psf_centroid_consistent_no_error(data_file): 173 | """Test that when the PSF centroids are consistent across bands, no errors 174 | are raised.""" 175 | 176 | data = np.asarray(data_file["images"]) 177 | channels = [str(f) for f in data_file["filters"]] 178 | weights = np.asarray(1 / data_file["variance"]) 179 | psf = np.zeros_like(data) 180 | psf[:, 1, 1] = 1 181 | 182 | bad_obs = Observation( 183 | data=data, 184 | weights=weights, 185 | channels=channels, 186 | psf=ArrayPSF(psf), 187 | ) 188 | 189 | checker = ObservationValidator(bad_obs) 190 | 191 | results = checker.check_psf_centroid_consistent() 192 | assert isinstance(results, ValidationInfo) 193 | -------------------------------------------------------------------------------- /tests/scarlet2/test_renderer.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D101 2 | # ruff: noqa: D102 3 | # ruff: noqa: D103 4 | # ruff: noqa: D106 5 | 6 | import astropy.units as u 7 | import jax.numpy as jnp 8 | from numpy.testing import assert_allclose 9 | 10 | import scarlet2 11 | from scarlet2 import ArrayPSF 12 | from scarlet2.frame import _flip_matrix, _rot_matrix, _wcs_default, get_affine 13 | from scarlet2.measure import Moments 14 | from scarlet2.morphology import GaussianMorphology 15 | 16 | # resampling renderer to test 17 | cls = scarlet2.renderer.ResamplingRenderer 18 | 19 | # create a Gaussian as model 20 | T = 10 21 | eps = jnp.array((0.5, 0.3)) 22 | model = GaussianMorphology(size=T, ellipticity=eps, shape=(150, 151))()[None, ...] # test even & odd 23 | g = scarlet2.measure.Moments(component=model[0], N=2) 24 | 25 | # make a Frame 26 | model_frame = scarlet2.Frame(scarlet2.Box(model.shape)) 27 | 28 | 29 | def test_rescale(): 30 | # how does this model look if we change the WCS scale 31 | # if scale becomes larger, the image gets smaller 32 | scale = 3.1 33 | shape = (int(model.shape[-2] / scale) + 1, int(model.shape[-1] / scale)) 34 | wcs_obs = _wcs_default(shape) 35 | m = scarlet2.frame.get_affine(wcs_obs) 36 | m = scale * m 37 | wcs_obs.wcs.pc = m 38 | 39 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), wcs=wcs_obs) 40 | renderer = cls(model_frame, obs_frame) 41 | model_ = renderer(model) 42 | 43 | # for testing outputs... 44 | # print(renderer) 45 | # fig, ax = plt.subplots(1, 2, figsize=(8, 4)) 46 | # ax[0].imshow(model[0]) 47 | # ax[1].imshow(model_[0]) 48 | 49 | # undo resizing 50 | g_obs = Moments(model_[0]) 51 | g_obs.resize(scale) 52 | assert_allclose(g_obs.flux, g.flux, atol=3e-3) 53 | shift_ = jnp.asarray(g_obs.centroid) - jnp.asarray(obs_frame.bbox.spatial.center) 54 | assert_allclose(shift_, (0, 0), atol=1e-4) 55 | assert_allclose(g_obs.size, g.size, rtol=3e-5) 56 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5) 57 | 58 | 59 | def test_rotate(): 60 | # how does this model look if we rotated the WCS scale 61 | # because we change the frame, this will appear as a rotation in the opposite direction 62 | shape = model.shape 63 | phi = (30 * u.deg).to(u.rad).value 64 | wcs_obs = _wcs_default(shape) 65 | m = scarlet2.frame.get_affine(wcs_obs) 66 | r = _rot_matrix(phi) 67 | m = r @ m 68 | wcs_obs.wcs.pc = m 69 | 70 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), wcs=wcs_obs) 71 | renderer = cls(model_frame, obs_frame) 72 | model_ = renderer(model) 73 | 74 | # rotate to correct for the counter-rotation of the frame 75 | g_obs = Moments(model_[0]) 76 | g_obs.rotate(phi) 77 | assert_allclose(g_obs.flux, g.flux, atol=3e-3) 78 | assert_allclose(g_obs.centroid, g.centroid, atol=1e-4) 79 | assert_allclose(g_obs.size, g.size, rtol=3e-5) 80 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5) 81 | 82 | 83 | def test_flip(): 84 | shape = model.shape 85 | wcs_obs = _wcs_default(shape) 86 | m = scarlet2.frame.get_affine(wcs_obs) 87 | f = _flip_matrix(-1) 88 | m = f @ m 89 | wcs_obs.wcs.pc = m 90 | 91 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), wcs=wcs_obs) 92 | renderer = cls(model_frame, obs_frame) 93 | model_ = renderer(model) 94 | 95 | # undo the flip 96 | g_obs = Moments(model_[0]) 97 | g_obs.flipud() 98 | assert_allclose(g_obs.flux, g.flux, atol=3e-3) 99 | assert_allclose(g_obs.centroid, g.centroid, atol=1e-4) 100 | assert_allclose(g_obs.size, g.size, rtol=3e-5) 101 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5) 102 | 103 | 104 | def test_translation(): 105 | shift = jnp.array((12, -1.9)) 106 | shape = model.shape 107 | wcs_obs = _wcs_default(shape) 108 | wcs_obs.wcs.crpix += shift[::-1] # x/y 109 | 110 | shape = model.shape 111 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), wcs=wcs_obs) 112 | renderer = cls(model_frame, obs_frame) 113 | model_ = renderer(model) 114 | 115 | g_obs = Moments(model_[0]) 116 | shift_ = jnp.asarray(g_obs.centroid) - jnp.asarray(g.centroid) 117 | 118 | assert_allclose(g_obs.flux, g.flux, rtol=2e-3) 119 | assert_allclose(shift_, -shift, atol=1e-4) 120 | assert_allclose(g_obs.size, g.size, rtol=3e-5) 121 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5) 122 | 123 | 124 | def test_convolution(): 125 | # create model PSF and convolve g 126 | t_p = 0.7 127 | eps_p = None 128 | psf = GaussianMorphology(size=t_p, ellipticity=eps_p)() 129 | psf /= psf.sum() 130 | p = Moments(psf, N=2) 131 | psf = ArrayPSF(psf[None, ...]) 132 | model_frame_ = scarlet2.Frame(scarlet2.Box(model.shape), psf=psf) 133 | 134 | # create obs PSF 135 | t_obs, eps_obs = 5, jnp.array((-0.1, 0.1)) 136 | psf_obs = GaussianMorphology(size=t_obs, ellipticity=eps_obs)() 137 | psf_obs /= psf_obs.sum() 138 | p_obs = Moments(psf_obs, N=2) 139 | psf_obs = ArrayPSF(psf_obs[None, ...]) 140 | 141 | # render: deconvolve, reconcolve 142 | shape = model.shape 143 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), psf=psf_obs) 144 | renderer = cls(model_frame_, obs_frame) 145 | model_ = renderer(model) 146 | g_obs = Moments(model_[0]) 147 | 148 | g_obs.deconvolve(p_obs).convolve(p) 149 | assert_allclose(g_obs.flux, g.flux, atol=3e-3) 150 | assert_allclose(g_obs.centroid, g.centroid, atol=1e-4) 151 | assert_allclose(g_obs.size, g.size, rtol=3e-5) 152 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5) 153 | 154 | 155 | def test_all(): 156 | scale = 2.1 157 | phi = (70 * u.deg).to(u.rad).value 158 | shift = jnp.array((1.4, -0.456)) 159 | shape = (int(model.shape[1] // scale), int(model.shape[2] // scale)) 160 | wcs_obs = _wcs_default(shape) 161 | m = get_affine(wcs_obs) 162 | wcs_obs.wcs.pc = scale * _rot_matrix(phi) @ _flip_matrix(-1) @ m 163 | wcs_obs.wcs.crpix += shift[::-1] # x/y 164 | 165 | # create model PSF and convolve g 166 | t_p = 1 167 | eps_p = None 168 | psf = GaussianMorphology(size=t_p, ellipticity=eps_p)() 169 | psf /= psf.sum() 170 | p = Moments(psf, N=2) 171 | psf = ArrayPSF(psf[None, ...]) 172 | model_frame_ = scarlet2.Frame(scarlet2.Box(model.shape), psf=psf) 173 | 174 | # obs PSF 175 | t_obs, eps_obs = 3, jnp.array((0.1, -0.1)) 176 | psf_obs = GaussianMorphology(size=t_obs, ellipticity=eps_obs)() 177 | psf_obs /= psf_obs.sum() 178 | p_obs = Moments(psf_obs, N=2) 179 | psf_obs = ArrayPSF(psf_obs[None, ...]) 180 | 181 | obs_frame = scarlet2.Frame(scarlet2.Box(shape), psf=psf_obs, wcs=wcs_obs) 182 | renderer = cls(model_frame_, obs_frame) 183 | model_ = renderer(model) 184 | 185 | g_obs = Moments(model_[0]) 186 | g_obs.deconvolve(p_obs).flipud().rotate(phi).resize(scale).convolve(p) 187 | shift_ = jnp.asarray(g_obs.centroid) - jnp.asarray(obs_frame.bbox.spatial.center) 188 | 189 | assert_allclose(g_obs.flux, g.flux, rtol=2e-3) 190 | assert_allclose(shift_, -shift, atol=1e-4) 191 | assert_allclose(g_obs.size, g.size, rtol=3e-5) 192 | assert_allclose(g_obs.ellipticity, g.ellipticity, atol=3e-5) 193 | -------------------------------------------------------------------------------- /src/scarlet2/wavelets.py: -------------------------------------------------------------------------------- 1 | """Wavelet functions""" 2 | # from https://github.com/pmelchior/scarlet/blob/master/scarlet/wavelet.py 3 | 4 | import jax.numpy as jnp 5 | 6 | 7 | class Starlet: 8 | """Wavelet transform of a images (2D or 3D) with the 'a trou' algorithm. 9 | 10 | The transform is performed by convolving the image by a seed starlet: the transform of an all-zero 11 | image with its central pixel set to one. This requires 2-fold padding of the image and an odd pad 12 | shape. The fft of the seed starlet is cached so that it can be reused in the transform of other 13 | images that have the same shape. 14 | """ 15 | 16 | def __init__(self, image, coefficients, generation, convolve2d): 17 | """ 18 | Parameters 19 | ---------- 20 | image: array 21 | Image in real space. 22 | coefficients: array 23 | Starlet transform of the image. 24 | generation: int 25 | The generation of the starlet transform (either `1` or `2`). 26 | convolve2d: array 27 | The filter used to convolve the image and create the wavelets. 28 | When `convolve2d` is `None` this uses a cubic bspline. 29 | """ 30 | self._image = image 31 | self._coeffs = coefficients 32 | self._generation = generation 33 | self._convolve2d = convolve2d 34 | self._norm = None 35 | 36 | @property 37 | def image(self): 38 | """The real space image""" 39 | return self._image 40 | 41 | @property 42 | def coefficients(self): 43 | """Starlet coefficients""" 44 | return self._coeffs 45 | 46 | @staticmethod 47 | def from_image(image, scales=None, generation=2, convolve2d=None): 48 | """Generate a set of starlet coefficients for an image 49 | 50 | Parameters 51 | ---------- 52 | image: array-like 53 | The image that is converted into starlet coefficients 54 | scales: int 55 | The number of starlet scales to use. 56 | If `scales` is `None` then the maximum number of scales is used. 57 | Note: this is the length of the coefficients-1, as in the notation 58 | of `Starck et al. 2011`. 59 | generation: int 60 | The generation of the starlet transform (either `1` or `2`). 61 | convolve2d: array-like 62 | The filter used to convolve the image and create the wavelets. 63 | When `convolve2D` is `None` this uses a cubic bspline. 64 | 65 | Returns 66 | ------- 67 | result: Starlet 68 | The resulting `Starlet` that contains the image, starlet coefficients, 69 | as well as the parameters used to generate the coefficients. 70 | """ 71 | if scales is None: 72 | scales = get_scales(image.shape) 73 | coefficients = starlet_transform(image, scales, generation, convolve2d) 74 | return Starlet(image, coefficients, generation, convolve2d) 75 | 76 | 77 | def bspline_convolve(image, scale): 78 | """Convolve an image with a bpsline at a given scale. 79 | 80 | This uses the spline 81 | `h1d = jnp.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])` 82 | from Starck et al. 2011. 83 | 84 | Parameters 85 | ---------- 86 | image: 2D array 87 | The image or wavelet coefficients to convolve. 88 | scale: int 89 | The wavelet scale for the convolution. This sets the 90 | spacing between adjacent pixels with the spline. 91 | 92 | """ 93 | # Filter for the scarlet transform. Here bspline 94 | h1d = jnp.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16]) 95 | j = scale 96 | 97 | slice0 = slice(None, -(2 ** (j + 1))) 98 | slice1 = slice(None, -(2**j)) 99 | slice3 = slice(2**j, None) 100 | slice4 = slice(2 ** (j + 1), None) 101 | 102 | # row 103 | col = image * h1d[2] 104 | col = col.at[slice4].add(image[slice0] * h1d[0]) 105 | col = col.at[slice3].add(image[slice1] * h1d[1]) 106 | col = col.at[slice1].add(image[slice3] * h1d[3]) 107 | col = col.at[slice0].add(image[slice4] * h1d[4]) 108 | 109 | # column 110 | result = col * h1d[2] 111 | result = result.at[:, slice4].add(col[:, slice0] * h1d[0]) 112 | result = result.at[:, slice3].add(col[:, slice1] * h1d[1]) 113 | result = result.at[:, slice1].add(col[:, slice3] * h1d[3]) 114 | result = result.at[:, slice0].add(col[:, slice4] * h1d[4]) 115 | return result 116 | 117 | 118 | def starlet_transform(image, scales=None, generation=2, convolve2d=None): 119 | """Perform a scarlet transform, or 2nd gen starlet transform. 120 | 121 | Parameters 122 | ---------- 123 | image: 2D array 124 | The image to transform into starlet coefficients. 125 | scales: int 126 | The number of scale to transform with starlets. 127 | The total dimension of the starlet will have 128 | `scales+1` dimensions, since it will also hold 129 | the image at all scales higher than `scales`. 130 | generation: int 131 | The generation of the transform. 132 | This must be `1` or `2`. Default is `2`. 133 | convolve2d: function 134 | The filter function to use to convolve the image 135 | with starlets in 2D. 136 | 137 | Returns 138 | ------- 139 | starlet: array with dimension (scales+1, Ny, Nx) 140 | The starlet dictionary for the input `image`. 141 | """ 142 | assert len(image.shape) == 2, f"Image should be 2D, got {len(image.shape)}" 143 | assert generation in (1, 2), f"generation should be 1 or 2, got {generation}" 144 | 145 | scales = get_scales(image.shape, scales) 146 | c = image 147 | if convolve2d is None: 148 | convolve2d = bspline_convolve 149 | 150 | ## wavelet set of coefficients. 151 | starlet = jnp.zeros((scales + 1,) + image.shape) 152 | for j in range(scales): 153 | gen1 = convolve2d(c, j) 154 | 155 | if generation == 2: 156 | gen2 = convolve2d(gen1, j) 157 | starlet = starlet.at[j].set(c - gen2) 158 | else: 159 | starlet = starlet.at[j].set(c - gen1) 160 | 161 | c = gen1 162 | 163 | starlet = starlet.at[-1].set(c) 164 | return starlet 165 | 166 | 167 | def starlet_reconstruction(starlets, generation=2, convolve2d=None): 168 | """Reconstruct an image from a dictionary of starlets 169 | 170 | Parameters 171 | ---------- 172 | starlets: array with dimension (scales+1, Ny, Nx) 173 | The starlet dictionary used to reconstruct the image. 174 | generation: int 175 | The generation of the transform. 176 | This must be `1` or `2`. Default is `2`. 177 | convolve2d: function 178 | The filter function to use to convolve the image 179 | with starlets in 2D. 180 | 181 | Returns 182 | ------- 183 | image: 2D array 184 | The image reconstructed from the input `starlet`. 185 | """ 186 | if generation == 1: 187 | return jnp.sum(starlets, axis=0) 188 | if convolve2d is None: 189 | convolve2d = bspline_convolve 190 | scales = len(starlets) - 1 191 | 192 | c = starlets[-1] 193 | for i in range(1, scales + 1): 194 | j = scales - i 195 | cj = convolve2d(c, j) 196 | c = cj + starlets[j] 197 | return c 198 | 199 | 200 | def get_scales(image_shape, scales=None): 201 | """Get the number of scales to use in the starlet transform. 202 | 203 | Parameters 204 | ---------- 205 | image_shape: tuple 206 | The 2D shape of the image that is being transformed 207 | scales: int 208 | The number of scale to transform with starlets. 209 | The total dimension of the starlet will have 210 | `scales+1` dimensions, since it will also hold 211 | the image at all scales higher than `scales`. 212 | """ 213 | # Number of levels for the Starlet decomposition 214 | max_scale = int(jnp.log2(min(image_shape[-2:]))) - 1 215 | if (scales is None) or scales > max_scale: 216 | scales = max_scale 217 | return int(scales) 218 | -------------------------------------------------------------------------------- /src/scarlet2/morphology.py: -------------------------------------------------------------------------------- 1 | import astropy.units as u 2 | import equinox as eqx 3 | import jax.numpy as jnp 4 | import jax.scipy 5 | 6 | from . import Scenery, measure 7 | from .module import Module 8 | from .wavelets import starlet_reconstruction, starlet_transform 9 | 10 | 11 | class Morphology(Module): 12 | """Morphology base class""" 13 | 14 | @property 15 | def shape(self): 16 | """Shape (2D) of the morphology model""" 17 | raise NotImplementedError 18 | 19 | 20 | class ProfileMorphology(Morphology): 21 | """Base class for morphologies based on a radial profile""" 22 | 23 | size: float 24 | """Size of the profile 25 | 26 | Can be given as an astropy angle, which will be transformed with the WCS of the current 27 | :py:class:`~scarlet2.Scene`. 28 | """ 29 | ellipticity: (None, jnp.array) 30 | """Ellipticity of the profile""" 31 | _shape: tuple = eqx.field(repr=False) 32 | 33 | def __init__(self, size, ellipticity=None, shape=None): 34 | if isinstance(size, u.Quantity): 35 | try: 36 | size = Scenery.scene.frame.u_to_pixel(size) 37 | except AttributeError: 38 | print("`size` defined in astropy units can only be used within the context of a Scene") 39 | print("Use 'with Scene(frame) as scene: (...)'") 40 | raise 41 | 42 | self.size = size 43 | self.ellipticity = ellipticity 44 | 45 | # default shape: square 15x size 46 | if shape is None: 47 | # explicit call to int() to avoid bbox sizes being jax-traced 48 | size = int(jnp.ceil(15 * self.size)) 49 | # odd shapes for unique center pixel 50 | if size % 2 == 0: 51 | size += 1 52 | shape = (size, size) 53 | self._shape = shape 54 | 55 | @property 56 | def shape(self): 57 | """Shape of the bounding box for the profile. 58 | 59 | If not set during `__init__`, uses a square box with an odd number of pixels 60 | not smaller than `10*size`. 61 | """ 62 | return self._shape 63 | 64 | def f(self, r2): 65 | """Radial profile function 66 | 67 | Parameters 68 | ---------- 69 | r2: float or array 70 | Radius (distance from the center) squared 71 | """ 72 | raise NotImplementedError 73 | 74 | def __call__(self, delta_center=jnp.zeros(2)): # noqa: B008 75 | """Evaluate the model""" 76 | _y = jnp.arange(-(self.shape[-2] // 2), self.shape[-2] // 2 + 1, dtype=float) - delta_center[-2] 77 | _x = jnp.arange(-(self.shape[-1] // 2), self.shape[-1] // 2 + 1, dtype=float) - delta_center[-1] 78 | 79 | if self.ellipticity is None: 80 | r2 = _y[:, None] ** 2 + _x[None, :] ** 2 81 | else: 82 | e1, e2 = self.ellipticity 83 | g_factor = 1 / (1.0 + jnp.sqrt(1.0 - (e1**2 + e2**2))) 84 | g1, g2 = self.ellipticity * g_factor 85 | __x = ((1 - g1) * _x[None, :] - g2 * _y[:, None]) / jnp.sqrt(1 - (g1**2 + g2**2)) 86 | __y = (-g2 * _x[None, :] + (1 + g1) * _y[:, None]) / jnp.sqrt(1 - (g1**2 + g2**2)) 87 | r2 = __y**2 + __x**2 88 | 89 | r2 /= self.size**2 90 | r2 = jnp.maximum(r2, 1e-3) # prevents infs at R2 = 0 91 | morph = self.f(r2) 92 | return morph 93 | 94 | 95 | class GaussianMorphology(ProfileMorphology): 96 | """Gaussian radial profile""" 97 | 98 | def f(self, r2): 99 | """Radial profile function 100 | 101 | Parameters 102 | ---------- 103 | r2: float or array 104 | Radius (distance from the center) squared 105 | """ 106 | return jnp.exp(-r2 / 2) 107 | 108 | def __call__(self, delta_center=jnp.zeros(2)): # noqa: B008 109 | """Evaluate the model""" 110 | # faster circular 2D Gaussian: instead of N^2 evaluations, use outer product of 2 1D Gaussian evals 111 | if self.ellipticity is None: 112 | _y = jnp.arange(-(self.shape[-2] // 2), self.shape[-2] // 2 + 1, dtype=float) - delta_center[-2] 113 | _x = jnp.arange(-(self.shape[-1] // 2), self.shape[-1] // 2 + 1, dtype=float) - delta_center[-1] 114 | 115 | # with pixel integration 116 | f = lambda x, s: 0.5 * ( # noqa: E731 117 | 1 118 | - jax.scipy.special.erfc((0.5 - x) / jnp.sqrt(2) / s) 119 | + 1 120 | - jax.scipy.special.erfc((0.5 + x) / jnp.sqrt(2) / s) 121 | ) 122 | # # without pixel integration 123 | # f = lambda x, s: jnp.exp(-(x ** 2) / (2 * s ** 2)) / (jnp.sqrt(2 * jnp.pi) * s) 124 | 125 | return jnp.outer(f(_y, self.size), f(_x, self.size)) 126 | 127 | else: 128 | return super().__call__(delta_center) 129 | 130 | @staticmethod 131 | def from_image(image): 132 | """Create Gaussian radial profile from the 2nd moments of `image` 133 | 134 | Parameters 135 | ---------- 136 | image: array 137 | 2D array to measure :py:class:`~scarlet2.measure.Moments` from. 138 | 139 | Returns 140 | ------- 141 | GaussianMorphology 142 | """ 143 | assert image.ndim == 2 144 | center = measure.centroid(image) 145 | # compute moments and create Gaussian from it 146 | g = measure.Moments(image, center=center, N=2) 147 | return GaussianMorphology.from_moments(g, shape=image.shape) 148 | 149 | @staticmethod 150 | def from_moments(g, shape=None): 151 | """Create Gaussian radial profile from the moments `g` 152 | 153 | Parameters 154 | ---------- 155 | g: :py:class:`~scarlet2.measure.Moments` 156 | Moments, order >= 2 157 | shape: tuple 158 | Shape of the bounding box 159 | 160 | Returns 161 | ------- 162 | GaussianMorphology 163 | """ 164 | t = g.size 165 | ellipticity = g.ellipticity 166 | 167 | # create image of Gaussian with these 2nd moments 168 | if jnp.isfinite(t) and jnp.isfinite(ellipticity).all(): 169 | morph = GaussianMorphology(t, ellipticity, shape=shape) 170 | else: 171 | raise ValueError( 172 | f"Gaussian morphology not possible with size={t}, and ellipticity={ellipticity}!" 173 | ) 174 | return morph 175 | 176 | 177 | class SersicMorphology(ProfileMorphology): 178 | """Sersic radial profile""" 179 | 180 | n: float 181 | """Sersic index""" 182 | 183 | def __init__(self, n, size, ellipticity=None, shape=None): 184 | self.n = n 185 | super().__init__(size, ellipticity=ellipticity, shape=shape) 186 | 187 | def f(self, r2): 188 | """Radial profile function 189 | 190 | Parameters 191 | ---------- 192 | r2: float or array 193 | Radius (distance from the center) squared 194 | """ 195 | n = self.n 196 | n2 = n * n 197 | # simplest form of bn: Capaccioli (1989) 198 | # bn = 1.9992 * n - 0.3271 199 | # 200 | # better treatment in Ciotti & Bertin (1999), eq. 18 201 | # stable to n > 0.36, with errors < 10^5 202 | bn = 2 * n - 0.333333 + 0.009877 / n + 0.001803 / n2 + 0.000114 / (n2 * n) - 0.000072 / (n2 * n2) 203 | 204 | # MacArthur, Courteau, & Holtzman (2003), eq. A2 205 | # much more stable for n < 0.36 206 | # not using it here to avoid if clause in jitted code 207 | # bn = 0.01945 - 0.8902 * n + 10.95 * n2 - 19.67 * n2 * n + 13.43 * n2 * n2 208 | 209 | # Graham & Driver 2005, eq. 1 210 | # we're given R^2, so we use R2^(0.5/n) instead of 1/n 211 | return jnp.exp(-bn * (r2 ** (0.5 / n) - 1)) 212 | 213 | 214 | prox_plus = lambda x: jnp.maximum(x, 0) # noqa: E731 215 | prox_soft = lambda x, thresh: jnp.sign(x) * prox_plus(jnp.abs(x) - thresh) # noqa: E731 216 | prox_soft_plus = lambda x, thresh: prox_plus(prox_soft(x, thresh)) # noqa: E731 217 | 218 | 219 | class StarletMorphology(Morphology): 220 | """Morphology in the starlet basis 221 | 222 | See Also 223 | -------- 224 | scarlet2.wavelets.Starlet 225 | """ 226 | 227 | coeffs: jnp.ndarray 228 | """Starlet coefficients""" 229 | l1_thresh: float = eqx.field(default=0) 230 | """L1 threshold for coefficient to create sparse representation""" 231 | positive: bool = eqx.field(default=True) 232 | """Whether the coefficients are restricted to non-negative values""" 233 | 234 | def __call__(self, **kwargs): 235 | """Evaluate the model""" 236 | f = prox_soft_plus if self.positive else prox_soft 237 | return starlet_reconstruction(f(self.coeffs, self.l1_thresh)) 238 | 239 | @property 240 | def shape(self): 241 | """Shape (2D) of the morphology model""" 242 | return self.coeffs.shape[-2:] # wavelet coeffs: scales x n1 x n2 243 | 244 | @staticmethod 245 | def from_image(image, **kwargs): 246 | """Create starlet morphology from `image` 247 | 248 | Parameters 249 | ---------- 250 | image: array 251 | 2D image array to determine coefficients from. 252 | kwargs: dict 253 | Additional arguments for `__init__` 254 | 255 | Returns 256 | ------- 257 | StarletMorphology 258 | """ 259 | # Starlet transform of image (n1,n2) into coefficient with 3 dimensions: (scales+1,n1,n2) 260 | coeffs = starlet_transform(image) 261 | return StarletMorphology(coeffs, **kwargs) 262 | -------------------------------------------------------------------------------- /docs/howto/sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Sample from Posterior\n", 8 | "\n", 9 | "_scarlet2_ can provide samples from the posterior distribution to pass to downstream operations and as the most precise option for uncertainty quantification. In principle, we can get posterior samples for every parameter, and this can be done with any sampler by evaluating the log-posterior distribution. For this guide we will use the Hamiltonian Monte Carlo sampler from numpyro, for which we created a convenient front-end in _scarlet2_.\n", 10 | "\n", 11 | "We start from the [quickstart tutorial](../0-quickstart), loading the same data and the best-fitting model." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# Import Packages and setup\n", 21 | "import jax.numpy as jnp\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "\n", 24 | "from scarlet2 import *\n", 25 | "\n", 26 | "set_validation(False)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Create Observation\n", 34 | "\n", 35 | "We need to create the {py:class}`~scarlet2.Observation` because it contains the {py:func}`~scarlet2.Observation.log_likelihood` method we need for the posterior:" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "# load the data\n", 45 | "from huggingface_hub import hf_hub_download\n", 46 | "\n", 47 | "filename = hf_hub_download(\n", 48 | " repo_id=\"astro-data-lab/scarlet-test-data\", filename=\"hsc_cosmos_35.npz\", repo_type=\"dataset\"\n", 49 | ")\n", 50 | "file = jnp.load(filename)\n", 51 | "data = jnp.asarray(file[\"images\"])\n", 52 | "channels = [str(f) for f in file[\"filters\"]]\n", 53 | "centers = jnp.array([(src[\"y\"], src[\"x\"]) for src in file[\"catalog\"]])\n", 54 | "weights = jnp.asarray(1 / file[\"variance\"])\n", 55 | "psf = jnp.asarray(file[\"psfs\"])\n", 56 | "\n", 57 | "# create the observation\n", 58 | "obs = Observation(\n", 59 | " data,\n", 60 | " weights,\n", 61 | " psf=ArrayPSF(psf),\n", 62 | " channels=channels,\n", 63 | ")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": { 69 | "collapsed": false 70 | }, 71 | "source": [ 72 | "## Load Model\n", 73 | "\n", 74 | "We can make use of the best-fit model from the Quickstart guide as the starting point of the sampler." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "import scarlet2.io\n", 84 | "\n", 85 | "id = 35\n", 86 | "filename = \"hsc_cosmos.h5\"\n", 87 | "scene = scarlet2.io.model_from_h5(filename, path=\"..\", id=id)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "collapsed": false 94 | }, 95 | "source": [ 96 | "Let's have a look:" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "norm = plot.AsinhAutomaticNorm(obs)\n", 106 | "plot.scene(scene, observation=obs, norm=norm, add_boxes=True)\n", 107 | "plt.show()" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "## Define Parameters with Prior\n", 115 | "\n", 116 | "In principle, we can get posterior samples for every parameter. We will demonstrate by sampling from the spectrum and the center position of the point source #0. We therefore need to set the `prior` attribute for each of these parameters; the attribute `stepsize` is ignored, but `constraint` cannot be used when `prior` is set." 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "import numpyro.distributions as dist\n", 126 | "\n", 127 | "C = len(channels)\n", 128 | "with scarlet2.Parameters(scene) as parameters:\n", 129 | " # rough guess of source brightness across bands\n", 130 | " p1 = scene.sources[0].spectrum\n", 131 | " prior1 = dist.Uniform(low=jnp.zeros(C), high=500 * jnp.ones(C))\n", 132 | " Parameter(p1, name=\"spectrum\", prior=prior1)\n", 133 | "\n", 134 | " # initial position was integer pixel coordinate\n", 135 | " # assume 0.5 pixel uncertainty\n", 136 | " p2 = scene.sources[0].center\n", 137 | " prior2 = dist.Normal(centers[0], scale=0.5)\n", 138 | " Parameter(p2, name=\"center\", prior=prior2)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": { 144 | "collapsed": false 145 | }, 146 | "source": [ 147 | "```{warning}\n", 148 | "You are responsible to set reasonable priors, which describe what you know about the parameter before having looked at the data. In the example above, the spectrum gets a wide flat prior, and the center prior uses the position `centers[0]`, which is given by the original detection catalog. Neither use information from the optimized `scene`.\n", 149 | "\n", 150 | "Also: If in doubt how much prior choices matter, vary them within reason.\n", 151 | "```\n", 152 | "\n", 153 | "## Run Sampler\n", 154 | "\n", 155 | "Then we can run numpyro's {py:class}`~numpyro.infer.hmc.NUTS` sampler with a call to {py:func}`~scarlet2.Scene.sample`, which is analogous to {py:func}`~scarlet2.Scene.fit`:" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "mcmc = scene.sample(\n", 165 | " obs,\n", 166 | " parameters,\n", 167 | " num_warmup=100,\n", 168 | " num_samples=1000,\n", 169 | " progress_bar=False,\n", 170 | ")\n", 171 | "mcmc.print_summary()" 172 | ] 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "metadata": { 177 | "collapsed": false 178 | }, 179 | "source": [ 180 | "## Access Samples\n", 181 | "\n", 182 | "The samples can be accessed from the MCMC chain and are listed as arrays under the names chosen above for the respective `Parameter`." 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "import pprint\n", 192 | "\n", 193 | "samples = mcmc.get_samples()\n", 194 | "pprint.pprint(samples)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "To create versions of the scene for any of the samples, we first select a few at random and then use the method {py:func}`scarlet2.Module.replace` to set their values at the locations identified by `parameters`:" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "# get values for three random samples\n", 211 | "S = 3\n", 212 | "import jax.random\n", 213 | "\n", 214 | "seed = 42\n", 215 | "key = jax.random.key(seed)\n", 216 | "idxs = jax.random.randint(key, shape=(S,), minval=0, maxval=mcmc.num_samples)\n", 217 | "\n", 218 | "values = [[spectrum, center] for spectrum, center in zip(samples[\"spectrum\"][idxs], samples[\"center\"][idxs])]\n", 219 | "\n", 220 | "# create versions of the scene with these posterior samples\n", 221 | "scenes = [scene.replace(parameters, v) for v in values]\n", 222 | "\n", 223 | "# display the source model\n", 224 | "fig, axes = plt.subplots(1, S, figsize=(10, 4))\n", 225 | "for s in range(S):\n", 226 | " source_array = scenes[s].sources[0]()\n", 227 | " axes[s].imshow(plot.img_to_rgb(source_array, norm=norm))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "markdown", 232 | "metadata": { 233 | "collapsed": false 234 | }, 235 | "source": [ 236 | "The difference are imperceptible for this source which tells us that the data were highly informative. But we can measure e.g. the total fluxes for each sample" 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "execution_count": null, 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "print(f\"-------------- {channels}\")\n", 246 | "for i, scene in enumerate(scenes):\n", 247 | " print(f\"Flux Sample {i}: {measure.flux(scene.sources[0])}\")" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "## Visualize Posterior\n", 255 | "\n", 256 | "We can also visualize the posterior distributions, e.g. with the [`corner`](https://corner.readthedocs.io/en/latest/) package:" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "import corner\n", 266 | "\n", 267 | "corner.corner(mcmc);" 268 | ] 269 | } 270 | ], 271 | "metadata": { 272 | "celltoolbar": "Raw Cell Format", 273 | "kernelspec": { 274 | "display_name": "scarlet2", 275 | "language": "python", 276 | "name": "python3" 277 | }, 278 | "language_info": { 279 | "codemirror_mode": { 280 | "name": "ipython", 281 | "version": 3 282 | }, 283 | "file_extension": ".py", 284 | "mimetype": "text/x-python", 285 | "name": "python", 286 | "nbconvert_exporter": "python", 287 | "pygments_lexer": "ipython3", 288 | "version": "3.12.11" 289 | } 290 | }, 291 | "nbformat": 4, 292 | "nbformat_minor": 4 293 | } 294 | -------------------------------------------------------------------------------- /src/scarlet2/bbox.py: -------------------------------------------------------------------------------- 1 | import equinox as eqx 2 | import jax.numpy as jnp 3 | 4 | 5 | class Box(eqx.Module): 6 | """Bounding Box for data array 7 | 8 | A Bounding box describes the location of a data array in the model coordinate system. 9 | It is used to identify spatial and channel overlap and to map from model 10 | to observed frames and back. 11 | 12 | The `BBox` code is agnostic about the meaning of the dimensions. 13 | We generally use this convention: 14 | 15 | - 2D shapes denote (Height, Width) 16 | - 3D shapes denote (Channels, Height, Width) 17 | """ 18 | 19 | shape: tuple 20 | """Size of the array""" 21 | origin: tuple 22 | """Start coordinate (in 2D: lower-left corner) of the array in model frame""" 23 | 24 | def __init__(self, shape, origin=None): 25 | """ 26 | Parameters 27 | ---------- 28 | shape: tuple 29 | Size of the array 30 | origin: tuple, optional 31 | Start coordinate (in 2D: lower-left corner) of the array in model frame 32 | """ 33 | self.shape = tuple(shape) 34 | if origin is None: 35 | origin = (0,) * len(shape) 36 | self.origin = tuple(origin) 37 | 38 | @staticmethod 39 | def from_bounds(*bounds): 40 | """Initialize a box from its bounds 41 | 42 | Parameters 43 | ---------- 44 | bounds: tuple of (min,max) pairs 45 | Min/Max coordinate for every dimension 46 | 47 | Returns 48 | ------- 49 | bbox: Box 50 | A new box bounded by the input bounds. 51 | """ 52 | shape = tuple(max(0, cmax - cmin) for cmin, cmax in bounds) 53 | origin = (cmin for cmin, cmax in bounds) 54 | return Box(shape, origin=origin) 55 | 56 | @staticmethod 57 | def from_data(X, min_value=0): # noqa: N803 58 | """Define box where `X` is above `min_value` 59 | 60 | Parameters 61 | ---------- 62 | X : jnp.ndarray 63 | Data to threshold 64 | min_value : float 65 | Minimum value of the result. 66 | 67 | Returns 68 | ------- 69 | bbox : :class:`scarlet2.bbox.Box` 70 | Bounding box for the thresholded `X` 71 | """ 72 | sel = min_value < X 73 | if sel.any(): 74 | nonzero = jnp.where(sel) 75 | bounds = [] 76 | for dim in range(len(X.shape)): 77 | bounds.append((nonzero[dim].min(), nonzero[dim].max() + 1)) 78 | else: 79 | bounds = [[0, 0]] * len(X.shape) 80 | return Box.from_bounds(*bounds) 81 | 82 | def contains(self, p): 83 | """Whether the box contains a given coordinate `p`""" 84 | if len(p) != self.D: 85 | raise ValueError(f"Dimension mismatch in {p} and {self.D}") 86 | 87 | for d in range(self.D): 88 | if p[d] < self.origin[d] or p[d] >= self.origin[d] + self.shape[d]: 89 | return False 90 | return True 91 | 92 | def get_extent(self): 93 | """Return the start and end coordinates.""" 94 | return [self.start[-1], self.stop[-1], self.start[-2], self.stop[-2]] 95 | 96 | @property 97 | def D(self): # noqa: N802 98 | """Dimensionality of this BBox""" 99 | return len(self.shape) 100 | 101 | @property 102 | def start(self): 103 | """Tuple of start coordinates""" 104 | return self.origin 105 | 106 | @property 107 | def stop(self): 108 | """Tuple of stop coordinates""" 109 | return tuple(o + s for o, s in zip(self.origin, self.shape, strict=False)) 110 | 111 | @property 112 | def center(self): 113 | """Tuple of center coordinates""" 114 | return tuple(o + s // 2 for o, s in zip(self.origin, self.shape, strict=False)) 115 | 116 | @property 117 | def bounds(self): 118 | """Bounds of the box""" 119 | return tuple((o, o + s) for o, s in zip(self.origin, self.shape, strict=False)) 120 | 121 | @property 122 | def slices(self): 123 | """Bounds of the box as slices""" 124 | return tuple([slice(o, o + s) for o, s in zip(self.origin, self.shape, strict=False)]) 125 | 126 | @property 127 | def spatial(self): 128 | """Spatial component of higher-dimensional box""" 129 | assert self.D >= 2 130 | return self[-2:] 131 | 132 | def set_center(self, pos): 133 | """Center box at given position""" 134 | pos_ = tuple(_.item() for _ in pos) 135 | origin = tuple(o + p - c for o, p, c in zip(self.origin, pos_, self.center, strict=False)) 136 | object.__setattr__(self, "origin", origin) 137 | 138 | def grow(self, delta): 139 | """Grow the Box by the given delta in each direction""" 140 | if not hasattr(delta, "__iter__"): 141 | delta = [delta] * self.D 142 | origin = tuple([self.origin[d] - delta[d] for d in range(self.D)]) 143 | shape = tuple([self.shape[d] + 2 * delta[d] for d in range(self.D)]) 144 | return Box(shape, origin=origin) 145 | 146 | def shrink(self, delta): 147 | """Shrink the Box by the given delta in each direction""" 148 | if not hasattr(delta, "__iter__"): 149 | delta = [delta] * self.D 150 | origin = tuple([self.origin[d] + delta[d] for d in range(self.D)]) 151 | shape = tuple([self.shape[d] - 2 * delta[d] for d in range(self.D)]) 152 | return Box(shape, origin=origin) 153 | 154 | def __or__(self, other): 155 | """Union of two bounding boxes 156 | 157 | Parameters 158 | ---------- 159 | other: `Box` 160 | The other bounding box in the union 161 | 162 | Returns 163 | ------- 164 | result: `Box` 165 | The smallest rectangular box that contains *both* boxes. 166 | """ 167 | if other.D != self.D: 168 | raise ValueError(f"Dimension mismatch in the boxes {other} and {self}") 169 | bounds = [] 170 | for d in range(self.D): 171 | bounds.append((min(self.start[d], other.start[d]), max(self.stop[d], other.stop[d]))) 172 | return Box.from_bounds(*bounds) 173 | 174 | def __and__(self, other): 175 | """Intersection of two bounding boxes 176 | 177 | If there is no intersection between the two bounding 178 | boxes then an empty bounding box is returned. 179 | 180 | Parameters 181 | ---------- 182 | other: `Box` 183 | The other bounding box in the intersection 184 | 185 | Returns 186 | ------- 187 | result: `Box` 188 | The rectangular box that is in the overlap region 189 | of both boxes. 190 | """ 191 | if other.D != self.D: 192 | raise ValueError(f"Dimension mismatch in the boxes {other} and {self}") 193 | assert other.D == self.D 194 | bounds = [] 195 | for d in range(self.D): 196 | bounds.append((max(self.start[d], other.start[d]), min(self.stop[d], other.stop[d]))) 197 | return Box.from_bounds(*bounds) 198 | 199 | def __getitem__(self, i): 200 | s_ = self.shape[i] 201 | o_ = self.origin[i] 202 | if not hasattr(s_, "__iter__"): 203 | s_ = (s_,) 204 | o_ = (o_,) 205 | return Box(s_, origin=o_) 206 | 207 | def __add__(self, offset): 208 | if not hasattr(offset, "__iter__"): 209 | offset = (offset,) * self.D 210 | origin = tuple([a + o for a, o in zip(self.origin, offset, strict=False)]) 211 | return Box(self.shape, origin=origin) 212 | 213 | def __sub__(self, offset): 214 | if not hasattr(offset, "__iter__"): 215 | offset = (offset,) * self.D 216 | origin = tuple([a - o for a, o in zip(self.origin, offset, strict=False)]) 217 | return Box(self.shape, origin=origin) 218 | 219 | def __matmul__(self, bbox): 220 | bounds = self.bounds + bbox.bounds 221 | return Box.from_bounds(*bounds) 222 | 223 | def __copy__(self): 224 | return Box(self.shape, origin=self.origin) 225 | 226 | def __eq__(self, other): 227 | return self.shape == other.shape and self.origin == other.origin 228 | 229 | def __hash__(self): 230 | return hash((self.shape, self.origin)) 231 | 232 | 233 | def overlap_slices(bbox1, bbox2, return_boxes=False): 234 | """Slices of bbox1 and bbox2 that overlap 235 | 236 | Parameters 237 | ---------- 238 | bbox1: `~scarlet.bbox.Box` 239 | The first box to use for comparing overlap. 240 | bbox2: `~scarlet.bbox.Box` 241 | The second box to use for comparing overlap. 242 | return_boxes: bool 243 | If True return new boxes corresponding to the overlapping portion of 244 | each of the input boxes. If False, return the overlapping portion of the 245 | original boxes. Default False. 246 | 247 | Returns 248 | ------- 249 | slices: tuple of slices 250 | The slice of an array bounded by `bbox1` and 251 | the slice of an array bounded by `bbox2` in the 252 | overlapping region. 253 | """ 254 | overlap = bbox1 & bbox2 255 | _bbox1 = overlap - bbox1.origin 256 | _bbox2 = overlap - bbox2.origin 257 | if return_boxes: 258 | return _bbox1, _bbox2 259 | slices = ( 260 | _bbox1.slices, 261 | _bbox2.slices, 262 | ) 263 | return slices 264 | 265 | 266 | def insert_into(image, sub, bbox): 267 | """Insert `sub` into `image` according to this bbox 268 | 269 | Inverse operation to :func:`~scarlet.bbox.Box.extract_from`. 270 | 271 | Parameters 272 | ---------- 273 | image: array 274 | Full image 275 | sub: array 276 | Smaller sub-image 277 | bbox: Box 278 | Bounding box that describes the shape and position of `sub` in the pixel coordinates of `image`. 279 | Returns 280 | ------- 281 | image: array 282 | Image with `sub` inserted at `bbox`. 283 | """ 284 | imbox = Box(image.shape) 285 | 286 | im_slices, sub_slices = overlap_slices(imbox, bbox) 287 | try: 288 | image[im_slices] = sub[sub_slices] # numpy arrays 289 | except TypeError: 290 | image = image.at[im_slices].set(sub[sub_slices]) # jax arrays 291 | return image 292 | -------------------------------------------------------------------------------- /src/scarlet2/lsst_utils.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import lsst.afw.geom as afw_geom 3 | import lsst.geom as geom 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from astropy import units as u 7 | from astropy.wcs import WCS 8 | from lsst.afw.fits import MemFileManager 9 | from lsst.afw.image import ExposureF 10 | from lsst.geom import Point2D 11 | from lsst.meas.algorithms import WarpedPsf 12 | from lsst.pipe.tasks.registerImage import RegisterConfig, RegisterTask 13 | from pyvo.dal.adhoc import DatalinkResults, SodaQuery 14 | 15 | import scarlet2 16 | 17 | 18 | def warp_img(ref_img, img_to_warp, ref_wcs, wcs_to_warp): 19 | """Warp and rotate an image onto the coordinate system of another image 20 | 21 | Parameters 22 | ---------- 23 | ref_img: 'ExposureF' 24 | is the reference image for the re-projection 25 | img_to_warp: 'ExposureF' 26 | the image to rotate and warp onto the reference image's wcs 27 | ref_wcs: 'wcs object' 28 | the wcs of the reference image (i.e. ref_img.getWcs() ) 29 | wcs_to_warp: 'wcs object' 30 | the wcs of the image to warp (i.e. img_to_warp.getWcs() ) 31 | Returns 32 | ------- 33 | warped_exp: 'ExposureF' 34 | a reprojected, rotated image that is aligned and matched to ref_image 35 | """ 36 | config = RegisterConfig() 37 | task = RegisterTask(name="register", config=config) 38 | warped_exp = task.warpExposure(img_to_warp, wcs_to_warp, ref_wcs, ref_img.getBBox()) 39 | 40 | return warped_exp 41 | 42 | 43 | def read_cutout_mem(sq): 44 | """Read the cutout into memory 45 | 46 | Parameters 47 | ---------- 48 | sq : 'dict' 49 | returned from SodaQuery.from_resource() 50 | 51 | Returns 52 | ------- 53 | exposure : 'ExposureF' 54 | the cutout in exposureF format 55 | """ 56 | 57 | cutout_bytes = sq.execute_stream().read() 58 | sq.raise_if_error() 59 | mem = MemFileManager(len(cutout_bytes)) 60 | mem.setData(cutout_bytes, len(cutout_bytes)) 61 | exposure = ExposureF(mem) 62 | 63 | return exposure 64 | 65 | 66 | def make_image_cutout(tap_service, ra, dec, data_id, cutout_size=0.01, imtype=None): 67 | """Wrapper function to generate a cutout using the cutout tool 68 | 69 | Parameters 70 | ---------- 71 | tap_service : `pyvo.dal.tap.TAPService` 72 | the TAP service to use for querying the cutouts 73 | ra, dec : 'float' 74 | the ra and dec of the cutout center 75 | data_id : 'dict' 76 | the dataId of the image to make a cutout from. The format 77 | must correspond to that provided for parameter 'imtype' 78 | cutout_size : 'float', optional 79 | edge length in degrees of the cutout 80 | imtype : 'string', optional 81 | string containing the type of LSST image to generate 82 | a cutout of (e.g. deepCoadd, calexp). If imtype=None, 83 | the function will assume a deepCoadd. 84 | 85 | Returns 86 | ------- 87 | exposure : 'ExposureF' 88 | the cutout in exposureF format 89 | """ 90 | 91 | sphere_point = geom.SpherePoint(ra * geom.degrees, dec * geom.degrees) 92 | 93 | if imtype == "calexp": 94 | query = ( 95 | "SELECT access_format, access_url, dataproduct_subtype, " 96 | + "lsst_visit, lsst_detector, lsst_band " 97 | + "FROM dp02_dc2_catalogs.ObsCore WHERE dataproduct_type = 'image' " 98 | + "AND obs_collection = 'LSST.DP02' " 99 | + "AND dataproduct_subtype = 'lsst.calexp' " 100 | + "AND lsst_visit = " 101 | + str(data_id["visit"]) 102 | + " " 103 | + "AND lsst_detector = " 104 | + str(data_id["detector"]) 105 | ) 106 | results = tap_service.search(query) 107 | 108 | else: 109 | # Find the tract and patch that contain this point 110 | tract = data_id["tract"] 111 | patch = data_id["patch"] 112 | 113 | # add optional default band if it is not contained in the data_id 114 | band = data_id["band"] 115 | 116 | query = ( 117 | "SELECT access_format, access_url, dataproduct_subtype, " 118 | + "lsst_patch, lsst_tract, lsst_band " 119 | + "FROM dp02_dc2_catalogs.ObsCore WHERE dataproduct_type = 'image' " 120 | + "AND obs_collection = 'LSST.DP02' " 121 | + "AND dataproduct_subtype = 'lsst.deepCoadd_calexp' " 122 | + "AND lsst_tract = " 123 | + str(tract) 124 | + " " 125 | + "AND lsst_patch = " 126 | + str(patch) 127 | + " " 128 | + "AND lsst_band = " 129 | + "'" 130 | + str(band) 131 | + "' " 132 | ) 133 | results = tap_service.search(query) 134 | 135 | # Get datalink 136 | data_link_url = results[0].getdataurl() 137 | auth_session = tap_service._session 138 | dl = DatalinkResults.from_result_url(data_link_url, session=auth_session) 139 | 140 | # from_resource: creates a instance from 141 | # a number of records and a Datalink Resource. 142 | sq = SodaQuery.from_resource(dl, dl.get_adhocservice_by_id("cutout-sync"), session=auth_session) 143 | 144 | sq.circle = ( 145 | sphere_point.getRa().asDegrees() * u.deg, 146 | sphere_point.getDec().asDegrees() * u.deg, 147 | cutout_size * u.deg, 148 | ) 149 | 150 | exposure = read_cutout_mem(sq) 151 | 152 | return exposure 153 | 154 | 155 | def dia_source_to_observations(cutout_size_pix, dia_src, service, plot_images=False): 156 | """Convert a DIA source to a list of scarlet2 Observations 157 | 158 | Parameters 159 | ---------- 160 | cutout_size_pix : 'int' 161 | the size of the cutout in pixels 162 | dia_src : `astropy.table.Table` 163 | the DIA source table containing the sources to make cutouts for 164 | service : `pyvo.dal.tap.TAPService` 165 | the TAP service to use for querying the cutouts 166 | i.e. the result from lsst.rsp.get_tap_service() 167 | plot_images : 'bool', optional 168 | whether to plot the images as they are processed 169 | (default is False) 170 | 171 | Returns 172 | ------- 173 | observations : 'list of scarlet2.Observation' 174 | a list of scarlet2 Observations, one for each DIA source 175 | channels_sc2 : 'list of tuples' 176 | a list of tuples containing the band and channel number for each observation 177 | """ 178 | # TODO: - add back in the plotting functionality? 179 | # - figure out how to get the WCS from the cutout without writing to disk 180 | cutout_size = cutout_size_pix * 0.2 / 3600.0 181 | ra = dia_src["ra"][0] 182 | dec = dia_src["decl"][0] 183 | 184 | observations = [] 185 | channels_sc2 = [] 186 | img_ref = None 187 | wcs_ref = None 188 | 189 | first_time = dia_src["midPointTai"][0] 190 | vmin = -200 191 | vmax = 300 192 | 193 | for i, src in enumerate(dia_src): 194 | ccd_visit_id = src["ccdVisitId"] 195 | band = str(src["filterName"]) 196 | visit = str(ccd_visit_id)[:-3] 197 | detector = str(ccd_visit_id)[-3:] 198 | visit = int(visit) 199 | detector = int(detector) 200 | data_id_calexp = {"visit": visit, "detector": detector} 201 | 202 | if i == 0: 203 | img = make_image_cutout( 204 | service, ra, dec, cutout_size=cutout_size, imtype="calexp", dataId=data_id_calexp 205 | ) 206 | img_ref = img 207 | # no warping is needed for the reference 208 | img_warped = img_ref 209 | offset = geom.Extent2D(geom.Point2I(0, 0) - img_ref.getXY0()) 210 | shifted_wcs = img_ref.getWcs().copyAtShiftedPixelOrigin(offset) 211 | wcs_ref = WCS(shifted_wcs.getFitsMetadata()) 212 | else: 213 | img = make_image_cutout( 214 | service, ra, dec, cutout_size=cutout_size * 50.0, imtype="calexp", dataId=data_id_calexp 215 | ) 216 | img_warped = warp_img(img_ref, img, img_ref.getWcs(), img.getWcs()) 217 | im_arr = img_warped.image.array 218 | var_arr = img_warped.variance.array 219 | 220 | # reshape image array 221 | n1, n2 = im_arr.shape 222 | image_sc2 = im_arr.reshape(1, n1, n2) 223 | 224 | # reshape variance array 225 | n1, n2 = var_arr.shape 226 | weight_sc2 = 1 / var_arr.reshape(1, n1, n2) 227 | 228 | # other transformations 229 | point_tuple = (int(img_warped.image.array.shape[0] / 2), int(img_warped.image.array.shape[1] / 2)) 230 | point_image = Point2D(point_tuple) 231 | xy_transform = afw_geom.makeWcsPairTransform(img.wcs, img_warped.wcs) 232 | psf_w = WarpedPsf(img.getPsf(), xy_transform) 233 | point_tuple = (int(img_warped.image.array.shape[0] / 2), int(img_warped.image.array.shape[1] / 2)) 234 | point_image = Point2D(point_tuple) 235 | psf_warped = psf_w.computeImage(point_image).convertF() 236 | 237 | # filter out images with no overlapping data at point 238 | if np.sum(psf_warped.array) == 0: 239 | print("PSF model unavailable, skipping") 240 | continue 241 | 242 | # reshape psf array 243 | n1, n2 = psf_warped.array.shape 244 | psf_sc2 = psf_warped.array.reshape(1, n1, n2) 245 | 246 | obs = scarlet2.Observation( 247 | jnp.array(image_sc2).astype(float), 248 | weights=jnp.array(weight_sc2).astype(float), 249 | psf=scarlet2.ArrayPSF(jnp.array(psf_sc2).astype(float)), 250 | wcs=wcs_ref, 251 | channels=[(band, str(i))], 252 | ) 253 | channels_sc2.append((band, str(i))) 254 | observations.append(obs) 255 | 256 | if plot_images: 257 | _, ax = plt.subplots(1, 1, figsize=(2, 2)) 258 | plt.imshow(im_arr, origin="lower", cmap="gray", vmin=vmin, vmax=vmax) 259 | ax.set_xticklabels([]) 260 | ax.set_yticklabels([]) 261 | ax.set_xticks([]) 262 | ax.set_yticks([]) 263 | ax.text( 264 | 0.1, 265 | 0.9, 266 | r"$\Delta$t=" + str(round(src["midPointTai"] - first_time, 2)), 267 | color="white", 268 | fontsize=12, 269 | ) 270 | plt.show() 271 | plt.close() 272 | return observations, channels_sc2 273 | -------------------------------------------------------------------------------- /docs/howto/priors.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Use Neural Priors\n", 8 | "\n", 9 | "In [the sampling tutorial](sampling), we have demonstrated how to define parameters with priors.\n", 10 | "This guide shows you how to set up and use neural network priors.\n", 11 | "We make use of the related package [`galaxygrad`](https://github.com/SampsonML/galaxygrad), which can be pip-installed.\n", 12 | "\n", 13 | "This guide will follow the [Quick Start Guide](../0-quickstart), with changes in the initialization and parameter specification. We assume that you have a full installation of _scarlet2_ including `optax`, `numpyro`, `h5py` and `galaxygrad`.\n", 14 | "\n", 15 | "More details about the use of a score-based prior model for diffusion can be found in the paper \"Score-matching neural networks for improved multi-band source separation\", [Sampson et al., 2024, A&C, 49, 100875](http://ui.adsabs.harvard.edu/abs/2024A&C....4900875S)." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "# Import Packages and setup\n", 25 | "import jax.numpy as jnp\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "\n", 28 | "import scarlet2 as sc2\n", 29 | "\n", 30 | "sc2.set_validation(False)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Create Observation\n", 38 | "\n", 39 | "Again we import the test data and create the observation:" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "from huggingface_hub import hf_hub_download\n", 49 | "\n", 50 | "filename = hf_hub_download(\n", 51 | " repo_id=\"astro-data-lab/scarlet-test-data\", filename=\"hsc_cosmos_35.npz\", repo_type=\"dataset\"\n", 52 | ")\n", 53 | "file = jnp.load(filename)\n", 54 | "data = jnp.asarray(file[\"images\"])\n", 55 | "channels = [str(f) for f in file[\"filters\"]]\n", 56 | "centers = jnp.array([(src[\"y\"], src[\"x\"]) for src in file[\"catalog\"]])\n", 57 | "weights = jnp.asarray(1 / file[\"variance\"])\n", 58 | "psf = jnp.asarray(file[\"psfs\"])\n", 59 | "\n", 60 | "# create the observation\n", 61 | "obs = sc2.Observation(\n", 62 | " data,\n", 63 | " weights,\n", 64 | " psf=sc2.ArrayPSF(psf),\n", 65 | " channels=channels,\n", 66 | ")\n", 67 | "model_frame = sc2.Frame.from_observations(obs)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "## Initialize Sources" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "with sc2.Scene(model_frame) as scene:\n", 84 | " for i, center in enumerate(centers):\n", 85 | " if i == 0: # we know source 0 is a star\n", 86 | " spectrum = sc2.init.pixel_spectrum(obs, center, correct_psf=True)\n", 87 | " sc2.PointSource(center, spectrum)\n", 88 | " else:\n", 89 | " try:\n", 90 | " spectrum, morph = sc2.init.from_gaussian_moments(obs, center, min_corr=0.99)\n", 91 | " except ValueError:\n", 92 | " spectrum = sc2.init.pixel_spectrum(obs, center)\n", 93 | " morph = sc2.init.compact_morphology()\n", 94 | "\n", 95 | " sc2.Source(center, spectrum, morph)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "## Load Neural Prior" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "# load in the model you wish to use\n", 112 | "from galaxygrad import get_prior\n", 113 | "from scarlet2.nn import ScorePrior\n", 114 | "\n", 115 | "# instantiate the prior class\n", 116 | "temp = 2e-2 # values in the range of [1e-3, 1e-1] produce good results\n", 117 | "prior32 = get_prior(\"hsc32\")\n", 118 | "prior64 = get_prior(\"hsc64\")\n", 119 | "prior32 = ScorePrior(prior32, prior32.shape(), t=temp)\n", 120 | "prior64 = ScorePrior(prior64, prior64.shape(), t=temp)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "The prior model itself is in the form of a score-based diffusion model, which matches the score function, i.e. the gradient of the log-likelihood of the training data with respect to the parameters. For an image-based parameterization, the free parameters are the pixels, which means the gradient has the same shape as the image. `galaxygrad` provides several pre-trained models, here we use a prior that was trained on deblended isolate source in HSC data, with the shapes of 32x32 or 64x64, respectively. These sizes denote the maximum image size for which the prior is trained.\n", 128 | "\n", 129 | "We import {py:class}`~scarlet2.nn.ScorePrior` to use with our prior. It automatically zero-pads any smaller image array up to the specified size and provides a custom gradient path that calls the underlying score model during optimization or HMC sampling. The `temp` argument refers to a fixed temperature for the diffusion process. For speed, we run a single diffusion step with the given temperature." 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": { 135 | "collapsed": false 136 | }, 137 | "source": [ 138 | "## Define Parameters with Prior\n", 139 | "\n", 140 | "We use the same fitting routine as in the Quickstart guide, but replace `contraints.positive` with `prior=prior` in the Parameter containing the source morphologies. It is also useful to reduce the step size for the morphology updates because large jumps can lead to unstable prior gradients." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "from functools import partial\n", 150 | "from numpyro.distributions import constraints\n", 151 | "\n", 152 | "spec_step = partial(sc2.relative_step, factor=0.05)\n", 153 | "morph_step = partial(sc2.relative_step, factor=1e-3)\n", 154 | "\n", 155 | "with sc2.Parameters(scene) as parameters:\n", 156 | " for i in range(len(scene.sources)):\n", 157 | " sc2.Parameter(\n", 158 | " scene.sources[i].spectrum,\n", 159 | " name=f\"spectrum:{i}\",\n", 160 | " constraint=constraints.positive,\n", 161 | " stepsize=spec_step,\n", 162 | " )\n", 163 | " if i == 0:\n", 164 | " sc2.Parameter(scene.sources[i].center, name=f\"center:{i}\", stepsize=0.1)\n", 165 | " else:\n", 166 | " # chose a prior of suitable size\n", 167 | " prior = prior32 if max(scene.sources[i].morphology.shape) <= 32 else prior64\n", 168 | " sc2.Parameter(\n", 169 | " scene.sources[i].morphology,\n", 170 | " name=f\"morph:{i}\",\n", 171 | " prior=prior, # attach the prior here\n", 172 | " stepsize=morph_step,\n", 173 | " )" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": {}, 179 | "source": [ 180 | "Note that the use of a `prior` is incompatible with the use of a `constraint`." 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": { 186 | "collapsed": false 187 | }, 188 | "source": [ 189 | "We again perform the fitting:" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "maxiter = 1000\n", 199 | "print(\"Initial likelihood:\", obs.log_likelihood(scene()))\n", 200 | "scene_ = scene.fit(obs, parameters, max_iter=maxiter, e_rel=1e-4, progress_bar=True)\n", 201 | "print(\"Optimized likelihood:\", obs.log_likelihood(scene_()))" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": { 207 | "collapsed": false 208 | }, 209 | "source": [ 210 | "The fit reaches values quite comparable to the run with the positivity constraints in the quickstart guide." 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "## Check Results" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "norm = sc2.plot.AsinhAutomaticNorm(obs)\n", 227 | "sc2.plot.scene(\n", 228 | " scene_,\n", 229 | " obs,\n", 230 | " norm=norm,\n", 231 | " show_model=True,\n", 232 | " show_rendered=True,\n", 233 | " show_observed=True,\n", 234 | " show_residual=True,\n", 235 | " add_boxes=True,\n", 236 | ")\n", 237 | "plt.show()" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "sc2.plot.sources(\n", 247 | " scene_,\n", 248 | " norm=norm,\n", 249 | " observation=obs,\n", 250 | " show_model=True,\n", 251 | " show_rendered=True,\n", 252 | " show_observed=True,\n", 253 | " show_spectrum=False,\n", 254 | " add_labels=False,\n", 255 | " add_boxes=True,\n", 256 | ")\n", 257 | "plt.show()" 258 | ] 259 | }, 260 | { 261 | "cell_type": "markdown", 262 | "metadata": {}, 263 | "source": [ 264 | "The results for most of the galaxies look very reasonable now, in particular for the fainter ones. They remain compact and not overly affected by noise. Source #1 has minor artifacts and picks up neighboring objects, indicating that this prior has not been trained (yet) on as many larger galaxies and is therefore still somewhat weak. An update will fix this soon." 265 | ] 266 | } 267 | ], 268 | "metadata": { 269 | "celltoolbar": "Raw Cell Format", 270 | "kernelspec": { 271 | "display_name": "scarlet2", 272 | "language": "python", 273 | "name": "python3" 274 | }, 275 | "language_info": { 276 | "codemirror_mode": { 277 | "name": "ipython", 278 | "version": 3 279 | }, 280 | "file_extension": ".py", 281 | "mimetype": "text/x-python", 282 | "name": "python", 283 | "nbconvert_exporter": "python", 284 | "pygments_lexer": "ipython3", 285 | "version": "3.12.11" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 4 290 | } 291 | -------------------------------------------------------------------------------- /tests/scarlet2/questionnaire/conftest.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from importlib.resources import files 4 | from pathlib import Path 5 | 6 | import yaml 7 | from ipywidgets import HTML, Button, HBox, VBox 8 | from pygments import highlight 9 | from pygments.formatters.html import HtmlFormatter 10 | from pygments.lexers.python import PythonLexer 11 | from pytest import fixture 12 | from scarlet2.questionnaire.models import QuestionAnswer, QuestionAnswers, Questionnaire 13 | from scarlet2.questionnaire.questionnaire import ( 14 | OUTPUT_BOX_LAYOUT, 15 | OUTPUT_BOX_STYLE_FILE, 16 | QUESTION_BOX_LAYOUT, 17 | VIEWS_PACKAGE_PATH, 18 | ) 19 | 20 | 21 | @fixture 22 | def data_dir(): 23 | """Path to the data directory containing the example questionnaire YAML file.""" 24 | return Path(__file__).parent / "data" 25 | 26 | 27 | @fixture 28 | def example_questionnaire_dict(data_dir): 29 | """An example questionnaire dictionary""" 30 | yaml_path = data_dir / "example_questionnaire.yaml" 31 | with yaml_path.open("r") as f: 32 | return yaml.safe_load(f) 33 | 34 | 35 | @fixture 36 | def example_questionnaire(example_questionnaire_dict): 37 | """An example Questionnaire model instance""" 38 | return Questionnaire.model_validate(example_questionnaire_dict) 39 | 40 | 41 | @fixture 42 | def example_questionnaire_with_switch_dict(data_dir): 43 | """An example questionnaire dictionary with a switch question""" 44 | yaml_path = data_dir / "example_questionnaire_switch.yaml" 45 | with yaml_path.open("r") as f: 46 | return yaml.safe_load(f) 47 | 48 | 49 | @fixture 50 | def example_questionnaire_with_switch(example_questionnaire_with_switch_dict): 51 | """An example Questionnaire model instance with a switch question""" 52 | return Questionnaire.model_validate(example_questionnaire_with_switch_dict) 53 | 54 | 55 | @fixture 56 | def questionnaire_with_followup_switch_example_dict(data_dir): 57 | """An example questionnaire dictionary with a switch question""" 58 | yaml_path = data_dir / "example_questionnaire_followup_switch.yaml" 59 | with yaml_path.open("r") as f: 60 | return yaml.safe_load(f) 61 | 62 | 63 | @fixture 64 | def example_questionnaire_with_followup_switch(questionnaire_with_followup_switch_example_dict): 65 | """An example Questionnaire model instance with a switch question""" 66 | return Questionnaire.model_validate(questionnaire_with_followup_switch_example_dict) 67 | 68 | 69 | @fixture 70 | def example_questionnaire_with_feedback(example_questionnaire): 71 | """An example Questionnaire model instance with a feedback URL""" 72 | questionnaire = example_questionnaire.model_copy(deep=True) 73 | questionnaire.feedback_url = "https://example.com/feedback" 74 | return questionnaire 75 | 76 | 77 | @fixture 78 | def example_question_answers(example_questionnaire): 79 | """An example list of question answers for the exmaple questionnaire""" 80 | answer_inds = [0, 1, 0] # indices of answers to select for each question 81 | questions = [ 82 | example_questionnaire.questions[0], 83 | example_questionnaire.questions[0].answers[0].followups[0], 84 | example_questionnaire.questions[0].answers[0].followups[1], 85 | ] 86 | qas = [ 87 | QuestionAnswer(question=q.question, answer=q.answers[i].answer, value=i) 88 | for q, i in zip(questions, answer_inds, strict=False) 89 | ] 90 | return QuestionAnswers(answers=qas) 91 | 92 | 93 | class Helpers: 94 | """Helper functions for testing the QuestionnaireWidget""" 95 | 96 | @staticmethod 97 | def get_answer_button(widget, answer_index): 98 | """Get an answer button from the question box children. 99 | 100 | Args: 101 | widget: The QuestionnaireWidget instance 102 | answer_index: The index of the answer button to get 103 | 104 | Returns: 105 | The answer button widget 106 | """ 107 | css_offset = 1 108 | prev_questions_offset = len(widget.question_answers) 109 | question_label_offset = 1 110 | 111 | # Get the buttons container which is at index after the question label 112 | buttons_container_index = css_offset + prev_questions_offset + question_label_offset 113 | buttons_container = widget.question_box.children[buttons_container_index] 114 | 115 | # Return the specific button from the buttons container 116 | return buttons_container.children[answer_index] 117 | 118 | @staticmethod 119 | def get_prev_question_button(widget, question_index): 120 | """Get a previous question button from the question box children. 121 | 122 | Args: 123 | widget: The QuestionnaireWidget instance 124 | question_index: The index of the previous question to get 125 | 126 | Returns: 127 | The previous question button widget 128 | """ 129 | css_offset = 1 130 | 131 | container = widget.question_box.children[css_offset + question_index] 132 | # The button is the first child of the container 133 | return container.children[0] 134 | 135 | @staticmethod 136 | def assert_widget_ui_matches_state(widget): 137 | """Assert that the widget's UI matches its internal state.""" 138 | assert isinstance(widget.ui, HBox) 139 | assert widget.ui.children == (widget.question_box, widget.output_box) 140 | 141 | assert isinstance(widget.output_box, VBox) 142 | assert widget.output_box.children == (widget.output_container,) 143 | assert widget.output_box.layout == OUTPUT_BOX_LAYOUT 144 | 145 | assert isinstance(widget.output_container, HTML) 146 | 147 | # check output container contains css from output_box css file 148 | css_file = files(VIEWS_PACKAGE_PATH).joinpath(OUTPUT_BOX_STYLE_FILE) 149 | with css_file.open("r") as f: 150 | css_content = f.read() 151 | 152 | assert css_content in widget.output_container.value 153 | 154 | output_code = re.sub(r"\{\{\s*\w+\s*\}\}", "", widget.code_output) 155 | 156 | html = widget.output_container.value 157 | 158 | # regex to capture the JS argument 159 | match = re.search(r"navigator\.clipboard\.writeText\((.*?)\)", html) 160 | assert match, "No copy button found" 161 | 162 | actual_arg = match.group(1) 163 | expected_arg = json.dumps(output_code) # exactly how Jinja|tojson would encode it 164 | 165 | assert actual_arg.strip() == expected_arg 166 | 167 | formatter = HtmlFormatter(style="monokai", noclasses=True) 168 | highlighted_code = highlight(output_code, PythonLexer(), formatter) 169 | 170 | assert highlighted_code in widget.output_container.value 171 | 172 | assert isinstance(widget.question_box, VBox) 173 | assert widget.question_box.layout == QUESTION_BOX_LAYOUT 174 | 175 | css_snippet_count = 1 176 | save_button_container_count = 1 177 | 178 | # If there's a current question, we have: 179 | # - CSS snippet 180 | # - Previous question containers 181 | # - Current question label 182 | # - Buttons container 183 | # - Save button container 184 | if widget.current_question: 185 | expected_children_count = ( 186 | css_snippet_count + len(widget.question_answers) + 1 + 1 + save_button_container_count 187 | ) 188 | # If there's no current question, we have: 189 | # - CSS snippet 190 | # - Previous question containers 191 | # - Final message container 192 | # - Save button container 193 | else: 194 | expected_children_count = ( 195 | css_snippet_count + len(widget.question_answers) + 1 + save_button_container_count 196 | ) 197 | 198 | assert len(widget.question_box.children) == expected_children_count 199 | 200 | # Skip the CSS snippet 201 | css_offset = 1 202 | 203 | for i in range(len(widget.question_answers)): 204 | # Add the CSS offset to the index 205 | child_index = i + css_offset 206 | assert isinstance(widget.question_box.children[child_index], HBox) 207 | # The button is the first child of the container 208 | btn = widget.question_box.children[child_index].children[0] 209 | question = widget.question_answers[i][0] 210 | assert question.question in btn.description 211 | ans_index = widget.question_answers[i][1] 212 | assert question.answers[ans_index].answer in btn.description 213 | 214 | if widget.current_question is not None: 215 | # Add the CSS offset to the index 216 | qs_ind = len(widget.question_answers) + css_offset 217 | 218 | assert isinstance(widget.question_box.children[qs_ind], HTML) 219 | assert widget.current_question.question in widget.question_box.children[qs_ind].value 220 | 221 | # Get the buttons container which is at index after the question label 222 | buttons_container_index = qs_ind + 1 223 | buttons_container = widget.question_box.children[buttons_container_index] 224 | assert isinstance(buttons_container, VBox) 225 | 226 | # Check each button in the buttons container 227 | for btn, ans in zip(buttons_container.children, widget.current_question.answers, strict=False): 228 | assert isinstance(btn, Button) 229 | assert btn.description == ans.answer 230 | assert btn.tooltip == ans.tooltip 231 | 232 | else: 233 | # When there's no current question, we have: 234 | # - CSS snippet 235 | # - Previous question containers 236 | # - Final message container 237 | # - Save button container 238 | 239 | # Get the final message container which is at index before the save button container 240 | final_message_index = len(widget.question_box.children) - 2 241 | final_message_container = widget.question_box.children[final_message_index] 242 | assert isinstance(final_message_container, VBox) 243 | 244 | # The final message is in the HTML child of the container 245 | final_message_html = final_message_container.children[0] 246 | assert isinstance(final_message_html, HTML) 247 | final_message = final_message_html.value 248 | assert "You're done" in final_message 249 | 250 | # Check for feedback URL if present in the questionnaire 251 | if widget.feedback_url: 252 | assert widget.feedback_url in final_message 253 | assert "feedback form" in final_message 254 | 255 | # Check that the last child is the save button container 256 | save_button_container = widget.question_box.children[-1] 257 | assert isinstance(save_button_container, VBox) 258 | assert len(save_button_container.children) > 0 259 | save_button = save_button_container.children[-1] 260 | assert isinstance(save_button, Button) 261 | assert save_button.description == "Save Answers" 262 | 263 | 264 | @fixture 265 | def helpers(): 266 | """Provide helper functions for testing.""" 267 | return Helpers() 268 | -------------------------------------------------------------------------------- /docs/2-questionnaire.md: -------------------------------------------------------------------------------- 1 | # Questionnaire 2 | 3 | ## User Guide - Interactive Project Setup 4 | 5 | The scarlet2 questionnaire is an interactive tool designed to help you quickly set up a new scarlet2 project 6 | by generating a customized code template that matches your use case. 7 | 8 | ### Running the Questionnaire 9 | 10 | To run the questionnaire, you need to import and call the `run_questionnaire` function from the `scarlet2.questionnaire` module: 11 | 12 | ```python 13 | from scarlet2.questionnaire import run_questionnaire 14 | run_questionnaire() 15 | ``` 16 | 17 | ![Questionnaire Screenshot](_static/questionnaire_screenshot.png) 18 | 19 | This will launch an interactive widget in your Jupyter notebook that guides you through a series of questions about your project and your data. 20 | 21 | The questionnaire presents one question at a time with multiple-choice answers. For each answer you select, 22 | the questionnaire will update the code template to match your choices and display some explanatory text to 23 | help you understand the code being generated. 24 | 25 | You can navigate through the questions using the answer buttons. If you want to change a previous answer, 26 | you can click the previous question in the list to go back to that question. 27 | 28 | ### Using the Generated Template 29 | 30 | As you progress through the questionnaire, a code template is dynamically generated based on your answers. 31 | The code will outline the steps needed to set up your scarlet2 project, but it will require some manual editing 32 | to fill in specific details about your data. For example, the code will include definitions of example values 33 | for variables like `channels` that you will need to replace with values that match your data, and will 34 | reference variables like `data` and `psf` that you will need to change to use your actual data. 35 | 36 | In the bottom right of the code output, you can click the "📋 Copy" button at any time to copy the generated 37 | code to your clipboard. 38 | 39 | > Note: The template is a starting point that you can modify to fit your specific data. It provides the structure for your project based on your use case. 40 | 41 | ### Saving Your Progress 42 | 43 | If you want to save your questionnaire answers and return to them later, you can use the `Save Answers` 44 | button in the bottom left of the questionnaire widget. This will save your answers to a yaml file that you can 45 | then load later and continue from. 46 | 47 | The yaml file will be saved to your current working directory with a filename like 48 | `scarlet2_questionnaire_timestamp.yaml`. To change where the file is saved, run the questionnaire with the 49 | `save_path` argument: 50 | 51 | ```python 52 | from scarlet2.questionnaire import run_questionnaire 53 | run_questionnaire(save_path="path/to/save/directory") 54 | ``` 55 | 56 | To load a previously saved questionnaire, use the `answer_path` argument to specify the path to your saved 57 | yaml file: 58 | 59 | ```python 60 | from scarlet2.questionnaire import run_questionnaire 61 | run_questionnaire(answer_path="path/to/saved/answers.yaml") 62 | ``` 63 | 64 | ### Feedback and Issues 65 | 66 | If you encounter any issues or have suggestions for improving the questionnaire, [please fill out our feedback form!](https://docs.google.com/forms/d/e/1FAIpQLScKHbiqxhizacgzRx3xHEdqqjgtZBsxjtQZFJlYBdLcbOnfBg/viewform) 67 | 68 | ## Developer Guide - Questionnaire Architecture 69 | 70 | The questionnaire module is designed to be extensible and maintainable. This section explains how the questionnaire works internally and how to modify or extend it. 71 | 72 | ### Module Structure 73 | 74 | The questionnaire module consists of several key components: 75 | 76 | 1. **questions.yaml**: Stores the actual questions, answers, and templates 77 | 2. **models.py**: Defines the data structures used by the questionnaire that map to the YAML file 78 | 3. **questionnaire.py**: Contains the main `QuestionnaireWidget` class that uses ipywidgets to render the UI and handle user interactions 79 | 80 | ### YAML Structure 81 | 82 | Setting up the questions in the questionnaire is done by modifying the `questions.yaml` file. 83 | The questions are defined in this YAML file with the following structure: 84 | 85 | ```yaml 86 | initial_template: "{{code}}" # The starting template with placeholders 87 | initial_commentary: "Welcome message" # (Optional) The initial commentary text before any questions are answered 88 | questions: # List of top-level questions 89 | - question: "Question text" # Each question object has a question text, answers, and optionally a variable 90 | variable: "variable_name" # Optional variable to store the answer to be referenced later 91 | answers: 92 | - answer: "Answer option 1" # Each answer has an answer text 93 | tooltip: "Helpful tooltip" # Optional tooltip for the answer that appears on hover of the button 94 | templates: # A list of code snippets to apply if selected 95 | - replacement: code # The placeholder to replace 96 | code: "# Your code here\n{{next_placeholder}}" # The replacement code 97 | commentary: "Explanation of this choice" # The commentary text to display when this answer is selected, can include markdown formatting 98 | followups: # Additional questions to ask immediately if this answer is selected. This list of followups matches the structure of top-level questions, and can include question objects or switch/case objects 99 | - question: "Follow-up question" 100 | answers: [...] 101 | - question: "..." # More questions 102 | answers: [...] 103 | # ... 104 | - switch: "variable_name" # Conditional branching based on a previous answer 105 | cases: 106 | - value: 0 # If the question with variable "variable_name" was answered with the first answer (index 0) 107 | questions: [...] # The questions to ask in this case. This list matches the structure of top-level questions, and can include question objects or switch/case objects 108 | - value: null # The default case if no other case matches or if the variable was not set (e.g. the question was skipped) If there is no default case, the switch is skipped 109 | questions: [...] 110 | ``` 111 | 112 | The questionnaire starts with an `initial_template` that contains placeholders (e.g. `{{code}}`) that will be replaced 113 | as the user answers questions. Each question has a list of possible answers, and each answer can specify one or more 114 | code snippets to replace the specified placeholders in the template, as well as commentary text to display. 115 | The commentary can include markdown formatting which will be rendered in the commentary box. 116 | Answers can also specify follow-up questions that are asked immediately after the current question if that answer is selected. 117 | 118 | The flow of questions can also include conditional branching using `switch` objects that check the value of a 119 | previously answered question (by its variable name) and present different sets of questions based on the answer. 120 | 121 | > Note: In YAML, single and double quotes behave differently. Single quotes will treat backslashes as literal 122 | > characters, while double quotes will interpret backslashes as escape characters. For example, to include a 123 | > newline character in a string, you would use double quotes: `"Line 1\nLine 2"`. If you used single quotes, 124 | > it would be treated as the literal text `Line 1\nLine 2` without a newline. 125 | 126 | The yaml file is packaged with the module as it is built, and is loaded using the `importlib.resources` module 127 | that allows access to package data files even in a zipped package. With `pyproject.toml` and `setuptools-scm`, 128 | any package data files that are tracked by git are automatically included in the package, and so the 129 | `questions.yaml` file is included automatically when the package is built and deployed. 130 | [(See the setuptools documentation for more details).](https://setuptools.pypa.io/en/latest/userguide/datafiles.html). 131 | 132 | ### Data Models 133 | 134 | The questionnaire uses Pydantic models to define its data structures which are found in `models.py`. 135 | These models match the structure of the YAML file and are used to parse and validate the questionnaire data 136 | loaded from the YAML file. 137 | 138 | Pydantic allows defining dataclasses with type hints and validation, making it easier to work with structured 139 | data like the questionnaire configuration. When the YAML file is loaded, it is parsed into these models for 140 | use in the questionnaire logic. If the YAML structure does not match the expected models, Pydantic will raise 141 | validation errors when the questionnaire is initialized. 142 | 143 | ### QuestionnaireWidget Class 144 | 145 | The `QuestionnaireWidget` class is responsible for rendering the questionnaire UI and handling user interactions. 146 | 147 | It uses `ipywidgets` to render the questionnaire in the output of a jupyter cell. The class maintains the 148 | state of the questionnaire, including the current question, the user's answers, and the generated code template. 149 | 150 | The state consists of: 151 | 152 | - `self.code_output`: The current code template with placeholders. This is updated as the user answers questions. Regex is used to find and replace placeholders in the template when updating with answer templates. 153 | - `self.commentary`: The current commentary text that explains the generated code. This is set with each answer. 154 | - `self.questions_stack`: A stack of questions to ask. This allows for handling follow-up questions and branching logic. The next question is popped from the stack when needed. 155 | - `self.question_answers`: A list of tuples of (question, answer_index) representing the user's answers to each question. This is used to display the history of questions and answers in the UI. 156 | - `self.variables`: A dictionary mapping variable names to the index of the selected answer. This is used for switch/case branching. 157 | 158 | The UI consists of a question area with a list of the previous questions and answers and a set of buttons 159 | for the current question's answers, and an output area that displays the generated code template and 160 | commentary text. 161 | 162 | The question area uses ipywidgets components to display the questions and answer buttons, and handle the user 163 | input. The output area uses a HTML widget to display the generated code with syntax highlighting, the copy 164 | button, and the commentary text. Since the output area is written in HTML, the code defining the output area 165 | is stored in separate HTML template files (`output_box.html.jinja` and `output_box.css`) for easier editing. 166 | These template files are also included in the package using `setuptools-scm` the same way as `questions.yaml`. 167 | The `jinja2` templating engine is used to render the HTML with the generated code and commentary. For syntax 168 | highlighting, the `pygments` library is used to convert the code into HTML with appropriate styling, and the 169 | commentary text is converted from markdown to HTML using the `markdown` library. 170 | 171 | Key methods include: 172 | - `_get_next_question()`: Determines the next question to display. A stack is used to manage the flow of questions, including handling follow-up questions and switch/case branching. 173 | - `_handle_answer()`: Processes user answers and updates the state. 174 | - `_update_template()`: Applies template changes based on answers. 175 | - `_render_output_box()`: Creates the UI for displaying the generated code. Uses Jinja2, pygments, and markdown for rendering. 176 | -------------------------------------------------------------------------------- /src/scarlet2/source.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | from astropy.coordinates import SkyCoord 6 | 7 | from . import Scenery 8 | from .bbox import Box, overlap_slices 9 | from .module import Module 10 | from .morphology import Morphology 11 | from .spectrum import Spectrum 12 | from .validation_utils import ( 13 | ValidationMethodCollector, 14 | ) 15 | 16 | 17 | class Component(Module): 18 | """Single component of a hyperspectral model 19 | 20 | The parameterization of the 3D model (channel, height, width) is defined by 21 | the outer product of `spectrum` and `morphology`. That means that there is no 22 | variation of the spectrum in spatial direction. The `center` coordinate is only 23 | needed to define the bounding box and place the component in the model frame. 24 | """ 25 | 26 | center: jnp.ndarray 27 | """Center position, in pixel coordinates of the model frame""" 28 | spectrum: (jnp.array, Spectrum) 29 | """Spectrum model""" 30 | morphology: (jnp.array, Morphology) 31 | """Morphology model""" 32 | bbox: Box 33 | """Bounding box of the model, in pixel coordinates of the model frame""" 34 | 35 | def __init__(self, center, spectrum, morphology): 36 | """ 37 | Parameters 38 | ---------- 39 | center: array, :py:class:`astropy.coordinates.SkyCoord` 40 | Center position. If given as astropy sky coordinate, it will be 41 | transformed with the WCS of the model frame. 42 | spectrum: :py:class:`~scarlet2.Spectrum` 43 | The spectrum of the component. 44 | morphology: :py:class:`~scarlet2.Morphology` 45 | The morphology of the component. 46 | 47 | Examples 48 | -------- 49 | To uniquely determine coordinates, the creation of components is restricted 50 | to a context defined by a :py:class:`~scarlet2.Scene`, which define the 51 | :py:class:`~scarlet2.Frame` of the model. 52 | 53 | >>> with Scene(model_frame) as scene: 54 | >>> component = Component(center, spectrum, morphology) 55 | """ 56 | self.spectrum = spectrum 57 | self.morphology = morphology 58 | 59 | if isinstance(center, SkyCoord): 60 | try: 61 | center = Scenery.scene.frame.get_pixel(center) 62 | except AttributeError: 63 | print("`center` defined in sky coordinates can only be created within the context of a Scene") 64 | print("Use 'with Scene(frame) as scene: (...)'") 65 | raise 66 | self.center = center 67 | 68 | box = Box(spectrum.shape) 69 | box2d = Box(morphology.shape) 70 | box2d.set_center(center.astype(int)) 71 | self.bbox = box @ box2d 72 | 73 | def __call__(self): 74 | """What to run when Component is called""" 75 | # Boxed and centered model 76 | delta_center = (self.center[-2] - self.bbox.center[-2], self.center[-1] - self.bbox.center[-1]) 77 | spectrum = self.spectrum() if isinstance(self.spectrum, Module) else self.spectrum 78 | morph = ( 79 | self.morphology(delta_center=delta_center) 80 | if isinstance(self.morphology, Module) 81 | else self.morphology 82 | ) 83 | return spectrum[:, None, None] * morph[None, :, :] 84 | 85 | 86 | class DustComponent(Component): 87 | """Component with negative exponential model 88 | 89 | This component is meant to describe the dust attenuation, :math:`\\exp(-\\tau)`, 90 | where :math:`\\tau` is the hyperspectral model defined by the base :py:class:`~scarlet2.Component`. 91 | """ 92 | 93 | def __call__(self): 94 | """What to run when DustComponent is called""" 95 | return jnp.exp(-super().__call__()) 96 | 97 | 98 | class Source(Component): 99 | """Source model 100 | 101 | The class is the basic parameterization for sources in :py:class:`~scarlet2.Scene`. 102 | """ 103 | 104 | components: list 105 | """List of components in this source""" 106 | component_ops: list 107 | """List of operators to combine `components` for the final model""" 108 | 109 | def __init__(self, center, spectrum, morphology): 110 | """ 111 | Parameters 112 | ---------- 113 | center: array, :py:class:`astropy.coordinates.SkyCoord` 114 | Center position. If given as astropy sky coordinate, it will be 115 | transformed with the WCS of the model frame. 116 | spectrum: array, :py:class:`~scarlet2.Spectrum` 117 | The spectrum of the source. 118 | morphology: array, :py:class:`~scarlet2.Morphology` 119 | The morphology of the source. 120 | 121 | Examples 122 | -------- 123 | A source declaration is restricted to a context of a :py:class:`~scarlet2.Scene`, 124 | which defines the :py:class:`~scarlet2.Frame` of the entire model. 125 | 126 | >>> with Scene(model_frame) as scene: 127 | >>> source = Source(center, spectrum, morphology) 128 | 129 | A source can comprise one or multiple :py:class:`~scarlet2.Component`, 130 | which can be added by :py:func:`add_component` or operators `+=` 131 | (for an additive component) or `*=` (for a multiplicative component). 132 | 133 | >>> with Scene(model_frame) as scene: 134 | >>> source = Source(center, spectrum, morphology) 135 | >>> source *= DustComponent(center, dust_spectrum, dust_morphology) 136 | """ 137 | # set the base component 138 | super().__init__(center, spectrum, morphology) 139 | # create the empty component list 140 | self.components = list() 141 | self.component_ops = list() 142 | 143 | # add this source to the active scene 144 | try: 145 | Scenery.scene.sources.append(self) 146 | except AttributeError: 147 | print("Source can only be created within the context of a Scene") 148 | print("Use 'with Scene(frame) as scene: Source(...)'") 149 | raise 150 | 151 | def add_component(self, component, op): 152 | """Add `component` to this source 153 | 154 | Parameters 155 | ---------- 156 | component: :py:class:`~scarlet2.Component` 157 | The component to include in this source. It will be combined with the 158 | previous component according to the operator `op`. 159 | op: callable 160 | Operator to combine this `component` with those before it in the list :py:attr:`components`. 161 | Conventional operators from the :py:mod:`operator` package can be used. 162 | Signature: op(x,y) -> z, where all terms have the same shapes 163 | """ 164 | assert isinstance(component, (Source, Component)) 165 | 166 | # if component is a source, it's already registered in scene 167 | # remove it from scene to not call it twice 168 | if isinstance(component, Source): 169 | try: 170 | Scenery.scene.sources.remove(component) 171 | except AttributeError: 172 | print("Source can only be modified within the context of a Scene") 173 | print("Use 'with Scene(frame) as scene: Source(...)'") 174 | raise 175 | except ValueError: 176 | pass 177 | 178 | # adding a full source will maintain its ownership of components: 179 | # hierarchical definition of sources withing sources 180 | self.components.append(component) 181 | self.component_ops.append(op) 182 | return self 183 | 184 | def __iadd__(self, component): 185 | return self.add_component(component, operator.add) 186 | 187 | def __imul__(self, component): 188 | return self.add_component(component, operator.mul) 189 | 190 | def __call__(self): 191 | """What to run when Source is called""" 192 | base = super() 193 | model = base.__call__() 194 | for component, op in zip(self.components, self.component_ops, strict=False): 195 | model_ = component() 196 | # cut out regions from model and model_ 197 | bbox, bbox_ = overlap_slices(base.bbox, component.bbox, return_boxes=True) 198 | sub_model = jax.lax.dynamic_slice(model, bbox.start, bbox.shape) 199 | sub_model_ = jax.lax.dynamic_slice(model_, bbox_.start, bbox_.shape) 200 | # combine with operator 201 | sub_model = op(sub_model, sub_model_) 202 | # add model_ back in full model 203 | model = jax.lax.dynamic_update_slice(model, sub_model, bbox.start) 204 | return model 205 | 206 | 207 | class PointSource(Source): 208 | """Point source model""" 209 | 210 | def __init__(self, center, spectrum): 211 | """Model for point sources 212 | 213 | Because the morphology is determined by the model PSF, it does not need to be provided. 214 | 215 | Parameters 216 | ---------- 217 | center: array, :py:class:`astropy.coordinates.SkyCoord` 218 | Center position. If given as astropy sky coordinate, it will be 219 | transformed with the WCS of the model frame. 220 | spectrum: array, :py:class:`~scarlet2.Spectrum` 221 | The spectrum of the point source. 222 | 223 | Examples 224 | -------- 225 | A source declaration is restricted to a context of a :py:class:`~scarlet2.Scene`, 226 | which defines the :py:class:`~scarlet2.Frame` of the entire model. 227 | 228 | >>> with Scene(model_frame) as scene: 229 | >>> point_source = PointSource(center, spectrum) 230 | """ 231 | try: 232 | frame = Scenery.scene.frame 233 | except AttributeError: 234 | print("Source can only be created within the context of a Scene") 235 | print("Use 'with Scene(frame) as scene: Source(...)'") 236 | raise 237 | if frame.psf is None: 238 | raise AttributeError("PointSource can only be create with a PSF in the model frame") 239 | morphology = frame.psf.morphology 240 | 241 | super().__init__(center, spectrum, morphology) 242 | 243 | 244 | class SourceValidator(metaclass=ValidationMethodCollector): 245 | """A class containing all of the validation checks for Source objects. 246 | 247 | Note that the metaclass is defined as `MethodCollector`, which collects all 248 | validation methods in this class into a single class attribute list called 249 | `validation_checks`. This allows for easy iteration over all checks.""" 250 | 251 | def __init__(self, source: Source): 252 | self.source = source 253 | 254 | # def check_source_has_positive_contribution(self) -> ValidationResult: 255 | # """Check that the source has a positive contribution i.e. that the result 256 | # of evaluating self.source() does not contain negative values. 257 | # 258 | # Returns 259 | # ------- 260 | # ValidationResult 261 | # A subclass of ValidationResult indicating the result of the check. 262 | # """ 263 | # model = self.source() 264 | # if jnp.any(model < 0): 265 | # return ValidationError( 266 | # "Source model has negative contributions.", 267 | # check=self.__class__.__name__, 268 | # context={"source": self.source}, 269 | # ) 270 | # else: 271 | # return ValidationInfo( 272 | # "Source model has positive contributions.", 273 | # check=self.__class__.__name__, 274 | # ) 275 | --------------------------------------------------------------------------------