├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── sen2sr.gif ├── srimg01.png ├── srimg02.png ├── srimg03.png ├── srimg04.png ├── srimg05.png └── style.css ├── codecov.yaml ├── mkdocs.yml ├── poetry.lock ├── poetry.toml ├── pyproject.toml ├── requirements.txt ├── sen2sr ├── __init__.py ├── models │ ├── __init__.py │ ├── opensr_baseline │ │ ├── __init__.py │ │ ├── cnn.py │ │ ├── mamba.py │ │ └── swin.py │ └── tricks.py ├── nonreference.py ├── referencex2.py ├── referencex4.py ├── utils.py └── xai │ ├── __init__.py │ ├── lam.py │ └── utils.py └── tests └── test_foo.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | demo.py 177 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [Unreleased] 8 | 9 | ### Added 10 | - Initial release of **SuperS2** with core functionalities for Sentinel-2 data processing. 11 | - Detailed documentation for installation and basic usage examples. 12 | 13 | ### Changed 14 | - Updated README to include new badges and links. 15 | 16 | ### Fixed 17 | - Fixed minor bugs in the data processing module related to edge cases. 18 | 19 | ## [0.1.0] - 2024-05-03 20 | ### Added 21 | - First public release with support for enhancing Sentinel-2 spatial resolution to 2.5 meters. -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # **Contributor covenant code of conduct** 📜 2 | 3 | ## **Our pledge** 🤝 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 🌎🤗 7 | 8 | ## **Our standards** 📏 9 | 10 | Examples of behavior that contributes to creating a positive environment 11 | include: 12 | 13 | - Using welcoming and inclusive language. 😊 14 | - Being respectful of differing viewpoints and experiences. 🤔👂 15 | - Gracefully accepting constructive criticism. 🛠️ 16 | - Focusing on what is best for the community. 🤲 17 | - Showing empathy towards other community members. 🥺❤️ 18 | 19 | Examples of unacceptable behavior by participants include: 20 | 21 | - The use of sexualized language or imagery and unwelcome sexual attention or advances. 🚫💬 22 | - Trolling, insulting/derogatory comments, and personal or political attacks. 🚫😠 23 | - Public or private harassment. 🚫👥 24 | - Publishing others' private information, such as a physical or electronic 25 | address, without explicit permission. 🚫🏡 26 | - Other conduct which could reasonably be considered inappropriate in a 27 | professional setting. 🚫👔 28 | 29 | ## **Our responsibilities** 🛡️ 30 | 31 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 32 | 33 | Project maintainers have the right and responsibility to remove, edit, or 34 | reject comments, commits, code, wiki edits, issues, and other contributions 35 | that are not aligned to this Code of Conduct, or to ban temporarily or 36 | permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 37 | 38 | ## **Scope** 🌐 39 | 40 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 41 | 42 | ## **Enforcement** 🚨 43 | 44 | All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 45 | 46 | ## **Attribution** 👏 47 | 48 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at 49 | 50 | [homepage]: 51 | 52 | For answers to common questions about this code of conduct, see 53 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # **Contributing** 🤝 2 | 3 | We welcome contributions from the community! Every contribution, no matter how small, is appreciated and credited. Here’s how you can get involved: 4 | 5 | ## **How to contribute** 🛠️ 6 | 7 | 1. **Fork the repository:** Start by forking the [SuperS2](https://github.com/IPL-UV/supers2) repository to your GitHub account. 🍴 8 | 2. **Clone your fork locally:** 9 | ```bash 10 | cd 11 | git clone https://github.com/YOUR_GITHUB_USERNAME/supers2.git 12 | cd supers2 13 | ``` 14 | 3. **Create a branch:** Create a new branch for your feature or bug fix: 15 | ```bash 16 | git checkout -b name-of-your-bugfix-or-feature 17 | ``` 18 | 4. **Set up the environment:** 🌱 19 | - If you're using `pyenv`, select a Python version: 20 | ```bash 21 | pyenv local 22 | ``` 23 | - Install dependencies and activate the environment: 24 | ```bash 25 | poetry install 26 | poetry shell 27 | ``` 28 | - Install pre-commit hooks: 29 | ```bash 30 | poetry run pre-commit install 31 | ``` 32 | 5. **Make your changes:** 🖋️ Develop your feature or fix, ensuring you write clear, concise commit messages and include any necessary tests. 33 | 6. **Check your changes:** ✅ 34 | - Run formatting checks: 35 | ```bash 36 | make check 37 | ``` 38 | - Run unit tests: 39 | ```bash 40 | make test 41 | ``` 42 | - Optionally, run tests across different Python versions using tox: 43 | ```bash 44 | tox 45 | ``` 46 | 7. **Submit a pull request:** 🚀 Push your branch to GitHub and submit a pull request to the `develop` branch of the SuperS2 repository. Ensure your pull request meets these guidelines: 47 | - Include tests. 48 | - Update the documentation if your pull request adds functionality. 49 | - Provide a detailed description of your changes. 50 | 51 | ## **Types of contributions** 📦 52 | 53 | - **Report bugs:** 🐛 54 | - Report bugs by creating an issue on the [SuperS2 GitHub repository](https://github.com/IPL-UV/supers2/issues). Please include your operating system, setup details, and steps to reproduce the bug. 55 | - **Fix bugs:** 🛠️ Look for issues tagged with "bug" and "help wanted" in the repository to start fixing. 56 | - **Implement features:** ✨ Contribute by implementing features tagged with "enhancement" and "help wanted." 57 | - **Write documentation:** 📚 Contribute to the documentation in the official docs, docstrings, or through blog posts and articles. 58 | - **Submit feedback:** 💬 Propose new features or give feedback by filing an issue on GitHub. 59 | - Use the [SuperS2 GitHub issues page](https://github.com/IPL-UV/supers2/issues) for feedback. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SEN2SR 2 | ### *A radiometrically and spatially consistent super-resolution framework for Sentinel-2* 3 |

4 | 5 |

6 | 7 | 8 | 9 | 10 |

11 | 12 | PyPI 13 | 14 | 15 | License 16 | 17 | 18 | Black 19 | 20 | 21 | isort 22 | 23 | 24 | isort 25 | 26 |

