├── .devcontainer.json ├── .github └── workflows │ └── main.yml ├── .gitignore ├── CHANGELOG.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── conda └── eclipse_pytorch │ └── meta.yaml ├── docker-compose.yml ├── docs ├── .gitignore ├── Gemfile ├── _config.yml ├── _data │ ├── sidebars │ │ └── home_sidebar.yml │ └── topnav.yml ├── feed.xml ├── images │ └── eclipse_diagram.png ├── index.html ├── layers.html ├── model.html ├── sidebar.json └── sitemap.xml ├── eclipse_pytorch ├── __init__.py ├── _nbdev.py ├── imports.py ├── layers.py └── model.py ├── nbs ├── 00_model.ipynb ├── 01_layers.ipynb ├── images │ └── eclipse_diagram.png └── index.ipynb ├── settings.ini └── setup.py /.devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "nbdev_template-codespaces", 3 | "dockerComposeFile": "docker-compose.yml", 4 | "service": "watcher", 5 | "settings": {"terminal.integrated.shell.linux": "/bin/bash"}, 6 | "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ], 7 | "forwardPorts": [4000, 8080], 8 | "appPort": [4000, 8080], 9 | "extensions": ["ms-python.python", 10 | "ms-azuretools.vscode-docker"], 11 | "runServices": ["notebook", "jekyll", "watcher"], 12 | "postStartCommand": "pip install -e ." 13 | } 14 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [push, pull_request] 3 | jobs: 4 | build: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v1 8 | - uses: actions/setup-python@v1 9 | with: 10 | python-version: '3.8' 11 | architecture: 'x64' 12 | - name: Install the library 13 | run: | 14 | pip install nbdev jupyter 15 | pip install -e . 16 | - name: Read all notebooks 17 | run: | 18 | nbdev_read_nbs 19 | - name: Check if all notebooks are cleaned 20 | run: | 21 | echo "Check we are starting with clean git checkout" 22 | if [ -n "$(git status -uno -s)" ]; then echo "git status is not clean"; false; fi 23 | echo "Trying to strip out notebooks" 24 | nbdev_clean_nbs 25 | echo "Check that strip out was unnecessary" 26 | git status -s # display the status to see which nbs need cleaning up 27 | if [ -n "$(git status -uno -s)" ]; then echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev_install_git_hooks"; false; fi 28 | - name: Check if there is no diff library/notebooks 29 | run: | 30 | if [ -n "$(nbdev_diff_nbs)" ]; then echo -e "!!! Detected difference between the notebooks and the library"; false; fi 31 | - name: Run tests 32 | run: | 33 | nbdev_test_nbs 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .jekyll-cache/ 2 | Gemfile.lock 3 | *.bak 4 | .gitattributes 5 | .last_checked 6 | .gitconfig 7 | *.bak 8 | *.log 9 | *~ 10 | ~* 11 | _tmp* 12 | tmp* 13 | tags 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # dotenv 97 | .env 98 | 99 | # virtualenv 100 | .venv 101 | venv/ 102 | ENV/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | 117 | .vscode 118 | *.swp 119 | 120 | # osx generated files 121 | .DS_Store 122 | .DS_Store? 123 | .Trashes 124 | ehthumbs.db 125 | Thumbs.db 126 | .idea 127 | 128 | # pytest 129 | .pytest_cache 130 | 131 | # tools/trust-doc-nbs 132 | docs_src/.last_checked 133 | 134 | # symlinks to fastai 135 | docs_src/fastai 136 | tools/fastai 137 | 138 | # link checker 139 | checklink/cookies.txt 140 | 141 | # .gitconfig is now autogenerated 142 | .gitconfig 143 | 144 | token 145 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Release notes 2 | 3 | 4 | 5 | ## 0.0.8 6 | 7 | 8 | 9 | 10 | ## 0.0.7 11 | 12 | 13 | 14 | 15 | ## 0.0.7 16 | 17 | 18 | 19 | 20 | ## 0.0.7 21 | 22 | 23 | 24 | 25 | ## 0.0.6 26 | 27 | 28 | 29 | 30 | ## 0.0.6 31 | 32 | 33 | 34 | 35 | ## 0.0.5 36 | 37 | 38 | 39 | 40 | ## 0.0.5 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | ## How to get started 4 | 5 | Before anything else, please install the git hooks that run automatic scripts during each commit and merge to strip the notebooks of superfluous metadata (and avoid merge conflicts). After cloning the repository, run the following command inside it: 6 | ``` 7 | nbdev_install_git_hooks 8 | ``` 9 | 10 | ## Did you find a bug? 11 | 12 | * Ensure the bug was not already reported by searching on GitHub under Issues. 13 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring. 14 | * Be sure to add the complete error messages. 15 | 16 | #### Did you write a patch that fixes a bug? 17 | 18 | * Open a new GitHub pull request with the patch. 19 | * Ensure that your PR includes a test that fails without your patch, and pass with it. 20 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable. 21 | 22 | ## PR submission guidelines 23 | 24 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused. 25 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected. 26 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can. 27 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project. 28 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another. 29 | 30 | ## Do you want to contribute to the documentation? 31 | 32 | * Docs are automatically created from the notebooks in the nbs folder. 33 | 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .ONESHELL: 2 | SHELL := /bin/bash 3 | SRC = $(wildcard nbs/*.ipynb) 4 | 5 | all: eclipse_pytorch docs 6 | 7 | eclipse_pytorch: $(SRC) 8 | nbdev_build_lib 9 | touch eclipse_pytorch 10 | 11 | sync: 12 | nbdev_update_lib 13 | 14 | docs_serve: docs 15 | cd docs && bundle exec jekyll serve 16 | 17 | docs: $(SRC) 18 | nbdev_build_docs 19 | touch docs 20 | 21 | test: 22 | nbdev_test_nbs 23 | 24 | release: pypi conda_release 25 | nbdev_bump_version 26 | 27 | conda_release: 28 | fastrelease_conda_package 29 | 30 | pypi: dist 31 | twine upload --repository pypi dist/* 32 | 33 | dist: clean 34 | python setup.py sdist bdist_wheel 35 | 36 | clean: 37 | rm -rf dist -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Eclipse 2 | > Implementing Paletta et al in Pytorch 3 | 4 | 5 | Most of the codebase comes from [Fiery](https://github.com/wayveai/fiery) 6 | 7 | ![Image](nbs/images/eclipse_diagram.png) 8 | 9 | ## Install 10 | 11 | ```bash 12 | pip install eclipse_pytorch 13 | ``` 14 | 15 | ## How to use 16 | 17 | ```python 18 | import torch 19 | 20 | from eclipse_pytorch.model import Eclipse 21 | ``` 22 | 23 | ```python 24 | eclipse = Eclipse(horizon=5) 25 | ``` 26 | 27 | let's simulte some input images: 28 | 29 | ```python 30 | images = [torch.rand(2, 3, 128, 128) for _ in range(4)] 31 | ``` 32 | 33 | ```python 34 | preds = eclipse(images) 35 | ``` 36 | 37 | you get a dict with forecasted masks and irradiances: 38 | 39 | ```python 40 | len(preds['masks']), preds['masks'][0].shape, preds['irradiances'].shape 41 | ``` 42 | 43 | 44 | 45 | 46 | (6, torch.Size([2, 4, 128, 128]), torch.Size([2, 6])) 47 | 48 | 49 | 50 | ## Citation 51 | 52 | ```latex 53 | @article{paletta2021eclipse, 54 | title = {{ECLIPSE} : Envisioning Cloud Induced Perturbations in Solar Energy}, 55 | author = {Quentin Paletta and Anthony Hu and Guillaume Arbod and Joan Lasenby}, 56 | year = {2021}, 57 | eprinttype = {arXiv}, 58 | eprint = {2104.12419} 59 | } 60 | ``` 61 | 62 | ## Contribute 63 | 64 | This repo is made with [nbdev](https://github.com/fastai/nbdev), please read the documentation to contribute 65 | -------------------------------------------------------------------------------- /conda/eclipse_pytorch/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: eclipse_pytorch 3 | version: 0.0.8 4 | source: 5 | sha256: 3dc73af281153764e70df6e2153f4118d49d012267d9d66317f57e12db566876 6 | url: https://files.pythonhosted.org/packages/71/ae/edf126362a552f73948c94ed09d1d3e653c3500e6ef9447cde863f7fa525/eclipse_pytorch-0.0.8.tar.gz 7 | about: 8 | dev_url: https://tcapelle.github.io 9 | doc_url: https://tcapelle.github.io 10 | home: https://tcapelle.github.io 11 | license: Apache Software 12 | license_family: APACHE 13 | summary: A pytorch implementation of Eclipse 14 | build: 15 | noarch: python 16 | number: '0' 17 | script: '{{ PYTHON }} -m pip install . -vv' 18 | extra: 19 | recipe-maintainers: 20 | - tcapelle 21 | requirements: 22 | host: 23 | - pip 24 | - python 25 | - packaging 26 | - torch>1.6 27 | - fastcore 28 | run: 29 | - pip 30 | - python 31 | - packaging 32 | - torch>1.6 33 | - fastcore 34 | test: 35 | imports: 36 | - eclipse_pytorch 37 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | fastai: &fastai 4 | restart: unless-stopped 5 | working_dir: /data 6 | image: fastai/codespaces 7 | logging: 8 | driver: json-file 9 | options: 10 | max-size: 50m 11 | stdin_open: true 12 | tty: true 13 | volumes: 14 | - .:/data/ 15 | 16 | notebook: 17 | <<: *fastai 18 | command: bash -c "pip install -e . && jupyter notebook --allow-root --no-browser --ip=0.0.0.0 --port=8080 --NotebookApp.token='' --NotebookApp.password=''" 19 | ports: 20 | - "8080:8080" 21 | 22 | watcher: 23 | <<: *fastai 24 | command: watchmedo shell-command --command nbdev_build_docs --pattern *.ipynb --recursive --drop 25 | network_mode: host # for GitHub Codespaces https://github.com/features/codespaces/ 26 | 27 | jekyll: 28 | <<: *fastai 29 | ports: 30 | - "4000:4000" 31 | command: > 32 | bash -c "pip install . 33 | && nbdev_build_docs && cd docs 34 | && bundle i 35 | && chmod -R u+rwx . && bundle exec jekyll serve --host 0.0.0.0" 36 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _site/ 2 | -------------------------------------------------------------------------------- /docs/Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gem "jekyll", ">= 3.7" 4 | gem "jekyll-remote-theme" 5 | -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | repository: tcapelle/eclipse_pytorch 2 | output: web 3 | topnav_title: eclipse_pytorch 4 | site_title: eclipse_pytorch 5 | company_name: bla blabla 6 | description: A pytorch implementation of Eclipse 7 | # Set to false to disable KaTeX math 8 | use_math: true 9 | # Add Google analytics id if you have one and want to use it here 10 | google_analytics: 11 | # See http://nbdev.fast.ai/search for help with adding Search 12 | google_search: 13 | 14 | host: 127.0.0.1 15 | # the preview server used. Leave as is. 16 | port: 4000 17 | # the port where the preview is rendered. 18 | 19 | exclude: 20 | - .idea/ 21 | - .gitignore 22 | - vendor 23 | 24 | exclude: [vendor] 25 | 26 | highlighter: rouge 27 | markdown: kramdown 28 | kramdown: 29 | input: GFM 30 | auto_ids: true 31 | hard_wrap: false 32 | syntax_highlighter: rouge 33 | 34 | collections: 35 | tooltips: 36 | output: false 37 | 38 | defaults: 39 | - 40 | scope: 41 | path: "" 42 | type: "pages" 43 | values: 44 | layout: "page" 45 | comments: true 46 | search: true 47 | sidebar: home_sidebar 48 | topnav: topnav 49 | - 50 | scope: 51 | path: "" 52 | type: "tooltips" 53 | values: 54 | layout: "page" 55 | comments: true 56 | search: true 57 | tooltip: true 58 | 59 | sidebars: 60 | - home_sidebar 61 | 62 | plugins: 63 | - jekyll-remote-theme 64 | 65 | remote_theme: fastai/nbdev-jekyll-theme 66 | baseurl: /eclipse_pytorch/ -------------------------------------------------------------------------------- /docs/_data/sidebars/home_sidebar.yml: -------------------------------------------------------------------------------- 1 | 2 | ################################################# 3 | ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ### 4 | ################################################# 5 | # Instead edit ../../sidebar.json 6 | entries: 7 | - folders: 8 | - folderitems: 9 | - output: web,pdf 10 | title: Overview 11 | url: / 12 | - output: web,pdf 13 | title: The model 14 | url: model.html 15 | - output: web,pdf 16 | title: Layers 17 | url: layers.html 18 | output: web 19 | title: eclipse_pytorch 20 | output: web 21 | title: Sidebar 22 | -------------------------------------------------------------------------------- /docs/_data/topnav.yml: -------------------------------------------------------------------------------- 1 | topnav: 2 | - title: Topnav 3 | items: 4 | - title: github 5 | external_url: https://github.com/tcapelle/eclipse_pytorch/tree/master/ 6 | 7 | #Topnav dropdowns 8 | topnav_dropdowns: 9 | - title: Topnav dropdowns 10 | folders: -------------------------------------------------------------------------------- /docs/feed.xml: -------------------------------------------------------------------------------- 1 | --- 2 | search: exclude 3 | layout: none 4 | --- 5 | 6 | 7 | 8 | 9 | {{ site.title | xml_escape }} 10 | {{ site.description | xml_escape }} 11 | {{ site.url }}/ 12 | 13 | {{ site.time | date_to_rfc822 }} 14 | {{ site.time | date_to_rfc822 }} 15 | Jekyll v{{ jekyll.version }} 16 | {% for post in site.posts limit:10 %} 17 | 18 | {{ post.title | xml_escape }} 19 | {{ post.content | xml_escape }} 20 | {{ post.date | date_to_rfc822 }} 21 | {{ post.url | prepend: site.url }} 22 | {{ post.url | prepend: site.url }} 23 | {% for tag in post.tags %} 24 | {{ tag | xml_escape }} 25 | {% endfor %} 26 | {% for tag in page.tags %} 27 | {{ cat | xml_escape }} 28 | {% endfor %} 29 | 30 | {% endfor %} 31 | 32 | 33 | -------------------------------------------------------------------------------- /docs/images/eclipse_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tcapelle/eclipse_pytorch/13b1c41b076a535eb01abf37f777f9fa8309ee48/docs/images/eclipse_diagram.png -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | title: Eclipse 4 | 5 | 6 | keywords: fastai 7 | sidebar: home_sidebar 8 | 9 | summary: "Implementing Paletta et al in Pytorch" 10 | description: "Implementing Paletta et al in Pytorch" 11 | nb_path: "nbs/index.ipynb" 12 | --- 13 | 22 | 23 |
24 | 25 | {% raw %} 26 | 27 |
28 | 29 |
30 | {% endraw %} 31 | 32 |
33 |
34 |

Most of the codebase comes from Fiery

35 | 36 |
37 |
38 |
39 |
40 |
41 |

Image

42 | 43 |
44 |
45 |
46 |
47 |
48 |

Install

49 |
50 |
51 |
52 |
53 |
54 |
pip install eclipse_pytorch
 55 | 
56 | 57 |
58 |
59 |
60 |
61 |
62 |

How to use

63 |
64 |
65 |
66 | {% raw %} 67 | 68 |
69 |
70 | 71 |
72 |
73 |
import torch
 74 | 
 75 | from eclipse_pytorch.model import Eclipse
 76 | 
77 | 78 |
79 |
80 |
81 | 82 |
83 | {% endraw %} 84 | 85 | {% raw %} 86 | 87 |
88 |
89 | 90 |
91 |
92 |
eclipse = Eclipse(horizon=5)
 93 | 
94 | 95 |
96 |
97 |
98 | 99 |
100 | {% endraw %} 101 | 102 |
103 |
104 |

let's simulte some input images:

105 | 106 |
107 |
108 |
109 | {% raw %} 110 | 111 |
112 |
113 | 114 |
115 |
116 |
images = [torch.rand(2, 3, 128, 128) for _ in range(4)]
117 | 
118 | 119 |
120 |
121 |
122 | 123 |
124 | {% endraw %} 125 | 126 | {% raw %} 127 | 128 |
129 |
130 | 131 |
132 |
133 |
preds = eclipse(images)
134 | 
135 | 136 |
137 |
138 |
139 | 140 |
141 | {% endraw %} 142 | 143 |
144 |
145 |

