├── .cruft.json ├── .github └── workflows │ └── on-push.yml ├── .gitignore ├── .pre-commit-config-cruft.yaml ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── ci ├── environment-ci.yml └── environment-integration.yml ├── deepr ├── __init__.py ├── cli.py ├── data │ ├── __init__.py │ ├── configuration.py │ ├── files.py │ ├── generator.py │ ├── scaler.py │ └── static │ │ ├── __init__.py │ │ └── climatology.py ├── model │ ├── __init__.py │ ├── activations.py │ ├── attention.py │ ├── autoencoder_trainer.py │ ├── conditional_ddpm.py │ ├── configs.py │ ├── conv_baseline.py │ ├── conv_swin2sr.py │ ├── diffusion_trainer.py │ ├── loss.py │ ├── models.py │ ├── nn_trainer.py │ ├── resnet.py │ ├── unet.py │ ├── unet_blocks.py │ └── utils.py ├── utilities │ ├── __init__.py │ ├── logger.py │ └── yml.py ├── validation │ ├── __init__.py │ ├── generate_data.py │ ├── netcdf │ │ ├── __init__.py │ │ ├── metrics.py │ │ ├── validation.py │ │ └── visualize.py │ ├── nn_performance_metrics.py │ ├── sample_predictions.py │ └── validation_nn.py ├── visualizations │ ├── __init__.py │ ├── giffs.py │ ├── plot_maps.py │ ├── plot_rose.py │ └── plot_samples.py └── workflow.py ├── docs ├── Makefile ├── _static │ ├── .gitkeep │ ├── convswin2sr_scheme.png │ ├── dp_scheme.png │ ├── eps-U-Net diagram.svg │ ├── pos_embedding.png │ ├── project_motivation.png │ ├── spatial-domain-small.png │ └── standardization_types.png ├── _templates │ └── .gitkeep ├── conf.py ├── index.md ├── make.bat └── usage │ ├── data.md │ ├── installation.md │ ├── methodology.md │ └── references.md ├── environment.yml ├── environment_CUDA.yml ├── pyproject.toml ├── resources ├── configuration_diffusion.yml ├── configuration_nn_bicubic.yml ├── configuration_nn_evaluation.yml ├── configuration_nn_swin2sr.yml └── configuration_vqvae.yml ├── scripts ├── download │ ├── climate_data_store.py │ └── european_weather_cloud.py ├── modeling │ ├── generate_model_predictions.py │ ├── train_model.py │ └── validate_model_predictions.py └── processing │ └── data_spatial_selection.py ├── setup.cfg └── tests ├── test_00_version.py └── tests_data └── test_files.py /.cruft.json: -------------------------------------------------------------------------------- 1 | { 2 | "template": "https://github.com/ecmwf-projects/cookiecutter-conda-package", 3 | "commit": "5d59b459cd43e5f527eb429e3546b5a22e4e9240", 4 | "checkout": null, 5 | "context": { 6 | "cookiecutter": { 7 | "project_name": "DeepR", 8 | "project_slug": "deepr", 9 | "project_short_description": "DeepR: Deep Reanalysis", 10 | "copyright_holder": "European Union", 11 | "copyright_year": "2023", 12 | "mypy_strict": "False", 13 | "integration_tests": "False", 14 | "_template": "https://github.com/ecmwf-projects/cookiecutter-conda-package" 15 | } 16 | }, 17 | "directory": null 18 | } 19 | -------------------------------------------------------------------------------- /.github/workflows/on-push.yml: -------------------------------------------------------------------------------- 1 | name: on-push 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | tags: 8 | - '*' 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | defaults: 18 | run: 19 | shell: bash -l {0} 20 | 21 | jobs: 22 | pre-commit: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v3 26 | - uses: actions/setup-python@v4 27 | with: 28 | python-version: 3.x 29 | - uses: pre-commit/action@v3.0.0 30 | 31 | combine-environments: 32 | runs-on: ubuntu-latest 33 | 34 | steps: 35 | - uses: actions/checkout@v3 36 | - name: Install conda-merge 37 | run: | 38 | $CONDA/bin/python -m pip install conda-merge 39 | - name: Combine environments 40 | run: | 41 | for SUFFIX in ci integration; do 42 | $CONDA/bin/conda-merge ci/environment-$SUFFIX.yml environment.yml > ci/combined-environment-$SUFFIX.yml || exit 43 | done 44 | - name: Archive combined environments 45 | uses: actions/upload-artifact@v3 46 | with: 47 | name: combined-environments 48 | path: ci/combined-environment-*.yml 49 | 50 | unit-tests: 51 | name: unit-tests 52 | needs: combine-environments 53 | runs-on: ubuntu-latest 54 | strategy: 55 | matrix: 56 | python-version: ['3.10', '3.11'] 57 | 58 | steps: 59 | - uses: actions/checkout@v3 60 | - name: Download combined environments 61 | uses: actions/download-artifact@v3 62 | with: 63 | name: combined-environments 64 | path: ci 65 | - name: Install Conda environment with Micromamba 66 | uses: mamba-org/setup-micromamba@v1 67 | with: 68 | environment-file: ci/combined-environment-ci.yml 69 | environment-name: DEVELOP 70 | cache-environment: true 71 | create-args: >- 72 | python=${{ matrix.python-version }} 73 | - name: Install package 74 | run: | 75 | python -m pip install --no-deps -e . 76 | - name: Run tests 77 | run: | 78 | make unit-tests COV_REPORT=xml 79 | 80 | type-check: 81 | needs: [combine-environments, unit-tests] 82 | runs-on: ubuntu-latest 83 | 84 | steps: 85 | - uses: actions/checkout@v3 86 | - name: Download combined environments 87 | uses: actions/download-artifact@v3 88 | with: 89 | name: combined-environments 90 | path: ci 91 | - name: Install Conda environment with Micromamba 92 | uses: mamba-org/setup-micromamba@v1 93 | with: 94 | environment-file: ci/combined-environment-ci.yml 95 | environment-name: DEVELOP 96 | cache-environment: true 97 | create-args: >- 98 | python=3.10 99 | - name: Install package 100 | run: | 101 | python -m pip install --no-deps -e . 102 | - name: Run code quality checks 103 | run: | 104 | make type-check 105 | 106 | docs-build: 107 | needs: [combine-environments, unit-tests] 108 | runs-on: ubuntu-latest 109 | 110 | steps: 111 | - uses: actions/checkout@v3 112 | - name: Download combined environments 113 | uses: actions/download-artifact@v3 114 | with: 115 | name: combined-environments 116 | path: ci 117 | - name: Install Conda environment with Micromamba 118 | uses: mamba-org/setup-micromamba@v1 119 | with: 120 | environment-file: ci/combined-environment-ci.yml 121 | environment-name: DEVELOP 122 | cache-environment: true 123 | create-args: >- 124 | python=3.10 125 | - name: Install package 126 | run: | 127 | python -m pip install --no-deps -e . 128 | - name: Build documentation 129 | run: | 130 | make docs-build 131 | 132 | integration-tests: 133 | needs: [combine-environments, unit-tests] 134 | if: | 135 | success() && false 136 | runs-on: ubuntu-latest 137 | 138 | strategy: 139 | matrix: 140 | include: 141 | - python-version: '3.10' 142 | extra: -integration 143 | 144 | steps: 145 | - uses: actions/checkout@v3 146 | - name: Download combined environments 147 | uses: actions/download-artifact@v3 148 | with: 149 | name: combined-environments 150 | path: ci 151 | - name: Install Conda environment with Micromamba 152 | uses: mamba-org/setup-micromamba@v1 153 | with: 154 | environment-file: ci/combined-environment${{ matrix.extra }}.yml 155 | environment-name: DEVELOP${{ matrix.extra }} 156 | cache-environment: true 157 | create-args: >- 158 | python=${{ matrix.python-version }} 159 | - name: Install package 160 | run: | 161 | python -m pip install --no-deps -e . 162 | - name: Run tests 163 | run: | 164 | make unit-tests COV_REPORT=xml 165 | 166 | distribution: 167 | runs-on: ubuntu-latest 168 | needs: [unit-tests, type-check, docs-build, integration-tests] 169 | if: | 170 | always() && 171 | needs.unit-tests.result == 'success' && 172 | needs.type-check.result == 'success' && 173 | needs.docs-build.result == 'success' && 174 | (needs.integration-tests.result == 'success' || needs.integration-tests.result == 'skipped') 175 | 176 | steps: 177 | - uses: actions/checkout@v3 178 | - name: Install packages 179 | run: | 180 | $CONDA/bin/python -m pip install build twine 181 | - name: Build distributions 182 | run: | 183 | $CONDA/bin/python -m build 184 | - name: Check wheels 185 | run: | 186 | cd dist || exit 187 | $CONDA/bin/python -m pip install DeepR*.whl || exit 188 | $CONDA/bin/python -m twine check * || exit 189 | $CONDA/bin/python -c "import deepr" 190 | - name: Publish a Python distribution to PyPI 191 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 192 | uses: pypa/gh-action-pypi-publish@release/v1 193 | with: 194 | user: __token__ 195 | password: ${{ secrets.PYPI_API_TOKEN }} 196 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # setuptools-scm 2 | version.py 3 | 4 | # Sphinx automatic generation of API 5 | docs/_api/ 6 | 7 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vim,visualstudiocode,pycharm,emacs,linux,macos,windows 8 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,vim,visualstudiocode,pycharm,emacs,linux,macos,windows 9 | 10 | ### Emacs ### 11 | # -*- mode: gitignore; -*- 12 | *~ 13 | \#*\# 14 | /.emacs.desktop 15 | /.emacs.desktop.lock 16 | *.elc 17 | auto-save-list 18 | tramp 19 | .\#* 20 | 21 | # Org-mode 22 | .org-id-locations 23 | *_archive 24 | 25 | # flymake-mode 26 | *_flymake.* 27 | 28 | # eshell files 29 | /eshell/history 30 | /eshell/lastdir 31 | 32 | # elpa packages 33 | /elpa/ 34 | 35 | # reftex files 36 | *.rel 37 | 38 | # AUCTeX auto folder 39 | /auto/ 40 | 41 | # cask packages 42 | .cask/ 43 | dist/ 44 | 45 | # Flycheck 46 | flycheck_*.el 47 | 48 | # server auth directory 49 | /server/ 50 | 51 | # projectiles files 52 | .projectile 53 | 54 | # directory configuration 55 | .dir-locals.el 56 | 57 | # network security 58 | /network-security.data 59 | 60 | 61 | ### JupyterNotebooks ### 62 | # gitignore template for Jupyter Notebooks 63 | # website: http://jupyter.org/ 64 | 65 | .ipynb_checkpoints 66 | */.ipynb_checkpoints/* 67 | 68 | # IPython 69 | profile_default/ 70 | ipython_config.py 71 | 72 | # Remove previous ipynb_checkpoints 73 | # git rm -r .ipynb_checkpoints/ 74 | 75 | ### Linux ### 76 | 77 | # temporary files which can be created if a process still has a handle open of a deleted file 78 | .fuse_hidden* 79 | 80 | # KDE directory preferences 81 | .directory 82 | 83 | # Linux trash folder which might appear on any partition or disk 84 | .Trash-* 85 | 86 | # .nfs files are created when an open file is removed but is still being accessed 87 | .nfs* 88 | 89 | ### macOS ### 90 | # General 91 | .DS_Store 92 | .AppleDouble 93 | .LSOverride 94 | 95 | # Icon must end with two \r 96 | Icon 97 | 98 | # Thumbnails 99 | ._* 100 | 101 | # Files that might appear in the root of a volume 102 | .DocumentRevisions-V100 103 | .fseventsd 104 | .Spotlight-V100 105 | .TemporaryItems 106 | .Trashes 107 | .VolumeIcon.icns 108 | .com.apple.timemachine.donotpresent 109 | 110 | # Directories potentially created on remote AFP share 111 | .AppleDB 112 | .AppleDesktop 113 | Network Trash Folder 114 | Temporary Items 115 | .apdisk 116 | 117 | ### macOS Patch ### 118 | # iCloud generated files 119 | *.icloud 120 | 121 | ### PyCharm ### 122 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 123 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 124 | 125 | # User-specific stuff 126 | .idea/**/workspace.xml 127 | .idea/**/tasks.xml 128 | .idea/**/usage.statistics.xml 129 | .idea/**/dictionaries 130 | .idea/**/shelf 131 | 132 | # AWS User-specific 133 | .idea/**/aws.xml 134 | 135 | # Generated files 136 | .idea/**/contentModel.xml 137 | 138 | # Sensitive or high-churn files 139 | .idea/**/dataSources/ 140 | .idea/**/dataSources.ids 141 | .idea/**/dataSources.local.xml 142 | .idea/**/sqlDataSources.xml 143 | .idea/**/dynamic.xml 144 | .idea/**/uiDesigner.xml 145 | .idea/**/dbnavigator.xml 146 | 147 | # Gradle 148 | .idea/**/gradle.xml 149 | .idea/**/libraries 150 | 151 | # Gradle and Maven with auto-import 152 | # When using Gradle or Maven with auto-import, you should exclude module files, 153 | # since they will be recreated, and may cause churn. Uncomment if using 154 | # auto-import. 155 | # .idea/artifacts 156 | # .idea/compiler.xml 157 | # .idea/jarRepositories.xml 158 | # .idea/modules.xml 159 | # .idea/*.iml 160 | # .idea/modules 161 | # *.iml 162 | # *.ipr 163 | 164 | # CMake 165 | cmake-build-*/ 166 | 167 | # Mongo Explorer plugin 168 | .idea/**/mongoSettings.xml 169 | 170 | # File-based project format 171 | *.iws 172 | 173 | # IntelliJ 174 | out/ 175 | 176 | # mpeltonen/sbt-idea plugin 177 | .idea_modules/ 178 | 179 | # JIRA plugin 180 | atlassian-ide-plugin.xml 181 | 182 | # Cursive Clojure plugin 183 | .idea/replstate.xml 184 | 185 | # SonarLint plugin 186 | .idea/sonarlint/ 187 | 188 | # Crashlytics plugin (for Android Studio and IntelliJ) 189 | com_crashlytics_export_strings.xml 190 | crashlytics.properties 191 | crashlytics-build.properties 192 | fabric.properties 193 | 194 | # Editor-based Rest Client 195 | .idea/httpRequests 196 | 197 | # Android studio 3.1+ serialized cache file 198 | .idea/caches/build_file_checksums.ser 199 | 200 | ### PyCharm Patch ### 201 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 202 | 203 | # *.iml 204 | # modules.xml 205 | # .idea/misc.xml 206 | # *.ipr 207 | 208 | # Sonarlint plugin 209 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 210 | .idea/**/sonarlint/ 211 | 212 | # SonarQube Plugin 213 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 214 | .idea/**/sonarIssues.xml 215 | 216 | # Markdown Navigator plugin 217 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 218 | .idea/**/markdown-navigator.xml 219 | .idea/**/markdown-navigator-enh.xml 220 | .idea/**/markdown-navigator/ 221 | 222 | # Cache file creation bug 223 | # See https://youtrack.jetbrains.com/issue/JBR-2257 224 | .idea/$CACHE_FILE$ 225 | 226 | # CodeStream plugin 227 | # https://plugins.jetbrains.com/plugin/12206-codestream 228 | .idea/codestream.xml 229 | 230 | # Azure Toolkit for IntelliJ plugin 231 | # https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij 232 | .idea/**/azureSettings.xml 233 | 234 | ### Python ### 235 | # Byte-compiled / optimized / DLL files 236 | __pycache__/ 237 | *.py[cod] 238 | *$py.class 239 | 240 | # C extensions 241 | *.so 242 | 243 | # Distribution / packaging 244 | .Python 245 | build/ 246 | develop-eggs/ 247 | downloads/ 248 | eggs/ 249 | .eggs/ 250 | lib/ 251 | lib64/ 252 | parts/ 253 | sdist/ 254 | var/ 255 | wheels/ 256 | share/python-wheels/ 257 | *.egg-info/ 258 | .installed.cfg 259 | *.egg 260 | MANIFEST 261 | 262 | # PyInstaller 263 | # Usually these files are written by a python script from a template 264 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 265 | *.manifest 266 | *.spec 267 | 268 | # Installer logs 269 | pip-log.txt 270 | pip-delete-this-directory.txt 271 | 272 | # Unit test / coverage reports 273 | htmlcov/ 274 | .tox/ 275 | .nox/ 276 | .coverage 277 | .coverage.* 278 | .cache 279 | nosetests.xml 280 | coverage.xml 281 | *.cover 282 | *.py,cover 283 | .hypothesis/ 284 | .pytest_cache/ 285 | cover/ 286 | 287 | # Translations 288 | *.mo 289 | *.pot 290 | 291 | # Django stuff: 292 | *.log 293 | local_settings.py 294 | db.sqlite3 295 | db.sqlite3-journal 296 | 297 | # Flask stuff: 298 | instance/ 299 | .webassets-cache 300 | 301 | # Scrapy stuff: 302 | .scrapy 303 | 304 | # Sphinx documentation 305 | docs/_build/ 306 | 307 | # PyBuilder 308 | .pybuilder/ 309 | target/ 310 | 311 | # Jupyter Notebook 312 | 313 | # IPython 314 | 315 | # pyenv 316 | # For a library or package, you might want to ignore these files since the code is 317 | # intended to run in multiple environments; otherwise, check them in: 318 | # .python-version 319 | 320 | # pipenv 321 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 322 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 323 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 324 | # install all needed dependencies. 325 | #Pipfile.lock 326 | 327 | # poetry 328 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 329 | # This is especially recommended for binary packages to ensure reproducibility, and is more 330 | # commonly ignored for libraries. 331 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 332 | #poetry.lock 333 | 334 | # pdm 335 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 336 | #pdm.lock 337 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 338 | # in version control. 339 | # https://pdm.fming.dev/#use-with-ide 340 | .pdm.toml 341 | 342 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 343 | __pypackages__/ 344 | 345 | # Celery stuff 346 | celerybeat-schedule 347 | celerybeat.pid 348 | 349 | # SageMath parsed files 350 | *.sage.py 351 | 352 | # Environments 353 | .env 354 | .venv 355 | env/ 356 | venv/ 357 | ENV/ 358 | env.bak/ 359 | venv.bak/ 360 | 361 | # Spyder project settings 362 | .spyderproject 363 | .spyproject 364 | 365 | # Rope project settings 366 | .ropeproject 367 | 368 | # mkdocs documentation 369 | /site 370 | 371 | # mypy 372 | .mypy_cache/ 373 | .dmypy.json 374 | dmypy.json 375 | 376 | # Pyre type checker 377 | .pyre/ 378 | 379 | # pytype static type analyzer 380 | .pytype/ 381 | 382 | # Cython debug symbols 383 | cython_debug/ 384 | 385 | # PyCharm 386 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 387 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 388 | # and can be added to the global gitignore or merged into this file. For a more nuclear 389 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 390 | #.idea/ 391 | 392 | ### Python Patch ### 393 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 394 | poetry.toml 395 | 396 | # ruff 397 | .ruff_cache/ 398 | 399 | # LSP config files 400 | pyrightconfig.json 401 | 402 | ### Vim ### 403 | # Swap 404 | [._]*.s[a-v][a-z] 405 | !*.svg # comment out if you don't need vector files 406 | [._]*.sw[a-p] 407 | [._]s[a-rt-v][a-z] 408 | [._]ss[a-gi-z] 409 | [._]sw[a-p] 410 | 411 | # Session 412 | Session.vim 413 | Sessionx.vim 414 | 415 | # Temporary 416 | .netrwhist 417 | # Auto-generated tag files 418 | tags 419 | # Persistent undo 420 | [._]*.un~ 421 | 422 | ### VisualStudioCode ### 423 | .vscode/ 424 | # .vscode/* 425 | # !.vscode/settings.json 426 | # !.vscode/tasks.json 427 | # !.vscode/launch.json 428 | # !.vscode/extensions.json 429 | # !.vscode/*.code-snippets 430 | 431 | # Local History for Visual Studio Code 432 | .history/ 433 | 434 | # Built Visual Studio Code Extensions 435 | *.vsix 436 | 437 | ### VisualStudioCode Patch ### 438 | # Ignore all local history of files 439 | .history 440 | .ionide 441 | 442 | ### Windows ### 443 | # Windows thumbnail cache files 444 | Thumbs.db 445 | Thumbs.db:encryptable 446 | ehthumbs.db 447 | ehthumbs_vista.db 448 | 449 | # Dump file 450 | *.stackdump 451 | 452 | # Folder config file 453 | [Dd]esktop.ini 454 | 455 | # Recycle Bin used on file shares 456 | $RECYCLE.BIN/ 457 | 458 | # Windows Installer files 459 | *.cab 460 | *.msi 461 | *.msix 462 | *.msm 463 | *.msp 464 | 465 | # Windows shortcuts 466 | *.lnk 467 | 468 | /tests/data/ 469 | resources/*.yml 470 | -------------------------------------------------------------------------------- /.pre-commit-config-cruft.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/cruft/cruft 3 | rev: 2.15.0 4 | hooks: 5 | - id: cruft 6 | entry: cruft update -y 7 | additional_dependencies: [toml] 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-json 8 | - id: check-yaml 9 | - id: check-toml 10 | - id: check-added-large-files 11 | - id: debug-statements 12 | - id: mixed-line-ending 13 | - repo: https://github.com/psf/black 14 | rev: 23.3.0 15 | hooks: 16 | - id: black 17 | - repo: https://github.com/keewis/blackdoc 18 | rev: v0.3.8 19 | hooks: 20 | - id: blackdoc 21 | additional_dependencies: [black==22.3.0] 22 | - repo: https://github.com/charliermarsh/ruff-pre-commit 23 | rev: v0.0.270 24 | hooks: 25 | - id: ruff 26 | args: [--fix, --show-fixes] 27 | - repo: https://github.com/executablebooks/mdformat 28 | rev: 0.7.16 29 | hooks: 30 | - id: mdformat 31 | - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks 32 | rev: v2.9.0 33 | hooks: 34 | - id: pretty-format-yaml 35 | args: [--autofix, --preserve-quotes] 36 | - id: pretty-format-toml 37 | args: [--autofix] 38 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | ARG PYTHON_VERSION=3.10 5 | 6 | WORKDIR /home/deepr 7 | 8 | COPY . . 9 | 10 | RUN conda install -c conda-forge gcc python=${PYTHON_VERSION} && \ 11 | conda env update -n base -f environment_CUDA.yml 12 | 13 | RUN pip install --no-deps -e . 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PROJECT := deepr 2 | CONDA := conda 3 | MAMBA := mamba 4 | CONDAFLAGS := 5 | MAMBAFLAGS := 6 | COV_REPORT := html 7 | 8 | default: qa unit-tests type-check 9 | 10 | qa: 11 | pre-commit run --all-files 12 | 13 | unit-tests: 14 | python -m pytest -vv --cov=. --cov-report=$(COV_REPORT) --doctest-glob="*.md" --doctest-glob="*.rst" 15 | 16 | type-check: 17 | python -m mypy . 18 | 19 | conda-env-update: 20 | $(CONDA) env update $(CONDAFLAGS) -f ci/environment-ci.yml 21 | $(CONDA) env update $(CONDAFLAGS) -f environment.yml 22 | 23 | mamba-env-update: 24 | $(MAMBA) env update $(MAMBAFLAGS) -f ci/environment-ci.yml 25 | $(MAMBA) env update $(MAMBAFLAGS) -f environment.yml 26 | 27 | mamba-cuda_env-update: 28 | $(MAMBA) env update $(MAMBAFLAGS) -f ci/environment-ci.yml 29 | $(MAMBA) env update $(MAMBAFLAGS) -f environment_CUDA.yml 30 | 31 | docker-build: 32 | docker build -t $(PROJECT) . 33 | 34 | docker-run: 35 | docker run --rm -ti -v $(PWD):/srv $(PROJECT) 36 | 37 | template-update: 38 | pre-commit run --all-files cruft -c .pre-commit-config-cruft.yaml 39 | 40 | docs-build: 41 | cd docs && rm -fr _api && make clean && make html 42 | 43 | # DO NOT EDIT ABOVE THIS LINE, ADD COMMANDS BELOW 44 | -------------------------------------------------------------------------------- /ci/environment-ci.yml: -------------------------------------------------------------------------------- 1 | # environment-ci.yml: Additional dependencies to install in the CI environment. 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - make 7 | - mypy 8 | - myst-parser 9 | - pre-commit 10 | - pydata-sphinx-theme 11 | - pytest 12 | - pytest-cov 13 | - sphinx 14 | - sphinx-autoapi 15 | # DO NOT EDIT ABOVE THIS LINE, ADD DEPENDENCIES BELOW 16 | -------------------------------------------------------------------------------- /ci/environment-integration.yml: -------------------------------------------------------------------------------- 1 | # environment-integration.yml: Additional dependencies to install in the integration environment (e.g., pinned dependencies). 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - make 7 | - pytest 8 | - pytest-cov 9 | # DO NOT EDIT ABOVE THIS LINE, ADD DEPENDENCIES BELOW 10 | -------------------------------------------------------------------------------- /deepr/__init__.py: -------------------------------------------------------------------------------- 1 | """DeepR: Deep Reanalysis.""" 2 | 3 | # Copyright 2023, European Union. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | try: 18 | # NOTE: the `version.py` file must not be present in the git repository 19 | # as it is generated by setuptools at install time 20 | from .version import __version__ 21 | except ImportError: # pragma: no cover 22 | # Local copy or not installed with setuptools 23 | __version__ = "999" 24 | 25 | __all__ = ["__version__"] 26 | -------------------------------------------------------------------------------- /deepr/cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | 5 | from deepr.utilities.logger import get_logger 6 | from deepr.workflow import MainPipeline 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | @click.group() 12 | def launcher(): 13 | pass 14 | 15 | 16 | @launcher.command("train_model") 17 | @click.option( 18 | "-c", 19 | "--configuration_yaml", 20 | help="File configuring the parameters for training the model", 21 | type=str, 22 | ) 23 | def train_model(configuration_yaml: str) -> None: 24 | """ 25 | Train a model. 26 | 27 | Parameters 28 | ---------- 29 | configuration_yaml : str 30 | File configuring the parameters for training the model. 31 | Default is "{main_folder}/resources/configuration/config.yml". 32 | 33 | Returns 34 | ------- 35 | None 36 | This function does not return any value. 37 | 38 | """ 39 | logger.info( 40 | f"Starting the process of training a model with the " 41 | f"configuration specified at {configuration_yaml}" 42 | ) 43 | MainPipeline(configuration_file=Path(configuration_yaml)).run_pipeline() 44 | logger.info("Process finished") 45 | -------------------------------------------------------------------------------- /deepr/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/deepr/data/__init__.py -------------------------------------------------------------------------------- /deepr/data/files.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional 3 | 4 | 5 | class DataFile: 6 | """ 7 | Class for generating and manipulating data paths based on a specific structure. 8 | 9 | Attributes 10 | ---------- 11 | base_dir (str): The base directory where the data files are stored. 12 | variable (str): The variable name. 13 | dataset (str): The dataset name. 14 | date (str): The date of the data. 15 | resolution (str): The resolution of the data. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | base_dir: str, 21 | variable: str, 22 | dataset: str, 23 | temporal_coverage: str, 24 | spatial_resolution: str, 25 | spatial_coverage: Optional[dict] = None, 26 | ): 27 | """ 28 | Initialize a DataPath instance. 29 | 30 | Parameters 31 | ---------- 32 | base_dir : str 33 | The base directory where the data files are stored. 34 | variable : str 35 | The variable name. 36 | dataset : str 37 | The dataset name. 38 | temporal_coverage : str 39 | The temporal coverage of the data. 40 | spatial_resolution : str 41 | The temporal resolution of the data. 42 | spatial_coverage: Optional[dict] 43 | The spatial coverage of the data to be selected. 44 | """ 45 | self.base_dir = base_dir 46 | self.variable = variable 47 | self.dataset = dataset 48 | self.temporal_coverage = temporal_coverage 49 | self.spatial_resolution = spatial_resolution 50 | self.spatial_coverage = spatial_coverage 51 | 52 | @classmethod 53 | def from_path(cls, file_path): 54 | """ 55 | Create a DataPath instance from a file path. 56 | 57 | Parameters 58 | ---------- 59 | file_path : str 60 | The file path. 61 | 62 | Returns 63 | ------- 64 | DataFile 65 | The DataPath instance. 66 | """ 67 | base_dir, filename = os.path.split(file_path) 68 | variable, dataset, date, resolution = filename[:-3].split("_") 69 | return cls(base_dir, variable, dataset, date, resolution) 70 | 71 | def to_path(self): 72 | """ 73 | Generate the file path based on the class attributes. 74 | 75 | Returns 76 | ------- 77 | str 78 | The complete file path. 79 | """ 80 | filename = ( 81 | f"{self.variable}_{self.dataset}_" 82 | f"{self.temporal_coverage}_{self.spatial_resolution}.nc" 83 | ) 84 | return os.path.join(self.base_dir, filename) 85 | 86 | def exist(self) -> bool: 87 | """ 88 | Indicate whether the file returned by to_path method already exists. 89 | 90 | Returns 91 | ------- 92 | bool 93 | True or False indicating if the file returned by to_path exists. 94 | """ 95 | return os.path.exists(self.to_path()) 96 | 97 | 98 | class DataFileCollection: 99 | def __init__(self, collection: List[DataFile]): 100 | self.collection = collection 101 | 102 | def __len__(self): 103 | """Get the length of the collection list.""" 104 | return len(self.collection) 105 | 106 | def append_data(self, data: DataFile): 107 | """ 108 | Append a new data object to the data list. 109 | 110 | Parameters 111 | ---------- 112 | data: Data 113 | The data object to be appended. 114 | """ 115 | if isinstance(data, DataFile): 116 | self.collection.append(data) 117 | else: 118 | raise ValueError("The input object is not a Data object.") 119 | 120 | def find_data(self, **kwargs): 121 | """ 122 | Find a DataFile object in the data list that matches the specified attributes. 123 | 124 | Parameters 125 | ---------- 126 | **kwargs: dict 127 | A dictionary with attributes to match the data objects. 128 | 129 | Returns 130 | ------- 131 | found_data: Data 132 | The first data object that matches the specified attributes. 133 | """ 134 | found_data = [] 135 | for data in self.collection: 136 | match = True 137 | for key, value in kwargs.items(): 138 | if not hasattr(data, key) or getattr(data, key) != value: 139 | match = False 140 | break 141 | if match: 142 | found_data.append(data) 143 | 144 | if len(found_data) == 0: 145 | return None 146 | else: 147 | return DataFileCollection(collection=found_data) 148 | 149 | def sort_data(self, attribute: str): 150 | """ 151 | Sort the collection list by the specified attribute of the DataFile objects. 152 | 153 | Parameters 154 | ---------- 155 | attribute : str 156 | The attribute name to sort the DataFile objects by. 157 | """ 158 | self.collection.sort(key=lambda x: getattr(x, attribute)) 159 | 160 | def split_data(self, split_coefficient: float): 161 | """ 162 | Split the data collection into two different data collections. 163 | 164 | Parameters 165 | ---------- 166 | split_coefficient : float 167 | The coefficient by which the data is split. 168 | """ 169 | idx = int((1 - split_coefficient) * len(self.collection)) 170 | split1 = DataFileCollection(collection=self.collection[:idx]) 171 | split2 = DataFileCollection(collection=self.collection[idx:]) 172 | return split1, split2 173 | 174 | def get_variable_list(self) -> List[str]: 175 | """ 176 | Get the list of variables in the data collection. 177 | 178 | Returns 179 | ------- 180 | variables : list 181 | The list of variables that are available in the data collection 182 | """ 183 | variables = {data_file.variable for data_file in self.collection} 184 | return list(variables) 185 | -------------------------------------------------------------------------------- /deepr/data/scaler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import pickle 4 | from typing import Tuple 5 | 6 | import numpy as np 7 | import pandas 8 | import torch 9 | import xarray 10 | 11 | from deepr.data.files import DataFileCollection 12 | 13 | 14 | class XarrayStandardScaler: 15 | def __init__( 16 | self, 17 | scaling_files: DataFileCollection, 18 | scaling_method: str, 19 | cache_file: pathlib.Path, 20 | ): 21 | """ 22 | Initialize the XarrayStandardScaler object. 23 | 24 | Parameters 25 | ---------- 26 | scaling_files : DataFileCollection 27 | Data files from which the XarrayStandardScaler wants to be calculated. 28 | scaling_method: str 29 | Method to perform the scaling (pixel-wise, domain-wise, ...) 30 | cache_file : pathlib.Path 31 | Path to store the pickle file. 32 | """ 33 | self.scaling_files = scaling_files 34 | self.scaling_method = scaling_method 35 | self.cache_file = cache_file 36 | 37 | if os.path.exists(self.cache_file): 38 | self.load() 39 | else: 40 | self.average, self.standard_deviation = self.get_parameters() 41 | self.average.load() 42 | self.standard_deviation.load() 43 | self.save() 44 | 45 | def to_dict(self) -> str: 46 | average = self.average[list(self.average.data_vars)[0]].values 47 | std = self.standard_deviation[list(self.standard_deviation.data_vars)[0]].values 48 | 49 | if len(average) == 1: 50 | average = float(average) 51 | elif average.ndim in [1, 2]: 52 | average = average.tolist() 53 | else: 54 | raise NotImplementedError("Average of scaler must be 1D or 2D.") 55 | 56 | if len(std) == 1: 57 | std = float(std) 58 | elif std.ndim in [1, 2]: 59 | std = std.tolist() 60 | else: 61 | raise NotImplementedError("Standard deviation of scaler must be 1D or 2D.") 62 | 63 | return { 64 | "method": self.scaling_method, 65 | "time-agg": "monthly", 66 | "average": average, 67 | "standard_deviation": std, 68 | } 69 | 70 | def get_parameters(self) -> Tuple[xarray.Dataset, xarray.Dataset]: 71 | """ 72 | Calculate the mean and standard deviation of the dataset parameters. 73 | 74 | Returns 75 | ------- 76 | mean : xarray.Dataset 77 | The dataset containing the mean values of the parameters. 78 | std : xarray.Dataset 79 | The dataset containing the standard deviation values of the parameters. 80 | """ 81 | datasets = [] 82 | for file in self.scaling_files.collection: 83 | file_path = file.to_path() 84 | dataset = xarray.open_dataset(file_path, chunks=16) 85 | dataset = dataset.sel( 86 | latitude=slice( 87 | file.spatial_coverage["latitude"][0], 88 | file.spatial_coverage["latitude"][1], 89 | ), 90 | longitude=slice( 91 | file.spatial_coverage["longitude"][0], 92 | file.spatial_coverage["longitude"][1], 93 | ), 94 | ) 95 | datasets.append(dataset) 96 | dataset = xarray.concat(datasets, dim="time") 97 | mean = dataset.groupby("time.month").mean() 98 | std = dataset.groupby("time.month").std() 99 | if self.scaling_method == "pixel-wise": 100 | return mean, std 101 | elif self.scaling_method == "domain-wise": 102 | spatial_dims = ["longitude", "latitude"] 103 | return mean.mean(spatial_dims), std.mean(spatial_dims) 104 | elif self.scaling_method == "landmask-wise": 105 | return mean, std 106 | else: 107 | raise NotImplementedError 108 | 109 | def apply_scaler(self, ds: xarray.Dataset) -> xarray.Dataset: 110 | """ 111 | Apply the standard scaling to the input dataset. 112 | 113 | Parameters 114 | ---------- 115 | ds : xarray.Dataset 116 | Dataset to be scaled. 117 | 118 | Returns 119 | ------- 120 | xarray.Dataset 121 | Scaled dataset with the same dimensions as the input dataset. 122 | 123 | Notes 124 | ----- 125 | This method subtracts the average dataset from the input dataset and divides 126 | it by the standard deviation dataset. 127 | """ 128 | time_month = pandas.to_datetime(ds.time.values).month 129 | ds_scaled = ds - self.average.sel(month=time_month, method="nearest") 130 | ds_scaled = ds_scaled / self.standard_deviation.sel( 131 | month=time_month, method="nearest" 132 | ) 133 | return ds_scaled 134 | 135 | def apply_inverse_scaler( 136 | self, data: torch.Tensor, month: torch.Tensor 137 | ) -> torch.Tensor: 138 | """ 139 | Inverse the standard scaling to the input dataset. 140 | 141 | Parameters 142 | ---------- 143 | data : torch.Tensor 144 | The input dataset that was previously scaled. 145 | month : torch.Tensor 146 | The month tensor used for selecting scaling parameters. 147 | 148 | Returns 149 | ------- 150 | torch.Tensor 151 | The dataset with standard scaling inverted. 152 | 153 | """ 154 | std_tensor = self.standard_deviation.sel( 155 | month=month, method="nearest" 156 | ).to_array() 157 | std_tensor = torch.from_numpy( 158 | std_tensor.squeeze().values[..., np.newaxis, np.newaxis, np.newaxis] 159 | ) 160 | mean_tensor = self.average.sel(month=month, method="nearest").to_array() 161 | mean_tensor = torch.from_numpy( 162 | mean_tensor.squeeze().values[..., np.newaxis, np.newaxis, np.newaxis] 163 | ) 164 | 165 | return data * std_tensor + mean_tensor 166 | 167 | def load(self): 168 | """Load an XarrayStandardScaler object from a pickle file.""" 169 | with open(self.cache_file, "rb") as f: 170 | scaler = pickle.load(f) 171 | self.average = scaler.average 172 | self.standard_deviation = scaler.standard_deviation 173 | 174 | def save(self): 175 | """Save the XarrayStandardScaler object to a pickle file.""" 176 | with open(self.cache_file, "wb") as f: 177 | pickle.dump(self, f) 178 | -------------------------------------------------------------------------------- /deepr/data/static/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/deepr/data/static/__init__.py -------------------------------------------------------------------------------- /deepr/data/static/climatology.py: -------------------------------------------------------------------------------- 1 | import pandas 2 | import xarray 3 | 4 | 5 | def compute_climatology_by_time_group( 6 | data_files: list, group_by: list, output_directory: str 7 | ): 8 | """ 9 | Compute statistics (mean and standard deviation) by time group and save to netCDF. 10 | 11 | Parameters 12 | ---------- 13 | data_files : list 14 | A list of file paths containing the data. 15 | group_by: list 16 | A list of time components to make the aggregation. 17 | output_directory : str 18 | The directory where the output NetCDF files will be saved. 19 | 20 | Returns 21 | ------- 22 | tuple 23 | A tuple containing the mean and standard deviation DataArrays computed by 24 | time group. 25 | """ 26 | # Open the multi-file dataset 27 | ds = xarray.open_mfdataset(data_files) 28 | 29 | # Create a MultiIndex for the time components given by group_by 30 | time_group_idx = pandas.MultiIndex.from_arrays( 31 | [ds[f"time.{x}"].values for x in group_by] 32 | ) 33 | 34 | # Assign time_group coordinate to the dataset 35 | ds.coords["time_group"] = ("time", time_group_idx) 36 | 37 | # Compute mean and standard deviation by time_group 38 | mean_by_group = ds.groupby("time_group").mean() 39 | mean_by_group.load() 40 | std_by_group = ds.groupby("time_group").std() 41 | std_by_group.load() 42 | 43 | # Save mean climatology as NetCDF files 44 | for time_group in list(mean_by_group.time_group.values): 45 | time_group_str = "_".join( 46 | [f"{group_by[i]}-{time_group[i]}" for i in range(len(group_by))] 47 | ) 48 | selection = mean_by_group.sel(time_group=time_group) 49 | selection = selection.drop( 50 | ["time_group"] + [f"time_level_{i}" for i in range(len(group_by))] 51 | ) 52 | for varname in list(selection.data_vars): 53 | filename = f"{output_directory}/{varname}_clim-mean_{time_group_str}.nc" 54 | selection.to_netcdf(filename) 55 | 56 | # Save standard deviation climatology as NetCDF files 57 | for time_group in list(std_by_group.time_group.values): 58 | time_group_str = "_".join( 59 | [f"{group_by[i]}-{time_group[i]}" for i in range(len(group_by))] 60 | ) 61 | selection = std_by_group.sel(time_group=time_group) 62 | selection = selection.drop( 63 | ["time_group"] + [f"time_level_{i}" for i in range(len(group_by))] 64 | ) 65 | for varname in list(selection.data_vars): 66 | filename = f"{output_directory}/{varname}_clim-std_{time_group_str}.nc" 67 | selection.to_netcdf(filename) 68 | 69 | # Return mean and standard deviation DataArrays 70 | return mean_by_group, std_by_group 71 | -------------------------------------------------------------------------------- /deepr/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/deepr/model/__init__.py -------------------------------------------------------------------------------- /deepr/model/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class GeGLU(nn.Module): 7 | def __init__(self, d_in: int, d_out: int): 8 | super().__init__() 9 | self.proj = nn.Linear(d_in, d_out * 2) 10 | 11 | def forward(self, x: torch.Tensor): 12 | x, gate = self.proj(x).chunk(2, dim=-1) 13 | return x * F.gelu(gate) 14 | 15 | 16 | class Swish(nn.Module): 17 | def forward(self, x): 18 | return x * torch.sigmoid(x) 19 | -------------------------------------------------------------------------------- /deepr/model/attention.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class AttentionBlock(nn.Module): 8 | def __init__( 9 | self, 10 | n_channels: int, 11 | n_heads: int = 1, 12 | d_k: Optional[int] = None, 13 | n_groups: int = 32, 14 | ): 15 | super().__init__() 16 | 17 | if d_k is None: 18 | d_k = n_channels 19 | self.norm = nn.GroupNorm(n_groups, n_channels) 20 | self.projection = nn.Linear(n_channels, n_heads * d_k * 3) 21 | self.output = nn.Linear(n_heads * d_k, n_channels) 22 | self.scale = d_k**-0.5 23 | self.n_heads = n_heads 24 | self.d_k = d_k 25 | 26 | def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None): 27 | # `t` is not used, but it's kept in the arguments because for the attention 28 | # layer function signature to match with `ResidualBlock`. 29 | _ = t 30 | batch_size, n_channels, height, width = x.shape 31 | x = x.view(batch_size, n_channels, -1).permute(0, 2, 1) 32 | # Get query, key, and values 33 | qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k) 34 | q, k, v = torch.chunk(qkv, 3, dim=-1) 35 | attn = torch.einsum("bihd,bjhd->bijh", q, k) * self.scale 36 | attn = attn.softmax(dim=2) 37 | res = torch.einsum("bijh,bjhd->bihd", attn, v) 38 | res = res.view(batch_size, -1, self.n_heads * self.d_k) 39 | res = self.output(res) 40 | res += x 41 | 42 | res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width) 43 | 44 | return res 45 | 46 | 47 | class CrossAttention(nn.Module): 48 | """It falls-back to self-attention when conditional embeddings are not specified.""" 49 | 50 | use_flash_attention: bool = False 51 | 52 | def __init__( 53 | self, 54 | d_model: int, 55 | d_cond: int, 56 | n_heads: int, 57 | d_head: int, 58 | is_inplace: bool = True, 59 | ): 60 | super().__init__() 61 | 62 | self.is_inplace = is_inplace 63 | self.n_heads = n_heads 64 | self.d_head = d_head 65 | 66 | # Attention scaling factor 67 | self.scale = d_head**-0.5 68 | 69 | # Query, key and value mappings 70 | d_attn = d_head * n_heads 71 | self.to_q = nn.Linear(d_model, d_attn, bias=False) 72 | self.to_k = nn.Linear(d_cond, d_attn, bias=False) 73 | self.to_v = nn.Linear(d_cond, d_attn, bias=False) 74 | 75 | self.to_out = nn.Sequential(nn.Linear(d_attn, d_model)) 76 | 77 | # Setup [flash attention](https://github.com/HazyResearch/flash-attention). 78 | # Flash attention is only used if it's installed 79 | # and `CrossAttention.use_flash_attention` is set to `True`. 80 | try: 81 | from flash_attn.flash_attention import FlashAttention 82 | 83 | self.flash = FlashAttention() 84 | # Set the scale for scaled dot-product attention. 85 | self.flash.softmax_scale = self.scale 86 | except ImportError: 87 | self.flash = None 88 | 89 | def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None): 90 | # If `cond` is `None` we perform self attention 91 | has_cond = cond is not None 92 | if not has_cond: 93 | cond = x 94 | 95 | # Get query, key and value vectors 96 | q = self.to_q(x) 97 | k = self.to_k(cond) 98 | v = self.to_v(cond) 99 | 100 | # Use flash attention if it's available and the head size is less than or equal to `128` 101 | if ( 102 | CrossAttention.use_flash_attention 103 | and self.flash is not None 104 | and not has_cond 105 | and self.d_head <= 128 106 | ): 107 | return self.flash_attention(q, k, v) 108 | else: 109 | return self.normal_attention(q, k, v) 110 | 111 | def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): 112 | batch_size, seq_len, _ = q.shape 113 | 114 | qkv = torch.stack((q, k, v), dim=2) 115 | qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) 116 | 117 | # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad 118 | # the heads to fit this size. 119 | if self.d_head <= 32: 120 | pad = 32 - self.d_head 121 | elif self.d_head <= 64: 122 | pad = 64 - self.d_head 123 | elif self.d_head <= 128: 124 | pad = 128 - self.d_head 125 | else: 126 | raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") 127 | 128 | # Pad the heads 129 | if pad: 130 | qkv = torch.cat( 131 | (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 132 | ) 133 | 134 | # Compute attention 135 | out, _ = self.flash(qkv) 136 | out = out[:, :, :, : self.d_head] 137 | out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) 138 | 139 | return self.to_out(out) 140 | 141 | def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): 142 | # Split them to heads 143 | q = q.view(*q.shape[:2], self.n_heads, -1) 144 | k = k.view(*k.shape[:2], self.n_heads, -1) 145 | v = v.view(*v.shape[:2], self.n_heads, -1) 146 | 147 | attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale 148 | 149 | # Compute softmax 150 | if self.is_inplace: 151 | half = attn.shape[0] // 2 152 | attn[half:] = attn[half:].softmax(dim=-1) 153 | attn[:half] = attn[:half].softmax(dim=-1) 154 | else: 155 | attn = attn.softmax(dim=-1) 156 | 157 | # Compute attention output 158 | out = torch.einsum("bhij,bjhd->bihd", attn, v) 159 | out = out.reshape(*out.shape[:2], -1) 160 | return self.to_out(out) 161 | -------------------------------------------------------------------------------- /deepr/model/autoencoder_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | import matplotlib.pyplot 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from accelerate import Accelerator, find_executable_batch_size 9 | from accelerate.utils import LoggerType 10 | from diffusers.optimization import get_cosine_schedule_with_warmup 11 | from huggingface_hub import Repository 12 | from tqdm import tqdm 13 | 14 | from deepr.data.generator import DataGenerator 15 | from deepr.model.configs import TrainingConfig 16 | from deepr.utilities.logger import get_logger 17 | from deepr.visualizations.plot_maps import get_figure_model_samples 18 | 19 | repo_name = "predictia/cerra_tas_vqvae" 20 | 21 | logger = get_logger(__name__) 22 | 23 | 24 | def save_samples( 25 | model, 26 | cerra: torch.Tensor, 27 | output_name: str, 28 | ) -> matplotlib.pyplot.Figure: 29 | """ 30 | Save a set of samples. 31 | 32 | Parameters 33 | ---------- 34 | model : nn.Module 35 | The model used for generating samples. 36 | cerra : torch.Tensor 37 | The CERRA data tensor. 38 | output_name : str 39 | The output file name. 40 | 41 | Returns 42 | ------- 43 | Figure: The figure. 44 | """ 45 | with torch.no_grad(): 46 | cerra_pred = model(cerra, return_dict=False)[0] 47 | 48 | figsize = 3 + 4.5 * cerra.shape[0], 8 49 | return get_figure_model_samples( 50 | cerra.cpu(), cerra_pred.cpu(), filename=output_name, fig_size=figsize 51 | ) 52 | 53 | 54 | def train_autoencoder( 55 | config: TrainingConfig, 56 | model, 57 | train_dataset: DataGenerator, 58 | val_dataset: DataGenerator, 59 | dataset_info: Dict = {}, 60 | ): 61 | """ 62 | Train a neural network model. 63 | 64 | Parameters 65 | ---------- 66 | config : TrainingConfig 67 | The training configuration. 68 | model : nn.Module 69 | The neural network model. 70 | train_dataset : DataGenerator 71 | The training dataset. 72 | val_dataset : DataGenerator 73 | The validation dataset. 74 | dataset_info : Dict, optional 75 | Additional dataset information, by default {}. 76 | 77 | Returns 78 | ------- 79 | model : nn.Module 80 | The trained model. 81 | repo_name : str 82 | The repository name. 83 | 84 | Notes 85 | ----- 86 | This function performs the training of a neural network model using the provided 87 | datasets and configuration. 88 | """ 89 | hparams = config.__dict__ 90 | number_model_params = int(sum([np.prod(m.size()) for m in model.parameters()])) 91 | if "number_model_params" not in hparams: 92 | hparams["number_model_params"] = number_model_params 93 | 94 | model_name = model.__class__.__name__ 95 | run_name = "Train VQ-VAE NN" 96 | 97 | # aim_tracker = AimTracker(run_name, logging_dir="aim://10.9.64.88:31441") 98 | accelerator = Accelerator( 99 | cpu=config.device == "cpu", 100 | device_placement=True, 101 | mixed_precision=config.mixed_precision, 102 | gradient_accumulation_steps=config.gradient_accumulation_steps, 103 | log_with=[LoggerType.TENSORBOARD], # aim_tracker 104 | project_dir=os.path.join(config.output_dir, "logs"), 105 | ) 106 | 107 | @find_executable_batch_size(starting_batch_size=64) 108 | def inner_training_loop(batch_size: int, model): 109 | nonlocal accelerator # Ensure they can be used in our context 110 | accelerator.free_memory() # Free all lingering references 111 | torch.cuda.empty_cache() 112 | 113 | # Define important objects 114 | dataloader = torch.utils.data.DataLoader( 115 | train_dataset, batch_size, pin_memory=True 116 | ) 117 | dataloader_val = torch.utils.data.DataLoader( 118 | val_dataset, batch_size, pin_memory=True 119 | ) 120 | 121 | optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate) 122 | lr_scheduler = get_cosine_schedule_with_warmup( 123 | optimizer=optimizer, 124 | num_warmup_steps=config.lr_warmup_steps, 125 | num_training_steps=(len(dataloader) * config.num_epochs), 126 | ) 127 | 128 | if accelerator.is_main_process: 129 | if config.push_to_hub: 130 | repo = Repository( 131 | config.output_dir, 132 | clone_from=repo_name.format(model=model_name.lower()), 133 | token=os.getenv("HF_TOKEN"), 134 | ) 135 | repo.git_pull() 136 | elif config.output_dir is not None: 137 | os.makedirs(config.output_dir, exist_ok=True) 138 | accelerator.init_trackers(run_name, config=hparams) 139 | tfboard_tracker = accelerator.get_tracker("tensorboard") 140 | 141 | ( 142 | model, 143 | optimizer, 144 | train_dataloader, 145 | val_dataloader, 146 | lr_scheduler, 147 | ) = accelerator.prepare( 148 | model, optimizer, dataloader, dataloader_val, lr_scheduler 149 | ) 150 | 151 | # Get fixed samples 152 | (val_cerra,) = next(iter(val_dataloader)) 153 | if batch_size > 4: 154 | val_cerra = val_cerra[:4] 155 | 156 | logger.info(f"Number of parameters: {number_model_params}") 157 | global_step = 0 158 | # Now you train the model 159 | for epoch in range(config.num_epochs): 160 | progress_bar = tqdm( 161 | total=len(train_dataloader) + len(val_dataloader), 162 | disable=not accelerator.is_local_main_process, 163 | ) 164 | progress_bar.set_description(f"Epoch {epoch+1}") 165 | 166 | for (cerra,) in train_dataloader: 167 | # Predict the noise residual 168 | with accelerator.accumulate(model): 169 | # Encode, quantize and decode 170 | h = model.encode(cerra).latents 171 | q, emb_loss, _ = model.quantize(h) 172 | q = model.post_quant_conv(q) 173 | cerra_pred = model.decoder(q) 174 | 175 | # Calculate the loss 176 | rec_loss = F.mse_loss(cerra, cerra_pred) 177 | loss = emb_loss + rec_loss 178 | 179 | accelerator.backward(loss) 180 | accelerator.clip_grad_norm_(model.parameters(), 1.0) 181 | optimizer.step() 182 | lr_scheduler.step() 183 | optimizer.zero_grad() 184 | 185 | progress_bar.update(1) 186 | lo = loss.detach().item() 187 | logs = { 188 | "loss_vs_step": lo, 189 | "loss_emb_vs_step": emb_loss.detach().item(), 190 | "loss_recon_vs_step": rec_loss.detach().item(), 191 | "lr_vs_step": lr_scheduler.get_last_lr()[0], 192 | "step": global_step, 193 | "epoch": epoch, 194 | } 195 | progress_bar.set_postfix(**logs) 196 | accelerator.log(logs, step=global_step) 197 | # tfboard_tracker.writer.add_histogram( 198 | # "cerra prediction", cerra_pred, global_step 199 | # ) 200 | # tfboard_tracker.writer.add_histogram("cerra", cerra, global_step) 201 | global_step += 1 202 | 203 | # Evaluate 204 | loss, loss_emb, loss_recs = [], [], [] 205 | for (cerra,) in val_dataloader: 206 | # Predict the noise residual 207 | with torch.no_grad(): 208 | # Encode, quantize and decode 209 | h = model.encode(cerra).latents 210 | quant, emb_loss, _ = model.quantize(h) 211 | quant2 = model.post_quant_conv(quant) 212 | cerra_pred = model.decoder(quant2) 213 | 214 | rec_loss = F.mse_loss(cerra, cerra_pred) 215 | 216 | loss.append(emb_loss + rec_loss) 217 | loss_emb.append(emb_loss) 218 | loss_recs.append(rec_loss) 219 | 220 | progress_bar.update(1) 221 | torch.cuda.empty_cache() 222 | 223 | logs = { 224 | "val_loss_vs_epoch": sum(loss) / len(loss), 225 | "val_loss_emb_vs_epoch": sum(loss_emb) / len(loss_emb), 226 | "val_loss_recon_vs_epoch": sum(loss_recs) / len(loss_recs), 227 | "epoch": epoch, 228 | } 229 | accelerator.log(logs, step=epoch) 230 | progress_bar.close() 231 | 232 | # After each epoch you optionally sample some demo images 233 | if accelerator.is_main_process: 234 | is_last_epoch = epoch == config.num_epochs - 1 235 | 236 | if (epoch + 1) % config.save_image_epochs == 0 or is_last_epoch: 237 | logger.info("Saving sample predictions...") 238 | samples_dir = os.path.join(config.output_dir, "samples") 239 | os.makedirs(samples_dir, exist_ok=True) 240 | fig = save_samples( 241 | accelerator.unwrap_model(model), 242 | val_cerra, 243 | output_name=f"{samples_dir}/{model_name}_{epoch+1:04d}.png", 244 | ) 245 | if is_last_epoch: 246 | tfboard_tracker.writer.add_figure( 247 | "Predictions", fig, global_step=epoch 248 | ) 249 | 250 | if (epoch + 1) % config.save_model_epochs == 0 or is_last_epoch: 251 | logger.info("Saving model weights...") 252 | model.save_pretrained(config.output_dir) 253 | if config.push_to_hub: 254 | repo.push_to_hub( 255 | commit_message=f"Epoch {epoch+1}", blocking=True 256 | ) 257 | 258 | return model 259 | 260 | trained_model = inner_training_loop(model) 261 | accelerator.end_training() 262 | 263 | return trained_model, repo_name 264 | -------------------------------------------------------------------------------- /deepr/model/conditional_ddpm.py: -------------------------------------------------------------------------------- 1 | from inspect import signature 2 | from typing import List, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput 8 | from diffusers.utils.torch_utils import randn_tensor 9 | 10 | from deepr.model.utils import get_hour_embedding 11 | 12 | 13 | class cDDPMPipeline(DiffusionPipeline): 14 | r""" 15 | DDPM conditioned on images. 16 | 17 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation 18 | for the generic methods the library implements for all the pipelines (such as 19 | downloading or saving, running on a particular device, etc.). 20 | 21 | Parameters 22 | ---------- 23 | unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image. 24 | scheduler ([`SchedulerMixin`]): 25 | A scheduler to be used in combination with `unet` to denoise the encoded 26 | image. Can be one of [`DDPMScheduler`], or [`DDIMScheduler`]. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | unet, 32 | scheduler, 33 | obs_model=None, 34 | baseline_interpolation_method: Optional[str] = "bicubic", 35 | learn_residuals: Optional[bool] = False, 36 | hour_embed_type: [Optional] = "class", 37 | hour_embed_dim: Optional[int] = 64, 38 | instance_norm: Optional[bool] = False, 39 | ): 40 | super().__init__() 41 | self.baseline_interpolation_method = baseline_interpolation_method 42 | self.hour_embed_type = hour_embed_type 43 | self.hour_embed_dim = hour_embed_dim 44 | self.instance_norm = instance_norm 45 | self.learn_residuals = learn_residuals 46 | self.register_modules(unet=unet, scheduler=scheduler, obs_model=obs_model) 47 | 48 | @torch.no_grad() 49 | def __call__( 50 | self, 51 | images: Union[torch.Tensor, List[torch.Tensor]], 52 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 53 | num_inference_steps: int = 1000, 54 | eta: Optional[float] = 1.0, 55 | class_labels: Optional[List[int]] = None, 56 | output_type: Optional[str] = "pil", 57 | saving_freq_interm: int = 0, 58 | return_dict: bool = True, 59 | ) -> Union[ImagePipelineOutput, Tuple]: 60 | # Get batch size 61 | if isinstance(images, torch.Tensor): 62 | batch_size = images.shape[0] 63 | elif isinstance(images, list): 64 | batch_size = len(images) 65 | else: 66 | raise ValueError(f"Unsupported type {type(images)} for `images` argument.") 67 | 68 | # Sample gaussian noise to begin loop 69 | if isinstance(self.unet.config.sample_size, int): 70 | image_shape = ( 71 | batch_size, 72 | self.unet.config.out_channels, 73 | self.unet.config.sample_size, 74 | self.unet.config.sample_size, 75 | ) 76 | else: 77 | image_shape = ( 78 | batch_size, 79 | self.unet.config.out_channels, 80 | *self.unet.config.sample_size, 81 | ) 82 | 83 | if self.obs_model is not None: 84 | self.obs_model = self.obs_model.to(self.device) 85 | up_images = self.obs_model(images.to(self.device))[0].to(self.device) 86 | else: 87 | up_images = F.interpolate( 88 | images, scale_factor=5, mode=self.baseline_interpolation_method 89 | ) 90 | l_lat, l_lon = (np.array(up_images.shape[-2:]) - image_shape[-2:]) // 2 91 | r_lat = None if l_lat == 0 else -l_lat 92 | r_lon = None if l_lon == 0 else -l_lon 93 | up_images = up_images[..., l_lat:r_lat, l_lon:r_lon].to(self.device) 94 | 95 | if self.instance_norm: 96 | m = up_images.mean((1, 2, 3))[:, np.newaxis, np.newaxis, np.newaxis] 97 | s = up_images.std((1, 2, 3))[:, np.newaxis, np.newaxis, np.newaxis] 98 | up_images = (up_images - m) / s 99 | 100 | if self.device.type == "mps": 101 | # randn does not work reproducibly on mps 102 | latents = randn_tensor(image_shape, generator=generator) 103 | latents = latents.to(self.device) 104 | else: 105 | latents = randn_tensor(image_shape, generator=generator, device=self.device) 106 | 107 | # support for DDIM scheduler 108 | accepts_eta = "eta" in set(signature(self.scheduler.step).parameters.keys()) 109 | extra_kwargs = {} 110 | if accepts_eta: 111 | extra_kwargs["eta"] = eta 112 | 113 | # Support for LSMDiscreteScheduler 114 | if "generator" in set(signature(self.scheduler.step).parameters.keys()): 115 | extra_kwargs["generator"] = generator 116 | 117 | # Hour encoding. Passed to NN as class labels 118 | if class_labels is not None: 119 | class_labels = get_hour_embedding( 120 | class_labels, self.hour_embed_type, self.hour_embed_dim 121 | ) 122 | class_labels = class_labels.to(self.device).squeeze() 123 | 124 | # set step values 125 | self.scheduler.set_timesteps(num_inference_steps, device=self.device) 126 | latents = latents * self.scheduler.init_noise_sigma 127 | 128 | intermediate_images = [] 129 | for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): 130 | if saving_freq_interm > 0 and i % saving_freq_interm == 0: 131 | intermediate_images.append(latents.cpu()) 132 | 133 | latents_input = torch.cat([latents, up_images], axis=1) 134 | latents_input = self.scheduler.scale_model_input(latents_input, t) 135 | 136 | # 1. predict noise model_output 137 | model_output = self.unet(latents_input, t, class_labels=class_labels).sample 138 | 139 | # 2. compute previous image: x_t -> x_t-1 140 | latents = self.scheduler.step( 141 | model_output, t, latents, **extra_kwargs 142 | ).prev_sample 143 | 144 | if saving_freq_interm > 0: 145 | intermediate_images.append(latents.cpu()) 146 | intermediate_images = torch.cat(intermediate_images, dim=1) 147 | 148 | if self.learn_residuals: 149 | latents = latents + up_images 150 | 151 | if self.instance_norm: 152 | latents = latents * s + m 153 | 154 | image = latents.cpu().numpy() 155 | if output_type == "pil": 156 | image = self.numpy_to_pil(image) 157 | elif output_type == "tensor": 158 | image = torch.tensor(image) 159 | 160 | if not return_dict: 161 | return image, intermediate_images if saving_freq_interm > 0 else (image,) 162 | 163 | return ImagePipelineOutput(images=image) 164 | -------------------------------------------------------------------------------- /deepr/model/configs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List, Optional 4 | 5 | import torch 6 | from pydantic.dataclasses import dataclass 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | @dataclass 12 | class TrainingConfig: 13 | num_epochs: int = 50 14 | batch_size: int = 2 15 | num_workers: int = 0 16 | num_samples: int = 3 17 | gradient_accumulation_steps = 1 18 | learning_rate: float = 1e-4 19 | lr_warmup_steps: int = 500 20 | save_image_epochs: Optional[int] = None 21 | save_model_epochs: Optional[int] = None 22 | instance_norm: Optional[bool] = False 23 | learn_residuals: Optional[bool] = False 24 | hour_embed_type: str = "none" 25 | hour_embed_size: int = 64 26 | device: str = "cuda" 27 | mixed_precision: str = ( 28 | "fp16" # `no` for float32, `fp16` for automatic mixed precision 29 | ) 30 | output_dir: str = "ddpm-probando-128" # the model name locally and on the HF Hub 31 | push_to_hub: bool = False # whether to upload the saved model to the HF Hub 32 | hub_private_repo: bool = False 33 | hf_repo_name: str = "" 34 | overwrite_output_dir: bool = ( 35 | True # overwrite the old model when re-running the notebook 36 | ) 37 | static_covariables: List[str] = None 38 | seed: int = 0 39 | 40 | def __post_init__(self): 41 | if self.device == "cuda" and not torch.cuda.is_available(): 42 | logger.info("CUDA device requested but not available :(") 43 | 44 | if self.output_dir is not None: 45 | os.makedirs(self.output_dir, exist_ok=True) 46 | 47 | def _is_last_epoch(self, epoch: int): 48 | return epoch == self.num_epochs - 1 49 | 50 | def is_save_model_time(self, epoch: int): 51 | if self.save_model_epochs is None: 52 | return False or self._is_last_epoch(epoch) 53 | 54 | _epoch_save = (epoch + 1) % self.save_model_epochs == 0 55 | return _epoch_save or self._is_last_epoch(epoch) 56 | 57 | def is_save_images_time(self, epoch: int): 58 | if self.save_image_epochs is None: 59 | return False or self._is_last_epoch(epoch) 60 | 61 | _epoch_save = (epoch + 1) % self.save_image_epochs == 0 62 | return _epoch_save or self._is_last_epoch(epoch) 63 | -------------------------------------------------------------------------------- /deepr/model/conv_baseline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from math import ceil, log2 3 | from typing import List, Optional, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from transformers import PretrainedConfig, PreTrainedModel 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ConvBaselineConfig(PretrainedConfig): 14 | model_type = "convbilinear" 15 | 16 | attribute_map = {"hidden_size": "embed_dim"} 17 | 18 | def __init__( 19 | self, 20 | num_channels: int = 1, 21 | upscale: int = 1, 22 | interpolation_method: str = "bicubic", 23 | input_shape: Tuple[int] = None, 24 | image_size: Tuple[int] = None, 25 | upblock_channels: List[int] = [64, 32], 26 | upblock_kernel_size: List[int] = [5, 3], 27 | **kwargs, 28 | ): 29 | super().__init__(**kwargs) 30 | self.num_channels = num_channels 31 | self.upscale = upscale 32 | self.input_shape = input_shape 33 | if hasattr(self, "sample_size"): 34 | self.image_size = tuple(map(lambda x: x // self.upscale, self.sample_size)) 35 | self.upblock_channels = upblock_channels 36 | self.upblock_kernel_size = upblock_kernel_size 37 | self.interpolation_method = interpolation_method 38 | self.upscale_power2 = int(ceil(log2(upscale))) 39 | 40 | 41 | class UpConvBlock(nn.Module): 42 | def __init__( 43 | self, 44 | in_channel: int, 45 | out_channel: int, 46 | interm_channels: List[int] = [64, 32], 47 | kernel_size: List[int] = [5, 3], 48 | upscale_ratio: int = 2, 49 | ): 50 | super(UpConvBlock, self).__init__() 51 | n_channels = out_channel * (upscale_ratio**2) 52 | 53 | self.conv1 = nn.Conv2d( 54 | in_channel, interm_channels[0], kernel_size=5, padding="same" 55 | ) 56 | self.conv2 = nn.Conv2d( 57 | interm_channels[0], interm_channels[1], kernel_size=3, padding="same" 58 | ) 59 | self.conv3 = nn.Conv2d( 60 | interm_channels[1], n_channels, kernel_size=1, padding="same" 61 | ) 62 | self.conv4 = nn.ConvTranspose2d( 63 | n_channels, out_channel, kernel_size=2, padding=0, stride=upscale_ratio 64 | ) 65 | self.relu = nn.ReLU() 66 | 67 | def forward(self, x: torch.Tensor) -> torch.Tensor: 68 | out = self.conv1(x) 69 | out = self.relu(out) 70 | out = self.conv2(out) 71 | out = self.relu(out) 72 | out = self.conv3(out) 73 | out = self.relu(out) 74 | out = self.conv4(out) 75 | return out 76 | 77 | 78 | class ConvBaseline(PreTrainedModel): 79 | config_class = ConvBaselineConfig 80 | base_model_prefix = "convbaseline" 81 | main_input_name = "pixel_values" 82 | supports_gradient_checkpointing = True 83 | upscale_ratio_upconv = 2 84 | 85 | def __init__(self, config: ConvBaselineConfig): 86 | super().__init__(config) 87 | 88 | self.input_upconv_shape = np.array(config.sample_size) / ( 89 | self.upscale_ratio_upconv**self.config.upscale_power2 90 | ) 91 | kernel_size = ( 92 | np.array(config.input_shape) - self.input_upconv_shape + 1 93 | ).astype(int) 94 | extra_pixels = np.array(config.input_shape) - config.image_size 95 | self.from_lat = int(extra_pixels[0] // 2) 96 | self.from_lon = int(extra_pixels[1] // 2) 97 | self.to_lat = -self.from_lat if self.from_lat > 0 else None 98 | self.to_lon = -self.from_lon if self.from_lon > 0 else None 99 | 100 | self.preprocess_model = nn.Conv2d( 101 | config.num_channels, 102 | config.num_channels * self.upscale_ratio_upconv**2, 103 | kernel_size=tuple(kernel_size), 104 | padding=0, 105 | ) 106 | 107 | convs: List[nn.Module] = [] 108 | in_channels = config.num_channels * self.upscale_ratio_upconv**2 109 | for i in range(self.config.upscale_power2): 110 | j = self.config.upscale_power2 - i - 1 111 | out_channel = config.num_channels * (self.upscale_ratio_upconv**2) ** j 112 | convs.append( 113 | UpConvBlock( 114 | in_channel=in_channels, 115 | out_channel=out_channel, 116 | interm_channels=config.upblock_channels, 117 | kernel_size=config.upblock_kernel_size, 118 | upscale_ratio=self.upscale_ratio_upconv, 119 | ) 120 | ) 121 | in_channels = out_channel 122 | self.cnns = nn.ModuleList(convs) 123 | 124 | super().post_init() 125 | 126 | def forward( 127 | self, 128 | pixel_values: Optional[torch.FloatTensor] = None, 129 | return_dict: Optional[bool] = None, 130 | ): 131 | # Baseline interpoletion 132 | out_baseline = torch.nn.functional.interpolate( 133 | pixel_values[..., self.from_lat : self.to_lat, self.from_lon : self.to_lon], 134 | mode=self.config.interpolation_method, 135 | scale_factor=self.config.upscale, 136 | ) 137 | 138 | # Use Transposed Convolutions to generate a target matrix. 139 | h = self.preprocess_model(pixel_values) 140 | for conv in self.cnns: 141 | h = conv(h) 142 | out_upconv = h 143 | 144 | if not return_dict: 145 | return (out_upconv + out_baseline,) 146 | 147 | return out_upconv + out_baseline 148 | -------------------------------------------------------------------------------- /deepr/model/conv_swin2sr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from math import ceil, log2 3 | from typing import List, Optional 4 | 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | from transformers import PreTrainedModel, Swin2SRConfig, Swin2SRForImageSuperResolution 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ConvSwin2SRConfig(Swin2SRConfig): 14 | model_type = "conv_swin2sr" 15 | 16 | attribute_map = { 17 | "hidden_size": "embed_dim", 18 | "num_attention_heads": "num_heads", 19 | "num_hidden_layers": "num_layers", 20 | } 21 | 22 | def __init__( 23 | self, 24 | image_size: int = 64, 25 | patch_size: int = 1, 26 | num_channels: int = None, 27 | embed_dim: int = 180, 28 | depths: List[int] = [6, 6, 6, 6, 6, 6], 29 | num_heads: List[int] = [6, 6, 6, 6, 6, 6], 30 | window_size: int = 8, 31 | mlp_ratio: float = 2.0, 32 | qkv_bias: bool = True, 33 | hidden_dropout_prob: float = 0.0, 34 | attention_probs_dropout_prob: float = 0.0, 35 | drop_path_rate: float = 0.1, 36 | hidden_act: str = "gelu", 37 | use_absolute_embeddings: bool = False, 38 | initializer_range: bool = 0.02, 39 | layer_norm_eps: float = 1e-5, 40 | upscale: int = 2, 41 | img_range: float = 1.0, 42 | resi_connection: str = "1conv", 43 | upsampler: str = "pixelshuffle", 44 | interpolation_method: str = "bicubic", 45 | num_high_res_covars: int = 0, 46 | **kwargs, 47 | ): 48 | self.interpolation_method = interpolation_method 49 | self.num_high_res_covars = num_high_res_covars 50 | 51 | if "real_upscale" in kwargs.keys(): 52 | self.real_upscale = kwargs["real_upscale"] 53 | else: 54 | self.real_upscale = upscale 55 | 56 | if "sample_size" in kwargs.keys(): 57 | self.image_size = tuple( 58 | map(lambda x: x // self.real_upscale, kwargs["sample_size"]) 59 | ) 60 | else: 61 | self.image_size = image_size 62 | 63 | upscale_power2 = int(ceil(log2(self.real_upscale))) 64 | super().__init__( 65 | image_size=self.image_size, 66 | patch_size=patch_size, 67 | num_channels=num_channels, 68 | embed_dim=embed_dim, 69 | depths=depths, 70 | num_heads=num_heads, 71 | window_size=window_size, 72 | mlp_ratio=mlp_ratio, 73 | qkv_bias=qkv_bias, 74 | hidden_dropout_prob=hidden_dropout_prob, 75 | attention_probs_dropout_prob=attention_probs_dropout_prob, 76 | drop_path_rate=drop_path_rate, 77 | hidden_act=hidden_act, 78 | use_absolute_embeddings=use_absolute_embeddings, 79 | initializer_range=initializer_range, 80 | layer_norm_eps=layer_norm_eps, 81 | upscale=2**upscale_power2, 82 | img_range=img_range, 83 | resi_connection=resi_connection, 84 | upsampler=upsampler, 85 | **kwargs, 86 | ) 87 | 88 | def swin2sr_kwargs(self): 89 | logger.info( 90 | f"The Swin2SR(x{self.upscale}) model should receive pixel values of shape" 91 | f"{self.image_size}." 92 | ) 93 | return Swin2SRConfig( 94 | image_size=self.image_size, 95 | patch_size=self.patch_size, 96 | num_channels=self.num_channels, 97 | embed_dim=self.embed_dim, 98 | depths=self.depths, 99 | num_heads=self.num_heads, 100 | window_size=self.window_size, 101 | mlp_ratio=self.mlp_ratio, 102 | qkv_bias=self.qkv_bias, 103 | hidden_dropout_prob=self.hidden_dropout_prob, 104 | attention_probs_dropout_prob=self.attention_probs_dropout_prob, 105 | drop_path_rate=self.drop_path_rate, 106 | hidden_act=self.hidden_act, 107 | use_absolute_embeddings=self.use_absolute_embeddings, 108 | initializer_range=self.initializer_range, 109 | layer_norm_eps=self.layer_norm_eps, 110 | upscale=self.upscale, 111 | img_range=self.img_range, 112 | resi_connection=self.resi_connection, 113 | upsampler=self.upsampler, 114 | ) 115 | 116 | 117 | class ConvSwin2SR(PreTrainedModel): 118 | config_class = ConvSwin2SRConfig 119 | base_model_prefix = "convswin2sr" 120 | main_input_name = "pixel_values" 121 | supports_gradient_checkpointing = True 122 | 123 | def __init__(self, config: ConvSwin2SRConfig): 124 | super().__init__(config) 125 | self.config = config 126 | 127 | # Define center region to upscale 128 | extra_pixels = np.array(config.input_shape) - config.image_size 129 | self.from_lat = int(extra_pixels[0] // 2) 130 | self.from_lon = int(extra_pixels[1] // 2) 131 | self.to_lat = -self.from_lat if self.from_lat > 0 else None 132 | self.to_lon = -self.from_lon if self.from_lon > 0 else None 133 | 134 | # Set preprocess layer to match the output shapes 135 | self.input_upconv_shape = np.array(config.sample_size) / config.upscale 136 | kernel_size = ( 137 | np.array(config.input_shape) - self.input_upconv_shape + 1 138 | ).astype(int) 139 | self.preprocess_model = nn.Conv2d( 140 | config.num_channels, 141 | config.num_channels, 142 | kernel_size=tuple(kernel_size), 143 | padding=0, 144 | ) 145 | 146 | self.swin = Swin2SRForImageSuperResolution(config.swin2sr_kwargs()) 147 | 148 | if self.config.num_high_res_covars > 0: 149 | self.merge_covars_interp = nn.Conv2d( 150 | config.num_channels + config.num_high_res_covars, config.num_channels, 1 151 | ) 152 | self.merge_covars_swin2sr = nn.Conv2d( 153 | config.num_channels + config.num_high_res_covars, config.num_channels, 1 154 | ) 155 | 156 | super().post_init() 157 | 158 | def forward( 159 | self, 160 | pixel_values: Optional[torch.FloatTensor] = None, 161 | head_mask: Optional[torch.FloatTensor] = None, 162 | labels: Optional[torch.LongTensor] = None, 163 | output_attentions: Optional[bool] = None, 164 | covariables: Optional[torch.FloatTensor] = None, 165 | output_hidden_states: Optional[bool] = None, 166 | return_dict: Optional[bool] = None, 167 | ): 168 | out_baseline = torch.nn.functional.interpolate( 169 | pixel_values[..., self.from_lat : self.to_lat, self.from_lon : self.to_lon], 170 | mode=self.config.interpolation_method, 171 | scale_factor=self.config.real_upscale, 172 | ) 173 | 174 | if self.config.num_high_res_covars > 0 and covariables is not None: 175 | covariables = torch.tile(covariables, (out_baseline.shape[0], 1, 1, 1)) 176 | out_baseline = self.merge_covars_interp( 177 | torch.cat([out_baseline, covariables], dim=1) 178 | ) 179 | 180 | h = self.preprocess_model(pixel_values) 181 | 182 | # Apply Denoising Swin2SR 183 | (out_swin2sr,) = self.swin( 184 | pixel_values=h, 185 | head_mask=head_mask, 186 | labels=labels, 187 | output_attentions=output_attentions, 188 | output_hidden_states=output_hidden_states, 189 | return_dict=False, 190 | ) 191 | 192 | if self.config.num_high_res_covars > 0 and covariables is not None: 193 | out_swin2sr = self.merge_covars_swin2sr( 194 | torch.cat([out_swin2sr, covariables], dim=1) 195 | ) 196 | 197 | if not return_dict: 198 | return (out_baseline + out_swin2sr,) 199 | 200 | return out_baseline + out_swin2sr 201 | -------------------------------------------------------------------------------- /deepr/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision.transforms import GaussianBlur 4 | 5 | blur = GaussianBlur(5) 6 | pooling = torch.nn.AvgPool2d(kernel_size=5) 7 | 8 | 9 | def compute_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 10 | """Loss terms' computation. 11 | 12 | The first loss term is the L1 loss of the predictions with the references. The 2nd 13 | loss term refers to the L1 loss between the downsampled prediction and reference. 14 | Lastly, the 3rd loss term is the L1 loss between the blurred prediction and the 15 | blurred reference. 16 | 17 | Args: 18 | ---- 19 | prediction (torch.Tensor): prediction tensor 20 | target (torch.Tensor): target tensor 21 | 22 | Returns: 23 | ------- 24 | torch.Tensor: the 3 loss terms 25 | """ 26 | l1 = F.l1_loss(prediction, target) 27 | l1_lowres = F.l1_loss(pooling(prediction), pooling(target)) 28 | l1_blur = F.l1_loss(blur(prediction), blur(target)) 29 | return l1, l1_lowres, l1_blur 30 | -------------------------------------------------------------------------------- /deepr/model/models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import diffusers 4 | from torch import nn 5 | 6 | from deepr.utilities.logger import get_logger 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | def load_trained_model(class_name: str = None, model_dir: str = None) -> nn.Module: 12 | """Load a trained model and return it in evaluation mode. 13 | 14 | Args: 15 | ---- 16 | class_name (str): Name of the model class. Options are 17 | model_dir (str): Directory where the model is stored. 18 | 19 | Returns: 20 | ------- 21 | nn.Module: the model in evaluation mode. 22 | """ 23 | if class_name is None or model_dir is None: 24 | return None 25 | elif class_name.lower() == "convbaseline": 26 | from deepr.model.conv_baseline import ConvBaseline 27 | 28 | model = ConvBaseline.from_pretrained(model_dir) 29 | elif class_name.lower() == "convswin2sr": 30 | from deepr.model.conv_swin2sr import ConvSwin2SR 31 | 32 | model = ConvSwin2SR.from_pretrained(model_dir) 33 | elif class_name.lower() == "cddpm": 34 | from deepr.model.conditional_ddpm import cDDPMPipeline 35 | 36 | model = cDDPMPipeline.from_pretrained(model_dir) 37 | elif class_name.split(".")[0].lower() == "diffusers": 38 | return getattr(diffusers, class_name.split(".")[1]).from_pretrained(model_dir) 39 | else: 40 | logger.warning( 41 | f"The class_name {class_name} is not implemented. " 42 | f"Options are 'convbaseline', 'convswin2sr' and 'cddpm." 43 | ) 44 | return None 45 | model.eval() 46 | return model 47 | 48 | 49 | def get_neural_network( 50 | class_name: str, 51 | kwargs: dict, 52 | input_shape: Tuple[int] = None, 53 | sample_size: Tuple[int] = None, 54 | out_channels: int = None, 55 | static_covariables: List[str] = None, 56 | ) -> nn.Module: 57 | """Get neural network. 58 | 59 | Given a class name and a dictionary of keyword arguments, returns an instance of a 60 | neural network. Current options are: "UNet". 61 | 62 | Arguments 63 | --------- 64 | class_name : str 65 | The name of the neural network class to use. 66 | kwargs : dict 67 | Dictionary of keyword arguments to pass to the neural network constructor. 68 | input_shape : Optional[tuple] 69 | Sample size of the input samples. 70 | sample_size : Optional[tuple] 71 | Sample size of the target samples. 72 | out_channels : Optional[int] 73 | Output channels of the target samples. 74 | 75 | Returns 76 | ------- 77 | model: nn.Module 78 | An instance of a neural network. 79 | 80 | Raises: 81 | ------ 82 | NotImplementedError: If the specified neural network class is not implemented. 83 | """ 84 | if "sample_size" in kwargs: 85 | kwargs["sample_size"] = tuple(kwargs["sample_size"]) 86 | elif sample_size is None: 87 | raise ValueError(f"sample_size must be specified for {class_name}") 88 | else: 89 | kwargs["sample_size"] = sample_size 90 | 91 | if "out_channels" not in kwargs and out_channels is not None: 92 | kwargs["out_channels"] = out_channels 93 | 94 | if class_name.lower() == "unet": 95 | from deepr.model.unet import UNet 96 | 97 | return UNet(**kwargs) 98 | elif class_name.lower() == "convswin2sr": 99 | from deepr.model.conv_swin2sr import ConvSwin2SR, ConvSwin2SRConfig 100 | 101 | kwargs["num_channels"] = kwargs.pop("out_channels") 102 | if input_shape is not None: 103 | kwargs["input_shape"] = input_shape 104 | 105 | if static_covariables is not None: 106 | kwargs["num_high_res_covars"] = len(static_covariables) 107 | 108 | cfg = ConvSwin2SRConfig(**kwargs) 109 | return ConvSwin2SR(cfg) 110 | elif class_name.lower() == "convbaseline": 111 | from deepr.model.conv_baseline import ConvBaseline, ConvBaselineConfig 112 | 113 | kwargs["num_channels"] = kwargs.pop("out_channels") 114 | if input_shape is not None: 115 | kwargs["input_shape"] = input_shape 116 | 117 | cfg = ConvBaselineConfig(**kwargs) 118 | return ConvBaseline(cfg) 119 | elif class_name.split(".")[0].lower() == "diffusers": 120 | return getattr(diffusers, class_name.split(".")[1])(**kwargs) 121 | elif class_name.split(".")[0].lower() == "transformers": 122 | import transformers 123 | 124 | return transformers.__dict__[class_name.split(".")[1]](**kwargs) 125 | else: 126 | raise NotImplementedError(f"{class_name} is not implemented") 127 | 128 | 129 | def get_hf_scheduler(class_name: str, kwargs: dict) -> diffusers.SchedulerMixin: 130 | logger.info(f"Loading scheduler {class_name}.") 131 | return getattr(diffusers, class_name)(**kwargs) 132 | -------------------------------------------------------------------------------- /deepr/model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from deepr.model.activations import Swish 5 | 6 | 7 | class ResidualBlock(nn.Module): 8 | """ 9 | Residual block. 10 | 11 | A residual block has two convolution layers with group normalization. 12 | Each resolution is processed with two residual blocks. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | in_channels: int, 18 | out_channels: int, 19 | time_channels: int, 20 | n_groups: int = 8, 21 | dropout: float = 0.1, 22 | ): 23 | """CNN block with Group Normalization, Swish activation, and Conv. layers. 24 | 25 | The block takes input channel values, output channel values, time channels, 26 | number of groups (n_groups), and dropout rate as parameters. 27 | 28 | Parameters 29 | ---------- 30 | in_channels: int 31 | Number of input channels. 32 | out_channels: int 33 | Number of output channels. 34 | time_channels: int 35 | Number of time channels. 36 | n_groups: int, optional (default=`32`) 37 | Number of groups. 38 | dropout: float, optional (default=`0.1`) 39 | Dropout rate. 40 | """ 41 | super().__init__() 42 | # Group normalization and the first convolution layer 43 | self.norm1 = nn.GroupNorm(n_groups, in_channels) 44 | self.act1 = Swish() 45 | self.conv1 = nn.Conv2d( 46 | in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1) 47 | ) 48 | 49 | # Group normalization and the second convolution layer 50 | self.norm2 = nn.GroupNorm(n_groups, out_channels) 51 | self.act2 = Swish() 52 | self.conv2 = nn.Conv2d( 53 | out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1) 54 | ) 55 | 56 | # If the number of input channels is not equal to the number of output channels 57 | # we have to project the shortcut connection 58 | if in_channels != out_channels: 59 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) 60 | 61 | # Linear layer for time embeddings 62 | self.time_emb = nn.Linear(time_channels, out_channels) 63 | self.time_act = Swish() 64 | 65 | self.dropout = nn.Dropout(dropout) 66 | 67 | def forward(self, x: torch.Tensor, t: torch.Tensor): 68 | """Forward pass. 69 | 70 | Parameters 71 | ---------- 72 | x : torch.Tensor 73 | Input vector with shape `[batch_size, in_channels, height, width]`. 74 | t : torch.Tensor 75 | Time vector `[batch_size, time_channels]`. 76 | 77 | Returns 78 | ------- 79 | torch.Tensor: vector with shape `[batch_size, out_channels, height, width]`. 80 | """ 81 | h = self.conv1(self.act1(self.norm1(x))) 82 | h += self.time_emb(self.time_act(t))[:, :, None, None] 83 | h = self.conv2(self.dropout(self.act2(self.norm2(h)))) 84 | if hasattr(self, "shortcut"): 85 | h += self.shortcut(x) 86 | return h 87 | -------------------------------------------------------------------------------- /deepr/model/unet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers import ConfigMixin, ModelMixin 6 | from diffusers.configuration_utils import register_to_config 7 | from diffusers.models.embeddings import ( 8 | GaussianFourierProjection, 9 | TimestepEmbedding, 10 | Timesteps, 11 | ) 12 | from diffusers.models.unet_2d import UNet2DOutput 13 | 14 | from deepr.model.activations import Swish 15 | from deepr.model.unet_blocks import ( 16 | DownBlock, 17 | Downsample, 18 | MiddleBlock, 19 | UpBlock, 20 | Upsample, 21 | ) 22 | 23 | 24 | class UNet(ModelMixin, ConfigMixin): 25 | @register_to_config 26 | def __init__( 27 | self, 28 | out_channels: int = 1, 29 | in_channels: int = 1, 30 | sample_size: Optional[Union[int, Tuple[int, int]]] = None, 31 | time_embedding_type: str = "positional", 32 | flip_sin_to_cos: bool = True, 33 | freq_shift: int = 0, 34 | block_out_channels: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4), 35 | is_attention: Union[Tuple[bool, ...], List[bool]] = (False, False, True, True), 36 | layers_per_block: int = 2, 37 | ): 38 | """ 39 | U-Net. 40 | 41 | NOTE: The spatial shapes of the input must be divisible by 2^{n_resolutions - 1} 42 | where the number of resolutions is specified by the length of the 43 | 'channel_multipliers' and 'is_attention' arguments. 44 | 45 | Parameters 46 | ---------- 47 | out_channels : int 48 | Number of channels in the output image. 49 | in_channels : int 50 | Number of channels of the input 2D matrix. 51 | sample_size : int | Tuple[int, int] 52 | Spatial dimension of the samples. 53 | time_embedding_type : str 54 | Type of time embedding. Options are: "positional" and "fourier". 55 | freq_shift : int 56 | Frequency shift of the Fourier time embedding. 57 | block_out_channels : Union[Tuple[int, ...], List[int]] 58 | The output channels for each resolution level of the U-Net. 59 | is_attention : Union[Tuple[bool, ...], List[int]] 60 | Whether to use attention mechanism at each resolution level of the U-Net. 61 | layers_per_block : int 62 | Number of residual blocks at each resolution level of the U-Net. 63 | conditioned_on_input : Union[bool, int] 64 | Whether to use conditioning on other inputs, or the number of conditions. 65 | """ 66 | super().__init__() 67 | self.sample_size = sample_size 68 | n_resolutions = len(block_out_channels) 69 | init_channels = block_out_channels[0] 70 | 71 | # Project input + conditions 72 | self.image_proj = nn.Conv2d( 73 | self.config.in_channels, 74 | init_channels, 75 | kernel_size=(3, 3), 76 | padding=(1, 1), 77 | ) 78 | 79 | # Time Embedding 80 | if time_embedding_type == "fourier": 81 | self.time_proj = GaussianFourierProjection( 82 | embedding_size=init_channels, scale=16 83 | ) 84 | timestep_input_dim = 2 * init_channels 85 | elif time_embedding_type == "positional": 86 | self.time_proj = Timesteps(init_channels, flip_sin_to_cos, freq_shift) 87 | timestep_input_dim = init_channels 88 | 89 | self.time_embedding = TimestepEmbedding(timestep_input_dim, init_channels * 4) 90 | 91 | # First half of U-Net - decreasing resolution 92 | down: List[nn.Module] = [] 93 | in_ch_down = init_channels 94 | for i, out_ch_down in enumerate(block_out_channels): 95 | # Resnet Blocks 96 | for _ in range(layers_per_block): 97 | down.append( 98 | DownBlock( 99 | in_ch_down, out_ch_down, init_channels * 4, is_attention[i] 100 | ) 101 | ) 102 | in_ch_down = out_ch_down 103 | # Down sample at all resolutions except the last 104 | if i < n_resolutions - 1: 105 | down.append(Downsample(in_ch_down)) 106 | 107 | self.down = nn.ModuleList(down) 108 | 109 | # Middle block 110 | self.middle = MiddleBlock(out_ch_down, init_channels * 4) 111 | in_ch_up = out_ch_down 112 | 113 | # Second half of U-Net - increasing resolution 114 | up: List[nn.Module] = [] 115 | for i, out_ch_up in reversed(list(enumerate(block_out_channels))): 116 | for _ in range(layers_per_block): 117 | up.append( 118 | UpBlock(in_ch_up, in_ch_up, init_channels * 4, is_attention[i]) 119 | ) 120 | 121 | # Final block to reduce the number of channels 122 | up.append(UpBlock(in_ch_up, out_ch_up, init_channels * 4, is_attention[i])) 123 | in_ch_up = out_ch_up 124 | 125 | # Up sample at all resolutions except last 126 | if i > 0: 127 | up.append(Upsample(in_ch_up)) 128 | 129 | # Combine the set of modules 130 | self.up = nn.ModuleList(up) 131 | 132 | # Final normalization and convolution layer 133 | self.norm = nn.GroupNorm(8, init_channels) 134 | self.act = Swish() 135 | self.final = nn.Conv2d( 136 | in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1) 137 | ) 138 | 139 | def forward( 140 | self, sample: torch.Tensor, timestep: torch.Tensor, return_dict: bool = True 141 | ): 142 | """ 143 | Forward pass. 144 | 145 | Applies the forward pass of the U-Net model on the given input tensor, `sample`, 146 | and timestep, `timestep`. 147 | 148 | Arguments 149 | --------- 150 | sample : torch.Tensor 151 | The input tensor of the shape (batch_size, num_channels, height, width). 152 | timestep : torch.Tensor 153 | The timestep tensor of the shape (batch_size,) representing the timestep 154 | of each sample. 155 | return_dict : bool 156 | Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a 157 | plain tuple. 158 | 159 | Returns 160 | ------- 161 | noise: torch.Tensor 162 | The output tensor of the shape (batch_size, num_classes, height, width). 163 | """ 164 | timesteps = timestep 165 | if not torch.is_tensor(timesteps): 166 | timesteps = torch.tensor( 167 | [timesteps], dtype=torch.float, device=sample.device 168 | ) 169 | elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: 170 | timesteps = timesteps[None].to(sample.device, dtype=torch.float) 171 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 172 | timesteps = timesteps * torch.ones( 173 | sample.shape[0], dtype=timesteps.dtype, device=timesteps.device 174 | ) 175 | 176 | t_emb = self.time_proj(timesteps).to(dtype=self.dtype) 177 | t = self.time_embedding(t_emb) 178 | 179 | x = self.image_proj(sample) 180 | 181 | h = [x] 182 | # First half of U-Net 183 | for m in self.down: 184 | x = m(x, t) 185 | h.append(x) 186 | 187 | x = self.middle(x, t) 188 | 189 | # Second half of U-Net 190 | for m in self.up: 191 | if isinstance(m, Upsample): 192 | x = m(x, t) 193 | else: 194 | s = h.pop() 195 | x = torch.cat((x, s), dim=1) 196 | x = m(x, t) 197 | 198 | out = self.final(self.act(self.norm(x))) 199 | 200 | if not return_dict: 201 | return (out,) 202 | 203 | return UNet2DOutput(sample=out) 204 | -------------------------------------------------------------------------------- /deepr/model/unet_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from deepr.model.attention import AttentionBlock 5 | from deepr.model.resnet import ResidualBlock 6 | 7 | 8 | class Upsample(nn.Module): 9 | def __init__(self, n_channels): 10 | super().__init__() 11 | self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1)) 12 | 13 | def forward(self, x: torch.Tensor, t: torch.Tensor): 14 | # `t` is not used, but it's kept in the arguments because for the attention 15 | # layer function signature to match with `ResidualBlock`. 16 | _ = t 17 | return self.conv(x) 18 | 19 | 20 | class Downsample(nn.Module): 21 | def __init__(self, n_channels): 22 | super().__init__() 23 | self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1)) 24 | 25 | def forward(self, x: torch.Tensor, t: torch.Tensor): 26 | _ = t 27 | return self.conv(x) 28 | 29 | 30 | class DownBlock(nn.Module): 31 | """Down Block class. 32 | 33 | It represents a block in the first half of U-Net where the input features are being 34 | encoded. 35 | 36 | Attributes 37 | ---------- 38 | res : ResidualBlock 39 | A residual block. 40 | final_layer : Type[nn.Module] 41 | The final layer after the Residual Block. If has_attn is True, it is 42 | `deepr.model.attention.AttentionBlock`. Otherwise it is `nn.Identity`. 43 | """ 44 | 45 | def __init__( 46 | self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool 47 | ): 48 | """Downsampling block class. 49 | 50 | These are used in the first half of U-Net at each resolution. 51 | 52 | Parameters 53 | ---------- 54 | in_channels : int 55 | The number of input channels. 56 | out_channels : int 57 | The number of output channels. 58 | time_channels : int 59 | The number of time channels. 60 | has_attn : bool 61 | A flag indicating whether to use attention block or not. 62 | """ 63 | super().__init__() 64 | self.res = ResidualBlock(in_channels, out_channels, time_channels) 65 | self.final_layer = AttentionBlock(out_channels) if has_attn else nn.Identity() 66 | 67 | def forward(self, x: torch.Tensor, t: torch.Tensor): 68 | x = self.res(x, t) 69 | x = self.final_layer(x) 70 | return x 71 | 72 | 73 | class UpBlock(nn.Module): 74 | """Up Block class. 75 | 76 | It represents a block in the second half of U-Net where the input features are being 77 | decoded. 78 | 79 | Attributes 80 | ---------- 81 | res : ResidualBlock 82 | A residual block. 83 | final_layer : Type[nn.Module] 84 | The final layer after the Residual Block. If has_attn is True, it is 85 | `deepr.model.attention.AttentionBlock`. Otherwise it is `nn.Identity`. 86 | """ 87 | 88 | def __init__( 89 | self, in_channels: int, out_channels: int, time_channels: int, has_attn: bool 90 | ): 91 | """Upsampling block class. 92 | 93 | These are used in the second half of U-Net at each resolution. 94 | 95 | Parameters 96 | ---------- 97 | in_channels : int 98 | The number of input channels. 99 | out_channels : int 100 | The number of output channels. 101 | time_channels : int 102 | The number of time channels. 103 | has_attn : bool 104 | A flag indicating whether to use attention block or not. 105 | """ 106 | super().__init__() 107 | # The input has `in_channels + out_channels` because we concatenate the output 108 | # of the same resolution from the first half of the U-Net 109 | self.res = ResidualBlock( 110 | in_channels + out_channels, out_channels, time_channels 111 | ) 112 | self.final_layer = AttentionBlock(out_channels) if has_attn else nn.Identity() 113 | 114 | def forward(self, x: torch.Tensor, t: torch.Tensor): 115 | x = self.res(x, t) 116 | x = self.final_layer(x) 117 | return x 118 | 119 | 120 | class MiddleBlock(nn.Module): 121 | def __init__(self, n_channels: int, time_channels: int): 122 | super().__init__() 123 | self.res1 = ResidualBlock(n_channels, n_channels, time_channels) 124 | self.attn = AttentionBlock(n_channels) 125 | self.res2 = ResidualBlock(n_channels, n_channels, time_channels) 126 | 127 | def forward(self, x: torch.Tensor, t: torch.Tensor): 128 | x = self.res1(x, t) 129 | x = self.attn(x) 130 | x = self.res2(x, t) 131 | return x 132 | -------------------------------------------------------------------------------- /deepr/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.models.embeddings import get_timestep_embedding 3 | 4 | 5 | def get_hour_embedding( 6 | hours: torch.Tensor, embedding_type: str, emb_size: int = 64 7 | ) -> torch.Tensor: 8 | if embedding_type == "positional": 9 | hour_emb = get_timestep_embedding(hours.squeeze(), emb_size, max_period=24) 10 | elif embedding_type == "cyclical": 11 | hour_emb = torch.stack( 12 | [ 13 | torch.cos(2 * torch.pi * hours / 24), 14 | torch.sin(2 * torch.pi * hours / 24), 15 | ], 16 | dim=1, 17 | ) 18 | elif embedding_type in ("class", "timestep"): 19 | hour_emb = hours 20 | else: 21 | hour_emb = None 22 | 23 | return hour_emb 24 | -------------------------------------------------------------------------------- /deepr/utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/deepr/utilities/__init__.py -------------------------------------------------------------------------------- /deepr/utilities/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | def get_logger(name: str) -> logging.Logger: 6 | """ 7 | Return a logger instance configured with an INFO level, logging to stdout. 8 | 9 | The logger is configured with a handler and a formatter already set up. 10 | The handler sends log records to the standard output (stdout). 11 | 12 | Parameters 13 | ---------- 14 | name : str 15 | Name for the logger. 16 | 17 | Returns 18 | ------- 19 | logger : logging.Logger 20 | The configured logger instance. 21 | """ 22 | logger = logging.getLogger(name) 23 | handler = logging.StreamHandler(sys.stdout) 24 | handler.setFormatter( 25 | logging.Formatter("%(asctime)s — %(name)s — %(levelname)s — %(message)s") 26 | ) 27 | logger.addHandler(handler) 28 | logger.setLevel("INFO") 29 | logger.propagate = False 30 | return logger 31 | -------------------------------------------------------------------------------- /deepr/utilities/yml.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict 3 | 4 | import yaml 5 | 6 | 7 | def read_yaml_file(yaml_file_path: Path) -> Dict: 8 | """ 9 | Read a YAML file and return its contents as a dictionary. 10 | 11 | Parameters 12 | ---------- 13 | yaml_file_path : Path 14 | The path of the YAML file to be read. 15 | 16 | Returns 17 | ------- 18 | configuration : dict 19 | The dictionary containing the YAML file's contents. 20 | 21 | Raises 22 | ------ 23 | FileNotFoundError 24 | If the specified YAML file does not exist. 25 | 26 | yaml.YAMLError 27 | If there's an error while parsing the YAML file. 28 | """ 29 | with open(yaml_file_path) as file: 30 | configuration = yaml.safe_load(file) 31 | configuration = replace_none(configuration) 32 | return configuration 33 | 34 | 35 | def replace_none(dictionary: Dict) -> Dict[Any, Any]: 36 | """ 37 | Recursively replace 'None' string values in a dictionary with None type. 38 | 39 | Parameters 40 | ---------- 41 | dictionary : dict 42 | The dictionary in which 'None' values should be replaced. 43 | 44 | Returns 45 | ------- 46 | new_dictionary : dict 47 | The dictionary with 'None' values replaced by None. 48 | """ 49 | new_dictionary = {} 50 | for key, value in dictionary.items(): 51 | if isinstance(value, dict): 52 | new_value = replace_none(value) 53 | elif isinstance(value, list): 54 | new_list = [] 55 | for element in value: 56 | if isinstance(element, dict): 57 | new_list.append(replace_none(element)) 58 | else: 59 | new_list.append(element) 60 | new_value = new_list # type: ignore 61 | elif value == "None": 62 | new_value = None 63 | else: 64 | new_value = value 65 | new_dictionary[key] = new_value 66 | return new_dictionary 67 | -------------------------------------------------------------------------------- /deepr/validation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/deepr/validation/__init__.py -------------------------------------------------------------------------------- /deepr/validation/generate_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pandas 6 | import torch 7 | import tqdm 8 | import xarray as xr 9 | from statsmodels.nonparametric.smoothers_lowess import lowess 10 | 11 | from deepr.model.conditional_ddpm import cDDPMPipeline 12 | 13 | 14 | def generate_validation_dataset( 15 | data_loader, data_scaler_func, model, config 16 | ) -> xr.Dataset: 17 | """ 18 | Generate validation datasets. 19 | 20 | Parameters 21 | ---------- 22 | data_loader : torch.utils.data.DataLoader 23 | DataLoader containing the validation data batches. 24 | data_scaler_func : callable 25 | Function to scale the model predictions and observed data. It should accept the 26 | model predictions and observed data tensors along with their corresponding 27 | times (timestamps) and return the scaled tensors. 28 | If set to None, no scaling will be applied. 29 | model : torch.nn.Module 30 | The machine learning model to be used for predictions. 31 | The model should either accept images and hour embeddings as inputs 32 | (if it's an instance of cDDPMPipeline) or just images as inputs. 33 | The output of the model should be the predicted images. 34 | config : dict 35 | Configuration settings for the validation. 36 | 37 | Returns 38 | ------- 39 | tuple 40 | pred_nn : xarray.DataArray 41 | Predicted data in xarray format for the machine learning model. 42 | pred_base : xarray.DataArray 43 | Predicted baseline data in xarray format (interpolated using 'baseline') 44 | for comparison with the model predictions. 45 | obs : xarray.DataArray 46 | Observed data in xarray format for validation. 47 | """ 48 | odir = Path(config["output_dir"]) / config["repo_name"].split("/")[-1] 49 | os.makedirs(odir, exist_ok=True) 50 | 51 | preds = [] 52 | current_month = None 53 | progress_bar = tqdm.tqdm(total=len(data_loader), desc="Batch ") 54 | for i, (era5, cerra, times) in enumerate(data_loader): 55 | pred_da, times = predict_xr( 56 | model, 57 | era5, 58 | times, 59 | config, 60 | # orog_low=orog_low_land, 61 | # orog_high=orog_high_land, 62 | data_scaler_func=data_scaler_func, 63 | latitudes=data_loader.dataset.label_latitudes, 64 | longitudes=data_loader.dataset.label_longitudes, 65 | ) 66 | 67 | # Save data based on specs 68 | if config["save_freq"] == "batch": 69 | filename = "prediction_" + times[0].strftime("%HH_%d-%m-%Y") + ".nc" 70 | pred_da.to_netcdf(odir / filename) 71 | elif config["save_freq"] == "month": 72 | if current_month is None: 73 | current_month = np.datetime64(pred_da.time.min().values, "M") 74 | preds.append(pred_da) 75 | next_month = np.datetime64(pred_da.time.max().values, "M") 76 | if current_month != next_month: 77 | pred_ds = xr.concat(preds, dim="time") 78 | month_pred_ds = pred_ds.where( 79 | pred_ds.time.dt.month == current_month.astype(object).month 80 | ).dropna("time") 81 | month_pred_ds.to_netcdf(odir / f"prediction_{current_month}.nc") 82 | current_month = next_month 83 | next_month_da = pred_ds.where( 84 | pred_ds.time.dt.month == next_month.astype(object).month 85 | ).dropna("time") 86 | preds = [next_month_da] 87 | elif config["save_freq"] == "all": 88 | preds.append(pred_da) 89 | progress_bar.update(1) 90 | progress_bar.close() 91 | if config["save_freq"] == "month": 92 | pred_ds = xr.concat(preds, dim="time") 93 | pred_ds.to_netcdf(odir / f"prediction_{current_month}.nc") 94 | elif config["save_freq"] == "all": 95 | pred_ds = xr.concat(preds, dim="time") 96 | d0, d1 = data_loader.dataset.init_date, data_loader.dataset.end_date 97 | pred_ds.to_netcdf(odir / f"prediction_{d0}-{d1}.nc") 98 | 99 | 100 | def predict_xr( 101 | model, 102 | era5, 103 | times, 104 | config, 105 | orog_low: np.array = None, 106 | orog_high: np.array = None, 107 | data_scaler_func=None, 108 | latitudes: np.array = None, 109 | longitudes: np.array = None, 110 | ): 111 | if isinstance(model, str): 112 | prediction = torch.nn.functional.interpolate( 113 | era5[..., 6:-6, 6:-6], 114 | mode=model, 115 | scale_factor=5, 116 | ) 117 | 118 | xs = np.ravel(orog_low)[~np.isnan(np.ravel(orog_low))] 119 | true_orog = np.ravel(orog_high) 120 | expected_orog = np.ravel( 121 | torch.nn.functional.interpolate( 122 | torch.from_numpy(orog_low[np.newaxis, np.newaxis, ...]), 123 | mode=model, 124 | scale_factor=5, 125 | ).numpy() 126 | ) 127 | 128 | deltas = [] 129 | for i in range(prediction.shape[0]): 130 | vals = np.ravel(era5[i, 0, 6:-6, 6:-6])[~np.isnan(np.ravel(orog_low))] 131 | delta = lowess(vals, xs, xvals=expected_orog) - lowess( 132 | vals, xs, xvals=true_orog 133 | ) 134 | deltas.append(delta.reshape(1, 1, *prediction.shape[-2:])) 135 | d = np.nan_to_num(np.concatenate(deltas, axis=0), 0) 136 | prediction -= d 137 | 138 | attrs = { 139 | "method": f"{model} + orography correction", 140 | "orog_correction": "Difference between LOWESS estimates (fitted at each " 141 | "sample) at low-res & high-res orography.", 142 | } 143 | elif isinstance(model, cDDPMPipeline): 144 | prediction = model( 145 | images=era5, 146 | class_labels=times[:, :1].to(model.device), 147 | eta=config["eta"], 148 | num_inference_steps=config["inference_steps"], 149 | generator=torch.manual_seed(config.get("seed", 2023)), 150 | output_type="tensor", 151 | ).images 152 | attrs = { 153 | "ddpm": model.__class__.__name__, 154 | "eta": config["eta"], 155 | "num_inference_steps": config["inference_steps"], 156 | "seed": config.get("seed", 2023), 157 | } 158 | else: 159 | with torch.no_grad(): 160 | prediction = model(era5, return_dict=False)[0] 161 | attrs = {} 162 | attrs["repo"] = config["repo_name"] 163 | attrs["input_inference_scaling"] = config["inference_scaling"].get("input", "") 164 | attrs["output_inference_scaling"] = config["inference_scaling"].get("output", "") 165 | 166 | if data_scaler_func is not None: 167 | prediction = data_scaler_func(prediction, times[:, 2]) 168 | 169 | times = transform_times_to_datetime(times) 170 | pred_da = transform_data_to_xr_format( 171 | prediction, "prediction", latitudes, longitudes, times 172 | ).chunk(chunks={"latitude": 20, "longitude": 40}) 173 | pred_da.prediction.attrs = attrs 174 | 175 | return pred_da, times 176 | 177 | 178 | def transform_data_to_xr_format(data, varname, latitudes, longitudes, times): 179 | """ 180 | Create a xarray dataset from the given variables. 181 | 182 | Parameters 183 | ---------- 184 | data : numpy array 185 | The prediction data with shape 186 | (num_times, num_channels, num_latitudes, num_longitudes). 187 | varname: str 188 | Name for the variable 189 | latitudes : list or numpy array 190 | List of latitude values. 191 | longitudes : list or numpy array 192 | List of longitude values. 193 | times : list or pandas DatetimeIndex 194 | List of timestamps. 195 | 196 | Returns 197 | ------- 198 | xarray.Dataset 199 | The xarray dataset containing the prediction data. 200 | 201 | Example 202 | ------- 203 | # Assuming the given variables have appropriate values 204 | pred_nn = np.random.rand(16, 1, 160, 240) 205 | latitudes = [44.95, 44.9, ...] # List of 160 latitude values 206 | longitudes = [-6.85, -6.8, ...] # List of 240 longitude values 207 | times = pd.date_range('2018-01-01', periods=16, freq='3H') 208 | dataset = create_xarray_dataset(pred_nn, latitudes, longitudes, times) 209 | print(dataset) 210 | """ 211 | # Ensure pred_nn is a numpy array 212 | data = np.asarray(data) 213 | 214 | # Create a dictionary to hold data variables 215 | data_vars = {varname: (["time", "channel", "latitude", "longitude"], data)} 216 | 217 | # Create coordinate variables 218 | coords = {"latitude": latitudes, "longitude": longitudes, "time": times} 219 | 220 | # Create the xarray dataset 221 | dataset = xr.Dataset(data_vars, coords=coords) 222 | 223 | # Remove channel dimension 224 | dataset = dataset.mean("channel") 225 | 226 | return dataset 227 | 228 | 229 | def transform_times_to_datetime(times: torch.tensor): 230 | """ 231 | Transform a tensor of times into a list of datetimes. 232 | 233 | Parameters 234 | ---------- 235 | times (tensor): A tensor containing times in the format [hour, day, month, year]. 236 | 237 | Returns 238 | ------- 239 | list: A list of pandas datetime objects representing the input times. 240 | 241 | Example: 242 | times = tensor([[0, 1, 1, 2018], 243 | [3, 1, 1, 2018], 244 | [6, 1, 1, 2018]]) 245 | result = transform_times_to_datetime(times) 246 | print(result) 247 | # Output: [Timestamp('2018-01-01 00:00:00'), 248 | # Timestamp('2018-01-01 03:00:00'), 249 | # Timestamp('2018-01-01 06:00:00')] 250 | """ 251 | # Convert each time entry to a pandas datetime object 252 | datetime_list = [ 253 | pandas.to_datetime(f"{time[3]}-{time[2]}-{time[1]} {time[0]}:00") 254 | for time in times 255 | ] 256 | 257 | return datetime_list 258 | -------------------------------------------------------------------------------- /deepr/validation/netcdf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/deepr/validation/netcdf/__init__.py -------------------------------------------------------------------------------- /deepr/validation/netcdf/metrics.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import xarray as xr 4 | import xskillscore as xs 5 | 6 | 7 | class Metrics: 8 | def __init__( 9 | self, 10 | model_name: str, 11 | observations: xr.Dataset, 12 | predictions: xr.Dataset, 13 | output_directory: pathlib.Path, 14 | ): 15 | """ 16 | Initialize the Metrics object. 17 | 18 | Parameters 19 | ---------- 20 | model_name: str 21 | The model name used to store the metrics 22 | observations : xr.Dataset 23 | The dataset containing the observations. 24 | predictions : xr.Dataset 25 | The dataset containing the predictions. 26 | output_directory : str or pathlib.Path 27 | The output directory for storing the metrics. 28 | 29 | """ 30 | self.model_name = model_name 31 | self.predictions, self.observations = predictions, observations 32 | self.output_directory = output_directory 33 | 34 | def get_metrics(self): 35 | """ 36 | Compute metrics for the predictions and observations. 37 | 38 | Returns 39 | ------- 40 | metrics_ds : xr.Dataset 41 | The dataset containing the computed metrics. 42 | 43 | Raises 44 | ------ 45 | NotImplementedError 46 | If the problem_type is neither "regression" nor "classification". 47 | """ 48 | metrics_datasets = [] 49 | 50 | regression_metrics = { 51 | "r2": xs.r2, 52 | "mae": xs.mae, 53 | "me": xs.me, 54 | "mse": xs.mse, 55 | "rmse": xs.rmse, 56 | } 57 | 58 | for metric_name, metric_function in regression_metrics.items(): 59 | ds_metric = metric_function(self.observations, self.predictions, dim="time") 60 | ds_metric = ds_metric.rename_vars({"variable": metric_name}) 61 | metrics_datasets.append(ds_metric) 62 | 63 | metrics_ds = xr.merge(metrics_datasets) 64 | metrics_ds["obs_mean"] = self.observations.mean(dim="time").rename_vars( 65 | {"variable": "obs_mean"} 66 | )["obs_mean"] 67 | metrics_ds["pred_mean"] = self.predictions.mean(dim="time").rename_vars( 68 | {"variable": "pred_mean"} 69 | )["pred_mean"] 70 | metrics_ds["obs_std"] = self.observations.std(dim="time").rename_vars( 71 | {"variable": "obs_std"} 72 | )["obs_std"] 73 | metrics_ds["pred_std"] = self.predictions.std(dim="time").rename_vars( 74 | {"variable": "pred_std"} 75 | )["pred_std"] 76 | 77 | metrics_ds.to_netcdf(self.get_output_path()) 78 | 79 | return metrics_ds 80 | 81 | def get_output_path(self): 82 | """ 83 | Retrieve the output path for the metrics dataset. 84 | 85 | Returns 86 | ------- 87 | Path 88 | Output path for the metrics dataset in netcdf format. 89 | """ 90 | output_path = self.output_directory / f"{self.model_name}.nc" 91 | output_path.parent.mkdir(parents=True, exist_ok=True, mode=0o777) 92 | return output_path 93 | -------------------------------------------------------------------------------- /deepr/validation/netcdf/validation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import xarray 4 | 5 | from deepr.utilities.logger import get_logger 6 | from deepr.validation.netcdf.metrics import Metrics 7 | from deepr.validation.netcdf.visualize import Visualization 8 | 9 | logger = get_logger("Validation Configuration") 10 | 11 | 12 | class ValidationConfig: 13 | def __init__(self, validation_configuration: dict): 14 | """ 15 | Initialize the ValidationConfig object. 16 | 17 | Parameters 18 | ---------- 19 | validation_configuration : dict 20 | A dictionary containing the validation configuration parameters. 21 | 22 | """ 23 | logger.info( 24 | f"Initializing model configuration object with parameters: " 25 | f"{validation_configuration}" 26 | ) 27 | self.validation_configuration = validation_configuration 28 | 29 | def run(self): 30 | model_predictions, baseline_predictions = self.open_predictions() 31 | model_predictions.load() 32 | baseline_predictions.load() 33 | 34 | observations = self.open_observations() 35 | observations.load() 36 | 37 | ( 38 | model_predictions, 39 | baseline_predictions, 40 | observations, 41 | ) = self.select_intersecting_times( 42 | model_predictions, baseline_predictions, observations 43 | ) 44 | 45 | self.validate_predictions(observations, model_predictions, baseline_predictions) 46 | 47 | def open_predictions(self): 48 | if self.validation_configuration["model_predictions_location"] is not None: 49 | model_predictions = xarray.open_mfdataset( 50 | f"{self.validation_configuration['model_predictions_location']}/*.nc" 51 | ) 52 | model_predictions = model_predictions - 273.15 53 | model_predictions = model_predictions.dropna("time") 54 | else: 55 | model_predictions = None 56 | 57 | if self.validation_configuration["baseline_predictions_location"] is not None: 58 | baseline_predictions = xarray.open_mfdataset( 59 | f"{self.validation_configuration['baseline_predictions_location']}/*.nc" 60 | ) 61 | baseline_predictions = baseline_predictions - 273.15 62 | baseline_predictions = baseline_predictions.dropna("time") 63 | else: 64 | baseline_predictions = None 65 | 66 | return model_predictions.sortby("time"), baseline_predictions.sortby("time") 67 | 68 | def open_observations(self): 69 | if self.validation_configuration["observations_location"] is not None: 70 | observations = xarray.open_mfdataset( 71 | f"{self.validation_configuration['observations_location']}/*.nc" 72 | ) 73 | observations = observations.rename_vars({"t2m": "observation"}) 74 | observations = observations - 273.15 75 | else: 76 | observations = None 77 | return observations.sortby("time") 78 | 79 | def validate_predictions( 80 | self, 81 | observations: xarray.Dataset, 82 | predictions: xarray.Dataset, 83 | baselines: xarray.Dataset, 84 | ): 85 | """ 86 | Validate predictions against observations and benchmarks. 87 | 88 | Parameters 89 | ---------- 90 | observations : xarray.Dataset 91 | A dataset containing the observations. 92 | predictions : xarray.Dataset 93 | A dataset containing the predictions. 94 | baselines : xarray.Dataset 95 | A dataset containing the benchmarks. 96 | """ 97 | # Metrics datasets for the different sample types 98 | model_metrics_dataset = Metrics( 99 | model_name=self.validation_configuration["model_name"], 100 | observations=observations.rename_vars({"observation": "variable"}), 101 | predictions=predictions.rename_vars({"prediction": "variable"}), 102 | output_directory=Path( 103 | f'{self.validation_configuration["validation_dir"]}/metrics' 104 | ), 105 | ).get_metrics() 106 | baseline_metrics_dataset = Metrics( 107 | model_name=self.validation_configuration["baseline_name"], 108 | observations=observations.rename_vars({"observation": "variable"}), 109 | predictions=baselines.rename_vars({"prediction": "variable"}), 110 | output_directory=Path( 111 | f'{self.validation_configuration["validation_dir"]}/metrics' 112 | ), 113 | ).get_metrics() 114 | 115 | # Visualizations for the different data types 116 | Visualization( 117 | model_name=self.validation_configuration["model_name"], 118 | baseline_name=self.validation_configuration["baseline_name"], 119 | observations=observations.rename_vars({"observation": "variable"}), 120 | predictions=predictions.rename_vars({"prediction": "variable"}), 121 | baselines=baselines.rename_vars({"prediction": "variable"}), 122 | model_metrics=model_metrics_dataset, 123 | baseline_metrics=baseline_metrics_dataset, 124 | visualization_types=self.validation_configuration["visualization_types"], 125 | output_directory=Path( 126 | f'{self.validation_configuration["validation_dir"]}/figures' 127 | ), 128 | ).get_visualizations() 129 | 130 | return model_metrics_dataset, baseline_metrics_dataset 131 | 132 | @staticmethod 133 | def select_intersecting_times( 134 | model_predictions: xarray.Dataset, 135 | baseline_predictions: xarray.Dataset, 136 | observations: xarray.Dataset, 137 | ): 138 | """ 139 | Select time values that are present in all three datasets. 140 | 141 | Parameters 142 | ---------- 143 | model_predictions : xarray.Dataset 144 | A dataset containing model predictions. 145 | baseline_predictions : xarray.Dataset 146 | A dataset containing baseline predictions. 147 | observations : xarray.Dataset 148 | A dataset containing observations. 149 | 150 | Returns 151 | ------- 152 | model_predictions : xarray.Dataset 153 | Model predictions dataset with time values that are present in 154 | all three datasets. 155 | baseline_predictions : xarray.Dataset 156 | Baseline predictions dataset with time values that are present in 157 | all three datasets. 158 | observations : xarray.Dataset 159 | Observations dataset with time values that are present in 160 | all three datasets. 161 | """ 162 | # Get the unique time values from each dataset 163 | model_times = set(model_predictions.time.values) 164 | baseline_times = set(baseline_predictions.time.values) 165 | obs_times = set(observations.time.values) 166 | 167 | # Find the intersection of time values 168 | intersecting_times = list(model_times.intersection(baseline_times, obs_times)) 169 | intersecting_times.sort() # Optional: sort the time values 170 | 171 | # Select the datasets with intersecting time values 172 | model_predictions = model_predictions.sel(time=intersecting_times) 173 | baseline_predictions = baseline_predictions.sel(time=intersecting_times) 174 | observations = observations.sel(time=intersecting_times) 175 | 176 | return model_predictions, baseline_predictions, observations 177 | 178 | 179 | if __name__ == "__main__": 180 | from deepr.utilities.yml import read_yaml_file 181 | 182 | configuration = read_yaml_file( 183 | "../../../resources/configuration_validation_netcdf.yml" 184 | )["validation"] 185 | 186 | ValidationConfig(configuration).run() 187 | -------------------------------------------------------------------------------- /deepr/validation/nn_performance_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Type 2 | 3 | import evaluate 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from deepr.model.conditional_ddpm import cDDPMPipeline 8 | from deepr.model.utils import get_hour_embedding 9 | 10 | metric_to_repo = { 11 | "MSE": "mse", 12 | "R2": "r_squared", 13 | "SMAPE": "smape", 14 | "PSNR": "jpxkqx/peak_signal_to_noise_ratio", 15 | "SSIM": "jpxkqx/structural_similarity_index_measure", 16 | "SRE": "jpxkqx/signal_to_reconstruction_error", 17 | } 18 | 19 | 20 | def compute_cerra_mean(dataloader, scaler_func: Callable = None): 21 | total_samples = 0 22 | overall_sum_cerra = torch.zeros(dataloader.dataset.output_shape) 23 | for era5, cerra, times in dataloader: 24 | total_samples += times.shape[0] 25 | if scaler_func is not None: 26 | cerra = scaler_func(cerra, times[:, 2]) 27 | overall_sum_cerra += cerra.sum((0, 1)) 28 | 29 | return overall_sum_cerra / total_samples 30 | 31 | 32 | def compute_and_upload_metrics( 33 | model: Type[torch.nn.Module], 34 | dataloader: torch.utils.data.DataLoader, 35 | hf_repo_name: str = None, 36 | scaler_func: Callable = None, 37 | ): 38 | """Compute and upload a set of metrics. 39 | 40 | The metrics computed in this function are: 41 | - MSE: Mean Squared Error of the predictions. 42 | - R2: Pearson Correlation (R²) coefficient of the predictions. 43 | - SMAPE: Symmetric Mean Absolute Percentage Error of the predictions. 44 | - PSNR: Peak Signal to Noise Ratio of the predictions. 45 | - SSIM: Structural Similarity Index Measure of the predictions. 46 | - SRE: Signal to Reconstruction Error of the predictions. 47 | """ 48 | # Load metrics over all dataset 49 | mse = evaluate.load("mse", "multilist") 50 | # r2 = evaluate.load("pearsonr", "multilist") 51 | smape = evaluate.load("smape", "multilist") 52 | psnr = evaluate.load("jpxkqx/peak_signal_to_noise_ratio", "multilist") 53 | ssim = evaluate.load("jpxkqx/structural_similarity_index_measure", "multilist") 54 | sre = evaluate.load("jpxkqx/signal_to_reconstruction_error", "multilist") 55 | 56 | progress_bar = tqdm(total=len(dataloader), desc="Batch ") 57 | max_pred, min_pred, max_true, min_true = -999, 999, -999, 999 58 | for era5, cerra, times in dataloader: 59 | # Predict the noise residual 60 | with torch.no_grad(): 61 | pred = model(era5, return_dict=False)[0] 62 | if scaler_func is not None: 63 | pred = scaler_func(pred, times[:, 2]) 64 | cerra = scaler_func(cerra, times[:, 2]) 65 | 66 | mse.add_batch( 67 | references=cerra.reshape((cerra.shape[0], -1)), 68 | predictions=pred.reshape((pred.shape[0], -1)), 69 | ) 70 | # r2.add_batch( 71 | # references=cerra.reshape((cerra.shape[0], -1)), 72 | # predictions=pred.reshape((pred.shape[0], -1)), 73 | # ) 74 | smape.add_batch( 75 | references=cerra.reshape((cerra.shape[0], -1)), 76 | predictions=pred.reshape((pred.shape[0], -1)), 77 | ) 78 | psnr.add_batch(references=cerra, predictions=pred) 79 | ssim.add_batch(references=cerra, predictions=pred) 80 | sre.add_batch(references=cerra, predictions=pred) 81 | max_pred, min_pred = max(max_pred, pred.max()), min(min_pred, pred.min()) 82 | max_true, min_true = max(max_true, cerra.max()), min(min_true, cerra.min()) 83 | progress_bar.update(1) 84 | progress_bar.close() 85 | 86 | # Compute Metrics 87 | data_range = float(max(max_pred, max_true) - min(min_pred, min_true)) 88 | test_metrics = { 89 | "MSE": mse.compute()["mse"], 90 | # "R2": r2.compute()["pearsonr"], 91 | "SMAPE": smape.compute()["smape"], 92 | "PSNR": psnr.compute(data_range=data_range), 93 | "SSIM": ssim.compute(data_range=data_range, channel_axis=0), # ignore batch dim 94 | "SRE": sre.compute()["Signal-to-Reconstruction Error"], 95 | } 96 | for name, metric_val in test_metrics.items(): 97 | print(f"Test {name}: {metric_val:.2f}") 98 | 99 | if hf_repo_name is not None: 100 | for name, metric_val in test_metrics.items(): 101 | evaluate.push_to_hub( 102 | model_id=hf_repo_name, 103 | metric_type=metric_to_repo[name], 104 | metric_name=name, 105 | metric_value=metric_val, 106 | dataset_type="era5", 107 | dataset_name="ERA5+CERRA", 108 | dataset_split="test", 109 | task_type="image-to-image", 110 | task_name="Super Resolution", 111 | ) 112 | 113 | return test_metrics 114 | 115 | 116 | def compute_model_and_baseline_errors( 117 | model: Type[torch.nn.Module], 118 | dataloader: torch.utils.data.DataLoader, 119 | baseline: str = "bicubic", 120 | scaler_func: Callable = None, 121 | inference_steps: int = 1000, 122 | num_batches: int = None, 123 | ): 124 | """ 125 | Compute the model and baseline errors. 126 | 127 | It makes it by comparing the predictions with the ground truth labels. 128 | 129 | Parameters 130 | ---------- 131 | model : Type[torch.nn.Module] 132 | The neural network model. 133 | dataloader : torch.utils.data.DataLoader 134 | The data loader used to fetch the data. 135 | baseline : str, optional 136 | The mode used for baseline interpolation, by default "bicubic". 137 | scaler_func : Callable, optional 138 | A scaling function to apply on the data, by default None. 139 | inference_steps : int, optional 140 | The number of inference steps in case a Diffusion Process is specified. 141 | num_batches : int, optional 142 | The number of batches to sample. By default None, meaning that it iterates 143 | over the whole dataset. 144 | 145 | Returns 146 | ------- 147 | mae : torch.Tensor 148 | Mean Absolute Error (MAE) between the model predictions and the 149 | ground truth labels. 150 | mse : torch.Tensor 151 | Mean Squared Error (MSE) between the model predictions and the 152 | ground truth labels. 153 | mae_bi : torch.Tensor 154 | MAE between the baseline predictions and the ground truth labels. 155 | mse_bi : torch.Tensor 156 | MSE between the baseline predictions and the ground truth labels. 157 | improvement : torch.Tensor 158 | Percentage of improvement of the error from the model to the baseline. 159 | 160 | """ 161 | count_hour = {} 162 | keys = [0, 3, 6, 9, 12, 15, 18, 21, "all"] 163 | abs_errors, sq_errors, r2, improvement = {}, {}, {}, {} 164 | abs_errors_base, sq_errors_base, r2_base = {}, {}, {} 165 | for key in keys: 166 | abs_errors[key] = torch.zeros(dataloader.dataset.output_shape) 167 | sq_errors[key] = torch.zeros(dataloader.dataset.output_shape) 168 | r2[key] = torch.zeros(dataloader.dataset.output_shape) 169 | abs_errors_base[key] = torch.zeros(dataloader.dataset.output_shape) 170 | sq_errors_base[key] = torch.zeros(dataloader.dataset.output_shape) 171 | r2_base[key] = torch.zeros(dataloader.dataset.output_shape) 172 | 173 | improvement[key] = torch.zeros(dataloader.dataset.output_shape) 174 | count_hour[key] = 0 175 | 176 | overall_mean_cerra = compute_cerra_mean(dataloader, scaler_func) 177 | 178 | progress_bar = tqdm(total=len(dataloader), desc="Batch ") 179 | for i, (era5, cerra, times) in enumerate(dataloader): 180 | if isinstance(model, cDDPMPipeline): 181 | hour_emb = get_hour_embedding(times[:, :1], "class", 24).to(model.device) 182 | pred_nn = model( 183 | images=era5, 184 | class_labels=hour_emb, 185 | num_inference_steps=inference_steps, 186 | generator=torch.manual_seed(2023), 187 | output_type="tensor", 188 | ).images 189 | else: 190 | with torch.no_grad(): 191 | pred_nn = model(era5, return_dict=False)[0] 192 | 193 | pred_base = torch.nn.functional.interpolate( 194 | era5[..., 6:-6, 6:-6], scale_factor=5, mode=baseline 195 | ) 196 | 197 | if scaler_func is not None: 198 | pred_nn = scaler_func(pred_nn, times[:, 2]) 199 | pred_base = scaler_func(pred_base, times[:, 2]) 200 | cerra = scaler_func(cerra, times[:, 2]) 201 | 202 | batch_size = times.shape[0] 203 | for sample in range(batch_size): 204 | hour = int(times[sample][0]) 205 | count_hour[hour] += 1 206 | error = pred_nn[sample] - cerra[sample] 207 | error_base = pred_base[sample] - cerra[sample] 208 | abs_errors[hour] += torch.sum(torch.abs(error), 0) 209 | sq_errors[hour] += torch.sum(error**2, 0) 210 | r2[hour] += torch.ones(dataloader.dataset.output_shape) - ( 211 | torch.sum(error**2, 0) 212 | / torch.sum((cerra[sample] - overall_mean_cerra) ** 2, 0) 213 | ) 214 | abs_errors_base[hour] += torch.sum(torch.abs(error_base), 0) 215 | sq_errors_base[hour] += torch.sum(error_base**2, 0) 216 | r2_base[hour] += torch.ones(dataloader.dataset.output_shape) - ( 217 | torch.sum(error_base**2, 0) 218 | / torch.sum((cerra[sample] - overall_mean_cerra) ** 2, 0) 219 | ) 220 | improvement[hour] += torch.sum(100 * (error - error_base) / error_base, 0) 221 | 222 | count_hour["all"] += times.shape[0] 223 | error = pred_nn - cerra 224 | error_base = pred_base - cerra 225 | abs_errors["all"] += torch.sum(torch.abs(error), (0, 1)) 226 | sq_errors["all"] += torch.sum(error**2, (0, 1)) 227 | r2["all"] += torch.ones(dataloader.dataset.output_shape) - ( 228 | torch.sum(error**2, (0, 1)) 229 | / torch.sum((cerra - overall_mean_cerra) ** 2, (0, 1)) 230 | ) 231 | abs_errors_base["all"] += torch.sum(torch.abs(error_base), (0, 1)) 232 | sq_errors_base["all"] += torch.sum(error_base**2, (0, 1)) 233 | r2_base["all"] += 1 - torch.sum(error_base**2) / torch.sum( 234 | (cerra - cerra.mean()) ** 2 235 | ) 236 | improvement["all"] += torch.sum(100 * (error - error_base) / error_base, (0, 1)) 237 | progress_bar.update(1) 238 | if num_batches is not None and i >= num_batches - 1: 239 | break 240 | 241 | progress_bar.close() 242 | 243 | mae, mse, mae_base, mse_base = {}, {}, {}, {} 244 | for hour, value in count_hour.items(): 245 | mae[hour] = abs_errors[hour] / count_hour[hour] 246 | mse[hour] = sq_errors[hour] / count_hour[hour] 247 | r2[hour] /= count_hour[hour] 248 | mae_base[hour] = abs_errors_base[hour] / count_hour[hour] 249 | mse_base[hour] = sq_errors_base[hour] / count_hour[hour] 250 | r2_base[hour] /= count_hour[hour] 251 | 252 | return mae, mse, r2, mae_base, mse_base, r2_base, improvement 253 | -------------------------------------------------------------------------------- /deepr/validation/sample_predictions.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | from typing import Callable 4 | 5 | import torch 6 | 7 | from deepr.model.conditional_ddpm import cDDPMPipeline 8 | from deepr.model.utils import get_hour_embedding 9 | from deepr.utilities.logger import get_logger 10 | from deepr.visualizations.giffs import generate_giff 11 | from deepr.visualizations.plot_maps import plot_2_model_comparison 12 | from deepr.visualizations.plot_samples import get_figure_model_samples 13 | 14 | K_to_C = 273.15 15 | logger = get_logger(__name__) 16 | 17 | 18 | def sample_observation_vs_prediction( 19 | model, 20 | dataloader: torch.utils.data.DataLoader, 21 | local_dir: str, 22 | scaler_func: Callable = None, 23 | baseline: str = "bicubic", 24 | num_samples: int = 10, 25 | ): 26 | """ 27 | Generate and save a comparison plot of model predictions and baseline samples. 28 | 29 | Parameters 30 | ---------- 31 | model : object 32 | The neural network model used for predictions. 33 | dataloader : torch.utils.data.DataLoader 34 | The data loader used to fetch the data. 35 | local_dir : str 36 | The directory where the plot will be saved. 37 | scaler_func : Callable, optional 38 | A scaling function to apply on the data, by default None. 39 | baseline : str, optional 40 | The mode used for baseline interpolation, by default "bicubic". 41 | num_samples : int, optional 42 | The number of samples to randomly select and compare, by default 10. 43 | """ 44 | samples_get = 0 45 | for era5, cerra, times in dataloader: 46 | with torch.no_grad(): 47 | pred_nn = model(era5, return_dict=False)[0] 48 | samples_base = torch.nn.functional.interpolate( 49 | era5[..., 6:-6, 6:-6], scale_factor=5, mode=baseline 50 | ) 51 | 52 | if scaler_func is not None: 53 | cerra = scaler_func(cerra, times[:, 2]) 54 | samples_base = scaler_func(samples_base, times[:, 2]) 55 | pred_nn = scaler_func(pred_nn, times[:, 2]) 56 | 57 | for i in range(len(times)): 58 | if random.choice([True, False]): 59 | continue 60 | filename = Path(local_dir) / f"pred_comparison_{samples_get}.png" 61 | t_str = f"{times[i, 0]:d}H {times[i, 1]:d}-{times[i, 2]:d}-{times[i, 3]:d}" 62 | plot_2_model_comparison( 63 | cerra[i, 0], 64 | samples_base[i, 0], 65 | pred_nn[i, 0], 66 | matrix_names=["CERRA", baseline.capitalize(), model.__class__.__name__], 67 | metric_name="ºC", 68 | date=t_str, 69 | filename=filename, 70 | ) 71 | samples_get += 1 72 | if samples_get == num_samples: 73 | return None 74 | 75 | 76 | def sample_diffusion_samples_random( 77 | pipeline: cDDPMPipeline, 78 | dataloader: torch.utils.data.DataLoader, 79 | scaler_func: Callable = None, 80 | baseline: str = "bicubic", 81 | num_samples: int = 10, 82 | num_realizations: int = 3, 83 | inference_steps: int = 1000, 84 | output_dir: str = None, 85 | device: str = "", 86 | ): 87 | n_samples = 0 88 | for i, (era5, cerra, times) in enumerate(dataloader): 89 | # Prepare data 90 | # 1 A) Encode hour 91 | hour_emb = get_hour_embedding(times[:, :1], "class", 24) 92 | 93 | # 1 B) Repeat each sample by number of realizations 94 | era5_repeated = era5.repeat(num_realizations, 1, 1, 1) 95 | 96 | if hour_emb is not None: 97 | hour_emb = hour_emb.to(device) 98 | hour_emb = hour_emb.repeat(num_realizations, 1).squeeze() 99 | 100 | # 1 C) Compute baseline predictions 101 | pred_base = torch.nn.functional.interpolate( 102 | era5[..., 6:-6, 6:-6], scale_factor=5, mode=baseline 103 | ) 104 | 105 | # 2) Run the predictions 106 | pred_nn = pipeline( 107 | images=era5_repeated, 108 | class_labels=hour_emb, 109 | num_inference_steps=inference_steps, 110 | generator=torch.manual_seed(2023), 111 | output_type="tensor", 112 | ).images 113 | 114 | if scaler_func is not None: 115 | cerra = scaler_func(cerra, times[:, 2]) - K_to_C 116 | era5 = scaler_func(era5, times[:, 2]) - K_to_C 117 | pred_nn = ( 118 | scaler_func(pred_nn, times[:, 2].repeat(num_realizations)) - K_to_C 119 | ) 120 | pred_base = scaler_func(pred_base, times[:, 2]) - K_to_C 121 | 122 | # Make a grid out of the images 123 | sample_names = [f"{t[0]:d}H {t[1]:02d}-{t[2]:02d}-{t[3]:04d}" for t in times] 124 | get_figure_model_samples( 125 | cerra.cpu(), 126 | pred_nn.cpu(), 127 | input_image=era5.cpu(), 128 | baseline=pred_base.cpu(), 129 | column_names=sample_names, 130 | filename=output_dir + f"/samples_{i}.png", 131 | ) 132 | n_samples += len(sample_names) 133 | if n_samples >= num_samples: 134 | return None 135 | 136 | 137 | def sample_gif( 138 | pipeline, 139 | dataloader, 140 | scaler_func: Callable = None, 141 | output_dir: str = None, 142 | freq_timesteps_frame: int = 1, 143 | inference_steps: int = 1000, 144 | fps: int = 50, 145 | eta: float = 1, 146 | ): 147 | """ 148 | Generate GIFs of the diffusion process for a given pipeline. 149 | 150 | Args: 151 | ---- 152 | pipeline (callable): The pipeline function to apply to the images. 153 | dataloader (iterable): An iterable containing low-resolution reanalysis. 154 | scaler_func (callable, optional): A function to un-scale the images. Defaults to None. 155 | output_dir (str, optional): The directory to save the generated GIFs. Defaults to None. 156 | freq_timesteps_frame (int, optional): The frequency of diffusion timesteps to 157 | save as frames in the GIFs. Defaults to 1, which saves latents at all 158 | timesteps as frames. 159 | inference_steps (int, optional): The number of inference timesteps to perform 160 | the diffusion process. Defaults to 1000. 161 | fps (int, optional): The frames per second to show. Maximum value supported for 162 | most of modern browsers is 50fps. 163 | """ 164 | era5, _, times = next(iter(dataloader)) 165 | hr_im, interm = pipeline( 166 | images=era5, 167 | class_labels=times[:, :1], 168 | generator=torch.manual_seed(2023), 169 | eta=eta, 170 | num_inference_steps=inference_steps, 171 | return_dict=False, 172 | saving_freq_interm=freq_timesteps_frame, 173 | output_type="tensor", 174 | ) 175 | for i, time in enumerate(times): 176 | date = f"{time[0]:d}H_{time[1]:02d}-{time[2]:02d}-{time[3]:04d}" 177 | logger.info(f"Generating GIF for time: {date}") 178 | fname = output_dir + f"/diffusion_{date}_{inference_steps}steps" 179 | generate_giff(interm[i], fname + "_scaled", fps=fps) 180 | 181 | scaled_interm = scaler_func( 182 | interm[i].unsqueeze(1), times[i, 2].repeat(interm.shape[1]) 183 | ) 184 | scaled_interm -= K_to_C 185 | generate_giff(scaled_interm.squeeze(), fname, label="Temperature (ºC)", fps=fps) 186 | 187 | 188 | def diffusion_callback( 189 | model, 190 | scheduler, 191 | era5, 192 | cerra, 193 | times, 194 | scaler_func: Callable = None, 195 | output_dir: str = None, 196 | freq_timesteps_frame: int = 1, 197 | inference_steps: int = 1000, 198 | fps: int = 50, 199 | eta: float = 1, 200 | epoch: int = 0, 201 | **ddpm_kwargs, 202 | ): 203 | pipeline = cDDPMPipeline(unet=model, scheduler=scheduler, **ddpm_kwargs) 204 | pipeline.to(era5.device) 205 | hr_im, interm = pipeline( 206 | images=era5[:1], 207 | class_labels=times[:1, :1], 208 | generator=torch.manual_seed(2023), 209 | eta=eta, 210 | num_inference_steps=inference_steps, 211 | return_dict=False, 212 | saving_freq_interm=freq_timesteps_frame, 213 | output_type="tensor", 214 | ) 215 | del era5 216 | times = times.cpu() 217 | date = f"{times[0, 0]:d}H_{times[0, 1]:02d}-{times[0, 2]:02d}-{times[0, 3]:04d}" 218 | logger.info(f"Generating GIFFs for time: {date}") 219 | fname = output_dir + f"/diffusion_{date}_{inference_steps}steps" 220 | if epoch is not None: 221 | fname += f"_{epoch}epoch" 222 | 223 | get_figure_model_samples( 224 | scaler_func(cerra[:1].cpu(), times[0, 2]), 225 | scaler_func(hr_im[:1].cpu(), times[0, 2]), 226 | filename=fname + "_comparison.png", 227 | ) 228 | del cerra, hr_im 229 | 230 | # GIFFs 231 | interm = interm.cpu() 232 | generate_giff(interm[0], f"{fname}_scaled", fps=fps) 233 | 234 | interm = scaler_func( # UNDO scaling of latents 235 | interm[0].unsqueeze(1), times[0, 2].repeat(interm.shape[1]) 236 | ) 237 | interm -= K_to_C 238 | generate_giff(interm.squeeze(), fname, label="Temperature (ºC)", fps=fps) 239 | -------------------------------------------------------------------------------- /deepr/validation/validation_nn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import evaluate 5 | import torch 6 | from huggingface_hub import Repository 7 | 8 | from deepr.data.scaler import XarrayStandardScaler 9 | from deepr.validation.nn_performance_metrics import ( 10 | compute_and_upload_metrics, 11 | compute_model_and_baseline_errors, 12 | ) 13 | from deepr.validation.sample_predictions import sample_observation_vs_prediction 14 | from deepr.visualizations.plot_maps import plot_2_maps_comparison 15 | from deepr.visualizations.plot_rose import plot_rose 16 | 17 | tmpdir = tempfile.mkdtemp(prefix="test-") 18 | 19 | experiment_name = "Test Neural Network" 20 | metric_to_repo = { 21 | "MSE": "mse", 22 | "R2": "r_squared", 23 | "SMAPE": "smape", 24 | "PSNR": "jpxkqx/peak_signal_to_noise_ratio", 25 | "SSIM": "jpxkqx/structural_similarity_index_measure", 26 | "SRE": "jpxkqx/signal_to_reconstruction_error", 27 | } 28 | 29 | 30 | def validate_model( 31 | model, 32 | dataset: torch.utils.data.IterableDataset, 33 | config: dict, 34 | hf_repo_name: str = None, 35 | label_scaler: XarrayStandardScaler = None, 36 | ): 37 | """ 38 | Validate the model. 39 | 40 | It makes it by generating evaluation plots, computing error metrics, and uploading 41 | metrics to the Hugging Face Model Hub. 42 | 43 | Parameters 44 | ---------- 45 | model : object 46 | The neural network model to validate. 47 | dataset : torch.utils.data.IterableDataset 48 | The dataset used for validation. 49 | config : dict 50 | Configuration settings for the validation. 51 | batch_size : int, optional 52 | Batch size for data loading, by default 4. 53 | hf_repo_name : str, optional 54 | Hugging Face repository name, by default None. 55 | label_scaler : XarrayStandardScaler, optional 56 | Label scaler object for applying inverse scaling, by default None. 57 | 58 | """ 59 | # Create data loader 60 | dataloader = torch.utils.data.DataLoader( 61 | dataset, config["batch_size"], pin_memory=True 62 | ) 63 | 64 | # Define scaler function if label scaler is provided 65 | scaler_func = None if label_scaler is None else label_scaler.apply_inverse_scaler 66 | 67 | # Define local directory for saving evaluation results 68 | local_dir = f"{config['output_directory']}/hf-{model.__class__.__name__}-evaluation" 69 | os.makedirs(name=local_dir, exist_ok=True) 70 | 71 | # Clone Hugging Face repository if provided 72 | if config["push_to_hub"] and hf_repo_name is not None: 73 | repo = Repository( 74 | local_dir, clone_from=hf_repo_name, token=os.getenv("HF_TOKEN") 75 | ) 76 | repo.git_pull() 77 | 78 | # Show samples compared with other models 79 | samples_cfg = config["visualizations"].get("sample_observation_vs_prediction", None) 80 | if samples_cfg is not None: 81 | visualization_local_dir = f"{local_dir}/sample_observation_vs_prediction" 82 | os.makedirs(visualization_local_dir, exist_ok=True) 83 | sample_observation_vs_prediction( 84 | model, 85 | dataloader, 86 | visualization_local_dir, 87 | scaler_func=scaler_func, 88 | baseline=config["baseline"], 89 | num_samples=samples_cfg["num_samples"], 90 | ) 91 | 92 | # Obtain error tensors by hour of the day and for all times 93 | ( 94 | mae, 95 | mse, 96 | r2, 97 | mae_base, 98 | mse_base, 99 | r2_base, 100 | improvement, 101 | ) = compute_model_and_baseline_errors( 102 | model, dataloader, config["baseline"], scaler_func 103 | ) 104 | 105 | # Compute error maps to compare spatial metric by hour (and for all the hours) 106 | names = [model.__class__.__name__, config["baseline"]] 107 | visualization_local_dir = f"{local_dir}/plot_2_maps_comparison" 108 | os.makedirs(visualization_local_dir, exist_ok=True) 109 | for time_value in [0, 3, 6, 9, 12, 15, 18, 21, "all"]: 110 | plot_2_maps_comparison( 111 | mse[time_value], 112 | mse_base[time_value], 113 | names, 114 | "MSE (ºC)", 115 | f"{visualization_local_dir}/mse_vs_{config['baseline']}_{time_value}.png", 116 | vmin=0, 117 | ) 118 | plot_2_maps_comparison( 119 | mae[time_value], 120 | mae_base[time_value], 121 | names, 122 | "MAE (ºC)", 123 | f"{visualization_local_dir}/mae_vs_{config['baseline']}_{time_value}.png", 124 | vmin=0, 125 | ) 126 | plot_2_maps_comparison( 127 | r2[time_value], 128 | r2_base[time_value], 129 | names, 130 | "R2", 131 | f"{visualization_local_dir}/r2_vs_{config['baseline']}_{time_value}.png", 132 | vmin=-1, 133 | ) 134 | 135 | # Compute rose plot to compare total metric by hour (and for all the hours) 136 | visualization_local_dir = f"{local_dir}/rose-plot" 137 | os.makedirs(visualization_local_dir, exist_ok=True) 138 | plot_rose( 139 | {key: value for key, value in mae.items() if key != "all"}, 140 | {key: value for key, value in mae_base.items() if key != "all"}, 141 | None, 142 | names=[model.__class__.__name__, config["baseline"]], 143 | custom_colors=["#390099", "#9e0059"], 144 | title="MAE (ºC)", 145 | output_path=f"{visualization_local_dir}/rose-plot_mae.png", 146 | ) 147 | plot_rose( 148 | {key: value for key, value in mse.items() if key != "all"}, 149 | {key: value for key, value in mse_base.items() if key != "all"}, 150 | None, 151 | names=[model.__class__.__name__, config["baseline"]], 152 | custom_colors=["#390099", "#9e0059"], 153 | title="MSE (ºC)", 154 | output_path=f"{visualization_local_dir}/rose-plot_mse.png", 155 | ) 156 | land_mask_array = dataset.add_auxiliary_features["lsm-high"].lsm.as_numpy().values 157 | plot_rose( 158 | {key: value for key, value in mae.items() if key != "all"}, 159 | {key: value for key, value in mae_base.items() if key != "all"}, 160 | ("land", land_mask_array), 161 | names=[model.__class__.__name__, config["baseline"]], 162 | custom_colors=["#390099", "#9e0059"], 163 | title="MAE (ºC) - only land points", 164 | output_path=f"{visualization_local_dir}/rose-plot_mae-on-land.png", 165 | ) 166 | plot_rose( 167 | {key: value for key, value in mse.items() if key != "all"}, 168 | {key: value for key, value in mse_base.items() if key != "all"}, 169 | ("land", land_mask_array), 170 | names=[model.__class__.__name__, config["baseline"]], 171 | custom_colors=["#390099", "#9e0059"], 172 | title="MSE (ºC) - only land points", 173 | output_path=f"{visualization_local_dir}/rose-plot_mse-on-land.png", 174 | ) 175 | plot_rose( 176 | {key: value for key, value in mae.items() if key != "all"}, 177 | {key: value for key, value in mae_base.items() if key != "all"}, 178 | ("sea", land_mask_array), 179 | names=[model.__class__.__name__, config["baseline"]], 180 | custom_colors=["#390099", "#9e0059"], 181 | title="MAE (ºC) - only sea points", 182 | output_path=f"{visualization_local_dir}/rose-plot_mae-on-sea.png", 183 | ) 184 | plot_rose( 185 | {key: value for key, value in mse.items() if key != "all"}, 186 | {key: value for key, value in mse_base.items() if key != "all"}, 187 | ("sea", land_mask_array), 188 | names=[model.__class__.__name__, config["baseline"]], 189 | custom_colors=["#390099", "#9e0059"], 190 | title="MSE (ºC) - only sea points", 191 | output_path=f"{visualization_local_dir}/rose-plot_mse-on-sea.png", 192 | ) 193 | 194 | # Compute and upload metrics to Hugging Face Model Hub 195 | test_metrics = compute_and_upload_metrics( 196 | model, dataloader, hf_repo_name, scaler_func 197 | ) 198 | evaluate.save(tmpdir, experiment=experiment_name, **test_metrics) 199 | 200 | # Push changes to Hugging Face repository if provided 201 | if hf_repo_name is not None: 202 | repo.push_to_hub( 203 | repo_id=hf_repo_name, 204 | commit_message=f"Tests on {dataset.init_date}-{dataset.end_date}", 205 | blocking=True, 206 | ) 207 | -------------------------------------------------------------------------------- /deepr/visualizations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/deepr/visualizations/__init__.py -------------------------------------------------------------------------------- /deepr/visualizations/giffs.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | 4 | import torch 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | from deepr.visualizations.plot_maps import plot_simple_map 9 | 10 | 11 | def generate_giff( 12 | latents: torch.Tensor, filename: str, fps: int = 50, label: str = "Temperature" 13 | ): 14 | vmin = torch.min(latents[-1, ...]) 15 | vmax = torch.max(latents[-1, ...]) 16 | 17 | with tempfile.TemporaryDirectory(suffix="-giff-diffusion") as f: 18 | fig_paths = [] 19 | for t in tqdm(range(latents.shape[0]), desc="Plotting frames for GIFF"): 20 | fname = Path(f) / f"latents_{t}.png" 21 | plot_simple_map(latents[t], vmin, vmax, label=label, out_file=fname) 22 | fig_paths.append(fname) 23 | 24 | imgs = [Image.open(f) for f in fig_paths] 25 | if not filename.endswith(".gif"): 26 | filename += ".gif" 27 | 28 | imgs[0].save( 29 | fp=filename, 30 | format="GIF", 31 | append_images=imgs, 32 | save_all=True, 33 | optimize=True, 34 | duration=max(20, int(1e3 / fps)), # 1 frame each 20ms = 50 fps (min value) 35 | loop=3, 36 | ) 37 | return filename 38 | -------------------------------------------------------------------------------- /deepr/visualizations/plot_maps.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from matplotlib.colors import LinearSegmentedColormap 8 | 9 | from deepr.utilities.logger import get_logger 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def get_figure_model_samples( 15 | fine_image: torch.Tensor, 16 | prediction: torch.Tensor, 17 | input_image: torch.Tensor = None, 18 | baseline: torch.Tensor = None, 19 | column_names: List[str] = None, 20 | filename: Optional[str] = None, 21 | fig_size: Optional[Tuple[int, int]] = None, 22 | ) -> matplotlib.pyplot.Figure: 23 | """ 24 | Generate a figure displaying model samples. 25 | 26 | Parameters 27 | ---------- 28 | fine_image : torch.Tensor 29 | Fine-resolution images. 30 | prediction : torch.Tensor 31 | Model predictions. 32 | input_image : torch.Tensor, optional 33 | Low-resolution input images. Default is None. 34 | baseline : torch.Tensor, optional 35 | Baseline images. Default is None. 36 | column_names : List[str], optional 37 | Names for columns in the figure. Default is None. 38 | filename : str, optional 39 | Filename to save the figure. Default is None. 40 | fig_size : Tuple[int, int], optional 41 | Figure size (width, height) in inches. Default is None. 42 | 43 | Returns 44 | ------- 45 | matplotlib.pyplot.Figure 46 | The generated figure. 47 | """ 48 | # Concatenating baseline to predictions, if baseline exists 49 | if baseline is not None: 50 | prediction = torch.cat([prediction, baseline], dim=0) 51 | 52 | n_extras = 2 if input_image is not None else 1 53 | 54 | # Defining the maximum and minimum value of the data that will be depicted 55 | vmax = max( 56 | float(torch.max(fine_image)), 57 | float(torch.max(prediction)), 58 | -9999.0 if input_image is None else float(torch.max(input_image)), 59 | -9999.0 if input_image is None else float(torch.max(baseline)), 60 | ) 61 | vmin = min( 62 | float(torch.min(fine_image)), 63 | float(torch.min(prediction)), 64 | 9999.0 if input_image is None else float(torch.min(input_image)), 65 | 9999.0 if input_image is None else float(torch.min(baseline)), 66 | ) 67 | v_kwargs = {"vmax": vmax, "vmin": vmin} 68 | 69 | n_samples = int(fine_image.shape[0]) 70 | 71 | if input_image is not None and n_samples != int(input_image.shape[0]): 72 | raise ValueError("Inconsistent number of samples between images.") 73 | elif int(prediction.shape[0]) % n_samples != 0: 74 | raise ValueError("Inconsistent number of samples between predictions.") 75 | else: 76 | n_realizations = prediction.shape[0] // n_samples 77 | 78 | # Defining the figure size if it is not provided as argument 79 | if fig_size is None: 80 | fig_size = (4.5 * (n_realizations + n_extras), 4.8 * n_samples) 81 | 82 | # Defining the figure and axes 83 | fig, axs = plt.subplots(n_realizations + n_extras, n_samples, figsize=fig_size) 84 | if n_samples == 1: # if only one row, it is necessary to include in the axes 85 | axs = axs[..., np.newaxis] 86 | plt.tight_layout() 87 | 88 | # Loop over the number of columns, which is the same as the number of samples 89 | for i in range(n_samples): 90 | if input_image is not None: 91 | axs[0, i].imshow(input_image[i, 0].numpy()[..., np.newaxis], **v_kwargs) 92 | axs[1, i].imshow(fine_image[i, 0].numpy()[..., np.newaxis], **v_kwargs) 93 | else: 94 | axs[0, i].imshow(fine_image[i, 0].numpy()[..., np.newaxis], **v_kwargs) 95 | 96 | # Loop over the number of rows in the column 97 | for r in range(n_realizations): 98 | # Plot the predictions 99 | im = axs[n_extras + r, i].imshow( 100 | prediction[i + r * n_samples, 0].numpy()[..., np.newaxis], **v_kwargs 101 | ) 102 | 103 | axs[0, i].get_xaxis().set_ticks([]) 104 | axs[0, i].get_yaxis().set_ticks([]) 105 | axs[1, i].get_xaxis().set_ticks([]) 106 | axs[1, i].get_yaxis().set_ticks([]) 107 | 108 | for r in range(n_realizations): 109 | axs[n_extras + r, i].get_xaxis().set_ticks([]) 110 | axs[n_extras + r, i].get_yaxis().set_ticks([]) 111 | 112 | # Title of the rows 113 | if i == 0: 114 | if input_image is not None: 115 | axs[0, i].set_ylabel("ERA5 (Low-res)", fontsize=14) 116 | axs[1, i].set_ylabel("CERRA (High-res)", fontsize=14) 117 | else: 118 | axs[0, i].set_ylabel("CERRA (High-res)", fontsize=14) 119 | 120 | for r in range(n_realizations): 121 | if baseline is not None and r == n_realizations - 1: 122 | label = "Bicubic Int." 123 | else: 124 | label = "Prediction (High-res)" 125 | axs[n_extras + r, i].set_ylabel(label, fontsize=14) 126 | 127 | # Title of the columns 128 | if column_names is not None: 129 | for c, col_name in enumerate(column_names): 130 | axs[0, c].set_title(col_name, fontsize=14) 131 | 132 | # Include the color bar in the depiction 133 | if n_samples == 1: 134 | fig.subplots_adjust(bottom=0.05) 135 | cbar_ax = fig.add_axes([0.15, 0.05, 0.7, 0.05]) 136 | fig.colorbar(im, cax=cbar_ax, orientation="horizontal") 137 | else: 138 | fig.subplots_adjust(right=0.95) 139 | cbar_ax = fig.add_axes([0.97, 0.15, 0.05, 0.7]) 140 | fig.colorbar(im, cax=cbar_ax, orientation="vertical") 141 | 142 | # Save figure if the name of a file is provided 143 | if filename is not None: 144 | logger.info(f"Samples from model have been saved to {filename}") 145 | plt.savefig(filename, bbox_inches="tight", transparent=True) 146 | plt.close() 147 | 148 | return fig 149 | 150 | 151 | def plot_2_maps_comparison( 152 | matrix1: torch.Tensor, 153 | matrix2: torch.Tensor, 154 | matrix_names: List[str] = None, 155 | metric_name: str = None, 156 | filename: Optional[str] = None, 157 | **kwargs, 158 | ): 159 | if "vmax" not in kwargs: 160 | vmax = max(float(torch.max(matrix1)), float(torch.max(matrix2))) 161 | else: 162 | vmax = kwargs["vmax"] 163 | 164 | if "vmin" not in kwargs: 165 | vmin = min(float(torch.min(matrix1)), float(torch.min(matrix2))) 166 | else: 167 | vmin = kwargs["vmin"] 168 | 169 | v_kwargs = {"vmax": vmax, "vmin": vmin} 170 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) 171 | plt.tight_layout() 172 | 173 | ax1.imshow(matrix1.numpy(), **v_kwargs) 174 | im = ax2.imshow(matrix2.numpy(), **v_kwargs) 175 | 176 | ax1.get_xaxis().set_ticks([]) 177 | ax1.get_yaxis().set_ticks([]) 178 | ax2.get_xaxis().set_ticks([]) 179 | ax2.get_yaxis().set_ticks([]) 180 | 181 | if matrix_names is not None: 182 | ax1.set_title(matrix_names[0].capitalize(), fontsize=18) 183 | ax2.set_title(matrix_names[1].capitalize(), fontsize=18) 184 | 185 | fig.subplots_adjust(bottom=0.05) 186 | cbar_ax = fig.add_axes([0.15, 0.02, 0.7, 0.02]) 187 | fig.colorbar(im, cax=cbar_ax, orientation="horizontal", label=metric_name) 188 | 189 | if filename is not None: 190 | logger.info(f"Samples from model have been saved to {filename}") 191 | plt.savefig(filename, bbox_inches="tight", transparent=True) 192 | plt.close() 193 | 194 | return fig 195 | 196 | 197 | def plot_2_model_comparison( 198 | reference: torch.Tensor, 199 | pred1: torch.Tensor, 200 | pred2: torch.Tensor, 201 | matrix_names: List[str] = None, 202 | metric_name: str = None, 203 | date: str = None, 204 | filename: Optional[str] = None, 205 | **kwargs, 206 | ): 207 | if "vmax" not in kwargs: 208 | vmax = max( 209 | float(torch.max(reference)), 210 | float(torch.max(pred1)), 211 | float(torch.max(pred2)), 212 | ) 213 | else: 214 | vmax = kwargs["vmax"] 215 | 216 | if "vmin" not in kwargs: 217 | vmin = min( 218 | float(torch.min(reference)), 219 | float(torch.min(pred1)), 220 | float(torch.min(pred2)), 221 | ) 222 | else: 223 | vmin = kwargs["vmin"] 224 | 225 | pred_kwargs = {"vmax": vmax, "vmin": vmin, "cmap": "summer"} 226 | fig = plt.figure(layout="constrained", figsize=(14, 5)) 227 | sfigs = fig.subfigures(1, 2, width_ratios=[1, 2]) 228 | ax = sfigs[0].subplots(1, 1) 229 | axs = sfigs[1].subplots(2, 2) 230 | 231 | # Reference Matrix 232 | ax.imshow(reference.numpy(), **pred_kwargs) 233 | ax.get_xaxis().set_ticks([]) 234 | ax.get_yaxis().set_ticks([]) 235 | 236 | # Prediction and errors 237 | error1 = pred1 - reference 238 | error2 = pred2 - reference 239 | 240 | emax = max(float(torch.max(torch.abs(error1))), float(torch.max(torch.abs(error2)))) 241 | 242 | axs[0, 0].imshow(pred1.numpy(), **pred_kwargs) 243 | axs[1, 0].imshow(error1.numpy(), vmin=-emax, vmax=emax, cmap="RdBu") 244 | im = axs[0, 1].imshow(pred2.numpy(), **pred_kwargs) 245 | im_e = axs[1, 1].imshow(error2.numpy(), vmin=-emax, vmax=emax, cmap="RdBu") 246 | 247 | for ax_unraveled in axs.ravel(): 248 | ax_unraveled.get_xaxis().set_ticks([]) 249 | ax_unraveled.get_yaxis().set_ticks([]) 250 | 251 | if date is not None: 252 | ax.set_xlabel(date, fontsize=18) 253 | 254 | if matrix_names is not None: 255 | ax.set_title(matrix_names[0], fontsize=18) 256 | axs[0, 0].set_title(matrix_names[1], fontsize=18) 257 | axs[0, 1].set_title(matrix_names[2], fontsize=18) 258 | 259 | sfigs[1].colorbar( 260 | im, ax=axs[0, 1], orientation="vertical", label=f"Prediction ({metric_name})" 261 | ) 262 | sfigs[1].colorbar( 263 | im_e, ax=axs[1, 1], orientation="vertical", label=f"Abs. Error ({metric_name})" 264 | ) 265 | 266 | if filename is not None: 267 | logger.info(f"Samples from model have been saved to {filename}") 268 | plt.savefig(filename, bbox_inches="tight", transparent=False) 269 | plt.close() 270 | 271 | return fig 272 | 273 | 274 | def plot_simple_map( 275 | data, 276 | vmin=None, 277 | vmax=None, 278 | cmap: str = None, 279 | label: str = "Temperature (ºC)", 280 | out_file: str = None, 281 | ): 282 | if vmin * vmax < 0: # opposite signs 283 | colors = [(0, "blue"), (-vmin / (vmax - vmin), "white"), (1, "red")] 284 | elif min(vmin, vmax) >= 0: # both possitive 285 | colors = [(0, "white"), (1, "red")] 286 | elif max(vmin, vmax) <= 0: # both negative 287 | colors = [(0, "blue"), (1, "white")] 288 | 289 | if cmap is None: 290 | cmap = LinearSegmentedColormap.from_list("temp", colors) 291 | 292 | plt.imshow(data, vmin=vmin, vmax=vmax, cmap=cmap) 293 | plt.axis("off") 294 | plt.colorbar(shrink=0.65, label=label) 295 | if out_file is not None: 296 | plt.savefig(out_file, transparent=True, bbox_inches="tight", dpi=200) 297 | plt.close() 298 | else: 299 | return plt.clf() 300 | -------------------------------------------------------------------------------- /deepr/visualizations/plot_rose.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def calculate_metric_mean(tensors: Dict, land_mask: Tuple[bool, np.ndarray]) -> list: 8 | """ 9 | Calculate the mean value of a metric for each tensor in the given dictionary. 10 | 11 | Parameters 12 | ---------- 13 | tensors (dict): A dictionary where each key corresponds to a tensor. 14 | The values of the dictionary are tensors (2D arrays). 15 | land_mask (numpy.array): A 2D numpy array representing the land mask. 16 | 17 | Returns 18 | ------- 19 | list: A list containing the mean values of the metric for each tensor. 20 | """ 21 | # Convert the dictionary values to a list of tensors 22 | tensor_list = list(tensors.values()) 23 | 24 | # Filter out the values where the land mask is 0 25 | if land_mask is not None: 26 | if land_mask[0] == "land": 27 | tensor_list = [ 28 | np.where(land_mask[1] == 1, tensor.numpy(), np.nan) 29 | for tensor in tensor_list 30 | ] 31 | elif land_mask[0] == "sea": 32 | tensor_list = [ 33 | np.where(land_mask[1] == 0, tensor.numpy(), np.nan) 34 | for tensor in tensor_list 35 | ] 36 | else: 37 | raise NotImplementedError 38 | 39 | # Calculate the mean value of the metric for each tensor and store it in a list 40 | mean_values = [np.nanmean(np.abs(tensor)) for tensor in tensor_list] 41 | 42 | return mean_values 43 | 44 | 45 | def plot_rose( 46 | metric: dict, 47 | metric_baseline: dict, 48 | land_mask: Tuple[bool, np.array], 49 | names: list, 50 | custom_colors: list, 51 | title: str, 52 | output_path: str, 53 | ) -> None: 54 | """ 55 | Create two rose plots side by side for the given metric dictionaries. 56 | 57 | Parameters 58 | ---------- 59 | metric (dict): A dictionary where each key corresponds to a tensor. 60 | The values of the dictionary are tensors (2D arrays). 61 | metric_baseline (dict): Another dictionary with the same structure as 'metric' 62 | representing the baseline metric values. 63 | land_mask (numpy.array): A 2D numpy array representing the land mask. 64 | title (str): The title of the rose plots. 65 | output_path (str): The file path where the plot will be saved. 66 | """ 67 | # Get the keys from the dictionaries and sort them in ascending order 68 | keys_metric = sorted(list(metric.keys()), reverse=True) 69 | if metric_baseline is not None: 70 | keys_baseline = sorted(list(metric_baseline.keys()), reverse=True) 71 | 72 | # Calculate the angle step for each key (clockwise direction) 73 | angle_step = 360.0 / len(keys_metric) 74 | 75 | # Assign unique angles to each key, starting at 90 degrees and moving clockwise 76 | angles_metric = ((np.arange(len(keys_metric)) * angle_step) + 90 + angle_step) % 360 77 | if metric_baseline is not None: 78 | angles_baseline = ( 79 | (np.arange(len(keys_baseline)) * angle_step) + 90 + angle_step 80 | ) % 360 81 | 82 | # Calculate the mean metric values for the 'metric' dictionary 83 | mean_values_metric = calculate_metric_mean(metric, land_mask) 84 | 85 | # Calculate the mean metric values for the 'metric_baseline' dictionary 86 | if metric_baseline is not None: 87 | mean_values_baseline = calculate_metric_mean(metric_baseline, land_mask) 88 | 89 | # Set the width of the bars (adjust the value as needed) 90 | bar_width = 10.0 91 | 92 | # Set the larger tick sizes 93 | plt.rcParams["xtick.labelsize"] = 16 94 | plt.rcParams["ytick.labelsize"] = 16 95 | 96 | # Create the rose plots side by side 97 | plt.figure(figsize=(18, 12)) 98 | 99 | ax = plt.subplot(polar=True) 100 | 101 | # Calculate the positions for the bars 102 | positions_metric = np.radians(angles_metric) - np.radians(bar_width / 2) 103 | if metric_baseline is not None: 104 | positions_baseline = np.radians(angles_baseline) + np.radians(bar_width / 2) 105 | 106 | if metric_baseline is not None: 107 | ax.bar( 108 | positions_baseline, 109 | mean_values_baseline, 110 | width=np.radians(bar_width), 111 | align="center", 112 | alpha=0.8, 113 | label=names[1].capitalize(), 114 | color=custom_colors[1], 115 | ) 116 | 117 | ax.bar( 118 | positions_metric, 119 | mean_values_metric, 120 | width=np.radians(bar_width), 121 | align="center", 122 | alpha=0.8, 123 | label=names[0].capitalize(), 124 | color=custom_colors[0], 125 | ) 126 | 127 | # Remove the last circumference (circular axis line) 128 | ax.spines["polar"].set_visible(False) 129 | 130 | plt.title(title, fontsize=20) 131 | plt.thetagrids(angles_baseline, labels=keys_baseline) 132 | plt.legend(fontsize=14) 133 | plt.tight_layout() 134 | plt.savefig(output_path) 135 | -------------------------------------------------------------------------------- /deepr/visualizations/plot_samples.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy 5 | import torch 6 | 7 | from deepr.utilities.logger import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def get_figure_model_samples( 13 | fine_image: torch.Tensor, 14 | prediction: torch.Tensor, 15 | input_image: torch.Tensor = None, 16 | baseline: torch.Tensor = None, 17 | column_names: List[str] = None, 18 | filename: Optional[str] = None, 19 | fig_size: Optional[Tuple[int, int]] = None, 20 | ) -> plt.Figure: 21 | """ 22 | Generate a figure displaying model samples. 23 | 24 | Parameters 25 | ---------- 26 | fine_image : torch.Tensor 27 | Fine-resolution images. 28 | prediction : torch.Tensor 29 | Model predictions. 30 | input_image : torch.Tensor, optional 31 | Low-resolution input images. Default is None. 32 | baseline : torch.Tensor, optional 33 | Baseline images. Default is None. 34 | column_names : List[str], optional 35 | Names for columns in the figure. Default is None. 36 | filename : str, optional 37 | Filename to save the figure. Default is None. 38 | fig_size : Tuple[int, int], optional 39 | Figure size (width, height) in inches. Default is None. 40 | 41 | Returns 42 | ------- 43 | plt.Figure 44 | The generated figure. 45 | """ 46 | # Concatenating baseline to predictions, if baseline exists 47 | if baseline is not None: 48 | prediction = torch.cat([prediction, baseline], dim=0) 49 | 50 | n_extras = 2 if input_image is not None else 1 51 | 52 | # Defining the maximum and minimum value of the data that will be depicted 53 | vmax = max( 54 | float(torch.max(fine_image)), 55 | float(torch.max(prediction)), 56 | -9999.0 if input_image is None else float(torch.max(input_image)), 57 | -9999.0 if baseline is None else float(torch.max(baseline)), 58 | ) 59 | vmin = min( 60 | float(torch.min(fine_image)), 61 | float(torch.min(prediction)), 62 | 9999.0 if input_image is None else float(torch.min(input_image)), 63 | 9999.0 if baseline is None else float(torch.min(baseline)), 64 | ) 65 | v_kwargs = {"vmax": vmax, "vmin": vmin} 66 | 67 | n_samples = int(fine_image.shape[0]) 68 | 69 | if input_image is not None and n_samples != int(input_image.shape[0]): 70 | raise ValueError("Inconsistent number of samples between images.") 71 | elif int(prediction.shape[0]) % n_samples != 0: 72 | raise ValueError("Inconsistent number of samples between predictions.") 73 | else: 74 | n_realizations = prediction.shape[0] // n_samples 75 | 76 | # Defining the figure size if it is not provided as argument 77 | if fig_size is None: 78 | fig_size = (4.5 * (n_realizations + n_extras), 4.8 * n_samples) 79 | 80 | # Defining the figure and axes 81 | if n_samples > 1: 82 | fig, axs = plt.subplots(n_realizations + n_extras, n_samples, figsize=fig_size) 83 | func_name = "set_ylabel" 84 | elif n_samples == 1: # if only one row, it is necessary to include in the axes 85 | fig, axs = plt.subplots(1, n_realizations + n_extras, figsize=fig_size) 86 | axs = axs[..., numpy.newaxis] 87 | func_name = "set_title" 88 | plt.tight_layout() 89 | 90 | # Loop over the number of columns, which is the same as the number of samples 91 | for i in range(n_samples): 92 | if input_image is not None: 93 | axs[0, i].imshow(input_image[i, 0].numpy()[..., numpy.newaxis], **v_kwargs) 94 | axs[1, i].imshow(fine_image[i, 0].numpy()[..., numpy.newaxis], **v_kwargs) 95 | else: 96 | axs[0, i].imshow(fine_image[i, 0].numpy()[..., numpy.newaxis], **v_kwargs) 97 | 98 | # Loop over the number of rows in the column 99 | for r in range(n_realizations): 100 | # Plot the predictions 101 | im = axs[n_extras + r, i].imshow( 102 | prediction[i + r * n_samples, 0].numpy()[..., numpy.newaxis], **v_kwargs 103 | ) 104 | 105 | axs[0, i].get_xaxis().set_ticks([]) 106 | axs[0, i].get_yaxis().set_ticks([]) 107 | axs[1, i].get_xaxis().set_ticks([]) 108 | axs[1, i].get_yaxis().set_ticks([]) 109 | 110 | for r in range(n_realizations): 111 | axs[n_extras + r, i].get_xaxis().set_ticks([]) 112 | axs[n_extras + r, i].get_yaxis().set_ticks([]) 113 | 114 | # Title of the rows 115 | if i == 0: 116 | if input_image is not None: 117 | getattr(axs[0, i], func_name)("ERA5 (Low-res)", fontsize=14) 118 | getattr(axs[1, i], func_name)("CERRA (High-res)", fontsize=14) 119 | else: 120 | getattr(axs[0, i], func_name)("CERRA (High-res)", fontsize=14) 121 | 122 | for r in range(n_realizations): 123 | if baseline is not None and r == n_realizations - 1: 124 | label = "Bicubic Int." 125 | else: 126 | label = "Prediction (High-res)" 127 | getattr(axs[n_extras + r, i], func_name)(label, fontsize=14) 128 | 129 | # Title of the columns 130 | if column_names is not None: 131 | for c, col_name in enumerate(column_names): 132 | axs[0, c].set_title(col_name, fontsize=14) 133 | 134 | # Include the color bar in the depiction 135 | if n_samples == 1: 136 | fig.subplots_adjust(bottom=0.05) 137 | cbar_ax = fig.add_axes([0.15, 0.02, 0.7, 0.05]) 138 | fig.colorbar(im, cax=cbar_ax, orientation="horizontal") 139 | else: 140 | fig.subplots_adjust(right=0.95) 141 | cbar_ax = fig.add_axes([0.97, 0.15, 0.05, 0.7]) 142 | fig.colorbar(im, cax=cbar_ax, orientation="vertical") 143 | 144 | # Save figure if the name of a file is provided 145 | if filename is not None: 146 | logger.info(f"Samples from model have been saved to {filename}") 147 | plt.savefig(filename, bbox_inches="tight", transparent=True) 148 | plt.close() 149 | 150 | return fig 151 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/_static/.gitkeep -------------------------------------------------------------------------------- /docs/_static/convswin2sr_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/_static/convswin2sr_scheme.png -------------------------------------------------------------------------------- /docs/_static/dp_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/_static/dp_scheme.png -------------------------------------------------------------------------------- /docs/_static/pos_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/_static/pos_embedding.png -------------------------------------------------------------------------------- /docs/_static/project_motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/_static/project_motivation.png -------------------------------------------------------------------------------- /docs/_static/spatial-domain-small.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/_static/spatial-domain-small.png -------------------------------------------------------------------------------- /docs/_static/standardization_types.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/_static/standardization_types.png -------------------------------------------------------------------------------- /docs/_templates/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/_templates/.gitkeep -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Import and path setup --------------------------------------------------- 8 | 9 | import os 10 | import sys 11 | 12 | import deepr 13 | 14 | sys.path.insert(0, os.path.abspath("../")) 15 | 16 | # -- Project information ----------------------------------------------------- 17 | 18 | project = "DeepR" 19 | copyright = "2023, European Union" 20 | author = "European Union" 21 | version = deepr.__version__ 22 | release = deepr.__version__ 23 | 24 | # -- General configuration --------------------------------------------------- 25 | 26 | # Add any Sphinx extension module names here, as strings. They can be 27 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 28 | # ones. 29 | extensions = [ 30 | "autoapi.extension", 31 | "myst_parser", 32 | "sphinx.ext.autodoc", 33 | "sphinx.ext.mathjax", 34 | "sphinx.ext.napoleon", 35 | ] 36 | 37 | # autodoc configuration 38 | autodoc_typehints = "none" 39 | 40 | # autoapi configuration 41 | autoapi_dirs = ["../deepr"] 42 | autoapi_ignore = ["*/version.py"] 43 | autoapi_options = [ 44 | "members", 45 | "inherited-members", 46 | "undoc-members", 47 | # "show-inheritance", 48 | "show-module-summary", 49 | "imported-members", 50 | ] 51 | autoapi_root = "_api" 52 | 53 | # napoleon configuration 54 | napoleon_google_docstring = False 55 | napoleon_numpy_docstring = True 56 | napoleon_preprocess_types = True 57 | 58 | # Add any paths that contain templates here, relative to this directory. 59 | templates_path = ["_templates"] 60 | 61 | # List of patterns, relative to source directory, that match files and 62 | # directories to ignore when looking for source files. 63 | # This pattern also affects html_static_path and html_extra_path. 64 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 65 | 66 | 67 | # -- Options for HTML output ------------------------------------------------- 68 | 69 | # The theme to use for HTML and HTML Help pages. See the documentation for 70 | # a list of builtin themes. 71 | # 72 | html_theme = "pydata_sphinx_theme" 73 | 74 | # Add any paths that contain custom static files (such as style sheets) here, 75 | # relative to this directory. They are copied after the builtin static files, 76 | # so a file named "default.css" will overwrite the builtin "default.css". 77 | html_static_path = ["_static"] 78 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to DeepR's documentation! 2 | 3 | DeepR: Global reanalysis downscaling to regional scales using Deep Diffusion models. 4 | 5 | This project has been developed under Code For Earth initiative, an innovation programme run by the European Centre for Medium-Range Weather Forecasts (ECMWF). 6 | 7 | ```{toctree} 8 | :caption: 'Contents:' 9 | :maxdepth: 2 10 | 11 | usage/installation 12 | usage/data 13 | usage/methodology 14 | usage/references 15 | ``` 16 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/usage/data.md: -------------------------------------------------------------------------------- 1 | ### Data specifications 2 | 3 | The spatial coverage of the datasets provided is described below: 4 | 5 | Features: (240, 150) ------- Label: (800, 500) 6 | 7 | ```complete-spatial-coverage.yml 8 | data_configuration: 9 | features_configuration: 10 | spatial_coverage: 11 | longitude: [ -20.5, 39.25 ] 12 | latitude: [ 66.25, 29 ] 13 | label_configuration: 14 | spatial_coverage: 15 | longitude: [ -10, 29.95] 16 | latitude: [ 60, 35.05 ] 17 | ``` 18 | 19 | During the development stage, a subset of the data is used to validate the implementation of the model: 20 | 21 | Features: (32, 20)------- Label: (32, 20) 22 | 23 | ```reduce-spatial-coverage.yml 24 | data_configuration: 25 | features_configuration: 26 | spatial_coverage: 27 | longitude: [ 6.0, 13.75 ] 28 | latitude: [ 50, 45.25 ] 29 | label_configuration: 30 | spatial_coverage: 31 | longitude: [ 9.2, 10.75] 32 | latitude: [ 48, 47.05 ] 33 | ``` 34 | -------------------------------------------------------------------------------- /docs/usage/installation.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECMWFCode4Earth/DeepR/761cc3bc710197ce42c97b211cc27bb743b17601/docs/usage/installation.md -------------------------------------------------------------------------------- /docs/usage/methodology.md: -------------------------------------------------------------------------------- 1 | ## Methodology 2 | 3 | The main purpose of this library is to test the capabilities of deep diffusion models for reanalysis super-resolution tasks. 4 | 5 | The objectives of this challenge focus on: 6 | 7 | - Explore the capabilities of Deep Diffusion models to represent high resolution reanalysis datasets. 8 | 9 | - Evaluate the impact of including several covariables in the model. 10 | 11 | - Conditioning on time stamps 12 | - Conditioning on meteorological covariables 13 | - Conditioning on in-site observations 14 | 15 | ### Super Resolution Diffusion model 16 | 17 | Explanation of Diffusion processes... 18 | 19 | Here, DL is considered to model the $\\epsilon_t$ sampled at each step given $x\_{t+1}$ and conditioned on the LR image. 20 | 21 | #### Training 22 | 23 | During training, for each batch of data, we sample random timesteps $t$ and noise $\\epsilon\_{t}$ and derive the corresponding values $x_t$. Then, we train our DL model to minimize the following loss function: 24 | 25 | $$ \\mathcal{L} (x) = || \\epsilon\_{t} - \\Phi \\left(x\_{t+1}, t \\right) ||^2$$ 26 | 27 | which is the mean squared error (MSE) between: 28 | 29 | - the noise, $\\epsilon\_{t}$, added at timestep $t$ 30 | 31 | - the prediction of the DL model, $\\Phi$, taking as input the timestep $t$ and the noisy matrix $x\_{t+1}$. 32 | 33 | #### Inference 34 | 35 | During inference, we can sample random noise and run the reverse process conditioned on input ERA5 grids, to obtain high resolution reanalysis grids. Another major benefit from this approach is the possibility of generation an ensemble of grids to represent its uncertainty avoiding the mode collapse (common in GANs). 36 | 37 | ### U-Net 38 | 39 | In particular, a tailored U-Net architecture with 2D convolutions, residual connections and attetion layers is used. 40 | 41 | ![U-Net Architecture Diagram](./docs/_media/eps-U-Net%20diagram.svg) 42 | 43 | The parameteres of these model implemented in [deepr/model/unet.py](deepr/model/unet.py) are: 44 | 45 | - `image_channels`: It is the number of channels of the high resolution imagen we want to generate, that matches with the number of channels of the output from the U-Net. Default value is `1`, as we plan to sample one variable at a time. 46 | 47 | - `n_channels`: It is the number of output channels of the initial Convolution. Defaults to `16`. 48 | 49 | - `channel_multipliers`: It is the multiplying factor over the channels applied at each down/upsampling level of the U-Net. Defaults to `[1, 2, 2, 4]`. 50 | 51 | - `is_attention`: It represents the use of Attention over each down/upsampling level of the U-Net. Defaults to `[False, False, True, True]`. 52 | 53 | - `n_blocks`: The number of residual blocks considered in each level. Defaults to `2`. 54 | 55 | - `conditioned_on_input`: The number of channels of the conditions considered. 56 | 57 | *NOTE I*: The length of `channel_multipliers` and `is_attention` should match as it sets the number of resolutions of our U-Net architecture. 58 | 59 | *NOTE II*: Spatial tensors fed to Diffusion model must have shapes of length multiple of $2^{\\text{num resolutions} - 1}$. 60 | 61 | #### Downsampling 62 | 63 | #### Upsampling 64 | 65 | #### Down Block 66 | 67 | #### Up Block 68 | 69 | #### Residual Block 70 | 71 | #### Final Block 72 | -------------------------------------------------------------------------------- /docs/usage/references.md: -------------------------------------------------------------------------------- 1 | - [Annotated Deep Learning Paper implementations](https://github.com/labmlai/annotated_deep_learning_paper_implementations) 2 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # environment.yml: Mandatory dependencies only. 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | 6 | dependencies: 7 | # Libraries to manipulate data 8 | - dask 9 | - xarray 10 | - netcdf4 11 | # Deep learning and machine learning libraries 12 | - pytorch 13 | - tensorboard 14 | - accelerate 15 | - diffusers 16 | - transformers 17 | - huggingface_hub 18 | # Visualization libraries 19 | - cartopy 20 | - matplotlib 21 | - seaborn 22 | # Metrics libraries 23 | - evaluate 24 | - scikit-image 25 | # Libraries for utilities 26 | - mlflow 27 | - pydantic 28 | - pyyaml 29 | - tqdm 30 | # Types for mypy 31 | - types-requests 32 | - types-pyyaml 33 | -------------------------------------------------------------------------------- /environment_CUDA.yml: -------------------------------------------------------------------------------- 1 | # environment.yml: Mandatory dependencies only. 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nvidia 6 | - nodefaults 7 | 8 | dependencies: 9 | # Libraries to manipulate data 10 | - dask 11 | - xarray 12 | - netcdf4 13 | # Deep learning and machine learning libraries 14 | - pytorch 15 | - torchvision 16 | - torchaudio 17 | - pytorch-cuda=11.8 18 | - tensorboard 19 | - accelerate 20 | - diffusers 21 | - transformers 22 | - huggingface_hub 23 | - aim 24 | # Visualization libraries 25 | - cartopy 26 | - matplotlib 27 | - seaborn 28 | # Metrics libraries 29 | - evaluate 30 | - scikit-image 31 | # Libraries for utilities 32 | - mlflow 33 | - pydantic 34 | - pyyaml 35 | - tqdm 36 | # Types for mypy 37 | - types-requests 38 | - types-pyyaml 39 | # Software 40 | - git-lfs 41 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 3 | 4 | [project] 5 | classifiers = [ 6 | "Development Status :: 2 - Pre-Alpha", 7 | "Intended Audience :: Science/Research", 8 | "License :: OSI Approved :: Apache Software License", 9 | "Operating System :: OS Independent", 10 | "Programming Language :: Python", 11 | "Programming Language :: Python :: 3", 12 | "Programming Language :: Python :: 3.10", 13 | "Programming Language :: Python :: 3.11", 14 | "Topic :: Scientific/Engineering" 15 | ] 16 | description = "DeepR: Deep Reanalysis" 17 | dynamic = ["version"] 18 | license = {file = "LICENSE"} 19 | name = "DeepR" 20 | readme = "README.md" 21 | 22 | [tool.coverage.run] 23 | branch = true 24 | 25 | [tool.mypy] 26 | explicit_package_bases = true 27 | follow_imports = "normal" 28 | ignore_missing_imports = true 29 | strict = false 30 | 31 | [tool.ruff] 32 | ignore = [ 33 | # pydocstyle: Missing Docstrings 34 | "D1", 35 | # pydocstyle: numpy convention 36 | "D107", 37 | "D203", 38 | "D212", 39 | "D213", 40 | "D402", 41 | "D413", 42 | "D415", 43 | "D416", 44 | "D417" 45 | ] 46 | # Black line length is 88, but black does not format comments. 47 | line-length = 110 48 | select = [ 49 | # pyflakes 50 | "F", 51 | # pycodestyle 52 | "E", 53 | "W", 54 | # isort 55 | "I", 56 | # pydocstyle 57 | "D" 58 | ] 59 | 60 | [tool.setuptools] 61 | packages = ["deepr"] 62 | 63 | [tool.setuptools_scm] 64 | write_to = "deepr/version.py" 65 | write_to_template = ''' 66 | # Do not change! Do not track in version control! 67 | __version__ = "{version}" 68 | ''' 69 | -------------------------------------------------------------------------------- /resources/configuration_diffusion.yml: -------------------------------------------------------------------------------- 1 | # Data configuration 2 | data_configuration: 3 | # Features configuration 4 | features_configuration: 5 | variables: 6 | - t2m 7 | data_name: era5 8 | spatial_resolution: "025deg" 9 | add_auxiliary: 10 | time: true 11 | lsm-low: true 12 | orog-low: true 13 | lsm-high: true 14 | orog-high: true 15 | spatial_coverage: 16 | longitude: [-8.35, 6.6] 17 | latitude: [46.45, 35.50] 18 | standardization: 19 | to_do: true 20 | cache_folder: /PATH/TO/.cache_reanalysis_scales 21 | method: domain-wise 22 | data_location: /PATH/TO/features/ 23 | land_mask_location: /PATH/TO/static/land-mask_ERA5.nc 24 | orography_location: /PATH/TOstatic/orography_ERA5.nc 25 | # Label configuration 26 | label_configuration: 27 | variable: t2m 28 | data_name: cerra 29 | spatial_resolution: "005deg" 30 | standardization: 31 | to_do: true 32 | cache_folder: /PATH/TO/.cache_reanalysis_scales 33 | method: domain-wise 34 | spatial_coverage: 35 | longitude: [-6.85, 5.1] 36 | latitude: [44.95, 37] 37 | data_location: /PATH/TO/labels/ 38 | land_mask_location: /PATH/TO/static/land-mask_CERRA.nc 39 | orography_location: /PATH/TO/static/orography_CERRA.nc 40 | # Common data configuration 41 | common_configuration: 42 | temporal_coverage: 43 | start: 1981-01 44 | end: 2018-12 45 | frequency: MS 46 | data_split: 47 | test: 0.0 48 | validation: 0.2 49 | # Training configuration 50 | training_configuration: 51 | type: diffusion 52 | model_configuration: 53 | eps_model: 54 | class_name: diffusers.UNet2DModel 55 | kwargs: 56 | block_out_channels: [16, 24, 32] 57 | down_block_types: [DownBlock2D, AttnDownBlock2D, AttnDownBlock2D] 58 | up_block_types: [AttnUpBlock2D, AttnUpBlock2D, UpBlock2D] 59 | layers_per_block: 2 60 | time_embedding_type: positional 61 | #class_embed_type: "none" # timestep 62 | num_class_embeds: 24 # Encode hours as table of 24 embeddings 63 | in_channels: 2 64 | norm_num_groups: 4 65 | #trained_obs_model: 66 | # class_name: ConvBilinear 67 | # model_dir: predictia/europe_reanalysis_downscaler_convbaseline 68 | scheduler: 69 | class_name: LMSDiscreteScheduler 70 | kwargs: 71 | num_train_timesteps: 1000 72 | beta_start: 0.0001 73 | beta_end: 0.02 74 | beta_schedule: linear 75 | prediction_type: epsilon 76 | rescale_betas_zero_snr: true 77 | timestep_spacing: trailing 78 | training_parameters: 79 | num_epochs: 30 80 | batch_size: 4 81 | gradient_accumulation_steps: 4 82 | learning_rate: 0.001 83 | lr_warmup_steps: 500 84 | mixed_precision: "fp16" 85 | hour_embed_type: class # none, timestep, positional, cyclical, class 86 | hf_repo_name: "predictia/MODEL_REPO_NAME" 87 | output_dir: "cddpm-probando-tiny" 88 | device: cuda 89 | push_to_hub: true 90 | seed: 2023 91 | -------------------------------------------------------------------------------- /resources/configuration_nn_bicubic.yml: -------------------------------------------------------------------------------- 1 | # Data configuration 2 | data_configuration: 3 | # Features configuration 4 | features_configuration: 5 | variables: 6 | - t2m 7 | data_name: era5 8 | add_auxiliary: false 9 | spatial_resolution: "025deg" 10 | spatial_coverage: 11 | longitude: [-8.35, 6.6] 12 | latitude: [46.45, 35.50] 13 | apply_standardization: true 14 | data_dir: /PATH/TO/features/ 15 | # Label configuration 16 | label_configuration: 17 | variable: t2m 18 | data_name: cerra 19 | spatial_resolution: "005deg" 20 | apply_standardization: true 21 | spatial_coverage: 22 | longitude: [-6.85, 5.1] 23 | latitude: [44.95, 37] 24 | data_dir: /PATH/TO/labels/ 25 | # Common data configuration 26 | common_configuration: 27 | temporal_coverage: 28 | start: 1985-01 29 | end: 2020-12 30 | frequency: MS 31 | data_split: 32 | test: 0.2 33 | validation: 0.3 34 | # Training configuration 35 | training_configuration: 36 | type: end2end 37 | model_configuration: 38 | neural_network: 39 | class_name: ConvBaseline 40 | kwargs: 41 | interpolation_method: bicubic 42 | num_channels: 1 43 | upblock_kernel_size: [5, 3] 44 | upblock_channels: [32, 16] 45 | upscale: 5 46 | training_parameters: 47 | num_epochs: 200 48 | gradient_accumulation_steps: 1 49 | learning_rate: 0.0001 50 | lr_warmup_steps: 500 51 | mixed_precision: "fp16" 52 | output_dir: "convbaseline-1985_2020" 53 | device: cpu 54 | push_to_hub: true 55 | seed: 2023 56 | save_image_epochs: 5 57 | save_model_epochs: 10 58 | -------------------------------------------------------------------------------- /resources/configuration_nn_evaluation.yml: -------------------------------------------------------------------------------- 1 | # Data configuration 2 | data_configuration: 3 | # Features configuration 4 | features_configuration: 5 | variables: 6 | - t2m 7 | data_name: era5 8 | add_auxiliary: false 9 | spatial_resolution: "025deg" 10 | spatial_coverage: 11 | longitude: [-8.35, 6.6] 12 | latitude: [46.45, 35.50] 13 | apply_standardization: true 14 | data_dir: /PATH/TO/features/ 15 | # Label configuration 16 | label_configuration: 17 | variable: t2m 18 | data_name: cerra 19 | spatial_resolution: "005deg" 20 | apply_standardization: true 21 | spatial_coverage: 22 | longitude: [-6.85, 5.1] 23 | latitude: [44.95, 37] 24 | data_dir: /PATH/TO/labels/ 25 | # Common data configuration 26 | common_configuration: 27 | temporal_coverage: 28 | start: 2018-01 29 | end: 2019-12 30 | frequency: MS 31 | data_split: 32 | test: 1 33 | # Training configuration 34 | training_configuration: 35 | type: end2end 36 | model_configuration: 37 | neural_network: 38 | class_name: ConvSwin2SR 39 | trained_model_dir: predictia/europe_reanalysis_downscaler_convswin2sr 40 | -------------------------------------------------------------------------------- /resources/configuration_nn_swin2sr.yml: -------------------------------------------------------------------------------- 1 | # Data configuration 2 | data_configuration: 3 | # Features configuration 4 | features_configuration: 5 | variables: 6 | - t2m 7 | data_name: era5 8 | spatial_resolution: "025deg" 9 | add_auxiliary: 10 | time: true 11 | lsm-low: true 12 | orog-low: true 13 | lsm-high: true 14 | orog-high: true 15 | spatial_coverage: 16 | longitude: [-8.35, 6.6] 17 | latitude: [46.45, 35.50] 18 | standardization: 19 | to_do: true 20 | cache_folder: /PATH/TO/.cache_reanalysis_scales 21 | method: domain-wise # pixel-wise, domain-wise, landmask-wise 22 | data_location: /PATH/TO/features/ 23 | land_mask_location: /PATH/TO/static/land-mask_ERA5.nc 24 | orography_location: /PATH/TO/static/orography_ERA5.nc 25 | # Label configuration 26 | label_configuration: 27 | variable: t2m 28 | data_name: cerra 29 | spatial_resolution: "005deg" 30 | spatial_coverage: 31 | longitude: [-6.85, 5.1] 32 | latitude: [44.95, 37] 33 | standardization: 34 | to_do: true 35 | cache_folder: /PATH/TO/.cache_reanalysis_scales 36 | method: domain-wise # pixel-wise, domain-wise, landmask-wise 37 | data_location: /PATH/TO/labels/ 38 | land_mask_location: /PATH/TO/static/land-mask_CERRA.nc 39 | orography_location: /PATH/TO/static/orography_CERRA.nc 40 | # Common data configuration 41 | split_coverages: 42 | train: 43 | start: 1981-01 44 | end: 2013-12 45 | frequency: MS 46 | validation: 47 | start: 2014-01 48 | end: 2017-12 49 | frequency: MS 50 | # Training configuration 51 | training_configuration: 52 | type: end2end 53 | model_configuration: 54 | neural_network: 55 | class_name: ConvSwin2SR 56 | kwargs: 57 | embed_dim: 128 58 | depths: [4, 4, 4, 4] 59 | num_heads: [4, 4, 4, 4] 60 | patch_size: 1 61 | window_size: 5 # divisor of input dims (1, 2 and 5 for images (20, 30)) 62 | num_channels: 1 63 | img_range: 1 64 | resi_connection: "1conv" 65 | upsampler: "pixelshuffle" 66 | interpolation_method: "bicubic" 67 | hidden_dropout_prob: 0.0 68 | upscale: 5 # For this method, must be power of 2. 69 | training_parameters: 70 | num_epochs: 100 71 | gradient_accumulation_steps: 4 72 | learning_rate: 0.0001 73 | lr_warmup_steps: 500 74 | mixed_precision: "fp16" 75 | hf_repo_name: predictia/europe_reanalysis_downscaler_convswin2sr 76 | output_dir: "swin2sr-1985_2020" 77 | device: cpu 78 | push_to_hub: false 79 | seed: 2023 80 | save_image_epochs: 5 81 | save_model_epochs: 10 82 | -------------------------------------------------------------------------------- /resources/configuration_vqvae.yml: -------------------------------------------------------------------------------- 1 | # Data configuration 2 | data_configuration: 3 | label_configuration: 4 | variable: t2m 5 | data_name: cerra 6 | spatial_resolution: "005deg" 7 | spatial_coverage: 8 | longitude: [-6.85, 5.1] 9 | latitude: [44.95, 37] 10 | apply_standardization: true 11 | data_dir: /PATH/TO/labels/ 12 | common_configuration: 13 | temporal_coverage: 14 | start: 1985-01 15 | end: 2020-12 16 | frequency: MS 17 | data_split: 18 | test: 0.2 19 | validation: 0.3 20 | # Training configuration 21 | training_configuration: 22 | type: autoencoder 23 | model_configuration: 24 | neural_network: 25 | class_name: diffusers.VQModel 26 | kwargs: 27 | latent_channels: 1 28 | norm_num_groups: 8 29 | num_vq_embeddings: 256 30 | vq_embed_dim: 32 31 | down_block_types: ["DownEncoderBlock2D", "AttnDownEncoderBlock2D", "AttnDownEncoderBlock2D"] 32 | up_block_types: ["UpDecoderBlock2D", "AttnUpDecoderBlock2D", "AttnUpDecoderBlock2D"] 33 | block_out_channels: [8, 16, 16] 34 | training_parameters: 35 | num_epochs: 200 36 | gradient_accumulation_steps: 1 37 | learning_rate: 0.0001 38 | lr_warmup_steps: 500 39 | mixed_precision: "fp16" 40 | hour_embed_type: class # none, timestep, positional, cyclical, class 41 | output_dir: "vqvae-tiny-small" 42 | device: cuda 43 | push_to_hub: true 44 | seed: 2023 45 | save_image_epochs: 5 46 | save_model_epochs: 10 47 | -------------------------------------------------------------------------------- /scripts/download/climate_data_store.py: -------------------------------------------------------------------------------- 1 | import calendar 2 | import os 3 | import pathlib 4 | 5 | import cdsapi 6 | import click 7 | import numpy 8 | import pandas 9 | import xarray 10 | 11 | from deepr.utilities.logger import get_logger 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | def get_number_of_days(year: int, month: int) -> int: 17 | """ 18 | Get the number of days in a specific month and year. 19 | 20 | Parameters 21 | ---------- 22 | year : int 23 | The year. 24 | month : int 25 | The month (1 to 12). 26 | 27 | Returns 28 | ------- 29 | int 30 | The number of days in the specified month and year, or None if the month or 31 | year is invalid. 32 | """ 33 | try: 34 | _, num_days = calendar.monthrange(year, month) 35 | return num_days 36 | except ValueError: 37 | return None # Invalid month or year 38 | 39 | 40 | def download_cds_data( 41 | output_directory: str, variable: str, month: int, year: int 42 | ) -> tuple: 43 | """ 44 | Download data from CDS (Climate Data Store) and save it to the specified directory. 45 | 46 | Parameters 47 | ---------- 48 | output_directory : str 49 | The path to the output directory. 50 | variable : str 51 | The variable name. 52 | month : int 53 | The month (1 to 12). 54 | year : int 55 | The year. 56 | 57 | Returns 58 | ------- 59 | tuple 60 | A tuple containing the paths to the downloaded ERA5 and CERRA data files. 61 | """ 62 | c = cdsapi.Client() 63 | cds_variable = { 64 | "t2m": "2m_temperature", 65 | } 66 | 67 | # Create directories if they don't exist 68 | path_cerra = pathlib.Path( 69 | f"{output_directory}/labels/{variable}/" 70 | f"{variable}_cerra_{year}{'{:02d}'.format(month)}_005deg.nc" 71 | ) 72 | if not path_cerra.parent.exists(): 73 | path_cerra.parent.mkdir(exist_ok=True, parents=True) 74 | 75 | logger.info(f"Downloading CERRA data to {path_cerra}") 76 | c.retrieve( 77 | "reanalysis-cerra-single-levels", 78 | { 79 | "format": "netcdf", 80 | "variable": cds_variable[variable], 81 | "level_type": "surface_or_atmosphere", 82 | "data_type": "reanalysis", 83 | "product_type": "analysis", 84 | "year": year, 85 | "month": "{:02d}".format(month), 86 | "day": [ 87 | "{:02d}".format(x) 88 | for x in range(1, get_number_of_days(year, month) + 1) 89 | ], 90 | "time": ["{:02d}:00".format(x) for x in range(0, 24, 3)], 91 | }, 92 | path_cerra, 93 | ) 94 | 95 | path_era5 = pathlib.Path( 96 | f"{output_directory}/features/{variable}/" 97 | f"{variable}_era5_{year}{'{:02d}'.format(month)}_025deg.nc" 98 | ) 99 | if not path_era5.parent.exists(): 100 | path_era5.parent.mkdir(exist_ok=True, parents=True) 101 | 102 | logger.info(f"Downloading ERA5 data to {path_era5}") 103 | c.retrieve( 104 | "reanalysis-era5-single-levels", 105 | { 106 | "product_type": "reanalysis", 107 | "variable": cds_variable[variable], 108 | "year": year, 109 | "month": "{:02d}".format(month), 110 | "day": [ 111 | "{:02d}".format(x) 112 | for x in range(1, get_number_of_days(year, month) + 1) 113 | ], 114 | "time": ["{:02d}:00".format(x) for x in range(0, 24, 3)], 115 | "format": "netcdf", 116 | }, 117 | path_era5, 118 | ) 119 | return path_era5, path_cerra 120 | 121 | 122 | def process_cds_data(file_path: pathlib.Path, varname: str) -> pathlib.Path: 123 | """ 124 | Process CDS data file by adjusting longitudes and data type. 125 | 126 | Parameters 127 | ---------- 128 | file_path : pathlib.Path 129 | The path to the CDS data file. 130 | varname : str 131 | The variable name. 132 | 133 | Returns 134 | ------- 135 | pathlib.Path 136 | The path to the processed CDS data file. 137 | """ 138 | ds = xarray.open_dataset(file_path) 139 | lon = ds["longitude"] 140 | 141 | # Adjust longitudes 142 | ds["longitude"] = ds["longitude"].where(lon <= 180, other=lon - 360) 143 | ds = ds.reindex(**{"longitude": sorted(ds["longitude"])}) 144 | 145 | # Modify encoding attributes 146 | encoding_attrs = ds[varname].encoding.copy() 147 | del encoding_attrs["scale_factor"] 148 | del encoding_attrs["add_offset"] 149 | encoding_attrs["dtype"] = numpy.float64 150 | ds[varname].encoding = encoding_attrs 151 | 152 | # Save the processed data to a new file 153 | new_file = file_path.with_suffix("_new.nc") 154 | ds.to_netcdf(new_file) 155 | 156 | # Remove the original file and rename the new file 157 | os.remove(file_path) 158 | os.rename(new_file, file_path) 159 | 160 | return file_path 161 | 162 | 163 | @click.command() 164 | @click.argument("output_directory", type=click.Path(file_okay=False, resolve_path=True)) 165 | @click.argument("variable", type=click.Choice(["t2m"])) 166 | @click.argument("start_date", type=click.DateTime(formats=["%Y-%m-%d"])) 167 | @click.argument("end_date", type=click.DateTime(formats=["%Y-%m-%d"])) 168 | def main(output_directory, variable, start_date, end_date): 169 | """ 170 | Download CDS data for a specific variable within the given date range and save it to the output directory. 171 | 172 | Parameters 173 | ---------- 174 | output_directory : str 175 | The path to the output directory. 176 | variable : str 177 | The variable name (e.g., "t2m"). 178 | start_date : datetime 179 | The start date in YYYY-MM-DD format. 180 | end_date : datetime 181 | The end date in YYYY-MM-DD format. 182 | """ 183 | dates = pandas.date_range(start_date, end_date, freq="MS") 184 | features, labels = [], [] 185 | for date in dates: 186 | data_feature, data_label = download_cds_data( 187 | output_directory, variable, date.year, date.month 188 | ) 189 | data_feature = process_cds_data(data_feature, variable) 190 | data_label = process_cds_data(data_label, variable) 191 | features.append(data_feature) 192 | labels.append(data_label) 193 | print(features) 194 | print(labels) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /scripts/download/european_weather_cloud.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | 4 | import click 5 | import pandas as pd 6 | import requests 7 | 8 | from deepr.utilities.logger import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | def download_data( 14 | variable: str, 15 | date: str, 16 | project: str, 17 | spatial_resolution: str, 18 | output_directory: str, 19 | ) -> None: 20 | """ 21 | Download data from a given URL to the specified output directory. 22 | 23 | Parameters 24 | ---------- 25 | variable : str 26 | The variable name. 27 | 28 | date : str 29 | The date in the format YYYYMM. 30 | 31 | project : str 32 | The project name. 33 | 34 | spatial_resolution : str 35 | The spatial resolution. 36 | 37 | output_directory : str 38 | The path to the output directory. 39 | 40 | Returns 41 | ------- 42 | None 43 | 44 | Raises 45 | ------ 46 | ValueError 47 | If the input date is outside the supported ranges. 48 | """ 49 | date_time = datetime.strptime(date, "%Y%m") 50 | 51 | start_date1 = datetime(1985, 1, 1) 52 | end_date1 = datetime(2018, 12, 31) 53 | 54 | start_date2 = datetime(2019, 1, 1) 55 | end_date2 = datetime(2021, 12, 31) 56 | 57 | if start_date1 <= date_time <= end_date1: 58 | date_range = "1985_2018" 59 | elif start_date2 <= date_time <= end_date2: 60 | date_range = "2019_2021" 61 | else: 62 | raise ValueError( 63 | "Invalid date range. " 64 | "Supported ranges are between 1985 and 2018, or between 2019 and 2021." 65 | ) 66 | 67 | cloud_url = "https://storage.ecmwf.europeanweather.cloud/Code4Earth" 68 | project_dir = f"netCDF_{project}_{date_range}" 69 | filename = f"{variable}_{project}_{date}_{spatial_resolution}.nc" 70 | output_path = os.path.join(output_directory, filename) 71 | 72 | if os.path.exists(output_path): 73 | logger.info(f"File {output_path} already exists!") 74 | else: 75 | response = requests.get(f"{cloud_url}/{project_dir}/{filename}") 76 | if response.status_code == 200: 77 | with open(output_path, "wb") as file: 78 | file.write(response.content) 79 | logger.info(f"Data downloaded successfully to: {output_path}") 80 | else: 81 | logger.info(f"Failed to download data. Status code: {response.status_code}") 82 | 83 | 84 | @click.command() 85 | @click.argument( 86 | "output_directory", type=click.Path(exists=True, file_okay=False, resolve_path=True) 87 | ) 88 | @click.argument("variable") 89 | @click.argument("project") 90 | @click.argument("spatial_resolution") 91 | @click.argument("start_date") 92 | @click.argument("end_date") 93 | def main(output_directory, variable, project, spatial_resolution, start_date, end_date): 94 | # Convert start_date and end_date to datetime objects 95 | start_date = datetime.strptime(start_date, "%Y-%m-%d") 96 | end_date = datetime.strptime(end_date, "%Y-%m-%d") 97 | 98 | dates = pd.date_range(start_date, end_date, freq="MS") 99 | 100 | logger.info( 101 | f"Downloading data to {output_directory} for variable {variable}, " 102 | f"project {project}, spatial resolution {spatial_resolution} from " 103 | f"{start_date} to {end_date}" 104 | ) 105 | 106 | for date in dates: 107 | logger.info(f"Downloading data for date: {date}") 108 | date_str = date.strftime("%Y%m") 109 | download_data(variable, date_str, project, spatial_resolution, output_directory) 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /scripts/modeling/generate_model_predictions.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | from deepr.workflow import MainPipeline 5 | 6 | 7 | def main(cfg_path: Path): 8 | main_pipeline = MainPipeline(cfg_path) 9 | main_pipeline.generate_predictions() 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = ArgumentParser( 14 | prog="generate_model_predictions.py", 15 | description="Dump predictions from Super Resolution model", 16 | ) 17 | parser.add_argument( 18 | "--cfg_path", 19 | default="../resources/configuration_predictions.yml", 20 | type=Path, 21 | help="Path to the configuration file.", 22 | ) 23 | 24 | args = parser.parse_args() 25 | main(cfg_path=Path(args.cfg_path)) 26 | -------------------------------------------------------------------------------- /scripts/modeling/train_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | from deepr.workflow import MainPipeline 5 | 6 | 7 | def main(cfg_path: Path): 8 | main_pipeline = MainPipeline(cfg_path) 9 | main_pipeline.run_pipeline() 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = ArgumentParser( 14 | prog="train_model.py", 15 | description="Train a Super Resolution model and validate its results", 16 | ) 17 | parser.add_argument( 18 | "--cfg_path", 19 | default="../resources/configuration_nn_vqmodel.yml", 20 | type=Path, 21 | help="Path to the configuration file.", 22 | ) 23 | 24 | args = parser.parse_args() 25 | main(cfg_path=Path(args.cfg_path)) 26 | -------------------------------------------------------------------------------- /scripts/modeling/validate_model_predictions.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | from deepr.workflow import MainPipeline 5 | 6 | 7 | def main(cfg_path: Path): 8 | main_pipeline = MainPipeline(cfg_path) 9 | main_pipeline.run_validation() 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = ArgumentParser( 14 | prog="train_model.py", 15 | description="Train a Super Resolution model and validate its results", 16 | ) 17 | parser.add_argument( 18 | "--cfg_path", 19 | default="../resources/configuration_nn_evaluation.yml", 20 | type=Path, 21 | help="Path to the configuration file.", 22 | ) 23 | 24 | args = parser.parse_args() 25 | main(cfg_path=Path(args.cfg_path)) 26 | -------------------------------------------------------------------------------- /scripts/processing/data_spatial_selection.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import click 5 | import xarray 6 | 7 | from deepr.utilities.logger import get_logger 8 | 9 | logger = get_logger(__name__) 10 | 11 | 12 | def adjust_latitudes(lat_min, lat_max, lat_values): 13 | # Check if latitude values are reversed (e.g., from 90 to -90) 14 | if lat_values[0] > lat_values[-1]: 15 | lat_min, lat_max = lat_max, lat_min 16 | return lat_min, lat_max 17 | 18 | 19 | @click.command() 20 | @click.argument( 21 | "input_directory", type=click.Path(exists=True, file_okay=False, resolve_path=True) 22 | ) 23 | @click.argument("output_directory", type=click.Path(file_okay=False, resolve_path=True)) 24 | @click.argument("lon_min", type=float) 25 | @click.argument("lon_max", type=float) 26 | @click.argument("lat_min", type=float) 27 | @click.argument("lat_max", type=float) 28 | def main(input_directory, output_directory, lon_min, lon_max, lat_min, lat_max): 29 | # Create the output directory if it doesn't exist 30 | os.makedirs(output_directory, exist_ok=True) 31 | 32 | input_files = glob.glob(f"{input_directory}/*.nc") 33 | input_files.sort() 34 | 35 | logger.info(f"A list of {len(input_files)} will be transformed.") 36 | 37 | for num, input_file in enumerate(input_files): 38 | logger.info(f"Processing input_file ({num}): {input_file}") 39 | input_data = xarray.open_dataset(input_file) 40 | 41 | # Adjust latitude values if needed based on the dataset 42 | lat_min, lat_max = adjust_latitudes( 43 | lat_min, lat_max, input_data.latitude.values 44 | ) 45 | 46 | input_data_sel = input_data.sel( 47 | latitude=slice(lat_min, lat_max), 48 | longitude=slice(lon_min, lon_max), 49 | ) 50 | output_file = input_file.replace(input_directory, output_directory) 51 | logger.info(f"Writing processed data to {output_file}") 52 | input_data_sel.to_netcdf(output_file) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = deepr 3 | license = Apache License 2.0 4 | description = DeepR: Deep Reanalysis 5 | classifiers = 6 | Development Status :: 3 - Alpha 7 | Intended Audience :: Science/Research 8 | License :: OSI Approved :: Apache Software License 9 | Operating System :: OS Independent 10 | Programming Language :: Python 11 | Programming Language :: Python :: 3 12 | Programming Language :: Python :: 3.10 13 | Topic :: Scientific/Engineering 14 | long_description_content_type=text/markdown 15 | long_description = file: README.md 16 | 17 | [options] 18 | packages = find: 19 | include_package_data = True 20 | [flake8] 21 | max-line-length = 110 22 | extend-ignore = E203, W503 23 | 24 | [mypy] 25 | strict = False 26 | 27 | [mypy-cartopy.*] 28 | ignore_missing_imports = True 29 | 30 | [mypy-matplotlib.*] 31 | ignore_missing_imports = True 32 | 33 | [mypy-xskillscore.*] 34 | ignore_missing_imports = True 35 | 36 | [options.entry_points] 37 | console_scripts = 38 | train_model = deepr.cli:train_model 39 | -------------------------------------------------------------------------------- /tests/test_00_version.py: -------------------------------------------------------------------------------- 1 | import deepr 2 | 3 | 4 | def test_version() -> None: 5 | assert deepr.__version__ != "999" 6 | -------------------------------------------------------------------------------- /tests/tests_data/test_files.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from deepr.data.files import DataFile 6 | 7 | 8 | @pytest.fixture 9 | def base_dir(): 10 | return "/path/to/data" 11 | 12 | 13 | @pytest.fixture 14 | def data_file(base_dir): 15 | return DataFile(base_dir, "t2m", "era5", "201801", "025deg", None) 16 | 17 | 18 | def test_data_path_attributes(data_file, base_dir): 19 | """ 20 | Test the attributes of the DataPath instance. 21 | 22 | Parameters 23 | ---------- 24 | data_file : DataFile 25 | The DataPath instance to test. 26 | base_dir : str 27 | The expected base directory. 28 | 29 | Returns 30 | ------- 31 | None 32 | """ 33 | assert data_file.base_dir == base_dir 34 | assert data_file.variable == "t2m" 35 | assert data_file.dataset == "era5" 36 | assert data_file.temporal_coverage == "201801" 37 | assert data_file.spatial_resolution == "025deg" 38 | 39 | 40 | def test_data_path_to_path(data_file, base_dir): 41 | """ 42 | Test the to_path method of the DataPath instance. 43 | 44 | Parameters 45 | ---------- 46 | data_file : DataFile 47 | The DataPath instance to test. 48 | base_dir : str 49 | The expected base directory. 50 | 51 | Returns 52 | ------- 53 | None 54 | """ 55 | expected_path = os.path.join(base_dir, "t2m_era5_201801_025deg.nc") 56 | assert data_file.to_path() == expected_path 57 | 58 | 59 | def test_data_path_from_path(data_file, base_dir): 60 | """ 61 | Test the from_path class method of the DataPath class. 62 | 63 | Parameters 64 | ---------- 65 | data_file : DataFile 66 | The DataPath instance for comparison. 67 | base_dir : str 68 | The expected base directory. 69 | 70 | Returns 71 | ------- 72 | None 73 | """ 74 | file_path = os.path.join(base_dir, "t2m_era5_201801_025deg.nc") 75 | new_data_path = DataFile.from_path(file_path) 76 | assert new_data_path.base_dir == base_dir 77 | assert new_data_path.variable == "t2m" 78 | assert new_data_path.dataset == "era5" 79 | assert new_data_path.temporal_coverage == "201801" 80 | assert new_data_path.spatial_resolution == "025deg" 81 | --------------------------------------------------------------------------------