├── .coveragerc ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── README.md │ ├── docs-build-deploy.yml │ ├── draft-pdf.yml │ ├── pre-release.yml │ ├── release.yml │ └── verify-tests-and-docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTORS.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── _static │ ├── baselines │ │ └── src │ │ │ └── print_conversions.py │ ├── bullets.css │ ├── fftshift │ │ └── src │ │ │ └── plot.py │ └── mmd │ │ └── src │ │ ├── BaseCube.mmd │ │ ├── GriddedNet.mmd │ │ ├── ImageCube.mmd │ │ ├── Parametric.mmd │ │ ├── SingleDish.mmd │ │ └── SkyModel.mmd ├── api │ ├── analysis.md │ ├── coordinates.md │ ├── crossval.md │ ├── datasets.md │ ├── fourier.md │ ├── geometry.md │ ├── gridding.md │ ├── images.md │ ├── losses.md │ ├── plotting.md │ ├── precomposed.md │ ├── train_test.md │ └── utilities.md ├── background.md ├── changelog.md ├── conf.py ├── developer-documentation.md ├── favicon.ico ├── index.md ├── installation.md ├── logo.png ├── make.bat └── units-and-conventions.md ├── paper ├── .gitignore ├── fig.pdf ├── paper.bib └── paper.md ├── pyproject.toml ├── src └── mpol │ ├── __init__.py │ ├── constants.py │ ├── coordinates.py │ ├── crossval.py │ ├── data │ └── mock_data.npz │ ├── datasets.py │ ├── exceptions.py │ ├── fourier.py │ ├── geometry.py │ ├── gridding.py │ ├── images.py │ ├── input_output.py │ ├── losses.py │ ├── onedim.py │ ├── plot.py │ ├── precomposed.py │ ├── tests.mplstyle │ ├── training.py │ └── utils.py └── test ├── README.md ├── conftest.py ├── coordinates_test.py ├── crossval_test.py ├── datasets_test.py ├── fftshift_test.py ├── fourier_test.py ├── geometry_test.py ├── gridder_dataset_export_test.py ├── gridder_gridding_test.py ├── gridder_imager_test.py ├── gridder_init_test.py ├── images_test.py ├── input_output_test.py ├── losses_test.py ├── onedim_test.py ├── plot_test.py ├── plot_utils.py ├── train_test_test.py └── utils_test.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | venv/* 4 | relative_files = True 5 | 6 | 7 | [report] 8 | exclude_lines = 9 | raise NotImplementedError 10 | except ImportError 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve MPoL 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To reproduce** 14 | Steps to reproduce the behavior: 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Desktop (please complete the following information):** 23 | - OS: [e.g. Linux + distro, MacOS, Windows] 24 | - Python version 25 | - MPoL version (`$python -c "import mpol; print(mpol.__version__)"` 26 | 27 | **Suggested fix** 28 | If you have a suggestion for how to fix the bug, please explain. 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem or opportunity? Please describe.** 11 | A clear and concise description of what the problem or opportunity is. 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | 22 | **How might you be able to help with the development of this feature?** 23 | I.e., specification of use case, mock data, pull requests with code implementations, tests. 24 | -------------------------------------------------------------------------------- /.github/workflows/README.md: -------------------------------------------------------------------------------- 1 | # MPoL GitHub Actions Workflows 2 | 3 | We use GitHub actions to continuously integrate and deploy the MPoL codebase. This document summarizes the intended functionality of each of the workflows. 4 | 5 | * `verify-tests-and-docs.yml` runs the unit tests and (if successful) builds (but does not deploy) the documentation to ascertain whether the codebase is in a working state (defined as passing tests on all non-experimental Python versions and a successful documentation build). This workflow is intended to run on every commit to the `main` branch as well as every commit to open pull requests. 6 | * `docs-build-deploy.yml` builds and deploys the documentation to [GitHub Pages](https://mpol-dev.github.io/MPoL/). This workflow is intended to run on every commit to the `main` branch, to ensure that the currently deployed documentation matches the current state of the source code. Note that if you are merging a PR: you may find that the docs are built twice, once as part of the `verify-tests-and-docs.yml` and then again as part of `docs-build-deploy.yml`. This duplication is OK, since it is designed to support small changes implemented directly on `main` as well as changes introduced through branches and PRs. 7 | * `pre-release.yml` tries to install the package into Linux, MacOS, and Windows using all supported Python versions. As the name suggests, this is designed to run in anticipation of a release and is triggered by a "draft" release on GitHub. 8 | * `release.yml` is run when a release is submitted on GitHub. Note that there is no prerequisite for `pre-release.yml` to have run (or passed), but it is a good idea to go through this manually by drafting a release. -------------------------------------------------------------------------------- /.github/workflows/docs-build-deploy.yml: -------------------------------------------------------------------------------- 1 | name: build and deploy docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-24.04 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: '3.12' 17 | - name: Install doc deps 18 | run: | 19 | pip install .'[dev]' 20 | - name: Install Pandoc dependency 21 | run: | 22 | sudo apt-get install pandoc 23 | - name: Set up node 24 | uses: actions/setup-node@v4 25 | - name: Install mermaid.js dependency 26 | run: | 27 | npm install @mermaid-js/mermaid-cli 28 | - name: Build the docs 29 | run: | 30 | make -C docs clean 31 | make -C docs html MERMAID_PATH="../node_modules/.bin/" 32 | - name: Deploy 33 | uses: peaceiris/actions-gh-pages@v4 34 | with: 35 | github_token: ${{ secrets.GITHUB_TOKEN }} 36 | publish_dir: ./docs/_build/html 37 | -------------------------------------------------------------------------------- /.github/workflows/draft-pdf.yml: -------------------------------------------------------------------------------- 1 | name: build JOSS pdf 2 | 3 | on: 4 | push: 5 | 6 | jobs: 7 | paper: 8 | runs-on: ubuntu-latest 9 | name: JOSS paper draft 10 | steps: 11 | - name: Checkout 12 | uses: actions/checkout@v4 13 | - name: Build draft PDF 14 | uses: openjournals/openjournals-draft-action@master 15 | with: 16 | journal: joss 17 | # This should be the path to the paper within your repo. 18 | paper-path: paper/paper.md 19 | - name: Upload 20 | uses: actions/upload-artifact@v4 21 | with: 22 | name: paper 23 | # This is the output path where Pandoc will write the compiled 24 | # PDF. Note, this should be the same directory as the input 25 | # paper.md 26 | path: paper/paper.pdf -------------------------------------------------------------------------------- /.github/workflows/pre-release.yml: -------------------------------------------------------------------------------- 1 | name: test all operating systems pre-release 2 | 3 | on: 4 | release: 5 | types: 6 | - prereleased 7 | 8 | jobs: 9 | tests: 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12"] 15 | os: [ubuntu-latest, macOS-latest, windows-latest] 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: | 24 | pip install --upgrade pip 25 | - name: Install vanilla package 26 | run: | 27 | pip install . 28 | - name: Install test deps 29 | run: | 30 | pip install .[test] 31 | - name: Run tests with coverage 32 | run: | 33 | pytest --cov=mpol 34 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: build and upload to PyPI 2 | 3 | on: 4 | release: 5 | types: 6 | - released 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-24.04 11 | steps: 12 | - uses: actions/checkout@v4 13 | - name: Set up Python 14 | uses: actions/setup-python@v5 15 | with: 16 | python-version: "3.12" 17 | - name: Install dependencies 18 | run: | 19 | pip install --upgrade pip 20 | pip install setuptools wheel twine 21 | pip install pep517 --user 22 | - name: Install vanilla package 23 | run: | 24 | pip install . 25 | - name: Build a binary wheel and a source tarball 26 | run: | 27 | python -m pep517.build --source --binary --out-dir dist/ . 28 | - name: Publish distribution to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | with: 31 | user: __token__ 32 | password: ${{ secrets.pypi_password }} 33 | -------------------------------------------------------------------------------- /.github/workflows/verify-tests-and-docs.yml: -------------------------------------------------------------------------------- 1 | name: verify tests pass and docs build 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | tests: 10 | runs-on: ubuntu-24.04 11 | continue-on-error: ${{ matrix.experimental }} 12 | strategy: 13 | fail-fast: true 14 | matrix: 15 | python-version: ["3.10", "3.11", "3.12"] 16 | experimental: [false] 17 | include: 18 | - python-version: "3.13" 19 | experimental: true 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install pip 27 | run: | 28 | pip install --upgrade pip 29 | - name: Install vanilla package 30 | run: | 31 | pip install . 32 | - name: Install test dependencies 33 | run: | 34 | pip install .[test] 35 | - name: Lint with ruff 36 | run: | 37 | ruff check . 38 | - name: Check types with MyPy 39 | run: | 40 | mypy src/mpol --pretty 41 | - name: Run tests with coverage 42 | run: | 43 | pytest --cov=mpol 44 | docs_build: 45 | runs-on: ubuntu-24.04 46 | steps: 47 | - uses: actions/checkout@v4 48 | - name: Set up Python 49 | uses: actions/setup-python@v5 50 | with: 51 | python-version: "3.12" 52 | - name: Install doc dependencies 53 | run: | 54 | pip install .[dev] 55 | - name: Install Pandoc dependency 56 | run: | 57 | sudo apt-get install pandoc 58 | - name: Build the docs 59 | run: | 60 | make -C docs clean 61 | make -C docs html 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | pip-wheel-metadata 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # vscode configs 108 | .vscode 109 | 110 | # mac files 111 | .DS_Store 112 | docs/.DS_Store 113 | 114 | # images from tests 115 | test/*.png 116 | 117 | # virtual environment 118 | venv/ 119 | 120 | # notebooks produced from jupytext 121 | docs/ci-tutorials/*.ipynb 122 | 123 | docs/ci-tutorials/alma.jpg 124 | docs/ci-tutorials/mock_data.npz 125 | 126 | 127 | # tensorboard outputs 128 | docs/ci-tutorials/runs 129 | docs/large-tutorials/runs 130 | docs/large-tutorials/logs 131 | 132 | # testing cache 133 | tmp_cache 134 | 135 | # initialized dirty image 136 | model.pt 137 | dirty_image_model.pt 138 | 139 | # setup file 140 | project_setup.sh 141 | 142 | # likely folders from testing 143 | plotsdir 144 | runs 145 | 146 | # hatch-generated version file 147 | src/mpol/mpol_version.py 148 | 149 | .ruff_cache 150 | 151 | build_joss.sh 152 | prof -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-toml 11 | - id: check-added-large-files 12 | args: ["--maxkb=2000"] 13 | - id: detect-private-key 14 | - id: name-tests-test 15 | - repo: https://github.com/psf/black 16 | rev: 23.1.0 17 | hooks: 18 | - id: black 19 | - repo: https://github.com/PyCQA/isort 20 | rev: 5.12.0 21 | hooks: 22 | - id: isort 23 | args: [] 24 | exclude: docs 25 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at iczekala@psu.edu. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | Contributors 2 | * Ian Czekala, `@iancze` 3 | * Jeff Jennings, `@jeffjennings` 4 | * Brianna Zawadzki, `@briannazawadzki` 5 | * Ryan Loomis, `@ryanaloomis` 6 | * Kadri Nizam, `@kadri-nizam` 7 | * Megan Delamer 8 | * Kaylee de Soto, `@kdesoto-astro` 9 | * Robert Frazier, `@RCF42` 10 | * Hannah Grzybowski, `@hgrzy` 11 | * Mary Ogborn 12 | * Tyler Quinn, `@trq5014` 13 | * Kristin Hopley -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 - 2025 Ian Czekala and contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MPoL 2 | 3 | [![Tests](https://github.com/MPoL-dev/MPoL/actions/workflows/verify-tests-and-docs.yml/badge.svg?branch=main)](https://github.com/MPoL-dev/MPoL/actions/workflows/verify-tests-and-docs.yml) 4 | [![build and deploy docs](https://github.com/MPoL-dev/MPoL/actions/workflows/docs-build-deploy.yml/badge.svg?branch=main)](https://github.com/MPoL-dev/MPoL/actions/workflows/docs-build-deploy.yml) 5 | [![DOI](https://zenodo.org/badge/224543208.svg)](https://zenodo.org/badge/latestdoi/224543208) 6 | 7 | 8 | MPoL is a [PyTorch](https://pytorch.org/) *library* built for Regularized Maximum Likelihood (RML) imaging and Bayesian Inference with datasets from interferometers like the Atacama Large Millimeter/Submillimeter Array ([ALMA](https://www.almaobservatory.org/en/home/)) and the Karl G. Jansky Very Large Array ([VLA](https://public.nrao.edu/telescopes/vla/)). 9 | 10 | As a PyTorch *library*, MPoL is designed expecting that the user will write Python code that uses MPoL primitives as building blocks to solve their interferometric imaging workflow, much the same way the artificial intelligence community writes Python code that uses PyTorch layers to implement new neural network architectures (for [example](https://github.com/pytorch/examples)). You will find MPoL easiest to use if you adhere to PyTorch customs and idioms, e.g., feed-forward neural networks, data storage, GPU acceleration, and train/test optimization loops. Therefore, a basic familiarity with PyTorch is considered a prerequisite for MPoL. 11 | 12 | MPoL is *not* an imaging application nor a pipeline, though such programs could be built for specialized workflows using MPoL components. We are focused on providing a numerically correct and expressive set of core primitives so the user can leverage the full power of the PyTorch (and Python) ecosystem to solve their research-grade imaging tasks. This is already a significant development and maintenance burden for our small research team, so our immediate scope must necessarily be limited. 13 | 14 | * **Documentation** is available at [https://mpol-dev.github.io/MPoL/](https://mpol-dev.github.io/MPoL/) 15 | * **Examples** are available at [https://github.com/MPoL-dev/examples/](https://github.com/MPoL-dev/examples/) 16 | 17 | ## Citation 18 | 19 | If you use this package or derivatives of it, please cite the following two references: 20 | 21 | @software{mpol, 22 | author = {Ian Czekala and 23 | Jeff Jennings and 24 | Brianna Zawadzki and 25 | Ryan Loomis and 26 | Kadri Nizam and 27 | Megan Delamer and 28 | Kaylee de Soto and 29 | Robert Frazier and 30 | Hannah Grzybowski and 31 | Mary Ogborn and 32 | Tyler Quinn}, 33 | title = {MPoL-dev/MPoL: v0.2.0 Release}, 34 | month = nov, 35 | year = 2023, 36 | publisher = {Zenodo}, 37 | version = {v0.2.0}, 38 | doi = {10.5281/zenodo.3594081}, 39 | url = {https://doi.org/10.5281/zenodo.3594081} 40 | } 41 | 42 | and 43 | 44 | @ARTICLE{2023PASP..135f4503Z, 45 | author = {{Zawadzki}, Brianna and {Czekala}, Ian and {Loomis}, Ryan A. and {Quinn}, Tyler and {Grzybowski}, Hannah and {Frazier}, Robert C. and {Jennings}, Jeff and {Nizam}, Kadri M. and {Jian}, Yina}, 46 | title = "{Regularized Maximum Likelihood Image Synthesis and Validation for ALMA Continuum Observations of Protoplanetary Disks}", 47 | journal = {\pasp}, 48 | keywords = {Protoplanetary disks, Submillimeter astronomy, Radio interferometry, Deconvolution, Open source software, 1300, 1647, 1346, 1910, 1866, Astrophysics - Earth and Planetary Astrophysics, Astrophysics - Instrumentation and Methods for Astrophysics}, 49 | year = 2023, 50 | month = jun, 51 | volume = {135}, 52 | number = {1048}, 53 | eid = {064503}, 54 | pages = {064503}, 55 | doi = {10.1088/1538-3873/acdf84}, 56 | archivePrefix = {arXiv}, 57 | eprint = {2209.11813}, 58 | primaryClass = {astro-ph.EP}, 59 | adsurl = {https://ui.adsabs.harvard.edu/abs/2023PASP..135f4503Z}, 60 | adsnote = {Provided by the SAO/NASA Astrophysics Data System} 61 | } 62 | 63 | --- 64 | Copyright Ian Czekala and contributors 2019-24 65 | 66 | A Million Points of Light are needed to synthesize image cubes from interferometers. -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # makefile for Sphinx documentation 2 | 3 | # You can set these variables from the command line, and also 4 | # from the environment for the first two. 5 | SPHINXOPTS ?= 6 | SPHINXBUILD ?= sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile html clean 15 | 16 | CI-NOTEBOOKS := ci-tutorials/PyTorch.ipynb ci-tutorials/gridder.ipynb ci-tutorials/optimization.ipynb ci-tutorials/crossvalidation.ipynb ci-tutorials/initializedirtyimage.ipynb 17 | clean: 18 | rm -rf _build 19 | rm -rf ci-tutorials/.ipynb_checkpoints 20 | rm -rf ci-tutorials/runs 21 | rm -rf ci-tutorials/alma.jpg 22 | rm -rf ci-tutorials/mock_data.npz 23 | rm -rf _static/baselines/build/baselines.csv 24 | 25 | # baseline table 26 | _static/baselines/build/baselines.csv: _static/baselines/src/print_conversions.py 27 | mkdir -p _static/baselines/build 28 | python _static/baselines/src/print_conversions.py $@ 29 | 30 | # fftshift figure 31 | _static/fftshift/build/plot.png: _static/fftshift/src/plot.py 32 | mkdir -p _static/fftshift/build 33 | python _static/fftshift/src/plot.py $@ 34 | 35 | html: _static/baselines/build/baselines.csv _static/fftshift/build/plot.png 36 | python -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html 37 | -------------------------------------------------------------------------------- /docs/_static/baselines/src/print_conversions.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | import numpy as np 4 | from mpol.constants import c_ms 5 | 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser( 9 | description="Save baselines and klambda conversions to CSV." 10 | ) 11 | parser.add_argument("outfile", help="Destination to save CSV table.") 12 | args = parser.parse_args() 13 | 14 | 15 | header = ["baseline", "100 GHz (Band 3)", "230 GHz (Band 6)", "340 GHz (Band 7)"] 16 | 17 | baselines = np.array([10, 50, 100, 500, 1000, 5000, 10000, 16000]) 18 | frequencies = np.array([100, 230, 340]) * 1e9 # Hz 19 | 20 | 21 | def format_baseline(baseline_m): 22 | if baseline_m < 1e3: 23 | return f"{baseline_m:.0f} m" 24 | elif baseline_m < 1e6: 25 | return f"{baseline_m * 1e-3:.0f} km" 26 | 27 | 28 | def format_lambda(lam): 29 | if lam < 1e3: 30 | return f"{lam:.0f}" + r" :math:`\lambda`" 31 | elif lam < 1e6: 32 | return f"{lam * 1e-3:.0f}" + r" :math:`\mathrm{k}\lambda`" 33 | else: 34 | return f"{lam * 1e-6:.0f}" + r" :math:`\mathrm{M}\lambda`" 35 | 36 | 37 | data = [] 38 | for baseline in baselines: 39 | row = [format_baseline(baseline)] 40 | for frequency in frequencies: 41 | lam = baseline / (c_ms / frequency) 42 | row.append(format_lambda(lam)) 43 | data.append(row) 44 | 45 | with open(args.outfile, "w", newline="") as f: 46 | mywriter = csv.writer(f) 47 | mywriter.writerow(header) 48 | mywriter.writerows(data) 49 | -------------------------------------------------------------------------------- /docs/_static/bullets.css: -------------------------------------------------------------------------------- 1 | .rst-content section ul li { 2 | list-style: disc; 3 | margin-left: 24px 4 | } 5 | -------------------------------------------------------------------------------- /docs/_static/fftshift/src/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from astropy.io import fits 4 | from astropy.utils.data import download_file 5 | from matplotlib import patches 6 | from matplotlib.colors import LogNorm 7 | from matplotlib.gridspec import GridSpec 8 | from mpol import coordinates 9 | 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser(description="Create the fftshift plot") 13 | parser.add_argument("outfile", help="Destination to save plot.") 14 | args = parser.parse_args() 15 | 16 | 17 | fname = download_file( 18 | "https://zenodo.org/record/4711811/files/logo_cont.fits", 19 | cache=True, 20 | show_progress=True, 21 | pkgname="mpol", 22 | ) 23 | coords = coordinates.GridCoords(cell_size=0.007, npix=512) 24 | kw = {"origin": "lower", "interpolation": "none", "extent": coords.img_ext} 25 | kwvis = {"origin": "lower", "interpolation": "none", "extent": coords.vis_ext} 26 | 27 | fig = plt.figure(constrained_layout=False, figsize=(8, 6)) 28 | gs = GridSpec(2, 3, fig) 29 | ax = fig.add_subplot(gs[0, 0]) 30 | ax1 = fig.add_subplot(gs[0, 1]) 31 | ax2 = fig.add_subplot(gs[0, 2]) 32 | ax3 = fig.add_subplot(gs[1, 0]) 33 | ax4 = fig.add_subplot(gs[1, 2]) 34 | 35 | 36 | d = fits.open(fname) 37 | sky_cube = d[0].data 38 | d.close() 39 | ax.imshow(sky_cube, **kw) 40 | ax.set_title("Sky Cube") 41 | ax.set_xlabel(r"$\Delta \alpha \cos \delta \; [{}^{\prime\prime}]$") 42 | ax.set_ylabel(r"$\Delta \delta\; [{}^{\prime\prime}]$") 43 | 44 | flip_cube = np.flip(sky_cube, (1,)) 45 | ax1.imshow(sky_cube, **kw) 46 | ax1.set_title("Flip Cube") 47 | ax1.set_xlabel(r"$\Delta \alpha \cos \delta\; [{}^{\prime\prime}]$") 48 | ax1.set_ylabel(r"$\Delta \delta\; [{}^{\prime\prime}]$") 49 | ax1.invert_xaxis() 50 | 51 | image_cube = np.fft.fftshift(flip_cube, axes=(0, 1)) 52 | ax2.imshow(image_cube, **kw) 53 | ax2.set_title("Packed Cube (Image)") 54 | ax2.xaxis.set_visible(False) 55 | ax2.yaxis.set_visible(False) 56 | 57 | visibility_cube = np.fft.fft2(image_cube, axes=(0, 1)) 58 | ax3.imshow(np.abs(visibility_cube), norm=LogNorm(vmin=1e-4, vmax=12000), **kwvis) 59 | ax3.set_title("Packed Cube (Visibility)") 60 | ax3.xaxis.set_visible(False) 61 | ax3.yaxis.set_visible(False) 62 | 63 | ground_cube = np.fft.fftshift(visibility_cube, axes=(0, 1)) 64 | ax4.imshow(np.abs(ground_cube), **kwvis, norm=LogNorm(vmin=1e-4, vmax=12000)) 65 | ax4.set_title("Ground Cube") 66 | ax4.set_xlabel(r"$u$ [k$\lambda$]") 67 | ax4.set_ylabel(r"$v$ [k$\lambda$]") 68 | 69 | arrow_kws = {"mutation_scale": 20, "transform": fig.transFigure, "fc": "black"} 70 | 71 | annotate_kws = { 72 | "xycoords": fig.transFigure, 73 | "va": "center", 74 | "ha": "center", 75 | "weight": "bold", 76 | "fontsize": "large", 77 | } 78 | 79 | text_kws = {"va": "center", "ha": "center", "weight": "bold", "fontsize": "large"} 80 | 81 | y = 0.82 82 | x0, x1 = 0.28, 0.37 83 | arrow_sky_cube_to_flip_cube = patches.FancyArrowPatch( 84 | (x0, y), (x1, y), **arrow_kws, arrowstyle="<->" 85 | ) 86 | fig.patches.append(arrow_sky_cube_to_flip_cube) 87 | fig.text((x0 + x1) / 2, y + 0.05, "flip \n across R.A.", **text_kws) 88 | 89 | x0, x1 = 0.62, 0.72 90 | y = 0.84 91 | arrow_flip_cube_to_packed_cube = patches.FancyArrowPatch( 92 | (x0, y), (x1, y), **arrow_kws, arrowstyle="->" 93 | ) 94 | fig.patches.append(arrow_flip_cube_to_packed_cube) 95 | fig.text((x0 + x1) / 2, y + 0.03, "fftshift", **text_kws) 96 | 97 | x0, x1 = 0.72, 0.62 98 | y = 0.77 99 | arrow_packed_cube_to_flip_cube = patches.FancyArrowPatch( 100 | (x0, y), (x1, y), **arrow_kws, arrowstyle="->" 101 | ) 102 | fig.patches.append(arrow_packed_cube_to_flip_cube) 103 | fig.text((x0 + x1) / 2, y + 0.03, "ifftshift", **text_kws) 104 | 105 | 106 | x_center = 0.5 107 | arrow_packed_image_to_packed_visibility = patches.FancyArrowPatch( 108 | (0.73, 0.59), (0.32, 0.38), **arrow_kws, arrowstyle="->" 109 | ) 110 | fig.patches.append(arrow_packed_image_to_packed_visibility) 111 | fig.text(x_center, 0.50, "fft2", rotation=17, **text_kws) 112 | 113 | arrow_packed_visibility_to_packed_image = patches.FancyArrowPatch( 114 | (0.34, 0.33), (0.75, 0.54), **arrow_kws, arrowstyle="->" 115 | ) 116 | fig.patches.append(arrow_packed_visibility_to_packed_image) 117 | fig.text(x_center, 0.37, "ifft2", rotation=17, **text_kws) 118 | 119 | 120 | x0, x1 = 0.62, 0.31 121 | y = 0.23 122 | arrow_ground_cube_to_packed_cube = patches.FancyArrowPatch( 123 | (x0, y), (x1, y), **arrow_kws, arrowstyle="->" 124 | ) 125 | fig.patches.append(arrow_ground_cube_to_packed_cube) 126 | plt.annotate("fftshift", ((x0 + x1) / 2, y + 0.02), **annotate_kws) 127 | 128 | x0, x1 = 0.31, 0.62 129 | y = 0.16 130 | arrow_packed_cube_to_ground_cube = patches.FancyArrowPatch( 131 | (x0, y), (x1, y), **arrow_kws, arrowstyle="->" 132 | ) 133 | fig.patches.append(arrow_packed_cube_to_ground_cube) 134 | plt.annotate("ifftshift", ((x0 + x1) / 2, y + 0.02), **annotate_kws) 135 | 136 | 137 | fig.subplots_adjust(wspace=0.65, left=0.06, right=0.94, top=0.97, bottom=0.05) 138 | plt.savefig(args.outfile, dpi=300) 139 | -------------------------------------------------------------------------------- /docs/_static/mmd/src/BaseCube.mmd: -------------------------------------------------------------------------------- 1 | graph TD 2 | bc(BaseCube) --> ImageCube 3 | ImageCube --> FourierLayer 4 | FourierLayer --> il([Loss]) 5 | ad[[Dataset]] --> il([Loss]) 6 | -------------------------------------------------------------------------------- /docs/_static/mmd/src/GriddedNet.mmd: -------------------------------------------------------------------------------- 1 | graph TD 2 | subgraph GriddedNet 3 | bc(BaseCube) --> HannConvCube 4 | HannConvCube --> ImageCube 5 | ImageCube --> FourierLayer 6 | end 7 | FourierLayer --> il([Loss]) 8 | ad[[Dataset]] --> il([Loss]) 9 | -------------------------------------------------------------------------------- /docs/_static/mmd/src/ImageCube.mmd: -------------------------------------------------------------------------------- 1 | graph TD 2 | ic(ImageCube) --> FourierLayer 3 | FourierLayer --> il([Loss]) 4 | ad[[Dataset]] --> il([Loss]) 5 | -------------------------------------------------------------------------------- /docs/_static/mmd/src/Parametric.mmd: -------------------------------------------------------------------------------- 1 | graph TD 2 | pm(DiskModel) --> ImageCube 3 | ImageCube --> FourierLayer 4 | FourierLayer --> il([Loss]) 5 | ad[[Dataset]] --> il([Loss]) 6 | -------------------------------------------------------------------------------- /docs/_static/mmd/src/SingleDish.mmd: -------------------------------------------------------------------------------- 1 | graph TD 2 | bc(BaseCube) --> ImageCube 3 | ad[[ALMA Dataset]] --> adc[ALMA DataConnector] 4 | ImageCube --> FourierCube 5 | FourierCube --> adc[ALMA DataConnector] 6 | FourierCube --> pl([PSD-Loss]) 7 | ImageCube --> me([MaxEntropy]) 8 | ImageCube --> idc[IRAM DataConnector] 9 | id[[IRAM-Dataset]] --> idc[IRAM DataConnector] 10 | adc[ALMA DataConnector] --> al([ALMA Loss]) 11 | idc[IRAM DataConnector] --> il([IRAM Loss]) 12 | subgraph Loss Calculations 13 | al([ALMA Loss]) --> tl([Total Loss]) 14 | pl([PSD Loss]) --> tl([Total Loss]) 15 | me([MaxEntropy]) --> tl([Total Loss]) 16 | il([IRAM Loss]) --> tl([Total Loss]) 17 | end 18 | 19 | 20 | #Rounded rectangles have trainable parameters. 21 | #(static) Datasets are double rectangles. 22 | #Passthrough layers are rectangles (no trainable parameters) 23 | #Loss calculations are stadium rectangles. 24 | -------------------------------------------------------------------------------- /docs/_static/mmd/src/SkyModel.mmd: -------------------------------------------------------------------------------- 1 | graph TD 2 | sm(SkyModel) --> ImageCube 3 | ImageCube --> FourierLayer 4 | FourierLayer --> DataConnector 5 | ad[[Dataset]] --> DataConnector 6 | DataConnector --> il([Loss]) 7 | -------------------------------------------------------------------------------- /docs/api/analysis.md: -------------------------------------------------------------------------------- 1 | # Analysis 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.onedim 5 | :members: 6 | ``` -------------------------------------------------------------------------------- /docs/api/coordinates.md: -------------------------------------------------------------------------------- 1 | # Coordinates 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.coordinates 5 | :members: 6 | ``` -------------------------------------------------------------------------------- /docs/api/crossval.md: -------------------------------------------------------------------------------- 1 | # Cross-validation 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.crossval 5 | :members: 6 | ``` 7 | -------------------------------------------------------------------------------- /docs/api/datasets.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.datasets 5 | :members: 6 | ``` -------------------------------------------------------------------------------- /docs/api/fourier.md: -------------------------------------------------------------------------------- 1 | # Fourier 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.fourier 5 | :members: 6 | ``` -------------------------------------------------------------------------------- /docs/api/geometry.md: -------------------------------------------------------------------------------- 1 | # Geometry 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.geometry 5 | :members: 6 | ``` 7 | -------------------------------------------------------------------------------- /docs/api/gridding.md: -------------------------------------------------------------------------------- 1 | # Gridding 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.gridding 5 | :members: 6 | ``` 7 | -------------------------------------------------------------------------------- /docs/api/images.md: -------------------------------------------------------------------------------- 1 | # Images 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.images 5 | :members: 6 | ``` 7 | -------------------------------------------------------------------------------- /docs/api/plotting.md: -------------------------------------------------------------------------------- 1 | # Plotting 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.plot 5 | :members: 6 | ``` 7 | -------------------------------------------------------------------------------- /docs/api/precomposed.md: -------------------------------------------------------------------------------- 1 | # Precomposed Modules 2 | 3 | For convenience, we provide some "precomposed" [modules](https://pytorch.org/docs/stable/notes/modules.html) which may be useful for simple imaging or modeling applications. In general, though, we encourage you to compose your own set of layers if your application requires it. The source code for a precomposed network can provide useful a starting point. We also recommend checking out the PyTorch documentation on [modules](https://pytorch.org/docs/stable/notes/modules.html). 4 | 5 | ```{eval-rst} 6 | .. automodule:: mpol.precomposed 7 | :members: 8 | ``` 9 | -------------------------------------------------------------------------------- /docs/api/train_test.md: -------------------------------------------------------------------------------- 1 | # Training and testing 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.training 5 | :members: 6 | ``` -------------------------------------------------------------------------------- /docs/api/utilities.md: -------------------------------------------------------------------------------- 1 | # Utilities 2 | 3 | ```{eval-rst} 4 | .. automodule:: mpol.utils 5 | :members: 6 | ``` 7 | -------------------------------------------------------------------------------- /docs/background.md: -------------------------------------------------------------------------------- 1 | # Background and prerequisites 2 | 3 | ## Radio astronomy 4 | 5 | A background in radio astronomy, Fourier transforms, and interferometry is a prerequisite for using MPoL but is beyond the scope of this documentation. We recommend reviewing these resources as needed. 6 | 7 | - [Essential radio astronomy](https://www.cv.nrao.edu/~sransom/web/xxx.html) textbook by James Condon and Scott Ransom, and in particular, Chapter 3.7 on Radio Interferometry. 8 | - NRAO's [17th Synthesis Imaging Workshop](http://www.cvent.com/events/virtual-17th-synthesis-imaging-workshop/agenda-0d59eb6cd1474978bce811194b2ff961.aspx) recorded lectures and slides available 9 | - [Interferometry and Synthesis in Radio Astronomy](https://ui.adsabs.harvard.edu/abs/2017isra.book.....T/abstract) by Thompson, Moran, and Swenson. An excellent and comprehensive reference on all things interferometry. 10 | - The [Revisiting the radio interferometer measurement equation](https://ui.adsabs.harvard.edu/abs/2011A%26A...527A.106S/abstract) series by O. Smirnov, 2011 11 | - Ian Czekala's lecture notes on [Radio Interferometry and Imaging](https://iancze.github.io/courses/as5003/lectures/) 12 | 13 | RML imaging is different from CLEAN imaging, which operates as a deconvolution procedure in the image plane. However, CLEAN is by far the dominant algorithm used to synthesize images from interferometric data at sub-mm and radio wavelengths, and it is useful to have at least a basic understanding of how it works. We recommend 14 | 15 | - [Interferometry and Synthesis in Radio Astronomy](https://ui.adsabs.harvard.edu/abs/2017isra.book.....T/abstract) Chapter 11.1 16 | - David Wilner's lecture on [Imaging and Deconvolution in Radio Astronomy](https://www.youtube.com/watch?v=mRUZ9eckHZg) 17 | - For a discussion on using both CLEAN and RML techniques to robustly interpret kinematic data of protoplanetary disks, see Section 3 of [Visualizing the Kinematics of Planet Formation](https://ui.adsabs.harvard.edu/abs/2020arXiv200904345D/abstract) by The Disk Dynamics Collaboration 18 | 19 | ## Statistics and Machine Learning 20 | 21 | MPoL is built on top of the [PyTorch](https://pytorch.org/) machine learning framework and adopts much of the terminology and design principles of machine learning workflows. As a prerequisite, we recommend at least a basic understanding of statistics and machine learning principles. Two excellent (free) textbooks are 22 | 23 | - [Dive into Deep Learning](https://d2l.ai/), in particular chapters 1 - 3 to cover the basics of forward models, automatic differentiation, and optimization. 24 | - [Deep Learning: Foundations and Concepts](https://www.bishopbook.com/) for a lengthier discussion of these concepts and other foundational statistical concepts. 25 | 26 | And we highly recommend the informative and entertaining 3b1b lectures on [deep learning](https://www.youtube.com/watch?v=aircAruvnKk&list=PLZHQObOWTQDNU6R1_67000Dx_ZCJB-3pi&ab_channel=3Blue1Brown). 27 | 28 | ## PyTorch 29 | 30 | As a PyTorch library, MPoL expects that the user will write Python code that uses MPoL primitives as building blocks to solve their interferometric imaging workflow, much the same way the artificial intelligence community writes Python code that uses PyTorch layers to implement new neural network architectures (for [example](https://github.com/pytorch/examples)). You will find MPoL easiest to use if you follow PyTorch customs and idioms, e.g., feed-forward neural networks, data storage, GPU acceleration, and train/test optimization loops. Therefore, a basic familiarity with PyTorch is considered a prerequisite for MPoL. 31 | 32 | If you are new to PyTorch, we recommend starting with the official [Learn the Basics](https://pytorch.org/tutorials/beginner/basics/intro.html) guide. You can also find high quality introductions on YouTube and in textbooks. 33 | 34 | ## RML Imaging 35 | 36 | MPoL is a modern PyTorch imaging library, however many of the key concepts behind Regularized Maximum Likelihood image have been around for some time. We recommend checking out the following (non-exhaustive) list of resources 37 | 38 | - [Regularized Maximum Likelihood Image Synthesis and Validation for ALMA Continuum Observations of Protoplanetary Disks](https://ui.adsabs.harvard.edu/abs/2023PASP..135f4503Z/abstract) by Zawadzki et al. 2023 39 | - The fourth paper in the 2019 [Event Horizon Telescope Collaboration series](https://ui.adsabs.harvard.edu/abs/2019ApJ...875L...4E/abstract) describing the imaging principles 40 | - [Maximum entropy image restoration in astronomy](https://ui.adsabs.harvard.edu/abs/1986ARA%26A..24..127N/abstract) AR&A by Narayan and Nityananda 1986 41 | - [Multi-GPU maximum entropy image synthesis for radio astronomy](https://ui.adsabs.harvard.edu/abs/2018A%26C....22...16C/abstract) by Cárcamo et al. 2018 42 | - Dr. Katie Bouman's Ph.D. thesis ["Extreme Imaging via Physical Model Inversion: Seeing Around Corners and Imaging Black Holes"](https://people.csail.mit.edu/klbouman/pw/papers_and_presentations/thesis.pdf) 43 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | 2 | # -- Project information ----------------------------------------------------- 3 | from pkg_resources import DistributionNotFound, get_distribution 4 | 5 | try: 6 | __version__ = get_distribution("MPoL").version 7 | except DistributionNotFound: 8 | __version__ = "unknown version" 9 | 10 | # https://github.com/mgaitan/sphinxcontrib-mermaid/issues/72 11 | import errno 12 | 13 | import sphinx.util.osutil 14 | 15 | sphinx.util.osutil.ENOENT = errno.ENOENT 16 | 17 | project = "MPoL" 18 | copyright = "2019-24, Ian Czekala" 19 | author = "Ian Czekala" 20 | 21 | # The full version, including alpha/beta/rc tags 22 | version = __version__ 23 | release = __version__ 24 | 25 | # -- General configuration --------------------------------------------------- 26 | extensions = [ 27 | "sphinx.ext.autodoc", 28 | "sphinx.ext.viewcode", 29 | "sphinx.ext.napoleon", 30 | "sphinx_copybutton", 31 | "sphinxcontrib.mermaid", 32 | "myst_nb", 33 | ] 34 | 35 | # add in additional files 36 | source_suffix = { 37 | ".ipynb": "myst-nb", 38 | ".rst": "restructuredtext", 39 | ".myst": "myst-nb", 40 | ".md": "myst-nb", 41 | } 42 | 43 | myst_enable_extensions = ["dollarmath", "colon_fence", "amsmath"] 44 | 45 | autodoc_mock_imports = ["torch", "torchvision"] 46 | autodoc_member_order = "bysource" 47 | # https://github.com/sphinx-doc/sphinx/issues/9709 48 | # bug that if we set this here, we can't list individual members in the 49 | # actual API doc 50 | # autodoc_default_options = {"members": None} 51 | 52 | # Add any paths that contain templates here, relative to this directory. 53 | templates_path = ["_templates"] 54 | 55 | # List of patterns, relative to source directory, that match files and 56 | # directories to ignore when looking for source files. 57 | # This pattern also affects html_static_path and html_extra_path. 58 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 59 | 60 | # -- Options for HTML output ------------------------------------------------- 61 | html_theme = "sphinx_book_theme" 62 | html_theme_options = { 63 | "repository_url": "https://github.com/MPoL-dev/MPoL", 64 | "use_repository_button": True, 65 | } 66 | 67 | html_logo = "logo.png" 68 | html_favicon = "favicon.ico" 69 | 70 | master_doc = "index" 71 | 72 | # Add any paths that contain custom static files (such as style sheets) here, 73 | # relative to this directory. They are copied after the builtin static files, 74 | # so a file named "default.css" will overwrite the builtin "default.css". 75 | html_static_path = ["_static"] 76 | 77 | # https://docs.readthedocs.io/en/stable/guides/adding-custom-css.html 78 | html_js_files = ["https://buttons.github.io/buttons.js"] 79 | 80 | # Mermaid configuration 81 | # mermaid_output_format = "svg" 82 | 83 | # zero out any JS, since it doesn't work 84 | # mermaid_init_js = "" 85 | # mermaid_version = "" 86 | 87 | # if os.getenv("CI"): 88 | # # if True, we're running on github actions and need 89 | # # to use the path of the installed mmdc 90 | # # relative to docs/ directory! 91 | # # (mmdc itself not in $PATH automatically, like local) 92 | # mermaid_cmd = "../node_modules/.bin/mmdc" 93 | 94 | nb_execution_mode = "cache" 95 | nb_execution_timeout = -1 96 | nb_execution_raise_on_error = True 97 | # .ipynb are produced using Makefile on own terms, 98 | # # both .md and executed .ipynb are kept in git repo 99 | nb_execution_excludepatterns = [ 100 | "large-tutorials/*.md", 101 | "large-tutorials/*.ipynb", 102 | "**.ipynb_checkpoints", 103 | ] 104 | myst_heading_anchors = 3 105 | -------------------------------------------------------------------------------- /docs/developer-documentation.md: -------------------------------------------------------------------------------- 1 | (developer-documentation-label)= 2 | 3 | # Developer Documentation 4 | 5 | If you find an issue with the code, documentation, or would like to make a specific suggestion for an improvement, please raise an issue on the [Github repository](https://github.com/MPoL-dev/MPoL/issues). If you have a more general query or would just like to discuss a topic, please make a post on our [Github discussions page](https://github.com/MPoL-dev/MPoL/discussions). 6 | 7 | If you are new to contributing to an open source project, we recommend a quick read through the excellent contributing guides of the [exoplanet](https://docs.exoplanet.codes/en/stable/user/dev/) and [astropy](https://docs.astropy.org/en/stable/development/workflow/development_workflow.html) packages. What follows in this guide draws upon many of the suggestions from those two resources. There are many ways to contribute to an open source project like MPoL, and all of them are valuable to us. No contribution is too small---even a typo fix is appreciated! 8 | 9 | The MPoL source repository is hosted on [Github](https://github.com/MPoL-dev/MPoL) as part of the [MPoL-dev](https://github.com/MPoL-dev/) organization. We use a "fork and pull request" model for collaborative development. If you are unfamiliar with this workflow, check out this short Github guide on [forking projects](https://guides.github.com/activities/forking/). 10 | 11 | ## Development dependencies 12 | 13 | Extra packages required for development can be installed via 14 | 15 | ``` 16 | (venv) $ pip install -e ".[dev]" 17 | ``` 18 | 19 | This directs pip to install whatever package is in the current working directory (`.`) as an editable package (`-e`), using the set of `[dev]` optional packages. There is also a more limited set of packages under `[test]`. You can view these packages in the `pyproject.toml` file. 20 | 21 | ## Testing 22 | 23 | MPoL includes a test suite written using [pytest](https://docs.pytest.org/). We aim for this test suite to be as comprehensive as possible, since this helps us achieve our goal of shipping stable software. 24 | 25 | ### Running tests 26 | 27 | To run all of the tests, change to the root of the repository and invoke 28 | 29 | ``` 30 | $ python -m pytest 31 | ``` 32 | 33 | If a test errors (especially on the `main` branch), please report what went wrong as a bug report issue on the [Github repository](https://github.com/MPoL-dev/MPoL/issues). 34 | 35 | ### Viewing test and debug plots 36 | 37 | Some tests produce temporary files, like plots, that could be useful to view for development or debugging. Normally these are saved to a temporary directory created by the system which will be cleaned up after the tests finish. To preserve them, first create a plot directory (e.g., `plotsdir`) and then run the tests with this `--basetemp` specified 38 | 39 | ``` 40 | $ mkdir plotsdir 41 | $ python -m pytest --basetemp=plotsdir 42 | ``` 43 | 44 | ### Test coverage 45 | 46 | To investigate how well the test suite covers the full range of program functionality, you can run [coverage.py](https://coverage.readthedocs.io/en/coverage-5.5/) through pytest using [pytest-cov](https://pypi.org/project/pytest-cov/), which should already be installed as part of the `[test]` dependencies 47 | 48 | ``` 49 | $ pytest --cov=mpol 50 | $ coverage html 51 | ``` 52 | 53 | And then use your favorite web browser to open `htmlcov/index.html` and view the coverage report. 54 | 55 | For more information on code coverage, see the [coverage.py documentation](https://coverage.readthedocs.io/en/coverage-5.5/). A worthy goal is to reach 100% code coverage with the testing suite. However, 100% coverage *doesn't mean the code is free of bugs*. More important than complete coverage is writing tests that properly probe program functionality. 56 | 57 | ### Test cache 58 | 59 | Several of the tests require mock data that is not practical to package within the github repository itself, and so it is stored on Zenodo and downloaded using astropy caching functions. If you run into trouble with the test cache becoming stale, you can delete it by deleting the `.mpol/cache` folder in your home directory. 60 | 61 | ### Contributing tests 62 | MPoL tests are located within the `test/` directory and follow [pytest](https://docs.pytest.org/en/6.2.x/contents.html#toc) conventions. Please add your new tests to this directory---we love new and useful tests. 63 | 64 | If you are adding new code functionality to the package, please make sure you have also written a set of corresponding tests probing the key interfaces. If you submit a pull request implementing code functionality without any new tests, be prepared for 'tests' to be one of the first suggestions on your pull request. Some helpful advice on *which* tests to write is [here](https://docs.python-guide.org/writing/tests/), [here](https://realpython.com/pytest-python-testing/), and [here](https://www.nerdwallet.com/blog/engineering/5-pytest-best-practices/). 65 | 66 | 67 | ## Type hinting 68 | 69 | Core MPoL routines are type-checked with [mypy](https://mypy.readthedocs.io/en/stable/index.html) for 100% coverage. Before you push your changes to the repo, you will want to make sure your code passes type checking locally (otherwise they will fail the GitHub Actions continuous integration tests). You can do this from the root of the repo by 70 | 71 | ``` 72 | mypy src/mpol --pretty 73 | ``` 74 | 75 | If you are unfamiliar with typing in Python, we recommend reading the [mypy cheatsheet](https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html) to get started. 76 | 77 | 78 | (documentation-build-reference-label)= 79 | ## Documentation 80 | 81 | MPoL documentation is written as docstrings attached to MPoL classes and functions (using reSt) and as individual `.md` files located in the `docs/` folder. The documentation is built using the [Sphinx](https://www.sphinx-doc.org/en/master/) Python documentation generator, with the [MyST-NB](https://myst-nb.readthedocs.io/en/latest/index.html) plugin. 82 | 83 | 84 | ### Building documentation 85 | 86 | To build the documentation, change to the `docs/` folder and run 87 | 88 | ``` 89 | $ make html 90 | ``` 91 | 92 | If successful, the HTML documentation will be available in the `_build/html` folder. You can preview the documentation locally using your favorite web browser by opening up `_build/html/index.html` 93 | 94 | You can clean up (delete) all of the built products by running 95 | 96 | ``` 97 | $ make clean 98 | ``` 99 | 100 | ### Contributing tutorials 101 | 102 | If your tutorial is self-contained and has minimal computational needs (under 30 seconds on a single CPU), provide the source file as a [MyST-NB `.md` file](https://myst-nb.readthedocs.io/en/latest/authoring/basics.html#text-based-notebooks) so that it is built with MyST-NB on GitHub Actions. You can use the small tutorials in `docs/ci-tutorials` as a reference to get started. 103 | 104 | If your tutorial requires significant computational resources (e.g., a GPU, multiple CPS, or more than 30 seconds runtime), execute the notebook on your local computing resources and commit both the `.md` and `.ipynb` files (with output cells) directly to the repository. You can see examples for these in `docs/large-tutorials`. 105 | 106 | If you're thinking about contributing a tutorial and would like guidance on form or scope, please raise an [issue](https://github.com/MPoL-dev/MPoL/issues) or [discussion](https://github.com/MPoL-dev/MPoL/discussions) on the github repository. 107 | 108 | ### Older documentation versions 109 | 110 | In the rare situation where you require documentation for a different (older) version of MPoL, you can swap to an older tag 111 | 112 | ``` 113 | git fetch --tags 114 | git checkout tags/v0.2.0 115 | ``` 116 | 117 | and then build the documentation. 118 | 119 | ## Releasing a new version of MPoL 120 | 121 | It is our intent that the `main` branch of the github repository always reflects a stable version of the code that passes all tests. After significant new functionality has been introduced, a tagged release (e.g., `v0.1.1`) is generated from the main branch and pushed to PyPI. 122 | 123 | To do this, follow this checklist in order: 124 | 125 | 1. Ensure *all* tests are passing on your PR, both locally and on GitHub Actions. 126 | 2. Ensure the docs build locally without errors or warnings. Check output by opening `docs/_build/html/index.html` with your web browser. 127 | 3. Perform final edits documenting the changes since last version in `docs/changelog.md`. Highlight potentially breaking changes and suggest how users might update their workflow. 128 | 4. Check contributors in `CONTRIBUTORS.md` up to date. 129 | 5. Update the copyright year and citation in `README.md` 130 | * In the citation, update all fields except 'Zenodo', 'doi', and 'url' (the current DOI will cite all versions and the URL will direct to the most recent version) 131 | 6. Merge your PR into `main` using the GitHub interface. 132 | * A new round of tests will be triggered by the merge. Make sure *all* of these pass. 133 | 7. Go to the [Releases](https://github.com/MPoL-dev/MPoL/releases) page, draft release notes, and publish a pre-release 134 | * Ensure the `pre-release.yml` workflow passes. 135 | 8. Publish the true release. GitHub actions will automatically uploaded the built package to PyPI and archive it on Zenodo 136 | * Verify the `package.yml` workflow passed. 137 | -------------------------------------------------------------------------------- /docs/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MPoL-dev/MPoL/f1facdb3737a2b45c5308f0ebd9ee60583edf0f1/docs/favicon.ico -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Million Points of Light (MPoL) 2 | 3 | [![Tests](https://github.com/MPoL-dev/MPoL/actions/workflows/verify-tests-and-docs.yml/badge.svg?branch=main)](https://github.com/MPoL-dev/MPoL/actions/workflows/verify-tests-and-docs.yml) 4 | [![gh-pages docs](https://img.shields.io/badge/community-Github%20Discussions-orange)](https://github.com/MPoL-dev/MPoL/discussions) 5 | 6 | MPoL is a [PyTorch](https://pytorch.org/) *library* built for Regularized Maximum Likelihood (RML) imaging and Bayesian Inference with datasets from interferometers like the Atacama Large Millimeter/Submillimeter Array ([ALMA](https://www.almaobservatory.org/en/home/)) and the Karl G. Jansky Very Large Array ([VLA](https://public.nrao.edu/telescopes/vla/)). 7 | 8 | As a PyTorch *library*, MPoL expects that the user will write Python code to link MPoL primitives as building blocks to solve their interferometric imaging workflow, much the same way the artificial intelligence community uses PyTorch layers to build new neural network architectures (for [example](https://github.com/pytorch/examples)). You will find MPoL easiest to use if you emulate PyTorch customs and idioms, e.g., feed-forward neural networks, data storage, GPU acceleration, and train/test optimization loops. Therefore, a basic familiarity with PyTorch is considered a prerequisite for MPoL. 9 | 10 | MPoL is *not* an imaging application nor a pipeline, though MPoL components could be used to build specialized workflows. We are focused on providing a numerically correct and expressive set of core primitives so the user can leverage the full power of the PyTorch (and Python) ecosystem to solve their research-grade imaging tasks. This is already a significant development and maintenance burden for the limited resources of our small research team, so our immediate scope must necessarily be limited. 11 | 12 | To get a sense of what background material MPoL assumes, please look at the [](background.md). If the package is right for your needs, follow the [installation instructions](installation.md). 13 | 14 | This documentation covers the API and a short set of tutorials demonstrating key components of the MPoL library. Longer examples demonstrating how one might use MPoL components to build an imaging workflow are packaged together in the [MPoL-dev/examples](https://github.com/MPoL-dev/examples) repository. 15 | 16 | If you'd like to help build the MPoL package, please check out the [](developer-documentation.md) to get started. For more information about the constellation of packages supporting RML imaging and modeling, check out the MPoL-dev organization [website](https://mpol-dev.github.io/) and [github](https://github.com/MPoL-dev) repository hosting the source code. If you have any questions, please ask us on our [Github discussions page](https://github.com/MPoL-dev/MPoL/discussions). 17 | 18 | *If you use MPoL in your research, please cite us!* See for the citation. 19 | 20 | ```{toctree} 21 | :caption: User Guide 22 | :maxdepth: 2 23 | 24 | background 25 | installation 26 | getting_started 27 | ``` 28 | 29 | ```{toctree} 30 | :caption: API 31 | :maxdepth: 2 32 | 33 | api/coordinates 34 | api/datasets 35 | api/fourier 36 | api/gridding 37 | api/images 38 | api/losses 39 | api/geometry 40 | api/utilities 41 | api/precomposed 42 | api/train_test 43 | api/plotting 44 | api/crossval 45 | api/analysis 46 | ``` 47 | 48 | ```{toctree} 49 | :caption: Reference 50 | :maxdepth: 2 51 | 52 | units-and-conventions.md 53 | developer-documentation.md 54 | changelog.md 55 | ``` 56 | 57 | - {ref}`genindex` 58 | - {ref}`changelog-reference-label` 59 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installation and Examples 2 | 3 | MPoL requires `python >= 3.10`. 4 | 5 | ## Using pip 6 | 7 | Stable versions are hosted on PyPI. You can install the latest version by 8 | 9 | ``` 10 | $ pip install MPoL 11 | ``` 12 | 13 | Or if you require a specific version of MPoL (e.g., `0.2.0`), you can install via 14 | 15 | ``` 16 | $ pip install MPoL==0.2.0 17 | ``` 18 | 19 | ## From source 20 | 21 | If you'd like to install the package from source to access the latest development version, download or `git clone` the MPoL repository and install 22 | 23 | ``` 24 | $ git clone https://github.com/MPoL-dev/MPoL.git 25 | $ cd MPoL 26 | $ pip install . 27 | ``` 28 | 29 | If you have trouble installing please raise a [github issue](https://github.com/MPoL-dev/MPoL/issues) with the particulars of your system. 30 | 31 | If you're interested in contributing to the MPoL package, please see the [](developer-documentation.md). 32 | 33 | ## Upgrading 34 | 35 | If you installed from PyPI, to upgrade to the latest stable version of MPoL, do 36 | 37 | ``` 38 | $ pip install --upgrade MPoL 39 | ``` 40 | 41 | If you installed from source, update the repository 42 | 43 | ``` 44 | $ cd MPoL 45 | $ git pull 46 | $ pip install . 47 | ``` 48 | 49 | You can determine your current installed version by 50 | 51 | ``` 52 | $ python 53 | >>> import mpol 54 | >>> print(mpol.__version__) 55 | ``` 56 | 57 | ## Documentation 58 | 59 | The documentation served online ([here](https://mpol-dev.github.io/MPoL/index.html)) corresponds to the `main` branch. This represents the current state of MPoL and is usually the best place to reference MPoL functionality. However, this documentation may be more current than last tagged version or the version you have installed. If you require the new features detailed in the documentation, then we recommend installing the package from source (as above). 60 | 61 | In the (foreseeably rare) situation where the latest online documentation significantly diverges from the package version you wish to use (but there are reasons you do not want to build the `main` branch from source), you can access the documentation for that version by [building the older documentation locally](developer-documentation.md#older-documentation-versions) 62 | 63 | ## Getting Started 64 | 65 | As a PyTorch imaging library, there are many things one could do with MPoL. Over at the [MPoL-dev/examples](https://github.com/MPoL-dev/examples/) repository, we've collected example scripts for some of the more common workflows such as diagnostic imaging with {meth}`mpol.gridding.DirtyImager`, imaging with a stochastic gradient descent workflow, and visibility inference with Pyro. 66 | -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MPoL-dev/MPoL/f1facdb3737a2b45c5308f0ebd9ee60583edf0f1/docs/logo.png -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/units-and-conventions.md: -------------------------------------------------------------------------------- 1 | (units-conventions-label)= 2 | 3 | # Units and Conventions 4 | 5 | ## Fourier transform conventions 6 | 7 | We follow the (reasonably standard) conventions of the Fourier transform (e.g., [Bracewell's](https://ui.adsabs.harvard.edu/abs/2000fta..book.....B/abstract) "system 1"). 8 | 9 | **Forward transform**: 10 | 11 | $$ 12 | F(s) = \int_{-\infty}^\infty f(x) e^{-i 2 \pi x s} \mathrm{d}x 13 | $$ 14 | 15 | **Inverse transform**: 16 | 17 | $$ 18 | f(x) = \int_{-\infty}^\infty F(s) e^{i 2 \pi x s} \mathrm{d}s 19 | $$ 20 | 21 | ## Baseline Convention 22 | 23 | For two antennas `ant1` and `ant2`, the baseline convention describes how the baseline between them is represented in UVW coordinates and whether the positive baseline is measured from `ant1->ant2`, or from `ant2->ant1`. 24 | MPoL follows the former, standard baseline convention as laid out in [Thompson, Moran, and Swenson](https://ui.adsabs.harvard.edu/abs/2017isra.book.....T/abstract) and other radio interferometry textbooks. However, CASA follows a historically complicated convention [described in CASA Memo 2](https://casadocs.readthedocs.io/en/stable/notebooks/memo-series.html). The difference between the two can be expressed as the complex conjugate of the visibilities. So, if you find that your image appears upside down and mirrored, you'll want to take ``np.conj`` of your visibilities before proceeding. 25 | 26 | ## Continuous representation of interferometry 27 | 28 | Consider some astronomical source parameterized by its sky brightness distribution $I$. The sky brightness is a function of position on the sky. For small fields of view (typical to single-pointing ALMA or VLA observations) we parameterize the sky direction using the direction cosines $l$ and $m$, which correspond to R.A. and Dec, respectively. In that case, we would have a function $I(l,m)$. The sky brightness is an *intensity*, so it has units of $\mathrm{Jy\,arcsec}^{-2}$ (equivalently $\mathrm{Jy\, beam}^{-1}$, where $\mathrm{beam}$ is the effective area of the synthesized beam). 29 | 30 | The real domain is linked to the Fourier domain, also called the visibility domain, via the Fourier transform 31 | 32 | $$ 33 | {\cal V}(u,v) = \int \int I(l,m) \exp \left \{- 2 \pi i (ul + vm) \right \} \, \mathrm{d}l\,\mathrm{d}m. 34 | $$ 35 | 36 | This integral demonstrates that the units of visibility function (and samples of it) are $\mathrm{Jy}$. 37 | 38 | ```{note} 39 | This simplified relationship omits many additional effects that modify the cosmic intensity before it is recorded as visibility data. A full treatment of these effects can be mathematically described by the Radio Interferometric Measurement Equation (RIME). See the [Revisiting the radio interferometer measurement equation](https://ui.adsabs.harvard.edu/abs/2011A%26A...527A.106S/abstract) series by O. Smirnov, 2011 for more details. 40 | ``` 41 | 42 | ## Discretized representation 43 | 44 | There are several annoying pitfalls that can arise when dealing with discretized images and Fourier transforms, and most relate back to confusing or ill-specified conventions. The purpose of this page is to explicitly define the conventions used throughout MPoL and make clear how each transformation relates back to the continuous equations. 45 | 46 | ### Pixel fluxes and Cube dimensions 47 | 48 | - Throughout the codebase, any sky plane cube representing the sky plane is assumed to have units of $\mathrm{Jy\,arcsec}^{-2}$. 49 | - The image cubes are packed as 3D arrays `(nchan, npix, npix)`. 50 | - The "rows" of the image cube (axis=1) correspond to the $m$ or Dec axis. There are $M$ number of pixels in the Dec axis. 51 | - The "columns" of the image cube (axis=2) correspond to the $l$ or R.A. axis. There are $L$ number of pixels in the R.A. axis. 52 | 53 | ### Angular units 54 | 55 | For nearly all user-facing routines, the angular axes (the direction cosines $l$ and $m$ corresponding to R.A. and Dec, respectively) are expected to have units of arcseconds. The sample rate of the image cube is set via the `cell_size` parameter, in units of arcsec. Internally, these quantities are represented in radians. Let the pixel spacing be represented by $\Delta l$ and $\Delta m$, respectively. 56 | 57 | ### Fourier units 58 | 59 | The sampling rate in the Fourier domain is inversely related to the number of samples and the sampling rate in the image domain. I.e., the grid spacing is 60 | 61 | $$ 62 | \Delta u = \frac{1}{L \Delta l} \\ 63 | \Delta v = \frac{1}{M \Delta m}. 64 | $$ 65 | 66 | If $\Delta l$ and $\Delta m$ are in units of radians, then $\Delta u$ and $\Delta v$ are in units of cycles per radian. Thanks to the geometric relationship of the interferometer, the spatial frequency units can equivalently be expressed as the baseline lengths measured in multiples of the observing wavelength $\lambda$. 67 | 68 | For example, take an observation with ALMA band 6 at an observing frequency of 230 GHz, corresponding to a wavelength of 1.3mm. A 100 meter baseline between antennas will measure a spatial frequency of $\frac{100\,\mathrm{m} }{ 1.3 \times 10^{-3}\,\mathrm{m}} \approx 77 \mathrm{k}\lambda$ or 77,000 cycles per radian. 69 | 70 | For more information on the relationship between baselines and spatial frequencies, see [TMS Eqn Chapter 2.3, equations 2.13 and 2.14](https://ui.adsabs.harvard.edu/abs/2017isra.book.....T/abstract). Internally, MPoL usually represents spatial frequencies in units of $\mathrm{k}\lambda$. 71 | 72 | For reference, here are some typical ALMA baseline lengths and their (approximate) corresponding spatial frequencies at common observing frequencies 73 | 74 | ```{eval-rst} 75 | .. csv-table:: 76 | :file: _static/baselines/build/baselines.csv 77 | :header-rows: 1 78 | ``` 79 | 80 | Occasionally, it is useful to represent the cartesian Fourier coordinates $u$, $v$ in polar coordinates $q$, $\phi$ 81 | 82 | $$ 83 | q = \sqrt{u^2 + v^2}\\ 84 | \phi = \mathrm{atan2}(v,u). 85 | $$ 86 | 87 | $\phi$ represents the angle between the $+u$ axis and the ray drawn from the origin to the point $(u,v)$. Following the [numerical conventions](https://en.wikipedia.org/wiki/Atan2) of the `arctan2` function, $\phi$ is defined over the range $(-\pi, \pi]$. 88 | 89 | ### The discrete Fourier transform 90 | 91 | Since we are dealing with discrete quantities (pixels), we use the discrete versions of the Fourier transform (DFT), carried out by the Fast Fourier transform (FFT). Throughout the package we use the implementations in numpy or PyTorch: they are mathematically the same, but PyTorch provides the opportunities for autodifferentiation. For both the forward and inverse transforms, we assume that `norm='backward'`, the default setting. This means we don't need to keep account for the $L$ or $M$ prefactors for the forward transform, but we do need to account for the $U$ and $V$ prefactors in the inverse transform. 92 | 93 | **Forward transform**: As before, we use the forward transform to go from the image plane (sky brightness distribution) to the Fourier plane (visibility function). This is the most common transform used in MPoL because RML can be thought of as a type of forward modeling procedure: we're proposing an image and carrying it to the visibility plane to check its fit with the data. In numpy, the forward FFT is [defined as](https://docs.scipy.org/doc/numpy/reference/routines.fft.html#module-numpy.fft) 94 | 95 | $$ 96 | \mathtt{FFT}(I_{l,m}) = \sum_{l=0}^{L-1} \sum_{m=0}^{M-1} I_{l,m} \exp \left \{- 2 \pi i (ul/L + vm/M) \right \} 97 | $$ 98 | 99 | To make the FFT output an appropriate representation of the continuous forward Fourier transform, we need to account for the spacing of the input samples. The FFT knows only that it was served a sequence of numbers, it does not know that the samples in $I_{l,m}$ are spaced `cell_size` apart. To do this, we just need to account for the spacing as a prefactor (i.e., converting the $\mathrm{d}l$ to $\Delta l$), following [TMS Eqn A8.18](https://ui.adsabs.harvard.edu/abs/2017isra.book.....T/abstract) 100 | 101 | $$ 102 | V_{u,v} = (\Delta l)(\Delta m) \mathtt{FFT}(I_{l,m}) 103 | $$ 104 | 105 | In this context, the $u,v$ subscripts indicate the elements of the $V$ array. As long as $I_{l,m}$ is in units of $\mathrm{Jy} / (\Delta l \Delta m)$, then $V$ will be in the correct output units (flux, or Jy). 106 | 107 | **Inverse transform**: The inverse transform is used within MPoL to produce a quick diagnostic image from the visibilities (called the "dirty image"). As you might expect, this is the inverse operation of the forward transform. Numpy and PyTorch define the inverse transform as 108 | 109 | $$ 110 | \mathtt{iFFT}({\cal V}_{u,v}) = \frac{1}{U} \frac{1}{V} \sum_{l=0}^{U-1} \sum_{m=0}^{V-1} {\cal V}_{u,v} \exp \left \{2 \pi i (ul/L + vm/M) \right \} 111 | $$ 112 | 113 | If we had a fully sampled grid of ${\cal V}_{u,v}$ values, then the operation we'd want to carry out to produce an image needs to correct for both the cell spacing and the counting terms 114 | 115 | $$ 116 | I_{l,m} = U V (\Delta u)(\Delta v) \mathtt{iFFT}({\cal V}_{u,v}) 117 | $$ 118 | 119 | For more information on this procedure as implmented in MPoL, see the {class}`~mpol.gridding.Gridder` class and the source code of its {func}`~mpol.gridding.Gridder.get_dirty_image` method. When the grid of ${\cal V}_{u,v}$ values is not fully sampled (as in any real-world interferometric observation), there are many subtleties beyond this simple equation that warrant consideration when synthesizing an image via inverse Fourier transform. For more information, consult the seminal [Ph.D. thesis](http://www.aoc.nrao.edu/dissertations/dbriggs/) of Daniel Briggs. 120 | 121 | (cube-orientation-label)= 122 | ### Image Cube Packing for FFTs 123 | 124 | Numerical FFT routines expect that the first element of an input array (i.e., `array[i,0,0]`) corresponds to the zeroth spatial ($l,m$) or frequency ($u,v$) coordinate. This convention is quite different than the way we normally look at images. As described above, MPoL deals with three dimensional image cubes of shape `(nchan, npix, npix)`, where the "rows" of the image cube (axis=1) correspond to the $m$ or Dec axis, and the "columns" of the image cube (axis=2) correspond to the $l$ or R.A. axis. Normally, the zeroth spatial component $(l,m) = (0,0)$ is in the *center* of the array (at position `array[i,M/2,L/2]`), so that when an array is visualized (say with `matplotlib.pyplot.imshow`, `origin="lower"`), the center of the array appears in the center of the image. 125 | 126 | ```{image} _static/fftshift/build/plot.png 127 | ``` 128 | 129 | Complicating this already non-standard situation is the fact that astronomers usually plot images as seen on the sky: with north ($m$) up and east ($l$) to the left. Throughout the MPoL base, we call these cubes 'sky cubes,' see the above figure for a representation. In order to display sky cubes properly with routines like `matplotilb.pyplot.imshow`, when indexed as `array[i,j,k]`, an increasing `k` index must correspond to *decreasing* values of $l$. (It's OK that an increasing `j` index corresponds to increasing values of $m$, however we must be certain to include the `origin="lower` argument when using `matplotlib.pyplot.imshow`). 130 | 131 | Correctly taking the Fourier transform of a sky cube requires several steps. First, we must flip the cube across the R.A. axis (axis=2) to create an `array[i,j,k]` which has both increasing values of `j` and `k` correspond to increasing values of $m$ and $l$, respectively. We call this intermediate product a 'flip cube.' 132 | 133 | Then, the cube must be packed such that the first element(s) of an input array (i.e., `array[i,0,0]`) correspond to the zeroth spatial coordinates $(l,m) = (0,0)$. Thankfully, we can carry out this operation easily using `fftshift` functions commonly provided by FFT packages like `numpy.fft` or `torch.fft`. We shift across the Dec and R.A. axes (axis=1 and axis=2) leaving the channel axis (axis=0) untouched to create a 'packed image cube.' MPoL has convenience functions to carry out both the flip and packing operations called {func}`mpol.utils.sky_cube_to_packed_cube` and the inverse process {func}`mpol.utils.packed_cube_to_sky_cube`. 134 | 135 | After the FFT is correctly applied to the R.A. and Dec dimensions using `fft2`, the output is a packed visibility cube, where the first elements (i.e., `array[i,0,0]`) correspond to the zeroth spatial frequency coordinates $(u,v) = (0,0)$. To translate this cube back into something that's more recognizable when plotted, we can apply the `ifftshift` operation along the $v$ and $u$ axes (axis=1 and axis=2) leaving the channel axis (axis=0) untouched to create a 'ground visibility cube'. We choose to orient the visibility plane from the perspective of an areial observer looking down at an interferometric array on the ground, such that north is up and east is to the right, therefore no additional flip is required for the visibility cube. MPoL has convenience functions to carry out the unpacking operation {func}`mpol.utils.packed_cube_to_ground_cube` and the inverse process {func}`mpol.utils.ground_cube_to_packed_cube`. 136 | 137 | In practice, `fftshift` and `ifftshift` routines operate identically for arrays with an even number of elements (currently required by MPoL). 138 | -------------------------------------------------------------------------------- /paper/.gitignore: -------------------------------------------------------------------------------- 1 | jats 2 | paper.pdf -------------------------------------------------------------------------------- /paper/fig.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MPoL-dev/MPoL/f1facdb3737a2b45c5308f0ebd9ee60583edf0f1/paper/fig.pdf -------------------------------------------------------------------------------- /paper/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'Million Points of Light (MPoL): a PyTorch library for radio interferometric imaging and inference' 3 | tags: 4 | - Python 5 | - astronomy 6 | - imaging 7 | - fourier 8 | - radio astronomy 9 | - radio interferometry 10 | - machine learning 11 | - neural networks 12 | authors: 13 | - name: Ian Czekala 14 | orcid: 0000-0002-1483-8811 15 | corresponding: true 16 | affiliation: 1 17 | - name: Jeff Jennings 18 | orcid: 0000-0002-7032-2350 19 | affiliation: 2 20 | - name: Brianna Zawadzki 21 | orcid: 0000-0001-9319-1296 22 | affiliation: 3 23 | - name: Kadri Nizam 24 | orcid: 0000-0002-7217-446X 25 | affiliation: 4 26 | - name: Ryan Loomis 27 | orcid: 0000-0002-8932-1219 28 | affiliation: 5 29 | - name: Megan Delamer 30 | orcid: 0000-0003-1439-2781 31 | affiliation: 4 32 | - name: Kaylee de Soto 33 | orcid: 0000-0002-9886-2834 34 | affiliation: 4 35 | - name: Robert Frazier 36 | orcid: 0000-0001-6569-3731 37 | affiliation: 4 38 | - name: Hannah Grzybowski 39 | orcid: # can't find 40 | affiliation: 4 41 | - name: Mary Ogborn 42 | orcid: 0000-0001-9741-2703 43 | affiliation: 4 44 | - name: Tyler Quinn 45 | orcid: 0000-0002-8974-8095 46 | affiliation: 4 47 | affiliations: 48 | - name: University of St Andrews, Scotland 49 | index: 1 50 | - name: CCA, Flatiron Institute, NY, USA 51 | index: 2 52 | - name: Wesleyan University, CT, USA 53 | index: 3 54 | - name: Pennsylvania State University, PA, USA 55 | index: 4 56 | - name: National Radio Astronomy Observatory, Charlottesville, VA, USA 57 | index: 5 58 | date: 24 January 2025 59 | bibliography: paper.bib 60 | --- 61 | 62 | # Summary 63 | 64 | Astronomical radio interferometers achieve exquisite angular resolution by cross-correlating signal from a cosmic source simultaneously observed by distant pairs of radio telescopes to produce a Fourier-type measurement called a visibility. *Million Points of Light* (`MPoL`) is a Python library supporting feed-forward modeling of interferometric visibility datasets for synthesis imaging and parametric Bayesian inference, built using the autodifferentiable machine learning framework PyTorch. Neural network components provide a rich set of modular and composable building blocks that can be used to express the physical relationships between latent model parameters and observed data following the radio interferometric measurement equation. Industry-grade optimizers make it straightforward to simultaneously solve for the synthesized image and calibration parameters using stochastic gradient descent. 65 | 66 | # Statement of need 67 | 68 | When an astrophysical source is observed by a radio interferometer, there are frequently large gaps in the spatial frequency coverage. Therefore, rather than perform a direct Fourier inversion, images must be synthesized from the visibility data using an imaging algorithm; it is common for the incomplete sampling to severely hamper image fidelity [@condon16; @thompson17]. CLEAN is the traditional image synthesis algorithm of the radio interferometry community [@hogbom74], with a modern implementation in the reduction and analysis software CASA [@mcmullin07; @casa22], the standard for current major facility operations [e.g., @hunter23]. CLEAN excels at the rapid imaging of astronomical fields comprising unresolved point sources (e.g. quasars) and marginally resolved sources, but may struggle when the source morphology is not well-matched by the CLEAN basis set (e.g., point sources, Gaussians), a common situation with ring-like protoplanetary disk sources [@disk20, §3]. 69 | 70 | High fidelity imaging algorithms for spatially resolved sources are needed to realize the full scientific potential of groundbreaking observatories like the Atacama Large Millimeter Array (ALMA; @wootten09), the Event Horizon Telescope [@eht19a], and the Square Kilometer Array [@dewdney09] as they deliver significantly improved sensitivity and resolving power compared to previous generation instruments. In the field of planet formation alone, spatially resolved observations from ALMA have rapidly advanced our understanding of protoplanetary disk structures [@andrews20], kinematic signatures of embedded protoplanets [@pinte18], and circumplanetary disks [@benisty21; @casassus22]. Application of higher performance imaging techniques to these groundbreaking datasets [e.g., @casassus22] showed great promise in unlocking further scientific progress. Simultaneously, a flexible, open-source platform could integrate machine learning algorithms and computational imaging techniques from non-astronomy fields. 71 | 72 | # The Million Points of Light (MPoL) library 73 | 74 | `MPoL` is a library designed for feed-forward modeling of interferometric datasets using Python, Numpy [@harris20], and the computationally performant machine learning framework PyTorch [@paszke19], which debuted with @zawadzki23. `MPoL` implements a set of foundational interferometry components using PyTorch `nn.module`, which can be easily combined to build a forward-model of the interferometric dataset(s) at hand. We strive to seamlessly integrate with the PyTorch ecosystem so that users can easily leverage well-established machine learning workflows: optimization with stochastic gradient descent [@bishop23, Ch. 7], straightforward acceleration with GPU(s), and integration with common neural network architectures. 75 | 76 | In a typical feed-forward workflow, `MPoL` users will use foundational components like `BaseCube` and `ImageCube` to define the true-sky model, Fourier layers like `FourierCube` or `NuFFT` [wrapping `torchkbnufft`, @nufft20] to apply the Fourier transform and sample the visibility function at the location of the array baselines, and the negative log likelihood to calculate a data loss. Backpropagation [see @baydin18 for a review] and stochastic gradient descent [e.g., AdamW, @loshchilov17] are used to find the true-sky model that minimizes the loss function. However, because of the aforementioned gaps in spatial frequency coverage, there is technically an infinite number of true-sky images fully consistent with the data likelihood, so regularization loss terms are required. `MPoL` supports Regularized Maximum Likelihood (RML) imaging with common regularizers like maximum entropy, sparsity, and others [e.g., as used in @eht19d]; users can also implement custom regularizers with PyTorch. 77 | 78 | `MPoL` also provides several other workflows relevant to astrophysical research. First, by seamlessly coupling with the probabilistic programming language Pyro [@pyro19], `MPoL` supports Bayesian parametric inference of astronomical sources by modeling the data visibilities. Second, users can implement additional data calibration components as their data requires, enabling fine-scale, residual calibration physics to be parameterized and optimized simultaneously with image synthesis [following the radio interferometric measurement equation @hamaker96; @smirnov11a]. Finally, the library also provides convenience utilities like `DirtyImager` (including Briggs robust and UV taper) to confirm the data has been loaded correctly. The MPoL-dev organization also develops the [MPoL-dev/visread](https://mpol-dev.github.io/visread/) package, which is designed to facilitate the extraction of visibility data from CASA's Measurement Set format for use in alternative imaging workflows. 79 | 80 | # Documentation, examples, and scientific results 81 | 82 | MPoL is freely available, open-source software licensed via the MIT license and is developed on GitHub at [MPoL-dev/MPoL](https://github.com/MPoL-dev/MPoL). Installation and API documentation is hosted at [https://mpol-dev.github.io/MPoL/](https://mpol-dev.github.io/MPoL/), and is continuously built with each commit to the `main` branch. As a library, `MPoL` expects researchers to write short scripts using use `MPoL` and PyTorch primitives, in much the same way that PyTorch users write scripts for machine learning workflows (e.g., as in the [official PyTorch examples](https://github.com/pytorch/examples)). `MPoL` example projects are hosted on GitHub at [MPoL-dev/examples](https://github.com/MPoL-dev/examples). These include an introduction to generating mock data, a quickstart using stochastic gradient descent, and a Pyro workflow using stochastic variational inference (SVI) to replicate the parametric inference done in @guzman18, among others. In Figure \ref{imlup}, we compare an image obtained with CLEAN to that using `MPoL` and RML, synthesized from the data presented in @huang18b, highlighting the improvement in resolution offered by feed-forward modeling technologies.[^1] 83 | 84 | [^1]: Source code to reproduce this result is available as an [MPoL example](https://github.com/MPoL-dev/examples/tree/main). 85 | 86 | `MPoL` has already been used in a number of scientific publications. @zawadzki23 introduced `MPoL` and explored RML imaging for ALMA observations of protoplanetary disks, finding a 3x improvement in spatial resolution at comparable sensitivity. @dia23 used `MPoL` as a reference imaging implementation to evaluate the performance of their score-based prior algorithm. @huang24 used the parametric inference capabilities of `MPoL` to analyze radial dust substructures in a suite of eight protoplanetary disks in the $\sigma$ Orionis stellar cluster. `MPoL` was selected as an imaging technology of the exoALMA large program, where Zawadzki et al. 2024 *submitted* used RML imaging to obtain high resolution image cubes of molecular line emission in protoplanetary disks in order to identify non-Keplerian features that may trace planet-disk interactions. 87 | 88 | ![Left: the synthesized image produced by the DSHARP ALMA Large Program [@andrews18] using `CASA/tclean`. Right: The regularized maximum likelihood image produced using `MPoL` on the same data. Both images are displayed using a `sqrt` stretch, with upper limit truncated to 70\% and 40\% of max value for CLEAN and `MPoL`, respectively, to emphasize faint features. The CLEAN algorithm permits negative intensity values, while the `MPoL` algorithm enforces image positivity by construction. Each side of the image is 3 arcseconds. Intensity units are shown in units of Jy/arcsec^2^. \label{imlup}](fig.pdf) 89 | 90 | # Similar tools 91 | 92 | Recently, there has been significant work to design robust algorithms to image spatially resolved sources. A non-exhaustive list includes the `RESOLVE` family of algorithms [@junklewitz16], which impose Gaussian random field image priors, the multi-algorithm approach of the Event Horizon Telescope Collaboration [@eht19d] including regularized maximum likelihood techniques, MaxEnt [@carcamo18], and domain-specific non-parametric 1D approaches like `frank` [@jennings20]. Several approaches have leveraged deep-learning, such as score-based priors [@dia23], denoising diffusion probabilistic models [@wang23], and residual-to-residual deep neural networks [@dabbech24]. By contrast to many imaging software programs, `MPoL` is designed as a library, and so in theory can support a variety of forward-modeling workflows. 93 | 94 | The parametric modeling capabilities of `MPoL`, provided by integration with `Pyro`, are similar to the `emcee` [@foreman-mackey13] + synthetic visibility workflow provided by the Galario software [@tazzari18]. Since PyTorch enables automatic differentiation, `Pyro` users can utilize HMC/NUTS sampling [@neal12; @hoffman14] or SVI, which offer significant benefits in high dimensional spaces compared to ensemble MCMC samplers. 95 | 96 | 97 | # Acknowledgements 98 | 99 | We acknowledge funding from an ALMA Development Cycle 8 grant number AST-1519126. J.H. acknowledges support by the National Science Foundation under Grant No. 2307916. ALMA is a partnership of ESO (representing its member states), NSF (USA) and NINS (Japan), together with NRC (Canada), MOST and ASIAA (Taiwan), and KASI (Republic of Korea), in cooperation with the Republic of Chile. The Joint ALMA Observatory is operated by ESO, AUI/NRAO and NAOJ. The National Radio Astronomy Observatory is a facility of the National Science Foundation operated under cooperative agreement by Associated Universities, Inc. 100 | 101 | # References 102 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling", "hatch-vcs"] 4 | 5 | [project] 6 | authors = [{ name = "Ian Czekala", email = "ic95@st-andrews.ac.uk" }] 7 | classifiers = [ 8 | "Programming Language :: Python :: 3", 9 | "License :: OSI Approved :: MIT License", 10 | "Operating System :: OS Independent", 11 | ] 12 | dependencies = [ 13 | "numpy", 14 | "fast-histogram", 15 | "scipy", 16 | "torch>=1.8.0", 17 | "torchvision", 18 | "torchaudio", 19 | "torchkbnufft", 20 | "astropy", 21 | ] 22 | description = "Regularized Maximum Likelihood Imaging for Radio Astronomy" 23 | dynamic = ["version"] 24 | name = "MPoL" 25 | readme = "README.md" 26 | requires-python = ">=3.8" 27 | 28 | [project.optional-dependencies] 29 | dev = [ 30 | "pytest", 31 | "pytest-cov", 32 | "matplotlib", 33 | "requests", 34 | "astropy", 35 | "tensorboard", 36 | "mypy", 37 | "frank>=1.2.1", 38 | "sphinx>=7.2.0", 39 | "jupytext", 40 | "ipython!=8.7.0", # broken version for syntax highlight https://github.com/spatialaudio/nbsphinx/issues/687 41 | "nbsphinx", 42 | "sphinx_book_theme>=0.9.3", 43 | "sphinx_copybutton", 44 | "jupyter", 45 | "nbconvert", 46 | "sphinxcontrib-mermaid>=0.8.1", 47 | "myst-nb", 48 | "jupyter-cache", 49 | "Pillow", 50 | "asdf", 51 | "pyro-ppl", 52 | "arviz[all]", 53 | "visread>=0.0.4", 54 | "ruff" 55 | ] 56 | test = [ 57 | "pytest", 58 | "pytest-cov", 59 | "matplotlib", 60 | "requests", 61 | "tensorboard", 62 | "mypy", 63 | "visread>=0.0.4", 64 | "frank>=1.2.1", 65 | "ruff" 66 | ] 67 | 68 | [project.urls] 69 | Homepage = "https://mpol-dev.github.io/MPoL/" 70 | Issues = "https://github.com/MPoL-dev/MPoL/issues" 71 | 72 | [tool.hatch.version] 73 | source = "vcs" 74 | 75 | [tool.hatch.build.hooks.vcs] 76 | version-file = "src/mpol/mpol_version.py" 77 | 78 | [tool.black] 79 | line-length = 88 80 | 81 | [tool.mypy] 82 | warn_unused_configs = true 83 | 84 | [[tool.mypy.overrides]] 85 | module = [ 86 | "astropy.*", 87 | "matplotlib.*", 88 | "scipy.*", 89 | "torchkbnufft.*", 90 | "frank.*", 91 | "fast_histogram.*" 92 | ] 93 | ignore_missing_imports = true 94 | 95 | [[tool.mypy.overrides]] 96 | module = [ 97 | "MPoL.constants", 98 | "MPoL.coordinates", 99 | "MPoL.datasets", 100 | "MPoL.fourier", 101 | "MPoL.geometry", 102 | "MPoL.gridding", 103 | "MPoL.images", 104 | "MPoL.losses", 105 | "MPoL.precomposed", 106 | "MPoL.utils" 107 | ] 108 | disallow_untyped_defs = true 109 | 110 | [tool.ruff] 111 | target-version = "py310" 112 | line-length = 88 113 | # will enable after sorting module locations 114 | # select = ["F", "I", "E", "W", "YTT", "B", "Q", "PLE", "PLR", "PLW", "UP"] 115 | lint.ignore = [ 116 | "E741", # Allow ambiguous variable names 117 | "PLR0911", # Allow many return statements 118 | "PLR0913", # Allow many arguments to functions 119 | "PLR0915", # Allow many statements 120 | "PLR2004", # Allow magic numbers in comparisons 121 | ] 122 | exclude = [] -------------------------------------------------------------------------------- /src/mpol/__init__.py: -------------------------------------------------------------------------------- 1 | zenodo_record = 10064221 2 | -------------------------------------------------------------------------------- /src/mpol/constants.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from astropy.constants import c, k_B 3 | 4 | # convert from arcseconds to radians 5 | arcsec: float = np.pi / (180.0 * 3600) # [radians] = 1/206265 radian/arcsec 6 | 7 | deg: float = np.pi / 180 # [radians] 8 | 9 | kB: float = k_B.cgs.value # [erg K^-1] Boltzmann constant 10 | cc: float = c.cgs.value # [cm s^-1] 11 | c_ms: float = c.value # [m s^-1] 12 | -------------------------------------------------------------------------------- /src/mpol/data/mock_data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MPoL-dev/MPoL/f1facdb3737a2b45c5308f0ebd9ee60583edf0f1/src/mpol/data/mock_data.npz -------------------------------------------------------------------------------- /src/mpol/datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any 4 | 5 | import numpy as np 6 | import torch 7 | from numpy import floating, integer 8 | from numpy.typing import ArrayLike, NDArray 9 | 10 | from mpol import utils 11 | from mpol.coordinates import GridCoords 12 | 13 | 14 | class GriddedDataset(torch.nn.Module): 15 | r""" 16 | Parameters 17 | ---------- 18 | coords : :class:`~mpol.coordinates.GridCoords` 19 | If providing this, cannot provide ``cell_size`` or ``npix``. 20 | vis_gridded : :class:`torch.Tensor` of :class:`torch.complex128` 21 | the gridded visibility data stored in a "packed" format (pre-shifted for fft) 22 | weight_gridded : :class:`torch.Tensor` 23 | the weights corresponding to the gridded visibility data, 24 | also in a packed format 25 | mask : :class:`torch.Tensor` of :class:`torch.bool` 26 | a boolean mask to index the non-zero locations of ``vis_gridded`` and 27 | ``weight_gridded`` in their packed format. 28 | nchan : int 29 | the number of channels in the image (default = 1). 30 | 31 | 32 | After initialization, the GriddedDataset provides the non-zero cells of the 33 | gridded visibilities and weights as a 1D vector via the following instance 34 | variables. This means that any individual channel information has been collapsed. 35 | 36 | :ivar vis_indexed: 1D complex tensor of visibility data 37 | 38 | :ivar weight_indexed: 1D tensor of weight values 39 | 40 | 41 | If you index the output of the Fourier layer in the same manner using ``self.mask``, 42 | then the model and data visibilities can be directly compared using a loss function. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | *, 48 | coords: GridCoords, 49 | vis_gridded: torch.Tensor, 50 | weight_gridded: torch.Tensor, 51 | mask: torch.Tensor, 52 | nchan: int = 1, 53 | ) -> None: 54 | super().__init__() 55 | 56 | self.coords = coords 57 | self.nchan = nchan 58 | 59 | # store variables as buffers of the module 60 | self.register_buffer("vis_gridded", vis_gridded) 61 | self.register_buffer("weight_gridded", weight_gridded) 62 | self.register_buffer("mask", mask) 63 | self.vis_gridded: torch.Tensor 64 | self.weight_gridded: torch.Tensor 65 | self.mask: torch.Tensor 66 | 67 | # pre-index the values 68 | # note that these are *collapsed* across all channels 69 | # 1D array 70 | self.register_buffer("vis_indexed", self.vis_gridded[self.mask]) 71 | self.register_buffer("weight_indexed", self.weight_gridded[self.mask]) 72 | self.vis_indexed: torch.Tensor 73 | self.weight_indexed: torch.Tensor 74 | 75 | def add_mask( 76 | self, 77 | mask: ArrayLike, 78 | ) -> None: 79 | r""" 80 | Apply an additional mask to the data. Only works as a data limiting operation 81 | (i.e., ``mask`` is more restrictive than the mask already attached 82 | to the dataset). 83 | 84 | Args: 85 | mask (2D numpy or PyTorch tensor): boolean mask (in packed format) to 86 | apply to dataset. Assumes input will be broadcast across all channels. 87 | """ 88 | 89 | new_2D_mask = torch.Tensor(mask).detach() 90 | new_3D_mask = torch.broadcast_to(new_2D_mask, self.mask.size()) 91 | 92 | # update mask via an AND operation, we will only keep visibilities that are 93 | # 1) part of the original dataset 94 | # 2) valid within the new mask 95 | self.mask = torch.logical_and(self.mask, new_3D_mask) 96 | 97 | # zero out vis_gridded and weight_gridded that may have existed 98 | # but are no longer valid 99 | # These operations on the gridded quantities are only important for routines 100 | # that grab these quantities directly, like residual grid imager 101 | self.vis_gridded[~self.mask] = 0.0 102 | self.weight_gridded[~self.mask] = 0.0 103 | 104 | # update pre-indexed values 105 | self.vis_indexed = self.vis_gridded[self.mask] 106 | self.weight_indexed = self.weight_gridded[self.mask] 107 | 108 | def forward(self, modelVisibilityCube: torch.Tensor) -> torch.Tensor: 109 | """ 110 | Args: 111 | modelVisibilityCube (complex torch.tensor): with shape 112 | ``(nchan, npix, npix)`` to be indexed. In "pre-packed" format, as in 113 | output from :meth:`mpol.fourier.FourierCube.forward()` 114 | 115 | Returns: 116 | torch complex tensor: 1d torch tensor of indexed model samples collapsed 117 | across cube dimensions. 118 | """ 119 | 120 | assert ( 121 | modelVisibilityCube.size()[0] == self.mask.size()[0] 122 | ), "vis and dataset mask do not have the same number of channels." 123 | 124 | # As of Pytorch 1.7.0, complex numbers are partially supported. 125 | # However, masked_select does not yet work (with gradients) 126 | # on the complex vis, so hence this awkward step of selecting 127 | # the reals and imaginaries separately 128 | re = modelVisibilityCube.real.masked_select(self.mask) 129 | im = modelVisibilityCube.imag.masked_select(self.mask) 130 | 131 | # we had trouble returning things as re + 1.0j * im, 132 | # but for some reason torch.complex seems to work OK. 133 | return torch.complex(re, im) 134 | 135 | @property 136 | def ground_mask(self) -> torch.Tensor: 137 | r""" 138 | The boolean mask, arranged in ground format. 139 | 140 | Returns: 141 | torch.boolean : 3D mask cube of shape ``(nchan, npix, npix)`` 142 | 143 | """ 144 | return utils.packed_cube_to_ground_cube(self.mask) 145 | 146 | 147 | class Dartboard: 148 | r""" 149 | A polar coordinate grid relative to a :class:`~mpol.coordinates.GridCoords` object, 150 | reminiscent of a dartboard layout. The main utility of this object is to support 151 | splitting a dataset along radial and azimuthal bins for k-fold cross validation. 152 | 153 | Args: 154 | coords (GridCoords): an object already instantiated from the GridCoords class. 155 | If providing this, cannot provide ``cell_size`` or ``npix``. 156 | q_edges (1D numpy array): an array of radial bin edges to set the dartboard 157 | cells in :math:`[\mathrm{k}\lambda]`. If ``None``, defaults to 12 158 | log-linearly radial bins stretching from 0 to the :math:`q_\mathrm{max}` 159 | represented by ``coords``. 160 | phi_edges (1D numpy array): an array of azimuthal bin edges to set the 161 | dartboard cells in [radians], over the domain :math:`[0, \pi]`, which is 162 | also implicitly mapped to the domain :math:`[-\pi, \pi]` to preserve the 163 | Hermitian nature of the visibilities. If ``None``, defaults to 164 | 8 equal-spaced azimuthal bins stretched from :math:`0` to :math:`\pi`. 165 | """ 166 | 167 | def __init__( 168 | self, 169 | coords: GridCoords, 170 | q_edges: NDArray[floating[Any]] | None = None, 171 | phi_edges: NDArray[floating[Any]] | None = None, 172 | ) -> None: 173 | self.coords = coords 174 | self.nchan = 1 175 | 176 | # if phi_edges is not given, we'll instantiate 177 | if phi_edges is None: 178 | phi_edges = np.linspace(0, np.pi, num=8 + 1) # [radians] 179 | elif not all(0 <= edge <= np.pi for edge in phi_edges): 180 | raise ValueError("Elements of phi_edges must be between 0 and pi.") 181 | 182 | if q_edges is None: 183 | # set q edges approximately following inspiration from Petry et al. scheme: 184 | # https://ui.adsabs.harvard.edu/abs/2020SPIE11449E..1DP/abstract 185 | # first two bins set to 7m width 186 | # after third bin, bin width increases linearly until it is 187 | # 700m at 16km baseline. 188 | # From 16m to 16km, bin width goes from 7m to 700m. 189 | # --- 190 | # We aren't doing *quite* the same thing, 191 | # just logspacing with a few linear cells at the start. 192 | q_edges = utils.loglinspace(0, self.q_max, N_log=8, M_linear=5) 193 | 194 | self.q_edges = q_edges 195 | self.phi_edges = phi_edges 196 | 197 | @property 198 | def cartesian_qs(self) -> NDArray[floating[Any]]: 199 | return self.coords.packed_q_centers_2D 200 | 201 | @property 202 | def cartesian_phis(self) -> NDArray[floating[Any]]: 203 | return self.coords.packed_phi_centers_2D 204 | 205 | @property 206 | def q_max(self) -> float: 207 | return self.coords.q_max 208 | 209 | def get_polar_histogram( 210 | self, qs: NDArray[floating[Any]], phis: NDArray[floating[Any]] 211 | ) -> NDArray[floating[Any]]: 212 | r""" 213 | Calculate a histogram in polar coordinates, using the bin edges defined by 214 | ``q_edges`` and ``phi_edges`` during initialization. 215 | Data coordinates should include the points for the Hermitian visibilities. 216 | 217 | Args: 218 | qs: 1d array of q values :math:`[\lambda]` 219 | phis: 1d array of datapoint azimuth values [radians] (must be the same 220 | length as qs) 221 | 222 | Returns: 223 | 2d integer numpy array of cell counts, i.e., how many datapoints fell into 224 | each dartboard cell. 225 | """ 226 | 227 | histogram: NDArray 228 | # make a polar histogram 229 | histogram, *_ = np.histogram2d( # type:ignore 230 | qs, phis, bins=[self.q_edges.tolist(), self.phi_edges.tolist()] # type:ignore 231 | ) 232 | 233 | return histogram 234 | 235 | def get_nonzero_cell_indices( 236 | self, qs: NDArray[floating[Any]], phis: NDArray[floating[Any]] 237 | ) -> NDArray[integer[Any]]: 238 | r""" 239 | Return a list of the cell indices that contain data points, using the bin edges 240 | defined by ``q_edges`` and ``phi_edges`` during initialization. 241 | Data coordinates should include the points for the Hermitian visibilities. 242 | 243 | Args: 244 | qs: 1d array of q values :math:`[\lambda]` 245 | phis: 1d array of datapoint azimuth values [radians] (must be the same 246 | length as qs) 247 | 248 | Returns: 249 | list of cell indices where cell contains at least one datapoint. 250 | """ 251 | 252 | # make a polar histogram 253 | histogram = self.get_polar_histogram(qs, phis) 254 | 255 | indices = np.argwhere(histogram > 0) # [i,j] indexes to go to q, phi 256 | 257 | return indices 258 | 259 | def build_grid_mask_from_cells( 260 | self, cell_index_list: NDArray[integer[Any]] 261 | ) -> NDArray[np.bool_]: 262 | r""" 263 | Create a boolean mask of size ``(npix, npix)`` (in packed format) corresponding 264 | to the ``vis_gridded`` and ``weight_gridded`` quantities of the 265 | :class:`~mpol.datasets.GriddedDataset` . 266 | 267 | Args: 268 | cell_index_list (list): list or iterable containing [q_cell, phi_cell] index 269 | pairs to include in the mask. 270 | 271 | Returns: (numpy array) 2D boolean mask in packed format. 272 | """ 273 | mask = np.zeros_like(self.cartesian_qs, dtype="bool") 274 | 275 | # uses about a Gb..., and this only 256x256 276 | for cell_index in cell_index_list: 277 | qi, pi = cell_index 278 | q_min, q_max = self.q_edges[qi : qi + 2] 279 | p0_min, p0_max = self.phi_edges[pi : pi + 2] 280 | # also include Hermitian values 281 | p1_min, p1_max = self.phi_edges[pi : pi + 2] - np.pi 282 | 283 | # whether or not the q and phi values of the coordinate array 284 | # fit in the q cell and *either of* the regular or Hermitian phi cell 285 | ind = ( 286 | (self.cartesian_qs >= q_min) 287 | & (self.cartesian_qs < q_max) 288 | & ( 289 | ((self.cartesian_phis > p0_min) & (self.cartesian_phis <= p0_max)) 290 | | ((self.cartesian_phis > p1_min) & (self.cartesian_phis <= p1_max)) 291 | ) 292 | ) 293 | 294 | mask[ind] = True 295 | 296 | return mask 297 | -------------------------------------------------------------------------------- /src/mpol/exceptions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | 4 | class CellSizeError(Exception): 5 | ... 6 | 7 | 8 | class WrongDimensionError(Exception): 9 | ... 10 | 11 | 12 | class DataError(Exception): 13 | ... 14 | 15 | 16 | class DimensionMismatchError(Exception): 17 | ... 18 | 19 | 20 | class ThresholdExceededError(Exception): 21 | ... 22 | -------------------------------------------------------------------------------- /src/mpol/geometry.py: -------------------------------------------------------------------------------- 1 | """The geometry package provides routines for projecting and de-projecting sky images. 2 | """ 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def flat_to_observer( 9 | x: torch.Tensor, 10 | y: torch.Tensor, 11 | omega: float = 0.0, 12 | incl: float = 0.0, 13 | Omega: float = 0.0, 14 | ) -> tuple[torch.Tensor, torch.Tensor]: 15 | """Rotate the frame to convert a point in the flat (x,y,z) frame to observer frame 16 | (X,Y,Z). 17 | 18 | It is assumed that the +Z axis points *towards* the observer. It is assumed that the 19 | model is flat in the (x,y) frame (like a flat disk), and so the operations 20 | involving ``z`` are neglected. But the model lives in 3D Cartesian space. 21 | 22 | In order, 23 | 24 | 1. rotate about the z axis by an amount omega -> x1, y1, z1 25 | 2. rotate about the x1 axis by an amount -incl -> x2, y2, z2 26 | 3. rotate about the z2 axis by an amount Omega -> x3, y3, z3 = X, Y, Z 27 | 28 | Inspired by `exoplanet/keplerian.py 29 | `_ 30 | 31 | Note that the (x,y) here are *not* the same as the `x_centers_2D` or `y_centers_2D` 32 | attached to the :class:`mpol.coordinates.GridCoords` object. The (x,y) referred to 33 | here are the 'perifocal frame' of the orbit, whereas the (X,Y,Z) are the sky or 34 | observer frame. Typically, the sky observer frame is oriented such that X is North 35 | (pointing up) and Y is East (pointing left). For more detail, see the `exoplanet 36 | docs `_ or `Murray and Correia `_. 37 | 38 | Parameters 39 | ---------- 40 | x : :class:`torch.Tensor` 41 | A tensor representing the x coordinate in the plane of the orbit. 42 | y : :class:`torch.Tensor` 43 | A tensor representing the y coordinate in the plane of the orbit. 44 | omega : float 45 | Argument of periastron [radians]. Default 0.0. 46 | incl : float 47 | Inclination value [radians]. Default 0.0. 48 | Omega : float 49 | Position angle of the ascending node in [radians]. Default 0.0 50 | 51 | Returns 52 | ------- 53 | 2-tuple of :class:`torch.Tensor` 54 | Two tensors representing ``(X, Y)`` in the observer frame. 55 | """ 56 | # Rotation matrices result in a *clockwise* rotation of the axes, 57 | # as defined using the righthand rule. 58 | # For example, looking down the z-axis, 59 | # a positive angle will rotate the x,y axes clockwise. 60 | # A vector in the coordinate system will appear as though it has been 61 | # rotated counter-clockwise. 62 | 63 | # 1) rotate about the z0 axis by omega 64 | cos_omega = np.cos(omega) 65 | sin_omega = np.sin(omega) 66 | 67 | x1 = cos_omega * x - sin_omega * y 68 | y1 = sin_omega * x + cos_omega * y 69 | 70 | # 2) rotate about x1 axis by -incl 71 | x2 = x1 72 | y2 = np.cos(incl) * y1 73 | # z3 = z2, subsequent rotation by Omega doesn't affect it 74 | # Z = -torch.sin(incl) * y1 75 | 76 | # 3) rotate about z2 axis by Omega 77 | cos_Omega = np.cos(Omega) 78 | sin_Omega = np.sin(Omega) 79 | 80 | X = cos_Omega * x2 - sin_Omega * y2 81 | Y = sin_Omega * x2 + cos_Omega * y2 82 | 83 | return X, Y 84 | 85 | 86 | def observer_to_flat( 87 | X: torch.Tensor, 88 | Y: torch.Tensor, 89 | omega: float = 0.0, 90 | incl: float = 0.0, 91 | Omega: float = 0.0, 92 | ) -> tuple[torch.Tensor, torch.Tensor]: 93 | """Rotate the frame to convert a point in the observer frame (X,Y,Z) to the 94 | flat (x,y,z) frame. 95 | 96 | It is assumed that the +Z axis points *towards* the observer. The rotation 97 | operations are the inverse of the :func:`~mpol.geometry.flat_to_observer` operations. 98 | 99 | In order, 100 | 101 | 1. inverse rotation about the Z axis by an amount Omega -> x2, y2, z2 102 | 2. inverse rotation about the x2 axis by an amount -incl -> x1, y1, z1 103 | 3. inverse rotation about the z1 axis by an amount omega -> x, y, z 104 | 105 | Inspired by `exoplanet/keplerian.py 106 | `_ 107 | 108 | Note that the (x,y) here are *not* the same as the `x_centers_2D` or `y_centers_2D` 109 | attached to the :class:`mpol.coordinates.GridCoords` object. The (x,y) referred to 110 | here are the 'perifocal frame' of the orbit, whereas the (X,Y,Z) are the sky or 111 | observer frame. Typically, the sky observer frame is oriented such that X is North 112 | (pointing up) and Y is East (pointing left). For more detail, see the `exoplanet 113 | docs `_ or `Murray and Correia `_. 114 | 115 | Parameters 116 | ---------- 117 | X : :class:`torch.Tensor` 118 | A tensor representing the x coordinate in the plane of the sky. 119 | Y : :class:`torch.Tensor` 120 | A tensor representing the y coordinate in the plane of the sky. 121 | omega : float 122 | A tensor representing an argument of periastron [radians] Default 0.0. 123 | incl : float 124 | A tensor representing an inclination value [radians]. Default 0.0. 125 | Omega : float 126 | A tensor representing the position angle of the ascending node in [radians]. 127 | Default 0.0 128 | 129 | Returns 130 | ------- 131 | 2-tuple of :class:`torch.Tensor` 132 | Two tensors representing ``(x, y)`` in the flat frame. 133 | """ 134 | # Rotation matrices result in a *clockwise* rotation of the axes, 135 | # as defined using the righthand rule. 136 | # For example, looking down the z-axis, a positive angle will rotate the 137 | # x,y axes clockwise. 138 | # A vector in the coordinate system will appear as though it has been 139 | # rotated counter-clockwise. 140 | 141 | # 1) inverse rotation about Z axis by Omega -> x2, y2, z2 142 | cos_Omega = np.cos(Omega) 143 | sin_Omega = np.sin(Omega) 144 | 145 | x2 = cos_Omega * X + sin_Omega * Y 146 | y2 = -sin_Omega * X + cos_Omega * Y 147 | 148 | # 2) inverse rotation about x2 axis by incl 149 | x1 = x2 150 | # we don't know Z, but we can solve some equations to find that 151 | # y = Y / cos(i), as expected by intuition 152 | y1 = y2 / np.cos(incl) 153 | 154 | # 3) inverse rotation about the z1 axis by an amount of omega 155 | cos_omega = np.cos(omega) 156 | sin_omega = np.sin(omega) 157 | 158 | x = x1 * cos_omega + y1 * sin_omega 159 | y = -x1 * sin_omega + y1 * cos_omega 160 | 161 | return x, y 162 | -------------------------------------------------------------------------------- /src/mpol/input_output.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from astropy.io import fits 3 | 4 | 5 | class ProcessFitsImage: 6 | """ 7 | Utilities for loading and retrieving metrics of a .fits image 8 | 9 | Parameters 10 | ---------- 11 | filename : str 12 | Path to the .fits file 13 | channel : int, default=0 14 | Channel of the image to access 15 | """ 16 | 17 | def __init__(self, filename, channel=0): 18 | self._fits_file = filename 19 | self._channel = channel 20 | 21 | 22 | def get_extent(self, header): 23 | """Get extent (in RA and Dec, units of [arcsec]) of image""" 24 | 25 | # get the coordinate labels 26 | nx = header["NAXIS1"] 27 | ny = header["NAXIS2"] 28 | 29 | assert ( 30 | nx % 2 == 0 and ny % 2 == 0 31 | ), f"Image dimensions x {nx} and y {ny} must be even." 32 | 33 | # RA coordinates 34 | CDELT1 = 3600 * header["CDELT1"] # arcsec (converted from decimal deg) 35 | # CRPIX1 = header["CRPIX1"] - 1.0 # Now indexed from 0 36 | 37 | # DEC coordinates 38 | CDELT2 = 3600 * header["CDELT2"] # arcsec 39 | # CRPIX2 = header["CRPIX2"] - 1.0 # Now indexed from 0 40 | 41 | RA = (np.arange(nx) - nx / 2) * CDELT1 # [arcsec] 42 | DEC = (np.arange(ny) - ny / 2) * CDELT2 # [arcsec] 43 | 44 | # extent needs to include extra half-pixels. 45 | # RA, DEC are pixel centers 46 | 47 | ext = ( 48 | RA[0] - CDELT1 / 2, 49 | RA[-1] + CDELT1 / 2, 50 | DEC[0] - CDELT2 / 2, 51 | DEC[-1] + CDELT2 / 2, 52 | ) # [arcsec] 53 | 54 | return RA, DEC, ext 55 | 56 | 57 | def get_beam(self, hdu_list, header): 58 | """Get the major and minor widths [arcsec], and position angle, of a 59 | clean beam""" 60 | 61 | if header.get("CASAMBM") is not None: 62 | # Get the beam info from average of record array 63 | data2 = hdu_list[1].data 64 | BMAJ = np.median(data2["BMAJ"]) 65 | BMIN = np.median(data2["BMIN"]) 66 | BPA = np.median(data2["BPA"]) 67 | else: 68 | # Get the beam info from the header, like normal 69 | BMAJ = 3600 * header["BMAJ"] 70 | BMIN = 3600 * header["BMIN"] 71 | BPA = header["BPA"] 72 | 73 | return BMAJ, BMIN, BPA 74 | 75 | 76 | def get_image(self, beam=True): 77 | """Load a .fits image and return as a numpy array. Also return image 78 | extent and optionally (`beam`) the clean beam dimensions""" 79 | 80 | hdu_list = fits.open(self._fits_file) 81 | hdu = hdu_list[0] 82 | 83 | if len(hdu.data.shape) in [3, 4]: 84 | image = hdu.data[self._channel] # first channel 85 | else: 86 | image = hdu.data 87 | 88 | image *= 1e3 89 | 90 | if len(image.shape) == 3: 91 | image = np.squeeze(image) 92 | 93 | header = hdu.header 94 | 95 | RA, DEC, ext = self.get_extent(hdu.header) 96 | 97 | if beam: 98 | return image, ext, self.get_beam(hdu_list, header) 99 | else: 100 | return image, ext 101 | -------------------------------------------------------------------------------- /src/mpol/onedim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from mpol.utils import torch2npy 4 | 5 | 6 | def radialI(icube, geom, chan=0, bins=None): 7 | r""" 8 | Obtain a 1D (radial) brightness profile I(r) from an image cube. 9 | 10 | Parameters 11 | ---------- 12 | icube : `mpol.images.ImageCube` object 13 | Instance of the MPoL `images.ImageCube` class 14 | geom : dict 15 | Dictionary of source geometry. Keys: 16 | "incl" : float, unit=[deg] 17 | Inclination 18 | "Omega" : float, unit=[deg] 19 | Position angle of the ascending node 20 | "omega" : float, unit=[deg] 21 | Argument of periastron 22 | "dRA" : float, unit=[arcsec] 23 | Phase center offset in right ascension. Positive is west of north. 24 | "dDec" : float, unit=[arcsec] 25 | Phase center offset in declination. 26 | chan : int, default=0 27 | Channel of the image cube corresponding to the desired image 28 | bins : array, default=None, unit=[arcsec] 29 | Radial bin edges to use in calculating I(r). If None, bins will span 30 | the full image, with widths equal to the hypotenuse of the pixels 31 | 32 | Returns 33 | ------- 34 | bin_centers : array, unit=[arcsec] 35 | Radial coordinates of image at center of `bins` 36 | Is : array, unit=[Jy / arcsec^2] (if `image` has these units) 37 | Azimuthally averaged pixel brightness at `rs` 38 | """ 39 | 40 | # projected Cartesian pixel coordinates [arcsec] 41 | xx, yy = icube.coords.sky_x_centers_2D, icube.coords.sky_y_centers_2D 42 | 43 | # shift image center to source center 44 | xc, yc = xx - geom["dRA"], yy - geom["dDec"] 45 | 46 | # deproject image 47 | cos_PA = np.cos(geom["Omega"] * np.pi / 180) 48 | sin_PA = np.sin(geom["Omega"] * np.pi / 180) 49 | xd = xc * cos_PA - yc * sin_PA 50 | yd = xc * sin_PA + yc * cos_PA 51 | xd /= np.cos(geom["incl"] * np.pi / 180) 52 | 53 | # deprojected radial coordinates 54 | rr = np.ravel(np.hypot(xd, yd)) 55 | 56 | if bins is None: 57 | # choose sensible bin size and range 58 | step = np.hypot(icube.coords.cell_size, icube.coords.cell_size) 59 | bins = np.arange(0.0, np.max((abs(xc.ravel()), abs(yc.ravel()))), step) 60 | 61 | bin_counts, bin_edges = np.histogram(a=rr, bins=bins, weights=None) 62 | 63 | # cumulative binned brightness in each annulus 64 | Is, _ = np.histogram( 65 | a=rr, bins=bins, weights=torch2npy(icube.sky_cube[chan]).ravel() 66 | ) 67 | 68 | # mask empty bins 69 | mask = bin_counts == 0 70 | Is = np.ma.masked_where(mask, Is) 71 | 72 | # average binned brightness in each annulus 73 | Is /= bin_counts 74 | 75 | bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 76 | 77 | return bin_centers, Is 78 | 79 | 80 | def radialV(fcube, geom, rescale_flux, chan=0, bins=None): 81 | r""" 82 | Obtain the 1D (radial) visibility model V(q) corresponding to a 2D MPoL 83 | image. 84 | 85 | Parameters 86 | ---------- 87 | fcube : `~mpol.fourier.FourierCube` object 88 | Instance of the MPoL `fourier.FourierCube` class 89 | geom : dict 90 | Dictionary of source geometry. Keys: 91 | "incl" : float, unit=[deg] 92 | Inclination 93 | "Omega" : float, unit=[deg] 94 | Position angle of the ascending node 95 | "omega" : float, unit=[deg] 96 | Argument of periastron 97 | "dRA" : float, unit=[arcsec] 98 | Phase center offset in right ascension. Positive is west of north. 99 | "dDec" : float, unit=[arcsec] 100 | Phase center offset in declination 101 | rescale_flux : bool 102 | If True, the visibility amplitudes are rescaled to account 103 | for the difference between the inclined (observed) brightness and the 104 | assumed face-on brightness, assuming the emission is optically thick. 105 | The source's integrated (2D) flux is assumed to be: 106 | :math:`F = \cos(i) \int_r^{r=R}{I(r) 2 \pi r dr}`. 107 | No rescaling would be appropriate in the optically thin limit. 108 | chan : int, default=0 109 | Channel of the image cube corresponding to the desired image 110 | bins : array, default=None, unit=[k\lambda] 111 | Baseline bin edges to use in calculating V(q). If None, bins will span 112 | the model baseline distribution, with widths equal to the hypotenuse of 113 | the (u, v) coordinates 114 | 115 | Returns 116 | ------- 117 | bin_centers : array, unit=:math:[`\lambda`] 118 | Baselines corresponding to `u` and `v` 119 | Vs : array, unit=[Jy] 120 | Visibility amplitudes at `q` 121 | 122 | Notes 123 | ----- 124 | This routine requires the `frank `_ package 125 | """ 126 | from frank.geometry import apply_phase_shift, deproject 127 | 128 | # projected model (u,v) points [k\lambda] 129 | uu, vv = fcube.coords.ground_u_centers_2D, fcube.coords.ground_v_centers_2D 130 | 131 | # visibilities 132 | V = torch2npy(fcube.ground_vis[chan]).ravel() 133 | 134 | # phase-shift the visibilities 135 | Vp = apply_phase_shift( 136 | uu.ravel(), vv.ravel(), V, geom["dRA"], geom["dDec"], inverse=True 137 | ) 138 | 139 | # deproject the (u,v) points 140 | up, vp, _ = deproject( 141 | uu.ravel(), vv.ravel(), geom["incl"], geom["Omega"] 142 | ) 143 | 144 | # if the source is optically thick, rescale the deprojected V(q) 145 | if rescale_flux: 146 | Vp.real /= np.cos(geom["incl"] * np.pi / 180) 147 | 148 | # deprojected baselines 149 | qq = np.hypot(up, vp) 150 | 151 | if bins is None: 152 | # choose sensible bin size and range 153 | step = np.hypot(fcube.coords.du, fcube.coords.dv) / 2 154 | bins = np.arange(0.0, max(qq), step) 155 | 156 | bin_counts, bin_edges = np.histogram(a=qq, bins=bins, weights=None) 157 | 158 | # cumulative binned visibility amplitude in each annulus 159 | Vs, _ = np.histogram(a=qq, bins=bins, weights=Vp) 160 | 161 | # mask empty bins 162 | mask = bin_counts == 0 163 | Vs = np.ma.masked_where(mask, Vs) 164 | 165 | # average binned visibility amplitude in each annulus 166 | Vs /= bin_counts 167 | 168 | bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2 169 | 170 | return bin_centers, Vs 171 | -------------------------------------------------------------------------------- /src/mpol/precomposed.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import torch 4 | 5 | from mpol import fourier, images 6 | from mpol.coordinates import GridCoords 7 | 8 | 9 | class GriddedNet(torch.nn.Module): 10 | r""" 11 | .. note:: 12 | 13 | This module is provided as a starting point. However, we recommend 14 | that you *don't get too comfortable using it* and instead write your own 15 | (custom) modules following PyTorch idioms, potentially 16 | using the source of this routine as a reference point. Using 17 | the torch module system directly is *much more powerful* and expressive. 18 | 19 | A basic but functional network for RML imaging. Designed to optimize a model image 20 | using the entirety of the dataset in a :class:`mpol.datasets.GriddedDataset` 21 | (i.e., gradient descent). For stochastic gradient descent (SGD), where the model 22 | is only seeing a fraction of the dataset with each iteration, we recommend defining 23 | your own module in your analysis code, following the 'Getting Started' guide. 24 | 25 | 26 | .. mermaid:: ../_static/mmd/src/GriddedNet.mmd 27 | 28 | Parameters 29 | ---------- 30 | coords : :class:`mpol.coordinates.GridCoords` 31 | nchan : int 32 | the number of channels in the base cube. Default = 1. 33 | base_cube : :class:`mpol.images.BaseCube` or ``None`` 34 | a pre-packed base cube to initialize the model with. If 35 | None, assumes ``torch.zeros``. 36 | 37 | 38 | After the object is initialized, instance variables can be accessed, for example 39 | 40 | :ivar bcube: the :class:`~mpol.images.BaseCube` instance 41 | :ivar icube: the :class:`~mpol.images.ImageCube` instance 42 | :ivar fcube: the :class:`~mpol.fourier.FourierCube` instance 43 | 44 | For example, you'll likely want to access the ``self.icube.sky_model`` 45 | at some point. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | coords: GridCoords, 51 | nchan: int = 1, 52 | base_cube: typing.Optional[torch.Tensor] = None, 53 | ) -> None: 54 | super().__init__() 55 | 56 | self.coords = coords 57 | self.nchan = nchan 58 | 59 | self.bcube = images.BaseCube( 60 | coords=self.coords, nchan=self.nchan, base_cube=base_cube 61 | ) 62 | 63 | self.conv_layer = images.HannConvCube(nchan=self.nchan) 64 | 65 | self.icube = images.ImageCube(coords=self.coords, nchan=self.nchan) 66 | self.fcube = fourier.FourierCube(coords=self.coords) 67 | self.nufft = fourier.NuFFT(coords=self.coords, nchan=self.nchan) 68 | 69 | def forward(self) -> torch.Tensor: 70 | r""" 71 | Feed forward to calculate the model visibilities. In this step, a 72 | :class:`~mpol.images.BaseCube` is fed to a :class:`~mpol.images.HannConvCube` 73 | is fed to a :class:`~mpol.images.ImageCube` is fed to a 74 | :class:`~mpol.fourier.FourierCube` to produce the visibility cube. 75 | 76 | Returns 77 | ------- 78 | 1D complex torch tensor of model visibilities. 79 | """ 80 | x = self.bcube() 81 | x = self.conv_layer(x) 82 | x = self.icube(x) 83 | vis: torch.Tensor = self.fcube(x) 84 | return vis 85 | 86 | def predict_loose_visibilities( 87 | self, uu: torch.Tensor, vv: torch.Tensor 88 | ) -> torch.Tensor: 89 | r""" 90 | Use the :class:`mpol.fourier.NuFFT` to calculate loose model visibilities from 91 | the cube stored to ``self.icube.packed_cube``. 92 | 93 | Parameters 94 | ---------- 95 | uu : :class:`torch.Tensor` of `class:`torch.double` 96 | array of u spatial frequency coordinates, 97 | not including Hermitian pairs. Units of [:math:`\mathrm{k}\lambda`] 98 | vv : :class:`torch.Tensor` of `class:`torch.double` 99 | array of v spatial frequency coordinates, 100 | not including Hermitian pairs. Units of [:math:`\mathrm{k}\lambda`] 101 | 102 | Returns 103 | ------- 104 | :class:`torch.Tensor` of `class:`torch.complex128` 105 | model visibilities corresponding to ``uu`` and ``vv`` locations. 106 | """ 107 | 108 | model_vis: torch.Tensor = self.nufft(self.icube.packed_cube, uu, vv) 109 | return model_vis 110 | -------------------------------------------------------------------------------- /src/mpol/tests.mplstyle: -------------------------------------------------------------------------------- 1 | image.cmap: inferno 2 | figure.figsize: 7.1, 5.0 3 | figure.autolayout: True 4 | savefig.dpi: 300 -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | # Testing for MPoL 2 | 3 | Testing is carried out with `pytest`. Routines for testing the core `MPoL` functionality are included with this package. For more complicated workflows, additional tests may be implemented in outside packages. 4 | 5 | You can install the package dependencies for testing via 6 | 7 | $ pip install .[test] 8 | 9 | after you've cloned the repository and changed to the root of the repository (`MPoL`). This installs the extra packages required for testing (they are listed in `setup.py`). 10 | 11 | To run all of the tests, from the root of the repository, invoke 12 | 13 | $ python -m pytest 14 | 15 | ## Test cache 16 | 17 | Several of the tests require mock data that is not practical to package within the github repository itself. These files are stored on Zenodo, and for continuous integration tests (e.g., on github workflows), they are downloaded as needed using astropy cache utilities. 18 | 19 | 20 | ## Viewing plots 21 | 22 | Some tests produce temporary files, like plots, that could be useful to view for development or debugging. Normally these are produced to a temporary directory created by the system which will be cleaned up after the tests finish. To preserve them, first create a plot directory and then run the tests with this `basetemp` specified 23 | 24 | $ mkdir plotsdir 25 | $ python -m pytest --basetemp=plotsdir 26 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | from importlib.resources import files 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | import visread.process 7 | from astropy.utils.data import download_file 8 | from mpol import coordinates, fourier, gridding, images, utils 9 | from mpol.__init__ import zenodo_record 10 | 11 | import matplotlib.pyplot as plt 12 | plt.style.use("mpol.tests") 13 | 14 | # private variables to this module 15 | _npz_path = files("mpol.data").joinpath("mock_data.npz") 16 | _nchan = 4 17 | _cell_size = 0.005 18 | 19 | # all of these are fixed quantities that could take a while to load from the 20 | # archive, so we scope them as session 21 | 22 | 23 | @pytest.fixture(scope="session") 24 | def img2D_butterfly(): 25 | """Return the 2D source image of the butterfly, for use as a test image cube.""" 26 | archive = np.load(_npz_path) 27 | img = archive["img"] 28 | 29 | # assuming we're going to go with _cell_size, set the total flux of this image 30 | # total flux should be 0.253 Jy from MPoL-examples. 31 | 32 | return img 33 | 34 | @pytest.fixture(scope="session") 35 | def sky_cube(img2D_butterfly): 36 | """Create a sky tensor image cube from the butterfly.""" 37 | print("npix packed cube", img2D_butterfly.shape) 38 | # tile to some nchan, npix, npix 39 | sky_cube = torch.tile(torch.from_numpy(img2D_butterfly), (_nchan, 1, 1)) 40 | return sky_cube 41 | 42 | @pytest.fixture(scope="session") 43 | def packed_cube(img2D_butterfly): 44 | """Create a packed tensor image cube from the butterfly.""" 45 | print("npix packed cube", img2D_butterfly.shape) 46 | # tile to some nchan, npix, npix 47 | sky_cube = torch.tile(torch.from_numpy(img2D_butterfly), (_nchan, 1, 1)) 48 | # convert to packed format 49 | return utils.sky_cube_to_packed_cube(sky_cube) 50 | 51 | 52 | @pytest.fixture(scope="session") 53 | def baselines_m(): 54 | "Return the mock baselines (in meters) produced from the IM Lup DSHARP dataset." 55 | archive = np.load(_npz_path) 56 | return archive["uu"], archive["vv"] 57 | 58 | 59 | @pytest.fixture(scope="session") 60 | def baselines_1D(baselines_m): 61 | uu, vv = baselines_m 62 | 63 | # lambda for now 64 | uu = visread.process.convert_baselines(uu, 230.0e9) 65 | vv = visread.process.convert_baselines(vv, 230.0e9) 66 | 67 | # convert to torch 68 | return torch.as_tensor(uu), torch.as_tensor(vv) 69 | 70 | 71 | @pytest.fixture(scope="session") 72 | def baselines_2D_np(baselines_m): 73 | uu, vv = baselines_m 74 | 75 | u_lam, v_lam = visread.process.broadcast_and_convert_baselines( 76 | uu, vv, np.linspace(230.0, 231.0, num=_nchan) * 1e9 77 | ) 78 | 79 | return u_lam, v_lam 80 | 81 | 82 | @pytest.fixture(scope="session") 83 | def baselines_2D_t(baselines_2D_np): 84 | uu, vv = baselines_2D_np 85 | return torch.as_tensor(uu), torch.as_tensor(vv) 86 | 87 | 88 | @pytest.fixture(scope="session") 89 | def weight_1D_np(): 90 | archive = np.load(_npz_path) 91 | return np.float64(archive["weight"]) 92 | 93 | 94 | @pytest.fixture(scope="session") 95 | def weight_2D_t(baselines_2D_t, weight_1D_np): 96 | weight1D_t = torch.as_tensor(weight_1D_np) 97 | uu, vv = baselines_2D_t 98 | weight = torch.broadcast_to(weight1D_t, uu.size()) 99 | return weight 100 | 101 | 102 | @pytest.fixture(scope="session") 103 | def coords(img2D_butterfly): 104 | npix, _ = img2D_butterfly.shape 105 | # note that this is now the same as the mock image we created 106 | return coordinates.GridCoords(cell_size=_cell_size, npix=npix) 107 | 108 | 109 | @pytest.fixture(scope="session") 110 | def mock_data_t(baselines_2D_t, packed_cube, coords, weight_2D_t): 111 | uu, vv = baselines_2D_t 112 | data, _ = fourier.generate_fake_data(packed_cube, coords, uu, vv, weight_2D_t) 113 | return data 114 | 115 | 116 | @pytest.fixture(scope="session") 117 | def mock_dataset_np(baselines_2D_np, weight_2D_t, mock_data_t): 118 | uu, vv = baselines_2D_np 119 | weight = utils.torch2npy(weight_2D_t) 120 | data = utils.torch2npy(mock_data_t) 121 | data_re = np.real(data) 122 | data_im = np.imag(data) 123 | 124 | return (uu, vv, weight, data_re, data_im) 125 | 126 | 127 | @pytest.fixture(scope="session") 128 | def dataset(mock_dataset_np, coords): 129 | uu, vv, weight, data_re, data_im = mock_dataset_np 130 | 131 | averager = gridding.DataAverager( 132 | coords=coords, 133 | uu=uu, 134 | vv=vv, 135 | weight=weight, 136 | data_re=data_re, 137 | data_im=data_im, 138 | ) 139 | 140 | return averager.to_pytorch_dataset() 141 | 142 | 143 | @pytest.fixture 144 | def dataset_cont(mock_dataset_np, coords): 145 | uu, vv, weight, data_re, data_im = mock_dataset_np 146 | # select only the 0th channel of each 147 | averager = gridding.DataAverager( 148 | coords=coords, 149 | uu=uu[0], 150 | vv=vv[0], 151 | weight=weight[0], 152 | data_re=data_re[0], 153 | data_im=data_im[0], 154 | ) 155 | 156 | return averager.to_pytorch_dataset() 157 | 158 | 159 | @pytest.fixture(scope="session") 160 | def mock_1d_archive(): 161 | # use astropy routines to cache data 162 | fname = download_file( 163 | f"https://zenodo.org/record/{zenodo_record}/files/mock_disk_1d.npz", 164 | cache=True, 165 | pkgname="mpol", 166 | ) 167 | 168 | return np.load(fname, allow_pickle=True) 169 | 170 | 171 | @pytest.fixture 172 | def mock_1d_image_model(mock_1d_archive): 173 | m = mock_1d_archive 174 | rtrue = m["rtrue"] 175 | itrue = m["itrue"] 176 | i2dtrue = m["i2dtrue"] 177 | xmax = ymax = m["xmax"] 178 | geom = m["geometry"] 179 | geom = geom[()] 180 | 181 | coords = coordinates.GridCoords( 182 | cell_size=xmax * 2 / i2dtrue.shape[0], npix=i2dtrue.shape[0] 183 | ) 184 | 185 | # the center of the array is already at the center of the image --> 186 | # undo this as expected by input to ImageCube 187 | i2dtrue = np.flip(np.fft.fftshift(i2dtrue), 1) 188 | 189 | # pack the numpy image array into an ImageCube 190 | packed_cube = np.broadcast_to(i2dtrue, (1, coords.npix, coords.npix)).copy() 191 | packed_tensor = torch.from_numpy(packed_cube) 192 | bcube = images.BaseCube( 193 | coords=coords, nchan=1, base_cube=packed_tensor, pixel_mapping=lambda x: x 194 | ) 195 | cube_true = images.ImageCube(coords=coords, nchan=1) 196 | # register cube to buffer inside cube_true.cube 197 | cube_true(bcube()) 198 | 199 | return rtrue, itrue, cube_true, xmax, ymax, geom 200 | 201 | 202 | @pytest.fixture 203 | def mock_1d_vis_model(mock_1d_archive): 204 | m = mock_1d_archive 205 | i2dtrue = m["i2dtrue"] 206 | xmax = m["xmax"] 207 | geom = m["geometry"] 208 | geom = geom[()] 209 | 210 | Vtrue_dep = m["vis_dep"] 211 | q_dep = m["baselines_dep"] 212 | 213 | coords = coordinates.GridCoords( 214 | cell_size=xmax * 2 / i2dtrue.shape[0], npix=i2dtrue.shape[0] 215 | ) 216 | 217 | # the center of the array is already at the center of the image --> 218 | # undo this as expected by input to ImageCube 219 | i2dtrue = np.flip(np.fft.fftshift(i2dtrue), 1) 220 | 221 | # pack the numpy image array into an ImageCube 222 | packed_cube = np.broadcast_to(i2dtrue, (1, coords.npix, coords.npix)).copy() 223 | packed_tensor = torch.from_numpy(packed_cube) 224 | bcube = images.BaseCube( 225 | coords=coords, nchan=1, base_cube=packed_tensor, pixel_mapping=lambda x: x 226 | ) 227 | cube_true = images.ImageCube(coords=coords, nchan=1) 228 | 229 | # register image 230 | cube_true(bcube()) 231 | 232 | # create a FourierCube 233 | fcube_true = fourier.FourierCube(coords=coords) 234 | 235 | # take FT of icube to populate fcube 236 | fcube_true.forward(cube_true.sky_cube) 237 | 238 | # insert the vis tensor into the FourierCube ('vis' would typically be 239 | # populated by taking the FFT of an image) 240 | # packed_fcube = np.broadcast_to(Vtrue, (1, len(Vtrue))).copy() 241 | # packed_ftensor = torch.from_numpy(packed_cube) 242 | # fcube_true.ground_cube = packed_tensor 243 | 244 | return fcube_true, Vtrue_dep, q_dep, geom 245 | 246 | 247 | @pytest.fixture 248 | def generic_parameters(tmp_path): 249 | # generic model parameters to test training loop and cross-val loop 250 | regularizers = { 251 | "entropy": {"lambda": 1e-3, "guess": False, "prior_intensity": 1e-10}, 252 | } 253 | 254 | train_pars = { 255 | "epochs": 15, 256 | "convergence_tol": 1e-3, 257 | "regularizers": regularizers, 258 | "train_diag_step": None, 259 | "save_prefix": tmp_path, 260 | "verbose": True, 261 | } 262 | 263 | crossval_pars = train_pars.copy() 264 | crossval_pars["learn_rate"] = 0.5 265 | crossval_pars["kfolds"] = 2 266 | crossval_pars["split_method"] = "random_cell" 267 | crossval_pars["seed"] = 47 268 | crossval_pars["split_diag_fig"] = False 269 | crossval_pars["store_cv_diagnostics"] = True 270 | crossval_pars["device"] = None 271 | 272 | gen_pars = {"train_pars": train_pars, "crossval_pars": crossval_pars} 273 | 274 | return gen_pars 275 | -------------------------------------------------------------------------------- /test/coordinates_test.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pytest 4 | import torch 5 | from mpol import constants, coordinates 6 | from mpol.exceptions import CellSizeError 7 | 8 | 9 | def test_grid_coords_instantiate(): 10 | coordinates.GridCoords(cell_size=0.01, npix=512) 11 | 12 | 13 | def test_grid_coords_equal(): 14 | coords1 = coordinates.GridCoords(cell_size=0.01, npix=512) 15 | coords2 = coordinates.GridCoords(cell_size=0.01, npix=512) 16 | 17 | assert coords1 == coords2 18 | 19 | 20 | def test_grid_coords_unequal_pix(): 21 | coords1 = coordinates.GridCoords(cell_size=0.01, npix=510) 22 | coords2 = coordinates.GridCoords(cell_size=0.01, npix=512) 23 | 24 | assert coords1 != coords2 25 | 26 | 27 | def test_grid_coords_unequal_cell_size(): 28 | coords1 = coordinates.GridCoords(cell_size=0.011, npix=512) 29 | coords2 = coordinates.GridCoords(cell_size=0.01, npix=512) 30 | 31 | assert coords1 != coords2 32 | 33 | 34 | def test_grid_coords_plot_2D_uvq_sky(tmp_path): 35 | coords = coordinates.GridCoords(cell_size=0.005, npix=800) 36 | 37 | ikw = {"origin": "lower"} 38 | 39 | fig, ax = plt.subplots(nrows=1, ncols=3) 40 | im = ax[0].imshow(coords.ground_u_centers_2D, **ikw) 41 | plt.colorbar(im, ax=ax[0]) 42 | 43 | im = ax[1].imshow(coords.ground_v_centers_2D, **ikw) 44 | plt.colorbar(im, ax=ax[1]) 45 | 46 | im = ax[2].imshow(coords.ground_q_centers_2D, **ikw) 47 | plt.colorbar(im, ax=ax[2]) 48 | 49 | for a, t in zip(ax, ["u", "v", "q"], strict=False): 50 | a.set_title(t) 51 | 52 | fig.savefig(tmp_path / "sky_uvq.png", dpi=300) 53 | 54 | 55 | def test_grid_coords_plot_2D_uvq_packed(tmp_path): 56 | coords = coordinates.GridCoords(cell_size=0.005, npix=800) 57 | 58 | ikw = {"origin": "lower"} 59 | 60 | fig, ax = plt.subplots(nrows=1, ncols=3) 61 | im = ax[0].imshow(coords.packed_u_centers_2D, **ikw) 62 | plt.colorbar(im, ax=ax[0]) 63 | 64 | im = ax[1].imshow(coords.packed_v_centers_2D, **ikw) 65 | plt.colorbar(im, ax=ax[1]) 66 | 67 | im = ax[2].imshow(coords.packed_q_centers_2D, **ikw) 68 | plt.colorbar(im, ax=ax[2]) 69 | 70 | for a, t in zip(ax, ["u", "v", "q"], strict=False): 71 | a.set_title(t) 72 | 73 | fig.savefig(tmp_path / "packed_uvq.png", dpi=300) 74 | 75 | 76 | def test_grid_coords_odd_fail(): 77 | with pytest.raises( 78 | ValueError, match="Image must have a positive and even number of pixels." 79 | ): 80 | coordinates.GridCoords(cell_size=0.01, npix=511) 81 | 82 | 83 | def test_grid_coords_neg_cell_size(): 84 | with pytest.raises(ValueError, match="cell_size must be a positive real number."): 85 | coordinates.GridCoords(cell_size=-0.01, npix=512) 86 | 87 | 88 | # instantiate a DataAverager object with mock visibilities 89 | def test_grid_coords_fit(baselines_2D_np, baselines_2D_t): 90 | coords = coordinates.GridCoords(cell_size=0.005, npix=800) 91 | 92 | uu, vv = baselines_2D_np 93 | coords.check_data_fit(uu, vv) 94 | 95 | uu, vv = baselines_2D_t 96 | coords.check_data_fit(uu, vv) 97 | 98 | 99 | def test_grid_coords_fail(baselines_2D_np, baselines_2D_t): 100 | coords = coordinates.GridCoords(cell_size=0.05, npix=800) 101 | 102 | uu, vv = baselines_2D_np 103 | print("max u data", np.max(uu)) 104 | print("max u grid", coords.max_uv_grid_value) 105 | with pytest.raises(CellSizeError): 106 | coords.check_data_fit(uu, vv) 107 | 108 | uu, vv = baselines_2D_t 109 | print("max u data", torch.max(uu)) 110 | print("max u grid", coords.max_uv_grid_value) 111 | with pytest.raises(CellSizeError): 112 | coords.check_data_fit(uu, vv) 113 | 114 | 115 | def test_tile_vs_meshgrid_implementation(): 116 | coords = coordinates.GridCoords(cell_size=0.05, npix=800) 117 | 118 | x_centers_2d, y_centers_2d = np.meshgrid( 119 | coords.l_centers / constants.arcsec, coords.m_centers / constants.arcsec, indexing="xy" 120 | ) 121 | 122 | ground_u_centers_2D, ground_v_centers_2D = np.meshgrid( 123 | coords.u_centers, coords.v_centers, indexing="xy" 124 | ) 125 | 126 | assert np.all(coords.ground_u_centers_2D == ground_u_centers_2D) 127 | assert np.all(coords.ground_v_centers_2D == ground_v_centers_2D) 128 | assert np.all(coords.x_centers_2D == x_centers_2d) 129 | assert np.all(coords.y_centers_2D == y_centers_2d) 130 | 131 | def test_coords_mock_image(coords, img2D_butterfly): 132 | npix, _ = img2D_butterfly.shape 133 | assert coords.npix == npix, "coords dimensions and mock image have different sizes" 134 | -------------------------------------------------------------------------------- /test/crossval_test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | # from mpol.crossval import CrossValidate, DartboardSplitGridded, RandomCellSplitGridded 7 | from mpol.crossval import DartboardSplitGridded, RandomCellSplitGridded 8 | from mpol.datasets import Dartboard 9 | 10 | # def test_crossvalclass_split_dartboard(coords, imager, dataset, generic_parameters): 11 | # # using the CrossValidate class, split a dataset into train/test subsets 12 | # # using 'dartboard' splitter 13 | 14 | # crossval_pars = generic_parameters["crossval_pars"] 15 | # crossval_pars["split_method"] = "dartboard" 16 | 17 | # cross_validator = CrossValidate(coords, imager, **crossval_pars) 18 | # cross_validator.split_dataset(dataset) 19 | 20 | 21 | # def test_crossvalclass_split_dartboard_1kfold( 22 | # coords, imager, dataset, generic_parameters 23 | # ): 24 | # # using the CrossValidate class, split a dataset into train/test subsets 25 | # # using 'dartboard' splitter with only 1 k-fold; check that the train set 26 | # # has ~80% of the model visibilities 27 | 28 | # crossval_pars = generic_parameters["crossval_pars"] 29 | # crossval_pars["split_method"] = "dartboard" 30 | # crossval_pars["kfolds"] = 1 31 | 32 | # cross_validator = CrossValidate(coords, imager, **crossval_pars) 33 | # split_iterator = cross_validator.split_dataset(dataset) 34 | 35 | # for train_set, test_set in split_iterator: 36 | # ntrain = len(train_set.vis_indexed) 37 | # ntest = len(test_set.vis_indexed) 38 | 39 | # ratio = ntrain / (ntrain + ntest) 40 | 41 | # np.testing.assert_allclose(ratio, 0.8, atol=0.05) 42 | 43 | 44 | # def test_crossvalclass_split_randomcell(coords, imager, dataset, generic_parameters): 45 | # # using the CrossValidate class, split a dataset into train/test subsets 46 | # # using 'random_cell' splitter 47 | 48 | # crossval_pars = generic_parameters["crossval_pars"] 49 | # cross_validator = CrossValidate(coords, imager, **crossval_pars) 50 | # cross_validator.split_dataset(dataset) 51 | 52 | 53 | # def test_crossvalclass_split_diagnostics_fig( 54 | # coords, imager, dataset, generic_parameters, tmp_path 55 | # ): 56 | # # using the CrossValidate class, split a dataset into train/test subsets 57 | # # using 'random_cell' splitter, then generate the split diagnostic figure 58 | 59 | # crossval_pars = generic_parameters["crossval_pars"] 60 | # cross_validator = CrossValidate(coords, imager, **crossval_pars) 61 | # split_iterator = cross_validator.split_dataset(dataset) 62 | # split_fig, split_axes = split_diagnostics_fig(split_iterator) 63 | # split_fig.savefig(tmp_path / "split_diagnostics_fig.png", dpi=300) 64 | # plt.close("all") 65 | 66 | 67 | # def test_crossvalclass_kfold(coords, imager, dataset, generic_parameters): 68 | # # using the CrossValidate class, perform k-fold cross-validation 69 | 70 | # crossval_pars = generic_parameters["crossval_pars"] 71 | # # reset some keys to bypass functionality tested elsewhere and speed up test 72 | # crossval_pars["regularizers"] = {} 73 | # crossval_pars["epochs"] = 11 74 | 75 | # cross_validator = CrossValidate(coords, imager, **crossval_pars) 76 | # cross_validator.run_crossval(dataset) 77 | 78 | 79 | def test_randomcellsplit(dataset, generic_parameters): 80 | pars = generic_parameters["crossval_pars"] 81 | RandomCellSplitGridded(dataset, pars["kfolds"], pars["seed"]) 82 | 83 | 84 | def test_dartboardsplit_init(coords, dataset): 85 | dartboard = Dartboard(coords=coords) 86 | 87 | # create cross validator through passing dartboard 88 | DartboardSplitGridded(dataset, 5, dartboard=dartboard) 89 | 90 | # create cross validator through implicit creation of dartboard 91 | DartboardSplitGridded(dataset, 5) 92 | 93 | 94 | def test_hermitian_mask_k(coords, dataset, tmp_path): 95 | dartboard = Dartboard(coords=coords) 96 | chan = 1 97 | 98 | # split these into k samples 99 | k = 5 100 | cv = DartboardSplitGridded(dataset, k, dartboard=dartboard) 101 | 102 | # get the split list indices 103 | indices_l0 = cv.k_split_cell_list[0] 104 | 105 | # create a new mask from this 106 | dartboard_mask = dartboard.build_grid_mask_from_cells(indices_l0) 107 | 108 | # use this mask to index the dataset 109 | masked_dataset = copy.deepcopy(dataset) 110 | masked_dataset.add_mask(dartboard_mask) 111 | 112 | # get updated q and phi values 113 | qs = masked_dataset.coords.packed_q_centers_2D[masked_dataset.mask[chan]] 114 | phis = masked_dataset.coords.packed_phi_centers_2D[masked_dataset.mask[chan]] 115 | 116 | ind = phis <= np.pi 117 | 118 | fig, ax = plt.subplots(nrows=1) 119 | 120 | ax.plot(qs[ind], phis[ind], "o", ms=3) 121 | ax.plot(qs[~ind], phis[~ind] - np.pi, "o", ms=1) 122 | fig.savefig(tmp_path / "hermitian.png", dpi=300) 123 | 124 | 125 | def test_dartboardsplit_iterate_masks(coords, dataset, tmp_path): 126 | dartboard = Dartboard(coords=coords) 127 | 128 | # create cross validator through passing dartboard 129 | k = 5 130 | chan = 1 131 | cv = DartboardSplitGridded(dataset, k, dartboard=dartboard) 132 | 133 | fig, ax = plt.subplots(nrows=k, ncols=2, figsize=(6, 12)) 134 | 135 | for k, (train, test) in enumerate(cv): 136 | ax[k, 0].imshow( 137 | np.fft.fftshift(train.mask[chan].detach().numpy()), 138 | interpolation="none", 139 | ) 140 | ax[k, 1].imshow( 141 | np.fft.fftshift(test.mask[chan].detach().numpy()), 142 | interpolation="none", 143 | ) 144 | 145 | ax[0, 0].set_title("train") 146 | ax[0, 1].set_title("test") 147 | fig.savefig(tmp_path / "masks", dpi=300) 148 | -------------------------------------------------------------------------------- /test/datasets_test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from mpol import datasets, fourier, images 8 | 9 | 10 | def test_index(coords, dataset): 11 | # test that we can index a dataset 12 | 13 | flayer = fourier.FourierCube(coords=coords) 14 | 15 | # create a mock cube that includes negative values 16 | nchan = dataset.nchan 17 | mean = torch.full( 18 | (nchan, coords.npix, coords.npix), fill_value=-0.5) 19 | std = torch.full( 20 | (nchan, coords.npix, coords.npix), fill_value=0.5) 21 | 22 | # tensor 23 | base_cube = torch.normal(mean=mean, std=std) 24 | 25 | # layer 26 | basecube = images.BaseCube(coords=coords, nchan=nchan, base_cube=base_cube) 27 | 28 | # try passing through ImageLayer 29 | imagecube = images.ImageCube(coords=coords, nchan=nchan) 30 | 31 | # produce dense model visibility cube 32 | modelVisibilityCube = flayer(imagecube(basecube())) 33 | 34 | # take a basecube, imagecube, and GriddedDataset and predict corresponding visibilities. 35 | dataset(modelVisibilityCube) 36 | 37 | 38 | def test_loss_grad(coords, dataset): 39 | # test that we can calculate the gradients through the loss 40 | 41 | flayer = fourier.FourierCube(coords=coords) 42 | nchan = dataset.nchan 43 | basecube = images.BaseCube(coords=coords, nchan=nchan) 44 | imagecube = images.ImageCube(coords=coords, nchan=nchan) 45 | 46 | # produce model visibilities 47 | modelVisibilityCube = flayer(imagecube(basecube())) 48 | samples = dataset(modelVisibilityCube) 49 | 50 | print(samples) 51 | loss = torch.sum(torch.abs(samples)) 52 | 53 | # segfaults on 3.9 54 | # https://github.com/pytorch/pytorch/issues/50014 55 | loss.backward() 56 | 57 | print(basecube.base_cube.grad) 58 | 59 | 60 | def test_dataset_device(dataset): 61 | # if we have a GPU available, test that we can send a dataset to it 62 | 63 | if torch.cuda.is_available(): 64 | dataset = dataset.to("cuda") 65 | dataset = dataset.to("cpu") 66 | else: 67 | pass 68 | 69 | 70 | def test_mask_dataset(dataset): 71 | updated_mask = np.ones_like(dataset.coords.packed_u_centers_2D) 72 | dataset.add_mask(updated_mask) 73 | 74 | 75 | def test_dartboard_init(coords): 76 | datasets.Dartboard(coords=coords) 77 | 78 | 79 | def test_dartboard_histogram(coords, dataset, tmp_path): 80 | 81 | # use default bins 82 | dartboard = datasets.Dartboard(coords=coords) 83 | 84 | # 2D mask for any UV cells that contain visibilities 85 | # in *any* channel 86 | stacked_mask = np.any(dataset.mask.detach().numpy(), axis=0) 87 | 88 | # get qs, phis from dataset and turn into 1D lists 89 | qs = dataset.coords.packed_q_centers_2D[stacked_mask] 90 | phis = dataset.coords.packed_phi_centers_2D[stacked_mask] 91 | 92 | # use dartboard to calculate histogram 93 | H = dartboard.get_polar_histogram(qs, phis) 94 | 95 | fig, ax = plt.subplots(subplot_kw={"projection": "polar"}) 96 | 97 | cmap = copy.copy(matplotlib.colormaps["plasma"]) 98 | cmap.set_under("w") 99 | norm = matplotlib.colors.LogNorm(vmin=1) 100 | 101 | ax.grid(False) 102 | im = ax.pcolormesh( 103 | dartboard.phi_edges, 104 | dartboard.q_edges, 105 | H, 106 | shading="flat", 107 | norm=norm, 108 | cmap=cmap, 109 | zorder=-90, 110 | ) 111 | plt.colorbar(im, ax=ax) 112 | 113 | ax.scatter(phis, qs, s=1.5, rasterized=True, linewidths=0.0, c="k", alpha=0.3) 114 | ax.set_ylim(top=2500) 115 | 116 | fig.savefig(tmp_path / "dartboard.png", dpi=300) 117 | 118 | plt.close("all") 119 | 120 | 121 | def test_dartboard_nonzero(coords, dataset, tmp_path): 122 | 123 | # use default bins 124 | dartboard = datasets.Dartboard(coords=coords) 125 | 126 | # 2D mask for any UV cells that contain visibilities 127 | # in *any* channel 128 | stacked_mask = np.any(dataset.mask.detach().numpy(), axis=0) 129 | 130 | # get qs, phis from dataset and turn into 1D lists 131 | qs = dataset.coords.packed_q_centers_2D[stacked_mask] 132 | phis = dataset.coords.packed_phi_centers_2D[stacked_mask] 133 | 134 | # use dartboard to calculate nonzero cells 135 | indices = dartboard.get_nonzero_cell_indices(qs, phis) 136 | 137 | fig, ax = plt.subplots(nrows=1) 138 | 139 | ax.scatter(*indices.T, s=1.5, rasterized=True, linewidths=0.0, c="k") 140 | ax.set_xlabel("q index") 141 | ax.set_ylabel("phi index") 142 | 143 | fig.savefig(tmp_path / "indices.png", dpi=300) 144 | 145 | plt.close("all") 146 | 147 | 148 | def test_dartboard_mask(coords, dataset, tmp_path): 149 | # use default bins 150 | dartboard = datasets.Dartboard(coords=coords) 151 | 152 | # 2D mask for any UV cells that contain visibilities 153 | # in *any* channel 154 | stacked_mask = np.any(dataset.mask.detach().numpy(), axis=0) 155 | 156 | # get qs, phis from dataset and turn into 1D lists 157 | qs = dataset.coords.packed_q_centers_2D[stacked_mask] 158 | phis = dataset.coords.packed_phi_centers_2D[stacked_mask] 159 | 160 | # use dartboard to calculate nonzero cells 161 | indices = dartboard.get_nonzero_cell_indices(qs, phis) 162 | print(indices) 163 | 164 | # get boolean mask from cell indices 165 | mask = np.fft.fftshift(dartboard.build_grid_mask_from_cells(indices)) 166 | 167 | fig, ax = plt.subplots(nrows=1) 168 | 169 | ax.imshow(mask, origin="lower", interpolation="none") 170 | fig.savefig(tmp_path / "mask.png", dpi=300) 171 | 172 | plt.close("all") 173 | 174 | 175 | def test_hermitian_mask_full(coords, dataset, tmp_path): 176 | 177 | dartboard = datasets.Dartboard(coords=coords) 178 | 179 | chan = 1 180 | 181 | # do the indexing of individual points 182 | # plot up as function of q, phi, each point should have two dots (opacity layer). 183 | # make the 184 | 185 | # 2D mask for any UV cells that contain visibilities 186 | # in *any* channel 187 | mask = dataset.mask[chan].detach().numpy() 188 | 189 | # get qs, phis from dataset and turn into 1D lists 190 | qs = dataset.coords.packed_q_centers_2D[mask] 191 | phis = dataset.coords.packed_phi_centers_2D[mask] 192 | 193 | # use dartboard to calculate nonzero cells between 0 and pi 194 | indices = dartboard.get_nonzero_cell_indices(qs, phis) 195 | 196 | # create a new mask from this 197 | dartboard_mask = dartboard.build_grid_mask_from_cells(indices) 198 | 199 | # use this mask to index the dataset 200 | masked_dataset = copy.deepcopy(dataset) 201 | masked_dataset.add_mask(dartboard_mask) 202 | 203 | # get updated q and phi values 204 | qs = masked_dataset.coords.packed_q_centers_2D[masked_dataset.mask[chan]] 205 | phis = masked_dataset.coords.packed_phi_centers_2D[masked_dataset.mask[chan]] 206 | 207 | ind = phis <= np.pi 208 | 209 | fig, ax = plt.subplots(nrows=1) 210 | 211 | ax.plot(qs[ind], phis[ind], "o", ms=3) 212 | ax.plot(qs[~ind], phis[~ind] - np.pi, "o", ms=1) 213 | fig.savefig(tmp_path / "hermitian.png", dpi=300) 214 | -------------------------------------------------------------------------------- /test/fftshift_test.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def test_mpol_fftshift(tmp_path): 7 | 8 | # create a fake image 9 | xx, yy = np.mgrid[0:20, 0:20] 10 | image_init = xx + 2 * yy 11 | 12 | # initialize it as a torch matrix 13 | image_torch = torch.tensor(image_init) 14 | 15 | # try torch fftshift to see if we understand it correctly 16 | shifted_torch = torch.fft.fftshift(image_torch, dim=(1,)) 17 | shifted_numpy = np.fft.fftshift(image_init, axes=1) 18 | 19 | fig, ax = plt.subplots(ncols=3) 20 | 21 | # compare to actual fftshift and diff 22 | ax[0].imshow(shifted_numpy, origin="upper") 23 | ax[1].imshow(shifted_torch.detach().numpy(), origin="upper") 24 | ax[2].imshow(shifted_numpy - shifted_torch.detach().numpy(), origin="upper") 25 | fig.savefig(str(tmp_path / "fftshift.png")) 26 | 27 | assert np.allclose( 28 | shifted_numpy, shifted_torch.detach().numpy() 29 | ), "fftshifts do not match" 30 | -------------------------------------------------------------------------------- /test/geometry_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from mpol import geometry 4 | from pytest import approx 5 | 6 | 7 | def test_rotate_points(): 8 | """ 9 | Test rotation from flat 2D frame to observer frame and back 10 | """ 11 | xs = torch.tensor([0.0, 1.0, 2.0]) 12 | ys = torch.tensor([1.0, -1.0, 2.0]) 13 | 14 | omega = 35. * np.pi/180 15 | incl = 30. * np.pi/180 16 | Omega = 210. * np.pi/180 17 | 18 | X, Y = geometry.flat_to_observer(xs, ys, omega=omega, incl=incl, Omega=Omega) 19 | 20 | xs_back, ys_back = geometry.observer_to_flat(X, Y, omega=omega, incl=incl, Omega=Omega) 21 | 22 | 23 | print("original", xs, ys) 24 | print("Observer", X, Y) 25 | print("return", xs_back, ys_back) 26 | 27 | assert xs == approx(xs_back, abs=1e-6) 28 | assert ys == approx(ys_back, abs=1e-6) 29 | 30 | 31 | def test_rotate_coords(coords): 32 | 33 | omega = 35. * np.pi/180 34 | incl = 30. * np.pi/180 35 | Omega = 210. * np.pi/180 36 | 37 | x, y = geometry.observer_to_flat(coords.sky_x_centers_2D, coords.sky_y_centers_2D, omega=omega, incl=incl, Omega=Omega) 38 | 39 | print(x, y) 40 | 41 | 42 | -------------------------------------------------------------------------------- /test/gridder_dataset_export_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from mpol import coordinates, gridding 4 | 5 | 6 | def test_cell_variance_error_pytorch(mock_dataset_np): 7 | """ 8 | Test that the gridder routine errors if we send it data that has the wrong scatter relative to the weight values. 9 | """ 10 | coords = coordinates.GridCoords(cell_size=0.01, npix=400) 11 | 12 | uu, vv, weight, data_re, data_im = mock_dataset_np 13 | sigma = np.sqrt(1 / weight) 14 | data_re = np.ones_like(uu) + np.random.normal(loc=0, scale=2 * sigma, size=uu.shape) 15 | data_im = np.zeros_like(uu) + np.random.normal( 16 | loc=0, scale=2 * sigma, size=uu.shape 17 | ) 18 | 19 | averager = gridding.DataAverager( 20 | coords=coords, 21 | uu=uu, 22 | vv=vv, 23 | weight=weight, 24 | data_re=data_re, 25 | data_im=data_im, 26 | ) 27 | 28 | with pytest.raises(RuntimeError): 29 | averager.to_pytorch_dataset() 30 | -------------------------------------------------------------------------------- /test/gridder_gridding_test.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pytest 4 | from mpol import coordinates, gridding 5 | 6 | 7 | def test_average_cont(coords, mock_dataset_np): 8 | """ 9 | Test that the gridding operation doesn't error if provided 'continuum'-like 10 | quantities (single-channel). 11 | """ 12 | uu, vv, weight, data_re, data_im = mock_dataset_np 13 | 14 | chan = 0 15 | 16 | averager = gridding.DataAverager( 17 | coords=coords, 18 | uu=uu[chan], 19 | vv=vv[chan], 20 | weight=weight[chan], 21 | data_re=data_re[chan], 22 | data_im=data_im[chan], 23 | ) 24 | 25 | print(averager.uu.shape) 26 | print(averager.nchan) 27 | 28 | averager._grid_visibilities() 29 | 30 | 31 | # test that we're getting the right numbers back for some well defined operations 32 | def test_uniform_ones(mock_dataset_np, tmp_path): 33 | """ 34 | Test that we can grid average a set of visibilities that are just 1. 35 | We should get back entirely 1s. 36 | """ 37 | 38 | coords = coordinates.GridCoords(cell_size=0.005, npix=800) 39 | 40 | uu, vv, weight, data_re, data_im = mock_dataset_np 41 | weight = 0.1 * np.ones_like(uu) 42 | data_re = np.ones_like(uu) 43 | data_im = np.zeros_like(uu) 44 | 45 | averager = gridding.DataAverager( 46 | coords=coords, 47 | uu=uu, 48 | vv=vv, 49 | weight=weight, 50 | data_re=data_re, 51 | data_im=data_im, 52 | ) 53 | 54 | # with uniform weighting, the gridded values should be == 1 55 | averager._grid_visibilities() 56 | 57 | im = plt.imshow( 58 | averager.ground_cube[1].real, 59 | origin="lower", 60 | extent=averager.coords.vis_ext, 61 | interpolation="none", 62 | ) 63 | plt.colorbar(im) 64 | plt.savefig(tmp_path / "gridded_re.png", dpi=300) 65 | 66 | plt.figure() 67 | 68 | im2 = plt.imshow( 69 | averager.ground_cube[0].imag, 70 | origin="lower", 71 | extent=averager.coords.vis_ext, 72 | interpolation="none", 73 | ) 74 | plt.colorbar(im2) 75 | plt.savefig(tmp_path / "gridded_im.png", dpi=300) 76 | 77 | plt.close("all") 78 | 79 | # if the gridding worked, 80 | # cells with no data should be 0 81 | assert averager.data_re_gridded[~averager.mask] == pytest.approx(0) 82 | 83 | # and cells with data should have real values approximately 1 84 | assert averager.data_re_gridded[averager.mask] == pytest.approx(1) 85 | 86 | # and imaginary values approximately 0 everywhere 87 | assert averager.data_im_gridded == pytest.approx(0) 88 | 89 | 90 | def test_weight_gridding(mock_dataset_np): 91 | uu, vv, weight, data_re, data_im = mock_dataset_np 92 | 93 | # initialize random (positive) weight values 94 | weight = np.random.uniform(low=0.01, high=0.1, size=uu.shape) 95 | data_re = np.ones_like(uu) 96 | data_im = np.ones_like(uu) 97 | 98 | coords = coordinates.GridCoords(cell_size=0.005, npix=800) 99 | averager = gridding.DataAverager( 100 | coords=coords, 101 | uu=uu, 102 | vv=vv, 103 | weight=weight, 104 | data_re=data_re, 105 | data_im=data_im, 106 | ) 107 | 108 | averager._grid_weights() 109 | 110 | print("sum of ungridded weights", np.sum(weight)) 111 | 112 | # test that the weights all sum to the same value 113 | print("sum of gridded weights", np.sum(averager.weight_gridded)) 114 | 115 | assert np.sum(weight) == pytest.approx(np.sum(averager.weight_gridded)) 116 | 117 | 118 | # test the standard deviation estimation routines 119 | def test_estimate_stddev(mock_dataset_np, tmp_path): 120 | coords = coordinates.GridCoords(cell_size=0.01, npix=400) 121 | 122 | uu, vv, weight, data_re, data_im = mock_dataset_np 123 | weight = 0.1 * np.ones_like(uu) 124 | sigma = np.sqrt(1 / weight) 125 | data_re = np.ones_like(uu) + np.random.normal(loc=0, scale=sigma, size=uu.shape) 126 | data_im = np.zeros_like(uu) + np.random.normal(loc=0, scale=sigma, size=uu.shape) 127 | 128 | averager = gridding.DataAverager( 129 | coords=coords, 130 | uu=uu, 131 | vv=vv, 132 | weight=weight, 133 | data_re=data_re, 134 | data_im=data_im, 135 | ) 136 | 137 | s_re, s_im = averager._estimate_cell_standard_deviation() 138 | 139 | chan = 0 140 | 141 | fig, ax = plt.subplots(ncols=2, figsize=(7, 4)) 142 | 143 | im = ax[0].imshow(s_re[chan], origin="lower", extent=averager.coords.vis_ext) 144 | ax[0].set_title(r"$s_{i,j}$ real") 145 | plt.colorbar(im, ax=ax[0]) 146 | 147 | im = ax[1].imshow(s_im[chan], origin="lower", extent=averager.coords.vis_ext) 148 | ax[1].set_title(r"$s_{i,j}$ imag") 149 | plt.colorbar(im, ax=ax[1]) 150 | 151 | plt.savefig(tmp_path / "stddev_correct.png", dpi=300) 152 | 153 | plt.close("all") 154 | 155 | 156 | def test_estimate_stddev_large(mock_dataset_np, tmp_path): 157 | coords = coordinates.GridCoords(cell_size=0.01, npix=400) 158 | 159 | uu, vv, weight, data_re, data_im = mock_dataset_np 160 | weight = 0.1 * np.ones_like(uu) 161 | sigma = np.sqrt(1 / weight) 162 | data_re = np.ones_like(uu) + np.random.normal(loc=0, scale=2 * sigma, size=uu.shape) 163 | data_im = np.zeros_like(uu) + np.random.normal( 164 | loc=0, scale=2 * sigma, size=uu.shape 165 | ) 166 | 167 | averager = gridding.DataAverager( 168 | coords=coords, 169 | uu=uu, 170 | vv=vv, 171 | weight=weight, 172 | data_re=data_re, 173 | data_im=data_im, 174 | ) 175 | 176 | s_re, s_im = averager._estimate_cell_standard_deviation() 177 | 178 | chan = 0 179 | 180 | fig, ax = plt.subplots(ncols=2, figsize=(7, 4)) 181 | 182 | im = ax[0].imshow(s_re[chan], origin="lower", extent=averager.coords.vis_ext) 183 | ax[0].set_title(r"$s_{i,j}$ real") 184 | plt.colorbar(im, ax=ax[0]) 185 | 186 | im = ax[1].imshow(s_im[chan], origin="lower", extent=averager.coords.vis_ext) 187 | ax[1].set_title(r"$s_{i,j}$ imag") 188 | plt.colorbar(im, ax=ax[1]) 189 | 190 | plt.savefig(tmp_path / "stddev_large.png", dpi=300) 191 | 192 | plt.close("all") 193 | 194 | 195 | def test_max_scatter_pass(mock_dataset_np): 196 | coords = coordinates.GridCoords(cell_size=0.01, npix=400) 197 | 198 | uu, vv, weight, data_re, data_im = mock_dataset_np 199 | weight = 0.1 * np.ones_like(uu) 200 | sigma = np.sqrt(1 / weight) 201 | data_re = np.ones_like(uu) + np.random.normal(loc=0, scale=sigma, size=uu.shape) 202 | data_im = np.zeros_like(uu) + np.random.normal(loc=0, scale=sigma, size=uu.shape) 203 | 204 | averager = gridding.DataAverager( 205 | coords=coords, 206 | uu=uu, 207 | vv=vv, 208 | weight=weight, 209 | data_re=data_re, 210 | data_im=data_im, 211 | ) 212 | 213 | # we want this to return an exit code of True, indicating an error 214 | d = averager._check_scatter_error() 215 | print(d["median_re"], d["median_im"]) 216 | assert not d["return_status"] 217 | 218 | 219 | def test_max_scatter_fail(mock_dataset_np): 220 | coords = coordinates.GridCoords(cell_size=0.01, npix=400) 221 | 222 | uu, vv, weight, data_re, data_im = mock_dataset_np 223 | weight = 0.1 * np.ones_like(uu) 224 | sigma = np.sqrt(1 / weight) 225 | data_re = np.ones_like(uu) + np.random.normal(loc=0, scale=2 * sigma, size=uu.shape) 226 | data_im = np.zeros_like(uu) + np.random.normal( 227 | loc=0, scale=2 * sigma, size=uu.shape 228 | ) 229 | 230 | averager = gridding.DataAverager( 231 | coords=coords, 232 | uu=uu, 233 | vv=vv, 234 | weight=weight, 235 | data_re=data_re, 236 | data_im=data_im, 237 | ) 238 | 239 | # we want this to return an exit code of True, indicating an error 240 | d = averager._check_scatter_error() 241 | print(d["median_re"], d["median_im"]) 242 | assert d["return_status"] 243 | -------------------------------------------------------------------------------- /test/gridder_imager_test.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pytest 7 | from mpol import coordinates, gridding 8 | 9 | 10 | # cache an instantiated imager for future imaging ops 11 | @pytest.fixture 12 | def imager(mock_dataset_np, coords): 13 | uu, vv, weight, data_re, data_im = mock_dataset_np 14 | 15 | return gridding.DirtyImager( 16 | coords=coords, 17 | uu=uu, 18 | vv=vv, 19 | weight=weight, 20 | data_re=data_re, 21 | data_im=data_im, 22 | ) 23 | 24 | 25 | # make sure the peak of the PSF normalizes to 1 for each channel 26 | def test_beam_normalized(imager): 27 | r = -0.5 28 | for weighting in ["uniform", "natural", "briggs"]: 29 | if weighting == "briggs": 30 | imager._grid_visibilities(weighting=weighting, robust=r) 31 | else: 32 | imager._grid_visibilities(weighting=weighting) 33 | beam = imager._get_dirty_beam(imager.C, imager.re_gridded_beam) 34 | 35 | for i in range(imager.nchan): 36 | assert np.max(beam[i]) == pytest.approx(1.0) 37 | 38 | 39 | def test_beam_null(imager, tmp_path): 40 | r = -0.5 41 | imager._grid_visibilities(weighting="briggs", robust=r) 42 | beam = imager._get_dirty_beam(imager.C, imager.re_gridded_beam) 43 | nulled = imager._null_dirty_beam() 44 | 45 | chan = 0 46 | fig, ax = plt.subplots(ncols=2) 47 | 48 | cmap = copy.copy(matplotlib.colormaps["viridis"]) 49 | cmap.set_under("r") 50 | norm = matplotlib.colors.Normalize(vmin=0) 51 | 52 | im = ax[0].imshow( 53 | beam[chan], 54 | origin="lower", 55 | interpolation="none", 56 | extent=imager.coords.img_ext, 57 | cmap=cmap, 58 | norm=norm, 59 | ) 60 | plt.colorbar(im, ax=ax[0]) 61 | 62 | im = ax[1].imshow( 63 | nulled[chan] - 1e-6, 64 | origin="lower", 65 | interpolation="none", 66 | extent=imager.coords.img_ext, 67 | cmap=cmap, 68 | norm=norm, 69 | ) 70 | plt.colorbar(im, ax=ax[1]) 71 | 72 | fig.savefig(tmp_path / "beam_v_nulled.png", dpi=300) 73 | plt.close("all") 74 | 75 | 76 | def test_beam_null_full(imager, tmp_path): 77 | r = -0.5 78 | imager._grid_visibilities(weighting="briggs", robust=r) 79 | beam = imager._get_dirty_beam(imager.C, imager.re_gridded_beam) 80 | nulled = imager._null_dirty_beam(single_channel_estimate=False) 81 | 82 | chan = 0 83 | fig, ax = plt.subplots(ncols=2) 84 | 85 | cmap = copy.copy(matplotlib.colormaps["viridis"]) 86 | cmap.set_under("r") 87 | norm = matplotlib.colors.Normalize(vmin=0) 88 | 89 | im = ax[0].imshow( 90 | beam[chan], 91 | origin="lower", 92 | interpolation="none", 93 | extent=imager.coords.img_ext, 94 | cmap=cmap, 95 | norm=norm, 96 | ) 97 | plt.colorbar(im, ax=ax[0]) 98 | 99 | im = ax[1].imshow( 100 | nulled[chan] - 1e-6, 101 | origin="lower", 102 | interpolation="none", 103 | extent=imager.coords.img_ext, 104 | cmap=cmap, 105 | norm=norm, 106 | ) 107 | plt.colorbar(im, ax=ax[1]) 108 | 109 | fig.savefig(tmp_path / "beam_v_nulled.png", dpi=300) 110 | plt.close("all") 111 | 112 | 113 | def test_beam_area_before_beam(imager): 114 | r = -0.5 115 | imager._grid_visibilities(weighting="briggs", robust=r) 116 | area = imager.get_dirty_beam_area() 117 | print(area) 118 | 119 | 120 | # compare uniform and robust = -2.0 121 | def test_grid_uniform(imager, tmp_path): 122 | kw = {"origin": "lower", "interpolation": "none", "extent": imager.coords.img_ext} 123 | 124 | chan = 0 125 | 126 | img_uniform, beam_uniform = imager.get_dirty_image( 127 | weighting="uniform", check_visibility_scatter=False 128 | ) 129 | 130 | r = -2 131 | img_robust, beam_robust = imager.get_dirty_image( 132 | weighting="briggs", robust=r, check_visibility_scatter=False 133 | ) 134 | 135 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 4.5)) 136 | 137 | ax[0, 0].imshow(beam_uniform[chan], **kw) 138 | ax[0, 0].set_title("uniform") 139 | ax[1, 0].imshow(img_uniform[chan], **kw) 140 | 141 | ax[0, 1].imshow(beam_robust[chan], **kw) 142 | ax[0, 1].set_title(f"robust={r}") 143 | ax[1, 1].imshow(img_robust[chan], **kw) 144 | 145 | # the differences 146 | im = ax[0, 2].imshow(beam_uniform[chan] - beam_robust[chan], **kw) 147 | plt.colorbar(im, ax=ax[0, 2]) 148 | ax[0, 2].set_title("difference") 149 | im = ax[1, 2].imshow(img_uniform[chan] - img_robust[chan], **kw) 150 | plt.colorbar(im, ax=ax[1, 2]) 151 | 152 | fig.subplots_adjust(left=0.05, right=0.95, wspace=0.02, bottom=0.07, top=0.94) 153 | 154 | fig.savefig(tmp_path / "uniform_v_robust.png", dpi=300) 155 | 156 | assert np.all(np.abs(beam_uniform - beam_robust) < 1e-4) 157 | assert np.all(np.abs(img_uniform - img_robust) < 1e-4) 158 | 159 | plt.close("all") 160 | 161 | 162 | # compare uniform and robust = -2.0 163 | def test_grid_uniform_arcsec2(imager, tmp_path): 164 | kw = {"origin": "lower", "interpolation": "none", "extent": imager.coords.img_ext} 165 | 166 | chan = 0 167 | img_uniform, beam_uniform = imager.get_dirty_image( 168 | weighting="uniform", unit="Jy/arcsec^2", check_visibility_scatter=False 169 | ) 170 | 171 | r = -2 172 | img_robust, beam_robust = imager.get_dirty_image( 173 | weighting="briggs", robust=r, unit="Jy/arcsec^2", check_visibility_scatter=False 174 | ) 175 | 176 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 4.5)) 177 | 178 | ax[0, 0].imshow(beam_uniform[chan], **kw) 179 | ax[0, 0].set_title("uniform") 180 | im = ax[1, 0].imshow(img_uniform[chan], **kw) 181 | plt.colorbar(im, ax=ax[1, 0]) 182 | 183 | ax[0, 1].imshow(beam_robust[chan], **kw) 184 | ax[0, 1].set_title(f"robust={r}") 185 | im = ax[1, 1].imshow(img_robust[chan], **kw) 186 | plt.colorbar(im, ax=ax[1, 1]) 187 | 188 | # the differences 189 | im = ax[0, 2].imshow(beam_uniform[chan] - beam_robust[chan], **kw) 190 | plt.colorbar(im, ax=ax[0, 2]) 191 | ax[0, 2].set_title("difference") 192 | im = ax[1, 2].imshow(img_uniform[chan] - img_robust[chan], **kw) 193 | plt.colorbar(im, ax=ax[1, 2]) 194 | 195 | fig.subplots_adjust(left=0.05, right=0.95, wspace=0.02, bottom=0.07, top=0.94) 196 | 197 | fig.savefig(tmp_path / "uniform_v_robust_arcsec2.png", dpi=300) 198 | 199 | assert np.all(np.abs(beam_uniform - beam_robust) < 1e-4) 200 | assert np.all(np.abs(img_uniform - img_robust) < 6e-3) 201 | 202 | plt.close("all") 203 | 204 | 205 | def test_grid_natural(imager, tmp_path): 206 | kw = {"origin": "lower", "interpolation": "none", "extent": imager.coords.img_ext} 207 | 208 | chan = 0 209 | 210 | img_natural, beam_natural = imager.get_dirty_image( 211 | weighting="natural", check_visibility_scatter=False 212 | ) 213 | 214 | r = 2 215 | img_robust, beam_robust = imager.get_dirty_image( 216 | weighting="briggs", robust=r, check_visibility_scatter=False 217 | ) 218 | 219 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 4.5)) 220 | 221 | ax[0, 0].imshow(beam_natural[chan], **kw) 222 | ax[0, 0].set_title("natural") 223 | ax[1, 0].imshow(img_natural[chan], **kw) 224 | 225 | ax[0, 1].imshow(beam_robust[chan], **kw) 226 | ax[0, 1].set_title(f"robust={r}") 227 | ax[1, 1].imshow(img_robust[chan], **kw) 228 | 229 | # the differences 230 | im = ax[0, 2].imshow(beam_natural[chan] - beam_robust[chan], **kw) 231 | plt.colorbar(im, ax=ax[0, 2]) 232 | ax[0, 2].set_title("difference") 233 | im = ax[1, 2].imshow(img_natural[chan] - img_robust[chan], **kw) 234 | plt.colorbar(im, ax=ax[1, 2]) 235 | 236 | fig.subplots_adjust(left=0.05, right=0.95, wspace=0.02, bottom=0.07, top=0.94) 237 | 238 | fig.savefig(tmp_path / "grid_natural_v_robust.png", dpi=300) 239 | 240 | assert np.all(np.abs(beam_natural - beam_robust) < 1.5e-3) 241 | assert np.all(np.abs(img_natural - img_robust) < 3e-5) 242 | 243 | plt.close("all") 244 | 245 | 246 | def test_grid_natural_arcsec2(imager, tmp_path): 247 | kw = {"origin": "lower", "interpolation": "none", "extent": imager.coords.img_ext} 248 | 249 | chan = 0 250 | 251 | img_natural, beam_natural = imager.get_dirty_image( 252 | weighting="natural", unit="Jy/arcsec^2", check_visibility_scatter=False 253 | ) 254 | 255 | r = 2 256 | img_robust, beam_robust = imager.get_dirty_image( 257 | weighting="briggs", robust=r, unit="Jy/arcsec^2", check_visibility_scatter=False 258 | ) 259 | 260 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 4.5)) 261 | 262 | ax[0, 0].imshow(beam_natural[chan], **kw) 263 | ax[0, 0].set_title("natural") 264 | im = ax[1, 0].imshow(img_natural[chan], **kw) 265 | plt.colorbar(im, ax=ax[1, 0]) 266 | 267 | ax[0, 1].imshow(beam_robust[chan], **kw) 268 | ax[0, 1].set_title(f"robust={r}") 269 | im = ax[1, 1].imshow(img_robust[chan], **kw) 270 | plt.colorbar(im, ax=ax[1, 1]) 271 | 272 | # the differences 273 | im = ax[0, 2].imshow(beam_natural[chan] - beam_robust[chan], **kw) 274 | plt.colorbar(im, ax=ax[0, 2]) 275 | ax[0, 2].set_title("difference") 276 | im = ax[1, 2].imshow(img_natural[chan] - img_robust[chan], **kw) 277 | plt.colorbar(im, ax=ax[1, 2]) 278 | 279 | fig.subplots_adjust(left=0.05, right=0.95, wspace=0.02, bottom=0.07, top=0.94) 280 | 281 | fig.savefig(tmp_path / "natural_v_robust_arcsec2.png", dpi=300) 282 | 283 | assert np.all(np.abs(beam_natural - beam_robust) < 1.5e-3) 284 | assert np.all(np.abs(img_natural - img_robust) <2e-4) 285 | 286 | plt.close("all") 287 | 288 | 289 | def test_cell_variance_warning_image(mock_dataset_np): 290 | coords = coordinates.GridCoords(cell_size=0.01, npix=400) 291 | 292 | uu, vv, weight, data_re, data_im = mock_dataset_np 293 | sigma = np.sqrt(1 / weight) 294 | data_re = np.ones_like(uu) + np.random.normal(loc=0, scale=2 * sigma, size=uu.shape) 295 | data_im = np.zeros_like(uu) + np.random.normal( 296 | loc=0, scale=2 * sigma, size=uu.shape 297 | ) 298 | 299 | imager = gridding.DirtyImager( 300 | coords=coords, 301 | uu=uu, 302 | vv=vv, 303 | weight=weight, 304 | data_re=data_re, 305 | data_im=data_im, 306 | ) 307 | 308 | with pytest.warns(RuntimeWarning): 309 | imager.get_dirty_image(weighting="uniform") 310 | -------------------------------------------------------------------------------- /test/gridder_init_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from mpol import coordinates, gridding 4 | from mpol.exceptions import CellSizeError, DataError 5 | 6 | 7 | def test_hermitian_pairs(mock_dataset_np): 8 | # Test to see whether our routine checking whether Hermitian pairs 9 | # exist in the dataset works correctly in the False and True cases 10 | 11 | uu, vv, weight, data_re, data_im = mock_dataset_np 12 | 13 | # should *NOT* contain Hermitian pairs 14 | gridding.verify_no_hermitian_pairs(uu, vv, data_re + 1.0j * data_im) 15 | 16 | # expand the vectors to include complex conjugates 17 | uu = np.concatenate([uu, -uu], axis=1) 18 | vv = np.concatenate([vv, -vv], axis=1) 19 | data_re = np.concatenate([data_re, data_re], axis=1) 20 | data_im = np.concatenate([data_im, -data_im], axis=1) 21 | 22 | # should contain Hermitian pairs 23 | with pytest.raises( 24 | DataError, 25 | match="Hermitian pairs were found in the data. Please provide data without Hermitian pairs.", 26 | ): 27 | gridding.verify_no_hermitian_pairs(uu, vv, data_re + 1.0j * data_im) 28 | 29 | 30 | def test_averager_instantiate_cell_npix(mock_dataset_np): 31 | uu, vv, weight, data_re, data_im = mock_dataset_np 32 | 33 | coords = coordinates.GridCoords( 34 | cell_size=0.005, 35 | npix=800 36 | ) 37 | gridding.DataAverager(coords=coords, 38 | uu=uu, 39 | vv=vv, 40 | weight=weight, 41 | data_re=data_re, 42 | data_im=data_im, 43 | ) 44 | 45 | 46 | def test_averager_instantiate_gridCoord(mock_dataset_np): 47 | uu, vv, weight, data_re, data_im = mock_dataset_np 48 | 49 | mycoords = coordinates.GridCoords(cell_size=0.005, npix=800) 50 | 51 | gridding.DataAverager( 52 | coords=mycoords, 53 | uu=uu, 54 | vv=vv, 55 | weight=weight, 56 | data_re=data_re, 57 | data_im=data_im, 58 | ) 59 | 60 | 61 | def test_averager_instantiate_bounds_fail(mock_dataset_np): 62 | uu, vv, weight, data_re, data_im = mock_dataset_np 63 | 64 | mycoords = coordinates.GridCoords(cell_size=0.05, npix=800) 65 | 66 | with pytest.raises(CellSizeError): 67 | gridding.DataAverager( 68 | coords=mycoords, 69 | uu=uu, 70 | vv=vv, 71 | weight=weight, 72 | data_re=data_re, 73 | data_im=data_im, 74 | ) 75 | -------------------------------------------------------------------------------- /test/images_test.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pytest 4 | import torch 5 | from astropy.io import fits 6 | from mpol import coordinates, images, plot, utils 7 | from plot_utils import imshow_two 8 | 9 | 10 | def test_BaseCube_map(coords, tmp_path): 11 | # create a mock cube that includes negative values 12 | nchan = 1 13 | mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5) 14 | std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5) 15 | 16 | bcube = torch.normal(mean=mean, std=std) 17 | blayer = images.BaseCube(coords=coords, nchan=nchan, base_cube=bcube) 18 | 19 | # the default softplus function should map everything to positive values 20 | blayer_output = blayer() 21 | 22 | imshow_two( 23 | tmp_path / "BaseCube_mapped.png", 24 | [bcube, blayer_output], 25 | title=["BaseCube input", "BaseCube output"], 26 | xlabel=["pixel"], 27 | ylabel=["pixel"], 28 | ) 29 | 30 | assert torch.all(blayer_output >= 0) 31 | 32 | 33 | def test_instantiate_ImageCube(): 34 | coords = coordinates.GridCoords(cell_size=0.015, npix=800) 35 | im = images.ImageCube(coords=coords) 36 | assert im.nchan == 1 37 | 38 | 39 | def test_ImageCube_apply_grad(coords): 40 | bcube = images.BaseCube(coords=coords) 41 | imagecube = images.ImageCube(coords=coords) 42 | loss = torch.sum(imagecube(bcube())) 43 | loss.backward() 44 | 45 | 46 | def test_to_FITS_pixel_scale(coords, tmp_path): 47 | """Test whether the FITS scale was written correctly.""" 48 | bcube = images.BaseCube(coords=coords) 49 | imagecube = images.ImageCube(coords=coords) 50 | imagecube(bcube()) 51 | 52 | # write FITS to file 53 | imagecube.to_FITS(fname=tmp_path / "test_cube_fits_file39.fits", overwrite=True) 54 | 55 | # read file and check pixel scale is correct 56 | fits_header = fits.open(tmp_path / "test_cube_fits_file39.fits")[0].header 57 | assert (fits_header["CDELT1"] and fits_header["CDELT2"]) == pytest.approx( 58 | coords.cell_size / 3600 59 | ) 60 | 61 | 62 | def test_HannConvCube(coords, tmp_path): 63 | # create a mock cube that includes negative values 64 | nchan = 1 65 | mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5) 66 | std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5) 67 | 68 | # The HannConvCube expects to function on a pre-packed ImageCube, 69 | test_cube = torch.normal(mean=mean, std=std) 70 | test_cube_packed = utils.sky_cube_to_packed_cube(test_cube) 71 | 72 | conv_layer = images.HannConvCube(nchan=nchan) 73 | 74 | conv_output_packed = conv_layer(test_cube_packed) 75 | conv_output = utils.packed_cube_to_sky_cube(conv_output_packed) 76 | 77 | imshow_two( 78 | tmp_path / "convcube.png", 79 | [test_cube, conv_output], 80 | title=["input", "convolved"], 81 | xlabel=["pixel"], 82 | ylabel=["pixel"], 83 | ) 84 | 85 | 86 | def test_HannConvCube_multi_chan(coords): 87 | """Make sure HannConvCube functions with multi-channeled input""" 88 | nchan = 10 89 | mean = torch.full((nchan, coords.npix, coords.npix), fill_value=-0.5) 90 | std = torch.full((nchan, coords.npix, coords.npix), fill_value=0.5) 91 | 92 | test_cube = torch.normal(mean=mean, std=std) 93 | 94 | conv_layer = images.HannConvCube(nchan=nchan) 95 | conv_layer(test_cube) 96 | 97 | 98 | def test_flux(coords): 99 | """Make sure we can read the flux attribute.""" 100 | nchan = 20 101 | bcube = images.BaseCube(coords=coords, nchan=nchan) 102 | im = images.ImageCube(coords=coords, nchan=nchan) 103 | im(bcube()) 104 | assert im.flux.size()[0] == nchan 105 | 106 | 107 | def test_plot_test_img(packed_cube, coords, tmp_path): 108 | # show only the first channel 109 | chan = 0 110 | fig, ax = plt.subplots(nrows=1) 111 | 112 | # put back to sky 113 | sky_cube = utils.packed_cube_to_sky_cube(packed_cube) 114 | im = ax.imshow(sky_cube[chan], extent=coords.img_ext, origin="lower") 115 | plt.colorbar(im) 116 | fig.savefig(tmp_path / "sky_cube.png") 117 | 118 | plt.close("all") 119 | 120 | 121 | def test_uv_gaussian_taper(coords, tmp_path): 122 | for r in np.arange(0.0, 0.2, step=0.04): 123 | fig, ax = plt.subplots(ncols=1) 124 | 125 | taper_2D = images.uv_gaussian_taper(coords, r, r, 0.0) 126 | print(type(taper_2D)) 127 | 128 | norm = plot.get_image_cmap_norm(taper_2D, symmetric=True) 129 | im = ax.imshow( 130 | taper_2D, 131 | extent=coords.vis_ext_Mlam, 132 | origin="lower", 133 | cmap="bwr_r", 134 | norm=norm, 135 | ) 136 | plt.colorbar(im, ax=ax) 137 | 138 | fig.savefig(tmp_path / f"taper{r:.2f}.png") 139 | 140 | plt.close("all") 141 | 142 | 143 | def test_GaussConvImage_kernel(coords, tmp_path): 144 | rs = np.array([0.02, 0.06, 0.10]) 145 | nchan = 3 146 | fig, ax = plt.subplots(nrows=len(rs), ncols=nchan, figsize=(10, 10)) 147 | for i, r in enumerate(rs): 148 | layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=r, FWHM_min=0.5 * r) 149 | weight = layer.m.weight.detach().numpy() 150 | for j in range(nchan): 151 | im = ax[i, j].imshow(weight[j, 0], interpolation="none", origin="lower") 152 | plt.colorbar(im, ax=ax[i, j]) 153 | 154 | fig.savefig(tmp_path / "filter.png") 155 | plt.close("all") 156 | 157 | 158 | def test_GaussConvImage_kernel_rotate(coords, tmp_path): 159 | r = 0.04 160 | Omegas = [0, 20, 40] # degrees 161 | nchan = 3 162 | fig, ax = plt.subplots(nrows=len(Omegas), ncols=nchan, figsize=(10, 10)) 163 | for i, Omega in enumerate(Omegas): 164 | layer = images.GaussConvImage( 165 | coords, nchan=nchan, FWHM_maj=r, FWHM_min=0.5 * r, Omega=Omega 166 | ) 167 | weight = layer.m.weight.detach().numpy() 168 | for j in range(nchan): 169 | im = ax[i, j].imshow(weight[j, 0], interpolation="none", origin="lower") 170 | plt.colorbar(im, ax=ax[i, j]) 171 | 172 | fig.savefig(tmp_path / "filter.png") 173 | plt.close("all") 174 | 175 | 176 | @pytest.mark.parametrize("FWHM", [0.02, 0.06, 0.1]) 177 | def test_GaussConvImage(sky_cube, coords, tmp_path, FWHM): 178 | chan = 0 179 | nchan = sky_cube.size()[0] 180 | 181 | layer = images.GaussConvImage(coords, nchan=nchan, FWHM_maj=FWHM, FWHM_min=FWHM) 182 | c_sky = layer(sky_cube) 183 | 184 | imgs = [sky_cube[chan], c_sky[chan]] 185 | fluxes = [coords.cell_size**2 * torch.sum(img).item() for img in imgs] 186 | title = [f"tot flux: {flux:.3f} Jy" for flux in fluxes] 187 | 188 | imshow_two( 189 | tmp_path / f"convolved_{FWHM:.2f}.png", 190 | imgs, 191 | sky=True, 192 | suptitle=f"Image Plane Gauss Convolution FWHM={FWHM}", 193 | title=title, 194 | extent=[coords.img_ext], 195 | ) 196 | 197 | assert pytest.approx(fluxes[0]) == fluxes[1] 198 | 199 | 200 | @pytest.mark.parametrize("Omega", [0, 15, 30, 45]) 201 | def test_GaussConvImage_rotate(sky_cube, coords, tmp_path, Omega): 202 | chan = 0 203 | nchan = sky_cube.size()[0] 204 | 205 | FWHM_maj = 0.10 206 | FWHM_min = 0.05 207 | 208 | layer = images.GaussConvImage( 209 | coords, nchan=nchan, FWHM_maj=FWHM_maj, FWHM_min=FWHM_min, Omega=Omega 210 | ) 211 | c_sky = layer(sky_cube) 212 | 213 | imgs = [sky_cube[chan], c_sky[chan]] 214 | fluxes = [coords.cell_size**2 * torch.sum(img).item() for img in imgs] 215 | title = [f"tot flux: {flux:.3f} Jy" for flux in fluxes] 216 | 217 | imshow_two( 218 | tmp_path / f"convolved_{Omega:.0f}_deg.png", 219 | imgs, 220 | sky=True, 221 | suptitle=r"Image Plane Gauss Convolution: $\Omega$=" 222 | + f'{Omega}, {FWHM_maj}", {FWHM_min}"', 223 | title=title, 224 | extent=[coords.img_ext], 225 | ) 226 | 227 | assert pytest.approx(fluxes[0], abs=4e-7) == fluxes[1] 228 | 229 | 230 | @pytest.mark.parametrize("FWHM", [0.02, 0.1, 0.2, 0.3, 0.5]) 231 | def test_GaussConvFourier(packed_cube, coords, tmp_path, FWHM): 232 | chan = 0 233 | sky_cube = utils.packed_cube_to_sky_cube(packed_cube) 234 | 235 | layer = images.GaussConvFourier(coords, FWHM, FWHM) 236 | c = layer(packed_cube) 237 | c_sky = utils.packed_cube_to_sky_cube(c) 238 | 239 | imgs = [sky_cube[chan], c_sky[chan]] 240 | fluxes = [coords.cell_size**2 * torch.sum(img).item() for img in imgs] 241 | title = [f"tot flux: {flux:.3f} Jy" for flux in fluxes] 242 | 243 | imshow_two( 244 | tmp_path / "convolved_FWHM_{:.2f}.png".format(FWHM), 245 | imgs, 246 | sky=True, 247 | suptitle=f"Fourier Plane Gauss Convolution: FWHM={FWHM}", 248 | title=title, 249 | extent=[coords.img_ext], 250 | ) 251 | 252 | assert pytest.approx(fluxes[0], abs=4e-7) == fluxes[1] 253 | 254 | 255 | @pytest.mark.parametrize("Omega", [0, 15, 30, 45]) 256 | def test_GaussConvFourier_rotate(packed_cube, coords, tmp_path, Omega): 257 | chan = 0 258 | sky_cube = utils.packed_cube_to_sky_cube(packed_cube) 259 | 260 | FWHM_maj = 0.10 261 | FWHM_min = 0.05 262 | layer = images.GaussConvFourier( 263 | coords, FWHM_maj=FWHM_maj, FWHM_min=FWHM_min, Omega=Omega 264 | ) 265 | 266 | c = layer(packed_cube) 267 | c_sky = utils.packed_cube_to_sky_cube(c) 268 | 269 | imgs = [sky_cube[chan], c_sky[chan]] 270 | fluxes = [coords.cell_size**2 * torch.sum(img).item() for img in imgs] 271 | title = [f"tot flux: {flux:.3f} Jy" for flux in fluxes] 272 | 273 | imshow_two( 274 | tmp_path / f"convolved_{Omega:.0f}_deg.png", 275 | imgs, 276 | sky=True, 277 | suptitle=r"Fourier Plane Gauss Convolution: $\Omega$=" 278 | + f'{Omega}, {FWHM_maj}", {FWHM_min}"', 279 | title=title, 280 | extent=[coords.img_ext], 281 | ) 282 | 283 | assert pytest.approx(fluxes[0], abs=4e-7) == fluxes[1] 284 | 285 | 286 | def test_GaussConvFourier_point(coords, tmp_path): 287 | FWHM = 0.5 288 | 289 | # create an image with a point source in the center 290 | sky_cube = torch.zeros((1, coords.npix, coords.npix)) 291 | cpix = coords.npix // 2 292 | sky_cube[0, cpix, cpix] = 1.0 293 | 294 | fig, ax = plt.subplots(ncols=2, sharex=True, sharey=True) 295 | # put back to sky 296 | im = ax[0].imshow(sky_cube[0], extent=coords.img_ext, origin="lower") 297 | flux = coords.cell_size**2 * torch.sum(sky_cube[0]) 298 | ax[0].set_title(f"tot flux: {flux:.3f} Jy") 299 | plt.colorbar(im, ax=ax[0]) 300 | 301 | # set base resolution 302 | layer = images.GaussConvFourier(coords, FWHM, FWHM) 303 | packed_cube = utils.sky_cube_to_packed_cube(sky_cube) 304 | c = layer(packed_cube) 305 | # put back to sky 306 | c_sky = utils.packed_cube_to_sky_cube(c) 307 | flux = coords.cell_size**2 * torch.sum(c_sky[0]) 308 | im = ax[1].imshow( 309 | c_sky[0].detach().numpy(), 310 | extent=coords.img_ext, 311 | origin="lower", 312 | cmap="inferno", 313 | ) 314 | ax[1].set_title(f"tot flux: {flux:.3f} Jy") 315 | r = 0.7 316 | ax[1].set_xlim(r, -r) 317 | ax[1].set_ylim(-r, r) 318 | 319 | plt.colorbar(im, ax=ax[1]) 320 | fig.savefig(tmp_path / "point_source_FWHM_{:.2f}.png".format(FWHM)) 321 | 322 | plt.close("all") 323 | -------------------------------------------------------------------------------- /test/input_output_test.py: -------------------------------------------------------------------------------- 1 | 2 | # from astropy.utils.data import download_file 3 | # from mpol.input_output import ProcessFitsImage 4 | 5 | 6 | # def test_ProcessFitsImage(): 7 | # # get a .fits file produced with casa 8 | # fname = download_file( 9 | # "https://zenodo.org/record/4711811/files/logo_cube.tclean.fits", 10 | # cache=True, 11 | # show_progress=True, 12 | # pkgname="mpol", 13 | # ) 14 | 15 | # fits_image = ProcessFitsImage(fname) 16 | # clean_im, clean_im_ext, clean_beam = fits_image.get_image(beam=True) 17 | -------------------------------------------------------------------------------- /test/losses_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from mpol import coordinates, fourier, losses 5 | 6 | 7 | @pytest.fixture 8 | def loose_vis_model(weight_2D_t): 9 | # just random noise is fine for these structural tests 10 | mean = torch.zeros_like(weight_2D_t) 11 | sigma = torch.sqrt(1 / weight_2D_t) 12 | model_re = torch.normal(mean, sigma) 13 | model_im = torch.normal(mean, sigma) 14 | return torch.complex(model_re, model_im) 15 | 16 | 17 | @pytest.fixture 18 | def gridded_vis_model(coords, packed_cube): 19 | nchan, npix, _ = packed_cube.size() 20 | coords = coordinates.GridCoords(npix=npix, cell_size=0.005) 21 | 22 | # use the FourierCube to produce model visibilities 23 | flayer = fourier.FourierCube(coords=coords) 24 | return flayer(packed_cube) 25 | 26 | 27 | def test_chi_squared_evaluation( 28 | loose_vis_model, mock_data_t, weight_2D_t 29 | ): 30 | # because of the way likelihood functions are defined, we would not expect 31 | # the loose chi_squared or log_likelihood function to give the same answers as 32 | # the gridded chi_squared or log_likelihood functions. This is because the normalization 33 | # of likelihood functions are somewhat ill-defined, since the value of the likelihood 34 | # function can change based on whether you bin your data. The normalization only really makes 35 | # sense in a full Bayesian setting where the evidence is computed and the normalization cancels out anyway. 36 | # The important thing is that the *shape* of the likelihood function is the same as parameters 37 | # are varied. 38 | # more info 39 | # https://stats.stackexchange.com/questions/97515/what-does-likelihood-is-only-defined-up-to-a-multiplicative-constant-of-proport?noredirect=1&lq=1 40 | 41 | # calculate the ungridded chi^2 42 | chi_squared = losses._chi_squared(loose_vis_model, mock_data_t, weight_2D_t) 43 | print("loose chi_squared", chi_squared) 44 | 45 | 46 | def test_log_likelihood_loose(loose_vis_model, mock_data_t, weight_2D_t): 47 | # calculate the ungridded log likelihood 48 | losses.log_likelihood(loose_vis_model, mock_data_t, weight_2D_t) 49 | 50 | def test_log_likelihood_gridded(gridded_vis_model, dataset): 51 | losses.log_likelihood_gridded(gridded_vis_model, dataset) 52 | 53 | 54 | def test_rchi_evaluation( 55 | loose_vis_model, mock_data_t, weight_2D_t, gridded_vis_model, dataset 56 | ): 57 | # We would have expected the ungridded and gridded values to be closer than they are 58 | # but I suppose this comes down to how the noise is averaged within each visibility cell. 59 | # and the definition of degrees of freedom 60 | # https://arxiv.org/abs/1012.3754 61 | 62 | # calculate the ungridded log likelihood 63 | log_like = losses.r_chi_squared(loose_vis_model, mock_data_t, weight_2D_t) 64 | print("loose nll", log_like) 65 | 66 | # calculate the gridded log likelihood 67 | print(gridded_vis_model.size(), dataset.mask.size()) 68 | log_like_gridded = losses.r_chi_squared_gridded(gridded_vis_model, dataset) 69 | print("gridded nll", log_like_gridded) 70 | 71 | 72 | def test_r_chi_1D_zero(): 73 | # make identical fake pytorch arrays for data and model 74 | # assert that nll losses returns 0 75 | 76 | N = 10000 77 | weights = torch.ones((N), dtype=torch.float64) 78 | 79 | model_re = torch.randn_like(weights) 80 | model_im = torch.randn_like(weights) 81 | model_vis = torch.complex(model_re, model_im) 82 | 83 | data_re = model_re 84 | data_im = model_im 85 | data_vis = torch.complex(data_re, data_im) 86 | 87 | loss = losses.r_chi_squared(model_vis, data_vis, weights) 88 | assert loss.item() == 0.0 89 | 90 | 91 | def test_r_chi_1D_random(): 92 | # make fake pytorch arrays that are random 93 | # and then test that the nll version evaluates 94 | 95 | N = 10000 96 | weights = torch.ones((N), dtype=torch.float64) 97 | 98 | model_re = torch.randn_like(weights) 99 | model_im = torch.randn_like(weights) 100 | model_vis = torch.complex(model_re, model_im) 101 | 102 | data_re = torch.randn_like(weights) 103 | data_im = torch.randn_like(weights) 104 | data_vis = torch.complex(data_re, data_im) 105 | 106 | losses.r_chi_squared(model_vis, data_vis, weights) 107 | 108 | 109 | def test_r_chi_2D_zero(): 110 | # sometimes thing come in as a (nchan, nvis) tensor 111 | # make identical fake pytorch arrays in this size, 112 | # and assert that they evaluate the same 113 | 114 | nchan = 50 115 | nvis = 10000 116 | weights = torch.ones((nchan, nvis), dtype=torch.float64) 117 | 118 | model_re = torch.randn_like(weights) 119 | model_im = torch.randn_like(weights) 120 | model_vis = torch.complex(model_re, model_im) 121 | 122 | data_re = model_re 123 | data_im = model_im 124 | data_vis = torch.complex(data_re, data_im) 125 | 126 | loss = losses.r_chi_squared(model_vis, data_vis, weights) 127 | assert loss.item() == 0.0 128 | 129 | 130 | def test_r_chi_2D_random(): 131 | # sometimes thing come in as a (nchan, nvis) tensor 132 | # make random fake pytorch arrays and make sure we can evaluate the function 133 | 134 | nchan = 50 135 | nvis = 10000 136 | weights = torch.ones((nchan, nvis), dtype=torch.float64) 137 | 138 | model_re = torch.randn_like(weights) 139 | model_im = torch.randn_like(weights) 140 | model_vis = torch.complex(model_re, model_im) 141 | 142 | data_re = torch.randn_like(weights) 143 | data_im = torch.randn_like(weights) 144 | data_vis = torch.complex(data_re, data_im) 145 | 146 | losses.r_chi_squared(model_vis, data_vis, weights) 147 | 148 | 149 | def test_loss_scaling(): 150 | for N in np.logspace(4, 5, num=10): 151 | # create fake model, resid, and weight 152 | N = int(N) 153 | 154 | mean = torch.zeros(N) 155 | std = 0.2 * torch.ones(N) 156 | weight = 1 / std**2 157 | 158 | model_real = torch.ones(N) 159 | model_imag = torch.zeros(N) 160 | model = torch.complex(model_real, model_imag) 161 | 162 | noise_real = torch.normal(mean, std) 163 | noise_imag = torch.normal(mean, std) 164 | noise = torch.complex(noise_real, noise_imag) 165 | 166 | data = model + noise 167 | 168 | nlla = losses.neg_log_likelihood_avg(model, data, weight) 169 | print("N", N, "nlla", nlla) 170 | 171 | 172 | def test_entropy_raise_error_negative(): 173 | nchan = 50 174 | npix = 512 175 | with pytest.raises(AssertionError): 176 | cube = torch.randn((nchan, npix, npix), dtype=torch.float64) 177 | losses.entropy(cube, 0.01) 178 | 179 | 180 | def test_entropy_raise_error_negative_prior(): 181 | nchan = 50 182 | npix = 512 183 | with pytest.raises(AssertionError): 184 | cube = torch.ones((nchan, npix, npix), dtype=torch.float64) 185 | losses.entropy(cube, -0.01) 186 | 187 | 188 | def test_entropy_cube(): 189 | # make a cube that should evaluate within the entropy loss function 190 | 191 | nchan = 50 192 | npix = 512 193 | 194 | cube = torch.ones((nchan, npix, npix), dtype=torch.float64) 195 | losses.entropy(cube, 0.01) 196 | 197 | 198 | def test_tsv(): 199 | # Here we test the accuracy of the losses.TSV() routine relative to what is 200 | # written in equations. Since for-loops in python are typically slow, it is 201 | # unreasonable to use this format in the TSV() function so a vector math format is used. 202 | # Here we test to ensure that this vector math is calculates correctly and results in the 203 | # same value as would come from the for-loop. 204 | 205 | # setting the size of our image 206 | npix = 3 207 | 208 | # creating the test cube 209 | cube = torch.rand((1, npix, npix)) 210 | 211 | # finding the value that our TSV function returns 212 | tsv_val = losses.TSV(cube) 213 | 214 | for_val = 0 215 | # calculating the TSV loss through a for loop 216 | for i in range(npix - 1): 217 | for j in range(npix - 1): 218 | for_val += (cube[:, i + 1, j] - cube[:, i, j]) ** 2 + ( 219 | cube[:, i, j + 1] - cube[:, i, j] 220 | ) ** 2 221 | # asserting that these two values calculated above are equivalent 222 | assert tsv_val == for_val 223 | 224 | 225 | def test_tv_image(): 226 | # Here we test the losses.TV_image(). Since for-loops in python are typically slow, it is 227 | # unreasonable to use this format in the TV_image() function so a vector math format is used. 228 | # Here we test to ensure that this vector math is calculates correctly and results in the same 229 | # value as would come from the for-loop. 230 | 231 | # setting the size of our image 232 | npix = 3 233 | 234 | # creating the test cube 235 | cube = torch.rand((1, npix, npix)) 236 | 237 | # finding the value that our TV_image function returns, we set epsilon=0 for a simpler for-loop 238 | tsv_val = losses.TV_image(cube, epsilon=0) 239 | for_val = 0 240 | 241 | # calculating the TV_image loss through a for loop 242 | for i in range(npix - 1): 243 | for j in range(npix - 1): 244 | for_val += torch.sqrt( 245 | (cube[:, i + 1, j] - cube[:, i, j]) ** 2 246 | + (cube[:, i, j + 1] - cube[:, i, j]) ** 2 247 | ) 248 | # asserting that these two values calculated above are equivalent 249 | assert tsv_val == for_val 250 | -------------------------------------------------------------------------------- /test/onedim_test.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from mpol.onedim import radialI, radialV 4 | from mpol.plot import plot_image 5 | from mpol.utils import torch2npy 6 | 7 | 8 | def test_radialI(mock_1d_image_model, tmp_path): 9 | # obtain a 1d radial brightness profile I(r) from an image 10 | 11 | rtrue, itrue, icube, _, _, geom = mock_1d_image_model 12 | 13 | bins = np.linspace(0, 2.0, 100) 14 | 15 | rtest, itest = radialI(icube, geom, bins=bins) 16 | 17 | fig, ax = plt.subplots(ncols=2, figsize=(10,5)) 18 | 19 | plot_image(np.squeeze(torch2npy(icube.sky_cube)), extent=icube.coords.img_ext, 20 | ax=ax[0], clab="Jy / sr") 21 | 22 | ax[1].plot(rtrue, itrue, "k", label="truth") 23 | ax[1].plot(rtest, itest, "r.-", label="recovery") 24 | 25 | ax[0].set_title(f"Geometry:\n{geom}", fontsize=7) 26 | 27 | ax[1].set_xlabel("r [arcsec]") 28 | ax[1].set_ylabel("I [Jy / sr]") 29 | ax[1].legend() 30 | 31 | fig.savefig(tmp_path / "test_radialI.png", dpi=300) 32 | plt.close("all") 33 | 34 | expected = [ 35 | 6.40747314e+10, 4.01920507e+10, 1.44803534e+10, 2.94238627e+09, 36 | 1.28782935e+10, 2.68613199e+10, 2.26564596e+10, 1.81151845e+10, 37 | 1.52128965e+10, 1.05640352e+10, 1.33411204e+10, 1.61124502e+10, 38 | 1.41500539e+10, 1.20121195e+10, 1.11770326e+10, 1.19676913e+10, 39 | 1.20941686e+10, 1.09498286e+10, 9.74236410e+09, 7.99589196e+09, 40 | 5.94787809e+09, 3.82074946e+09, 1.80823933e+09, 4.48414819e+08, 41 | 3.17808840e+08, 5.77317876e+08, 3.98851281e+08, 8.06459834e+08, 42 | 2.88706161e+09, 6.09577814e+09, 6.98556762e+09, 4.47436415e+09, 43 | 1.89511273e+09, 5.96604356e+08, 3.44571640e+08, 5.65906765e+08, 44 | 2.85854589e+08, 2.67589013e+08, 3.98357054e+08, 2.97052261e+08, 45 | 3.82744591e+08, 3.52239791e+08, 2.74336969e+08, 2.28425747e+08, 46 | 1.82290043e+08, 3.16077299e+08, 1.18465538e+09, 3.32239287e+09, 47 | 5.26718846e+09, 5.16458748e+09, 3.58114198e+09, 2.13431954e+09, 48 | 1.40936556e+09, 1.04032244e+09, 9.24050422e+08, 8.46829316e+08, 49 | 6.80909295e+08, 6.83812465e+08, 6.91856237e+08, 5.29227136e+08, 50 | 3.97557293e+08, 3.54893419e+08, 2.60997039e+08, 2.09306498e+08, 51 | 1.93930693e+08, 6.97032407e+07, 6.66090083e+07, 1.40079594e+08, 52 | 7.21775931e+07, 3.23902663e+07, 3.35932300e+07, 7.63318789e+06, 53 | 1.29740981e+07, 1.44300351e+07, 8.06249624e+06, 5.85567843e+06, 54 | 1.42637174e+06, 3.21445075e+06, 1.83763663e+06, 1.16926652e+07, 55 | 2.46918188e+07, 1.60206523e+07, 3.26596592e+06, 1.27837319e+05, 56 | 2.27104612e+04, 4.77267063e+03, 2.90467640e+03, 2.88482230e+03, 57 | 1.43402521e+03, 1.54791996e+03, 7.23397046e+02, 1.02561351e+03, 58 | 5.24845888e+02, 1.47320552e+03, 7.40419174e+02, 4.59029378e-03, 59 | 0.00000000e+00, 0.00000000e+00, 0.00000000e+00 60 | ] 61 | 62 | np.testing.assert_allclose(itest, expected, rtol=1e-6, 63 | err_msg="test_radialI") 64 | 65 | 66 | def test_radialV(mock_1d_vis_model, tmp_path): 67 | # obtain a 1d radial visibility profile V(q) from 2d visibilities 68 | 69 | fcube, Vtrue_dep, q_dep, geom = mock_1d_vis_model 70 | 71 | bins = np.linspace(1,5e3,100) 72 | 73 | qtest, Vtest = radialV(fcube, geom, rescale_flux=True, bins=bins) 74 | 75 | fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(10,10)) 76 | 77 | ax[0].plot(q_dep / 1e6, Vtrue_dep.real, "k.", label="truth deprojected") 78 | ax[0].plot(qtest / 1e3, Vtest.real, "r.-", label="recovery") 79 | 80 | ax[1].plot(q_dep / 1e6, Vtrue_dep.imag, "k.") 81 | ax[1].plot(qtest / 1e3, Vtest.imag, "r.") 82 | 83 | ax[0].set_xlim(-0.5, 6) 84 | ax[1].set_xlim(-0.5, 6) 85 | ax[1].set_xlabel(r"Baseline [M$\lambda$]") 86 | ax[0].set_ylabel("Re(V) [Jy]") 87 | ax[1].set_ylabel("Im(V) [Jy]") 88 | ax[0].set_title(f"Geometry {geom}", fontsize=10) 89 | ax[0].legend() 90 | 91 | fig.savefig(tmp_path / "test_radialV.png", dpi=300) 92 | plt.close("all") 93 | 94 | expected = [ 95 | -9.61751019e+09, 2.75229026e+09, -4.36137738e+08, -2.30171445e+07, 96 | -2.10099938e+08, 2.86360366e+08, -1.37544187e+07, -3.62764471e+07, 97 | 1.94332782e+07, -4.63579878e+07, 4.38157379e+07, -1.19891002e+07, 98 | 2.47285137e+07, -3.43389203e+07, 7.49974578e+05, 3.68423107e+06, 99 | 9.43443498e+06, -1.16182426e+07, 1.08867793e+07, -8.74943322e+06, 100 | 1.14521810e+07, -6.36361380e+06, 3.58538842e+05, -5.96714707e+06, 101 | 1.04348614e+07, -1.47220982e+06, -1.19522309e+07, -4.09776593e+06, 102 | 7.86540505e+06, 3.60337006e+06, -8.30025685e+06, 4.05093017e+06, 103 | 3.33292357e+06, 2.05733741e+06, -7.65245396e+06, 3.73332165e+06, 104 | 3.40645897e+06, -4.58494946e+06, 3.66101584e+06, -3.69627118e+06, 105 | 5.27955178e+06, 9.75812262e+06, -1.65425072e+07, 5.47225658e+06, 106 | -3.49680316e+06, 8.22030443e+06, -7.32448474e+06, -4.23843848e+06, 107 | 1.27346507e+07, -4.60792496e+06, -2.56148856e+06, 6.29770245e+05, 108 | -2.25521550e+06, 5.35018477e+06, -4.61334469e+06, 3.09166148e+06, 109 | -9.18155255e+05, -1.00736465e+06, 1.12177040e+06, -9.21570359e+05, 110 | 8.70817075e+05, 3.16472432e+04, -1.59681139e+06, 1.16213263e+06, 111 | 3.64004059e+04, -8.49130119e+04, -2.30599556e+05, -1.59965392e+03, 112 | 9.30837779e+05, -3.90387012e+05, -4.75338516e+05, 7.53183050e+04, 113 | 3.41897054e+05, -7.53936979e+05, 1.99039974e+06, -1.90488504e+06, 114 | -4.19283666e+05, 1.53004765e+06, -8.55774990e+05, 6.21661335e+05, 115 | -5.00689314e+05, -6.26249184e+05, 1.94062725e+06, -1.65756778e+06, 116 | 1.03046094e+06, -7.77307547e+05, 1.65177536e+05, 1.07726803e+05, 117 | -2.46681205e+05, 1.18707317e+05, 7.05201176e+04, 1.39152470e+05, 118 | -2.80631868e+04, 4.92257795e+05, -1.52044894e+06, 1.02459630e+06, 119 | 8.42494484e+05, -1.57362080e+06, 7.22603120e+05 120 | ] 121 | 122 | np.testing.assert_allclose(Vtest.real, expected, rtol=1e-6, 123 | err_msg="test_radialV") 124 | -------------------------------------------------------------------------------- /test/plot_test.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # from mpol.plot import image_comparison_fig 4 | 5 | # def test_image_comparison_fig(coords, tmp_path): 6 | # # generate an image comparison figure 7 | 8 | # model = precomposed.GriddedNet(coords=coords, nchan=1) 9 | # model() 10 | 11 | # # just interested in whether the tested functionality runs 12 | # u = v = np.repeat(1e3, 1000) 13 | # V = weights = np.ones_like(u) 14 | 15 | # # .fits file to act as clean image 16 | # fname = download_file( 17 | # "https://zenodo.org/record/4711811/files/logo_cube.tclean.fits", 18 | # cache=True, 19 | # show_progress=True, 20 | # pkgname="mpol", 21 | # ) 22 | 23 | # image_comparison_fig(model, u, v, V, weights, robust=0.5, 24 | # clean_fits=fname, 25 | # share_cscale=False, 26 | # xzoom=[-2, 2], yzoom=[-2, 2], 27 | # title="test", 28 | # save_prefix=None, 29 | # ) 30 | -------------------------------------------------------------------------------- /test/plot_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def extend_list(l, num=2): 8 | """ 9 | Duplicate or extend a list to two items. 10 | 11 | l: list 12 | the list of items to potentially duplicate or truncate. 13 | num: int 14 | the final length of the list 15 | 16 | Returns 17 | ------- 18 | list 19 | Length num list of items. 20 | 21 | Examples 22 | -------- 23 | >>> extend_list(["L Plot", "R Plot"]) 24 | ["L Plot", "R Plot"] 25 | >>> extend_list({["Plot"]) # both L and R will have "Plot" 26 | ["Plot", "Plot"] 27 | >>> extend_list({["L Plot", "R Plot", "Z Plot"]}) # "Z Plot" is ignored 28 | ["L Plot", "R Plot"] 29 | """ 30 | if len(l) == 1: 31 | return num * l 32 | else: 33 | return l[:num] 34 | 35 | def extend_kwargs(kwargs): 36 | """ 37 | This is a helper routine for imshow_two, designed to flexibly consume a variety 38 | of options for each of the two plots. 39 | 40 | kwargs: dict 41 | the kwargs dict provided from the function call 42 | 43 | Returns 44 | ------- 45 | dict 46 | Updated kwargs with length 2 lists of items. 47 | """ 48 | 49 | for key, item in kwargs.items(): 50 | kwargs[key] = extend_list(item) 51 | 52 | def imshow_two(path, imgs, sky=False, suptitle=None, **kwargs): 53 | """Plot two images side by side, with scalebars. 54 | 55 | imgs is a list 56 | Parameters 57 | ---------- 58 | path : string 59 | path and filename to save figure 60 | imgs : list 61 | length-2 list of images to plot. Arguments are designed to be very permissive. If the image is a PyTorch tensor, the routine converts it to numpy, and then numpy.squeeze is called. 62 | sky: bool 63 | If True, treat images as sky plots and label with offset arcseconds. 64 | title: list 65 | if provided, list of strings corresponding to title for each subplot. If only one provided, 66 | xlabel: list 67 | if provided, list of strings 68 | 69 | 70 | Returns 71 | ------- 72 | None 73 | """ 74 | 75 | xx = 7.5 # in 76 | rmargin = 0.8 77 | lmargin = 0.8 78 | tmargin = 0.3 if suptitle is None else 0.5 79 | bmargin = 0.5 80 | middle_sep = 1.3 81 | ax_width = (xx - rmargin - lmargin - middle_sep) / 2 82 | ax_height = ax_width 83 | cax_width = 0.1 84 | cax_sep = 0.15 85 | cax_height = ax_height 86 | yy = bmargin + ax_height + tmargin 87 | 88 | with mpl.rc_context({'figure.autolayout': False}): 89 | fig = plt.figure(figsize=(xx, yy)) 90 | 91 | ax = [] 92 | cax = [] 93 | 94 | extend_kwargs(kwargs) 95 | 96 | if "extent" not in kwargs: 97 | kwargs["extent"] = [None, None] 98 | 99 | for i in [0, 1]: 100 | a = fig.add_axes( 101 | [ 102 | (lmargin + i * (ax_width + middle_sep)) / xx, 103 | bmargin / yy, 104 | ax_width / xx, 105 | ax_height / yy, 106 | ] 107 | ) 108 | ax.append(a) 109 | 110 | ca = fig.add_axes( 111 | ( 112 | [ 113 | (lmargin + (i + 1) * ax_width + i * middle_sep + cax_sep) / xx, 114 | bmargin / yy, 115 | cax_width / xx, 116 | cax_height / yy, 117 | ] 118 | ) 119 | ) 120 | cax.append(ca) 121 | 122 | img = imgs[i] 123 | img = img.detach().numpy() if torch.is_tensor(img) else img 124 | 125 | im = a.imshow(np.squeeze(img), origin="lower", interpolation="none", extent=kwargs["extent"][i]) 126 | plt.colorbar(im, cax=ca) 127 | 128 | if "title" in kwargs: 129 | a.set_title(kwargs["title"][i]) 130 | 131 | if sky: 132 | a.set_xlabel(r"$\Delta \alpha\ \cos \delta\;[{}^{\prime\prime}]$") 133 | a.set_ylabel(r"$\Delta \delta\;[{}^{\prime\prime}]$") 134 | else: 135 | if "xlabel" in kwargs: 136 | a.set_xlabel(kwargs["xlabel"][i]) 137 | 138 | if "ylabel" in kwargs: 139 | a.set_ylabel(kwargs["ylabel"][i]) 140 | 141 | if suptitle is not None: 142 | fig.suptitle(suptitle) 143 | fig.savefig(path) 144 | plt.close("all") 145 | -------------------------------------------------------------------------------- /test/train_test_test.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import torch.optim 5 | from mpol import losses, precomposed 6 | 7 | # from mpol.plot import train_diagnostics_fig 8 | # from mpol.training import TrainTest, train_to_dirty_image 9 | from mpol.utils import torch2npy 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | # def test_traintestclass_training(coords, imager, dataset, generic_parameters): 13 | # # using the TrainTest class, run a training loop without regularizers 14 | # nchan = dataset.nchan 15 | # model = precomposed.GriddedNet(coords=coords, nchan=nchan) 16 | 17 | # train_pars = generic_parameters["train_pars"] 18 | 19 | # # no regularizers 20 | # train_pars["regularizers"] = {} 21 | 22 | # learn_rate = generic_parameters["crossval_pars"]["learn_rate"] 23 | 24 | # optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) 25 | 26 | # trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) 27 | # loss, loss_history = trainer.train(model, dataset) 28 | 29 | 30 | # def test_traintestclass_training_scheduler(coords, imager, dataset, generic_parameters): 31 | # # using the TrainTest class, run a training loop with regularizers, 32 | # # using the learning rate scheduler 33 | # nchan = dataset.nchan 34 | # model = precomposed.GriddedNet(coords=coords, nchan=nchan) 35 | 36 | # train_pars = generic_parameters["train_pars"] 37 | 38 | # learn_rate = generic_parameters["crossval_pars"]["learn_rate"] 39 | 40 | # optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) 41 | 42 | # # use a scheduler 43 | # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.995) 44 | # train_pars["scheduler"] = scheduler 45 | 46 | # trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) 47 | # loss, loss_history = trainer.train(model, dataset) 48 | 49 | 50 | # def test_traintestclass_training_guess(coords, imager, dataset, generic_parameters): 51 | # # using the TrainTest class, run a training loop with regularizers, 52 | # # with a call to the regularizer strength guesser 53 | # nchan = dataset.nchan 54 | # model = precomposed.GriddedNet(coords=coords, nchan=nchan) 55 | 56 | # train_pars = generic_parameters["train_pars"] 57 | 58 | # learn_rate = generic_parameters["crossval_pars"]["learn_rate"] 59 | 60 | # train_pars['regularizers']['entropy']['guess'] = True 61 | 62 | # optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) 63 | 64 | # trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) 65 | # loss, loss_history = trainer.train(model, dataset) 66 | 67 | 68 | # def test_traintestclass_train_diagnostics_fig(coords, imager, dataset, generic_parameters, tmp_path): 69 | # # using the TrainTest class, run a training loop, 70 | # # and generate the train diagnostics figure 71 | # nchan = dataset.nchan 72 | # model = precomposed.GriddedNet(coords=coords, nchan=nchan) 73 | 74 | # train_pars = generic_parameters["train_pars"] 75 | # # bypass TrainTest.loss_lambda_guess 76 | # train_pars["regularizers"] = {} 77 | 78 | # learn_rate = generic_parameters["crossval_pars"]["learn_rate"] 79 | 80 | # optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) 81 | 82 | # trainer = TrainTest(imager=imager, optimizer=optimizer, **train_pars) 83 | # loss, loss_history = trainer.train(model, dataset) 84 | 85 | # learn_rates = np.repeat(learn_rate, len(loss_history)) 86 | 87 | # old_mod_im = torch2npy(model.icube.sky_cube[0]) 88 | 89 | # train_fig, train_axes = train_diagnostics_fig(model, 90 | # losses=loss_history, 91 | # learn_rates=learn_rates, 92 | # fluxes=np.zeros(len(loss_history)), 93 | # old_model_image=old_mod_im 94 | # ) 95 | # train_fig.savefig(tmp_path / "train_diagnostics_fig.png", dpi=300) 96 | # plt.close("all") 97 | 98 | 99 | # def test_traintestclass_testing(coords, imager, dataset, generic_parameters): 100 | # # using the TrainTest class, perform a call to test 101 | # nchan = dataset.nchan 102 | # model = precomposed.GriddedNet(coords=coords, nchan=nchan) 103 | 104 | # learn_rate = generic_parameters["crossval_pars"]["learn_rate"] 105 | 106 | # optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate) 107 | 108 | # trainer = TrainTest(imager=imager, optimizer=optimizer) 109 | # trainer.test(model, dataset) 110 | 111 | 112 | def test_standalone_init_train(coords, dataset): 113 | # not using TrainTest class, 114 | # configure a class to train with and test that it initializes 115 | 116 | nchan = dataset.nchan 117 | rml = precomposed.GriddedNet(coords=coords, nchan=nchan) 118 | 119 | vis = rml() 120 | 121 | rml.zero_grad() 122 | 123 | # calculate a loss 124 | loss = losses.r_chi_squared_gridded(vis, dataset) 125 | 126 | # calculate gradients of parameters 127 | loss.backward() 128 | 129 | print(rml.bcube.base_cube.grad) 130 | 131 | 132 | def test_standalone_train_loop(coords, dataset_cont, tmp_path): 133 | # not using TrainTest class, 134 | # set everything up to run on a single channel 135 | # and run a few iterations 136 | 137 | nchan = 1 138 | rml = precomposed.GriddedNet(coords=coords, nchan=nchan) 139 | 140 | optimizer = torch.optim.SGD(rml.parameters(), lr=0.001) 141 | 142 | for i in range(50): 143 | rml.zero_grad() 144 | 145 | # get the predicted model 146 | vis = rml() 147 | 148 | # calculate a loss 149 | loss = losses.r_chi_squared_gridded(vis, dataset_cont) 150 | 151 | # calculate gradients of parameters 152 | loss.backward() 153 | 154 | # update the model parameters 155 | optimizer.step() 156 | 157 | # let's see what one channel of the image looks like 158 | fig, ax = plt.subplots(nrows=1) 159 | ax.imshow( 160 | np.squeeze(torch2npy(rml.icube.packed_cube)), 161 | origin="lower", 162 | interpolation="none", 163 | extent=rml.icube.coords.img_ext, 164 | ) 165 | fig.savefig(tmp_path / "trained.png", dpi=300) 166 | plt.close("all") 167 | 168 | 169 | # def test_train_to_dirty_image(coords, dataset, imager): 170 | # # run a training loop against a dirty image 171 | # nchan = dataset.nchan 172 | # model = precomposed.GriddedNet(coords=coords, nchan=nchan) 173 | 174 | # train_to_dirty_image(model, imager, niter=10) 175 | 176 | 177 | def test_tensorboard(coords, dataset_cont): 178 | # not using TrainTest class, 179 | # set everything up to run on a single channel and then 180 | # test the writer function 181 | 182 | nchan = 1 183 | rml = precomposed.GriddedNet(coords=coords, nchan=nchan) 184 | 185 | optimizer = torch.optim.SGD(rml.parameters(), lr=0.001) 186 | 187 | writer = SummaryWriter() 188 | 189 | for i in range(50): 190 | rml.zero_grad() 191 | 192 | # get the predicted model 193 | vis = rml() 194 | 195 | # calculate a loss 196 | loss = losses.r_chi_squared_gridded(vis, dataset_cont) 197 | 198 | writer.add_scalar("loss", loss.item(), i) 199 | 200 | # calculate gradients of parameters 201 | loss.backward() 202 | 203 | # update the model parameters 204 | optimizer.step() 205 | -------------------------------------------------------------------------------- /test/utils_test.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pytest 4 | from mpol import coordinates, utils 5 | 6 | 7 | @pytest.fixture 8 | def imagekw(): 9 | return { 10 | "a": 1, 11 | "delta_x": 0.3, 12 | "delta_y": 0.1, 13 | "sigma_x": 0.3, 14 | "sigma_y": 0.1, 15 | "Omega": 20, 16 | } 17 | 18 | 19 | def test_sky_gaussian(imagekw, tmp_path): 20 | coords = coordinates.GridCoords(cell_size=0.005, npix=800) 21 | 22 | ikw = {"origin": "lower"} 23 | 24 | g = utils.sky_gaussian_arcsec( 25 | coords.sky_x_centers_2D, coords.sky_y_centers_2D, **imagekw 26 | ) 27 | 28 | fig, ax = plt.subplots(nrows=1, ncols=1) 29 | im = ax.imshow(g, **ikw, extent=coords.img_ext) 30 | plt.colorbar(im, ax=ax) 31 | fig.savefig(tmp_path / "sky_gauss_2D.png", dpi=300) 32 | 33 | 34 | def test_packed_gaussian(imagekw, tmp_path): 35 | coords = coordinates.GridCoords(cell_size=0.005, npix=800) 36 | 37 | ikw = {"origin": "lower"} 38 | 39 | g = utils.sky_gaussian_arcsec( 40 | coords.packed_x_centers_2D, coords.packed_y_centers_2D, **imagekw 41 | ) 42 | 43 | fig, ax = plt.subplots(nrows=1, ncols=1) 44 | im = ax.imshow(g, **ikw) 45 | plt.colorbar(im, ax=ax) 46 | fig.savefig(tmp_path / "packed_gauss_2D.png", dpi=300) 47 | 48 | 49 | def test_analytic_plot(tmp_path): 50 | # plot the analytic Gaussian and its Fourier transform 51 | 52 | kw = { 53 | "a": 1, 54 | "delta_x": 0.02, # arcsec 55 | "delta_y": -0.01, 56 | "sigma_x": 0.02, 57 | "sigma_y": 0.01, 58 | "Omega": 20, # degrees 59 | } 60 | 61 | coords = coordinates.GridCoords(cell_size=0.005, npix=800) 62 | 63 | img = utils.sky_gaussian_arcsec( 64 | coords.sky_x_centers_2D, coords.sky_y_centers_2D, **kw 65 | ) # Jy/arcsec^2 66 | 67 | fig, ax = plt.subplots(nrows=1) 68 | im = ax.imshow(img, origin="lower") 69 | ax.set_xlabel("axis2 index") 70 | ax.set_ylabel("axis1 index") 71 | plt.colorbar(im, ax=ax, label=r"$Jy/\mathrm{arcsec}^2$") 72 | fig.savefig(tmp_path / "gaussian_sky.png", dpi=300) 73 | 74 | img_packed = utils.sky_gaussian_arcsec( 75 | coords.packed_x_centers_2D, coords.packed_y_centers_2D, **kw 76 | ) # Jy/arcsec^2 77 | 78 | fig, ax = plt.subplots(nrows=1) 79 | ax.imshow(img_packed, origin="lower") 80 | ax.set_xlabel("axis2 index") 81 | ax.set_ylabel("axis1 index") 82 | fig.savefig(tmp_path / "gaussian_packed.png", dpi=300) 83 | 84 | # calculated the packed FFT 85 | fourier_packed_num = coords.cell_size**2 * np.fft.fft2(img_packed) 86 | 87 | # calculate the analytical FFT 88 | fourier_packed_an = utils.fourier_gaussian_lambda_arcsec( 89 | coords.packed_u_centers_2D, coords.packed_v_centers_2D, **kw 90 | ) 91 | 92 | ikw = {"origin": "lower", "interpolation": "none"} 93 | 94 | fig, ax = plt.subplots(nrows=3, ncols=2, figsize=(6, 8)) 95 | im = ax[0, 0].imshow(fourier_packed_an.real, **ikw) 96 | plt.colorbar(im, ax=ax[0, 0]) 97 | ax[0, 0].set_title("real") 98 | ax[0, 0].set_ylabel("analytical") 99 | im = ax[0, 1].imshow(fourier_packed_an.imag, **ikw) 100 | plt.colorbar(im, ax=ax[0, 1]) 101 | ax[0, 1].set_title("imag") 102 | 103 | im = ax[1, 0].imshow(fourier_packed_num.real, **ikw) 104 | plt.colorbar(im, ax=ax[1, 0]) 105 | ax[1, 0].set_ylabel("numerical") 106 | im = ax[1, 1].imshow(fourier_packed_num.imag, **ikw) 107 | plt.colorbar(im, ax=ax[1, 1]) 108 | 109 | diff_real = fourier_packed_an.real - fourier_packed_num.real 110 | diff_imag = fourier_packed_an.imag - fourier_packed_num.imag 111 | im = ax[2, 0].imshow(diff_real, **ikw) 112 | ax[2, 0].set_ylabel("difference") 113 | plt.colorbar(im, ax=ax[2, 0]) 114 | im = ax[2, 1].imshow(diff_imag, **ikw) 115 | plt.colorbar(im, ax=ax[2, 1]) 116 | 117 | fig.savefig(tmp_path / "fourier_packed.png", dpi=300) 118 | 119 | assert np.all(np.abs(diff_real) < 1e-12) 120 | assert np.all(np.abs(diff_imag) < 1e-12) 121 | 122 | 123 | def test_loglinspace(): 124 | # test that our log linspace routine calculates the correct spacing 125 | array = utils.loglinspace(0, 10, 5, 3) 126 | print(array) 127 | print(np.diff(array)) 128 | assert len(array) == 5 + 3 129 | 130 | 131 | def test_get_optimal_image_properties(baselines_1D): 132 | # test that get_optimal_image_properties returns sensible cell_size, npix 133 | image_width = 5.0 # [arcsec] 134 | 135 | u, v = baselines_1D 136 | 137 | cell_size, npix = utils.get_optimal_image_properties(image_width, u, v) 138 | 139 | max_data_freq = max(abs(u).max(), abs(v).max()) 140 | 141 | assert(utils.get_max_spatial_freq(cell_size, npix) >= max_data_freq) 142 | --------------------------------------------------------------------------------