you get a dict with forecasted masks and irradiances:

146 | 147 |
148 |
149 |
150 | {% raw %} 151 | 152 |
153 |
154 | 155 |
156 |
157 |
len(preds['masks']), preds['masks'][0].shape, preds['irradiances'].shape
158 | 
159 | 160 |
161 |
162 |
163 | 164 |
165 |
166 | 167 |
168 | 169 | 170 | 171 |
172 |
(6, torch.Size([2, 4, 128, 128]), torch.Size([2, 6]))
173 |
174 | 175 |
176 | 177 |
178 |
179 | 180 |
181 | {% endraw %} 182 | 183 |
184 |
185 |

Citation

186 |
187 |
188 |
189 |
190 |
191 |
@article{paletta2021eclipse,
192 |   title     = {{ECLIPSE} : Envisioning Cloud Induced Perturbations in Solar Energy},
193 |   author    = {Quentin Paletta and Anthony Hu and Guillaume Arbod and Joan Lasenby},
194 |   year      = {2021},
195 |   eprinttype = {arXiv},
196 |   eprint    = {2104.12419}
197 | }
198 | 
199 | 200 |
201 |
202 |
203 |
204 |
205 |

Contribute

This repo is made with nbdev, please read the documentation to contribute

206 | 207 |
208 |
209 |
210 |
211 | 212 | 213 | -------------------------------------------------------------------------------- /docs/layers.html: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | title: Layers 4 | 5 | 6 | keywords: fastai 7 | sidebar: home_sidebar 8 | 9 | summary: "most of them come from fiery" 10 | description: "most of them come from fiery" 11 | nb_path: "nbs/01_layers.ipynb" 12 | --- 13 | 22 | 23 |
24 | 25 | {% raw %} 26 | 27 |
28 | 29 |
30 | {% endraw %} 31 | 32 | {% raw %} 33 | 34 |
35 | 36 |
37 | {% endraw %} 38 | 39 | {% raw %} 40 | 41 |
42 | 43 |
44 |
45 | 46 |
47 | 48 | 49 |
50 |

get_activation[source]

get_activation(activation)

51 |
52 | 53 |
54 | 55 |
56 | 57 |
58 |
59 | 60 |
61 | {% endraw %} 62 | 63 | {% raw %} 64 | 65 |
66 | 67 |
68 |
69 | 70 |
71 | 72 | 73 |
74 |

get_norm[source]

get_norm(norm, out_channels)

75 |
76 | 77 |
78 | 79 |
80 | 81 |
82 |
83 | 84 |
85 | {% endraw %} 86 | 87 | {% raw %} 88 | 89 |
90 | 91 |
92 | {% endraw %} 93 | 94 | {% raw %} 95 | 96 |
97 | 98 |
99 |
100 | 101 |
102 | 103 | 104 |
105 |

init_linear[source]

init_linear(m, act_func=None, init='auto', bias_std=0.01)

106 |
107 | 108 |
109 | 110 |
111 | 112 |
113 |
114 | 115 |
116 | {% endraw %} 117 | 118 | {% raw %} 119 | 120 |
121 | 122 |
123 |
124 | 125 |
126 | 127 | 128 |
129 |

class ConvBlock[source]

ConvBlock(in_channels, out_channels=None, kernel_size=3, stride=1, norm='bn', activation='relu', bias=False, transpose=False, init='auto') :: Sequential

130 |
131 |

2D convolution followed by

132 |
    133 |
  • an optional normalisation (batch norm or instance norm)
  • 134 |
  • an optional activation (ReLU, LeakyReLU, or tanh)
  • 135 |
136 | 137 |
138 | 139 |
140 | 141 |
142 |
143 | 144 |
145 | {% endraw %} 146 | 147 | {% raw %} 148 | 149 |
150 | 151 |
152 | {% endraw %} 153 | 154 | {% raw %} 155 | 156 |
157 | 158 |
159 |
160 | 161 |
162 | 163 | 164 |
165 |

class Bottleneck[source]

Bottleneck(in_channels, out_channels=None, kernel_size=3, dilation=1, groups=1, upsample=False, downsample=False, dropout=0.0) :: Module

166 |
167 |

Defines a bottleneck module with a residual connection

168 | 169 |
170 | 171 |
172 | 173 |
174 |
175 | 176 |
177 | {% endraw %} 178 | 179 | {% raw %} 180 | 181 |
182 | 183 |
184 | {% endraw %} 185 | 186 | {% raw %} 187 | 188 |
189 | 190 |
191 |
192 | 193 |
194 | 195 | 196 |
197 |

class Interpolate[source]

Interpolate(scale_factor:int=2) :: Module

198 |
199 |

Base class for all neural network modules.

200 |

Your models should also subclass this class.

201 |

Modules can also contain other Modules, allowing to nest them in 202 | a tree structure. You can assign the submodules as regular attributes::

203 | 204 |
import torch.nn as nn
 205 | import torch.nn.functional as F
 206 | 
 207 | class Model(nn.Module):
 208 |     def __init__(self):
 209 |         super(Model, self).__init__()
 210 |         self.conv1 = nn.Conv2d(1, 20, 5)
 211 |         self.conv2 = nn.Conv2d(20, 20, 5)
 212 | 
 213 |     def forward(self, x):
 214 |         x = F.relu(self.conv1(x))
 215 |         return F.relu(self.conv2(x))
 216 | 
 217 | 
218 |

Submodules assigned in this way will be registered, and will have their 219 | parameters converted too when you call :meth:to, etc.

220 |

:ivar training: Boolean represents whether this module is in training or 221 | evaluation mode. 222 | :vartype training: bool

223 | 224 |
225 | 226 |
227 | 228 |
229 |
230 | 231 |
232 | {% endraw %} 233 | 234 | {% raw %} 235 | 236 |
237 | 238 |
239 |
240 | 241 |
242 | 243 | 244 |
245 |

class UpsamplingConcat[source]

UpsamplingConcat(in_channels, out_channels, scale_factor=2) :: Module

246 |
247 |

Base class for all neural network modules.

248 |

Your models should also subclass this class.

249 |

Modules can also contain other Modules, allowing to nest them in 250 | a tree structure. You can assign the submodules as regular attributes::

251 | 252 |
import torch.nn as nn
 253 | import torch.nn.functional as F
 254 | 
 255 | class Model(nn.Module):
 256 |     def __init__(self):
 257 |         super(Model, self).__init__()
 258 |         self.conv1 = nn.Conv2d(1, 20, 5)
 259 |         self.conv2 = nn.Conv2d(20, 20, 5)
 260 | 
 261 |     def forward(self, x):
 262 |         x = F.relu(self.conv1(x))
 263 |         return F.relu(self.conv2(x))
 264 | 
 265 | 
266 |

Submodules assigned in this way will be registered, and will have their 267 | parameters converted too when you call :meth:to, etc.

268 |

:ivar training: Boolean represents whether this module is in training or 269 | evaluation mode. 270 | :vartype training: bool

271 | 272 |
273 | 274 |
275 | 276 |
277 |
278 | 279 |
280 | {% endraw %} 281 | 282 | {% raw %} 283 | 284 |
285 | 286 |
287 |
288 | 289 |
290 | 291 | 292 |
293 |

class UpsamplingAdd[source]

UpsamplingAdd(in_channels, out_channels, scale_factor=2) :: Module

294 |
295 |

Base class for all neural network modules.

296 |

Your models should also subclass this class.

297 |

Modules can also contain other Modules, allowing to nest them in 298 | a tree structure. You can assign the submodules as regular attributes::

299 | 300 |
import torch.nn as nn
 301 | import torch.nn.functional as F
 302 | 
 303 | class Model(nn.Module):
 304 |     def __init__(self):
 305 |         super(Model, self).__init__()
 306 |         self.conv1 = nn.Conv2d(1, 20, 5)
 307 |         self.conv2 = nn.Conv2d(20, 20, 5)
 308 | 
 309 |     def forward(self, x):
 310 |         x = F.relu(self.conv1(x))
 311 |         return F.relu(self.conv2(x))
 312 | 
 313 | 
314 |

Submodules assigned in this way will be registered, and will have their 315 | parameters converted too when you call :meth:to, etc.

316 |

:ivar training: Boolean represents whether this module is in training or 317 | evaluation mode. 318 | :vartype training: bool

319 | 320 |
321 | 322 |
323 | 324 |
325 |
326 | 327 |
328 | {% endraw %} 329 | 330 | {% raw %} 331 | 332 |
333 | 334 |
335 | {% endraw %} 336 | 337 | {% raw %} 338 | 339 |
340 | 341 |
342 |
343 | 344 |
345 | 346 | 347 |
348 |

class ResBlock[source]

ResBlock(in_channels, out_channels, kernel_size=3, stride=1, norm='bn', activation='relu') :: Module

349 |
350 |

A simple resnet Block

351 | 352 |
353 | 354 |
355 | 356 |
357 |
358 | 359 |
360 | {% endraw %} 361 | 362 | {% raw %} 363 | 364 |
365 | 366 |
367 | {% endraw %} 368 | 369 | {% raw %} 370 | 371 |
372 |
373 | 374 |
375 |
376 |
res_block = ResBlock(64, 128, stride=2)
 377 | 
378 | 379 |
380 |
381 |
382 | 383 |
384 | {% endraw %} 385 | 386 | {% raw %} 387 | 388 |
389 |
390 | 391 |
392 |
393 |
x = torch.rand(1,64, 20,20)
 394 | res_block(x).shape
 395 | 
396 | 397 |
398 |
399 |
400 | 401 |
402 |
403 | 404 |
405 | 406 | 407 | 408 |
409 |
torch.Size([1, 128, 10, 10])
410 |
411 | 412 |
413 | 414 |
415 |
416 | 417 |
418 | {% endraw %} 419 | 420 | {% raw %} 421 | 422 |
423 | 424 |
425 |
426 | 427 |
428 | 429 | 430 |
431 |

class SpatialGRU[source]

SpatialGRU(input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu') :: Module

432 |
433 |

A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a 434 | convolutional gated recurrent unit over the data

435 | 436 |
437 | 438 |
439 | 440 |
441 |
442 | 443 |
444 | {% endraw %} 445 | 446 | {% raw %} 447 | 448 |
449 | 450 |
451 | {% endraw %} 452 | 453 | {% raw %} 454 | 455 |
456 |
457 | 458 |
459 |
460 |
sgru = SpatialGRU(input_size=32, hidden_size=64)
 461 | 
462 | 463 |
464 |
465 |
466 | 467 |
468 | {% endraw %} 469 | 470 |
471 |
472 |

without hidden state

473 | 474 |
475 |
476 |
477 | {% raw %} 478 | 479 |
480 |
481 | 482 |
483 |
484 |
x = torch.rand(1,3,32,8, 8)
 485 | test_eq(sgru(x).shape, (1,3,64,8,8))
 486 | 
487 | 488 |
489 |
490 |
491 | 492 |
493 | {% endraw %} 494 | 495 |
496 |
497 |

with hidden

498 |

hidden.shape = (bs, hidden_size, h, w)

499 |
500 | 501 |
502 |
503 |
504 | {% raw %} 505 | 506 |
507 |
508 | 509 |
510 |
511 |
x = torch.rand(1,3,32,8, 8)
 512 | hidden = torch.rand(1,64,8,8)
 513 | # test_eq(sgru(x).shape, (1,3,64,8,8))
 514 | 
 515 | sgru(x, hidden).shape
 516 | 
517 | 518 |
519 |
520 |
521 | 522 |
523 |
524 | 525 |
526 | 527 | 528 | 529 |
530 |
torch.Size([1, 3, 64, 8, 8])
531 |
532 | 533 |
534 | 535 |
536 |
537 | 538 |
539 | {% endraw %} 540 | 541 | {% raw %} 542 | 543 |
544 | 545 |
546 |
547 | 548 |
549 | 550 | 551 |
552 |

class CausalConv3d[source]

CausalConv3d(in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False) :: Module

553 |
554 |

Base class for all neural network modules.

555 |

Your models should also subclass this class.

556 |

Modules can also contain other Modules, allowing to nest them in 557 | a tree structure. You can assign the submodules as regular attributes::

558 | 559 |
import torch.nn as nn
 560 | import torch.nn.functional as F
 561 | 
 562 | class Model(nn.Module):
 563 |     def __init__(self):
 564 |         super(Model, self).__init__()
 565 |         self.conv1 = nn.Conv2d(1, 20, 5)
 566 |         self.conv2 = nn.Conv2d(20, 20, 5)
 567 | 
 568 |     def forward(self, x):
 569 |         x = F.relu(self.conv1(x))
 570 |         return F.relu(self.conv2(x))
 571 | 
 572 | 
573 |

Submodules assigned in this way will be registered, and will have their 574 | parameters converted too when you call :meth:to, etc.

575 |

:ivar training: Boolean represents whether this module is in training or 576 | evaluation mode. 577 | :vartype training: bool

578 | 579 |
580 | 581 |
582 | 583 |
584 |
585 | 586 |
587 | {% endraw %} 588 | 589 | {% raw %} 590 | 591 |
592 | 593 |
594 | {% endraw %} 595 | 596 | {% raw %} 597 | 598 |
599 |
600 | 601 |
602 |
603 |
cc3d = CausalConv3d(64, 128)
 604 | 
605 | 606 |
607 |
608 |
609 | 610 |
611 | {% endraw %} 612 | 613 | {% raw %} 614 | 615 |
616 |
617 | 618 |
619 |
620 |
x = torch.rand(1,64,4,8,8)
 621 | cc3d(x).shape
 622 | 
623 | 624 |
625 |
626 |
627 | 628 |
629 |
630 | 631 |
632 | 633 | 634 | 635 |
636 |
torch.Size([1, 128, 4, 8, 8])
637 |
638 | 639 |
640 | 641 |
642 |
643 | 644 |
645 | {% endraw %} 646 | 647 | {% raw %} 648 | 649 |
650 | 651 |
652 |
653 | 654 |
655 | 656 | 657 |
658 |

conv_1x1x1_norm_activated[source]

conv_1x1x1_norm_activated(in_channels, out_channels)

659 |
660 |

1x1x1 3D convolution, normalization and activation layer.

661 | 662 |
663 | 664 |
665 | 666 |
667 |
668 | 669 |
670 | {% endraw %} 671 | 672 | {% raw %} 673 | 674 |
675 | 676 |
677 | {% endraw %} 678 | 679 | {% raw %} 680 | 681 |
682 |
683 | 684 |
685 |
686 |
conv_1x1x1_norm_activated(2, 4)
 687 | 
688 | 689 |
690 |
691 |
692 | 693 |
694 |
695 | 696 |
697 | 698 | 699 | 700 |
701 |
Sequential(
 702 |   (conv): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
 703 |   (norm): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 704 |   (activation): ReLU(inplace=True)
 705 | )
706 |
707 | 708 |
709 | 710 |
711 |
712 | 713 |
714 | {% endraw %} 715 | 716 | {% raw %} 717 | 718 |
719 | 720 |
721 |
722 | 723 |
724 | 725 | 726 |
727 |

class Bottleneck3D[source]

Bottleneck3D(in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)) :: Module

728 |
729 |

Defines a bottleneck module with a residual connection

730 | 731 |
732 | 733 |
734 | 735 |
736 |
737 | 738 |
739 | {% endraw %} 740 | 741 | {% raw %} 742 | 743 |
744 | 745 |
746 | {% endraw %} 747 | 748 | {% raw %} 749 | 750 |
751 |
752 | 753 |
754 |
755 |
bn3d = Bottleneck3D(8, 12)
 756 | 
757 | 758 |
759 |
760 |
761 | 762 |
763 | {% endraw %} 764 | 765 | {% raw %} 766 | 767 |
768 |
769 | 770 |
771 |
772 |
x = torch.rand(1,8,4,8,8)
 773 | bn3d(x).shape
 774 | 
775 | 776 |
777 |
778 |
779 | 780 |
781 |
782 | 783 |
784 | 785 | 786 | 787 |
788 |
torch.Size([1, 12, 4, 8, 8])
789 |
790 | 791 |
792 | 793 |
794 |
795 | 796 |
797 | {% endraw %} 798 | 799 | {% raw %} 800 | 801 |
802 | 803 |
804 |
805 | 806 |
807 | 808 | 809 |
810 |

