├── .github └── workflows │ └── publish-to-pypi.yml ├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── opensr_model ├── .editorconfig ├── .github │ ├── .stale.yml │ ├── ISSUE_TEMPLATE │ │ ├── bug_report.md │ │ ├── config.yml │ │ ├── feature_request.md │ │ └── question.md │ ├── PULL_REQUEST_TEMPLATE.md │ ├── dependabot.yml │ └── workflows │ │ ├── build.yml │ │ ├── greetings.yml │ │ └── release-drafter.yml ├── .gitignore ├── .pre-commit-config.yaml ├── __init__.py ├── autoencoder │ ├── __init__.py │ ├── autoencoder.py │ └── utils.py ├── configs │ └── config_10m.yaml ├── denoiser │ ├── __init__.py │ ├── unet.py │ └── utils.py ├── diffusion │ ├── __init__.py │ ├── latentdiffusion.py │ └── utils.py ├── srmodel.py └── utils.py ├── opensr_utils_demo.py ├── requirements.txt ├── resources ├── example.png ├── example2.png ├── example3.png ├── sr_example.png └── uncertainty_map.png └── setup.py /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | build-and-publish: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine build 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: __token__ 23 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 24 | run: | 25 | python -m build 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .vscode/* 3 | *.swp 4 | other_uncertainty_2.py 5 | sr_of_s2_tile.py 6 | creating_usecase_data/ 7 | other_stuff/ 8 | austria.py 9 | upload_to_drive.py 10 | client_secret.json 11 | *.DS_Store 12 | example_images/ 13 | client_secret.json 14 | 15 | *.svg 16 | old/ 17 | old/* 18 | *.safetensors 19 | *.ckpt 20 | *.pth 21 | *.pt 22 | *.ipynb 23 | *.JSON 24 | *.json 25 | ablation.py 26 | images/ 27 | images/* 28 | images_comparison 29 | images_comparison/* 30 | comp_s2_s2naip.py 31 | 32 | # Byte-compiled / optimized / DLL files 33 | __pycache__/ 34 | *.py[cod] 35 | *$py.class 36 | 37 | # C extensions 38 | *.so 39 | 40 | # Distribution / packaging 41 | .Python 42 | build/ 43 | develop-eggs/ 44 | dist/ 45 | downloads/ 46 | eggs/ 47 | .eggs/ 48 | lib/ 49 | lib64/ 50 | parts/ 51 | sdist/ 52 | var/ 53 | wheels/ 54 | share/python-wheels/ 55 | *.egg-info/ 56 | .installed.cfg 57 | *.egg 58 | MANIFEST 59 | 60 | # PyInstaller 61 | # Usually these files are written by a python script from a template 62 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 63 | *.manifest 64 | *.spec 65 | 66 | # Installer logs 67 | pip-log.txt 68 | pip-delete-this-directory.txt 69 | 70 | # Unit test / coverage reports 71 | htmlcov/ 72 | .tox/ 73 | .nox/ 74 | .coverage 75 | .coverage.* 76 | .cache 77 | nosetests.xml 78 | coverage.xml 79 | *.cover 80 | *.py,cover 81 | .hypothesis/ 82 | .pytest_cache/ 83 | cover/ 84 | 85 | # Translations 86 | *.mo 87 | *.pot 88 | 89 | # Django stuff: 90 | *.log 91 | local_settings.py 92 | db.sqlite3 93 | db.sqlite3-journal 94 | 95 | # Flask stuff: 96 | instance/ 97 | .webassets-cache 98 | 99 | # Scrapy stuff: 100 | .scrapy 101 | 102 | # Sphinx documentation 103 | docs/_build/ 104 | 105 | # PyBuilder 106 | .pybuilder/ 107 | target/ 108 | 109 | # Jupyter Notebook 110 | .ipynb_checkpoints 111 | 112 | # IPython 113 | profile_default/ 114 | ipython_config.py 115 | 116 | # pyenv 117 | # For a library or package, you might want to ignore these files since the code is 118 | # intended to run in multiple environments; otherwise, check them in: 119 | # .python-version 120 | 121 | # pipenv 122 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 123 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 124 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 125 | # install all needed dependencies. 126 | #Pipfile.lock 127 | 128 | # poetry 129 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 130 | # This is especially recommended for binary packages to ensure reproducibility, and is more 131 | # commonly ignored for libraries. 132 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 133 | #poetry.lock 134 | 135 | # pdm 136 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 137 | #pdm.lock 138 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 139 | # in version control. 140 | # https://pdm.fming.dev/#use-with-ide 141 | .pdm.toml 142 | 143 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 144 | __pypackages__/ 145 | 146 | # Celery stuff 147 | celerybeat-schedule 148 | celerybeat.pid 149 | 150 | # SageMath parsed files 151 | *.sage.py 152 | 153 | # Environments 154 | .env 155 | .venv 156 | env/ 157 | venv/ 158 | ENV/ 159 | env.bak/ 160 | venv.bak/ 161 | 162 | # Spyder project settings 163 | .spyderproject 164 | .spyproject 165 | 166 | # Rope project settings 167 | .ropeproject 168 | 169 | # mkdocs documentation 170 | /site 171 | 172 | # mypy 173 | .mypy_cache/ 174 | .dmypy.json 175 | dmypy.json 176 | 177 | # Pyre type checker 178 | .pyre/ 179 | 180 | # pytype static type analyzer 181 | .pytype/ 182 | 183 | # Cython debug symbols 184 | cython_debug/ 185 | 186 | # PyCharm 187 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 188 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 189 | # and can be added to the global gitignore or merged into this file. For a more nuclear 190 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 191 | #.idea/ 192 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | (a): License of Base Repository 2 | (b): (opensr-model/opensr_model/diffusion/latentdiffusion.py: Adaptation from CompVis LMU Munich 3 | 4 | 5 | 6 | ------------------ (a) base repository ---------------------------------------- 7 | 8 | The MIT License (MIT) 9 | Copyright (c) 2023 opensr-model 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in all 19 | copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 22 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 23 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 24 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 25 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 26 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 27 | OR OTHER DEALINGS IN THE SOFTWARE. 28 | 29 | ------------------ (b) LDM CompVis ---------------------------------------- 30 | 31 | MIT License 32 | 33 | Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich 34 | 35 | Permission is hereby granted, free of charge, to any person obtaining a copy 36 | of this software and associated documentation files (the "Software"), to deal 37 | in the Software without restriction, including without limitation the rights 38 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 39 | copies of the Software, and to permit persons to whom the Software is 40 | furnished to do so, subject to the following conditions: 41 | 42 | The above copyright notice and this permission notice shall be included in all 43 | copies or substantial portions of the Software. 44 | 45 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 46 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 47 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 48 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 49 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 50 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 51 | SOFTWARE. 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Latent Diffusion Super-Resolution - Sentinel 2 (LDSR-S2) 4 | This repository contains the code of the paper [Trustworthy Super-Resolution of Multispectral Sentinel-2 Imagery with Latent Diffusion](https://ieeexplore.ieee.org/abstract/document/10887321). 5 | 6 | **PLEASE NOTE**: 7 | - This model is currently research-grade code, more user-friendly adaptations are planned for the future. 8 | - This repository will leave the experimental stage with v1.0.0. 9 | 10 | ## Citation 11 | If you use this model in your work, please cite 12 | ```tex 13 | @ARTICLE{10887321, 14 | author={Donike, Simon and Aybar, Cesar and Gómez-Chova, Luis and Kalaitzis, Freddie}, 15 | journal={IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing}, 16 | title={Trustworthy Super-Resolution of Multispectral Sentinel-2 Imagery With Latent Diffusion}, 17 | year={2025}, 18 | volume={18}, 19 | number={}, 20 | pages={6940-6952}, 21 | doi={10.1109/JSTARS.2025.3542220}} 22 | ``` 23 | 24 | ## Install and Usage 25 | ```bash 26 | pip install opensr-model 27 | ``` 28 | 29 | Minimal Example 30 | ```python 31 | import opensr_model # import pachage 32 | model = opensr_model.SRLatentDiffusion(config, device=device) # create model 33 | model.load_pretrained(config.ckpt_version) # load checkpoint 34 | lr = torch.rand(1,4,128,128) # create data 35 | sr = model.forward(lr, custom_steps=100) # run SR 36 | ``` 37 | 38 | Run the 'demo.py' file to gain an understanding how the package works. It will SR and example tensor and save the according uncertainty map. 39 | Output of demo.py file: 40 | ![example](resources/sr_example.png) 41 | ![example](resources/uncertainty_map.png) 42 | 43 | ## Weights and Checkpoints 44 | The model should load automatically with the model.load_pretrained command. Alternatively, the checkpoints can be found on [HuggingFace](https://huggingface.co/simon-donike/RS-SR-LTDF/tree/main) 45 | 46 | ## Description 47 | This package contains the latent-diffusion model to super-resolute 10 and 20m bands of Sentinel-2. This repository contains the bare model. It can be embedded in the "opensr-utils" package in order to be applied to Sentinel-2 Imagery. 48 | 49 | ## S2 Examples 50 | Example on real S2 image 51 | ![example2](resources/example2.png) 52 | 53 | Examples on S2NAIP training dataset 54 | ![example](resources/example.png) 55 | 56 | 57 | ## Status 58 | This is a work in progress and published explicitly as a research preview. This repository will leave the experimental stage with the publication of v1.0.0. 59 | 60 | ## Results Preview 61 | Some example Sr scenes can be found as [super-resoluted tiffs](https://drive.google.com/drive/folders/1OBgYS6c8Kpe_JuGzWOQwOK6UYwhm-3Vh?usp=drive_link) on Doogle Drive. Scenes available: 62 | - Buenos Aires, Argentina 63 | - Blue Mountains, Australia 64 | - Louisville, USA 65 | - Kutahya, Türkyie 66 | - Catalunya, Spain 67 | 68 | [![PyPI Downloads](https://static.pepy.tech/badge/opensr-model)](https://pepy.tech/projects/opensr-model) 69 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import opensr_model 3 | from omegaconf import OmegaConf 4 | 5 | # ------------------------------------------------------------------------------------- 6 | # 0.0 Create Model 7 | device = "cuda" if torch.cuda.is_available() else "cpu" #try to run on GPU 8 | config = OmegaConf.load("opensr_model/configs/config_10m.yaml") # load config 9 | model = opensr_model.SRLatentDiffusion(config, device=device) # create model 10 | model.load_pretrained(config.ckpt_version) # load checkpoint 11 | assert model.training == False, "Model has to be in eval mode." 12 | 13 | # ------------------------------------------------------------------------------------- 14 | # 0.1 - Get Data Example 15 | from opensr_model.utils import download_from_HF 16 | lr = download_from_HF(file_name="example_lr.pt") 17 | lr = (lr/10_000).to(torch.float32).to(device) 18 | 19 | # ------------------------------------------------------------------------------------- 20 | # 1. Run Super-Resolution 21 | sr = model.forward(lr, custom_steps=100) 22 | 23 | # ------------------------------------------------------------------------------------- 24 | # 2. Run Encertainty Map Generation 25 | uncertainty_map = model.uncertainty_map(lr,n_variations=25,custom_steps=100) # create uncertainty map 26 | 27 | # ------------------------------------------------------------------------------------- 28 | # 3 Plot Examples 29 | from opensr_model.utils import plot_example,plot_uncertainty 30 | plot_example(lr,sr,out_file="sr_example.png") 31 | plot_uncertainty(uncertainty_map,out_file="uncertainty_map.png",normalize=True) 32 | 33 | -------------------------------------------------------------------------------- /opensr_model/.editorconfig: -------------------------------------------------------------------------------- 1 | # Check http://editorconfig.org for more information 2 | # This is the main config file for this project: 3 | root = true 4 | 5 | [*] 6 | charset = utf-8 7 | end_of_line = lf 8 | insert_final_newline = true 9 | indent_style = tab 10 | indent_size = 2 11 | trim_trailing_whitespace = true 12 | 13 | [*.{py, pyi}] 14 | indent_style = space 15 | indent_size = 4 16 | 17 | [Makefile] 18 | indent_style = tab 19 | 20 | [*.md] 21 | trim_trailing_whitespace = false 22 | 23 | [*.{diff,patch}] 24 | trim_trailing_whitespace = false 25 | -------------------------------------------------------------------------------- /opensr_model/.github/.stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /opensr_model/.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug report 3 | about: If something isn't working 🔧 4 | title: '' 5 | labels: bug 6 | assignees: 7 | --- 8 | 9 | ## 🐛 Bug Report 10 | 11 | 12 | 13 | ## 🔬 How To Reproduce 14 | 15 | Steps to reproduce the behavior: 16 | 17 | 1. ... 18 | 19 | ### Code sample 20 | 21 | 22 | 23 | ### Environment 24 | 25 | * OS: [e.g. Linux / Windows / macOS] 26 | * Python version, get it with: 27 | 28 | ```bash 29 | python --version 30 | ``` 31 | 32 | ### Screenshots 33 | 34 | 35 | 36 | ## 📈 Expected behavior 37 | 38 | 39 | 40 | ## 📎 Additional context 41 | 42 | 43 | -------------------------------------------------------------------------------- /opensr_model/.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | # Configuration: https://help.github.com/en/github/building-a-strong-community/configuring-issue-templates-for-your-repository 2 | 3 | blank_issues_enabled: false 4 | -------------------------------------------------------------------------------- /opensr_model/.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature request 3 | about: Suggest an idea for this project 🏖 4 | title: '' 5 | labels: enhancement 6 | assignees: 7 | --- 8 | 9 | ## 🚀 Feature Request 10 | 11 | 12 | 13 | ## 🔈 Motivation 14 | 15 | 16 | 17 | ## 🛰 Alternatives 18 | 19 | 20 | 21 | ## 📎 Additional context 22 | 23 | 24 | -------------------------------------------------------------------------------- /opensr_model/.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: ❓ Question 3 | about: Ask a question about this project 🎓 4 | title: '' 5 | labels: question 6 | assignees: 7 | --- 8 | 9 | ## Checklist 10 | 11 | 12 | 13 | - [ ] I've searched the project's [`issues`](https://github.com/ESAOpenSR/opensr-model/issues?q=is%3Aissue). 14 | 15 | ## ❓ Question 16 | 17 | 18 | 19 | How can I [...]? 20 | 21 | Is it possible to [...]? 22 | 23 | ## 📎 Additional context 24 | 25 | 26 | -------------------------------------------------------------------------------- /opensr_model/.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | 5 | ## Related Issue 6 | 7 | 8 | 9 | ## Type of Change 10 | 11 | 12 | 13 | - [ ] 📚 Examples / docs / tutorials / dependencies update 14 | - [ ] 🔧 Bug fix (non-breaking change which fixes an issue) 15 | - [ ] 🥂 Improvement (non-breaking change which improves an existing feature) 16 | - [ ] 🚀 New feature (non-breaking change which adds functionality) 17 | - [ ] 💥 Breaking change (fix or feature that would cause existing functionality to change) 18 | - [ ] 🔐 Security fix 19 | 20 | ## Checklist 21 | 22 | 23 | 24 | - [ ] I've read the [`CODE_OF_CONDUCT.md`](https://github.com/ESAOpenSR/opensr-model/blob/master/CODE_OF_CONDUCT.md) document. 25 | - [ ] I've read the [`CONTRIBUTING.md`](https://github.com/ESAOpenSR/opensr-model/blob/master/CONTRIBUTING.md) guide. 26 | - [ ] I've updated the code style using `make codestyle`. 27 | - [ ] I've written tests for all new methods and classes that I created. 28 | - [ ] I've written the docstring in Google format for all the methods and classes that I used. 29 | -------------------------------------------------------------------------------- /opensr_model/.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Configuration: https://dependabot.com/docs/config-file/ 2 | # Docs: https://docs.github.com/en/github/administering-a-repository/keeping-your-dependencies-updated-automatically 3 | 4 | version: 2 5 | 6 | updates: 7 | - package-ecosystem: "pip" 8 | directory: "/" 9 | schedule: 10 | interval: "daily" 11 | allow: 12 | - dependency-type: "all" 13 | commit-message: 14 | prefix: ":arrow_up:" 15 | open-pull-requests-limit: 50 16 | 17 | - package-ecosystem: "github-actions" 18 | directory: "/" 19 | schedule: 20 | interval: "daily" 21 | allow: 22 | - dependency-type: "all" 23 | commit-message: 24 | prefix: ":arrow_up:" 25 | open-pull-requests-limit: 50 26 | 27 | - package-ecosystem: "docker" 28 | directory: "/docker" 29 | schedule: 30 | interval: "weekly" 31 | allow: 32 | - dependency-type: "all" 33 | commit-message: 34 | prefix: ":arrow_up:" 35 | open-pull-requests-limit: 50 36 | -------------------------------------------------------------------------------- /opensr_model/.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.7", "3.9"] 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v2.2.2 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | 19 | - name: Install poetry 20 | run: make poetry-download 21 | 22 | - name: Set up cache 23 | uses: actions/cache@v2.1.6 24 | with: 25 | path: .venv 26 | key: venv-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('poetry.lock') }} 27 | - name: Install dependencies 28 | run: | 29 | poetry config virtualenvs.in-project true 30 | poetry install 31 | 32 | - name: Run style checks 33 | run: | 34 | make check-codestyle 35 | 36 | - name: Run tests 37 | run: | 38 | make test 39 | 40 | - name: Run safety checks 41 | run: | 42 | make check-safety 43 | -------------------------------------------------------------------------------- /opensr_model/.github/workflows/greetings.yml: -------------------------------------------------------------------------------- 1 | name: Greetings 2 | 3 | on: [pull_request, issues] 4 | 5 | jobs: 6 | greeting: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/first-interaction@v1 10 | with: 11 | repo-token: ${{ secrets.GITHUB_TOKEN }} 12 | pr-message: 'Hello @${{ github.actor }}, thank you for submitting a PR! We will respond as soon as possible.' 13 | issue-message: | 14 | Hello @${{ github.actor }}, thank you for your interest in our work! 15 | 16 | If this is a bug report, please provide screenshots and **minimum viable code to reproduce your issue**, otherwise we can not help you. 17 | -------------------------------------------------------------------------------- /opensr_model/.github/workflows/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name: Release Drafter 2 | 3 | on: 4 | push: 5 | # branches to consider in the event; optional, defaults to all 6 | branches: 7 | - master 8 | 9 | jobs: 10 | update_release_draft: 11 | runs-on: ubuntu-latest 12 | steps: 13 | # Drafts your next Release notes as Pull Requests are merged into "master" 14 | - uses: release-drafter/release-drafter@v5.15.0 15 | env: 16 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 17 | -------------------------------------------------------------------------------- /opensr_model/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/osx,python,pycharm,windows,visualstudio,visualstudiocode 3 | # Edit at https://www.gitignore.io/?templates=osx,python,pycharm,windows,visualstudio,visualstudiocode 4 | 5 | ### OSX ### 6 | # General 7 | .DS_Store 8 | .AppleDouble 9 | .LSOverride 10 | 11 | # Icon must end with two \r 12 | Icon 13 | 14 | # Thumbnails 15 | ._* 16 | 17 | # Files that might appear in the root of a volume 18 | .DocumentRevisions-V100 19 | .fseventsd 20 | .Spotlight-V100 21 | .TemporaryItems 22 | .Trashes 23 | .VolumeIcon.icns 24 | .com.apple.timemachine.donotpresent 25 | 26 | # Directories potentially created on remote AFP share 27 | .AppleDB 28 | .AppleDesktop 29 | Network Trash Folder 30 | Temporary Items 31 | .apdisk 32 | 33 | ### PyCharm ### 34 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 35 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 36 | 37 | # User-specific stuff 38 | .idea/**/workspace.xml 39 | .idea/**/tasks.xml 40 | .idea/**/usage.statistics.xml 41 | .idea/**/dictionaries 42 | .idea/**/shelf 43 | 44 | # Generated files 45 | .idea/**/contentModel.xml 46 | 47 | # Sensitive or high-churn files 48 | .idea/**/dataSources/ 49 | .idea/**/dataSources.ids 50 | .idea/**/dataSources.local.xml 51 | .idea/**/sqlDataSources.xml 52 | .idea/**/dynamic.xml 53 | .idea/**/uiDesigner.xml 54 | .idea/**/dbnavigator.xml 55 | 56 | # Gradle 57 | .idea/**/gradle.xml 58 | .idea/**/libraries 59 | 60 | # Gradle and Maven with auto-import 61 | # When using Gradle or Maven with auto-import, you should exclude module files, 62 | # since they will be recreated, and may cause churn. Uncomment if using 63 | # auto-import. 64 | # .idea/modules.xml 65 | # .idea/*.iml 66 | # .idea/modules 67 | # *.iml 68 | # *.ipr 69 | 70 | # CMake 71 | cmake-build-*/ 72 | 73 | # Mongo Explorer plugin 74 | .idea/**/mongoSettings.xml 75 | 76 | # File-based project format 77 | *.iws 78 | 79 | # IntelliJ 80 | out/ 81 | 82 | # mpeltonen/sbt-idea plugin 83 | .idea_modules/ 84 | 85 | # JIRA plugin 86 | atlassian-ide-plugin.xml 87 | 88 | # Cursive Clojure plugin 89 | .idea/replstate.xml 90 | 91 | # Crashlytics plugin (for Android Studio and IntelliJ) 92 | com_crashlytics_export_strings.xml 93 | crashlytics.properties 94 | crashlytics-build.properties 95 | fabric.properties 96 | 97 | # Editor-based Rest Client 98 | .idea/httpRequests 99 | 100 | # Android studio 3.1+ serialized cache file 101 | .idea/caches/build_file_checksums.ser 102 | 103 | ### PyCharm Patch ### 104 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 105 | 106 | # *.iml 107 | # modules.xml 108 | # .idea/misc.xml 109 | # *.ipr 110 | 111 | # Sonarlint plugin 112 | .idea/**/sonarlint/ 113 | 114 | # SonarQube Plugin 115 | .idea/**/sonarIssues.xml 116 | 117 | # Markdown Navigator plugin 118 | .idea/**/markdown-navigator.xml 119 | .idea/**/markdown-navigator/ 120 | 121 | ### Python ### 122 | # Byte-compiled / optimized / DLL files 123 | __pycache__/ 124 | *.py[cod] 125 | *$py.class 126 | 127 | # C extensions 128 | *.so 129 | 130 | # Distribution / packaging 131 | .Python 132 | build/ 133 | develop-eggs/ 134 | dist/ 135 | downloads/ 136 | eggs/ 137 | .eggs/ 138 | lib/ 139 | lib64/ 140 | parts/ 141 | sdist/ 142 | var/ 143 | wheels/ 144 | pip-wheel-metadata/ 145 | share/python-wheels/ 146 | *.egg-info/ 147 | .installed.cfg 148 | *.egg 149 | MANIFEST 150 | 151 | # PyInstaller 152 | # Usually these files are written by a python script from a template 153 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 154 | *.manifest 155 | *.spec 156 | 157 | # Installer logs 158 | pip-log.txt 159 | pip-delete-this-directory.txt 160 | 161 | # Unit test / coverage reports 162 | htmlcov/ 163 | .tox/ 164 | .nox/ 165 | .coverage 166 | .coverage.* 167 | .cache 168 | nosetests.xml 169 | coverage.xml 170 | *.cover 171 | .hypothesis/ 172 | .pytest_cache/ 173 | 174 | # Translations 175 | *.mo 176 | *.pot 177 | 178 | # Scrapy stuff: 179 | .scrapy 180 | 181 | # Sphinx documentation 182 | docs/_build/ 183 | 184 | # PyBuilder 185 | target/ 186 | 187 | # pyenv 188 | .python-version 189 | 190 | # poetry 191 | .venv 192 | 193 | # pipenv 194 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 195 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 196 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 197 | # install all needed dependencies. 198 | #Pipfile.lock 199 | 200 | # celery beat schedule file 201 | celerybeat-schedule 202 | 203 | # SageMath parsed files 204 | *.sage.py 205 | 206 | # Spyder project settings 207 | .spyderproject 208 | .spyproject 209 | 210 | # Rope project settings 211 | .ropeproject 212 | 213 | # Mr Developer 214 | .mr.developer.cfg 215 | .project 216 | .pydevproject 217 | 218 | # mkdocs documentation 219 | /site 220 | 221 | # mypy 222 | .mypy_cache/ 223 | .dmypy.json 224 | dmypy.json 225 | 226 | # Pyre type checker 227 | .pyre/ 228 | 229 | # Plugins 230 | .secrets.baseline 231 | 232 | ### VisualStudioCode ### 233 | .vscode/* 234 | !.vscode/tasks.json 235 | !.vscode/launch.json 236 | !.vscode/extensions.json 237 | 238 | ### VisualStudioCode Patch ### 239 | # Ignore all local history of files 240 | .history 241 | 242 | ### Windows ### 243 | # Windows thumbnail cache files 244 | Thumbs.db 245 | Thumbs.db:encryptable 246 | ehthumbs.db 247 | ehthumbs_vista.db 248 | 249 | # Dump file 250 | *.stackdump 251 | 252 | # Folder config file 253 | [Dd]esktop.ini 254 | 255 | # Recycle Bin used on file shares 256 | $RECYCLE.BIN/ 257 | 258 | # Windows Installer files 259 | *.cab 260 | *.msi 261 | *.msix 262 | *.msm 263 | *.msp 264 | 265 | # Windows shortcuts 266 | *.lnk 267 | 268 | ### VisualStudio ### 269 | ## Ignore Visual Studio temporary files, build results, and 270 | ## files generated by popular Visual Studio add-ons. 271 | ## 272 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 273 | 274 | # User-specific files 275 | *.rsuser 276 | *.suo 277 | *.user 278 | *.userosscache 279 | *.sln.docstates 280 | 281 | # User-specific files (MonoDevelop/Xamarin Studio) 282 | *.userprefs 283 | 284 | # Mono auto generated files 285 | mono_crash.* 286 | 287 | # Build results 288 | [Dd]ebug/ 289 | [Dd]ebugPublic/ 290 | [Rr]elease/ 291 | [Rr]eleases/ 292 | x64/ 293 | x86/ 294 | [Aa][Rr][Mm]/ 295 | [Aa][Rr][Mm]64/ 296 | bld/ 297 | [Bb]in/ 298 | [Oo]bj/ 299 | [Ll]og/ 300 | 301 | # Visual Studio 2015/2017 cache/options directory 302 | .vs/ 303 | # Uncomment if you have tasks that create the project's static files in wwwroot 304 | #wwwroot/ 305 | 306 | # Visual Studio 2017 auto generated files 307 | Generated\ Files/ 308 | 309 | # MSTest test Results 310 | [Tt]est[Rr]esult*/ 311 | [Bb]uild[Ll]og.* 312 | 313 | # NUnit 314 | *.VisualState.xml 315 | TestResult.xml 316 | nunit-*.xml 317 | 318 | # Build Results of an ATL Project 319 | [Dd]ebugPS/ 320 | [Rr]eleasePS/ 321 | dlldata.c 322 | 323 | # Benchmark Results 324 | BenchmarkDotNet.Artifacts/ 325 | 326 | # .NET Core 327 | project.lock.json 328 | project.fragment.lock.json 329 | artifacts/ 330 | 331 | # StyleCop 332 | StyleCopReport.xml 333 | 334 | # Files built by Visual Studio 335 | *_i.c 336 | *_p.c 337 | *_h.h 338 | *.ilk 339 | *.obj 340 | *.iobj 341 | *.pch 342 | *.pdb 343 | *.ipdb 344 | *.pgc 345 | *.pgd 346 | *.rsp 347 | *.sbr 348 | *.tlb 349 | *.tli 350 | *.tlh 351 | *.tmp 352 | *.tmp_proj 353 | *_wpftmp.csproj 354 | *.log 355 | *.vspscc 356 | *.vssscc 357 | .builds 358 | *.pidb 359 | *.svclog 360 | *.scc 361 | 362 | # Chutzpah Test files 363 | _Chutzpah* 364 | 365 | # Visual C++ cache files 366 | ipch/ 367 | *.aps 368 | *.ncb 369 | *.opendb 370 | *.opensdf 371 | *.sdf 372 | *.cachefile 373 | *.VC.db 374 | *.VC.VC.opendb 375 | 376 | # Visual Studio profiler 377 | *.psess 378 | *.vsp 379 | *.vspx 380 | *.sap 381 | 382 | # Visual Studio Trace Files 383 | *.e2e 384 | 385 | # TFS 2012 Local Workspace 386 | $tf/ 387 | 388 | # Guidance Automation Toolkit 389 | *.gpState 390 | 391 | # ReSharper is a .NET coding add-in 392 | _ReSharper*/ 393 | *.[Rr]e[Ss]harper 394 | *.DotSettings.user 395 | 396 | # JustCode is a .NET coding add-in 397 | .JustCode 398 | 399 | # TeamCity is a build add-in 400 | _TeamCity* 401 | 402 | # DotCover is a Code Coverage Tool 403 | *.dotCover 404 | 405 | # AxoCover is a Code Coverage Tool 406 | .axoCover/* 407 | !.axoCover/settings.json 408 | 409 | # Visual Studio code coverage results 410 | *.coverage 411 | *.coveragexml 412 | 413 | # NCrunch 414 | _NCrunch_* 415 | .*crunch*.local.xml 416 | nCrunchTemp_* 417 | 418 | # MightyMoose 419 | *.mm.* 420 | AutoTest.Net/ 421 | 422 | # Web workbench (sass) 423 | .sass-cache/ 424 | 425 | # Installshield output folder 426 | [Ee]xpress/ 427 | 428 | # DocProject is a documentation generator add-in 429 | DocProject/buildhelp/ 430 | DocProject/Help/*.HxT 431 | DocProject/Help/*.HxC 432 | DocProject/Help/*.hhc 433 | DocProject/Help/*.hhk 434 | DocProject/Help/*.hhp 435 | DocProject/Help/Html2 436 | DocProject/Help/html 437 | 438 | # Click-Once directory 439 | publish/ 440 | 441 | # Publish Web Output 442 | *.[Pp]ublish.xml 443 | *.azurePubxml 444 | # Note: Comment the next line if you want to checkin your web deploy settings, 445 | # but database connection strings (with potential passwords) will be unencrypted 446 | *.pubxml 447 | *.publishproj 448 | 449 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 450 | # checkin your Azure Web App publish settings, but sensitive information contained 451 | # in these scripts will be unencrypted 452 | PublishScripts/ 453 | 454 | # NuGet Packages 455 | *.nupkg 456 | # NuGet Symbol Packages 457 | *.snupkg 458 | # The packages folder can be ignored because of Package Restore 459 | **/[Pp]ackages/* 460 | # except build/, which is used as an MSBuild target. 461 | !**/[Pp]ackages/build/ 462 | # Uncomment if necessary however generally it will be regenerated when needed 463 | #!**/[Pp]ackages/repositories.config 464 | # NuGet v3's project.json files produces more ignorable files 465 | *.nuget.props 466 | *.nuget.targets 467 | 468 | # Microsoft Azure Build Output 469 | csx/ 470 | *.build.csdef 471 | 472 | # Microsoft Azure Emulator 473 | ecf/ 474 | rcf/ 475 | 476 | # Windows Store app package directories and files 477 | AppPackages/ 478 | BundleArtifacts/ 479 | Package.StoreAssociation.xml 480 | _pkginfo.txt 481 | *.appx 482 | *.appxbundle 483 | *.appxupload 484 | 485 | # Visual Studio cache files 486 | # files ending in .cache can be ignored 487 | *.[Cc]ache 488 | # but keep track of directories ending in .cache 489 | !?*.[Cc]ache/ 490 | 491 | # Others 492 | ClientBin/ 493 | ~$* 494 | *~ 495 | *.dbmdl 496 | *.dbproj.schemaview 497 | *.jfm 498 | *.pfx 499 | *.publishsettings 500 | orleans.codegen.cs 501 | 502 | # Including strong name files can present a security risk 503 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 504 | #*.snk 505 | 506 | # Since there are multiple workflows, uncomment next line to ignore bower_components 507 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 508 | #bower_components/ 509 | 510 | # RIA/Silverlight projects 511 | Generated_Code/ 512 | 513 | # Backup & report files from converting an old project file 514 | # to a newer Visual Studio version. Backup files are not needed, 515 | # because we have git ;-) 516 | _UpgradeReport_Files/ 517 | Backup*/ 518 | UpgradeLog*.XML 519 | UpgradeLog*.htm 520 | ServiceFabricBackup/ 521 | *.rptproj.bak 522 | 523 | # SQL Server files 524 | *.mdf 525 | *.ldf 526 | *.ndf 527 | 528 | # Business Intelligence projects 529 | *.rdl.data 530 | *.bim.layout 531 | *.bim_*.settings 532 | *.rptproj.rsuser 533 | *- [Bb]ackup.rdl 534 | *- [Bb]ackup ([0-9]).rdl 535 | *- [Bb]ackup ([0-9][0-9]).rdl 536 | 537 | # Microsoft Fakes 538 | FakesAssemblies/ 539 | 540 | # GhostDoc plugin setting file 541 | *.GhostDoc.xml 542 | 543 | # Node.js Tools for Visual Studio 544 | .ntvs_analysis.dat 545 | node_modules/ 546 | 547 | # Visual Studio 6 build log 548 | *.plg 549 | 550 | # Visual Studio 6 workspace options file 551 | *.opt 552 | 553 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 554 | *.vbw 555 | 556 | # Visual Studio LightSwitch build output 557 | **/*.HTMLClient/GeneratedArtifacts 558 | **/*.DesktopClient/GeneratedArtifacts 559 | **/*.DesktopClient/ModelManifest.xml 560 | **/*.Server/GeneratedArtifacts 561 | **/*.Server/ModelManifest.xml 562 | _Pvt_Extensions 563 | 564 | # Paket dependency manager 565 | .paket/paket.exe 566 | paket-files/ 567 | 568 | # FAKE - F# Make 569 | .fake/ 570 | 571 | # CodeRush personal settings 572 | .cr/personal 573 | 574 | # Python Tools for Visual Studio (PTVS) 575 | *.pyc 576 | 577 | # Cake - Uncomment if you are using it 578 | # tools/** 579 | # !tools/packages.config 580 | 581 | # Tabs Studio 582 | *.tss 583 | 584 | # Telerik's JustMock configuration file 585 | *.jmconfig 586 | 587 | # BizTalk build output 588 | *.btp.cs 589 | *.btm.cs 590 | *.odx.cs 591 | *.xsd.cs 592 | 593 | # OpenCover UI analysis results 594 | OpenCover/ 595 | 596 | # Azure Stream Analytics local run output 597 | ASALocalRun/ 598 | 599 | # MSBuild Binary and Structured Log 600 | *.binlog 601 | 602 | # NVidia Nsight GPU debugger configuration file 603 | *.nvuser 604 | 605 | # MFractors (Xamarin productivity tool) working folder 606 | .mfractor/ 607 | 608 | # Local History for Visual Studio 609 | .localhistory/ 610 | 611 | # BeatPulse healthcheck temp database 612 | healthchecksdb 613 | 614 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 615 | MigrationBackup/ 616 | 617 | # End of https://www.gitignore.io/api/osx,python,pycharm,windows,visualstudio,visualstudiocode 618 | -------------------------------------------------------------------------------- /opensr_model/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.7 3 | 4 | default_stages: [commit, push] 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v2.5.0 9 | hooks: 10 | - id: check-yaml 11 | - id: end-of-file-fixer 12 | exclude: LICENSE 13 | 14 | - repo: local 15 | hooks: 16 | - id: pyupgrade 17 | name: pyupgrade 18 | entry: poetry run pyupgrade --py37-plus 19 | types: [python] 20 | language: system 21 | 22 | - repo: local 23 | hooks: 24 | - id: isort 25 | name: isort 26 | entry: poetry run isort --settings-path pyproject.toml 27 | types: [python] 28 | language: system 29 | 30 | - repo: local 31 | hooks: 32 | - id: black 33 | name: black 34 | entry: poetry run black --config pyproject.toml 35 | types: [python] 36 | language: system 37 | -------------------------------------------------------------------------------- /opensr_model/__init__.py: -------------------------------------------------------------------------------- 1 | # type: ignore[attr-defined] 2 | """Latent diffusion model trained in RGBN optical remote sensing imagery""" 3 | 4 | import sys 5 | from opensr_model.srmodel import * 6 | from opensr_model import * 7 | 8 | if sys.version_info >= (3, 8): 9 | from importlib import metadata as importlib_metadata 10 | else: 11 | import importlib_metadata 12 | 13 | 14 | def get_version() -> str: 15 | try: 16 | return importlib_metadata.version(__name__) 17 | except importlib_metadata.PackageNotFoundError: # pragma: no cover 18 | return "unknown" 19 | 20 | 21 | version: str = get_version() 22 | -------------------------------------------------------------------------------- /opensr_model/autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/opensr-model/85bbf2dcc7937c8f38596ed5c26b1e75bd46c85f/opensr_model/autoencoder/__init__.py -------------------------------------------------------------------------------- /opensr_model/autoencoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from opensr_model.autoencoder.utils import (Downsample, Normalize, ResnetBlock, 6 | Upsample, make_attn, nonlinearity) 7 | from torch import nn 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__( 12 | self, 13 | *, 14 | ch: int, 15 | ch_mult: Tuple[int, int, int, int] = (1, 2, 4, 8), 16 | num_res_blocks: int, 17 | attn_resolutions: Tuple[int, ...], 18 | dropout: float = 0.0, 19 | resamp_with_conv: bool = True, 20 | in_channels: int, 21 | resolution: int, 22 | z_channels: int, 23 | double_z: bool = True, 24 | use_linear_attn: bool = False, 25 | attn_type: str = "vanilla", 26 | **ignorekwargs: dict, 27 | ): 28 | """ 29 | Encoder module responsible for downsampling and transforming an input image tensor. 30 | 31 | Args: 32 | ch (int): Base number of channels in the model. 33 | num_res_blocks (int): Number of residual blocks per resolution. 34 | attn_resolutions (tuple of int): Resolutions at which attention should be applied. 35 | in_channels (int): Number of channels in the input data. 36 | resolution (int): The resolution of the input data. 37 | z_channels (int): Number of channels for the latent variable 'z'. 38 | ch_mult (tuple of int, optional): Multipliers for the channels in different blocks. Defaults to (1, 2, 4, 8). 39 | dropout (float, optional): Dropout rate to use in ResNet blocks. Defaults to 0.0. 40 | resamp_with_conv (bool, optional): Whether to use convolution for downsampling. Defaults to True. 41 | double_z (bool, optional): If True, output channels will be doubled for 'z'. Defaults to True. 42 | use_linear_attn (bool, optional): If True, linear attention will be used. Overrides 'attn_type'. Defaults to False. 43 | attn_type (str, optional): Type of attention mechanism. Options are "vanilla" or "linear". Defaults to "vanilla". 44 | ignorekwargs (dict): Ignore extra keyword arguments. 45 | 46 | Examples: 47 | >>> encoder = Encoder(in_channels=3, z_channels=64, ch=32, resolution=64, num_res_blocks=2, attn_resolutions=(16, 8)) 48 | >>> x = torch.randn(1, 3, 64, 64) 49 | >>> z = encoder(x) 50 | """ 51 | super().__init__() 52 | 53 | # If linear attention is used, override the attention type. 54 | if use_linear_attn: 55 | attn_type = "linear" 56 | 57 | # Setting global attributes to create the encoder. 58 | self.ch = ch 59 | self.temb_ch = 0 60 | self.num_resolutions = len(ch_mult) 61 | self.num_res_blocks = num_res_blocks 62 | self.resolution = resolution 63 | self.in_channels = in_channels 64 | 65 | # Initial convolution for spectral reduction. 66 | self.conv_in = torch.nn.Conv2d( 67 | in_channels, self.ch, kernel_size=3, stride=1, padding=1 68 | ) 69 | 70 | # Downsampling with residual blocks and optionally attention 71 | curr_res = resolution 72 | in_ch_mult = (1,) + tuple(ch_mult) 73 | self.in_ch_mult = in_ch_mult 74 | self.down = nn.ModuleList() 75 | for i_level in range(self.num_resolutions): 76 | block = nn.ModuleList() 77 | attn = nn.ModuleList() 78 | block_in = ch * in_ch_mult[i_level] 79 | block_out = ch * ch_mult[i_level] 80 | for i_block in range(self.num_res_blocks): 81 | block.append( 82 | ResnetBlock( 83 | in_channels=block_in, 84 | out_channels=block_out, 85 | temb_channels=self.temb_ch, 86 | dropout=dropout, 87 | ) 88 | ) 89 | block_in = block_out 90 | if curr_res in attn_resolutions: 91 | attn.append(make_attn(block_in, attn_type=attn_type)) 92 | down = nn.Module() 93 | down.block = block 94 | down.attn = attn 95 | if i_level != self.num_resolutions - 1: 96 | down.downsample = Downsample(block_in, resamp_with_conv) 97 | curr_res = curr_res // 2 98 | self.down.append(down) 99 | 100 | # Upsampling with residual blocks and optionally attention 101 | self.mid = nn.Module() 102 | self.mid.block_1 = ResnetBlock( 103 | in_channels=block_in, 104 | out_channels=block_in, 105 | temb_channels=self.temb_ch, 106 | dropout=dropout, 107 | ) 108 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 109 | self.mid.block_2 = ResnetBlock( 110 | in_channels=block_in, 111 | out_channels=block_in, 112 | temb_channels=self.temb_ch, 113 | dropout=dropout, 114 | ) 115 | 116 | # Final convolution to get the latent variable 'z' 117 | self.norm_out = Normalize(block_in) 118 | self.conv_out = torch.nn.Conv2d( 119 | block_in, 120 | 2 * z_channels if double_z else z_channels, 121 | kernel_size=3, 122 | stride=1, 123 | padding=1, 124 | ) 125 | 126 | def forward(self, x: torch.Tensor) -> torch.Tensor: 127 | """ 128 | Forward pass of the Encoder. 129 | 130 | Args: 131 | x: Input tensor. 132 | 133 | Returns: 134 | Transformed tensor after passing through the Encoder. 135 | """ 136 | 137 | # timestep embedding (if needed in the next Diffusion runs!) 138 | temb = None 139 | 140 | # Initial downsampling 141 | hs = [self.conv_in(x)] 142 | 143 | # Downsampling through the layers 144 | for i_level in range(self.num_resolutions): 145 | for i_block in range(self.num_res_blocks): 146 | h = self.down[i_level].block[i_block](hs[-1], temb) 147 | if len(self.down[i_level].attn) > 0: 148 | h = self.down[i_level].attn[i_block](h) 149 | hs.append(h) 150 | if i_level != self.num_resolutions - 1: 151 | hs.append(self.down[i_level].downsample(hs[-1])) 152 | 153 | # Middle processing with blocks and attention 154 | h = hs[-1] 155 | h = self.mid.block_1(h, temb) 156 | h = self.mid.attn_1(h) 157 | h = self.mid.block_2(h, temb) 158 | 159 | # Final transformation to produce the output 160 | h = self.norm_out(h) 161 | h = nonlinearity(h) 162 | h = self.conv_out(h) 163 | return h 164 | 165 | 166 | class Decoder(nn.Module): 167 | def __init__( 168 | self, 169 | *, 170 | ch: int, 171 | out_ch: int, 172 | ch_mult: Tuple[int, int, int, int] = (1, 2, 4, 8), 173 | num_res_blocks: int, 174 | attn_resolutions: Tuple[int, ...], 175 | dropout: float = 0.0, 176 | resamp_with_conv: bool = True, 177 | in_channels: int, 178 | resolution: int, 179 | z_channels: int, 180 | give_pre_end: bool = False, 181 | tanh_out: bool = False, 182 | use_linear_attn: bool = False, 183 | attn_type: str = "vanilla", 184 | **ignorekwargs: dict, 185 | ): 186 | """ 187 | A Decoder class that converts a given encoded data 'z' back to its original state. 188 | 189 | Args: 190 | ch (int): Number of channels in the input data. 191 | out_ch (int): Number of channels in the output data. 192 | num_res_blocks (int): Number of residual blocks in the network. 193 | attn_resolutions (Tuple[int, ...]): Resolutions at which attention mechanisms are applied. 194 | in_channels (int): Number of channels in the encoded data 'z'. 195 | resolution (int): The resolution of the output image. 196 | z_channels (int): Number of channels in the latent space representation. 197 | ch_mult (Tuple[int, int, int, int], optional): Multiplier for channels at different resolution 198 | levels. Defaults to (1, 2, 4, 8). 199 | dropout (float, optional): Dropout rate for regularization. Defaults to 0.0. 200 | resamp_with_conv (bool, optional): Whether to use convolutional layers for upsampling. Defaults to True. 201 | give_pre_end (bool, optional): If set to True, returns the output before the last layer. Useful for further 202 | processing. Defaults to False. 203 | tanh_out (bool, optional): If set to True, applies tanh activation function to the output. Defaults to False. 204 | use_linear_attn (bool, optional): If set to True, uses linear attention mechanism. Defaults to False. 205 | attn_type (str, optional): Type of attention mechanism used ("vanilla" or "linear"). Defaults to "vanilla". 206 | ignorekwargs (dict): Ignore extra keyword arguments. 207 | 208 | Examples: 209 | >>> decoder = Decoder( 210 | ch=32, out_ch=3, z_channels=64, resolution=64, 211 | in_channels=64, num_res_blocks=2, 212 | attn_resolutions=(16, 8) 213 | ) 214 | >>> z = torch.randn(1, 64, 8, 8) 215 | >>> x_reconstructed = decoder(z) 216 | """ 217 | 218 | super().__init__() 219 | 220 | # If linear attention is required, set attention type as 'linear' 221 | if use_linear_attn: 222 | attn_type = "linear" 223 | 224 | # Initialize basic attributes for Decoding 225 | self.ch = ch 226 | self.temb_ch = 0 # Temporal embedding channel 227 | self.num_resolutions = len(ch_mult) 228 | self.num_res_blocks = num_res_blocks 229 | self.resolution = resolution 230 | self.in_channels = in_channels 231 | self.give_pre_end = give_pre_end # Controls the final output 232 | self.tanh_out = tanh_out # Apply tanh activation at the end 233 | 234 | # Compute input channel multiplier, initial block input channel and current resolution 235 | block_in = ch * ch_mult[self.num_resolutions - 1] 236 | curr_res = resolution // 2 ** (self.num_resolutions - 1) 237 | self.z_shape = (1, z_channels, curr_res, curr_res) 238 | 239 | # Display z-shape details 240 | print( 241 | "Working with z of shape {} = {} dimensions.".format( 242 | self.z_shape, np.prod(self.z_shape) 243 | ) 244 | ) 245 | 246 | # Conversion layer: From z dimension to block input channels 247 | self.conv_in = torch.nn.Conv2d( 248 | z_channels, block_in, kernel_size=3, stride=1, padding=1 249 | ) 250 | 251 | # Middle processing blocks 252 | self.mid = nn.Module() 253 | self.mid.block_1 = ResnetBlock( 254 | in_channels=block_in, 255 | out_channels=block_in, 256 | temb_channels=self.temb_ch, 257 | dropout=dropout, 258 | ) 259 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 260 | self.mid.block_2 = ResnetBlock( 261 | in_channels=block_in, 262 | out_channels=block_in, 263 | temb_channels=self.temb_ch, 264 | dropout=dropout, 265 | ) 266 | 267 | # Upsampling layers 268 | self.up = nn.ModuleList() 269 | for i_level in reversed(range(self.num_resolutions)): 270 | block = nn.ModuleList() 271 | attn = nn.ModuleList() 272 | 273 | # Apply ResNet blocks and attention at each resolution 274 | block_out = ch * ch_mult[i_level] 275 | for i_block in range(self.num_res_blocks + 1): 276 | block.append( 277 | ResnetBlock( 278 | in_channels=block_in, 279 | out_channels=block_out, 280 | temb_channels=self.temb_ch, 281 | dropout=dropout, 282 | ) 283 | ) 284 | block_in = block_out 285 | if curr_res in attn_resolutions: 286 | attn.append(make_attn(block_in, attn_type=attn_type)) 287 | 288 | up = nn.Module() 289 | up.block = block 290 | up.attn = attn 291 | 292 | # Upsampling operations 293 | if i_level != 0: 294 | up.upsample = Upsample(block_in, resamp_with_conv) 295 | curr_res = curr_res * 2 296 | 297 | # Keep the order consistent with original resolutions 298 | self.up.insert(0, up) 299 | 300 | # Final normalization and conversion layers 301 | self.norm_out = Normalize(block_in) 302 | self.conv_out = torch.nn.Conv2d( 303 | block_in, out_ch, kernel_size=3, stride=1, padding=1 304 | ) 305 | 306 | def forward(self, z: torch.Tensor) -> torch.Tensor: 307 | """Forward pass of the Decoder. 308 | 309 | Args: 310 | z (torch.Tensor): The latent variable 'z' to be decoded. 311 | 312 | Returns: 313 | torch.Tensor: Transformed tensor after passing through the Decoder. 314 | """ 315 | 316 | self.last_z_shape = z.shape 317 | 318 | # Time-step embedding (not used, in the Decoder part) 319 | temb = None 320 | 321 | # Convert z to block input 322 | h = self.conv_in(z) 323 | 324 | # Middle processing blocks 325 | h = self.mid.block_1(h, temb) 326 | h = self.mid.attn_1(h) 327 | h = self.mid.block_2(h, temb) 328 | 329 | # Upsampling steps 330 | for i_level in reversed(range(self.num_resolutions)): 331 | for i_block in range(self.num_res_blocks + 1): 332 | h = self.up[i_level].block[i_block](h, temb) 333 | if len(self.up[i_level].attn) > 0: 334 | h = self.up[i_level].attn[i_block](h) 335 | if i_level != 0: 336 | h = self.up[i_level].upsample(h) 337 | 338 | # Final output steps 339 | if self.give_pre_end: 340 | return h 341 | 342 | h = self.norm_out(h) 343 | h = nonlinearity(h) 344 | h = self.conv_out(h) 345 | 346 | # Apply tanh activation if required 347 | if self.tanh_out: 348 | h = torch.tanh(h) 349 | 350 | return h 351 | 352 | 353 | class DiagonalGaussianDistribution(object): 354 | """ 355 | Represents a multi-dimensional diagonal Gaussian distribution. 356 | 357 | The distribution is parameterized by means and diagonal variances 358 | (or standard deviations) for each dimension. This means that the 359 | covariance matrix of this Gaussian distribution is diagonal 360 | (i.e., non-diagonal elements are zero). 361 | 362 | Attributes: 363 | parameters (torch.Tensor): A tensor containing concatenated means and log-variances. 364 | mean (torch.Tensor): The mean of the Gaussian distribution. 365 | logvar (torch.Tensor): The logarithm of variances of the Gaussian distribution. 366 | deterministic (bool): If true, the variance is set to zero, making the distribution 367 | deterministic. 368 | std (torch.Tensor): The standard deviation of the Gaussian distribution. 369 | var (torch.Tensor): The variance of the Gaussian distribution. 370 | 371 | Examples: 372 | >>> params = torch.randn((1, 10)) # Assuming 5 for mean and 5 for log variance 373 | >>> dist = DiagonalGaussianDistribution(params) 374 | >>> sample = dist.sample() # Sample from the distribution 375 | """ 376 | 377 | def __init__(self, parameters: torch.Tensor, deterministic: bool = False): 378 | """ 379 | Initializes the DiagonalGaussianDistribution. 380 | 381 | Args: 382 | parameters (torch.Tensor): A tensor containing concatenated means and log-variances. 383 | deterministic (bool, optional): If set to true, this distribution becomes 384 | deterministic (i.e., has zero variance). 385 | """ 386 | self.parameters = parameters 387 | self.deterministic = deterministic 388 | 389 | # Split the parameters into means and log-variances 390 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 391 | 392 | # Limit the log variance values 393 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 394 | 395 | # Calculate standard deviation & variance from log variance 396 | self.std = torch.exp(0.5 * self.logvar) 397 | self.var = torch.exp(self.logvar) 398 | 399 | # If deterministic, set variance and standard deviation to zero 400 | if self.deterministic: 401 | self.var = self.std = torch.zeros_like(self.mean).to( 402 | device=self.parameters.device 403 | ) 404 | 405 | def sample(self) -> torch.Tensor: 406 | """ 407 | Sample from the Gaussian distribution. 408 | 409 | Returns: 410 | torch.Tensor: Sampled tensor. 411 | """ 412 | 413 | # Sample from a standard Gaussian distribution 414 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 415 | device=self.parameters.device 416 | ) 417 | 418 | return x 419 | 420 | def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: 421 | """ 422 | Compute the KL divergence between this Gaussian distribution and another. 423 | 424 | Args: 425 | other (DiagonalGaussianDistribution, optional): The other Gaussian 426 | distribution. If None, computes the KL divergence with a standard 427 | Gaussian (mean 0, variance 1). 428 | 429 | Returns: 430 | torch.Tensor: KL divergence values. 431 | """ 432 | if self.deterministic: 433 | return torch.Tensor([0.0]).to(device=self.parameters.device) 434 | else: 435 | if other is None: 436 | return 0.5 * torch.sum( 437 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 438 | dim=[1, 2, 3], 439 | ) 440 | else: 441 | return 0.5 * torch.sum( 442 | torch.pow(self.mean - other.mean, 2) / other.var 443 | + self.var / other.var 444 | - 1.0 445 | - self.logvar 446 | + other.logvar, 447 | dim=[1, 2, 3], 448 | ) 449 | 450 | def nll(self, sample: torch.Tensor, dims: list = [1, 2, 3]) -> torch.Tensor: 451 | """ 452 | Compute the negative log likelihood of a sample under this Gaussian distribution. 453 | 454 | Args: 455 | sample (torch.Tensor): The input sample tensor. 456 | dims (list, optional): The dimensions over which the sum is performed. Defaults 457 | to [1, 2, 3]. 458 | 459 | Returns: 460 | torch.Tensor: Negative log likelihood values. 461 | """ 462 | if self.deterministic: 463 | return torch.Tensor([0.0]).to(device=self.parameters.device) 464 | logtwopi = np.log(2.0 * np.pi) 465 | return 0.5 * torch.sum( 466 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 467 | dim=dims, 468 | ) 469 | 470 | def mode(self) -> torch.Tensor: 471 | """ 472 | Get the mode of the Gaussian distribution (which is equal to its mean). 473 | 474 | Returns: 475 | torch.Tensor: The mode (mean) of the Gaussian distribution. 476 | """ 477 | return self.mean 478 | 479 | 480 | class AutoencoderKL(nn.Module): 481 | """ 482 | Autoencoder with KL divergence regularization. 483 | 484 | This class implements an autoencoder model where the encoder outputs parameters of a 485 | Gaussian distribution, from which the latent representation can be sampled or its 486 | mode can be taken. The decoder then reconstructs the input from the latent 487 | representation. 488 | 489 | Attributes: 490 | encoder (Encoder): Encoder module. 491 | decoder (Decoder): Decoder module. 492 | quant_conv (torch.nn.Conv2d): Convolutional layer used to process encoder outputs 493 | into Gaussian parameters. 494 | post_quant_conv (torch.nn.Conv2d): Convolutional layer used after sampling/mode 495 | from the Gaussian distribution. 496 | embed_dim (int): Embedding dimension of the latent space. 497 | 498 | Examples: 499 | 500 | >>> ddconfig = { 501 | "z_channels": 16, "ch": 32, 502 | "out_ch": 3, "ch_mult": (1, 2, 4, 8), 503 | "resolution": 64, "in_channels": 3, 504 | "double_z": True, "num_res_blocks": 2, 505 | "attn_resolutions": (16, 8) 506 | } 507 | >>> embed_dim = 8 508 | >>> ae_model = AutoencoderKL(ddconfig, embed_dim) 509 | >>> data = torch.randn((1, 3, 64, 64)) 510 | >>> recon_data, posterior = ae_model(data) 511 | """ 512 | 513 | def __init__(self, ddconfig: dict, embed_dim: int): 514 | """ 515 | Initialize the AutoencoderKL. 516 | 517 | Args: 518 | ddconfig (dict): Configuration dictionary for the encoder and decoder. 519 | embed_dim (int): Embedding dimension of the latent space. 520 | """ 521 | super().__init__() 522 | 523 | # Initialize the encoder and decoder with provided configurations 524 | self.encoder = Encoder(**ddconfig) 525 | self.decoder = Decoder(**ddconfig) 526 | 527 | # Check if the configuration expects double the z_channels 528 | assert ddconfig["double_z"], "ddconfig must have 'double_z' set to True." 529 | 530 | # Define convolutional layers to transform between the latent space and Gaussian parameters 531 | self.quant_conv = nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 532 | self.post_quant_conv = nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 533 | 534 | self.embed_dim = embed_dim 535 | 536 | def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution: 537 | """ 538 | Pass the input through the encoder and return the posterior Gaussian 539 | distribution. 540 | 541 | Args: 542 | x (torch.Tensor): Input tensor. 543 | 544 | Returns: 545 | DiagonalGaussianDistribution: Gaussian distribution parameters from the 546 | encoded input. 547 | """ 548 | # Encoder's output 549 | h = self.encoder(x) 550 | 551 | # Convert encoder's output to Gaussian parameters 552 | moments = self.quant_conv(h) 553 | 554 | # Create a DiagonalGaussianDistribution using the moments 555 | posterior = DiagonalGaussianDistribution(moments) 556 | 557 | return posterior 558 | 559 | def decode(self, z: torch.Tensor) -> torch.Tensor: 560 | """ 561 | Decode the latent representation to reconstruct the input. 562 | 563 | Args: 564 | z (torch.Tensor): Latent representation. 565 | 566 | Returns: 567 | torch.Tensor: Reconstructed tensor. 568 | """ 569 | # Process latent representation through a convolutional layer 570 | z = self.post_quant_conv(z) 571 | 572 | # Decoder's output 573 | dec = self.decoder(z) 574 | 575 | return dec 576 | 577 | def forward(self, input: torch.Tensor, sample_posterior: bool = True) -> tuple: 578 | """ 579 | Forward pass of the autoencoder. 580 | 581 | Encodes the input, samples/modes from the resulting Gaussian distribution, 582 | and then decodes to get the reconstructed input. 583 | 584 | Args: 585 | input (torch.Tensor): Input tensor. 586 | sample_posterior (bool, optional): If True, sample from the posterior Gaussian 587 | distribution. If False, use its mode. Defaults to True. 588 | 589 | Returns: 590 | tuple: Reconstructed tensor and the posterior Gaussian distribution. 591 | """ 592 | 593 | # Encode the input to get the Gaussian distribution parameters 594 | posterior = self.encode(input) 595 | 596 | # Sample from the Gaussian distribution or take its mode 597 | z = posterior.sample() if sample_posterior else posterior.mode() 598 | 599 | # Decode the sampled/mode latent representation 600 | dec = self.decode(z) 601 | 602 | return dec, posterior 603 | -------------------------------------------------------------------------------- /opensr_model/autoencoder/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | 5 | 6 | def Normalize(in_channels: int, num_groups: int = 32) -> torch.nn.GroupNorm: 7 | """ 8 | Returns a GroupNorm layer that normalizes the input tensor along the channel dimension. 9 | 10 | Args: 11 | in_channels (int): Number of channels in the input tensor. 12 | num_groups (int): Number of groups to separate the channels into. Default is 32. 13 | 14 | Returns: 15 | torch.nn.GroupNorm: A GroupNorm layer that normalizes the input tensor along the 16 | channel dimension. 17 | 18 | Example: 19 | >>> input_tensor = torch.randn(1, 64, 32, 32) 20 | >>> norm_layer = Normalize(in_channels=64, num_groups=16) 21 | >>> output_tensor = norm_layer(input_tensor) 22 | """ 23 | # Create a GroupNorm layer with the specified number of groups and input channels 24 | # Set eps to a small value to avoid division by zero 25 | # Set affine to True to learn scaling and shifting parameters 26 | return torch.nn.GroupNorm( 27 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True 28 | ) 29 | 30 | 31 | def nonlinearity(x: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Applies a non-linear activation function to the input tensor x. 34 | 35 | Args: 36 | x (torch.Tensor): Input tensor. 37 | 38 | Returns: 39 | torch.Tensor: Output tensor with the same shape as the input tensor. 40 | 41 | Example: 42 | >>> input_tensor = torch.randn(10, 20) 43 | >>> output_tensor = nonlinearity(input_tensor) 44 | """ 45 | # Apply the sigmoid function to the input tensor 46 | sigmoid_x = torch.sigmoid(x) 47 | 48 | # Multiply the input tensor by the sigmoid of the input tensor 49 | output_tensor = x * sigmoid_x 50 | 51 | return output_tensor 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, in_channels: int, with_conv: bool): 56 | """ 57 | Initializes a Downsample module that reduces the spatial dimensions 58 | of the input tensor. 59 | 60 | Args: 61 | in_channels (int): Number of channels in the input tensor. 62 | with_conv (bool): Whether to use a convolutional layer for downsampling. 63 | 64 | Attributes: 65 | conv (torch.nn.Conv2d): Convolutional layer for downsampling. Only used 66 | if with_conv is True. 67 | 68 | Example: 69 | >>> input_tensor = torch.randn(1, 64, 32, 32) 70 | >>> downsample_module = Downsample(in_channels=64, with_conv=True) 71 | >>> output_tensor = downsample_module(input_tensor) 72 | """ 73 | super().__init__() 74 | self.with_conv = with_conv 75 | if self.with_conv: 76 | # Create a convolutional layer for downsampling 77 | # Use kernel size 3, stride 2, and no padding 78 | self.conv = torch.nn.Conv2d( 79 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 80 | ) 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | """ 84 | Applies the Downsample module to the input tensor x. 85 | 86 | Args: 87 | x (torch.Tensor): Input tensor with shape (batch_size, in_channels, height, width). 88 | 89 | Returns: 90 | torch.Tensor: Output tensor with shape (batch_size, in_channels, height/2, width/2) 91 | if with_conv is False, or (batch_size, in_channels, (height+1)/2, (width+1)/2) if 92 | with_conv is True. 93 | """ 94 | if self.with_conv: 95 | # Apply asymmetric padding to the input tensor 96 | pad = (0, 1, 0, 1) 97 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 98 | 99 | # Apply the convolutional layer to the padded input tensor 100 | x = self.conv(x) 101 | else: 102 | # Apply average pooling to the input tensor 103 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 104 | 105 | # Return the output tensor 106 | return x 107 | 108 | 109 | class Upsample(nn.Module): 110 | def __init__(self, in_channels: int, with_conv: bool): 111 | """ 112 | Initializes an Upsample module that increases the spatial dimensions of 113 | the input tensor. 114 | 115 | Args: 116 | in_channels (int): Number of channels in the input tensor. 117 | with_conv (bool): Whether to use a convolutional layer for upsampling. 118 | 119 | Attributes: 120 | with_conv (bool): Whether to use a convolutional layer for upsampling. 121 | conv (torch.nn.Conv2d): Convolutional layer for upsampling. Only used 122 | if with_conv is True. 123 | 124 | Example: 125 | >>> input_tensor = torch.randn(1, 64, 32, 32) 126 | >>> upsample_module = Upsample(in_channels=64, with_conv=True) 127 | >>> output_tensor = upsample_module(input_tensor) 128 | """ 129 | super().__init__() 130 | self.with_conv = with_conv 131 | if self.with_conv: 132 | # Create a convolutional layer for upsampling 133 | # Use kernel size 3, stride 1, and padding 1 134 | self.conv = torch.nn.Conv2d( 135 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 136 | ) 137 | 138 | def forward(self, x: torch.Tensor) -> torch.Tensor: 139 | """ 140 | Applies the Upsample module to the input tensor x. 141 | 142 | Args: 143 | x (torch.Tensor): Input tensor with shape (batch_size, in_channels, 144 | height, width). 145 | 146 | Returns: 147 | torch.Tensor: Output tensor with shape (batch_size, in_channels, height*2, width*2) 148 | if with_conv is False, or (batch_size, in_channels, height*2-1, width*2-1) if 149 | with_conv is True. 150 | """ 151 | # Apply nearest interpolation to the input tensor to double its spatial dimensions 152 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 153 | 154 | if self.with_conv: 155 | # Apply the convolutional layer to the upsampled input tensor 156 | x = self.conv(x) 157 | 158 | return x 159 | 160 | 161 | class ResnetBlock(nn.Module): 162 | def __init__( 163 | self, 164 | *, 165 | in_channels: int, 166 | out_channels: int = None, 167 | conv_shortcut: bool = False, 168 | dropout: float, 169 | temb_channels: int = 512, 170 | ): 171 | """ 172 | Initializes a ResnetBlock module that consists of two convolutional layers with batch 173 | normalization and a residual connection. 174 | 175 | Args: 176 | in_channels (int): Number of channels in the input tensor. 177 | out_channels (int, optional): Number of channels in the output tensor. If None, 178 | defaults to in_channels. 179 | conv_shortcut (bool): Whether to use a convolutional layer for the residual connection. 180 | If False, uses a 1x1 convolution. 181 | dropout (float): Dropout probability. 182 | temb_channels (int): Number of channels in the conditioning tensor. If 0, no conditioning 183 | is used. 184 | 185 | Attributes: 186 | in_channels (int): Number of channels in the input tensor. 187 | out_channels (int): Number of channels in the output tensor. 188 | use_conv_shortcut (bool): Whether to use a convolutional layer for the residual connection. 189 | norm1 (utils.Normalize): Batch normalization layer for the first convolutional layer. 190 | conv1 (torch.nn.Conv2d): First convolutional layer. 191 | temb_proj (torch.nn.Linear): Linear projection layer for the conditioning tensor. Only used 192 | if temb_channels > 0. 193 | norm2 (utils.Normalize): Batch normalization layer for the second convolutional layer. 194 | dropout (torch.nn.Dropout): Dropout layer. 195 | conv2 (torch.nn.Conv2d): Second convolutional layer. 196 | conv_shortcut (torch.nn.Conv2d): Convolutional layer for the residual connection. Only 197 | used if use_conv_shortcut is True. 198 | nin_shortcut (torch.nn.Conv2d): 1x1 convolutional layer for the residual connection. Only 199 | used if use_conv_shortcut is False. 200 | """ 201 | super().__init__() 202 | 203 | # Set the number of input and output channels 204 | self.in_channels = in_channels 205 | out_channels = in_channels if out_channels is None else out_channels 206 | self.out_channels = out_channels 207 | self.use_conv_shortcut = conv_shortcut 208 | 209 | # Batch normalization layer for the first convolutional layer 210 | self.norm1 = Normalize(in_channels) 211 | 212 | # First convolutional layer 213 | self.conv1 = torch.nn.Conv2d( 214 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 215 | ) 216 | 217 | # Linear projection layer for the conditioning tensor 218 | if temb_channels > 0: 219 | self.temb_proj = torch.nn.Linear(temb_channels, out_channels) 220 | 221 | # BN+Dropout+Conv layer for the last convolutional layer 222 | self.norm2 = Normalize(out_channels) 223 | self.dropout = torch.nn.Dropout(dropout) 224 | self.conv2 = torch.nn.Conv2d( 225 | out_channels, out_channels, kernel_size=3, stride=1, padding=1 226 | ) 227 | 228 | if self.in_channels != self.out_channels: 229 | if self.use_conv_shortcut: 230 | # 3x3 conv for the residual connection 231 | self.conv_shortcut = torch.nn.Conv2d( 232 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 233 | ) 234 | else: 235 | # 1x1 conv for the residual connection 236 | self.nin_shortcut = torch.nn.Conv2d( 237 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 238 | ) 239 | 240 | def forward(self, x, temb): 241 | """ 242 | Applies the ResnetBlock module to the input tensor x. 243 | 244 | Args: 245 | x (torch.Tensor): Input tensor with shape (batch_size, in_channels, height, width). 246 | temb (torch.Tensor): Conditioning tensor with shape (batch_size, temb_channels). 247 | 248 | Returns: 249 | torch.Tensor: Output tensor with the same shape as the input tensor. 250 | 251 | Example: 252 | >>> input_tensor = torch.randn(1, 64, 32, 32) 253 | >>> resnet_block = ResnetBlock(in_channels=64, out_channels=128, dropout=0.5) 254 | >>> output_tensor = resnet_block(input_tensor, temb=None) 255 | """ 256 | 257 | # BN+Sigmoid+Conv 258 | h = x 259 | h = self.norm1(h) 260 | h = nonlinearity(h) 261 | h = self.conv1(h) 262 | 263 | # Linear projection layer for the conditioning tensor 264 | if temb is not None: 265 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 266 | 267 | # BN+Sigmoid+Dropout+Conv 268 | h = self.norm2(h) 269 | h = nonlinearity(h) 270 | h = self.dropout(h) 271 | h = self.conv2(h) 272 | 273 | if self.in_channels != self.out_channels: 274 | if self.use_conv_shortcut: 275 | # 3x3 conv for the residual connection 276 | x = self.conv_shortcut(x) 277 | else: 278 | # 1x1 conv for the residual connection 279 | x = self.nin_shortcut(x) 280 | 281 | # Add the residual connection to the output tensor 282 | return x + h 283 | 284 | 285 | def make_attn(in_channels, attn_type="vanilla"): 286 | assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" 287 | print(f"making attention of type '{attn_type}' with {in_channels} in_channels") 288 | if attn_type == "vanilla": 289 | return AttnBlock(in_channels) 290 | elif attn_type == "none": 291 | return nn.Identity(in_channels) 292 | else: 293 | return LinAttnBlock(in_channels) 294 | 295 | 296 | def make_attn(in_channels: int, attn_type: str = "vanilla") -> nn.Module: 297 | """ 298 | Creates an attention module of the specified type. 299 | 300 | Args: 301 | in_channels (int): Number of channels in the input tensor. 302 | attn_type (str): Type of attention module to create. Must be one of "vanilla", 303 | "linear", or "none". Defaults to "vanilla". 304 | 305 | Returns: 306 | nn.Module: Attention module. 307 | 308 | Raises: 309 | AssertionError: If attn_type is not one of "vanilla", "linear", or "none". 310 | 311 | Example: 312 | >>> input_tensor = torch.randn(1, 64, 32, 32) 313 | >>> attn_module = make_attn(in_channels=64, attn_type="vanilla") 314 | >>> output_tensor = attn_module(input_tensor) 315 | """ 316 | assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" 317 | print(f"making attention of type '{attn_type}' with {in_channels} in_channels") 318 | if attn_type == "vanilla": 319 | # Create a vanilla attention module 320 | return AttnBlock(in_channels) 321 | elif attn_type == "none": 322 | # Create an identity module 323 | return nn.Identity(in_channels) 324 | else: 325 | # Create a linear attention module 326 | return LinAttnBlock(in_channels) 327 | 328 | 329 | class AttnBlock(nn.Module): 330 | """ 331 | An attention module that computes attention weights for each spatial location in the input tensor. 332 | 333 | Args: 334 | in_channels (int): Number of channels in the input tensor. 335 | 336 | Attributes: 337 | in_channels (int): Number of channels in the input tensor. 338 | norm (Normalize): Normalization layer for the input tensor. 339 | q (torch.nn.Conv2d): Convolutional layer for computing the query tensor. 340 | k (torch.nn.Conv2d): Convolutional layer for computing the key tensor. 341 | v (torch.nn.Conv2d): Convolutional layer for computing the value tensor. 342 | proj_out (torch.nn.Conv2d): Convolutional layer for projecting the attended tensor. 343 | 344 | Example: 345 | >>> input_tensor = torch.randn(1, 64, 32, 32) 346 | >>> attn_module = AttnBlock(in_channels=64) 347 | >>> output_tensor = attn_module(input_tensor) 348 | """ 349 | 350 | def __init__(self, in_channels: int): 351 | super().__init__() 352 | self.in_channels = in_channels 353 | 354 | self.norm = Normalize(in_channels) 355 | self.q = torch.nn.Conv2d( 356 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 357 | ) 358 | self.k = torch.nn.Conv2d( 359 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 360 | ) 361 | self.v = torch.nn.Conv2d( 362 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 363 | ) 364 | self.proj_out = torch.nn.Conv2d( 365 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 366 | ) 367 | 368 | def forward(self, x: torch.Tensor) -> torch.Tensor: 369 | """ 370 | Computes the output tensor of the attention module. 371 | 372 | Args: 373 | x (torch.Tensor): Input tensor. 374 | 375 | Returns: 376 | torch.Tensor: Output tensor with the same shape as the input tensor. 377 | """ 378 | h_ = x 379 | h_ = self.norm(h_) 380 | q = self.q(h_) 381 | k = self.k(h_) 382 | v = self.v(h_) 383 | 384 | # compute attention 385 | b, c, h, w = q.shape 386 | q = q.reshape(b, c, h * w) 387 | q = q.permute(0, 2, 1) # reshape q to b,hw,c and transpose to b,c,hw 388 | k = k.reshape(b, c, h * w) # reshape k to b,c,hw 389 | w_ = torch.bmm( 390 | q, k 391 | ) # compute attention weights w[b,i,j] = sum_c q[b,i,c]k[b,c,j] 392 | w_ = w_ * (int(c) ** (-0.5)) # scale the attention weights 393 | w_ = torch.nn.functional.softmax( 394 | w_, dim=2 395 | ) # apply softmax to get the attention probabilities 396 | 397 | # attend to values 398 | v = v.reshape(b, c, h * w) 399 | w_ = w_.permute(0, 2, 1) # transpose w to b,hw,hw (first hw of k, second of q) 400 | h_ = torch.bmm( 401 | v, w_ 402 | ) # compute the attended values h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 403 | h_ = h_.reshape(b, c, h, w) # reshape h_ to b,c,h,w 404 | h_ = self.proj_out(h_) # project the attended values to the output space 405 | 406 | return x + h_ 407 | 408 | 409 | class LinearAttention(nn.Module): 410 | """ 411 | A linear attention module that computes attention weights for each spatial 412 | location in the input tensor. 413 | 414 | Args: 415 | dim (int): Number of channels in the input tensor. 416 | heads (int): Number of attention heads. Defaults to 4. 417 | dim_head (int): Number of channels per attention head. Defaults to 32. 418 | 419 | Example: 420 | >>> input_tensor = torch.randn(1, 64, 32, 32) 421 | >>> attn_module = LinearAttention(dim=64, heads=8, dim_head=16) 422 | >>> output_tensor = attn_module(input_tensor) 423 | """ 424 | 425 | def __init__(self, dim: int, heads: int = 4, dim_head: int = 32): 426 | super().__init__() 427 | self.heads = heads 428 | hidden_dim = dim_head * heads 429 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 430 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 431 | 432 | def forward(self, x: torch.Tensor) -> torch.Tensor: 433 | """ 434 | Computes the output tensor of the attention module. 435 | 436 | Args: 437 | x (torch.Tensor): Input tensor. 438 | 439 | Returns: 440 | torch.Tensor: Output tensor with the same shape as the input tensor. 441 | """ 442 | b, c, h, w = x.shape 443 | qkv = self.to_qkv(x) 444 | q, k, v = rearrange( 445 | qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 446 | ) 447 | k = k.softmax(dim=-1) 448 | context = torch.einsum("bhdn,bhen->bhde", k, v) 449 | out = torch.einsum("bhde,bhdn->bhen", context, q) 450 | out = rearrange( 451 | out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w 452 | ) 453 | return self.to_out(out) 454 | 455 | 456 | class LinAttnBlock(LinearAttention): 457 | """to match AttnBlock usage""" 458 | 459 | def __init__(self, in_channels): 460 | super().__init__(dim=in_channels, heads=1, dim_head=in_channels) 461 | -------------------------------------------------------------------------------- /opensr_model/configs/config_10m.yaml: -------------------------------------------------------------------------------- 1 | # INFO 2 | # For ckpts >= opensr-model version 1.0, disable 'apply_normalization' 3 | # For ckpts >= opensr-model version 1.0, set linear_start to 0.0001 and linear_end to 0.01 4 | # For ckpts < opensr-model version 1.0, set linear_start to and linear_end to 0.0015 and 0.0155 respectively 5 | 6 | # General Settings 7 | apply_normalization: False 8 | ckpt_version: "opensr_10m_v4_v6.ckpt" 9 | encode_conditioning: True 10 | 11 | denoiser_settings: # noise settings 12 | linear_start: 0.0015 13 | linear_end: 0.0155 14 | timesteps: 1000 # Timesteps from training 15 | 16 | 17 | # AE Settings 18 | first_stage_config: 19 | embed_dim: 4 20 | double_z: true 21 | z_channels: 4 22 | resolution: 256 23 | in_channels: 4 24 | out_ch: 4 25 | ch: 128 26 | ch_mult: [1, 2, 4] 27 | num_res_blocks: 2 28 | attn_resolutions: [] 29 | dropout: 0.0 30 | 31 | # Denoiser Settings 32 | cond_stage_config: 33 | image_size: 64 34 | in_channels: 8 35 | model_channels: 160 36 | out_channels: 4 37 | num_res_blocks: 2 38 | attention_resolutions: [16, 8] 39 | channel_mult: [1, 2, 2, 4] 40 | num_head_channels: 32 41 | 42 | # other settings 43 | other: 44 | concat_mode: True 45 | cond_stage_trainable: False 46 | first_stage_key: "image" 47 | cond_stage_key: "LR_image" 48 | 49 | -------------------------------------------------------------------------------- /opensr_model/denoiser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/opensr-model/85bbf2dcc7937c8f38596ed5c26b1e75bd46c85f/opensr_model/denoiser/__init__.py -------------------------------------------------------------------------------- /opensr_model/denoiser/unet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Set, Tuple, Union 2 | 3 | import torch 4 | import torch as th 5 | from einops import rearrange 6 | from opensr_model.denoiser.utils import (BasicTransformerBlock, Downsample, 7 | Normalize, QKVAttention, 8 | QKVAttentionLegacy, TimestepBlock, 9 | Upsample, checkpoint, conv_nd, 10 | convert_module_to_f16, 11 | convert_module_to_f32, linear, 12 | normalization, timestep_embedding, 13 | zero_module) 14 | from torch import nn 15 | 16 | 17 | class ResBlock(TimestepBlock): 18 | def __init__( 19 | self, 20 | channels: int, 21 | emb_channels: int, 22 | dropout: float, 23 | out_channels: Optional[int] = None, 24 | use_conv: bool = False, 25 | use_scale_shift_norm: bool = False, 26 | dims: int = 2, 27 | use_checkpoint: bool = False, 28 | up: bool = False, 29 | down: bool = False, 30 | ): 31 | """ 32 | A residual block with optional timestep conditioning. 33 | 34 | Args: 35 | channels (int): The number of input channels. 36 | emb_channels (int): The number of timestep embedding channels. 37 | dropout (float): The dropout probability. 38 | out_channels (int, optional): The number of output channels. 39 | Defaults to None (same as input channels). 40 | use_conv (bool, optional): Whether to use a convolutional skip connection. 41 | Defaults to False. 42 | use_scale_shift_norm (bool, optional): Whether to use scale-shift normalization. 43 | Defaults to False. 44 | dims (int, optional): The number of dimensions in the input tensor. 45 | Defaults to 2. 46 | use_checkpoint (bool, optional): Whether to use checkpointing to save memory. 47 | Defaults to False. 48 | up (bool, optional): Whether to use upsampling in the skip connection. Defaults to 49 | False. 50 | down (bool, optional): Whether to use downsampling in the skip connection. Defaults to 51 | False. 52 | 53 | Example: 54 | >>> resblock = ResBlock(channels=64, emb_channels=32, dropout=0.1) 55 | >>> x = torch.randn(1, 64, 32, 32) 56 | >>> emb = torch.randn(1, 32) 57 | >>> out = resblock(x, emb) 58 | >>> print(out.shape) 59 | """ 60 | super().__init__() 61 | self.channels = channels 62 | self.emb_channels = emb_channels 63 | self.dropout = dropout 64 | self.out_channels = out_channels or channels 65 | self.use_conv = use_conv 66 | self.use_checkpoint = use_checkpoint 67 | self.use_scale_shift_norm = use_scale_shift_norm 68 | 69 | # input layers 70 | self.in_layers = nn.Sequential( 71 | normalization(channels), 72 | nn.SiLU(), 73 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 74 | ) 75 | 76 | # skip connection 77 | self.updown = up or down 78 | if up: 79 | self.h_upd = Upsample(channels, False, dims) 80 | self.x_upd = Upsample(channels, False, dims) 81 | elif down: 82 | self.h_upd = Downsample(channels, False, dims) 83 | self.x_upd = Downsample(channels, False, dims) 84 | else: 85 | self.h_upd = self.x_upd = nn.Identity() 86 | 87 | # timestep embedding layers 88 | self.emb_layers = nn.Sequential( 89 | nn.SiLU(), 90 | linear( 91 | emb_channels, 92 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 93 | ), 94 | ) 95 | 96 | # output layers 97 | self.out_layers = nn.Sequential( 98 | normalization(self.out_channels), 99 | nn.SiLU(), 100 | nn.Dropout(p=dropout), 101 | zero_module( 102 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 103 | ), 104 | ) 105 | 106 | # Skip connection 107 | if self.out_channels == channels: 108 | self.skip_connection = nn.Identity() 109 | elif use_conv: 110 | self.skip_connection = conv_nd( 111 | dims, channels, self.out_channels, 3, padding=1 112 | ) 113 | else: 114 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 115 | 116 | def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: 117 | """ 118 | Apply the block to a Tensor, conditioned on a timestep embedding. 119 | 120 | Args: 121 | x (torch.Tensor): An [N x C x ...] Tensor of features. 122 | emb (torch.Tensor): An [N x emb_channels] Tensor of timestep embeddings. 123 | 124 | Returns: 125 | torch.Tensor: An [N x C x ...] Tensor of outputs. 126 | """ 127 | return checkpoint( 128 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 129 | ) 130 | 131 | def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: 132 | if self.updown: 133 | # up/downsampling in skip connection 134 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 135 | h = in_rest(x) 136 | h = self.h_upd(h) 137 | x = self.x_upd(x) 138 | h = in_conv(h) 139 | else: 140 | h = self.in_layers(x) 141 | 142 | # timestep embedding 143 | emb_out = self.emb_layers(emb).type(h.dtype) 144 | while len(emb_out.shape) < len(h.shape): 145 | emb_out = emb_out[..., None] 146 | 147 | # scale-shift normalization 148 | if self.use_scale_shift_norm: 149 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 150 | scale, shift = th.chunk(emb_out, 2, dim=1) 151 | h = out_norm(h) * (1 + scale) + shift 152 | h = out_rest(h) 153 | else: 154 | h = h + emb_out 155 | h = self.out_layers(h) 156 | 157 | # skip connection 158 | return self.skip_connection(x) + h 159 | 160 | 161 | class AttentionBlock(nn.Module): 162 | """ 163 | An attention block that allows spatial positions to attend to each other. 164 | 165 | Args: 166 | channels (int): The number of input channels. 167 | num_heads (int, optional): The number of attention heads. Defaults to 1. 168 | num_head_channels (int, optional): The number of channels per attention head. 169 | If not specified, the input channels will be divided equally among the heads. 170 | Defaults to -1. 171 | use_checkpoint (bool, optional): Whether to use checkpointing to save memory. 172 | Defaults to False. 173 | use_new_attention_order (bool, optional): Whether to split the qkv tensor before 174 | splitting the heads. If False, the heads will be split before the qkv tensor. 175 | Defaults to False. 176 | 177 | Example: 178 | >>> attention_block = AttentionBlock(channels=64, num_heads=4) 179 | >>> x = torch.randn(1, 64, 32, 32) 180 | >>> out = attention_block(x) 181 | """ 182 | 183 | def __init__( 184 | self, 185 | channels: int, 186 | num_heads: Optional[int] = 1, 187 | num_head_channels: Optional[int] = -1, 188 | use_checkpoint: Optional[bool] = False, 189 | use_new_attention_order: Optional[bool] = False, 190 | ) -> None: 191 | super().__init__() 192 | 193 | # Set the number of input channels and attention heads 194 | self.channels = channels 195 | if num_head_channels == -1: 196 | self.num_heads = num_heads 197 | else: 198 | assert ( 199 | channels % num_head_channels == 0 200 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 201 | self.num_heads = channels // num_head_channels 202 | 203 | # Set whether to use checkpointing and create normalization layer 204 | self.use_checkpoint = use_checkpoint 205 | self.norm = normalization(channels) 206 | 207 | # Create convolutional layer for qkv tensor and attention module 208 | self.qkv = conv_nd(1, channels, channels * 3, 1) 209 | if use_new_attention_order: 210 | # split qkv before split heads 211 | self.attention = QKVAttention(self.num_heads) 212 | else: 213 | # split heads before split qkv 214 | self.attention = QKVAttentionLegacy(self.num_heads) 215 | 216 | # Create convolutional layer for output projection 217 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 218 | 219 | def forward(self, x: torch.Tensor) -> torch.Tensor: 220 | """ 221 | Apply the attention block to the input tensor. 222 | 223 | Args: 224 | x (torch.Tensor): The input tensor. 225 | 226 | Returns: 227 | torch.Tensor: The output tensor. 228 | """ 229 | return checkpoint(self._forward, (x,), self.parameters(), False) 230 | 231 | def _forward(self, x: torch.Tensor) -> torch.Tensor: 232 | """ 233 | Apply the attention block to the input tensor. 234 | 235 | Args: 236 | x (torch.Tensor): The input tensor. 237 | 238 | Returns: 239 | torch.Tensor: The output tensor. 240 | """ 241 | b, c, *spatial = x.shape 242 | x = x.reshape(b, c, -1) 243 | 244 | # Apply normalization and convolutional layer to qkv tensor 245 | qkv = self.qkv(self.norm(x)) 246 | 247 | # Apply attention module and convolutional layer to output 248 | h = self.attention(qkv) 249 | h = self.proj_out(h) 250 | 251 | # Add input tensor to output and reshape 252 | return (x + h).reshape(b, c, *spatial) 253 | 254 | 255 | class SpatialTransformer(nn.Module): 256 | """ 257 | Transformer block for image-like data. 258 | First, project the input (aka embedding) 259 | and reshape to b, t, d. 260 | Then apply standard transformer action. 261 | Finally, reshape to image. 262 | 263 | Args: 264 | in_channels (int): The number of input channels. 265 | n_heads (int): The number of attention heads. 266 | d_head (int): The number of channels per attention head. 267 | depth (int, optional): The number of transformer blocks. Defaults to 1. 268 | dropout (float, optional): The dropout probability. Defaults to 0. 269 | context_dim (int, optional): The dimension of the context tensor. 270 | If not specified, cross-attention defaults to self-attention. 271 | Defaults to None. 272 | """ 273 | 274 | def __init__( 275 | self, 276 | in_channels: int, 277 | n_heads: int, 278 | d_head: int, 279 | depth: Optional[int] = 1, 280 | dropout: Optional[float] = 0.0, 281 | context_dim: Optional[int] = None, 282 | ) -> None: 283 | super().__init__() 284 | 285 | # Set the number of input channels and attention heads 286 | self.in_channels = in_channels 287 | inner_dim = n_heads * d_head 288 | self.norm = Normalize(in_channels) 289 | 290 | # Create convolutional layer for input projection 291 | self.proj_in = nn.Conv2d( 292 | in_channels, inner_dim, kernel_size=1, stride=1, padding=0 293 | ) 294 | 295 | # Create list of transformer blocks 296 | self.transformer_blocks = nn.ModuleList( 297 | [ 298 | BasicTransformerBlock( 299 | inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim 300 | ) 301 | for d in range(depth) 302 | ] 303 | ) 304 | 305 | # Create convolutional layer for output projection 306 | self.proj_out = zero_module( 307 | nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 308 | ) 309 | 310 | def forward( 311 | self, x: torch.Tensor, context: Optional[torch.Tensor] = None 312 | ) -> torch.Tensor: 313 | """ 314 | Apply the spatial transformer block to the input tensor. 315 | 316 | Args: 317 | x (torch.Tensor): The input tensor. 318 | context (torch.Tensor, optional): The context tensor. If not specified, 319 | cross-attention defaults to self-attention. Defaults to None. 320 | 321 | Returns: 322 | torch.Tensor: The output tensor. 323 | """ 324 | b, c, h, w = x.shape 325 | x_in = x 326 | x = self.norm(x) 327 | 328 | # Apply input projection and reshape 329 | x = self.proj_in(x) 330 | x = rearrange(x, "b c h w -> b (h w) c") 331 | 332 | # Apply transformer blocks 333 | for block in self.transformer_blocks: 334 | x = block(x, context=context) 335 | 336 | # Reshape and apply output projection 337 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 338 | x = self.proj_out(x) 339 | 340 | # Add input tensor to output 341 | return x + x_in 342 | 343 | 344 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 345 | """ 346 | A sequential module that passes timestep embeddings to the children that 347 | support it as an extra input. 348 | 349 | Args: 350 | nn.Sequential: The sequential module. 351 | TimestepBlock: The timestep block module. 352 | 353 | Example: 354 | >>> model = TimestepEmbedSequential( 355 | ResBlock(channels=64, emb_channels=32, dropout=0.1), 356 | ResBlock(channels=64, emb_channels=32, dropout=0.1) 357 | ) 358 | >>> x = torch.randn(1, 64, 32, 32) 359 | >>> emb = torch.randn(1, 32) 360 | >>> out = model(x, emb) 361 | >>> print(out.shape) 362 | 363 | """ 364 | 365 | def forward( 366 | self, x: torch.Tensor, emb: torch.Tensor, context: Optional[torch.Tensor] = None 367 | ) -> torch.Tensor: 368 | """ 369 | Apply the sequential module to the input tensor. 370 | 371 | Args: 372 | x (torch.Tensor): The input tensor. 373 | emb (torch.Tensor): The timestep embedding tensor. 374 | context (torch.Tensor, optional): The context tensor. Defaults to None. 375 | 376 | Returns: 377 | torch.Tensor: The output tensor. 378 | """ 379 | for layer in self: 380 | if isinstance(layer, TimestepBlock): 381 | x = layer(x, emb) 382 | elif isinstance(layer, SpatialTransformer): 383 | x = layer(x, context) 384 | else: 385 | x = layer(x) 386 | return x 387 | 388 | 389 | class UNetModel(nn.Module): 390 | """ 391 | The full UNet model with attention and timestep embedding. 392 | 393 | Args: 394 | in_channels (int): The number of channels in the input tensor. 395 | model_channels (int): The base channel count for the model. 396 | out_channels (int): The number of channels in the output tensor. 397 | num_res_blocks (int): The number of residual blocks per downsample. 398 | attention_resolutions (Union[Set[int], List[int], Tuple[int]]): A collection 399 | of downsample rates at which attention will take place. For example, if 400 | this contains 4, then at 4x downsampling, attention will be used. 401 | dropout (float, optional): The dropout probability. Defaults to 0. 402 | channel_mult (Tuple[int], optional): The channel multiplier for each level 403 | of the UNet. Defaults to (1, 2, 4, 8). 404 | conv_resample (bool, optional): If True, use learned convolutions for upsampling 405 | and downsampling. Defaults to True. 406 | dims (int, optional): Determines if the signal is 1D, 2D, or 3D. Defaults to 2. 407 | num_classes (int, optional): If specified, then this model will be class-conditional 408 | with `num_classes` classes. Defaults to None. 409 | use_checkpoint (bool, optional): Use gradient checkpointing to reduce memory usage. 410 | Defaults to False. 411 | use_fp16 (bool, optional): Use half-precision floating point. Defaults to False. 412 | num_heads (int, optional): The number of attention heads in each attention layer. 413 | Defaults to -1. 414 | num_head_channels (int, optional): If specified, ignore num_heads and instead use 415 | a fixed channel width per attention head. Defaults to -1. 416 | num_heads_upsample (int, optional): Works with num_heads to set a different number 417 | of heads for upsampling. Deprecated. Defaults to -1. 418 | use_scale_shift_norm (bool, optional): Use a FiLM-like conditioning mechanism. Defaults 419 | to False. 420 | resblock_updown (bool, optional): Use residual blocks for up/downsampling. Defaults to False. 421 | use_new_attention_order (bool, optional): Use a different attention pattern for 422 | potentially increased efficiency. Defaults to False. 423 | use_spatial_transformer (bool, optional): Use a custom transformer support. Defaults to 424 | False. 425 | transformer_depth (int, optional): The depth of the custom transformer support. Defaults 426 | to 1. 427 | context_dim (int, optional): The dimension of the context tensor. Defaults to None. 428 | n_embed (int, optional): Custom support for prediction of discrete ids into codebook 429 | of first stage vq model. Defaults to None. 430 | legacy (bool, optional): Use legacy mode. Defaults to True. 431 | ignorekwargs (dict, optional): Ignore extra keyword arguments. 432 | Example: 433 | >>> cond_stage_config = { 434 | "image_size": 64, 435 | "in_channels": 8, 436 | "model_channels": 160, 437 | "out_channels": 4, 438 | "num_res_blocks": 2, 439 | "attention_resolutions": [16, 8], 440 | "channel_mult": [1, 2, 2, 4], 441 | "num_head_channels": 32 442 | } 443 | 444 | >>> model = UNetModel(**cond_stage_config) 445 | >>> x = torch.randn(2, 8, 128, 128) 446 | >>> emb = torch.randn(2) 447 | >>> out = model(x, emb) 448 | >>> print(out.shape) 449 | """ 450 | 451 | def __init__( 452 | self, 453 | in_channels: int, 454 | model_channels: int, 455 | out_channels: int, 456 | num_res_blocks: int, 457 | attention_resolutions: Union[Set[int], List[int], Tuple[int]], 458 | dropout: float = 0, 459 | channel_mult: Tuple[int] = (1, 2, 4, 8), 460 | conv_resample: bool = True, 461 | dims: int = 2, 462 | num_classes: Optional[int] = None, 463 | use_checkpoint: bool = False, 464 | use_fp16: bool = False, 465 | num_heads: int = -1, 466 | num_head_channels: int = -1, 467 | num_heads_upsample: int = -1, 468 | use_scale_shift_norm: bool = False, 469 | resblock_updown: bool = False, 470 | use_new_attention_order: bool = False, 471 | use_spatial_transformer: bool = False, 472 | transformer_depth: int = 1, 473 | context_dim: Optional[int] = None, 474 | n_embed: Optional[int] = None, 475 | legacy: bool = True, 476 | **ignorekwargs: dict, 477 | ): 478 | super().__init__() 479 | 480 | # If num_heads_upsample is not set, set it to num_heads 481 | if num_heads_upsample == -1: 482 | num_heads_upsample = num_heads 483 | 484 | # If num_heads is not set, raise an error if num_head_channels is not set 485 | if num_heads == -1: 486 | assert ( 487 | num_head_channels != -1 488 | ), "Either num_heads or num_head_channels has to be set" 489 | 490 | # If num_head_channels is not set, raise an error if num_heads is not set 491 | if num_head_channels == -1: 492 | assert ( 493 | num_heads != -1 494 | ), "Either num_heads or num_head_channels has to be set" 495 | 496 | # Set the instance variables 497 | self.in_channels = in_channels 498 | self.model_channels = model_channels 499 | self.out_channels = out_channels 500 | self.num_res_blocks = num_res_blocks 501 | self.attention_resolutions = attention_resolutions 502 | self.dropout = dropout 503 | self.channel_mult = channel_mult 504 | self.conv_resample = conv_resample 505 | self.num_classes = num_classes 506 | self.use_checkpoint = use_checkpoint 507 | self.dtype = th.float16 if use_fp16 else th.float32 508 | self.num_heads = num_heads 509 | self.num_head_channels = num_head_channels 510 | self.num_heads_upsample = num_heads_upsample 511 | self.predict_codebook_ids = n_embed is not None 512 | 513 | # Set up the time embedding layers 514 | time_embed_dim = model_channels * 4 515 | self.time_embed = nn.Sequential( 516 | linear(model_channels, time_embed_dim), 517 | nn.SiLU(), 518 | linear(time_embed_dim, time_embed_dim), 519 | ) 520 | 521 | # If num_classes is not None, set up the label embedding layer 522 | if self.num_classes is not None: 523 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 524 | 525 | # Set up the input blocks 526 | self.input_blocks = nn.ModuleList( 527 | [ 528 | TimestepEmbedSequential( 529 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 530 | ) 531 | ] 532 | ) 533 | 534 | # parameters for the block attention 535 | self._feature_size = model_channels 536 | input_block_chans = [model_channels] 537 | ch = model_channels 538 | ds = 1 539 | 540 | # Set up the attention blocks 541 | for level, mult in enumerate(channel_mult): 542 | for _ in range(num_res_blocks): 543 | layers = [ 544 | ResBlock( 545 | ch, 546 | time_embed_dim, 547 | dropout, 548 | out_channels=mult * model_channels, 549 | dims=dims, 550 | use_checkpoint=use_checkpoint, 551 | use_scale_shift_norm=use_scale_shift_norm, 552 | ) 553 | ] 554 | 555 | # If the downsample rate is in the attention resolutions, add an attention block 556 | ch = mult * model_channels 557 | if ds in attention_resolutions: 558 | if num_head_channels == -1: 559 | dim_head = ch // num_heads 560 | else: 561 | num_heads = ch // num_head_channels 562 | dim_head = num_head_channels 563 | if legacy: 564 | # num_heads = 1 565 | dim_head = ( 566 | ch // num_heads 567 | if use_spatial_transformer 568 | else num_head_channels 569 | ) 570 | layers.append( 571 | AttentionBlock( 572 | ch, 573 | use_checkpoint=use_checkpoint, 574 | num_heads=num_heads, 575 | num_head_channels=dim_head, 576 | use_new_attention_order=use_new_attention_order, 577 | ) 578 | if not use_spatial_transformer 579 | else SpatialTransformer( 580 | ch, 581 | num_heads, 582 | dim_head, 583 | depth=transformer_depth, 584 | context_dim=context_dim, 585 | ) 586 | ) 587 | 588 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 589 | self._feature_size += ch 590 | input_block_chans.append(ch) 591 | 592 | # If the downsample rate is not the last one, add a downsample block 593 | if level != len(channel_mult) - 1: 594 | out_ch = ch 595 | self.input_blocks.append( 596 | TimestepEmbedSequential( 597 | ResBlock( 598 | ch, 599 | time_embed_dim, 600 | dropout, 601 | out_channels=out_ch, 602 | dims=dims, 603 | use_checkpoint=use_checkpoint, 604 | use_scale_shift_norm=use_scale_shift_norm, 605 | down=True, 606 | ) 607 | if resblock_updown 608 | else Downsample( 609 | ch, conv_resample, dims=dims, out_channels=out_ch 610 | ) 611 | ) 612 | ) 613 | ch = out_ch 614 | input_block_chans.append(ch) 615 | ds *= 2 616 | self._feature_size += ch 617 | 618 | # Set up the middle block parameters 619 | if num_head_channels == -1: 620 | dim_head = ch // num_heads 621 | else: 622 | num_heads = ch // num_head_channels 623 | dim_head = num_head_channels 624 | if legacy: 625 | # num_heads = 1 626 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 627 | 628 | # If use_spatial_transformer is True, set up the spatial transformer 629 | self.middle_block = TimestepEmbedSequential( 630 | ResBlock( 631 | ch, 632 | time_embed_dim, 633 | dropout, 634 | dims=dims, 635 | use_checkpoint=use_checkpoint, 636 | use_scale_shift_norm=use_scale_shift_norm, 637 | ), 638 | AttentionBlock( 639 | ch, 640 | use_checkpoint=use_checkpoint, 641 | num_heads=num_heads, 642 | num_head_channels=dim_head, 643 | use_new_attention_order=use_new_attention_order, 644 | ) 645 | if not use_spatial_transformer 646 | else SpatialTransformer( 647 | ch, 648 | num_heads, 649 | dim_head, 650 | depth=transformer_depth, 651 | context_dim=context_dim, 652 | ), 653 | ResBlock( 654 | ch, 655 | time_embed_dim, 656 | dropout, 657 | dims=dims, 658 | use_checkpoint=use_checkpoint, 659 | use_scale_shift_norm=use_scale_shift_norm, 660 | ), 661 | ) 662 | self._feature_size += ch 663 | 664 | # Set up the output blocks 665 | self.output_blocks = nn.ModuleList([]) 666 | for level, mult in list(enumerate(channel_mult))[::-1]: 667 | for i in range(num_res_blocks + 1): 668 | # If the downsample rate is in the attention resolutions, add an attention block 669 | ich = input_block_chans.pop() 670 | layers = [ 671 | ResBlock( 672 | ch + ich, 673 | time_embed_dim, 674 | dropout, 675 | out_channels=model_channels * mult, 676 | dims=dims, 677 | use_checkpoint=use_checkpoint, 678 | use_scale_shift_norm=use_scale_shift_norm, 679 | ) 680 | ] 681 | ch = model_channels * mult 682 | 683 | # If the downsample rate is in the attention resolutions, add an attention block 684 | if ds in attention_resolutions: 685 | if num_head_channels == -1: 686 | dim_head = ch // num_heads 687 | else: 688 | num_heads = ch // num_head_channels 689 | dim_head = num_head_channels 690 | if legacy: 691 | # num_heads = 1 692 | dim_head = ( 693 | ch // num_heads 694 | if use_spatial_transformer 695 | else num_head_channels 696 | ) 697 | layers.append( 698 | AttentionBlock( 699 | ch, 700 | use_checkpoint=use_checkpoint, 701 | num_heads=num_heads_upsample, 702 | num_head_channels=dim_head, 703 | use_new_attention_order=use_new_attention_order, 704 | ) 705 | if not use_spatial_transformer 706 | else SpatialTransformer( 707 | ch, 708 | num_heads, 709 | dim_head, 710 | depth=transformer_depth, 711 | context_dim=context_dim, 712 | ) 713 | ) 714 | 715 | # If the downsample rate is the last one, add an upsample block 716 | if level and i == num_res_blocks: 717 | out_ch = ch 718 | layers.append( 719 | ResBlock( 720 | ch, 721 | time_embed_dim, 722 | dropout, 723 | out_channels=out_ch, 724 | dims=dims, 725 | use_checkpoint=use_checkpoint, 726 | use_scale_shift_norm=use_scale_shift_norm, 727 | up=True, 728 | ) 729 | if resblock_updown 730 | else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) 731 | ) 732 | ds //= 2 733 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 734 | self._feature_size += ch 735 | 736 | # Set up the output layer 737 | self.out = nn.Sequential( 738 | normalization(ch), 739 | nn.SiLU(), 740 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 741 | ) 742 | 743 | # Set up the codebook id predictor layer 744 | if self.predict_codebook_ids: 745 | self.id_predictor = nn.Sequential( 746 | normalization(ch), conv_nd(dims, model_channels, n_embed, 1) 747 | ) 748 | 749 | def convert_to_fp16(self): 750 | """ 751 | Convert the torso of the model to float16. 752 | """ 753 | self.input_blocks.apply(convert_module_to_f16) 754 | self.middle_block.apply(convert_module_to_f16) 755 | self.output_blocks.apply(convert_module_to_f16) 756 | 757 | def convert_to_fp32(self): 758 | """ 759 | Convert the torso of the model to float32. 760 | """ 761 | self.input_blocks.apply(convert_module_to_f32) 762 | self.middle_block.apply(convert_module_to_f32) 763 | self.output_blocks.apply(convert_module_to_f32) 764 | 765 | def forward(self, x, timesteps=None, context=None, y=None, **kwargs): 766 | """ 767 | Apply the model to an input batch. 768 | 769 | Args: 770 | x (torch.Tensor): An [N x C x ...] Tensor of inputs. 771 | timesteps (torch.Tensor, optional): A 1-D batch of timesteps. 772 | Defaults to None. 773 | context (torch.Tensor, optional): Conditioning plugged in via crossattn. 774 | Defaults to None. 775 | y (torch.Tensor, optional): An [N] Tensor of labels, if class-conditional. 776 | Defaults to None. 777 | 778 | Returns: 779 | torch.Tensor: An [N x C x ...] Tensor of outputs. 780 | """ 781 | # print("aaa") 782 | # print(x.shape) 783 | # print(timesteps.shape) 784 | # print("aaa") 785 | # 1 + "a" 786 | 787 | # Check if y is specified only if the model is class-conditional 788 | assert (y is not None) == ( 789 | self.num_classes is not None 790 | ), "must specify y if and only if the model is class-conditional" 791 | 792 | # Initialize a list to store the hidden states of the input blocks 793 | hs = [] 794 | 795 | # Compute the timestep embeddings and time embeddings 796 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 797 | emb = self.time_embed(t_emb) 798 | 799 | # Add label embeddings if the model is class-conditional 800 | if self.num_classes is not None: 801 | assert y.shape == (x.shape[0],) 802 | emb = emb + self.label_emb(y) 803 | 804 | # Convert the input tensor to the specified data type 805 | h = x.type(self.dtype) 806 | 807 | # Pass the input tensor through the input blocks and store the hidden states 808 | for module in self.input_blocks: 809 | h = module(h, emb, context) 810 | hs.append(h) 811 | 812 | # Pass the output of the input blocks through the middle block 813 | h = self.middle_block(h, emb, context) 814 | 815 | # Pass the output of the middle block through the output blocks in reverse order 816 | for module in self.output_blocks: 817 | # Concatenate the output of the current output block with the corresponding 818 | # hidden state from the input blocks 819 | h = th.cat([h, hs.pop()], dim=1) 820 | 821 | # Pass the concatenated tensor through the current output block 822 | h = module(h, emb, context) 823 | 824 | # Convert the output tensor to the same data type as the input tensor 825 | h = h.type(x.dtype) 826 | 827 | # Return the output tensor or the codebook ID predictions if specified 828 | if self.predict_codebook_ids: 829 | return self.id_predictor(h) 830 | else: 831 | return self.out(h) 832 | -------------------------------------------------------------------------------- /opensr_model/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/opensr-model/85bbf2dcc7937c8f38596ed5c26b1e75bd46c85f/opensr_model/diffusion/__init__.py -------------------------------------------------------------------------------- /opensr_model/diffusion/latentdiffusion.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from functools import partial 3 | from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from opensr_model.autoencoder.autoencoder import (AutoencoderKL, 9 | DiagonalGaussianDistribution) 10 | from opensr_model.denoiser.unet import UNetModel 11 | from opensr_model.diffusion.utils import (LitEma, count_params, default, 12 | disabled_train, exists, 13 | extract_into_tensor, 14 | make_beta_schedule, 15 | make_convolutional_sample) 16 | 17 | __conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} 18 | 19 | 20 | class DiffusionWrapper(nn.Module): 21 | """ 22 | A wrapper around a UNetModel that supports different types of conditioning. 23 | 24 | Args: 25 | diff_model_config (dict): A dictionary of configuration options for the UNetModel. 26 | conditioning_key (str, optional): The type of conditioning to use 27 | (None, 'concat', 'crossattn', 'hybrid', or 'adm'). Defaults to None. 28 | 29 | Raises: 30 | AssertionError: If the conditioning key is not one of the supported values. 31 | 32 | Example: 33 | >>> diff_model_config = {'in_channels': 3, 'out_channels': 3, 'num_filters': 32} 34 | >>> wrapper = DiffusionWrapper(diff_model_config, conditioning_key='concat') 35 | >>> x = torch.randn(1, 3, 256, 256) 36 | >>> t = torch.randn(1) 37 | >>> c_concat = [torch.randn(1, 32, 256, 256)] 38 | >>> y = wrapper(x, t, c_concat=c_concat) 39 | """ 40 | 41 | def __init__(self, diff_model_config: dict, conditioning_key: Optional[str] = None): 42 | super().__init__() 43 | self.diffusion_model = UNetModel(**diff_model_config) 44 | self.conditioning_key = conditioning_key 45 | 46 | ckey_options = [None, "concat", "crossattn", "hybrid", "adm"] 47 | assert ( 48 | self.conditioning_key in ckey_options 49 | ), f"Unsupported conditioning key: {self.conditioning_key}" 50 | 51 | def forward( 52 | self, 53 | x: torch.Tensor, 54 | t: torch.Tensor, 55 | c_concat: Optional[List[torch.Tensor]] = None, 56 | c_crossattn: Optional[List[torch.Tensor]] = None, 57 | ) -> torch.Tensor: 58 | """ 59 | Apply the diffusion model to the input tensor. 60 | 61 | Args: 62 | x (torch.Tensor): The input tensor. 63 | t (torch.Tensor): The diffusion time. 64 | c_concat (List[torch.Tensor], optional): A list of tensors to concatenate with the input tensor. 65 | Used when conditioning_key is 'concat'. Defaults to None. 66 | c_crossattn (List[torch.Tensor], optional): A list of tensors to use for cross-attention. 67 | Used when conditioning_key is 'crossattn', 'hybrid', or 'adm'. Defaults to None. 68 | 69 | Returns: 70 | torch.Tensor: The output tensor. 71 | 72 | Raises: 73 | NotImplementedError: If the conditioning key is not one of the supported values. 74 | """ 75 | xc = torch.cat([x] + c_concat, dim=1) 76 | out = self.diffusion_model(xc, t) 77 | return out 78 | 79 | 80 | class DDPM(nn.Module): 81 | """This class implements the classic DDPM (Diffusion Models) with Gaussian diffusion 82 | in image space. 83 | 84 | Args: 85 | unet_config (dict): A dictionary of configuration options for the UNetModel. 86 | timesteps (int): The number of diffusion timesteps to use. 87 | beta_schedule (str): The type of beta schedule to use (linear, cosine, or fixed). 88 | use_ema (bool): Whether to use exponential moving averages (EMAs) of the model weights during training. 89 | first_stage_key (str): The key to use for the first stage of the model (either "image" or "noise"). 90 | linear_start (float): The starting value for the linear beta schedule. 91 | linear_end (float): The ending value for the linear beta schedule. 92 | cosine_s (float): The scaling factor for the cosine beta schedule. 93 | given_betas (list): A list of beta values to use for the fixed beta schedule. 94 | v_posterior (float): The weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta. 95 | conditioning_key (str): The type of conditioning to use (None, 'concat', 'crossattn', 'hybrid', or 'adm'). 96 | parameterization (str): The type of parameterization to use for the diffusion process (either "eps" or "x0"). 97 | use_positional_encodings (bool): Whether to use positional encodings for the input. 98 | 99 | Methods: 100 | register_schedule: Registers the schedule for the betas and alphas. 101 | get_input: Gets the input from the DataLoader and rearranges it. 102 | decode_first_stage: Decodes the first stage of the model. 103 | ema_scope: Switches to EMA weights during training. 104 | 105 | Attributes: 106 | parameterization (str): The type of parameterization used for the diffusion process. 107 | cond_stage_model (None): The conditioning stage model (not used in this implementation). 108 | first_stage_key (str): The key used for the first stage of the model. 109 | use_positional_encodings (bool): Whether positional encodings are used for the input. 110 | model (DiffusionWrapper): The diffusion model. 111 | use_ema (bool): Whether EMAs of the model weights are used during training. 112 | model_ema (LitEma): The EMA of the model weights. 113 | v_posterior (float): The weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta. 114 | 115 | Example: 116 | >>> unet_config = { 117 | 'in_channels': 3, 118 | 'model_channels': 160, 119 | 'num_res_blocks': 2, 120 | 'attention_resolutions': [16, 8], 121 | 'channel_mult': [1, 2, 2, 4], 122 | 'num_head_channels': 32 123 | } 124 | >>> model = DDPM( 125 | unet_config, timesteps=1000, beta_schedule='linear', 126 | use_ema=True, first_stage_key='image' 127 | ) 128 | """ 129 | 130 | def __init__( 131 | self, 132 | unet_config: Dict[str, Any], 133 | timesteps: int = 1000, 134 | beta_schedule: str = "linear", 135 | use_ema: bool = True, 136 | first_stage_key: str = "image", 137 | linear_start: float = 1e-4, 138 | linear_end: float = 2e-2, 139 | cosine_s: float = 8e-3, 140 | given_betas: Optional[List[float]] = None, 141 | v_posterior: float = 0.0, 142 | conditioning_key: Optional[str] = None, 143 | parameterization: str = "eps", 144 | use_positional_encodings: bool = False, 145 | ) -> None: 146 | super().__init__() 147 | assert parameterization in [ 148 | "eps", 149 | "x0", 150 | ], 'currently only supporting "eps" and "x0"' 151 | self.parameterization = parameterization 152 | 153 | print( 154 | f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode" 155 | ) 156 | 157 | self.cond_stage_model = None 158 | self.first_stage_key = first_stage_key 159 | self.use_positional_encodings = use_positional_encodings 160 | self.model = DiffusionWrapper(unet_config, conditioning_key) 161 | 162 | count_params(self.model, verbose=True) 163 | 164 | self.use_ema = use_ema 165 | if self.use_ema: 166 | self.model_ema = LitEma(self.model) 167 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 168 | 169 | self.v_posterior = v_posterior 170 | 171 | self.register_schedule( 172 | given_betas=given_betas, 173 | beta_schedule=beta_schedule, 174 | timesteps=timesteps, 175 | linear_start=linear_start, 176 | linear_end=linear_end, 177 | cosine_s=cosine_s, 178 | ) 179 | 180 | def register_schedule( 181 | self, 182 | given_betas: Optional[List[float]] = None, 183 | beta_schedule: str = "linear", 184 | timesteps: int = 1000, 185 | linear_start: float = 1e-4, 186 | linear_end: float = 2e-2, 187 | cosine_s: float = 8e-3, 188 | ) -> None: 189 | """ 190 | Registers the schedule for the betas and alphas. 191 | 192 | Args: 193 | given_betas (list, optional): A list of beta values to use for the fixed beta schedule. 194 | Defaults to None. 195 | beta_schedule (str, optional): The type of beta schedule to use (linear, cosine, or fixed). 196 | Defaults to "linear". 197 | timesteps (int, optional): The number of diffusion timesteps to use. Defaults to 1000. 198 | linear_start (float, optional): The starting value for the linear beta schedule. Defaults to 1e-4. 199 | linear_end (float, optional): The ending value for the linear beta schedule. Defaults to 2e-2. 200 | cosine_s (float, optional): The scaling factor for the cosine beta schedule. Defaults to 8e-3. 201 | """ 202 | if exists(given_betas): 203 | betas = given_betas 204 | else: 205 | betas = make_beta_schedule( 206 | beta_schedule, 207 | timesteps, 208 | linear_start=linear_start, 209 | linear_end=linear_end, 210 | cosine_s=cosine_s, 211 | ) 212 | alphas = 1.0 - betas 213 | alphas_cumprod = np.cumprod(alphas, axis=0) 214 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 215 | 216 | (timesteps,) = betas.shape 217 | self.num_timesteps = int(timesteps) 218 | self.linear_start = linear_start 219 | self.linear_end = linear_end 220 | assert ( 221 | alphas_cumprod.shape[0] == self.num_timesteps 222 | ), "alphas have to be defined for each timestep" 223 | 224 | to_torch = partial(torch.tensor, dtype=torch.float32) 225 | 226 | self.register_buffer("betas", to_torch(betas)) 227 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 228 | self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) 229 | 230 | # calculations for diffusion q(x_t | x_{t-1}) and others 231 | self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) 232 | self.register_buffer( 233 | "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) 234 | ) 235 | self.register_buffer( 236 | "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) 237 | ) 238 | self.register_buffer( 239 | "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) 240 | ) 241 | self.register_buffer( 242 | "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) 243 | ) 244 | 245 | # calculations for posterior q(x_{t-1} | x_t, x_0) 246 | posterior_variance = (1 - self.v_posterior) * betas * ( 247 | 1.0 - alphas_cumprod_prev 248 | ) / (1.0 - alphas_cumprod) + self.v_posterior * betas 249 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 250 | self.register_buffer("posterior_variance", to_torch(posterior_variance)) 251 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 252 | self.register_buffer( 253 | "posterior_log_variance_clipped", 254 | to_torch(np.log(np.maximum(posterior_variance, 1e-20))), 255 | ) 256 | self.register_buffer( 257 | "posterior_mean_coef1", 258 | to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), 259 | ) 260 | self.register_buffer( 261 | "posterior_mean_coef2", 262 | to_torch( 263 | (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) 264 | ), 265 | ) 266 | 267 | if self.parameterization == "eps": 268 | lvlb_weights = self.betas**2 / ( 269 | 2 270 | * self.posterior_variance 271 | * to_torch(alphas) 272 | * (1 - self.alphas_cumprod) 273 | ) 274 | elif self.parameterization == "x0": 275 | lvlb_weights = ( 276 | 0.5 277 | * np.sqrt(torch.Tensor(alphas_cumprod)) 278 | / (2.0 * 1 - torch.Tensor(alphas_cumprod)) 279 | ) 280 | else: 281 | raise NotImplementedError("mu not supported") 282 | # TODO how to choose this term 283 | lvlb_weights[0] = lvlb_weights[1] 284 | self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) 285 | assert not torch.isnan(self.lvlb_weights).all() 286 | 287 | def get_input(self, batch: Dict[str, torch.Tensor], k: str) -> torch.Tensor: 288 | """ 289 | Gets the input from the DataLoader and rearranges it. 290 | 291 | Args: 292 | batch (Dict[str, torch.Tensor]): The batch of data from the DataLoader. 293 | k (str): The key for the input tensor in the batch. 294 | 295 | Returns: 296 | torch.Tensor: The input tensor, rearranged and converted to float. 297 | """ 298 | 299 | x = batch[k] 300 | if len(x.shape) == 3: 301 | x = x[..., None] 302 | 303 | x = x.to(memory_format=torch.contiguous_format).float() 304 | 305 | return x 306 | 307 | @contextmanager 308 | def ema_scope(self, context: Optional[str] = None) -> Generator[None, None, None]: 309 | """ 310 | A context manager that switches to EMA weights during training. 311 | 312 | Args: 313 | context (Optional[str]): A string to print when switching to and from EMA weights. 314 | 315 | Yields: 316 | None 317 | """ 318 | if self.use_ema: 319 | self.model_ema.store(self.model.parameters()) 320 | self.model_ema.copy_to(self.model) 321 | if context is not None: 322 | print(f"{context}: Switched to EMA weights") 323 | try: 324 | yield None 325 | finally: 326 | if self.use_ema: 327 | self.model_ema.restore(self.model.parameters()) 328 | if context is not None: 329 | print(f"{context}: Restored training weights") 330 | 331 | 332 | def decode_first_stage(self, z: torch.Tensor) -> torch.Tensor: 333 | """ 334 | Decodes the first stage of the model. 335 | 336 | Args: 337 | z (torch.Tensor): The input tensor. 338 | 339 | Returns: 340 | torch.Tensor: The decoded output tensor. 341 | """ 342 | 343 | z = 1.0 / self.scale_factor * z 344 | 345 | if hasattr(self, "split_input_params"): 346 | if self.split_input_params["patch_distributed_vq"]: 347 | ks = self.split_input_params["ks"] # eg. (128, 128) 348 | stride = self.split_input_params["stride"] # eg. (64, 64) 349 | uf = self.split_input_params["vqf"] 350 | bs, nc, h, w = z.shape 351 | if ks[0] > h or ks[1] > w: 352 | ks = (min(ks[0], h), min(ks[1], w)) 353 | print("reducing Kernel") 354 | 355 | if stride[0] > h or stride[1] > w: 356 | stride = (min(stride[0], h), min(stride[1], w)) 357 | print("reducing stride") 358 | 359 | fold, unfold, normalization, weighting = self.get_fold_unfold( 360 | z, ks, stride, uf=uf 361 | ) 362 | 363 | z = unfold(z) # (bn, nc * prod(**ks), L) 364 | # 1. Reshape to img shape 365 | z = z.view( 366 | (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) 367 | ) # (bn, nc, ks[0], ks[1], L ) 368 | 369 | # 2. apply model loop over last dim 370 | output_list = [ 371 | self.first_stage_model.decode(z[:, :, :, :, i]) 372 | for i in range(z.shape[-1]) 373 | ] 374 | 375 | o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) 376 | o = o * weighting 377 | # Reverse 1. reshape to img shape 378 | o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) 379 | # stitch crops together 380 | decoded = fold(o) 381 | decoded = decoded / normalization # norm is shape (1, 1, h, w) 382 | return decoded 383 | 384 | else: 385 | return self.first_stage_model.decode(z) 386 | 387 | else: 388 | return self.first_stage_model.decode(z) 389 | 390 | 391 | class LatentDiffusion(DDPM): 392 | """ 393 | LatentDiffusion is a class that extends the DDPM class and implements a diffusion 394 | model with a latent variable. The model consists of two stages: a first stage that 395 | encodes the input tensor into a latent tensor, and a second stage that decodes the 396 | latent tensor into the output tensor. The model also has a conditional stage that 397 | takes a conditioning tensor as input and produces a learned conditioning tensor 398 | that is used to condition the first and second stages of the model. The class 399 | provides methods for encoding and decoding tensors, computing the output tensor 400 | and loss, and sampling from the distribution at a given latent tensor and timestep. 401 | The class also provides methods for registering and applying schedules, and for 402 | getting and setting the scale factor and conditioning key. 403 | 404 | Methods: 405 | register_schedule(self, schedule: Schedule) -> None: Registers the given schedule 406 | with the model. 407 | make_cond_schedule(self, schedule: Schedule) -> Schedule: Returns a new schedule 408 | with the given schedule applied to the conditional stage of the model. 409 | encode_first_stage(self, x: torch.Tensor, t: int) -> torch.Tensor: Encodes the given 410 | input tensor with the first stage of the model for the given timestep. 411 | get_first_stage_encoding(self, x: torch.Tensor, t: int) -> torch.Tensor: Returns the 412 | encoding of the given input tensor with the first stage of the model for the 413 | given timestep. 414 | get_learned_conditioning(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: 415 | Returns the learned conditioning tensor for the given input 416 | tensor, timestep, and conditioning tensor. 417 | get_input(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: 418 | Returns the input tensor for the given input tensor, timestep, and 419 | conditioning tensor. 420 | compute(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 421 | Computes the output tensor and loss for the given input tensor, 422 | timestep, and conditioning tensor. 423 | apply_model(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: Applies 424 | the model to the given input tensor, timestep, and conditioning tensor. 425 | get_fold_unfold(self, ks: int, stride: int, vqf: int) -> Tuple[Callable, Callable]: Returns the fold 426 | and unfold functions for the given kernel size, stride, and vector quantization factor. 427 | forward(self, x: torch.Tensor, t: int, y: Optional[torch.Tensor] = None) -> torch.Tensor: Computes the 428 | output tensor for the given input tensor, timestep, and conditioning tensor. 429 | q_sample(self, z: torch.Tensor, t: int, eps: Optional[torch.Tensor] = None) -> torch.Tensor: Samples 430 | from the distribution at the given latent tensor and timestep. 431 | """ 432 | 433 | def __init__( 434 | self, 435 | first_stage_config: Dict[str, Any], 436 | cond_stage_config: Union[str, Dict[str, Any]], 437 | num_timesteps_cond: Optional[int] = None, 438 | cond_stage_key: str = "image", 439 | cond_stage_trainable: bool = False, 440 | concat_mode: bool = True, 441 | cond_stage_forward: Optional[Callable] = None, 442 | conditioning_key: Optional[str] = None, 443 | scale_factor: float = 1.0, 444 | scale_by_std: bool = False, 445 | *args: Any, 446 | **kwargs: Any, 447 | ): 448 | """ 449 | Initializes the LatentDiffusion model. 450 | 451 | Args: 452 | first_stage_config (Dict[str, Any]): The configuration for the first stage of the model. 453 | cond_stage_config (Union[str, Dict[str, Any]]): The configuration for the conditional stage of the model. 454 | num_timesteps_cond (Optional[int]): The number of timesteps for the conditional stage of the model. 455 | cond_stage_key (str): The key for the conditional stage of the model. 456 | cond_stage_trainable (bool): Whether the conditional stage of the model is trainable. 457 | concat_mode (bool): Whether to use concatenation or cross-attention for the conditioning. 458 | cond_stage_forward (Optional[Callable]): A function to apply to the output of the conditional stage of the model. 459 | conditioning_key (Optional[str]): The key for the conditioning. 460 | scale_factor (float): The scale factor for the input tensor. 461 | scale_by_std (bool): Whether to scale the input tensor by its standard deviation. 462 | *args (Any): Additional arguments. 463 | **kwargs (Any): Additional keyword arguments. 464 | """ 465 | 466 | self.num_timesteps_cond = default(num_timesteps_cond, 1) 467 | self.scale_by_std = scale_by_std 468 | assert self.num_timesteps_cond <= kwargs["timesteps"] 469 | 470 | # for backwards compatibility after implementation of DiffusionWrapper 471 | if conditioning_key is None: 472 | conditioning_key = "concat" if concat_mode else "crossattn" 473 | if cond_stage_config == "__is_unconditional__": 474 | conditioning_key = None 475 | 476 | super().__init__(conditioning_key=conditioning_key, *args, **kwargs) 477 | self.concat_mode = concat_mode 478 | self.cond_stage_trainable = cond_stage_trainable 479 | self.cond_stage_key = cond_stage_key 480 | if not scale_by_std: 481 | self.scale_factor = scale_factor 482 | else: 483 | self.register_buffer("scale_factor", torch.tensor(scale_factor)) 484 | 485 | self.cond_stage_forward = cond_stage_forward 486 | 487 | # Set Fusion parameters (SIMON) 488 | # TODO: We only have SISR parameters 489 | self.sr_type = "SISR" 490 | 491 | # Setup the AutoencoderKL model 492 | embed_dim = first_stage_config["embed_dim"] # extract embedded dim fro first stage config 493 | self.first_stage_model = AutoencoderKL(first_stage_config, embed_dim=embed_dim) 494 | self.first_stage_model.eval() 495 | self.first_stage_model.train = disabled_train 496 | for param in self.first_stage_model.parameters(): 497 | param.requires_grad = False 498 | 499 | # Setup the Unet model 500 | self.cond_stage_model = torch.nn.Identity() # Unet 501 | self.cond_stage_model.eval() 502 | self.cond_stage_model.train = disabled_train 503 | for param in self.cond_stage_model.parameters(): 504 | param.requires_grad = False 505 | 506 | def register_schedule( 507 | self, 508 | given_betas: Optional[Union[float, torch.Tensor]] = None, 509 | beta_schedule: str = "linear", 510 | timesteps: int = 1000, 511 | linear_start: float = 1e-4, 512 | linear_end: float = 2e-2, 513 | cosine_s: float = 8e-3, 514 | ) -> None: 515 | """ 516 | Registers the given schedule with the model. 517 | 518 | Args: 519 | given_betas (Optional[Union[float, torch.Tensor]]): The betas for the schedule. 520 | beta_schedule (str): The type of beta schedule to use. 521 | timesteps (int): The number of timesteps for the schedule. 522 | linear_start (float): The start value for the linear schedule. 523 | linear_end (float): The end value for the linear schedule. 524 | cosine_s (float): The scale factor for the cosine schedule. 525 | """ 526 | super().register_schedule( 527 | given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s 528 | ) 529 | 530 | self.shorten_cond_schedule = self.num_timesteps_cond > 1 531 | if self.shorten_cond_schedule: 532 | self.make_cond_schedule() 533 | 534 | def make_cond_schedule(self) -> None: 535 | """ 536 | Shortens the schedule for the conditional stage of the model. 537 | """ 538 | self.cond_ids = torch.full( 539 | size=(self.num_timesteps,), 540 | fill_value=self.num_timesteps - 1, 541 | dtype=torch.long, 542 | ) 543 | ids = torch.round( 544 | torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) 545 | ).long() 546 | self.cond_ids[: self.num_timesteps_cond] = ids 547 | 548 | 549 | def encode_first_stage(self, x: torch.Tensor) -> torch.Tensor: 550 | """ 551 | Encodes the given input tensor with the first stage of the model. 552 | 553 | Args: 554 | x (torch.Tensor): The input tensor. 555 | 556 | Returns: 557 | torch.Tensor: The encoded output tensor. 558 | """ 559 | return self.first_stage_model.encode(x) 560 | 561 | 562 | def get_first_stage_encoding( 563 | self, encoder_posterior: Union[DiagonalGaussianDistribution, torch.Tensor] 564 | ) -> torch.Tensor: 565 | """ 566 | Returns the encoding of the given input tensor with the first stage of the 567 | model for the given timestep. 568 | 569 | Args: 570 | encoder_posterior (Union[DiagonalGaussianDistribution, torch.Tensor]): The 571 | encoder posterior. 572 | 573 | Returns: 574 | torch.Tensor: The encoding of the input tensor. 575 | """ 576 | if isinstance(encoder_posterior, DiagonalGaussianDistribution): 577 | z = encoder_posterior.sample() 578 | elif isinstance(encoder_posterior, torch.Tensor): 579 | z = encoder_posterior 580 | else: 581 | raise NotImplementedError( 582 | f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" 583 | ) 584 | return self.scale_factor * z 585 | 586 | def get_learned_conditioning(self, c: torch.Tensor) -> torch.Tensor: 587 | """ 588 | Returns the learned conditioning tensor for the given input tensor. 589 | 590 | Args: 591 | c (torch.Tensor): The input tensor. 592 | 593 | Returns: 594 | torch.Tensor: The learned conditioning tensor. 595 | """ 596 | if self.cond_stage_forward is None: 597 | if hasattr(self.cond_stage_model, "encode") and callable( 598 | self.cond_stage_model.encode 599 | ): 600 | c = self.cond_stage_model.encode(c) 601 | if isinstance(c, DiagonalGaussianDistribution): 602 | c = c.mode() 603 | else: 604 | c = self.cond_stage_model(c) 605 | # cond stage model is identity 606 | else: 607 | assert hasattr(self.cond_stage_model, self.cond_stage_forward) 608 | c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) 609 | return c 610 | 611 | def get_input( 612 | self, 613 | batch: torch.Tensor, 614 | k: int, 615 | return_first_stage_outputs: bool = False, 616 | force_c_encode: bool = False, 617 | cond_key: Optional[str] = None, 618 | return_original_cond: bool = False, 619 | bs: Optional[int] = None, 620 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: 621 | """ 622 | Returns the input tensor for the given batch and timestep. 623 | 624 | Args: 625 | batch (torch.Tensor): The input batch tensor. 626 | k (int): The timestep. 627 | return_first_stage_outputs (bool): Whether to return the outputs of the first stage of the model. 628 | force_c_encode (bool): Whether to force encoding of the conditioning tensor. 629 | cond_key (Optional[str]): The key for the conditioning tensor. 630 | return_original_cond (bool): Whether to return the original conditioning tensor. 631 | bs (Optional[int]): The batch size. 632 | 633 | Returns: 634 | Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: The input tensor, the outputs of the 635 | first stage of the model (if `return_first_stage_outputs` is `True`), and the encoded conditioning tensor 636 | (if `force_c_encode` is `True` and `cond_key` is not `None`). 637 | """ 638 | 639 | # k = first_stage_key on this SR example 640 | x = super().get_input(batch, k) # line 333 641 | 642 | if bs is not None: 643 | x = x[:bs] 644 | x = x.to(self.device) 645 | 646 | # perform always for HR and for HR only of SISR 647 | if self.sr_type == "SISR" or k == "image": 648 | encoder_posterior = self.encode_first_stage(x) 649 | z = self.get_first_stage_encoding(encoder_posterior).detach() 650 | 651 | if self.model.conditioning_key is not None: 652 | # self.model.conditioning_key = "image" in SR example 653 | 654 | if cond_key is None: 655 | cond_key = self.cond_stage_key 656 | 657 | if cond_key != self.first_stage_key: 658 | if cond_key in ["caption", "coordinates_bbox"]: 659 | xc = batch[cond_key] 660 | elif cond_key == "class_label": 661 | xc = batch 662 | else: 663 | xc = super().get_input(batch, cond_key).to(self.device) 664 | else: 665 | xc = x 666 | if not self.cond_stage_trainable or force_c_encode: 667 | if isinstance(xc, dict) or isinstance(xc, list): 668 | # import pudb; pudb.set_trace() 669 | c = self.get_learned_conditioning(xc) 670 | else: 671 | c = self.get_learned_conditioning(xc.to(self.device)) 672 | else: 673 | c = xc 674 | if bs is not None: 675 | c = c[:bs] 676 | 677 | # BUG if use_positional_encodings is True 678 | if self.use_positional_encodings: 679 | pos_x, pos_y = self.compute_latent_shifts(batch) 680 | ckey = __conditioning_keys__[self.model.conditioning_key] 681 | c = {ckey: c, "pos_x": pos_x, "pos_y": pos_y} 682 | 683 | else: 684 | c = None 685 | xc = None 686 | if self.use_positional_encodings: 687 | pos_x, pos_y = self.compute_latent_shifts(batch) 688 | c = {"pos_x": pos_x, "pos_y": pos_y} 689 | out = [z, c] 690 | if return_first_stage_outputs: 691 | xrec = self.decode_first_stage(z) 692 | out.extend([x, xrec]) 693 | if return_original_cond: 694 | out.append(xc) 695 | 696 | """ 697 | # overwrite LR original with encoded LR if wanted 698 | self.encode_conditioning = True 699 | if self.encode_conditioning==True and self.sr_type=="SISR": 700 | #print("Encoding conditioning!") 701 | # try to upsample->encode conditioning 702 | c = torch.nn.functional.interpolate(out[1], size=(512,512), mode='bilinear', align_corners=False) 703 | # encode c 704 | c = self.encode_first_stage(c).sample() 705 | out[1] = c 706 | """ 707 | 708 | 709 | return out 710 | 711 | def compute( 712 | self, example: torch.Tensor, custom_steps: int = 200, temperature: float = 1.0 713 | ) -> torch.Tensor: 714 | """ 715 | Performs inference on the given example tensor. 716 | 717 | Args: 718 | example (torch.Tensor): The example tensor. 719 | custom_steps (int): The number of steps to perform. 720 | temperature (float): The temperature to use. 721 | 722 | Returns: 723 | torch.Tensor: The output tensor. 724 | """ 725 | guider = None 726 | ckwargs = None 727 | ddim_use_x0_pred = False 728 | temperature = temperature 729 | eta = 1.0 730 | custom_shape = None 731 | 732 | if hasattr(self, "split_input_params"): 733 | delattr(self, "split_input_params") 734 | 735 | logs = make_convolutional_sample( 736 | example, 737 | self, 738 | custom_steps=custom_steps, 739 | eta=eta, 740 | quantize_x0=False, 741 | custom_shape=custom_shape, 742 | temperature=temperature, 743 | noise_dropout=0.0, 744 | corrector=guider, 745 | corrector_kwargs=ckwargs, 746 | x_T=None, 747 | ddim_use_x0_pred=ddim_use_x0_pred, 748 | ) 749 | 750 | return logs["sample"] 751 | 752 | def apply_model( 753 | self, 754 | x_noisy: torch.Tensor, 755 | t: int, 756 | cond: Optional[torch.Tensor] = None, 757 | return_ids: bool = False, 758 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 759 | """ 760 | Applies the model to the given noisy input tensor. 761 | 762 | Args: 763 | x_noisy (torch.Tensor): The noisy input tensor. 764 | t (int): The timestep. 765 | cond (Optional[torch.Tensor]): The conditioning tensor. 766 | return_ids (bool): Whether to return the IDs of the diffusion process. 767 | 768 | Returns: 769 | Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The output tensor, and optionally the IDs of the 770 | diffusion process. 771 | """ 772 | 773 | if isinstance(cond, dict): 774 | pass 775 | else: 776 | if not isinstance(cond, list): 777 | cond = [cond] 778 | key = ( 779 | "c_concat" if self.model.conditioning_key == "concat" else "c_crossattn" 780 | ) 781 | cond = {key: cond} 782 | 783 | x_recon = self.model(x_noisy, t, **cond) 784 | 785 | if isinstance(x_recon, tuple) and not return_ids: 786 | return x_recon[0] 787 | else: 788 | return x_recon 789 | 790 | def get_fold_unfold( 791 | self, x: torch.Tensor, kernel_size: int, stride: int, uf: int = 1, df: int = 1 792 | ) -> Tuple[nn.Conv2d, nn.ConvTranspose2d]: 793 | """ 794 | Returns the fold and unfold convolutional layers for the given input tensor. 795 | 796 | Args: 797 | x (torch.Tensor): The input tensor. 798 | kernel_size (int): The kernel size. 799 | stride (int): The stride. 800 | uf (int): The unfold factor. 801 | df (int): The fold factor. 802 | 803 | Returns: 804 | Tuple[nn.Conv2d, nn.ConvTranspose2d]: The fold and unfold convolutional layers. 805 | """ 806 | bs, nc, h, w = x.shape 807 | 808 | # number of crops in image 809 | Ly = (h - kernel_size[0]) // stride[0] + 1 810 | Lx = (w - kernel_size[1]) // stride[1] + 1 811 | 812 | if uf == 1 and df == 1: 813 | fold_params = dict( 814 | kernel_size=kernel_size, dilation=1, padding=0, stride=stride 815 | ) 816 | unfold = torch.nn.Unfold(**fold_params) 817 | 818 | fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) 819 | 820 | weighting = self.get_weighting( 821 | kernel_size[0], kernel_size[1], Ly, Lx, x.device 822 | ).to(x.dtype) 823 | normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap 824 | weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) 825 | 826 | elif uf > 1 and df == 1: 827 | fold_params = dict( 828 | kernel_size=kernel_size, dilation=1, padding=0, stride=stride 829 | ) 830 | unfold = torch.nn.Unfold(**fold_params) 831 | 832 | fold_params2 = dict( 833 | kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), 834 | dilation=1, 835 | padding=0, 836 | stride=(stride[0] * uf, stride[1] * uf), 837 | ) 838 | fold = torch.nn.Fold( 839 | output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 840 | ) 841 | 842 | weighting = self.get_weighting( 843 | kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device 844 | ).to(x.dtype) 845 | normalization = fold(weighting).view( 846 | 1, 1, h * uf, w * uf 847 | ) # normalizes the overlap 848 | weighting = weighting.view( 849 | (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) 850 | ) 851 | 852 | elif df > 1 and uf == 1: 853 | fold_params = dict( 854 | kernel_size=kernel_size, dilation=1, padding=0, stride=stride 855 | ) 856 | unfold = torch.nn.Unfold(**fold_params) 857 | 858 | fold_params2 = dict( 859 | kernel_size=(kernel_size[0] // df, kernel_size[0] // df), 860 | dilation=1, 861 | padding=0, 862 | stride=(stride[0] // df, stride[1] // df), 863 | ) 864 | fold = torch.nn.Fold( 865 | output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2 866 | ) 867 | 868 | weighting = self.get_weighting( 869 | kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device 870 | ).to(x.dtype) 871 | normalization = fold(weighting).view( 872 | 1, 1, h // df, w // df 873 | ) # normalizes the overlap 874 | weighting = weighting.view( 875 | (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) 876 | ) 877 | 878 | else: 879 | raise NotImplementedError 880 | 881 | return fold, unfold, normalization, weighting 882 | 883 | def forward( 884 | self, x: torch.Tensor, c: torch.Tensor, *args, **kwargs 885 | ) -> torch.Tensor: 886 | """ 887 | Computes the forward pass of the model. 888 | 889 | Args: 890 | x (torch.Tensor): The input tensor. 891 | c (torch.Tensor): The conditioning tensor. 892 | *args: Additional positional arguments. 893 | **kwargs: Additional keyword arguments. 894 | 895 | Returns: 896 | torch.Tensor: The output tensor. 897 | """ 898 | t = torch.randint( 899 | 0, self.num_timesteps, (x.shape[0],), device=self.device 900 | ).long() 901 | if self.model.conditioning_key is not None: 902 | assert c is not None 903 | if self.cond_stage_trainable: # This is FALSE in our case 904 | c = self.get_learned_conditioning(c) 905 | if self.shorten_cond_schedule: # TODO: drop this option # TRUE in our case 906 | tc = self.cond_ids[t].to(self.device) 907 | c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) 908 | 909 | return self.p_losses(x, c, t, *args, **kwargs) 910 | 911 | def q_sample( 912 | self, x_start: torch.Tensor, t: int, noise: Optional[torch.Tensor] = None 913 | ) -> torch.Tensor: 914 | """ 915 | Samples from the posterior distribution at the given timestep. 916 | 917 | Args: 918 | x_start (torch.Tensor): The starting tensor. 919 | t (int): The timestep. 920 | noise (Optional[torch.Tensor]): The noise tensor. 921 | 922 | Returns: 923 | torch.Tensor: The sampled tensor. 924 | """ 925 | noise = default(noise, lambda: torch.randn_like(x_start)) 926 | return ( 927 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 928 | + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) 929 | * noise 930 | ) 931 | -------------------------------------------------------------------------------- /opensr_model/diffusion/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | from torch import nn as nn 9 | from tqdm import tqdm 10 | 11 | 12 | def exists(val: Any) -> bool: 13 | """ 14 | Returns whether the given value exists (i.e., is not None). 15 | 16 | Args: 17 | val (Any): The value to check. 18 | 19 | Returns: 20 | bool: Whether the value exists. 21 | """ 22 | return val is not None 23 | 24 | 25 | def default(val: Any, d: Callable) -> Any: 26 | """ 27 | Returns the given value if it exists, otherwise returns the default value. 28 | 29 | Args: 30 | val (Any): The value to check. 31 | d (Callable): The default value or function to generate the default value. 32 | 33 | Returns: 34 | Any: The given value or the default value. 35 | """ 36 | if exists(val): 37 | return val 38 | return d() if callable(d) else d 39 | 40 | 41 | def count_params(model: nn.Module, verbose: bool = False) -> int: 42 | """ 43 | Returns the total number of parameters in the given model. 44 | 45 | Args: 46 | model (nn.Module): The model. 47 | verbose (bool): Whether to print the number of parameters. 48 | 49 | Returns: 50 | int: The total number of parameters. 51 | """ 52 | total_params = sum(p.numel() for p in model.parameters()) 53 | if verbose: 54 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 55 | return total_params 56 | 57 | 58 | def disabled_train(self, mode: bool = True) -> nn.Module: 59 | """ 60 | Overwrites the `train` method of the model to disable changing the mode. 61 | 62 | Args: 63 | mode (bool): Whether to enable or disable training mode. 64 | 65 | Returns: 66 | nn.Module: The model. 67 | """ 68 | return self 69 | 70 | 71 | def make_convolutional_sample( 72 | batch: Tensor, 73 | model: nn.Module, 74 | custom_steps: Optional[Union[int, Tuple[int, int]]] = None, 75 | eta: float = 1.0, 76 | quantize_x0: bool = False, 77 | custom_shape: Optional[Tuple[int, int]] = None, 78 | temperature: float = 1.0, 79 | noise_dropout: float = 0.0, 80 | corrector: Optional[nn.Module] = None, 81 | corrector_kwargs: Optional[dict] = None, 82 | x_T: Optional[Tensor] = None, 83 | ddim_use_x0_pred: bool = False, 84 | ) -> Tuple[Tensor, Optional[Tensor]]: 85 | """ 86 | Generates a convolutional sample using the given model. 87 | 88 | Args: 89 | batch (Tensor): The input batch tensor. 90 | model (nn.Module): The model to use for sampling. 91 | custom_steps (Optional[Union[int, Tuple[int, int]]]): The custom number of steps. 92 | eta (float): The eta value. 93 | quantize_x0 (bool): Whether to quantize the initial sample. 94 | custom_shape (Optional[Tuple[int, int]]): The custom shape. 95 | temperature (float): The temperature value. 96 | noise_dropout (float): The noise dropout value. 97 | corrector (Optional[nn.Module]): The corrector module. 98 | corrector_kwargs (Optional[dict]): The corrector module keyword arguments. 99 | x_T (Optional[Tensor]): The target tensor. 100 | ddim_use_x0_pred (bool): Whether to use x0 prediction for DDim. 101 | 102 | Returns: 103 | Tuple[Tensor, Optional[Tensor]]: The generated sample tensor and the 104 | target tensor (if provided). 105 | """ 106 | # create an empty dictionary to store the log 107 | log = dict() 108 | 109 | # get the input data and conditioning from the model 110 | z, c, x, xrec, xc = model.get_input( 111 | batch, 112 | model.first_stage_key, 113 | return_first_stage_outputs=True, 114 | force_c_encode=not ( 115 | hasattr(model, "split_input_params") 116 | and model.cond_stage_key == "coordinates_bbox" 117 | ), 118 | return_original_cond=True, 119 | ) 120 | 121 | # if custom_shape is not None, generate random noise of the specified shape 122 | if custom_shape is not None: 123 | z = torch.randn(custom_shape) 124 | print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") 125 | 126 | # store the input and reconstruction in the log 127 | log["input"] = x 128 | log["reconstruction"] = xrec 129 | 130 | # sample from the model using convsample_ddim 131 | with model.ema_scope("Plotting"): 132 | t0 = time.time() 133 | sample, intermediates = convsample_ddim( 134 | model=model, 135 | cond=c, 136 | steps=custom_steps, 137 | shape=z.shape, 138 | eta=eta, 139 | quantize_x0=quantize_x0, 140 | noise_dropout=noise_dropout, 141 | mask=None, 142 | x0=None, 143 | temperature=temperature, 144 | score_corrector=corrector, 145 | corrector_kwargs=corrector_kwargs, 146 | x_T=x_T, 147 | ) 148 | t1 = time.time() 149 | 150 | # if ddim_use_x0_pred is True, use the predicted x0 from the intermediates 151 | if ddim_use_x0_pred: 152 | sample = intermediates["pred_x0"][-1] 153 | 154 | # decode the sample to get the generated image 155 | x_sample = model.decode_first_stage(sample) 156 | 157 | # try to decode the sample without quantization to get the unquantized image 158 | try: 159 | x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) 160 | log["sample_noquant"] = x_sample_noquant 161 | log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) 162 | except: 163 | pass 164 | 165 | # store the generated image, time taken, and other information in the log 166 | log["sample"] = x_sample 167 | log["time"] = t1 - t0 168 | 169 | # return the log 170 | return log 171 | 172 | 173 | def disabled_train(self: nn.Module, mode: bool = True) -> nn.Module: 174 | """ 175 | Overwrites the `train` method of the model to disable changing the mode. 176 | 177 | Args: 178 | mode (bool): Whether to enable or disable training mode. 179 | 180 | Returns: 181 | nn.Module: The model. 182 | """ 183 | return self 184 | 185 | 186 | def convsample_ddim( 187 | model: nn.Module, 188 | cond: Tensor, 189 | steps: int, 190 | shape: Tuple[int, int], 191 | eta: float = 1.0, 192 | callback: Optional[callable] = None, 193 | noise_dropout: float = 0.0, 194 | normals_sequence: Optional[Tensor] = None, 195 | mask: Optional[Tensor] = None, 196 | x0: Optional[Tensor] = None, 197 | quantize_x0: bool = False, 198 | temperature: float = 1.0, 199 | score_corrector: Optional[nn.Module] = None, 200 | corrector_kwargs: Optional[dict] = None, 201 | x_T: Optional[Tensor] = None, 202 | ) -> Tuple[Tensor, Optional[Tensor]]: 203 | """ 204 | Generates a convolutional sample using the given model and conditioning tensor. 205 | 206 | Args: 207 | model (nn.Module): The model to use for sampling. 208 | cond (Tensor): The conditioning tensor. 209 | steps (int): The number of steps. 210 | shape (Tuple[int, int]): The shape of the sample. 211 | eta (float): The eta value. 212 | callback (Optional[callable]): The callback function. 213 | normals_sequence (Optional[Tensor]): The normals sequence tensor. 214 | noise_dropout (float): The noise dropout value. 215 | mask (Optional[Tensor]): The mask tensor. 216 | x0 (Optional[Tensor]): The initial sample tensor. 217 | quantize_x0 (bool): Whether to quantize the initial sample. 218 | temperature (float): The temperature value. 219 | score_corrector (Optional[nn.Module]): The score corrector module. 220 | corrector_kwargs (Optional[dict]): The score corrector module keyword arguments. 221 | x_T (Optional[Tensor]): The target tensor. 222 | 223 | Returns: 224 | Tuple[Tensor, Optional[Tensor]]: The generated sample tensor and the target tensor (if provided). 225 | """ 226 | ddim = DDIMSampler(model) 227 | bs = shape[0] # dont know where this comes from but wayne 228 | shape = shape[1:] # cut batch dim 229 | print(f"Sampling with eta = {eta}; steps: {steps}") 230 | samples, intermediates = ddim.sample( 231 | steps, 232 | batch_size=bs, 233 | shape=shape, 234 | conditioning=cond, 235 | callback=callback, 236 | normals_sequence=normals_sequence, 237 | quantize_x0=quantize_x0, 238 | eta=eta, 239 | mask=mask, 240 | x0=x0, 241 | temperature=temperature, 242 | verbose=False, 243 | score_corrector=score_corrector, 244 | noise_dropout=noise_dropout, 245 | corrector_kwargs=corrector_kwargs, 246 | x_T=x_T, 247 | ) 248 | 249 | return samples, intermediates 250 | 251 | 252 | def make_ddim_sampling_parameters( 253 | alphacums: np.ndarray, ddim_timesteps: np.ndarray, eta: float, verbose: bool = True 254 | ) -> tuple: 255 | """ 256 | Computes the variance schedule for the ddim sampler, based on the given parameters. 257 | 258 | Args: 259 | alphacums (np.ndarray): Array of cumulative alpha values. 260 | ddim_timesteps (np.ndarray): Array of timesteps to use for computing alphas. 261 | eta (float): Scaling factor for computing sigmas. 262 | verbose (bool, optional): Whether to print out the selected alphas and sigmas. Defaults to True. 263 | 264 | Returns: 265 | tuple: A tuple containing three arrays: sigmas, alphas, and alphas_prev. 266 | sigmas (np.ndarray): Array of sigma values for each timestep. 267 | alphas (np.ndarray): Array of alpha values for each timestep. 268 | alphas_prev (np.ndarray): Array of alpha values for the previous timestep. 269 | """ 270 | # select alphas for computing the variance schedule 271 | alphas = alphacums[ddim_timesteps] 272 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 273 | 274 | # according the the formula provided in https://arxiv.org/abs/2010.02502 275 | sigmas = eta * np.sqrt( 276 | (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) 277 | ) 278 | if verbose: 279 | print( 280 | f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" 281 | ) 282 | print( 283 | f"For the chosen value of eta, which is {eta}, " 284 | f"this results in the following sigma_t schedule for ddim sampler {sigmas}" 285 | ) 286 | return sigmas, alphas, alphas_prev 287 | 288 | 289 | def make_ddim_timesteps( 290 | ddim_discr_method: str, 291 | num_ddim_timesteps: int, 292 | num_ddpm_timesteps: int, 293 | verbose: bool = True, 294 | ) -> np.ndarray: 295 | """ 296 | Computes the timesteps to use for computing alphas in the ddim sampler. 297 | 298 | Args: 299 | ddim_discr_method (str): The method to use for discretizing the timesteps. 300 | Must be either 'uniform' or 'quad'. 301 | num_ddim_timesteps (int): The number of timesteps to use for computing alphas. 302 | num_ddpm_timesteps (int): The total number of timesteps in the DDPM model. 303 | verbose (bool, optional): Whether to print out the selected timesteps. Defaults to True. 304 | 305 | Returns: 306 | np.ndarray: An array of timesteps to use for computing alphas in the ddim sampler. 307 | """ 308 | if ddim_discr_method == "uniform": 309 | c = num_ddpm_timesteps // num_ddim_timesteps 310 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 311 | elif ddim_discr_method == "quad": 312 | ddim_timesteps = ( 313 | (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 314 | ).astype(int) 315 | else: 316 | raise NotImplementedError( 317 | f'There is no ddim discretization method called "{ddim_discr_method}"' 318 | ) 319 | 320 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 321 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 322 | steps_out = ddim_timesteps + 1 323 | if verbose: 324 | print(f"Selected timesteps for ddim sampler: {steps_out}") 325 | return steps_out 326 | 327 | 328 | def noise_like(shape: tuple, device: str, repeat: bool = False) -> torch.Tensor: 329 | """ 330 | Generates noise with the same shape as the given tensor. 331 | 332 | Args: 333 | shape (tuple): The shape of the tensor to generate noise for. 334 | device (str): The device to place the noise tensor on. 335 | repeat (bool, optional): Whether to repeat the same noise across the batch dimension. Defaults to False. 336 | 337 | Returns: 338 | torch.Tensor: A tensor of noise with the same shape as the input tensor. 339 | """ 340 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 341 | shape[0], *((1,) * (len(shape) - 1)) 342 | ) 343 | noise = lambda: torch.randn(shape, device=device) 344 | return repeat_noise() if repeat else noise() 345 | 346 | 347 | class DDIMSampler(object): 348 | def __init__(self, model: object, schedule: str = "linear", **kwargs: dict) -> None: 349 | super().__init__() 350 | self.model = model 351 | self.ddpm_num_timesteps = model.num_timesteps 352 | self.schedule = schedule 353 | self.device = model.device 354 | 355 | def register_buffer(self, name: str, attr: torch.Tensor) -> None: 356 | if type(attr) == torch.Tensor: 357 | if attr.device != torch.device(self.device): 358 | attr = attr.to(torch.device(self.device)) 359 | setattr(self, name, attr) 360 | 361 | def make_schedule( 362 | self, 363 | ddim_num_steps: int, 364 | ddim_discretize: str = "uniform", 365 | ddim_eta: float = 0.0, 366 | verbose: bool = True, 367 | ) -> None: 368 | # make ddim timesteps. these are the timesteps at which we compute alphas 369 | self.ddim_timesteps = make_ddim_timesteps( 370 | ddim_discr_method=ddim_discretize, 371 | num_ddim_timesteps=ddim_num_steps, 372 | num_ddpm_timesteps=self.ddpm_num_timesteps, 373 | verbose=verbose, 374 | ) 375 | 376 | # get alphas_cumprod from the model 377 | alphas_cumprod = self.model.alphas_cumprod 378 | 379 | # check if alphas_cumprod is defined for each timestep 380 | assert ( 381 | alphas_cumprod.shape[0] == self.ddpm_num_timesteps 382 | ), "alphas have to be defined for each timestep" 383 | 384 | # define a function to convert tensor to torch tensor 385 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 386 | 387 | # register buffers for betas, alphas_cumprod, and alphas_cumprod_prev 388 | self.register_buffer("betas", to_torch(self.model.betas)) 389 | self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) 390 | self.register_buffer( 391 | "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) 392 | ) 393 | 394 | # calculations for diffusion q(x_t | x_{t-1}) and others 395 | self.register_buffer( 396 | "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) 397 | ) 398 | self.register_buffer( 399 | "sqrt_one_minus_alphas_cumprod", 400 | to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), 401 | ) 402 | self.register_buffer( 403 | "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) 404 | ) 405 | self.register_buffer( 406 | "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) 407 | ) 408 | self.register_buffer( 409 | "sqrt_recipm1_alphas_cumprod", 410 | to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), 411 | ) 412 | 413 | # ddim sampling parameters 414 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( 415 | alphacums=alphas_cumprod.cpu(), 416 | ddim_timesteps=self.ddim_timesteps, 417 | eta=ddim_eta, 418 | verbose=verbose, 419 | ) 420 | self.register_buffer("ddim_sigmas", ddim_sigmas) 421 | self.register_buffer("ddim_alphas", ddim_alphas) 422 | self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) 423 | self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) 424 | 425 | # calculate sigmas for original sampling steps 426 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 427 | (1 - self.alphas_cumprod_prev) 428 | / (1 - self.alphas_cumprod) 429 | * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) 430 | ) 431 | self.register_buffer( 432 | "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps 433 | ) 434 | 435 | def sample( 436 | self, 437 | S: int, 438 | batch_size: int, 439 | shape: Tuple[int, int, int], 440 | conditioning: Optional[torch.Tensor] = None, 441 | callback: Optional[callable] = None, 442 | img_callback: Optional[callable] = None, 443 | quantize_x0: bool = False, 444 | eta: float = 0.0, 445 | mask: Optional[torch.Tensor] = None, 446 | x0: Optional[torch.Tensor] = None, 447 | temperature: float = 1.0, 448 | noise_dropout: float = 0.0, 449 | score_corrector: Optional[callable] = None, 450 | corrector_kwargs: Optional[dict] = None, 451 | verbose: bool = True, 452 | x_T: Optional[torch.Tensor] = None, 453 | log_every_t: int = 100, 454 | unconditional_guidance_scale: float = 1.0, 455 | unconditional_conditioning: Optional[torch.Tensor] = None, 456 | **kwargs, 457 | ) -> Tuple[torch.Tensor, dict]: 458 | """ 459 | Samples from the model using DDIM sampling. 460 | 461 | Args: 462 | S (int): Number of DDIM steps. 463 | batch_size (int): Batch size. 464 | shape (Tuple[int, int, int]): Shape of the output tensor. 465 | conditioning (Optional[torch.Tensor], optional): Conditioning tensor. Defaults to None. 466 | callback (Optional[callable], optional): Callback function. Defaults to None. 467 | img_callback (Optional[callable], optional): Image callback function. Defaults to None. 468 | quantize_x0 (bool, optional): Whether to quantize the denoised image. Defaults to False. 469 | eta (float, optional): Learning rate for DDIM. Defaults to 0.. 470 | mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None. 471 | x0 (Optional[torch.Tensor], optional): Initial tensor. Defaults to None. 472 | temperature (float, optional): Sampling temperature. Defaults to 1.. 473 | noise_dropout (float, optional): Noise dropout rate. Defaults to 0.. 474 | score_corrector (Optional[callable], optional): Score corrector function. Defaults to None. 475 | corrector_kwargs (Optional[dict], optional): Keyword arguments for the score corrector function. 476 | Defaults to None. 477 | verbose (bool, optional): Whether to print verbose output. Defaults to True. 478 | x_T (Optional[torch.Tensor], optional): Target tensor. Defaults to None. 479 | log_every_t (int, optional): Log every t steps. Defaults to 100. 480 | unconditional_guidance_scale (float, optional): Scale for unconditional guidance. Defaults to 1.. 481 | unconditional_conditioning (Optional[torch.Tensor], optional): Unconditional conditioning tensor. 482 | Defaults to None. 483 | 484 | Returns: 485 | Tuple[torch.Tensor, dict]: Tuple containing the generated samples and intermediate results. 486 | """ 487 | # check if conditioning is None 488 | if conditioning is not None: 489 | if isinstance(conditioning, dict): 490 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 491 | if cbs != batch_size: 492 | print( 493 | f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" 494 | ) 495 | else: 496 | if conditioning.shape[0] != batch_size: 497 | print( 498 | f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" 499 | ) 500 | 501 | # make schedule to compute alphas and sigmas 502 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 503 | 504 | # parameters for sampling 505 | C, H, W = shape 506 | size = (batch_size, C, H, W) 507 | print(f"Data shape for DDIM sampling is {size}, eta {eta}") 508 | 509 | # sample from the model using ddim_sampling 510 | samples, intermediates = self.ddim_sampling( 511 | cond=conditioning, 512 | shape=size, 513 | callback=callback, 514 | img_callback=img_callback, 515 | quantize_denoised=quantize_x0, 516 | mask=mask, 517 | x0=x0, 518 | ddim_use_original_steps=False, 519 | noise_dropout=noise_dropout, 520 | temperature=temperature, 521 | score_corrector=score_corrector, 522 | corrector_kwargs=corrector_kwargs, 523 | x_T=x_T, 524 | log_every_t=log_every_t, 525 | unconditional_guidance_scale=unconditional_guidance_scale, 526 | unconditional_conditioning=unconditional_conditioning, 527 | ) 528 | 529 | return samples, intermediates 530 | 531 | def ddim_sampling( 532 | self, 533 | cond: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]], 534 | shape: Tuple[int, int, int], 535 | x_T: Optional[torch.Tensor] = None, 536 | ddim_use_original_steps: bool = False, 537 | callback: Optional[callable] = None, 538 | timesteps: Optional[List[int]] = None, 539 | quantize_denoised: bool = False, 540 | mask: Optional[torch.Tensor] = None, 541 | x0: Optional[torch.Tensor] = None, 542 | img_callback: Optional[callable] = None, 543 | log_every_t: int = 100, 544 | temperature: float = 1.0, 545 | noise_dropout: float = 0.0, 546 | score_corrector: Optional[callable] = None, 547 | corrector_kwargs: Optional[Dict[str, Any]] = None, 548 | unconditional_guidance_scale: float = 1.0, 549 | unconditional_conditioning: Optional[torch.Tensor] = None, 550 | ) -> Tuple[torch.Tensor, Dict[str, Any]]: 551 | """ 552 | Samples from the model using DDIM sampling. 553 | 554 | Args: 555 | cond (Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]): Conditioning 556 | tensor. Defaults to None. 557 | shape (Tuple[int, int, int]): Shape of the output tensor. 558 | x_T (Optional[torch.Tensor], optional): Target tensor. Defaults to None. 559 | ddim_use_original_steps (bool, optional): Whether to use original DDIM steps. Defaults to False. 560 | callback (Optional[callable], optional): Callback function. Defaults to None. 561 | timesteps (Optional[List[int]], optional): List of timesteps. Defaults to None. 562 | quantize_denoised (bool, optional): Whether to quantize the denoised image. Defaults to False. 563 | mask (Optional[torch.Tensor], optional): Mask tensor. Defaults to None. 564 | x0 (Optional[torch.Tensor], optional): Initial tensor. Defaults to None. 565 | img_callback (Optional[callable], optional): Image callback function. Defaults to None. 566 | log_every_t (int, optional): Log every t steps. Defaults to 100. 567 | temperature (float, optional): Sampling temperature. Defaults to 1.. 568 | noise_dropout (float, optional): Noise dropout rate. Defaults to 0.. 569 | score_corrector (Optional[callable], optional): Score corrector function. Defaults to None. 570 | corrector_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments for the score corrector 571 | function. Defaults to None. 572 | unconditional_guidance_scale (float, optional): Scale for unconditional guidance. Defaults to 1. 573 | unconditional_conditioning (Optional[torch.Tensor], optional): Unconditional conditioning tensor. 574 | Defaults to None. 575 | 576 | Returns: 577 | Tuple[torch.Tensor, Dict[str, Any]]: Tuple containing the generated samples and intermediate results. 578 | """ 579 | # Get the device and batch size 580 | device = self.model.betas.device 581 | b = shape[0] 582 | 583 | # Initialize the image tensor 584 | if x_T is None: 585 | img = torch.randn(shape, device=device) 586 | else: 587 | img = x_T 588 | 589 | # Get the timesteps 590 | if timesteps is None: 591 | timesteps = ( 592 | self.ddpm_num_timesteps 593 | if ddim_use_original_steps 594 | else self.ddim_timesteps 595 | ) 596 | elif timesteps is not None and not ddim_use_original_steps: 597 | subset_end = ( 598 | int( 599 | min(timesteps / self.ddim_timesteps.shape[0], 1) 600 | * self.ddim_timesteps.shape[0] 601 | ) 602 | - 1 603 | ) 604 | timesteps = self.ddim_timesteps[:subset_end] 605 | 606 | # Initialize the intermediates dictionary 607 | intermediates = {"x_inter": [img], "pred_x0": [img]} 608 | 609 | # Set the time range and total steps 610 | time_range = ( 611 | reversed(range(0, timesteps)) 612 | if ddim_use_original_steps 613 | else np.flip(timesteps) 614 | ) 615 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 616 | print(f"Running DDIM Sampling with {total_steps} timesteps") 617 | 618 | # Initialize the progress bar iterator 619 | iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) 620 | 621 | # Loop over the timesteps 622 | for i, step in enumerate(iterator): 623 | index = total_steps - i - 1 624 | ts = torch.full((b,), step, device=device, dtype=torch.long) 625 | 626 | # Sample from the model using DDIM 627 | outs = self.p_sample_ddim( 628 | img, 629 | cond, 630 | ts, 631 | index=index, 632 | use_original_steps=ddim_use_original_steps, 633 | temperature=temperature, 634 | noise_dropout=noise_dropout, 635 | unconditional_guidance_scale=unconditional_guidance_scale, 636 | unconditional_conditioning=unconditional_conditioning, 637 | ) 638 | img, pred_x0 = outs 639 | 640 | 641 | # Append the intermediate results to the intermediates dictionary 642 | if index % log_every_t == 0 or index == total_steps - 1: 643 | intermediates["x_inter"].append(img) 644 | intermediates["pred_x0"].append(pred_x0) 645 | 646 | return img, intermediates 647 | 648 | def p_sample_ddim( 649 | self, 650 | x: torch.Tensor, 651 | c: torch.Tensor, 652 | t: int, 653 | index: int, 654 | repeat_noise: bool = False, 655 | use_original_steps: bool = False, 656 | temperature: float = 1.0 657 | ) -> Tuple[torch.Tensor, torch.Tensor]: 658 | """ 659 | Samples from the model using DDIM sampling. 660 | 661 | Args: 662 | x (torch.Tensor): Input tensor. 663 | c (torch.Tensor): Conditioning tensor. 664 | t (int): Current timestep. 665 | index (int): Index of the current timestep. 666 | repeat_noise (bool, optional): Whether to repeat noise. Defaults to False. 667 | use_original_steps (bool, optional): Whether to use original DDIM steps. 668 | Defaults to False. 669 | quantize_denoised (bool, optional): Whether to quantize the denoised image. 670 | Defaults to False. 671 | temperature (float, optional): Sampling temperature. Defaults to 1.. 672 | noise_dropout (float, optional): Noise dropout rate. Defaults to 0.. 673 | score_corrector (Optional[callable], optional): Score corrector function. 674 | Defaults to None. 675 | corrector_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments 676 | for the score corrector function. Defaults to None. 677 | unconditional_guidance_scale (float, optional): Scale for unconditional 678 | guidance. Defaults to 1.. 679 | unconditional_conditioning (Optional[torch.Tensor], optional): Unconditional 680 | conditioning tensor. Defaults to None. 681 | 682 | Returns: 683 | Tuple[torch.Tensor, torch.Tensor]: Tuple containing the generated samples and intermediate results. 684 | """ 685 | t = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long) 686 | 687 | # get batch size and device 688 | b, *_, device = *x.shape, x.device 689 | 690 | # apply model with or without unconditional conditioning 691 | e_t = self.model.apply_model(x, t, c) 692 | 693 | # get alphas, alphas_prev, sqrt_one_minus_alphas, and sigmas 694 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 695 | alphas_prev = ( 696 | self.model.alphas_cumprod_prev 697 | if use_original_steps 698 | else self.ddim_alphas_prev 699 | ) 700 | sqrt_one_minus_alphas = ( 701 | self.model.sqrt_one_minus_alphas_cumprod 702 | if use_original_steps 703 | else self.ddim_sqrt_one_minus_alphas 704 | ) 705 | sigmas = ( 706 | self.model.ddim_sigmas_for_original_num_steps 707 | if use_original_steps 708 | else self.ddim_sigmas 709 | ) 710 | 711 | # select parameters corresponding to the currently considered timestep 712 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 713 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 714 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 715 | sqrt_one_minus_at = torch.full( 716 | (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device 717 | ) 718 | 719 | # current prediction for x_0 720 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 721 | 722 | # direction pointing to x_t 723 | dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t 724 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 725 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 726 | 727 | return x_prev, pred_x0 728 | 729 | 730 | def make_beta_schedule( 731 | schedule: str, 732 | n_timestep: int, 733 | linear_start: float = 1e-4, 734 | linear_end: float = 2e-2, 735 | cosine_s: float = 8e-3, 736 | ) -> np.ndarray: 737 | """ 738 | Creates a beta schedule for the diffusion process. 739 | 740 | Args: 741 | schedule (str): Type of schedule to use. Can be "linear", "cosine", "sqrt_linear", or "sqrt". 742 | n_timestep (int): Number of timesteps. 743 | linear_start (float, optional): Starting value for linear schedule. Defaults to 1e-4. 744 | linear_end (float, optional): Ending value for linear schedule. Defaults to 2e-2. 745 | cosine_s (float, optional): Scaling factor for cosine schedule. Defaults to 8e-3. 746 | 747 | Returns: 748 | np.ndarray: Array of beta values. 749 | """ 750 | if schedule == "linear": 751 | betas = ( 752 | torch.linspace( 753 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 754 | ) 755 | ** 2 756 | ) 757 | 758 | elif schedule == "cosine": 759 | timesteps = ( 760 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 761 | ) 762 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 763 | alphas = torch.cos(alphas).pow(2) 764 | alphas = alphas / alphas[0] 765 | betas = 1 - alphas[1:] / alphas[:-1] 766 | betas = np.clip(betas, a_min=0, a_max=0.999) 767 | 768 | elif schedule == "sqrt_linear": 769 | betas = torch.linspace( 770 | linear_start, linear_end, n_timestep, dtype=torch.float64 771 | ) 772 | elif schedule == "sqrt": 773 | betas = ( 774 | torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 775 | ** 0.5 776 | ) 777 | else: 778 | raise ValueError(f"schedule '{schedule}' unknown.") 779 | return betas.numpy() 780 | 781 | 782 | def extract_into_tensor( 783 | a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...] 784 | ) -> torch.Tensor: 785 | """ 786 | Extracts values from a tensor into a new tensor based on indices. 787 | 788 | Args: 789 | a (torch.Tensor): Input tensor. 790 | t (torch.Tensor): Indices tensor. 791 | x_shape (Tuple[int, ...]): Shape of the output tensor. 792 | 793 | Returns: 794 | torch.Tensor: Output tensor. 795 | """ 796 | b, *_ = t.shape 797 | out = a.gather(-1, t) 798 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 799 | 800 | 801 | class LitEma(nn.Module): 802 | def __init__( 803 | self, model: nn.Module, decay: float = 0.9999, use_num_upates: bool = True 804 | ) -> None: 805 | """ 806 | Initializes the LitEma class. 807 | 808 | Args: 809 | model (nn.Module): The model to apply EMA to. 810 | decay (float, optional): The decay rate for EMA. Must be between 0 and 1. Defaults to 0.9999. 811 | use_num_upates (bool, optional): Whether to use the number of updates to adjust decay. Defaults to True. 812 | """ 813 | super().__init__() 814 | if decay < 0.0 or decay > 1.0: 815 | raise ValueError("Decay must be between 0 and 1") 816 | 817 | self.m_name2s_name = {} 818 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 819 | self.register_buffer( 820 | "num_updates", 821 | torch.tensor(0, dtype=torch.int) 822 | if use_num_upates 823 | else torch.tensor(-1, dtype=torch.int), 824 | ) 825 | 826 | for name, p in model.named_parameters(): 827 | if p.requires_grad: 828 | # remove as '.'-character is not allowed in buffers 829 | s_name = name.replace(".", "") 830 | self.m_name2s_name.update({name: s_name}) 831 | self.register_buffer(s_name, p.clone().detach().data) 832 | 833 | self.collected_params = [] 834 | 835 | def forward(self, model: nn.Module) -> None: 836 | """ 837 | Applies EMA to the model. 838 | 839 | Args: 840 | model (nn.Module): The model to apply EMA to. 841 | """ 842 | decay = self.decay 843 | 844 | if self.num_updates >= 0: 845 | self.num_updates += 1 846 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 847 | 848 | one_minus_decay = 1.0 - decay 849 | 850 | with True: 851 | m_param = dict(model.named_parameters()) 852 | shadow_params = dict(self.named_buffers()) 853 | 854 | for key in m_param: 855 | if m_param[key].requires_grad: 856 | sname = self.m_name2s_name[key] 857 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 858 | shadow_params[sname].sub_( 859 | one_minus_decay * (shadow_params[sname] - m_param[key]) 860 | ) 861 | else: 862 | assert not key in self.m_name2s_name 863 | 864 | def copy_to(self, model: nn.Module) -> None: 865 | """ 866 | Copies the EMA parameters to the model. 867 | 868 | Args: 869 | model (nn.Module): The model to copy the EMA parameters to. 870 | """ 871 | m_param = dict(model.named_parameters()) 872 | shadow_params = dict(self.named_buffers()) 873 | for key in m_param: 874 | if m_param[key].requires_grad: 875 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 876 | else: 877 | assert not key in self.m_name2s_name 878 | 879 | def store(self, parameters: Iterable[nn.Parameter]) -> None: 880 | """ 881 | Saves the current parameters for restoring later. 882 | 883 | Args: 884 | parameters (Iterable[nn.Parameter]): The parameters to be temporarily stored. 885 | """ 886 | self.collected_params = [param.clone() for param in parameters] 887 | 888 | def restore(self, parameters: Iterable[nn.Parameter]) -> None: 889 | """ 890 | Restores the parameters stored with the `store` method. 891 | 892 | Useful to validate the model with EMA parameters without affecting the 893 | original optimization process. Store the parameters before the 894 | `copy_to` method. After validation (or model saving), use this to 895 | restore the former parameters. 896 | 897 | Args: 898 | parameters (Iterable[nn.Parameter]): The parameters to be updated with the stored parameters. 899 | """ 900 | for c_param, param in zip(self.collected_params, parameters): 901 | param.data.copy_(c_param.data) 902 | -------------------------------------------------------------------------------- /opensr_model/srmodel.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Union 3 | import requests 4 | import torch 5 | from opensr_model.diffusion.latentdiffusion import LatentDiffusion 6 | from skimage.exposure import match_histograms 7 | import torch.utils.checkpoint as checkpoint 8 | from opensr_model.diffusion.utils import DDIMSampler 9 | from tqdm import tqdm 10 | import numpy as np 11 | from typing import Literal 12 | import shutil 13 | import random 14 | import numpy as np 15 | from einops import rearrange 16 | from opensr_model.utils import suppress_stdout 17 | from opensr_model.utils import assert_tensor_validity 18 | from opensr_model.utils import revert_padding 19 | from opensr_model.utils import create_no_data_mask 20 | from opensr_model.utils import apply_no_data_mask 21 | 22 | 23 | 24 | class SRLatentDiffusion(torch.nn.Module): 25 | def __init__(self,config, device: Union[str, torch.device] = "cpu"): 26 | super().__init__() 27 | 28 | # Set up the model 29 | self.config = config 30 | self.model = LatentDiffusion( 31 | config.first_stage_config, 32 | config.cond_stage_config, 33 | timesteps=config.denoiser_settings.timesteps, 34 | unet_config=config.cond_stage_config, 35 | linear_start=config.denoiser_settings.linear_start, 36 | linear_end=config.denoiser_settings.linear_end, 37 | concat_mode=config.other.concat_mode, 38 | cond_stage_trainable=config.other.cond_stage_trainable, 39 | first_stage_key=config.other.first_stage_key, 40 | cond_stage_key=config.other.cond_stage_key, 41 | ) 42 | 43 | 44 | # Set up the model for inference 45 | self.set_normalization() # decide wether to use norm 46 | self.device = device # set self device 47 | self.model.device = device # set model device as selected 48 | self.model = self.model.to(device) # move model to device 49 | self.model.eval() # set model state 50 | self = self.eval() # set main model state 51 | self._X = None # placeholder for LR image 52 | self.encode_conditioning = config.encode_conditioning # encode LR images before dif? 53 | 54 | def set_normalization(self): 55 | if self.config.apply_normalization==True: 56 | from opensr_model.utils import linear_transform_4b 57 | self.linear_transform = linear_transform_4b 58 | else: 59 | from opensr_model.utils import linear_transform_placeholder 60 | self.linear_transform = linear_transform_placeholder 61 | print("Normalization disabled.") 62 | 63 | def _tensor_encode(self,X: torch.Tensor): 64 | # set copy to model 65 | self._X = X.clone() 66 | # normalize image 67 | X_enc = self.linear_transform(X, stage="norm") 68 | # encode LR images 69 | if self.encode_conditioning==True : 70 | # try to upsample->encode conditioning 71 | X_int = torch.nn.functional.interpolate(X, size=(X.shape[-1]*4,X.shape[-1]*4), mode='bilinear', align_corners=False) 72 | # encode conditioning 73 | X_enc = self.model.first_stage_model.encode(X_int).sample() 74 | # move to same device as the model 75 | X_enc = X_enc.to(self.device) 76 | return X_enc 77 | 78 | def _tensor_decode(self, X_enc: torch.Tensor, spe_cor: bool = True): 79 | # Decode 80 | X_dec = self.model.decode_first_stage(X_enc) 81 | X_dec = self.linear_transform(X_dec, stage="denorm") 82 | # Apply spectral correction 83 | if spe_cor: 84 | for i in range(X_dec.shape[1]): 85 | X_dec[:, i] = self.hq_histogram_matching(X_dec[:, i], self._X[:, i]) 86 | # If the value is negative, set it to 0 87 | X_dec[X_dec < 0] = 0 88 | return X_dec 89 | 90 | def _prepare_model( 91 | self, 92 | X: torch.Tensor, 93 | eta: float = 1.0, 94 | custom_steps: int = 100, 95 | verbose: bool = False 96 | ): 97 | # Create the DDIM sampler 98 | ddim = DDIMSampler(self.model) 99 | 100 | # make schedule to compute alphas and sigmas 101 | ddim.make_schedule(ddim_num_steps=custom_steps, ddim_eta=eta, verbose=verbose) 102 | 103 | # Create the HR latent image 104 | latent = torch.randn(X.shape, device=X.device) 105 | 106 | # Create the vector with the timesteps 107 | timesteps = ddim.ddim_timesteps 108 | time_range = np.flip(timesteps) 109 | 110 | return ddim, latent, time_range 111 | 112 | @torch.no_grad() 113 | def forward( 114 | self, 115 | X: torch.Tensor, 116 | eta: float = 1.0, 117 | custom_steps: int = 100, 118 | temperature: float = 1.0, 119 | histogram_matching: bool = True, 120 | save_iterations: bool = False, 121 | verbose: bool = False 122 | ): 123 | """Obtain the super resolution of the given image. 124 | 125 | Args: 126 | X (torch.Tensor): If a Sentinel-2 L2A image with reflectance values 127 | in the range [0, 1] and shape CxWxH, the super resolution of the image 128 | is returned. If a batch of images with shape BxCxWxH is given, a batch 129 | of super resolved images is returned. 130 | custom_steps (int, optional): Number of steps to run the denoiser. Defaults 131 | to 100. 132 | temperature (float, optional): Temperature to use in the denoiser. 133 | Defaults to 1.0. The higher the temperature, the more stochastic 134 | the denoiser is (random noise gets multiplied by this). 135 | spectral_correction (bool, optional): Apply spectral correction to the SR 136 | image, using the LR image as reference. Defaults to True. 137 | 138 | Returns: 139 | torch.Tensor: The super resolved image or batch of images with a shape of 140 | Cx(Wx4)x(Hx4) or BxCx(Wx4)x(Hx4). 141 | """ 142 | 143 | # Assert shape, size, dimensionality 144 | X,padding = assert_tensor_validity(X) 145 | 146 | # create no_data_mask 147 | no_data_mask = create_no_data_mask(X, target_size= X.shape[-1]*4) 148 | 149 | # Normalize the image 150 | X = X.clone() 151 | Xnorm = self._tensor_encode(X) 152 | 153 | # ddim, latent and time_range 154 | ddim, latent, time_range = self._prepare_model( 155 | X=Xnorm, eta=eta, custom_steps=custom_steps, verbose=verbose 156 | ) 157 | iterator = tqdm(time_range, desc="DDIM Sampler", total=custom_steps,disable=True) 158 | 159 | # Iterate over the timesteps 160 | if save_iterations: 161 | save_iters = [] 162 | 163 | for i, step in enumerate(iterator): 164 | outs = ddim.p_sample_ddim( 165 | x=latent, 166 | c=Xnorm, 167 | t=step, 168 | index=custom_steps - i - 1, 169 | use_original_steps=False, 170 | temperature=temperature 171 | ) 172 | latent, _ = outs 173 | 174 | if save_iterations: 175 | save_iters.append( 176 | self._tensor_decode(latent, spe_cor=histogram_matching) 177 | ) 178 | 179 | if save_iterations: 180 | return save_iters 181 | 182 | sr = self._tensor_decode(latent, spe_cor=histogram_matching) # decode the latent image 183 | sr = apply_no_data_mask(sr, no_data_mask) # apply no data mask as in LR image 184 | sr = revert_padding(sr,padding) # remove padding from the SR image if there was any 185 | return sr 186 | 187 | 188 | def hq_histogram_matching( 189 | self, image1: torch.Tensor, image2: torch.Tensor 190 | ) -> torch.Tensor: 191 | """ 192 | Applies histogram matching to align the color distribution of image1 to image2. 193 | 194 | This function adjusts the pixel intensity distribution of `image1` (typically the 195 | low-resolution or degraded image) to match that of `image2` (typically the 196 | high-resolution or reference image). The operation is done per channel and 197 | assumes both images are in (C, H, W) format. 198 | 199 | Args: 200 | image1 (torch.Tensor): The source image whose histogram will be modified (C, H, W). 201 | image2 (torch.Tensor): The reference image whose histogram will be matched (C, H, W). 202 | 203 | Returns: 204 | torch.Tensor: A new tensor with the same shape as `image1`, but with pixel 205 | intensities adjusted to match the histogram of `image2`. 206 | 207 | Raises: 208 | ValueError: If input tensors are not 2D or 3D. 209 | """ 210 | 211 | # Go to numpy 212 | np_image1 = image1.detach().cpu().numpy() 213 | np_image2 = image2.detach().cpu().numpy() 214 | 215 | if np_image1.ndim == 3: 216 | np_image1_hat = match_histograms(np_image1, np_image2, channel_axis=0) 217 | elif np_image1.ndim == 2: 218 | np_image1_hat = match_histograms(np_image1, np_image2, channel_axis=None) 219 | else: 220 | raise ValueError("The input image must have 2 or 3 dimensions.") 221 | 222 | # Go back to torch 223 | image1_hat = torch.from_numpy(np_image1_hat).to(image1.device) 224 | 225 | return image1_hat 226 | 227 | def load_pretrained(self, weights_file: str): 228 | """ 229 | Loads pretrained model weights from a local file or downloads them from Hugging Face if not present. 230 | 231 | If the specified `weights_file` does not exist locally, it is automatically downloaded from the 232 | Hugging Face model hub under `simon-donike/RS-SR-LTDF`. A progress bar is shown during download. 233 | 234 | After loading, the method removes any perceptual loss-related weights from the state dict and 235 | loads the remaining weights into the model. 236 | 237 | Args: 238 | weights_file (str): Path to the local weights file. If the file is not found, it will be downloaded 239 | using this name from the Hugging Face repository. 240 | 241 | Raises: 242 | RuntimeError: If the weights cannot be loaded or parsed correctly. 243 | 244 | Example: 245 | self.load_pretrained("model_weights.ckpt") 246 | """ 247 | 248 | # download pretrained model 249 | # create download link based on input 250 | hf_model = str("https://huggingface.co/simon-donike/RS-SR-LTDF/resolve/main/"+str(weights_file)) 251 | 252 | # Total size in bytes. 253 | if not pathlib.Path(weights_file).exists(): 254 | print("Downloading pretrained weights from: ", hf_model) 255 | response = requests.get(hf_model, stream=True) 256 | total_size = int(response.headers.get('content-length', 0)) 257 | block_size = 1024 # 1 Kibibyte 258 | 259 | # Open the file to write as binary - write bytes to a file 260 | with open(weights_file, "wb") as f: 261 | # Setup the progress bar 262 | with tqdm(total=total_size, unit='iB', unit_scale=True, desc=weights_file) as bar: 263 | for data in response.iter_content(block_size): 264 | bar.update(len(data)) 265 | f.write(data) 266 | 267 | weights = torch.load(weights_file, map_location=self.device)["state_dict"] 268 | 269 | # Remote perceptual tensors from weights 270 | for key in list(weights.keys()): 271 | if "loss" in key: 272 | del weights[key] 273 | 274 | self.model.load_state_dict(weights, strict=True) 275 | print("Loaded pretrained weights from: ", weights_file) 276 | 277 | 278 | def uncertainty_map(self, x, n_variations=15, custom_steps=100): 279 | """ 280 | Estimates uncertainty maps for each sample in the input batch using repeated stochastic forward passes. 281 | 282 | For each input sample, the method generates multiple super-resolved outputs by varying the random seed. 283 | It then computes the per-pixel standard deviation across these outputs as a proxy for uncertainty. 284 | The returned uncertainty map represents the average width of the confidence interval per pixel. 285 | 286 | Args: 287 | x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is batch size. 288 | n_variations (int): Number of stochastic forward passes per input sample. 289 | custom_steps (int): Custom inference steps passed to the forward method. 290 | 291 | Returns: 292 | torch.Tensor: Uncertainty maps of shape (B, 1, H, W), where each value indicates pixel-wise uncertainty. 293 | """ 294 | assert n_variations>3, "n_variations must be greater than 3 to compute uncertainty." 295 | 296 | 297 | batch_size = x.shape[0] 298 | rand_seed_list = random.sample(range(1, 9999), n_variations) 299 | 300 | all_variations = [] 301 | for b in range(batch_size): 302 | variations = [] 303 | x_b = x[b].unsqueeze(0) # shape (1, 4, 512, 512) 304 | for seed in rand_seed_list: 305 | with suppress_stdout(): 306 | np.random.seed(seed) 307 | torch.manual_seed(seed) 308 | random.seed(seed) 309 | #pytorch_lightning.utilities.seed.seed_everything(seed=seed, workers=True) 310 | 311 | sr = self.forward(x_b, custom_steps=custom_steps) # shape (1, C, H, W) 312 | variations.append(sr.detach().cpu()) 313 | 314 | variations = torch.stack(variations) # (n_variations, 1, C, H, W) 315 | srs_mean = variations.mean(dim=0) 316 | srs_stdev = variations.std(dim=0) 317 | interval_size = (srs_stdev * 2).mean(dim=1) # mean over channels 318 | 319 | all_variations.append(interval_size) # each is (1, H, W) 320 | 321 | result = torch.stack(all_variations) # (B, 1, H, W) 322 | return result 323 | 324 | 325 | 326 | def _attribution_methods( 327 | self, 328 | X: torch.Tensor, 329 | grads: torch.Tensor, 330 | attribution_method: Literal[ 331 | "grad_x_input", "max_grad", "mean_grad", "min_grad" 332 | ], 333 | ): 334 | """ 335 | DEPRECIATED; SUBJECT TO REMOVAL 336 | """ 337 | if attribution_method == "grad_x_input": 338 | return torch.norm(grads * X, dim=(0, 1)) 339 | elif attribution_method == "max_grad": 340 | return grads.abs().max(dim=0).max(dim=0) 341 | elif attribution_method == "mean_grad": 342 | return grads.abs().mean(dim=0).mean(dim=0) 343 | elif attribution_method == "min_grad": 344 | return grads.abs().min(dim=0).min(dim=0) 345 | else: 346 | raise ValueError( 347 | "The attribution method must be one of: grad_x_input, max_grad, mean_grad, min_grad" 348 | ) 349 | 350 | def explainer( 351 | self, 352 | X: torch.Tensor, 353 | mask: torch.Tensor, 354 | eta: float = 1.0, 355 | temperature: float = 1.0, 356 | custom_steps: int = 100, 357 | steps_to_consider_for_attributions: list = list(range(100)), 358 | attribution_method: Literal[ 359 | "grad_x_input", "max_grad", "mean_grad", "min_grad" 360 | ] = "grad_x_input", 361 | verbose: bool = False, 362 | enable_checkpoint = True, 363 | histogram_matching=True 364 | ): 365 | """ 366 | DEPRECIATED; SUBJECT TO REMOVAL 367 | """ 368 | # Normalize and encode the LR image 369 | X = X.clone() 370 | Xnorm = self._tensor_encode(X) 371 | 372 | # ddim, latent and time_range 373 | ddim, latent, time_range = self._prepare_model( 374 | X=Xnorm, eta=eta, custom_steps=custom_steps, verbose=verbose 375 | ) 376 | 377 | # Iterate over the timesteps 378 | container = [] 379 | iterator = tqdm(time_range, desc="DDIM Sampler", total=custom_steps,disable=True) 380 | for i, step in enumerate(iterator): 381 | 382 | # Activate or deactivate gradient tracking 383 | if i in steps_to_consider_for_attributions: 384 | torch.set_grad_enabled(True) 385 | else: 386 | torch.set_grad_enabled(False) 387 | 388 | # Compute the latent image 389 | if enable_checkpoint: 390 | outs = checkpoint.checkpoint( 391 | ddim.p_sample_ddim, 392 | latent, 393 | Xnorm, 394 | step, 395 | custom_steps - i - 1, 396 | temperature, 397 | use_reentrant=False, 398 | ) 399 | else: 400 | outs = ddim.p_sample_ddim( 401 | x=latent, 402 | c=Xnorm, 403 | t=step, 404 | index=custom_steps - i - 1, 405 | temperature=temperature 406 | ) 407 | latent, _ = outs 408 | 409 | 410 | if i not in steps_to_consider_for_attributions: 411 | continue 412 | 413 | # Apply the mask 414 | output_graph = (latent*mask).mean() 415 | 416 | # Compute the gradients 417 | grads = torch.autograd.grad(output_graph, Xnorm, retain_graph=True)[0] 418 | 419 | # Compute the attribution and save it 420 | with torch.no_grad(): 421 | to_save = { 422 | "latent": self._tensor_decode(latent, spe_cor=histogram_matching), 423 | "attribution": self._attribution_methods( 424 | Xnorm, grads, attribution_method 425 | ) 426 | } 427 | container.append(to_save) 428 | 429 | return container 430 | 431 | 432 | 433 | # ----------------------------------------------------------------------------- 434 | # Logic to create PyTorch Lightning Model from dif model 435 | # Logic to handle outputs from PL model and save them 436 | # ----------------------------------------------------------------------------- 437 | 438 | import torch 439 | import pytorch_lightning 440 | from pytorch_lightning import LightningModule 441 | 442 | class SRLatentDiffusionLightning(LightningModule): 443 | """ 444 | This Pytorch Lightning Class wraps around the torch model to 445 | aid in distrubuted GPU processing and optimized dataloaders 446 | provided by PL. ToDo: implement demo showcase 447 | """ 448 | def __init__(self,config, device: Union[str, torch.device] = "cpu"): 449 | super().__init__() 450 | self.model = SRLatentDiffusion(config,device=device) 451 | self.model = self.model.eval() 452 | 453 | @torch.no_grad() 454 | def forward(self, x,**kwargs): 455 | #print("Dont call 'forward' on the PL model, instead use 'predict'") 456 | return self.model(x) 457 | 458 | def load_pretrained(self, weights_file: str): 459 | self.model.load_pretrained(weights_file) 460 | print("PL Model: Model loaded from ", weights_file) 461 | 462 | @torch.no_grad() 463 | def predict_step(self, x, idx: int = 0,**kwargs): 464 | # perform SR 465 | assert self.model.training == False, "Model in Training mode. Abort." # make sure we're in eval 466 | p = self.model.forward(x,custom_steps=self.custom_steps,temperature=self.temperature) 467 | return(p) 468 | 469 | @torch.no_grad() 470 | def uncertainty_map(self, x,n_variations=15,custom_steps=100): 471 | uncertainty_map = self.model.uncertainty_map(x,n_variations,custom_steps) 472 | return(uncertainty_map) 473 | 474 | -------------------------------------------------------------------------------- /opensr_model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from matplotlib import pyplot as plt 4 | 5 | def linear_transform_placeholder(t_input, stage="norm"): 6 | return( t_input ) 7 | 8 | def linear_transform_4b(t_input,stage="norm"): 9 | assert stage in ["norm","denorm"] 10 | # get the shape of the tensor 11 | shape = t_input.shape 12 | 13 | # if 5 d tensor, norm/denorm individually 14 | if len(shape)==5: 15 | stack = [] 16 | for batch in t_input: 17 | stack2 = [] 18 | for i in range(0, t_input.size(1), 4): 19 | slice_tensor = batch[i:i+4, :, :, :] 20 | slice_denorm = linear_transform_4b(slice_tensor,stage=stage) 21 | stack2.append(slice_denorm) 22 | stack2 = torch.stack(stack2) 23 | stack2 = stack2.reshape(shape[1], shape[2], shape[3], shape[4]) 24 | stack.append(stack2) 25 | stack = torch.stack(stack) 26 | return(stack) 27 | 28 | # here only if len(shape) == 4 29 | squeeze_needed = False 30 | if len( shape ) == 3: 31 | squeeze_needed = True 32 | t_input = t_input.unsqueeze(0) 33 | shape = t_input.shape 34 | 35 | assert len(shape)==4 or len(shape)==5,"Input tensor must have 4 dimensions (B,C,H,W) - or 5D for MISR" 36 | transpose_needed = False 37 | if shape[-1]>shape[1]: 38 | transpose_needed = True 39 | t_input = rearrange(t_input,"b c h w -> b w h c") 40 | 41 | # define constants 42 | rgb_c = 3. 43 | nir_c = 5. 44 | 45 | # iterate over batches 46 | return_ls = [] 47 | for t in t_input: 48 | if stage == "norm": 49 | # divide according to conventions 50 | t[:,:,0] = t[:,:,0] * (10.0 / rgb_c) # R 51 | t[:,:,1] = t[:,:,1] * (10.0 / rgb_c) # G 52 | t[:,:,2] = t[:,:,2] * (10.0 / rgb_c) # B 53 | t[:,:,3] = t[:,:,3] * (10.0 / nir_c) # NIR 54 | # clamp to get rif of outlier pixels 55 | t = t.clamp(0,1) 56 | # bring to -1..+1 57 | t = (t*2)-1 58 | if stage == "denorm": 59 | # bring to 0..1 60 | t = (t+1)/2 61 | # divide according to conventions 62 | t[:,:,0] = t[:,:,0] * (rgb_c / 10.0) # R 63 | t[:,:,1] = t[:,:,1] * (rgb_c / 10.0) # G 64 | t[:,:,2] = t[:,:,2] * (rgb_c / 10.0) # B 65 | t[:,:,3] = t[:,:,3] * (nir_c / 10.0) # NIR 66 | # clamp to get rif of outlier pixels 67 | t = t.clamp(0,1) 68 | 69 | # append result to list 70 | return_ls.append(t) 71 | 72 | # after loop, stack image 73 | t_output = torch.stack(return_ls) 74 | #print("stacked",t_output.shape) 75 | 76 | if transpose_needed==True: 77 | t_output = rearrange(t_output,"b w h c -> b c h w") 78 | if squeeze_needed: 79 | t_output = t_output.squeeze(0) 80 | 81 | return(t_output) 82 | 83 | 84 | def linear_transform_6b(t_input,stage="norm"): 85 | # iterate over batches 86 | assert stage in ["norm","denorm"] 87 | bands_c = 5. 88 | return_ls = [] 89 | clamp = False 90 | for t in t_input: 91 | if stage == "norm": 92 | # divide according to conventions 93 | t[:,:,0] = t[:,:,0] * (10.0 / bands_c) 94 | t[:,:,1] = t[:,:,1] * (10.0 / bands_c) 95 | t[:,:,2] = t[:,:,2] * (10.0 / bands_c) 96 | t[:,:,3] = t[:,:,3] * (10.0 / bands_c) 97 | t[:,:,4] = t[:,:,4] * (10.0 / bands_c) 98 | t[:,:,5] = t[:,:,5] * (10.0 / bands_c) 99 | # clamp to get rif of outlier pixels 100 | if clamp: 101 | t = t.clamp(0,1) 102 | # bring to -1..+1 103 | t = (t*2)-1 104 | if stage == "denorm": 105 | # bring to 0..1 106 | t = (t+1)/2 107 | # divide according to conventions 108 | t[:,:,0] = t[:,:,0] * (bands_c / 10.0) 109 | t[:,:,1] = t[:,:,1] * (bands_c / 10.0) 110 | t[:,:,2] = t[:,:,2] * (bands_c / 10.0) 111 | t[:,:,3] = t[:,:,3] * (bands_c / 10.0) 112 | t[:,:,4] = t[:,:,4] * (bands_c / 10.0) 113 | t[:,:,5] = t[:,:,5] * (bands_c / 10.0) 114 | # clamp to get rif of outlier pixels 115 | if clamp: 116 | t = t.clamp(0,1) 117 | 118 | # append result to list 119 | return_ls.append(t) 120 | 121 | # after loop, stack image 122 | t_output = torch.stack(return_ls) 123 | 124 | return t_output 125 | 126 | def assert_tensor_validity(tensor): 127 | 128 | # ASSERT BATCH DIMENSION 129 | # if unbatched, add batch dimension 130 | if len(tensor.shape)==3: 131 | tensor = tensor.unsqueeze(0) 132 | 133 | # ASSERT BxCxHxW ORDER 134 | # Check the size of the input tensor 135 | if tensor.shape[-1]<10: 136 | tensor = rearrange(tensor,"b w h c -> b c h w") 137 | 138 | 139 | height, width = tensor.shape[-2],tensor.shape[-1] 140 | # Calculate how much padding is needed for height and width 141 | if height < 128 or width < 128: 142 | pad_height = max(0, 128 - height) # Amount to pad on height 143 | pad_width = max(0, 128 - width) # Amount to pad on width 144 | 145 | # Padding for height and width needs to be added to both sides of the dimension 146 | # The pad has the format (left, right, top, bottom) 147 | padding = (pad_width // 2, pad_width - pad_width // 2, pad_height // 2, pad_height - pad_height // 2) 148 | padding = padding 149 | 150 | # Apply symmetric padding 151 | tensor = torch.nn.functional.pad(tensor, padding, mode='reflect') 152 | 153 | else: # save padding with 0s 154 | padding = (0,0,0,0) 155 | padding = padding 156 | 157 | return tensor,padding 158 | 159 | 160 | 161 | def revert_padding(tensor,padding): 162 | left, right, top, bottom = padding 163 | # account for 4x upsampling Factor 164 | left, right, top, bottom = left*4, right*4, top*4, bottom*4 165 | # Calculate the indices to slice from the padded tensor 166 | start_height = top 167 | end_height = tensor.size(-2) - bottom 168 | start_width = left 169 | end_width = tensor.size(-1) - right 170 | 171 | # Slice the tensor to remove padding 172 | unpadded_tensor = tensor[:,:, start_height:end_height, start_width:end_width] 173 | return unpadded_tensor 174 | 175 | 176 | 177 | def create_no_data_mask(X,target_size = 512): 178 | """ 179 | Create a mask for no-data values in the input tensor. 180 | No-data values are defined as those equal to 0. 181 | 182 | Args: 183 | X (torch.Tensor): Input tensor of shape (B, C, H, W). 184 | 185 | Returns: 186 | torch.Tensor: Mask of shape (B, C, H, W) where no-data values are marked as True. 187 | """ 188 | # Create a mask where no-data values are True 189 | interpolated_X = torch.nn.functional.interpolate(X, size=(target_size,target_size), mode='nearest') 190 | mask = (interpolated_X == 0).float() 191 | return mask 192 | 193 | def apply_no_data_mask(X, mask): 194 | """ 195 | Apply a no-data mask to the input tensor. 196 | 197 | Args: 198 | X (torch.Tensor): Input tensor of shape (B, C, H, W). 199 | mask (torch.Tensor): Mask of shape (B, C, H, W) where no-data values are marked as True. 200 | 201 | Returns: 202 | torch.Tensor: Tensor with no-data values set to 0. 203 | """ 204 | # Apply the mask to set no-data values to 0 205 | X_masked = X * (1 - mask) 206 | return X_masked 207 | 208 | 209 | def plot_example(lr,sr,out_file="example.png"): 210 | # Assumes input is the LR example tensor nad the SR from the demo.py 211 | sr = sr.cpu()*3.5 212 | sr = sr.clamp(0,1) # Ensure values are in [0, 1] range 213 | lr = lr.cpu()*3.5 214 | lr = lr.clamp(0,1) # Ensure values are in [0, 1] range 215 | sr = sr 216 | fig, ax = plt.subplots(1, 2, figsize=(10, 5)) 217 | ax[0].imshow(rearrange(lr[0,:3,:,:].cpu()*1.5, 'c h w -> h w c').numpy()) 218 | ax[0].set_title("LR") 219 | ax[1].imshow(rearrange(sr[0,:3,:,:].cpu()*1.5, 'c h w -> h w c').numpy()) 220 | ax[1].set_title("SR") 221 | plt.savefig("example.png") 222 | plt.close() 223 | 224 | def plot_uncertainty(uncertainty_map,out_file="uncertainty_map.png",normalize=True,): 225 | uncertainty_map = uncertainty_map.cpu() 226 | img = rearrange(uncertainty_map[0, :, :, :], 'c h w -> h w c').numpy() 227 | 228 | # Convert to grayscale if single-channel 229 | if img.shape[2] == 1: 230 | img = img[:, :, 0] 231 | 232 | if normalize: 233 | # Stretch to [0, 1] 234 | img_min = img.min() 235 | img_max = img.max() 236 | img = (img - img_min) / (img_max - img_min + 1e-8) 237 | 238 | fig, ax = plt.subplots(figsize=(5, 5)) 239 | im = ax.imshow(img, cmap='viridis') 240 | ax.set_title("Uncertainty Map") 241 | label = "Uncertainty (Normalized)" if normalize else "Uncertainty (Absolute)" 242 | plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label=label) 243 | plt.tight_layout() 244 | plt.savefig(out_file) 245 | plt.close() 246 | 247 | 248 | def download_from_HF(file_name="example_lr.pt"): 249 | """ 250 | Downloads an example low-resolution tensor from the Hugging Face Hub. 251 | This function uses the `hf_hub_download` function to fetch a sample tensor 252 | for testing purposes. 253 | 254 | Returns: 255 | torch.Tensor: The downloaded low-resolution tensor. 256 | """ 257 | from huggingface_hub import hf_hub_download 258 | import torch 259 | 260 | # Download the file from your HF model repo 261 | file_path = hf_hub_download( 262 | repo_id="simon-donike/RS-SR-LTDF", 263 | filename="example_lr.pt", 264 | repo_type="model" # or "dataset" if it's in a dataset repo 265 | ) 266 | 267 | # Load the tensor 268 | obj = torch.load(file_path) 269 | return(obj) 270 | 271 | import contextlib 272 | import os 273 | import sys 274 | @contextlib.contextmanager 275 | def suppress_stdout(): 276 | with open(os.devnull, "w") as devnull: 277 | old_stdout = sys.stdout 278 | sys.stdout = devnull 279 | try: 280 | yield 281 | finally: 282 | sys.stdout = old_stdout 283 | 284 | 285 | -------------------------------------------------------------------------------- /opensr_utils_demo.py: -------------------------------------------------------------------------------- 1 | # This script is an example how opensr-utils can be used in unison with opensr-model 2 | # in order to SR a while S2 tile. 3 | 4 | # Import and Instanciate SR Model 5 | import opensr_model 6 | import torch 7 | from omegaconf import OmegaConf 8 | 9 | device = "cuda" if torch.cuda.is_available() else "cpu" 10 | config = OmegaConf.load("opensr_model/configs/config_10m.yaml") # load config 11 | model = opensr_model.SRLatentDiffusion(config, device=device) # create model 12 | model.load_pretrained(config.ckpt_version) # load checkpoint 13 | 14 | 15 | # perform SR with opensr-utils on whole tile (.SAFE format) 16 | from opensr_utils.main import windowed_SR_and_saving 17 | path = "/data3/inf_data/S2A_MSIL2A_20241026T105131_N0511_R051_T30SYJ_20241026T150453.SAFE/" 18 | sr_obj = windowed_SR_and_saving(folder_path=path, window_size=(128, 128), factor=4, keep_lr_stack=True) 19 | sr_obj.start_super_resolution(band_selection="10m",model=model,forward_call="forward",custom_steps=100,overlap=20, eliminate_border_px=0) # start 20 | 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # ESA OpenSR packages 2 | opensr_model 3 | 4 | # Tensors 5 | torch==1.13.1 6 | pytorch-lightning==1.9.0 7 | transformers==4.26.1 8 | 9 | # Geo 10 | geopandas==0.12.2 11 | rasterio==1.3.6 12 | pyproj==3.4.1 13 | affine==2.4.0 14 | shapely==2.0.1 15 | 16 | # CV 17 | albumentations==1.3.1 18 | opencv-python==4.7.0.68 19 | scikit-image==0.19.3 20 | scikit-learn==1.3.0 21 | lpips==0.1.4 22 | kornia==0.6.9 23 | imageio==2.25.0 24 | Pillow==9.4.0 25 | 26 | # General 27 | torch==1.13.1 28 | torchaudio==0.13.1 29 | torchmetrics==0.11.1 30 | torchvision==0.14.1 31 | omegaconf==2.3.0 32 | oauthlib==3.2.2 33 | numpy==1.23.5 34 | tqdm==4.64.1 35 | einops==0.6.0 36 | typing-extensions==4.7.1 37 | wandb==0.13.9 38 | pathtools==0.1.2 39 | 40 | # Diffusion 41 | diffusers==0.12.1 42 | huggingface-hub==0.12.0 43 | taming-transformers==0.0.1 -------------------------------------------------------------------------------- /resources/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/opensr-model/85bbf2dcc7937c8f38596ed5c26b1e75bd46c85f/resources/example.png -------------------------------------------------------------------------------- /resources/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/opensr-model/85bbf2dcc7937c8f38596ed5c26b1e75bd46c85f/resources/example2.png -------------------------------------------------------------------------------- /resources/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/opensr-model/85bbf2dcc7937c8f38596ed5c26b1e75bd46c85f/resources/example3.png -------------------------------------------------------------------------------- /resources/sr_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/opensr-model/85bbf2dcc7937c8f38596ed5c26b1e75bd46c85f/resources/sr_example.png -------------------------------------------------------------------------------- /resources/uncertainty_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/opensr-model/85bbf2dcc7937c8f38596ed5c26b1e75bd46c85f/resources/uncertainty_map.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # read the contents of README file 4 | with open('README.md', encoding='utf-8') as f: 5 | long_description = f.read() 6 | 7 | setup( 8 | name='opensr-model', 9 | version='0.3.0', 10 | author = "Simon Donike, Cesar Aybar, Luis Gomez Chova, Freddie Kalaitzis", 11 | author_email = "accounts@donike.net", 12 | description = "ESA OpenSR Diffusion model package for Super-Resolution of Senintel-2 Imagery", 13 | url = "https://isp.uv.es/opensr/", 14 | project_urls={'Source Code': 'https://github.com/ESAopenSR/opensr-model'}, 15 | license='MIT', 16 | packages=find_packages(), 17 | long_description=long_description, 18 | long_description_content_type='text/markdown', 19 | install_requires=[ 20 | 'numpy', 21 | 'einops', 22 | 'rasterio', 23 | 'tqdm', 24 | 'torch', 25 | 'scikit-image', 26 | 'pytorch-lightning', 27 | 'requests', 28 | 'omegaconf',], 29 | ) 30 | --------------------------------------------------------------------------------