27 | 28 | 29 | 30 | 31 | --- 32 | 33 | **GitHub**: [https://github.com/ESAOpenSR/sen2sr](https://github.com/ESAOpenSR/sen2sr) 🌐 34 | 35 | **PyPI**: [https://pypi.org/project/sen2sr/](https://pypi.org/project/sen2sr/) 🛠️ 36 | 37 | **Preprint**: [https://papers.ssrn.com/sol3/papers.cfm?abstract_id=5247739](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=5247739) 📄 38 | 39 | --- 40 | 41 | ## **Table of Contents** 42 | 43 | - [**Overview**](#overview-) 44 | - [**Installation**](#installation-) 45 | - [**From 10m and 20m Sentinel-2 bands to 2.5m**](#from-10m-and-20m-sentinel-2-bands-to-25m) 46 | - [**From 10m Sentinel-2 bands to 2.5m**](#from-10m-sentinel-2-bands-to-25m) 47 | - [**From 20m Sentinel-2 bands to 10m**](#from-20m-sentinel-2-bands-to-10m) 48 | - [**Predict on large images**](#predict-on-large-images) 49 | - [**Estimate the Local Attention Map**](#estimate-the-local-attention-map) 50 | 51 | ## **Overview** 52 | 53 | **sen2sr** is a Python package designed to enhance the spatial resolution of Sentinel-2 satellite images to 2.5 meters using a set of neural network models. 54 | 55 | | Model | Description | Run Link | 56 | |--------|-------------|---------| 57 | | **Run SENSRLite** | A lightweight SR model optimized for running fast! | Open In Colab | 58 | | **Run SENSR** | Our most accurate SR model! | Open In Colab | 59 | | 60 | 61 | ## **Installation** 62 | 63 | Install the **SEN2SRLite** version using pip: 64 | 65 | ```bash 66 | pip install sen2sr mlstac git+https://github.com/ESDS-Leipzig/cubo.git 67 | ``` 68 | 69 | For using the full version of **SEN2SR**, which employs the Mamba architecture, install as follows: 70 | 71 | > ⚠️ **Warning** 72 | > `mamba-ssm` **requires a GPU runtime** and **CUDA version > 12** to function properly. 73 | > Installation may take a significant amount of time. **Please avoid interrupting the process.** 74 | 75 | 76 | 1. Create a fresh Conda environment: 77 | 78 | ```bash 79 | conda create -n test_env python=3.11 80 | conda activate test_env 81 | ``` 82 | 83 | 2. Install PyTorch with CUDA support: 84 | 85 | ```bash 86 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121 87 | ``` 88 | 89 | 3. Attempt to install mamba-ssm using --no-build-isolation: 90 | 91 | ```bash 92 | pip install mamba-ssm --no-build-isolation 93 | ``` 94 | 95 | 4. Install the remaining dependencies: 96 | ```bash 97 | pip install sen2sr mlstac git+https://github.com/ESDS-Leipzig/cubo.git 98 | ``` 99 | 100 | Adapted from this [state-spaces/mamba issue](https://github.com/state-spaces/mamba/issues/662). 101 | 102 | 103 | ## From 10m and 20m Sentinel-2 bands to 2.5m 104 | 105 | 106 | This example demonstrates the use of the `SEN2SRLite` model to enhance the spatial resolution of Sentinel-2 imagery. A 107 | Sentinel-2 L2A data cube is created over a specified region and time range using the cubo library, including both 10 m 108 | and 20 m bands. The pretrained model, downloaded via mlstac, takes a single normalized sample as input and predicts a 109 | HR output. The visualization compares the original RGB composite to the super-resolved result. 110 | 111 | 112 | ```python 113 | import mlstac 114 | import torch 115 | import cubo 116 | 117 | # Download the model 118 | mlstac.download( 119 | file="https://huggingface.co/tacofoundation/sen2sr/resolve/main/SEN2SRLite/main/mlm.json", 120 | output_dir="model/SEN2SRLite", 121 | ) 122 | 123 | # Load the model 124 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 125 | model = mlstac.load("model/SEN2SRLite").compiled_model(device=device) 126 | model = model.to(device) 127 | 128 | # Create a Sentinel-2 L2A data cube for a specific location and date range 129 | da = cubo.create( 130 | lat=39.49152740347753, 131 | lon=-0.4308725142800361, 132 | collection="sentinel-2-l2a", 133 | bands=["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"], 134 | start_date="2023-01-01", 135 | end_date="2023-12-31", 136 | edge_size=128, 137 | resolution=10 138 | ) 139 | 140 | # Prepare the data to be used in the model, select just one sample 141 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 142 | original_s2_numpy = (da[11].compute().to_numpy() / 10_000).astype("float32") 143 | X = torch.from_numpy(original_s2_numpy).float().to(device) 144 | 145 | # Apply model 146 | superX = model(X[None]).squeeze(0) 147 | ``` 148 | 149 |

150 | 151 |

152 | 153 | 154 | ## From 10m Sentinel-2 bands to 2.5m 155 | 156 | 157 | This example demonstrates the use of the `SEN2SRLite NonReference_RGBN_x4` model variant to enhance the spatial resolution 158 | of only the 10 m Sentinel-2 bands: red (B04), green (B03), blue (B02), and near-infrared (B08). A Sentinel-2 L2A data cube is created using the cubo library for a specific location and date range. The input is normalized and passed to a pretrained non-reference model optimized for RGB+NIR inputs. 159 | 160 | 161 | ```python 162 | import mlstac 163 | import torch 164 | import cubo 165 | 166 | # Download the model 167 | mlstac.download( 168 | file="https://huggingface.co/tacofoundation/sen2sr/resolve/main/SEN2SRLite/NonReference_RGBN_x4/mlm.json", 169 | output_dir="model/SEN2SRLite_RGBN", 170 | ) 171 | 172 | # Create a Sentinel-2 L2A data cube for a specific location and date range 173 | da = cubo.create( 174 | lat=39.49152740347753, 175 | lon=-0.4308725142800361, 176 | collection="sentinel-2-l2a", 177 | bands=["B04", "B03", "B02", "B08"], 178 | start_date="2023-01-01", 179 | end_date="2023-12-31", 180 | edge_size=128, 181 | resolution=10 182 | ) 183 | 184 | 185 | # Prepare the data to be used in the model 186 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 187 | original_s2_numpy = (da[11].compute().to_numpy() / 10_000).astype("float32") 188 | X = torch.from_numpy(original_s2_numpy).float().to(device) 189 | X = torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) 190 | 191 | # Load the model 192 | model = mlstac.load("model/SEN2SRLite_RGBN").compiled_model(device=device) 193 | 194 | # Apply model 195 | superX = model(X[None]).squeeze(0) 196 | ``` 197 | 198 |

199 | 200 |

201 | 202 | 203 | ## From 20m Sentinel-2 bands to 10m 204 | 205 | This example demonstrates the use of the `SEN2SRLite Reference_RSWIR_x2` model variant to enhance the spatial resolution of the 20 m Sentinel-2 bands: red-edge (B05, B06, B07), shortwave infrared (B11, B12), and near-infrared (B8A) to 10 m. 206 | 207 | 208 | ```python 209 | import mlstac 210 | import torch 211 | import cubo 212 | 213 | # Download the model 214 | mlstac.download( 215 | file="https://huggingface.co/tacofoundation/sen2sr/resolve/main/SEN2SRLite/Reference_RSWIR_x2/mlm.json", 216 | output_dir="model/SEN2SRLite_Reference_RSWIR_x2", 217 | ) 218 | 219 | # Create a Sentinel-2 L2A data cube for a specific location and date range 220 | da = cubo.create( 221 | lat=39.49152740347753, 222 | lon=-0.4308725142800361, 223 | collection="sentinel-2-l2a", 224 | bands=["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"], 225 | start_date="2023-01-01", 226 | end_date="2023-12-31", 227 | edge_size=128, 228 | resolution=10 229 | ) 230 | 231 | # Prepare the data to be used in the model 232 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 233 | original_s2_numpy = (da[11].compute().to_numpy() / 10_000).astype("float32") 234 | X = torch.from_numpy(original_s2_numpy).float().to(device) 235 | 236 | # Load the model 237 | model = mlstac.load("model/SEN2SRLite_Reference_RSWIR_x2").compiled_model(device=device) 238 | model = model.to(device) 239 | 240 | # Apply model 241 | superX = model(X[None]).squeeze(0) 242 | ``` 243 | 244 |

245 | 246 |

247 | 248 | 249 | ## **Predict on large images** 250 | 251 | This example demonstrates the use of `SEN2SRLite NonReference_RGBN_x4` for super-resolving large Sentinel-2 RGB+NIR images by chunking the 252 | input into smaller overlapping tiles. Although the model is trained to operate on fixed-size 128×128 patches, the `sen2sr.predict_large` utility automatically segments larger inputs into these tiles, applies the model to each tile independently, and then reconstructs the full image. An overlap margin (e.g., 32 pixels) is introduced between tiles to minimize edge artifacts and ensure continuity across tile boundaries. 253 | 254 | ```python 255 | import mlstac 256 | import sen2sr 257 | import torch 258 | import cubo 259 | 260 | # Create a Sentinel-2 L2A data cube for a specific location and date range 261 | da = cubo.create( 262 | lat=39.49152740347753, 263 | lon=-0.4308725142800361, 264 | collection="sentinel-2-l2a", 265 | bands=["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"], 266 | start_date="2023-01-01", 267 | end_date="2023-12-31", 268 | edge_size=1024, 269 | resolution=10 270 | ) 271 | 272 | # Prepare the data to be used in the model 273 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 274 | original_s2_numpy = (da[11].compute().to_numpy() / 10_000).astype("float32") 275 | X = torch.from_numpy(original_s2_numpy).float().to(device) 276 | X = torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) 277 | 278 | # Load the model 279 | model = mlstac.load("model/SEN2SRLite").compiled_model(device=device) 280 | 281 | 282 | # Apply model 283 | superX = sen2sr.predict_large( 284 | model=model, 285 | X=X, # The input tensor 286 | overlap=32, # The overlap between the patches 287 | ) 288 | ``` 289 | 290 |

291 | 292 |

293 | 294 | 295 | ### Estimate the Local Attention Map 296 | 297 | 298 | This example computes the Local Attention Map (LAM) to analyze the model's spatial sensitivity 299 | and robustness. The input image is scanned with a sliding window, and the model's attention is 300 | estimated across multiple upscaling factors. The resulting KDE map highlights regions where 301 | the model focuses more strongly, while the robustness vector quantifies the model's stability 302 | to spatial perturbations. 303 | 304 | 305 | ```python 306 | import mlstac 307 | import sen2sr 308 | import torch 309 | import cubo 310 | 311 | # Create a Sentinel-2 L2A data cube for a specific location and date range 312 | da = cubo.create( 313 | lat=39.49152740347753, 314 | lon=-0.4308725142800361, 315 | collection="sentinel-2-l2a", 316 | bands=["B04", "B03", "B02", "B08"], 317 | start_date="2023-01-01", 318 | end_date="2023-12-31", 319 | edge_size=128, 320 | resolution=10 321 | ) 322 | 323 | 324 | # Prepare the data to be used in the model 325 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 326 | original_s2_numpy = (da[11].compute().to_numpy() / 10_000).astype("float32") 327 | X = torch.from_numpy(original_s2_numpy).float().to(device) 328 | X = torch.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0) 329 | 330 | # Load the model 331 | #mlstac.download( 332 | # file="https://huggingface.co/tacofoundation/sen2sr/resolve/main/SEN2SRLite/NonReference_RGBN_x4/mlm.json", 333 | # output_dir="model/SEN2SRLite_RGBN", 334 | #) 335 | model = mlstac.load("model/SEN2SRLite_RGBN").compiled_model(device=device) 336 | 337 | # Apply model 338 | kde_map, complexity_metric, robustness_metric, robustness_vector = sen2sr.lam( 339 | X=X, # The input tensor 340 | model=model, # The SR model 341 | h=240, # The height of the window 342 | w=240, # The width of the window 343 | window=32, # The window size 344 | scales = ["2x", "3x", "4x", "5x", "6x"] 345 | ) 346 | 347 | 348 | import matplotlib.pyplot as plt 349 | fig, ax = plt.subplots(1, 2, figsize=(12, 6)) 350 | ax[0].imshow(kde_map) 351 | ax[0].set_title("Kernel Density Estimation") 352 | ax[1].plot(robustness_vector) 353 | ax[1].set_title("Robustness Vector") 354 | plt.show() 355 | ``` 356 | 357 |

358 | 359 |

360 | 361 | -------------------------------------------------------------------------------- /assets/sen2sr.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/assets/sen2sr.gif -------------------------------------------------------------------------------- /assets/srimg01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/assets/srimg01.png -------------------------------------------------------------------------------- /assets/srimg02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/assets/srimg02.png -------------------------------------------------------------------------------- /assets/srimg03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/assets/srimg03.png -------------------------------------------------------------------------------- /assets/srimg04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/assets/srimg04.png -------------------------------------------------------------------------------- /assets/srimg05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/assets/srimg05.png -------------------------------------------------------------------------------- /assets/style.css: -------------------------------------------------------------------------------- 1 | /* nav */ 2 | .md-tabs__list { 3 | display: flex; 4 | justify-content: space-evenly; 5 | } 6 | 7 | .md-header, .md-footer { 8 | background-color: #377c2f; 9 | } 10 | 11 | .md-tabs { 12 | background-color: #1a4e7e; 13 | } -------------------------------------------------------------------------------- /codecov.yaml: -------------------------------------------------------------------------------- 1 | coverage: 2 | range: 70..100 3 | round: down 4 | precision: 1 5 | status: 6 | project: 7 | default: 8 | target: 90% 9 | threshold: 0.5% 10 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | docs_dir: . 2 | 3 | # Project information 4 | site_name: SuperS2 5 | site_url: https://ipl-uv.github.io/supers2/ 6 | site_author: IPL-UV 7 | site_description: A Python package for enhancing the spatial resolution of Sentinel-2 satellite images to 2.5 meters. 8 | 9 | 10 | # Repository 11 | repo_url: https://github.com/IPL-UV/supers2 12 | repo_name: supers2 13 | use_directory_urls: false 14 | 15 | # Configuration 16 | theme: 17 | name: material 18 | language: en 19 | palette: 20 | - scheme: default 21 | primary: '#d49f0c' 22 | accent: '#d49f0c' 23 | toggle: 24 | icon: material/toggle-switch-off-outline 25 | name: Switch to dark mode 26 | - scheme: slate 27 | primary: '#201357' 28 | accent: white 29 | toggle: 30 | icon: material/toggle-switch 31 | name: Switch to light mode 32 | font: 33 | text: Roboto 34 | code: Roboto Mono 35 | logo: assets/images/logo_ss2.png 36 | favicon: assets/images/logo_ss2.png 37 | features: 38 | - navigation.instant 39 | - navigation.tabs 40 | - navigation.top 41 | - navigation.expand 42 | - navigation.indexes 43 | - header.autohide 44 | 45 | 46 | nav: 47 | - Home: 48 | - README.md 49 | - Contributing: CONTRIBUTING.md 50 | - Changelog: CHANGELOG.md 51 | - Code of conduct: CODE_OF_CONDUCT.md 52 | 53 | # Plugins 54 | plugins: 55 | - search 56 | - same-dir 57 | - mkdocstrings 58 | - awesome-pages 59 | 60 | markdown_extensions: 61 | - meta 62 | - admonition 63 | - pymdownx.highlight 64 | - pymdownx.superfences 65 | - pymdownx.pathconverter 66 | - pymdownx.tabbed 67 | - mdx_truly_sane_lists 68 | - pymdownx.tasklist 69 | 70 | extra_css: 71 | - assets/style.css -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | in-project = true 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "sen2sr" 3 | version = "0.8.5" 4 | description = "A Python package to super-resolve Sentinel-2 satellite imagery up to 2.5 meters." 5 | authors = ["Cesar Aybar ", "Julio Contreras "] 6 | repository = "https://github.com/ESAOpenSR/sen2sr" 7 | documentation = "https://esaopensr.github.io/sen2sr/" 8 | readme = "README.md" 9 | packages = [ 10 | {include = "sen2sr"}, 11 | ] 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.10,<4.0" 15 | tqdm = ">=4.67.1" 16 | numpy = ">=2.0.2" 17 | einops = ">=0.8.1" 18 | 19 | 20 | [tool.poetry.group.dev.dependencies] 21 | pytest = "^7.2.0" 22 | pytest-cov = "^4.0.0" 23 | deptry = "^0.16.2" 24 | mypy = "^1.5.1" 25 | pre-commit = "^3.4.0" 26 | tox = "^4.11.1" 27 | 28 | [tool.poetry.group.docs.dependencies] 29 | mkdocs = "^1.4.2" 30 | mkdocs-material = "^9.2.7" 31 | mkdocstrings = {extras = ["python"], version = "^0.26.1"} 32 | 33 | [build-system] 34 | requires = ["poetry-core>=1.0.0"] 35 | build-backend = "poetry.core.masonry.api" 36 | 37 | [tool.mypy] 38 | files = ["sen2sr"] 39 | disallow_untyped_defs = "True" 40 | disallow_any_unimported = "True" 41 | no_implicit_optional = "True" 42 | check_untyped_defs = "True" 43 | warn_return_any = "True" 44 | warn_unused_ignores = "True" 45 | show_error_codes = "True" 46 | 47 | 48 | [tool.pytest.ini_options] 49 | testpaths = ["tests"] 50 | 51 | [tool.ruff] 52 | target-version = "py39" 53 | line-length = 120 54 | fix = true 55 | select = [ 56 | # flake8-2020 57 | "YTT", 58 | # flake8-bandit 59 | "S", 60 | # flake8-bugbear 61 | "B", 62 | # flake8-builtins 63 | "A", 64 | # flake8-comprehensions 65 | "C4", 66 | # flake8-debugger 67 | "T10", 68 | # flake8-simplify 69 | "SIM", 70 | # isort 71 | "I", 72 | # mccabe 73 | "C90", 74 | # pycodestyle 75 | "E", "W", 76 | # pyflakes 77 | "F", 78 | # pygrep-hooks 79 | "PGH", 80 | # pyupgrade 81 | "UP", 82 | # ruff 83 | "RUF", 84 | # tryceratops 85 | "TRY", 86 | ] 87 | ignore = [ 88 | # LineTooLong 89 | "E501", 90 | # DoNotAssignLambda 91 | "E731", 92 | ] 93 | 94 | [tool.ruff.format] 95 | preview = true 96 | 97 | [tool.coverage.report] 98 | skip_empty = true 99 | 100 | [tool.coverage.run] 101 | branch = true 102 | source = ["sen2sr"] 103 | 104 | 105 | [tool.ruff.per-file-ignores] 106 | "tests/*" = ["S101"] 107 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs==1.3.* 2 | mkdocs-material==8.2.* 3 | pymdown-extensions>=10.0 4 | markdown==3.3.* 5 | mdx_truly_sane_lists==1.2.* 6 | mkdocs-git-revision-date-localized-plugin==1.0.* 7 | mkdocs-awesome-pages-plugin==2.9.1 8 | mkdocstrings 9 | mkdocs-same-dir 10 | mkdocs-autorefs==0.3.0 11 | jinja2==3.0.3 -------------------------------------------------------------------------------- /sen2sr/__init__.py: -------------------------------------------------------------------------------- 1 | # dynamic versioning 2 | from importlib.metadata import version, PackageNotFoundError 3 | from sen2sr.utils import predict_large 4 | from sen2sr.xai.lam import lam 5 | 6 | try: 7 | __version__ = version("sen2sr") 8 | except PackageNotFoundError: 9 | __version__ = "unknown" 10 | 11 | 12 | try: 13 | import torch 14 | import timm 15 | except ImportError: 16 | raise ImportError( 17 | "sen2sr requires torch and timm. Please install them." 18 | ) 19 | 20 | __all__ = [ 21 | "__version__", 22 | "predict_large", 23 | "lam" 24 | ] -------------------------------------------------------------------------------- /sen2sr/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/sen2sr/models/__init__.py -------------------------------------------------------------------------------- /sen2sr/models/opensr_baseline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/sen2sr/models/opensr_baseline/__init__.py -------------------------------------------------------------------------------- /sen2sr/models/opensr_baseline/cnn.py: -------------------------------------------------------------------------------- 1 | # I stole the code from here: https://github.com/hongyuanyu/SPAN 2 | # The author of the code deserves all the credit. I just make 3 | # basic modifications to make it work with my codebase. 4 | 5 | 6 | from collections import OrderedDict 7 | from typing import List, Optional, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn as nn 12 | 13 | 14 | def _make_pair(value: int) -> tuple: 15 | """ 16 | Converts a single integer into a tuple of the same integer repeated twice. 17 | 18 | Args: 19 | value (int): Integer value to be converted. 20 | 21 | Returns: 22 | tuple: Tuple containing the integer repeated twice. 23 | """ 24 | if isinstance(value, int): 25 | value = (value,) * 2 26 | return value 27 | 28 | 29 | def conv_layer( 30 | in_channels: int, out_channels: int, kernel_size: int, bias: bool = True 31 | ) -> nn.Conv2d: 32 | """ 33 | Creates a 2D convolutional layer with adaptive padding. 34 | 35 | Args: 36 | in_channels (int): Number of input channels. 37 | out_channels (int): Number of output channels. 38 | kernel_size (int): Size of the convolution kernel. 39 | bias (bool, optional): Whether to include a bias term. Defaults to True. 40 | 41 | Returns: 42 | nn.Conv2d: 2D convolutional layer with calculated padding. 43 | """ 44 | kernel_size = _make_pair(kernel_size) 45 | padding = (int((kernel_size[0] - 1) / 2), int((kernel_size[1] - 1) / 2)) 46 | return nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) 47 | 48 | 49 | def activation( 50 | act_type: str, inplace: bool = True, neg_slope: float = 0.05, n_prelu: int = 1 51 | ) -> nn.Module: 52 | """ 53 | Returns an activation layer based on the specified type. 54 | 55 | Args: 56 | act_type (str): Type of activation ('relu', 'lrelu', 'prelu'). 57 | inplace (bool, optional): If True, performs the operation in-place. Defaults to True. 58 | neg_slope (float, optional): Negative slope for 'lrelu' and 'prelu'. Defaults to 0.05. 59 | n_prelu (int, optional): Number of parameters for 'prelu'. Defaults to 1. 60 | 61 | Returns: 62 | nn.Module: Activation layer. 63 | """ 64 | act_type = act_type.lower() 65 | if act_type == "relu": 66 | layer = nn.ReLU(inplace) 67 | elif act_type == "lrelu": 68 | layer = nn.LeakyReLU(neg_slope, inplace) 69 | elif act_type == "prelu": 70 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 71 | else: 72 | raise NotImplementedError( 73 | "activation layer [{:s}] is not found".format(act_type) 74 | ) 75 | return layer 76 | 77 | 78 | def sequential(*args) -> nn.Sequential: 79 | """ 80 | Constructs a sequential container for the provided modules. 81 | 82 | Args: 83 | args: Modules in order of execution. 84 | 85 | Returns: 86 | nn.Sequential: A Sequential container. 87 | """ 88 | if len(args) == 1: 89 | if isinstance(args[0], OrderedDict): 90 | raise NotImplementedError("sequential does not support OrderedDict input.") 91 | return args[0] 92 | modules = [] 93 | for module in args: 94 | if isinstance(module, nn.Sequential): 95 | for submodule in module.children(): 96 | modules.append(submodule) 97 | elif isinstance(module, nn.Module): 98 | modules.append(module) 99 | return nn.Sequential(*modules) 100 | 101 | 102 | def pixelshuffle_block( 103 | in_channels: int, out_channels: int, upscale_factor: int = 2, kernel_size: int = 3 104 | ) -> nn.Sequential: 105 | """ 106 | Creates an upsampling block using pixel shuffle. 107 | 108 | Args: 109 | in_channels (int): Number of input channels. 110 | out_channels (int): Number of output channels. 111 | upscale_factor (int, optional): Factor by which to upscale. Defaults to 2. 112 | kernel_size (int, optional): Size of the convolution kernel. Defaults to 3. 113 | 114 | Returns: 115 | nn.Sequential: Sequential block for upsampling. 116 | """ 117 | conv = conv_layer(in_channels, out_channels * (upscale_factor**2), kernel_size) 118 | pixel_shuffle = nn.PixelShuffle(upscale_factor) 119 | return sequential(conv, pixel_shuffle) 120 | 121 | 122 | class Conv3XC(nn.Module): 123 | def __init__( 124 | self, 125 | c_in: int, 126 | c_out: int, 127 | gain1: int = 1, 128 | s: int = 1, 129 | bias: bool = True, 130 | relu: bool = False, 131 | train_mode: bool = True, 132 | ): 133 | """ 134 | Custom 3-stage convolutional block with optional ReLU activation and train/evaluation mode support. 135 | 136 | Args: 137 | c_in (int): Number of input channels. 138 | c_out (int): Number of output channels. 139 | gain1 (int, optional): Gain multiplier for intermediate layers. Defaults to 1. 140 | s (int, optional): Stride value for the convolutions. Defaults to 1. 141 | bias (bool, optional): Whether to include a bias term in the convolutions. Defaults to True. 142 | relu (bool, optional): If True, apply a LeakyReLU activation after the convolution. Defaults to False. 143 | train_mode (bool, optional): If True, use training mode with learnable parameters. Defaults to True. 144 | """ 145 | super(Conv3XC, self).__init__() 146 | self.train_mode = train_mode 147 | self.weight_concat = None 148 | self.bias_concat = None 149 | self.update_params_flag = False 150 | self.stride = s 151 | self.has_relu = relu 152 | gain = gain1 153 | 154 | self.sk = nn.Conv2d( 155 | in_channels=c_in, 156 | out_channels=c_out, 157 | kernel_size=1, 158 | padding=0, 159 | stride=s, 160 | bias=bias, 161 | ) 162 | self.conv = nn.Sequential( 163 | nn.Conv2d( 164 | in_channels=c_in, 165 | out_channels=c_in * gain, 166 | kernel_size=1, 167 | padding=0, 168 | bias=bias, 169 | ), 170 | nn.Conv2d( 171 | in_channels=c_in * gain, 172 | out_channels=c_out * gain, 173 | kernel_size=3, 174 | stride=s, 175 | padding=0, 176 | bias=bias, 177 | ), 178 | nn.Conv2d( 179 | in_channels=c_out * gain, 180 | out_channels=c_out, 181 | kernel_size=1, 182 | padding=0, 183 | bias=bias, 184 | ), 185 | ) 186 | 187 | self.eval_conv = nn.Conv2d( 188 | in_channels=c_in, 189 | out_channels=c_out, 190 | kernel_size=3, 191 | padding=1, 192 | stride=s, 193 | bias=bias, 194 | ) 195 | self.eval_conv.weight.requires_grad = False 196 | self.eval_conv.bias.requires_grad = False 197 | self.update_params() 198 | 199 | def update_params(self): 200 | """ 201 | Updates the parameters for evaluation mode by combining weights from the convolution layers. 202 | """ 203 | w1 = self.conv[0].weight.data.clone().detach() 204 | b1 = self.conv[0].bias.data.clone().detach() 205 | w2 = self.conv[1].weight.data.clone().detach() 206 | b2 = self.conv[1].bias.data.clone().detach() 207 | w3 = self.conv[2].weight.data.clone().detach() 208 | b3 = self.conv[2].bias.data.clone().detach() 209 | 210 | w = ( 211 | F.conv2d(w1.flip(2, 3).permute(1, 0, 2, 3), w2, padding=2, stride=1) 212 | .flip(2, 3) 213 | .permute(1, 0, 2, 3) 214 | ) 215 | b = (w2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + b2 216 | 217 | self.weight_concat = ( 218 | F.conv2d(w.flip(2, 3).permute(1, 0, 2, 3), w3, padding=0, stride=1) 219 | .flip(2, 3) 220 | .permute(1, 0, 2, 3) 221 | ) 222 | self.bias_concat = (w3 * b.reshape(1, -1, 1, 1)).sum((1, 2, 3)) + b3 223 | 224 | sk_w = self.sk.weight.data.clone().detach() 225 | sk_b = self.sk.bias.data.clone().detach() 226 | target_kernel_size = 3 227 | 228 | H_pixels_to_pad = (target_kernel_size - 1) // 2 229 | W_pixels_to_pad = (target_kernel_size - 1) // 2 230 | sk_w = F.pad( 231 | sk_w, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad] 232 | ) 233 | 234 | self.weight_concat = self.weight_concat + sk_w 235 | self.bias_concat = self.bias_concat + sk_b 236 | 237 | self.eval_conv.weight.data = self.weight_concat 238 | self.eval_conv.bias.data = self.bias_concat 239 | 240 | def forward(self, x: torch.Tensor) -> torch.Tensor: 241 | """ 242 | Forward pass of the convolution block. 243 | 244 | Args: 245 | x (torch.Tensor): Input tensor. 246 | 247 | Returns: 248 | torch.Tensor: Output tensor after convolution and optional activation. 249 | """ 250 | if self.train_mode: 251 | pad = 1 252 | x_pad = F.pad(x, (pad, pad, pad, pad), "constant", 0) 253 | out = self.conv(x_pad) + self.sk(x) 254 | else: 255 | self.update_params() 256 | out = self.eval_conv(x) 257 | 258 | if self.has_relu: 259 | out = F.leaky_relu(out, negative_slope=0.05) 260 | return out 261 | 262 | 263 | class SPAB(nn.Module): 264 | def __init__( 265 | self, 266 | in_channels: int, 267 | mid_channels: Optional[int] = None, 268 | out_channels: Optional[int] = None, 269 | train_mode: bool = True, 270 | bias: bool = False, 271 | ): 272 | """ 273 | Self-parameterized attention block (SPAB) with multiple convolution layers. 274 | 275 | Args: 276 | in_channels (int): Number of input channels. 277 | mid_channels (Optional[int], optional): Number of middle channels. Defaults to in_channels. 278 | out_channels (Optional[int], optional): Number of output channels. Defaults to in_channels. 279 | train_mode (bool, optional): Indicates if the block is in training mode. Defaults to True. 280 | bias (bool, optional): Include bias in convolutions. Defaults to False. 281 | """ 282 | super(SPAB, self).__init__() 283 | if mid_channels is None: 284 | mid_channels = in_channels 285 | if out_channels is None: 286 | out_channels = in_channels 287 | 288 | self.in_channels = in_channels 289 | self.c1_r = Conv3XC( 290 | in_channels, mid_channels, gain1=2, s=1, train_mode=train_mode 291 | ) 292 | self.c2_r = Conv3XC( 293 | mid_channels, mid_channels, gain1=2, s=1, train_mode=train_mode 294 | ) 295 | self.c3_r = Conv3XC( 296 | mid_channels, out_channels, gain1=2, s=1, train_mode=train_mode 297 | ) 298 | self.act1 = torch.nn.SiLU(inplace=True) 299 | self.act2 = activation("lrelu", neg_slope=0.1, inplace=True) 300 | 301 | def forward(self, x: torch.Tensor) -> tuple: 302 | """ 303 | Forward pass of the SPAB block. 304 | 305 | Args: 306 | x (torch.Tensor): Input tensor. 307 | 308 | Returns: 309 | tuple: (Output tensor, intermediate tensor, attention map). 310 | """ 311 | out1 = self.c1_r(x) 312 | out1_act = self.act1(out1) 313 | 314 | out2 = self.c2_r(out1_act) 315 | out2_act = self.act1(out2) 316 | 317 | out3 = self.c3_r(out2_act) 318 | 319 | sim_att = torch.sigmoid(out3) - 0.5 320 | out = (out3 + x) * sim_att 321 | 322 | return out, out1, sim_att 323 | 324 | 325 | class CNNSR(nn.Module): 326 | """ 327 | Swift Parameter-free Attention Network (SPAN) for efficient super-resolution 328 | with deeper layers and channel attention. 329 | """ 330 | 331 | def __init__( 332 | self, 333 | in_channels: int, 334 | out_channels: int, 335 | feature_channels: int = 48, 336 | upscale: int = 4, 337 | bias: bool = True, 338 | train_mode: bool = True, 339 | num_blocks: int = 10, 340 | **kwargs, 341 | ): 342 | """ 343 | Initializes the CNNSR model. 344 | 345 | Args: 346 | in_channels (int): Number of input channels. 347 | out_channels (int): Number of output channels. 348 | feature_channels (int, optional): Number of feature channels. Defaults to 48. 349 | upscale (int, optional): Upscaling factor. Defaults to 4. 350 | bias (bool, optional): Whether to include a bias term. Defaults to True. 351 | train_mode (bool, optional): If True, the model is in training mode. Defaults to True. 352 | num_blocks (int, optional): Number of attention blocks in the network. Defaults to 10. 353 | """ 354 | super(CNNSR, self).__init__() 355 | 356 | # Initial Convolution 357 | self.conv_1 = Conv3XC( 358 | in_channels, feature_channels, gain1=2, s=1, train_mode=train_mode 359 | ) 360 | 361 | # Deeper Blocks 362 | self.blocks = nn.ModuleList( 363 | [ 364 | SPAB(feature_channels, bias=bias, train_mode=train_mode) 365 | for _ in range(num_blocks) 366 | ] 367 | ) 368 | 369 | # Convolution after attention blocks 370 | self.conv_cat = conv_layer( 371 | feature_channels * 4, feature_channels, kernel_size=1, bias=True 372 | ) 373 | self.conv_2 = Conv3XC( 374 | feature_channels, feature_channels, gain1=2, s=1, train_mode=train_mode 375 | ) 376 | 377 | # Upsampling 378 | self.upsampler = pixelshuffle_block( 379 | feature_channels, out_channels, upscale_factor=upscale 380 | ) 381 | 382 | def forward( 383 | self, x: torch.Tensor, save_attentions: Optional[List[int]] = None 384 | ) -> Union[torch.Tensor, tuple]: 385 | """ 386 | Forward pass of the CNNSR model. 387 | 388 | Args: 389 | x (torch.Tensor): Input tensor. 390 | save_attentions (Optional[List[int]], optional): List of block indices from which to save attention maps. 391 | 392 | Returns: 393 | torch.Tensor: Super-resolved output. 394 | tuple: If save_attentions is specified, returns (output tensor, attention maps). 395 | """ 396 | # Initial Convolution 397 | out_feature = self.conv_1(x) 398 | 399 | # Pass through all blocks, accumulating attention outputs 400 | attentions = [] 401 | for index, block in enumerate(self.blocks): 402 | out, out2, att = block(out_feature) 403 | 404 | # Save the first residual block output 405 | if index == 0: 406 | out_b1 = out 407 | 408 | # Save the last residual block output 409 | if index == len(self.blocks) - 1: 410 | out_blast = out2 411 | 412 | # Save attention if needed 413 | if save_attentions is not None and index in save_attentions: 414 | attentions.append(att) 415 | 416 | # Final Convolution and concatenation 417 | out_bn = self.conv_2(out) 418 | out = self.conv_cat(torch.cat([out_feature, out_bn, out_b1, out_blast], 1)) 419 | 420 | # Upsample 421 | output = self.upsampler(out) 422 | 423 | if save_attentions is not None: 424 | return output, attentions 425 | return output 426 | -------------------------------------------------------------------------------- /sen2sr/models/opensr_baseline/mamba.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from typing import Callable, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.utils.checkpoint as checkpoint 9 | from einops import repeat 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | 12 | try: 13 | from mamba_ssm.ops.selective_scan_interface import selective_scan_fn 14 | except ImportError: 15 | # Handle the missing dependency here, for example: 16 | raise ImportError( 17 | "Please install the mamba_ssm package before using MambaSR model." 18 | ) 19 | 20 | 21 | class ChannelAttention(nn.Module): 22 | """ 23 | Implements channel-wise attention mechanism to recalibrate feature responses. 24 | 25 | Args: 26 | num_feat (int): Number of input feature channels. 27 | squeeze_factor (int, optional): Factor by which the feature channels are reduced in the squeeze operation. Default is 16. 28 | """ 29 | 30 | def __init__(self, num_feat, squeeze_factor=16): 31 | super(ChannelAttention, self).__init__() 32 | self.attention = nn.Sequential( 33 | nn.AdaptiveAvgPool2d(1), 34 | nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), 37 | nn.Sigmoid(), 38 | ) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | """ 42 | Forward pass of the channel attention mechanism. 43 | 44 | Args: 45 | x (torch.Tensor): Input feature map of shape (B, C, H, W), where B is batch size, C is number of channels, H is height, and W is width. 46 | 47 | Returns: 48 | torch.Tensor: Output feature map with recalibrated channel responses, same shape as input. 49 | """ 50 | y = self.attention(x) 51 | return x * y 52 | 53 | 54 | class CAB(nn.Module): 55 | """ 56 | Convolutional Attention Block (CAB) that combines convolution layers and channel attention for feature enhancement. 57 | 58 | Args: 59 | num_feat (int): Number of input feature channels. 60 | is_light_sr (bool, optional): If True, applies a higher compression ratio for lightweight super-resolution models. Default is False. 61 | compress_ratio (int, optional): Compression ratio for reducing channels in the convolution layers. Default is 3. 62 | squeeze_factor (int, optional): Factor used in the channel attention for squeezing feature maps. Default is 30. 63 | """ 64 | 65 | def __init__( 66 | self, num_feat, is_light_sr=False, compress_ratio=3, squeeze_factor=30 67 | ): 68 | super(CAB, self).__init__() 69 | if is_light_sr: # a larger compression ratio is used for light-SR 70 | compress_ratio = 6 71 | self.cab = nn.Sequential( 72 | nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), 73 | nn.GELU(), 74 | nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), 75 | ChannelAttention(num_feat, squeeze_factor), 76 | ) 77 | 78 | def forward(self, x: torch.Tensor) -> torch.Tensor: 79 | """ 80 | Forward pass of the CAB. 81 | 82 | Args: 83 | x (torch.Tensor): Input feature map of shape (B, C, H, W). 84 | 85 | Returns: 86 | torch.Tensor: Enhanced feature map after convolution and attention, with the same shape as input. 87 | """ 88 | return self.cab(x) 89 | 90 | 91 | class Mlp(nn.Module): 92 | """ 93 | Multi-layer perceptron used for transforming feature representations in transformer models. 94 | 95 | Args: 96 | in_features (int): Number of input features. 97 | hidden_features (int, optional): Number of hidden layer features. Default is in_features. 98 | out_features (int, optional): Number of output features. Default is in_features. 99 | act_layer (nn.Module, optional): Activation function applied after the first fully connected layer. Default is GELU. 100 | drop (float, optional): Dropout rate applied after the activation and second fully connected layer. Default is 0.0. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | in_features, 106 | hidden_features=None, 107 | out_features=None, 108 | act_layer=nn.GELU, 109 | drop=0.0, 110 | ): 111 | super().__init__() 112 | out_features = out_features or in_features 113 | hidden_features = hidden_features or in_features 114 | self.fc1 = nn.Linear(in_features, hidden_features) 115 | self.act = act_layer() 116 | self.fc2 = nn.Linear(hidden_features, out_features) 117 | self.drop = nn.Dropout(drop) 118 | 119 | def forward(self, x: torch.Tensor) -> torch.Tensor: 120 | """ 121 | Forward pass of the MLP. 122 | 123 | Args: 124 | x (torch.Tensor): Input tensor of shape (B, N, C), where B is batch size, N is sequence length, and C is number of channels. 125 | 126 | Returns: 127 | torch.Tensor: Output tensor with transformed feature representations, same shape as input. 128 | """ 129 | x = self.fc1(x) 130 | x = self.act(x) 131 | x = self.drop(x) 132 | x = self.fc2(x) 133 | x = self.drop(x) 134 | return x 135 | 136 | 137 | class DynamicPosBias(nn.Module): 138 | """ 139 | Dynamic positional bias generator for adding spatial positional encoding in attention mechanisms. 140 | 141 | Args: 142 | dim (int): Number of input channels. 143 | num_heads (int): Number of attention heads. 144 | """ 145 | 146 | def __init__(self, dim, num_heads): 147 | super().__init__() 148 | self.num_heads = num_heads 149 | self.pos_dim = dim // 4 150 | self.pos_proj = nn.Linear(2, self.pos_dim) 151 | self.pos1 = nn.Sequential( 152 | nn.LayerNorm(self.pos_dim), 153 | nn.ReLU(inplace=True), 154 | nn.Linear(self.pos_dim, self.pos_dim), 155 | ) 156 | self.pos2 = nn.Sequential( 157 | nn.LayerNorm(self.pos_dim), 158 | nn.ReLU(inplace=True), 159 | nn.Linear(self.pos_dim, self.pos_dim), 160 | ) 161 | self.pos3 = nn.Sequential( 162 | nn.LayerNorm(self.pos_dim), 163 | nn.ReLU(inplace=True), 164 | nn.Linear(self.pos_dim, self.num_heads), 165 | ) 166 | 167 | def forward(self, biases: torch.Tensor) -> torch.Tensor: 168 | """ 169 | Forward pass to compute positional biases. 170 | 171 | Args: 172 | biases (torch.Tensor): Input tensor representing spatial biases of shape (2,). 173 | 174 | Returns: 175 | torch.Tensor: Output tensor with dynamically generated positional biases for attention heads. 176 | """ 177 | pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases)))) 178 | return pos 179 | 180 | def flops(self, N: int) -> int: 181 | """ 182 | Calculate the number of floating point operations (FLOPs) required by the positional bias mechanism. 183 | 184 | Args: 185 | N (int): Number of tokens (e.g., the product of height and width of the feature map). 186 | 187 | Returns: 188 | int: Total number of FLOPs required for the positional bias computations. 189 | """ 190 | flops = N * 2 * self.pos_dim 191 | flops += N * self.pos_dim * self.pos_dim 192 | flops += N * self.pos_dim * self.pos_dim 193 | flops += N * self.pos_dim * self.num_heads 194 | return flops 195 | 196 | 197 | class Attention(nn.Module): 198 | """ 199 | Multi-head self-attention with dynamic positional bias for transformer models. 200 | 201 | Args: 202 | dim (int): Number of input channels. 203 | num_heads (int): Number of attention heads. 204 | qkv_bias (bool, optional): Whether to include bias in query, key, and value projections. Default is True. 205 | qk_scale (float, optional): Custom scale factor for query-key dot product. Default is None (uses head_dim**-0.5). 206 | attn_drop (float, optional): Dropout rate for attention weights. Default is 0.0. 207 | proj_drop (float, optional): Dropout rate for output projection. Default is 0.0. 208 | position_bias (bool, optional): Whether to include dynamic positional bias in the attention mechanism. Default is True. 209 | """ 210 | 211 | def __init__( 212 | self, 213 | dim, 214 | num_heads, 215 | qkv_bias=True, 216 | qk_scale=None, 217 | attn_drop=0.0, 218 | proj_drop=0.0, 219 | position_bias=True, 220 | ): 221 | super().__init__() 222 | self.dim = dim 223 | self.num_heads = num_heads 224 | head_dim = dim // num_heads 225 | self.scale = qk_scale or head_dim**-0.5 226 | self.position_bias = position_bias 227 | if self.position_bias: 228 | self.pos = DynamicPosBias(self.dim // 4, self.num_heads) 229 | 230 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 231 | self.attn_drop = nn.Dropout(attn_drop) 232 | self.proj = nn.Linear(dim, dim) 233 | self.proj_drop = nn.Dropout(proj_drop) 234 | 235 | self.softmax = nn.Softmax(dim=-1) 236 | 237 | def forward( 238 | self, x: torch.Tensor, H: int, W: int, mask: Optional[torch.Tensor] = None 239 | ) -> torch.Tensor: 240 | """ 241 | Forward pass of the multi-head attention mechanism. 242 | 243 | Args: 244 | x (torch.Tensor): Input tensor of shape (B, N, C), where B is batch size, N is number of tokens (H*W), and C is number of channels. 245 | H (int): Height of the feature map. 246 | W (int): Width of the feature map. 247 | mask (Optional[torch.Tensor]): Optional mask for blocking certain attention connections. Default is None. 248 | 249 | Returns: 250 | torch.Tensor: Output tensor after applying attention mechanism, same shape as input. 251 | """ 252 | group_size = (H, W) 253 | B_, N, C = x.shape 254 | assert H * W == N 255 | qkv = ( 256 | self.qkv(x) 257 | .reshape(B_, N, 3, self.num_heads, C // self.num_heads) 258 | .permute(2, 0, 3, 1, 4) 259 | .contiguous() 260 | ) 261 | q, k, v = qkv[0], qkv[1], qkv[2] 262 | 263 | q = q * self.scale 264 | attn = q @ k.transpose(-2, -1) # (B_, self.num_heads, N, N), N = H*W 265 | 266 | if self.position_bias: 267 | # generate mother-set 268 | position_bias_h = torch.arange( 269 | 1 - group_size[0], group_size[0], device=attn.device 270 | ) 271 | position_bias_w = torch.arange( 272 | 1 - group_size[1], group_size[1], device=attn.device 273 | ) 274 | biases = torch.stack( 275 | torch.meshgrid([position_bias_h, position_bias_w], indexing="ij") 276 | ) # 2, 2Gh-1, 2W2-1 277 | biases = ( 278 | biases.flatten(1).transpose(0, 1).contiguous().float() 279 | ) # (2h-1)*(2w-1) 2 280 | 281 | # get pair-wise relative position index for each token inside the window 282 | coords_h = torch.arange(group_size[0], device=attn.device) 283 | coords_w = torch.arange(group_size[1], device=attn.device) 284 | coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) 285 | coords_flatten = torch.flatten(coords, 1) # 2, Gh*Gw 286 | relative_coords = ( 287 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 288 | ) # 2, Gh*Gw, Gh*Gw 289 | relative_coords = relative_coords.permute( 290 | 1, 2, 0 291 | ).contiguous() # Gh*Gw, Gh*Gw, 2 292 | relative_coords[:, :, 0] += group_size[0] - 1 # shift to start from 0 293 | relative_coords[:, :, 1] += group_size[1] - 1 294 | relative_coords[:, :, 0] *= 2 * group_size[1] - 1 295 | relative_position_index = relative_coords.sum(-1) # Gh*Gw, Gh*Gw 296 | 297 | pos = self.pos(biases) # 2Gh-1 * 2Gw-1, heads 298 | # select position bias 299 | relative_position_bias = pos[relative_position_index.view(-1)].view( 300 | group_size[0] * group_size[1], group_size[0] * group_size[1], -1 301 | ) # Gh*Gw,Gh*Gw,nH 302 | relative_position_bias = relative_position_bias.permute( 303 | 2, 0, 1 304 | ).contiguous() # nH, Gh*Gw, Gh*Gw 305 | attn = attn + relative_position_bias.unsqueeze(0) 306 | 307 | if mask is not None: 308 | nP = mask.shape[0] 309 | attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze( 310 | 1 311 | ).unsqueeze( 312 | 0 313 | ) # (B, nP, nHead, N, N) 314 | attn = attn.view(-1, self.num_heads, N, N) 315 | attn = self.softmax(attn) 316 | else: 317 | attn = self.softmax(attn) 318 | 319 | attn = self.attn_drop(attn) 320 | 321 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 322 | x = self.proj(x) 323 | x = self.proj_drop(x) 324 | return x 325 | 326 | 327 | class SS2D(nn.Module): 328 | """ 329 | Implements the 2D state-space model for attention mechanisms, incorporating convolutions and learned dynamic projections. 330 | 331 | Args: 332 | d_model (int): Dimension of the input model. 333 | d_state (int, optional): Number of states in the state-space model. Default is 16. 334 | d_conv (int, optional): Convolution kernel size. Default is 3. 335 | expand (float, optional): Expansion factor for the inner dimensions. Default is 2.0. 336 | dt_rank (str or int, optional): Rank of the time-step projection. If "auto", it is determined based on d_model. Default is "auto". 337 | dt_min (float, optional): Minimum value for the time-step bias initialization. Default is 0.001. 338 | dt_max (float, optional): Maximum value for the time-step bias initialization. Default is 0.1. 339 | dt_init (str, optional): Initialization strategy for time-step bias, either "random" or "constant". Default is "random". 340 | dt_scale (float, optional): Scaling factor for time-step projection initialization. Default is 1.0. 341 | dt_init_floor (float, optional): Minimum floor value for time-step bias. Default is 1e-4. 342 | dropout (float, optional): Dropout rate applied to the output. Default is 0.0. 343 | conv_bias (bool, optional): If True, adds a bias to the convolution layer. Default is True. 344 | bias (bool, optional): If True, adds a bias to the linear layers. Default is False. 345 | device (torch.device, optional): Device on which to create parameters. Default is None. 346 | dtype (torch.dtype, optional): Data type for parameters. Default is None. 347 | """ 348 | 349 | def __init__( 350 | self, 351 | d_model, 352 | d_state=16, 353 | d_conv=3, 354 | expand=2.0, 355 | dt_rank="auto", 356 | dt_min=0.001, 357 | dt_max=0.1, 358 | dt_init="random", 359 | dt_scale=1.0, 360 | dt_init_floor=1e-4, 361 | dropout=0.0, 362 | conv_bias=True, 363 | bias=False, 364 | device=None, 365 | dtype=None, 366 | **kwargs, 367 | ): 368 | 369 | factory_kwargs = {"device": device, "dtype": dtype} 370 | super().__init__() 371 | self.d_model = d_model 372 | self.d_state = d_state 373 | self.d_conv = d_conv 374 | self.expand = expand 375 | self.d_inner = int(self.expand * self.d_model) 376 | self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank 377 | 378 | self.in_proj = nn.Linear( 379 | self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs 380 | ) 381 | self.conv2d = nn.Conv2d( 382 | in_channels=self.d_inner, 383 | out_channels=self.d_inner, 384 | groups=self.d_inner, 385 | bias=conv_bias, 386 | kernel_size=d_conv, 387 | padding=(d_conv - 1) // 2, 388 | **factory_kwargs, 389 | ) 390 | self.act = nn.SiLU() 391 | 392 | self.x_proj = ( 393 | nn.Linear( 394 | self.d_inner, 395 | (self.dt_rank + self.d_state * 2), 396 | bias=False, 397 | **factory_kwargs, 398 | ), 399 | nn.Linear( 400 | self.d_inner, 401 | (self.dt_rank + self.d_state * 2), 402 | bias=False, 403 | **factory_kwargs, 404 | ), 405 | nn.Linear( 406 | self.d_inner, 407 | (self.dt_rank + self.d_state * 2), 408 | bias=False, 409 | **factory_kwargs, 410 | ), 411 | nn.Linear( 412 | self.d_inner, 413 | (self.dt_rank + self.d_state * 2), 414 | bias=False, 415 | **factory_kwargs, 416 | ), 417 | ) 418 | self.x_proj_weight = nn.Parameter( 419 | torch.stack([t.weight for t in self.x_proj], dim=0) 420 | ) # (K=4, N, inner) 421 | del self.x_proj 422 | 423 | self.dt_projs = ( 424 | self.dt_init( 425 | self.dt_rank, 426 | self.d_inner, 427 | dt_scale, 428 | dt_init, 429 | dt_min, 430 | dt_max, 431 | dt_init_floor, 432 | **factory_kwargs, 433 | ), 434 | self.dt_init( 435 | self.dt_rank, 436 | self.d_inner, 437 | dt_scale, 438 | dt_init, 439 | dt_min, 440 | dt_max, 441 | dt_init_floor, 442 | **factory_kwargs, 443 | ), 444 | self.dt_init( 445 | self.dt_rank, 446 | self.d_inner, 447 | dt_scale, 448 | dt_init, 449 | dt_min, 450 | dt_max, 451 | dt_init_floor, 452 | **factory_kwargs, 453 | ), 454 | self.dt_init( 455 | self.dt_rank, 456 | self.d_inner, 457 | dt_scale, 458 | dt_init, 459 | dt_min, 460 | dt_max, 461 | dt_init_floor, 462 | **factory_kwargs, 463 | ), 464 | ) 465 | self.dt_projs_weight = nn.Parameter( 466 | torch.stack([t.weight for t in self.dt_projs], dim=0) 467 | ) # (K=4, inner, rank) 468 | self.dt_projs_bias = nn.Parameter( 469 | torch.stack([t.bias for t in self.dt_projs], dim=0) 470 | ) # (K=4, inner) 471 | del self.dt_projs 472 | 473 | self.A_logs = self.A_log_init( 474 | self.d_state, self.d_inner, copies=4, merge=True 475 | ) # (K=4, D, N) 476 | self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N) 477 | 478 | self.selective_scan = selective_scan_fn 479 | 480 | self.out_norm = nn.LayerNorm(self.d_inner) 481 | self.out_proj = nn.Linear( 482 | self.d_inner, self.d_model, bias=bias, **factory_kwargs 483 | ) 484 | self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None 485 | 486 | @staticmethod 487 | def dt_init( 488 | dt_rank, 489 | d_inner, 490 | dt_scale=1.0, 491 | dt_init="random", 492 | dt_min=0.001, 493 | dt_max=0.1, 494 | dt_init_floor=1e-4, 495 | **factory_kwargs, 496 | ): 497 | dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) 498 | 499 | # Initialize special dt projection to preserve variance at initialization 500 | dt_init_std = dt_rank**-0.5 * dt_scale 501 | if dt_init == "constant": 502 | nn.init.constant_(dt_proj.weight, dt_init_std) 503 | elif dt_init == "random": 504 | nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) 505 | else: 506 | raise NotImplementedError 507 | 508 | # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max 509 | dt = torch.exp( 510 | torch.rand(d_inner, **factory_kwargs) 511 | * (math.log(dt_max) - math.log(dt_min)) 512 | + math.log(dt_min) 513 | ).clamp(min=dt_init_floor) 514 | # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 515 | inv_dt = dt + torch.log(-torch.expm1(-dt)) 516 | with torch.no_grad(): 517 | dt_proj.bias.copy_(inv_dt) 518 | # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit 519 | dt_proj.bias._no_reinit = True 520 | 521 | return dt_proj 522 | 523 | @staticmethod 524 | def A_log_init(d_state, d_inner, copies=1, device=None, merge=True): 525 | # S4D real initialization 526 | A = repeat( 527 | torch.arange(1, d_state + 1, dtype=torch.float32, device=device), 528 | "n -> d n", 529 | d=d_inner, 530 | ).contiguous() 531 | A_log = torch.log(A) # Keep A_log in fp32 532 | if copies > 1: 533 | A_log = repeat(A_log, "d n -> r d n", r=copies) 534 | if merge: 535 | A_log = A_log.flatten(0, 1) 536 | A_log = nn.Parameter(A_log) 537 | A_log._no_weight_decay = True 538 | return A_log 539 | 540 | @staticmethod 541 | def D_init(d_inner, copies=1, device=None, merge=True): 542 | # D "skip" parameter 543 | D = torch.ones(d_inner, device=device) 544 | if copies > 1: 545 | D = repeat(D, "n1 -> r n1", r=copies) 546 | if merge: 547 | D = D.flatten(0, 1) 548 | D = nn.Parameter(D) # Keep in fp32 549 | D._no_weight_decay = True 550 | return D 551 | 552 | def forward_core(self, x: torch.Tensor): 553 | B, C, H, W = x.shape 554 | L = H * W 555 | K = 4 556 | x_hwwh = torch.stack( 557 | [ 558 | x.view(B, -1, L), 559 | torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L), 560 | ], 561 | dim=1, 562 | ).view(B, 2, -1, L) 563 | xs = torch.cat( 564 | [x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1 565 | ) # (1, 4, 192, 3136) 566 | 567 | x_dbl = torch.einsum( 568 | "b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight 569 | ) 570 | dts, Bs, Cs = torch.split( 571 | x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2 572 | ) 573 | dts = torch.einsum( 574 | "b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight 575 | ) 576 | xs = xs.float().view(B, -1, L) 577 | dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) 578 | Bs = Bs.float().view(B, K, -1, L) 579 | Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l) 580 | Ds = self.Ds.float().view(-1) 581 | As = -torch.exp(self.A_logs.float()).view(-1, self.d_state) 582 | dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) 583 | out_y = self.selective_scan( 584 | xs, 585 | dts, 586 | As, 587 | Bs, 588 | Cs, 589 | Ds, 590 | z=None, 591 | delta_bias=dt_projs_bias, 592 | delta_softplus=True, 593 | return_last_state=False, 594 | ).view(B, K, -1, L) 595 | assert out_y.dtype == torch.float 596 | 597 | inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) 598 | wh_y = ( 599 | torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3) 600 | .contiguous() 601 | .view(B, -1, L) 602 | ) 603 | invwh_y = ( 604 | torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3) 605 | .contiguous() 606 | .view(B, -1, L) 607 | ) 608 | 609 | return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y 610 | 611 | def forward(self, x: torch.Tensor, **kwargs): 612 | B, H, W, C = x.shape 613 | 614 | xz = self.in_proj(x) 615 | x, z = xz.chunk(2, dim=-1) 616 | 617 | x = x.permute(0, 3, 1, 2).contiguous() 618 | x = self.act(self.conv2d(x)) 619 | y1, y2, y3, y4 = self.forward_core(x) 620 | assert y1.dtype == torch.float32 621 | y = y1 + y2 + y3 + y4 622 | y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1) 623 | y = self.out_norm(y) 624 | y = y * F.silu(z) 625 | out = self.out_proj(y) 626 | if self.dropout is not None: 627 | out = self.dropout(out) 628 | return out 629 | 630 | 631 | class VSSBlock(nn.Module): 632 | """ 633 | Vision State-Space Block (VSSBlock) that combines self-attention with state-space models and convolutional layers. 634 | 635 | Args: 636 | hidden_dim (int, optional): Dimensionality of the hidden layers. Default is 0. 637 | drop_path (float, optional): Dropout rate for the drop path. Default is 0. 638 | norm_layer (Callable[..., nn.Module], optional): Normalization layer to apply. Default is LayerNorm. 639 | attn_drop_rate (float, optional): Dropout rate for attention layers. Default is 0. 640 | d_state (int, optional): Number of states in the state-space model. Default is 16. 641 | expand (float, optional): Expansion factor for the inner dimensions in the attention block. Default is 2.0. 642 | is_light_sr (bool, optional): If True, applies lightweight super-resolution optimizations. Default is False. 643 | """ 644 | 645 | def __init__( 646 | self, 647 | hidden_dim: int = 0, 648 | drop_path: float = 0, 649 | norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), 650 | attn_drop_rate: float = 0, 651 | d_state: int = 16, 652 | expand: float = 2.0, 653 | is_light_sr: bool = False, 654 | **kwargs, 655 | ): 656 | super().__init__() 657 | self.ln_1 = norm_layer(hidden_dim) 658 | self.self_attention = SS2D( 659 | d_model=hidden_dim, 660 | d_state=d_state, 661 | expand=expand, 662 | dropout=attn_drop_rate, 663 | **kwargs, 664 | ) 665 | self.drop_path = DropPath(drop_path) 666 | self.skip_scale = nn.Parameter(torch.ones(hidden_dim)) 667 | self.conv_blk = CAB(hidden_dim, is_light_sr) 668 | self.ln_2 = nn.LayerNorm(hidden_dim) 669 | self.skip_scale2 = nn.Parameter(torch.ones(hidden_dim)) 670 | 671 | def forward(self, input: torch.Tensor, x_size: tuple) -> torch.Tensor: 672 | """ 673 | Forward pass of the VSSBlock. 674 | 675 | Args: 676 | input (torch.Tensor): Input tensor of shape (B, L, C), where B is batch size, L is sequence length, and C is the number of channels. 677 | x_size (tuple): Tuple representing the spatial dimensions (H, W) of the input. 678 | 679 | Returns: 680 | torch.Tensor: Output tensor after applying the VSSBlock, same shape as input. 681 | """ 682 | # x [B,HW,C] 683 | B, L, C = input.shape 684 | input = input.view(B, *x_size, C).contiguous() # [B,H,W,C] 685 | x = self.ln_1(input) 686 | x = input * self.skip_scale + self.drop_path(self.self_attention(x)) 687 | x = ( 688 | x * self.skip_scale2 689 | + self.conv_blk(self.ln_2(x).permute(0, 3, 1, 2).contiguous()) 690 | .permute(0, 2, 3, 1) 691 | .contiguous() 692 | ) 693 | x = x.view(B, -1, C).contiguous() 694 | return x 695 | 696 | 697 | class BasicLayer(nn.Module): 698 | """The Basic MambaIR Layer in one Residual State Space Group 699 | Args: 700 | dim (int): Number of input channels. 701 | input_resolution (tuple[int]): Input resolution. 702 | depth (int): Number of blocks. 703 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 704 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 705 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 706 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 707 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 708 | """ 709 | 710 | def __init__( 711 | self, 712 | dim, 713 | input_resolution, 714 | depth, 715 | drop_path=0.0, 716 | d_state=16, 717 | mlp_ratio=2.0, 718 | norm_layer=nn.LayerNorm, 719 | downsample=None, 720 | use_checkpoint=False, 721 | is_light_sr=False, 722 | ): 723 | super().__init__() 724 | self.dim = dim 725 | self.input_resolution = input_resolution 726 | self.depth = depth 727 | self.mlp_ratio = mlp_ratio 728 | self.use_checkpoint = use_checkpoint 729 | 730 | # build blocks 731 | self.blocks = nn.ModuleList() 732 | for i in range(depth): 733 | self.blocks.append( 734 | VSSBlock( 735 | hidden_dim=dim, 736 | drop_path=( 737 | drop_path[i] if isinstance(drop_path, list) else drop_path 738 | ), 739 | norm_layer=nn.LayerNorm, 740 | attn_drop_rate=0, 741 | d_state=d_state, 742 | expand=self.mlp_ratio, 743 | input_resolution=input_resolution, 744 | is_light_sr=is_light_sr, 745 | ) 746 | ) 747 | 748 | # patch merging layer 749 | if downsample is not None: 750 | self.downsample = downsample( 751 | input_resolution, dim=dim, norm_layer=norm_layer 752 | ) 753 | else: 754 | self.downsample = None 755 | 756 | def forward(self, x, x_size): 757 | for blk in self.blocks: 758 | if self.use_checkpoint: 759 | x = checkpoint.checkpoint(blk, x) 760 | else: 761 | x = blk(x, x_size) 762 | if self.downsample is not None: 763 | x = self.downsample(x) 764 | return x 765 | 766 | def extra_repr(self) -> str: 767 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 768 | 769 | def flops(self): 770 | flops = 0 771 | for blk in self.blocks: 772 | flops += blk.flops() 773 | if self.downsample is not None: 774 | flops += self.downsample.flops() 775 | return flops 776 | 777 | 778 | class MambaSR(nn.Module): 779 | r"""MambaIR Model 780 | A PyTorch impl of : `A Simple Baseline for Image Restoration with State Space Model `. 781 | 782 | Args: 783 | img_size (int | tuple(int)): Input image size. Default 64 784 | patch_size (int | tuple(int)): Patch size. Default: 1 785 | in_chans (int): Number of input image channels. Default: 3 786 | embed_dim (int): Patch embedding dimension. Default: 96 787 | d_state (int): num of hidden state in the state space model. Default: 16 788 | depths (tuple(int)): Depth of each RSSG 789 | drop_rate (float): Dropout rate. Default: 0 790 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 791 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 792 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 793 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 794 | upscale: Upscale factor. 2/3/4 for image SR, 1 for denoising 795 | img_range: Image range. 1. or 255. 796 | upsampler: The reconstruction reconstruction module. 'pixelshuffle'/None 797 | resi_connection: The convolutional block before residual connection. '1conv'/'3conv' 798 | """ 799 | 800 | def __init__( 801 | self, 802 | img_size=64, 803 | patch_size=1, 804 | in_channels=3, 805 | out_channels=3, 806 | embed_dim=96, 807 | depths=(6, 6, 6, 6), 808 | drop_rate=0.0, 809 | d_state=16, 810 | mlp_ratio=2.0, 811 | drop_path_rate=0.1, 812 | norm_layer=nn.LayerNorm, 813 | patch_norm=True, 814 | use_checkpoint=False, 815 | upscale=2, 816 | upsampler="", 817 | resi_connection="1conv", 818 | **kwargs, 819 | ): 820 | super(MambaSR, self).__init__() 821 | num_in_ch = in_channels 822 | num_out_ch = out_channels 823 | num_feat = 64 824 | self.upscale = upscale 825 | self.upsampler = upsampler 826 | self.mlp_ratio = mlp_ratio 827 | # ------------------------- 1, shallow feature extraction ------------------------- # 828 | self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) 829 | 830 | # ------------------------- 2, deep feature extraction ------------------------- # 831 | self.num_layers = len(depths) 832 | self.embed_dim = embed_dim 833 | self.patch_norm = patch_norm 834 | self.num_features = embed_dim 835 | 836 | # transfer 2D feature map into 1D token sequence, pay attention to whether using normalization 837 | self.patch_embed = PatchEmbed( 838 | img_size=img_size, 839 | patch_size=patch_size, 840 | in_chans=embed_dim, 841 | embed_dim=embed_dim, 842 | norm_layer=norm_layer if self.patch_norm else None, 843 | ) 844 | num_patches = self.patch_embed.num_patches 845 | patches_resolution = self.patch_embed.patches_resolution 846 | self.patches_resolution = patches_resolution 847 | 848 | # return 2D feature map from 1D token sequence 849 | self.patch_unembed = PatchUnEmbed( 850 | img_size=img_size, 851 | patch_size=patch_size, 852 | in_chans=embed_dim, 853 | embed_dim=embed_dim, 854 | norm_layer=norm_layer if self.patch_norm else None, 855 | ) 856 | 857 | self.pos_drop = nn.Dropout(p=drop_rate) 858 | self.is_light_sr = True if self.upsampler == "pixelshuffledirect" else False 859 | # stochastic depth 860 | dpr = [ 861 | x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) 862 | ] # stochastic depth decay rule 863 | 864 | # build Residual State Space Group (RSSG) 865 | self.layers = nn.ModuleList() 866 | for i_layer in range(self.num_layers): # 6-layer 867 | layer = ResidualGroup( 868 | dim=embed_dim, 869 | input_resolution=(patches_resolution[0], patches_resolution[1]), 870 | depth=depths[i_layer], 871 | d_state=d_state, 872 | mlp_ratio=self.mlp_ratio, 873 | drop_path=dpr[ 874 | sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) 875 | ], # no impact on SR results 876 | norm_layer=norm_layer, 877 | downsample=None, 878 | use_checkpoint=use_checkpoint, 879 | img_size=img_size, 880 | patch_size=patch_size, 881 | resi_connection=resi_connection, 882 | is_light_sr=self.is_light_sr, 883 | ) 884 | self.layers.append(layer) 885 | self.norm = norm_layer(self.num_features) 886 | 887 | # build the last conv layer in the end of all residual groups 888 | if resi_connection == "1conv": 889 | self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) 890 | elif resi_connection == "3conv": 891 | # to save parameters and memory 892 | self.conv_after_body = nn.Sequential( 893 | nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), 894 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 895 | nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), 896 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 897 | nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1), 898 | ) 899 | 900 | # -------------------------3. high-quality image reconstruction ------------------------ # 901 | if self.upsampler == "pixelshuffle": 902 | # for classical SR 903 | self.conv_before_upsample = nn.Sequential( 904 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 905 | ) 906 | self.upsample = Upsample(upscale, num_feat) 907 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 908 | elif self.upsampler == "pixelshuffledirect": 909 | # for lightweight SR (to save parameters) 910 | self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch) 911 | 912 | else: 913 | # for image denoising 914 | self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) 915 | 916 | self.apply(self._init_weights) 917 | 918 | def _init_weights(self, m): 919 | if isinstance(m, nn.Linear): 920 | trunc_normal_(m.weight, std=0.02) 921 | if isinstance(m, nn.Linear) and m.bias is not None: 922 | nn.init.constant_(m.bias, 0) 923 | elif isinstance(m, nn.LayerNorm): 924 | nn.init.constant_(m.bias, 0) 925 | nn.init.constant_(m.weight, 1.0) 926 | 927 | @torch.jit.ignore 928 | def no_weight_decay(self): 929 | return {"absolute_pos_embed"} 930 | 931 | @torch.jit.ignore 932 | def no_weight_decay_keywords(self): 933 | return {"relative_position_bias_table"} 934 | 935 | def forward_features(self, x): 936 | x_size = (x.shape[2], x.shape[3]) 937 | x = self.patch_embed(x) # N,L,C 938 | 939 | x = self.pos_drop(x) 940 | 941 | for layer in self.layers: 942 | x = layer(x, x_size) 943 | 944 | x = self.norm(x) # b seq_len c 945 | x = self.patch_unembed(x, x_size) 946 | 947 | return x 948 | 949 | def forward(self, x): 950 | if self.upsampler == "pixelshuffle": 951 | # for classical SR 952 | x = self.conv_first(x) 953 | x = self.conv_after_body(self.forward_features(x)) + x 954 | x = self.conv_before_upsample(x) 955 | x = self.conv_last(self.upsample(x)) 956 | 957 | elif self.upsampler == "pixelshuffledirect": 958 | # for lightweight SR 959 | x = self.conv_first(x) 960 | x = self.conv_after_body(self.forward_features(x)) + x 961 | x = self.upsample(x) 962 | 963 | else: 964 | # for image denoising 965 | x_first = self.conv_first(x) 966 | res = self.conv_after_body(self.forward_features(x_first)) + x_first 967 | x = x + self.conv_last(res) 968 | 969 | return x 970 | 971 | def flops(self): 972 | flops = 0 973 | h, w = self.patches_resolution 974 | flops += h * w * 3 * self.embed_dim * 9 975 | flops += self.patch_embed.flops() 976 | for layer in self.layers: 977 | flops += layer.flops() 978 | flops += h * w * 3 * self.embed_dim * self.embed_dim 979 | flops += self.upsample.flops() 980 | return flops 981 | 982 | 983 | class ResidualGroup(nn.Module): 984 | """Residual State Space Group (RSSG). 985 | 986 | Args: 987 | dim (int): Number of input channels. 988 | input_resolution (tuple[int]): Input resolution. 989 | depth (int): Number of blocks. 990 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 991 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 992 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 993 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 994 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 995 | img_size: Input image size. 996 | patch_size: Patch size. 997 | resi_connection: The convolutional block before residual connection. 998 | """ 999 | 1000 | def __init__( 1001 | self, 1002 | dim, 1003 | input_resolution, 1004 | depth, 1005 | d_state=16, 1006 | mlp_ratio=4.0, 1007 | drop_path=0.0, 1008 | norm_layer=nn.LayerNorm, 1009 | downsample=None, 1010 | use_checkpoint=False, 1011 | img_size=None, 1012 | patch_size=None, 1013 | resi_connection="1conv", 1014 | is_light_sr=False, 1015 | ): 1016 | super(ResidualGroup, self).__init__() 1017 | 1018 | self.dim = dim 1019 | self.input_resolution = input_resolution # [64, 64] 1020 | 1021 | self.residual_group = BasicLayer( 1022 | dim=dim, 1023 | input_resolution=input_resolution, 1024 | depth=depth, 1025 | d_state=d_state, 1026 | mlp_ratio=mlp_ratio, 1027 | drop_path=drop_path, 1028 | norm_layer=norm_layer, 1029 | downsample=downsample, 1030 | use_checkpoint=use_checkpoint, 1031 | is_light_sr=is_light_sr, 1032 | ) 1033 | 1034 | # build the last conv layer in each residual state space group 1035 | if resi_connection == "1conv": 1036 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1) 1037 | elif resi_connection == "3conv": 1038 | # to save parameters and memory 1039 | self.conv = nn.Sequential( 1040 | nn.Conv2d(dim, dim // 4, 3, 1, 1), 1041 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 1042 | nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), 1043 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 1044 | nn.Conv2d(dim // 4, dim, 3, 1, 1), 1045 | ) 1046 | 1047 | self.patch_embed = PatchEmbed( 1048 | img_size=img_size, 1049 | patch_size=patch_size, 1050 | in_chans=0, 1051 | embed_dim=dim, 1052 | norm_layer=None, 1053 | ) 1054 | 1055 | self.patch_unembed = PatchUnEmbed( 1056 | img_size=img_size, 1057 | patch_size=patch_size, 1058 | in_chans=0, 1059 | embed_dim=dim, 1060 | norm_layer=None, 1061 | ) 1062 | 1063 | def forward(self, x, x_size): 1064 | return ( 1065 | self.patch_embed( 1066 | self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) 1067 | ) 1068 | + x 1069 | ) 1070 | 1071 | def flops(self): 1072 | flops = 0 1073 | flops += self.residual_group.flops() 1074 | h, w = self.input_resolution 1075 | flops += h * w * self.dim * self.dim * 9 1076 | flops += self.patch_embed.flops() 1077 | flops += self.patch_unembed.flops() 1078 | 1079 | return flops 1080 | 1081 | 1082 | class PatchEmbed(nn.Module): 1083 | r"""transfer 2D feature map into 1D token sequence 1084 | 1085 | Args: 1086 | img_size (int): Image size. Default: None. 1087 | patch_size (int): Patch token size. Default: None. 1088 | in_chans (int): Number of input image channels. Default: 3. 1089 | embed_dim (int): Number of linear projection output channels. Default: 96. 1090 | norm_layer (nn.Module, optional): Normalization layer. Default: None 1091 | """ 1092 | 1093 | def __init__( 1094 | self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None 1095 | ): 1096 | super().__init__() 1097 | img_size = to_2tuple(img_size) 1098 | patch_size = to_2tuple(patch_size) 1099 | patches_resolution = [ 1100 | img_size[0] // patch_size[0], 1101 | img_size[1] // patch_size[1], 1102 | ] 1103 | self.img_size = img_size 1104 | self.patch_size = patch_size 1105 | self.patches_resolution = patches_resolution 1106 | self.num_patches = patches_resolution[0] * patches_resolution[1] 1107 | 1108 | self.in_chans = in_chans 1109 | self.embed_dim = embed_dim 1110 | 1111 | if norm_layer is not None: 1112 | self.norm = norm_layer(embed_dim) 1113 | else: 1114 | self.norm = None 1115 | 1116 | def forward(self, x): 1117 | x = x.flatten(2).transpose(1, 2) # b Ph*Pw c 1118 | if self.norm is not None: 1119 | x = self.norm(x) 1120 | return x 1121 | 1122 | def flops(self): 1123 | flops = 0 1124 | h, w = self.img_size 1125 | if self.norm is not None: 1126 | flops += h * w * self.embed_dim 1127 | return flops 1128 | 1129 | 1130 | class PatchUnEmbed(nn.Module): 1131 | r"""return 2D feature map from 1D token sequence 1132 | 1133 | Args: 1134 | img_size (int): Image size. Default: None. 1135 | patch_size (int): Patch token size. Default: None. 1136 | in_chans (int): Number of input image channels. Default: 3. 1137 | embed_dim (int): Number of linear projection output channels. Default: 96. 1138 | norm_layer (nn.Module, optional): Normalization layer. Default: None 1139 | """ 1140 | 1141 | def __init__( 1142 | self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None 1143 | ): 1144 | super().__init__() 1145 | img_size = to_2tuple(img_size) 1146 | patch_size = to_2tuple(patch_size) 1147 | patches_resolution = [ 1148 | img_size[0] // patch_size[0], 1149 | img_size[1] // patch_size[1], 1150 | ] 1151 | self.img_size = img_size 1152 | self.patch_size = patch_size 1153 | self.patches_resolution = patches_resolution 1154 | self.num_patches = patches_resolution[0] * patches_resolution[1] 1155 | 1156 | self.in_chans = in_chans 1157 | self.embed_dim = embed_dim 1158 | 1159 | def forward(self, x, x_size): 1160 | x = x.transpose(1, 2).view( 1161 | x.shape[0], self.embed_dim, x_size[0], x_size[1] 1162 | ) # b Ph*Pw c 1163 | return x 1164 | 1165 | def flops(self): 1166 | flops = 0 1167 | return flops 1168 | 1169 | 1170 | class UpsampleOneStep(nn.Sequential): 1171 | """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) 1172 | Used in lightweight SR to save parameters. 1173 | 1174 | Args: 1175 | scale (int): Scale factor. Supported scales: 2^n and 3. 1176 | num_feat (int): Channel number of intermediate features. 1177 | 1178 | """ 1179 | 1180 | def __init__(self, scale, num_feat, num_out_ch): 1181 | self.num_feat = num_feat 1182 | m = [] 1183 | m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) 1184 | m.append(nn.PixelShuffle(scale)) 1185 | super(UpsampleOneStep, self).__init__(*m) 1186 | 1187 | 1188 | class Upsample(nn.Sequential): 1189 | """Upsample module. 1190 | 1191 | Args: 1192 | scale (int): Scale factor. Supported scales: 2^n and 3. 1193 | num_feat (int): Channel number of intermediate features. 1194 | """ 1195 | 1196 | def __init__(self, scale, num_feat): 1197 | m = [] 1198 | if (scale & (scale - 1)) == 0: # scale = 2^n 1199 | for _ in range(int(math.log(scale, 2))): 1200 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 1201 | m.append(nn.PixelShuffle(2)) 1202 | elif scale == 3: 1203 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 1204 | m.append(nn.PixelShuffle(3)) 1205 | else: 1206 | raise ValueError( 1207 | f"scale {scale} is not supported. Supported scales: 2^n and 3." 1208 | ) 1209 | super(Upsample, self).__init__(*m) 1210 | -------------------------------------------------------------------------------- /sen2sr/models/opensr_baseline/swin.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------------- 2 | # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed 3 | # Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345 4 | # Written by Conde and Choi et al. 5 | # ----------------------------------------------------------------------------------- 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.utils.checkpoint as checkpoint 14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features, 21 | hidden_features=None, 22 | out_features=None, 23 | act_layer=nn.GELU, 24 | drop=0.0, 25 | ): 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x): 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | 42 | 43 | def window_partition(x, window_size): 44 | """ 45 | Args: 46 | x: (B, H, W, C) 47 | window_size (int): window size 48 | Returns: 49 | windows: (num_windows*B, window_size, window_size, C) 50 | """ 51 | B, H, W, C = x.shape 52 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 53 | windows = ( 54 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 55 | ) 56 | return windows 57 | 58 | 59 | def window_reverse(windows, window_size, H, W): 60 | """ 61 | Args: 62 | windows: (num_windows*B, window_size, window_size, C) 63 | window_size (int): Window size 64 | H (int): Height of image 65 | W (int): Width of image 66 | Returns: 67 | x: (B, H, W, C) 68 | """ 69 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 70 | x = windows.view( 71 | B, H // window_size, W // window_size, window_size, window_size, -1 72 | ) 73 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 74 | return x 75 | 76 | 77 | class WindowAttention(nn.Module): 78 | r"""Window based multi-head self attention (W-MSA) module with relative position bias. 79 | It supports both of shifted and non-shifted window. 80 | Args: 81 | dim (int): Number of input channels. 82 | window_size (tuple[int]): The height and width of the window. 83 | num_heads (int): Number of attention heads. 84 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 85 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 86 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 87 | pretrained_window_size (tuple[int]): The height and width of the window in pre-training. 88 | """ 89 | 90 | def __init__( 91 | self, 92 | dim, 93 | window_size, 94 | num_heads, 95 | qkv_bias=True, 96 | attn_drop=0.0, 97 | proj_drop=0.0, 98 | pretrained_window_size=[0, 0], 99 | ): 100 | super().__init__() 101 | self.dim = dim 102 | self.window_size = window_size # Wh, Ww 103 | self.pretrained_window_size = pretrained_window_size 104 | self.num_heads = num_heads 105 | 106 | self.logit_scale = nn.Parameter( 107 | torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True 108 | ) 109 | 110 | # mlp to generate continuous relative position bias 111 | self.cpb_mlp = nn.Sequential( 112 | nn.Linear(2, 512, bias=True), 113 | nn.ReLU(inplace=True), 114 | nn.Linear(512, num_heads, bias=False), 115 | ) 116 | 117 | # get relative_coords_table 118 | relative_coords_h = torch.arange( 119 | -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32 120 | ) 121 | relative_coords_w = torch.arange( 122 | -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32 123 | ) 124 | relative_coords_table = ( 125 | torch.stack( 126 | torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij") 127 | ) 128 | .permute(1, 2, 0) 129 | .contiguous() 130 | .unsqueeze(0) 131 | ) # 1, 2*Wh-1, 2*Ww-1, 2 132 | if pretrained_window_size[0] > 0: 133 | relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 134 | relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 135 | else: 136 | relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 137 | relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 138 | relative_coords_table *= 8 # normalize to -8, 8 139 | relative_coords_table = ( 140 | torch.sign(relative_coords_table) 141 | * torch.log2(torch.abs(relative_coords_table) + 1.0) 142 | / np.log2(8) 143 | ) 144 | 145 | self.register_buffer("relative_coords_table", relative_coords_table) 146 | 147 | # get pair-wise relative position index for each token inside the window 148 | coords_h = torch.arange(self.window_size[0]) 149 | coords_w = torch.arange(self.window_size[1]) 150 | coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) 151 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 152 | relative_coords = ( 153 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 154 | ) # 2, Wh*Ww, Wh*Ww 155 | relative_coords = relative_coords.permute( 156 | 1, 2, 0 157 | ).contiguous() # Wh*Ww, Wh*Ww, 2 158 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 159 | relative_coords[:, :, 1] += self.window_size[1] - 1 160 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 161 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 162 | self.register_buffer("relative_position_index", relative_position_index) 163 | 164 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 165 | if qkv_bias: 166 | self.q_bias = nn.Parameter(torch.zeros(dim)) 167 | self.v_bias = nn.Parameter(torch.zeros(dim)) 168 | else: 169 | self.q_bias = None 170 | self.v_bias = None 171 | self.attn_drop = nn.Dropout(attn_drop) 172 | self.proj = nn.Linear(dim, dim) 173 | self.proj_drop = nn.Dropout(proj_drop) 174 | self.softmax = nn.Softmax(dim=-1) 175 | 176 | def forward(self, x, mask=None): 177 | """ 178 | Args: 179 | x: input features with shape of (num_windows*B, N, C) 180 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 181 | """ 182 | B_, N, C = x.shape 183 | qkv_bias = None 184 | if self.q_bias is not None: 185 | qkv_bias = torch.cat( 186 | ( 187 | self.q_bias, 188 | torch.zeros_like(self.v_bias, requires_grad=False), 189 | self.v_bias, 190 | ) 191 | ) 192 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 193 | qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 194 | q, k, v = ( 195 | qkv[0], 196 | qkv[1], 197 | qkv[2], 198 | ) # make torchscript happy (cannot use tensor as tuple) 199 | 200 | # cosine attention 201 | attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) 202 | logit_scale = torch.clamp( 203 | self.logit_scale, 204 | max=torch.log(torch.tensor(1.0 / 0.01)).to(self.logit_scale.device), 205 | ).exp() 206 | attn = attn * logit_scale 207 | 208 | relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view( 209 | -1, self.num_heads 210 | ) 211 | relative_position_bias = relative_position_bias_table[ 212 | self.relative_position_index.view(-1) 213 | ].view( 214 | self.window_size[0] * self.window_size[1], 215 | self.window_size[0] * self.window_size[1], 216 | -1, 217 | ) # Wh*Ww,Wh*Ww,nH 218 | relative_position_bias = relative_position_bias.permute( 219 | 2, 0, 1 220 | ).contiguous() # nH, Wh*Ww, Wh*Ww 221 | relative_position_bias = 16 * torch.sigmoid(relative_position_bias) 222 | attn = attn + relative_position_bias.unsqueeze(0) 223 | 224 | if mask is not None: 225 | nW = mask.shape[0] 226 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( 227 | 1 228 | ).unsqueeze(0) 229 | attn = attn.view(-1, self.num_heads, N, N) 230 | attn = self.softmax(attn) 231 | else: 232 | attn = self.softmax(attn) 233 | 234 | attn = self.attn_drop(attn) 235 | 236 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 237 | x = self.proj(x) 238 | x = self.proj_drop(x) 239 | return x 240 | 241 | def extra_repr(self) -> str: 242 | return ( 243 | f"dim={self.dim}, window_size={self.window_size}, " 244 | f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}" 245 | ) 246 | 247 | def flops(self, N): 248 | # calculate flops for 1 window with token length of N 249 | flops = 0 250 | # qkv = self.qkv(x) 251 | flops += N * self.dim * 3 * self.dim 252 | # attn = (q @ k.transpose(-2, -1)) 253 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 254 | # x = (attn @ v) 255 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 256 | # x = self.proj(x) 257 | flops += N * self.dim * self.dim 258 | return flops 259 | 260 | 261 | class SwinTransformerBlock(nn.Module): 262 | r"""Swin Transformer Block. 263 | Args: 264 | dim (int): Number of input channels. 265 | input_resolution (tuple[int]): Input resulotion. 266 | num_heads (int): Number of attention heads. 267 | window_size (int): Window size. 268 | shift_size (int): Shift size for SW-MSA. 269 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 270 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 271 | drop (float, optional): Dropout rate. Default: 0.0 272 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 273 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 274 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 275 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 276 | pretrained_window_size (int): Window size in pre-training. 277 | """ 278 | 279 | def __init__( 280 | self, 281 | dim, 282 | input_resolution, 283 | num_heads, 284 | window_size=7, 285 | shift_size=0, 286 | mlp_ratio=4.0, 287 | qkv_bias=True, 288 | drop=0.0, 289 | attn_drop=0.0, 290 | drop_path=0.0, 291 | act_layer=nn.GELU, 292 | norm_layer=nn.LayerNorm, 293 | pretrained_window_size=0, 294 | ): 295 | super().__init__() 296 | self.dim = dim 297 | self.input_resolution = input_resolution 298 | self.num_heads = num_heads 299 | self.window_size = window_size 300 | self.shift_size = shift_size 301 | self.mlp_ratio = mlp_ratio 302 | if min(self.input_resolution) <= self.window_size: 303 | # if window size is larger than input resolution, we don't partition windows 304 | self.shift_size = 0 305 | self.window_size = min(self.input_resolution) 306 | assert ( 307 | 0 <= self.shift_size < self.window_size 308 | ), "shift_size must in 0-window_size" 309 | 310 | self.norm1 = norm_layer(dim) 311 | self.attn = WindowAttention( 312 | dim, 313 | window_size=to_2tuple(self.window_size), 314 | num_heads=num_heads, 315 | qkv_bias=qkv_bias, 316 | attn_drop=attn_drop, 317 | proj_drop=drop, 318 | pretrained_window_size=to_2tuple(pretrained_window_size), 319 | ) 320 | 321 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 322 | self.norm2 = norm_layer(dim) 323 | mlp_hidden_dim = int(dim * mlp_ratio) 324 | self.mlp = Mlp( 325 | in_features=dim, 326 | hidden_features=mlp_hidden_dim, 327 | act_layer=act_layer, 328 | drop=drop, 329 | ) 330 | 331 | if self.shift_size > 0: 332 | attn_mask = self.calculate_mask(self.input_resolution) 333 | else: 334 | attn_mask = None 335 | 336 | self.register_buffer("attn_mask", attn_mask) 337 | 338 | def calculate_mask(self, x_size): 339 | # calculate attention mask for SW-MSA 340 | H, W = x_size 341 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 342 | h_slices = ( 343 | slice(0, -self.window_size), 344 | slice(-self.window_size, -self.shift_size), 345 | slice(-self.shift_size, None), 346 | ) 347 | w_slices = ( 348 | slice(0, -self.window_size), 349 | slice(-self.window_size, -self.shift_size), 350 | slice(-self.shift_size, None), 351 | ) 352 | cnt = 0 353 | for h in h_slices: 354 | for w in w_slices: 355 | img_mask[:, h, w, :] = cnt 356 | cnt += 1 357 | 358 | mask_windows = window_partition( 359 | img_mask, self.window_size 360 | ) # nW, window_size, window_size, 1 361 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 362 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 363 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( 364 | attn_mask == 0, float(0.0) 365 | ) 366 | 367 | return attn_mask 368 | 369 | def forward(self, x, x_size): 370 | H, W = x_size 371 | B, L, C = x.shape 372 | # assert L == H * W, "input feature has wrong size" 373 | 374 | shortcut = x 375 | x = x.view(B, H, W, C) 376 | 377 | # cyclic shift 378 | if self.shift_size > 0: 379 | shifted_x = torch.roll( 380 | x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) 381 | ) 382 | else: 383 | shifted_x = x 384 | 385 | # partition windows 386 | x_windows = window_partition( 387 | shifted_x, self.window_size 388 | ) # nW*B, window_size, window_size, C 389 | x_windows = x_windows.view( 390 | -1, self.window_size * self.window_size, C 391 | ) # nW*B, window_size*window_size, C 392 | 393 | # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size 394 | if self.input_resolution == x_size: 395 | attn_windows = self.attn( 396 | x_windows, mask=self.attn_mask 397 | ) # nW*B, window_size*window_size, C 398 | else: 399 | attn_windows = self.attn( 400 | x_windows, mask=self.calculate_mask(x_size).to(x.device) 401 | ) 402 | 403 | # merge windows 404 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 405 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 406 | 407 | # reverse cyclic shift 408 | if self.shift_size > 0: 409 | x = torch.roll( 410 | shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) 411 | ) 412 | else: 413 | x = shifted_x 414 | x = x.view(B, H * W, C) 415 | x = shortcut + self.drop_path(self.norm1(x)) 416 | 417 | # FFN 418 | x = x + self.drop_path(self.norm2(self.mlp(x))) 419 | 420 | return x 421 | 422 | def extra_repr(self) -> str: 423 | return ( 424 | f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " 425 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 426 | ) 427 | 428 | def flops(self): 429 | flops = 0 430 | H, W = self.input_resolution 431 | # norm1 432 | flops += self.dim * H * W 433 | # W-MSA/SW-MSA 434 | nW = H * W / self.window_size / self.window_size 435 | flops += nW * self.attn.flops(self.window_size * self.window_size) 436 | # mlp 437 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 438 | # norm2 439 | flops += self.dim * H * W 440 | return flops 441 | 442 | 443 | class PatchMerging(nn.Module): 444 | r"""Patch Merging Layer. 445 | Args: 446 | input_resolution (tuple[int]): Resolution of input feature. 447 | dim (int): Number of input channels. 448 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 449 | """ 450 | 451 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 452 | super().__init__() 453 | self.input_resolution = input_resolution 454 | self.dim = dim 455 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 456 | self.norm = norm_layer(2 * dim) 457 | 458 | def forward(self, x): 459 | """ 460 | x: B, H*W, C 461 | """ 462 | H, W = self.input_resolution 463 | B, L, C = x.shape 464 | assert L == H * W, "input feature has wrong size" 465 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 466 | 467 | x = x.view(B, H, W, C) 468 | 469 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 470 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 471 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 472 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 473 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 474 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 475 | 476 | x = self.reduction(x) 477 | x = self.norm(x) 478 | 479 | return x 480 | 481 | def extra_repr(self) -> str: 482 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 483 | 484 | def flops(self): 485 | H, W = self.input_resolution 486 | flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 487 | flops += H * W * self.dim // 2 488 | return flops 489 | 490 | 491 | class BasicLayer(nn.Module): 492 | """A basic Swin Transformer layer for one stage. 493 | Args: 494 | dim (int): Number of input channels. 495 | input_resolution (tuple[int]): Input resolution. 496 | depth (int): Number of blocks. 497 | num_heads (int): Number of attention heads. 498 | window_size (int): Local window size. 499 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 500 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 501 | drop (float, optional): Dropout rate. Default: 0.0 502 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 503 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 504 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 505 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 506 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 507 | pretrained_window_size (int): Local window size in pre-training. 508 | """ 509 | 510 | def __init__( 511 | self, 512 | dim, 513 | input_resolution, 514 | depth, 515 | num_heads, 516 | window_size, 517 | mlp_ratio=4.0, 518 | qkv_bias=True, 519 | drop=0.0, 520 | attn_drop=0.0, 521 | drop_path=0.0, 522 | norm_layer=nn.LayerNorm, 523 | downsample=None, 524 | use_checkpoint=False, 525 | pretrained_window_size=0, 526 | ): 527 | super().__init__() 528 | self.dim = dim 529 | self.input_resolution = input_resolution 530 | self.depth = depth 531 | self.use_checkpoint = use_checkpoint 532 | 533 | # build blocks 534 | self.blocks = nn.ModuleList( 535 | [ 536 | SwinTransformerBlock( 537 | dim=dim, 538 | input_resolution=input_resolution, 539 | num_heads=num_heads, 540 | window_size=window_size, 541 | shift_size=0 if (i % 2 == 0) else window_size // 2, 542 | mlp_ratio=mlp_ratio, 543 | qkv_bias=qkv_bias, 544 | drop=drop, 545 | attn_drop=attn_drop, 546 | drop_path=( 547 | drop_path[i] if isinstance(drop_path, list) else drop_path 548 | ), 549 | norm_layer=norm_layer, 550 | pretrained_window_size=pretrained_window_size, 551 | ) 552 | for i in range(depth) 553 | ] 554 | ) 555 | 556 | # patch merging layer 557 | if downsample is not None: 558 | self.downsample = downsample( 559 | input_resolution, dim=dim, norm_layer=norm_layer 560 | ) 561 | else: 562 | self.downsample = None 563 | 564 | def forward(self, x, x_size): 565 | for blk in self.blocks: 566 | if self.use_checkpoint: 567 | x = checkpoint.checkpoint(blk, x, x_size) 568 | else: 569 | x = blk(x, x_size) 570 | if self.downsample is not None: 571 | x = self.downsample(x) 572 | return x 573 | 574 | def extra_repr(self) -> str: 575 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 576 | 577 | def flops(self): 578 | flops = 0 579 | for blk in self.blocks: 580 | flops += blk.flops() 581 | if self.downsample is not None: 582 | flops += self.downsample.flops() 583 | return flops 584 | 585 | def _init_respostnorm(self): 586 | for blk in self.blocks: 587 | nn.init.constant_(blk.norm1.bias, 0) 588 | nn.init.constant_(blk.norm1.weight, 0) 589 | nn.init.constant_(blk.norm2.bias, 0) 590 | nn.init.constant_(blk.norm2.weight, 0) 591 | 592 | 593 | class PatchEmbed(nn.Module): 594 | r"""Image to Patch Embedding 595 | Args: 596 | img_size (int): Image size. Default: 224. 597 | patch_size (int): Patch token size. Default: 4. 598 | in_chans (int): Number of input image channels. Default: 3. 599 | embed_dim (int): Number of linear projection output channels. Default: 96. 600 | norm_layer (nn.Module, optional): Normalization layer. Default: None 601 | """ 602 | 603 | def __init__( 604 | self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None 605 | ): 606 | super().__init__() 607 | img_size = to_2tuple(img_size) 608 | patch_size = to_2tuple(patch_size) 609 | patches_resolution = [ 610 | img_size[0] // patch_size[0], 611 | img_size[1] // patch_size[1], 612 | ] 613 | self.img_size = img_size 614 | self.patch_size = patch_size 615 | self.patches_resolution = patches_resolution 616 | self.num_patches = patches_resolution[0] * patches_resolution[1] 617 | 618 | self.in_chans = in_chans 619 | self.embed_dim = embed_dim 620 | 621 | self.proj = nn.Conv2d( 622 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size 623 | ) 624 | if norm_layer is not None: 625 | self.norm = norm_layer(embed_dim) 626 | else: 627 | self.norm = None 628 | 629 | def forward(self, x): 630 | B, C, H, W = x.shape 631 | # FIXME look at relaxing size constraints 632 | # assert H == self.img_size[0] and W == self.img_size[1], 633 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 634 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 635 | if self.norm is not None: 636 | x = self.norm(x) 637 | return x 638 | 639 | def flops(self): 640 | Ho, Wo = self.patches_resolution 641 | flops = ( 642 | Ho 643 | * Wo 644 | * self.embed_dim 645 | * self.in_chans 646 | * (self.patch_size[0] * self.patch_size[1]) 647 | ) 648 | if self.norm is not None: 649 | flops += Ho * Wo * self.embed_dim 650 | return flops 651 | 652 | 653 | class RSTB(nn.Module): 654 | """Residual Swin Transformer Block (RSTB). 655 | 656 | Args: 657 | dim (int): Number of input channels. 658 | input_resolution (tuple[int]): Input resolution. 659 | depth (int): Number of blocks. 660 | num_heads (int): Number of attention heads. 661 | window_size (int): Local window size. 662 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 663 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 664 | drop (float, optional): Dropout rate. Default: 0.0 665 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 666 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 667 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 668 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 669 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 670 | img_size: Input image size. 671 | patch_size: Patch size. 672 | resi_connection: The convolutional block before residual connection. 673 | """ 674 | 675 | def __init__( 676 | self, 677 | dim, 678 | input_resolution, 679 | depth, 680 | num_heads, 681 | window_size, 682 | mlp_ratio=4.0, 683 | qkv_bias=True, 684 | drop=0.0, 685 | attn_drop=0.0, 686 | drop_path=0.0, 687 | norm_layer=nn.LayerNorm, 688 | downsample=None, 689 | use_checkpoint=False, 690 | img_size=224, 691 | patch_size=4, 692 | resi_connection="1conv", 693 | ): 694 | super(RSTB, self).__init__() 695 | 696 | self.dim = dim 697 | self.input_resolution = input_resolution 698 | 699 | self.residual_group = BasicLayer( 700 | dim=dim, 701 | input_resolution=input_resolution, 702 | depth=depth, 703 | num_heads=num_heads, 704 | window_size=window_size, 705 | mlp_ratio=mlp_ratio, 706 | qkv_bias=qkv_bias, 707 | drop=drop, 708 | attn_drop=attn_drop, 709 | drop_path=drop_path, 710 | norm_layer=norm_layer, 711 | downsample=downsample, 712 | use_checkpoint=use_checkpoint, 713 | ) 714 | 715 | if resi_connection == "1conv": 716 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1) 717 | elif resi_connection == "3conv": 718 | # to save parameters and memory 719 | self.conv = nn.Sequential( 720 | nn.Conv2d(dim, dim // 4, 3, 1, 1), 721 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 722 | nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), 723 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 724 | nn.Conv2d(dim // 4, dim, 3, 1, 1), 725 | ) 726 | 727 | self.patch_embed = PatchEmbed( 728 | img_size=img_size, 729 | patch_size=patch_size, 730 | in_chans=dim, 731 | embed_dim=dim, 732 | norm_layer=None, 733 | ) 734 | 735 | self.patch_unembed = PatchUnEmbed( 736 | img_size=img_size, 737 | patch_size=patch_size, 738 | in_chans=dim, 739 | embed_dim=dim, 740 | norm_layer=None, 741 | ) 742 | 743 | def forward(self, x, x_size): 744 | return ( 745 | self.patch_embed( 746 | self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size)) 747 | ) 748 | + x 749 | ) 750 | 751 | def flops(self): 752 | flops = 0 753 | flops += self.residual_group.flops() 754 | H, W = self.input_resolution 755 | flops += H * W * self.dim * self.dim * 9 756 | flops += self.patch_embed.flops() 757 | flops += self.patch_unembed.flops() 758 | 759 | return flops 760 | 761 | 762 | class PatchUnEmbed(nn.Module): 763 | r"""Image to Patch Unembedding 764 | 765 | Args: 766 | img_size (int): Image size. Default: 224. 767 | patch_size (int): Patch token size. Default: 4. 768 | in_chans (int): Number of input image channels. Default: 3. 769 | embed_dim (int): Number of linear projection output channels. Default: 96. 770 | norm_layer (nn.Module, optional): Normalization layer. Default: None 771 | """ 772 | 773 | def __init__( 774 | self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None 775 | ): 776 | super().__init__() 777 | img_size = to_2tuple(img_size) 778 | patch_size = to_2tuple(patch_size) 779 | patches_resolution = [ 780 | img_size[0] // patch_size[0], 781 | img_size[1] // patch_size[1], 782 | ] 783 | self.img_size = img_size 784 | self.patch_size = patch_size 785 | self.patches_resolution = patches_resolution 786 | self.num_patches = patches_resolution[0] * patches_resolution[1] 787 | 788 | self.in_chans = in_chans 789 | self.embed_dim = embed_dim 790 | 791 | def forward(self, x, x_size): 792 | B, HW, C = x.shape 793 | x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C 794 | return x 795 | 796 | def flops(self): 797 | flops = 0 798 | return flops 799 | 800 | 801 | class Upsample(nn.Sequential): 802 | """Upsample module. 803 | 804 | Args: 805 | scale (int): Scale factor. Supported scales: 2^n and 3. 806 | num_feat (int): Channel number of intermediate features. 807 | """ 808 | 809 | def __init__(self, scale, num_feat): 810 | m = [] 811 | if (scale & (scale - 1)) == 0: # scale = 2^n 812 | for _ in range(int(math.log(scale, 2))): 813 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 814 | m.append(nn.PixelShuffle(2)) 815 | elif scale == 3: 816 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 817 | m.append(nn.PixelShuffle(3)) 818 | else: 819 | raise ValueError( 820 | f"scale {scale} is not supported. " "Supported scales: 2^n and 3." 821 | ) 822 | super(Upsample, self).__init__(*m) 823 | 824 | 825 | class Upsample_hf(nn.Sequential): 826 | """Upsample module. 827 | 828 | Args: 829 | scale (int): Scale factor. Supported scales: 2^n and 3. 830 | num_feat (int): Channel number of intermediate features. 831 | """ 832 | 833 | def __init__(self, scale, num_feat): 834 | m = [] 835 | if (scale & (scale - 1)) == 0: # scale = 2^n 836 | for _ in range(int(math.log(scale, 2))): 837 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 838 | m.append(nn.PixelShuffle(2)) 839 | elif scale == 3: 840 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 841 | m.append(nn.PixelShuffle(3)) 842 | else: 843 | raise ValueError( 844 | f"scale {scale} is not supported. " "Supported scales: 2^n and 3." 845 | ) 846 | super(Upsample_hf, self).__init__(*m) 847 | 848 | 849 | class UpsampleOneStep(nn.Sequential): 850 | """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) 851 | Used in lightweight SR to save parameters. 852 | 853 | Args: 854 | scale (int): Scale factor. Supported scales: 2^n and 3. 855 | num_feat (int): Channel number of intermediate features. 856 | 857 | """ 858 | 859 | def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): 860 | self.num_feat = num_feat 861 | self.input_resolution = input_resolution 862 | m = [] 863 | m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) 864 | m.append(nn.PixelShuffle(scale)) 865 | super(UpsampleOneStep, self).__init__(*m) 866 | 867 | def flops(self): 868 | H, W = self.input_resolution 869 | flops = H * W * self.num_feat * 3 * 9 870 | return flops 871 | 872 | 873 | class Swin2SR(nn.Module): 874 | r"""Swin2SR 875 | A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. 876 | 877 | Args: 878 | img_size (int | tuple(int)): Input image size. Default 64 879 | patch_size (int | tuple(int)): Patch size. Default: 1 880 | in_chans (int): Number of input image channels. Default: 3 881 | embed_dim (int): Patch embedding dimension. Default: 96 882 | depths (tuple(int)): Depth of each Swin Transformer layer. 883 | num_heads (tuple(int)): Number of attention heads in different layers. 884 | window_size (int): Window size. Default: 7 885 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 886 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 887 | drop_rate (float): Dropout rate. Default: 0 888 | attn_drop_rate (float): Attention dropout rate. Default: 0 889 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 890 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 891 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 892 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 893 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 894 | upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction 895 | upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None 896 | resi_connection: The convolutional block before residual connection. '1conv'/'3conv' 897 | """ 898 | 899 | def __init__( 900 | self, 901 | img_size=64, 902 | patch_size=1, 903 | in_channels=3, 904 | out_channels=3, 905 | embed_dim=96, 906 | depths=[6, 6, 6, 6], 907 | num_heads=[6, 6, 6, 6], 908 | window_size=7, 909 | mlp_ratio=4.0, 910 | qkv_bias=True, 911 | drop_rate=0.0, 912 | attn_drop_rate=0.0, 913 | drop_path_rate=0.1, 914 | norm_layer=nn.LayerNorm, 915 | ape=False, 916 | patch_norm=True, 917 | use_checkpoint=False, 918 | upscale=2, 919 | upsampler="", 920 | resi_connection="1conv", 921 | **kwargs, 922 | ): 923 | super(Swin2SR, self).__init__() 924 | num_in_ch = in_channels 925 | num_out_ch = out_channels 926 | num_feat = 64 927 | self.upscale = upscale 928 | self.upsampler = upsampler 929 | self.window_size = window_size 930 | 931 | ##################################################################################################### 932 | ################################### 1, shallow feature extraction ################################### 933 | self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) 934 | 935 | ##################################################################################################### 936 | ################################### 2, deep feature extraction ###################################### 937 | self.num_layers = len(depths) 938 | self.embed_dim = embed_dim 939 | self.ape = ape 940 | self.patch_norm = patch_norm 941 | self.num_features = embed_dim 942 | self.mlp_ratio = mlp_ratio 943 | 944 | # split image into non-overlapping patches 945 | self.patch_embed = PatchEmbed( 946 | img_size=img_size, 947 | patch_size=patch_size, 948 | in_chans=embed_dim, 949 | embed_dim=embed_dim, 950 | norm_layer=norm_layer if self.patch_norm else None, 951 | ) 952 | num_patches = self.patch_embed.num_patches 953 | patches_resolution = self.patch_embed.patches_resolution 954 | self.patches_resolution = patches_resolution 955 | 956 | # merge non-overlapping patches into image 957 | self.patch_unembed = PatchUnEmbed( 958 | img_size=img_size, 959 | patch_size=patch_size, 960 | in_chans=embed_dim, 961 | embed_dim=embed_dim, 962 | norm_layer=norm_layer if self.patch_norm else None, 963 | ) 964 | 965 | # absolute position embedding 966 | if self.ape: 967 | self.absolute_pos_embed = nn.Parameter( 968 | torch.zeros(1, num_patches, embed_dim) 969 | ) 970 | trunc_normal_(self.absolute_pos_embed, std=0.02) 971 | 972 | self.pos_drop = nn.Dropout(p=drop_rate) 973 | 974 | # stochastic depth 975 | dpr = [ 976 | x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) 977 | ] # stochastic depth decay rule 978 | 979 | # build Residual Swin Transformer blocks (RSTB) 980 | self.layers = nn.ModuleList() 981 | for i_layer in range(self.num_layers): 982 | layer = RSTB( 983 | dim=embed_dim, 984 | input_resolution=(patches_resolution[0], patches_resolution[1]), 985 | depth=depths[i_layer], 986 | num_heads=num_heads[i_layer], 987 | window_size=window_size, 988 | mlp_ratio=self.mlp_ratio, 989 | qkv_bias=qkv_bias, 990 | drop=drop_rate, 991 | attn_drop=attn_drop_rate, 992 | drop_path=dpr[ 993 | sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) 994 | ], # no impact on SR results 995 | norm_layer=norm_layer, 996 | downsample=None, 997 | use_checkpoint=use_checkpoint, 998 | img_size=img_size, 999 | patch_size=patch_size, 1000 | resi_connection=resi_connection, 1001 | ) 1002 | self.layers.append(layer) 1003 | 1004 | if self.upsampler == "pixelshuffle_hf": 1005 | self.layers_hf = nn.ModuleList() 1006 | for i_layer in range(self.num_layers): 1007 | layer = RSTB( 1008 | dim=embed_dim, 1009 | input_resolution=(patches_resolution[0], patches_resolution[1]), 1010 | depth=depths[i_layer], 1011 | num_heads=num_heads[i_layer], 1012 | window_size=window_size, 1013 | mlp_ratio=self.mlp_ratio, 1014 | qkv_bias=qkv_bias, 1015 | drop=drop_rate, 1016 | attn_drop=attn_drop_rate, 1017 | drop_path=dpr[ 1018 | sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) 1019 | ], # no impact on SR results 1020 | norm_layer=norm_layer, 1021 | downsample=None, 1022 | use_checkpoint=use_checkpoint, 1023 | img_size=img_size, 1024 | patch_size=patch_size, 1025 | resi_connection=resi_connection, 1026 | ) 1027 | self.layers_hf.append(layer) 1028 | 1029 | self.norm = norm_layer(self.num_features) 1030 | 1031 | # build the last conv layer in deep feature extraction 1032 | if resi_connection == "1conv": 1033 | self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) 1034 | elif resi_connection == "3conv": 1035 | # to save parameters and memory 1036 | self.conv_after_body = nn.Sequential( 1037 | nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), 1038 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 1039 | nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), 1040 | nn.LeakyReLU(negative_slope=0.2, inplace=True), 1041 | nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1), 1042 | ) 1043 | 1044 | ##################################################################################################### 1045 | ################################ 3, high quality image reconstruction ################################ 1046 | if self.upsampler == "pixelshuffle": 1047 | # for classical SR 1048 | self.conv_before_upsample = nn.Sequential( 1049 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1050 | ) 1051 | self.upsample = Upsample(upscale, num_feat) 1052 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1053 | elif self.upsampler == "pixelshuffle_aux": 1054 | self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 1055 | self.conv_before_upsample = nn.Sequential( 1056 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1057 | ) 1058 | self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1059 | self.conv_after_aux = nn.Sequential( 1060 | nn.Conv2d(3, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1061 | ) 1062 | self.upsample = Upsample(upscale, num_feat) 1063 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1064 | 1065 | elif self.upsampler == "pixelshuffle_hf": 1066 | self.conv_before_upsample = nn.Sequential( 1067 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1068 | ) 1069 | self.upsample = Upsample(upscale, num_feat) 1070 | self.upsample_hf = Upsample_hf(upscale, num_feat) 1071 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1072 | self.conv_first_hf = nn.Sequential( 1073 | nn.Conv2d(num_feat, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True) 1074 | ) 1075 | self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) 1076 | self.conv_before_upsample_hf = nn.Sequential( 1077 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1078 | ) 1079 | self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1080 | 1081 | elif self.upsampler == "pixelshuffledirect": 1082 | # for lightweight SR (to save parameters) 1083 | self.upsample = UpsampleOneStep( 1084 | upscale, 1085 | embed_dim, 1086 | num_out_ch, 1087 | (patches_resolution[0], patches_resolution[1]), 1088 | ) 1089 | elif self.upsampler == "nearest+conv": 1090 | # for real-world SR (less artifacts) 1091 | assert self.upscale == 4, "only support x4 now." 1092 | self.conv_before_upsample = nn.Sequential( 1093 | nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True) 1094 | ) 1095 | self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 1096 | self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 1097 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 1098 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 1099 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 1100 | else: 1101 | # for image denoising and JPEG compression artifact reduction 1102 | self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) 1103 | 1104 | self.apply(self._init_weights) 1105 | 1106 | def _init_weights(self, m): 1107 | if isinstance(m, nn.Linear): 1108 | trunc_normal_(m.weight, std=0.02) 1109 | if isinstance(m, nn.Linear) and m.bias is not None: 1110 | nn.init.constant_(m.bias, 0) 1111 | elif isinstance(m, nn.LayerNorm): 1112 | nn.init.constant_(m.bias, 0) 1113 | nn.init.constant_(m.weight, 1.0) 1114 | 1115 | @torch.jit.ignore 1116 | def no_weight_decay(self): 1117 | return {"absolute_pos_embed"} 1118 | 1119 | @torch.jit.ignore 1120 | def no_weight_decay_keywords(self): 1121 | return {"relative_position_bias_table"} 1122 | 1123 | def check_image_size(self, x): 1124 | _, _, h, w = x.size() 1125 | mod_pad_h = (self.window_size - h % self.window_size) % self.window_size 1126 | mod_pad_w = (self.window_size - w % self.window_size) % self.window_size 1127 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") 1128 | return x 1129 | 1130 | def forward_features(self, x): 1131 | x_size = (x.shape[2], x.shape[3]) 1132 | x = self.patch_embed(x) 1133 | if self.ape: 1134 | x = x + self.absolute_pos_embed 1135 | x = self.pos_drop(x) 1136 | 1137 | for layer in self.layers: 1138 | x = layer(x, x_size) 1139 | 1140 | x = self.norm(x) # B L C 1141 | x = self.patch_unembed(x, x_size) 1142 | 1143 | return x 1144 | 1145 | def forward_features_hf(self, x): 1146 | x_size = (x.shape[2], x.shape[3]) 1147 | x = self.patch_embed(x) 1148 | if self.ape: 1149 | x = x + self.absolute_pos_embed 1150 | x = self.pos_drop(x) 1151 | 1152 | for layer in self.layers_hf: 1153 | x = layer(x, x_size) 1154 | 1155 | x = self.norm(x) # B L C 1156 | x = self.patch_unembed(x, x_size) 1157 | 1158 | return x 1159 | 1160 | def forward(self, x): 1161 | H, W = x.shape[2:] 1162 | x = self.check_image_size(x) 1163 | 1164 | if self.upsampler == "pixelshuffle": 1165 | # for classical SR 1166 | x = self.conv_first(x) 1167 | x = self.conv_after_body(self.forward_features(x)) + x 1168 | x = self.conv_before_upsample(x) 1169 | x = self.conv_last(self.upsample(x)) 1170 | elif self.upsampler == "pixelshuffle_aux": 1171 | bicubic = F.interpolate( 1172 | x, 1173 | size=(H * self.upscale, W * self.upscale), 1174 | mode="bicubic", 1175 | align_corners=False, 1176 | ) 1177 | bicubic = self.conv_bicubic(bicubic) 1178 | x = self.conv_first(x) 1179 | x = self.conv_after_body(self.forward_features(x)) + x 1180 | x = self.conv_before_upsample(x) 1181 | aux = self.conv_aux(x) # b, 3, LR_H, LR_W 1182 | x = self.conv_after_aux(aux) 1183 | x = ( 1184 | self.upsample(x)[:, :, : H * self.upscale, : W * self.upscale] 1185 | + bicubic[:, :, : H * self.upscale, : W * self.upscale] 1186 | ) 1187 | x = self.conv_last(x) 1188 | elif self.upsampler == "pixelshuffle_hf": 1189 | # for classical SR with HF 1190 | x = self.conv_first(x) 1191 | x = self.conv_after_body(self.forward_features(x)) + x 1192 | x_before = self.conv_before_upsample(x) 1193 | x_out = self.conv_last(self.upsample(x_before)) 1194 | 1195 | x_hf = self.conv_first_hf(x_before) 1196 | x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf 1197 | x_hf = self.conv_before_upsample_hf(x_hf) 1198 | x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) 1199 | x = x_out + x_hf 1200 | 1201 | elif self.upsampler == "pixelshuffledirect": 1202 | # for lightweight SR 1203 | x = self.conv_first(x) 1204 | x = self.conv_after_body(self.forward_features(x)) + x 1205 | x = self.upsample(x) 1206 | elif self.upsampler == "nearest+conv": 1207 | # for real-world SR 1208 | x = self.conv_first(x) 1209 | x = self.conv_after_body(self.forward_features(x)) + x 1210 | x = self.conv_before_upsample(x) 1211 | x = self.lrelu( 1212 | self.conv_up1( 1213 | torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") 1214 | ) 1215 | ) 1216 | x = self.lrelu( 1217 | self.conv_up2( 1218 | torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") 1219 | ) 1220 | ) 1221 | x = self.conv_last(self.lrelu(self.conv_hr(x))) 1222 | else: 1223 | # for image denoising and JPEG compression artifact reduction 1224 | x_first = self.conv_first(x) 1225 | res = self.conv_after_body(self.forward_features(x_first)) + x_first 1226 | x = x + self.conv_last(res) 1227 | 1228 | if self.upsampler == "pixelshuffle_aux": 1229 | return x[:, :, : H * self.upscale, : W * self.upscale], aux 1230 | 1231 | elif self.upsampler == "pixelshuffle_hf": 1232 | return ( 1233 | x_out[:, :, : H * self.upscale, : W * self.upscale], 1234 | x[:, :, : H * self.upscale, : W * self.upscale], 1235 | x_hf[:, :, : H * self.upscale, : W * self.upscale], 1236 | ) 1237 | 1238 | else: 1239 | return x[:, :, : H * self.upscale, : W * self.upscale] 1240 | 1241 | def flops(self): 1242 | flops = 0 1243 | H, W = self.patches_resolution 1244 | flops += H * W * 3 * self.embed_dim * 9 1245 | flops += self.patch_embed.flops() 1246 | for i, layer in enumerate(self.layers): 1247 | flops += layer.flops() 1248 | flops += H * W * 3 * self.embed_dim * self.embed_dim 1249 | flops += self.upsample.flops() 1250 | return flops 1251 | -------------------------------------------------------------------------------- /sen2sr/models/tricks.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Tuple 2 | import torch 3 | 4 | 5 | def ideal_filter(shape: Tuple[int, int], cutoff: int) -> torch.Tensor: 6 | """ 7 | Creates an ideal low-pass filter. 8 | 9 | Args: 10 | shape: (rows, cols) of the filter. 11 | cutoff: Cutoff radius for the filter. 12 | 13 | Returns: 14 | torch.Tensor: Normalized ideal filter. 15 | """ 16 | rows, cols = shape 17 | crow, ccol = rows // 2, cols // 2 18 | filter = torch.zeros((rows, cols), dtype=torch.float32) 19 | for u in range(rows): 20 | for v in range(cols): 21 | distance = ((u - crow) ** 2 + (v - ccol) ** 2) ** 0.5 22 | if distance <= cutoff: 23 | filter[u, v] = 1 24 | return filter 25 | 26 | 27 | def butterworth_filter(shape: Tuple[int, int], cutoff: int, order: int) -> torch.Tensor: 28 | """ 29 | Creates a Butterworth low-pass filter. 30 | 31 | Args: 32 | shape: (rows, cols) of the filter. 33 | cutoff: Cutoff frequency. 34 | order: Order of the Butterworth filter. 35 | 36 | Returns: 37 | torch.Tensor: Normalized Butterworth filter. 38 | """ 39 | rows, cols = shape 40 | crow, ccol = rows // 2, cols // 2 41 | filter = torch.zeros((rows, cols), dtype=torch.float32) 42 | for u in range(rows): 43 | for v in range(cols): 44 | distance = ((u - crow) ** 2 + (v - ccol) ** 2) ** 0.5 45 | filter[u, v] = 1 / (1 + (distance / cutoff) ** (2 * order)) 46 | return filter 47 | 48 | 49 | def gaussian_filter(shape: Tuple[int, int], cutoff: int) -> torch.Tensor: 50 | """ 51 | Creates a Gaussian low-pass filter. 52 | 53 | Args: 54 | shape: (rows, cols) of the filter. 55 | cutoff: Standard deviation for the Gaussian filter. 56 | 57 | Returns: 58 | torch.Tensor: Normalized Gaussian filter. 59 | """ 60 | rows, cols = shape 61 | crow, ccol = rows // 2, cols // 2 62 | filter = torch.zeros((rows, cols), dtype=torch.float32) 63 | for u in range(rows): 64 | for v in range(cols): 65 | distance = (u - crow) ** 2 + (v - ccol) ** 2 66 | filter[u, v] = torch.exp(-distance / (2 * (cutoff**2))) 67 | return filter 68 | 69 | 70 | def sigmoid_filter( 71 | shape: Tuple[int, int], cutoff: int, sharpness: float 72 | ) -> torch.Tensor: 73 | """ 74 | Creates a Sigmoid-based low-pass filter. 75 | 76 | Args: 77 | shape: (rows, cols) of the filter. 78 | cutoff: Cutoff frequency. 79 | sharpness: Sharpness of the transition in the filter. 80 | 81 | Returns: 82 | torch.Tensor: Normalized Sigmoid filter. 83 | """ 84 | rows, cols = shape 85 | crow, ccol = rows // 2, cols // 2 86 | filter = torch.zeros((rows, cols), dtype=torch.float32) 87 | for u in range(rows): 88 | for v in range(cols): 89 | distance = ((u - crow) ** 2 + (v - ccol) ** 2) ** 0.5 90 | filter[u, v] = 1 / (1 + torch.exp((distance - cutoff) / sharpness)) 91 | return filter 92 | 93 | 94 | class FourierHardConstraint(torch.nn.Module): 95 | """ 96 | Applies a low-pass Fourier constraint using different filter methods. 97 | 98 | Args: 99 | filter_method: Filter type ('ideal', 'butterworth', 'sigmoid', 'gaussian'). 100 | filter_hyperparameters: Hyperparameters for the chosen filter method. 101 | sr_image_size: Size of the super-resolution image (height, width). 102 | scale_factor: Scale factor for the super-resolution task. 103 | device: Device where the filters and tensors are located. Default is "cpu". 104 | """ 105 | 106 | def __init__( 107 | self, 108 | filter_method: Literal["ideal", "butterworth", "sigmoid", "gaussian"], 109 | filter_hyperparameters: dict, 110 | sr_image_size: tuple, 111 | scale_factor: int, 112 | device: str = "cpu", 113 | low_pass_mask: torch.Tensor = None, 114 | ): 115 | super().__init__() 116 | h, w = sr_image_size 117 | center_h, center_w = h // 2, w // 2 118 | 119 | # Calculate the radius for the low-pass filter based on the scale factor 120 | if filter_method == "ideal": 121 | radius = min(center_h, center_w) // scale_factor 122 | low_pass_mask = ideal_filter((h, w), radius) 123 | elif filter_method == "butterworth": 124 | radius = min(center_h, center_w) // scale_factor 125 | low_pass_mask = butterworth_filter( 126 | (h, w), radius, order=filter_hyperparameters["order"] 127 | ) 128 | elif filter_method == "gaussian": 129 | radius = min(center_h, center_w) // scale_factor 130 | low_pass_mask = gaussian_filter((h, w), radius) 131 | elif filter_method == "sigmoid": 132 | radius = min(center_h, center_w) // scale_factor 133 | low_pass_mask = sigmoid_filter( 134 | (h, w), radius, sharpness=filter_hyperparameters["sharpness"] 135 | ) 136 | else: 137 | raise ValueError(f"Unsupported fourier_method: {filter_method}") 138 | self.low_pass_mask = low_pass_mask.to(device) 139 | self.scale_factor = scale_factor 140 | 141 | def forward(self, lr: torch.Tensor, sr: torch.Tensor) -> torch.Tensor: 142 | """ 143 | Applies the Fourier constraint on the super-resolution image. 144 | 145 | Args: 146 | lr: Low-resolution input tensor. 147 | sr: Super-resolution output tensor. 148 | 149 | Returns: 150 | torch.Tensor: Hybrid image after applying Fourier constraint. 151 | """ 152 | # Upsample the LR image to the HR size 153 | lr_up = torch.nn.functional.interpolate( 154 | lr, size=sr.shape[-2:], mode="bicubic", antialias=True 155 | ) 156 | 157 | # Apply the low-pass filter to the HR image 158 | sr_fft = torch.fft.fftn(sr, dim=(-2, -1)) 159 | lr_fft = torch.fft.fftn(lr_up, dim=(-2, -1)) 160 | 161 | # Shift the zero-frequency component to the center 162 | sr_fft_shifted = torch.fft.fftshift(sr_fft) 163 | lr_fft_shifted = torch.fft.fftshift(lr_fft) 164 | 165 | # High-pass filter is the complement of the low-pass filter 166 | high_pass_mask = 1 - self.low_pass_mask 167 | 168 | # Apply the high-pass filter to the SR image 169 | f1_low = lr_fft_shifted * self.low_pass_mask 170 | f1_high = sr_fft_shifted * high_pass_mask 171 | 172 | # Combine the low-pass and high-pass components 173 | sr_fft_filtered = f1_low + f1_high 174 | 175 | # Inverse FFT to get the filtered SR image 176 | combined_ishift = torch.fft.ifftshift(sr_fft_filtered) 177 | hybrid_image = torch.real(torch.fft.ifft2(combined_ishift)) 178 | 179 | return hybrid_image 180 | 181 | 182 | 183 | class HardConstraint(torch.nn.Module): 184 | """ 185 | Applies a low-pass constraint 186 | 187 | Args: 188 | device: Device where the filters and tensors are located. Default is "cpu". 189 | """ 190 | 191 | def __init__( 192 | self, 193 | low_pass_mask: torch.Tensor, 194 | bands: int | str = "all", 195 | device: str = "cpu", 196 | ): 197 | super().__init__() 198 | self.low_pass_mask = low_pass_mask.to(device) 199 | self.bands = bands 200 | 201 | def forward(self, lr: torch.Tensor, sr: torch.Tensor) -> torch.Tensor: 202 | """ 203 | Applies the Fourier constraint on the super-resolution image. 204 | 205 | Args: 206 | lr: Low-resolution input tensor. 207 | sr: Super-resolution output tensor. 208 | 209 | Returns: 210 | torch.Tensor: Hybrid image after applying Fourier constraint. 211 | """ 212 | # Upsample the LR image to the HR size 213 | if self.bands == "all": 214 | lr_up = torch.nn.functional.interpolate( 215 | lr, size=sr.shape[-2:], mode="bicubic", antialias=True 216 | ) 217 | else: 218 | lr_up = torch.nn.functional.interpolate( 219 | lr[:, self.bands], size=sr.shape[-2:], mode="bicubic", antialias=True 220 | ) 221 | 222 | # Apply the low-pass filter to the HR image 223 | sr_fft = torch.fft.fftn(sr, dim=(-2, -1)) 224 | lr_fft = torch.fft.fftn(lr_up, dim=(-2, -1)) 225 | 226 | # Shift the zero-frequency component to the center 227 | sr_fft_shifted = torch.fft.fftshift(sr_fft) 228 | lr_fft_shifted = torch.fft.fftshift(lr_fft) 229 | 230 | # High-pass filter is the complement of the low-pass filter 231 | high_pass_mask = 1 - self.low_pass_mask 232 | 233 | # Apply the high-pass filter to the SR image 234 | f1_low = lr_fft_shifted * self.low_pass_mask 235 | f1_high = sr_fft_shifted * high_pass_mask 236 | 237 | # Combine the low-pass and high-pass components 238 | sr_fft_filtered = f1_low + f1_high 239 | 240 | # Inverse FFT to get the filtered SR image 241 | combined_ishift = torch.fft.ifftshift(sr_fft_filtered) 242 | hybrid_image = torch.real(torch.fft.ifft2(combined_ishift)) 243 | 244 | return hybrid_image 245 | 246 | -------------------------------------------------------------------------------- /sen2sr/nonreference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def srmodel( 6 | sr_model: nn.Module, 7 | hard_constraint: nn.Module, 8 | device: str = "cpu" 9 | ) -> nn.Module: 10 | """ 11 | Wraps a super-resolution (SR) model with a hard constraint module to enforce 12 | physical consistency. 13 | 14 | Parameters 15 | ---------- 16 | sr_model : nn.Module 17 | The base super-resolution model to be applied on the input tensor. 18 | 19 | hard_constraint : nn.Module 20 | A non-trainable constraint module that adjusts the output of the SR model 21 | based on prior knowledge or application-specific rules. 22 | 23 | device : str, optional 24 | Target device for model execution (e.g., "cpu" or "cuda"), by default "cpu". 25 | 26 | Returns 27 | ------- 28 | nn.Module 29 | A composite model that applies the SR model and enforces the hard constraint during the forward pass. 30 | """ 31 | 32 | # Move the SR model to the target device 33 | sr_model = sr_model.to(device) 34 | 35 | # Prepare the hard constraint module: evaluation mode, no gradients, moved to device 36 | hard_constraint = hard_constraint.eval() 37 | for param in hard_constraint.parameters(): 38 | param.requires_grad = False 39 | hard_constraint = hard_constraint.to(device) 40 | 41 | class SRModelWithConstraint(nn.Module): 42 | """ 43 | Composite model applying SR followed by a hard constraint module. 44 | """ 45 | def __init__(self, sr_model: nn.Module, hard_constraint: nn.Module): 46 | super().__init__() 47 | self.sr_model = sr_model 48 | self.hard_constraint = hard_constraint 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | """ 52 | Forward pass: apply SR model, then enforce hard constraint. 53 | 54 | Parameters 55 | ---------- 56 | x : torch.Tensor 57 | Input tensor representing low-resolution imagery. 58 | 59 | Returns 60 | ------- 61 | torch.Tensor 62 | Super-resolved and constraint-corrected output. 63 | """ 64 | 65 | # Apply SR model 66 | sr = self.sr_model(x) 67 | 68 | # Results must be always positive 69 | sr = torch.clamp(sr, min=0.0) 70 | 71 | return self.hard_constraint(x, sr) 72 | 73 | return SRModelWithConstraint(sr_model, hard_constraint) 74 | -------------------------------------------------------------------------------- /sen2sr/referencex2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def resample_sentinel2_bands(X: torch.Tensor) -> torch.Tensor: 7 | """ 8 | Resamples 20m Sentinel-2 bands to 10m resolution. 9 | 10 | Sentinel-2 provides bands at different spatial resolutions. 11 | This function first upsamples the 20m bands to 10m resolution using a 12 | two-step process: nearest-neighbor to 20m (to match spatial alignment), 13 | followed by bilinear interpolation to 10m. 14 | 15 | Args: 16 | X (torch.Tensor): Input tensor of shape (B, C, H, W), where C includes 17 | all Sentinel-2 bands in a specific order. 18 | 19 | Returns: 20 | torch.Tensor: Resampled tensor of shape (B, 10, H, W), combining 10m-native 21 | and upsampled 20m bands. 22 | """ 23 | # Indices of 20m and 10m bands 24 | indices_20m = [3, 4, 5, 7, 8, 9] # B5, B6, B7, B8A, B11, B12 25 | indices_10m = [0, 1, 2, 6] # B2, B3, B4, B8 26 | 27 | # Separate bands by resolution 28 | bands_20m = X[:, indices_20m] 29 | bands_10m = X[:, indices_10m] 30 | 31 | # Step 1: Downsample 20m bands to 10m pixel count (for alignment) 32 | bands_20m_down = F.interpolate(bands_20m, scale_factor=0.5, mode="nearest") 33 | 34 | # Step 2: Upsample to 10m using bilinear interpolation for smoothness 35 | bands_20m_up = F.interpolate( 36 | bands_20m_down, scale_factor=2, mode="bilinear", antialias=True 37 | ) 38 | 39 | # Concatenate upsampled 20m bands with native 10m bands 40 | return torch.cat([bands_20m_up, bands_10m], dim=1) 41 | 42 | 43 | def reconstruct_sentinel2_stack( 44 | b10m: torch.Tensor, b20m: torch.Tensor 45 | ) -> torch.Tensor: 46 | """ 47 | Reconstructs a 10-band Sentinel-2-like stack from separate 10m and 20m sources. 48 | 49 | Args: 50 | b10m (torch.Tensor): Tensor of shape (B, 4, H, W) containing B2, B3, B4, B8. 51 | b20m (torch.Tensor): Tensor of shape (B, 6, H, W) containing B5, B6, B7, B8A, B11, B12. 52 | 53 | Returns: 54 | torch.Tensor: Tensor of shape (B, 10, H, W), mimicking the original band order. 55 | """ 56 | return torch.stack( 57 | [ 58 | b10m[:, 0], # B2 (Blue) 59 | b10m[:, 1], # B3 (Green) 60 | b10m[:, 2], # B4 (Red) 61 | b20m[:, 0], # B5 (Red Edge 1) 62 | b20m[:, 1], # B6 (Red Edge 2) 63 | b20m[:, 2], # B7 (Red Edge 3) 64 | b10m[:, 3], # B8 (NIR) 65 | b20m[:, 3], # B8A (Narrow NIR) 66 | b20m[:, 4], # B11 (SWIR 1) 67 | b20m[:, 5], # B12 (SWIR 2) 68 | ], 69 | dim=1, 70 | ) 71 | 72 | 73 | def srmodel( 74 | sr_model: nn.Module, 75 | hard_constraint: nn.Module, 76 | device: str = "cpu", 77 | ) -> nn.Module: 78 | """ 79 | Wraps a super-resolution model with band-specific preprocessing and postprocessing. 80 | 81 | The function returns a composite model that: 82 | 1. Resamples the input Sentinel-2 bands to a uniform 10m resolution. 83 | 2. Applies a super-resolution model. 84 | 3. Applies a hard constraint or refinement module. 85 | 4. Reconstructs the final 10-band Sentinel-2 output stack. 86 | 87 | Args: 88 | sr_model (nn.Module): A super-resolution model that takes a 10-band input. 89 | hard_constraint (nn.Module): A postprocessing constraint module applied after SR. 90 | device (str): The device to move models to ("cpu" or "cuda"). 91 | 92 | Returns: 93 | nn.Module: A callable model ready for inference. 94 | """ 95 | sr_model.to(device) 96 | hard_constraint.to(device) 97 | hard_constraint.eval() 98 | for param in hard_constraint.parameters(): 99 | param.requires_grad = False 100 | 101 | class SRModel(nn.Module): 102 | def __init__(self, sr_model: nn.Module, hard_constraint: nn.Module): 103 | super().__init__() 104 | self.sr_model = sr_model 105 | self.hard_constraint = hard_constraint 106 | 107 | def forward(self, x: torch.Tensor) -> torch.Tensor: 108 | """ 109 | Forward pass of the composite super-resolution model. 110 | 111 | Args: 112 | x (torch.Tensor): Input Sentinel-2 tensor of shape (B, 10, H, W). 113 | 114 | Returns: 115 | torch.Tensor: Super-resolved tensor of shape (B, 10, H, W). 116 | """ 117 | # Extract original RGB + NIR (10m bands) for reconstruction 118 | rgbn = x[:, [0, 1, 2, 6]].clone() 119 | 120 | # Resample full input to 10m uniform resolution 121 | x_resampled = resample_sentinel2_bands(x) 122 | 123 | # Apply SR model 124 | sr_out = self.sr_model(x_resampled) 125 | 126 | # Results must be always positive 127 | sr_out = torch.clamp(sr_out, min=0.0) 128 | 129 | # Apply hard constraint/refinement 130 | sr_out = self.hard_constraint(x_resampled, sr_out) 131 | 132 | # Reconstruct full 10-band stack 133 | return reconstruct_sentinel2_stack(rgbn, sr_out) 134 | 135 | return SRModel(sr_model, hard_constraint) 136 | -------------------------------------------------------------------------------- /sen2sr/referencex4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def srmodel( 6 | sr_model: nn.Module, 7 | f2_model: nn.Module, 8 | reference_model_x4: nn.Module, 9 | reference_model_hard_constraint_x4: nn.Module, 10 | device: str = "cpu", 11 | ) -> nn.Module: 12 | """ 13 | Wraps a multi-stage super-resolution (SR) model for Sentinel-2 imagery. 14 | 15 | This wrapper performs a staged upsampling pipeline: 16 | 1. Uses a 2x SR model (`f2_model`) to bring 20m bands to 10m. 17 | 2. Applies a 4x SR model (`sr_model`) to upsample 10m RGBN bands to 2.5m. 18 | 3. Uses another 4x fusion model (`reference_model_x4`) to enhance SWIR bands. 19 | 4. A hard constraint model (`reference_model_hard_constraint_x4`) ensures spectral consistency. 20 | 21 | Args: 22 | sr_model (nn.Module): 4x SR model for RGBN bands. 23 | f2_model (nn.Module): 2x SR model to upsample 20m bands to 10m. 24 | reference_model_x4 (nn.Module): 4x SR model for SWIR fusion. 25 | reference_model_hard_constraint_x4 (nn.Module): Spectral constraint module for SWIR bands. 26 | device (str): Device to place models on ('cpu' or 'cuda'). 27 | 28 | Returns: 29 | nn.Module: A callable module performing the full multi-band super-resolution. 30 | """ 31 | 32 | # Move fusion models to the correct device 33 | reference_model_x4 = reference_model_x4.to(device) 34 | reference_model_hard_constraint_x4 = reference_model_hard_constraint_x4.to(device) 35 | reference_model_hard_constraint_x4.eval() 36 | for param in reference_model_hard_constraint_x4.parameters(): 37 | param.requires_grad = False 38 | 39 | class SRModel(nn.Module): 40 | """ 41 | Full Sentinel-2 super-resolution pipeline. 42 | """ 43 | def __init__( 44 | self, 45 | sr_model: nn.Module, 46 | f2_model: nn.Module, 47 | f4_model: nn.Module, 48 | f4_hard_constraint: nn.Module, 49 | ): 50 | super().__init__() 51 | self.sr_model = sr_model # SR model for RGBN bands 52 | self.f2_model = f2_model # 2x SR model for initial upsampling 53 | self.f4_model = f4_model # Fusion model for SWIR bands 54 | self.f4_hard_constraint = f4_hard_constraint # Spectral constraint for SWIR 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | """ 58 | Args: 59 | x (torch.Tensor): Input image tensor of shape (B, 10, H, W), 60 | where 10 Sentinel-2 bands are ordered as: 61 | [B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12] 62 | 63 | Returns: 64 | torch.Tensor: Super-resolved image of shape (B, 10, 4*H, 4*W) 65 | """ 66 | # Define band indices 67 | bands_20m = [3, 4, 5, 7, 8, 9] # RSWIR: B5, B6, B7, B8A, B11, B12 68 | bands_10m = [2, 1, 0, 6] # RGBN: B4, B3, B2, B8 69 | 70 | # Step 1: Upsample RSWIR bands from 20m to 10m using reference SR model 71 | allbands10m = self.f2_model(x) 72 | 73 | # Extract and upsample RSWIR (10m → 2.5m) using bilinear interpolation 74 | rsiwr_10m = allbands10m[:, bands_20m] 75 | rsiwr_2dot5m_bilinear = nn.functional.interpolate( 76 | rsiwr_10m, scale_factor=4, mode="bilinear", antialias=True 77 | ) 78 | 79 | # Step 2: Super-resolve RGBN bands (10m → 2.5m) with learned model 80 | rgbn_input = x[:, bands_10m] 81 | rgbn_2dot5m = self.sr_model(rgbn_input) 82 | 83 | # Reorder from RGBN → BGRN (e.g., for consistency with downstream expectations) 84 | rgbn_2dot5m = rgbn_2dot5m[:, [2, 1, 0, 3]] 85 | 86 | # Step 3: Apply fusion model to enhance RSWIR bands (10m → 2.5m) 87 | fusion_input = torch.cat([rsiwr_2dot5m_bilinear, rgbn_2dot5m], dim=1) 88 | rswirs_2dot5 = self.f4_model(fusion_input) 89 | 90 | # Results must be always positive 91 | rswirs_2dot5 = torch.clamp(rswirs_2dot5, min=0.0) 92 | 93 | # Step 4: Apply hard constraint model to ensure spectral consistency 94 | rswirs_2dot5 = self.f4_hard_constraint(rsiwr_10m, rswirs_2dot5) 95 | 96 | # Final step: Reconstruct full band stack in Sentinel-2 order 97 | return torch.stack([ 98 | rgbn_2dot5m[:, 0], # B2 (Blue) 99 | rgbn_2dot5m[:, 1], # B3 (Green) 100 | rgbn_2dot5m[:, 2], # B4 (Red) 101 | rswirs_2dot5[:, 0], # B5 (Red Edge 1) 102 | rswirs_2dot5[:, 1], # B6 (Red Edge 2) 103 | rswirs_2dot5[:, 2], # B7 (Red Edge 3) 104 | rgbn_2dot5m[:, 3], # B8 (NIR) 105 | rswirs_2dot5[:, 3], # B8A (Narrow NIR) 106 | rswirs_2dot5[:, 4], # B11 (SWIR 1) 107 | rswirs_2dot5[:, 5], # B12 (SWIR 2) 108 | ], dim=1) 109 | 110 | return SRModel(sr_model, f2_model, reference_model_x4, reference_model_hard_constraint_x4) 111 | -------------------------------------------------------------------------------- /sen2sr/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | from tqdm import tqdm 4 | 5 | 6 | def define_iteration(dimension: tuple, chunk_size: int, overlap: int = 0): 7 | """ 8 | Define the iteration strategy to walk through the image with an overlap. 9 | 10 | Args: 11 | dimension (tuple): Dimension of the S2 image. 12 | chunk_size (int): Size of the chunks. 13 | overlap (int): Size of the overlap between chunks. 14 | 15 | Returns: 16 | list: List of chunk coordinates. 17 | """ 18 | dimy, dimx = dimension 19 | 20 | if chunk_size > max(dimx, dimy): 21 | return [(0, 0)] 22 | 23 | # Adjust step to create overlap 24 | y_step = chunk_size - overlap 25 | x_step = chunk_size - overlap 26 | 27 | # Generate initial chunk positions 28 | iterchunks = list(itertools.product(range(0, dimy, y_step), range(0, dimx, x_step))) 29 | 30 | # Fix chunks at the edges to stay within bounds 31 | iterchunks_fixed = fix_lastchunk( 32 | iterchunks=iterchunks, s2dim=dimension, chunk_size=chunk_size 33 | ) 34 | 35 | return iterchunks_fixed 36 | 37 | 38 | def fix_lastchunk(iterchunks, s2dim, chunk_size): 39 | """ 40 | Fix the last chunk of the overlay to ensure it aligns with image boundaries. 41 | 42 | Args: 43 | iterchunks (list): List of chunks created by itertools.product. 44 | s2dim (tuple): Dimension of the S2 images. 45 | chunk_size (int): Size of the chunks. 46 | 47 | Returns: 48 | list: List of adjusted chunk coordinates. 49 | """ 50 | itercontainer = [] 51 | 52 | for index_i, index_j in iterchunks: 53 | # Adjust if the chunk extends beyond bounds 54 | if index_i + chunk_size > s2dim[0]: 55 | index_i = max(s2dim[0] - chunk_size, 0) 56 | if index_j + chunk_size > s2dim[1]: 57 | index_j = max(s2dim[1] - chunk_size, 0) 58 | 59 | itercontainer.append((index_i, index_j)) 60 | 61 | return itercontainer 62 | 63 | 64 | def predict_large( 65 | X: torch.Tensor, 66 | model: torch.nn.Module, 67 | overlap: int = 32, 68 | ) -> torch.Tensor: 69 | 70 | # Run always in patches of 128x128 with 32 of overlap 71 | nruns = define_iteration( 72 | dimension=(X.shape[1], X.shape[2]), 73 | chunk_size=128, 74 | overlap=overlap, 75 | ) 76 | 77 | # Define the output metadata 78 | for index, point in enumerate(tqdm(nruns)): 79 | 80 | # Read a block of the image 81 | Xchunk = X[:, point[1] : (point[1] + 128), point[0] : (point[0] + 128)] 82 | 83 | # Predict the SR 84 | result = model(Xchunk[None]).squeeze(0) 85 | 86 | # If index is 0, create the output image 87 | if index == 0: 88 | res_n = result.shape[1] // 128 89 | output = torch.zeros( 90 | (result.shape[0], X.shape[1] * res_n, X.shape[1] * res_n), 91 | dtype=result.dtype, 92 | device="cpu", 93 | ) 94 | 95 | 96 | # Define the offset in the output space 97 | # If the point is at the border, the offset is 0 98 | # otherwise consider the overlap 99 | offset_x = point[0] * res_n + overlap * res_n // 2 100 | offset_y = point[1] * res_n + overlap * res_n // 2 101 | if point[0] == 0: 102 | offset_x = 0 103 | if point[1] == 0: 104 | offset_y = 0 105 | 106 | 107 | 108 | # Our output is always 128*res_n x 128*res_n, 109 | # Crop this batch output in order to fit in the 110 | # output image 111 | # 112 | # There is three conditions: 113 | # - The patch is at the initial borders 114 | # - The patch is at the final borders 115 | # - The patch is in the middle of the image 116 | skip = overlap * res_n // 2 117 | 118 | # Work in the X axis 119 | if offset_x == 0: # Initial border 120 | length_x = 128 * res_n - skip 121 | result = result[:, :, :length_x] 122 | elif (offset_x + 128) == X.shape[1]: 123 | length_x = 128 * res_n 124 | result = result[:, :, :length_x] 125 | else: 126 | skip = overlap * res_n // 2 127 | length_x = 128 * res_n - skip 128 | result = result[:, :, skip:(128 * res_n)] 129 | 130 | # Work in the Y axis 131 | if offset_y == 0: 132 | length_y = 128 * res_n - skip 133 | result = result[:, :length_y, :] 134 | elif (offset_y + 128) == X.shape[2]: 135 | length_y = 128 * res_n 136 | result = result[:, :length_y, :] 137 | else: 138 | length_y = 128 * res_n - skip 139 | result = result[:, skip:(128 * res_n), :] 140 | 141 | # Write the result in the output image 142 | output[:, offset_y:(offset_y + length_y), offset_x:(offset_x + length_x)] = result.detach().cpu() 143 | 144 | return output 145 | 146 | -------------------------------------------------------------------------------- /sen2sr/xai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ESAOpenSR/SEN2SR/a8e108c9e3d5974b7167a8c79417cb610bc73c5c/sen2sr/xai/__init__.py -------------------------------------------------------------------------------- /sen2sr/xai/lam.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from sen2sr.xai.utils import gini, vis_saliency_kde 8 | 9 | 10 | def attribution_objective(attr_func, h: int, w: int, window: int = 16): 11 | """ 12 | Creates an objective function to calculate attribution within a specified window 13 | at given coordinates using an attribution function. 14 | 15 | Args: 16 | attr_func (Callable): A function that calculates attributions for an image. 17 | h (int): The top coordinate of the window within the image. 18 | w (int): The left coordinate of the window within the image. 19 | window (int, optional): The size of the square window. Defaults to 16. 20 | 21 | Returns: 22 | Callable: A function that takes an image as input and computes the attribution 23 | at the specified window location. 24 | """ 25 | 26 | def calculate_objective(image): 27 | """ 28 | Computes the attribution for a specified window within the given image. 29 | 30 | Args: 31 | image (torch.Tensor): A tensor representing the input image. 32 | 33 | Returns: 34 | torch.Tensor: The calculated attribution value within the specified window. 35 | """ 36 | return attr_func(image, h, w, window=window) 37 | 38 | return calculate_objective 39 | 40 | 41 | def attr_grad( 42 | tensor: torch.Tensor, 43 | h: int, 44 | w: int, 45 | window: int = 8, 46 | reduce: str = "sum", 47 | scale: float = 1.0, 48 | ) -> torch.Tensor: 49 | """ 50 | Computes the gradient magnitude within a specified window of a 4D tensor and reduces the result. 51 | 52 | Args: 53 | tensor (torch.Tensor): A 4D tensor of shape (batch_size, channels, height, width). 54 | h (int): Starting height position of the window within the tensor. 55 | w (int): Starting width position of the window within the tensor. 56 | window (int, optional): The size of the square window to extract. Defaults to 8. 57 | reduce (str, optional): The reduction operation to apply to the window ('sum' or 'mean'). Defaults to 'sum'. 58 | scale (float, optional): Scaling factor to apply to the gradient magnitude. Defaults to 1.0. 59 | 60 | Returns: 61 | torch.Tensor: The reduced gradient magnitude for the specified window. 62 | """ 63 | 64 | # Get tensor dimensions 65 | height = tensor.size(2) 66 | width = tensor.size(3) 67 | 68 | # Compute horizontal gradients by taking the difference between adjacent rows 69 | h_grad = torch.pow(tensor[:, :, : height - 1, :] - tensor[:, :, 1:, :], 2) 70 | 71 | # Compute vertical gradients by taking the difference between adjacent columns 72 | w_grad = torch.pow(tensor[:, :, :, : width - 1] - tensor[:, :, :, 1:], 2) 73 | 74 | # Calculate gradient magnitude by summing squares of gradients and taking the square root 75 | grad_magnitude = torch.sqrt(h_grad[:, :, :, :-1] + w_grad[:, :, :-1, :]) 76 | 77 | # Crop the gradient magnitude tensor to the specified window 78 | windowed_grad = grad_magnitude[:, :, h : h + window, w : w + window] 79 | 80 | # Apply reduction (sum or mean) to the cropped window 81 | if reduce == "sum": 82 | return torch.sum(windowed_grad) 83 | elif reduce == "mean": 84 | return torch.mean(windowed_grad) 85 | else: 86 | raise ValueError(f"Invalid reduction type: {reduce}. Use 'sum' or 'mean'.") 87 | 88 | 89 | def down_up(X: torch.Tensor, scale_factor: float = 0.5) -> torch.Tensor: 90 | """Downsample and upsample an image using bilinear interpolation. 91 | 92 | Args: 93 | X (torch.Tensor): The input tensor (Bands x Height x Width). 94 | scale_factor (float, optional): The scaling factor. Defaults to 0.5. 95 | 96 | Returns: 97 | torch.Tensor: The downsampled and upsampled image. 98 | """ 99 | shape_init = X.shape 100 | return torch.nn.functional.interpolate( 101 | input=torch.nn.functional.interpolate( 102 | input=X, scale_factor=1 / scale_factor, mode="bilinear", antialias=True 103 | ), 104 | size=shape_init[2:], 105 | mode="bilinear", 106 | antialias=True, 107 | ) 108 | 109 | 110 | def create_blur_cube(X: torch.Tensor, scales: list) -> torch.Tensor: 111 | """Create a cube of blurred images at different scales. 112 | 113 | Args: 114 | X (torch.Tensor): The input tensor (Bands x Height x Width). 115 | scales (list): The scales to evaluate. 116 | 117 | Returns: 118 | torch.Tensor: The cube of blurred images. 119 | """ 120 | scales_int = [float(scale[:-1]) for scale in scales] 121 | return torch.stack([down_up(X[None], scale) for scale in scales_int]).squeeze() 122 | 123 | 124 | def create_lam_inputs( 125 | X: torch.Tensor, scales: list 126 | ) -> Tuple[torch.Tensor, torch.Tensor, list]: 127 | """Create the inputs for the Local Attribution Map (LAM). 128 | 129 | Args: 130 | X (torch.Tensor): The input tensor (Bands x Height x Width). 131 | scales (list): The scales to evaluate. 132 | 133 | Returns: 134 | Tuple[torch.Tensor, torch.Tensor, list]: The cube of blurred 135 | images, the difference between the input and the cube, 136 | and the scales. 137 | """ 138 | cube = create_blur_cube(X, scales) 139 | diff = torch.abs(X[None] - cube) 140 | return cube[1:], diff[1:], scales[1:] 141 | 142 | 143 | def lam( 144 | X: torch.Tensor, 145 | model: torch.nn.Module, 146 | h: int = 240, 147 | w: int = 240, 148 | window: int = 32, 149 | scales: list = ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"], 150 | ) -> Tuple[np.ndarray, float, float, np.ndarray]: 151 | """Estimate the Local Attribution Map (LAM) 152 | 153 | Args: 154 | X (torch.Tensor): The input tensor (Bands x Height x Width). 155 | model (torch.nn.Module): The model to evaluate. 156 | model_scale (float, optional): The scale of the model. Defaults to 4. 157 | h (int, optional): The height of the window to evaluate. Defaults to 240. 158 | w (int, optional): The width of the window to evaluate. Defaults to 240. 159 | window (int, optional): The window size. Defaults to 32. 160 | scales (list, optional): The scales to evaluate. Defaults to 161 | ["1x", "2x", "3x", "4x", "5x", "6x", "7x", "8x"]. 162 | 163 | Returns: 164 | Tuple[np.ndarray, float, float, np.ndarray]: _description_ 165 | """ 166 | 167 | # Create the LAM inputs 168 | cube, diff, scales = create_lam_inputs(X, scales) 169 | 170 | # Create the attribution objective function 171 | attr_objective = attribution_objective(attr_grad, h, w, window=window) 172 | 173 | # Initialize the gradient accumulation list 174 | grad_accumulate_list = torch.zeros_like(cube).cpu().numpy() 175 | 176 | # Compute gradient for each interpolated image 177 | for i in tqdm(range(cube.shape[0]), desc="Computing gradients"): 178 | 179 | # Convert interpolated image to tensor and set requires_grad for backpropagation 180 | img_tensor = cube[i].float()[None] 181 | img_tensor.requires_grad_(True) 182 | 183 | # Forward pass through the model and compute attribution objective 184 | result = model(img_tensor) 185 | target = attr_objective(result) 186 | target.backward() # Compute gradients 187 | 188 | # determine the scale of the model 189 | if i == 0: 190 | scale_factor = result.shape[2] / img_tensor.shape[2] 191 | 192 | # Extract gradient, handling NaNs if present 193 | grad = img_tensor.grad.cpu().numpy() 194 | grad = np.nan_to_num(grad) # Replace NaNs with 0 195 | 196 | # Accumulate gradients adjusted by lambda derivatives 197 | grad_accumulate_list[i] = grad * diff[i].cpu().numpy() 198 | 199 | # Sum the accumulated gradients across all bands 200 | lam_results = torch.sum(torch.from_numpy(np.abs(grad_accumulate_list)), dim=0) 201 | grad_2d = np.abs(lam_results.sum(axis=0)) 202 | grad_max = grad_2d.max() 203 | grad_norm = grad_2d / grad_max 204 | 205 | # Estimate gini index 206 | gini_index = gini(grad_norm.flatten()) 207 | 208 | ## window to image size 209 | # ratio_img_to_window = (X.shape[1] * model_scale) // window 210 | 211 | # KDE estimation 212 | kde_map = np.log1p(vis_saliency_kde(grad_norm, scale=scale_factor, bandwidth=1.0)) 213 | complexity_metric = (1 - gini_index) * 100 # / ratio_img_to_window 214 | 215 | # Estimate blurriness sensitivity 216 | robustness_vector = np.abs(grad_accumulate_list).mean(axis=(1, 2, 3)) 217 | robustness_metric = np.trapz(robustness_vector) 218 | 219 | # Return the LAM results 220 | return kde_map, complexity_metric, robustness_metric, robustness_vector 221 | -------------------------------------------------------------------------------- /sen2sr/xai/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def torch_gaussian_kde( 8 | points: torch.Tensor, weights: torch.Tensor, grid: torch.Tensor, bandwidth: float 9 | ) -> torch.Tensor: 10 | """ 11 | Perform Kernel Density Estimation (KDE) using a Gaussian kernel with PyTorch. 12 | 13 | Args: 14 | points (torch.Tensor): A 2D tensor of shape (2, n_points) with [x, y] positions. 15 | weights (torch.Tensor): A 1D tensor of shape (n_points,) with weights for each point. 16 | grid (torch.Tensor): A 2D tensor of shape (2, n_grid_points) with [x, y] positions for the evaluation grid. 17 | bandwidth (float): The bandwidth (standard deviation) for the Gaussian kernel. 18 | 19 | Returns: 20 | torch.Tensor: A tensor representing KDE values evaluated at the grid positions. 21 | """ 22 | # Compute pairwise squared distances between grid and data points 23 | distances = torch.cdist(grid.T, points.T, p=2) ** 2 24 | 25 | # Apply Gaussian kernel 26 | kernel_values = torch.exp(-distances / (2 * bandwidth**2)) 27 | 28 | # Weight and sum the kernel values to get the KDE 29 | kde_values = (kernel_values * weights).sum(dim=1) 30 | 31 | return kde_values 32 | 33 | 34 | def vis_saliency_kde( 35 | map: torch.Tensor, scale: int = 4, bandwidth: float = 1.0 36 | ) -> torch.Tensor: 37 | """ 38 | Visualize saliency map KDE using a Gaussian kernel. 39 | 40 | Args: 41 | map (torch.Tensor): A 2D tensor representing the saliency map. 42 | scale (int): Scaling factor for the output density map. 43 | bandwidth (float): Bandwidth for the KDE Gaussian kernel. 44 | 45 | Returns: 46 | torch.Tensor: A normalized 2D tensor representing the KDE of the saliency map. 47 | """ 48 | # Flatten the saliency map and prepare coordinates 49 | grad_flat = map.flatten() 50 | datapoint_y, datapoint_x = torch.meshgrid( 51 | torch.arange(map.shape[0], dtype=torch.float32), 52 | torch.arange(map.shape[1], dtype=torch.float32), 53 | ) 54 | pixels = torch.vstack([datapoint_x.flatten(), datapoint_y.flatten()]) 55 | 56 | # Generate grid for KDE evaluation 57 | Y, X = torch.meshgrid( 58 | torch.arange(0, map.shape[0], dtype=torch.float32), 59 | torch.arange(0, map.shape[1], dtype=torch.float32), 60 | indexing="ij", 61 | ) 62 | grid_positions = torch.vstack([X.flatten(), Y.flatten()]) 63 | 64 | # Perform KDE on the grid 65 | kde_values = torch_gaussian_kde( 66 | pixels, grad_flat, grid_positions, bandwidth=bandwidth 67 | ) 68 | 69 | # Reshape and normalize KDE output 70 | Z = kde_values.reshape(map.shape) 71 | Z = Z / Z.max() 72 | 73 | # Reshape to the SR scale 74 | Z = torch.nn.functional.interpolate( 75 | Z[None, None], scale_factor=scale, mode="bicubic", antialias=True 76 | ).squeeze() 77 | 78 | return Z 79 | 80 | 81 | def gini(array: Union[np.ndarray, list]) -> float: 82 | """ 83 | Calculate the Gini coefficient of a 1-dimensional array. The Gini coefficient is a measure of inequality 84 | where 0 represents perfect equality and 1 represents maximal inequality. 85 | 86 | Args: 87 | array (Union[np.ndarray, list]): A 1D array or list of numerical values for which the Gini coefficient is calculated. 88 | 89 | Returns: 90 | float: The Gini coefficient, a value between 0 and 1. 91 | 92 | Notes: 93 | - This implementation is based on the formula for the Gini coefficient described here: 94 | http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm 95 | - The input array is treated as a 1-dimensional array. 96 | - All values in the array must be non-negative. Negative values are shifted to zero if present. 97 | - Zero values are adjusted slightly to avoid division by zero. 98 | 99 | """ 100 | # Ensure array is a flattened 1D numpy array 101 | array = np.asarray(array).flatten() 102 | 103 | # Shift values if there are any negative elements, as Gini requires non-negative values 104 | if np.amin(array) < 0: 105 | array -= np.amin(array) 106 | 107 | # Avoid division by zero by slightly adjusting zero values 108 | array += 1e-7 109 | 110 | # Sort array values in ascending order for the Gini calculation 111 | array = np.sort(array) 112 | 113 | # Create an index array (1-based) for each element in the sorted array 114 | index = np.arange(1, array.shape[0] + 1) 115 | 116 | # Calculate the number of elements in the array 117 | n = array.shape[0] 118 | 119 | # Compute the Gini coefficient using the sorted values and index-based formula 120 | gini_coefficient = (np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)) 121 | 122 | return gini_coefficient 123 | -------------------------------------------------------------------------------- /tests/test_foo.py: -------------------------------------------------------------------------------- 1 | def test_foo(): 2 | assert 1 3 | --------------------------------------------------------------------------------