class PyramidSpatioTemporalPooling[source]

PyramidSpatioTemporalPooling(in_channels, reduction_channels, pool_sizes) :: Module

811 |
812 |

Spatio-temporal pyramid pooling. 813 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling. 814 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]

815 | 816 |
817 | 818 |
819 | 820 |
821 |
822 | 823 |
824 | {% endraw %} 825 | 826 | {% raw %} 827 | 828 |
829 | 830 |
831 | {% endraw %} 832 | 833 | {% raw %} 834 | 835 |
836 |
837 | 838 |
839 |
840 |
h,w = (64,64)
 841 | ptp = PyramidSpatioTemporalPooling(12, 12, [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)])
 842 | 
843 | 844 |
845 |
846 |
847 | 848 |
849 | {% endraw %} 850 | 851 | {% raw %} 852 | 853 |
854 |
855 | 856 |
857 |
858 |
x = torch.rand(1,12,4,64,64)
 859 | 
 860 | #the output is concatenated...
 861 | ptp(x).shape
 862 | 
863 | 864 |
865 |
866 |
867 | 868 |
869 |
870 | 871 |
872 | 873 | 874 | 875 |
876 |
torch.Size([1, 36, 4, 64, 64])
877 |
878 | 879 |
880 | 881 |
882 |
883 | 884 |
885 | {% endraw %} 886 | 887 | {% raw %} 888 | 889 |
890 | 891 |
892 |
893 | 894 |
895 | 896 | 897 |
898 |

class TemporalBlock[source]

TemporalBlock(in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None) :: Module

899 |
900 |

Temporal block with the following layers:

901 |
    902 |
  • 2x3x3, 1x3x3, spatio-temporal pyramid pooling
  • 903 |
  • dropout
  • 904 |
  • skip connection.
  • 905 |
906 | 907 |
908 | 909 |
910 | 911 |
912 |
913 | 914 |
915 | {% endraw %} 916 | 917 | {% raw %} 918 | 919 |
920 | 921 |
922 | {% endraw %} 923 | 924 | {% raw %} 925 | 926 |
927 |
928 | 929 |
930 |
931 |
tb = TemporalBlock(4,8)
 932 | 
 933 | x = torch.rand(1,4,4,64,64)
 934 | 
 935 | tb(x).shape
 936 | 
937 | 938 |
939 |
940 |
941 | 942 |
943 |
944 | 945 |
946 | 947 | 948 | 949 |
950 |
torch.Size([1, 8, 4, 64, 64])
951 |
952 | 953 |
954 | 955 |
956 |
957 | 958 |
959 | {% endraw %} 960 | 961 | {% raw %} 962 | 963 |
964 | 965 |
966 |
967 | 968 |
969 | 970 | 971 |
972 |

class FuturePrediction[source]

FuturePrediction(in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3) :: Module

973 |
974 |

Base class for all neural network modules.

975 |

Your models should also subclass this class.

976 |

Modules can also contain other Modules, allowing to nest them in 977 | a tree structure. You can assign the submodules as regular attributes::

978 | 979 |
import torch.nn as nn
 980 | import torch.nn.functional as F
 981 | 
 982 | class Model(nn.Module):
 983 |     def __init__(self):
 984 |         super(Model, self).__init__()
 985 |         self.conv1 = nn.Conv2d(1, 20, 5)
 986 |         self.conv2 = nn.Conv2d(20, 20, 5)
 987 | 
 988 |     def forward(self, x):
 989 |         x = F.relu(self.conv1(x))
 990 |         return F.relu(self.conv2(x))
 991 | 
 992 | 
993 |

Submodules assigned in this way will be registered, and will have their 994 | parameters converted too when you call :meth:to, etc.

995 |

:ivar training: Boolean represents whether this module is in training or 996 | evaluation mode. 997 | :vartype training: bool

998 | 999 |
1000 | 1001 |
1002 | 1003 |
1004 |
1005 | 1006 |
1007 | {% endraw %} 1008 | 1009 | {% raw %} 1010 | 1011 |
1012 | 1013 |
1014 | {% endraw %} 1015 | 1016 | {% raw %} 1017 | 1018 |
1019 |
1020 | 1021 |
1022 |
1023 |
fp = FuturePrediction(32, 32, 4, 4)
1024 | 
1025 | x = torch.rand(1,4, 32, 64, 64)
1026 | hidden = torch.rand(1,32,64,64)
1027 | test_eq(fp(x, hidden).shape, x.shape)
1028 | 
1029 | 1030 |
1031 |
1032 |
1033 | 1034 |
1035 | {% endraw %} 1036 | 1037 |
1038 |
1039 |

Export

1040 |
1041 |
1042 |
1043 |
1044 | 1045 | 1046 | -------------------------------------------------------------------------------- /docs/model.html: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | title: The model 4 | 5 | 6 | keywords: fastai 7 | sidebar: home_sidebar 8 | 9 | summary: "from the paper from Paletta et al" 10 | description: "from the paper from Paletta et al" 11 | nb_path: "nbs/00_model.ipynb" 12 | --- 13 | 22 | 23 |
24 | 25 | {% raw %} 26 | 27 |
28 | 29 |
30 | {% endraw %} 31 | 32 |
33 |
34 |

We will try to implement as close as possible the architecture from the paper ECLIPSE : Envisioning Cloud Induced Perturbations in Solar Energy

35 | 36 |
37 |
38 |
39 |
40 |
41 |

Image

42 | 43 |
44 |
45 |
46 | {% raw %} 47 | 48 |
49 | 50 |
51 | {% endraw %} 52 | 53 |
54 |
55 |

1. Spatial Downsampler

A resnet encoder to get image features

56 |
57 | 58 |
59 |
60 |
61 |
62 |
63 |

You could use any spatial downsampler as you want, but the paper states a simple resnet arch...

64 | 65 |
66 |
67 |
68 | {% raw %} 69 | 70 |
71 | 72 |
73 |
74 | 75 |
76 | 77 | 78 |
79 |

class SpatialDownsampler[source]

SpatialDownsampler(in_channels=3) :: Module

80 |
81 |

Base class for all neural network modules.

82 |

Your models should also subclass this class.

83 |

Modules can also contain other Modules, allowing to nest them in 84 | a tree structure. You can assign the submodules as regular attributes::

85 | 86 |
import torch.nn as nn
 87 | import torch.nn.functional as F
 88 | 
 89 | class Model(nn.Module):
 90 |     def __init__(self):
 91 |         super(Model, self).__init__()
 92 |         self.conv1 = nn.Conv2d(1, 20, 5)
 93 |         self.conv2 = nn.Conv2d(20, 20, 5)
 94 | 
 95 |     def forward(self, x):
 96 |         x = F.relu(self.conv1(x))
 97 |         return F.relu(self.conv2(x))
 98 | 
 99 | 
100 |

Submodules assigned in this way will be registered, and will have their 101 | parameters converted too when you call :meth:to, etc.

102 |

:ivar training: Boolean represents whether this module is in training or 103 | evaluation mode. 104 | :vartype training: bool

105 | 106 |
107 | 108 |
109 | 110 |
111 |
112 | 113 |
114 | {% endraw %} 115 | 116 | {% raw %} 117 | 118 |
119 | 120 |
121 | {% endraw %} 122 | 123 | {% raw %} 124 | 125 |
126 |
127 | 128 |
129 |
130 |
sd = SpatialDownsampler()
131 | 
132 | 133 |
134 |
135 |
136 | 137 |
138 | {% endraw %} 139 | 140 | {% raw %} 141 | 142 |
143 |
144 | 145 |
146 |
147 |
images = [torch.rand(1, 3, 64, 64) for _ in range(4)]
148 | features = torch.stack([sd(image) for image in images], dim=2)
149 | features.shape
150 | 
151 | 152 |
153 |
154 |
155 | 156 |
157 |
158 | 159 |
160 | 161 | 162 | 163 |
164 |
torch.Size([1, 256, 4, 8, 8])
165 |
166 | 167 |
168 | 169 |
170 |
171 | 172 |
173 | {% endraw %} 174 | 175 |
176 |
177 |

2. Temporal Encoder

178 |
179 |
180 |
181 | {% raw %} 182 | 183 |
184 |
185 | 186 |
187 |
188 |
te = TemporalBlock(256, 128)
189 | 
190 | 191 |
192 |
193 |
194 | 195 |
196 | {% endraw %} 197 | 198 | {% raw %} 199 | 200 |
201 |
202 | 203 |
204 |
205 |
temp_encoded = te(features)
206 | temp_encoded.shape
207 | 
208 | 209 |
210 |
211 |
212 | 213 |
214 |
215 | 216 |
217 | 218 | 219 | 220 |
221 |
torch.Size([1, 128, 4, 8, 8])
222 |
223 | 224 |
225 | 226 |
227 |
228 | 229 |
230 | {% endraw %} 231 | 232 |
233 |
234 |

3. Future State Predictions

235 |
236 |
237 |
238 | {% raw %} 239 | 240 |
241 |
242 | 243 |
244 |
245 |
fp = FuturePrediction(128, 128, n_gru_blocks=4, n_res_layers=4)
246 | 
247 | 248 |
249 |
250 |
251 | 252 |
253 | {% endraw %} 254 | 255 | {% raw %} 256 | 257 |
258 |
259 | 260 |
261 |
262 |
hidden = torch.rand(1, 128, 8, 8)
263 | x = torch.rand(1, 4, 128, 8, 8)
264 | fp(x, hidden).shape
265 | 
266 | 267 |
268 |
269 |
270 | 271 |
272 |
273 | 274 |
275 | 276 | 277 | 278 |
279 |
torch.Size([1, 4, 128, 8, 8])
280 |
281 | 282 |
283 | 284 |
285 |
286 | 287 |
288 | {% endraw %} 289 | 290 |
291 |
292 |

4A. Segmentation Decoder

