├── .all-contributorsrc ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── BUG-REPORT.yml │ ├── FEATURE-REQUEST.yml │ └── config.yml ├── pull_request_template.md └── workflows │ ├── docs.yml │ ├── publish.yml │ ├── style.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── deepsensor ├── __init__.py ├── active_learning │ ├── __init__.py │ ├── acquisition_fns.py │ └── algorithms.py ├── config.py ├── data │ ├── __init__.py │ ├── loader.py │ ├── processor.py │ ├── sources.py │ ├── task.py │ └── utils.py ├── errors.py ├── eval │ ├── __init__.py │ └── metrics.py ├── model │ ├── __init__.py │ ├── convnp.py │ ├── defaults.py │ ├── model.py │ ├── nps.py │ └── pred.py ├── plot.py ├── py.typed ├── tensorflow │ └── __init__.py ├── torch │ └── __init__.py └── train │ ├── __init__.py │ └── train.py ├── docs ├── _config.yml ├── _static │ ├── index_api.svg │ ├── index_community.svg │ ├── index_community2.pdf │ ├── index_community2.png │ ├── index_contribute.svg │ ├── index_getting_started.svg │ └── index_user_guide.svg ├── _toc.yml ├── community │ ├── code_of_conduct.md │ ├── contributing.md │ ├── faq.md │ ├── index.md │ └── roadmap.md ├── contact.md ├── getting-started │ ├── data_requirements.ipynb │ ├── index.md │ ├── installation.md │ └── overview.md ├── index.md ├── reference │ ├── active_learning │ │ ├── acquisition_fns.rst │ │ ├── algorithms.rst │ │ └── index.md │ ├── data │ │ ├── index.md │ │ ├── loader.rst │ │ ├── processor.rst │ │ ├── sources.rst │ │ ├── task.rst │ │ └── utils.rst │ ├── index.md │ ├── model │ │ ├── convnp.rst │ │ ├── defaults.rst │ │ ├── index.md │ │ ├── model.rst │ │ ├── nps.rst │ │ └── pred.rst │ ├── plot.rst │ ├── tensorflow │ │ └── index.rst │ ├── torch │ │ └── index.rst │ └── train │ │ ├── index.md │ │ └── train.rst ├── references.bib ├── research_ideas.md ├── resources.md └── user-guide │ ├── acquisition_functions.ipynb │ ├── active_learning.ipynb │ ├── convnp.ipynb │ ├── data_processor.ipynb │ ├── deepsensor_design.md │ ├── extending.ipynb │ ├── index.md │ ├── prediction.ipynb │ ├── task.ipynb │ ├── task_loader.ipynb │ └── training.ipynb ├── figs ├── DeepSensorLogo.png ├── DeepSensorLogo2.png ├── convnp_arch.png ├── deepsensor_application_examples.png └── deepsensor_design.png ├── pyproject.toml ├── tests ├── __init__.py ├── test_active_learning.py ├── test_data_processor.py ├── test_model.py ├── test_plotting.py ├── test_task.py ├── test_task_loader.py ├── test_training.py └── utils.py └── tox.ini /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "projectName": "deepsensor", 3 | "projectOwner": "alan-turing-institute", 4 | "skipCI": true, 5 | "files": [ 6 | "README.md" 7 | ], 8 | "commitType": "docs", 9 | "commitConvention": "angular", 10 | "contributorsPerLine": 7, 11 | "contributorsSortAlphabetically": true, 12 | "contributors": [ 13 | { 14 | "login": "kallewesterling", 15 | "name": "Kalle Westerling", 16 | "avatar_url": "https://avatars.githubusercontent.com/u/7298727?v=4", 17 | "profile": "http://www.westerling.nu", 18 | "contributions": [ 19 | "doc", 20 | "infra", 21 | "ideas", 22 | "projectManagement", 23 | "promotion", 24 | "question" 25 | ] 26 | }, 27 | { 28 | "login": "tom-andersson", 29 | "name": "Tom Andersson", 30 | "avatar_url": "https://avatars.githubusercontent.com/u/26459412?v=4", 31 | "profile": "https://www.bas.ac.uk/profile/tomand", 32 | "contributions": [ 33 | "code", 34 | "research", 35 | "maintenance", 36 | "bug", 37 | "test", 38 | "tutorial", 39 | "doc", 40 | "review", 41 | "talk", 42 | "question" 43 | ] 44 | }, 45 | { 46 | "login": "acocac", 47 | "name": "Alejandro ©", 48 | "avatar_url": "https://avatars.githubusercontent.com/u/13321552?v=4", 49 | "profile": "https://github.com/acocac", 50 | "contributions": [ 51 | "userTesting", 52 | "bug", 53 | "mentoring", 54 | "ideas", 55 | "research", 56 | "code", 57 | "test" 58 | ] 59 | }, 60 | { 61 | "login": "wesselb", 62 | "name": "Wessel", 63 | "avatar_url": "https://avatars.githubusercontent.com/u/1444448?v=4", 64 | "profile": "http://wessel.ai", 65 | "contributions": [ 66 | "research", 67 | "code", 68 | "ideas" 69 | ] 70 | }, 71 | { 72 | "login": "scotthosking", 73 | "name": "Scott Hosking", 74 | "avatar_url": "https://avatars.githubusercontent.com/u/10783052?v=4", 75 | "profile": "https://scotthosking.com", 76 | "contributions": [ 77 | "fundingFinding", 78 | "ideas", 79 | "projectManagement" 80 | ] 81 | }, 82 | { 83 | "login": "patel-zeel", 84 | "name": "Zeel B Patel", 85 | "avatar_url": "https://avatars.githubusercontent.com/u/59758528?v=4", 86 | "profile": "http://patel-zeel.github.io", 87 | "contributions": [ 88 | "bug", 89 | "code", 90 | "userTesting", 91 | "ideas" 92 | ] 93 | }, 94 | { 95 | "login": "jonas-scholz123", 96 | "name": "Jonas Scholz", 97 | "avatar_url": "https://avatars.githubusercontent.com/u/37850411?v=4", 98 | "profile": "https://github.com/jonas-scholz123", 99 | "contributions": [ 100 | "userTesting", 101 | "research", 102 | "code", 103 | "bug", 104 | "ideas" 105 | ] 106 | }, 107 | { 108 | "login": "nilsleh", 109 | "name": "Nils Lehmann", 110 | "avatar_url": "https://avatars.githubusercontent.com/u/35272119?v=4", 111 | "profile": "https://nilsleh.info/", 112 | "contributions": [ 113 | "ideas", 114 | "userTesting", 115 | "bug" 116 | ] 117 | }, 118 | { 119 | "login": "kenzaxtazi", 120 | "name": "Kenza Tazi", 121 | "avatar_url": "https://avatars.githubusercontent.com/u/43008274?v=4", 122 | "profile": "http://kenzaxtazi.github.io", 123 | "contributions": [ 124 | "ideas" 125 | ] 126 | }, 127 | { 128 | "login": "polpel", 129 | "name": "Paolo Pelucchi", 130 | "avatar_url": "https://avatars.githubusercontent.com/u/56694450?v=4", 131 | "profile": "https://github.com/polpel", 132 | "contributions": [ 133 | "userTesting", 134 | "bug" 135 | ] 136 | }, 137 | { 138 | "login": "RohitRathore1", 139 | "name": "Rohit Singh Rathaur", 140 | "avatar_url": "https://avatars.githubusercontent.com/u/42641738?v=4", 141 | "profile": "https://rohitrathore.netlify.app/", 142 | "contributions": [ 143 | "code" 144 | ] 145 | }, 146 | { 147 | "login": "magnusross", 148 | "name": "Magnus Ross", 149 | "avatar_url": "https://avatars.githubusercontent.com/u/51709759?v=4", 150 | "profile": "http://magnusross.github.io/about", 151 | "contributions": [ 152 | "tutorial", 153 | "data" 154 | ] 155 | }, 156 | { 157 | "login": "annavaughan", 158 | "name": "Anna Vaughan", 159 | "avatar_url": "https://avatars.githubusercontent.com/u/45528489?v=4", 160 | "profile": "https://github.com/annavaughan", 161 | "contributions": [ 162 | "research" 163 | ] 164 | }, 165 | { 166 | "login": "ots22", 167 | "name": "ots22", 168 | "avatar_url": "https://avatars.githubusercontent.com/u/5434836?v=4", 169 | "profile": "https://github.com/ots22", 170 | "contributions": [ 171 | "ideas" 172 | ] 173 | }, 174 | { 175 | "login": "JimCircadian", 176 | "name": "Jim Circadian", 177 | "avatar_url": "https://avatars.githubusercontent.com/u/731727?v=4", 178 | "profile": "http://inconsistentrecords.co.uk", 179 | "contributions": [ 180 | "ideas", 181 | "projectManagement", 182 | "maintenance" 183 | ] 184 | }, 185 | { 186 | "login": "davidwilby", 187 | "name": "David Wilby", 188 | "avatar_url": "https://avatars.githubusercontent.com/u/24752124?v=4", 189 | "profile": "http://davidwilby.dev", 190 | "contributions": [ 191 | "doc", 192 | "test", 193 | "maintenance", 194 | "bug" 195 | ] 196 | }, 197 | { 198 | "login": "vinayakrana", 199 | "name": "vinayakrana", 200 | "avatar_url": "https://avatars.githubusercontent.com/u/95575600?v=4", 201 | "profile": "https://github.com/vinayakrana", 202 | "contributions": [ 203 | "doc" 204 | ] 205 | }, 206 | { 207 | "login": "DaniJonesOcean", 208 | "name": "Dani Jones", 209 | "avatar_url": "https://avatars.githubusercontent.com/u/11757453?v=4", 210 | "profile": "https://github.com/DaniJonesOcean", 211 | "contributions": [ 212 | "bug" 213 | ] 214 | }, 215 | { 216 | "login": "holzwolf", 217 | "name": "holzwolf", 218 | "avatar_url": "https://avatars.githubusercontent.com/u/135216528?v=4", 219 | "profile": "https://github.com/holzwolf", 220 | "contributions": [ 221 | "bug" 222 | ] 223 | } 224 | ] 225 | } 226 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | notebooks/** linguist-vendored 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/BUG-REPORT.yml: -------------------------------------------------------------------------------- 1 | name: "🐛 Bug Report" 2 | description: Report a suspected bug in the DeepSensor package. 3 | title: "Potential bug: " 4 | labels: [ 5 | "bug" 6 | ] 7 | body: 8 | - type: textarea 9 | id: description 10 | attributes: 11 | label: "Description" 12 | description: Please enter a specific description of your issue 13 | placeholder: Short and specific description of your incident... 14 | validations: 15 | required: true 16 | - type: textarea 17 | id: reprod 18 | attributes: 19 | label: "Reproduction steps" 20 | description: Please enter the steps that you took so we can try to reproduce the issue. 21 | value: | 22 | 1. Go to '...' 23 | 2. Click on '....' 24 | 3. Scroll down to '....' 25 | 4. See error 26 | render: bash 27 | validations: 28 | required: true 29 | - type: textarea 30 | id: version 31 | attributes: 32 | label: "Version" 33 | description: Please tell us which version of DeepSensor you're using. 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: screenshot 38 | attributes: 39 | label: "Screenshots" 40 | description: If applicable, add screenshots to help explain your problem. 41 | value: | 42 | ![DESCRIPTION](LINK.png) 43 | render: bash 44 | validations: 45 | required: false 46 | - type: dropdown 47 | id: os 48 | attributes: 49 | label: "OS" 50 | description: Which OS are you using? 51 | multiple: true 52 | options: 53 | - Windows 54 | - Linux 55 | - Mac 56 | validations: 57 | required: false -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/FEATURE-REQUEST.yml: -------------------------------------------------------------------------------- 1 | name: "💡 Feature Request" 2 | description: Submit a new feature request for the DeepSensor package. 3 | title: "Potential new feature: <title>" 4 | labels: [ 5 | "enhancement" 6 | ] 7 | body: 8 | - type: textarea 9 | id: summary 10 | attributes: 11 | label: "Summary" 12 | description: Provide a brief explanation of the feature. 13 | placeholder: Describe your feature request in a few lines. 14 | validations: 15 | required: true 16 | - type: textarea 17 | id: basic_example 18 | attributes: 19 | label: "Basic Example" 20 | description: Indicate here some basic examples of your feature. 21 | placeholder: A few specific words about your feature request. 22 | validations: 23 | required: true 24 | - type: textarea 25 | id: drawbacks 26 | attributes: 27 | label: "Drawbacks" 28 | description: What are the drawbacks/impacts of your feature request? 29 | placeholder: Identify the drawbacks and impacts while being neutral on your feature request. 30 | validations: 31 | required: true 32 | - type: textarea 33 | id: unresolved_question 34 | attributes: 35 | label: "Unresolved questions" 36 | description: What questions still remain unresolved? 37 | placeholder: Identify any unresolved issues. 38 | validations: 39 | required: false 40 | - type: textarea 41 | id: implementation_pr 42 | attributes: 43 | label: "Implementation PR" 44 | description: Pull request used (if you have submitted a pull request). 45 | placeholder: "#<Pull Request ID>" 46 | validations: 47 | required: false 48 | - type: textarea 49 | id: reference_issues 50 | attributes: 51 | label: "Reference Issues" 52 | description: Issues that the new feature might resolve (if it addresses any existing issues). 53 | placeholder: "#<Issues IDs>" 54 | validations: 55 | required: false -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Slack channel 4 | url: https://ai4environment.slack.com/archives/C05NQ76L87R 5 | about: Join our Slack channel to ask and answer questions about DeepSensor. 6 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | <!-- Thank you for your contribution to DeepSensor! --> 2 | <!-- Please fill out the details in the pull request description below as appropriate --> 3 | <!-- Ensure that you abide by the code of conduct --> 4 | 5 | ## :pencil: Description 6 | <!-- Please provide a clear description of the changes that are introduced in this pull request and why --> 7 | <!-- It also helps to explain any justification or options you're aware of --> 8 | <!-- If this pull request relates to an issue, please include the issue number here, e.g. #42 --> 9 | 10 | 11 | ## :white_check_mark: Checklist before requesting a review 12 | (See the contributing guide for more details on these steps.) 13 | - [ ] I have installed developer dependencies with `pip install .[dev]` and running `pre-commit install` (or alternatively, manually running `ruff format` before commiting) 14 | 15 | If changing or adding source code: 16 | - [ ] tests are included and are passing (run `pytest`). 17 | - [ ] documentation is included or updated as relevant, including docstrings. 18 | 19 | If changing or adding documentation: 20 | - [ ] docs build successfully (`jupyter-book build docs --all`) and the changes look good from a manual inspection of the HTML in `docs/_build/html/`. 21 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | on: [push, pull_request, workflow_dispatch] 3 | permissions: 4 | contents: write 5 | jobs: 6 | docs: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - uses: actions/setup-python@v5 11 | with: 12 | python-version: 3.11 13 | - uses: nikeee/setup-pandoc@v1 14 | - name: "Upgrade pip" 15 | run: pip install --upgrade pip 16 | - name: "Install dependencies" 17 | run: pip install -e .[docs] 18 | - name: "Run jupyterbook" 19 | run: jupyter-book build docs --all 20 | - name: "Deploy" 21 | uses: peaceiris/actions-gh-pages@v4 22 | if: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/docs') }} 23 | with: 24 | publish_branch: gh-pages 25 | github_token: ${{ secrets.GITHUB_TOKEN }} 26 | publish_dir: docs/_build/html/ 27 | force_orphan: true 28 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python package using Twine when a release is 2 | # created. For more information see the following link: 3 | # https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 4 | 5 | name: Publish to PyPI 6 | 7 | on: 8 | release: 9 | types: [published] 10 | 11 | jobs: 12 | deploy: 13 | name: Build and publish Python 🐍 distributions 📦 to PyPI and TestPyPI 14 | runs-on: ubuntu-latest 15 | 16 | # https://github.com/pypa/gh-action-pypi-publish#trusted-publishing 17 | permissions: 18 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 19 | 20 | environment: 21 | name: pypi 22 | url: https://pypi.org/p/deepsensor 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up Python 27 | uses: actions/setup-python@v4 28 | with: 29 | python-version: "3.x" 30 | 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install -U setuptools wheel twine build 35 | 36 | - name: Build and publish 37 | run: python -m build 38 | 39 | - name: Publish distribution 📦 to PyPI 40 | uses: pypa/gh-action-pypi-publish@release/v1 41 | with: 42 | verbose: true -------------------------------------------------------------------------------- /.github/workflows/style.yml: -------------------------------------------------------------------------------- 1 | name: Code Style 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: [3.8] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install Ruff 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install ruff 23 | - name: Check style 24 | run: | 25 | ruff check -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | pull_request: 6 | workflow_dispatch: 7 | schedule: 8 | - cron: "5 0 * * 2" # run weekly tests 9 | 10 | jobs: 11 | test: 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ubuntu-latest] 16 | python-version: ['3.8', '3.11'] 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Free Disk Space (Ubuntu) 25 | uses: jlumbroso/free-disk-space@main 26 | - name: Print space 27 | run: df -h 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install .[dev,testing] 32 | pip install tox-gh-actions 33 | - name: Test with tox 34 | run: tox 35 | - name: Run coveralls 36 | run: coveralls --service=github 37 | env: 38 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 39 | COVERALLS_FLAG_NAME: ${{ matrix.test-name }} 40 | COVERALLS_PARALLEL: true 41 | 42 | 43 | finish: 44 | name: Finish Coveralls 45 | needs: test 46 | runs-on: ubuntu-latest 47 | steps: 48 | - name: Finish Coveralls 49 | uses: coverallsapp/github-action@v2 50 | with: 51 | github-token: ${{ secrets.github_token }} 52 | parallel-finished: true 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | *.xml 3 | *.pyc 4 | .vscode/ 5 | .DS_Store 6 | .tox-info.* 7 | .coverage 8 | build/ 9 | dist/* 10 | .tox/ 11 | _build 12 | *.png 13 | deepsensor.egg-info/ 14 | .venv/ 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | # Ruff version. 4 | rev: v0.7.0 5 | hooks: 6 | # Run the linter. 7 | - id: ruff 8 | # Run the formatter. 9 | - id: ruff-format -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | # This CITATION.cff file was generated with cffinit. 2 | # Visit https://bit.ly/cffinit to generate yours today! 3 | 4 | cff-version: 1.2.0 5 | title: 'DeepSensor: A Python package for modelling environmental data with convolutional neural processes' 6 | message: >- 7 | If you use DeepSensor in your research, please cite it 8 | using the information below. 9 | type: software 10 | authors: 11 | - given-names: Tom Robin 12 | family-names: Andersson 13 | email: tomandersson3@gmail.com 14 | affiliation: Google DeepMind 15 | orcid: 'https://orcid.org/0000-0002-1556-9932' 16 | repository-code: 'https://github.com/alan-turing-institute/deepsensor' 17 | abstract: >- 18 | DeepSensor is a Python package for modelling environmental 19 | data with convolutional neural processes (ConvNPs). 20 | ConvNPs are versatile deep learning models capable of 21 | ingesting multiple environmental data streams of varying 22 | modalities and resolutions, handling missing data, and 23 | predicting at arbitrary target locations with uncertainty. 24 | DeepSensor allows users to tackle a diverse array of 25 | environmental prediction tasks, including downscaling 26 | (super-resolution), sensor placement, gap-filling, and 27 | forecasting. The library includes a user-friendly 28 | pandas/xarray interface, automatic unnormalisation of 29 | model predictions, active learning functionality, 30 | integration with both PyTorch and TensorFlow, and model 31 | customisation. DeepSensor streamlines and simplifies the 32 | environmental data modelling pipeline, enabling 33 | researchers and practitioners to harness the potential of 34 | ConvNPs for complex environmental prediction challenges. 35 | keywords: 36 | - machine learning 37 | - environmental science 38 | - neural processes 39 | - active learning 40 | license: MIT 41 | version: 0.4.2 42 | date-released: '2024-10-20' 43 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | We want DeepSensor to be an open, welcoming, diverse, inclusive, and healthy community. 4 | We do not tolerate rude or disrespectful behavior toward anyone. 5 | For most people, this should be enough information to just know what we mean. 6 | However, for more specific expectations, see our code of conduct below. 7 | 8 | ## Our Pledge 9 | 10 | We as members, contributors, and leaders pledge to make participation in our community a 11 | harassment-free experience for everyone, regardless of age, body size, visible or invisible 12 | disability, ethnicity, sex characteristics, gender identity and expression, level of experience, 13 | education, socio-economic status, nationality, personal appearance, race, religion, or sexual 14 | identity and orientation. 15 | 16 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and 17 | healthy community. 18 | 19 | ## Our Standards 20 | 21 | Examples of behavior that contributes to a positive environment for our community include: 22 | 23 | - Demonstrating empathy and kindness toward other people 24 | - Being respectful of differing opinions, viewpoints, and experiences 25 | - Giving and gracefully accepting constructive feedback 26 | - Accepting responsibility and apologizing to those affected by our mistakes, and learning from the 27 | experience 28 | - Focusing on what is best not just for us as individuals, but for the overall community 29 | 30 | Examples of unacceptable behavior include: 31 | 32 | - The use of sexualized language or imagery, and sexual attention or advances of any kind 33 | - Trolling, insulting or derogatory comments, and personal or political attacks 34 | - Public or private harassment 35 | - Publishing others' private information, such as a physical or email address, without their 36 | explicit permission 37 | - Other conduct which could reasonably be considered inappropriate in a professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior 42 | and will take appropriate and fair corrective action in response to any behavior that they deem 43 | inappropriate, threatening, offensive, or harmful. 44 | 45 | Community leaders have the right and responsibility to remove, edit, or reject comments, commits, 46 | code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and 47 | will communicate reasons for moderation decisions when appropriate. 48 | 49 | ## Scope 50 | 51 | This Code of Conduct applies within all community spaces, and also applies when an individual is 52 | officially representing the community in public spaces. Examples of representing our community 53 | include using an official e-mail address, posting via an official social media account, or acting as 54 | an appointed representative at an online or offline event. 55 | 56 | ## Enforcement 57 | 58 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community 59 | leaders responsible for enforcement via tomandersson3@gmail.com. All complaints will be reviewed and 60 | investigated promptly and fairly. 61 | 62 | All community leaders are obligated to respect the privacy and security of the reporter of any 63 | incident. 64 | 65 | ## Enforcement Guidelines 66 | 67 | Community leaders will follow these Community Impact Guidelines in determining the consequences for 68 | any action they deem in violation of this Code of Conduct: 69 | 70 | ### 1. Correction 71 | 72 | **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or 73 | unwelcome in the community. 74 | 75 | **Consequence**: A private, written warning from community leaders, providing clarity around the 76 | nature of the violation and an explanation of why the behavior was inappropriate. A public apology 77 | may be requested. 78 | 79 | ### 2. Warning 80 | 81 | **Community Impact**: A violation through a single incident or series of actions. 82 | 83 | **Consequence**: A warning with consequences for continued behavior. No interaction with the people 84 | involved, including unsolicited interaction with those enforcing the Code of Conduct, for a 85 | specified period of time. This includes avoiding interactions in community spaces as well as 86 | external channels like social media. Violating these terms may lead to a temporary or permanent ban. 87 | 88 | ### 3. Temporary Ban 89 | 90 | **Community Impact**: A serious violation of community standards, including sustained inappropriate 91 | behavior. 92 | 93 | **Consequence**: A temporary ban from any sort of interaction or public communication with the 94 | community for a specified period of time. No public or private interaction with the people involved, 95 | including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this 96 | period. Violating these terms may lead to a permanent ban. 97 | 98 | ### 4. Permanent Ban 99 | 100 | **Community Impact**: Demonstrating a pattern of violation of community standards, including 101 | sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement 102 | of classes of individuals. 103 | 104 | **Consequence**: A permanent ban from any sort of public interaction within the community. 105 | 106 | ## Attribution 107 | 108 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available 109 | at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 110 | 111 | Community Impact Guidelines were inspired 112 | by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). 113 | 114 | [homepage]: https://www.contributor-covenant.org 115 | 116 | For answers to common questions about this code of conduct, see the FAQ 117 | at https://www.contributor-covenant.org/faq. Translations are available 118 | at https://www.contributor-covenant.org/translations. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DeepSensor 2 | 3 | 🌍💫 We're excited that you're here and want to contribute. 💫🌍 4 | 5 | By joining our efforts, you will be helping to push the frontiers of environmental sciences. 6 | 7 | We want to ensure that every user and contributor feels welcome, included and supported to 8 | participate in DeepSensor community. Whether you're a seasoned developer, a machine learning 9 | researcher, an environmental scientist, or just someone eager to learn and contribute, **you are 10 | welcome here**. We value every contribution, be it big or small, and we appreciate the unique 11 | perspectives you bring to the project. 12 | 13 | We hope that the information provided in this document will make it as easy as possible for you to 14 | get involved. If you find that you have questions that are not discussed below, please let us know 15 | through one of the many ways to [get in touch](#get-in-touch). 16 | 17 | ## Important Resources 18 | 19 | If you'd like to find out more about DeepSensor, make sure to check out: 20 | 21 | 1. **README**: For a high-level overview of the project, please refer to our README. 22 | 2. **Documentation**: For more detailed information about the project, please refer to 23 | our [documentation](https://alan-turing-institute.github.io/deepsensor). 24 | 3. **Project Roadmap**: Familiarize yourself with our direction and goals by checking 25 | out [the project's roadmap](https://alan-turing-institute.github.io/deepsensor/community/roadmap.html). 26 | 27 | ## Get in touch 28 | 29 | The easiest way to get involved with the active development of DeepSensor is to join our regular 30 | community calls. The community calls are currently on a hiatus but if you are interested in 31 | participating in the forthcoming community calls, which will start in 2024, you should join our 32 | Slack workspace, where conversation about when to hold the community calls in the future will take 33 | place. 34 | 35 | **Slack Workspace**: Join 36 | our DeepSensor Slack channel for 37 | discussions, queries, and community interactions. In order to join, [sign up for the Turing Environment & Sustainability stakeholder community](https://forms.office.com/pages/responsepage.aspx?id=p_SVQ1XklU-Knx-672OE-ZmEJNLHTHVFkqQ97AaCfn9UMTZKT1IwTVhJRE82UjUzMVE2MThSOU5RMC4u). The form includes a question on signing up for the Slack team, where you can find DeepSensor's channel. 38 | 39 | **Email**: If you prefer a more formal communication method or have specific concerns, please reach 40 | us at tomandersson3@gmail.com. 41 | 42 | ## How to Contribute 43 | 44 | We welcome contributions of all kinds, be it code, documentation, raising issues, or community engagement. We 45 | encourage you to read through the following sections to learn more about how you can contribute to 46 | 47 | ### How to Submit Changes 48 | 49 | We follow the same instructions for submitting changes to the project as those developed 50 | by [The Turing Way](https://github.com/the-turing-way/the-turing-way/blob/main/CONTRIBUTING.md#making-a-change-with-a-pull-request). 51 | In short, there are five steps to adding changes to this repository: 52 | 53 | 1. **Fork the Repository**: Start 54 | by [forking the DeepSensor repository](https://github.com/alan-turing-institute/deepsensor/fork). 55 | 2. **Make Changes**: Ensure your code adheres to the style guidelines and passes all tests. 56 | 3. **Commit and Push**: Use clear commit messages. 57 | 4. **Open a Pull Request**: Ensure you describe the changes made and any additional details. 58 | 59 | #### 1. Fork the Repository 60 | 61 | Once you have [created a fork of the repository](https://github.com/alan-turing-institute/deepsensor/fork), 62 | you now have your own unique local copy of DeepSensor. Changes here won't affect anyone else's work, 63 | so it's a safe space to explore edits to the code! 64 | 65 | Make sure to [keep your fork up to date](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork) with the main repository, otherwise, you 66 | can end up with lots of dreaded [merge conflicts](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/addressing-merge-conflicts/about-merge-conflicts). 67 | 68 | If you prefer working with GitHub in the 69 | browser, [these instructions](https://github.com/KirstieJane/STEMMRoleModels/wiki/Syncing-your-fork-to-the-original-repository-via-the-browser) 70 | describe how to sync your fork to the original repository. 71 | 72 | #### 2. Make Changes 73 | 74 | Try to keep the changes focused. 75 | If you submit a large amount of work all in one go it will be much more work for whoever is 76 | reviewing your pull request. 77 | Help them help you! :wink: 78 | 79 | Are you new to Git and GitHub or just want a detailed guide on getting started with version control? 80 | Check out 81 | our [Version Control chapter](https://the-turing-way.netlify.com/version_control/version_control.html) 82 | in _The Turing Way_ Book! 83 | 84 | #### 3. Commit and Push 85 | 86 | While making your changes, commit often and write good, detailed commit messages. 87 | [This blog](https://chris.beams.io/posts/git-commit/) explains how to write a good Git commit 88 | message and why it matters. 89 | It is also perfectly fine to have a lot of commits - including ones that break code. 90 | A good rule of thumb is to push up to GitHub when you _do_ have passing tests then the continuous 91 | integration (CI) has a good chance of passing everything. 😸 92 | 93 | Please do not re-write history! 94 | That is, please do not use the [rebase](https://help.github.com/en/articles/about-git-rebase) 95 | command to edit previous commit messages, combine multiple commits into one, or delete or revert 96 | commits that are no longer necessary. 97 | 98 | Make sure you're using the developer dependencies. 99 | If you're working locally on the source code, *before* commiting, please run `pip install .[dev]` to install some useful dependencies just for development. 100 | This includes `pre-commit` and `ruff` which are used to check and format your code style when you run `git commit`, so that you don't have to. 101 | 102 | Using pre-commit: 103 | + To make this work, just run `pre-commit install` and when you commit, `ruff` will be run to check the style of the files that you've changed. 104 | + Note, you may find that if `ruff` needs to edit your files, you'll have to run `git add ...` and `git commit ..` again as the pre-commit hook will stop the commit until the changes pass its tests, `ruff` will also have slightly edited the files you added, so you'll need to stage and commit again. 105 | 106 | Without pre-commit: 107 | + Alternatively, you can run `ruff` yourself (without) `pre-commit` by installing `ruff` as above and just running `ruff format`. 108 | 109 | You should also run `pytest` and check that your changes don't break any of the existing tests. 110 | If you've made changes to the source code, you may need to add some tests to make sure that they don't get broken in the future. 111 | 112 | #### 4. Open a Pull Request 113 | 114 | We encourage you to open a pull request as early in your contributing process as possible. 115 | This allows everyone to see what is currently being worked on. 116 | It also provides you, the contributor, feedback in real-time from both the community and the 117 | continuous integration as you make commits (which will help prevent stuff from breaking). 118 | 119 | GitHub has a [nice introduction](https://guides.github.com/introduction/flow) to the pull request 120 | workflow, but please [get in touch](#get-in-touch) if you have any questions :balloon:. 121 | 122 | ### DeepSensor's documentation 123 | 124 | You don't have to write code to contribute to DeepSensor. 125 | Another highly valuable way of contributing is helping with DeepSensor's [documentation](https://alan-turing-institute.github.io/deepsensor). 126 | See below for information on how to do this. 127 | 128 | #### Background 129 | 130 | We use the Jupyter Book framework to build our documentation in the `docs/` folder. 131 | The documentation is written in 132 | Markdown and Jupyter Notebooks. The documentation is hosted on GitHub Pages and is automatically 133 | built and deployed using GitHub Actions after every commit to the `main` branch. 134 | 135 | DeepSensor requires slightly unique documentation, because demonstrating the package requires 136 | both data and trained models. 137 | This makes it compute- and data-hungry to run some of the notebooks, and they cannot 138 | run on GitHub Actions. 139 | Therefore, all the notebooks are run locally - the code cell outputs are saved in the .ipynb files 140 | and are rendered when the documentation is built. 141 | If DeepSensor is updated, some of the notebooks may become out of date and will need to be re-run. 142 | 143 | Some relevant links for Juptyer Book and MyST: 144 | * https://jupyterbook.org/en/stable/intro.html 145 | * https://jupyterbook.org/en/stable/content/myst.html 146 | * https://jupyterbook.org/en/stable/reference/cheatsheet.html 147 | 148 | #### Contributing to documentation 149 | 150 | One easy way to contribute to the documentation is to provide feedback in [this issue](https://github.com/alan-turing-institute/deepsensor/issues/87) and/or in the DeepSensor Slack channel. 151 | 152 | Another way to contribute is to directly edit or add to the documentation and open a PR: 153 | * Follow all the forking instructions above 154 | * Install the documentation requirements: `pip install deepsensor[docs]` 155 | * Option A: Editing a markdown file 156 | * Simply make your edits! 157 | * Option B: Editing a jupyter notebook file 158 | * This can be more involved... Firstly, reach out on the Slack channel to ask if anyone else is working on the same notebook file locally. Working one-at-a-time can save Jupyter JSON merge conflict headaches later! 159 | * If you are only editing markdown cells, just re-run those cells specifically to compile them 160 | * If you are editing code cells: 161 | * Install `cartopy` using `conda install -c conda-forge cartopy` 162 | * Run the all the code cells that the current cell depends on and any subsequent code cells that depend on the current cell (you may need to rerun the whole notebook) 163 | * Note: Some notebooks require a GPU and some assume that previous notebooks have been run 164 | * Please be careful about not clearing any code cell outputs that you don't intend to! 165 | * Once your changes have been made, regenerate the docs locally with `jupyter-book build docs --all` and check your changes have applied as expected 166 | * Push your changes and open a PR (see above) 167 | 168 | ## First-timers' Corner 169 | 170 | If you're new to the project, we recommend starting with issues labeled 171 | as ["good first issue"](https://github.com/alan-turing-institute/deepsensor/issues?q=is:issue+is:open+label:%22good+first+issue%22). 172 | These are typically simpler tasks that offer a great starting point. Browse these here. 173 | 174 | There's also the 175 | label ["thoughts welcome"](https://github.com/alan-turing-institute/deepsensor/issues?q=is:issue+is:open+label:%22thoughts+welcome%22), 176 | which allows for you to contribute with discussion points in the issues, even if you don't want to 177 | or cannot contribute to the codebase. 178 | 179 | If you feel ready for it, you can also open a new issue. Before you open a new issue, please check 180 | if any of [our open issues](https://github.com/alan-turing-institute/deepsensor/issues) cover your idea 181 | already. If you open a new issue, please follow our basic guidelines laid out in our issue 182 | templates, which you should be able to see if 183 | you [open a new issue](https://github.com/alan-turing-institute/deepsensor/issues/new/choose). 184 | 185 | ## Reporting Bugs 186 | 187 | Found a bug? Please open an issue here on GitHub to report it. We have a template for opening 188 | issues, so make sure you follow the correct format and ensure you include: 189 | 190 | - A clear title. 191 | - A detailed description of the bug. 192 | - Steps to reproduce it. 193 | - Expected versus actual behavior. 194 | 195 | ## Recognising Contributions 196 | 197 | We value and recognize every contribution. All contributors will be acknowledged in the 198 | [contributors](https://github.com/alan-turing-institute/deepsensor/tree/main#contributors) section of the 199 | README. 200 | Notable contributions will also be highlighted in our fortnightly community meetings. 201 | 202 | DeepSensor follows the [all-contributors](https://github.com/kentcdodds/all-contributors#emoji-key) 203 | specifications. The all-contributors bot usage is 204 | described [here](https://allcontributors.org/docs/en/bot/usage). You can see a list of current 205 | contributors here. 206 | 207 | To add yourself or someone else as a contributor, comment on the relevant Issue or Pull Request with 208 | the following: 209 | 210 | > @all-contributors please add username for contribution1, contribution2 211 | 212 | You can see 213 | the [Emoji Key (Contribution Types Reference)](https://allcontributors.org/docs/en/emoji-key) for a 214 | list of valid <contribution> types and examples of how this command can be run 215 | in [this issue](https://github.com/alan-turing-institute/deepsensor/issues/58). The bot will then create a 216 | Pull Request to add the contributor and reply with the pull request details. 217 | 218 | **PLEASE NOTE: Only one contributor can be added with the bot at a time!** Add each contributor in 219 | turn, merge the pull request and delete the branch (`all-contributors/add-<username>`) before adding 220 | another one. Otherwise, you can end up with 221 | dreaded [merge conflicts](https://help.github.com/articles/about-merge-conflicts). Therefore, please 222 | check the open pull requests first to make sure there aren't 223 | any [open requests from the bot](https://github.com/alan-turing-institute/deepsensor/pulls/app%2Fallcontributors) 224 | before adding another. 225 | 226 | What happens if you accidentally run the bot before the previous run was merged and you got those 227 | pesky merge conflicts? (Don't feel bad, we have all done it! 🙈) Simply close the pull request and 228 | delete the branch (`all-contributors/add-<username>`). If you are unable to do this for any reason, 229 | please let us know on Slack <link to Slack> or by opening an issue, and one of our core team members 230 | will be very happy to help! 231 | 232 | ## Need Help? 233 | 234 | If you're stuck or need assistance: 235 | 236 | - Check our [FAQ](https://alan-turing-institute.github.io/deepsensor/community/faq.html) section first. 237 | - Reach out on Slack or via email for personalized assistance. (See ["Get in touch"](#get-in-touch) 238 | above for links.) 239 | - Consider pairing up with a another contributor for guidance. You can always find us in the Slack 240 | channel and we're happy to chat! 241 | 242 | **Once again, thank you for considering contributing to DeepSensor! We hope you enjoy your 243 | contributing experience.** 244 | 245 | ## Inclusivity 246 | 247 | We aim to make DeepSensor a collaboratively developed project. We, therefore, require that all our 248 | members and their contributions **adhere to our [Code of Conduct](./CODE_OF_CONDUCT.md)**. Please 249 | familiarize yourself with our Code of Conduct that lists the expected behaviours. 250 | 251 | Every contributor is expected to adhere to our Code of Conduct. It outlines our expectations and 252 | ensures a safe, respectful environment for everyone. 253 | 254 | ---- 255 | 256 | These Contributing Guidelines have been adapted from 257 | the [Contributing Guidelines](https://github.com/the-turing-way/the-turing-way/blob/main/CONTRIBUTING.md#recognising-contributions) 258 | of [The Turing Way](https://github.com/the-turing-way/the-turing-way)! (License: CC-BY) 259 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tom Robin Andersson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /deepsensor/__init__.py: -------------------------------------------------------------------------------- 1 | class Backend: 2 | """Backend for deepsensor. 3 | 4 | This class is used to provide a consistent interface for either tensorflow or 5 | pytorch backends. It is used to assign the backend to the deepsensor module. 6 | 7 | Usage: blah 8 | """ 9 | 10 | def __getattr__(self, attr): 11 | raise AttributeError( 12 | f"Attempting to access Backend.{attr} before {attr} has been assigned. " 13 | f"Please assign a backend with `import deepsensor.tensorflow` " 14 | f"or `import deepsensor.torch` before using backend-dependent functionality." 15 | ) 16 | 17 | 18 | backend = Backend() 19 | 20 | from .data.processor import DataProcessor 21 | from .data.loader import TaskLoader 22 | from .plot import * 23 | -------------------------------------------------------------------------------- /deepsensor/active_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from .algorithms import * 2 | from .acquisition_fns import * 3 | -------------------------------------------------------------------------------- /deepsensor/active_learning/acquisition_fns.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional 3 | 4 | import numpy as np 5 | 6 | from scipy.stats import norm 7 | 8 | from deepsensor.model.model import ProbabilisticModel 9 | from deepsensor.data.task import Task 10 | 11 | 12 | class AcquisitionFunction: 13 | """Parent class for acquisition functions.""" 14 | 15 | # Class attribute to indicate whether the acquisition function should be 16 | # minimised or maximised 17 | min_or_max = None 18 | 19 | def __init__( 20 | self, 21 | model: Optional[ProbabilisticModel] = None, 22 | context_set_idx: int = 0, 23 | target_set_idx: int = 0, 24 | ): 25 | """Args: 26 | model (:class:`~.model.model.ProbabilisticModel`): 27 | [Description of the model parameter.] 28 | context_set_idx (int): 29 | Index of context set to add new observations to when computing 30 | the acquisition function. 31 | target_set_idx (int): 32 | Index of target set to compute acquisition function for. 33 | """ 34 | self.model = model 35 | self.context_set_idx = context_set_idx 36 | self.target_set_idx = target_set_idx 37 | 38 | def __call__(self, task: Task, *args, **kwargs) -> np.ndarray: 39 | """... 40 | 41 | :no-index: 42 | 43 | Args: 44 | task (:class:`~.data.task.Task`): 45 | Task object containing context and target sets. 46 | 47 | Returns: 48 | :class:`numpy:numpy.ndarray`: 49 | Acquisition function value/s. Shape (). 50 | 51 | Raises: 52 | NotImplementedError: 53 | Because this is an abstract method, it must be implemented by 54 | the subclass. 55 | """ 56 | raise NotImplementedError 57 | 58 | 59 | class AcquisitionFunctionOracle(AcquisitionFunction): 60 | """Signifies that the acquisition function is computed using the true 61 | target values. 62 | """ 63 | 64 | 65 | class AcquisitionFunctionParallel(AcquisitionFunction): 66 | """Parent class for acquisition functions that are computed across all search 67 | points in parallel. 68 | """ 69 | 70 | def __call__(self, task: Task, X_s: np.ndarray, **kwargs) -> np.ndarray: 71 | """... 72 | 73 | :param **kwargs: 74 | :no-index: 75 | 76 | Args: 77 | task (:class:`~.data.task.Task`): 78 | Task object containing context and target sets. 79 | X_s (:class:`numpy:numpy.ndarray`): 80 | Search points. Shape (2, N_search). 81 | 82 | Returns: 83 | :class:`numpy:numpy.ndarray`: 84 | Should return acquisition function value/s. Shape (N_search,). 85 | 86 | Raises: 87 | NotImplementedError: 88 | Because this is an abstract method, it must be implemented by 89 | the subclass. 90 | """ 91 | raise NotImplementedError 92 | 93 | 94 | class MeanStddev(AcquisitionFunction): 95 | """Mean of the marginal variances.""" 96 | 97 | min_or_max = "min" 98 | 99 | def __call__(self, task: Task): 100 | """... 101 | 102 | :no-index: 103 | 104 | Args: 105 | task (:class:`~.data.task.Task`): 106 | [Description of the task parameter.] 107 | 108 | Returns: 109 | [Type of the return value]: 110 | [Description of the return value.] 111 | """ 112 | return np.mean(self.model.stddev(task)[self.target_set_idx]) 113 | 114 | 115 | class MeanVariance(AcquisitionFunction): 116 | """Mean of the marginal variances.""" 117 | 118 | min_or_max = "min" 119 | 120 | def __call__(self, task: Task): 121 | """... 122 | 123 | :no-index: 124 | 125 | Args: 126 | task (:class:`~.data.task.Task`): 127 | [Description of the task parameter.] 128 | 129 | Returns: 130 | [Type of the return value]: 131 | [Description of the return value.] 132 | """ 133 | return np.mean(self.model.variance(task)[self.target_set_idx]) 134 | 135 | 136 | class pNormStddev(AcquisitionFunction): 137 | """p-norm of the vector of marginal standard deviations.""" 138 | 139 | min_or_max = "min" 140 | 141 | def __init__(self, *args, p: int = 1, **kwargs): 142 | """... 143 | 144 | :no-index: 145 | 146 | Args: 147 | p (int, optional): 148 | [Description of the parameter p.], default is 1 149 | """ 150 | super().__init__(*args, **kwargs) 151 | self.p = p 152 | 153 | def __call__(self, task: Task): 154 | """... 155 | 156 | :no-index: 157 | 158 | Args: 159 | task (:class:`~.data.task.Task`): 160 | [Description of the task parameter.] 161 | 162 | Returns: 163 | [Type of the return value]: 164 | [Description of the return value.] 165 | """ 166 | return np.linalg.norm( 167 | self.model.stddev(task)[self.target_set_idx].ravel(), ord=self.p 168 | ) 169 | 170 | 171 | class MeanMarginalEntropy(AcquisitionFunction): 172 | """Mean of the entropies of the marginal predictive distributions.""" 173 | 174 | min_or_max = "min" 175 | 176 | def __call__(self, task): 177 | """... 178 | 179 | :no-index: 180 | 181 | Args: 182 | task (:class:`~.data.task.Task`): 183 | Task object containing context and target sets. 184 | 185 | Returns: 186 | [Type of the return value]: 187 | [Description of the return value.] 188 | """ 189 | marginal_entropy = self.model.mean_marginal_entropy(task) 190 | return marginal_entropy 191 | 192 | 193 | class JointEntropy(AcquisitionFunction): 194 | """Joint entropy of the predictive distribution.""" 195 | 196 | min_or_max = "min" 197 | 198 | def __call__(self, task: Task): 199 | """... 200 | 201 | :no-index: 202 | 203 | Args: 204 | task (:class:`~.data.task.Task`): 205 | Task object containing context and target sets. 206 | 207 | Returns: 208 | [Type of the return value]: 209 | [Description of the return value.] 210 | """ 211 | return self.model.joint_entropy(task) 212 | 213 | 214 | class OracleMAE(AcquisitionFunctionOracle): 215 | """Oracle mean absolute error.""" 216 | 217 | min_or_max = "min" 218 | 219 | def __call__(self, task: Task): 220 | """... 221 | 222 | :no-index: 223 | 224 | Args: 225 | task (:class:`~.data.task.Task`): 226 | Task object containing context and target sets. 227 | 228 | Returns: 229 | [Type of the return value]: 230 | [Description of the return value.] 231 | """ 232 | pred = self.model.mean(task) 233 | if isinstance(pred, list): 234 | pred = pred[self.target_set_idx] 235 | true = task["Y_t"][self.target_set_idx] 236 | return np.mean(np.abs(pred - true)) 237 | 238 | 239 | class OracleRMSE(AcquisitionFunctionOracle): 240 | """Oracle root mean squared error.""" 241 | 242 | min_or_max = "min" 243 | 244 | def __call__(self, task: Task): 245 | """... 246 | 247 | :no-index: 248 | 249 | Args: 250 | task (:class:`~.data.task.Task`): 251 | Task object containing context and target sets. 252 | 253 | Returns: 254 | [Type of the return value]: 255 | [Description of the return value.] 256 | """ 257 | pred = self.model.mean(task) 258 | if isinstance(pred, list): 259 | pred = pred[self.target_set_idx] 260 | true = task["Y_t"][self.target_set_idx] 261 | return np.sqrt(np.mean((pred - true) ** 2)) 262 | 263 | 264 | class OracleMarginalNLL(AcquisitionFunctionOracle): 265 | """Oracle marginal negative log-likelihood.""" 266 | 267 | min_or_max = "min" 268 | 269 | def __call__(self, task: Task): 270 | """... 271 | 272 | :no-index: 273 | 274 | Args: 275 | task (:class:`~.data.task.Task`): 276 | Task object containing context and target sets. 277 | 278 | Returns: 279 | [Type of the return value]: 280 | [Description of the return value.] 281 | """ 282 | pred = self.model.mean(task) 283 | if isinstance(pred, list): 284 | pred = pred[self.target_set_idx] 285 | true = task["Y_t"][self.target_set_idx] 286 | return -np.mean(norm.logpdf(true, loc=pred, scale=self.model.stddev(task))) 287 | 288 | 289 | class OracleJointNLL(AcquisitionFunctionOracle): 290 | """Oracle joint negative log-likelihood.""" 291 | 292 | min_or_max = "min" 293 | 294 | def __call__(self, task: Task): 295 | """... 296 | 297 | :no-index: 298 | 299 | Args: 300 | task (:class:`~.data.task.Task`): 301 | Task object containing context and target sets. 302 | 303 | Returns: 304 | [Type of the return value]: 305 | [Description of the return value.] 306 | """ 307 | return -self.model.logpdf(task) 308 | 309 | 310 | class Random(AcquisitionFunctionParallel): 311 | """Random acquisition function.""" 312 | 313 | min_or_max = "max" 314 | 315 | def __init__(self, *args, seed: int = 42, **kwargs): 316 | """... 317 | 318 | :no-index: 319 | 320 | Args: 321 | seed (int, optional): 322 | Random seed, defaults to 42. 323 | """ 324 | super().__init__(*args, **kwargs) 325 | self.rng = np.random.default_rng(seed) 326 | 327 | def __call__(self, task: Task, X_s: np.ndarray, **kwargs): 328 | """... 329 | 330 | :param **kwargs: 331 | :no-index: 332 | 333 | Args: 334 | task (:class:`~.data.task.Task`): 335 | [Description of the task parameter.] 336 | X_s (:class:`numpy:numpy.ndarray`): 337 | [Description of the X_s parameter.] 338 | 339 | Returns: 340 | float: 341 | A random acquisition function value. 342 | """ 343 | return self.rng.random(X_s.shape[1]) 344 | 345 | 346 | class ContextDist(AcquisitionFunctionParallel): 347 | """Distance to closest context point.""" 348 | 349 | min_or_max = "max" 350 | 351 | def __call__(self, task: Task, X_s: np.ndarray, **kwargs): 352 | """... 353 | 354 | :param **kwargs: 355 | :no-index: 356 | 357 | Args: 358 | task (:class:`~.data.task.Task`): 359 | [Description of the task parameter.] 360 | X_s (:class:`numpy:numpy.ndarray`): 361 | [Description of the X_s parameter.] 362 | 363 | Returns: 364 | [Type of the return value]: 365 | [Description of the return value.] 366 | """ 367 | X_c = task["X_c"][self.context_set_idx] 368 | 369 | if X_c.size == 0: 370 | # No sensors placed yet, so arbitrarily choose first query point by setting its 371 | # acquisition fn to non-zero and all others to zero 372 | dist_to_closest_sensor = np.zeros(X_s.shape[-1]) 373 | dist_to_closest_sensor[0] = 1 374 | else: 375 | # Use broadcasting to get matrix of distances from each possible 376 | # new sensor location to each existing sensor location 377 | dists_all = np.linalg.norm( 378 | X_s[..., np.newaxis] - X_c[..., np.newaxis, :], 379 | axis=0, 380 | ) # Shape (n_possible_locs, n_context + n_placed_sensors) 381 | 382 | # Compute distance to nearest sensor 383 | dist_to_closest_sensor = dists_all.min(axis=1) 384 | return dist_to_closest_sensor 385 | 386 | 387 | class Stddev(AcquisitionFunctionParallel): 388 | """Model standard deviation.""" 389 | 390 | min_or_max = "max" 391 | 392 | def __call__(self, task: Task, X_s: np.ndarray, **kwargs): 393 | """... 394 | 395 | :param **kwargs: 396 | :no-index: 397 | 398 | Args: 399 | task (:class:`~.data.task.Task`): 400 | [Description of the task parameter.] 401 | X_s (:class:`numpy:numpy.ndarray`): 402 | [Description of the X_s parameter.] 403 | 404 | Returns: 405 | [Type of the return value]: 406 | [Description of the return value.] 407 | """ 408 | # Set the target points to the search points 409 | task = copy.deepcopy(task) 410 | task["X_t"] = X_s 411 | 412 | return self.model.stddev(task)[self.target_set_idx] 413 | 414 | 415 | class ExpectedImprovement(AcquisitionFunctionParallel): 416 | """Expected improvement acquisition function. 417 | 418 | .. note:: 419 | 420 | The current implementation of this acquisition function is only valid 421 | for maximisation. 422 | """ 423 | 424 | min_or_max = "max" 425 | 426 | def __call__(self, task: Task, X_s: np.ndarray, **kwargs) -> np.ndarray: 427 | """:param **kwargs: 428 | :no-index: 429 | 430 | Args: 431 | task (:class:`~.data.task.Task`): 432 | Task object containing context and target sets. 433 | X_s (:class:`numpy:numpy.ndarray`): 434 | Search points. Shape (2, N_search). 435 | 436 | Returns: 437 | :class:`numpy:numpy.ndarray`: 438 | Acquisition function value/s. Shape (N_search,). 439 | """ 440 | # Set the target points to the search points 441 | task = copy.deepcopy(task) 442 | task["X_t"] = X_s 443 | 444 | # Compute the predictive mean and variance of the target set 445 | mean = self.model.mean(task)[self.target_set_idx] 446 | 447 | if task["Y_c"][self.context_set_idx].size == 0: 448 | # No previous context points, so heuristically use the predictive mean as the 449 | # acquisition function. This will at least select the most positive predicted mean. 450 | return self.model.mean(task)[self.target_set_idx] 451 | else: 452 | # Determine the best target value seen so far 453 | best_target_value = task["Y_c"][self.context_set_idx].max() 454 | 455 | # Compute the standard deviation of the context set 456 | stddev = self.model.stddev(task)[self.context_set_idx] 457 | 458 | # Compute the expected improvement 459 | Z = (mean - best_target_value) / stddev 460 | ei = stddev * (mean - best_target_value) * norm.cdf(Z) + stddev * norm.pdf(Z) 461 | 462 | return ei 463 | -------------------------------------------------------------------------------- /deepsensor/config.py: -------------------------------------------------------------------------------- 1 | """Configuration file for deepsensor.""" 2 | 3 | DEFAULT_LAB_EPSILON = 1e-6 4 | """ 5 | Magnitude of diagonal to regularise matrices with in ``backends`` library used 6 | by ``neuralprocesses`` 7 | """ 8 | -------------------------------------------------------------------------------- /deepsensor/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import DataProcessor 2 | from .loader import TaskLoader 3 | from .task import Task 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /deepsensor/data/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import scipy 6 | import xarray as xr 7 | 8 | 9 | def construct_x1x2_ds(gridded_ds): 10 | """Construct an :class:`xarray.Dataset` containing two vars, where each var is 11 | a 2D gridded channel whose values contain the x_1 and x_2 coordinate 12 | values, respectively. 13 | 14 | Args: 15 | gridded_ds (:class:`xarray.Dataset`): 16 | ... 17 | 18 | Returns: 19 | :class:`xarray.Dataset` 20 | ... 21 | """ 22 | X1, X2 = np.meshgrid(gridded_ds.x1, gridded_ds.x2, indexing="ij") 23 | ds = xr.Dataset( 24 | coords={"x1": gridded_ds.x1, "x2": gridded_ds.x2}, 25 | data_vars={"x1_arr": (("x1", "x2"), X1), "x2_arr": (("x1", "x2"), X2)}, 26 | ) 27 | return ds 28 | 29 | 30 | def construct_circ_time_ds(dates, freq): 31 | """Return an :class:`xarray.Dataset` containing a circular variable for time. 32 | 33 | The ``freq`` entry dictates the frequency of cycling of the circular 34 | variable. E.g.: 35 | 36 | - ``'H'``: cycles once per day at hourly intervals 37 | - ``'D'``: cycles once per year at daily intervals 38 | - ``'M'``: cycles once per year at monthly intervals 39 | 40 | Args: 41 | dates (...): 42 | ... 43 | freq (...): 44 | ... 45 | 46 | Returns: 47 | :class:`xarray.Dataset` 48 | ... 49 | """ 50 | # Ensure dates are pandas 51 | dates = pd.DatetimeIndex(dates) 52 | if freq == "D": 53 | time_var = dates.dayofyear 54 | mod = 365.25 55 | elif freq == "H": 56 | time_var = dates.hour 57 | mod = 24 58 | elif freq == "M": 59 | time_var = dates.month 60 | mod = 12 61 | else: 62 | raise ValueError("Circular time variable not implemented for this frequency.") 63 | 64 | cos_time = np.cos(2 * np.pi * time_var / mod) 65 | sin_time = np.sin(2 * np.pi * time_var / mod) 66 | 67 | ds = xr.Dataset( 68 | coords={"time": dates}, 69 | data_vars={ 70 | f"cos_{freq}": ("time", cos_time), 71 | f"sin_{freq}": ("time", sin_time), 72 | }, 73 | ) 74 | return ds 75 | 76 | 77 | def compute_xarray_data_resolution(ds: Union[xr.DataArray, xr.Dataset]) -> float: 78 | """Computes the resolution of an xarray object with coordinates x1 and x2. 79 | 80 | The data resolution is the finer of the two coordinate resolutions (x1 and 81 | x2). For example, if x1 has a resolution of 0.1 degrees and x2 has a 82 | resolution of 0.2 degrees, the data resolution returned will be 0.1 83 | degrees. 84 | 85 | Args: 86 | ds (:class:`xarray.DataArray` | :class:`xarray.Dataset`): 87 | Xarray object with coordinates x1 and x2. 88 | 89 | Returns: 90 | float: Resolution of the data (in spatial units, e.g. 0.1 degrees). 91 | """ 92 | x1_res = np.abs(np.mean(np.diff(ds["x1"]))) 93 | x2_res = np.abs(np.mean(np.diff(ds["x2"]))) 94 | 95 | # ensure float type, since numpy 2, np.mean returns a numpy float32 96 | data_resolution = float(np.min([x1_res, x2_res])) 97 | return data_resolution 98 | 99 | 100 | def compute_pandas_data_resolution( 101 | df: Union[pd.DataFrame, pd.Series], 102 | n_times: int = 1000, 103 | percentile: int = 5, 104 | ) -> float: 105 | """Approximates the resolution of non-gridded pandas data with indexes time, 106 | x1, and x2. 107 | 108 | The resolution is approximated as the Nth percentile of the distances 109 | between neighbouring observations, possibly using a subset of the dates in 110 | the data. The default is to use 1000 dates (or all dates if there are fewer 111 | than 1000) and to use the 5th percentile. This means that the resolution is 112 | the distance between the closest 5% of neighbouring observations. 113 | 114 | Args: 115 | df (:class:`pandas.DataFrame` | :class:`pandas.Series`): 116 | Dataframe or series with indexes time, x1, and x2. 117 | n_times (int, optional): 118 | Number of dates to sample. Defaults to 1000. If "all", all dates 119 | are used. 120 | percentile (int, optional): 121 | Percentile of pairwise distances for computing the resolution. 122 | Defaults to 5. 123 | 124 | Returns: 125 | float: Resolution of the data (in spatial units, e.g. 0.1 degrees). 126 | """ 127 | dates = df.index.get_level_values("time").unique() 128 | 129 | if n_times != "all" and len(dates) > n_times: 130 | rng = np.random.default_rng(42) 131 | dates = rng.choice(dates, size=n_times, replace=False) 132 | 133 | closest_distances = [] 134 | df = df.reset_index().set_index("time") 135 | for time in dates: 136 | df_t = df.loc[[time]] 137 | X = df_t[["x1", "x2"]].values # (N, 2) array of coordinates 138 | if X.shape[0] < 2: 139 | # Skip this time if there are fewer than 2 stationS 140 | continue 141 | X_unique = np.unique(X, axis=0) # (N_unique, 2) array of unique coordinates 142 | 143 | pairwise_distances = scipy.spatial.distance.cdist(X_unique, X_unique) 144 | percentile_distances_without_self = np.ma.masked_equal(pairwise_distances, 0) 145 | 146 | # Compute the closest distance from each station to each other station 147 | closest_distances_t = np.min(percentile_distances_without_self, axis=1) 148 | closest_distances.extend(closest_distances_t) 149 | 150 | data_resolution = np.percentile(closest_distances, percentile) 151 | return data_resolution 152 | -------------------------------------------------------------------------------- /deepsensor/errors.py: -------------------------------------------------------------------------------- 1 | class TaskSetIndexError(Exception): 2 | """Raised when the task context/target set index is out of range.""" 3 | 4 | def __init__(self, index, set_length, context_or_target): 5 | super().__init__( 6 | f"{context_or_target} set index {index} is out of range for task with " 7 | f"{set_length} {context_or_target} sets." 8 | ) 9 | 10 | 11 | class GriddedDataError(Exception): 12 | """Raised during invalid operation with gridded data.""" 13 | 14 | pass 15 | 16 | 17 | class InvalidSamplingStrategyError(Exception): 18 | """Raised when TaskLoader sampling strategy is invalid.""" 19 | 20 | pass 21 | 22 | 23 | class SamplingTooManyPointsError(ValueError): 24 | """Raised when the number of points to sample is greater than the number of points in the dataset.""" 25 | 26 | def __init__(self, requested: int, available: int): 27 | super().__init__( 28 | f"Requested {requested} points to sample, but only {available} are available." 29 | ) 30 | -------------------------------------------------------------------------------- /deepsensor/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import * 2 | -------------------------------------------------------------------------------- /deepsensor/eval/metrics.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | from deepsensor.model.pred import Prediction 3 | 4 | 5 | def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset: 6 | """Compute errors between predictions and targets. 7 | 8 | Args: 9 | pred: Prediction object. 10 | target: Target data. 11 | 12 | Returns: 13 | xr.Dataset: Dataset of pointwise differences between predictions and targets 14 | at the same valid time in the predictions. Note, the difference is positive 15 | when the prediction is greater than the target. 16 | """ 17 | errors = {} 18 | for var_ID, pred_var in pred.items(): 19 | target_var = target[var_ID] 20 | error = pred_var["mean"] - target_var.sel(time=pred_var.time) 21 | error.name = f"{var_ID}" 22 | errors[var_ID] = error 23 | return xr.Dataset(errors) 24 | -------------------------------------------------------------------------------- /deepsensor/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .convnp import ConvNP 2 | from .model import ProbabilisticModel, DeepSensorModel 3 | -------------------------------------------------------------------------------- /deepsensor/model/defaults.py: -------------------------------------------------------------------------------- 1 | from deepsensor.data.loader import TaskLoader 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import xarray as xr 6 | 7 | from deepsensor.data.utils import ( 8 | compute_xarray_data_resolution, 9 | compute_pandas_data_resolution, 10 | ) 11 | 12 | from typing import List 13 | 14 | 15 | def compute_greatest_data_density(task_loader: TaskLoader) -> int: 16 | """Computes data-informed settings for the model's internal grid density (ppu, 17 | points per unit). 18 | 19 | Loops over all context and target variables in the ``TaskLoader`` and 20 | computes the data resolution for each. The model ppu is then set to the 21 | maximum data ppu. 22 | 23 | Args: 24 | task_loader (:class:`~.data.loader.TaskLoader`): 25 | TaskLoader object containing context and target sets. 26 | 27 | Returns: 28 | max_density (int): 29 | The maximum data density (ppu) across all context and target 30 | variables, where 'density' is the number of points per unit of 31 | input space (in both spatial dimensions). 32 | """ 33 | # List of data resolutions for each context/target variable (in points-per-unit) 34 | data_densities = [] 35 | for var in [*task_loader.context, *task_loader.target]: 36 | if isinstance(var, (xr.DataArray, xr.Dataset)): 37 | # Gridded variable: use data resolution 38 | data_resolution = compute_xarray_data_resolution(var) 39 | elif isinstance(var, (pd.DataFrame, pd.Series)): 40 | # Point-based variable: calculate density based on pairwise distances between observations 41 | data_resolution = compute_pandas_data_resolution( 42 | var, n_times=1000, percentile=5 43 | ) 44 | else: 45 | raise ValueError(f"Unknown context input type: {type(var)}") 46 | data_density = int(1 / data_resolution) 47 | data_densities.append(data_density) 48 | max_density = int(max(data_densities)) 49 | return max_density 50 | 51 | 52 | def gen_decoder_scale(model_ppu: int) -> float: 53 | """Computes informed setting for the decoder SetConv scale. 54 | 55 | This sets the length scale of the Gaussian basis functions used interpolate 56 | from the model's internal grid to the target locations. 57 | 58 | The decoder scale should be as small as possible given the model's 59 | internal grid. The value chosen is 1 / model_ppu (i.e. the length scale is 60 | equal to the model's internal grid spacing). 61 | 62 | Args: 63 | model_ppu (int): 64 | Model ppu (points per unit), i.e. the number of points per unit of 65 | input space. 66 | 67 | Returns: 68 | float: Decoder scale. 69 | """ 70 | return 1 / model_ppu 71 | 72 | 73 | def gen_encoder_scales(model_ppu: int, task_loader: TaskLoader) -> List[float]: 74 | """Computes data-informed settings for the encoder SetConv scale for each 75 | context set. 76 | 77 | This sets the length scale of the Gaussian basis functions used to encode 78 | the context sets. 79 | 80 | For off-grid station data, the scale should be as small as possible given 81 | the model's internal grid density (ppu, points per unit). The value chosen 82 | is 0.5 / model_ppu (i.e. half the model's internal resolution). 83 | 84 | For gridded data, the scale should be such that the functional 85 | representation smoothly interpolates the data. This is determined by 86 | computing the *data resolution* (the distance between the nearest two data 87 | points) for each context variable. The encoder scale is then set to 0.5 * 88 | data_resolution. 89 | 90 | Args: 91 | model_ppu (int): 92 | Model ppu (points per unit), i.e. the number of points per unit of 93 | input space. 94 | task_loader (:class:`~.data.loader.TaskLoader`): 95 | TaskLoader object containing context and target sets. 96 | 97 | Returns: 98 | list[float]: List of encoder scales for each context set. 99 | """ 100 | encoder_scales = [] 101 | for var in task_loader.context: 102 | if isinstance(var, (xr.DataArray, xr.Dataset)): 103 | encoder_scale = 0.5 * compute_xarray_data_resolution(var) 104 | elif isinstance(var, (pd.DataFrame, pd.Series)): 105 | encoder_scale = 0.5 / model_ppu 106 | else: 107 | raise ValueError(f"Unknown context input type: {type(var)}") 108 | encoder_scales.append(encoder_scale) 109 | 110 | if task_loader.aux_at_contexts: 111 | # Add encoder scale for the final auxiliary-at-contexts context set: use smallest possible 112 | # scale within model discretisation 113 | encoder_scales.append(0.5 / model_ppu) 114 | 115 | return encoder_scales 116 | -------------------------------------------------------------------------------- /deepsensor/model/nps.py: -------------------------------------------------------------------------------- 1 | from .. import backend 2 | import lab as B 3 | 4 | from deepsensor.data.task import Task 5 | from typing import Tuple, Optional, Literal 6 | 7 | 8 | def convert_task_to_nps_args(task: Task): 9 | """Infer & build model call signature from ``task`` dict. 10 | 11 | .. 12 | TODO move to ConvNP class? 13 | 14 | Args: 15 | task (:class:`~.data.task.Task`): 16 | Task object containing context and target sets. 17 | 18 | Returns: 19 | tuple[list[tuple[numpy.ndarray, numpy.ndarray]], numpy.ndarray, numpy.ndarray, dict]: 20 | ... 21 | """ 22 | context_data = list(zip(task["X_c"], task["Y_c"])) 23 | 24 | if task["X_t"] is None: 25 | raise ValueError( 26 | f"Running `neuralprocesses` model with no target locations (got {task['X_t']}). " 27 | f"Have you not provided a `target_sampling` argument to `TaskLoader`?" 28 | ) 29 | elif len(task["X_t"]) == 1 and task["Y_t"] is None: 30 | xt = task["X_t"][0] 31 | yt = None 32 | elif len(task["X_t"]) > 1 and task["Y_t"] is None: 33 | # Multiple target sets, different target locations 34 | xt = backend.nps.AggregateInput(*[(xt, i) for i, xt in enumerate(task["X_t"])]) 35 | yt = None 36 | elif len(task["X_t"]) == 1 and len(task["Y_t"]) == 1: 37 | # Single target set 38 | xt = task["X_t"][0] 39 | yt = task["Y_t"][0] 40 | elif len(task["X_t"]) > 1 and len(task["Y_t"]) > 1: 41 | # Multiple target sets, different target locations 42 | assert len(task["X_t"]) == len(task["Y_t"]) 43 | xts = [] 44 | yts = [] 45 | target_dims = [yt.shape[1] for yt in task["Y_t"]] 46 | # Map from ND target sets to 1D target sets 47 | dim_counter = 0 48 | for i, (xt, yt) in enumerate(zip(task["X_t"], task["Y_t"])): 49 | # Repeat target locations for each target dimension in target set 50 | xts.extend([(xt, dim_counter + j) for j in range(target_dims[i])]) 51 | yts.extend([yt[:, j : j + 1] for j in range(target_dims[i])]) 52 | dim_counter += target_dims[i] 53 | xt = backend.nps.AggregateInput(*xts) 54 | yt = backend.nps.Aggregate(*yts) 55 | elif len(task["X_t"]) == 1 and len(task["Y_t"]) > 1: 56 | # Multiple target sets, same target locations; `Y_t`s along feature dim 57 | xt = task["X_t"][0] 58 | yt = B.concat(*task["Y_t"], axis=1) 59 | else: 60 | raise ValueError( 61 | f"Incorrect target locations and target observations (got {len(task['X_t'])} and {len(task['Y_t'])})" 62 | ) 63 | 64 | model_kwargs = {} 65 | if "Y_t_aux" in task.keys(): 66 | model_kwargs["aux_t"] = task["Y_t_aux"] 67 | 68 | return context_data, xt, yt, model_kwargs 69 | 70 | 71 | def run_nps_model( 72 | neural_process, 73 | task: Task, 74 | n_samples: Optional[int] = None, 75 | requires_grad: bool = False, 76 | ): 77 | """Run ``neuralprocesses`` model. 78 | 79 | Args: 80 | neural_process (neuralprocesses.Model): 81 | Neural process model. 82 | task (:class:`~.data.task.Task`): 83 | Task object containing context and target sets. 84 | n_samples (int, optional): 85 | Number of samples to draw from the model. Defaults to ``None`` 86 | (single sample). 87 | requires_grad (bool, optional): 88 | Whether to require gradients. Defaults to ``False``. 89 | 90 | Returns: 91 | neuralprocesses.distributions.Distribution: 92 | Distribution object containing the model's predictions. 93 | """ 94 | context_data, xt, _, model_kwargs = convert_task_to_nps_args(task) 95 | if backend.str == "torch" and not requires_grad: 96 | # turn off grad 97 | import torch 98 | 99 | with torch.no_grad(): 100 | dist = neural_process( 101 | context_data, xt, **model_kwargs, num_samples=n_samples 102 | ) 103 | else: 104 | dist = neural_process(context_data, xt, **model_kwargs, num_samples=n_samples) 105 | return dist 106 | 107 | 108 | def run_nps_model_ar(neural_process, task: Task, num_samples: int = 1): 109 | """Run ``neural_process`` in AR mode. 110 | 111 | Args: 112 | neural_process (neuralprocesses.Model): 113 | Neural process model. 114 | task (:class:`~.data.task.Task`): 115 | Task object containing context and target sets. 116 | num_samples (int, optional): 117 | Number of samples to draw from the model. Defaults to 1. 118 | 119 | Returns: 120 | tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]: 121 | Tuple of mean, variance, noiseless samples, and noisy samples. 122 | """ 123 | context_data, xt, _, _ = convert_task_to_nps_args(task) 124 | 125 | # NOTE can't use `model_kwargs` in AR mode (ie can't use auxiliary MLP at targets) 126 | mean, variance, noiseless_samples, noisy_samples = backend.nps.ar_predict( 127 | neural_process, 128 | context_data, 129 | xt, 130 | num_samples=num_samples, 131 | ) 132 | 133 | return mean, variance, noiseless_samples, noisy_samples 134 | 135 | 136 | def construct_neural_process( 137 | dim_x: int = 2, 138 | dim_yc: int = 1, 139 | dim_yt: int = 1, 140 | dim_aux_t: Optional[int] = None, 141 | dim_lv: int = 0, 142 | conv_arch: str = "unet", 143 | unet_channels: Tuple[int, ...] = (64, 64, 64, 64), 144 | unet_resize_convs: bool = True, 145 | unet_resize_conv_interp_method: Literal["bilinear"] = "bilinear", 146 | aux_t_mlp_layers: Optional[Tuple[int, ...]] = None, 147 | likelihood: Literal["cnp", "gnp", "cnp-spikes-beta"] = "cnp", 148 | unet_kernels: int = 5, 149 | internal_density: int = 100, 150 | encoder_scales: float = 1 / 100, 151 | encoder_scales_learnable: bool = False, 152 | decoder_scale: float = 1 / 100, 153 | decoder_scale_learnable: bool = False, 154 | num_basis_functions: int = 64, 155 | epsilon: float = 1e-2, 156 | ): 157 | """Construct a ``neuralprocesses`` ConvNP model. 158 | 159 | See: https://github.com/wesselb/neuralprocesses/blob/main/neuralprocesses/architectures/convgnp.py 160 | 161 | Docstring below modified from ``neuralprocesses``. If more kwargs are 162 | needed, they must be explicitly passed to ``neuralprocesses`` constructor 163 | (not currently safe to use `**kwargs` here). 164 | 165 | Args: 166 | dim_x (int, optional): 167 | Dimensionality of the inputs. Defaults to 1. 168 | dim_y (int, optional): 169 | Dimensionality of the outputs. Defaults to 1. 170 | dim_yc (int or tuple[int], optional): 171 | Dimensionality of the outputs of the context set. You should set 172 | this if the dimensionality of the outputs of the context set is not 173 | equal to the dimensionality of the outputs of the target set. You 174 | should also set this if you want to use multiple context sets. In 175 | that case, set this equal to a tuple of integers indicating the 176 | respective output dimensionalities. 177 | dim_yt (int, optional): 178 | Dimensionality of the outputs of the target set. You should set 179 | this if the dimensionality of the outputs of the target set is not 180 | equal to the dimensionality of the outputs of the context set. 181 | dim_aux_t (int, optional): 182 | Dimensionality of target-specific auxiliary variables. 183 | internal_density (int, optional): 184 | Density of the ConvNP's internal grid (in terms of number of points 185 | per 1x1 unit square). Defaults to 100. 186 | likelihood (str, optional): 187 | Likelihood. Must be one of ``"cnp"`` (equivalently ``"het"``), 188 | ``"gnp"`` (equivalently ``"lowrank"``), or ``"cnp-spikes-beta"`` 189 | (equivalently ``"spikes-beta"``). Defaults to ``"cnp"``. 190 | conv_arch (str, optional): 191 | Convolutional architecture to use. Must be one of 192 | ``"unet[-res][-sep]"`` or ``"conv[-res][-sep]"``. Defaults to 193 | ``"unet"``. 194 | unet_channels (tuple[int], optional): 195 | Channels of every layer of the UNet. Defaults to six layers each 196 | with 64 channels. 197 | unet_kernels (int or tuple[int], optional): 198 | Sizes of the kernels in the UNet. Defaults to 5. 199 | unet_resize_convs (bool, optional): 200 | Use resize convolutions rather than transposed convolutions in the 201 | UNet. Defaults to ``False``. 202 | unet_resize_conv_interp_method (str, optional): 203 | Interpolation method for the resize convolutions in the UNet. Can 204 | be set to ``"bilinear"``. Defaults to "bilinear". 205 | num_basis_functions (int, optional): 206 | Number of basis functions for the low-rank likelihood. Defaults to 207 | 64. 208 | dim_lv (int, optional): 209 | Dimensionality of the latent variable. Setting to >0 constructs a 210 | latent neural process. Defaults to 0. 211 | encoder_scales (float or tuple[float], optional): 212 | Initial value for the length scales of the set convolutions for the 213 | context sets embeddings. Set to a tuple equal to the number of 214 | context sets to use different values for each set. Set to a single 215 | value to use the same value for all context sets. Defaults to 216 | ``1 / internal_density``. 217 | encoder_scales_learnable (bool, optional): 218 | Whether the encoder SetConv length scale(s) are learnable. 219 | Defaults to ``False``. 220 | decoder_scale (float, optional): 221 | Initial value for the length scale of the set convolution in the 222 | decoder. Defaults to ``1 / internal_density``. 223 | decoder_scale_learnable (bool, optional): 224 | Whether the decoder SetConv length scale(s) are learnable. Defaults 225 | to ``False``. 226 | aux_t_mlp_layers (tuple[int], optional): 227 | Widths of the layers of the MLP for the target-specific auxiliary 228 | variable. Defaults to three layers of width 128. 229 | epsilon (float, optional): 230 | Epsilon added by the set convolutions before dividing by the 231 | density channel. Defaults to ``1e-2``. 232 | 233 | Returns: 234 | :class:`.model.Model`: 235 | ConvNP model. 236 | 237 | Raises: 238 | NotImplementedError 239 | If specified backend has no default dtype. 240 | """ 241 | if likelihood == "cnp": 242 | likelihood = "het" 243 | elif likelihood == "gnp": 244 | likelihood = "lowrank" 245 | elif likelihood == "cnp-spikes-beta": 246 | likelihood = "spikes-beta" 247 | elif likelihood == "cnp-bernoulli-gamma": 248 | likelihood = "bernoulli-gamma" 249 | 250 | # Log the call signature for `construct_convgnp` 251 | config = dict(locals()) 252 | 253 | if backend.str == "torch": 254 | import torch 255 | 256 | dtype = torch.float32 257 | elif backend.str == "tf": 258 | import tensorflow as tf 259 | 260 | dtype = tf.float32 261 | else: 262 | raise NotImplementedError(f"Backend {backend.str} has no default dtype.") 263 | 264 | neural_process = backend.nps.construct_convgnp( 265 | dim_x=dim_x, 266 | dim_yc=dim_yc, 267 | dim_yt=dim_yt, 268 | dim_aux_t=dim_aux_t, 269 | dim_lv=dim_lv, 270 | likelihood=likelihood, 271 | conv_arch=conv_arch, 272 | unet_channels=tuple(unet_channels), 273 | unet_resize_convs=unet_resize_convs, 274 | unet_resize_conv_interp_method=unet_resize_conv_interp_method, 275 | aux_t_mlp_layers=aux_t_mlp_layers, 276 | unet_kernels=unet_kernels, 277 | # Use a stride of 1 for the first layer and 2 for all other layers 278 | unet_strides=(1, *(2,) * (len(unet_channels) - 1)), 279 | points_per_unit=internal_density, 280 | encoder_scales=encoder_scales, 281 | encoder_scales_learnable=encoder_scales_learnable, 282 | decoder_scale=decoder_scale, 283 | decoder_scale_learnable=decoder_scale_learnable, 284 | num_basis_functions=num_basis_functions, 285 | epsilon=epsilon, 286 | dtype=dtype, 287 | ) 288 | 289 | return neural_process, config 290 | 291 | 292 | def compute_encoding_tensor(model, task: Task): 293 | """Compute the encoding tensor for a given task. 294 | 295 | Args: 296 | model (...): 297 | Model object. 298 | task (:class:`~.data.task.Task`): 299 | Task object containing context and target sets. 300 | 301 | Returns: 302 | encoding : :class:`numpy:numpy.ndarray` 303 | Encoding tensor? #TODO 304 | """ 305 | neural_process_encoder = backend.nps.Model(model.model.encoder, lambda x: x) 306 | task = model.modify_task(task) 307 | encoding = B.to_numpy(run_nps_model(neural_process_encoder, task)) 308 | return encoding 309 | -------------------------------------------------------------------------------- /deepsensor/model/pred.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Optional 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import xarray as xr 6 | 7 | Timestamp = Union[str, pd.Timestamp, np.datetime64] 8 | 9 | 10 | class Prediction(dict): 11 | """Object to store model predictions in a dictionary-like format. 12 | 13 | Maps from target variable IDs to xarray/pandas objects containing 14 | prediction parameters (depending on the output distribution of the model). 15 | 16 | For example, if the model outputs a Gaussian distribution, then the xarray/pandas 17 | objects in the ``Prediction`` will contain a ``mean`` and ``std``. 18 | 19 | If using a ``Prediction`` to store model samples, there is only a ``samples`` entry, and the 20 | xarray/pandas objects will have an additional ``sample`` dimension. 21 | 22 | Args: 23 | target_var_IDs (List[str]) 24 | List of target variable IDs. 25 | dates (List[Union[str, pd.Timestamp]]) 26 | List of dates corresponding to the predictions. 27 | X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`) 28 | Target locations to predict at. Can be an xarray object containing 29 | on-grid locations or a pandas object containing off-grid locations. 30 | X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional) 31 | 2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated 32 | to the same grid as ``X_t``. Default None (no mask). 33 | n_samples (int) 34 | Number of joint samples to draw from the model. If 0, will not 35 | draw samples. Default 0. 36 | forecasting_mode (bool) 37 | If True, stored forecast predictions with an init_time and lead_time dimension, 38 | and a valid_time coordinate. If False, stores prediction at t=0 only 39 | (i.e. spatial interpolation), with only a single time dimension. Default False. 40 | lead_times (List[pd.Timedelta], optional) 41 | List of lead times to store in predictions. Must be provided if 42 | forecasting_mode is True. Default None. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | target_var_IDs: List[str], 48 | pred_params: List[str], 49 | dates: List[Timestamp], 50 | X_t: Union[ 51 | xr.Dataset, 52 | xr.DataArray, 53 | pd.DataFrame, 54 | pd.Series, 55 | pd.Index, 56 | np.ndarray, 57 | ], 58 | X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None, 59 | coord_names: dict = None, 60 | n_samples: int = 0, 61 | forecasting_mode: bool = False, 62 | lead_times: Optional[List[pd.Timedelta]] = None, 63 | ): 64 | self.target_var_IDs = target_var_IDs 65 | self.X_t_mask = X_t_mask 66 | if coord_names is None: 67 | coord_names = {"x1": "x1", "x2": "x2"} 68 | self.x1_name = coord_names["x1"] 69 | self.x2_name = coord_names["x2"] 70 | 71 | self.forecasting_mode = forecasting_mode 72 | if forecasting_mode: 73 | assert ( 74 | lead_times is not None 75 | ), "If forecasting_mode is True, lead_times must be provided." 76 | self.lead_times = lead_times 77 | 78 | self.mode = infer_prediction_modality_from_X_t(X_t) 79 | 80 | self.pred_params = pred_params 81 | if n_samples >= 1: 82 | self.pred_params = [ 83 | *pred_params, 84 | *[f"sample_{i}" for i in range(n_samples)], 85 | ] 86 | 87 | # Create empty xarray/pandas objects to store predictions 88 | if self.mode == "on-grid": 89 | for var_ID in self.target_var_IDs: 90 | if self.forecasting_mode: 91 | prepend_dims = ["lead_time"] 92 | prepend_coords = {"lead_time": lead_times} 93 | else: 94 | prepend_dims = None 95 | prepend_coords = None 96 | self[var_ID] = create_empty_spatiotemporal_xarray( 97 | X_t, 98 | dates, 99 | data_vars=self.pred_params, 100 | coord_names=coord_names, 101 | prepend_dims=prepend_dims, 102 | prepend_coords=prepend_coords, 103 | ) 104 | if self.forecasting_mode: 105 | self[var_ID] = self[var_ID].rename(time="init_time") 106 | if self.X_t_mask is None: 107 | # Create 2D boolean array of True values to simplify indexing 108 | self.X_t_mask = ( 109 | create_empty_spatiotemporal_xarray(X_t, dates[0:1], coord_names) 110 | .to_array() 111 | .isel(time=0, variable=0) 112 | .astype(bool) 113 | ) 114 | elif self.mode == "off-grid": 115 | # Repeat target locs for each date to create multiindex 116 | if self.forecasting_mode: 117 | index_names = ["lead_time", "init_time", *X_t.index.names] 118 | idxs = [ 119 | (lt, date, *idxs) 120 | for lt in lead_times 121 | for date in dates 122 | for idxs in X_t.index 123 | ] 124 | else: 125 | index_names = ["time", *X_t.index.names] 126 | idxs = [(date, *idxs) for date in dates for idxs in X_t.index] 127 | index = pd.MultiIndex.from_tuples(idxs, names=index_names) 128 | for var_ID in self.target_var_IDs: 129 | self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params) 130 | 131 | def __getitem__(self, key): 132 | # Support self[i] syntax 133 | if isinstance(key, int): 134 | key = self.target_var_IDs[key] 135 | return super().__getitem__(key) 136 | 137 | def __str__(self): 138 | dict_repr = {var_ID: self.pred_params for var_ID in self.target_var_IDs} 139 | return f"Prediction({dict_repr}), mode={self.mode}" 140 | 141 | def assign( 142 | self, 143 | prediction_parameter: str, 144 | date: Union[str, pd.Timestamp], 145 | data: np.ndarray, 146 | lead_times: Optional[List[pd.Timedelta]] = None, 147 | ): 148 | """Args: 149 | prediction_parameter (str) 150 | ... 151 | date (Union[str, pd.Timestamp]) 152 | ... 153 | data (np.ndarray) 154 | If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets). 155 | If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2). 156 | lead_time (pd.Timedelta, optional) 157 | Lead time of the forecast. Required if forecasting_mode is True. Default None. 158 | """ 159 | if self.forecasting_mode: 160 | assert ( 161 | lead_times is not None 162 | ), "If forecasting_mode is True, lead_times must be provided." 163 | 164 | msg = f""" 165 | If forecasting_mode is True, lead_times must be of equal length to the number of 166 | variables in the data (the first dimension). Got {lead_times=} of length 167 | {len(lead_times)} lead times and data shape {data.shape}. 168 | """ 169 | assert len(lead_times) == data.shape[0], msg 170 | 171 | if self.mode == "on-grid": 172 | if prediction_parameter != "samples": 173 | for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)): 174 | if self.forecasting_mode: 175 | index = (lead_times[i], date) 176 | else: 177 | index = date 178 | self[var_ID][prediction_parameter].loc[index].data[ 179 | self.X_t_mask.data 180 | ] = pred.ravel() 181 | elif prediction_parameter == "samples": 182 | assert len(data.shape) == 4, ( 183 | f"If prediction_parameter is 'samples', and mode is 'on-grid', data must" 184 | f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}." 185 | ) 186 | for sample_i, sample in enumerate(data): 187 | for i, (var_ID, pred) in enumerate( 188 | zip(self.target_var_IDs, sample) 189 | ): 190 | if self.forecasting_mode: 191 | index = (lead_times[i], date) 192 | else: 193 | index = date 194 | self[var_ID][f"sample_{sample_i}"].loc[index].data[ 195 | self.X_t_mask.data 196 | ] = pred.ravel() 197 | 198 | elif self.mode == "off-grid": 199 | if prediction_parameter != "samples": 200 | for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)): 201 | if self.forecasting_mode: 202 | index = (lead_times[i], date) 203 | else: 204 | index = date 205 | self[var_ID].loc[index, prediction_parameter] = pred 206 | elif prediction_parameter == "samples": 207 | assert len(data.shape) == 3, ( 208 | f"If prediction_parameter is 'samples', and mode is 'off-grid', data must" 209 | f"have shape (N_samples, N_var, N_targets). Got {data.shape}." 210 | ) 211 | for sample_i, sample in enumerate(data): 212 | for i, (var_ID, pred) in enumerate( 213 | zip(self.target_var_IDs, sample) 214 | ): 215 | if self.forecasting_mode: 216 | index = (lead_times[i], date) 217 | else: 218 | index = date 219 | self[var_ID].loc[index, f"sample_{sample_i}"] = pred 220 | 221 | 222 | def create_empty_spatiotemporal_xarray( 223 | X: Union[xr.Dataset, xr.DataArray], 224 | dates: List[Timestamp], 225 | coord_names: dict = None, 226 | data_vars: List[str] = None, 227 | prepend_dims: Optional[List[str]] = None, 228 | prepend_coords: Optional[dict] = None, 229 | ): 230 | """... 231 | 232 | Args: 233 | X (:class:`xarray.Dataset` | :class:`xarray.DataArray`): 234 | ... 235 | dates (List[...]): 236 | ... 237 | coord_names (dict, optional): 238 | Dict mapping from normalised coord names to raw coord names, 239 | by default {"x1": "x1", "x2": "x2"} 240 | data_vars (List[str], optional): 241 | ..., by default ["var"] 242 | prepend_dims (List[str], optional): 243 | ..., by default None 244 | prepend_coords (dict, optional): 245 | ..., by default None 246 | 247 | Returns: 248 | ... 249 | ... 250 | 251 | Raises: 252 | ValueError 253 | If ``data_vars`` contains duplicate values. 254 | ValueError 255 | If ``coord_names["x1"]`` is not uniformly spaced. 256 | ValueError 257 | If ``coord_names["x2"]`` is not uniformly spaced. 258 | ValueError 259 | If ``prepend_dims`` and ``prepend_coords`` are not the same length. 260 | """ 261 | if coord_names is None: 262 | coord_names = {"x1": "x1", "x2": "x2"} 263 | if data_vars is None: 264 | data_vars = ["var"] 265 | 266 | if prepend_dims is None: 267 | prepend_dims = [] 268 | if prepend_coords is None: 269 | prepend_coords = {} 270 | 271 | # Check for any repeated data_vars 272 | if len(data_vars) != len(set(data_vars)): 273 | raise ValueError( 274 | f"Duplicate data_vars found in data_vars: {data_vars}. " 275 | "This would cause the xarray.Dataset to have fewer variables than expected." 276 | ) 277 | 278 | x1_predict = X.coords[coord_names["x1"]] 279 | x2_predict = X.coords[coord_names["x2"]] 280 | 281 | if len(prepend_dims) != len(set(prepend_dims)): 282 | # TODO unit test 283 | raise ValueError( 284 | f"Length of prepend_dims ({len(prepend_dims)}) must be equal to length of " 285 | f"prepend_coords ({len(prepend_coords)})." 286 | ) 287 | 288 | dims = [*prepend_dims, "time", coord_names["x1"], coord_names["x2"]] 289 | coords = { 290 | **prepend_coords, 291 | "time": pd.to_datetime(dates), 292 | coord_names["x1"]: x1_predict, 293 | coord_names["x2"]: x2_predict, 294 | } 295 | 296 | pred_ds = xr.Dataset( 297 | {data_var: xr.DataArray(dims=dims, coords=coords) for data_var in data_vars} 298 | ).astype("float32") 299 | 300 | # Convert time coord to pandas timestamps 301 | pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values)) 302 | 303 | return pred_ds 304 | 305 | 306 | def increase_spatial_resolution( 307 | X_t_normalised, 308 | resolution_factor, 309 | coord_names: dict = None, 310 | ): 311 | """... 312 | 313 | .. 314 | # TODO wasteful to interpolate X_t_normalised 315 | 316 | Args: 317 | X_t_normalised (...): 318 | ... 319 | resolution_factor (...): 320 | ... 321 | coord_names (dict, optional): 322 | Dict mapping from normalised coord names to raw coord names, 323 | by default {"x1": "x1", "x2": "x2"} 324 | 325 | Returns: 326 | ... 327 | ... 328 | 329 | """ 330 | assert isinstance(resolution_factor, (float, int)) 331 | assert isinstance(X_t_normalised, (xr.DataArray, xr.Dataset)) 332 | if coord_names is None: 333 | coord_names = {"x1": "x1", "x2": "x2"} 334 | x1_name, x2_name = coord_names["x1"], coord_names["x2"] 335 | x1, x2 = X_t_normalised.coords[x1_name], X_t_normalised.coords[x2_name] 336 | x1 = np.linspace(x1[0], x1[-1], int(x1.size * resolution_factor), dtype="float64") 337 | x2 = np.linspace(x2[0], x2[-1], int(x2.size * resolution_factor), dtype="float64") 338 | X_t_normalised = X_t_normalised.interp( 339 | **{x1_name: x1, x2_name: x2}, method="nearest" 340 | ) 341 | return X_t_normalised 342 | 343 | 344 | def infer_prediction_modality_from_X_t( 345 | X_t: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, pd.Index, np.ndarray], 346 | ) -> str: 347 | """Args: 348 | X_t (Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, pd.Index, np.ndarray]): 349 | ... 350 | 351 | Returns: 352 | str: "on-grid" if X_t is an xarray object, "off-grid" if X_t is a pandas or numpy object. 353 | 354 | Raises: 355 | ValueError 356 | If X_t is not an xarray, pandas or numpy object. 357 | """ 358 | if isinstance(X_t, (xr.DataArray, xr.Dataset)): 359 | mode = "on-grid" 360 | elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)): 361 | mode = "off-grid" 362 | else: 363 | raise ValueError( 364 | f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}." 365 | ) 366 | return mode 367 | -------------------------------------------------------------------------------- /deepsensor/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/deepsensor/py.typed -------------------------------------------------------------------------------- /deepsensor/tensorflow/__init__.py: -------------------------------------------------------------------------------- 1 | # Load the tensorflow extension in lab (only needs to be called once in a session) 2 | import lab.tensorflow as B # noqa 3 | 4 | # Load the TF extension in nps (to assign to deepsensor backend) 5 | import neuralprocesses.tensorflow as nps 6 | 7 | import tensorflow as tf 8 | import tensorflow.keras 9 | 10 | # Necessary for dispatching with TF and PyTorch model types when they have not yet been loaded. 11 | # See https://beartype.github.io/plum/types.html#moduletype 12 | from plum import clear_all_cache 13 | 14 | clear_all_cache() 15 | 16 | from .. import * # noqa 17 | 18 | 19 | def convert_to_tensor(arr): 20 | """Convert `arr` to tensorflow tensor.""" 21 | return tf.convert_to_tensor(arr) 22 | 23 | 24 | from .. import config as deepsensor_config 25 | from .. import backend 26 | 27 | backend.nps = nps 28 | backend.model = tf.keras.Model 29 | backend.convert_to_tensor = convert_to_tensor 30 | backend.str = "tf" 31 | 32 | B.epsilon = deepsensor_config.DEFAULT_LAB_EPSILON 33 | -------------------------------------------------------------------------------- /deepsensor/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Load the torch extension in lab (only needs to be called once in a session) 2 | import lab.torch as B # noqa 3 | 4 | # Load the TF extension in nps (to assign to deepsensor backend) 5 | import neuralprocesses.torch as nps 6 | 7 | import torch 8 | import torch.nn 9 | 10 | # Necessary for dispatching with TF and PyTorch model types when they have not yet been loaded. 11 | # See https://beartype.github.io/plum/types.html#moduletype 12 | from plum import clear_all_cache 13 | 14 | clear_all_cache() 15 | 16 | from .. import * # noqa 17 | 18 | 19 | def convert_to_tensor(arr): 20 | """Convert `arr` to pytorch tensor.""" 21 | return torch.tensor(arr) 22 | 23 | 24 | from .. import config as deepsensor_config 25 | from .. import backend 26 | 27 | backend.nps = nps 28 | backend.model = torch.nn.Module 29 | backend.convert_to_tensor = convert_to_tensor 30 | backend.str = "torch" 31 | 32 | B.epsilon = deepsensor_config.DEFAULT_LAB_EPSILON 33 | -------------------------------------------------------------------------------- /deepsensor/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import * 2 | -------------------------------------------------------------------------------- /deepsensor/train/train.py: -------------------------------------------------------------------------------- 1 | import deepsensor 2 | from deepsensor.data.task import Task, concat_tasks 3 | from deepsensor.model.convnp import ConvNP 4 | 5 | import numpy as np 6 | 7 | import lab as B 8 | 9 | from typing import List 10 | 11 | 12 | def set_gpu_default_device() -> None: 13 | """Set default GPU device for the backend. 14 | 15 | Raises: 16 | RuntimeError 17 | If no GPU is available. 18 | RuntimeError 19 | If backend is not supported. 20 | NotImplementedError 21 | If backend is not supported. 22 | 23 | Returns: 24 | None. 25 | """ 26 | if deepsensor.backend.str == "torch": 27 | # Run on GPU if available 28 | import torch 29 | 30 | if torch.cuda.is_available(): 31 | # Set default GPU device 32 | torch.set_default_device("cuda") 33 | B.set_global_device("cuda:0") 34 | else: 35 | raise RuntimeError("No GPU available: torch.cuda.is_available() == False") 36 | elif deepsensor.backend.str == "tf": 37 | # Run on GPU if available 38 | import tensorflow as tf 39 | 40 | if tf.test.is_gpu_available(): 41 | # Set default GPU device 42 | tf.config.set_visible_devices( 43 | tf.config.list_physical_devices("GPU")[0], "GPU" 44 | ) 45 | B.set_global_device("GPU:0") 46 | else: 47 | raise RuntimeError("No GPU available: tf.test.is_gpu_available() == False") 48 | 49 | else: 50 | raise NotImplementedError(f"Backend {deepsensor.backend.str} not implemented") 51 | 52 | 53 | def train_epoch( 54 | model: ConvNP, 55 | tasks: List[Task], 56 | lr: float = 5e-5, 57 | batch_size: int = None, 58 | opt=None, 59 | progress_bar=False, 60 | tqdm_notebook=False, 61 | ) -> List[float]: 62 | """Train model for one epoch. 63 | 64 | Args: 65 | model (:class:`~.model.convnp.ConvNP`): 66 | Model to train. 67 | tasks (List[:class:`~.data.task.Task`]): 68 | List of tasks to train on. 69 | lr (float, optional): 70 | Learning rate, by default 5e-5. 71 | batch_size (int, optional): 72 | Batch size. Defaults to None. If None, no batching is performed. 73 | opt (Optimizer, optional): 74 | TF or Torch optimizer. Defaults to None. If None, 75 | :class:`tensorflow:tensorflow.keras.optimizer.Adam` is used. 76 | progress_bar (bool, optional): 77 | Whether to display a progress bar. Defaults to False. 78 | tqdm_notebook (bool, optional): 79 | Whether to use a notebook progress bar. Defaults to False. 80 | 81 | Returns: 82 | List[float]: List of losses for each task/batch. 83 | """ 84 | if deepsensor.backend.str == "tf": 85 | import tensorflow as tf 86 | 87 | if opt is None: 88 | opt = tf.keras.optimizers.Adam(lr) 89 | 90 | def train_step(tasks): 91 | if not isinstance(tasks, list): 92 | tasks = [tasks] 93 | with tf.GradientTape() as tape: 94 | task_losses = [] 95 | for task in tasks: 96 | task_losses.append(model.loss_fn(task, normalise=True)) 97 | mean_batch_loss = B.mean(B.stack(*task_losses)) 98 | grads = tape.gradient(mean_batch_loss, model.model.trainable_weights) 99 | opt.apply_gradients(zip(grads, model.model.trainable_weights)) 100 | return mean_batch_loss 101 | 102 | elif deepsensor.backend.str == "torch": 103 | import torch.optim as optim 104 | 105 | if opt is None: 106 | opt = optim.Adam(model.model.parameters(), lr=lr) 107 | 108 | def train_step(tasks): 109 | if not isinstance(tasks, list): 110 | tasks = [tasks] 111 | opt.zero_grad() 112 | task_losses = [] 113 | for task in tasks: 114 | task_losses.append(model.loss_fn(task, normalise=True)) 115 | mean_batch_loss = B.mean(B.stack(*task_losses)) 116 | mean_batch_loss.backward() 117 | opt.step() 118 | return mean_batch_loss.detach().cpu().numpy() 119 | 120 | else: 121 | raise NotImplementedError(f"Backend {deepsensor.backend.str} not implemented") 122 | 123 | tasks = np.random.permutation(tasks) 124 | 125 | if batch_size is not None: 126 | n_batches = len(tasks) // batch_size # Note that this will drop the remainder 127 | else: 128 | n_batches = len(tasks) 129 | 130 | if tqdm_notebook: 131 | from tqdm.notebook import tqdm 132 | else: 133 | from tqdm import tqdm 134 | 135 | batch_losses = [] 136 | for batch_i in tqdm(range(n_batches), disable=not progress_bar): 137 | if batch_size is not None: 138 | task = concat_tasks( 139 | tasks[batch_i * batch_size : (batch_i + 1) * batch_size] 140 | ) 141 | else: 142 | task = tasks[batch_i] 143 | batch_loss = train_step(task) 144 | batch_losses.append(batch_loss) 145 | 146 | return batch_losses 147 | 148 | 149 | class Trainer: 150 | """Class for training ConvNP models with an Adam optimiser. 151 | 152 | Args: 153 | lr (float): Learning rate 154 | """ 155 | 156 | def __init__(self, model: ConvNP, lr: float = 5e-5): 157 | if deepsensor.backend.str == "tf": 158 | import tensorflow as tf 159 | 160 | self.opt = tf.keras.optimizers.Adam(lr) 161 | elif deepsensor.backend.str == "torch": 162 | import torch.optim as optim 163 | 164 | self.opt = optim.Adam(model.model.parameters(), lr=lr) 165 | 166 | self.model = model 167 | 168 | def __call__( 169 | self, 170 | tasks: List[Task], 171 | batch_size: int = None, 172 | progress_bar=False, 173 | tqdm_notebook=False, 174 | ) -> List[float]: 175 | """Train model for one epoch.""" 176 | return train_epoch( 177 | model=self.model, 178 | tasks=tasks, 179 | batch_size=batch_size, 180 | opt=self.opt, 181 | progress_bar=progress_bar, 182 | tqdm_notebook=tqdm_notebook, 183 | ) 184 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | # Book settings 2 | # Learn more at https://jupyterbook.org/customize/config.html 3 | 4 | title: DeepSensor 5 | author: Tom Andersson 6 | logo: ../figs/DeepSensorLogo2.png 7 | email: tomandersson3@gmail.com 8 | 9 | # Force re-execution of notebooks on each build. 10 | # See https://jupyterbook.org/content/execute.html 11 | execute: 12 | execute_notebooks: off # Don't run notebooks during doc compilation 13 | # execute_notebooks: force 14 | # timeout: 1200 # 20 mins 15 | # # Exclude notebooks with model training 16 | # exclude_patterns: 17 | # - "*quickstart*" 18 | 19 | # Define the name of the latex output file for PDF builds 20 | latex: 21 | latex_documents: 22 | targetname: deepsensor.tex 23 | 24 | only_build_toc_files: true 25 | 26 | # Add a bibtex file so that we can create citations 27 | bibtex_bibfiles: 28 | - references.bib 29 | 30 | # Information about where the book exists on the web 31 | repository: 32 | url: https://github.com/alan-turing-institute/deepsensor # Online location of your book 33 | path_to_book: docs # Optional path to your book, relative to the repository root 34 | branch: main # Which branch of the repository should be used when creating links (optional) 35 | 36 | # Add GitHub buttons to your book 37 | # See https://jupyterbook.org/customize/config.html#add-a-link-to-your-repository 38 | html: 39 | use_issues_button: true 40 | use_repository_button: true 41 | 42 | parse: 43 | myst_enable_extensions: 44 | # don't forget to list any other extensions you want enabled, 45 | # including those that are enabled by default! See here: https://jupyterbook.org/en/stable/customize/config.html 46 | # - amsmath 47 | - colon_fence 48 | # - deflist 49 | - dollarmath 50 | # - html_admonition 51 | # - html_image 52 | - linkify 53 | # - replacements 54 | # - smartquotes 55 | - substitution 56 | - tasklist 57 | - html_image # Added to support HTML images in DeepSensor documentation 58 | myst_url_schemes: [mailto, http, https] # URI schemes that will be recognised as external URLs in Markdown links 59 | myst_dmath_double_inline: true # Allow display math ($$) within an inline context 60 | 61 | 62 | sphinx: 63 | extra_extensions: 64 | - 'sphinx.ext.autodoc' 65 | - 'sphinx.ext.napoleon' 66 | - 'sphinx.ext.viewcode' 67 | - 'sphinx.ext.todo' 68 | config: 69 | add_module_names: False 70 | autodoc_typehints: "none" 71 | autoclass_content: "class" 72 | bibtex_reference_style: author_year 73 | napoleon_use_rtype: False 74 | todo_include_todos: True 75 | intersphinx_mapping: 76 | python: 77 | - https://docs.python.org/3 78 | - null 79 | pandas: 80 | - http://pandas.pydata.org/pandas-docs/stable/ 81 | - null 82 | # tensorflow: 83 | # - http://www.tensorflow.org/api_docs/python 84 | # - https://raw.githubusercontent.com/GPflow/tensorflow-intersphinx/master/tf2_py_objects.inv 85 | numpy: 86 | - https://numpy.org/doc/stable/ 87 | - null 88 | matplotlib: 89 | - http://matplotlib.org/stable/ 90 | - null 91 | xarray: 92 | - http://xarray.pydata.org/en/stable/ 93 | - https://docs.xarray.dev/en/stable/objects.inv 94 | language: en 95 | copybutton_prompt_text: "$" -------------------------------------------------------------------------------- /docs/_static/index_api.svg: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="UTF-8" standalone="no"?> 2 | <!-- Created with Inkscape (http://www.inkscape.org/) --> 3 | 4 | <svg 5 | xmlns:dc="http://purl.org/dc/elements/1.1/" 6 | xmlns:cc="http://creativecommons.org/ns#" 7 | xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" 8 | xmlns:svg="http://www.w3.org/2000/svg" 9 | xmlns="http://www.w3.org/2000/svg" 10 | xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" 11 | xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" 12 | width="99.058548mm" 13 | height="89.967583mm" 14 | viewBox="0 0 99.058554 89.967582" 15 | version="1.1" 16 | id="svg1040" 17 | inkscape:version="0.92.4 (f8dce91, 2019-08-02)" 18 | sodipodi:docname="index_api.svg"> 19 | <defs 20 | id="defs1034" /> 21 | <sodipodi:namedview 22 | id="base" 23 | pagecolor="#ffffff" 24 | bordercolor="#666666" 25 | borderopacity="1.0" 26 | inkscape:pageopacity="0.0" 27 | inkscape:pageshadow="2" 28 | inkscape:zoom="0.35" 29 | inkscape:cx="533.74914" 30 | inkscape:cy="10.90433" 31 | inkscape:document-units="mm" 32 | inkscape:current-layer="layer1" 33 | showgrid="false" 34 | fit-margin-top="0" 35 | fit-margin-left="0" 36 | fit-margin-right="0" 37 | fit-margin-bottom="0" 38 | inkscape:window-width="930" 39 | inkscape:window-height="472" 40 | inkscape:window-x="2349" 41 | inkscape:window-y="267" 42 | inkscape:window-maximized="0" /> 43 | <metadata 44 | id="metadata1037"> 45 | <rdf:RDF> 46 | <cc:Work 47 | rdf:about=""> 48 | <dc:format>image/svg+xml</dc:format> 49 | <dc:type 50 | rdf:resource="http://purl.org/dc/dcmitype/StillImage" /> 51 | <dc:title></dc:title> 52 | </cc:Work> 53 | </rdf:RDF> 54 | </metadata> 55 | <g 56 | inkscape:label="Layer 1" 57 | inkscape:groupmode="layer" 58 | id="layer1" 59 | transform="translate(195.19933,-1.0492759)"> 60 | <g 61 | id="g1008" 62 | transform="matrix(1.094977,0,0,1.094977,-521.5523,-198.34055)"> 63 | <path 64 | inkscape:connector-curvature="0" 65 | id="path899" 66 | d="M 324.96812,187.09499 H 303.0455 v 72.1639 h 22.67969" 67 | style="fill:none;stroke:#5a5a5a;stroke-width:10;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" /> 68 | <path 69 | inkscape:connector-curvature="0" 70 | id="path899-3" 71 | d="m 361.58921,187.09499 h 21.92262 v 72.1639 h -22.67969" 72 | style="fill:none;stroke:#5a5a5a;stroke-width:10;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" /> 73 | <g 74 | transform="translate(415.87139,46.162126)" 75 | id="g944"> 76 | <circle 77 | style="fill:#5a5a5a;fill-opacity:1;stroke:#5a5a5a;stroke-width:4.53704548;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" 78 | id="path918" 79 | cx="-84.40152" 80 | cy="189.84375" 81 | r="2.2293637" /> 82 | <circle 83 | style="fill:#5a5a5a;fill-opacity:1;stroke:#5a5a5a;stroke-width:4.53704548;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" 84 | id="path918-5" 85 | cx="-72.949402" 86 | cy="189.84375" 87 | r="2.2293637" /> 88 | <circle 89 | style="fill:#5a5a5a;fill-opacity:1;stroke:#5a5a5a;stroke-width:4.53704548;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" 90 | id="path918-6" 91 | cx="-61.497284" 92 | cy="189.84375" 93 | r="2.2293637" /> 94 | </g> 95 | </g> 96 | </g> 97 | </svg> 98 | -------------------------------------------------------------------------------- /docs/_static/index_community.svg: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="iso-8859-1"?> 2 | <!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools --> 3 | <svg height="800px" width="800px" version="1.1" id="Capa_1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" 4 | viewBox="0 0 502.648 502.648" xml:space="preserve"> 5 | <g> 6 | <g> 7 | <g> 8 | <circle style="fill:#010002;" cx="250.399" cy="91.549" r="58.694"/> 9 | <path style="fill:#010002;" d="M455.861,253.028l-54.703-11.411c-18.637-3.904-37.037,4.638-46.765,19.824 10 | c-9.448-4.853-19.608-9.038-30.415-12.511v-32.529c0.022-24.612-20.126-44.738-44.651-44.738h-55.933 11 | c-24.655,0-44.716,20.126-44.716,44.738v32.701c-10.699,3.408-20.751,7.593-30.264,12.468 12 | c-9.728-15.251-28.15-23.857-46.809-19.953l-54.747,11.411c-24.03,5.026-39.626,28.862-34.6,52.978l13.741,65.64 13 | c4.983,24.051,28.84,39.647,52.892,34.621l17.321-3.624c8.671,12.813,20.665,24.569,36.023,34.621 14 | c31.989,20.967,74.247,32.529,119.092,32.529c68.617,0,127.721-27.589,154.943-67.215l17.602,3.689 15 | c24.03,5.004,47.887-10.57,52.87-34.621l13.762-65.64C495.508,281.89,479.912,258.054,455.861,253.028z M251.305,447.381 16 | c-40.51,0-78.475-10.203-106.797-28.862c-9.707-6.342-17.753-13.395-24.202-20.945l13.266-2.783 17 | c24.073-5.004,39.669-28.84,34.643-52.913l-12.317-59.018c7.183-3.861,14.733-7.248,22.757-10.138v10.764 18 | c0,24.569,20.104,44.695,44.716,44.695h55.933c24.548,0,44.652-20.147,44.652-44.695v-11.325 19 | c8.175,2.912,15.854,6.256,22.973,10.052L334.439,341.9c-4.983,24.073,10.591,47.909,34.664,52.913l13.395,2.804 20 | C357.52,427.191,308.101,447.381,251.305,447.381z"/> 21 | <circle style="fill:#010002;" cx="443.954" cy="168.708" r="58.694"/> 22 | <path style="fill:#010002;" d="M70.736,226.172c31.752-6.644,52.029-37.77,45.471-69.501 23 | c-6.687-31.709-37.749-52.072-69.523-45.428c-31.709,6.622-52.072,37.727-45.428,69.458 24 | C7.879,212.453,38.984,232.795,70.736,226.172z"/> 25 | </g> 26 | </g> 27 | <g> 28 | </g> 29 | <g> 30 | </g> 31 | <g> 32 | </g> 33 | <g> 34 | </g> 35 | <g> 36 | </g> 37 | <g> 38 | </g> 39 | <g> 40 | </g> 41 | <g> 42 | </g> 43 | <g> 44 | </g> 45 | <g> 46 | </g> 47 | <g> 48 | </g> 49 | <g> 50 | </g> 51 | <g> 52 | </g> 53 | <g> 54 | </g> 55 | <g> 56 | </g> 57 | </g> 58 | </svg> -------------------------------------------------------------------------------- /docs/_static/index_community2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/docs/_static/index_community2.pdf -------------------------------------------------------------------------------- /docs/_static/index_community2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/docs/_static/index_community2.png -------------------------------------------------------------------------------- /docs/_static/index_contribute.svg: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="UTF-8" standalone="no"?> 2 | <!-- Created with Inkscape (http://www.inkscape.org/) --> 3 | 4 | <svg 5 | xmlns:dc="http://purl.org/dc/elements/1.1/" 6 | xmlns:cc="http://creativecommons.org/ns#" 7 | xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" 8 | xmlns:svg="http://www.w3.org/2000/svg" 9 | xmlns="http://www.w3.org/2000/svg" 10 | xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" 11 | xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" 12 | width="89.624855mm" 13 | height="89.96759mm" 14 | viewBox="0 0 89.62486 89.96759" 15 | version="1.1" 16 | id="svg1040" 17 | inkscape:version="0.92.4 (f8dce91, 2019-08-02)" 18 | sodipodi:docname="index_contribute.svg"> 19 | <defs 20 | id="defs1034" /> 21 | <sodipodi:namedview 22 | id="base" 23 | pagecolor="#ffffff" 24 | bordercolor="#666666" 25 | borderopacity="1.0" 26 | inkscape:pageopacity="0.0" 27 | inkscape:pageshadow="2" 28 | inkscape:zoom="0.35" 29 | inkscape:cx="683.11893" 30 | inkscape:cy="-59.078181" 31 | inkscape:document-units="mm" 32 | inkscape:current-layer="layer1" 33 | showgrid="false" 34 | fit-margin-top="0" 35 | fit-margin-left="0" 36 | fit-margin-right="0" 37 | fit-margin-bottom="0" 38 | inkscape:window-width="930" 39 | inkscape:window-height="472" 40 | inkscape:window-x="2349" 41 | inkscape:window-y="267" 42 | inkscape:window-maximized="0" /> 43 | <metadata 44 | id="metadata1037"> 45 | <rdf:RDF> 46 | <cc:Work 47 | rdf:about=""> 48 | <dc:format>image/svg+xml</dc:format> 49 | <dc:type 50 | rdf:resource="http://purl.org/dc/dcmitype/StillImage" /> 51 | <dc:title></dc:title> 52 | </cc:Work> 53 | </rdf:RDF> 54 | </metadata> 55 | <g 56 | inkscape:label="Layer 1" 57 | inkscape:groupmode="layer" 58 | id="layer1" 59 | transform="translate(234.72009,17.466935)"> 60 | <g 61 | id="g875" 62 | transform="matrix(0.99300176,0,0,0.99300176,-133.24106,-172.58804)"> 63 | <path 64 | sodipodi:nodetypes="ccc" 65 | inkscape:connector-curvature="0" 66 | id="path869" 67 | d="m -97.139881,161.26069 47.247024,40.25446 -47.247024,40.25446" 68 | style="fill:none;stroke:#5a5a5a;stroke-width:10;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" /> 69 | <path 70 | inkscape:connector-curvature="0" 71 | id="path871" 72 | d="m -49.514879,241.81547 h 32.505951" 73 | style="fill:none;stroke:#5a5a5a;stroke-width:10;stroke-linecap:round;stroke-linejoin:round;stroke-miterlimit:4;stroke-dasharray:none;stroke-opacity:1" /> 74 | </g> 75 | </g> 76 | </svg> 77 | -------------------------------------------------------------------------------- /docs/_static/index_getting_started.svg: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="UTF-8" standalone="no"?> 2 | <!-- Created with Inkscape (http://www.inkscape.org/) --> 3 | 4 | <svg 5 | xmlns:dc="http://purl.org/dc/elements/1.1/" 6 | xmlns:cc="http://creativecommons.org/ns#" 7 | xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" 8 | xmlns:svg="http://www.w3.org/2000/svg" 9 | xmlns="http://www.w3.org/2000/svg" 10 | xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" 11 | xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" 12 | width="101.09389mm" 13 | height="89.96759mm" 14 | viewBox="0 0 101.09389 89.96759" 15 | version="1.1" 16 | id="svg1040" 17 | inkscape:version="0.92.4 (f8dce91, 2019-08-02)" 18 | sodipodi:docname="index_getting_started.svg"> 19 | <defs 20 | id="defs1034" /> 21 | <sodipodi:namedview 22 | id="base" 23 | pagecolor="#ffffff" 24 | bordercolor="#666666" 25 | borderopacity="1.0" 26 | inkscape:pageopacity="0.0" 27 | inkscape:pageshadow="2" 28 | inkscape:zoom="0.35" 29 | inkscape:cx="-93.242129" 30 | inkscape:cy="-189.9825" 31 | inkscape:document-units="mm" 32 | inkscape:current-layer="layer1" 33 | showgrid="false" 34 | fit-margin-top="0" 35 | fit-margin-left="0" 36 | fit-margin-right="0" 37 | fit-margin-bottom="0" 38 | inkscape:window-width="1875" 39 | inkscape:window-height="1056" 40 | inkscape:window-x="1965" 41 | inkscape:window-y="0" 42 | inkscape:window-maximized="1" /> 43 | <metadata 44 | id="metadata1037"> 45 | <rdf:RDF> 46 | <cc:Work 47 | rdf:about=""> 48 | <dc:format>image/svg+xml</dc:format> 49 | <dc:type 50 | rdf:resource="http://purl.org/dc/dcmitype/StillImage" /> 51 | <dc:title></dc:title> 52 | </cc:Work> 53 | </rdf:RDF> 54 | </metadata> 55 | <g 56 | inkscape:label="Layer 1" 57 | inkscape:groupmode="layer" 58 | id="layer1" 59 | transform="translate(2.9219487,-8.5995374)"> 60 | <path 61 | style="fill:#5a5a5a;fill-opacity:1;stroke-width:0.20233451" 62 | d="M 37.270955,98.335591 C 33.358064,97.07991 31.237736,92.52319 32.964256,89.08022 c 0.18139,-0.361738 4.757999,-5.096629 10.17021,-10.521968 l 9.84041,-9.864254 -4.03738,-4.041175 -4.037391,-4.041172 -4.96415,4.916665 c -3.61569,3.581096 -5.238959,5.04997 -5.975818,5.407377 l -1.011682,0.490718 H 17.267525 1.5866055 L 0.65034544,70.96512 C -2.2506745,69.535833 -3.5952145,66.18561 -2.5925745,62.884631 c 0.53525,-1.762217 1.61699004,-3.050074 3.22528014,-3.839847 l 1.15623996,-0.56778 13.2591094,-0.05613 13.259111,-0.05613 11.5262,-11.527539 11.526199,-11.527528 H 40.622647 c -12.145542,0 -12.189222,-0.0046 -13.752801,-1.445851 -2.229871,-2.055423 -2.162799,-5.970551 0.135998,-7.938238 1.475193,-1.262712 1.111351,-1.238469 18.588522,-1.238469 12.899229,0 16.035311,0.05193 16.692589,0.276494 0.641832,0.219264 2.590731,2.051402 9.416301,8.852134 l 8.606941,8.575638 h 6.848168 c 4.837422,0 7.092281,0.07311 7.679571,0.249094 0.48064,0.144008 1.22985,0.634863 1.77578,1.163429 2.383085,2.307333 1.968685,6.539886 -0.804989,8.221882 -0.571871,0.346781 -1.38284,0.687226 -1.80217,0.756523 -0.41933,0.06928 -4.2741,0.127016 -8.56615,0.128238 -6.56998,0.0016 -7.977492,-0.04901 -8.902732,-0.321921 -0.975569,-0.287742 -1.400468,-0.622236 -3.783999,-2.978832 l -2.685021,-2.654679 -5.05411,5.051071 -5.0541,5.051081 3.926292,3.947202 c 2.365399,2.378001 4.114289,4.309171 4.399158,4.857713 0.39266,0.75606 0.47311,1.219412 0.474321,2.731516 0.003,3.083647 0.620779,2.331942 -13.598011,16.531349 -10.273768,10.259761 -12.679778,12.563171 -13.500979,12.92519 -1.267042,0.55857 -3.156169,0.681342 -4.390271,0.285321 z m 40.130741,-65.45839 c -2.212909,-0.579748 -3.782711,-1.498393 -5.51275,-3.226063 -2.522111,-2.518633 -3.633121,-5.181304 -3.633121,-8.707194 0,-3.530699 1.11238,-6.197124 3.631161,-8.704043 4.866751,-4.8438383 12.324781,-4.8550953 17.211791,-0.026 3.908758,3.862461 4.818578,9.377999 2.372188,14.380771 -0.846209,1.730481 -3.39493,4.326384 -5.143839,5.239072 -2.69708,1.407492 -6.042829,1.798628 -8.92543,1.043434 z" 63 | id="path1000" 64 | inkscape:connector-curvature="0" /> 65 | </g> 66 | </svg> 67 | -------------------------------------------------------------------------------- /docs/_static/index_user_guide.svg: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="UTF-8" standalone="no"?> 2 | <!-- Created with Inkscape (http://www.inkscape.org/) --> 3 | 4 | <svg 5 | xmlns:dc="http://purl.org/dc/elements/1.1/" 6 | xmlns:cc="http://creativecommons.org/ns#" 7 | xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" 8 | xmlns:svg="http://www.w3.org/2000/svg" 9 | xmlns="http://www.w3.org/2000/svg" 10 | xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd" 11 | xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape" 12 | width="123.72241mm" 13 | height="89.96759mm" 14 | viewBox="0 0 123.72242 89.96759" 15 | version="1.1" 16 | id="svg1040" 17 | inkscape:version="0.92.4 (f8dce91, 2019-08-02)" 18 | sodipodi:docname="index_userguide.svg"> 19 | <defs 20 | id="defs1034" /> 21 | <sodipodi:namedview 22 | id="base" 23 | pagecolor="#ffffff" 24 | bordercolor="#666666" 25 | borderopacity="1.0" 26 | inkscape:pageopacity="0.0" 27 | inkscape:pageshadow="2" 28 | inkscape:zoom="0.35" 29 | inkscape:cx="332.26618" 30 | inkscape:cy="83.744004" 31 | inkscape:document-units="mm" 32 | inkscape:current-layer="layer1" 33 | showgrid="false" 34 | fit-margin-top="0" 35 | fit-margin-left="0" 36 | fit-margin-right="0" 37 | fit-margin-bottom="0" 38 | inkscape:window-width="930" 39 | inkscape:window-height="472" 40 | inkscape:window-x="2349" 41 | inkscape:window-y="267" 42 | inkscape:window-maximized="0" /> 43 | <metadata 44 | id="metadata1037"> 45 | <rdf:RDF> 46 | <cc:Work 47 | rdf:about=""> 48 | <dc:format>image/svg+xml</dc:format> 49 | <dc:type 50 | rdf:resource="http://purl.org/dc/dcmitype/StillImage" /> 51 | <dc:title></dc:title> 52 | </cc:Work> 53 | </rdf:RDF> 54 | </metadata> 55 | <g 56 | inkscape:label="Layer 1" 57 | inkscape:groupmode="layer" 58 | id="layer1" 59 | transform="translate(141.8903,-20.32143)"> 60 | <path 61 | style="fill:#5a5a5a;fill-opacity:1;stroke-width:0.20483544" 62 | d="m -139.53374,110.1657 c -0.80428,-0.24884 -1.71513,-1.11296 -2.07107,-1.96486 -0.23905,-0.57214 -0.28453,-6.28104 -0.28453,-35.720988 0,-38.274546 -0.079,-35.840728 1.19849,-36.91568 0.58869,-0.495345 4.63766,-2.187548 8.47998,-3.544073 l 1.58749,-0.560453 v -3.309822 c 0,-3.025538 0.0396,-3.388179 0.46086,-4.222122 0.68808,-1.362003 1.38671,-1.714455 4.60319,-2.322195 4.12797,-0.779966 5.13304,-0.912766 8.81544,-1.16476 11.80964,-0.808168 22.80911,2.509277 30.965439,9.3392 1.750401,1.465747 3.840861,3.5635 5.0903,5.108065 l 0.659122,0.814805 0.659109,-0.814805 c 1.249431,-1.544565 3.33988,-3.642318 5.09029,-5.108065 8.156331,-6.829923 19.155791,-10.147368 30.965441,-9.3392 3.682389,0.251994 4.68748,0.384794 8.81544,1.16476 3.21647,0.60774 3.91511,0.960192 4.60318,2.322195 0.4213,0.833943 0.46087,1.196584 0.46087,4.222122 v 3.309822 l 1.58748,0.560453 c 4.10165,1.448077 7.98852,3.072753 8.5259,3.563743 1.22643,1.120567 1.15258,-1.245868 1.15258,36.927177 0,34.567591 -0.005,35.083151 -0.40663,35.903991 -0.22365,0.45804 -0.73729,1.05665 -1.14143,1.33024 -1.22281,0.82783 -2.17721,0.70485 -5.86813,-0.7561 -9.19595,-3.63998 -18.956011,-6.38443 -26.791332,-7.53353 -3.02827,-0.44412 -9.26189,-0.61543 -11.77821,-0.3237 -5.19357,0.60212 -8.736108,2.05527 -11.700039,4.79936 -0.684501,0.63371 -1.466141,1.23646 -1.736979,1.33942 -0.63859,0.2428 -4.236521,0.2428 -4.875112,0 -0.27083,-0.10296 -1.05247,-0.70571 -1.73696,-1.33942 -2.96395,-2.74409 -6.50648,-4.19724 -11.700058,-4.79936 -2.516312,-0.29173 -8.749941,-0.12042 -11.778201,0.3237 -7.78194,1.14127 -17.39965,3.83907 -26.73341,7.49883 -3.38325,1.32658 -4.15525,1.50926 -5.11851,1.21125 z m 4.2107,-5.34052 c 5.86759,-2.29858 14.40398,-4.922695 20.2018,-6.210065 6.31584,-1.402418 8.5236,-1.646248 14.91592,-1.647338 4.68699,-7.94e-4 6.013661,0.0632 7.257809,0.3497 0.837332,0.19286 1.561052,0.312028 1.60828,0.264819 0.147111,-0.147119 -1.803289,-1.307431 -4.154879,-2.471801 -8.12511,-4.023029 -18.27311,-4.986568 -29.0861,-2.761718 -1.09536,0.22538 -2.32708,0.40827 -2.73715,0.406418 -1.12787,-0.005 -2.3054,-0.76382 -2.84516,-1.8332 l -0.46086,-0.913098 V 62.99179 35.97471 l -0.56331,0.138329 c -0.30981,0.07608 -1.89985,0.665075 -3.5334,1.308881 -2.27551,0.896801 -2.96414,1.252878 -2.94452,1.522563 0.014,0.193604 0.0372,15.284513 0.0512,33.535345 0.014,18.250839 0.0538,33.183322 0.0884,33.183322 0.0346,0 1.02543,-0.3771 2.20198,-0.83801 z m 113.006991,-32.697216 -0.0518,-33.535203 -3.17495,-1.272156 c -1.74623,-0.699685 -3.33627,-1.278755 -3.53341,-1.286819 -0.33966,-0.01389 -0.35847,1.401778 -0.35847,26.980216 v 26.994863 l -0.46087,0.913112 c -0.53976,1.06939 -1.71729,1.828088 -2.84515,1.833189 -0.41008,0.0021 -1.6418,-0.181031 -2.73716,-0.406421 -11.888201,-2.446089 -22.84337,-1.046438 -31.491022,4.02332 -1.68175,0.985941 -2.216748,1.467501 -1.36534,1.228942 1.575181,-0.441362 4.990592,-0.73864 8.524862,-0.742011 5.954408,-0.005 11.43046,0.791951 19.10874,2.78333 3.9516,1.024874 12.1555,3.687454 15.6699,5.085704 1.23926,0.49306 2.36869,0.90517 2.50985,0.9158 0.20489,0.0155 0.2462,-6.745894 0.20483,-33.515866 z m -59.76135,-2.233777 V 40.065438 l -0.95972,-1.357442 c -1.380522,-1.952627 -5.376262,-5.847994 -7.64336,-7.45136 -3.778692,-2.672401 -9.063392,-4.943324 -13.672511,-5.875304 -3.19731,-0.646503 -5.23069,-0.833103 -9.05886,-0.831312 -4.37716,0.0021 -7.70223,0.349169 -11.83461,1.235469 l -1.07538,0.230645 v 31.242342 c 0,26.565778 0.0426,31.226011 0.28429,31.133261 0.15637,-0.06 1.42379,-0.297169 2.81648,-0.527026 12.37657,-2.042634 23.21658,-0.346861 32.521639,5.087596 2.10018,1.226558 5.20202,3.618878 6.880942,5.30692 0.788609,0.792909 1.502978,1.446609 1.587468,1.452679 0.0845,0.006 0.153622,-13.411893 0.153622,-29.817719 z m 5.80221,28.3766 c 6.21476,-6.141601 15.08488,-10.061509 25.025529,-11.05933 4.262419,-0.427849 11.579921,-0.0054 16.017661,0.924912 0.75932,0.15916 1.45259,0.244888 1.54058,0.190498 0.088,-0.05434 0.16003,-14.060382 0.16003,-31.124436 V 26.176883 l -0.52136,-0.198219 c -0.66893,-0.254325 -4.77649,-0.95482 -7.159981,-1.221048 -2.41372,-0.269605 -8.559851,-0.266589 -10.759229,0.0052 -6.458111,0.798299 -12.584091,3.083792 -17.405651,6.49374 -2.267091,1.603366 -6.262831,5.498733 -7.64336,7.45136 l -0.959721,1.357438 v 29.828747 c 0,16.405812 0.0532,29.828746 0.11802,29.828746 0.065,0 0.77928,-0.65347 1.587482,-1.452149 z" 63 | id="path845" 64 | inkscape:connector-curvature="0" 65 | sodipodi:nodetypes="csscccscsssscsssssscscsccsccsccscsscccccccscccccccccsccscscscccscccsccssccsscccscccccsccccsccscsccsscc" /> 66 | </g> 67 | </svg> 68 | -------------------------------------------------------------------------------- /docs/_toc.yml: -------------------------------------------------------------------------------- 1 | # Table of contents 2 | # Learn more at https://jupyterbook.org/customize/toc.html 3 | 4 | format: jb-book 5 | root: index 6 | chapters: 7 | - file: getting-started/index 8 | sections: 9 | - file: getting-started/overview 10 | - file: getting-started/data_requirements 11 | - file: getting-started/installation 12 | 13 | - file: user-guide/index 14 | sections: 15 | - file: user-guide/deepsensor_design 16 | - file: user-guide/data_processor 17 | - file: user-guide/task 18 | - file: user-guide/task_loader 19 | - file: user-guide/convnp 20 | - file: user-guide/training 21 | - file: user-guide/prediction 22 | - file: user-guide/active_learning 23 | - file: user-guide/acquisition_functions 24 | - file: user-guide/extending 25 | 26 | - file: community/index 27 | sections: 28 | - file: community/contributing 29 | - file: community/code_of_conduct 30 | - file: community/roadmap 31 | - file: community/faq 32 | 33 | - file: resources 34 | 35 | - file: research_ideas 36 | 37 | - file: contact 38 | 39 | - file: reference/index 40 | sections: 41 | - file: reference/data/index 42 | sections: 43 | - file: reference/data/sources 44 | - file: reference/data/loader 45 | - file: reference/data/processor 46 | - file: reference/data/task 47 | - file: reference/data/utils 48 | - file: reference/model/index 49 | sections: 50 | - file: reference/model/model 51 | - file: reference/model/pred 52 | - file: reference/model/convnp 53 | - file: reference/model/defaults 54 | - file: reference/model/nps 55 | - file: reference/train/index 56 | sections: 57 | - file: reference/train/train 58 | - file: reference/active_learning/index 59 | sections: 60 | - file: reference/active_learning/acquisition_fns 61 | - file: reference/active_learning/algorithms 62 | - file: reference/plot -------------------------------------------------------------------------------- /docs/community/code_of_conduct.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../CODE_OF_CONDUCT.md 2 | ``` 3 | -------------------------------------------------------------------------------- /docs/community/contributing.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../CONTRIBUTING.md 2 | ``` 3 | -------------------------------------------------------------------------------- /docs/community/faq.md: -------------------------------------------------------------------------------- 1 | # Community FAQ 2 | 3 | This FAQ aims to answer common questions about the DeepSensor library. It is our way to streamline the onboarding process and clarify expectations. 4 | 5 | ```{note} 6 | If you have a question you'd like to see answered here, make a request in a issue or in the [Slack channel](https://ai4environment.slack.com/archives/C05NQ76L87R). 7 | ``` 8 | 9 | ## Questions 10 | 11 | **Q: Why doesn't the package name `DeepSensor` mention NPs if it's all about neural processes?** 12 | 13 | **Answer:** 14 | DeepSensor aims to be extensible to models that are not necessarily NPs. 15 | We also wanted to keep the name short and easy to remember. 16 | The name `DeepSensor` is a reference to the fact that the library is about deep learning and sensor data. 17 | 18 | --- 19 | 20 | **Q: How can I contribute?** 21 | 22 | **Answer:** 23 | There are many ways to contribute, from writing code and fixing bugs to improving documentation or translating content. 24 | Check our [](./contributing.md) guide. 25 | 26 | --- 27 | 28 | **Q: Do I need to sign a Contributor License Agreement (CLA)?** 29 | 30 | **Answer:** At the current time, we do not require a CLA from our contributors. 31 | 32 | --- 33 | 34 | **Q: How do I report a bug?** 35 | 36 | **Answer:** Please submit an issue in our GitHub repository. Make sure to provide detailed information, including steps to reproduce the bug and the expected outcome. 37 | 38 | --- 39 | 40 | **Q: How do I request a new feature?** 41 | 42 | **Answer:** Open a new issue on our GitHub repository and label it as a feature request. Describe the feature in detail and its potential benefits. 43 | 44 | --- 45 | 46 | **Q: How can I get in touch with other contributors or maintainers?** 47 | 48 | **Answer:** 49 | Request to join our Slack channel to stay in touch with other contributors and maintainers. You can join by [signing up for the Turing Environment & Sustainability stakeholder community](https://forms.office.com/pages/responsepage.aspx?id=p_SVQ1XklU-Knx-672OE-ZmEJNLHTHVFkqQ97AaCfn9UMTZKT1IwTVhJRE82UjUzMVE2MThSOU5RMC4u). The form includes a question on signing up for the Slack team, where you can find DeepSensor's channel. 50 | 51 | We also have a regular community Zoom call (join the Slack channel or get in touch to find out more). 52 | 53 | --- 54 | 55 | **Q: How do I set up the development environment?** 56 | 57 | **Answer:** Follow the instructions in our developer documentation. If you run into issues, ask in our [community chat](https://ai4environment.slack.com/archives/C05NQ76L87R) (on Slack). 58 | 59 | --- 60 | 61 | **Q: Do you have a code of conduct?** 62 | 63 | **Answer:** 64 | Yes, we value a respectful and inclusive community. 65 | Please read our [](./code_of_conduct.md) before contributing. 66 | 67 | --- 68 | 69 | **Q: Can I contribute even if I'm not a coder?** 70 | 71 | **Answer:** Absolutely! Contributions can be made in the form of documentation, design, testing, and more. Everyone's skills are valuable. Join our Slack discussion to learn more. 72 | 73 | --- 74 | 75 | **Q: How do I claim an issue to work on?** 76 | 77 | **Answer:** Comment on the issue expressing your interest to help out. If the issue is unassigned, a maintainer will likely assign it to you. 78 | 79 | --- 80 | 81 | **Q: What's the process for proposing a significant change?** 82 | 83 | **Answer:** For significant changes, it's a good practice to first open a discussion or issue to gather feedback. Once there's a consensus, you can proceed with a pull request. 84 | 85 | --- 86 | 87 | **Q: How can I get my pull request (PR) merged?** 88 | 89 | **Answer:** Ensure your PR follows the contribution guidelines, passes all tests, and has been reviewed by at least one maintainer. Address any feedback provided. 90 | 91 | --- 92 | 93 | **Q: How is credit given to contributors?** 94 | 95 | **Answer:** 96 | Contributors are acknowledged via an `all-contributors` system, which records contributions (code or non-code) at the end of the project's README. 97 | Code contributions are acknowledged in our release notes. 98 | -------------------------------------------------------------------------------- /docs/community/index.md: -------------------------------------------------------------------------------- 1 | # Community 2 | 3 | The DeepSensor community is a group of users and contributors who are interested in the development of DeepSensor. The community is open to anyone who is interested in DeepSensor. The community is a place to ask questions, discuss ideas, and share your work. 4 | 5 | If you are interested in joining the community, please request to join our Slack channel. You can join by [signing up for the Turing Environment & Sustainability stakeholder community](https://forms.office.com/pages/responsepage.aspx?id=p_SVQ1XklU-Knx-672OE-ZmEJNLHTHVFkqQ97AaCfn9UMTZKT1IwTVhJRE82UjUzMVE2MThSOU5RMC4u). The form includes a question on signing up for the Slack team, where you can find DeepSensor's channel. 6 | 7 | We welcome contributions from the community. If you are interested in contributing to DeepSensor, please read the [Contributing Guide](./contributing.md). 8 | -------------------------------------------------------------------------------- /docs/community/roadmap.md: -------------------------------------------------------------------------------- 1 | # DeepSensor Roadmap 2 | 3 | This page contains a list of new features that we would like to add to DeepSensor in the future. 4 | Some of these have been raised as issues on the [GitHub issue tracker](https://github.com/alan-turing-institute/deepsensor/issues) 5 | with further details. 6 | 7 | ```{note} 8 | We will soon create a GitHub project board to track progress on these items, which will provide a more up-to-date view of the roadmap. 9 | ``` 10 | 11 | ```{note} 12 | We are unable to provide a timetable for the roadmap due to maintainer time constraints. 13 | If you are interested in contributing to the project, check out our [](./contributing.md) page. 14 | ``` 15 | 16 | * Patch-wise training and inference 17 | * Saving a ``TaskLoader`` when instantiated with raw xarray/pandas objects 18 | * Spatial-only modelling 19 | * Continuous time measurements (i.e. not just discrete, uniformly sampled data on the same time grid) 20 | * Test the framework with other models (e.g. GPs) 21 | * Add simple baselines models (e.g. linear interpolation, GPs) 22 | * Test and extend support for using ``dask`` in the ``DataProcessor`` and ``TaskLoader`` 23 | * Infer linked context-target sets from the ``TaskLoader`` entries, don't require user to explicitly specify ``links`` kwarg 24 | * Improve unit test suite, increase coverage, test more edge cases, etc 25 | -------------------------------------------------------------------------------- /docs/contact.md: -------------------------------------------------------------------------------- 1 | # Contact the core team 2 | 3 | If you would like to contact us directly, please loop in everyone on the core team: 4 | 5 | * Lead developer: tomandersson3@gmail.com 6 | * Product manager: kwesterling@turing.ac.uk 7 | -------------------------------------------------------------------------------- /docs/getting-started/index.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | 3 | This first part of the documentation provides an overview of the package in [](overview.md) and 4 | the [](./data_requirements.ipynb) of DeepSensor. 5 | If these align with your use case, move on to the [](installation.md) to get started. 6 | -------------------------------------------------------------------------------- /docs/getting-started/installation.md: -------------------------------------------------------------------------------- 1 | # Installation instructions 2 | 3 | DeepSensor is a Python package that can be installed in a number of ways. In this section we will describe the two main ways to install the package. 4 | 5 | ## Install DeepSensor from [PyPI](https://pypi.org/project/deepsensor/) 6 | 7 | If you want to use the latest stable release of DeepSensor and do not want/need access to the worked examples or the package's source code, we recommend installing from PyPI. 8 | 9 | This is the easiest way to install DeepSensor. 10 | 11 | ```bash 12 | pip install deepsensor 13 | ``` 14 | 15 | ```{note} 16 | We advise installing DeepSensor and its dependencies in a python virtual environment using a tool such as [venv](https://docs.python.org/3/library/venv.html) or [conda](https://conda.io/projects/conda/en/latest/user-guide/getting-started.html#managing-python) (other virtual environment managers are available). 17 | ``` 18 | 19 | ## Install DeepSensor from [source](https://github.com/alan-turing-institute/deepsensor) 20 | 21 | ```{note} 22 | You will want to use this method if you intend on contributing to the source code of DeepSensor. 23 | ``` 24 | 25 | If you want to keep up with the latest changes to DeepSensor, or want/need easy access to the worked examples or the package's source code, we recommend installing from source. 26 | 27 | This method will create a `DeepSensor` directory on your machine which will contain all the source code, docs and worked examples. 28 | 29 | - Clone the repository: 30 | 31 | ```bash 32 | git clone https://github.com/alan-turing-institute/deepsensor 33 | ``` 34 | 35 | - Install `DeepSensor`: 36 | 37 | ```bash 38 | pip install -v -e . 39 | ``` 40 | 41 | ```{note} 42 | If you intend on contributing to the source code of DeepSensor, install DeepSensor with its development dependencies using 43 | ````bash 44 | pip install -v -e .[dev] 45 | ```` 46 | ``` 47 | 48 | ## Install PyTorch or TensorFlow 49 | 50 | The next step, if you intend to use any of DeepSensor's deep learning modelling functionality, 51 | is to install the deep learning backend of your choice. 52 | Currently, DeepSensor supports PyTorch or TensorFlow. 53 | 54 | The quickest way to install these packages is with `pip` (see below), although this doesn't guarantee 55 | GPU functionality will work (asssuming you have a GPU). 56 | To access GPU support, you may need to follow the installation instructions of 57 | these libraries (PyTorch: https://pytorch.org/, TensorFlow: https://www.tensorflow.org/install). 58 | 59 | To install `tensorflow` via pip: 60 | 61 | ```bash 62 | pip install tensorflow 63 | pip install tensorflow_probability[tf] 64 | ``` 65 | 66 | To install `pytorch` via pip: 67 | 68 | ```bash 69 | pip install torch 70 | ``` 71 | 72 | To install DeepSensor as well as a deep learning library at the same time use: 73 | 74 | ```bash 75 | pip install deepsensor[tf] # for tensorflow and tensorflow_probability 76 | # or 77 | pip install deepsensor[torch] # for pytorch 78 | ``` -------------------------------------------------------------------------------- /docs/getting-started/overview.md: -------------------------------------------------------------------------------- 1 | # Overview: Why DeepSensor? 2 | 3 | Machine learning (ML) has made its way from the fringes to the frontiers of environmental science. 4 | DeepSensor aims to accelerate the next generation of research in this growing field. 5 | How? By making it easy and fun to apply advanced ML models to environmental data. 6 | 7 | ## Environmental data 8 | 9 | Environmental data is challenging for conventional ML architectures because 10 | it can be multi-modal, multi-resolution, and have missing data. 11 | The various data modalities (e.g. in-situ weather stations, satellites, and simulators) each provide different kinds of information. 12 | We need to move beyond vanilla CNNs, MLPs, and GPs if we want to fuse these data streams. 13 | 14 | ## Neural processes 15 | 16 | Neural processes have emerged as promising ML architectures for environmental data because they can: 17 | * efficiently fuse multi-modal and multi-resolution data, 18 | * handle missing observations, 19 | * capture prediction uncertainty. 20 | 21 | Early research has shown NPs are capable of tackling diverse spatiotemporal modelling tasks, 22 | such as sensor placement, forecasting, downscaling, and satellite gap-filling. 23 | 24 | ## What DeepSensor does 25 | 26 | The DeepSensor Python package streamlines the application of NPs 27 | to environmental sciences by plugging together the `xarray`, `pandas`, and `neuralprocesses` packages with a user-friendly interface that enables rapid experimentation. 28 | **All figures below visualise outputs from DeepSensor**: 29 | ![DeepSensor applications](../../figs/deepsensor_application_examples.png) 30 | 31 | ```{warning} 32 | NPs are not off-the-shelf ML models like those you might find in `scikit-learn`. 33 | They are novel, data-hungry deep learning models. 34 | Early studies have been very promising, 35 | but more research is needed to understand when NPs work best and how to get the most out of them. 36 | That's where the DeepSensor package and community come in! 37 | ``` 38 | 39 | ## Project goals 40 | 41 | DeepSensor aims to: 42 | * Drastically reduce the effort required to apply NPs to environmental data so users can focus on the science 43 | * Build an open-source software and research community 44 | * Generate a positive feedback loop between research and software 45 | * Stay updated with the latest SOTA models that align with the DeepSensor modelling paradigm 46 | 47 | If this interests you, then let's get started! 48 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to DeepSensor's documentation! 2 | 3 | DeepSensor is Python package and open-source project for modelling environmental data with 4 | neural processes. 5 | 6 | 7 | **Useful links**: 8 | [Code repository](https://github.com/alan-turing-institute/deepsensor) | 9 | [Issues](https://github.com/alan-turing-institute/deepsensor/issues) | 10 | [Slack join request form](https://forms.office.com/pages/responsepage.aspx?id=p_SVQ1XklU-Knx-672OE-ZmEJNLHTHVFkqQ97AaCfn9UMTZKT1IwTVhJRE82UjUzMVE2MThSOU5RMC4u) | 11 | [Slack channel](https://ai4environment.slack.com/archives/C05NQ76L87R) | 12 | [DeepSensor Gallery](https://github.com/tom-andersson/deepsensor_gallery) 13 | 14 | 15 | ::::{grid} 1 1 2 2 16 | :gutter: 2 17 | 18 | :::{grid-item-card} 19 | :link: getting-started/index 20 | :link-type: doc 21 | ```{image} _static/index_getting_started.svg 22 | :height: 100px 23 | :align: center 24 | ``` 25 | **Getting started**. 26 | 27 | New to *DeepSensor*? Check out the getting started guides, containing an 28 | introduction to *DeepSensor's* main concepts and how to install it. 29 | ::: 30 | 31 | :::{grid-item-card} 32 | :link: user-guide/index 33 | :link-type: doc 34 | ```{image} _static/index_user_guide.svg 35 | :height: 100px 36 | :align: center 37 | ``` 38 | **User guide**. 39 | 40 | The user guide provides a walkthrough of the main features of the 41 | *DeepSensor* package. 42 | ::: 43 | 44 | :::{grid-item-card} 45 | :link: community/index 46 | :link-type: doc 47 | ```{image} _static/index_community2.png 48 | :height: 100px 49 | :align: center 50 | ``` 51 | **Community**. 52 | 53 | The community guide contains information about how to contribute to 54 | *DeepSensor*, how to get in touch with the community, our project 55 | roadmap, and research questions you can contribute to. 56 | ::: 57 | 58 | :::{grid-item-card} 59 | :link: reference/index 60 | :link-type: doc 61 | ```{image} _static/index_api.svg 62 | :height: 100px 63 | :align: center 64 | ``` 65 | **API reference**. 66 | 67 | The reference guide contains a detailed description of the DeepSensor API, 68 | including all the classes and functions. 69 | It assumes that you have an understanding of the key concepts. 70 | ::: 71 | 72 | :::: 73 | 74 | -------------------------------------------------------------------------------- /docs/reference/active_learning/acquisition_fns.rst: -------------------------------------------------------------------------------- 1 | deepsensor.active_learning.acquisition_fns 2 | ============================================== 3 | 4 | .. automodule:: deepsensor.active_learning.acquisition_fns 5 | :members: 6 | :undoc-members: __init__ 7 | :special-members: __call__ 8 | -------------------------------------------------------------------------------- /docs/reference/active_learning/algorithms.rst: -------------------------------------------------------------------------------- 1 | deepsensor.active_learning.algorithms 2 | ========================================= 3 | 4 | .. automodule:: deepsensor.active_learning.algorithms 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/active_learning/index.md: -------------------------------------------------------------------------------- 1 | # deepsensor.active_learning 2 | -------------------------------------------------------------------------------- /docs/reference/data/index.md: -------------------------------------------------------------------------------- 1 | # deepsensor.data -------------------------------------------------------------------------------- /docs/reference/data/loader.rst: -------------------------------------------------------------------------------- 1 | deepsensor.data.loader 2 | ========================== 3 | 4 | .. 5 | Can not do automodule of deepsensor.data.loader because of 6 | some weird bug in sphinx. 7 | 8 | .. autoclass:: deepsensor.data.loader.TaskLoader 9 | :members: 10 | :undoc-members: __init__ 11 | :special-members: __call__ 12 | -------------------------------------------------------------------------------- /docs/reference/data/processor.rst: -------------------------------------------------------------------------------- 1 | deepsensor.data.processor 2 | ============================= 3 | 4 | .. automodule:: deepsensor.data.processor 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/data/sources.rst: -------------------------------------------------------------------------------- 1 | deepsensor.data.sources 2 | ============================= 3 | 4 | .. automodule:: deepsensor.data.sources 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/data/task.rst: -------------------------------------------------------------------------------- 1 | deepsensor.data.task 2 | ======================== 3 | 4 | .. automodule:: deepsensor.data.task 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/data/utils.rst: -------------------------------------------------------------------------------- 1 | deepsensor.data.utils 2 | ========================= 3 | 4 | .. automodule:: deepsensor.data.utils 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/index.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | This part of the documentation contains the API reference for the package. It is structured by modules, and each module contains its respective classes, functions, and attributes. The API is designed to be as simple as possible while still allowing for a lot of flexibility. The API is divided into several submodules, which are described in the following sections. 4 | -------------------------------------------------------------------------------- /docs/reference/model/convnp.rst: -------------------------------------------------------------------------------- 1 | deepsensor.model.convnp 2 | =========================== 3 | 4 | .. automodule:: deepsensor.model.convnp 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/model/defaults.rst: -------------------------------------------------------------------------------- 1 | deepsensor.model.defaults 2 | ============================= 3 | 4 | .. automodule:: deepsensor.model.defaults 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/model/index.md: -------------------------------------------------------------------------------- 1 | # deepsensor.model 2 | -------------------------------------------------------------------------------- /docs/reference/model/model.rst: -------------------------------------------------------------------------------- 1 | deepsensor.model.model 2 | ========================== 3 | 4 | .. automodule:: deepsensor.model.model 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/model/nps.rst: -------------------------------------------------------------------------------- 1 | deepsensor.model.nps 2 | ======================== 3 | 4 | .. automodule:: deepsensor.model.nps 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/model/pred.rst: -------------------------------------------------------------------------------- 1 | deepsensor.model.pred 2 | ========================== 3 | 4 | .. automodule:: deepsensor.model.pred 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/plot.rst: -------------------------------------------------------------------------------- 1 | deepsensor.plot 2 | =================== 3 | 4 | .. automodule:: deepsensor.plot 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/reference/tensorflow/index.rst: -------------------------------------------------------------------------------- 1 | deepsensor.tensorflow 2 | ========================= 3 | 4 | .. 5 | .. automodule:: deepsensor.tensorflow 6 | :members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/reference/torch/index.rst: -------------------------------------------------------------------------------- 1 | deepsensor.torch 2 | ==================== 3 | 4 | .. 5 | .. automodule:: deepsensor.torch 6 | :members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/reference/train/index.md: -------------------------------------------------------------------------------- 1 | # deepsensor.train 2 | -------------------------------------------------------------------------------- /docs/reference/train/train.rst: -------------------------------------------------------------------------------- 1 | deepsensor.train.train 2 | ========================== 3 | 4 | .. automodule:: deepsensor.train.train 5 | :members: 6 | :show-inheritance: 7 | :undoc-members: __init__ 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /docs/references.bib: -------------------------------------------------------------------------------- 1 | @article{perez2011python 2 | , title = {Python: an ecosystem for scientific computing} 3 | , author = {Perez, Fernando and Granger, Brian E and Hunter, John D} 4 | , journal = {Computing in Science \\& Engineering} 5 | , volume = {13} 6 | , number = {2} 7 | , pages = {13--21} 8 | , year = {2011} 9 | , publisher = {AIP Publishing} 10 | } -------------------------------------------------------------------------------- /docs/research_ideas.md: -------------------------------------------------------------------------------- 1 | # DeepSensor research ideas 2 | 3 | Are you interested in using DeepSensor for a research project? 4 | Thankfully there are many interesting open questions with ConvNPs and their application 5 | to environmental science. 6 | Below are a non-exhaustive selection of research ideas that you could explore. 7 | It would be helpful to ensure you are familiar with the literature and 8 | resources in the [](resources.md) page before starting. 9 | 10 | Why not join our Slack channel and start a conversation around these ideas or your own? :-) You can join by [signing up for the Turing Environment & Sustainability stakeholder community](https://forms.office.com/pages/responsepage.aspx?id=p_SVQ1XklU-Knx-672OE-ZmEJNLHTHVFkqQ97AaCfn9UMTZKT1IwTVhJRE82UjUzMVE2MThSOU5RMC4u). The form includes a question on signing up for the Slack team, where you can find DeepSensor's channel. 11 | 12 | ## Transfer learning from regions of dense observations to regions of sparse observations 13 | 14 | Since the `ConvNP` is a data-hungry model, it does not perform well if only trained on a 15 | small number of observations, which presents a challenge for modelling variables that 16 | are poorly observed. 17 | But what if a particular variable is well observed in one region and poorly observed in another? 18 | Can we train a model on a region of dense observations and then transfer the model to a region 19 | of sparse observations? 20 | Does the performance improve? 21 | 22 | ## Sensor placement for forecasting 23 | 24 | Previous active learning research with ConvNPs has only considered sensor placement for interpolation. 25 | Do the sensor placements change when the model is trained for forecasting? 26 | 27 | See, e.g., Section 4.2.1 of [Environmental sensor placement with convolutional Gaussian neural processes](https://doi.org/10.1017/eds.2023.22). 28 | 29 | ## U-Net architectural changes 30 | 31 | The `ConvNP` currently uses a vanilla U-Net architecture. 32 | Do any architectural changes improve performance, such as batch normalisation or dropout? 33 | 34 | This would require digging into the [`neuralprocesses.construct_convgnp` method](https://github.com/wesselb/neuralprocesses/blob/f20572ba480c1279ad5fb66dbb89cbc73a0171c7/neuralprocesses/architectures/convgnp.py#L97) 35 | and replacing the U-Net module with a custom one. 36 | 37 | ## Extension to continuous time observations 38 | 39 | The `ConvNP` currently assumes that the observations are on a regular time grid. 40 | How can we extend this to continuous time observations, where the observations are not necessarily 41 | on a regular time grid? 42 | Can we do this without a major rework of the code and model? 43 | For example, can we pass a 'time of observation' auxiliary input to the model? 44 | What are the limitations of this approach? 45 | 46 | ## Training with ablations for interpretability 47 | 48 | Since the `ConvNP` operates on sets of observations, it is possible to ablate observations 49 | and see how the model's predictions change. 50 | Thus, the `ConvNP` admits unique interpretability opportunities. 51 | 52 | However, the model would need to be trained with examples of ablated observations so that it 53 | is not out of distribution when it sees ablated observations at test time. 54 | For example, when generating `Task`s with a `TaskLoader`, randomly set some of the 55 | `context_sampling` entries to `0` to remove all observations for those context sets. 56 | Then, at test time, ablate context sets and measure the change in the model's predictions 57 | or performance. 58 | 59 | ## Monte Carlo sensor placement using AR sampling 60 | 61 | The `GreedyAlgorithm` for sensor placement currently uses the model's mean prediction 62 | to infill missing observations at query sites. 63 | However, one could also draw multiple [AR samples](user-guide/prediction.ipynb) 64 | from the model to perform *Monte Carlo sampling* over the acquisition function. 65 | 66 | How does this change the sensor placements and what benefits does it yield? 67 | Do the acquisition functions become more robust (e.g. correlate better with 68 | true performance gains)? 69 | 70 | The [Environmental sensor placement with convolutional Gaussian neural processes](https://doi.org/10.1017/eds.2023.22) 71 | paper will be important background reading for this. 72 | -------------------------------------------------------------------------------- /docs/resources.md: -------------------------------------------------------------------------------- 1 | # Resources 2 | We aim to keep this document updated with the latest resources related to DeepSensor and 3 | neural processes. 4 | 5 | ## 🎤 Recorded talks 6 | 7 | | Date | Title | Presenter | Length | Video | 8 | |------------|:---------------------------------------------|:----------------|---------|:--------------------------------------------------------------------------------------------------------------------------------------------:| 9 | | August 2023 | Tackling diverse environmental prediction tasks with neural processes | Tom Andersson | 1 hour | [🎥](https://youtu.be/MIHNyKjw204) / [slides](https://github.com/tom-andersson/slides/blob/main/2023_08_04_nerc_cde_webinar.pdf) | 10 | | April 2023 | Environmental Sensor Placement with ConvGNPs | Tom Andersson | 15 mins | [🎥](https://youtu.be/v0pmqh09u1Y) | 11 | | Jul 2022 | Advances in Neural Processes | Richard Turner | 1 hour | [🎥](https://www.youtube.com/watch?v=Eu6rGePXYX8) | 12 | | May 2023 | Autoregressive Conditional Neural Processes | Wessel Bruinsma | 5 mins | [🎥](https://www.youtube.com/watch?v=93ZliHS0qBk) | 13 | 14 | ## 📑 Papers 15 | * Tom Andersson et al. [Environmental Sensor Placement with Convolutional Gaussian Neural Processes](https://doi.org/10.1017/eds.2023.22). *Environmental Data Science* (2023) 16 | * Jonas Scholz et al. [Sim2Real with Environmental Neural Processes](https://arxiv.org/abs/2310.19932). *NeurIPS Tackling Climate Change with Machine Learning Workshop* (2023) 17 | * Wessel Bruinsma et al. [Autoregressive Conditional Neural Processes]( 18 | https://doi.org/10.48550/arXiv.2303.14468). In *Proceedings of the 11th 19 | International Conference on Learning Representations, ICLR* (2023) 20 | * Anna Vaughan et al. [Convolutional conditional neural processes for local climate downscaling](https://doi.org/10.5194/gmd-15-251-2022). *Geoscientific Model Development* (2022) 21 | 22 | ## 🗒️ Posters 23 | * Paolo Pelucchi et al. [Optimal Sensor Placement for Black Carbon AOD with Convolutional Neural Processes](https://zenodo.org/record/8370274) 24 | *iMIRACLI Summer School / FORCeS annual meeting* (2023) 25 | 26 | ## 📖 Other resources 27 | * Yann Dubois' [Neural Process Family website](https://yanndubs.github.io/Neural-Process-Family/text/Intro.html) 28 | -------------------------------------------------------------------------------- /docs/user-guide/deepsensor_design.md: -------------------------------------------------------------------------------- 1 | # DeepSensor design 2 | 3 | Some users will find it useful to understand the design of DeepSensor 4 | before they begin. Others would prefer to just see some examples and 5 | get started right away. 6 | 7 | If you fall into the latter category, 8 | feel free to jump straight to the next page ([](data_processor.ipynb)). 9 | 10 | ## Design overview 11 | 12 | A schematic overview of the core components of DeepSensor is shown below. 13 | This shows how the package's components process data & interact from end-to-end. 14 | 15 | ![DeepSensor design](../../figs/deepsensor_design.png) 16 | 17 | The key classes are: 18 | * `DataProcessor`: Maps `xarray` and `pandas` data from their native units 19 | to a normalised and standardised format (and vice versa). 20 | * `TaskLoader`: Slices and samples normalised `xarray` and `pandas` data to generate `Task` objects for 21 | training and inference. 22 | * `Task`: Container for context and target data. Subclass of `dict` with additional methods 23 | for processing and summarising the data. 24 | * `DeepSensorModel`: Base class for DeepSensor models, implementing a high-level `.predict` 25 | method for predicting straight to `xarray`/`pandas` in original coordinates and units. 26 | * `ConvNP`: Convolutional neural process (ConvNP) model class (subclass of `DeepSensorModel`). 27 | Uses the `neuralprocesses` library. This is currently the only model provided by DeepSensor. 28 | * `Trainer`: Class for training on `Task` objects using backpropagation and the Adam optimiser. 29 | * `AcquisitionFunction`: Base class for active learning acquisition functions. 30 | * `GreedyAlgorithm`: Greedy search algorithm for active learning. 31 | 32 | In addition, a [`deepsensor.plot`](../reference/plot.rst) module provides useful plotting functions for 33 | visualising: 34 | * `Task` context and target sets, 35 | * ``DeepSensorModel`` predictions, 36 | * ``ConvNP`` internals (encoding and feature maps), 37 | * ``GreedyAlgorithm`` active learning outputs. 38 | 39 | You will see examples of these `deepsensor.plot` visualisation functions 40 | throughout the documentation. 41 | 42 | 43 | ## Design principles 44 | 45 | A few key design principles have guided the development of DeepSensor: 46 | 47 | * **User-friendly interface**: The interface should be simple and intuitive, with the flexibility to 48 | handle a wide range of use cases. 49 | * **Leverage powerful and ubiquitous data science libraries**: Users can stay within the familiar `xarray`/`pandas` 50 | ecosystem from start to finish in their DeepSensor research workflows. 51 | * **Infer sensible defaults**: DeepSensor should leverage information in the data to infer 52 | sensible defaults for hyperparameters, with the option to override these defaults if desired. 53 | * **Extensible**: Extend DeepSensor with new models by sub-classing `DeepSensorModel` and 54 | implementing the low-level prediction methods of `ProbabilisticModel`. 55 | * **Modular**: The `DataProcessor` and `TaskLoader` classes can be used independently of 56 | the downstream modelling and active learning components, and can thus be used outside of 57 | a DeepSensor workflow. 58 | * **Deep learning library agnostic**: DeepSensor is compatible with both 59 | TensorFlow and PyTorch thanks to the [`backends`](https://github.com/wesselb/lab) library - simply `import deepsensor.tensorflow` or 60 | `import deepsensor.torch`. 61 | -------------------------------------------------------------------------------- /docs/user-guide/index.md: -------------------------------------------------------------------------------- 1 | # User Guide 2 | 3 | The DeepSensor user guide will walk you through the core components of the package using 4 | code examples and visualisations. 5 | 6 | The pages of this guide are Jupyter notebooks and are fully-reproducible. 7 | Some of the notebooks depend on previous notebooks to be run, e.g. model training. 8 | However, most of the notebooks can be run in a standalone way without any modification. 9 | Click the download button at the top of the pages to download the .ipynb files 10 | if you would like to run them yourself. 11 | -------------------------------------------------------------------------------- /docs/user-guide/task.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Tasks" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## What is a 'task'?\n", 15 | "\n", 16 | "The concept of a *task* is central to DeepSensor.\n", 17 | "It originates from the meta-learning literature in machine learning and has a specific meaning.\n", 18 | "\n", 19 | "Users unfamiliar with the notation and terminology of meta-learning are recommended to expand the section below." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "```{admonition} Click to reveal the meta-learning primer\n", 27 | ":class: dropdown\n", 28 | "\n", 29 | "**Sets of observations**\n", 30 | "\n", 31 | "A *set* of observations is a collection of $M$ input-output pairs $\\{(\\mathbf{x}_1, \\mathbf{y}_1), (\\mathbf{x}_2, \\mathbf{y}_2), \\ldots, (\\mathbf{x}_M, \\mathbf{y}_M)\\}$.\n", 32 | "In DeepSensor $\\mathbf{x}_i \\in \\mathbb{R}^2$ is a 2D spatial location (such as latitude-longitude)\n", 33 | " and $\\mathbf{y}_i \\in \\mathbb{R}^N$ is an $N$-dimensional observation at that location (such as a temperature and precipitation).\n", 34 | "Context sets may lie on scattered, off-grid locations (such as weather stations), or on a regular grid (such as a reanalysis or satellite data).\n", 35 | "A *set* can be compactly written as $(\\mathbf{X}, \\mathbf{Y})$, where $\\mathbf{X} \\in \\mathbb{R}^{2\\times M}$ and $\\mathbf{Y} \\in \\mathbb{R}^{N\\times M}$.\n", 36 | "\n", 37 | "**Context sets**\n", 38 | "\n", 39 | "A *context set* is a set of observations that are used to make predictions for another set of observations. Following our notations above, we denote a context set as $C_j=(\\mathbf{X}^{(c)}, \\mathbf{Y}^{(c)})_j$.\n", 40 | "We may have multiple context sets, denoted as $C = \\{ (\\mathbf{X}^{(c)}, \\mathbf{Y}^{(c)})_j \\}_{j=1}^{N_C}$.\n", 41 | "\n", 42 | "**Target sets**\n", 43 | "\n", 44 | "A *target set* is a set of observations that we wish to predict using the context sets.\n", 45 | "Similarly to context sets, we denote the collection of all target sets as $T = \\{ (\\mathbf{X}^{(t)}, \\mathbf{Y}^{(t)})_j \\}_{j=1}^{N_T}$.\n", 46 | "During training, the target observations $\\mathbf{y}_i$ are known, but at inference time will be unknown latent variables.\n", 47 | "\n", 48 | "**Tasks**\n", 49 | "\n", 50 | "A *task* is a collection of context sets and target sets.\n", 51 | "We denote a task as $\\mathcal{D} = (C, T)$.\n", 52 | "The modelling goal is make probabilistic predictions for the target variables $\\mathbf{Y}^{(t)}_j$ given the context sets $C$ and target prediction locations $\\mathbf{X}^{(t)}_j$.\n", 53 | "```" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## The DeepSensor Task\n", 61 | "\n", 62 | "In DeepSensor, a `Task` is a `dict`-like data structure that contains context sets, target sets, and other metadata.\n", 63 | "Before diving into the [](./task_loader) class which generates `Task` objects from `xarray` and `pandas` objects,\n", 64 | "we will first introduce the `Task` class itself.\n", 65 | "\n", 66 | "First, we will generate a `Task` using DeepSensor. These code cells are kept hidden because they includes\n", 67 | "features that are only covered later in the User Guide. Only expand them if you are curious!" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 1, 73 | "metadata": { 74 | "ExecuteTime": { 75 | "start_time": "2023-11-01T14:28:15.732009455Z" 76 | }, 77 | "collapsed": false, 78 | "tags": [ 79 | "hide-cell" 80 | ] 81 | }, 82 | "outputs": [ 83 | { 84 | "name": "stderr", 85 | "output_type": "stream", 86 | "text": [ 87 | "100%|████████████████████████████████████████████████████████████████| 3124/3124 [02:38<00:00, 19.75it/s]\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "import logging\n", 93 | "\n", 94 | "logging.captureWarnings(True)\n", 95 | "\n", 96 | "import deepsensor.torch\n", 97 | "from deepsensor.data import DataProcessor\n", 98 | "from deepsensor.data.sources import get_ghcnd_station_data, get_era5_reanalysis_data, get_earthenv_auxiliary_data, get_gldas_land_mask\n", 99 | "\n", 100 | "import matplotlib.pyplot as plt\n", 101 | "\n", 102 | "# Using the same settings allows use to use pre-downloaded cached data\n", 103 | "data_range = (\"2016-06-25\", \"2016-06-30\")\n", 104 | "extent = \"europe\"\n", 105 | "station_var_IDs = [\"TAVG\", \"PRCP\"]\n", 106 | "era5_var_IDs = [\"2m_temperature\", \"10m_u_component_of_wind\", \"10m_v_component_of_wind\"]\n", 107 | "auxiliary_var_IDs = [\"elevation\", \"tpi\"]\n", 108 | "cache_dir = \"../../.datacache\"\n", 109 | "\n", 110 | "station_raw_df = get_ghcnd_station_data(station_var_IDs, extent, date_range=data_range, cache=True, cache_dir=cache_dir)\n", 111 | "era5_raw_ds = get_era5_reanalysis_data(era5_var_IDs, extent, date_range=data_range, cache=True, cache_dir=cache_dir)\n", 112 | "auxiliary_raw_ds = get_earthenv_auxiliary_data(auxiliary_var_IDs, extent, \"10KM\", cache=True, cache_dir=cache_dir)\n", 113 | "land_mask_raw_ds = get_gldas_land_mask(extent, cache=True, cache_dir=cache_dir)\n", 114 | "\n", 115 | "data_processor = DataProcessor(x1_name=\"lat\", x2_name=\"lon\")\n", 116 | "era5_ds = data_processor(era5_raw_ds)\n", 117 | "aux_ds, land_mask_ds = data_processor([auxiliary_raw_ds, land_mask_raw_ds], method=\"min_max\")\n", 118 | "station_df = data_processor(station_raw_df)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 2, 124 | "metadata": { 125 | "ExecuteTime": { 126 | "end_time": "2023-11-01T14:32:15.553656830Z", 127 | "start_time": "2023-11-01T14:32:15.548454739Z" 128 | }, 129 | "tags": [ 130 | "hide-cell" 131 | ] 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "from deepsensor.data import TaskLoader\n", 136 | "task_loader = TaskLoader(context=[era5_ds, land_mask_ds], target=station_df)\n", 137 | "task = task_loader(\"2016-06-25\", context_sampling=[52, 112], target_sampling=245)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "In the code cell below, `task` is a `Task` object.\n", 145 | "Printing a `Task` will print each of its entries and replace numerical arrays with their shape for convenience." 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 3, 151 | "metadata": { 152 | "ExecuteTime": { 153 | "end_time": "2023-11-01T14:32:15.566930620Z", 154 | "start_time": "2023-11-01T14:32:15.553282595Z" 155 | } 156 | }, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "time: 2016-06-25 00:00:00\n", 163 | "ops: []\n", 164 | "X_c: [(2, 52), (2, 112)]\n", 165 | "Y_c: [(3, 52), (1, 112)]\n", 166 | "X_t: [(2, 245)]\n", 167 | "Y_t: [(2, 245)]\n" 168 | ] 169 | } 170 | ], 171 | "source": [ 172 | "print(task)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "metadata": {}, 178 | "source": [ 179 | "## Task structure\n", 180 | "\n", 181 | "A `Task` typically contains at least the following entries:\n", 182 | "- `\"time\"`: timestamp that was used for slicing the spatiotemporal data.\n", 183 | "- `\"ops\"` list of processing operations that have been applied to the data (more on this shortly).\n", 184 | "- `\"X_c\"` and `\"Y_c\"`: length-$N_C$ lists of context set observations $\\mathbf{X}^{(c)}_i \\in \\mathbb{R}^{2\\times M}$ and $\\mathbf{Y}^{(c)}_i \\in \\mathbb{R}^{N\\times M}$.\n", 185 | "- `\"X_t\"` and `\"Y_t\"`: as above, but for the target sets. In the example above, the target observations are known, so this `Task` may be used for training." 186 | ] 187 | }, 188 | { 189 | "cell_type": "markdown", 190 | "metadata": { 191 | "collapsed": false 192 | }, 193 | "source": [ 194 | "**Exercise:**\n", 195 | "\n", 196 | "For the `task` object above, use the `\"X_c\"`, `\"Y_c\"`, `\"X_t\"`, and `\"Y_t\"` entries to work out the following (answer hidden below):\n", 197 | "- The number of context sets\n", 198 | "- The number of observations in each context set\n", 199 | "- The dimensionality of each context set\n", 200 | "- The number of target sets\n", 201 | "- The number of observations in each target set\n", 202 | "- The dimensionality of each target set\n" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "```{admonition} Click to reveal the answers!\n", 210 | ":class: dropdown\n", 211 | "\n", 212 | "Answers, respectively: 2 context sets, 52 and 112 context observations, 3 and 1 context dimensions, 1 target set, 245 target observations, 2 target dimensions.\n", 213 | "```" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": { 219 | "collapsed": false 220 | }, 221 | "source": [ 222 | "### Gridded data in Tasks\n", 223 | "\n", 224 | "For convenience, data that lies on a regular grid is given a compact tuple representation for the `\"X\"` entries:" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 4, 230 | "metadata": { 231 | "ExecuteTime": { 232 | "end_time": "2023-11-01T14:32:15.620494504Z", 233 | "start_time": "2023-11-01T14:32:15.570462444Z" 234 | }, 235 | "collapsed": false 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "task_with_gridded_data = task_loader(\"2016-06-25\", context_sampling=[\"all\", \"all\"], target_sampling=245)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": 5, 245 | "metadata": { 246 | "ExecuteTime": { 247 | "end_time": "2023-11-01T14:32:15.628949091Z", 248 | "start_time": "2023-11-01T14:32:15.611675646Z" 249 | }, 250 | "collapsed": false 251 | }, 252 | "outputs": [ 253 | { 254 | "name": "stdout", 255 | "output_type": "stream", 256 | "text": [ 257 | "time: 2016-06-25 00:00:00\n", 258 | "ops: []\n", 259 | "X_c: [((1, 141), (1, 221)), ((1, 140), (1, 220))]\n", 260 | "Y_c: [(3, 141, 221), (1, 140, 220)]\n", 261 | "X_t: [(2, 245)]\n", 262 | "Y_t: [(2, 245)]\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "print(task_with_gridded_data)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": { 273 | "collapsed": false 274 | }, 275 | "source": [ 276 | "In the above example, the first context set lies on a 141 x 221 grid, and the second context set lies on a 140 x 220 grid." 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "### Task methods\n", 284 | "The `Task` class also contains methods for applying processing operations to the data (like removing NaNs, adding batch dimensions, etc.).\n", 285 | "These operations will be recorded in the order they were applied the `\"ops\"` entry of the `Task`.\n", 286 | "Operations can be chained together, for example:" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 6, 292 | "metadata": { 293 | "ExecuteTime": { 294 | "end_time": "2023-11-01T14:32:15.906470888Z", 295 | "start_time": "2023-11-01T14:32:15.633776731Z" 296 | } 297 | }, 298 | "outputs": [ 299 | { 300 | "name": "stdout", 301 | "output_type": "stream", 302 | "text": [ 303 | "time: 2016-06-25 00:00:00\n", 304 | "ops: ['batch_dim', 'tensor']\n", 305 | "X_c: [torch.Size([1, 2, 52]), torch.Size([1, 2, 112])]\n", 306 | "Y_c: [torch.Size([1, 3, 52]), torch.Size([1, 1, 112])]\n", 307 | "X_t: [torch.Size([1, 2, 245])]\n", 308 | "Y_t: [torch.Size([1, 2, 245])]\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "print(task.add_batch_dim().convert_to_tensor())" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "collapsed": false 320 | }, 321 | "source": [ 322 | "Gridded data in a `Task` can be flattened using the `.flatten_gridded_data` method.\n", 323 | "Notice how the `\"X\"` entries are now 2D arrays of shape `(2, M)` rather than tuples of two 1D arrays of shape `(M,)`." 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 7, 329 | "metadata": { 330 | "ExecuteTime": { 331 | "end_time": "2023-11-01T14:32:15.970618528Z", 332 | "start_time": "2023-11-01T14:32:15.909066194Z" 333 | }, 334 | "collapsed": false 335 | }, 336 | "outputs": [ 337 | { 338 | "name": "stdout", 339 | "output_type": "stream", 340 | "text": [ 341 | "time: 2016-06-25 00:00:00\n", 342 | "ops: ['gridded_data_flattened']\n", 343 | "X_c: [(2, 31161), (2, 30800)]\n", 344 | "Y_c: [(3, 31161), (1, 30800)]\n", 345 | "X_t: [(2, 245)]\n", 346 | "Y_t: [(2, 245)]\n" 347 | ] 348 | } 349 | ], 350 | "source": [ 351 | "print(task_with_gridded_data.flatten_gridded_data())" 352 | ] 353 | } 354 | ], 355 | "metadata": { 356 | "celltoolbar": "Edit Metadata", 357 | "kernelspec": { 358 | "display_name": "Python 3 (ipykernel)", 359 | "language": "python", 360 | "name": "python3" 361 | }, 362 | "language_info": { 363 | "codemirror_mode": { 364 | "name": "ipython", 365 | "version": 3 366 | }, 367 | "file_extension": ".py", 368 | "mimetype": "text/x-python", 369 | "name": "python", 370 | "nbconvert_exporter": "python", 371 | "pygments_lexer": "ipython3", 372 | "version": "3.8.10" 373 | } 374 | }, 375 | "nbformat": 4, 376 | "nbformat_minor": 2 377 | } 378 | -------------------------------------------------------------------------------- /figs/DeepSensorLogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/figs/DeepSensorLogo.png -------------------------------------------------------------------------------- /figs/DeepSensorLogo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/figs/DeepSensorLogo2.png -------------------------------------------------------------------------------- /figs/convnp_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/figs/convnp_arch.png -------------------------------------------------------------------------------- /figs/deepsensor_application_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/figs/deepsensor_application_examples.png -------------------------------------------------------------------------------- /figs/deepsensor_design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/figs/deepsensor_design.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "deepsensor" 7 | version = "0.4.2" 8 | authors = [ 9 | {name = "Tom R. Andersson", email="tomandersson3@gmail.com"}, 10 | ] 11 | description = "A Python package for modelling xarray and pandas data with neural processes." 12 | readme = "README.md" 13 | license = {text="MIT"} 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "Operating System :: OS Independent" 17 | ] 18 | requires-python = ">=3.8" 19 | dependencies = [ 20 | "backends>=1.7.0", 21 | "backends-matrix", 22 | "dask", 23 | "distributed", 24 | "gcsfs", 25 | "matplotlib", 26 | "neuralprocesses>=0.2.7", 27 | "numpy", 28 | "pandas", 29 | "pooch", 30 | "pyshp", 31 | "seaborn", 32 | "tqdm", 33 | "xarray", 34 | "zarr" 35 | ] 36 | 37 | [tool.setuptools] 38 | packages = ["deepsensor"] 39 | 40 | [project.urls] 41 | Source = "https://github.com/alan-turing-institute/deepsensor" 42 | Bug_Tracker = "https://github.com/alan-turing-institute/deepsensor/issues" 43 | 44 | [project.optional-dependencies] 45 | torch = ["torch>=2"] 46 | tf = ["tensorflow", "tensorflow_probability[tf]"] 47 | dev = [ 48 | "coveralls", 49 | "parameterized", 50 | "pre-commit", 51 | "pytest", 52 | "ruff", 53 | ] 54 | docs = [ 55 | "jupyter-book", 56 | "matplotlib", 57 | "numpy", 58 | "sphinx", 59 | ] 60 | testing = [ 61 | "mypy", 62 | "parameterized", 63 | "pytest", 64 | "pytest-cov", 65 | "tox", 66 | ] 67 | rioxarray = [ 68 | "rioxarray" 69 | ] 70 | 71 | [tool.setuptools.package-data] 72 | deepsensor = ["py.typed"] 73 | 74 | [tool.pytest.ini_options] 75 | addopts = "--cov=deepsensor" 76 | testpaths = [ 77 | "tests", 78 | ] 79 | 80 | [tool.mypy] 81 | mypy_path = "deepsensor" 82 | check_untyped_defs = true 83 | disallow_any_generics = true 84 | ignore_missing_imports = true 85 | no_implicit_optional = true 86 | show_error_codes = true 87 | strict_equality = true 88 | warn_redundant_casts = true 89 | warn_return_any = true 90 | warn_unreachable = true 91 | warn_unused_configs = true 92 | no_implicit_reexport = true 93 | 94 | [tool.ruff] 95 | exclude = ["tests", "docs", "*.ipynb"] 96 | lint.select = [ 97 | "D", 98 | "NPY201" 99 | ] 100 | lint.ignore = [ 101 | "D100", # Missing docstring in public module 102 | "D104", # Missing docstring in public package 103 | "D105", # Missing docstring in magic method 104 | "D107", # Missing docstring in __init__ 105 | "D205", # 1 blank line required between summary line and description 106 | "D417", # Missing argument description in function docstring 107 | ] 108 | lint.pydocstyle.convention = "google" -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alan-turing-institute/deepsensor/2fb86f7f9ff1bd8933fd3e6e1ce71153753e1b44/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_active_learning.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import xarray as xr 3 | import numpy as np 4 | 5 | import deepsensor.tensorflow as deepsensor 6 | from deepsensor.active_learning.acquisition_fns import ( 7 | AcquisitionFunction, 8 | MeanVariance, 9 | MeanStddev, 10 | pNormStddev, 11 | MeanMarginalEntropy, 12 | JointEntropy, 13 | OracleMAE, 14 | OracleRMSE, 15 | OracleMarginalNLL, 16 | OracleJointNLL, 17 | Stddev, 18 | ExpectedImprovement, 19 | Random, 20 | ContextDist, 21 | ) 22 | from deepsensor.active_learning.algorithms import GreedyAlgorithm 23 | 24 | from deepsensor.data.loader import TaskLoader 25 | from deepsensor.data.processor import DataProcessor, xarray_to_coord_array_normalised 26 | from deepsensor.model.convnp import ConvNP 27 | 28 | 29 | class TestActiveLearning(unittest.TestCase): 30 | 31 | @classmethod 32 | def setUpClass(cls): 33 | 34 | # It's safe to share data between tests because the TaskLoader does not modify data 35 | ds_raw = xr.tutorial.open_dataset("air_temperature")["air"] 36 | cls.ds_raw = ds_raw 37 | cls.data_processor = DataProcessor(x1_name="lat", x2_name="lon") 38 | cls.ds = cls.data_processor(ds_raw) 39 | # Set up a model with two context sets and two target sets for generality 40 | cls.task_loader = TaskLoader(context=[cls.ds, cls.ds], target=[cls.ds, cls.ds]) 41 | cls.model = ConvNP( 42 | cls.data_processor, 43 | cls.task_loader, 44 | unet_channels=(5, 5, 5), 45 | verbose=False, 46 | ) 47 | 48 | # Set up model with aux-at-target data 49 | aux_at_targets = cls.ds.isel(time=0).drop_vars("time") 50 | cls.task_loader_with_aux = TaskLoader( 51 | context=cls.ds, target=cls.ds, aux_at_targets=aux_at_targets 52 | ) 53 | cls.model_with_aux = ConvNP( 54 | cls.data_processor, 55 | cls.task_loader_with_aux, 56 | unet_channels=(5, 5, 5), 57 | verbose=False, 58 | ) 59 | 60 | def test_wrong_n_new_sensors(self): 61 | with self.assertRaises(ValueError): 62 | alg = GreedyAlgorithm( 63 | model=self.model, 64 | X_t=self.ds_raw, 65 | X_s=self.ds_raw, 66 | N_new_context=-1, 67 | ) 68 | 69 | with self.assertRaises(ValueError): 70 | alg = GreedyAlgorithm( 71 | model=self.model, 72 | X_t=self.ds_raw, 73 | X_s=self.ds_raw, 74 | N_new_context=10_000, # > number of search points 75 | ) 76 | 77 | def test_acquisition_fns_run(self): 78 | """Run each acquisition function to check that it runs and returns correct shape""" 79 | for context_set_idx, target_set_idx in [(0, 0), (0, 1), (1, 0), (1, 1)]: 80 | sequential_acquisition_fns = [ 81 | MeanStddev(self.model, context_set_idx, target_set_idx), 82 | MeanVariance(self.model, context_set_idx, target_set_idx), 83 | pNormStddev(self.model, context_set_idx, target_set_idx, p=3), 84 | MeanMarginalEntropy(self.model, context_set_idx, target_set_idx), 85 | JointEntropy(self.model, context_set_idx, target_set_idx), 86 | OracleMAE(self.model, context_set_idx, target_set_idx), 87 | OracleRMSE(self.model, context_set_idx, target_set_idx), 88 | OracleMarginalNLL(self.model, context_set_idx, target_set_idx), 89 | OracleJointNLL(self.model, context_set_idx, target_set_idx), 90 | ] 91 | parallel_acquisition_fns = [ 92 | Stddev(self.model, context_set_idx, target_set_idx), 93 | ExpectedImprovement(self.model, context_set_idx, target_set_idx), 94 | ContextDist(self.model, context_set_idx, target_set_idx), 95 | Random(self.model, context_set_idx, target_set_idx), 96 | ] 97 | 98 | # Coarsen search points to speed up computation 99 | X_s = self.ds_raw.coarsen(lat=10, lon=10, boundary="trim").mean() 100 | X_s = self.data_processor.map_coords(X_s) 101 | X_s_arr = xarray_to_coord_array_normalised(X_s) 102 | 103 | task = self.task_loader( 104 | "2014-12-31", context_sampling=10, target_sampling="all" 105 | ) 106 | 107 | for acquisition_fn in sequential_acquisition_fns: 108 | importance = acquisition_fn(task) 109 | assert importance.size == 1 110 | for acquisition_fn in parallel_acquisition_fns: 111 | importances = acquisition_fn(task, X_s_arr) 112 | assert importances.size == X_s_arr.shape[-1] 113 | 114 | def test_greedy_alg_runs(self): 115 | """Run the greedy algorithm to check that it runs without error""" 116 | # Both a sequential and parallel acquisition function 117 | acquisition_fns = [ 118 | MeanStddev(self.model), 119 | Stddev(self.model), 120 | ] 121 | # Coarsen search points to speed up computation 122 | X_s = self.ds_raw.coarsen(lat=10, lon=10, boundary="trim").mean() 123 | alg = GreedyAlgorithm( 124 | model=self.model, 125 | X_t=X_s, 126 | X_s=X_s, 127 | N_new_context=2, 128 | ) 129 | task = self.task_loader("2014-12-31", context_sampling=10) 130 | for acquisition_fn in acquisition_fns: 131 | X_new_df, acquisition_fn_ds = alg(acquisition_fn, task) 132 | 133 | def test_greedy_alg_with_aux_at_targets(self): 134 | """Run the greedy algorithm to check that it runs without error""" 135 | # Both a sequential and parallel acquisition function 136 | acquisition_fns = [ 137 | MeanStddev(self.model_with_aux), 138 | Stddev(self.model_with_aux), 139 | ] 140 | # Coarsen search points to speed up computation 141 | X_s = self.ds_raw.coarsen(lat=10, lon=10, boundary="trim").mean() 142 | alg = GreedyAlgorithm( 143 | model=self.model_with_aux, 144 | X_t=X_s, 145 | X_s=X_s, 146 | N_new_context=2, 147 | task_loader=self.task_loader_with_aux, 148 | ) 149 | task = self.task_loader_with_aux("2014-12-31", context_sampling=10) 150 | for acquisition_fn in acquisition_fns: 151 | X_new_df, acquisition_fn_ds = alg(acquisition_fn, task) 152 | 153 | def test_greedy_alg_with_oracle_acquisition_fn(self): 154 | acquisition_fn = OracleMAE(self.model) 155 | # Coarsen search points to speed up computation 156 | X_s = self.ds_raw.coarsen(lat=10, lon=10, boundary="trim").mean() 157 | alg = GreedyAlgorithm( 158 | model=self.model, 159 | X_t=X_s, 160 | X_s=X_s, 161 | N_new_context=2, 162 | task_loader=self.task_loader, 163 | ) 164 | task = self.task_loader("2014-12-31", context_sampling=10) 165 | _ = alg(acquisition_fn, task) 166 | 167 | def test_greedy_alg_with_sequential_acquisition_fn(self): 168 | acquisition_fn = Stddev(self.model) 169 | X_s = self.ds_raw 170 | alg = GreedyAlgorithm( 171 | model=self.model, 172 | X_t=X_s, 173 | X_s=X_s, 174 | N_new_context=1, 175 | task_loader=self.task_loader, 176 | ) 177 | task = self.task_loader("2014-12-31", context_sampling=10) 178 | _ = alg(acquisition_fn, task) 179 | 180 | def test_greedy_algorithm_column_names(self): 181 | # Setup 182 | acquisition_fn = Stddev(self.model) 183 | X_s = self.ds_raw 184 | alg = GreedyAlgorithm( 185 | model=self.model, 186 | X_t=X_s, 187 | X_s=X_s, 188 | N_new_context=1, 189 | task_loader=self.task_loader, 190 | ) 191 | task = self.task_loader("2014-12-31", context_sampling=10) 192 | 193 | # Exercise 194 | X_new_df, acquisition_fn_ds = alg(acquisition_fn, task) 195 | 196 | # Assert 197 | expected_columns = ["lat", "lon"] # Replace with actual expected column names 198 | actual_columns = X_new_df.columns.tolist() 199 | self.assertEqual( 200 | expected_columns, 201 | actual_columns, 202 | "Column names do not match the expected names", 203 | ) 204 | 205 | def test_greedy_alg_with_aux_at_targets_without_task_loader_raises_value_error( 206 | self, 207 | ): 208 | acquisition_fn = MeanStddev(self.model) 209 | X_s = self.ds_raw 210 | alg = GreedyAlgorithm( 211 | model=self.model_with_aux, 212 | X_t=X_s, 213 | X_s=X_s, 214 | N_new_context=1, 215 | task_loader=None, # don't pass task_loader (to raise error) 216 | ) 217 | task = self.task_loader_with_aux("2014-12-31", context_sampling=10) 218 | with self.assertRaises(ValueError): 219 | _ = alg(acquisition_fn, task) 220 | 221 | def test_greedy_alg_with_oracle_acquisition_fn_without_task_loader_raises_value_error( 222 | self, 223 | ): 224 | acquisition_fn = OracleMAE(self.model) 225 | X_s = self.ds_raw 226 | alg = GreedyAlgorithm( 227 | model=self.model, 228 | X_t=X_s, 229 | X_s=X_s, 230 | N_new_context=2, 231 | task_loader=None, # don't pass task_loader (to raise error) 232 | ) 233 | task = self.task_loader("2014-12-31", context_sampling=10) 234 | with self.assertRaises(ValueError): 235 | _ = alg(acquisition_fn, task) 236 | 237 | def test_acquisition_fn_without_min_or_max_raises_error( 238 | self, 239 | ): 240 | class DummyAcquisitionFn(AcquisitionFunction): 241 | """Dummy acquisition function that doesn't set min or max""" 242 | 243 | def __call__(self, **kwargs): 244 | return np.zeros(1) 245 | 246 | acquisition_fn = DummyAcquisitionFn(self.model) 247 | 248 | X_s = self.ds_raw 249 | alg = GreedyAlgorithm( 250 | model=self.model, 251 | X_t=X_s, 252 | X_s=X_s, 253 | N_new_context=2, 254 | ) 255 | with self.assertRaises(ValueError): 256 | _ = alg(acquisition_fn, None) 257 | 258 | def test_parallel_acquisition_fn_with_diff_raises_error( 259 | self, 260 | ): 261 | acquisition_fn = Stddev(self.model) 262 | X_s = self.ds_raw 263 | alg = GreedyAlgorithm( 264 | model=self.model, 265 | X_t=X_s, 266 | X_s=X_s, 267 | ) 268 | task = self.task_loader( 269 | "2014-12-31", context_sampling=10, target_sampling="all" 270 | ) 271 | 272 | # This should work 273 | _ = alg(acquisition_fn, task) 274 | # This should raise an error 275 | with self.assertRaises(ValueError): 276 | _ = alg(acquisition_fn, task, diff=True) 277 | -------------------------------------------------------------------------------- /tests/test_data_processor.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import numpy as np 3 | import pandas as pd 4 | import unittest 5 | import tempfile 6 | 7 | from deepsensor.data.processor import DataProcessor 8 | from tests.utils import ( 9 | gen_random_data_xr, 10 | gen_random_data_pandas, 11 | assert_allclose_pd, 12 | assert_allclose_xr, 13 | ) 14 | 15 | 16 | def _gen_data_xr(coords=None, dims=None, data_vars=None): 17 | """Gen random raw data""" 18 | if coords is None: 19 | coords = dict( 20 | time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), 21 | lat=np.linspace(20, 40, 30), 22 | lon=np.linspace(40, 60, 20), 23 | ) 24 | da = gen_random_data_xr(coords, dims, data_vars) 25 | return da 26 | 27 | 28 | def _gen_data_pandas(coords=None, dims=None, cols=None): 29 | """Gen random raw data""" 30 | if coords is None: 31 | coords = dict( 32 | time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), 33 | lat=np.linspace(20, 40, 10), 34 | lon=np.linspace(40, 60, 10), 35 | ) 36 | df = gen_random_data_pandas(coords, dims, cols) 37 | return df 38 | 39 | 40 | class TestDataProcessor(unittest.TestCase): 41 | """Test DataProcessor 42 | 43 | Tests TODO: 44 | - Test different time frequencies 45 | - ... 46 | """ 47 | 48 | def test_only_passing_one_x_mapping_raises_valueerror(self): 49 | with self.assertRaises(ValueError): 50 | DataProcessor(x1_map=(20, 40), x2_map=None) 51 | 52 | def test_unnorm_restores_data_for_each_method(self): 53 | """Check that the unnormalisation restores the original data for each normalisation method.""" 54 | da_raw = _gen_data_xr() 55 | df_raw = _gen_data_pandas() 56 | 57 | dp_with_x_mappings = DataProcessor( 58 | time_name="time", x1_name="lat", x2_name="lon" 59 | ) 60 | dp_inferred_x_mappings = DataProcessor( 61 | time_name="time", x1_name="lat", x2_name="lon" 62 | ) 63 | dps = [dp_with_x_mappings, dp_inferred_x_mappings] 64 | 65 | for dp in dps: 66 | for method in dp.valid_methods: 67 | da_norm, df_norm = dp([da_raw, df_raw], method=method) 68 | da_unnorm, df_unnorm = dp.unnormalise([da_norm, df_norm]) 69 | self.assertTrue( 70 | assert_allclose_xr(da_unnorm, da_raw), 71 | f"Original {type(da_raw).__name__} not restored for method {method}.", 72 | ) 73 | self.assertTrue( 74 | assert_allclose_pd(df_unnorm, df_raw), 75 | f"Original {type(df_raw).__name__} not restored for method {method}.", 76 | ) 77 | 78 | def test_different_names_xr(self): 79 | """The time, x1 and x2 dimensions can have arbitrary names and these should be restored 80 | after unnormalisation. 81 | """ 82 | da_raw = _gen_data_xr() 83 | da_raw = da_raw.rename( 84 | {"time": "datetime", "lat": "latitude", "lon": "longitude"} 85 | ) 86 | 87 | dp = DataProcessor( 88 | time_name="datetime", x1_name="latitude", x2_name="longitude" 89 | ) 90 | da_norm = dp(da_raw) 91 | self.assertListEqual( 92 | ["time", "x1", "x2"], list(da_norm.dims), "Failed to rename dims." 93 | ) 94 | 95 | da_unnorm = dp.unnormalise(da_norm) 96 | self.assertTrue( 97 | assert_allclose_xr(da_unnorm, da_raw), 98 | f"Original {type(da_raw).__name__} not restored.", 99 | ) 100 | 101 | def test_same_names_xr(self): 102 | """Test edge case when dim names are already in standard form. 103 | """ 104 | da_raw = _gen_data_xr() 105 | da_raw = da_raw.rename({"lat": "x1", "lon": "x2"}) 106 | 107 | dp = DataProcessor() 108 | da_norm = dp(da_raw) 109 | self.assertListEqual( 110 | ["time", "x1", "x2"], list(da_norm.dims), "Failed to rename dims." 111 | ) 112 | 113 | da_unnorm = dp.unnormalise(da_norm) 114 | self.assertTrue( 115 | assert_allclose_xr(da_unnorm, da_raw), 116 | f"Original {type(da_raw).__name__} not restored.", 117 | ) 118 | 119 | def test_wrong_order_xr_ds(self): 120 | """Order of dimensions in xarray must be: time, x1, x2""" 121 | ds_raw = _gen_data_xr(dims=("time", "lat", "lon"), data_vars=["var1", "var2"]) 122 | ds_raw = ds_raw.transpose("time", "lon", "lat") # Transpose, changing order 123 | 124 | dp = DataProcessor(time_name="time", x1_name="lat", x2_name="lon") 125 | with self.assertRaises(ValueError): 126 | dp(ds_raw) 127 | 128 | def test_wrong_order_xr_da(self): 129 | """Order of dimensions in xarray must be: time, x1, x2""" 130 | da_raw = _gen_data_xr() 131 | da_raw = da_raw.T # Transpose, changing order 132 | 133 | dp = DataProcessor(time_name="time", x1_name="lat", x2_name="lon") 134 | with self.assertRaises(ValueError): 135 | dp(da_raw) 136 | 137 | def test_not_passing_method_raises_valuerror(self): 138 | """Must pass a valid method when normalising.""" 139 | da_raw = _gen_data_xr() 140 | dp = DataProcessor() 141 | with self.assertRaises(ValueError): 142 | dp(da_raw, method="not_a_valid_method") 143 | 144 | def test_different_names_pandas(self): 145 | """The time, x1 and x2 dimensions can have arbitrary names and these should be restored 146 | after unnormalisation. 147 | """ 148 | df_raw = _gen_data_pandas() 149 | df_raw.index.names = ["datetime", "lat", "lon"] 150 | 151 | dp = DataProcessor(time_name="datetime", x1_name="lat", x2_name="lon") 152 | 153 | df_norm = dp(df_raw) 154 | 155 | self.assertListEqual(["time", "x1", "x2"], list(df_norm.index.names)) 156 | 157 | df_unnorm = dp.unnormalise(df_norm) 158 | 159 | self.assertTrue( 160 | assert_allclose_pd(df_unnorm, df_raw), 161 | f"Original {type(df_raw).__name__} not restored.", 162 | ) 163 | 164 | def test_same_names_pandas(self): 165 | """Test edge case when dim names are already in standard form. 166 | """ 167 | df_raw = _gen_data_pandas() 168 | df_raw.index.names = ["time", "x1", "x2"] 169 | 170 | dp = DataProcessor() # No name changes 171 | 172 | df_norm = dp(df_raw) 173 | 174 | self.assertListEqual(["time", "x1", "x2"], list(df_norm.index.names)) 175 | 176 | df_unnorm = dp.unnormalise(df_norm) 177 | 178 | self.assertTrue( 179 | assert_allclose_pd(df_unnorm, df_raw), 180 | f"Original {type(df_raw).__name__} not restored.", 181 | ) 182 | 183 | def test_wrong_order_pandas(self): 184 | """Order of dimensions in pandas index must be: time, x1, x2""" 185 | df_raw = _gen_data_pandas() 186 | df_raw = df_raw.swaplevel(0, 2) 187 | 188 | dp = DataProcessor(time_name="time", x1_name="lat", x2_name="lon") 189 | 190 | with self.assertRaises(ValueError): 191 | dp(df_raw) 192 | 193 | def test_extra_indexes_preserved_pandas(self): 194 | """Other metadata indexes are allowed (only *after* the default dimension indexes of 195 | [time, x1, x2] or just [x1, x2]), and these should be preserved during normalisation. 196 | """ 197 | coords = dict( 198 | time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), 199 | lat=np.linspace(20, 40, 30), 200 | lon=np.linspace(40, 60, 20), 201 | station=["A", "B"], 202 | ) 203 | df_raw = _gen_data_pandas(coords=coords) 204 | 205 | dp = DataProcessor(time_name="time", x1_name="lat", x2_name="lon") 206 | 207 | df_norm = dp(df_raw) 208 | df_unnorm = dp.unnormalise(df_norm) 209 | 210 | self.assertListEqual(list(df_raw.index.names), list(df_unnorm.index.names)) 211 | self.assertTrue( 212 | assert_allclose_pd(df_unnorm, df_raw), 213 | f"Original {type(df_raw).__name__} not restored.", 214 | ) 215 | 216 | def test_wrong_extra_indexes_pandas(self): 217 | """Other metadata indexes are allowed but if they are not *after* the default dimension 218 | indexes of [time, x1, x2] or just [x1, x2], then an error should be raised. 219 | """ 220 | coords = dict( 221 | station=["A", "B"], 222 | time=pd.date_range("2020-01-01", "2020-01-31", freq="D"), 223 | lat=np.linspace(20, 40, 30), 224 | lon=np.linspace(40, 60, 20), 225 | ) 226 | df_raw = _gen_data_pandas(coords=coords) 227 | 228 | dp = DataProcessor(time_name="time", x1_name="lat", x2_name="lon") 229 | 230 | with self.assertRaises(ValueError): 231 | dp(df_raw) 232 | 233 | def test_saving_and_loading(self): 234 | """Test saving and loading DataProcessor""" 235 | with tempfile.TemporaryDirectory() as tmp_dir: 236 | da_raw = _gen_data_xr() 237 | df_raw = _gen_data_pandas() 238 | 239 | dp = DataProcessor(time_name="time", x1_name="lat", x2_name="lon") 240 | # Normalise some data to store normalisation parameters in config 241 | da_norm = dp(da_raw, method="mean_std") 242 | df_norm = dp(df_raw, method="min_max") 243 | 244 | dp.save(tmp_dir) 245 | 246 | dp_loaded = DataProcessor(tmp_dir) 247 | 248 | # Check that the TaskLoader was saved and loaded correctly 249 | self.assertEqual( 250 | dp.config, 251 | dp_loaded.config, 252 | "Config not saved and loaded correctly", 253 | ) 254 | 255 | 256 | if __name__ == "__main__": 257 | unittest.main() 258 | -------------------------------------------------------------------------------- /tests/test_plotting.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | import numpy as np 3 | import pandas as pd 4 | import unittest 5 | 6 | import deepsensor.tensorflow as deepsensor 7 | 8 | from deepsensor.data.processor import DataProcessor 9 | from deepsensor.data.loader import TaskLoader 10 | from deepsensor.model.convnp import ConvNP 11 | 12 | 13 | class TestPlotting(unittest.TestCase): 14 | 15 | @classmethod 16 | def setUpClass(cls): 17 | # It's safe to share data between tests because the TaskLoader does not modify data 18 | ds_raw = xr.tutorial.open_dataset("air_temperature") 19 | cls.ds_raw = ds_raw 20 | cls.data_processor = DataProcessor(x1_name="lat", x2_name="lon") 21 | ds = cls.data_processor(ds_raw) 22 | cls.task_loader = TaskLoader(context=ds, target=ds) 23 | cls.model = ConvNP( 24 | cls.data_processor, 25 | cls.task_loader, 26 | unet_channels=(5, 5, 5), 27 | verbose=False, 28 | ) 29 | # Sample a task with 10 random context points 30 | cls.task = cls.task_loader( 31 | "2014-12-31", context_sampling=10, target_sampling="all" 32 | ) 33 | 34 | def test_context_encoding(self): 35 | fig = deepsensor.plot.context_encoding(self.model, self.task, self.task_loader) 36 | 37 | def test_feature_maps(self): 38 | figs = deepsensor.plot.feature_maps(self.model, self.task) 39 | 40 | def test_offgrid_context(self): 41 | pred = self.model.predict(self.task, X_t=self.ds_raw) 42 | fig = pred["air"]["mean"].isel(time=0).plot(cmap="seismic") 43 | deepsensor.plot.offgrid_context( 44 | fig.axes, self.task, self.data_processor, self.task_loader 45 | ) 46 | 47 | def test_offgrid_context_observations(self): 48 | pred = self.model.predict(self.task, X_t=self.ds_raw) 49 | fig = pred["air"]["mean"].isel(time=0).plot(cmap="seismic") 50 | deepsensor.plot.offgrid_context_observations( 51 | fig.axes, 52 | self.task, 53 | self.data_processor, 54 | self.task_loader, 55 | context_set_idx=0, 56 | format_str=None, 57 | extent=None, 58 | color="black", 59 | ) 60 | -------------------------------------------------------------------------------- /tests/test_task.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import xarray as xr 5 | 6 | from deepsensor import DataProcessor, TaskLoader 7 | from deepsensor.data.task import append_obs_to_task 8 | from deepsensor.errors import TaskSetIndexError, GriddedDataError 9 | from deepsensor.model import ConvNP 10 | 11 | 12 | class TestConcatTasks(unittest.TestCase): 13 | 14 | @classmethod 15 | def setUpClass(cls): 16 | # It's safe to share data between tests because the TaskLoader does not modify data 17 | ds_raw = xr.tutorial.open_dataset("air_temperature") 18 | cls.ds_raw = ds_raw 19 | cls.data_processor = DataProcessor(x1_name="lat", x2_name="lon") 20 | ds = cls.data_processor(ds_raw) 21 | cls.task_loader = TaskLoader(context=ds, target=ds) 22 | cls.model = ConvNP( 23 | cls.data_processor, 24 | cls.task_loader, 25 | unet_channels=(5, 5, 5), 26 | verbose=False, 27 | ) 28 | 29 | def test_concat_obs_to_task_shapes(self): 30 | ctx_idx = 0 # Context set index to add new observations to 31 | 32 | # Sample 10 context observations 33 | task = self.task_loader("2014-12-31", context_sampling=10) 34 | 35 | # 1 context observation 36 | X_new = np.random.randn(2, 1) 37 | Y_new = np.random.randn(1, 1) 38 | new_task = append_obs_to_task(task, X_new, Y_new, ctx_idx) 39 | self.assertEqual(new_task["X_c"][ctx_idx].shape, (2, 11)) 40 | self.assertEqual(new_task["Y_c"][ctx_idx].shape, (1, 11)) 41 | 42 | # 1 context observation with flattened obs dim 43 | X_new = np.random.randn(2) 44 | Y_new = np.random.randn(1) 45 | new_task = append_obs_to_task(task, X_new, Y_new, ctx_idx) 46 | self.assertEqual(new_task["X_c"][ctx_idx].shape, (2, 11)) 47 | self.assertEqual(new_task["Y_c"][ctx_idx].shape, (1, 11)) 48 | 49 | # 5 context observations 50 | X_new = np.random.randn(2, 5) 51 | Y_new = np.random.randn(1, 5) 52 | new_task = append_obs_to_task(task, X_new, Y_new, ctx_idx) 53 | self.assertEqual(new_task["X_c"][ctx_idx].shape, (2, 15)) 54 | self.assertEqual(new_task["Y_c"][ctx_idx].shape, (1, 15)) 55 | 56 | def test_concat_obs_to_task_wrong_context_index(self): 57 | # Sample 10 context observations 58 | task = self.task_loader("2014-12-31", context_sampling=10) 59 | 60 | ctx_idx = 1 # Wrong context set index 61 | 62 | # 1 context observation 63 | X_new = np.random.randn(2, 1) 64 | Y_new = np.random.randn(1, 1) 65 | 66 | with self.assertRaises(TaskSetIndexError): 67 | _ = append_obs_to_task(task, X_new, Y_new, ctx_idx) 68 | 69 | def test_concat_obs_to_task_fails_for_gridded_data(self): 70 | ctx_idx = 0 # Context set index to add new observations to 71 | 72 | # Sample context observations on a grid 73 | task = self.task_loader("2014-12-31", context_sampling="all") 74 | 75 | # Confirm that context observations are gridded with tuple for coordinates 76 | assert isinstance(task["X_c"][ctx_idx], tuple) 77 | 78 | # 1 context observation 79 | X_new = np.random.randn(2, 1) 80 | Y_new = np.random.randn(1, 1) 81 | 82 | with self.assertRaises(GriddedDataError): 83 | new_task = append_obs_to_task(task, X_new, Y_new, ctx_idx) 84 | -------------------------------------------------------------------------------- /tests/test_training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import xarray as xr 4 | 5 | import unittest 6 | 7 | from tqdm import tqdm 8 | 9 | import deepsensor.tensorflow as deepsensor 10 | 11 | from deepsensor.train.train import Trainer 12 | from deepsensor.data.processor import DataProcessor 13 | from deepsensor.data.loader import TaskLoader 14 | from deepsensor.model.convnp import ConvNP 15 | from deepsensor.data.task import concat_tasks 16 | 17 | 18 | class TestTraining(unittest.TestCase): 19 | 20 | @classmethod 21 | def setUpClass(cls): 22 | # It's safe to share data between tests because the TaskLoader does not modify data 23 | ds_raw = xr.tutorial.open_dataset("air_temperature") 24 | 25 | cls.ds_raw = ds_raw 26 | cls.data_processor = DataProcessor(x1_name="lat", x2_name="lon") 27 | 28 | cls.da = cls.data_processor(ds_raw) 29 | 30 | def test_concat_tasks(self): 31 | tl = TaskLoader(context=self.da, target=self.da) 32 | 33 | seed = 42 34 | rng = np.random.default_rng(seed) 35 | 36 | n_tasks = 5 37 | tasks = [] 38 | tasks_different_n_targets = [] 39 | for i in range(n_tasks): 40 | n_context = rng.integers(1, 100) 41 | n_target = rng.integers(1, 100) 42 | date = rng.choice(self.da.time.values) 43 | tasks_different_n_targets.append( 44 | tl(date, n_context, n_target) 45 | ) # Changing number of targets 46 | tasks.append(tl(date, n_context, 42)) # Fixed number of targets 47 | 48 | multiple = 50 49 | with self.assertRaises(ValueError): 50 | merged_task = concat_tasks(tasks_different_n_targets, multiple=multiple) 51 | 52 | # Check that the context and target data are concatenated correctly 53 | merged_task = concat_tasks(tasks, multiple=multiple) 54 | 55 | def test_concat_tasks_with_nans(self): 56 | tl = TaskLoader(context=self.da, target=self.da) 57 | 58 | seed = 42 59 | rng = np.random.default_rng(seed) 60 | 61 | n_tasks = 5 62 | tasks = [] 63 | tasks_different_n_targets = [] 64 | for i in range(n_tasks): 65 | n_context = rng.integers(1, 100) 66 | n_target = rng.integers(1, 100) 67 | date = rng.choice(self.da.time.values) 68 | tasks_different_n_targets.append( 69 | tl(date, n_context, n_target) 70 | ) # Changing number of targets 71 | task = tl(date, n_context, 42) 72 | task["Y_c"][0][:, 0] = np.nan # Add NaN to context 73 | task["Y_t"][0][:, 0] = np.nan # Add NaN to target 74 | tasks.append(task) 75 | 76 | multiple = 50 77 | 78 | # Check that the context and target data are concatenated correctly 79 | merged_task = concat_tasks(tasks, multiple=multiple) 80 | 81 | if np.any(np.isnan(merged_task["Y_c"][0].y)): 82 | raise ValueError("NaNs in the merged context data") 83 | 84 | def test_training(self): 85 | """A basic test of the training loop 86 | 87 | Note: This could be extended into a regression test, e.g. checking the loss decreases, 88 | the model parameters change, the speed of training is reasonable, etc. 89 | """ 90 | tl = TaskLoader(context=self.da, target=self.da) 91 | model = ConvNP(self.data_processor, tl, unet_channels=(5, 5, 5), verbose=False) 92 | 93 | # Generate training tasks 94 | n_train_tasks = 10 95 | train_tasks = [] 96 | for i in range(n_train_tasks): 97 | date = np.random.choice(self.da.time.values) 98 | task = tl(date, 10, 10) 99 | task["Y_c"][0][:, 0] = np.nan # Add NaN to context 100 | task["Y_t"][0][:, 0] = np.nan # Add NaN to target 101 | print(task) 102 | train_tasks.append(task) 103 | 104 | # Train 105 | trainer = Trainer(model, lr=5e-5) 106 | # batch_size = None 107 | batch_size = 5 108 | n_epochs = 10 109 | epoch_losses = [] 110 | for epoch in tqdm(range(n_epochs)): 111 | batch_losses = trainer(train_tasks, batch_size=batch_size) 112 | epoch_losses.append(np.mean(batch_losses)) 113 | 114 | # Check for NaNs in the loss 115 | loss = np.mean(epoch_losses) 116 | self.assertFalse(np.isnan(loss)) 117 | 118 | def test_training_multidim(self): 119 | """A basic test of the training loop with multidimensional context sets""" 120 | # Load raw data 121 | ds_raw = xr.tutorial.open_dataset("air_temperature") 122 | 123 | # Add extra dim 124 | ds_raw["air2"] = ds_raw["air"].copy() 125 | 126 | # Normalise data 127 | dp = DataProcessor(x1_name="lat", x2_name="lon") 128 | ds = dp(ds_raw) 129 | 130 | # Set up task loader 131 | tl = TaskLoader(context=ds, target=ds) 132 | 133 | # Set up model 134 | model = ConvNP(dp, tl) 135 | 136 | # Generate training tasks 137 | n_train_tasks = 10 138 | train_tasks = [] 139 | for i in range(n_train_tasks): 140 | date = np.random.choice(self.da.time.values) 141 | task = tl(date, 10, 10) 142 | task["Y_c"][0][:, 0] = np.nan # Add NaN to context 143 | task["Y_t"][0][:, 0] = np.nan # Add NaN to target 144 | print(task) 145 | train_tasks.append(task) 146 | 147 | # Train 148 | trainer = Trainer(model, lr=5e-5) 149 | # batch_size = None 150 | batch_size = 5 151 | n_epochs = 10 152 | epoch_losses = [] 153 | for epoch in tqdm(range(n_epochs)): 154 | batch_losses = trainer(train_tasks, batch_size=batch_size) 155 | epoch_losses.append(np.mean(batch_losses)) 156 | 157 | # Check for NaNs in the loss 158 | loss = np.mean(epoch_losses) 159 | self.assertFalse(np.isnan(loss)) 160 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import xarray as xr 4 | 5 | from typing import Union, Optional 6 | 7 | 8 | def gen_random_data_xr( 9 | coords: dict, dims: Optional[list] = None, data_vars: Optional[list] = None 10 | ): 11 | """Generate random xarray data. 12 | 13 | Args: 14 | coords (dict): 15 | Coordinates of the data. 16 | dims (list, optional): 17 | Dimensions of the data. Defaults to None. If None, dims is inferred 18 | from coords. This arg can be used to change the order of the 19 | dimensions. 20 | data_vars (list, optional): 21 | Data variables. Defaults to None. If None, variable is an 22 | :class:`xarray.DataArray`. If not None, variable is an 23 | :class:`xarray.Dataset` containing the data_vars. 24 | 25 | Returns: 26 | da (:class:`xarray.DataArray` | :class:`xarray.Dataset`): 27 | Random xarray data. 28 | """ 29 | if dims is None: 30 | shape = tuple([len(coords[dim]) for dim in coords]) 31 | else: 32 | shape = tuple([len(coords[dim]) for dim in dims]) 33 | data = np.random.rand(*shape) 34 | if data_vars is None: 35 | name = "var" 36 | da = xr.DataArray(data, coords=coords, name=name) 37 | else: 38 | data = {var: xr.DataArray(data, coords=coords) for var in data_vars} 39 | da = xr.Dataset(data, coords=coords) 40 | return da 41 | 42 | 43 | def gen_random_data_pandas(coords: dict, dims: list = None, cols: list = None): 44 | """Generate random pandas data. 45 | 46 | Args: 47 | coords (dict): 48 | Coordinates of the data. This will be used to construct a 49 | MultiIndex using pandas.MultiIndex.from_product. 50 | dims (list, optional): 51 | Dimensions of the data. Defaults to None. If None, dims is inferred 52 | from coords. This arg can be used to change the order of the 53 | MultiIndex. 54 | cols (list, optional): 55 | Columns of the data. Defaults to None. If None, generate a 56 | :class:`pandas.Series` with an arbitrary name. If not None, cols is 57 | used to construct a :class:`pandas.DataFrame`. 58 | 59 | Returns: 60 | :class:`pandas.Series` | :class:`pandas.DataFrame` 61 | Random pandas data. 62 | """ 63 | if dims is None: 64 | dims = list(coords.keys()) 65 | mi = pd.MultiIndex.from_product([coords[dim] for dim in dims], names=dims) 66 | if cols is None: 67 | name = "var" 68 | df = pd.Series(index=mi, name=name) 69 | else: 70 | df = pd.DataFrame(index=mi, columns=cols) 71 | df[:] = np.random.rand(*df.shape) 72 | return df 73 | 74 | 75 | def assert_allclose_pd( 76 | df1: Union[pd.DataFrame, pd.Series], df2: Union[pd.DataFrame, pd.Series] 77 | ): 78 | if isinstance(df1, pd.Series): 79 | df1 = df1.to_frame() 80 | if isinstance(df2, pd.Series): 81 | df2 = df2.to_frame() 82 | try: 83 | pd.testing.assert_frame_equal(df1, df2) 84 | except AssertionError: 85 | return False 86 | return True 87 | 88 | 89 | def assert_allclose_xr( 90 | da1: Union[xr.DataArray, xr.Dataset], da2: Union[xr.DataArray, xr.Dataset] 91 | ): 92 | try: 93 | xr.testing.assert_allclose(da1, da2) 94 | except AssertionError: 95 | return False 96 | return True 97 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion = 3.8.0 3 | envlist = py3.8, py3.9, py3.10, py3.11, py3.12 4 | isolated_build = true 5 | 6 | [gh-actions] 7 | python = 8 | 3.8: python3.8 9 | 3.9: python3.9 10 | 3.10: python3.10 11 | 3.11: python3.11 12 | 3.12: python3.12 13 | 14 | [testenv] 15 | setenv = 16 | PYTHONPATH = {toxinidir} 17 | deps = 18 | .[tf,torch,testing] 19 | commands = 20 | pytest --basetemp={envtmpdir} --------------------------------------------------------------------------------