├── .devcontainer.json
├── .github
└── workflows
│ └── main.yml
├── .gitignore
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── conda
└── eclipse_pytorch
│ └── meta.yaml
├── docker-compose.yml
├── docs
├── .gitignore
├── Gemfile
├── _config.yml
├── _data
│ ├── sidebars
│ │ └── home_sidebar.yml
│ └── topnav.yml
├── feed.xml
├── images
│ └── eclipse_diagram.png
├── index.html
├── layers.html
├── model.html
├── sidebar.json
└── sitemap.xml
├── eclipse_pytorch
├── __init__.py
├── _nbdev.py
├── imports.py
├── layers.py
└── model.py
├── nbs
├── 00_model.ipynb
├── 01_layers.ipynb
├── images
│ └── eclipse_diagram.png
└── index.ipynb
├── settings.ini
└── setup.py
/.devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "nbdev_template-codespaces",
3 | "dockerComposeFile": "docker-compose.yml",
4 | "service": "watcher",
5 | "settings": {"terminal.integrated.shell.linux": "/bin/bash"},
6 | "mounts": [ "source=/var/run/docker.sock,target=/var/run/docker.sock,type=bind" ],
7 | "forwardPorts": [4000, 8080],
8 | "appPort": [4000, 8080],
9 | "extensions": ["ms-python.python",
10 | "ms-azuretools.vscode-docker"],
11 | "runServices": ["notebook", "jekyll", "watcher"],
12 | "postStartCommand": "pip install -e ."
13 | }
14 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 | on: [push, pull_request]
3 | jobs:
4 | build:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v1
8 | - uses: actions/setup-python@v1
9 | with:
10 | python-version: '3.8'
11 | architecture: 'x64'
12 | - name: Install the library
13 | run: |
14 | pip install nbdev jupyter
15 | pip install -e .
16 | - name: Read all notebooks
17 | run: |
18 | nbdev_read_nbs
19 | - name: Check if all notebooks are cleaned
20 | run: |
21 | echo "Check we are starting with clean git checkout"
22 | if [ -n "$(git status -uno -s)" ]; then echo "git status is not clean"; false; fi
23 | echo "Trying to strip out notebooks"
24 | nbdev_clean_nbs
25 | echo "Check that strip out was unnecessary"
26 | git status -s # display the status to see which nbs need cleaning up
27 | if [ -n "$(git status -uno -s)" ]; then echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev_install_git_hooks"; false; fi
28 | - name: Check if there is no diff library/notebooks
29 | run: |
30 | if [ -n "$(nbdev_diff_nbs)" ]; then echo -e "!!! Detected difference between the notebooks and the library"; false; fi
31 | - name: Run tests
32 | run: |
33 | nbdev_test_nbs
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .jekyll-cache/
2 | Gemfile.lock
3 | *.bak
4 | .gitattributes
5 | .last_checked
6 | .gitconfig
7 | *.bak
8 | *.log
9 | *~
10 | ~*
11 | _tmp*
12 | tmp*
13 | tags
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 |
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | env/
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | *.egg-info/
39 | .installed.cfg
40 | *.egg
41 |
42 | # PyInstaller
43 | # Usually these files are written by a python script from a template
44 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
45 | *.manifest
46 | *.spec
47 |
48 | # Installer logs
49 | pip-log.txt
50 | pip-delete-this-directory.txt
51 |
52 | # Unit test / coverage reports
53 | htmlcov/
54 | .tox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | .hypothesis/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # celery beat schedule file
91 | celerybeat-schedule
92 |
93 | # SageMath parsed files
94 | *.sage.py
95 |
96 | # dotenv
97 | .env
98 |
99 | # virtualenv
100 | .venv
101 | venv/
102 | ENV/
103 |
104 | # Spyder project settings
105 | .spyderproject
106 | .spyproject
107 |
108 | # Rope project settings
109 | .ropeproject
110 |
111 | # mkdocs documentation
112 | /site
113 |
114 | # mypy
115 | .mypy_cache/
116 |
117 | .vscode
118 | *.swp
119 |
120 | # osx generated files
121 | .DS_Store
122 | .DS_Store?
123 | .Trashes
124 | ehthumbs.db
125 | Thumbs.db
126 | .idea
127 |
128 | # pytest
129 | .pytest_cache
130 |
131 | # tools/trust-doc-nbs
132 | docs_src/.last_checked
133 |
134 | # symlinks to fastai
135 | docs_src/fastai
136 | tools/fastai
137 |
138 | # link checker
139 | checklink/cookies.txt
140 |
141 | # .gitconfig is now autogenerated
142 | .gitconfig
143 |
144 | token
145 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Release notes
2 |
3 |
4 |
5 | ## 0.0.8
6 |
7 |
8 |
9 |
10 | ## 0.0.7
11 |
12 |
13 |
14 |
15 | ## 0.0.7
16 |
17 |
18 |
19 |
20 | ## 0.0.7
21 |
22 |
23 |
24 |
25 | ## 0.0.6
26 |
27 |
28 |
29 |
30 | ## 0.0.6
31 |
32 |
33 |
34 |
35 | ## 0.0.5
36 |
37 |
38 |
39 |
40 | ## 0.0.5
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to contribute
2 |
3 | ## How to get started
4 |
5 | Before anything else, please install the git hooks that run automatic scripts during each commit and merge to strip the notebooks of superfluous metadata (and avoid merge conflicts). After cloning the repository, run the following command inside it:
6 | ```
7 | nbdev_install_git_hooks
8 | ```
9 |
10 | ## Did you find a bug?
11 |
12 | * Ensure the bug was not already reported by searching on GitHub under Issues.
13 | * If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
14 | * Be sure to add the complete error messages.
15 |
16 | #### Did you write a patch that fixes a bug?
17 |
18 | * Open a new GitHub pull request with the patch.
19 | * Ensure that your PR includes a test that fails without your patch, and pass with it.
20 | * Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
21 |
22 | ## PR submission guidelines
23 |
24 | * Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
25 | * Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
26 | * Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
27 | * Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
28 | * If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
29 |
30 | ## Do you want to contribute to the documentation?
31 |
32 | * Docs are automatically created from the notebooks in the nbs folder.
33 |
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include settings.ini
2 | include LICENSE
3 | include CONTRIBUTING.md
4 | include README.md
5 | recursive-exclude * __pycache__
6 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .ONESHELL:
2 | SHELL := /bin/bash
3 | SRC = $(wildcard nbs/*.ipynb)
4 |
5 | all: eclipse_pytorch docs
6 |
7 | eclipse_pytorch: $(SRC)
8 | nbdev_build_lib
9 | touch eclipse_pytorch
10 |
11 | sync:
12 | nbdev_update_lib
13 |
14 | docs_serve: docs
15 | cd docs && bundle exec jekyll serve
16 |
17 | docs: $(SRC)
18 | nbdev_build_docs
19 | touch docs
20 |
21 | test:
22 | nbdev_test_nbs
23 |
24 | release: pypi conda_release
25 | nbdev_bump_version
26 |
27 | conda_release:
28 | fastrelease_conda_package
29 |
30 | pypi: dist
31 | twine upload --repository pypi dist/*
32 |
33 | dist: clean
34 | python setup.py sdist bdist_wheel
35 |
36 | clean:
37 | rm -rf dist
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Eclipse
2 | > Implementing Paletta et al in Pytorch
3 |
4 |
5 | Most of the codebase comes from [Fiery](https://github.com/wayveai/fiery)
6 |
7 | 
8 |
9 | ## Install
10 |
11 | ```bash
12 | pip install eclipse_pytorch
13 | ```
14 |
15 | ## How to use
16 |
17 | ```python
18 | import torch
19 |
20 | from eclipse_pytorch.model import Eclipse
21 | ```
22 |
23 | ```python
24 | eclipse = Eclipse(horizon=5)
25 | ```
26 |
27 | let's simulte some input images:
28 |
29 | ```python
30 | images = [torch.rand(2, 3, 128, 128) for _ in range(4)]
31 | ```
32 |
33 | ```python
34 | preds = eclipse(images)
35 | ```
36 |
37 | you get a dict with forecasted masks and irradiances:
38 |
39 | ```python
40 | len(preds['masks']), preds['masks'][0].shape, preds['irradiances'].shape
41 | ```
42 |
43 |
44 |
45 |
46 | (6, torch.Size([2, 4, 128, 128]), torch.Size([2, 6]))
47 |
48 |
49 |
50 | ## Citation
51 |
52 | ```latex
53 | @article{paletta2021eclipse,
54 | title = {{ECLIPSE} : Envisioning Cloud Induced Perturbations in Solar Energy},
55 | author = {Quentin Paletta and Anthony Hu and Guillaume Arbod and Joan Lasenby},
56 | year = {2021},
57 | eprinttype = {arXiv},
58 | eprint = {2104.12419}
59 | }
60 | ```
61 |
62 | ## Contribute
63 |
64 | This repo is made with [nbdev](https://github.com/fastai/nbdev), please read the documentation to contribute
65 |
--------------------------------------------------------------------------------
/conda/eclipse_pytorch/meta.yaml:
--------------------------------------------------------------------------------
1 | package:
2 | name: eclipse_pytorch
3 | version: 0.0.8
4 | source:
5 | sha256: 3dc73af281153764e70df6e2153f4118d49d012267d9d66317f57e12db566876
6 | url: https://files.pythonhosted.org/packages/71/ae/edf126362a552f73948c94ed09d1d3e653c3500e6ef9447cde863f7fa525/eclipse_pytorch-0.0.8.tar.gz
7 | about:
8 | dev_url: https://tcapelle.github.io
9 | doc_url: https://tcapelle.github.io
10 | home: https://tcapelle.github.io
11 | license: Apache Software
12 | license_family: APACHE
13 | summary: A pytorch implementation of Eclipse
14 | build:
15 | noarch: python
16 | number: '0'
17 | script: '{{ PYTHON }} -m pip install . -vv'
18 | extra:
19 | recipe-maintainers:
20 | - tcapelle
21 | requirements:
22 | host:
23 | - pip
24 | - python
25 | - packaging
26 | - torch>1.6
27 | - fastcore
28 | run:
29 | - pip
30 | - python
31 | - packaging
32 | - torch>1.6
33 | - fastcore
34 | test:
35 | imports:
36 | - eclipse_pytorch
37 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3"
2 | services:
3 | fastai: &fastai
4 | restart: unless-stopped
5 | working_dir: /data
6 | image: fastai/codespaces
7 | logging:
8 | driver: json-file
9 | options:
10 | max-size: 50m
11 | stdin_open: true
12 | tty: true
13 | volumes:
14 | - .:/data/
15 |
16 | notebook:
17 | <<: *fastai
18 | command: bash -c "pip install -e . && jupyter notebook --allow-root --no-browser --ip=0.0.0.0 --port=8080 --NotebookApp.token='' --NotebookApp.password=''"
19 | ports:
20 | - "8080:8080"
21 |
22 | watcher:
23 | <<: *fastai
24 | command: watchmedo shell-command --command nbdev_build_docs --pattern *.ipynb --recursive --drop
25 | network_mode: host # for GitHub Codespaces https://github.com/features/codespaces/
26 |
27 | jekyll:
28 | <<: *fastai
29 | ports:
30 | - "4000:4000"
31 | command: >
32 | bash -c "pip install .
33 | && nbdev_build_docs && cd docs
34 | && bundle i
35 | && chmod -R u+rwx . && bundle exec jekyll serve --host 0.0.0.0"
36 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _site/
2 |
--------------------------------------------------------------------------------
/docs/Gemfile:
--------------------------------------------------------------------------------
1 | source "https://rubygems.org"
2 |
3 | gem "jekyll", ">= 3.7"
4 | gem "jekyll-remote-theme"
5 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | repository: tcapelle/eclipse_pytorch
2 | output: web
3 | topnav_title: eclipse_pytorch
4 | site_title: eclipse_pytorch
5 | company_name: bla blabla
6 | description: A pytorch implementation of Eclipse
7 | # Set to false to disable KaTeX math
8 | use_math: true
9 | # Add Google analytics id if you have one and want to use it here
10 | google_analytics:
11 | # See http://nbdev.fast.ai/search for help with adding Search
12 | google_search:
13 |
14 | host: 127.0.0.1
15 | # the preview server used. Leave as is.
16 | port: 4000
17 | # the port where the preview is rendered.
18 |
19 | exclude:
20 | - .idea/
21 | - .gitignore
22 | - vendor
23 |
24 | exclude: [vendor]
25 |
26 | highlighter: rouge
27 | markdown: kramdown
28 | kramdown:
29 | input: GFM
30 | auto_ids: true
31 | hard_wrap: false
32 | syntax_highlighter: rouge
33 |
34 | collections:
35 | tooltips:
36 | output: false
37 |
38 | defaults:
39 | -
40 | scope:
41 | path: ""
42 | type: "pages"
43 | values:
44 | layout: "page"
45 | comments: true
46 | search: true
47 | sidebar: home_sidebar
48 | topnav: topnav
49 | -
50 | scope:
51 | path: ""
52 | type: "tooltips"
53 | values:
54 | layout: "page"
55 | comments: true
56 | search: true
57 | tooltip: true
58 |
59 | sidebars:
60 | - home_sidebar
61 |
62 | plugins:
63 | - jekyll-remote-theme
64 |
65 | remote_theme: fastai/nbdev-jekyll-theme
66 | baseurl: /eclipse_pytorch/
--------------------------------------------------------------------------------
/docs/_data/sidebars/home_sidebar.yml:
--------------------------------------------------------------------------------
1 |
2 | #################################################
3 | ### THIS FILE WAS AUTOGENERATED! DO NOT EDIT! ###
4 | #################################################
5 | # Instead edit ../../sidebar.json
6 | entries:
7 | - folders:
8 | - folderitems:
9 | - output: web,pdf
10 | title: Overview
11 | url: /
12 | - output: web,pdf
13 | title: The model
14 | url: model.html
15 | - output: web,pdf
16 | title: Layers
17 | url: layers.html
18 | output: web
19 | title: eclipse_pytorch
20 | output: web
21 | title: Sidebar
22 |
--------------------------------------------------------------------------------
/docs/_data/topnav.yml:
--------------------------------------------------------------------------------
1 | topnav:
2 | - title: Topnav
3 | items:
4 | - title: github
5 | external_url: https://github.com/tcapelle/eclipse_pytorch/tree/master/
6 |
7 | #Topnav dropdowns
8 | topnav_dropdowns:
9 | - title: Topnav dropdowns
10 | folders:
--------------------------------------------------------------------------------
/docs/feed.xml:
--------------------------------------------------------------------------------
1 | ---
2 | search: exclude
3 | layout: none
4 | ---
5 |
6 |
7 |
8 |
9 | {{ site.title | xml_escape }}
10 | {{ site.description | xml_escape }}
11 | {{ site.url }}/
12 |
13 | {{ site.time | date_to_rfc822 }}
14 | {{ site.time | date_to_rfc822 }}
15 | Jekyll v{{ jekyll.version }}
16 | {% for post in site.posts limit:10 %}
17 | -
18 |
{{ post.title | xml_escape }}
19 | {{ post.content | xml_escape }}
20 | {{ post.date | date_to_rfc822 }}
21 | {{ post.url | prepend: site.url }}
22 | {{ post.url | prepend: site.url }}
23 | {% for tag in post.tags %}
24 | {{ tag | xml_escape }}
25 | {% endfor %}
26 | {% for tag in page.tags %}
27 | {{ cat | xml_escape }}
28 | {% endfor %}
29 |
30 | {% endfor %}
31 |
32 |
33 |
--------------------------------------------------------------------------------
/docs/images/eclipse_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tcapelle/eclipse_pytorch/13b1c41b076a535eb01abf37f777f9fa8309ee48/docs/images/eclipse_diagram.png
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Eclipse
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "Implementing Paletta et al in Pytorch"
10 | description: "Implementing Paletta et al in Pytorch"
11 | nb_path: "nbs/index.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
Most of the codebase comes from Fiery
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
52 |
53 |
54 |
pip install eclipse_pytorch
55 |
56 |
57 |
58 |
59 |
60 |
66 | {% raw %}
67 |
68 |
83 | {% endraw %}
84 |
85 | {% raw %}
86 |
87 |
100 | {% endraw %}
101 |
102 |
103 |
104 |
let's simulte some input images:
105 |
106 |
107 |
108 |
109 | {% raw %}
110 |
111 |
124 | {% endraw %}
125 |
126 | {% raw %}
127 |
128 |
141 | {% endraw %}
142 |
143 |
144 |
145 |
you get a dict with forecasted masks and irradiances:
146 |
147 |
148 |
149 |
150 | {% raw %}
151 |
152 |
153 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
(6, torch.Size([2, 4, 128, 128]), torch.Size([2, 6]))
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | {% endraw %}
182 |
183 |
189 |
190 |
191 |
@article{ paletta2021eclipse,
192 | title = {{ ECLIPSE} : Envisioning Cloud Induced Perturbations in Solar Energy} ,
193 | author = { Quentin Paletta and Anthony Hu and Guillaume Arbod and Joan Lasenby} ,
194 | year = { 2021} ,
195 | eprinttype = { arXiv} ,
196 | eprint = { 2104.12419}
197 | }
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
Contribute This repo is made with nbdev , please read the documentation to contribute
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
--------------------------------------------------------------------------------
/docs/layers.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: Layers
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "most of them come from fiery "
10 | description: "most of them come from fiery "
11 | nb_path: "nbs/01_layers.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 | {% raw %}
33 |
34 |
35 |
36 |
37 | {% endraw %}
38 |
39 | {% raw %}
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
get_activation
(activation
)
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 | {% endraw %}
62 |
63 | {% raw %}
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
get_norm
(norm
, out_channels
)
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 | {% endraw %}
86 |
87 | {% raw %}
88 |
89 |
90 |
91 |
92 | {% endraw %}
93 |
94 | {% raw %}
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
init_linear
(m
, act_func
=None
, init
='auto'
, bias_std
=0.01
)
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 | {% endraw %}
117 |
118 | {% raw %}
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
ConvBlock
(in_channels
, out_channels
=None
, kernel_size
=3
, stride
=1
, norm
='bn'
, activation
='relu'
, bias
=False
, transpose
=False
, init
='auto'
) :: Sequential
130 |
131 |
2D convolution followed by
132 |
133 | an optional normalisation (batch norm or instance norm)
134 | an optional activation (ReLU, LeakyReLU, or tanh)
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 | {% endraw %}
146 |
147 | {% raw %}
148 |
149 |
150 |
151 |
152 | {% endraw %}
153 |
154 | {% raw %}
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
Bottleneck
(in_channels
, out_channels
=None
, kernel_size
=3
, dilation
=1
, groups
=1
, upsample
=False
, downsample
=False
, dropout
=0.0
) :: Module
166 |
167 |
Defines a bottleneck module with a residual connection
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 | {% endraw %}
178 |
179 | {% raw %}
180 |
181 |
182 |
183 |
184 | {% endraw %}
185 |
186 | {% raw %}
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
Interpolate
(scale_factor
:int
=2
) :: Module
198 |
199 |
Base class for all neural network modules.
200 |
Your models should also subclass this class.
201 |
Modules can also contain other Modules, allowing to nest them in
202 | a tree structure. You can assign the submodules as regular attributes::
203 |
204 |
import torch.nn as nn
205 | import torch.nn.functional as F
206 |
207 | class Model(nn.Module):
208 | def __init__(self):
209 | super(Model, self).__init__()
210 | self.conv1 = nn.Conv2d(1, 20, 5)
211 | self.conv2 = nn.Conv2d(20, 20, 5)
212 |
213 | def forward(self, x):
214 | x = F.relu(self.conv1(x))
215 | return F.relu(self.conv2(x))
216 |
217 |
218 |
Submodules assigned in this way will be registered, and will have their
219 | parameters converted too when you call :meth:to
, etc.
220 |
:ivar training: Boolean represents whether this module is in training or
221 | evaluation mode.
222 | :vartype training: bool
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 | {% endraw %}
233 |
234 | {% raw %}
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
UpsamplingConcat
(in_channels
, out_channels
, scale_factor
=2
) :: Module
246 |
247 |
Base class for all neural network modules.
248 |
Your models should also subclass this class.
249 |
Modules can also contain other Modules, allowing to nest them in
250 | a tree structure. You can assign the submodules as regular attributes::
251 |
252 |
import torch.nn as nn
253 | import torch.nn.functional as F
254 |
255 | class Model(nn.Module):
256 | def __init__(self):
257 | super(Model, self).__init__()
258 | self.conv1 = nn.Conv2d(1, 20, 5)
259 | self.conv2 = nn.Conv2d(20, 20, 5)
260 |
261 | def forward(self, x):
262 | x = F.relu(self.conv1(x))
263 | return F.relu(self.conv2(x))
264 |
265 |
266 |
Submodules assigned in this way will be registered, and will have their
267 | parameters converted too when you call :meth:to
, etc.
268 |
:ivar training: Boolean represents whether this module is in training or
269 | evaluation mode.
270 | :vartype training: bool
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 | {% endraw %}
281 |
282 | {% raw %}
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
UpsamplingAdd
(in_channels
, out_channels
, scale_factor
=2
) :: Module
294 |
295 |
Base class for all neural network modules.
296 |
Your models should also subclass this class.
297 |
Modules can also contain other Modules, allowing to nest them in
298 | a tree structure. You can assign the submodules as regular attributes::
299 |
300 |
import torch.nn as nn
301 | import torch.nn.functional as F
302 |
303 | class Model(nn.Module):
304 | def __init__(self):
305 | super(Model, self).__init__()
306 | self.conv1 = nn.Conv2d(1, 20, 5)
307 | self.conv2 = nn.Conv2d(20, 20, 5)
308 |
309 | def forward(self, x):
310 | x = F.relu(self.conv1(x))
311 | return F.relu(self.conv2(x))
312 |
313 |
314 |
Submodules assigned in this way will be registered, and will have their
315 | parameters converted too when you call :meth:to
, etc.
316 |
:ivar training: Boolean represents whether this module is in training or
317 | evaluation mode.
318 | :vartype training: bool
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 | {% endraw %}
329 |
330 | {% raw %}
331 |
332 |
333 |
334 |
335 | {% endraw %}
336 |
337 | {% raw %}
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
ResBlock
(in_channels
, out_channels
, kernel_size
=3
, stride
=1
, norm
='bn'
, activation
='relu'
) :: Module
349 |
350 |
A simple resnet Block
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 | {% endraw %}
361 |
362 | {% raw %}
363 |
364 |
365 |
366 |
367 | {% endraw %}
368 |
369 | {% raw %}
370 |
371 |
384 | {% endraw %}
385 |
386 | {% raw %}
387 |
388 |
389 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
torch.Size([1, 128, 10, 10])
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 | {% endraw %}
419 |
420 | {% raw %}
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
SpatialGRU
(input_size
, hidden_size
, gru_bias_init
=0.0
, norm
='bn'
, activation
='relu'
) :: Module
432 |
433 |
A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a
434 | convolutional gated recurrent unit over the data
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 | {% endraw %}
445 |
446 | {% raw %}
447 |
448 |
449 |
450 |
451 | {% endraw %}
452 |
453 | {% raw %}
454 |
455 |
468 | {% endraw %}
469 |
470 |
471 |
472 |
without hidden state
473 |
474 |
475 |
476 |
477 | {% raw %}
478 |
479 |
493 | {% endraw %}
494 |
495 |
496 |
497 |
with hidden
498 |
hidden.shape = (bs, hidden_size, h, w)
499 |
500 |
501 |
502 |
503 |
504 | {% raw %}
505 |
506 |
507 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
torch.Size([1, 3, 64, 8, 8])
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 | {% endraw %}
540 |
541 | {% raw %}
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
CausalConv3d
(in_channels
, out_channels
, kernel_size
=(2, 3, 3)
, dilation
=(1, 1, 1)
, bias
=False
) :: Module
553 |
554 |
Base class for all neural network modules.
555 |
Your models should also subclass this class.
556 |
Modules can also contain other Modules, allowing to nest them in
557 | a tree structure. You can assign the submodules as regular attributes::
558 |
559 |
import torch.nn as nn
560 | import torch.nn.functional as F
561 |
562 | class Model(nn.Module):
563 | def __init__(self):
564 | super(Model, self).__init__()
565 | self.conv1 = nn.Conv2d(1, 20, 5)
566 | self.conv2 = nn.Conv2d(20, 20, 5)
567 |
568 | def forward(self, x):
569 | x = F.relu(self.conv1(x))
570 | return F.relu(self.conv2(x))
571 |
572 |
573 |
Submodules assigned in this way will be registered, and will have their
574 | parameters converted too when you call :meth:to
, etc.
575 |
:ivar training: Boolean represents whether this module is in training or
576 | evaluation mode.
577 | :vartype training: bool
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 | {% endraw %}
588 |
589 | {% raw %}
590 |
591 |
592 |
593 |
594 | {% endraw %}
595 |
596 | {% raw %}
597 |
598 |
611 | {% endraw %}
612 |
613 | {% raw %}
614 |
615 |
616 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
torch.Size([1, 128, 4, 8, 8])
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 | {% endraw %}
646 |
647 | {% raw %}
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
conv_1x1x1_norm_activated
(in_channels
, out_channels
)
659 |
660 |
1x1x1 3D convolution, normalization and activation layer.
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 | {% endraw %}
671 |
672 | {% raw %}
673 |
674 |
675 |
676 |
677 | {% endraw %}
678 |
679 | {% raw %}
680 |
681 |
682 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
Sequential(
702 | (conv): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
703 | (norm): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
704 | (activation): ReLU(inplace=True)
705 | )
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 | {% endraw %}
715 |
716 | {% raw %}
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
Bottleneck3D
(in_channels
, out_channels
=None
, kernel_size
=(2, 3, 3)
, dilation
=(1, 1, 1)
) :: Module
728 |
729 |
Defines a bottleneck module with a residual connection
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 | {% endraw %}
740 |
741 | {% raw %}
742 |
743 |
744 |
745 |
746 | {% endraw %}
747 |
748 | {% raw %}
749 |
750 |
763 | {% endraw %}
764 |
765 | {% raw %}
766 |
767 |
768 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
torch.Size([1, 12, 4, 8, 8])
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 | {% endraw %}
798 |
799 | {% raw %}
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
PyramidSpatioTemporalPooling
(in_channels
, reduction_channels
, pool_sizes
) :: Module
811 |
812 |
Spatio-temporal pyramid pooling.
813 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.
814 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]
815 |
816 |
817 |
818 |
819 |
820 |
821 |
822 |
823 |
824 | {% endraw %}
825 |
826 | {% raw %}
827 |
828 |
829 |
830 |
831 | {% endraw %}
832 |
833 | {% raw %}
834 |
835 |
849 | {% endraw %}
850 |
851 | {% raw %}
852 |
853 |
854 |
867 |
868 |
869 |
870 |
871 |
872 |
873 |
874 |
875 |
876 |
torch.Size([1, 36, 4, 64, 64])
877 |
878 |
879 |
880 |
881 |
882 |
883 |
884 |
885 | {% endraw %}
886 |
887 | {% raw %}
888 |
889 |
890 |
891 |
892 |
893 |
894 |
895 |
896 |
897 |
898 |
TemporalBlock
(in_channels
, out_channels
=None
, use_pyramid_pooling
=False
, pool_sizes
=None
) :: Module
899 |
900 |
Temporal block with the following layers:
901 |
902 | 2x3x3, 1x3x3, spatio-temporal pyramid pooling
903 | dropout
904 | skip connection.
905 |
906 |
907 |
908 |
909 |
910 |
911 |
912 |
913 |
914 |
915 | {% endraw %}
916 |
917 | {% raw %}
918 |
919 |
920 |
921 |
922 | {% endraw %}
923 |
924 | {% raw %}
925 |
926 |
927 |
941 |
942 |
943 |
944 |
945 |
946 |
947 |
948 |
949 |
950 |
torch.Size([1, 8, 4, 64, 64])
951 |
952 |
953 |
954 |
955 |
956 |
957 |
958 |
959 | {% endraw %}
960 |
961 | {% raw %}
962 |
963 |
964 |
965 |
966 |
967 |
968 |
969 |
970 |
971 |
972 |
FuturePrediction
(in_channels
, latent_dim
, n_gru_blocks
=3
, n_res_layers
=3
) :: Module
973 |
974 |
Base class for all neural network modules.
975 |
Your models should also subclass this class.
976 |
Modules can also contain other Modules, allowing to nest them in
977 | a tree structure. You can assign the submodules as regular attributes::
978 |
979 |
import torch.nn as nn
980 | import torch.nn.functional as F
981 |
982 | class Model(nn.Module):
983 | def __init__(self):
984 | super(Model, self).__init__()
985 | self.conv1 = nn.Conv2d(1, 20, 5)
986 | self.conv2 = nn.Conv2d(20, 20, 5)
987 |
988 | def forward(self, x):
989 | x = F.relu(self.conv1(x))
990 | return F.relu(self.conv2(x))
991 |
992 |
993 |
Submodules assigned in this way will be registered, and will have their
994 | parameters converted too when you call :meth:to
, etc.
995 |
:ivar training: Boolean represents whether this module is in training or
996 | evaluation mode.
997 | :vartype training: bool
998 |
999 |
1000 |
1001 |
1002 |
1003 |
1004 |
1005 |
1006 |
1007 | {% endraw %}
1008 |
1009 | {% raw %}
1010 |
1011 |
1012 |
1013 |
1014 | {% endraw %}
1015 |
1016 | {% raw %}
1017 |
1018 |
1035 | {% endraw %}
1036 |
1037 |
1043 |
1044 |
1045 |
1046 |
--------------------------------------------------------------------------------
/docs/model.html:
--------------------------------------------------------------------------------
1 | ---
2 |
3 | title: The model
4 |
5 |
6 | keywords: fastai
7 | sidebar: home_sidebar
8 |
9 | summary: "from the paper from Paletta et al "
10 | description: "from the paper from Paletta et al "
11 | nb_path: "nbs/00_model.ipynb"
12 | ---
13 |
22 |
23 |
24 |
25 | {% raw %}
26 |
27 |
28 |
29 |
30 | {% endraw %}
31 |
32 |
33 |
34 |
We will try to implement as close as possible the architecture from the paper ECLIPSE : Envisioning Cloud Induced Perturbations in Solar Energy
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | {% raw %}
47 |
48 |
49 |
50 |
51 | {% endraw %}
52 |
53 |
54 |
55 |
1. Spatial Downsampler A resnet encoder to get image features
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
You could use any spatial downsampler as you want, but the paper states a simple resnet arch...
64 |
65 |
66 |
67 |
68 | {% raw %}
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
SpatialDownsampler
(in_channels
=3
) :: Module
80 |
81 |
Base class for all neural network modules.
82 |
Your models should also subclass this class.
83 |
Modules can also contain other Modules, allowing to nest them in
84 | a tree structure. You can assign the submodules as regular attributes::
85 |
86 |
import torch.nn as nn
87 | import torch.nn.functional as F
88 |
89 | class Model(nn.Module):
90 | def __init__(self):
91 | super(Model, self).__init__()
92 | self.conv1 = nn.Conv2d(1, 20, 5)
93 | self.conv2 = nn.Conv2d(20, 20, 5)
94 |
95 | def forward(self, x):
96 | x = F.relu(self.conv1(x))
97 | return F.relu(self.conv2(x))
98 |
99 |
100 |
Submodules assigned in this way will be registered, and will have their
101 | parameters converted too when you call :meth:to
, etc.
102 |
:ivar training: Boolean represents whether this module is in training or
103 | evaluation mode.
104 | :vartype training: bool
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 | {% endraw %}
115 |
116 | {% raw %}
117 |
118 |
119 |
120 |
121 | {% endraw %}
122 |
123 | {% raw %}
124 |
125 |
138 | {% endraw %}
139 |
140 | {% raw %}
141 |
142 |
143 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
torch.Size([1, 256, 4, 8, 8])
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 | {% endraw %}
174 |
175 |
176 |
177 |
2. Temporal Encoder
178 |
179 |
180 |
181 | {% raw %}
182 |
183 |
196 | {% endraw %}
197 |
198 | {% raw %}
199 |
200 |
201 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
torch.Size([1, 128, 4, 8, 8])
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 | {% endraw %}
231 |
232 |
233 |
234 |
3. Future State Predictions
235 |
236 |
237 |
238 | {% raw %}
239 |
240 |
253 | {% endraw %}
254 |
255 | {% raw %}
256 |
257 |
258 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
torch.Size([1, 4, 128, 8, 8])
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 | {% endraw %}
289 |
290 |
291 |
292 |
4A. Segmentation Decoder
293 |
294 |
295 |
296 | {% raw %}
297 |
298 |
299 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
Bottleneck(
320 | (layers): Sequential(
321 | (conv_down_project): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
322 | (abn_down_project): Sequential(
323 | (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
324 | (1): ReLU(inplace=True)
325 | )
326 | (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
327 | (abn): Sequential(
328 | (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
329 | (1): ReLU(inplace=True)
330 | )
331 | (conv_up_project): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
332 | (abn_up_project): Sequential(
333 | (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
334 | (1): ReLU(inplace=True)
335 | )
336 | (dropout): Dropout2d(p=0.0, inplace=False)
337 | )
338 | (projection): Sequential(
339 | (upsample_skip_proj): Interpolate()
340 | (conv_skip_proj): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
341 | (bn_skip_proj): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
342 | )
343 | )
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 | {% endraw %}
353 |
354 | {% raw %}
355 |
356 |
357 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
torch.Size([256, 4, 8, 8])
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 | {% endraw %}
386 |
387 | {% raw %}
388 |
389 |
390 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
torch.Size([1, 128, 64, 64])
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 | {% endraw %}
420 |
421 | {% raw %}
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
Upsampler
(sizes
=[128, 128, 64]
, n_out
=3
) :: Module
433 |
434 |
Base class for all neural network modules.
435 |
Your models should also subclass this class.
436 |
Modules can also contain other Modules, allowing to nest them in
437 | a tree structure. You can assign the submodules as regular attributes::
438 |
439 |
import torch.nn as nn
440 | import torch.nn.functional as F
441 |
442 | class Model(nn.Module):
443 | def __init__(self):
444 | super(Model, self).__init__()
445 | self.conv1 = nn.Conv2d(1, 20, 5)
446 | self.conv2 = nn.Conv2d(20, 20, 5)
447 |
448 | def forward(self, x):
449 | x = F.relu(self.conv1(x))
450 | return F.relu(self.conv2(x))
451 |
452 |
453 |
Submodules assigned in this way will be registered, and will have their
454 | parameters converted too when you call :meth:to
, etc.
455 |
:ivar training: Boolean represents whether this module is in training or
456 | evaluation mode.
457 | :vartype training: bool
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 | {% endraw %}
468 |
469 | {% raw %}
470 |
471 |
472 |
473 |
474 | {% endraw %}
475 |
476 | {% raw %}
477 |
478 |
479 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
torch.Size([1, 3, 256, 256])
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 | {% endraw %}
511 |
512 |
513 |
514 |
4B. Irradiance Module
515 |
516 |
517 |
518 | {% raw %}
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
IrradianceModule
() :: Module
530 |
531 |
Base class for all neural network modules.
532 |
Your models should also subclass this class.
533 |
Modules can also contain other Modules, allowing to nest them in
534 | a tree structure. You can assign the submodules as regular attributes::
535 |
536 |
import torch.nn as nn
537 | import torch.nn.functional as F
538 |
539 | class Model(nn.Module):
540 | def __init__(self):
541 | super(Model, self).__init__()
542 | self.conv1 = nn.Conv2d(1, 20, 5)
543 | self.conv2 = nn.Conv2d(20, 20, 5)
544 |
545 | def forward(self, x):
546 | x = F.relu(self.conv1(x))
547 | return F.relu(self.conv2(x))
548 |
549 |
550 |
Submodules assigned in this way will be registered, and will have their
551 | parameters converted too when you call :meth:to
, etc.
552 |
:ivar training: Boolean represents whether this module is in training or
553 | evaluation mode.
554 | :vartype training: bool
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 | {% endraw %}
565 |
566 | {% raw %}
567 |
568 |
569 |
570 |
571 | {% endraw %}
572 |
573 | {% raw %}
574 |
575 |
588 | {% endraw %}
589 |
590 | {% raw %}
591 |
592 |
593 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
torch.Size([2, 1])
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 | {% endraw %}
623 |
624 |
625 |
626 |
Everything Together...
627 |
628 |
629 |
630 | {% raw %}
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
Eclipse
(n_in
=3
, n_out
=4
, horizon
=5
, debug
=False
) :: Module
642 |
643 |
Not very parametric
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 | {% endraw %}
654 |
655 | {% raw %}
656 |
657 |
658 |
659 |
660 | {% endraw %}
661 |
662 | {% raw %}
663 |
664 |
677 | {% endraw %}
678 |
679 | {% raw %}
680 |
681 |
682 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
states.shape=torch.Size([2, 4, 128, 16, 16])
702 | present_state.shape=torch.Size([2, 1, 128, 16, 16])
703 | hidden_state.shape=torch.Size([2, 128, 16, 16])
704 | future_states.shape=torch.Size([2, 6, 128, 16, 16])
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
(torch.Size([2, 4, 128, 128]), torch.Size([2, 6]))
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 | {% endraw %}
724 |
725 |
731 |
732 |
733 |
734 |
--------------------------------------------------------------------------------
/docs/sidebar.json:
--------------------------------------------------------------------------------
1 | {
2 | "eclipse_pytorch": {
3 | "Overview": "/",
4 | "The model": "model.html",
5 | "Layers": "layers.html"
6 | }
7 | }
--------------------------------------------------------------------------------
/docs/sitemap.xml:
--------------------------------------------------------------------------------
1 | ---
2 | layout: none
3 | search: exclude
4 | ---
5 |
6 |
7 |
8 | {% for post in site.posts %}
9 | {% unless post.search == "exclude" %}
10 |
11 | {{site.url}}{{post.url}}
12 |
13 | {% endunless %}
14 | {% endfor %}
15 |
16 |
17 | {% for page in site.pages %}
18 | {% unless page.search == "exclude" %}
19 |
20 | {{site.url}}{{ page.url}}
21 |
22 | {% endunless %}
23 | {% endfor %}
24 |
--------------------------------------------------------------------------------
/eclipse_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.8"
2 |
--------------------------------------------------------------------------------
/eclipse_pytorch/_nbdev.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED BY NBDEV! DO NOT EDIT!
2 |
3 | __all__ = ["index", "modules", "custom_doc_links", "git_url"]
4 |
5 | index = {"SpatialDownsampler": "00_model.ipynb",
6 | "TemporalModel": "00_model.ipynb",
7 | "Upsampler": "00_model.ipynb",
8 | "IrradianceModule": "00_model.ipynb",
9 | "Eclipse": "00_model.ipynb",
10 | "get_activation": "01_layers.ipynb",
11 | "get_norm": "01_layers.ipynb",
12 | "init_linear": "01_layers.ipynb",
13 | "ConvBlock": "01_layers.ipynb",
14 | "Bottleneck": "01_layers.ipynb",
15 | "Interpolate": "01_layers.ipynb",
16 | "UpsamplingConcat": "01_layers.ipynb",
17 | "UpsamplingAdd": "01_layers.ipynb",
18 | "ResBlock": "01_layers.ipynb",
19 | "SpatialGRU": "01_layers.ipynb",
20 | "CausalConv3d": "01_layers.ipynb",
21 | "conv_1x1x1_norm_activated": "01_layers.ipynb",
22 | "Bottleneck3D": "01_layers.ipynb",
23 | "PyramidSpatioTemporalPooling": "01_layers.ipynb",
24 | "TemporalBlock": "01_layers.ipynb",
25 | "FuturePrediction": "01_layers.ipynb"}
26 |
27 | modules = ["model.py",
28 | "layers.py"]
29 |
30 | doc_url = "https://tcapelle.github.io/eclipse_pytorch/"
31 |
32 | git_url = "https://github.com/tcapelle/eclipse_pytorch/tree/master/"
33 |
34 | def custom_doc_links(name): return None
35 |
--------------------------------------------------------------------------------
/eclipse_pytorch/imports.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from fastcore.all import *
--------------------------------------------------------------------------------
/eclipse_pytorch/layers.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_layers.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['get_activation', 'get_norm', 'init_linear', 'ConvBlock', 'Bottleneck', 'Interpolate', 'UpsamplingConcat',
4 | 'UpsamplingAdd', 'ResBlock', 'SpatialGRU', 'CausalConv3d', 'conv_1x1x1_norm_activated', 'Bottleneck3D',
5 | 'PyramidSpatioTemporalPooling', 'TemporalBlock', 'FuturePrediction']
6 |
7 | # Cell
8 | from .imports import *
9 |
10 | # Cell
11 | def get_activation(activation):
12 | if activation == 'relu':
13 | return nn.ReLU(inplace=True)
14 | elif activation == 'lrelu':
15 | return nn.LeakyReLU(0.1, inplace=True)
16 | elif activation == 'elu':
17 | return nn.ELU(inplace=True)
18 | elif activation == 'tanh':
19 | return nn.Tanh(inplace=True)
20 | else:
21 | raise ValueError('Invalid activation {}'.format(activation))
22 |
23 | def get_norm(norm, out_channels):
24 | if norm == 'bn':
25 | return nn.BatchNorm2d(out_channels)
26 | elif norm == 'in':
27 | return nn.InstanceNorm2d(out_channels)
28 | else:
29 | raise ValueError('Invalid norm {}'.format(norm))
30 |
31 | # Cell
32 |
33 | def init_linear(m, act_func=None, init='auto', bias_std=0.01):
34 | if getattr(m,'bias',None) is not None and bias_std is not None:
35 | if bias_std != 0: normal_(m.bias, 0, bias_std)
36 | else: m.bias.data.zero_()
37 | if init=='auto':
38 | if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_
39 | else: init = getattr(act_func.__class__, '__default_init__', None)
40 | if init is None: init = getattr(act_func, '__default_init__', None)
41 | if init is not None: init(m.weight)
42 |
43 |
44 | class ConvBlock(nn.Sequential):
45 | """2D convolution followed by
46 | - an optional normalisation (batch norm or instance norm)
47 | - an optional activation (ReLU, LeakyReLU, or tanh)
48 | """
49 |
50 | def __init__(
51 | self,
52 | in_channels,
53 | out_channels=None,
54 | kernel_size=3,
55 | stride=1,
56 | norm='bn',
57 | activation='relu',
58 | bias=False,
59 | transpose=False,
60 | init='auto'
61 | ):
62 |
63 | out_channels = out_channels or in_channels
64 | padding = (kernel_size-1)//2
65 | conv_cls = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1)
66 | conv = conv_cls(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)
67 | if activation is not None: activation = get_activation(activation)
68 | init_linear(conv, activation, init=init)
69 | layers = [conv]
70 | if activation is not None: layers.append(activation)
71 | if norm is not None: layers.append(get_norm(norm, out_channels))
72 | super().__init__(*layers)
73 |
74 | # Cell
75 | class Bottleneck(nn.Module):
76 | """
77 | Defines a bottleneck module with a residual connection
78 | """
79 |
80 | def __init__(
81 | self,
82 | in_channels,
83 | out_channels=None,
84 | kernel_size=3,
85 | dilation=1,
86 | groups=1,
87 | upsample=False,
88 | downsample=False,
89 | dropout=0.0,
90 | ):
91 | super().__init__()
92 | self._downsample = downsample
93 | bottleneck_channels = int(in_channels / 2)
94 | out_channels = out_channels or in_channels
95 | padding_size = ((kernel_size - 1) * dilation + 1) // 2
96 |
97 | # Define the main conv operation
98 | assert dilation == 1
99 | if upsample:
100 | assert not downsample, 'downsample and upsample not possible simultaneously.'
101 | bottleneck_conv = nn.ConvTranspose2d(
102 | bottleneck_channels,
103 | bottleneck_channels,
104 | kernel_size=kernel_size,
105 | bias=False,
106 | dilation=1,
107 | stride=2,
108 | output_padding=padding_size,
109 | padding=padding_size,
110 | groups=groups,
111 | )
112 | elif downsample:
113 | bottleneck_conv = nn.Conv2d(
114 | bottleneck_channels,
115 | bottleneck_channels,
116 | kernel_size=kernel_size,
117 | bias=False,
118 | dilation=dilation,
119 | stride=2,
120 | padding=padding_size,
121 | groups=groups,
122 | )
123 | else:
124 | bottleneck_conv = nn.Conv2d(
125 | bottleneck_channels,
126 | bottleneck_channels,
127 | kernel_size=kernel_size,
128 | bias=False,
129 | dilation=dilation,
130 | padding=padding_size,
131 | groups=groups,
132 | )
133 |
134 | self.layers = nn.Sequential(
135 | OrderedDict(
136 | [
137 | # First projection with 1x1 kernel
138 | ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)),
139 | ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels),
140 | nn.ReLU(inplace=True))),
141 | # Second conv block
142 | ('conv', bottleneck_conv),
143 | ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))),
144 | # Final projection with 1x1 kernel
145 | ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)),
146 | ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels),
147 | nn.ReLU(inplace=True))),
148 | # Regulariser
149 | ('dropout', nn.Dropout2d(p=dropout)),
150 | ]
151 | )
152 | )
153 |
154 | if out_channels == in_channels and not downsample and not upsample:
155 | self.projection = None
156 | else:
157 | projection = OrderedDict()
158 | if upsample:
159 | projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)})
160 | elif downsample:
161 | projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)})
162 | projection.update(
163 | {
164 | 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
165 | 'bn_skip_proj': nn.BatchNorm2d(out_channels),
166 | }
167 | )
168 | self.projection = nn.Sequential(projection)
169 |
170 | # pylint: disable=arguments-differ
171 | def forward(self, *args):
172 | (x,) = args
173 | x_residual = self.layers(x)
174 | if self.projection is not None:
175 | if self._downsample:
176 | # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer
177 | x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0)
178 | return x_residual + self.projection(x)
179 | return x_residual + x
180 |
181 | # Cell
182 | class Interpolate(nn.Module):
183 | def __init__(self, scale_factor: int = 2):
184 | super().__init__()
185 | self._interpolate = nn.functional.interpolate
186 | self._scale_factor = scale_factor
187 |
188 | # pylint: disable=arguments-differ
189 | def forward(self, x):
190 | return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False)
191 |
192 |
193 | class UpsamplingConcat(nn.Module):
194 | def __init__(self, in_channels, out_channels, scale_factor=2):
195 | super().__init__()
196 |
197 | self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
198 |
199 | self.conv = nn.Sequential(
200 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
201 | nn.BatchNorm2d(out_channels),
202 | nn.ReLU(inplace=True),
203 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
204 | nn.BatchNorm2d(out_channels),
205 | nn.ReLU(inplace=True),
206 | )
207 |
208 | def forward(self, x_to_upsample, x):
209 | x_to_upsample = self.upsample(x_to_upsample)
210 | x_to_upsample = torch.cat([x, x_to_upsample], dim=1)
211 | return self.conv(x_to_upsample)
212 |
213 |
214 | class UpsamplingAdd(nn.Module):
215 | def __init__(self, in_channels, out_channels, scale_factor=2):
216 | super().__init__()
217 | self.upsample_layer = nn.Sequential(
218 | nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
219 | nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),
220 | nn.BatchNorm2d(out_channels),
221 | )
222 |
223 | def forward(self, x, x_skip):
224 | x = self.upsample_layer(x)
225 | return x + x_skip
226 |
227 | # Cell
228 | class ResBlock(nn.Module):
229 | " A simple resnet Block"
230 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm='bn', activation='relu'):
231 | super().__init__()
232 | self.convs = nn.Sequential(ConvBlock(in_channels, out_channels, kernel_size, stride, norm=norm, activation=activation),
233 | ConvBlock(out_channels, out_channels, norm=norm, activation=activation)
234 | )
235 | id_path = [ConvBlock(in_channels, out_channels, kernel_size=1, activation=None, norm=None)]
236 | self.activation = get_activation(activation)
237 | if stride!=1: id_path.insert(1, nn.AvgPool2d(2, stride, ceil_mode=True))
238 | self.id_path = nn.Sequential(*id_path)
239 |
240 | def forward(self, x):
241 | return self.activation(self.convs(x) + self.id_path(x))
242 |
243 | # Cell
244 | class SpatialGRU(nn.Module):
245 | """A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a
246 | convolutional gated recurrent unit over the data"""
247 |
248 | def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'):
249 | super().__init__()
250 | self.input_size = input_size
251 | self.hidden_size = hidden_size
252 | self.gru_bias_init = gru_bias_init
253 |
254 | self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
255 | self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)
256 |
257 | self.conv_state_tilde = ConvBlock(
258 | input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation
259 | )
260 |
261 | def forward(self, x, state=None, flow=None, mode='bilinear'):
262 | # pylint: disable=unused-argument, arguments-differ
263 | # Check size
264 | assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.'
265 | b, timesteps, c, h, w = x.size()
266 | assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}'
267 |
268 | # recurrent layers
269 | rnn_output = []
270 | rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state
271 | for t in range(timesteps):
272 | x_t = x[:, t]
273 | if flow is not None:
274 | rnn_state = warp_features(rnn_state, flow[:, t], mode=mode)
275 |
276 | # propagate rnn state
277 | rnn_state = self.gru_cell(x_t, rnn_state)
278 | rnn_output.append(rnn_state)
279 |
280 | # reshape rnn output to batch tensor
281 | return torch.stack(rnn_output, dim=1)
282 |
283 | def gru_cell(self, x, state):
284 | # Compute gates
285 | x_and_state = torch.cat([x, state], dim=1)
286 | update_gate = self.conv_update(x_and_state)
287 | reset_gate = self.conv_reset(x_and_state)
288 | # Add bias to initialise gate as close to identity function
289 | update_gate = torch.sigmoid(update_gate + self.gru_bias_init)
290 | reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init)
291 |
292 | # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc)
293 | state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1))
294 |
295 | output = (1.0 - update_gate) * state + update_gate * state_tilde
296 | return output
297 |
298 | # Cell
299 | class CausalConv3d(nn.Module):
300 | def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False):
301 | super().__init__()
302 | assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'
303 | time_pad = (kernel_size[0] - 1) * dilation[0]
304 | height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2
305 | width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2
306 |
307 | # Pad temporally on the left
308 | self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)
309 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias)
310 | self.norm = nn.BatchNorm3d(out_channels)
311 | self.activation = nn.ReLU(inplace=True)
312 |
313 | def forward(self, *inputs):
314 | (x,) = inputs
315 | x = self.pad(x)
316 | x = self.conv(x)
317 | x = self.norm(x)
318 | x = self.activation(x)
319 | return x
320 |
321 | # Cell
322 | def conv_1x1x1_norm_activated(in_channels, out_channels):
323 | """1x1x1 3D convolution, normalization and activation layer."""
324 | return nn.Sequential(
325 | OrderedDict(
326 | [
327 | ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)),
328 | ('norm', nn.BatchNorm3d(out_channels)),
329 | ('activation', nn.ReLU(inplace=True)),
330 | ]
331 | )
332 | )
333 |
334 | # Cell
335 | class Bottleneck3D(nn.Module):
336 | """
337 | Defines a bottleneck module with a residual connection
338 | """
339 |
340 | def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)):
341 | super().__init__()
342 | bottleneck_channels = in_channels // 2
343 | out_channels = out_channels or in_channels
344 |
345 | self.layers = nn.Sequential(
346 | OrderedDict(
347 | [
348 | # First projection with 1x1 kernel
349 | ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)),
350 | # Second conv block
351 | (
352 | 'conv',
353 | CausalConv3d(
354 | bottleneck_channels,
355 | bottleneck_channels,
356 | kernel_size=kernel_size,
357 | dilation=dilation,
358 | bias=False,
359 | ),
360 | ),
361 | # Final projection with 1x1 kernel
362 | ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)),
363 | ]
364 | )
365 | )
366 |
367 | if out_channels != in_channels:
368 | self.projection = nn.Sequential(
369 | nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),
370 | nn.BatchNorm3d(out_channels),
371 | )
372 | else:
373 | self.projection = None
374 |
375 | def forward(self, *args):
376 | (x,) = args
377 | x_residual = self.layers(x)
378 | x_features = self.projection(x) if self.projection is not None else x
379 | return x_residual + x_features
380 |
381 | # Cell
382 | class PyramidSpatioTemporalPooling(nn.Module):
383 | """ Spatio-temporal pyramid pooling.
384 | Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.
385 | Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]
386 | """
387 |
388 | def __init__(self, in_channels, reduction_channels, pool_sizes):
389 | super().__init__()
390 | self.features = []
391 | for pool_size in pool_sizes:
392 | assert pool_size[0] == 2, (
393 | "Time kernel should be 2 as PyTorch raises an error when" "padding with more than half the kernel size"
394 | )
395 | stride = (1, *pool_size[1:])
396 | padding = (pool_size[0] - 1, 0, 0)
397 | self.features.append(
398 | nn.Sequential(
399 | OrderedDict(
400 | [
401 | # Pad the input tensor but do not take into account zero padding into the average.
402 | (
403 | 'avgpool',
404 | torch.nn.AvgPool3d(
405 | kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False
406 | ),
407 | ),
408 | ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)),
409 | ]
410 | )
411 | )
412 | )
413 | self.features = nn.ModuleList(self.features)
414 |
415 | def forward(self, *inputs):
416 | (x,) = inputs
417 | b, _, t, h, w = x.shape
418 | # Do not include current tensor when concatenating
419 | out = []
420 | for f in self.features:
421 | # Remove unnecessary padded values (time dimension) on the right
422 | x_pool = f(x)[:, :, :-1].contiguous()
423 | c = x_pool.shape[1]
424 | x_pool = nn.functional.interpolate(
425 | x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False
426 | )
427 | x_pool = x_pool.view(b, c, t, h, w)
428 | out.append(x_pool)
429 | out = torch.cat(out, 1)
430 | return out
431 |
432 | # Cell
433 | class TemporalBlock(nn.Module):
434 | """ Temporal block with the following layers:
435 | - 2x3x3, 1x3x3, spatio-temporal pyramid pooling
436 | - dropout
437 | - skip connection.
438 | """
439 |
440 | def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None):
441 | super().__init__()
442 | self.in_channels = in_channels
443 | self.half_channels = in_channels // 2
444 | self.out_channels = out_channels or self.in_channels
445 | self.kernels = [(2, 3, 3), (1, 3, 3)]
446 |
447 | # Flag for spatio-temporal pyramid pooling
448 | self.use_pyramid_pooling = use_pyramid_pooling
449 |
450 | # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1
451 | self.convolution_paths = []
452 | for kernel_size in self.kernels:
453 | self.convolution_paths.append(
454 | nn.Sequential(
455 | conv_1x1x1_norm_activated(self.in_channels, self.half_channels),
456 | CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size),
457 | )
458 | )
459 | self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels))
460 | self.convolution_paths = nn.ModuleList(self.convolution_paths)
461 |
462 | agg_in_channels = len(self.convolution_paths) * self.half_channels
463 |
464 | if self.use_pyramid_pooling:
465 | assert pool_sizes is not None, "setting must contain the list of kernel_size, but is None."
466 | reduction_channels = self.in_channels // 3
467 | self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes)
468 | agg_in_channels += len(pool_sizes) * reduction_channels
469 |
470 | # Feature aggregation
471 | self.aggregation = nn.Sequential(
472 | conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),)
473 |
474 | if self.out_channels != self.in_channels:
475 | self.projection = nn.Sequential(
476 | nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False),
477 | nn.BatchNorm3d(self.out_channels),
478 | )
479 | else:
480 | self.projection = None
481 |
482 | def forward(self, *inputs):
483 | (x,) = inputs
484 | x_paths = []
485 | for conv in self.convolution_paths:
486 | x_paths.append(conv(x))
487 | x_residual = torch.cat(x_paths, dim=1)
488 | if self.use_pyramid_pooling:
489 | x_pool = self.pyramid_pooling(x)
490 | x_residual = torch.cat([x_residual, x_pool], dim=1)
491 | x_residual = self.aggregation(x_residual)
492 |
493 | if self.out_channels != self.in_channels:
494 | x = self.projection(x)
495 | x = x + x_residual
496 | return x
497 |
498 | # Cell
499 | class FuturePrediction(torch.nn.Module):
500 | def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3):
501 | super().__init__()
502 | self.n_gru_blocks = n_gru_blocks
503 |
504 | # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample
505 | # from the probabilistic model. The architecture of the model is:
506 | # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks
507 | self.spatial_grus = []
508 | self.res_blocks = []
509 |
510 | for i in range(self.n_gru_blocks):
511 | gru_in_channels = latent_dim if i == 0 else in_channels
512 | self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels))
513 | self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels)
514 | for _ in range(n_res_layers)]))
515 |
516 | self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)
517 | self.res_blocks = torch.nn.ModuleList(self.res_blocks)
518 |
519 | def forward(self, x, hidden_state):
520 | # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w)
521 | for i in range(self.n_gru_blocks):
522 | x = self.spatial_grus[i](x, hidden_state, flow=None)
523 | b, n_future, c, h, w = x.shape
524 |
525 | x = self.res_blocks[i](x.view(b * n_future, c, h, w))
526 | x = x.view(b, n_future, c, h, w)
527 |
528 | return x
--------------------------------------------------------------------------------
/eclipse_pytorch/model.py:
--------------------------------------------------------------------------------
1 | # AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_model.ipynb (unless otherwise specified).
2 |
3 | __all__ = ['SpatialDownsampler', 'TemporalModel', 'Upsampler', 'IrradianceModule', 'Eclipse']
4 |
5 | # Cell
6 | from .imports import *
7 | from .layers import *
8 |
9 | # Cell
10 | class SpatialDownsampler(nn.Module):
11 |
12 | def __init__(self, in_channels=3):
13 | super().__init__()
14 | self.conv1 = ConvBlock(in_channels, 64, kernel_size=7, stride=1)
15 | self.blocks = nn.Sequential(ResBlock(64, 64, kernel_size=3, stride=2),
16 | ResBlock(64, 128, kernel_size=3, stride=2),
17 | ResBlock(128,256, kernel_size=3, stride=2))
18 |
19 | def forward(self, x):
20 | return self.blocks(self.conv1(x))
21 |
22 | # Cell
23 | class TemporalModel(nn.Module):
24 | def __init__(
25 | self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0,
26 | n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True):
27 | super().__init__()
28 | self.receptive_field = receptive_field
29 | n_temporal_layers = receptive_field - 1
30 |
31 | h, w = input_shape
32 | modules = []
33 |
34 | block_in_channels = in_channels
35 | block_out_channels = start_out_channels
36 |
37 | for _ in range(n_temporal_layers):
38 | if use_pyramid_pooling:
39 | use_pyramid_pooling = True
40 | pool_sizes = [(2, h, w)]
41 | else:
42 | use_pyramid_pooling = False
43 | pool_sizes = None
44 | temporal = TemporalBlock(
45 | block_in_channels,
46 | block_out_channels,
47 | use_pyramid_pooling=use_pyramid_pooling,
48 | pool_sizes=pool_sizes,
49 | )
50 | spatial = [
51 | Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3))
52 | for _ in range(n_spatial_layers_between_temporal_layers)
53 | ]
54 | temporal_spatial_layers = nn.Sequential(temporal, *spatial)
55 | modules.extend(temporal_spatial_layers)
56 |
57 | block_in_channels = block_out_channels
58 | block_out_channels += extra_in_channels
59 |
60 | self.out_channels = block_in_channels
61 |
62 | self.model = nn.Sequential(*modules)
63 |
64 | def forward(self, x):
65 | # Reshape input tensor to (batch, C, time, H, W)
66 | x = x.permute(0, 2, 1, 3, 4)
67 | x = self.model(x)
68 | x = x.permute(0, 2, 1, 3, 4).contiguous()
69 | return x[:, (self.receptive_field - 1):]
70 |
71 | # Cell
72 | class Upsampler(nn.Module):
73 | def __init__(self, sizes=[128,128,64], n_out=3):
74 | super().__init__()
75 | zsizes = zip(sizes[:-1], sizes[1:])
76 | self.convs = nn.Sequential(*[Bottleneck(si, sf, upsample=True) for si,sf in zsizes],
77 | Bottleneck(sizes[-1], sizes[-1], upsample=True),
78 | ConvBlock(sizes[-1], n_out, kernel_size=1, activation=None))
79 |
80 | def forward(self, x):
81 | return self.convs(x)
82 |
83 | # Cell
84 | class IrradianceModule(nn.Module):
85 | def __init__(self):
86 | super().__init__()
87 | self.convs = nn.Sequential(ConvBlock(128, 64, stride=2),
88 | ConvBlock(64, 64),
89 | nn.AdaptiveMaxPool2d(1)
90 | )
91 | self.linear = nn.Sequential(nn.Flatten(),
92 | nn.Linear(64, 1)
93 | )
94 | def forward(self, x):
95 | return self.linear(self.convs(x))
96 |
97 | # Cell
98 | class Eclipse(nn.Module):
99 | """Not very parametric"""
100 | def __init__(self, n_in=3, n_out=4, horizon=5, img_size=(128, 128), n_gru_layers=4, n_res_layers=4, debug=False):
101 | super().__init__()
102 | store_attr()
103 | self.spatial_downsampler = SpatialDownsampler(n_in)
104 | self.temporal_model = TemporalModel(256, 3, input_shape=(img_size[0]//8, img_size[1]//8), start_out_channels=128)
105 | self.future_prediction = FuturePrediction(128, 128, n_gru_blocks=n_gru_layers, n_res_layers=n_res_layers)
106 | self.upsampler = Upsampler(n_out=n_out)
107 | self.irradiance = IrradianceModule()
108 |
109 | def zero_hidden(self, x, horizon):
110 | bs, ch, h, w = x.shape
111 | return x.new_zeros(bs, horizon, ch, h, w)
112 |
113 | def forward(self, imgs):
114 | x = torch.stack([self.spatial_downsampler(img) for img in imgs], dim=1)
115 |
116 | #encode temporal model
117 | states = self.temporal_model(x)
118 | if self.debug: print(f'{states.shape=}')
119 |
120 | #get hidden state
121 | present_state = states[:, -1:]
122 | if self.debug: print(f'{present_state.shape=}')
123 |
124 |
125 | # Prepare future prediction input
126 | hidden_state = present_state.squeeze()
127 | if self.debug: print(f'{hidden_state.shape=}')
128 |
129 | future_prediction_input = self.zero_hidden(hidden_state, self.horizon)
130 |
131 | # Recursively predict future states
132 | future_states = self.future_prediction(future_prediction_input, hidden_state)
133 |
134 | # Concatenate present state
135 | future_states = torch.cat([present_state, future_states], dim=1)
136 | if self.debug: print(f'{future_states.shape=}')
137 |
138 | #decode outputs
139 | masks, irradiances = [], []
140 |
141 | for state in future_states.unbind(dim=1):
142 | masks.append(self.upsampler(state))
143 | irradiances.append(self.irradiance(state))
144 | return {'masks': masks, 'irradiances': torch.cat(irradiances, dim=-1)}
145 |
--------------------------------------------------------------------------------
/nbs/00_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# default_exp model"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "# The model\n",
17 | "\n",
18 | "> from the paper from [Paletta et al](https://arxiv.org/pdf/2104.12419v1.pdf)"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "We will try to implement as close as possible the architecture from the paper `ECLIPSE : Envisioning Cloud Induced Perturbations in Solar Energy`"
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "metadata": {},
31 | "source": [
32 | ""
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "#export\n",
42 | "from eclipse_pytorch.imports import *\n",
43 | "from eclipse_pytorch.layers import *"
44 | ]
45 | },
46 | {
47 | "cell_type": "markdown",
48 | "metadata": {},
49 | "source": [
50 | "## 1. Spatial Downsampler\n",
51 | "> A resnet encoder to get image features"
52 | ]
53 | },
54 | {
55 | "cell_type": "markdown",
56 | "metadata": {},
57 | "source": [
58 | "You could use any spatial downsampler as you want, but the paper states a simple resnet arch..."
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "metadata": {},
65 | "outputs": [],
66 | "source": [
67 | "#export\n",
68 | "class SpatialDownsampler(nn.Module):\n",
69 | " \n",
70 | " def __init__(self, in_channels=3):\n",
71 | " super().__init__()\n",
72 | " self.conv1 = ConvBlock(in_channels, 64, kernel_size=7, stride=1)\n",
73 | " self.blocks = nn.Sequential(ResBlock(64, 64, kernel_size=3, stride=2), \n",
74 | " ResBlock(64, 128, kernel_size=3, stride=2), \n",
75 | " ResBlock(128,256, kernel_size=3, stride=2))\n",
76 | " \n",
77 | " def forward(self, x):\n",
78 | " return self.blocks(self.conv1(x))"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "sd = SpatialDownsampler()"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "metadata": {},
94 | "outputs": [
95 | {
96 | "data": {
97 | "text/plain": [
98 | "torch.Size([1, 4, 256, 8, 8])"
99 | ]
100 | },
101 | "execution_count": null,
102 | "metadata": {},
103 | "output_type": "execute_result"
104 | }
105 | ],
106 | "source": [
107 | "images = [torch.rand(1, 3, 64, 64) for _ in range(4)]\n",
108 | "features = torch.stack([sd(image) for image in images], dim=1)\n",
109 | "features.shape"
110 | ]
111 | },
112 | {
113 | "cell_type": "markdown",
114 | "metadata": {},
115 | "source": [
116 | "## 2. Temporal Encoder"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "#export\n",
126 | "class TemporalModel(nn.Module):\n",
127 | " def __init__(\n",
128 | " self, in_channels, receptive_field, input_shape, start_out_channels=64, extra_in_channels=0,\n",
129 | " n_spatial_layers_between_temporal_layers=0, use_pyramid_pooling=True):\n",
130 | " super().__init__()\n",
131 | " self.receptive_field = receptive_field\n",
132 | " n_temporal_layers = receptive_field - 1\n",
133 | "\n",
134 | " h, w = input_shape\n",
135 | " modules = []\n",
136 | "\n",
137 | " block_in_channels = in_channels\n",
138 | " block_out_channels = start_out_channels\n",
139 | "\n",
140 | " for _ in range(n_temporal_layers):\n",
141 | " if use_pyramid_pooling:\n",
142 | " use_pyramid_pooling = True\n",
143 | " pool_sizes = [(2, h, w)]\n",
144 | " else:\n",
145 | " use_pyramid_pooling = False\n",
146 | " pool_sizes = None\n",
147 | " temporal = TemporalBlock(\n",
148 | " block_in_channels,\n",
149 | " block_out_channels,\n",
150 | " use_pyramid_pooling=use_pyramid_pooling,\n",
151 | " pool_sizes=pool_sizes,\n",
152 | " )\n",
153 | " spatial = [\n",
154 | " Bottleneck3D(block_out_channels, block_out_channels, kernel_size=(1, 3, 3))\n",
155 | " for _ in range(n_spatial_layers_between_temporal_layers)\n",
156 | " ]\n",
157 | " temporal_spatial_layers = nn.Sequential(temporal, *spatial)\n",
158 | " modules.extend(temporal_spatial_layers)\n",
159 | "\n",
160 | " block_in_channels = block_out_channels\n",
161 | " block_out_channels += extra_in_channels\n",
162 | "\n",
163 | " self.out_channels = block_in_channels\n",
164 | "\n",
165 | " self.model = nn.Sequential(*modules)\n",
166 | "\n",
167 | " def forward(self, x):\n",
168 | " # Reshape input tensor to (batch, C, time, H, W)\n",
169 | " x = x.permute(0, 2, 1, 3, 4)\n",
170 | " x = self.model(x)\n",
171 | " x = x.permute(0, 2, 1, 3, 4).contiguous()\n",
172 | " return x[:, (self.receptive_field - 1):]"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "metadata": {},
179 | "outputs": [],
180 | "source": [
181 | "tm = TemporalModel(256, 3, (8,8), 128)"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "metadata": {},
188 | "outputs": [
189 | {
190 | "data": {
191 | "text/plain": [
192 | "torch.Size([1, 2, 128, 8, 8])"
193 | ]
194 | },
195 | "execution_count": null,
196 | "metadata": {},
197 | "output_type": "execute_result"
198 | }
199 | ],
200 | "source": [
201 | "temp_encoded = tm(features)\n",
202 | "temp_encoded.shape"
203 | ]
204 | },
205 | {
206 | "cell_type": "markdown",
207 | "metadata": {},
208 | "source": [
209 | "## 3. Future State Predictions"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "fp = FuturePrediction(128, 128, n_gru_blocks=4, n_res_layers=4)"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {},
225 | "outputs": [
226 | {
227 | "data": {
228 | "text/plain": [
229 | "torch.Size([1, 4, 128, 8, 8])"
230 | ]
231 | },
232 | "execution_count": null,
233 | "metadata": {},
234 | "output_type": "execute_result"
235 | }
236 | ],
237 | "source": [
238 | "hidden = torch.rand(1, 128, 8, 8)\n",
239 | "x = torch.rand(1, 4, 128, 8, 8)\n",
240 | "fp(x, hidden).shape"
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "metadata": {},
246 | "source": [
247 | "## 4A. Segmentation Decoder"
248 | ]
249 | },
250 | {
251 | "cell_type": "code",
252 | "execution_count": null,
253 | "metadata": {},
254 | "outputs": [
255 | {
256 | "data": {
257 | "text/plain": [
258 | "Bottleneck(\n",
259 | " (layers): Sequential(\n",
260 | " (conv_down_project): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
261 | " (abn_down_project): Sequential(\n",
262 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
263 | " (1): ReLU(inplace=True)\n",
264 | " )\n",
265 | " (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)\n",
266 | " (abn): Sequential(\n",
267 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
268 | " (1): ReLU(inplace=True)\n",
269 | " )\n",
270 | " (conv_up_project): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
271 | " (abn_up_project): Sequential(\n",
272 | " (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
273 | " (1): ReLU(inplace=True)\n",
274 | " )\n",
275 | " (dropout): Dropout2d(p=0.0, inplace=False)\n",
276 | " )\n",
277 | " (projection): Sequential(\n",
278 | " (upsample_skip_proj): Interpolate()\n",
279 | " (conv_skip_proj): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
280 | " (bn_skip_proj): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
281 | " )\n",
282 | ")"
283 | ]
284 | },
285 | "execution_count": null,
286 | "metadata": {},
287 | "output_type": "execute_result"
288 | }
289 | ],
290 | "source": [
291 | "bn = Bottleneck(256, 128, upsample=True)\n",
292 | "bn"
293 | ]
294 | },
295 | {
296 | "cell_type": "code",
297 | "execution_count": null,
298 | "metadata": {},
299 | "outputs": [
300 | {
301 | "data": {
302 | "text/plain": [
303 | "torch.Size([4, 256, 8, 8])"
304 | ]
305 | },
306 | "execution_count": null,
307 | "metadata": {},
308 | "output_type": "execute_result"
309 | }
310 | ],
311 | "source": [
312 | "features[0].shape"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": null,
318 | "metadata": {},
319 | "outputs": [
320 | {
321 | "data": {
322 | "text/plain": [
323 | "torch.Size([1, 128, 64, 64])"
324 | ]
325 | },
326 | "execution_count": null,
327 | "metadata": {},
328 | "output_type": "execute_result"
329 | }
330 | ],
331 | "source": [
332 | "x = torch.rand(1,256,32,32)\n",
333 | "bn(x).shape"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": null,
339 | "metadata": {},
340 | "outputs": [],
341 | "source": [
342 | "#export\n",
343 | "class Upsampler(nn.Module):\n",
344 | " def __init__(self, sizes=[128,128,64], n_out=3):\n",
345 | " super().__init__()\n",
346 | " zsizes = zip(sizes[:-1], sizes[1:])\n",
347 | " self.convs = nn.Sequential(*[Bottleneck(si, sf, upsample=True) for si,sf in zsizes], \n",
348 | " Bottleneck(sizes[-1], sizes[-1], upsample=True), \n",
349 | " ConvBlock(sizes[-1], n_out, kernel_size=1, activation=None))\n",
350 | " \n",
351 | " def forward(self, x):\n",
352 | " return self.convs(x)"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "execution_count": null,
358 | "metadata": {},
359 | "outputs": [
360 | {
361 | "data": {
362 | "text/plain": [
363 | "torch.Size([1, 3, 256, 256])"
364 | ]
365 | },
366 | "execution_count": null,
367 | "metadata": {},
368 | "output_type": "execute_result"
369 | }
370 | ],
371 | "source": [
372 | "us = Upsampler()\n",
373 | "\n",
374 | "x = torch.rand(1,128,32,32)\n",
375 | "us(x).shape"
376 | ]
377 | },
378 | {
379 | "cell_type": "markdown",
380 | "metadata": {},
381 | "source": [
382 | "## 4B. Irradiance Module"
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": null,
388 | "metadata": {},
389 | "outputs": [],
390 | "source": [
391 | "#export\n",
392 | "class IrradianceModule(nn.Module):\n",
393 | " def __init__(self):\n",
394 | " super().__init__()\n",
395 | " self.convs = nn.Sequential(ConvBlock(128, 64, stride=2), \n",
396 | " ConvBlock(64, 64),\n",
397 | " nn.AdaptiveMaxPool2d(1)\n",
398 | " )\n",
399 | " self.linear = nn.Sequential(nn.Flatten(), \n",
400 | " nn.Linear(64, 1)\n",
401 | " )\n",
402 | " def forward(self, x):\n",
403 | " return self.linear(self.convs(x))"
404 | ]
405 | },
406 | {
407 | "cell_type": "code",
408 | "execution_count": null,
409 | "metadata": {},
410 | "outputs": [],
411 | "source": [
412 | "im = IrradianceModule()"
413 | ]
414 | },
415 | {
416 | "cell_type": "code",
417 | "execution_count": null,
418 | "metadata": {},
419 | "outputs": [
420 | {
421 | "data": {
422 | "text/plain": [
423 | "IrradianceModule(\n",
424 | " (convs): Sequential(\n",
425 | " (0): ConvBlock(\n",
426 | " (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
427 | " (1): ReLU(inplace=True)\n",
428 | " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
429 | " )\n",
430 | " (1): ConvBlock(\n",
431 | " (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
432 | " (1): ReLU(inplace=True)\n",
433 | " (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
434 | " )\n",
435 | " (2): AdaptiveMaxPool2d(output_size=1)\n",
436 | " )\n",
437 | " (linear): Sequential(\n",
438 | " (0): Flatten(start_dim=1, end_dim=-1)\n",
439 | " (1): Linear(in_features=64, out_features=1, bias=True)\n",
440 | " )\n",
441 | ")"
442 | ]
443 | },
444 | "execution_count": null,
445 | "metadata": {},
446 | "output_type": "execute_result"
447 | }
448 | ],
449 | "source": [
450 | "im"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": null,
456 | "metadata": {},
457 | "outputs": [
458 | {
459 | "data": {
460 | "text/plain": [
461 | "torch.Size([2, 1])"
462 | ]
463 | },
464 | "execution_count": null,
465 | "metadata": {},
466 | "output_type": "execute_result"
467 | }
468 | ],
469 | "source": [
470 | "x = torch.rand(2, 128, 32, 32)\n",
471 | "im(x).shape"
472 | ]
473 | },
474 | {
475 | "cell_type": "markdown",
476 | "metadata": {},
477 | "source": [
478 | "## Everything Together..."
479 | ]
480 | },
481 | {
482 | "cell_type": "code",
483 | "execution_count": null,
484 | "metadata": {},
485 | "outputs": [],
486 | "source": [
487 | "#export\n",
488 | "class Eclipse(nn.Module):\n",
489 | " \"\"\"Not very parametric\"\"\"\n",
490 | " def __init__(self, n_in=3, n_out=4, horizon=5, img_size=(128, 128), n_gru_layers=4, n_res_layers=4, debug=False):\n",
491 | " super().__init__()\n",
492 | " store_attr()\n",
493 | " self.spatial_downsampler = SpatialDownsampler(n_in)\n",
494 | " self.temporal_model = TemporalModel(256, 3, input_shape=(img_size[0]//8, img_size[1]//8), start_out_channels=128)\n",
495 | " self.future_prediction = FuturePrediction(128, 128, n_gru_blocks=n_gru_layers, n_res_layers=n_res_layers)\n",
496 | " self.upsampler = Upsampler(n_out=n_out)\n",
497 | " self.irradiance = IrradianceModule()\n",
498 | " \n",
499 | " def zero_hidden(self, x, horizon):\n",
500 | " bs, ch, h, w = x.shape\n",
501 | " return x.new_zeros(bs, horizon, ch, h, w)\n",
502 | " \n",
503 | " def forward(self, imgs):\n",
504 | " x = torch.stack([self.spatial_downsampler(img) for img in imgs], dim=1)\n",
505 | " \n",
506 | " #encode temporal model\n",
507 | " states = self.temporal_model(x)\n",
508 | " if self.debug: print(f'{states.shape=}')\n",
509 | " \n",
510 | " #get hidden state\n",
511 | " present_state = states[:, -1:]\n",
512 | " if self.debug: print(f'{present_state.shape=}')\n",
513 | " \n",
514 | " \n",
515 | " # Prepare future prediction input\n",
516 | " hidden_state = present_state.squeeze()\n",
517 | " if self.debug: print(f'{hidden_state.shape=}')\n",
518 | " \n",
519 | " future_prediction_input = self.zero_hidden(hidden_state, self.horizon)\n",
520 | " \n",
521 | " # Recursively predict future states\n",
522 | " future_states = self.future_prediction(future_prediction_input, hidden_state)\n",
523 | "\n",
524 | " # Concatenate present state\n",
525 | " future_states = torch.cat([present_state, future_states], dim=1)\n",
526 | " if self.debug: print(f'{future_states.shape=}')\n",
527 | " \n",
528 | " #decode outputs\n",
529 | " masks, irradiances = [], []\n",
530 | "\n",
531 | " for state in future_states.unbind(dim=1):\n",
532 | " masks.append(self.upsampler(state))\n",
533 | " irradiances.append(self.irradiance(state))\n",
534 | " return {'masks': masks, 'irradiances': torch.cat(irradiances, dim=-1)}\n",
535 | " "
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": null,
541 | "metadata": {},
542 | "outputs": [],
543 | "source": [
544 | "eclipse = Eclipse(img_size=(256, 192), debug=True)"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {},
551 | "outputs": [
552 | {
553 | "name": "stdout",
554 | "output_type": "stream",
555 | "text": [
556 | "states.shape=torch.Size([2, 2, 128, 32, 24])\n",
557 | "present_state.shape=torch.Size([2, 1, 128, 32, 24])\n",
558 | "hidden_state.shape=torch.Size([2, 128, 32, 24])\n",
559 | "future_states.shape=torch.Size([2, 6, 128, 32, 24])\n"
560 | ]
561 | },
562 | {
563 | "data": {
564 | "text/plain": [
565 | "(torch.Size([2, 4, 256, 192]), torch.Size([2, 6]))"
566 | ]
567 | },
568 | "execution_count": null,
569 | "metadata": {},
570 | "output_type": "execute_result"
571 | }
572 | ],
573 | "source": [
574 | "preds = eclipse([torch.rand(2, 3, 256, 192) for _ in range(4)])\n",
575 | "\n",
576 | "preds['masks'][0].shape, preds['irradiances'].shape"
577 | ]
578 | },
579 | {
580 | "cell_type": "markdown",
581 | "metadata": {},
582 | "source": [
583 | "## Export"
584 | ]
585 | },
586 | {
587 | "cell_type": "code",
588 | "execution_count": null,
589 | "metadata": {},
590 | "outputs": [
591 | {
592 | "name": "stdout",
593 | "output_type": "stream",
594 | "text": [
595 | "Converted 00_model.ipynb.\n",
596 | "Converted 01_layers.ipynb.\n",
597 | "Converted index.ipynb.\n"
598 | ]
599 | }
600 | ],
601 | "source": [
602 | "# hide\n",
603 | "from nbdev.export import *\n",
604 | "notebook2script()"
605 | ]
606 | }
607 | ],
608 | "metadata": {
609 | "kernelspec": {
610 | "display_name": "Python 3",
611 | "language": "python",
612 | "name": "python3"
613 | }
614 | },
615 | "nbformat": 4,
616 | "nbformat_minor": 4
617 | }
618 |
--------------------------------------------------------------------------------
/nbs/01_layers.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "94f61f73-8683-4638-b730-6d9174eb212d",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "# default_exp layers"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "id": "0ea22c98-3207-4673-8ee8-7396637f45cd",
16 | "metadata": {},
17 | "source": [
18 | "# Layers\n",
19 | "> most of them come from [fiery](https://github.com/wayveai/fiery)"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "id": "113f4b7d-ace4-4416-a067-9d7025ef7ae6",
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "#export\n",
30 | "from eclipse_pytorch.imports import *"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "id": "e8f80af7-bcc4-43a8-a6bd-283991737d8f",
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "#export\n",
41 | "def get_activation(activation):\n",
42 | " if activation == 'relu':\n",
43 | " return nn.ReLU(inplace=True)\n",
44 | " elif activation == 'lrelu':\n",
45 | " return nn.LeakyReLU(0.1, inplace=True)\n",
46 | " elif activation == 'elu':\n",
47 | " return nn.ELU(inplace=True)\n",
48 | " elif activation == 'tanh':\n",
49 | " return nn.Tanh(inplace=True)\n",
50 | " else:\n",
51 | " raise ValueError('Invalid activation {}'.format(activation))\n",
52 | " \n",
53 | "def get_norm(norm, out_channels):\n",
54 | " if norm == 'bn':\n",
55 | " return nn.BatchNorm2d(out_channels)\n",
56 | " elif norm == 'in':\n",
57 | " return nn.InstanceNorm2d(out_channels)\n",
58 | " else:\n",
59 | " raise ValueError('Invalid norm {}'.format(norm))"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "f158184b-c6d0-4a22-8625-f8514b1b524f",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "#export\n",
70 | "\n",
71 | "def init_linear(m, act_func=None, init='auto', bias_std=0.01):\n",
72 | " if getattr(m,'bias',None) is not None and bias_std is not None:\n",
73 | " if bias_std != 0: normal_(m.bias, 0, bias_std)\n",
74 | " else: m.bias.data.zero_()\n",
75 | " if init=='auto':\n",
76 | " if act_func in (F.relu_,F.leaky_relu_): init = kaiming_uniform_\n",
77 | " else: init = getattr(act_func.__class__, '__default_init__', None)\n",
78 | " if init is None: init = getattr(act_func, '__default_init__', None)\n",
79 | " if init is not None: init(m.weight)\n",
80 | "\n",
81 | "\n",
82 | "class ConvBlock(nn.Sequential):\n",
83 | " \"\"\"2D convolution followed by\n",
84 | " - an optional normalisation (batch norm or instance norm)\n",
85 | " - an optional activation (ReLU, LeakyReLU, or tanh)\n",
86 | " \"\"\"\n",
87 | "\n",
88 | " def __init__(\n",
89 | " self,\n",
90 | " in_channels,\n",
91 | " out_channels=None,\n",
92 | " kernel_size=3,\n",
93 | " stride=1,\n",
94 | " norm='bn',\n",
95 | " activation='relu',\n",
96 | " bias=False,\n",
97 | " transpose=False,\n",
98 | " init='auto'\n",
99 | " ):\n",
100 | " \n",
101 | " out_channels = out_channels or in_channels\n",
102 | " padding = (kernel_size-1)//2\n",
103 | " conv_cls = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1)\n",
104 | " conv = conv_cls(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)\n",
105 | " if activation is not None: activation = get_activation(activation)\n",
106 | " init_linear(conv, activation, init=init)\n",
107 | " layers = [conv]\n",
108 | " if activation is not None: layers.append(activation)\n",
109 | " if norm is not None: layers.append(get_norm(norm, out_channels))\n",
110 | " super().__init__(*layers)"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "id": "cd4f1349-6d9d-4a3c-8be3-92f2a9d5c173",
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "#export\n",
121 | "class Bottleneck(nn.Module):\n",
122 | " \"\"\"\n",
123 | " Defines a bottleneck module with a residual connection\n",
124 | " \"\"\"\n",
125 | "\n",
126 | " def __init__(\n",
127 | " self,\n",
128 | " in_channels,\n",
129 | " out_channels=None,\n",
130 | " kernel_size=3,\n",
131 | " dilation=1,\n",
132 | " groups=1,\n",
133 | " upsample=False,\n",
134 | " downsample=False,\n",
135 | " dropout=0.0,\n",
136 | " ):\n",
137 | " super().__init__()\n",
138 | " self._downsample = downsample\n",
139 | " bottleneck_channels = int(in_channels / 2)\n",
140 | " out_channels = out_channels or in_channels\n",
141 | " padding_size = ((kernel_size - 1) * dilation + 1) // 2\n",
142 | "\n",
143 | " # Define the main conv operation\n",
144 | " assert dilation == 1\n",
145 | " if upsample:\n",
146 | " assert not downsample, 'downsample and upsample not possible simultaneously.'\n",
147 | " bottleneck_conv = nn.ConvTranspose2d(\n",
148 | " bottleneck_channels,\n",
149 | " bottleneck_channels,\n",
150 | " kernel_size=kernel_size,\n",
151 | " bias=False,\n",
152 | " dilation=1,\n",
153 | " stride=2,\n",
154 | " output_padding=padding_size,\n",
155 | " padding=padding_size,\n",
156 | " groups=groups,\n",
157 | " )\n",
158 | " elif downsample:\n",
159 | " bottleneck_conv = nn.Conv2d(\n",
160 | " bottleneck_channels,\n",
161 | " bottleneck_channels,\n",
162 | " kernel_size=kernel_size,\n",
163 | " bias=False,\n",
164 | " dilation=dilation,\n",
165 | " stride=2,\n",
166 | " padding=padding_size,\n",
167 | " groups=groups,\n",
168 | " )\n",
169 | " else:\n",
170 | " bottleneck_conv = nn.Conv2d(\n",
171 | " bottleneck_channels,\n",
172 | " bottleneck_channels,\n",
173 | " kernel_size=kernel_size,\n",
174 | " bias=False,\n",
175 | " dilation=dilation,\n",
176 | " padding=padding_size,\n",
177 | " groups=groups,\n",
178 | " )\n",
179 | "\n",
180 | " self.layers = nn.Sequential(\n",
181 | " OrderedDict(\n",
182 | " [\n",
183 | " # First projection with 1x1 kernel\n",
184 | " ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)),\n",
185 | " ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels),\n",
186 | " nn.ReLU(inplace=True))),\n",
187 | " # Second conv block\n",
188 | " ('conv', bottleneck_conv),\n",
189 | " ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))),\n",
190 | " # Final projection with 1x1 kernel\n",
191 | " ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)),\n",
192 | " ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels),\n",
193 | " nn.ReLU(inplace=True))),\n",
194 | " # Regulariser\n",
195 | " ('dropout', nn.Dropout2d(p=dropout)),\n",
196 | " ]\n",
197 | " )\n",
198 | " )\n",
199 | "\n",
200 | " if out_channels == in_channels and not downsample and not upsample:\n",
201 | " self.projection = None\n",
202 | " else:\n",
203 | " projection = OrderedDict()\n",
204 | " if upsample:\n",
205 | " projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)})\n",
206 | " elif downsample:\n",
207 | " projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)})\n",
208 | " projection.update(\n",
209 | " {\n",
210 | " 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),\n",
211 | " 'bn_skip_proj': nn.BatchNorm2d(out_channels),\n",
212 | " }\n",
213 | " )\n",
214 | " self.projection = nn.Sequential(projection)\n",
215 | "\n",
216 | " # pylint: disable=arguments-differ\n",
217 | " def forward(self, *args):\n",
218 | " (x,) = args\n",
219 | " x_residual = self.layers(x)\n",
220 | " if self.projection is not None:\n",
221 | " if self._downsample:\n",
222 | " # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer\n",
223 | " x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0)\n",
224 | " return x_residual + self.projection(x)\n",
225 | " return x_residual + x"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": null,
231 | "id": "1bd71c90-9228-4a9f-9821-20328c57ba4d",
232 | "metadata": {},
233 | "outputs": [],
234 | "source": [
235 | "#export\n",
236 | "class Interpolate(nn.Module):\n",
237 | " def __init__(self, scale_factor: int = 2):\n",
238 | " super().__init__()\n",
239 | " self._interpolate = nn.functional.interpolate\n",
240 | " self._scale_factor = scale_factor\n",
241 | "\n",
242 | " # pylint: disable=arguments-differ\n",
243 | " def forward(self, x):\n",
244 | " return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False)\n",
245 | "\n",
246 | "\n",
247 | "class UpsamplingConcat(nn.Module):\n",
248 | " def __init__(self, in_channels, out_channels, scale_factor=2):\n",
249 | " super().__init__()\n",
250 | "\n",
251 | " self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)\n",
252 | "\n",
253 | " self.conv = nn.Sequential(\n",
254 | " nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),\n",
255 | " nn.BatchNorm2d(out_channels),\n",
256 | " nn.ReLU(inplace=True),\n",
257 | " nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),\n",
258 | " nn.BatchNorm2d(out_channels),\n",
259 | " nn.ReLU(inplace=True),\n",
260 | " )\n",
261 | "\n",
262 | " def forward(self, x_to_upsample, x):\n",
263 | " x_to_upsample = self.upsample(x_to_upsample)\n",
264 | " x_to_upsample = torch.cat([x, x_to_upsample], dim=1)\n",
265 | " return self.conv(x_to_upsample)\n",
266 | "\n",
267 | "\n",
268 | "class UpsamplingAdd(nn.Module):\n",
269 | " def __init__(self, in_channels, out_channels, scale_factor=2):\n",
270 | " super().__init__()\n",
271 | " self.upsample_layer = nn.Sequential(\n",
272 | " nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),\n",
273 | " nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),\n",
274 | " nn.BatchNorm2d(out_channels),\n",
275 | " )\n",
276 | "\n",
277 | " def forward(self, x, x_skip):\n",
278 | " x = self.upsample_layer(x)\n",
279 | " return x + x_skip"
280 | ]
281 | },
282 | {
283 | "cell_type": "code",
284 | "execution_count": null,
285 | "id": "b46ad863-f4b8-400b-a62e-b1d593ff28ca",
286 | "metadata": {},
287 | "outputs": [],
288 | "source": [
289 | "#export\n",
290 | "class ResBlock(nn.Module):\n",
291 | " \" A simple resnet Block\"\n",
292 | " def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, norm='bn', activation='relu'):\n",
293 | " super().__init__()\n",
294 | " self.convs = nn.Sequential(ConvBlock(in_channels, out_channels, kernel_size, stride, norm=norm, activation=activation),\n",
295 | " ConvBlock(out_channels, out_channels, norm=norm, activation=activation)\n",
296 | " )\n",
297 | " id_path = [ConvBlock(in_channels, out_channels, kernel_size=1, activation=None, norm=None)]\n",
298 | " self.activation = get_activation(activation)\n",
299 | " if stride!=1: id_path.insert(1, nn.AvgPool2d(2, stride, ceil_mode=True))\n",
300 | " self.id_path = nn.Sequential(*id_path)\n",
301 | " \n",
302 | " def forward(self, x):\n",
303 | " return self.activation(self.convs(x) + self.id_path(x))"
304 | ]
305 | },
306 | {
307 | "cell_type": "code",
308 | "execution_count": null,
309 | "id": "87cd945e-757e-4d81-abee-8ba5fce5e02b",
310 | "metadata": {},
311 | "outputs": [],
312 | "source": [
313 | "res_block = ResBlock(64, 128, stride=2)"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": null,
319 | "id": "cc730d95-89a2-454b-8949-fa45879ccf41",
320 | "metadata": {},
321 | "outputs": [
322 | {
323 | "data": {
324 | "text/plain": [
325 | "torch.Size([1, 128, 10, 10])"
326 | ]
327 | },
328 | "execution_count": null,
329 | "metadata": {},
330 | "output_type": "execute_result"
331 | }
332 | ],
333 | "source": [
334 | "x = torch.rand(1,64, 20,20)\n",
335 | "res_block(x).shape"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": null,
341 | "id": "4568af38-4494-4fbe-aacb-6d3c5fa79e51",
342 | "metadata": {},
343 | "outputs": [],
344 | "source": [
345 | "#export\n",
346 | "class SpatialGRU(nn.Module):\n",
347 | " \"\"\"A GRU cell that takes an input tensor [BxTxCxHxW] and an optional previous state and passes a\n",
348 | " convolutional gated recurrent unit over the data\"\"\"\n",
349 | "\n",
350 | " def __init__(self, input_size, hidden_size, gru_bias_init=0.0, norm='bn', activation='relu'):\n",
351 | " super().__init__()\n",
352 | " self.input_size = input_size\n",
353 | " self.hidden_size = hidden_size\n",
354 | " self.gru_bias_init = gru_bias_init\n",
355 | "\n",
356 | " self.conv_update = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)\n",
357 | " self.conv_reset = nn.Conv2d(input_size + hidden_size, hidden_size, kernel_size=3, bias=True, padding=1)\n",
358 | "\n",
359 | " self.conv_state_tilde = ConvBlock(\n",
360 | " input_size + hidden_size, hidden_size, kernel_size=3, bias=False, norm=norm, activation=activation\n",
361 | " )\n",
362 | "\n",
363 | " def forward(self, x, state=None, flow=None, mode='bilinear'):\n",
364 | " # pylint: disable=unused-argument, arguments-differ\n",
365 | " # Check size\n",
366 | " assert len(x.size()) == 5, 'Input tensor must be BxTxCxHxW.'\n",
367 | " b, timesteps, c, h, w = x.size()\n",
368 | " assert c == self.input_size, f'feature sizes must match, got input {c} for layer with size {self.input_size}'\n",
369 | "\n",
370 | " # recurrent layers\n",
371 | " rnn_output = []\n",
372 | " rnn_state = torch.zeros(b, self.hidden_size, h, w, device=x.device) if state is None else state\n",
373 | " for t in range(timesteps):\n",
374 | " x_t = x[:, t]\n",
375 | " if flow is not None:\n",
376 | " rnn_state = warp_features(rnn_state, flow[:, t], mode=mode)\n",
377 | "\n",
378 | " # propagate rnn state\n",
379 | " rnn_state = self.gru_cell(x_t, rnn_state)\n",
380 | " rnn_output.append(rnn_state)\n",
381 | "\n",
382 | " # reshape rnn output to batch tensor\n",
383 | " return torch.stack(rnn_output, dim=1)\n",
384 | "\n",
385 | " def gru_cell(self, x, state):\n",
386 | " # Compute gates\n",
387 | " x_and_state = torch.cat([x, state], dim=1)\n",
388 | " update_gate = self.conv_update(x_and_state)\n",
389 | " reset_gate = self.conv_reset(x_and_state)\n",
390 | " # Add bias to initialise gate as close to identity function\n",
391 | " update_gate = torch.sigmoid(update_gate + self.gru_bias_init)\n",
392 | " reset_gate = torch.sigmoid(reset_gate + self.gru_bias_init)\n",
393 | "\n",
394 | " # Compute proposal state, activation is defined in norm_act_config (can be tanh, ReLU etc)\n",
395 | " state_tilde = self.conv_state_tilde(torch.cat([x, (1.0 - reset_gate) * state], dim=1))\n",
396 | "\n",
397 | " output = (1.0 - update_gate) * state + update_gate * state_tilde\n",
398 | " return output"
399 | ]
400 | },
401 | {
402 | "cell_type": "code",
403 | "execution_count": null,
404 | "id": "25fdb367-7794-48bc-94aa-1ba421a7c8b0",
405 | "metadata": {},
406 | "outputs": [],
407 | "source": [
408 | "sgru = SpatialGRU(input_size=32, hidden_size=64)"
409 | ]
410 | },
411 | {
412 | "cell_type": "markdown",
413 | "id": "a55cd120-755b-4783-b356-03044cb5abe3",
414 | "metadata": {},
415 | "source": [
416 | "without hidden state"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": null,
422 | "id": "8207186f-761b-4744-8e37-48f47101525d",
423 | "metadata": {},
424 | "outputs": [],
425 | "source": [
426 | "x = torch.rand(1,3,32,8, 8)\n",
427 | "test_eq(sgru(x).shape, (1,3,64,8,8))"
428 | ]
429 | },
430 | {
431 | "cell_type": "markdown",
432 | "id": "bdcaa79f-6d57-4b74-b956-f9dcc3df0589",
433 | "metadata": {},
434 | "source": [
435 | "with hidden\n",
436 | "> hidden.shape = `(bs, hidden_size, h, w)`"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": null,
442 | "id": "53e47adc-2ce8-4cd9-86b4-433cbd47dbd1",
443 | "metadata": {},
444 | "outputs": [
445 | {
446 | "data": {
447 | "text/plain": [
448 | "torch.Size([1, 3, 64, 8, 8])"
449 | ]
450 | },
451 | "execution_count": null,
452 | "metadata": {},
453 | "output_type": "execute_result"
454 | }
455 | ],
456 | "source": [
457 | "x = torch.rand(1,3,32,8, 8)\n",
458 | "hidden = torch.rand(1,64,8,8)\n",
459 | "# test_eq(sgru(x).shape, (1,3,64,8,8))\n",
460 | "\n",
461 | "sgru(x, hidden).shape"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": null,
467 | "id": "4f4a2fe0-535a-4d51-ab22-81bcefac5c7c",
468 | "metadata": {},
469 | "outputs": [],
470 | "source": [
471 | "#export\n",
472 | "class CausalConv3d(nn.Module):\n",
473 | " def __init__(self, in_channels, out_channels, kernel_size=(2, 3, 3), dilation=(1, 1, 1), bias=False):\n",
474 | " super().__init__()\n",
475 | " assert len(kernel_size) == 3, 'kernel_size must be a 3-tuple.'\n",
476 | " time_pad = (kernel_size[0] - 1) * dilation[0]\n",
477 | " height_pad = ((kernel_size[1] - 1) * dilation[1]) // 2\n",
478 | " width_pad = ((kernel_size[2] - 1) * dilation[2]) // 2\n",
479 | "\n",
480 | " # Pad temporally on the left\n",
481 | " self.pad = nn.ConstantPad3d(padding=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), value=0)\n",
482 | " self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, dilation=dilation, stride=1, padding=0, bias=bias)\n",
483 | " self.norm = nn.BatchNorm3d(out_channels)\n",
484 | " self.activation = nn.ReLU(inplace=True)\n",
485 | "\n",
486 | " def forward(self, *inputs):\n",
487 | " (x,) = inputs\n",
488 | " x = self.pad(x)\n",
489 | " x = self.conv(x)\n",
490 | " x = self.norm(x)\n",
491 | " x = self.activation(x)\n",
492 | " return x"
493 | ]
494 | },
495 | {
496 | "cell_type": "code",
497 | "execution_count": null,
498 | "id": "f97c1b07-dcb1-444a-8372-bdd3ff9dd95a",
499 | "metadata": {},
500 | "outputs": [],
501 | "source": [
502 | "cc3d = CausalConv3d(64, 128)"
503 | ]
504 | },
505 | {
506 | "cell_type": "code",
507 | "execution_count": null,
508 | "id": "c00cb8c1-7513-40e2-85ef-646ddd20d585",
509 | "metadata": {},
510 | "outputs": [
511 | {
512 | "data": {
513 | "text/plain": [
514 | "torch.Size([1, 128, 4, 8, 8])"
515 | ]
516 | },
517 | "execution_count": null,
518 | "metadata": {},
519 | "output_type": "execute_result"
520 | }
521 | ],
522 | "source": [
523 | "x = torch.rand(1,64,4,8,8)\n",
524 | "cc3d(x).shape"
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": null,
530 | "id": "6ac267c7-a4f3-47dd-a15a-04c61c82d434",
531 | "metadata": {},
532 | "outputs": [],
533 | "source": [
534 | "#export\n",
535 | "def conv_1x1x1_norm_activated(in_channels, out_channels):\n",
536 | " \"\"\"1x1x1 3D convolution, normalization and activation layer.\"\"\"\n",
537 | " return nn.Sequential(\n",
538 | " OrderedDict(\n",
539 | " [\n",
540 | " ('conv', nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False)),\n",
541 | " ('norm', nn.BatchNorm3d(out_channels)),\n",
542 | " ('activation', nn.ReLU(inplace=True)),\n",
543 | " ]\n",
544 | " )\n",
545 | " )"
546 | ]
547 | },
548 | {
549 | "cell_type": "code",
550 | "execution_count": null,
551 | "id": "e318b3fd-d944-4d98-8965-88a1dd2d8d8e",
552 | "metadata": {},
553 | "outputs": [
554 | {
555 | "data": {
556 | "text/plain": [
557 | "Sequential(\n",
558 | " (conv): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
559 | " (norm): BatchNorm3d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
560 | " (activation): ReLU(inplace=True)\n",
561 | ")"
562 | ]
563 | },
564 | "execution_count": null,
565 | "metadata": {},
566 | "output_type": "execute_result"
567 | }
568 | ],
569 | "source": [
570 | "conv_1x1x1_norm_activated(2, 4)"
571 | ]
572 | },
573 | {
574 | "cell_type": "code",
575 | "execution_count": null,
576 | "id": "5574dfd4-3891-464d-b0d6-a67466fed549",
577 | "metadata": {},
578 | "outputs": [],
579 | "source": [
580 | "#export\n",
581 | "class Bottleneck3D(nn.Module):\n",
582 | " \"\"\"\n",
583 | " Defines a bottleneck module with a residual connection\n",
584 | " \"\"\"\n",
585 | "\n",
586 | " def __init__(self, in_channels, out_channels=None, kernel_size=(2, 3, 3), dilation=(1, 1, 1)):\n",
587 | " super().__init__()\n",
588 | " bottleneck_channels = in_channels // 2\n",
589 | " out_channels = out_channels or in_channels\n",
590 | "\n",
591 | " self.layers = nn.Sequential(\n",
592 | " OrderedDict(\n",
593 | " [\n",
594 | " # First projection with 1x1 kernel\n",
595 | " ('conv_down_project', conv_1x1x1_norm_activated(in_channels, bottleneck_channels)),\n",
596 | " # Second conv block\n",
597 | " (\n",
598 | " 'conv',\n",
599 | " CausalConv3d(\n",
600 | " bottleneck_channels,\n",
601 | " bottleneck_channels,\n",
602 | " kernel_size=kernel_size,\n",
603 | " dilation=dilation,\n",
604 | " bias=False,\n",
605 | " ),\n",
606 | " ),\n",
607 | " # Final projection with 1x1 kernel\n",
608 | " ('conv_up_project', conv_1x1x1_norm_activated(bottleneck_channels, out_channels)),\n",
609 | " ]\n",
610 | " )\n",
611 | " )\n",
612 | "\n",
613 | " if out_channels != in_channels:\n",
614 | " self.projection = nn.Sequential(\n",
615 | " nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False),\n",
616 | " nn.BatchNorm3d(out_channels),\n",
617 | " )\n",
618 | " else:\n",
619 | " self.projection = None\n",
620 | "\n",
621 | " def forward(self, *args):\n",
622 | " (x,) = args\n",
623 | " x_residual = self.layers(x)\n",
624 | " x_features = self.projection(x) if self.projection is not None else x\n",
625 | " return x_residual + x_features"
626 | ]
627 | },
628 | {
629 | "cell_type": "code",
630 | "execution_count": null,
631 | "id": "3dbcf0d0-1cc2-4914-9ba0-30b99f03a4c3",
632 | "metadata": {},
633 | "outputs": [],
634 | "source": [
635 | "bn3d = Bottleneck3D(8, 12)"
636 | ]
637 | },
638 | {
639 | "cell_type": "code",
640 | "execution_count": null,
641 | "id": "12917923-c26f-4c40-b2ee-2836d2ffb4f6",
642 | "metadata": {},
643 | "outputs": [
644 | {
645 | "data": {
646 | "text/plain": [
647 | "torch.Size([1, 12, 4, 8, 8])"
648 | ]
649 | },
650 | "execution_count": null,
651 | "metadata": {},
652 | "output_type": "execute_result"
653 | }
654 | ],
655 | "source": [
656 | "x = torch.rand(1,8,4,8,8)\n",
657 | "bn3d(x).shape"
658 | ]
659 | },
660 | {
661 | "cell_type": "code",
662 | "execution_count": null,
663 | "id": "1bd3b9ac-7faa-4da6-b59b-86792e649590",
664 | "metadata": {},
665 | "outputs": [],
666 | "source": [
667 | "#export\n",
668 | "class PyramidSpatioTemporalPooling(nn.Module):\n",
669 | " \"\"\" Spatio-temporal pyramid pooling.\n",
670 | " Performs 3D average pooling followed by 1x1x1 convolution to reduce the number of channels and upsampling.\n",
671 | " Setting contains a list of kernel_size: usually it is [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)]\n",
672 | " \"\"\"\n",
673 | "\n",
674 | " def __init__(self, in_channels, reduction_channels, pool_sizes):\n",
675 | " super().__init__()\n",
676 | " self.features = []\n",
677 | " for pool_size in pool_sizes:\n",
678 | " assert pool_size[0] == 2, (\n",
679 | " \"Time kernel should be 2 as PyTorch raises an error when\" \"padding with more than half the kernel size\"\n",
680 | " )\n",
681 | " stride = (1, *pool_size[1:])\n",
682 | " padding = (pool_size[0] - 1, 0, 0)\n",
683 | " self.features.append(\n",
684 | " nn.Sequential(\n",
685 | " OrderedDict(\n",
686 | " [\n",
687 | " # Pad the input tensor but do not take into account zero padding into the average.\n",
688 | " (\n",
689 | " 'avgpool',\n",
690 | " torch.nn.AvgPool3d(\n",
691 | " kernel_size=pool_size, stride=stride, padding=padding, count_include_pad=False\n",
692 | " ),\n",
693 | " ),\n",
694 | " ('conv_bn_relu', conv_1x1x1_norm_activated(in_channels, reduction_channels)),\n",
695 | " ]\n",
696 | " )\n",
697 | " )\n",
698 | " )\n",
699 | " self.features = nn.ModuleList(self.features)\n",
700 | "\n",
701 | " def forward(self, *inputs):\n",
702 | " (x,) = inputs\n",
703 | " b, _, t, h, w = x.shape\n",
704 | " # Do not include current tensor when concatenating\n",
705 | " out = []\n",
706 | " for f in self.features:\n",
707 | " # Remove unnecessary padded values (time dimension) on the right\n",
708 | " x_pool = f(x)[:, :, :-1].contiguous()\n",
709 | " c = x_pool.shape[1]\n",
710 | " x_pool = nn.functional.interpolate(\n",
711 | " x_pool.view(b * t, c, *x_pool.shape[-2:]), (h, w), mode='bilinear', align_corners=False\n",
712 | " )\n",
713 | " x_pool = x_pool.view(b, c, t, h, w)\n",
714 | " out.append(x_pool)\n",
715 | " out = torch.cat(out, 1)\n",
716 | " return out"
717 | ]
718 | },
719 | {
720 | "cell_type": "code",
721 | "execution_count": null,
722 | "id": "45902f68-7d7f-4935-ac97-00e332ef9d48",
723 | "metadata": {},
724 | "outputs": [],
725 | "source": [
726 | "h,w = (64,64)\n",
727 | "ptp = PyramidSpatioTemporalPooling(12, 12, [(2, h, w), (2, h//2, w//2), (2, h//4, w//4)])"
728 | ]
729 | },
730 | {
731 | "cell_type": "code",
732 | "execution_count": null,
733 | "id": "0e0db71a-dfef-4389-bc9c-483adc060008",
734 | "metadata": {},
735 | "outputs": [
736 | {
737 | "data": {
738 | "text/plain": [
739 | "torch.Size([1, 36, 4, 64, 64])"
740 | ]
741 | },
742 | "execution_count": null,
743 | "metadata": {},
744 | "output_type": "execute_result"
745 | }
746 | ],
747 | "source": [
748 | "x = torch.rand(1,12,4,64,64)\n",
749 | "\n",
750 | "#the output is concatenated...\n",
751 | "ptp(x).shape"
752 | ]
753 | },
754 | {
755 | "cell_type": "code",
756 | "execution_count": null,
757 | "id": "7575bf79-6718-4a70-90ed-985b79fea483",
758 | "metadata": {},
759 | "outputs": [],
760 | "source": [
761 | "#export\n",
762 | "class TemporalBlock(nn.Module):\n",
763 | " \"\"\" Temporal block with the following layers:\n",
764 | " - 2x3x3, 1x3x3, spatio-temporal pyramid pooling\n",
765 | " - dropout\n",
766 | " - skip connection.\n",
767 | " \"\"\"\n",
768 | "\n",
769 | " def __init__(self, in_channels, out_channels=None, use_pyramid_pooling=False, pool_sizes=None):\n",
770 | " super().__init__()\n",
771 | " self.in_channels = in_channels\n",
772 | " self.half_channels = in_channels // 2\n",
773 | " self.out_channels = out_channels or self.in_channels\n",
774 | " self.kernels = [(2, 3, 3), (1, 3, 3)]\n",
775 | "\n",
776 | " # Flag for spatio-temporal pyramid pooling\n",
777 | " self.use_pyramid_pooling = use_pyramid_pooling\n",
778 | "\n",
779 | " # 3 convolution paths: 2x3x3, 1x3x3, 1x1x1\n",
780 | " self.convolution_paths = []\n",
781 | " for kernel_size in self.kernels:\n",
782 | " self.convolution_paths.append(\n",
783 | " nn.Sequential(\n",
784 | " conv_1x1x1_norm_activated(self.in_channels, self.half_channels),\n",
785 | " CausalConv3d(self.half_channels, self.half_channels, kernel_size=kernel_size),\n",
786 | " )\n",
787 | " )\n",
788 | " self.convolution_paths.append(conv_1x1x1_norm_activated(self.in_channels, self.half_channels))\n",
789 | " self.convolution_paths = nn.ModuleList(self.convolution_paths)\n",
790 | "\n",
791 | " agg_in_channels = len(self.convolution_paths) * self.half_channels\n",
792 | "\n",
793 | " if self.use_pyramid_pooling:\n",
794 | " assert pool_sizes is not None, \"setting must contain the list of kernel_size, but is None.\"\n",
795 | " reduction_channels = self.in_channels // 3\n",
796 | " self.pyramid_pooling = PyramidSpatioTemporalPooling(self.in_channels, reduction_channels, pool_sizes)\n",
797 | " agg_in_channels += len(pool_sizes) * reduction_channels\n",
798 | "\n",
799 | " # Feature aggregation\n",
800 | " self.aggregation = nn.Sequential(\n",
801 | " conv_1x1x1_norm_activated(agg_in_channels, self.out_channels),)\n",
802 | "\n",
803 | " if self.out_channels != self.in_channels:\n",
804 | " self.projection = nn.Sequential(\n",
805 | " nn.Conv3d(self.in_channels, self.out_channels, kernel_size=1, bias=False),\n",
806 | " nn.BatchNorm3d(self.out_channels),\n",
807 | " )\n",
808 | " else:\n",
809 | " self.projection = None\n",
810 | "\n",
811 | " def forward(self, *inputs):\n",
812 | " (x,) = inputs\n",
813 | " x_paths = []\n",
814 | " for conv in self.convolution_paths:\n",
815 | " x_paths.append(conv(x))\n",
816 | " x_residual = torch.cat(x_paths, dim=1)\n",
817 | " if self.use_pyramid_pooling:\n",
818 | " x_pool = self.pyramid_pooling(x)\n",
819 | " x_residual = torch.cat([x_residual, x_pool], dim=1)\n",
820 | " x_residual = self.aggregation(x_residual)\n",
821 | "\n",
822 | " if self.out_channels != self.in_channels:\n",
823 | " x = self.projection(x)\n",
824 | " x = x + x_residual\n",
825 | " return x"
826 | ]
827 | },
828 | {
829 | "cell_type": "code",
830 | "execution_count": null,
831 | "id": "385f0475-0c44-4b90-83bf-6869c24e105e",
832 | "metadata": {},
833 | "outputs": [
834 | {
835 | "data": {
836 | "text/plain": [
837 | "torch.Size([1, 8, 4, 64, 64])"
838 | ]
839 | },
840 | "execution_count": null,
841 | "metadata": {},
842 | "output_type": "execute_result"
843 | }
844 | ],
845 | "source": [
846 | "tb = TemporalBlock(4,8)\n",
847 | "\n",
848 | "x = torch.rand(1,4,4,64,64)\n",
849 | "\n",
850 | "tb(x).shape"
851 | ]
852 | },
853 | {
854 | "cell_type": "code",
855 | "execution_count": null,
856 | "id": "65e50f4d-40f1-48fc-b390-66d0fe7a4f16",
857 | "metadata": {},
858 | "outputs": [
859 | {
860 | "data": {
861 | "text/plain": [
862 | "TemporalBlock(\n",
863 | " (convolution_paths): ModuleList(\n",
864 | " (0): Sequential(\n",
865 | " (0): Sequential(\n",
866 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
867 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
868 | " (activation): ReLU(inplace=True)\n",
869 | " )\n",
870 | " (1): CausalConv3d(\n",
871 | " (pad): ConstantPad3d(padding=(1, 1, 1, 1, 1, 0), value=0)\n",
872 | " (conv): Conv3d(2, 2, kernel_size=(2, 3, 3), stride=(1, 1, 1), bias=False)\n",
873 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
874 | " (activation): ReLU(inplace=True)\n",
875 | " )\n",
876 | " )\n",
877 | " (1): Sequential(\n",
878 | " (0): Sequential(\n",
879 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
880 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
881 | " (activation): ReLU(inplace=True)\n",
882 | " )\n",
883 | " (1): CausalConv3d(\n",
884 | " (pad): ConstantPad3d(padding=(1, 1, 1, 1, 0, 0), value=0)\n",
885 | " (conv): Conv3d(2, 2, kernel_size=(1, 3, 3), stride=(1, 1, 1), bias=False)\n",
886 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
887 | " (activation): ReLU(inplace=True)\n",
888 | " )\n",
889 | " )\n",
890 | " (2): Sequential(\n",
891 | " (conv): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
892 | " (norm): BatchNorm3d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
893 | " (activation): ReLU(inplace=True)\n",
894 | " )\n",
895 | " )\n",
896 | " (aggregation): Sequential(\n",
897 | " (0): Sequential(\n",
898 | " (conv): Conv3d(6, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
899 | " (norm): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
900 | " (activation): ReLU(inplace=True)\n",
901 | " )\n",
902 | " )\n",
903 | " (projection): Sequential(\n",
904 | " (0): Conv3d(4, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)\n",
905 | " (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
906 | " )\n",
907 | ")"
908 | ]
909 | },
910 | "execution_count": null,
911 | "metadata": {},
912 | "output_type": "execute_result"
913 | }
914 | ],
915 | "source": [
916 | "tb"
917 | ]
918 | },
919 | {
920 | "cell_type": "code",
921 | "execution_count": null,
922 | "id": "5f671904-4256-48d5-8787-72556d225456",
923 | "metadata": {},
924 | "outputs": [],
925 | "source": [
926 | "#export\n",
927 | "class FuturePrediction(torch.nn.Module):\n",
928 | " def __init__(self, in_channels, latent_dim, n_gru_blocks=3, n_res_layers=3):\n",
929 | " super().__init__()\n",
930 | " self.n_gru_blocks = n_gru_blocks\n",
931 | "\n",
932 | " # Convolutional recurrent model with z_t as an initial hidden state and inputs the sample\n",
933 | " # from the probabilistic model. The architecture of the model is:\n",
934 | " # [Spatial GRU - [Bottleneck] x n_res_layers] x n_gru_blocks\n",
935 | " self.spatial_grus = []\n",
936 | " self.res_blocks = []\n",
937 | "\n",
938 | " for i in range(self.n_gru_blocks):\n",
939 | " gru_in_channels = latent_dim if i == 0 else in_channels\n",
940 | " self.spatial_grus.append(SpatialGRU(gru_in_channels, in_channels))\n",
941 | " self.res_blocks.append(torch.nn.Sequential(*[Bottleneck(in_channels)\n",
942 | " for _ in range(n_res_layers)]))\n",
943 | "\n",
944 | " self.spatial_grus = torch.nn.ModuleList(self.spatial_grus)\n",
945 | " self.res_blocks = torch.nn.ModuleList(self.res_blocks)\n",
946 | "\n",
947 | " def forward(self, x, hidden_state):\n",
948 | " # x has shape (b, n_future, c, h, w), hidden_state (b, c, h, w)\n",
949 | " for i in range(self.n_gru_blocks):\n",
950 | " x = self.spatial_grus[i](x, hidden_state, flow=None)\n",
951 | " b, n_future, c, h, w = x.shape\n",
952 | "\n",
953 | " x = self.res_blocks[i](x.view(b * n_future, c, h, w))\n",
954 | " x = x.view(b, n_future, c, h, w)\n",
955 | "\n",
956 | " return x"
957 | ]
958 | },
959 | {
960 | "cell_type": "code",
961 | "execution_count": null,
962 | "id": "7ce0e2ce-a862-4631-b5f7-a6b2c9928415",
963 | "metadata": {},
964 | "outputs": [],
965 | "source": [
966 | "fp = FuturePrediction(32, 32, 4, 4)\n",
967 | "\n",
968 | "x = torch.rand(1,4, 32, 64, 64)\n",
969 | "hidden = torch.rand(1,32,64,64)\n",
970 | "test_eq(fp(x, hidden).shape, x.shape)"
971 | ]
972 | },
973 | {
974 | "cell_type": "markdown",
975 | "id": "542aa109-21e9-48b8-a87a-e4bbd5f59b8d",
976 | "metadata": {},
977 | "source": [
978 | "## Export"
979 | ]
980 | },
981 | {
982 | "cell_type": "code",
983 | "execution_count": null,
984 | "id": "d4c520b5-d666-4ebb-a9c0-02992b822b8d",
985 | "metadata": {},
986 | "outputs": [
987 | {
988 | "name": "stdout",
989 | "output_type": "stream",
990 | "text": [
991 | "Converted 00_model.ipynb.\n",
992 | "Converted 01_layers.ipynb.\n",
993 | "Converted index.ipynb.\n"
994 | ]
995 | }
996 | ],
997 | "source": [
998 | "# hide\n",
999 | "from nbdev.export import *\n",
1000 | "notebook2script()"
1001 | ]
1002 | }
1003 | ],
1004 | "metadata": {
1005 | "kernelspec": {
1006 | "display_name": "Python 3",
1007 | "language": "python",
1008 | "name": "python3"
1009 | }
1010 | },
1011 | "nbformat": 4,
1012 | "nbformat_minor": 5
1013 | }
1014 |
--------------------------------------------------------------------------------
/nbs/images/eclipse_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tcapelle/eclipse_pytorch/13b1c41b076a535eb01abf37f777f9fa8309ee48/nbs/images/eclipse_diagram.png
--------------------------------------------------------------------------------
/nbs/index.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Eclipse\n",
8 | "\n",
9 | "> Implementing [Paletta et al](https://arxiv.org/pdf/2104.12419v1.pdf) in Pytorch"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "Most of the codebase comes from [Fiery](https://github.com/wayveai/fiery)"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | ""
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "## Install"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "```bash\n",
38 | "pip install eclipse_pytorch\n",
39 | "```"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "## How to use"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "import torch\n",
56 | "\n",
57 | "from eclipse_pytorch.model import Eclipse"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [
65 | {
66 | "ename": "NameError",
67 | "evalue": "name 'TemporalModel' is not defined",
68 | "output_type": "error",
69 | "traceback": [
70 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
71 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
72 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0meclipse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mEclipse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhorizon\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
73 | "\u001b[0;32m~/Documents/eclipse/eclipse_pytorch/model.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, n_in, n_out, horizon, img_size, debug)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0mstore_attr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspatial_downsampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSpatialDownsampler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_in\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtemporal_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTemporalModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m256\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_shape\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg_size\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m//\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimg_size\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m//\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstart_out_channels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuture_prediction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mFuturePrediction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_gru_blocks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_res_layers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupsampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mUpsampler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
74 | "\u001b[0;31mNameError\u001b[0m: name 'TemporalModel' is not defined"
75 | ]
76 | }
77 | ],
78 | "source": [
79 | "eclipse = Eclipse(horizon=5)"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "let's simulte some input images:"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": null,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "images = [torch.rand(2, 3, 128, 128) for _ in range(4)]"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "preds = eclipse(images)"
105 | ]
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "metadata": {},
110 | "source": [
111 | "you get a dict with forecasted masks and irradiances:"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "len(preds['masks']), preds['masks'][0].shape, preds['irradiances'].shape"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "## Citation"
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "metadata": {},
133 | "source": [
134 | "```latex\n",
135 | "@article{paletta2021eclipse,\n",
136 | " title = {{ECLIPSE} : Envisioning Cloud Induced Perturbations in Solar Energy},\n",
137 | " author = {Quentin Paletta and Anthony Hu and Guillaume Arbod and Joan Lasenby},\n",
138 | " year = {2021},\n",
139 | " eprinttype = {arXiv},\n",
140 | " eprint = {2104.12419}\n",
141 | "}\n",
142 | "```"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "metadata": {},
148 | "source": [
149 | "## Contribute\n",
150 | "\n",
151 | "This repo is made with [nbdev](https://github.com/fastai/nbdev), please read the documentation to contribute"
152 | ]
153 | }
154 | ],
155 | "metadata": {
156 | "kernelspec": {
157 | "display_name": "Python 3",
158 | "language": "python",
159 | "name": "python3"
160 | }
161 | },
162 | "nbformat": 4,
163 | "nbformat_minor": 4
164 | }
165 |
--------------------------------------------------------------------------------
/settings.ini:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | host = github
3 | lib_name = eclipse_pytorch
4 | user = tcapelle
5 | description = A pytorch implementation of Eclipse
6 | keywords = pytorch nowcasting solar energy
7 | author = Thomas Capelle
8 | author_email = tcapelle@pm.me
9 | copyright = bla blabla
10 | branch = master
11 | version = 0.0.8
12 | min_python = 3.8
13 | audience = Developers
14 | language = English
15 | custom_sidebar = False
16 | license = apache2
17 | status = 2
18 | requirements = torch>1.6 fastcore
19 | nbs_path = nbs
20 | doc_path = docs
21 | recursive = False
22 | doc_host = https://tcapelle.github.io
23 | doc_baseurl = /eclipse_pytorch/
24 | git_url = https://github.com/tcapelle/eclipse_pytorch/tree/master/
25 | lib_path = eclipse_pytorch
26 | title = eclipse_pytorch
27 |
28 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pkg_resources import parse_version
2 | from configparser import ConfigParser
3 | import setuptools,re,sys
4 | assert parse_version(setuptools.__version__)>=parse_version('36.2')
5 |
6 | # note: all settings are in settings.ini; edit there, not here
7 | config = ConfigParser(delimiters=['='])
8 | config.read('settings.ini')
9 | cfg = config['DEFAULT']
10 |
11 | cfg_keys = 'version description keywords author author_email'.split()
12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split()
13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o)
14 | setup_cfg = {o:cfg[o] for o in cfg_keys}
15 |
16 | if len(sys.argv)>1 and sys.argv[1]=='version':
17 | print(setup_cfg['version'])
18 | exit()
19 |
20 | licenses = {
21 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'),
22 | 'mit': ('MIT License', 'OSI Approved :: MIT License'),
23 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'),
24 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'),
25 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'),
26 | }
27 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha',
28 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ]
29 | py_versions = '2.0 2.1 2.2 2.3 2.4 2.5 2.6 2.7 3.0 3.1 3.2 3.3 3.4 3.5 3.6 3.7 3.8'.split()
30 |
31 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None))
32 | min_python = cfg['min_python']
33 |
34 | requirements = ['pip', 'packaging']
35 | if cfg.get('requirements'): requirements += cfg.get('requirements','').split()
36 | if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split()
37 | dev_requirements = (cfg.get('dev_requirements') or '').split()
38 |
39 | long_description = open('README.md').read()
40 | # 
41 | for ext in ['png', 'svg']:
42 | long_description = re.sub(r'!\['+ext+'\]\((.*)\)', '+'/'+cfg['branch']+'/\\1)', long_description)
43 | long_description = re.sub(r'src=\"(.*)\.'+ext+'\"', 'src=\"https://raw.githubusercontent.com/{}/{}'.format(cfg['user'],cfg['lib_name'])+'/'+cfg['branch']+'/\\1.'+ext+'\"', long_description)
44 |
45 | setuptools.setup(
46 | name = cfg['lib_name'],
47 | license = lic[0],
48 | classifiers = [
49 | 'Development Status :: ' + statuses[int(cfg['status'])],
50 | 'Intended Audience :: ' + cfg['audience'].title(),
51 | 'Natural Language :: ' + cfg['language'].title(),
52 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []),
53 | url = cfg['git_url'],
54 | packages = setuptools.find_packages(),
55 | include_package_data = True,
56 | install_requires = requirements,
57 | extras_require={ 'dev': dev_requirements },
58 | python_requires = '>=' + cfg['min_python'],
59 | long_description = long_description,
60 | long_description_content_type = 'text/markdown',
61 | zip_safe = False,
62 | entry_points = { 'console_scripts': cfg.get('console_scripts','').split() },
63 | **setup_cfg)
64 |
65 |
--------------------------------------------------------------------------------