293 |
294 |
295 |
296 | {% raw %} 297 | 298 |
299 |
300 | 301 |
302 |
303 |
bn = Bottleneck(256, 128, upsample=True)
304 | bn
305 | 
306 | 307 |
308 |
309 |
310 | 311 |
312 |
313 | 314 |
315 | 316 | 317 | 318 |
319 |
Bottleneck(
320 |   (layers): Sequential(
321 |     (conv_down_project): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
322 |     (abn_down_project): Sequential(
323 |       (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
324 |       (1): ReLU(inplace=True)
325 |     )
326 |     (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
327 |     (abn): Sequential(
328 |       (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
329 |       (1): ReLU(inplace=True)
330 |     )
331 |     (conv_up_project): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
332 |     (abn_up_project): Sequential(
333 |       (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
334 |       (1): ReLU(inplace=True)
335 |     )
336 |     (dropout): Dropout2d(p=0.0, inplace=False)
337 |   )
338 |   (projection): Sequential(
339 |     (upsample_skip_proj): Interpolate()
340 |     (conv_skip_proj): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
341 |     (bn_skip_proj): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
342 |   )
343 | )
344 |
345 | 346 |
347 | 348 |
349 |
350 | 351 |
352 | {% endraw %} 353 | 354 | {% raw %} 355 | 356 |
357 |
358 | 359 |
360 |
361 |
features[0].shape
362 | 
363 | 364 |
365 |
366 |
367 | 368 |
369 |
370 | 371 |
372 | 373 | 374 | 375 |
376 |
torch.Size([256, 4, 8, 8])
377 |
378 | 379 |
380 | 381 |
382 |
383 | 384 |
385 | {% endraw %} 386 | 387 | {% raw %} 388 | 389 |
390 |
391 | 392 |
393 |
394 |
x = torch.rand(1,256,32,32)
395 | bn(x).shape
396 | 
397 | 398 |
399 |
400 |
401 | 402 |
403 |
404 | 405 |
406 | 407 | 408 | 409 |
410 |
torch.Size([1, 128, 64, 64])
411 |
412 | 413 |
414 | 415 |
416 |
417 | 418 |
419 | {% endraw %} 420 | 421 | {% raw %} 422 | 423 |
424 | 425 |
426 |
427 | 428 |
429 | 430 | 431 |
432 |

class Upsampler[source]

Upsampler(sizes=[128, 128, 64], n_out=3) :: Module

433 |
434 |

Base class for all neural network modules.

435 |

Your models should also subclass this class.

436 |

Modules can also contain other Modules, allowing to nest them in 437 | a tree structure. You can assign the submodules as regular attributes::

438 | 439 |
import torch.nn as nn
440 | import torch.nn.functional as F
441 | 
442 | class Model(nn.Module):
443 |     def __init__(self):
444 |         super(Model, self).__init__()
445 |         self.conv1 = nn.Conv2d(1, 20, 5)
446 |         self.conv2 = nn.Conv2d(20, 20, 5)
447 | 
448 |     def forward(self, x):
449 |         x = F.relu(self.conv1(x))
450 |         return F.relu(self.conv2(x))
451 | 
452 | 
453 |

Submodules assigned in this way will be registered, and will have their 454 | parameters converted too when you call :meth:to, etc.

455 |

:ivar training: Boolean represents whether this module is in training or 456 | evaluation mode. 457 | :vartype training: bool

458 | 459 |
460 | 461 |
462 | 463 |
464 |
465 | 466 |
467 | {% endraw %} 468 | 469 | {% raw %} 470 | 471 |
472 | 473 |
474 | {% endraw %} 475 | 476 | {% raw %} 477 | 478 |
479 |
480 | 481 |
482 |
483 |
us = Upsampler()
484 | 
485 | x = torch.rand(1,128,32,32)
486 | us(x).shape
487 | 
488 | 489 |
490 |
491 |
492 | 493 |
494 |
495 | 496 |
497 | 498 | 499 | 500 |
501 |
torch.Size([1, 3, 256, 256])
502 |
503 | 504 |
505 | 506 |
507 |
508 | 509 |
510 | {% endraw %} 511 | 512 |
513 |
514 |

4B. Irradiance Module

515 |
516 |
517 |
518 | {% raw %} 519 | 520 |
521 | 522 |
523 |
524 | 525 |
526 | 527 | 528 |
529 |

class IrradianceModule[source]

IrradianceModule() :: Module

530 |
531 |

Base class for all neural network modules.

532 |

Your models should also subclass this class.

533 |

Modules can also contain other Modules, allowing to nest them in 534 | a tree structure. You can assign the submodules as regular attributes::

535 | 536 |
import torch.nn as nn
537 | import torch.nn.functional as F
538 | 
539 | class Model(nn.Module):
540 |     def __init__(self):
541 |         super(Model, self).__init__()
542 |         self.conv1 = nn.Conv2d(1, 20, 5)
543 |         self.conv2 = nn.Conv2d(20, 20, 5)
544 | 
545 |     def forward(self, x):
546 |         x = F.relu(self.conv1(x))
547 |         return F.relu(self.conv2(x))
548 | 
549 | 
550 |

Submodules assigned in this way will be registered, and will have their 551 | parameters converted too when you call :meth:to, etc.

552 |

:ivar training: Boolean represents whether this module is in training or 553 | evaluation mode. 554 | :vartype training: bool

555 | 556 |
557 | 558 |
559 | 560 |
561 |
562 | 563 |
564 | {% endraw %} 565 | 566 | {% raw %} 567 | 568 |
569 | 570 |
571 | {% endraw %} 572 | 573 | {% raw %} 574 | 575 |
576 |
577 | 578 |
579 |
580 |
im = IrradianceModule()
581 | 
582 | 583 |
584 |
585 |
586 | 587 |
588 | {% endraw %} 589 | 590 | {% raw %} 591 | 592 |
593 |
594 | 595 |
596 |
597 |
x = torch.rand(2, 128, 32, 32)
598 | im(x).shape
599 | 
600 | 601 |
602 |
603 |
604 | 605 |
606 |
607 | 608 |
609 | 610 | 611 | 612 |
613 |
torch.Size([2, 1])
614 |
615 | 616 |
617 | 618 |
619 |
620 | 621 |
622 | {% endraw %} 623 | 624 |
625 |
626 |

Everything Together...

627 |
628 |
629 |
630 | {% raw %} 631 | 632 |
633 | 634 |
635 |
636 | 637 |
638 | 639 | 640 |
641 |

class Eclipse[source]

Eclipse(n_in=3, n_out=4, horizon=5, debug=False) :: Module

642 |
643 |

Not very parametric

644 | 645 |
646 | 647 |
648 | 649 |
650 |
651 | 652 |
653 | {% endraw %} 654 | 655 | {% raw %} 656 | 657 |
658 | 659 |
660 | {% endraw %} 661 | 662 | {% raw %} 663 | 664 |
665 |
666 | 667 |
668 |
669 |
eclipse = Eclipse(debug=True)
670 | 
671 | 672 |
673 |
674 |
675 | 676 |
677 | {% endraw %} 678 | 679 | {% raw %} 680 | 681 |
682 |
683 | 684 |
685 |
686 |
preds = eclipse([torch.rand(2, 3, 128, 128) for _ in range(4)])
687 | 
688 | preds['masks'][0].shape, preds['irradiances'].shape
689 | 
690 | 691 |
692 |
693 |
694 | 695 |
696 |
697 | 698 |
699 | 700 |
701 |
states.shape=torch.Size([2, 4, 128, 16, 16])
702 | present_state.shape=torch.Size([2, 1, 128, 16, 16])
703 | hidden_state.shape=torch.Size([2, 128, 16, 16])
704 | future_states.shape=torch.Size([2, 6, 128, 16, 16])
705 | 
706 |
707 |
708 | 709 |
710 | 711 | 712 | 713 |
714 |
(torch.Size([2, 4, 128, 128]), torch.Size([2, 6]))
715 |
716 | 717 |
718 | 719 |
720 |
721 | 722 |
723 | {% endraw %} 724 | 725 |
726 |
727 |

Export

728 |
729 |
730 |
731 |
732 | 733 | 734 | -------------------------------------------------------------------------------- /docs/sidebar.json: -------------------------------------------------------------------------------- 1 | { 2 | "eclipse_pytorch": { 3 | "Overview": "/", 4 | "The model": "model.html", 5 | "Layers": "layers.html" 6 | } 7 | } -------------------------------------------------------------------------------- /docs/sitemap.xml: -------------------------------------------------------------------------------- 1 | --- 2 | layout: none 3 | search: exclude 4 | --- 5 | 6 | 7 | 8 | {% for post in site.posts %} 9 | {% unless post.search == "exclude" %} 10 | 11 | {{site.url}}{{post.url}} 12 | 13 | {% endunless %} 14 | {% endfor %} 15 | 16 | 17 | {% for page in site.pages %} 18 | {% unless page.search == "exclude" %} 19 | 20 | {{site.url}}{{ page.url}} 21 | 22 | {% endunless %} 23 | {% endfor %} 24 | -------------------------------------------------------------------------------- /eclipse_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.8" 2 | -------------------------------------------------------------------------------- /eclipse_pytorch/_nbdev.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED BY NBDEV! DO NOT EDIT! 2 | 3 | __all__ = ["index", "modules", "custom_doc_links", "git_url"] 4 | 5 | index = {"SpatialDownsampler": "00_model.ipynb", 6 | "TemporalModel": "00_model.ipynb", 7 | "Upsampler": "00_model.ipynb", 8 | "IrradianceModule": "00_model.ipynb", 9 | "Eclipse": "00_model.ipynb", 10 | "get_activation": "01_layers.ipynb", 11 | "get_norm": "01_layers.ipynb", 12 | "init_linear": "01_layers.ipynb", 13 | "ConvBlock": "01_layers.ipynb", 14 | "Bottleneck": "01_layers.ipynb", 15 | "Interpolate": "01_layers.ipynb", 16 | "UpsamplingConcat": "01_layers.ipynb", 17 | "UpsamplingAdd": "01_layers.ipynb", 18 | "ResBlock": "01_layers.ipynb", 19 | "SpatialGRU": "01_layers.ipynb", 20 | "CausalConv3d": "01_layers.ipynb", 21 | "conv_1x1x1_norm_activated": "01_layers.ipynb", 22 | "Bottleneck3D": "01_layers.ipynb", 23 | "PyramidSpatioTemporalPooling": "01_layers.ipynb", 24 | "TemporalBlock": "01_layers.ipynb", 25 | "FuturePrediction": "01_layers.ipynb"} 26 | 27 | modules = ["model.py", 28 | "layers.py"] 29 | 30 | doc_url = "https://tcapelle.github.io/eclipse_pytorch/" 31 | 32 | git_url = "https://github.com/tcapelle/eclipse_pytorch/tree/master/" 33 | 34 | def custom_doc_links(name): return None 35 | -------------------------------------------------------------------------------- /eclipse_pytorch/imports.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from fastcore.all import * -------------------------------------------------------------------------------- /eclipse_pytorch/layers.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_layers.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['get_activation', 'get_norm', 'init_linear', 'ConvBlock', 'Bottleneck', 'Interpolate', 'UpsamplingConcat', 4 | 'UpsamplingAdd', 'ResBlock', 'SpatialGRU', 'CausalConv3d', 'conv_1x1x1_norm_activated', 'Bottleneck3D', 5 | 'PyramidSpatioTemporalPooling', 'TemporalBlock', 'FuturePrediction'] 6 | 7 | # Cell 8 | from .imports import * 9 | 10 | # Cell 11 | def get_activation(activation): 12 | if activation == 'relu': 13 | return nn.ReLU(inplace=True) 14 | elif activation == 'lrelu': 15 | return nn.LeakyReLU(0.1, inplace=True) 16 | elif activation == 'elu': 17 | return nn.ELU(inplace=True) 18 | elif activation == 'tanh': 19 | return nn.Tanh(inplace=True) 20 | else: 21 | raise ValueError('Invalid activation {}'.format(activation)) 22 | 23 | def get_norm(norm, out_channels): 24 | if norm == 'bn': 25 | return nn.BatchNorm2d(out_channels) 26 | elif norm == 'in': 27 | return nn.InstanceNorm2d(out_channels) 28 | else: 29 | raise ValueError('Invalid norm {}'.format(norm)) 30 | 31 | # Cell 32 | 33 | def init_linear(m, act_func=None, init='auto', bias_std=0.01): 34 | if getattr(m,'bias',None) is not None and bias_std is not None: 35 | if bias_std != 0: normal_(m.bias, 0, bias_std) 36 | else: m.bias.data.zero_() 37 | if init=='auto': 38 | if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_ 39 | else: init = getattr(act_func.__class__, '__default_init__', None) 40 | if init is None: init = getattr(act_func, '__default_init__', None) 41 | if init is not None: init(m.weight) 42 | 43 | 44 | class ConvBlock(nn.Sequential): 45 | """2D convolution followed by 46 | - an optional normalisation (batch norm or instance norm) 47 | - an optional activation (ReLU, LeakyReLU, or tanh) 48 | """ 49 | 50 | def __init__( 51 | self, 52 | in_channels, 53 | out_channels=None, 54 | kernel_size=3, 55 | stride=1, 56 | norm='bn', 57 | activation='relu', 58 | bias=False, 59 | transpose=False, 60 | init='auto' 61 | ): 62 | 63 | out_channels = out_channels or in_channels 64 | padding = (kernel_size-1)//2 65 | conv_cls = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1) 66 | conv = conv_cls(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias) 67 | if activation is not None: activation = get_activation(activation) 68 | init_linear(conv, activation, init=init) 69 | layers = [conv] 70 | if activation is not None: layers.append(activation) 71 | if norm is not None: layers.append(get_norm(norm, out_channels)) 72 | super().__init__(*layers) 73 | 74 | # Cell 75 | class Bottleneck(nn.Module): 76 | """ 77 | Defines a bottleneck module with a residual connection 78 | """ 79 | 80 | def __init__( 81 | self, 82 | in_channels, 83 | out_channels=None, 84 | kernel_size=3, 85 | dilation=1, 86 | groups=1, 87 | upsample=False, 88 | downsample=False, 89 | dropout=0.0, 90 | ): 91 | super().__init__() 92 | self._downsample = downsample 93 | bottleneck_channels = int(in_channels / 2) 94 | out_channels = out_channels or in_channels 95 | padding_size = ((kernel_size - 1) * dilation + 1) // 2 96 | 97 | # Define the main conv operation 98 | assert dilation == 1 99 | if upsample: 100 | assert not downsample, 'downsample and upsample not possible simultaneously.' 101 | bottleneck_conv = nn.ConvTranspose2d( 102 | bottleneck_channels, 103 | bottleneck_channels, 104 | kernel_size=kernel_size, 105 | bias=False, 106 | dilation=1, 107 | stride=2, 108 | output_padding=padding_size, 109 | padding=padding_size, 110 | groups=groups, 111 | ) 112 | elif downsample: 113 | bottleneck_conv = nn.Conv2d( 114 | bottleneck_channels, 115 | bottleneck_channels, 116 | kernel_size=kernel_size, 117 | bias=False, 118 | dilation=dilation, 119 | stride=2, 120 | padding=padding_size, 121 | groups=groups, 122 | ) 123 | else: 124 | bottleneck_conv = nn.Conv2d( 125 | bottleneck_channels, 126 | bottleneck_channels, 127 | kernel_size=kernel_size, 128 | bias=False, 129 | dilation=dilation, 130 | padding=padding_size, 131 | groups=groups, 132 | ) 133 | 134 | self.layers = nn.Sequential( 135 | OrderedDict( 136 | [ 137 | # First projection with 1x1 kernel 138 | ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)), 139 | ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), 140 | nn.ReLU(inplace=True))), 141 | # Second conv block 142 | ('conv', bottleneck_conv), 143 | ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))), 144 | # Final projection with 1x1 kernel 145 | ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)), 146 | ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels), 147 | nn.ReLU(inplace=True))), 148 | # Regulariser 149 | ('dropout', nn.Dropout2d(p=dropout)), 150 | ] 151 | ) 152 | ) 153 | 154 | if out_channels == in_channels and not downsample and not upsample: 155 | self.projection = None 156 | else: 157 | projection = OrderedDict() 158 | if upsample: 159 | projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)}) 160 | elif downsample: 161 | projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)}) 162 | projection.update( 163 | { 164 | 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 165 | 'bn_skip_proj': nn.BatchNorm2d(out_channels), 166 | } 167 | ) 168 | self.projection = nn.Sequential(projection) 169 | 170 | # pylint: disable=arguments-differ 171 | def forward(self, *args): 172 | (x,) = args 173 | x_residual = self.layers(x) 174 | if self.projection is not None: 175 | if self._downsample: 176 | # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer 177 | x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0) 178 | return x_residual + self.projection(x) 179 | return x_residual + x 180 | 181 | # Cell 182 | class Interpolate(nn.Module): 183 | def __init__(self, scale_factor: int = 2): 184 | super().__init__() 185 | self._interpolate = nn.functional.interpolate 186 | self._scale_factor = scale_factor 187 | 188 | # pylint: disable=arguments-differ 189 | def forward(self, x): 190 | return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False) 191 | 192 | 193 | class UpsamplingConcat(nn.Module): 194 | def __init__(self, in_channels, out_channels, scale_factor=2): 195 | super().__init__() 196 | 197 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False) 198 | 199 | self.conv = nn.Sequential( 200 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), 201 | nn.BatchNorm2d(out_channels), 202 | nn.ReLU(inplace=True), 203 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 204 | nn.BatchNorm2d(out_channels), 205 | nn.ReLU(inplace=True), 206 | ) 207 | 208 | def forward(self, x_to_upsample, x): 209 | x_to_upsample = self.upsample(x_to_upsample) 210 | x_to_upsample = torch.cat([x, x_to_upsample], dim=1) 211 | return self.conv(x_to_upsample) 212 | 213 | 214 | class UpsamplingAdd(nn.Module): 215 | def __init__(self, in_channels, out_channels, scale_factor=2): 216 | super().__init__() 217 | self.upsample_layer = nn.Sequential( 218 | nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False), 219 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False), 220 | nn.BatchNorm2d(out_channels), 221 | ) 222 | 223 | def forward(self, x, x_skip): 224 | x = self.upsample_layer(x) 225 | return x + x_skip 226 | 227 | # Cell 228 | class ResBlock(nn.Module): 229 | " A simple resnet Block" 230 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm='bn', activation='relu'): 231 | super().__init__() 232 | self.convs = nn.Sequential(ConvBlock(in_channels, out_channels, kernel_size, stride, norm=norm, activation=activation), 233 | ConvBlock(out_channels, out_channels, norm=norm, activation=activation) 234 | ) 235 | id_path = [ConvBlock(in_channels, out_channels, kernel_size=1, activation=None, norm=None)] 236 | self.activation = get_activation(activation) 237 | if stride!=1: id_path.insert(1, nn.AvgPool2d(2, stride, ceil_mode=True)) 238 | self.id_path = nn.Sequential(*id_path) 239 | 240 | def forward(self, x): 241 | return self.activation(self.convs(x) + self.id_path(x)) 242 | 243 | # Cell 244 | class SpatialGRU(nn.Module): 245 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a 246 | convolutional gated recurrent unit over the data""" 247 | 248 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'): 249 | super().__init__() 250 | self.input_size = input_size 251 | self.hidden_size = hidden_size 252 | self.gru_bias_init = gru_bias_init 253 | 254 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1) 255 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1) 256 | 257 | self.conv_state_tilde = ConvBlock( 258 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation 259 | ) 260 | 261 | def forward(self, x, state=None, flow=None, mode='bilinear'): 262 | # pylint: disable=unused-argument, arguments-differ 263 | # Check size 264 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.' 265 | b, timesteps, c, h, w = x.size() 266 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}' 267 | 268 | # recurrent layers 269 | rnn_output = [] 270 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state 271 | for t in range(timesteps): 272 | x_t = x[:, t] 273 | if flow is not None: 274 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode) 275 | 276 | # propagate rnn state 277 | rnn_state = self.gru_cell(x_t, rnn_state) 278 | rnn_output.append(rnn_state) 279 | 280 | # reshape rnn output to batch tensor 281 | return torch.stack(rnn_output, dim=1) 282 | 283 | def gru_cell(self, x, state): 284 | # Compute gates 285 | x_and_state = torch.cat([x, state], dim=1) 286 | update_gate = self.conv_update(x_and_state) 287 | reset_gate = self.conv_reset(x_and_state) 288 | # Add bias to initialise gate as close to identity function 289 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init) 290 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init) 291 | 292 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc) 293 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1)) 294 | 295 | output = (1.0 - update_gate) * state + update_gate * state_tilde 296 | return output 297 | 298 | # Cell 299 | class CausalConv3d(nn.Module): 300 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False): 301 | super().__init__() 302 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.' 303 | time_pad = (kernel_size[0] - 1) * dilation[0] 304 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2 305 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2 306 | 307 | # Pad temporally on the left 308 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0) 309 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias) 310 | self.norm = nn.BatchNorm3d(out_channels) 311 | self.activation = nn.ReLU(inplace=True) 312 | 313 | def forward(self, *inputs): 314 | (x,) = inputs 315 | x = self.pad(x) 316 | x = self.conv(x) 317 | x = self.norm(x) 318 | x = self.activation(x) 319 | return x 320 | 321 | # Cell 322 | def conv_1x1x1_norm_activated(in_channels, out_channels): 323 | """1x1x1 3D convolution, normalization and activation layer.""" 324 | return nn.Sequential( 325 | OrderedDict( 326 | [ 327 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)), 328 | ('norm', nn.BatchNorm3d(out_channels)), 329 | ('activation', nn.ReLU(inplace=True)), 330 | ] 331 | ) 332 | ) 333 | 334 | # Cell 335 | class Bottleneck3D(nn.Module): 336 | """ 337 | Defines a bottleneck module with a residual connection 338 | """ 339 | 340 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)): 341 | super().__init__() 342 | bottleneck_channels = in_channels // 2 343 | out_channels = out_channels or in_channels 344 | 345 | self.layers = nn.Sequential( 346 | OrderedDict( 347 | [ 348 | # First projection with 1x1 kernel 349 | ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)), 350 | # Second conv block 351 | ( 352 | 'conv', 353 | CausalConv3d( 354 | bottleneck_channels, 355 | bottleneck_channels, 356 | kernel_size=kernel_size, 357 | dilation=dilation, 358 | bias=False, 359 | ), 360 | ), 361 | # Final projection with 1x1 kernel 362 | ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)), 363 | ] 364 | ) 365 | ) 366 | 367 | if out_channels != in_channels: 368 | self.projection = nn.Sequential( 369 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False), 370 | nn.BatchNorm3d(out_channels), 371 | ) 372 | else: 373 | self.projection = None 374 | 375 | def forward(self, *args): 376 | (x,) = args 377 | x_residual = self.layers(x) 378 | x_features = self.projection(x) if self.projection is not None else x 379 | return x_residual + x_features 380 | 381 | # Cell 382 | class PyramidSpatioTemporalPooling(nn.Module): 383 | """ Spatio-temporal pyramid pooling. 384 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling. 385 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)] 386 | """ 387 | 388 | def __init__(self, in_channels, reduction_channels, pool_sizes): 389 | super().__init__() 390 | self.features = [] 391 | for pool_size in pool_sizes: 392 | assert pool_size[0] == 2, ( 393 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size" 394 | ) 395 | stride = (1, *pool_size[1:]) 396 | padding = (pool_size[0] - 1, 0, 0) 397 | self.features.append( 398 | nn.Sequential( 399 | OrderedDict( 400 | [ 401 | # Pad the input tensor but do not take into account zero padding into the average. 402 | ( 403 | 'avgpool', 404 | torch.nn.AvgPool3d( 405 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False 406 | ), 407 | ), 408 | ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)), 409 | ] 410 | ) 411 | ) 412 | ) 413 | self.features = nn.ModuleList(self.features) 414 | 415 | def forward(self, *inputs): 416 | (x,) = inputs 417 | b, _, t, h, w = x.shape 418 | # Do not include current tensor when concatenating 419 | out = [] 420 | for f in self.features: 421 | # Remove unnecessary padded values (time dimension) on the right 422 | x_pool = f(x)[:, :, :-1].contiguous() 423 | c = x_pool.shape[1] 424 | x_pool = nn.functional.interpolate( 425 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False 426 | ) 427 | x_pool = x_pool.view(b, c, t, h, w) 428 | out.append(x_pool) 429 | out = torch.cat(out, 1) 430 | return out 431 | 432 | # Cell 433 | class TemporalBlock(nn.Module): 434 | """ Temporal block with the following layers: 435 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling 436 | - dropout 437 | - skip connection. 438 | """ 439 | 440 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None): 441 | super().__init__() 442 | self.in_channels = in_channels 443 | self.half_channels = in_channels // 2 444 | self.out_channels = out_channels or self.in_channels 445 | self.kernels = [(2, 3, 3), (1, 3, 3)] 446 | 447 | # Flag for spatio-temporal pyramid pooling 448 | self.use_pyramid_pooling = use_pyramid_pooling 449 | 450 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1 451 | self.convolution_paths = [] 452 | for kernel_size in self.kernels: 453 | self.convolution_paths.append( 454 | nn.Sequential( 455 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels), 456 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size), 457 | ) 458 | ) 459 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels)) 460 | self.convolution_paths = nn.ModuleList(self.convolution_paths) 461 | 462 | agg_in_channels = len(self.convolution_paths) * self.half_channels 463 | 464 | if self.use_pyramid_pooling: 465 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None." 466 | reduction_channels = self.in_channels // 3 467 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes) 468 | agg_in_channels += len(pool_sizes) * reduction_channels 469 | 470 | # Feature aggregation 471 | self.aggregation = nn.Sequential( 472 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),) 473 | 474 | if self.out_channels != self.in_channels: 475 | self.projection = nn.Sequential( 476 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False), 477 | nn.BatchNorm3d(self.out_channels), 478 | ) 479 | else: 480 | self.projection = None 481 | 482 | def forward(self, *inputs): 483 | (x,) = inputs 484 | x_paths = [] 485 | for conv in self.convolution_paths: 486 | x_paths.append(conv(x)) 487 | x_residual = torch.cat(x_paths, dim=1) 488 | if self.use_pyramid_pooling: 489 | x_pool = self.pyramid_pooling(x) 490 | x_residual = torch.cat([x_residual, x_pool], dim=1) 491 | x_residual = self.aggregation(x_residual) 492 | 493 | if self.out_channels != self.in_channels: 494 | x = self.projection(x) 495 | x = x + x_residual 496 | return x 497 | 498 | # Cell 499 | class FuturePrediction(torch.nn.Module): 500 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3): 501 | super().__init__() 502 | self.n_gru_blocks = n_gru_blocks 503 | 504 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample 505 | # from the probabilistic model. The architecture of the model is: 506 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks 507 | self.spatial_grus = [] 508 | self.res_blocks = [] 509 | 510 | for i in range(self.n_gru_blocks): 511 | gru_in_channels = latent_dim if i == 0 else in_channels 512 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels)) 513 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels) 514 | for _ in range(n_res_layers)])) 515 | 516 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus) 517 | self.res_blocks = torch.nn.ModuleList(self.res_blocks) 518 | 519 | def forward(self, x, hidden_state): 520 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w) 521 | for i in range(self.n_gru_blocks): 522 | x = self.spatial_grus[i](x, hidden_state, flow=None) 523 | b, n_future, c, h, w = x.shape 524 | 525 | x = self.res_blocks[i](x.view(b * n_future, c, h, w)) 526 | x = x.view(b, n_future, c, h, w) 527 | 528 | return x -------------------------------------------------------------------------------- /eclipse_pytorch/model.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_model.ipynb (unless otherwise specified). 2 | 3 | __all__ = ['SpatialDownsampler', 'TemporalModel', 'Upsampler', 'IrradianceModule', 'Eclipse'] 4 | 5 | # Cell 6 | from .imports import * 7 | from .layers import * 8 | 9 | # Cell 10 | class SpatialDownsampler(nn.Module): 11 | 12 | def __init__(self, in_channels=3): 13 | super().__init__() 14 | self.conv1 = ConvBlock(in_channels, 64, kernel_size=7, stride=1) 15 | self.blocks = nn.Sequential(ResBlock(64, 64, kernel_size=3, stride=2), 16 | ResBlock(64, 128, kernel_size=3, stride=2), 17 | ResBlock(128,256, kernel_size=3, stride=2)) 18 | 19 | def forward(self, x): 20 | return self.blocks(self.conv1(x)) 21 | 22 | # Cell 23 | class TemporalModel(nn.Module): 24 | def __init__( 25 | self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0, 26 | n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True): 27 | super().__init__() 28 | self.receptive_field = receptive_field 29 | n_temporal_layers = receptive_field - 1 30 | 31 | h, w = input_shape 32 | modules = [] 33 | 34 | block_in_channels = in_channels 35 | block_out_channels = start_out_channels 36 | 37 | for _ in range(n_temporal_layers): 38 | if use_pyramid_pooling: 39 | use_pyramid_pooling = True 40 | pool_sizes = [(2, h, w)] 41 | else: 42 | use_pyramid_pooling = False 43 | pool_sizes = None 44 | temporal = TemporalBlock( 45 | block_in_channels, 46 | block_out_channels, 47 | use_pyramid_pooling=use_pyramid_pooling, 48 | pool_sizes=pool_sizes, 49 | ) 50 | spatial = [ 51 | Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3)) 52 | for _ in range(n_spatial_layers_between_temporal_layers) 53 | ] 54 | temporal_spatial_layers = nn.Sequential(temporal, *spatial) 55 | modules.extend(temporal_spatial_layers) 56 | 57 | block_in_channels = block_out_channels 58 | block_out_channels += extra_in_channels 59 | 60 | self.out_channels = block_in_channels 61 | 62 | self.model = nn.Sequential(*modules) 63 | 64 | def forward(self, x): 65 | # Reshape input tensor to (batch, C, time, H, W) 66 | x = x.permute(0, 2, 1, 3, 4) 67 | x = self.model(x) 68 | x = x.permute(0, 2, 1, 3, 4).contiguous() 69 | return x[:, (self.receptive_field - 1):] 70 | 71 | # Cell 72 | class Upsampler(nn.Module): 73 | def __init__(self, sizes=[128,128,64], n_out=3): 74 | super().__init__() 75 | zsizes = zip(sizes[:-1], sizes[1:]) 76 | self.convs = nn.Sequential(*[Bottleneck(si, sf, upsample=True) for si,sf in zsizes], 77 | Bottleneck(sizes[-1], sizes[-1], upsample=True), 78 | ConvBlock(sizes[-1], n_out, kernel_size=1, activation=None)) 79 | 80 | def forward(self, x): 81 | return self.convs(x) 82 | 83 | # Cell 84 | class IrradianceModule(nn.Module): 85 | def __init__(self): 86 | super().__init__() 87 | self.convs = nn.Sequential(ConvBlock(128, 64, stride=2), 88 | ConvBlock(64, 64), 89 | nn.AdaptiveMaxPool2d(1) 90 | ) 91 | self.linear = nn.Sequential(nn.Flatten(), 92 | nn.Linear(64, 1) 93 | ) 94 | def forward(self, x): 95 | return self.linear(self.convs(x)) 96 | 97 | # Cell 98 | class Eclipse(nn.Module): 99 | """Not very parametric""" 100 | def __init__(self, n_in=3, n_out=4, horizon=5, img_size=(128, 128), n_gru_layers=4, n_res_layers=4, debug=False): 101 | super().__init__() 102 | store_attr() 103 | self.spatial_downsampler = SpatialDownsampler(n_in) 104 | self.temporal_model = TemporalModel(256, 3, input_shape=(img_size[0]//8, img_size[1]//8), start_out_channels=128) 105 | self.future_prediction = FuturePrediction(128, 128, n_gru_blocks=n_gru_layers, n_res_layers=n_res_layers) 106 | self.upsampler = Upsampler(n_out=n_out) 107 | self.irradiance = IrradianceModule() 108 | 109 | def zero_hidden(self, x, horizon): 110 | bs, ch, h, w = x.shape 111 | return x.new_zeros(bs, horizon, ch, h, w) 112 | 113 | def forward(self, imgs): 114 | x = torch.stack([self.spatial_downsampler(img) for img in imgs], dim=1) 115 | 116 | #encode temporal model 117 | states = self.temporal_model(x) 118 | if self.debug: print(f'{states.shape=}') 119 | 120 | #get hidden state 121 | present_state = states[:, -1:] 122 | if self.debug: print(f'{present_state.shape=}') 123 | 124 | 125 | # Prepare future prediction input 126 | hidden_state = present_state.squeeze() 127 | if self.debug: print(f'{hidden_state.shape=}') 128 | 129 | future_prediction_input = self.zero_hidden(hidden_state, self.horizon) 130 | 131 | # Recursively predict future states 132 | future_states = self.future_prediction(future_prediction_input, hidden_state) 133 | 134 | # Concatenate present state 135 | future_states = torch.cat([present_state, future_states], dim=1) 136 | if self.debug: print(f'{future_states.shape=}') 137 | 138 | #decode outputs 139 | masks, irradiances = [], [] 140 | 141 | for state in future_states.unbind(dim=1): 142 | masks.append(self.upsampler(state)) 143 | irradiances.append(self.irradiance(state)) 144 | return {'masks': masks, 'irradiances': torch.cat(irradiances, dim=-1)} 145 | -------------------------------------------------------------------------------- /nbs/00_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# default_exp model" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# The model\n", 17 | "\n", 18 | "> from the paper from [Paletta et al](https://arxiv.org/pdf/2104.12419v1.pdf)" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "We will try to implement as close as possible the architecture from the paper `ECLIPSE : Envisioning Cloud Induced Perturbations in Solar Energy`" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "![Image](images/eclipse_diagram.png)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "#export\n", 42 | "from eclipse_pytorch.imports import *\n", 43 | "from eclipse_pytorch.layers import *" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "## 1. Spatial Downsampler\n", 51 | "> A resnet encoder to get image features" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "You could use any spatial downsampler as you want, but the paper states a simple resnet arch..." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "#export\n", 68 | "class SpatialDownsampler(nn.Module):\n", 69 | " \n", 70 | " def __init__(self, in_channels=3):\n", 71 | " super().__init__()\n", 72 | " self.conv1 = ConvBlock(in_channels, 64, kernel_size=7, stride=1)\n", 73 | " self.blocks = nn.Sequential(ResBlock(64, 64, kernel_size=3, stride=2), \n", 74 | " ResBlock(64, 128, kernel_size=3, stride=2), \n", 75 | " ResBlock(128,256, kernel_size=3, stride=2))\n", 76 | " \n", 77 | " def forward(self, x):\n", 78 | " return self.blocks(self.conv1(x))" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "sd = SpatialDownsampler()" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "torch.Size([1, 4, 256, 8, 8])" 99 | ] 100 | }, 101 | "execution_count": null, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | } 105 | ], 106 | "source": [ 107 | "images = [torch.rand(1, 3, 64, 64) for _ in range(4)]\n", 108 | "features = torch.stack([sd(image) for image in images], dim=1)\n", 109 | "features.shape" 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "## 2. Temporal Encoder" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "#export\n", 126 | "class TemporalModel(nn.Module):\n", 127 | " def __init__(\n", 128 | " self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0,\n", 129 | " n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True):\n", 130 | " super().__init__()\n", 131 | " self.receptive_field = receptive_field\n", 132 | " n_temporal_layers = receptive_field - 1\n", 133 | "\n", 134 | " h, w = input_shape\n", 135 | " modules = []\n", 136 | "\n", 137 | " block_in_channels = in_channels\n", 138 | " block_out_channels = start_out_channels\n", 139 | "\n", 140 | " for _ in range(n_temporal_layers):\n", 141 | " if use_pyramid_pooling:\n", 142 | " use_pyramid_pooling = True\n", 143 | " pool_sizes = [(2, h, w)]\n", 144 | " else:\n", 145 | " use_pyramid_pooling = False\n", 146 | " pool_sizes = None\n", 147 | " temporal = TemporalBlock(\n", 148 | " block_in_channels,\n", 149 | " block_out_channels,\n", 150 | " use_pyramid_pooling=use_pyramid_pooling,\n", 151 | " pool_sizes=pool_sizes,\n", 152 | " )\n", 153 | " spatial = [\n", 154 | " Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3))\n", 155 | " for _ in range(n_spatial_layers_between_temporal_layers)\n", 156 | " ]\n", 157 | " temporal_spatial_layers = nn.Sequential(temporal, *spatial)\n", 158 | " modules.extend(temporal_spatial_layers)\n", 159 | "\n", 160 | " block_in_channels = block_out_channels\n", 161 | " block_out_channels += extra_in_channels\n", 162 | "\n", 163 | " self.out_channels = block_in_channels\n", 164 | "\n", 165 | " self.model = nn.Sequential(*modules)\n", 166 | "\n", 167 | " def forward(self, x):\n", 168 | " # Reshape input tensor to (batch, C, time, H, W)\n", 169 | " x = x.permute(0, 2, 1, 3, 4)\n", 170 | " x = self.model(x)\n", 171 | " x = x.permute(0, 2, 1, 3, 4).contiguous()\n", 172 | " return x[:, (self.receptive_field - 1):]" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "tm = TemporalModel(256, 3, (8,8), 128)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [ 189 | { 190 | "data": { 191 | "text/plain": [ 192 | "torch.Size([1, 2, 128, 8, 8])" 193 | ] 194 | }, 195 | "execution_count": null, 196 | "metadata": {}, 197 | "output_type": "execute_result" 198 | } 199 | ], 200 | "source": [ 201 | "temp_encoded = tm(features)\n", 202 | "temp_encoded.shape" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": {}, 208 | "source": [ 209 | "## 3. Future State Predictions" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "fp = FuturePrediction(128, 128, n_gru_blocks=4, n_res_layers=4)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "data": { 228 | "text/plain": [ 229 | "torch.Size([1, 4, 128, 8, 8])" 230 | ] 231 | }, 232 | "execution_count": null, 233 | "metadata": {}, 234 | "output_type": "execute_result" 235 | } 236 | ], 237 | "source": [ 238 | "hidden = torch.rand(1, 128, 8, 8)\n", 239 | "x = torch.rand(1, 4, 128, 8, 8)\n", 240 | "fp(x, hidden).shape" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "metadata": {}, 246 | "source": [ 247 | "## 4A. Segmentation Decoder" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "data": { 257 | "text/plain": [ 258 | "Bottleneck(\n", 259 | " (layers): Sequential(\n", 260 | " (conv_down_project): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 261 | " (abn_down_project): Sequential(\n", 262 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 263 | " (1): ReLU(inplace=True)\n", 264 | " )\n", 265 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)\n", 266 | " (abn): Sequential(\n", 267 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 268 | " (1): ReLU(inplace=True)\n", 269 | " )\n", 270 | " (conv_up_project): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 271 | " (abn_up_project): Sequential(\n", 272 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 273 | " (1): ReLU(inplace=True)\n", 274 | " )\n", 275 | " (dropout): Dropout2d(p=0.0, inplace=False)\n", 276 | " )\n", 277 | " (projection): Sequential(\n", 278 | " (upsample_skip_proj): Interpolate()\n", 279 | " (conv_skip_proj): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", 280 | " (bn_skip_proj): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 281 | " )\n", 282 | ")" 283 | ] 284 | }, 285 | "execution_count": null, 286 | "metadata": {}, 287 | "output_type": "execute_result" 288 | } 289 | ], 290 | "source": [ 291 | "bn = Bottleneck(256, 128, upsample=True)\n", 292 | "bn" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "data": { 302 | "text/plain": [ 303 | "torch.Size([4, 256, 8, 8])" 304 | ] 305 | }, 306 | "execution_count": null, 307 | "metadata": {}, 308 | "output_type": "execute_result" 309 | } 310 | ], 311 | "source": [ 312 | "features[0].shape" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [ 320 | { 321 | "data": { 322 | "text/plain": [ 323 | "torch.Size([1, 128, 64, 64])" 324 | ] 325 | }, 326 | "execution_count": null, 327 | "metadata": {}, 328 | "output_type": "execute_result" 329 | } 330 | ], 331 | "source": [ 332 | "x = torch.rand(1,256,32,32)\n", 333 | "bn(x).shape" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [ 342 | "#export\n", 343 | "class Upsampler(nn.Module):\n", 344 | " def __init__(self, sizes=[128,128,64], n_out=3):\n", 345 | " super().__init__()\n", 346 | " zsizes = zip(sizes[:-1], sizes[1:])\n", 347 | " self.convs = nn.Sequential(*[Bottleneck(si, sf, upsample=True) for si,sf in zsizes], \n", 348 | " Bottleneck(sizes[-1], sizes[-1], upsample=True), \n", 349 | " ConvBlock(sizes[-1], n_out, kernel_size=1, activation=None))\n", 350 | " \n", 351 | " def forward(self, x):\n", 352 | " return self.convs(x)" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": null, 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "data": { 362 | "text/plain": [ 363 | "torch.Size([1, 3, 256, 256])" 364 | ] 365 | }, 366 | "execution_count": null, 367 | "metadata": {}, 368 | "output_type": "execute_result" 369 | } 370 | ], 371 | "source": [ 372 | "us = Upsampler()\n", 373 | "\n", 374 | "x = torch.rand(1,128,32,32)\n", 375 | "us(x).shape" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "## 4B. Irradiance Module" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "#export\n", 392 | "class IrradianceModule(nn.Module):\n", 393 | " def __init__(self):\n", 394 | " super().__init__()\n", 395 | " self.convs = nn.Sequential(ConvBlock(128, 64, stride=2), \n", 396 | " ConvBlock(64, 64),\n", 397 | " nn.AdaptiveMaxPool2d(1)\n", 398 | " )\n", 399 | " self.linear = nn.Sequential(nn.Flatten(), \n", 400 | " nn.Linear(64, 1)\n", 401 | " )\n", 402 | " def forward(self, x):\n", 403 | " return self.linear(self.convs(x))" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "im = IrradianceModule()" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "data": { 422 | "text/plain": [ 423 | "IrradianceModule(\n", 424 | " (convs): Sequential(\n", 425 | " (0): ConvBlock(\n", 426 | " (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 427 | " (1): ReLU(inplace=True)\n", 428 | " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 429 | " )\n", 430 | " (1): ConvBlock(\n", 431 | " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 432 | " (1): ReLU(inplace=True)\n", 433 | " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 434 | " )\n", 435 | " (2): AdaptiveMaxPool2d(output_size=1)\n", 436 | " )\n", 437 | " (linear): Sequential(\n", 438 | " (0): Flatten(start_dim=1, end_dim=-1)\n", 439 | " (1): Linear(in_features=64, out_features=1, bias=True)\n", 440 | " )\n", 441 | ")" 442 | ] 443 | }, 444 | "execution_count": null, 445 | "metadata": {}, 446 | "output_type": "execute_result" 447 | } 448 | ], 449 | "source": [ 450 | "im" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": {}, 457 | "outputs": [ 458 | { 459 | "data": { 460 | "text/plain": [ 461 | "torch.Size([2, 1])" 462 | ] 463 | }, 464 | "execution_count": null, 465 | "metadata": {}, 466 | "output_type": "execute_result" 467 | } 468 | ], 469 | "source": [ 470 | "x = torch.rand(2, 128, 32, 32)\n", 471 | "im(x).shape" 472 | ] 473 | }, 474 | { 475 | "cell_type": "markdown", 476 | "metadata": {}, 477 | "source": [ 478 | "## Everything Together..." 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "execution_count": null, 484 | "metadata": {}, 485 | "outputs": [], 486 | "source": [ 487 | "#export\n", 488 | "class Eclipse(nn.Module):\n", 489 | " \"\"\"Not very parametric\"\"\"\n", 490 | " def __init__(self, n_in=3, n_out=4, horizon=5, img_size=(128, 128), n_gru_layers=4, n_res_layers=4, debug=False):\n", 491 | " super().__init__()\n", 492 | " store_attr()\n", 493 | " self.spatial_downsampler = SpatialDownsampler(n_in)\n", 494 | " self.temporal_model = TemporalModel(256, 3, input_shape=(img_size[0]//8, img_size[1]//8), start_out_channels=128)\n", 495 | " self.future_prediction = FuturePrediction(128, 128, n_gru_blocks=n_gru_layers, n_res_layers=n_res_layers)\n", 496 | " self.upsampler = Upsampler(n_out=n_out)\n", 497 | " self.irradiance = IrradianceModule()\n", 498 | " \n", 499 | " def zero_hidden(self, x, horizon):\n", 500 | " bs, ch, h, w = x.shape\n", 501 | " return x.new_zeros(bs, horizon, ch, h, w)\n", 502 | " \n", 503 | " def forward(self, imgs):\n", 504 | " x = torch.stack([self.spatial_downsampler(img) for img in imgs], dim=1)\n", 505 | " \n", 506 | " #encode temporal model\n", 507 | " states = self.temporal_model(x)\n", 508 | " if self.debug: print(f'{states.shape=}')\n", 509 | " \n", 510 | " #get hidden state\n", 511 | " present_state = states[:, -1:]\n", 512 | " if self.debug: print(f'{present_state.shape=}')\n", 513 | " \n", 514 | " \n", 515 | " # Prepare future prediction input\n", 516 | " hidden_state = present_state.squeeze()\n", 517 | " if self.debug: print(f'{hidden_state.shape=}')\n", 518 | " \n", 519 | " future_prediction_input = self.zero_hidden(hidden_state, self.horizon)\n", 520 | " \n", 521 | " # Recursively predict future states\n", 522 | " future_states = self.future_prediction(future_prediction_input, hidden_state)\n", 523 | "\n", 524 | " # Concatenate present state\n", 525 | " future_states = torch.cat([present_state, future_states], dim=1)\n", 526 | " if self.debug: print(f'{future_states.shape=}')\n", 527 | " \n", 528 | " #decode outputs\n", 529 | " masks, irradiances = [], []\n", 530 | "\n", 531 | " for state in future_states.unbind(dim=1):\n", 532 | " masks.append(self.upsampler(state))\n", 533 | " irradiances.append(self.irradiance(state))\n", 534 | " return {'masks': masks, 'irradiances': torch.cat(irradiances, dim=-1)}\n", 535 | " " 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "eclipse = Eclipse(img_size=(256, 192), debug=True)" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": {}, 551 | "outputs": [ 552 | { 553 | "name": "stdout", 554 | "output_type": "stream", 555 | "text": [ 556 | "states.shape=torch.Size([2, 2, 128, 32, 24])\n", 557 | "present_state.shape=torch.Size([2, 1, 128, 32, 24])\n", 558 | "hidden_state.shape=torch.Size([2, 128, 32, 24])\n", 559 | "future_states.shape=torch.Size([2, 6, 128, 32, 24])\n" 560 | ] 561 | }, 562 | { 563 | "data": { 564 | "text/plain": [ 565 | "(torch.Size([2, 4, 256, 192]), torch.Size([2, 6]))" 566 | ] 567 | }, 568 | "execution_count": null, 569 | "metadata": {}, 570 | "output_type": "execute_result" 571 | } 572 | ], 573 | "source": [ 574 | "preds = eclipse([torch.rand(2, 3, 256, 192) for _ in range(4)])\n", 575 | "\n", 576 | "preds['masks'][0].shape, preds['irradiances'].shape" 577 | ] 578 | }, 579 | { 580 | "cell_type": "markdown", 581 | "metadata": {}, 582 | "source": [ 583 | "## Export" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "metadata": {}, 590 | "outputs": [ 591 | { 592 | "name": "stdout", 593 | "output_type": "stream", 594 | "text": [ 595 | "Converted 00_model.ipynb.\n", 596 | "Converted 01_layers.ipynb.\n", 597 | "Converted index.ipynb.\n" 598 | ] 599 | } 600 | ], 601 | "source": [ 602 | "# hide\n", 603 | "from nbdev.export import *\n", 604 | "notebook2script()" 605 | ] 606 | } 607 | ], 608 | "metadata": { 609 | "kernelspec": { 610 | "display_name": "Python 3", 611 | "language": "python", 612 | "name": "python3" 613 | } 614 | }, 615 | "nbformat": 4, 616 | "nbformat_minor": 4 617 | } 618 | -------------------------------------------------------------------------------- /nbs/01_layers.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "94f61f73-8683-4638-b730-6d9174eb212d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "# default_exp layers" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "0ea22c98-3207-4673-8ee8-7396637f45cd", 16 | "metadata": {}, 17 | "source": [ 18 | "# Layers\n", 19 | "> most of them come from [fiery](https://github.com/wayveai/fiery)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "113f4b7d-ace4-4416-a067-9d7025ef7ae6", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "#export\n", 30 | "from eclipse_pytorch.imports import *" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "e8f80af7-bcc4-43a8-a6bd-283991737d8f", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "#export\n", 41 | "def get_activation(activation):\n", 42 | " if activation == 'relu':\n", 43 | " return nn.ReLU(inplace=True)\n", 44 | " elif activation == 'lrelu':\n", 45 | " return nn.LeakyReLU(0.1, inplace=True)\n", 46 | " elif activation == 'elu':\n", 47 | " return nn.ELU(inplace=True)\n", 48 | " elif activation == 'tanh':\n", 49 | " return nn.Tanh(inplace=True)\n", 50 | " else:\n", 51 | " raise ValueError('Invalid activation {}'.format(activation))\n", 52 | " \n", 53 | "def get_norm(norm, out_channels):\n", 54 | " if norm == 'bn':\n", 55 | " return nn.BatchNorm2d(out_channels)\n", 56 | " elif norm == 'in':\n", 57 | " return nn.InstanceNorm2d(out_channels)\n", 58 | " else:\n", 59 | " raise ValueError('Invalid norm {}'.format(norm))" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "f158184b-c6d0-4a22-8625-f8514b1b524f", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "#export\n", 70 | "\n", 71 | "def init_linear(m, act_func=None, init='auto', bias_std=0.01):\n", 72 | " if getattr(m,'bias',None) is not None and bias_std is not None:\n", 73 | " if bias_std != 0: normal_(m.bias, 0, bias_std)\n", 74 | " else: m.bias.data.zero_()\n", 75 | " if init=='auto':\n", 76 | " if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_\n", 77 | " else: init = getattr(act_func.__class__, '__default_init__', None)\n", 78 | " if init is None: init = getattr(act_func, '__default_init__', None)\n", 79 | " if init is not None: init(m.weight)\n", 80 | "\n", 81 | "\n", 82 | "class ConvBlock(nn.Sequential):\n", 83 | " \"\"\"2D convolution followed by\n", 84 | " - an optional normalisation (batch norm or instance norm)\n", 85 | " - an optional activation (ReLU, LeakyReLU, or tanh)\n", 86 | " \"\"\"\n", 87 | "\n", 88 | " def __init__(\n", 89 | " self,\n", 90 | " in_channels,\n", 91 | " out_channels=None,\n", 92 | " kernel_size=3,\n", 93 | " stride=1,\n", 94 | " norm='bn',\n", 95 | " activation='relu',\n", 96 | " bias=False,\n", 97 | " transpose=False,\n", 98 | " init='auto'\n", 99 | " ):\n", 100 | " \n", 101 | " out_channels = out_channels or in_channels\n", 102 | " padding = (kernel_size-1)//2\n", 103 | " conv_cls = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1)\n", 104 | " conv = conv_cls(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)\n", 105 | " if activation is not None: activation = get_activation(activation)\n", 106 | " init_linear(conv, activation, init=init)\n", 107 | " layers = [conv]\n", 108 | " if activation is not None: layers.append(activation)\n", 109 | " if norm is not None: layers.append(get_norm(norm, out_channels))\n", 110 | " super().__init__(*layers)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "cd4f1349-6d9d-4a3c-8be3-92f2a9d5c173", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "#export\n", 121 | "class Bottleneck(nn.Module):\n", 122 | " \"\"\"\n", 123 | " Defines a bottleneck module with a residual connection\n", 124 | " \"\"\"\n", 125 | "\n", 126 | " def __init__(\n", 127 | " self,\n", 128 | " in_channels,\n", 129 | " out_channels=None,\n", 130 | " kernel_size=3,\n", 131 | " dilation=1,\n", 132 | " groups=1,\n", 133 | " upsample=False,\n", 134 | " downsample=False,\n", 135 | " dropout=0.0,\n", 136 | " ):\n", 137 | " super().__init__()\n", 138 | " self._downsample = downsample\n", 139 | " bottleneck_channels = int(in_channels / 2)\n", 140 | " out_channels = out_channels or in_channels\n", 141 | " padding_size = ((kernel_size - 1) * dilation + 1) // 2\n", 142 | "\n", 143 | " # Define the main conv operation\n", 144 | " assert dilation == 1\n", 145 | " if upsample:\n", 146 | " assert not downsample, 'downsample and upsample not possible simultaneously.'\n", 147 | " bottleneck_conv = nn.ConvTranspose2d(\n", 148 | " bottleneck_channels,\n", 149 | " bottleneck_channels,\n", 150 | " kernel_size=kernel_size,\n", 151 | " bias=False,\n", 152 | " dilation=1,\n", 153 | " stride=2,\n", 154 | " output_padding=padding_size,\n", 155 | " padding=padding_size,\n", 156 | " groups=groups,\n", 157 | " )\n", 158 | " elif downsample:\n", 159 | " bottleneck_conv = nn.Conv2d(\n", 160 | " bottleneck_channels,\n", 161 | " bottleneck_channels,\n", 162 | " kernel_size=kernel_size,\n", 163 | " bias=False,\n", 164 | " dilation=dilation,\n", 165 | " stride=2,\n", 166 | " padding=padding_size,\n", 167 | " groups=groups,\n", 168 | " )\n", 169 | " else:\n", 170 | " bottleneck_conv = nn.Conv2d(\n", 171 | " bottleneck_channels,\n", 172 | " bottleneck_channels,\n", 173 | " kernel_size=kernel_size,\n", 174 | " bias=False,\n", 175 | " dilation=dilation,\n", 176 | " padding=padding_size,\n", 177 | " groups=groups,\n", 178 | " )\n", 179 | "\n", 180 | " self.layers = nn.Sequential(\n", 181 | " OrderedDict(\n", 182 | " [\n", 183 | " # First projection with 1x1 kernel\n", 184 | " ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)),\n", 185 | " ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels),\n", 186 | " nn.ReLU(inplace=True))),\n", 187 | " # Second conv block\n", 188 | " ('conv', bottleneck_conv),\n", 189 | " ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))),\n", 190 | " # Final projection with 1x1 kernel\n", 191 | " ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)),\n", 192 | " ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels),\n", 193 | " nn.ReLU(inplace=True))),\n", 194 | " # Regulariser\n", 195 | " ('dropout', nn.Dropout2d(p=dropout)),\n", 196 | " ]\n", 197 | " )\n", 198 | " )\n", 199 | "\n", 200 | " if out_channels == in_channels and not downsample and not upsample:\n", 201 | " self.projection = None\n", 202 | " else:\n", 203 | " projection = OrderedDict()\n", 204 | " if upsample:\n", 205 | " projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)})\n", 206 | " elif downsample:\n", 207 | " projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)})\n", 208 | " projection.update(\n", 209 | " {\n", 210 | " 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),\n", 211 | " 'bn_skip_proj': nn.BatchNorm2d(out_channels),\n", 212 | " }\n", 213 | " )\n", 214 | " self.projection = nn.Sequential(projection)\n", 215 | "\n", 216 | " # pylint: disable=arguments-differ\n", 217 | " def forward(self, *args):\n", 218 | " (x,) = args\n", 219 | " x_residual = self.layers(x)\n", 220 | " if self.projection is not None:\n", 221 | " if self._downsample:\n", 222 | " # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer\n", 223 | " x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0)\n", 224 | " return x_residual + self.projection(x)\n", 225 | " return x_residual + x" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "id": "1bd71c90-9228-4a9f-9821-20328c57ba4d", 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "#export\n", 236 | "class Interpolate(nn.Module):\n", 237 | " def __init__(self, scale_factor: int = 2):\n", 238 | " super().__init__()\n", 239 | " self._interpolate = nn.functional.interpolate\n", 240 | " self._scale_factor = scale_factor\n", 241 | "\n", 242 | " # pylint: disable=arguments-differ\n", 243 | " def forward(self, x):\n", 244 | " return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False)\n", 245 | "\n", 246 | "\n", 247 | "class UpsamplingConcat(nn.Module):\n", 248 | " def __init__(self, in_channels, out_channels, scale_factor=2):\n", 249 | " super().__init__()\n", 250 | "\n", 251 | " self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)\n", 252 | "\n", 253 | " self.conv = nn.Sequential(\n", 254 | " nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),\n", 255 | " nn.BatchNorm2d(out_channels),\n", 256 | " nn.ReLU(inplace=True),\n", 257 | " nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),\n", 258 | " nn.BatchNorm2d(out_channels),\n", 259 | " nn.ReLU(inplace=True),\n", 260 | " )\n", 261 | "\n", 262 | " def forward(self, x_to_upsample, x):\n", 263 | " x_to_upsample = self.upsample(x_to_upsample)\n", 264 | " x_to_upsample = torch.cat([x, x_to_upsample], dim=1)\n", 265 | " return self.conv(x_to_upsample)\n", 266 | "\n", 267 | "\n", 268 | "class UpsamplingAdd(nn.Module):\n", 269 | " def __init__(self, in_channels, out_channels, scale_factor=2):\n", 270 | " super().__init__()\n", 271 | " self.upsample_layer = nn.Sequential(\n", 272 | " nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),\n", 273 | " nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),\n", 274 | " nn.BatchNorm2d(out_channels),\n", 275 | " )\n", 276 | "\n", 277 | " def forward(self, x, x_skip):\n", 278 | " x = self.upsample_layer(x)\n", 279 | " return x + x_skip" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "id": "b46ad863-f4b8-400b-a62e-b1d593ff28ca", 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "#export\n", 290 | "class ResBlock(nn.Module):\n", 291 | " \" A simple resnet Block\"\n", 292 | " def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm='bn', activation='relu'):\n", 293 | " super().__init__()\n", 294 | " self.convs = nn.Sequential(ConvBlock(in_channels, out_channels, kernel_size, stride, norm=norm, activation=activation),\n", 295 | " ConvBlock(out_channels, out_channels, norm=norm, activation=activation)\n", 296 | " )\n", 297 | " id_path = [ConvBlock(in_channels, out_channels, kernel_size=1, activation=None, norm=None)]\n", 298 | " self.activation = get_activation(activation)\n", 299 | " if stride!=1: id_path.insert(1, nn.AvgPool2d(2, stride, ceil_mode=True))\n", 300 | " self.id_path = nn.Sequential(*id_path)\n", 301 | " \n", 302 | " def forward(self, x):\n", 303 | " return self.activation(self.convs(x) + self.id_path(x))" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "id": "87cd945e-757e-4d81-abee-8ba5fce5e02b", 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "res_block = ResBlock(64, 128, stride=2)" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "id": "cc730d95-89a2-454b-8949-fa45879ccf41", 320 | "metadata": {}, 321 | "outputs": [ 322 | { 323 | "data": { 324 | "text/plain": [ 325 | "torch.Size([1, 128, 10, 10])" 326 | ] 327 | }, 328 | "execution_count": null, 329 | "metadata": {}, 330 | "output_type": "execute_result" 331 | } 332 | ], 333 | "source": [ 334 | "x = torch.rand(1,64, 20,20)\n", 335 | "res_block(x).shape" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "id": "4568af38-4494-4fbe-aacb-6d3c5fa79e51", 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "#export\n", 346 | "class SpatialGRU(nn.Module):\n", 347 | " \"\"\"A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a\n", 348 | " convolutional gated recurrent unit over the data\"\"\"\n", 349 | "\n", 350 | " def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'):\n", 351 | " super().__init__()\n", 352 | " self.input_size = input_size\n", 353 | " self.hidden_size = hidden_size\n", 354 | " self.gru_bias_init = gru_bias_init\n", 355 | "\n", 356 | " self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)\n", 357 | " self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)\n", 358 | "\n", 359 | " self.conv_state_tilde = ConvBlock(\n", 360 | " input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation\n", 361 | " )\n", 362 | "\n", 363 | " def forward(self, x, state=None, flow=None, mode='bilinear'):\n", 364 | " # pylint: disable=unused-argument, arguments-differ\n", 365 | " # Check size\n", 366 | " assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.'\n", 367 | " b, timesteps, c, h, w = x.size()\n", 368 | " assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}'\n", 369 | "\n", 370 | " # recurrent layers\n", 371 | " rnn_output = []\n", 372 | " rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state\n", 373 | " for t in range(timesteps):\n", 374 | " x_t = x[:, t]\n", 375 | " if flow is not None:\n", 376 | " rnn_state = warp_features(rnn_state, flow[:, t], mode=mode)\n", 377 | "\n", 378 | " # propagate rnn state\n", 379 | " rnn_state = self.gru_cell(x_t, rnn_state)\n", 380 | " rnn_output.append(rnn_state)\n", 381 | "\n", 382 | " # reshape rnn output to batch tensor\n", 383 | " return torch.stack(rnn_output, dim=1)\n", 384 | "\n", 385 | " def gru_cell(self, x, state):\n", 386 | " # Compute gates\n", 387 | " x_and_state = torch.cat([x, state], dim=1)\n", 388 | " update_gate = self.conv_update(x_and_state)\n", 389 | " reset_gate = self.conv_reset(x_and_state)\n", 390 | " # Add bias to initialise gate as close to identity function\n", 391 | " update_gate = torch.sigmoid(update_gate + self.gru_bias_init)\n", 392 | " reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init)\n", 393 | "\n", 394 | " # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc)\n", 395 | " state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1))\n", 396 | "\n", 397 | " output = (1.0 - update_gate) * state + update_gate * state_tilde\n", 398 | " return output" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "id": "25fdb367-7794-48bc-94aa-1ba421a7c8b0", 405 | "metadata": {}, 406 | "outputs": [], 407 | "source": [ 408 | "sgru = SpatialGRU(input_size=32, hidden_size=64)" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "id": "a55cd120-755b-4783-b356-03044cb5abe3", 414 | "metadata": {}, 415 | "source": [ 416 | "without hidden state" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "id": "8207186f-761b-4744-8e37-48f47101525d", 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [ 426 | "x = torch.rand(1,3,32,8, 8)\n", 427 | "test_eq(sgru(x).shape, (1,3,64,8,8))" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "id": "bdcaa79f-6d57-4b74-b956-f9dcc3df0589", 433 | "metadata": {}, 434 | "source": [ 435 | "with hidden\n", 436 | "> hidden.shape = `(bs, hidden_size, h, w)`" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "id": "53e47adc-2ce8-4cd9-86b4-433cbd47dbd1", 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "data": { 447 | "text/plain": [ 448 | "torch.Size([1, 3, 64, 8, 8])" 449 | ] 450 | }, 451 | "execution_count": null, 452 | "metadata": {}, 453 | "output_type": "execute_result" 454 | } 455 | ], 456 | "source": [ 457 | "x = torch.rand(1,3,32,8, 8)\n", 458 | "hidden = torch.rand(1,64,8,8)\n", 459 | "# test_eq(sgru(x).shape, (1,3,64,8,8))\n", 460 | "\n", 461 | "sgru(x, hidden).shape" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": null, 467 | "id": "4f4a2fe0-535a-4d51-ab22-81bcefac5c7c", 468 | "metadata": {}, 469 | "outputs": [], 470 | "source": [ 471 | "#export\n", 472 | "class CausalConv3d(nn.Module):\n", 473 | " def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False):\n", 474 | " super().__init__()\n", 475 | " assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'\n", 476 | " time_pad = (kernel_size[0] - 1) * dilation[0]\n", 477 | " height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2\n", 478 | " width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2\n", 479 | "\n", 480 | " # Pad temporally on the left\n", 481 | " self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)\n", 482 | " self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias)\n", 483 | " self.norm = nn.BatchNorm3d(out_channels)\n", 484 | " self.activation = nn.ReLU(inplace=True)\n", 485 | "\n", 486 | " def forward(self, *inputs):\n", 487 | " (x,) = inputs\n", 488 | " x = self.pad(x)\n", 489 | " x = self.conv(x)\n", 490 | " x = self.norm(x)\n", 491 | " x = self.activation(x)\n", 492 | " return x" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "id": "f97c1b07-dcb1-444a-8372-bdd3ff9dd95a", 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "cc3d = CausalConv3d(64, 128)" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": null, 508 | "id": "c00cb8c1-7513-40e2-85ef-646ddd20d585", 509 | "metadata": {}, 510 | "outputs": [ 511 | { 512 | "data": { 513 | "text/plain": [ 514 | "torch.Size([1, 128, 4, 8, 8])" 515 | ] 516 | }, 517 | "execution_count": null, 518 | "metadata": {}, 519 | "output_type": "execute_result" 520 | } 521 | ], 522 | "source": [ 523 | "x = torch.rand(1,64,4,8,8)\n", 524 | "cc3d(x).shape" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "id": "6ac267c7-a4f3-47dd-a15a-04c61c82d434", 531 | "metadata": {}, 532 | "outputs": [], 533 | "source": [ 534 | "#export\n", 535 | "def conv_1x1x1_norm_activated(in_channels, out_channels):\n", 536 | " \"\"\"1x1x1 3D convolution, normalization and activation layer.\"\"\"\n", 537 | " return nn.Sequential(\n", 538 | " OrderedDict(\n", 539 | " [\n", 540 | " ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)),\n", 541 | " ('norm', nn.BatchNorm3d(out_channels)),\n", 542 | " ('activation', nn.ReLU(inplace=True)),\n", 543 | " ]\n", 544 | " )\n", 545 | " )" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": null, 551 | "id": "e318b3fd-d944-4d98-8965-88a1dd2d8d8e", 552 | "metadata": {}, 553 | "outputs": [ 554 | { 555 | "data": { 556 | "text/plain": [ 557 | "Sequential(\n", 558 | " (conv): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n", 559 | " (norm): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 560 | " (activation): ReLU(inplace=True)\n", 561 | ")" 562 | ] 563 | }, 564 | "execution_count": null, 565 | "metadata": {}, 566 | "output_type": "execute_result" 567 | } 568 | ], 569 | "source": [ 570 | "conv_1x1x1_norm_activated(2, 4)" 571 | ] 572 | }, 573 | { 574 | "cell_type": "code", 575 | "execution_count": null, 576 | "id": "5574dfd4-3891-464d-b0d6-a67466fed549", 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "#export\n", 581 | "class Bottleneck3D(nn.Module):\n", 582 | " \"\"\"\n", 583 | " Defines a bottleneck module with a residual connection\n", 584 | " \"\"\"\n", 585 | "\n", 586 | " def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)):\n", 587 | " super().__init__()\n", 588 | " bottleneck_channels = in_channels // 2\n", 589 | " out_channels = out_channels or in_channels\n", 590 | "\n", 591 | " self.layers = nn.Sequential(\n", 592 | " OrderedDict(\n", 593 | " [\n", 594 | " # First projection with 1x1 kernel\n", 595 | " ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)),\n", 596 | " # Second conv block\n", 597 | " (\n", 598 | " 'conv',\n", 599 | " CausalConv3d(\n", 600 | " bottleneck_channels,\n", 601 | " bottleneck_channels,\n", 602 | " kernel_size=kernel_size,\n", 603 | " dilation=dilation,\n", 604 | " bias=False,\n", 605 | " ),\n", 606 | " ),\n", 607 | " # Final projection with 1x1 kernel\n", 608 | " ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)),\n", 609 | " ]\n", 610 | " )\n", 611 | " )\n", 612 | "\n", 613 | " if out_channels != in_channels:\n", 614 | " self.projection = nn.Sequential(\n", 615 | " nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),\n", 616 | " nn.BatchNorm3d(out_channels),\n", 617 | " )\n", 618 | " else:\n", 619 | " self.projection = None\n", 620 | "\n", 621 | " def forward(self, *args):\n", 622 | " (x,) = args\n", 623 | " x_residual = self.layers(x)\n", 624 | " x_features = self.projection(x) if self.projection is not None else x\n", 625 | " return x_residual + x_features" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": null, 631 | "id": "3dbcf0d0-1cc2-4914-9ba0-30b99f03a4c3", 632 | "metadata": {}, 633 | "outputs": [], 634 | "source": [ 635 | "bn3d = Bottleneck3D(8, 12)" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": null, 641 | "id": "12917923-c26f-4c40-b2ee-2836d2ffb4f6", 642 | "metadata": {}, 643 | "outputs": [ 644 | { 645 | "data": { 646 | "text/plain": [ 647 | "torch.Size([1, 12, 4, 8, 8])" 648 | ] 649 | }, 650 | "execution_count": null, 651 | "metadata": {}, 652 | "output_type": "execute_result" 653 | } 654 | ], 655 | "source": [ 656 | "x = torch.rand(1,8,4,8,8)\n", 657 | "bn3d(x).shape" 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": null, 663 | "id": "1bd3b9ac-7faa-4da6-b59b-86792e649590", 664 | "metadata": {}, 665 | "outputs": [], 666 | "source": [ 667 | "#export\n", 668 | "class PyramidSpatioTemporalPooling(nn.Module):\n", 669 | " \"\"\" Spatio-temporal pyramid pooling.\n", 670 | " Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.\n", 671 | " Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]\n", 672 | " \"\"\"\n", 673 | "\n", 674 | " def __init__(self, in_channels, reduction_channels, pool_sizes):\n", 675 | " super().__init__()\n", 676 | " self.features = []\n", 677 | " for pool_size in pool_sizes:\n", 678 | " assert pool_size[0] == 2, (\n", 679 | " \"Time kernel should be 2 as PyTorch raises an error when\" \"padding with more than half the kernel size\"\n", 680 | " )\n", 681 | " stride = (1, *pool_size[1:])\n", 682 | " padding = (pool_size[0] - 1, 0, 0)\n", 683 | " self.features.append(\n", 684 | " nn.Sequential(\n", 685 | " OrderedDict(\n", 686 | " [\n", 687 | " # Pad the input tensor but do not take into account zero padding into the average.\n", 688 | " (\n", 689 | " 'avgpool',\n", 690 | " torch.nn.AvgPool3d(\n", 691 | " kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False\n", 692 | " ),\n", 693 | " ),\n", 694 | " ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)),\n", 695 | " ]\n", 696 | " )\n", 697 | " )\n", 698 | " )\n", 699 | " self.features = nn.ModuleList(self.features)\n", 700 | "\n", 701 | " def forward(self, *inputs):\n", 702 | " (x,) = inputs\n", 703 | " b, _, t, h, w = x.shape\n", 704 | " # Do not include current tensor when concatenating\n", 705 | " out = []\n", 706 | " for f in self.features:\n", 707 | " # Remove unnecessary padded values (time dimension) on the right\n", 708 | " x_pool = f(x)[:, :, :-1].contiguous()\n", 709 | " c = x_pool.shape[1]\n", 710 | " x_pool = nn.functional.interpolate(\n", 711 | " x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False\n", 712 | " )\n", 713 | " x_pool = x_pool.view(b, c, t, h, w)\n", 714 | " out.append(x_pool)\n", 715 | " out = torch.cat(out, 1)\n", 716 | " return out" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": null, 722 | "id": "45902f68-7d7f-4935-ac97-00e332ef9d48", 723 | "metadata": {}, 724 | "outputs": [], 725 | "source": [ 726 | "h,w = (64,64)\n", 727 | "ptp = PyramidSpatioTemporalPooling(12, 12, [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)])" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": null, 733 | "id": "0e0db71a-dfef-4389-bc9c-483adc060008", 734 | "metadata": {}, 735 | "outputs": [ 736 | { 737 | "data": { 738 | "text/plain": [ 739 | "torch.Size([1, 36, 4, 64, 64])" 740 | ] 741 | }, 742 | "execution_count": null, 743 | "metadata": {}, 744 | "output_type": "execute_result" 745 | } 746 | ], 747 | "source": [ 748 | "x = torch.rand(1,12,4,64,64)\n", 749 | "\n", 750 | "#the output is concatenated...\n", 751 | "ptp(x).shape" 752 | ] 753 | }, 754 | { 755 | "cell_type": "code", 756 | "execution_count": null, 757 | "id": "7575bf79-6718-4a70-90ed-985b79fea483", 758 | "metadata": {}, 759 | "outputs": [], 760 | "source": [ 761 | "#export\n", 762 | "class TemporalBlock(nn.Module):\n", 763 | " \"\"\" Temporal block with the following layers:\n", 764 | " - 2x3x3, 1x3x3, spatio-temporal pyramid pooling\n", 765 | " - dropout\n", 766 | " - skip connection.\n", 767 | " \"\"\"\n", 768 | "\n", 769 | " def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None):\n", 770 | " super().__init__()\n", 771 | " self.in_channels = in_channels\n", 772 | " self.half_channels = in_channels // 2\n", 773 | " self.out_channels = out_channels or self.in_channels\n", 774 | " self.kernels = [(2, 3, 3), (1, 3, 3)]\n", 775 | "\n", 776 | " # Flag for spatio-temporal pyramid pooling\n", 777 | " self.use_pyramid_pooling = use_pyramid_pooling\n", 778 | "\n", 779 | " # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1\n", 780 | " self.convolution_paths = []\n", 781 | " for kernel_size in self.kernels:\n", 782 | " self.convolution_paths.append(\n", 783 | " nn.Sequential(\n", 784 | " conv_1x1x1_norm_activated(self.in_channels, self.half_channels),\n", 785 | " CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size),\n", 786 | " )\n", 787 | " )\n", 788 | " self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels))\n", 789 | " self.convolution_paths = nn.ModuleList(self.convolution_paths)\n", 790 | "\n", 791 | " agg_in_channels = len(self.convolution_paths) * self.half_channels\n", 792 | "\n", 793 | " if self.use_pyramid_pooling:\n", 794 | " assert pool_sizes is not None, \"setting must contain the list of kernel_size, but is None.\"\n", 795 | " reduction_channels = self.in_channels // 3\n", 796 | " self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes)\n", 797 | " agg_in_channels += len(pool_sizes) * reduction_channels\n", 798 | "\n", 799 | " # Feature aggregation\n", 800 | " self.aggregation = nn.Sequential(\n", 801 | " conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),)\n", 802 | "\n", 803 | " if self.out_channels != self.in_channels:\n", 804 | " self.projection = nn.Sequential(\n", 805 | " nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False),\n", 806 | " nn.BatchNorm3d(self.out_channels),\n", 807 | " )\n", 808 | " else:\n", 809 | " self.projection = None\n", 810 | "\n", 811 | " def forward(self, *inputs):\n", 812 | " (x,) = inputs\n", 813 | " x_paths = []\n", 814 | " for conv in self.convolution_paths:\n", 815 | " x_paths.append(conv(x))\n", 816 | " x_residual = torch.cat(x_paths, dim=1)\n", 817 | " if self.use_pyramid_pooling:\n", 818 | " x_pool = self.pyramid_pooling(x)\n", 819 | " x_residual = torch.cat([x_residual, x_pool], dim=1)\n", 820 | " x_residual = self.aggregation(x_residual)\n", 821 | "\n", 822 | " if self.out_channels != self.in_channels:\n", 823 | " x = self.projection(x)\n", 824 | " x = x + x_residual\n", 825 | " return x" 826 | ] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": null, 831 | "id": "385f0475-0c44-4b90-83bf-6869c24e105e", 832 | "metadata": {}, 833 | "outputs": [ 834 | { 835 | "data": { 836 | "text/plain": [ 837 | "torch.Size([1, 8, 4, 64, 64])" 838 | ] 839 | }, 840 | "execution_count": null, 841 | "metadata": {}, 842 | "output_type": "execute_result" 843 | } 844 | ], 845 | "source": [ 846 | "tb = TemporalBlock(4,8)\n", 847 | "\n", 848 | "x = torch.rand(1,4,4,64,64)\n", 849 | "\n", 850 | "tb(x).shape" 851 | ] 852 | }, 853 | { 854 | "cell_type": "code", 855 | "execution_count": null, 856 | "id": "65e50f4d-40f1-48fc-b390-66d0fe7a4f16", 857 | "metadata": {}, 858 | "outputs": [ 859 | { 860 | "data": { 861 | "text/plain": [ 862 | "TemporalBlock(\n", 863 | " (convolution_paths): ModuleList(\n", 864 | " (0): Sequential(\n", 865 | " (0): Sequential(\n", 866 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n", 867 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 868 | " (activation): ReLU(inplace=True)\n", 869 | " )\n", 870 | " (1): CausalConv3d(\n", 871 | " (pad): ConstantPad3d(padding=(1, 1, 1, 1, 1, 0), value=0)\n", 872 | " (conv): Conv3d(2, 2, kernel_size=(2, 3, 3), stride=(1, 1, 1), bias=False)\n", 873 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 874 | " (activation): ReLU(inplace=True)\n", 875 | " )\n", 876 | " )\n", 877 | " (1): Sequential(\n", 878 | " (0): Sequential(\n", 879 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n", 880 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 881 | " (activation): ReLU(inplace=True)\n", 882 | " )\n", 883 | " (1): CausalConv3d(\n", 884 | " (pad): ConstantPad3d(padding=(1, 1, 1, 1, 0, 0), value=0)\n", 885 | " (conv): Conv3d(2, 2, kernel_size=(1, 3, 3), stride=(1, 1, 1), bias=False)\n", 886 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 887 | " (activation): ReLU(inplace=True)\n", 888 | " )\n", 889 | " )\n", 890 | " (2): Sequential(\n", 891 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n", 892 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 893 | " (activation): ReLU(inplace=True)\n", 894 | " )\n", 895 | " )\n", 896 | " (aggregation): Sequential(\n", 897 | " (0): Sequential(\n", 898 | " (conv): Conv3d(6, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n", 899 | " (norm): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 900 | " (activation): ReLU(inplace=True)\n", 901 | " )\n", 902 | " )\n", 903 | " (projection): Sequential(\n", 904 | " (0): Conv3d(4, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n", 905 | " (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 906 | " )\n", 907 | ")" 908 | ] 909 | }, 910 | "execution_count": null, 911 | "metadata": {}, 912 | "output_type": "execute_result" 913 | } 914 | ], 915 | "source": [ 916 | "tb" 917 | ] 918 | }, 919 | { 920 | "cell_type": "code", 921 | "execution_count": null, 922 | "id": "5f671904-4256-48d5-8787-72556d225456", 923 | "metadata": {}, 924 | "outputs": [], 925 | "source": [ 926 | "#export\n", 927 | "class FuturePrediction(torch.nn.Module):\n", 928 | " def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3):\n", 929 | " super().__init__()\n", 930 | " self.n_gru_blocks = n_gru_blocks\n", 931 | "\n", 932 | " # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample\n", 933 | " # from the probabilistic model. The architecture of the model is:\n", 934 | " # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks\n", 935 | " self.spatial_grus = []\n", 936 | " self.res_blocks = []\n", 937 | "\n", 938 | " for i in range(self.n_gru_blocks):\n", 939 | " gru_in_channels = latent_dim if i == 0 else in_channels\n", 940 | " self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels))\n", 941 | " self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels)\n", 942 | " for _ in range(n_res_layers)]))\n", 943 | "\n", 944 | " self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)\n", 945 | " self.res_blocks = torch.nn.ModuleList(self.res_blocks)\n", 946 | "\n", 947 | " def forward(self, x, hidden_state):\n", 948 | " # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w)\n", 949 | " for i in range(self.n_gru_blocks):\n", 950 | " x = self.spatial_grus[i](x, hidden_state, flow=None)\n", 951 | " b, n_future, c, h, w = x.shape\n", 952 | "\n", 953 | " x = self.res_blocks[i](x.view(b * n_future, c, h, w))\n", 954 | " x = x.view(b, n_future, c, h, w)\n", 955 | "\n", 956 | " return x" 957 | ] 958 | }, 959 | { 960 | "cell_type": "code", 961 | "execution_count": null, 962 | "id": "7ce0e2ce-a862-4631-b5f7-a6b2c9928415", 963 | "metadata": {}, 964 | "outputs": [], 965 | "source": [ 966 | "fp = FuturePrediction(32, 32, 4, 4)\n", 967 | "\n", 968 | "x = torch.rand(1,4, 32, 64, 64)\n", 969 | "hidden = torch.rand(1,32,64,64)\n", 970 | "test_eq(fp(x, hidden).shape, x.shape)" 971 | ] 972 | }, 973 | { 974 | "cell_type": "markdown", 975 | "id": "542aa109-21e9-48b8-a87a-e4bbd5f59b8d", 976 | "metadata": {}, 977 | "source": [ 978 | "## Export" 979 | ] 980 | }, 981 | { 982 | "cell_type": "code", 983 | "execution_count": null, 984 | "id": "d4c520b5-d666-4ebb-a9c0-02992b822b8d", 985 | "metadata": {}, 986 | "outputs": [ 987 | { 988 | "name": "stdout", 989 | "output_type": "stream", 990 | "text": [ 991 | "Converted 00_model.ipynb.\n", 992 | "Converted 01_layers.ipynb.\n", 993 | "Converted index.ipynb.\n" 994 | ] 995 | } 996 | ], 997 | "source": [ 998 | "# hide\n", 999 | "from nbdev.export import *\n", 1000 | "notebook2script()" 1001 | ] 1002 | } 1003 | ], 1004 | "metadata": { 1005 | "kernelspec": { 1006 | "display_name": "Python 3", 1007 | "language": "python", 1008 | "name": "python3" 1009 | } 1010 | }, 1011 | "nbformat": 4, 1012 | "nbformat_minor": 5 1013 | } 1014 | -------------------------------------------------------------------------------- /nbs/images/eclipse_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tcapelle/eclipse_pytorch/13b1c41b076a535eb01abf37f777f9fa8309ee48/nbs/images/eclipse_diagram.png -------------------------------------------------------------------------------- /nbs/index.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Eclipse\n", 8 | "\n", 9 | "> Implementing [Paletta et al](https://arxiv.org/pdf/2104.12419v1.pdf) in Pytorch" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "Most of the codebase comes from [Fiery](https://github.com/wayveai/fiery)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "![Image](images/eclipse_diagram.png)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Install" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "```bash\n", 38 | "pip install eclipse_pytorch\n", 39 | "```" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## How to use" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "import torch\n", 56 | "\n", 57 | "from eclipse_pytorch.model import Eclipse" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "ename": "NameError", 67 | "evalue": "name 'TemporalModel' is not defined", 68 | "output_type": "error", 69 | "traceback": [ 70 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 71 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 72 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0meclipse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEclipse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhorizon\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 73 | "\u001b[0;32m~/Documents/eclipse/eclipse_pytorch/model.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, n_in, n_out, horizon, img_size, debug)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mstore_attr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspatial_downsampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSpatialDownsampler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_in\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtemporal_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTemporalModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg_size\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m//\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg_size\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m//\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_out_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuture_prediction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mFuturePrediction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_gru_blocks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_res_layers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupsampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mUpsampler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 74 | "\u001b[0;31mNameError\u001b[0m: name 'TemporalModel' is not defined" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "eclipse = Eclipse(horizon=5)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "let's simulte some input images:" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "images = [torch.rand(2, 3, 128, 128) for _ in range(4)]" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "preds = eclipse(images)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "you get a dict with forecasted masks and irradiances:" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "len(preds['masks']), preds['masks'][0].shape, preds['irradiances'].shape" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## Citation" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "```latex\n", 135 | "@article{paletta2021eclipse,\n", 136 | " title = {{ECLIPSE} : Envisioning Cloud Induced Perturbations in Solar Energy},\n", 137 | " author = {Quentin Paletta and Anthony Hu and Guillaume Arbod and Joan Lasenby},\n", 138 | " year = {2021},\n", 139 | " eprinttype = {arXiv},\n", 140 | " eprint = {2104.12419}\n", 141 | "}\n", 142 | "```" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "## Contribute\n", 150 | "\n", 151 | "This repo is made with [nbdev](https://github.com/fastai/nbdev), please read the documentation to contribute" 152 | ] 153 | } 154 | ], 155 | "metadata": { 156 | "kernelspec": { 157 | "display_name": "Python 3", 158 | "language": "python", 159 | "name": "python3" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 4 164 | } 165 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | host = github 3 | lib_name = eclipse_pytorch 4 | user = tcapelle 5 | description = A pytorch implementation of Eclipse 6 | keywords = pytorch nowcasting solar energy 7 | author = Thomas Capelle 8 | author_email = tcapelle@pm.me 9 | copyright = bla blabla 10 | branch = master 11 | version = 0.0.8 12 | min_python = 3.8 13 | audience = Developers 14 | language = English 15 | custom_sidebar = False 16 | license = apache2 17 | status = 2 18 | requirements = torch>1.6 fastcore 19 | nbs_path = nbs 20 | doc_path = docs 21 | recursive = False 22 | doc_host = https://tcapelle.github.io 23 | doc_baseurl = /eclipse_pytorch/ 24 | git_url = https://github.com/tcapelle/eclipse_pytorch/tree/master/ 25 | lib_path = eclipse_pytorch 26 | title = eclipse_pytorch 27 | 28 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools,re,sys 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | if len(sys.argv)>1 and sys.argv[1]=='version': 17 | print(setup_cfg['version']) 18 | exit() 19 | 20 | licenses = { 21 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 22 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 23 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 24 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 25 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 26 | } 27 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 28 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 29 | py_versions = '2.0 2.1 2.2 2.3 2.4 2.5 2.6 2.7 3.0 3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8'.split() 30 | 31 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 32 | min_python = cfg['min_python'] 33 | 34 | requirements = ['pip', 'packaging'] 35 | if cfg.get('requirements'): requirements += cfg.get('requirements','').split() 36 | if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() 37 | dev_requirements = (cfg.get('dev_requirements') or '').split() 38 | 39 | long_description = open('README.md').read() 40 | # ![png](docs/images/output_13_0.png) 41 | for ext in ['png', 'svg']: 42 | long_description = re.sub(r'!\['+ext+'\]\((.*)\)', '!['+ext+']('+'https://raw.githubusercontent.com/{}/{}'.format(cfg['user'],cfg['lib_name'])+'/'+cfg['branch']+'/\\1)', long_description) 43 | long_description = re.sub(r'src=\"(.*)\.'+ext+'\"', 'src=\"https://raw.githubusercontent.com/{}/{}'.format(cfg['user'],cfg['lib_name'])+'/'+cfg['branch']+'/\\1.'+ext+'\"', long_description) 44 | 45 | setuptools.setup( 46 | name = cfg['lib_name'], 47 | license = lic[0], 48 | classifiers = [ 49 | 'Development Status :: ' + statuses[int(cfg['status'])], 50 | 'Intended Audience :: ' + cfg['audience'].title(), 51 | 'Natural Language :: ' + cfg['language'].title(), 52 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 53 | url = cfg['git_url'], 54 | packages = setuptools.find_packages(), 55 | include_package_data = True, 56 | install_requires = requirements, 57 | extras_require={ 'dev': dev_requirements }, 58 | python_requires = '>=' + cfg['min_python'], 59 | long_description = long_description, 60 | long_description_content_type = 'text/markdown', 61 | zip_safe = False, 62 | entry_points = { 'console_scripts': cfg.get('console_scripts','').split() }, 63 | **setup_cfg) 64 | 65 | --------------------------------------------------------------------------------