├── .github └── workflows │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── advance.png ├── aid.png ├── brazilian_coffee.jpg ├── dubai.jpg ├── etci2021.jpg ├── eurosat.jpg ├── fair1m.jpg ├── fc_cd.png ├── gid15.png ├── hkh_glacier.png ├── hrscd.png ├── inria_ail.png ├── levircd_plus.png ├── oscd.png ├── patternnet.png ├── proba-v.jpg ├── rams.png ├── resisc45.png ├── rsicd.png ├── rsscn7.png ├── rsvqa_hr.png ├── rsvqa_lr.png ├── rsvqaxben.png ├── s2looking.png ├── s2mtcp.png ├── sat.png ├── sydney_captions.png ├── tiselac.png ├── ucm_captions.png ├── whu_rs19.png └── zuericrop.png ├── examples ├── levir-cd+.ipynb └── probav.ipynb ├── requirements.txt ├── scripts ├── download_advance.sh ├── download_aid.sh ├── download_brazilian_coffee.sh ├── download_dubai_segmentation.sh ├── download_etci2021.sh ├── download_eurosat_ms.sh ├── download_eurosat_rgb.sh ├── download_fair1m.sh ├── download_gid15.sh ├── download_hkh_glacier.sh ├── download_hrscd.sh ├── download_inria_ail.sh ├── download_levircd_plus.sh ├── download_oscd.sh ├── download_patternnet.sh ├── download_probav.sh ├── download_resisc45.sh ├── download_rsicd.sh ├── download_rsscn7.sh ├── download_rsvqa_hr.sh ├── download_rsvqa_lr.sh ├── download_rsvqaxben.sh ├── download_s2looking.sh ├── download_s2mtcp.sh ├── download_sat.sh ├── download_sydney_captions.sh ├── download_tiselac.sh ├── download_ucm.sh ├── download_ucm_captions.sh ├── download_whu_rs19.sh └── download_zuericrop.sh ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── _test_datamodules.py ├── _test_datasets.py ├── models │ ├── __init__.py │ └── test_rams.py └── test_transforms.py └── torchrs ├── __init__.py ├── datasets ├── __init__.py ├── advance.py ├── aid.py ├── brazilian_coffee.py ├── dubai_segmentation.py ├── etci2021.py ├── eurosat.py ├── fair1m.py ├── gid15.py ├── hkh_glacier.py ├── hrscd.py ├── inria_ail.py ├── levircd.py ├── oscd.py ├── patternnet.py ├── probav.py ├── resisc45.py ├── rsicd.py ├── rsscn7.py ├── rsvqa.py ├── s2looking.py ├── s2mtcp.py ├── sat.py ├── sydney_captions.py ├── tiselac.py ├── ucm.py ├── ucm_captions.py ├── utils.py ├── whu_rs19.py └── zuericrop.py ├── models ├── __init__.py ├── fc_cd.py ├── oscd.py └── rams.py ├── train ├── __init__.py ├── datamodules │ ├── __init__.py │ ├── advance.py │ ├── aid.py │ ├── base.py │ ├── brazilian_coffee.py │ ├── dubai_segmentation.py │ ├── etci2021.py │ ├── eurosat.py │ ├── fair1m.py │ ├── gid15.py │ ├── hkh_glacier.py │ ├── hrscd.py │ ├── inria_ail.py │ ├── levircd.py │ ├── oscd.py │ ├── patternnet.py │ ├── probav.py │ ├── resisc45.py │ ├── rsicd.py │ ├── rsscn7.py │ ├── rsvqa.py │ ├── s2looking.py │ ├── s2mtcp.py │ ├── sat.py │ ├── sydney_captions.py │ ├── tiselac.py │ ├── ucm.py │ ├── ucm_captions.py │ ├── whu_rs19.py │ └── zuericrop.py └── modules │ ├── __init__.py │ ├── fc_cd.py │ └── rams.py └── transforms.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 2 | 3 | name: Build 4 | 5 | on: 6 | push: 7 | branches: [ main ] 8 | pull_request: 9 | branches: [ main ] 10 | 11 | jobs: 12 | build: 13 | 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: [3.7, 3.8, 3.9] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install flake8 pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | lightning_logs/ 2 | .vscode/ 3 | .data/ 4 | .data 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 isaac 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/advance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/advance.png -------------------------------------------------------------------------------- /assets/aid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/aid.png -------------------------------------------------------------------------------- /assets/brazilian_coffee.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/brazilian_coffee.jpg -------------------------------------------------------------------------------- /assets/dubai.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/dubai.jpg -------------------------------------------------------------------------------- /assets/etci2021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/etci2021.jpg -------------------------------------------------------------------------------- /assets/eurosat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/eurosat.jpg -------------------------------------------------------------------------------- /assets/fair1m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/fair1m.jpg -------------------------------------------------------------------------------- /assets/fc_cd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/fc_cd.png -------------------------------------------------------------------------------- /assets/gid15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/gid15.png -------------------------------------------------------------------------------- /assets/hkh_glacier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/hkh_glacier.png -------------------------------------------------------------------------------- /assets/hrscd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/hrscd.png -------------------------------------------------------------------------------- /assets/inria_ail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/inria_ail.png -------------------------------------------------------------------------------- /assets/levircd_plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/levircd_plus.png -------------------------------------------------------------------------------- /assets/oscd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/oscd.png -------------------------------------------------------------------------------- /assets/patternnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/patternnet.png -------------------------------------------------------------------------------- /assets/proba-v.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/proba-v.jpg -------------------------------------------------------------------------------- /assets/rams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/rams.png -------------------------------------------------------------------------------- /assets/resisc45.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/resisc45.png -------------------------------------------------------------------------------- /assets/rsicd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/rsicd.png -------------------------------------------------------------------------------- /assets/rsscn7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/rsscn7.png -------------------------------------------------------------------------------- /assets/rsvqa_hr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/rsvqa_hr.png -------------------------------------------------------------------------------- /assets/rsvqa_lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/rsvqa_lr.png -------------------------------------------------------------------------------- /assets/rsvqaxben.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/rsvqaxben.png -------------------------------------------------------------------------------- /assets/s2looking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/s2looking.png -------------------------------------------------------------------------------- /assets/s2mtcp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/s2mtcp.png -------------------------------------------------------------------------------- /assets/sat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/sat.png -------------------------------------------------------------------------------- /assets/sydney_captions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/sydney_captions.png -------------------------------------------------------------------------------- /assets/tiselac.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/tiselac.png -------------------------------------------------------------------------------- /assets/ucm_captions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/ucm_captions.png -------------------------------------------------------------------------------- /assets/whu_rs19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/whu_rs19.png -------------------------------------------------------------------------------- /assets/zuericrop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/assets/zuericrop.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | torchvision>=0.10.0 3 | torchaudio>=0.9.0 4 | einops>=0.3.0 5 | numpy>=1.21.0 6 | pillow>=8.3.1 7 | tifffile>=2021.7.2 8 | h5py>=3.3.0 9 | imagecodecs>=2021.7.30 -------------------------------------------------------------------------------- /scripts/download_advance.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/advance 2 | wget --no-check-certificate https://zenodo.org/record/3828124/files/ADVANCE_vision.zip?download=1 -O ADVANCE_vision.zip 3 | wget --no-check-certificate https://zenodo.org/record/3828124/files/ADVANCE_sound.zip?download=1 -O ADVANCE_sound.zip 4 | unzip ADVANCE_vision.zip -d .data/advance/ 5 | rm ADVANCE_vision.zip 6 | unzip ADVANCE_sound.zip -d .data/advance/ 7 | rm ADVANCE_sound.zip 8 | -------------------------------------------------------------------------------- /scripts/download_aid.sh: -------------------------------------------------------------------------------- 1 | pip instal gdown 2 | mkdir -p .data 3 | gdown --id 1cvjfe_MZJI9HXwkRgoQbSCH55qQd6Esm 4 | unzip AID.zip -d .data/ 5 | rm AID.zip 6 | -------------------------------------------------------------------------------- /scripts/download_brazilian_coffee.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/ 2 | wget http://www.patreo.dcc.ufmg.br/wp-content/uploads/2017/11/brazilian_coffee_dataset.zip -O brazilian_coffee_dataset.zip 3 | unzip brazilian_coffee_dataset.zip -d .data/ 4 | rm brazilian_coffee_dataset.zip 5 | -------------------------------------------------------------------------------- /scripts/download_dubai_segmentation.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data 2 | wget https://semantic-segmentation-uae.s3.eu-central-1.amazonaws.com/Semantic-segmentation-dataset-1.zip 3 | unzip Semantic-segmentation-dataset-1.zip 4 | mv "Semantic segmentation dataset" dubai-segmentation 5 | mv dubai-segmentation .data/ 6 | rm Semantic-segmentation-dataset-1.zip 7 | -------------------------------------------------------------------------------- /scripts/download_etci2021.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/etci2021 3 | 4 | # train dataset (3GB) 5 | gdown --id 14HqNW5uWLS92n7KrxKgDwUTsSEST6LCr 6 | unzip train.zip -d .data/etci2021 7 | rm -r .data/etci2021/__MACOSX 8 | rm train.zip 9 | 10 | # val dataset (0.85GB) 11 | gdown --id 19sriKPHCZLfJn_Jmk3Z_0b3VaCBVRVyn 12 | unzip val_with_ref_labels.zip -d .data/etci2021 13 | rm -r .data/etci2021/__MACOSX 14 | rm val_with_ref_labels.zi 15 | 16 | # test dataset (1.2GB) (no labels) 17 | gdown --id 1rpMVluASnSHBfm2FhpPDio0GyCPOqg7E 18 | unzip test_without_ref_labels.zip -d .data/etci2021 19 | rm -r .data/etci2021/__MACOSX 20 | rm test_without_ref_labels.zip 21 | -------------------------------------------------------------------------------- /scripts/download_eurosat_ms.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/eurosat-ms 2 | wget http://madm.dfki.de/files/sentinel/EuroSATallBands.zip 3 | unzip EuroSATallBands.zip -d .data/eurosat-ms 4 | rm EuroSATallBands.zip 5 | -------------------------------------------------------------------------------- /scripts/download_eurosat_rgb.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/eurosat-rgb 2 | wget http://madm.dfki.de/files/sentinel/EuroSAT.zip 3 | unzip EuroSAT.zip -d .data/eurosat-rgb 4 | rm EuroSAT.zip 5 | -------------------------------------------------------------------------------- /scripts/download_fair1m.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/fair1m 3 | gdown --id 1maeEBvno6BhXWXyYqPCug9YVR1yyE-xF 4 | unzip FAIR1M.zip -d .data/fair1m 5 | unzip .data/fair1m/images.zip -d .data/fair1m 6 | unzip .data/fair1m/labelXmls.zip -d .data/fair1m 7 | rm FAIR1M.zip 8 | rm .data/fair1m/images.zip 9 | rm .data/fair1m/labelXmls.zip 10 | -------------------------------------------------------------------------------- /scripts/download_gid15.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/ 3 | gdown --id 1zbkCEXPEKEV6gq19OKmIbaT8bXXfWW6u -O gid-15.zip 4 | unzip gid-15.zip -d .data/ 5 | rm gid-15.zip 6 | -------------------------------------------------------------------------------- /scripts/download_hkh_glacier.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data 3 | gdown --id 1Tk32aGadb7YBfI5JxLbX6kXUGdJNiC8p 4 | unzip hkh_glacier_mapping.zip -d .data/ 5 | rm hkh_glacier_mapping.zip -------------------------------------------------------------------------------- /scripts/download_hrscd.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/ 3 | gdown --id 1hJYP3Rn2lyFTvrzLYrn-Q6lmma6BRSE7 4 | unzip HRSCD.zip -d .data/ 5 | rm HRSCD.zip 6 | -------------------------------------------------------------------------------- /scripts/download_inria_ail.sh: -------------------------------------------------------------------------------- 1 | sudo apt install p7zip-full 2 | mkdir -p .data/inria_ail 3 | wget --no-check-certificate https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.001 4 | wget --no-check-certificate https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.002 5 | wget --no-check-certificate https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.003 6 | wget --no-check-certificate https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.004 7 | wget --no-check-certificate https://files.inria.fr/aerialimagelabeling/aerialimagelabeling.7z.005 8 | 7z x aerialimagelabeling.7z.001 9 | unzip NEW2-AerialImageDataset.zip -d .data/ 10 | rm -i aerialimagelabeling.7z.* 11 | rm -i NEW2-AerialImageDataset.zip 12 | -------------------------------------------------------------------------------- /scripts/download_levircd_plus.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/ 3 | gdown --id 1JamSsxiytXdzAIk6VDVWfc-OsX-81U81 4 | unzip LEVIR-CD+.zip -d .data/ 5 | mv .data/LEVIR-CD+ .data/levircd_plus 6 | rm -r .data/__MACOSX 7 | find .data/levircd_plus/ -name '.DS_Store' -type f -delete 8 | rm LEVIR-CD+.zip 9 | -------------------------------------------------------------------------------- /scripts/download_oscd.sh: -------------------------------------------------------------------------------- 1 | # Manually download here but requires IEEE login https://ieee-dataport.org/open-access/oscd-onera-satellite-change-detection#files 2 | # rehosted to Google Drive to download in script 3 | pip install gdown 4 | mkdir -p .data/oscd 5 | gdown --id 1jidN0DKEIybOrP0j7Bos8bGDDq3Varj3 6 | unzip OSCD.zip -d .data/oscd 7 | rm OSCD.zip 8 | mv '.data/oscd/Onera Satellite Change Detection dataset - Images.zip' .data/oscd/Images.zip 9 | mv '.data/oscd/Onera Satellite Change Detection dataset - Train Labels.zip' .data/oscd/TrainLabels.zip 10 | mv '.data/oscd/Onera Satellite Change Detection dataset - Test Labels.zip' .data/oscd/TestLabels.zip 11 | unzip .data/oscd/Images.zip -d .data/oscd/ 12 | unzip .data/oscd/TrainLabels.zip -d .data/oscd/ 13 | unzip .data/oscd/TestLabels.zip -d .data/oscd/ 14 | mv '.data/oscd/Onera Satellite Change Detection dataset - Images' .data/oscd/images 15 | mv '.data/oscd/Onera Satellite Change Detection dataset - Train Labels' .data/oscd/train_labels 16 | mv '.data/oscd/Onera Satellite Change Detection dataset - Test Labels' .data/oscd/test_labels 17 | rm .data/oscd/Images.zip 18 | rm .data/oscd/TrainLabels.zip 19 | rm .data/oscd/TestLabels.zip 20 | -------------------------------------------------------------------------------- /scripts/download_patternnet.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/ 3 | gdown --id 127lxXYqzO6Bd0yZhvEbgIfz95HaEnr9K 4 | unzip PatternNet.zip -d .data/ 5 | rm PatternNet.zip 6 | -------------------------------------------------------------------------------- /scripts/download_probav.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/probav 2 | wget --no-check-certificate https://kelvins.esa.int/media/competitions/proba-v-super-resolution/probav_data.zip 3 | unzip probav_data.zip -d .data/probav 4 | rm probav_data.zip 5 | -------------------------------------------------------------------------------- /scripts/download_resisc45.sh: -------------------------------------------------------------------------------- 1 | # can be downloaded manually here https://1drv.ms/u/s!AmgKYzARBl5ca3HNaHIlzp_IXjs 2 | # uploaded to gdrive for downloading in a script 3 | pip install gdown 4 | apt-get install unrar 5 | gdown --id 1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv 6 | unrar x NWPU-RESISC45.rar .data/ 7 | rm NWPU-RESISC45.rar 8 | -------------------------------------------------------------------------------- /scripts/download_rsicd.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/rsicd 2 | 3 | # Download images 4 | pip install gdown 5 | apt-get install unrar 6 | gdown --id 0B1jt7lJDEXy3SmZEdDd0aWpXcWc 7 | unrar x RSICD_images.rar .data/rscid 8 | rm RSICD_images.rar 9 | 10 | # Download annotations 11 | gdown --id 1q8EcBWuCbvtTnMILE60WOhd9S0C4rsfT 12 | mv dataset_rsicd.json .data/rsicd/ 13 | -------------------------------------------------------------------------------- /scripts/download_rsscn7.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/RSSCN7 2 | git clone https://github.com/palewithout/RSSCN7.git .data/RSSCN7 3 | rm -r .data/RSSCN7/.git/ 4 | -------------------------------------------------------------------------------- /scripts/download_rsvqa_hr.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/ 2 | wget https://cloud.sylvainlobry.com/s/f7NpYQKqx4bZStx/download -O RSVQA_HR.zip 3 | unzip RSVQA_HR.zip -d .data/ 4 | rm RSVQA_HR.zip 5 | -------------------------------------------------------------------------------- /scripts/download_rsvqa_lr.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/ 2 | wget https://cloud.sylvainlobry.com/s/4Qg5AXX8YfCswmX/download -O RSVQA_LR.zip 3 | unzip RSVQA_LR.zip -d .data/ 4 | unzip .data/RSVQA_LR/Images_LR.zip -d .data/RSVQA_LR 5 | rm RSVQA_LR.zip 6 | rm .data/RSVQA_LR/Images_LR.zip 7 | -------------------------------------------------------------------------------- /scripts/download_rsvqaxben.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/rsvqaxben/ 2 | wget https://zenodo.org/record/5084904/files/LRBENanswers.json?download=1 -O .data/rsvqaxben/LRBENanswers.json 3 | wget https://zenodo.org/record/5084904/files/LRBENimages.json?download=1 -O .data/rsvqaxben/LRBENimages.json 4 | wget https://zenodo.org/record/5084904/files/LRBENpeople.json?download=1 -O .data/rsvqaxben/LRBENpeople.json 5 | wget https://zenodo.org/record/5084904/files/LRBENquestions.json?download=1 -O .data/rsvqaxben/LRBENquestions.json 6 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_test_answers.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_test_answers.json 7 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_test_images.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_test_images.json 8 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_test_questions.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_test_questions.json 9 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_train_answers.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_train_answers.json 10 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_train_images.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_train_images.json 11 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_train_questions.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_train_questions.json 12 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_val_answers.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_val_answers.json 13 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_val_images.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_val_images.json 14 | wget https://zenodo.org/record/5084904/files/RSVQAxBEN_split_val_questions.json?download=1 -O .data/rsvqaxben/RSVQAxBEN_split_val_questions.json 15 | wget https://zenodo.org/record/5084904/files/Images.zip?download=1 -O .data/rsvqaxben/Images.zip 16 | unzip .data/rsvqaxben/Images.zip -I .data/rsvqaxben/ 17 | rm .data/rsvqaxben/Images.zip 18 | -------------------------------------------------------------------------------- /scripts/download_s2looking.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/ 2 | gdown --id 1M4w3acW0wGVuGFfxGxxZlw8lvIbAEa1I 3 | unzip S2Looking.zip -------------------------------------------------------------------------------- /scripts/download_s2mtcp.sh: -------------------------------------------------------------------------------- 1 | apt-get install -y p7zip-full 2 | mkdir -p .data/s2mtcp 3 | wget --no-check-certificate https://zenodo.org/record/4280482/files/S2MTCP_data.7z?download=1 -O S2MTCP_data.7z 4 | wget --no-check-certificate https://zenodo.org/record/4280482/files/S2MTCP_metadata.csv?download=1 -O .data/s2mtcp/S2MTCP_metadata.csv 5 | wget https://zenodo.org/record/4280482/files/README.txt?download=1 -O .data/s2mtcp/README.txt 6 | 7z x S2MTCP_data.7z -o.data/s2mtcp 7 | rm S2MTCP_data.7z -------------------------------------------------------------------------------- /scripts/download_sat.sh: -------------------------------------------------------------------------------- 1 | # Converted from mat to hdf5 format and moved N dim to front 2 | mkdir -p .data/sat 3 | pip install gdown 4 | gdown --id 1q4Xpi67DQtbLx1tnA9XZGfvBRIXYr8dG 5 | unzip Sat.zip -d .data/sat 6 | rm Sat.zip 7 | -------------------------------------------------------------------------------- /scripts/download_sydney_captions.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/sydney_captions 3 | gdown --id 1zAA8qG8FZLCj2t4t9JPgNS48sA2iltaU 4 | unzip Sydney_captions.zip -d .data/sydney_captions 5 | rm Sydney_captions.zip 6 | -------------------------------------------------------------------------------- /scripts/download_tiselac.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/tiselac/ 3 | gdown --id 0B383FlgU32evbm5Fa2JZcHRpVzg -O .data/tiselac/train.txt 4 | gdown --id 0B383FlgU32evY2RJamExSFI2QlE -O .data/tiselac/test.txt 5 | gdown --id 0B383FlgU32evTUdpWHBNZDlHR1E -O .data/tiselac/train_labels.txt 6 | gdown --id 0B383FlgU32evU3hsQlBfX0wtTFE -O .data/tiselac/test_labels.txt 7 | -------------------------------------------------------------------------------- /scripts/download_ucm.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/ 2 | wget http://weegee.vision.ucmerced.edu/datasets/UCMerced_LandUse.zip -O UCMerced_LandUse.zip 3 | unzip UCMerced_LandUse.zip -d .data/ 4 | rm UCMerced_LandUse.zip 5 | -------------------------------------------------------------------------------- /scripts/download_ucm_captions.sh: -------------------------------------------------------------------------------- 1 | pip install gdown 2 | mkdir -p .data/ucm_captions 3 | gdown --id 1WNAQgdgghz0cz4rV_KmTNyNGF_m8cNkt 4 | unzip UCM_captions.zip -d .data/ucm_captions 5 | rm UCM_captions.zip 6 | -------------------------------------------------------------------------------- /scripts/download_whu_rs19.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/ 2 | wget https://github.com/CAPTAIN-WHU/BED4RS/raw/main/datasets/WHU-RS19.zip 3 | unzip WHU-RS19.zip -d .data/ 4 | rm WHU-RS19.zip 5 | -------------------------------------------------------------------------------- /scripts/download_zuericrop.sh: -------------------------------------------------------------------------------- 1 | mkdir -p .data/zuericrop 2 | wget https://polybox.ethz.ch/index.php/s/uXfdr2AcXE3QNB6/download -O .data/zuericrop/ZueriCrop.hdf5 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = torch-rs 3 | version = attr: torchrs.__version__ 4 | description = PyTorch Library for Remote Sensing 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | license = MIT License 8 | url = https://github.com/isaaccorley/torchrs 9 | author = Isaac Corley 10 | author_email = isaac.corley@my.utsa.edu 11 | classifiers = 12 | Programming Language :: Python :: 3 13 | License :: OSI Approved :: MIT License 14 | Operating System :: OS Independent 15 | keywords = 16 | pytorch 17 | remote-sensing 18 | computer-vision 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from setuptools.config import read_configuration 3 | 4 | extras = { 5 | "train": ["pytorch-lightning>=1.4.0", "torchmetrics>=0.4.1"], 6 | } 7 | install_requires = [ 8 | "torch>=1.9.0", 9 | "torchvision>=0.10.0", 10 | "torchaudio>=0.9.0", 11 | "einops>=0.3.0", 12 | "numpy>=1.21.0", 13 | "pillow>=8.3.1", 14 | "tifffile>=2021.7.2", 15 | "h5py>=3.3.0", 16 | "imagecodecs>=2021.7.30" 17 | ] 18 | setup_requires = ["pytest-runner"] 19 | tests_require = ["pytest", "pytest-cov", "mock", "mypy", "black", "pylint"] 20 | 21 | cfg = read_configuration("setup.cfg") 22 | 23 | setup( 24 | download_url='{}/archive/{}.tar.gz'.format(cfg["metadata"]["url"], cfg["metadata"]["version"]), 25 | project_urls={"Bug Tracker": cfg["metadata"]["url"] + "/issues"}, 26 | install_requires=install_requires, 27 | extras_require=extras, 28 | setup_requires=setup_requires, 29 | tests_require=tests_require, 30 | packages=find_packages(), 31 | python_requires=">=3.7", 32 | ) 33 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/tests/__init__.py -------------------------------------------------------------------------------- /tests/_test_datamodules.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import pytorch_lightning as pl 4 | 5 | from torchrs.train import datamodules 6 | 7 | 8 | skip = ["BaseDataModule", "RSVQAxBENDataModule", "S2MTCPDataModule"] 9 | 10 | 11 | @torch.no_grad() 12 | @pytest.mark.parametrize("datamodule", reversed(datamodules.__all__)) 13 | def test_datamodules(datamodule: pl.LightningDataModule): 14 | 15 | if datamodule in skip: 16 | return 17 | 18 | dm = getattr(datamodules, datamodule)() 19 | dm.setup() 20 | batch = next(iter(dm.train_dataloader())) 21 | batch = next(iter(dm.val_dataloader())) 22 | batch = next(iter(dm.test_dataloader())) 23 | -------------------------------------------------------------------------------- /tests/_test_datasets.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import pytest 3 | from torch.utils.data import Dataset 4 | 5 | from torchrs import datasets 6 | 7 | 8 | skip = ["RSVQAxBEN", "S2MTCP"] 9 | 10 | 11 | @pytest.mark.parametrize("dataset", reversed(datasets.__all__)) 12 | def test_datamodules(dataset: Dataset): 13 | 14 | if dataset in skip: 15 | return 16 | 17 | dataclass = getattr(datasets, dataset) 18 | 19 | if "split" in inspect.getfullargspec(dataclass).args: 20 | for split in dataclass.splits: 21 | ds = dataclass(split=split) 22 | length = len(ds) 23 | sample = ds[0] 24 | sample = ds[-1] 25 | else: 26 | ds = dataclass() 27 | length = len(ds) 28 | sample = ds[0] 29 | sample = ds[-1] 30 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/torchrs/baafb7e639859c74f076ad27a104f6b5bf7cee91/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_rams.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import pytest 3 | import torch 4 | 5 | from torchrs.models import RAMS 6 | 7 | 8 | DTYPE = torch.float32 9 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 10 | IMAGE_SIZE = 32 11 | T = [9, 11, 13, 15] 12 | SCALE_FACTOR = [2, 3, 4] 13 | CHANNELS = [1, 3] 14 | BATCH_SIZE = [1, 2] 15 | 16 | params = list(itertools.product(SCALE_FACTOR, T, CHANNELS, BATCH_SIZE)) 17 | 18 | 19 | @torch.no_grad() 20 | @pytest.mark.parametrize("scale_factor, t, channels, batch_size", params) 21 | def test_rams(scale_factor, t, channels, batch_size): 22 | model = RAMS(scale_factor, t, channels, num_feature_attn_blocks=3) 23 | model = model.to(DEVICE) 24 | model = model.eval() 25 | lr = torch.ones(batch_size, t, channels, IMAGE_SIZE, IMAGE_SIZE) 26 | lr = lr.to(DTYPE) 27 | lr = lr.to(DEVICE) 28 | sr = model(lr) 29 | assert sr.shape == (batch_size, channels, IMAGE_SIZE*scale_factor, IMAGE_SIZE*scale_factor) 30 | assert sr.dtype == torch.float32 31 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from torchrs.transforms import ExtractChips 5 | 6 | 7 | def test_extract_chips(): 8 | x = torch.ones(3, 128, 128) 9 | f = ExtractChips((32, 32)) 10 | z = f(x) 11 | assert z.shape == (16, 3, 32, 32) 12 | -------------------------------------------------------------------------------- /torchrs/__init__.py: -------------------------------------------------------------------------------- 1 | from . import transforms 2 | from . import datasets 3 | from . import models 4 | 5 | __version__ = "0.0.4" 6 | -------------------------------------------------------------------------------- /torchrs/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | from .probav import PROBAV 3 | from .etci2021 import ETCI2021 4 | from .rsvqa import RSVQALR, RSVQAHR, RSVQAxBEN 5 | from .eurosat import EuroSATRGB, EuroSATMS 6 | from .resisc45 import RESISC45 7 | from .rsicd import RSICD 8 | from .oscd import OSCD 9 | from .s2looking import S2Looking 10 | from .levircd import LEVIRCDPlus 11 | from .fair1m import FAIR1M 12 | from .sydney_captions import SydneyCaptions 13 | from .ucm_captions import UCMCaptions 14 | from .s2mtcp import S2MTCP 15 | from .advance import ADVANCE 16 | from .sat import SAT4, SAT6 17 | from .hrscd import HRSCD 18 | from .inria_ail import InriaAIL 19 | from .tiselac import Tiselac 20 | from .gid15 import GID15 21 | from .zuericrop import ZueriCrop 22 | from .aid import AID 23 | from .dubai_segmentation import DubaiSegmentation 24 | from .hkh_glacier import HKHGlacierMapping 25 | from .ucm import UCM 26 | from .patternnet import PatternNet 27 | from .whu_rs19 import WHURS19 28 | from .rsscn7 import RSSCN7 29 | from .brazilian_coffee import BrazilianCoffeeScenes 30 | 31 | 32 | __all__ = [ 33 | "PROBAV", "ETCI2021", "RSVQALR", "RSVQAxBEN", "EuroSATRGB", "EuroSATMS", 34 | "RESISC45", "RSICD", "OSCD", "S2Looking", "LEVIRCDPlus", "FAIR1M", 35 | "SydneyCaptions", "UCMCaptions", "S2MTCP", "ADVANCE", "SAT4", "SAT6", 36 | "HRSCD", "InriaAIL", "Tiselac", "GID15", "ZueriCrop", "AID", "DubaiSegmentation", 37 | "HKHGlacierMapping", "UCM", "PatternNet", "RSVQAHR", "WHURS19", "RSSCN7", 38 | "BrazilianCoffeeScenes" 39 | ] 40 | -------------------------------------------------------------------------------- /torchrs/datasets/advance.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Dict 4 | 5 | import torch 6 | import torchaudio 7 | import numpy as np 8 | import torchvision.transforms as T 9 | from PIL import Image 10 | 11 | 12 | class ADVANCE(torch.utils.data.Dataset): 13 | """ AuDio Visual Aerial sceNe reCognition datasEt (ADVANCE) from 14 | 'Cross-Task Transfer for Geotagged Audiovisual Aerial Scene Recognition', Hu et al. (2020) 15 | https://arxiv.org/abs/2005.08449 16 | 17 | 'We create an annotated dataset consisting of 5075 geotagged aerial imagesound pairs 18 | involving 13 scene classes. This dataset covers a large variety of scenes from across 19 | the world' 20 | """ 21 | def __init__( 22 | self, 23 | root: str = ".data/advance", 24 | image_transform: T.Compose = T.Compose([T.ToTensor()]), 25 | audio_transform: T.Compose = T.Compose([]), 26 | ): 27 | self.root = root 28 | self.image_transform = image_transform 29 | self.audio_transform = audio_transform 30 | self.files = self.load_files(root) 31 | self.classes = sorted(set(f["cls"] for f in self.files)) 32 | 33 | @staticmethod 34 | def load_files(root: str) -> List[Dict]: 35 | images = sorted(glob(os.path.join(root, "vision", "**", "*.jpg"))) 36 | wavs = sorted(glob(os.path.join(root, "sound", "**", "*.wav"))) 37 | labels = [image.split(os.sep)[-2] for image in images] 38 | files = [dict(image=image, audio=wav, cls=label) for image, wav, label in zip(images, wavs, labels)] 39 | return files 40 | 41 | def __len__(self) -> int: 42 | return len(self.files) 43 | 44 | def __getitem__(self, idx: int) -> Dict: 45 | """ Returns a dict containing image, audio, and class label 46 | image: (3, 512, 512) 47 | audio: (1, 220500) 48 | cls: int 49 | """ 50 | files = self.files[idx] 51 | image = np.array(Image.open(files["image"]).convert("RGB")) 52 | audio, fs = torchaudio.load(files["audio"]) 53 | image = self.image_transform(image) 54 | audio = self.audio_transform(audio) 55 | return dict(image=image, audio=audio, cls=files["cls"]) 56 | -------------------------------------------------------------------------------- /torchrs/datasets/aid.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from torchvision.datasets import ImageFolder 3 | 4 | 5 | class AID(ImageFolder): 6 | """ Aerial Image Dataset (AID) from 'AID: A Benchmark Dataset for Performance 7 | Evaluation of Aerial Scene Classification', Xia et al. (2017) 8 | https://arxiv.org/abs/1608.05167 9 | 10 | 'The AID dataset has a number of 10000 images within 30 classes.' 11 | """ 12 | def __init__( 13 | self, 14 | root: str = ".data/AID", 15 | transform: T.Compose = T.Compose([T.ToTensor()]) 16 | ): 17 | super().__init__( 18 | root=root, 19 | transform=transform 20 | ) 21 | -------------------------------------------------------------------------------- /torchrs/datasets/brazilian_coffee.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import Tuple, List 4 | 5 | import torch 6 | import numpy as np 7 | import torchvision.transforms as T 8 | from PIL import Image 9 | 10 | 11 | class BrazilianCoffeeScenes(torch.utils.data.Dataset): 12 | """ Brazilian Coffee Scenes dataset from 'Do Deep Features Generalize from 13 | Everyday Objects to Remote Sensing and Aerial Scenes Domains?', Penatti at al. (2015) 14 | https://arxiv.org/abs/1703.00121 15 | """ 16 | bands = ["Green", "Red", "NIR"] 17 | classes = ["non-coffee", "coffee"] 18 | 19 | def __init__( 20 | self, 21 | root: str = ".data/brazilian_coffee_scenes", 22 | transform: T.Compose = T.Compose([T.ToTensor()]) 23 | ): 24 | super().__init__() 25 | self.transform = transform 26 | self.images, self.labels = self.load_images(root) 27 | 28 | @staticmethod 29 | def load_images(path: str) -> Tuple[List[str], List[int]]: 30 | folds = glob(os.path.join(path, "*.txt")) 31 | images, labels = [], [] 32 | for fold in folds: 33 | fold_dir = os.path.join(path, os.path.splitext(os.path.basename(fold))[0]) 34 | with open(fold, "r") as f: 35 | lines = f.read().strip().splitlines() 36 | 37 | for line in lines: 38 | label, image = line.split(".", 1) 39 | images.append(os.path.join(fold_dir, image + ".jpg")) 40 | labels.append(0 if label == "noncoffee" else 1) 41 | 42 | return images, labels 43 | 44 | def __len__(self) -> int: 45 | return len(self.images) 46 | 47 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]: 48 | image, y = self.images[idx], self.labels[idx] 49 | x = np.array(Image.open(image).convert("RGB")) 50 | x = self.transform(x) 51 | return x, y 52 | -------------------------------------------------------------------------------- /torchrs/datasets/dubai_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Tuple, Dict 4 | 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from torchrs.transforms import Compose, ToTensor 10 | 11 | 12 | class DubaiSegmentation(torch.utils.data.Dataset): 13 | """ Semantic segmentation dataset of Dubai imagery taken by MBRSC satellites 14 | https://humansintheloop.org/resources/datasets/semantic-segmentation-dataset/ 15 | 16 | """ 17 | classes = { 18 | "Unlabeled": {"rgb": (155, 155, 155), "color": "#9B9B9B"}, 19 | "Water": {"rgb": (226, 169, 41), "color": "#E2A929"}, 20 | "Land (unpaved area)": {"rgb": (132, 41, 246), "color": "#8429F6"}, 21 | "Road": {"rgb": (110, 193, 228), "color": "#6EC1E4"}, 22 | "Building": {"rgb": (60, 16, 152), "color": "#3C1098"}, 23 | "Vegetation": {"rgb": (254, 221, 58), "color": "#FEDD3A"} 24 | } 25 | colors = [v["rgb"] for k, v in classes.items()] 26 | 27 | def __init__( 28 | self, 29 | root: str = ".data/dubai-segmentation", 30 | transform: Compose = Compose([ToTensor()]), 31 | ): 32 | self.transform = transform 33 | self.images = self.load_images(root) 34 | self.regions = list(set([image["region"] for image in self.images])) 35 | 36 | @staticmethod 37 | def load_images(path: str) -> List[Dict]: 38 | images = sorted(glob(os.path.join(path, "**", "images", "*.jpg"), recursive=True)) 39 | masks = sorted(glob(os.path.join(path, "**", "masks", "*.png"), recursive=True)) 40 | regions = [image.split(os.sep)[-3] for image in images] 41 | files = [ 42 | dict(image=image, mask=mask, region=region) 43 | for image, mask, region in zip(images, masks, regions) 44 | ] 45 | return files 46 | 47 | @staticmethod 48 | def rgb_to_mask(rgb: np.ndarray, colors: List[Tuple[int, int, int]]) -> np.ndarray: 49 | h, w = rgb.shape[:2] 50 | mask = np.zeros(shape=(h, w), dtype=np.uint8) 51 | for i, c in enumerate(colors): 52 | cmask = (rgb == c) 53 | if isinstance(cmask, np.ndarray): 54 | mask[cmask.all(axis=-1)] = i 55 | 56 | return mask 57 | 58 | def __len__(self) -> int: 59 | return len(self.images) 60 | 61 | def __getitem__(self, idx: int) -> Dict: 62 | image_path, target_path = self.images[idx]["image"], self.images[idx]["mask"] 63 | x = np.array(Image.open(image_path).convert("RGB")) 64 | y = np.array(Image.open(target_path).convert("RGB")) 65 | y = self.rgb_to_mask(y, self.colors) 66 | x, y = self.transform([x, y]) 67 | return dict(x=x, mask=y, region=self.images[idx]["region"]) 68 | -------------------------------------------------------------------------------- /torchrs/datasets/etci2021.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Dict 4 | 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from torchrs.transforms import Compose, ToTensor 10 | 11 | 12 | class ETCI2021(torch.utils.data.Dataset): 13 | """Sentinel-1 Synthetic Aperature Radar (SAR) segmentation dataset from the ETCI 2021 Competition on Flood Detection 14 | https://nasa-impact.github.io/etci2021/ 15 | https://competitions.codalab.org/competitions/30440 16 | 17 | 'The contest dataset is composed of 66,810 (33,405 x 2 VV & VH polarization) tiles of 256 x 256 pixels, 18 | distributed respectively across the training, validation and test sets as follows: 33,405, 10,400, 19 | and 12,348 tiles for each polarization. Each tile includes 3 RGB channels which have been converted 20 | by tiling 54 labeled GeoTiff files generated from Sentinel-1 C-band synthetic aperture radar (SAR) 21 | imagery data using Hybrid Pluggable Processing Pipeline hyp3.' 22 | 23 | Note that hyp3 preprocessing generates 3 band for each band so VV and VH are both of shape (256, 256, 3) 24 | """ 25 | bands = ["VV", "VH"] 26 | splits = ["train", "val", "test"] 27 | split_to_folder = dict(train="train", val="test", test="test_internal") 28 | 29 | def __init__( 30 | self, 31 | root: str = ".data/etci2021", 32 | split: str = "train", 33 | transform: Compose = Compose([ToTensor()]), 34 | ): 35 | assert split in self.splits 36 | self.split = split 37 | self.transform = transform 38 | self.images = self.load_files(root, self.split_to_folder[split]) 39 | 40 | @staticmethod 41 | def load_files(root: str, split: str) -> List[Dict]: 42 | images = [] 43 | folders = sorted(glob(os.path.join(root, split, "*"))) 44 | folders = [f + "/tiles" for f in folders] 45 | for folder in folders: 46 | vvs = glob(os.path.join(folder, "vv", "*.png")) 47 | vhs = glob(os.path.join(folder, "vh", "*.png")) 48 | water_masks = glob(os.path.join(folder, "water_body_label", "*.png")) 49 | 50 | if split == "test_internal": 51 | flood_masks = [None] * len(water_masks) 52 | else: 53 | flood_masks = glob(os.path.join(folder, "flood_label", "*.png")) 54 | 55 | for vv, vh, flood_mask, water_mask in zip(vvs, vhs, flood_masks, water_masks): 56 | images.append(dict(vv=vv, vh=vh, flood_mask=flood_mask, water_mask=water_mask)) 57 | return images 58 | 59 | def __len__(self) -> int: 60 | return len(self.images) 61 | 62 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 63 | """ Returns a dict containing vv, vh, flood mask, water mask 64 | vv: (3, h, w) 65 | vh: (3, h, w) 66 | flood mask: (1, h, w) flood mask 67 | water mask: (1, h, w) water mask 68 | """ 69 | images = self.images[idx] 70 | vv = np.array(Image.open(images["vv"]), dtype="uint8") 71 | vh = np.array(Image.open(images["vh"]), dtype="uint8") 72 | water_mask = np.array(Image.open(images["water_mask"]).convert("L"), dtype="bool") 73 | 74 | if self.split == "test": 75 | vv, vh, water_mask = self.transform([vv, vh, water_mask]) 76 | output = dict(vv=vv, vh=vh, water_mask=water_mask) 77 | else: 78 | flood_mask = np.array(Image.open(images["flood_mask"]).convert("L"), dtype="bool") 79 | vv, vh, flood_mask, water_mask = self.transform([vv, vh, flood_mask, water_mask]) 80 | output = dict(vv=vv, vh=vh, flood_mask=flood_mask, water_mask=water_mask) 81 | 82 | return output 83 | -------------------------------------------------------------------------------- /torchrs/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tifffile 4 | import torchvision.transforms as T 5 | from torchvision.datasets import ImageFolder 6 | 7 | from torchrs.transforms import ToTensor 8 | 9 | 10 | class EuroSATRGB(ImageFolder): 11 | """ Sentinel-2 RGB Land Cover Classification dataset from 'EuroSAT: A Novel Dataset 12 | and Deep Learning Benchmark for Land Use and Land Cover Classification', Helber at al. (2017) 13 | https://arxiv.org/abs/1709.00029 14 | 15 | 'We present a novel dataset based on Sentinel-2 satellite images covering 13 spectral 16 | bands and consisting out of 10 classes with in total 27,000 labeled and geo-referenced images.' 17 | 18 | Note: RGB bands only 19 | """ 20 | def __init__( 21 | self, 22 | root: str = ".data/eurosat-rgb", 23 | transform: T.Compose = T.Compose([T.ToTensor()]) 24 | ): 25 | super().__init__( 26 | root=os.path.join(root, "2750"), 27 | transform=transform 28 | ) 29 | 30 | 31 | class EuroSATMS(ImageFolder): 32 | """ Sentinel-2 RGB Land Cover Classification dataset from 'EuroSAT: A Novel Dataset 33 | and Deep Learning Benchmark for Land Use and Land Cover Classification', Helber at al. (2017) 34 | https://arxiv.org/abs/1709.00029 35 | 36 | 'We present a novel dataset based on Sentinel-2 satellite images covering 13 spectral 37 | bands and consisting out of 10 classes with in total 27,000 labeled and geo-referenced images.' 38 | 39 | Note: all 13 multispectral (MS) bands 40 | """ 41 | def __init__( 42 | self, 43 | root: str = ".data/eurosat-ms", 44 | transform: T.Compose = T.Compose([ToTensor()]) 45 | ): 46 | super().__init__( 47 | root=os.path.join(root, "ds/images/remote_sensing/otherDatasets/sentinel_2/tif"), 48 | transform=transform, 49 | loader=tifffile.imread 50 | ) 51 | -------------------------------------------------------------------------------- /torchrs/datasets/fair1m.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from xml.etree import ElementTree 4 | from typing import List, Dict 5 | 6 | import torch 7 | import numpy as np 8 | import torchvision.transforms as T 9 | from PIL import Image 10 | 11 | 12 | def parse_pascal_voc(path: str) -> Dict: 13 | et = ElementTree.parse(path) 14 | element = et.getroot() 15 | image = element.find("source").find("filename").text 16 | classes, points = [], [] 17 | for obj in element.find("objects").findall("object"): 18 | obj_points = [p.text.split(",") for p in obj.find("points").findall("point")] 19 | obj_points = [(float(p1), float(p2)) for p1, p2 in obj_points] 20 | cls = obj.find("possibleresult").find("name").text 21 | classes.append(cls) 22 | points.append(obj_points) 23 | return dict(image=image, points=points, classes=classes) 24 | 25 | 26 | class FAIR1M(torch.utils.data.Dataset): 27 | """ FAIR1M dataset from 'FAIR1M: A Benchmark Dataset 28 | for Fine-grained Object Recognition in High-Resolution Remote Sensing Imagery', Sun at al. (2021) 29 | https://arxiv.org/abs/2103.05569 30 | 31 | 'We propose a novel benchmark dataset with more than 1 million instances and more 32 | than 15,000 images for Fine-grained object recognition in high-resolution remote 33 | sensing imagery which is named as FAIR1M. We collected remote sensing images with 34 | a resolution of 0.3m to 0.8m from different platforms, which are spread across many 35 | countries and regions. All objects in the FAIR1M dataset are annotated with respect 36 | to 5 categories and 37 sub-categories by oriented bounding boxes.' 37 | """ 38 | classes = { 39 | "Passenger Ship": {"id": 0, "category": "Ship"}, 40 | "Motorboat": {"id": 1, "category": "Ship"}, 41 | "Fishing Boat": {"id": 2, "category": "Ship"}, 42 | "Tugboat": {"id": 3, "category": "Ship"}, 43 | "other-ship": {"id": 4, "category": "Ship"}, 44 | "Engineering Ship": {"id": 5, "category": "Ship"}, 45 | "Liquid Cargo Ship": {"id": 6, "category": "Ship"}, 46 | "Dry Cargo Ship": {"id": 7, "category": "Ship"}, 47 | "Warship": {"id": 8, "category": "Ship"}, 48 | "Small Car": {"id": 9, "category": "Vehicle"}, 49 | "Bus": {"id": 10, "category": "Vehicle"}, 50 | "Cargo Truck": {"id": 11, "category": "Vehicle"}, 51 | "Dump Truck": {"id": 12, "category": "Vehicle"}, 52 | "other-vehicle": {"id": 13, "category": "Vehicle"}, 53 | "Van": {"id": 14, "category": "Vehicle"}, 54 | "Trailer": {"id": 15, "category": "Vehicle"}, 55 | "Tractor": {"id": 16, "category": "Vehicle"}, 56 | "Excavator": {"id": 17, "category": "Vehicle"}, 57 | "Truck Tractor": {"id": 18, "category": "Vehicle"}, 58 | "Boeing737": {"id": 19, "category": "Airplane"}, 59 | "Boeing747": {"id": 20, "category": "Airplane"}, 60 | "Boeing777": {"id": 21, "category": "Airplane"}, 61 | "Boeing787": {"id": 22, "category": "Airplane"}, 62 | "ARJ21": {"id": 23, "category": "Airplane"}, 63 | "C919": {"id": 24, "category": "Airplane"}, 64 | "A220": {"id": 25, "category": "Airplane"}, 65 | "A321": {"id": 26, "category": "Airplane"}, 66 | "A330": {"id": 27, "category": "Airplane"}, 67 | "A350": {"id": 28, "category": "Airplane"}, 68 | "other-airplane": {"id": 29, "category": "Airplane"}, 69 | "Baseball Field": {"id": 30, "category": "Court"}, 70 | "Basketball Court": {"id": 31, "category": "Court"}, 71 | "Football Field": {"id": 32, "category": "Court"}, 72 | "Tennis Court": {"id": 33, "category": "Court"}, 73 | "Roundabout": {"id": 34, "category": "Road"}, 74 | "Intersection": {"id": 35, "category": "Road"}, 75 | "Bridge": {"id": 36, "category": "Road"} 76 | } 77 | 78 | def __init__( 79 | self, 80 | root: str = ".data/fair1m", 81 | transform: T.Compose = T.Compose([T.ToTensor()]), 82 | ): 83 | split = "train" 84 | self.root = root 85 | self.image_root = os.path.join(root, split, "part1", "images") 86 | self.transform = transform 87 | self.images = self.load_files(root, split) 88 | self.idx2cls = {i: c for i, c in enumerate(self.classes)} 89 | self.cls2idx = {c: i for i, c in self.idx2cls.items()} 90 | 91 | @staticmethod 92 | def load_files(root: str, split: str) -> List[Dict]: 93 | files = sorted(glob(os.path.join(root, split, "part1", "labelXmls", "*.xml"))) 94 | return [parse_pascal_voc(f) for f in files] 95 | 96 | def __len__(self) -> int: 97 | return len(self.images) 98 | 99 | def __getitem__(self, idx: int) -> Dict: 100 | """ Returns a dict containing x, y, points where points is the x,y coords of the rotated bbox 101 | x: (3, h, w) 102 | y: (N,) 103 | points: (N, 5, 2) 104 | """ 105 | image = self.images[idx] 106 | x = np.array(Image.open(os.path.join(self.image_root, image["image"]))) 107 | x = x[..., :3] 108 | x = self.transform(x) 109 | y = torch.tensor([self.cls2idx[c] for c in image["classes"]]) 110 | points = torch.tensor(image["points"]) 111 | return dict(x=x, y=y, points=points) 112 | -------------------------------------------------------------------------------- /torchrs/datasets/gid15.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Dict 4 | 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from torchrs.transforms import Compose, ToTensor 10 | 11 | 12 | class GID15(torch.utils.data.Dataset): 13 | """ Gaofen Image Dataset (GID-15) from 'Land-Cover Classification with High-Resolution 14 | Remote Sensing Images Using Transferable Deep Models', Tong et al. (2018) 15 | https://arxiv.org/abs/1807.05713 16 | 17 | 'We construct a new large-scale land-cover dataset with Gaofen-2 (GF-2) satellite 18 | images. This new dataset, which is named as Gaofen Image Dataset with 15 categories 19 | (GID-15), has superiorities over the existing land-cover dataset because of its 20 | large coverage, wide distribution, and high spatial resolution. The large-scale 21 | remote sensing semantic segmentation set contains 150 pixel-level annotated GF-2 22 | images, which is labeled in 15 categories.' 23 | """ 24 | classes = [ 25 | "background", 26 | "industrial_land", 27 | "urban_residential", 28 | "rural_residential", 29 | "traffic_land", 30 | "paddy_field", 31 | "irrigated_land", 32 | "dry_cropland", 33 | "garden_plot", 34 | "arbor_woodland", 35 | "shrub_land", 36 | "natural_grassland", 37 | "artificial_grassland", 38 | "river", 39 | "lake", 40 | "pond" 41 | ] 42 | splits = ["train", "val", "test"] 43 | 44 | def __init__( 45 | self, 46 | root: str = ".data/gid-15", 47 | split: str = "train", 48 | transform: Compose = Compose([ToTensor()]), 49 | ): 50 | self.split = split 51 | self.transform = transform 52 | self.images = self.load_images(os.path.join(root, "GID"), split) 53 | 54 | @staticmethod 55 | def load_images(path: str, split: str) -> List[Dict]: 56 | images = sorted(glob(os.path.join(path, "img_dir", split, "*.tif"))) 57 | if split in ["train", "val"]: 58 | masks = [ 59 | image.replace("img_dir", "ann_dir").replace(".tif", "_15label.png") 60 | for image in images 61 | ] 62 | else: 63 | masks = [None] * len(images) 64 | 65 | files = [dict(image=image, mask=mask) for image, mask in zip(images, masks)] 66 | return files 67 | 68 | def __len__(self) -> int: 69 | return len(self.images) 70 | 71 | def __getitem__(self, idx: int) -> Dict: 72 | image_path, mask_path = self.images[idx]["image"], self.images[idx]["mask"] 73 | x = np.array(Image.open(image_path)) 74 | 75 | if self.split in ["train", "val"]: 76 | y = np.array(Image.open(mask_path)) 77 | x, y = self.transform([x, y]) 78 | output = dict(x=x, mask=y) 79 | else: 80 | x = self.transform(x) 81 | output = dict(x=x) 82 | 83 | return output 84 | -------------------------------------------------------------------------------- /torchrs/datasets/hkh_glacier.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Dict 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from torchrs.transforms import Compose, ToTensor 9 | 10 | 11 | class HKHGlacierMapping(torch.utils.data.Dataset): 12 | """ Hindu Kush Himalayas (HKH) Glacier Mapping dataset 13 | https://lila.science/datasets/hkh-glacier-mapping 14 | 15 | 'We also provide 14190 numpy patches. The numpy patches are all of size 512x512x15 and 16 | corresponding 512x512x2 pixel-wise mask labels; the two channels in the pixel-wise masks 17 | correspond to clean-iced and debris-covered glaciers. Patches' geolocation information, 18 | time stamps, source Landsat IDs, and glacier density are available in a geojson metadata file.' 19 | 20 | """ 21 | bands = [ 22 | "LE7 B1 (blue)", 23 | "LE7 B2 (green)", 24 | "LE7 B3 (red)", 25 | "LE7 B4 (near infrared)", 26 | "LE7 B5 (shortwave infrared 1)", 27 | "LE7 B6_VCID_1 (low-gain thermal infrared)", 28 | "LE7 B6_VCID_2 (high-gain thermal infrared)", 29 | "LE7 B7 (shortwave infrared 2)", 30 | "LE7 B8 (panchromatic)", 31 | "LE7 BQA (quality bitmask)", 32 | "NDVI (vegetation index)", 33 | "NDSI (snow index)", 34 | "NDWI (water index)", 35 | "SRTM 90 elevation", 36 | "SRTM 90 slope" 37 | ] 38 | 39 | def __init__( 40 | self, 41 | root: str = ".data/hkh_glacier_mapping", 42 | transform: Compose = Compose([ToTensor()]), 43 | ): 44 | self.transform = transform 45 | self.images = self.load_images(root) 46 | 47 | @staticmethod 48 | def load_images(path: str) -> List[Dict]: 49 | images = sorted(glob(os.path.join(path, "images", "*.npy"))) 50 | masks = sorted(glob(os.path.join(path, "masks", "*.npy"))) 51 | return [dict(image=image, mask=mask) for image, mask in zip(images, masks)] 52 | 53 | def __len__(self) -> int: 54 | return len(self.images) 55 | 56 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 57 | image_path, target_path = self.images[idx]["image"], self.images[idx]["mask"] 58 | x, y = np.load(image_path), np.load(target_path) 59 | y0, y1 = y[..., 0], y[..., 1] 60 | x, y0, y1 = self.transform([x, y0, y1]) 61 | return dict(x=x, clean_ice_mask=y0, debris_covered_mask=y1) 62 | -------------------------------------------------------------------------------- /torchrs/datasets/hrscd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Dict 4 | 5 | import torch 6 | import tifffile 7 | 8 | from torchrs.transforms import Compose, ToTensor 9 | 10 | 11 | class HRSCD(torch.utils.data.Dataset): 12 | """ The High Resolution Semantic Change Detection (HRSCD) dataset from 'Multitask Learning 13 | for Large-scale Semantic Change Detection', Daudt at al. (2018) 14 | https://arxiv.org/abs/1810.08452 15 | 16 | 'This dataset contains 291 coregistered image pairs of RGB aerial images from IGS's 17 | BD ORTHO database. Pixel-level change and land cover annotations are provided, generated 18 | by rasterizing Urban Atlas 2006, Urban Atlas 2012, and Urban Atlas Change 2006-2012 maps.' 19 | """ 20 | classes = [ 21 | "No Information", 22 | "Artificial surfaces", 23 | "Agricultural areas", 24 | "Forests", 25 | "Wetlands", 26 | "Water" 27 | ] 28 | 29 | def __init__( 30 | self, 31 | root: str = ".data/HRSCD", 32 | transform: Compose = Compose([ToTensor()]) 33 | ): 34 | self.root = root 35 | self.transform = transform 36 | self.files = self.load_files(root) 37 | 38 | @staticmethod 39 | def load_files(root: str) -> List[Dict]: 40 | images1 = sorted(glob(os.path.join(root, "images", "2006", "**", "*.tif"), recursive=True)) 41 | images2 = sorted(glob(os.path.join(root, "images", "2012", "**", "*.tif"), recursive=True)) 42 | lcs1 = sorted(glob(os.path.join(root, "labels", "2006", "**", "*.tif"), recursive=True)) 43 | lcs2 = sorted(glob(os.path.join(root, "labels", "2012", "**", "*.tif"), recursive=True)) 44 | changes = sorted(glob(os.path.join(root, "labels", "change", "**", "*.tif"), recursive=True)) 45 | files = [] 46 | for image1, image2, lc1, lc2, change in zip(images1, images2, lcs1, lcs2, changes): 47 | region = image1.split(os.sep)[-2] 48 | files.append(dict(image1=image1, image2=image2, lc1=lc1, lc2=lc2, mask=change, region=region)) 49 | return files 50 | 51 | def __len__(self) -> int: 52 | return len(self.files) 53 | 54 | def __getitem__(self, idx: int) -> Dict: 55 | """ Returns a dict containing x, land cover mask, change mask 56 | x: (2, 3, 1000, 1000) 57 | lc: (2, 1000, 1000) 58 | mask: (1, 1000, 1000) 59 | """ 60 | files = self.files[idx] 61 | image1 = tifffile.imread(files["image1"]) 62 | image2 = tifffile.imread(files["image2"]) 63 | lc1 = tifffile.imread(files["lc1"]) 64 | lc2 = tifffile.imread(files["lc2"]) 65 | mask = tifffile.imread(files["mask"]) 66 | image1, image2, lc1, lc2, mask = self.transform([image1, image2, lc1, lc2, mask]) 67 | x = torch.stack([image1, image2], dim=0) 68 | lc = torch.cat([lc1, lc2], dim=0) 69 | return dict(x=x, lc=lc, mask=mask) 70 | -------------------------------------------------------------------------------- /torchrs/datasets/inria_ail.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from glob import glob 4 | from typing import List, Dict 5 | 6 | import torch 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from torchrs.transforms import Compose, ToTensor 11 | 12 | 13 | class InriaAIL(torch.utils.data.Dataset): 14 | """ Inria Aerial Image Labeling dataset from 'Can semantic labeling methods 15 | generalize to any city? the inria aerial image labeling benchmark', Maggiori et al. (2017) 16 | https://ieeexplore.ieee.org/document/8127684 17 | 18 | 'The training set contains 180 color image tiles of size 5000x5000, covering a surface of 1500mx1500m 19 | each (at a 30 cm resolution). There are 36 tiles for each of the following regions (Austin, Chicago, Kitsap County, Western Tyrol, Vienna) 20 | The format is GeoTIFF. Files are named by a prefix associated to the region (e.g., austin- or vienna-) 21 | followed by the tile number (1-36). The reference data is in a different folder and the file names 22 | correspond exactly to those of the color images. In the case of the reference data, the tiles are 23 | single-channel images with values 255 for the building class and 0 for the not building class.' 24 | """ 25 | splits = ["train", "test"] 26 | 27 | def __init__( 28 | self, 29 | root: str = ".data/AerialImageDataset", 30 | split: str = "train", 31 | transform: Compose = Compose([ToTensor()]), 32 | ): 33 | self.split = split 34 | self.transform = transform 35 | self.images = self.load_images(root, split) 36 | self.regions = sorted(list(set(image["region"] for image in self.images))) 37 | 38 | @staticmethod 39 | def load_images(path: str, split: str) -> List[Dict]: 40 | images = sorted(glob(os.path.join(path, split, "images", "*.tif"))) 41 | pattern = re.compile("[a-zA-Z]+") 42 | regions = [re.findall(pattern, os.path.basename(image))[0] for image in images] 43 | 44 | if split == "train": 45 | targets = sorted(glob(os.path.join(path, split, "gt", "*.tif"))) 46 | else: 47 | targets = [None] * len(images) 48 | 49 | files = [ 50 | dict(image=image, target=target, region=region) 51 | for image, target, region in zip(images, targets, regions) 52 | ] 53 | return files 54 | 55 | def __len__(self) -> int: 56 | return len(self.images) 57 | 58 | def __getitem__(self, idx: int) -> Dict: 59 | image_path, target_path = self.images[idx]["image"], self.images[idx]["target"] 60 | x = np.array(Image.open(image_path)) 61 | 62 | if self.split == "train": 63 | y = np.array(Image.open(target_path)) 64 | y = np.clip(y, a_min=0, a_max=1) 65 | x, y = self.transform([x, y]) 66 | output = dict(x=x, mask=y, region=self.images[idx]["region"]) 67 | else: 68 | x = self.transform(x) 69 | output = dict(x=x) 70 | 71 | return output 72 | -------------------------------------------------------------------------------- /torchrs/datasets/levircd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import Dict 4 | 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from torchrs.transforms import Compose, ToTensor 10 | 11 | 12 | class LEVIRCDPlus(torch.utils.data.Dataset): 13 | """ LEVIR-CD+ dataset from 'S2Looking: A Satellite Side-Looking 14 | Dataset for Building Change Detection', Shen at al. (2021) 15 | https://arxiv.org/abs/2107.09244 16 | 17 | 'LEVIR-CD+ contains more than 985 VHR (0.5m/pixel) bitemporal Google 18 | Earth images with dimensions of 1024x1024 pixels. These bitemporal images 19 | are from 20 different regions located in several cities in the state of 20 | Texas in the USA. The capture times of the image data vary from 2002 to 21 | 2020. Images of different regions were taken at different times. The 22 | bitemporal images have a time span of 5 years.' 23 | """ 24 | splits = ["train", "test"] 25 | 26 | def __init__( 27 | self, 28 | root: str = ".data/levircd_plus", 29 | split: str = "train", 30 | transform: Compose = Compose([ToTensor()]), 31 | ): 32 | assert split in self.splits 33 | self.root = root 34 | self.transform = transform 35 | self.files = self.load_files(root, split) 36 | 37 | @staticmethod 38 | def load_files(root: str, split: str): 39 | files = [] 40 | images = glob(os.path.join(root, split, "A", "*.png")) 41 | images = sorted([os.path.basename(image) for image in images]) 42 | for image in images: 43 | image1 = os.path.join(root, split, "A", image) 44 | image2 = os.path.join(root, split, "B", image) 45 | mask = os.path.join(root, split, "label", image) 46 | files.append(dict(image1=image1, image2=image2, mask=mask)) 47 | return files 48 | 49 | def __len__(self) -> int: 50 | return len(self.files) 51 | 52 | def __getitem__(self, idx: int) -> Dict: 53 | """ Returns a dict containing x, mask 54 | x: (2, 13, h, w) 55 | mask: (1, h, w) 56 | """ 57 | files = self.files[idx] 58 | mask = np.array(Image.open(files["mask"])) 59 | mask = np.clip(mask, 0, 1) 60 | image1 = np.array(Image.open(files["image1"])) 61 | image2 = np.array(Image.open(files["image2"])) 62 | image1, image2, mask = self.transform([image1, image2, mask]) 63 | x = torch.stack([image1, image2], dim=0) 64 | return dict(x=x, mask=mask) 65 | -------------------------------------------------------------------------------- /torchrs/datasets/oscd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import Dict 4 | 5 | import torch 6 | import tifffile 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from torchrs.transforms import Compose, ToTensor 11 | 12 | 13 | def sort(x): 14 | band = os.path.splitext(x.split(os.sep)[-1])[0] 15 | if band == "B8A": 16 | band = "B08A" 17 | return band 18 | 19 | 20 | class OSCD(torch.utils.data.Dataset): 21 | """ The Onera Satellite Change Detection (OSCD) dataset from 'Urban Change Detection for 22 | Multispectral Earth Observation Using Convolutional Neural Networks', Daudt at al. (2018) 23 | https://arxiv.org/abs/1703.00121 24 | 25 | 'The Onera Satellite Change Detection dataset addresses the issue of detecting changes between 26 | satellite images from different dates. It comprises 24 pairs of multispectral images taken 27 | from the Sentinel-2 satellites between 2015 and 2018. Locations are picked all over the world, 28 | in Brazil, USA, Europe, Middle-East and Asia. For each location, registered pairs of 13-band 29 | multispectral satellite images obtained by the Sentinel-2 satellites are provided. Images vary 30 | in spatial resolution between 10m, 20m and 60m. Pixel-level change ground truth is provided for 31 | all 14 training and 10 test image pairs. The annotated changes focus on urban changes, such as 32 | new buildings or new roads. These data can be used for training and setting parameters of change 33 | detection algorithms. 34 | """ 35 | splits = ["train", "test"] 36 | 37 | def __init__( 38 | self, 39 | root: str = ".data/oscd", 40 | split: str = "train", 41 | transform: Compose = Compose([ToTensor(permute_dims=False)]), 42 | ): 43 | assert split in self.splits 44 | self.root = root 45 | self.transform = transform 46 | self.regions = self.load_files(root, split) 47 | 48 | @staticmethod 49 | def load_files(root: str, split: str): 50 | regions = [] 51 | labels_root = os.path.join(root, f"{split}_labels") 52 | images_root = os.path.join(root, "images") 53 | folders = glob(os.path.join(labels_root, "*/")) 54 | for folder in folders: 55 | region = folder.split(os.sep)[-2] 56 | mask = os.path.join(labels_root, region, "cm", "cm.png") 57 | images1 = glob(os.path.join(images_root, region, "imgs_1_rect", "*.tif")) 58 | images2 = glob(os.path.join(images_root, region, "imgs_2_rect", "*.tif")) 59 | images1 = sorted(images1, key=sort) 60 | images2 = sorted(images2, key=sort) 61 | with open(os.path.join(images_root, region, "dates.txt")) as f: 62 | dates = tuple([line.split()[-1] for line in f.read().strip().splitlines()]) 63 | 64 | regions.append(dict(region=region, images1=images1, images2=images2, mask=mask, dates=dates)) 65 | 66 | return regions 67 | 68 | def __len__(self) -> int: 69 | return len(self.regions) 70 | 71 | def __getitem__(self, idx: int) -> Dict: 72 | """ Returns a dict containing x, mask 73 | x: (2, 13, h, w) 74 | mask: (1, h, w) 75 | """ 76 | region = self.regions[idx] 77 | mask = np.array(Image.open(region["mask"])) 78 | mask[mask == 255] = 1 79 | image1 = np.stack([tifffile.imread(path) for path in region["images1"]], axis=0) 80 | image2 = np.stack([tifffile.imread(path) for path in region["images2"]], axis=0) 81 | image1, image2, mask = self.transform([image1, image2, mask]) 82 | x = torch.stack([image1, image2], dim=0) 83 | return dict(x=x, mask=mask) 84 | -------------------------------------------------------------------------------- /torchrs/datasets/patternnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision.transforms as T 4 | from torchvision.datasets import ImageFolder 5 | 6 | 7 | class PatternNet(ImageFolder): 8 | """ PatternNet dataset from 'PatternNet: A benchmark dataset for performance 9 | evaluation of remote sensing image retrieval', Zhou at al. (2018) 10 | https://arxiv.org/abs/1706.03424 11 | 12 | """ 13 | def __init__( 14 | self, 15 | root: str = ".data/PatternNet", 16 | transform: T.Compose = T.Compose([T.ToTensor()]) 17 | ): 18 | super().__init__( 19 | root=os.path.join(root, "images"), 20 | transform=transform 21 | ) 22 | -------------------------------------------------------------------------------- /torchrs/datasets/probav.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Dict 4 | 5 | import torch 6 | import numpy as np 7 | import torchvision.transforms as T 8 | from PIL import Image 9 | 10 | from torchrs.transforms import ToTensor, ToDtype 11 | 12 | 13 | STATS = { 14 | "RED": {"min": 864.0, "max": 15312.8, "mean": 0.0587, "std": 0.6386}, 15 | "NIR": {"min": 628.8, "max": 17464.0, "mean": 0.0618, "std": 0.7428} 16 | } 17 | 18 | 19 | def collate_fn(batch: List[Dict], t: int = 9, shuffle: bool = False) -> Dict[str, torch.Tensor]: 20 | lrs = [x["lr"] for x in batch] 21 | qms = [x["qm"] for x in batch] 22 | 23 | # Shuffle lr temporal order 24 | if shuffle: 25 | perms = [torch.randperm(lr.shape[0]) for lr in lrs] 26 | lrs = [lr[perm] for perm, lr in zip(perms, lrs)] 27 | qms = [qm[perm] for perm, qm in zip(perms, qms)] 28 | 29 | # Repeat some images if less than temporal length t 30 | for i in range(len(lrs)): 31 | if lrs[i].shape[0] < t: 32 | n = t - lrs[i].shape[0] 33 | lrs[i] = torch.cat([lrs[i], lrs[i][:n]], dim=0) 34 | qms[i] = torch.cat([qms[i], qms[i][:n]], dim=0) 35 | 36 | # Select t lr images from each set 37 | lrs = torch.stack([lr[:t] for lr in lrs]) 38 | qms = torch.stack([qm[:t] for qm in qms]) 39 | 40 | hr = torch.stack([x["hr"] for x in batch]) 41 | sm = torch.stack([x["sm"] for x in batch]) 42 | 43 | return dict(lr=lrs, hr=hr, qm=qms, sm=sm) 44 | 45 | 46 | class PROBAV(torch.utils.data.Dataset): 47 | """Multi-image super resolution (MISR) dataset from the PROBA-V Super Resolution Competition 48 | https://kelvins.esa.int/proba-v-super-resolution/home/ 49 | 50 | 'We collected satellite data from the PROBA-V mission of the European Space Agency from 74 hand-selected 51 | regions around the globe at different points in time. The data is composed of radiometrically and geometrically 52 | corrected Top-Of-Atmosphere (TOA) reflectances for the RED and NIR spectral bands at 300m and 100m resolution 53 | in Plate Carrée projection. The 300m resolution data is delivered as 128x128 grey-scale pixel images, the 100m 54 | resolution data as 384x384 grey-scale pixel images. The bit-depth of the images is 14, but they are saved in a 55 | 16-bit .png-format (which makes them look relatively dark if opened in typical image viewers). 56 | 57 | Each image comes with a quality map, indicating which pixels in the image are concealed 58 | (i.e. clouds, cloud shadows, ice, water, missing, etc) and which should be considered clear. For an image to be 59 | included in the dataset, at least 75% of its pixels have to be clear for 100m resolution images, and 60% for 60 | 300m resolution images. Each data-point consists of exactly one 100m resolution image and several 300m resolution 61 | images from the same scene. In total, the dataset contains 1450 scenes, which are split into 1160 scenes for 62 | training and 290 scenes for testing. On average, each scene comes with 19 different low resolution images and 63 | always with at least 9. We expect you to submit a 384x384 image for each of the 290 test-scenes, for which we 64 | will not provide a high resolution image.' 65 | """ 66 | splits = ["train", "test"] 67 | 68 | def __init__( 69 | self, 70 | root: str = ".data/probav", 71 | split: str = "train", 72 | band: str = "RED", 73 | lr_transform: T.Compose = T.Compose([ToTensor(), ToDtype(torch.float32)]), 74 | hr_transform: T.Compose = T.Compose([ToTensor(), ToDtype(torch.float32)]), 75 | ): 76 | assert split in self.splits 77 | assert band in ["RED", "NIR"] 78 | self.split = split 79 | self.lr_transform = lr_transform 80 | self.hr_transform = hr_transform 81 | self.imgsets = self.load_files(root, split, band) 82 | 83 | @staticmethod 84 | def load_files(root: str, split: str, band: str) -> List[Dict]: 85 | imgsets = [] 86 | folders = sorted(glob(os.path.join(root, split, band, "imgset*"))) 87 | for folder in folders: 88 | lr = sorted(glob(os.path.join(folder, "LR*.png"))) 89 | qm = sorted(glob(os.path.join(folder, "QM*.png"))) 90 | sm = glob(os.path.join(folder, "SM.png"))[0] 91 | 92 | if split == "train": 93 | hr = glob(os.path.join(folder, "HR.png"))[0] 94 | else: 95 | hr = None 96 | 97 | imgsets.append(dict(lr=lr, qm=qm, hr=hr, sm=sm)) 98 | return imgsets 99 | 100 | def __len__(self) -> int: 101 | return len(self.imgsets) 102 | 103 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 104 | """ Returns a dict containing lr, qm, hr, sm 105 | lr: (t, 1, h, w) low resolution images 106 | qm: (t, 1, h, w) low resolution image quality masks 107 | hr: (1, h, w) high resolution image 108 | sm: (1, h, w) high resolution image status mask 109 | 110 | Note: 111 | lr/qm original size is (128, 128), 112 | hr/sm original size is (384, 384) (scale factor = 3) 113 | t is the number of lr images for an image set (min = 9) 114 | """ 115 | imgset = self.imgsets[idx] 116 | 117 | # Load 118 | lrs = [np.array(Image.open(lr), dtype="int32") for lr in imgset["lr"]] 119 | qms = [np.array(Image.open(qm), dtype="bool") for qm in imgset["qm"]] 120 | sm = np.array(Image.open(imgset["sm"]), dtype="bool") 121 | 122 | # Transform 123 | lrs = torch.stack([self.lr_transform(lr) for lr in lrs]) 124 | qms = torch.stack([torch.from_numpy(qm) for qm in qms]).unsqueeze(1) 125 | sm = torch.from_numpy(sm).unsqueeze(0) 126 | 127 | if self.split == "train": 128 | hr = np.array(Image.open(imgset["hr"]), dtype="int32") 129 | hr = self.hr_transform(hr) 130 | output = dict(lr=lrs, qm=qms, hr=hr, sm=sm) 131 | else: 132 | output = dict(lr=lrs, qm=qms, sm=sm) 133 | 134 | return output 135 | -------------------------------------------------------------------------------- /torchrs/datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from torchvision.datasets import ImageFolder 3 | 4 | 5 | class RESISC45(ImageFolder): 6 | """ Image Scene Classification dataset from 'Remote Sensing Image 7 | Scene Classification: Benchmark and State of the Art', Cheng at al. (2017) 8 | https://arxiv.org/abs/1703.00121 9 | 10 | 'We propose a large-scale dataset, termed "NWPU-RESISC45", which is a publicly 11 | available benchmark for REmote Sensing Image Scene Classification (RESISC), created 12 | by Northwestern Polytechnical University (NWPU). This dataset contains 31,500 images, 13 | covering 45 scene classes with 700 images in each class. The proposed NWPU-RESISC45 (i) 14 | is large-scale on the scene classes and the total image number, (ii) holds big variations 15 | in translation, spatial resolution, viewpoint, object pose, illumination, background, and 16 | occlusion, and (iii) has high within-class diversity and between-class similarity.' 17 | """ 18 | def __init__( 19 | self, 20 | root: str = ".data/NWPU-RESISC45", 21 | transform: T.Compose = T.Compose([T.ToTensor()]) 22 | ): 23 | super().__init__( 24 | root=root, 25 | transform=transform 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/datasets/rsicd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Dict 4 | 5 | import torch 6 | import torchvision.transforms as T 7 | from PIL import Image 8 | 9 | 10 | class RSICD(torch.utils.data.Dataset): 11 | """ Image Captioning Dataset from 'Exploring Models and Data for 12 | Remote Sensing Image Caption Generation', Lu et al. (2017) 13 | https://arxiv.org/abs/1712.07835 14 | 15 | 'RSICD is used for remote sensing image captioning task. more than ten thousands 16 | remote sensing images are collected from Google Earth, Baidu Map, MapABC, Tianditu. 17 | The images are fixed to 224X224 pixels with various resolutions. The total number of 18 | remote sensing images are 10921, with five sentences descriptions per image.' 19 | """ 20 | splits = ["train", "val", "test"] 21 | 22 | def __init__( 23 | self, 24 | root: str = ".data/rsicd", 25 | split: str = "train", 26 | transform: T.Compose = T.Compose([T.ToTensor()]) 27 | ): 28 | assert split in self.splits 29 | self.root = root 30 | self.transform = transform 31 | self.captions = self.load_captions(os.path.join(root, "dataset_rsicd.json"), split) 32 | self.image_root = "RSICD_images" 33 | 34 | @staticmethod 35 | def load_captions(path: str, split: str) -> List[Dict]: 36 | with open(path) as f: 37 | captions = json.load(f)["images"] 38 | return [c for c in captions if c["split"] == split] 39 | 40 | def __len__(self) -> int: 41 | return len(self.captions) 42 | 43 | def __getitem__(self, idx: int) -> Dict: 44 | captions = self.captions[idx] 45 | path = os.path.join(self.root, self.image_root, captions["filename"]) 46 | x = Image.open(path).convert("RGB") 47 | x = self.transform(x) 48 | sentences = [sentence["raw"] for sentence in captions["sentences"]] 49 | return dict(x=x, captions=sentences) 50 | -------------------------------------------------------------------------------- /torchrs/datasets/rsscn7.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from torchvision.datasets import ImageFolder 3 | 4 | 5 | class RSSCN7(ImageFolder): 6 | """ RSSCN7 dataset from 'Deep Learning Based Feature Selection for 7 | Remote Sensing Scene Classification', Zou at al. (2015) 8 | https://ieeexplore.ieee.org/abstract/document/7272047 9 | 10 | 'The data set RSSCN7 contains 2800 remote sensing scene images, which 11 | are from seven typical scene categories, namely, the grassland, forest, 12 | farmland, parking lot, residential region, industrial region, and river 13 | and lake. For each category, there are 400 images collected from the 14 | Google Earth, which are sampled on four different scales with 100 images 15 | per scale. Each image has a size of 400x400 pixels. This data set is 16 | rather challenging due to the wide diversity of the scene images that 17 | are captured under changing seasons and varying weathers and sampled 18 | on different scales' 19 | """ 20 | def __init__( 21 | self, 22 | root: str = ".data/RSSCN7", 23 | transform: T.Compose = T.Compose([T.ToTensor()]) 24 | ): 25 | super().__init__( 26 | root=root, 27 | transform=transform 28 | ) 29 | -------------------------------------------------------------------------------- /torchrs/datasets/rsvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from glob import glob 4 | from typing import List, Dict, Tuple 5 | 6 | import torch 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from torchrs.transforms import Compose, ToTensor 11 | 12 | 13 | def sort(x): 14 | x = os.path.basename(x) 15 | x = os.path.splitext(x)[0] 16 | return int(x) 17 | 18 | 19 | class RSVQA(torch.utils.data.Dataset): 20 | """ Base RSVQA dataset """ 21 | splits = ["train", "val", "test"] 22 | prefix = "" 23 | image_root = "" 24 | 25 | def __init__( 26 | self, 27 | root: str = "", 28 | split: str = "train", 29 | image_transform: Compose = Compose([ToTensor()]), 30 | text_transform: Compose = Compose([]) 31 | ): 32 | assert split in self.splits 33 | self.root = root 34 | self.split = split 35 | self.image_transform = image_transform 36 | self.text_transform = text_transform 37 | self.image_root = os.path.join(root, self.image_root) 38 | self.ids, self.paths, self.images, self.questions, self.answers = self.load_files( 39 | self.root, self.image_root, self.split, self.prefix 40 | ) 41 | 42 | @staticmethod 43 | def load_files(root: str, image_root: str, split: str, prefix: str) -> Tuple[List[int], List[str], List[Dict], List[Dict], List[Dict]]: 44 | paths = glob(os.path.join(image_root, "*.tif")) 45 | paths = sorted(paths, key=sort) 46 | with open(os.path.join(root, f"{prefix}_split_{split}_questions.json")) as f: 47 | questions = json.load(f)["questions"] 48 | with open(os.path.join(root, f"{prefix}_split_{split}_answers.json")) as f: 49 | answers = json.load(f)["answers"] 50 | with open(os.path.join(root, f"{prefix}_split_{split}_images.json")) as f: 51 | images = json.load(f)["images"] 52 | ids = [x["id"] for x in images if x["active"]] 53 | return ids, paths, images, questions, answers 54 | 55 | def __len__(self) -> int: 56 | return len(self.ids) 57 | 58 | def __getitem__(self, idx: int) -> Dict: 59 | """ Returns a dict containing x, questions, answers, q/a category 60 | x: (3, h, w) 61 | questions: List[str] 62 | answers: List[str] 63 | types: List[str] 64 | """ 65 | id = self.ids[idx] 66 | x = np.array(Image.open(os.path.join(self.image_root, f"{id}.tif"))) 67 | x = self.image_transform(x) 68 | questions = [self.questions[i] for i in self.images[id]["questions_ids"]] 69 | answers = [self.answers[q["answers_ids"][0]]["answer"] for q in questions] 70 | types = [q["type"] for q in questions] 71 | questions = [q["question"] for q in questions] 72 | questions = self.text_transform(questions) 73 | answers = self.text_transform(answers) 74 | output = dict(x=x, questions=questions, answers=answers, types=types) 75 | return output 76 | 77 | 78 | class RSVQALR(RSVQA): 79 | """Remote Sensing Visual Question Answering Low Resolution (RSVQA LR) dataset from 80 | 'RSVQA: Visual Question Answering for Remote Sensing Data', Lobry et al (2020) 81 | https://arxiv.org/abs/2003.07333 82 | 83 | 'This dataset is based on Sentinel-2 images acquired over the Netherlands. Sentinel-2 satellites provide 10m resolution 84 | (for the visible bands used in this dataset) images with frequent updates (around 5 days) at a global scale. These images 85 | are openly available through ESA’s Copernicus Open Access Hub. To generate the dataset, we selected 9 Sentinel-2 tiles 86 | covering the Netherlands with a low cloud cover (selected tiles are shown in Figure 3). These tiles were divided in 772 87 | images of size 256x256 (covering 6.55km^2) retaining the RGB bands. From these, we constructed 770,232 questions and 88 | answers following the methodology presented in subsection II-A. We split the data in a training set (77.8% of the original tiles), 89 | a validation set (11.1%) and a test set (11.1%) at the tile level (the spatial split is shown in Figure 3). This allows 90 | to limit spatial correlation between the different splits.' 91 | """ 92 | image_root = "Images_LR" 93 | prefix = "LR" 94 | 95 | def __init__(self, root: str = ".data/RSVQA_LR", *args, **kwargs): 96 | super().__init__(root, *args, **kwargs) 97 | 98 | 99 | class RSVQAHR(RSVQA): 100 | """Remote Sensing Visual Question Answering High Resolution (RSVQA HR) dataset from 101 | 'RSVQA: Visual Question Answering for Remote Sensing Data', Lobry et al (2020) 102 | https://arxiv.org/abs/2003.07333 103 | 104 | 'This dataset uses 15cm resolution aerial RGB images extracted from the High Resolution 105 | Orthoimagery (HRO) data collection of the USGS. This collection covers most urban areas of the 106 | USA, along with a few areas of interest (e.g. national parks). For most areas covered by the dataset, 107 | only one tile is available with acquisition dates ranging from year 2000 to 2016, with various sensors. 108 | The tiles are openly accessible through USGS' EarthExplorer tool. 109 | 110 | From this collection, we extracted 161 tiles belonging to the North-East coast of the USA 111 | that were split into 100659 images of size 512x512 (each covering 5898m^2).We constructed 100,660,316 112 | questions and answers following the methodology presented in subsection II-A. We split the data in 113 | a training set (61.5% of the tiles), a validation set (11.2%), and test sets (20.5% for test set 1, 114 | 6.8% for test set 2). As it can be seen in Figure 4, test set 1 covers similar regions as the training 115 | and validation sets, while test set 2 covers the city of Philadelphia, which is not seen during the 116 | training. Note that this second test set also uses another sensor (marked as unknown on the USGS 117 | data catalog), not seen during training. 118 | """ 119 | image_root = "Data" 120 | prefix = "USGS" 121 | 122 | def __init__(self, root: str = ".data/RSVQA_HR", *args, **kwargs): 123 | super().__init__(root, *args, **kwargs) 124 | 125 | 126 | class RSVQAxBEN(RSVQA): 127 | """Remote Sensing Visual Question Answering BigEarthNet (RSVQAxBEN) dataset from 128 | 'RSVQA Meets BigEarthNet: A New, Large-Scale, Visual Question Answering Dataset for Remote Sensing', Lobry et al (2021) 129 | https://rsvqa.sylvainlobry.com/IGARSS21.pdf 130 | 131 | 'We introduce a new dataset to tackle the task of visual question answering on remote 132 | sensing images: this largescale, open access dataset extracts image/question/answer triplets 133 | from the BigEarthNet dataset. This new dataset contains close to 15 millions samples and is openly 134 | available.' 135 | """ 136 | image_root = "Images" 137 | prefix = "RSVQAxBEN" 138 | 139 | def __init__(self, root: str = ".data/rsvqaxben", *args, **kwargs): 140 | super().__init__(root, *args, **kwargs) 141 | -------------------------------------------------------------------------------- /torchrs/datasets/s2looking.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import Dict 4 | 5 | import torch 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from torchrs.transforms import Compose, ToTensor 10 | 11 | 12 | class S2Looking(torch.utils.data.Dataset): 13 | """ The Satellite Side-Looking (S2Looking) dataset from 'S2Looking: A Satellite Side-Looking 14 | Dataset for Building Change Detection', Shen at al. (2021) 15 | https://arxiv.org/abs/2107.09244 16 | 17 | 'S2Looking is a building change detection dataset that contains large-scale side-looking 18 | satellite images captured at varying off-nadir angles. The S2Looking dataset consists of 19 | 5,000 registered bitemporal image pairs (size of 1024*1024, 0.5 ~ 0.8 m/pixel) of rural 20 | areas throughout the world and more than 65,920 annotated change instances. We provide 21 | two label maps to separately indicate the newly built and demolished building regions 22 | for each sample in the dataset.' 23 | """ 24 | splits = ["train", "val", "test"] 25 | 26 | def __init__( 27 | self, 28 | root: str = ".data/s2looking", 29 | split: str = "train", 30 | transform: Compose = Compose([ToTensor()]), 31 | ): 32 | assert split in self.splits 33 | self.root = root 34 | self.transform = transform 35 | self.files = self.load_files(root, split) 36 | 37 | @staticmethod 38 | def load_files(root: str, split: str): 39 | files = [] 40 | images = glob(os.path.join(root, split, "Image1", "*.png")) 41 | images = sorted([os.path.basename(image) for image in images]) 42 | for image in images: 43 | image1 = os.path.join(root, split, "Image1", image) 44 | image2 = os.path.join(root, split, "Image2", image) 45 | build_mask = os.path.join(root, split, "label1", image) 46 | demo_mask = os.path.join(root, split, "label2", image) 47 | files.append(dict(image1=image1, image2=image2, build_mask=build_mask, demolish_mask=demo_mask)) 48 | return files 49 | 50 | def __len__(self) -> int: 51 | return len(self.files) 52 | 53 | def __getitem__(self, idx: int) -> Dict: 54 | """ Returns a dict containing x, mask 55 | x: (2, 13, h, w) 56 | build_mask: (1, h, w) 57 | demolish_mask: (1, h, w) 58 | """ 59 | files = self.files[idx] 60 | build_mask = np.array(Image.open(files["build_mask"])) 61 | demo_mask = np.array(Image.open(files["demolish_mask"])) 62 | build_mask = np.clip(build_mask.mean(axis=-1), 0, 1).astype("uint8") 63 | demo_mask = np.clip(demo_mask.mean(axis=-1), 0, 1).astype("uint8") 64 | image1 = np.array(Image.open(files["image1"])) 65 | image2 = np.array(Image.open(files["image2"])) 66 | image1, image2, build_mask, demo_mask = self.transform([image1, image2, build_mask, demo_mask]) 67 | x = torch.stack([image1, image2], dim=0) 68 | return dict(x=x, build_mask=build_mask, demolish_mask=demo_mask) 69 | -------------------------------------------------------------------------------- /torchrs/datasets/s2mtcp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import List, Dict 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from torchrs.transforms import Compose, ToTensor 9 | 10 | 11 | class S2MTCP(torch.utils.data.Dataset): 12 | """ Sentinel-2 Multitemporal Cities Pairs (S2MTCP) dataset from 'Self-supervised 13 | pre-training enhances change detection in Sentinel-2 imagery', Leenstra at al. (2021) 14 | https://arxiv.org/abs/2101.08122 15 | 16 | 'A dataset of S2 level 1C image pairs was created ... Image pairs are selected randomly 17 | from available S2 images of each location with less than one percent cloud cover. Bands 18 | with a spatial resolution smaller than 10 m are resampled to 10 m and images are cropped 19 | to approximately 600x600 pixels centered on the selected coordinates ... The S2MTCP dataset 20 | contains N = 1520 image pairs, spread over all inhabited continents, with the highest 21 | concentration of image pairs in North-America, Europe and Asia' 22 | 23 | Note this dataset doesn't contain change masks as it was created for use in self-supervised pretraining 24 | """ 25 | def __init__( 26 | self, 27 | root: str = ".data/s2mtcp", 28 | transform: Compose = Compose([ToTensor()]), 29 | ): 30 | self.root = os.path.join(root, "data_S21C") 31 | self.transform = transform 32 | self.files = self.load_files(self.root) 33 | 34 | @staticmethod 35 | def load_files(root: str) -> List[Dict]: 36 | files = glob(os.path.join(root, "*.npy")) 37 | files = [os.path.basename(f).split("_")[0] for f in files] 38 | files = sorted(set(files), key=int) 39 | files = [dict(image1=f"{num}_a.npy", image2=f"{num}_b.npy") for num in files] 40 | return files 41 | 42 | def __len__(self) -> int: 43 | return len(self.files) 44 | 45 | def __getitem__(self, idx: int) -> torch.Tensor: 46 | """ Returns x: (2, 14, h, w) """ 47 | files = self.files[idx] 48 | image1 = np.load(os.path.join(self.root, files["image1"])) 49 | image2 = np.load(os.path.join(self.root, files["image2"])) 50 | image1, image2 = self.transform([image1, image2]) 51 | x = torch.stack([image1, image2], dim=0) 52 | return x 53 | -------------------------------------------------------------------------------- /torchrs/datasets/sat.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import h5py 4 | import torch 5 | import torchvision.transforms as T 6 | 7 | 8 | class SAT(torch.utils.data.Dataset): 9 | """ Base SAT dataset """ 10 | splits = ["train", "test"] 11 | 12 | def __init__( 13 | self, 14 | root: str = "", 15 | split: str = "train", 16 | transform: T.Compose = T.Compose([T.ToTensor()]) 17 | ): 18 | assert split in self.splits 19 | self.root = root 20 | self.split = split 21 | self.transform = transform 22 | 23 | def __len__(self) -> int: 24 | with h5py.File(self.root, "r") as f: 25 | length = f[f"{self.split}_y"].shape[0] 26 | return length 27 | 28 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 29 | with h5py.File(self.root, "r") as f: 30 | x = f[f"{self.split}_x"][idx] 31 | y = f[f"{self.split}_y"][idx] 32 | x = self.transform(x) 33 | return x, y 34 | 35 | 36 | class SAT4(SAT): 37 | """ SAT-4 land cover classification dataset from "DeepSat - A Learning framework 38 | for Satellite Imagery", Basu et al (2015) 39 | https://arxiv.org/abs/1509.03602 40 | 41 | 'SAT-4 consists of a total of 500,000 image patches covering four broad land cover classes. 42 | These include - barren land, trees, grassland and a class that consists of all land cover 43 | classes other than the above three. 400,000 patches (comprising of four-fifths of the total 44 | dataset) were chosen for training and the remaining 100,000 (one-fifths) were chosen as the testing 45 | dataset. We ensured that the training and test datasets belong to disjoint set of image tiles. 46 | Each image patch is size normalized to 28x28 pixels. Once generated, both the training and testing 47 | datasets were randomized using a pseudo-random number generator.' 48 | """ 49 | classes = [ 50 | "barren land", 51 | "trees", 52 | "grassland", 53 | "other" 54 | ] 55 | 56 | def __init__(self, root: str = ".data/sat/sat4.h5", *args, **kwargs): 57 | super().__init__(root, *args, **kwargs) 58 | 59 | 60 | class SAT6(SAT): 61 | """ SAT-6 land cover classification dataset from "DeepSat - A Learning framework 62 | for Satellite Imagery", Basu et al (2015) 63 | https://arxiv.org/abs/1509.03602 64 | 65 | 'SAT-6 consists of a total of 405,000 image patches each of size 28x28 and covering 6 66 | landcover classes - barren land, trees, grassland, roads, buildings and water bodies. 67 | 324,000 images (comprising of four-fifths of the total dataset) were chosen as the training 68 | dataset and 81,000 (one fifths) were chosen as the testing dataset. Similar to SAT-4, 69 | the training and test sets were selected from disjoint NAIP tiles. Once generated, the 70 | images in the dataset were randomized in the same way as that for SAT-4. The specifications 71 | for the various landcover classes of SAT-4 and SAT-6 were adopted from those used in the 72 | National Land Cover Data (NLCD) algorithm.' 73 | """ 74 | classes = [ 75 | "barren land", 76 | "trees", 77 | "grassland", 78 | "roads", 79 | "buildings", 80 | "water" 81 | ] 82 | 83 | def __init__(self, root: str = ".data/sat/sat6.h5", *args, **kwargs): 84 | super().__init__(root, *args, **kwargs) 85 | -------------------------------------------------------------------------------- /torchrs/datasets/sydney_captions.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets import RSICD 6 | 7 | 8 | class SydneyCaptions(RSICD): 9 | """ Sydney Captions dataset from 'Deep semantic understanding of 10 | high resolution remote sensing image', Qu et al (2016) 11 | https://ieeexplore.ieee.org/document/7546397 12 | 13 | 'The Sydney dataset contains 7 different scene categories and totally has 613 HSR images 14 | ... Then every HSR image is annotated with 5 reference sentences 15 | ... has 613 images with 3065 captions' 16 | """ 17 | splits = ["train", "val", "test"] 18 | 19 | def __init__( 20 | self, 21 | root: str = ".data/sydney_captions", 22 | split: str = "train", 23 | transform: T.Compose = T.Compose([T.ToTensor()]) 24 | ): 25 | assert split in self.splits 26 | self.root = root 27 | self.transform = transform 28 | self.captions = self.load_captions(os.path.join(root, "dataset.json"), split) 29 | self.image_root = "images" 30 | -------------------------------------------------------------------------------- /torchrs/datasets/tiselac.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple 3 | 4 | import torch 5 | import numpy as np 6 | from einops import rearrange 7 | 8 | from torchrs.transforms import Compose, ToTensor 9 | 10 | 11 | class Tiselac(torch.utils.data.Dataset): 12 | """ TiSeLac dataset from the Time Series Land Cover Classification Challenge (2017) 13 | https://sites.google.com/site/dinoienco/tiselac-time-series-land-cover-classification-challenge 14 | 15 | 'A MSTC Land Cover classification problem for data taken from the Reunion island. 16 | A case is a pixel. Measurements are taken over 23 time points (days), with 17 | 10 dimensions: 7 surface reflectances (Ultra Blue, Blue, Green, Red, NIR, SWIR1 and SWIR2) 18 | plus 3 indices (NDVI, NDWI and BI). Class values relate to one of 9 land cover types class values.' 19 | """ 20 | classes = [ 21 | "Urban Areas", 22 | "Other built-up surfaces", 23 | "Forests", 24 | "Sparse Vegetation", 25 | "Rocks and bare soil", 26 | "Grassland", 27 | "Sugarcane crops", 28 | "Other crops", 29 | "Water" 30 | ] 31 | splits = ["train", "test"] 32 | 33 | def __init__( 34 | self, 35 | root: str = ".data/tiselac", 36 | split: str = "train", 37 | transform: Compose = Compose([ToTensor(permute_dims=False)]) 38 | ): 39 | assert split in self.splits 40 | self.root = root 41 | self.transform = transform 42 | self.series, self.labels = self.load_file(root, split) 43 | 44 | @staticmethod 45 | def load_file(path: str, split: str) -> Tuple[np.ndarray, np.ndarray]: 46 | x = np.loadtxt(os.path.join(path, f"{split}.txt"), dtype=np.int16, delimiter=",") 47 | y = np.loadtxt(os.path.join(path, f"{split}_labels.txt"), dtype=np.uint8) 48 | x = rearrange(x, "n (t c) -> n t c", c=10) 49 | return x, y 50 | 51 | def __len__(self) -> int: 52 | return len(self.series) 53 | 54 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 55 | x, y = self.series[idx], self.labels[idx] - 1 56 | x, y = self.transform(x).squeeze(dim=0), torch.tensor(y).to(torch.long) 57 | return x, y 58 | -------------------------------------------------------------------------------- /torchrs/datasets/ucm.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision.transforms as T 4 | from torchvision.datasets import ImageFolder 5 | 6 | 7 | class UCM(ImageFolder): 8 | """ UC Merced Land Use dataset from 'Bag-Of-Visual-Words and 9 | Spatial Extensions for Land-Use Classification', Yang at al. (2010) 10 | https://faculty.ucmerced.edu/snewsam/papers/Yang_ACMGIS10_BagOfVisualWords.pdf 11 | 12 | """ 13 | def __init__( 14 | self, 15 | root: str = ".data/UCMerced_LandUse", 16 | transform: T.Compose = T.Compose([T.ToTensor()]) 17 | ): 18 | super().__init__( 19 | root=os.path.join(root, "Images"), 20 | transform=transform 21 | ) 22 | -------------------------------------------------------------------------------- /torchrs/datasets/ucm_captions.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets import RSICD 6 | 7 | 8 | class UCMCaptions(RSICD): 9 | """ UC Merced (UCM) Captions dataset from 'Deep semantic understanding of 10 | high resolution remote sensing image', Qu et al (2016) 11 | https://ieeexplore.ieee.org/document/7546397 12 | 13 | 'The UCM dataset totally has 2100 HSR images which are divided into 21 challenging 14 | scene categories ... Then every HSR image is annotated with 5 reference sentences 15 | ... one has totally 2100 HSR remote sensing images with 10500 descriptions' 16 | """ 17 | splits = ["train", "val", "test"] 18 | 19 | def __init__( 20 | self, 21 | root: str = ".data/ucm_captions", 22 | split: str = "train", 23 | transform: T.Compose = T.Compose([T.ToTensor()]) 24 | ): 25 | assert split in self.splits 26 | self.root = root 27 | self.transform = transform 28 | self.captions = self.load_captions(os.path.join(root, "dataset.json"), split) 29 | self.image_root = "images" 30 | -------------------------------------------------------------------------------- /torchrs/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from torch.utils.data import Dataset, random_split 4 | 5 | 6 | def dataset_split(dataset: Dataset, val_pct: float, test_pct: Optional[float] = None) -> List[Dataset]: 7 | """ Split a torch Dataset into train/val/test sets """ 8 | if test_pct is None: 9 | val_length = int(len(dataset) * val_pct) 10 | train_length = len(dataset) - val_length 11 | return random_split(dataset, [train_length, val_length]) 12 | else: 13 | val_length = int(len(dataset) * val_pct) 14 | test_length = int(len(dataset) * test_pct) 15 | train_length = len(dataset) - (val_length + test_length) 16 | return random_split(dataset, [train_length, val_length, test_length]) 17 | -------------------------------------------------------------------------------- /torchrs/datasets/whu_rs19.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | from torchvision.datasets import ImageFolder 3 | 4 | 5 | class WHURS19(ImageFolder): 6 | """ WHU-RS19 dataset from'Structural High-resolution Satellite Image Indexing', Xia at al. (2010) 7 | https://hal.archives-ouvertes.fr/file/index/docid/458685/filename/structural_satellite_indexing_XYDG.pdf 8 | 9 | """ 10 | def __init__( 11 | self, 12 | root: str = ".data/WHU-RS19", 13 | transform: T.Compose = T.Compose([T.ToTensor()]) 14 | ): 15 | super().__init__( 16 | root=root, 17 | transform=transform 18 | ) 19 | -------------------------------------------------------------------------------- /torchrs/datasets/zuericrop.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | from typing import Dict 4 | 5 | import h5py 6 | import torch 7 | 8 | from torchrs.transforms import Compose, ToTensor 9 | 10 | 11 | class ZueriCrop(torch.utils.data.Dataset): 12 | """ ZueriCrop dataset from 'Crop mapping from image time series: 13 | deep learning with multi-scale label hierarchies', Turkoglu et al. (2021) 14 | https://arxiv.org/abs/2102.08820 15 | 16 | 'We provide a new, publicly available crop classification dataset ZueriCrop, equipped 17 | with a tree-structured label hierarchy. ZueriCrop covers a 50 km × 48 km area 18 | in the Swiss cantons of Zurich and Thurgau. It contains 28,000 Sentinel-2 image 19 | patches of size 24 pixels × 24 pixels, each observed 71 times over a period of 20 | 52 weeks; 48 agricultural land cover classes; and 116,000 individual agricultural fields.' 21 | 22 | """ 23 | ZueriCropClass = namedtuple("ZueriCropClass", ["id", "label", "level1", "level2", "level3", "description"]) 24 | classes = [ 25 | ZueriCropClass(0, "Unknown", "Unknown", "Unknown", "Unknown", "Unknown"), 26 | ZueriCropClass(1, "SummerBarley", "Vegetation", "Field crops", "SmallGrainCereal", "summer barley"), 27 | ZueriCropClass(2, "WinterBarley", "Vegetation", "Field crops", "SmallGrainCereal", "Winter barley"), 28 | ZueriCropClass(3, "Oat", "Vegetation", "Field crops", "SmallGrainCereal", "oats"), 29 | ZueriCropClass(4, "Wheat", "Vegetation", "Field crops", "SmallGrainCereal", "triticale"), 30 | ZueriCropClass(5, "Grain", "Vegetation", "Field crops", "SmallGrainCereal", "mixed fodder cereals"), 31 | ZueriCropClass(6, "Wheat", "Vegetation", "Field crops", "SmallGrainCereal", "Feed wheat according to the list of varieties swiss granum"), 32 | ZueriCropClass(7, "Maize", "Vegetation", "Field crops", "LargeGrainCereal", "grain maize"), 33 | ZueriCropClass(8, "EinkornWheat", "Vegetation", "Field crops", "SmallGrainCereal", "wild emmer, einkorn wheat"), 34 | ZueriCropClass(9, "SummerWheat", "Vegetation", "Field crops", "SmallGrainCereal", "summer wheat"), 35 | ZueriCropClass(10, "WinterWheat", "Vegetation", "Field crops", "SmallGrainCereal", "Winter wheat (without forage wheat of the variety list swiss granum),"), 36 | ZueriCropClass(11, "Rye", "Vegetation", "Field crops", "SmallGrainCereal", "rye"), 37 | ZueriCropClass(12, "Grain", "Vegetation", "Field crops", "SmallGrainCereal", "mixed cereals for bread production"), 38 | ZueriCropClass(13, "Spelt", "Vegetation", "Field crops", "SmallGrainCereal", "Spelt"), 39 | ZueriCropClass(14, "Maize", "Vegetation", "Field crops", "LargeGrainCereal", "maize for seed production (by contract),"), 40 | ZueriCropClass(15, "Maize", "Vegetation", "Field crops", "LargeGrainCereal", "Mads d'ensilage et mads vert"), 41 | ZueriCropClass(16, "Sugar_beets", "Vegetation", "Field crops", "BroadLeafRowCrop", "Sugar beets"), 42 | ZueriCropClass(17, "Beets", "Vegetation", "Field crops", "BroadLeafRowCrop", "fodder beet"), 43 | ZueriCropClass(18, "Potatoes", "Vegetation", "Field crops", "BroadLeafRowCrop", "potatoes"), 44 | ZueriCropClass(19, "Potatoes", "Vegetation", "Field crops", "BroadLeafRowCrop", "potatoes for planting (by contract),"), 45 | ZueriCropClass(20, "SummerRapeseed", "Vegetation", "Field crops", "BroadLeafRowCrop", "summer rapeseed for oil production"), 46 | ZueriCropClass(21, "WinterRapeseed", "Vegetation", "Field crops", "BroadLeafRowCrop", "Winter rape for cooking oil"), 47 | ZueriCropClass(22, "Soy", "Vegetation", "Field crops", "BroadLeafRowCrop", "soy"), 48 | ZueriCropClass(23, "Sunflowers", "Vegetation", "Field crops", "BroadLeafRowCrop", "Sunflowers"), 49 | ZueriCropClass(24, "Linen", "Vegetation", "Field crops", "BroadLeafRowCrop", "flax"), 50 | ZueriCropClass(25, "Hemp", "Vegetation", "Field crops", "BroadLeafRowCrop", "hemp"), 51 | ZueriCropClass(26, "Field bean", "Vegetation", "Field crops", "BroadLeafRowCrop", "Field beans for animal feed"), 52 | ZueriCropClass(27, "Peas", "Vegetation", "Field crops", "BroadLeafRowCrop", "protein peas for animal fodder"), 53 | ZueriCropClass(28, "Lupine", "Vegetation", "Field crops", "BroadLeafRowCrop", "lupine for fodder"), 54 | ZueriCropClass(29, "Pumpkin", "Vegetation", "Field crops", "VegetableCrop", "oil pumpkins"), 55 | ZueriCropClass(30, "Tobacco", "Vegetation", "Field crops", "BroadLeafRowCrop", "tobacco"), 56 | ZueriCropClass(31, "Sorghum", "Vegetation", "Field crops", "LargeGrainCereal", "millet"), 57 | ZueriCropClass(32, "Grain", "Vegetation", "Field crops", "SmallGrainCereal", "Ensiled grain"), 58 | ZueriCropClass(33, "Linen", "Vegetation", "Field crops", "BroadLeafRowCrop", "false flax (camelina sativa),"), 59 | ZueriCropClass(34, "Vegetables", "Vegetation", "Field crops", "VegetableCrop", "Annual free-range vegetables, without canned vegetables"), 60 | ZueriCropClass(35, "Vegetables", "Vegetation", "Field crops", "VegetableCrop", "Ground-canned vegetables"), 61 | ZueriCropClass(36, "Chicory", "Vegetation", "Field crops", "VegetableCrop", "chicory roots"), 62 | ZueriCropClass(37, "Buckwheat", "Vegetation", "Field crops", "SmallGrainCereal", "buckwheat"), 63 | ZueriCropClass(38, "Sorghum", "Vegetation", "Field crops", "LargeGrainCereal", "sorghum"), 64 | ZueriCropClass(39, "Berries", "Vegetation", "Special crops", "Berries", "Annual berries (e.g., strawberries),"), 65 | ZueriCropClass(40, "Unknown", "Vegetation", "Special crops", "Unknown", "single-year regrowing ressources (Kenaf and others),"), 66 | ZueriCropClass(41, "Unknown", "Vegetation", "Special crops", "Unknown", "Annual spice and medicinal plants"), 67 | ZueriCropClass(42, "Unknown", "Vegetation", "Special crops", "Unknown", "Annual horticultural outdoor crops (flowers, turf etc.),"), 68 | ZueriCropClass(43, "Biodiversity encouragement area", "Vegetation", "Grassland", "BiodiversityArea", "Conservation headlands"), 69 | ZueriCropClass(44, "Fallow", "Vegetation", "Special crops", "Fallow", "fallow"), 70 | ZueriCropClass(45, "Fallow", "Vegetation", "Special crops", "Fallow", "rotational set-aside"), 71 | ZueriCropClass(46, "Unknown", "Vegetation", "Grassland", "Meadow", "Hem on arable land"), 72 | ZueriCropClass(47, "Unknown", "Vegetation", "Special crops", "Unknown", "Poppy"), 73 | ZueriCropClass(48, "Unknown", "Vegetation", "Special crops", "Unknown", "safflower"), 74 | ZueriCropClass(49, "Unknown", "Vegetation", "Field crops", "CropMix", "lentils"), 75 | ZueriCropClass(50, "MixedCrop", "Vegetation", "Field crops", "CropMix", "Mixtures of field beans, protein peas and lupines for animal feed with cereals, at least 30% legume content at harvest"), 76 | ZueriCropClass(51, "Biodiversity encouragement area", "Vegetation", "Grassland", "BiodiversityArea", "Bloom strips for pollinators and other beneficials"), 77 | ZueriCropClass(52, "Mustard", "Vegetation", "Field crops", "BroadLeafRowCrop", "mustard"), 78 | ZueriCropClass(53, "WinterRapeseed", "Vegetation", "Field crops", "BroadLeafRowCrop", "Winter rape as a renewable raw material"), 79 | ZueriCropClass(54, "Sunflowers", "Vegetation", "Field crops", "BroadLeafRowCrop", "Sunflowers as regrowing ressource"), 80 | ZueriCropClass(55, "Unknown", "Bare soil", "Unknown", "Unknown", "open arable land, eligible for subsidies (region-specific biodiversity area),"), 81 | ZueriCropClass(56, "Unknown", "Bare soil", "Unknown", "Unknown", "Other open arable land, eligible"), 82 | ZueriCropClass(57, "Unknown", "Bare soil", "Unknown", "Unknown", "Other open arable land, not eligible"), 83 | ZueriCropClass(58, "Meadow", "Vegetation", "Grassland", "Meadow", "Art meadows (without pastures),"), 84 | ZueriCropClass(59, "Meadow", "Vegetation", "Grassland", "Meadow", "Other artificial meadow, eligible (eg pork pasture, poultry pasture),"), 85 | ZueriCropClass(60, "Meadow", "Vegetation", "Grassland", "Meadow", "Extensively used meadows (without pastures),"), 86 | ZueriCropClass(61, "Meadow", "Vegetation", "Grassland", "Meadow", "Little intensively used meadows (without pastures),"), 87 | ZueriCropClass(62, "Meadow", "Vegetation", "Grassland", "Meadow", "Other permanent pastures (without pastures),"), 88 | ZueriCropClass(63, "Pasture", "Vegetation", "Grassland", "Pasture", "Pastures (pastures, other pastures without summer pastures),"), 89 | ZueriCropClass(64, "Pasture", "Vegetation", "Grassland", "Pasture", "Extensively used pastures"), 90 | ZueriCropClass(65, "Pasture", "Vegetation", "Grassland", "Pasture", "Forest pastures (without wooded area),"), 91 | ZueriCropClass(66, "Meadow", "Vegetation", "Grassland", "Meadow", "Hay meadows in the summering area, other meadows"), 92 | ZueriCropClass(67, "Meadow", "Vegetation", "Grassland", "Meadow", "Hay meadows in the summering area, type extensively used meadow"), 93 | ZueriCropClass(68, "Pasture", "Vegetation", "Grassland", "Pasture", "Forest pastures (without wooded area),"), 94 | ZueriCropClass(69, "Legumes", "Vegetation", "Field crops", "BroadLeafRowCrop", "fodder legumes (Fabaceae), for seed production (by contract),"), 95 | ZueriCropClass(70, "Unknown", "Vegetation", "Grassland", "Unknown", "fodder grasses for seed production (by contract),"), 96 | ZueriCropClass(71, "Meadow", "Vegetation", "Grassland", "Meadow", "Riverside meadows along rivers (without pastures),"), 97 | ZueriCropClass(72, "Unknown", "Vegetation", "Grassland", "BiodiversityArea", "Other green area (permanent green area),, entitled to contributions"), 98 | ZueriCropClass(73, "Unknown", "Vegetation", "Grassland", "Unknown", "Remaining green area (permanent green areas),, not eligible"), 99 | ZueriCropClass(74, "Vines", "Vegetation", "Orchards", "OrchardCrop", "vines"), 100 | ZueriCropClass(75, "Apples", "Vegetation", "Orchards", "OrchardCrop", "Fruit plants (apples),"), 101 | ZueriCropClass(76, "Pears", "Vegetation", "Orchards", "OrchardCrop", "Fruit plants (pears),"), 102 | ZueriCropClass(77, "StoneFruit", "Vegetation", "Orchards", "OrchardCrop", "Fruit plants (Steinobs),"), 103 | ZueriCropClass(78, "Berries", "Vegetation", "Special crops", "Berries", "Perennial berries"), 104 | ZueriCropClass(79, "Unknown", "Vegetation", "Special crops", "Unknown", "Perennial spice and medicinal plants"), 105 | ZueriCropClass(80, "Unknown", "Vegetation", "Special crops", "Unknown", "Perennial renewable resources (miscanthus, etc.),"), 106 | ZueriCropClass(81, "Hops", "Vegetation", "Orchards", "OrchardCrop", "hop"), 107 | ZueriCropClass(82, "Unknown", "Vegetation", "Special crops", "Unknown", "rhubarb"), 108 | ZueriCropClass(83, "Unknown", "Vegetation", "Special crops", "Unknown", "asparagus"), 109 | ZueriCropClass(84, "TreeCrop", "Vegetation", "Orchards", "TreeCrop", "Christmas trees"), 110 | ZueriCropClass(85, "TreeCrop", "Vegetation", "Orchards", "TreeCrop", "Nursery of forest plants outside the forest zone"), 111 | ZueriCropClass(86, "Unknown", "Vegetation", "Special crops", "Unknown", "Ornamental shrubs, ornamental shrubs and ornamental shrubs"), 112 | ZueriCropClass(87, "Unknown", "Vegetation", "Special crops", "Unknown", "Other nurseries (roses, fruits, etc.),"), 113 | ZueriCropClass(88, "Vines", "Vegetation", "Orchards", "OrchardCrop", "Vineyards with natural biodiversity"), 114 | ZueriCropClass(89, "Unknown", "Vegetation", "Special crops", "Unknown", "Truffle plants (in production),"), 115 | ZueriCropClass(90, "Unknown", "Vegetation", "Special crops", "Unknown", "Mulberry trees (feeding silkworms),"), 116 | ZueriCropClass(91, "Chestnut", "Vegetation", "Orchards", "OrchardCrop", "Cultivated selven (chestnut trees),"), 117 | ZueriCropClass(92, "Unknown", "Vegetation", "Special crops", "Unknown", "Perennial horticultural outdoor crops (not in the greenhouse),"), 118 | ZueriCropClass(93, "Vines", "Vegetation", "Orchards", "OrchardCrop", "vines nursery"), 119 | ZueriCropClass(94, "Unknown", "Vegetation", "Special crops", "Unknown", "Other fruit plants (kiwis, elderberries, etc.),"), 120 | ZueriCropClass(95, "Vines", "Vegetation", "Orchards", "OrchardCrop", "vines (region-specific biodiversity area),"), 121 | ZueriCropClass(96, "Unknown", "Vegetation", "Special crops", "BiodiversityArea", "Other areas with permanent crops, eligible"), 122 | ZueriCropClass(97, "Unknown", "Vegetation", "Special crops", "Unknown", "Other areas with permanent crops, not eligible"), 123 | ZueriCropClass(98, "Vegetables", "Infrastructure", "Unknown", "Greenhouse", "Vegetable crops in greenhouses with solid foundations"), 124 | ZueriCropClass(99, "Special cultures", "Infrastructure", "Unknown", "Greenhouse", "Other specialized crops in greenhouses with solid foundations"), 125 | ZueriCropClass(100, "Special cultures", "Infrastructure", "Unknown", "Greenhouse", "Horticultural crops in greenhouses with solid foundations"), 126 | ZueriCropClass(101, "Vegetables", "Unknown", "Unknown", "ProtectedCultivation", "Vegetable crops in protected cultivation without firm foundations"), 127 | ZueriCropClass(102, "Special cultures", "Unknown", "Unknown", "ProtectedCultivation", "Other special crops in protected cultivation without firm foundations"), 128 | ZueriCropClass(103, "Special cultures", "Unknown", "Unknown", "ProtectedCultivation", "Horticultural crops in protected cultivation without firm foundations"), 129 | ZueriCropClass(104, "Special cultures", "Unknown", "Unknown", "ProtectedCultivation", "Other crops in protected cultivation without a firm foundation"), 130 | ZueriCropClass(105, "Unknown", "Infrastructure", "Unknown", "Unknown", "other cultures in protected cultivation with solid foundation),"), 131 | ZueriCropClass(106, "Special cultures", "Unknown", "Unknown", "ProtectedCultivation", "Other crops in protected cultivation without firm foundations, not eligible"), 132 | ZueriCropClass(107, "Unknown", "Vegetation", "Special crops", "Unknown", "Scattering areas in the LN"), 133 | ZueriCropClass(108, "Hedge", "Vegetation", "Special crops", "Hedge", "Hedge, field and bank shrubs (with herbaceous area),"), 134 | ZueriCropClass(109, "Hedge", "Vegetation", "Special crops", "Hedge", "Hedgerow, field and bank shrubs (with buffer strips),"), 135 | ZueriCropClass(110, "Hedge", "Vegetation", "Special crops", "Hedge", "Hedgerow, field and bank shrubs (with buffer strips), (region-specific biodiversity production area),"), 136 | ZueriCropClass(111, "Multiple", "Undefined", "Unknown", "Unknown", "Other areas within the LN, entitled to contribute"), 137 | ZueriCropClass(112, "Multiple", "Undefined", "Unknown", "Unknown", "Other areas within the LN, not eligible"), 138 | ZueriCropClass(113, "Forest", "Vegetation", "Forest", "Forest", "Forest"), 139 | ZueriCropClass(114, "Multiple", "Vegetation", "Special crops", "Multiple", "Other unproductive areas (eg mulched areas, heavily weedy areas, hedges without buffer strips),"), 140 | ZueriCropClass(115, "Non agriculture", "Undefined", "Unknown", "Undefined", "Areas without main agricultural purpose (developed building land, playground, riding, camping, golf, air and military spaces"), 141 | ZueriCropClass(116, "Waters", "Undefined", "Unknown", "Unknown", "Ditches, ponds, ponds"), 142 | ZueriCropClass(117, "Non agriculture", "Infrastructure", "Unknown", "Unknown", "Ruderal areas, cairns and ramparts"), 143 | ZueriCropClass(118, "Multiple", "Infrastructure", "Unknown", "Unknown", "dry stone walls"), 144 | ZueriCropClass(119, "Unknown", "Bare soil", "Unknown", "Unknown", "non-asphalted, natural paths"), 145 | ZueriCropClass(120, "Biodiversity encouragement area", "Vegetation", "Grassland", "BiodiversityArea", "Region-specific biodiversity promotion areas"), 146 | ZueriCropClass(121, "Gardens", "Vegetation", "Special crops", "Gardens", "home gardens"), 147 | ZueriCropClass(122, "Unknown", "Infrastructure", "Unknown", "Unknown", "agricultural production in buildings (e.g. champignons, brussel sprouts),"), 148 | ZueriCropClass(123, "Pasture", "Vegetation", "Grassland", "Pasture", "Summer pastures"), 149 | ZueriCropClass(124, "Non agriculture", "Undefined", "Unknown", "Undefined", "Other areas outside the LN and SF") 150 | ] 151 | 152 | def __init__( 153 | self, 154 | root: str = ".data/zuericrop", 155 | transform: Compose = Compose([ToTensor()]), 156 | ): 157 | self.transform = transform 158 | self.f = h5py.File(os.path.join(root, "ZueriCrop.hdf5"), "r") 159 | 160 | def __len__(self) -> int: 161 | return self.f["data"].shape[0] 162 | 163 | def __getitem__(self, idx: int) -> Dict: 164 | x = self.f["data"][idx, ...] 165 | mask = self.f["gt"][idx, ...] 166 | instance_mask = self.f["gt_instance"][idx, ...] 167 | x, mask, instance_mask = self.transform([x, mask, instance_mask]) 168 | return dict(x=x, mask=mask, instance_mask=instance_mask) 169 | -------------------------------------------------------------------------------- /torchrs/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .rams import RAMS 2 | from .oscd import EarlyFusion, Siam 3 | from .fc_cd import FCEF, FCSiamConc, FCSiamDiff 4 | 5 | 6 | __all__ = [ 7 | "RAMS", "EarlyFusion", "Siam", "FCEF", "FCSiamConc", "FCSiamDiff" 8 | ] 9 | -------------------------------------------------------------------------------- /torchrs/models/fc_cd.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | 10 | def __init__(self, filters: List[int], kernel_size: int = 3, dropout: float = 0.2, pool: bool = True): 11 | super().__init__() 12 | layers = [] 13 | for i in range(1, len(filters)): 14 | layers.extend([ 15 | nn.Conv2d(filters[i - 1], filters[i], kernel_size, stride=1, padding=kernel_size//2), 16 | nn.BatchNorm2d(filters[i]), 17 | nn.ReLU(), 18 | nn.Dropout(dropout), 19 | ]) 20 | self.model = nn.Sequential(*layers) 21 | self.pool = nn.MaxPool2d(kernel_size=2) if pool else nn.Identity() 22 | 23 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 24 | x = self.model(x) 25 | return self.pool(x), x 26 | 27 | 28 | class DeConvBlock(nn.Sequential): 29 | 30 | def __init__(self, filters: List[int], kernel_size: int = 3, dropout: float = 0.2): 31 | super().__init__( 32 | *[nn.Sequential( 33 | nn.ConvTranspose2d(filters[i - 1], filters[i], kernel_size, padding=kernel_size//2), 34 | nn.BatchNorm2d(filters[i]), 35 | nn.ReLU(), 36 | nn.Dropout(dropout) 37 | ) for i in range(1, len(filters))] 38 | ) 39 | 40 | 41 | class UpsampleBlock(nn.Sequential): 42 | 43 | def __init__(self, channels: int, kernel_size: int = 3): 44 | super().__init__( 45 | nn.ConvTranspose2d(channels, channels, kernel_size, padding=kernel_size//2, stride=2, output_padding=1) 46 | ) 47 | 48 | 49 | class Encoder(nn.ModuleList): 50 | 51 | def __init__(self, in_channels: int = 3): 52 | super().__init__([ 53 | ConvBlock([in_channels, 16, 16]), 54 | ConvBlock([16, 32, 32]), 55 | ConvBlock([32, 64, 64, 64]), 56 | ConvBlock([64, 128, 128, 128]) 57 | ]) 58 | 59 | 60 | class Decoder(nn.ModuleList): 61 | 62 | def __init__(self, num_classes: int = 2): 63 | super().__init__([ 64 | DeConvBlock([256, 128, 128, 64]), 65 | DeConvBlock([128, 64, 64, 32]), 66 | DeConvBlock([64, 32, 16]), 67 | DeConvBlock([32, 16, num_classes]) 68 | ]) 69 | 70 | 71 | class SiamEncoder(nn.ModuleList): 72 | 73 | def __init__(self, in_channels: int = 3): 74 | super().__init__([ 75 | ConvBlock([in_channels, 16, 16]), 76 | ConvBlock([16, 32, 32]), 77 | ConvBlock([32, 64, 64, 64]), 78 | ConvBlock([64, 128, 128, 128], pool=False) 79 | ]) 80 | 81 | 82 | class ConcatDecoder(nn.ModuleList): 83 | 84 | def __init__(self, t: int = 2, num_classes: int = 2): 85 | scale = 0.5 * (t + 1) 86 | super().__init__([ 87 | DeConvBlock([int(256 * scale), 128, 128, 64]), 88 | DeConvBlock([int(128 * scale), 64, 64, 32]), 89 | DeConvBlock([int(64 * scale), 32, 16]), 90 | DeConvBlock([int(32 * scale), 16, num_classes]) 91 | ]) 92 | 93 | 94 | class Upsample(nn.ModuleList): 95 | 96 | def __init__(self): 97 | super().__init__([ 98 | UpsampleBlock(128), 99 | UpsampleBlock(64), 100 | UpsampleBlock(32), 101 | UpsampleBlock(16) 102 | ]) 103 | 104 | 105 | class FCEF(nn.Module): 106 | """ Fully-convolutional Early Fusion (FC-EF) from 107 | 'Fully Convolutional Siamese Networks for Change Detection', Daudt et al. (2018) 108 | https://arxiv.org/abs/1810.08462 109 | """ 110 | def __init__(self, channels: int = 3, t: int = 2, num_classes: int = 2): 111 | super().__init__() 112 | self.encoder = Encoder(channels * t) 113 | self.decoder = Decoder(num_classes) 114 | self.upsample = Upsample() 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | b, t, c, h, w = x.shape 118 | x = rearrange(x, "b t c h w -> b (t c) h w") 119 | 120 | skips = [] 121 | for block in self.encoder: 122 | x, skip = block(x) 123 | skips.append(skip) 124 | 125 | for block, upsample, skip in zip(self.decoder, self.upsample, reversed(skips)): 126 | x = upsample(x) 127 | x = rearrange([x, skip], "t b c h w -> b (t c) h w") 128 | x = block(x) 129 | 130 | return x 131 | 132 | 133 | class FCSiamConc(nn.Module): 134 | """ Fully-convolutional Siamese Concatenation (FC-Siam-conc) from 135 | 'Fully Convolutional Siamese Networks for Change Detection', Daudt et al. (2018) 136 | https://arxiv.org/abs/1810.08462 137 | """ 138 | def __init__(self, channels: int = 3, t: int = 2, num_classes: int = 2): 139 | super().__init__() 140 | self.encoder = SiamEncoder(channels) 141 | self.decoder = ConcatDecoder(t, num_classes) 142 | self.upsample = Upsample() 143 | self.pool = nn.MaxPool2d(kernel_size=2) 144 | 145 | def forward(self, x: torch.Tensor) -> torch.Tensor: 146 | b, t, c, h, w = x.shape 147 | x = rearrange(x, "b t c h w -> (b t) c h w") 148 | 149 | skips = [] 150 | for block in self.encoder: 151 | x, skip = block(x) 152 | skips.append(skip) 153 | 154 | # Concat skips 155 | skips = [rearrange(skip, "(b t) c h w -> b (t c) h w", t=t) for skip in skips] 156 | 157 | # Only first input encoding is passed directly to decoder 158 | x = rearrange(x, "(b t) c h w -> b t c h w", t=t) 159 | x = x[:, 0, ...] 160 | x = self.pool(x) 161 | 162 | for block, upsample, skip in zip(self.decoder, self.upsample, reversed(skips)): 163 | x = upsample(x) 164 | x = torch.cat([x, skip], dim=1) 165 | x = block(x) 166 | 167 | return x 168 | 169 | 170 | class FCSiamDiff(nn.Module): 171 | """ Fully-convolutional Siamese Difference (FC-Siam-diff) from 172 | 'Fully Convolutional Siamese Networks for Change Detection', Daudt et al. (2018) 173 | https://arxiv.org/abs/1810.08462 174 | """ 175 | def __init__(self, channels: int = 3, t: int = 2, num_classes: int = 2): 176 | super().__init__() 177 | self.encoder = SiamEncoder(channels) 178 | self.decoder = Decoder(num_classes) 179 | self.upsample = Upsample() 180 | self.pool = nn.MaxPool2d(kernel_size=2) 181 | 182 | def forward(self, x: torch.Tensor) -> torch.Tensor: 183 | b, t, c, h, w = x.shape 184 | x = rearrange(x, "b t c h w -> (b t) c h w") 185 | 186 | skips = [] 187 | for block in self.encoder: 188 | x, skip = block(x) 189 | skips.append(skip) 190 | 191 | # Diff skips 192 | skips = [rearrange(skip, "(b t) c h w -> b t c h w", t=t) for skip in skips] 193 | diffs = [] 194 | for skip in skips: 195 | diff, xt = skip[:, 0, ...], skip[:, 1:, ...] 196 | for i in range(t - 1): 197 | diff = torch.abs(diff - xt[:, i, ...]) 198 | diffs.append(diff) 199 | 200 | # Only first input encoding is passed directly to decoder 201 | x = rearrange(x, "(b t) c h w -> b t c h w", t=t) 202 | x = x[:, 0, ...] 203 | x = self.pool(x) 204 | 205 | for block, upsample, skip in zip(self.decoder, self.upsample, reversed(diffs)): 206 | x = upsample(x) 207 | x = torch.cat([x, skip], dim=1) 208 | x = block(x) 209 | 210 | return x 211 | -------------------------------------------------------------------------------- /torchrs/models/oscd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | 6 | class ConvBlock(nn.Sequential): 7 | 8 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, dropout: float = 0.2, norm: bool = True): 9 | super().__init__( 10 | nn.Conv2d(in_channels, out_channels, kernel_size), 11 | nn.BatchNorm2d(out_channels) if norm else nn.Identity(), 12 | nn.ReLU(), 13 | nn.Dropout(dropout) 14 | ) 15 | 16 | 17 | class EarlyFusion(nn.Module): 18 | """ Early Fusion (EF) from 'Urban Change Detection for Multispectral Earth Observation 19 | Using ConvolutionalNeural Networks', Daudt et al. (2018) 20 | https://arxiv.org/abs/1810.08468 21 | 22 | This model takes as input the concatenated image pairs (T*C, 15, 15) 23 | and is essentially a simple CNN classifier of the central pixel in an input patch. 24 | Assumes (T*Cx15x15) patch size input 25 | """ 26 | def __init__(self, channels: int = 3, t: int = 2, num_classes: int = 2): 27 | super().__init__() 28 | filters = [channels * t, 32, 32, 64, 64, 128, 128] 29 | dropout = 0.2 30 | self.encoder = nn.Sequential( 31 | *[ConvBlock(filters[i-1], filters[i]) for i in range(1, len(filters))], 32 | ConvBlock(filters[-1], 128, dropout=0.0, norm=False), 33 | nn.Flatten() 34 | ) 35 | self.mlp = nn.Sequential( 36 | nn.Linear(128, 8), 37 | nn.BatchNorm1d(8), 38 | nn.ReLU(), 39 | nn.Dropout2d(dropout), 40 | nn.Linear(8, num_classes) 41 | ) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | x = rearrange(x, "b t c h w -> b (t c) h w") 45 | x = self.encoder(x) 46 | x = self.mlp(x) 47 | return x 48 | 49 | 50 | class Siam(nn.Module): 51 | """ Siamese (Siam) from 'Urban Change Detection for Multispectral Earth Observation 52 | Using ConvolutionalNeural Networks', Daudt et al. (2018) 53 | https://arxiv.org/abs/1810.08468 54 | 55 | This model takes as input the concatenated image pairs (T*C, 15, 15) 56 | and is essentially a simple CNN classifier of the central pixel in an input patch. 57 | Assumes (T*Cx15x15) patch size input 58 | """ 59 | def __init__(self, channels: int = 3, t: int = 2, num_classes: int = 2): 60 | super().__init__() 61 | filters = [channels, 64, 64, 128] 62 | dropout = 0.2 63 | self.encoder = nn.Sequential( 64 | *[ConvBlock(filters[i-1], filters[i]) for i in range(1, len(filters))], 65 | ConvBlock(filters[-1], 128, dropout=0.0), 66 | ) 67 | self.mlp = nn.Sequential( 68 | nn.Linear(t*128*7*7, 64), 69 | nn.BatchNorm1d(64), 70 | nn.ReLU(), 71 | nn.Dropout2d(dropout), 72 | nn.Linear(64, num_classes) 73 | ) 74 | 75 | def forward(self, x: torch.Tensor) -> torch.Tensor: 76 | b, t, c, h, w = x.shape 77 | x = rearrange(x, "b t c h w -> (b t) c h w") 78 | x = self.encoder(x) 79 | x = rearrange(x, "(b t) c h w -> b (t c h w)", b=b) 80 | x = self.mlp(x) 81 | return x 82 | -------------------------------------------------------------------------------- /torchrs/models/rams.py: -------------------------------------------------------------------------------- 1 | """ Referenced from official TF implementation https://github.com/EscVM/RAMS/blob/master/utils/network.py """ 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | from einops.layers.torch import Rearrange, Reduce 6 | 7 | 8 | class ReflectionPad3d(nn.Module): 9 | """ Custom 3D reflection padding for only h, w dims """ 10 | def __init__(self, padding): 11 | super().__init__() 12 | self.pad = nn.ReflectionPad2d(padding) 13 | 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | b, t, c, h, w = x.shape 16 | x = rearrange(x, "b t c h w -> b (t c) h w") 17 | x = self.pad(x) 18 | x = rearrange(x, "b (t c) h w -> b t c h w", t=t, c=c) 19 | return x 20 | 21 | 22 | class TemporalAttention(nn.Module): 23 | """ Temporal Attention Block """ 24 | def __init__(self, channels: int, kernel_size: int, r: int): 25 | super().__init__() 26 | self.model = nn.Sequential( 27 | Reduce("b c h w -> b c () ()", "mean"), 28 | nn.Conv2d(channels, channels//r, kernel_size, stride=1, padding=kernel_size//2), 29 | nn.ReLU(), 30 | nn.Conv2d(channels//r, channels, kernel_size, stride=1, padding=kernel_size//2), 31 | nn.Sigmoid() 32 | ) 33 | 34 | def forward(self, x: torch.Tensor) -> torch.Tensor: 35 | return x * self.model(x) 36 | 37 | 38 | class FeatureAttention(nn.Module): 39 | """ Feature Attention Block """ 40 | def __init__(self, channels: int, kernel_size: int, r: int): 41 | super().__init__() 42 | self.model = nn.Sequential( 43 | Reduce("b c t h w -> b c () () ()", "mean"), 44 | nn.Conv3d(channels, channels//r, kernel_size, stride=1, padding=kernel_size//2), 45 | nn.ReLU(), 46 | nn.Conv3d(channels//r, channels, kernel_size, stride=1, padding=kernel_size//2), 47 | nn.Sigmoid() 48 | ) 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | return x * self.model(x) 52 | 53 | 54 | class RTAB(nn.Module): 55 | """ Residual Temporal Attention Block """ 56 | def __init__(self, channels: int, kernel_size: int, r: int): 57 | super().__init__() 58 | self.model = nn.Sequential( 59 | nn.Conv2d(channels, channels, kernel_size, stride=1, padding=kernel_size // 2), 60 | nn.ReLU(), 61 | nn.Conv2d(channels, channels, kernel_size, stride=1, padding=kernel_size // 2), 62 | TemporalAttention(channels, kernel_size, r) 63 | ) 64 | 65 | def forward(self, x: torch.Tensor) -> torch.Tensor: 66 | return x + self.model(x) 67 | 68 | 69 | class RFAB(nn.Module): 70 | """ Residual Feature Attention Block """ 71 | def __init__(self, channels: int, kernel_size: int, r: int): 72 | super().__init__() 73 | self.model = nn.Sequential( 74 | nn.Conv3d(channels, channels, kernel_size, stride=1, padding=kernel_size // 2), 75 | nn.ReLU(), 76 | nn.Conv3d(channels, channels, kernel_size, stride=1, padding=kernel_size // 2), 77 | FeatureAttention(channels, kernel_size, r) 78 | ) 79 | 80 | def forward(self, x: torch.Tensor) -> torch.Tensor: 81 | return x + self.model(x) 82 | 83 | 84 | class TemporalReductionBlock(nn.Module): 85 | """ Temporal Reduction Block """ 86 | def __init__(self, channels: int, kernel_size: int, r: int): 87 | super().__init__() 88 | self.model = nn.Sequential( 89 | ReflectionPad3d(1), 90 | RFAB(channels, kernel_size, r), 91 | nn.Conv3d(channels, channels, kernel_size, stride=1, padding=0), 92 | nn.ReLU() 93 | ) 94 | 95 | def forward(self, x: torch.Tensor) -> torch.Tensor: 96 | return self.model(x) 97 | 98 | 99 | class RAMS(nn.Module): 100 | """ 101 | Residual Attention Multi-image Super-resolution Network (RAMS) 102 | 'Multi-Image Super Resolution of Remotely Sensed Images Using Residual Attention Deep Neural Networks' 103 | Salvetti et al. (2021) 104 | https://www.mdpi.com/2072-4292/12/14/2207 105 | 106 | Note this model was built to work with t=9 input images and kernel_size=3. Other values may not work. 107 | t must satisfy the constraints of ((t-1)/(kernel_size-1) - 1) % 1 == 0 where kernel_size=3 and t >= 5. 108 | Some valid t's are [9, 11, 13, ...] 109 | """ 110 | def __init__( 111 | self, 112 | scale_factor: int = 3, 113 | t: int = 9, 114 | c: int = 1, 115 | num_feature_attn_blocks: int = 12, 116 | ): 117 | super().__init__() 118 | filters = 32 119 | kernel_size = 3 120 | r = 8 121 | num_temporal_redn_blocks = ((t-1)/(kernel_size-1) - 1) 122 | err = """t must satisfy the ((t-1)/(kernel_size-1) - 1) % 1 == 0 where kernel_size=3 123 | and t >= 5. Some valid t's are [9, 11, 13, 15, ...] """ 124 | assert num_temporal_redn_blocks % 1 == 0 and t >= 9, err 125 | 126 | self.temporal_attn = nn.Sequential( 127 | Rearrange("b t c h w -> b (t c) h w"), 128 | nn.ReflectionPad2d(1), 129 | RTAB(t * c, kernel_size, r), 130 | ) 131 | self.residual_upsample = nn.Sequential( 132 | nn.Conv2d(t * c, c * scale_factor ** 2, kernel_size, stride=1, padding=0), 133 | nn.PixelShuffle(scale_factor) 134 | ) 135 | self.head = nn.Sequential( 136 | Rearrange("b t c h w -> b c t h w"), 137 | ReflectionPad3d(1), 138 | nn.Conv3d(c, filters, kernel_size, stride=1, padding=kernel_size//2) 139 | ) 140 | self.feature_attn = nn.Sequential( 141 | *[RFAB(filters, kernel_size, r) for _ in range(num_feature_attn_blocks)], 142 | nn.Conv3d(filters, filters, kernel_size, stride=1, padding=kernel_size//2) 143 | ) 144 | self.temporal_redn = nn.Sequential( 145 | *[TemporalReductionBlock(filters, kernel_size, r) for _ in range(int(num_temporal_redn_blocks))] 146 | ) 147 | self.main_upsample = nn.Sequential( 148 | nn.Conv3d(filters, c * scale_factor ** 2, kernel_size, stride=1, padding=0), 149 | Rearrange("b c t h w -> b (c t) h w"), 150 | nn.PixelShuffle(scale_factor), 151 | ) 152 | 153 | def forward(self, x: torch.Tensor) -> torch.Tensor: 154 | # Main branch 155 | h = self.head(x) 156 | feature_attn = h + self.feature_attn(h) 157 | temporal_redn = self.temporal_redn(feature_attn) 158 | temporal_redn = self.main_upsample(temporal_redn) 159 | 160 | # Global residual branch 161 | temporal_attn = self.temporal_attn(x) 162 | temporal_attn = self.residual_upsample(temporal_attn) 163 | 164 | return temporal_attn + temporal_redn 165 | -------------------------------------------------------------------------------- /torchrs/train/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datamodules 2 | from . import modules 3 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDataModule 2 | from .probav import PROBAVDataModule 3 | from .etci2021 import ETCI2021DataModule 4 | from .rsvqa import RSVQALRDataModule, RSVQAHRDataModule, RSVQAxBENDataModule 5 | from .eurosat import EuroSATRGBDataModule, EuroSATMSDataModule 6 | from .resisc45 import RESISC45DataModule 7 | from .rsicd import RSICDDataModule 8 | from .oscd import OSCDDataModule 9 | from .s2looking import S2LookingDataModule 10 | from .levircd import LEVIRCDPlusDataModule 11 | from .fair1m import FAIR1MDataModule 12 | from .sydney_captions import SydneyCaptionsDataModule 13 | from .ucm_captions import UCMCaptionsDataModule 14 | from .s2mtcp import S2MTCPDataModule 15 | from .advance import ADVANCEDataModule 16 | from .sat import SAT4DataModule, SAT6DataModule 17 | from .hrscd import HRSCDDataModule 18 | from .inria_ail import InriaAILDataModule 19 | from .tiselac import TiselacDataModule 20 | from .gid15 import GID15DataModule 21 | from .zuericrop import ZueriCropDataModule 22 | from .aid import AIDDataModule 23 | from .dubai_segmentation import DubaiSegmentationDataModule 24 | from .hkh_glacier import HKHGlacierMappingDataModule 25 | from .ucm import UCMDataModule 26 | from .patternnet import PatternNetDataModule 27 | from .whu_rs19 import WHURS19DataModule 28 | from .rsscn7 import RSSCN7DataModule 29 | from .brazilian_coffee import BrazilianCoffeeScenesDataModule 30 | 31 | 32 | __all__ = [ 33 | "BaseDataModule", "PROBAVDataModule", "ETCI2021DataModule", "RSVQALRDataModule", 34 | "RSVQAxBENDataModule", "EuroSATRGBDataModule", "EuroSATMSDataModule", "RESISC45DataModule", 35 | "RSICDDataModule", "OSCDDataModule", "S2LookingDataModule", "LEVIRCDPlusDataModule", 36 | "FAIR1MDataModule", "SydneyCaptionsDataModule", "UCMCaptionsDataModule", "S2MTCPDataModule", 37 | "ADVANCEDataModule", "SAT4DataModule", "SAT6DataModule", "HRSCDDataModule", "InriaAILDataModule", 38 | "TiselacDataModule", "GID15DataModule", "ZueriCropDataModule", "AIDDataModule", 39 | "DubaiSegmentationDataModule", "HKHGlacierMappingDataModule", "UCMDataModule", "PatternNetDataModule", 40 | "RSVQAHRDataModule", "WHURS19DataModule", "RSSCN7DataModule", "BrazilianCoffeeScenesDataModule" 41 | ] 42 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/advance.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import ADVANCE 8 | 9 | 10 | class ADVANCEDataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/advance", 15 | image_transform: T.Compose = T.Compose([T.ToTensor()]), 16 | audio_transform: T.Compose = T.Compose([]), 17 | *args, **kwargs 18 | ): 19 | super().__init__(*args, **kwargs) 20 | self.root = root 21 | self.image_transform = image_transform 22 | self.audio_transform = audio_transform 23 | 24 | def setup(self, stage: Optional[str] = None): 25 | dataset = ADVANCE( 26 | root=self.root, image_transform=self.image_transform, audio_transform=self.audio_transform 27 | ) 28 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 29 | dataset, val_pct=self.val_split, test_pct=self.test_split 30 | ) 31 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/aid.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import AID 8 | 9 | 10 | class AIDDataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/NWPU-RESISC45", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | dataset = AID(root=self.root, transform=self.transform) 24 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 25 | dataset, val_pct=self.val_split, test_pct=self.test_split 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/base.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | 3 | import pytorch_lightning as pl 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class BaseDataModule(pl.LightningDataModule): 8 | 9 | def __init__( 10 | self, 11 | batch_size: int = 1, 12 | num_workers: int = 0, 13 | prefetch_factor: int = 2, 14 | pin_memory: bool = False, 15 | collate_fn: Optional[Callable] = None, 16 | test_collate_fn: Optional[Callable] = None, 17 | val_split: float = 0.1, 18 | test_split: float = 0.25 19 | ): 20 | super().__init__() 21 | self.batch_size = batch_size 22 | self.num_workers = num_workers 23 | self.prefetch_factor = prefetch_factor 24 | self.pin_memory = pin_memory 25 | self.collate_fn = collate_fn 26 | self.test_collate_fn = test_collate_fn 27 | self.val_split = val_split 28 | self.test_split = test_split 29 | 30 | def train_dataloader(self) -> DataLoader: 31 | return DataLoader( 32 | self.train_dataset, 33 | shuffle=True, 34 | batch_size=self.batch_size, 35 | num_workers=self.num_workers, 36 | prefetch_factor=self.prefetch_factor, 37 | pin_memory=self.pin_memory, 38 | collate_fn=self.collate_fn 39 | ) 40 | 41 | def val_dataloader(self) -> DataLoader: 42 | return DataLoader( 43 | self.val_dataset, 44 | batch_size=self.batch_size, 45 | num_workers=self.num_workers, 46 | prefetch_factor=self.prefetch_factor, 47 | pin_memory=self.pin_memory, 48 | collate_fn=self.test_collate_fn 49 | ) 50 | 51 | def test_dataloader(self) -> DataLoader: 52 | return DataLoader( 53 | self.test_dataset, 54 | batch_size=self.batch_size, 55 | num_workers=self.num_workers, 56 | prefetch_factor=self.prefetch_factor, 57 | pin_memory=self.pin_memory, 58 | collate_fn=self.test_collate_fn 59 | ) 60 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/brazilian_coffee.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import BrazilianCoffeeScenes 8 | 9 | 10 | class BrazilianCoffeeScenesDataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/brazilian_coffee_scenes", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | dataset = BrazilianCoffeeScenes(root=self.root, transform=self.transform) 24 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 25 | dataset, val_pct=self.val_split, test_pct=self.test_split 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/dubai_segmentation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import DubaiSegmentation 7 | 8 | 9 | class DubaiSegmentationDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/dubai-segmentation", 14 | transform: Compose = Compose([ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | dataset = DubaiSegmentation(root=self.root, transform=self.transform) 23 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 24 | dataset, val_pct=self.val_split, test_pct=self.test_split 25 | ) 26 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/etci2021.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.train.datamodules import BaseDataModule 5 | from torchrs.datasets import ETCI2021 6 | 7 | 8 | class ETCI2021DataModule(BaseDataModule): 9 | 10 | def __init__( 11 | self, 12 | root: str = ".data/etci2021", 13 | transform: Compose = Compose([ToTensor()]), 14 | *args, **kwargs 15 | ): 16 | super().__init__(*args, **kwargs) 17 | self.root = root 18 | self.transform = transform 19 | 20 | def setup(self, stage: Optional[str] = None): 21 | self.train_dataset = ETCI2021(root=self.root, split="train", transform=self.transform) 22 | self.val_dataset = ETCI2021(root=self.root, split="val", transform=self.transform) 23 | self.test_dataset = ETCI2021(root=self.root, split="test", transform=self.transform) 24 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/eurosat.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.transforms import ToTensor 6 | from torchrs.datasets.utils import dataset_split 7 | from torchrs.train.datamodules import BaseDataModule 8 | from torchrs.datasets import EuroSATRGB, EuroSATMS 9 | 10 | 11 | class EuroSATRGBDataModule(BaseDataModule): 12 | 13 | def __init__( 14 | self, 15 | root: str = ".data/eurosat-rgb", 16 | transform: T.Compose = T.Compose([T.ToTensor()]), 17 | *args, **kwargs 18 | ): 19 | super().__init__(*args, **kwargs) 20 | self.root = root 21 | self.transform = transform 22 | 23 | def setup(self, stage: Optional[str] = None): 24 | dataset = EuroSATRGB(root=self.root, transform=self.transform) 25 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 26 | dataset, val_pct=self.val_split, test_pct=self.test_split 27 | ) 28 | 29 | 30 | class EuroSATMSDataModule(BaseDataModule): 31 | 32 | def __init__( 33 | self, 34 | root: str = ".data/eurosat-ms", 35 | transform: T.Compose = T.Compose([ToTensor()]), 36 | *args, **kwargs 37 | ): 38 | super().__init__(*args, **kwargs) 39 | self.root = root 40 | self.transform = transform 41 | 42 | def setup(self, stage: Optional[str] = None): 43 | dataset = EuroSATMS(root=self.root, transform=self.transform) 44 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 45 | dataset, val_pct=self.val_split, test_pct=self.test_split 46 | ) 47 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/fair1m.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import FAIR1M 8 | 9 | 10 | class FAIR1MDataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/fair1m", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | dataset = FAIR1M(root=self.root, transform=self.transform) 24 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 25 | dataset, val_pct=self.val_split, test_pct=self.test_split 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/gid15.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.train.datamodules import BaseDataModule 5 | from torchrs.datasets import GID15 6 | 7 | 8 | class GID15DataModule(BaseDataModule): 9 | 10 | def __init__( 11 | self, 12 | root: str = ".data/gid-15", 13 | transform: Compose = Compose([ToTensor()]), 14 | *args, **kwargs 15 | ): 16 | super().__init__(*args, **kwargs) 17 | self.root = root 18 | self.transform = transform 19 | 20 | def setup(self, stage: Optional[str] = None): 21 | self.train_dataset = GID15(root=self.root, split="train", transform=self.transform) 22 | self.val_dataset = GID15(root=self.root, split="val", transform=self.transform) 23 | self.test_dataset = GID15(root=self.root, split="test", transform=self.transform) 24 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/hkh_glacier.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import HKHGlacierMapping 7 | 8 | 9 | class HKHGlacierMappingDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/hkh_glacier_mapping", 14 | transform: Compose = Compose([ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | dataset = HKHGlacierMapping(root=self.root, transform=self.transform) 23 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 24 | dataset, val_pct=self.val_split, test_pct=self.test_split 25 | ) 26 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/hrscd.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import HRSCD 7 | 8 | 9 | class HRSCDDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/HRSCD", 14 | transform: Compose = Compose([ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | dataset = HRSCD(root=self.root, transform=self.transform) 23 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 24 | dataset, val_pct=self.val_split, test_pct=self.test_split 25 | ) 26 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/inria_ail.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import InriaAIL 7 | 8 | 9 | class InriaAILDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/AerialImageDataset", 14 | transform: Compose = Compose([ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | train_dataset = InriaAIL(root=self.root, split="train", transform=self.transform) 23 | self.train_dataset, self.val_dataset = dataset_split(train_dataset, val_pct=self.val_split) 24 | self.test_dataset = InriaAIL(root=self.root, split="test", transform=self.transform) 25 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/levircd.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import LEVIRCDPlus 7 | 8 | 9 | class LEVIRCDPlusDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/levircd_plus", 14 | transform: Compose = Compose([ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | train_dataset = LEVIRCDPlus(root=self.root, split="train", transform=self.transform) 23 | self.train_dataset, self.val_dataset = dataset_split(train_dataset, val_pct=self.val_split) 24 | self.test_dataset = LEVIRCDPlus(root=self.root, split="test", transform=self.transform) 25 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/oscd.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import OSCD 7 | 8 | 9 | class OSCDDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/oscd", 14 | transform: Compose = Compose([ToTensor(permute_dims=False)]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | train_dataset = OSCD(root=self.root, split="train", transform=self.transform) 23 | self.train_dataset, self.val_dataset = dataset_split(train_dataset, val_pct=self.val_split) 24 | self.test_dataset = OSCD(root=self.root, split="test", transform=self.transform) 25 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/patternnet.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import PatternNet 8 | 9 | 10 | class PatternNetDataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/PatternNet", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | dataset = PatternNet(root=self.root, transform=self.transform) 24 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 25 | dataset, val_pct=self.val_split, test_pct=self.test_split 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/probav.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torchvision.transforms as T 5 | from einops import rearrange 6 | 7 | from torchrs.datasets.utils import dataset_split 8 | from torchrs.train.datamodules import BaseDataModule 9 | from torchrs.transforms import ToTensor, ToDtype 10 | from torchrs.datasets import PROBAV 11 | 12 | 13 | class PROBAVDataModule(BaseDataModule): 14 | 15 | def __init__( 16 | self, 17 | root: str = ".data/probav", 18 | band: str = "RED", 19 | lr_transform: T.Compose = T.Compose([ToTensor(), ToDtype(torch.float32)]), 20 | hr_transform: T.Compose = T.Compose([ToTensor(), ToDtype(torch.float32)]), 21 | *args, **kwargs 22 | ): 23 | super().__init__(*args, **kwargs) 24 | self.root = root 25 | self.band = band 26 | self.lr_transform = lr_transform 27 | self.hr_transform = hr_transform 28 | 29 | def setup(self, stage: Optional[str] = None): 30 | train_dataset = PROBAV( 31 | root=self.root, split="train", band=self.band, lr_transform=self.lr_transform, hr_transform=self.hr_transform 32 | ) 33 | self.train_dataset, self.val_dataset = dataset_split(train_dataset, val_pct=self.val_split) 34 | self.test_dataset = PROBAV( 35 | root=self.root, split="test", band=self.band, lr_transform=self.lr_transform, hr_transform=self.hr_transform 36 | ) 37 | 38 | def on_before_batch_transfer(self, batch, dataloader_idx): 39 | """ Handle if lr and hr are chipped by ExtractChips transform """ 40 | if batch["lr"].ndim == 6: 41 | batch["lr"] = rearrange(batch["lr"], "b t d c h w -> (b d) t c h w") 42 | if batch["hr"].ndim == 5: 43 | batch["hr"] = rearrange(batch["hr"], "b d c h w -> (b d) c h w") 44 | return batch 45 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/resisc45.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import RESISC45 8 | 9 | 10 | class RESISC45DataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/NWPU-RESISC45", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | dataset = RESISC45(root=self.root, transform=self.transform) 24 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 25 | dataset, val_pct=self.val_split, test_pct=self.test_split 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/rsicd.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import RSICD 7 | 8 | 9 | class RSICDDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/rsicd", 14 | transform: T.Compose = T.Compose([T.ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | self.train_dataset = RSICD(root=self.root, split="train", transform=self.transform) 23 | self.val_dataset = RSICD(root=self.root, split="val", transform=self.transform) 24 | self.test_dataset = RSICD(root=self.root, split="test", transform=self.transform) 25 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/rsscn7.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import RSSCN7 8 | 9 | 10 | class RSSCN7DataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/RSSCN7", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | dataset = RSSCN7(root=self.root, transform=self.transform) 24 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 25 | dataset, val_pct=self.val_split, test_pct=self.test_split 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/rsvqa.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.train.datamodules import BaseDataModule 4 | from torchrs.datasets import RSVQALR, RSVQAHR, RSVQAxBEN 5 | from torchrs.transforms import Compose, ToTensor 6 | 7 | 8 | class RSVQALRDataModule(BaseDataModule): 9 | 10 | def __init__( 11 | self, 12 | root: str = ".data/RSVQA_LR", 13 | image_transform: Compose = Compose([ToTensor()]), 14 | text_transform: Compose = Compose([]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.image_transform = image_transform 20 | self.text_transform = text_transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | self.train_dataset = RSVQALR(root=self.root, split="train", image_transform=self.image_transform, text_transform=self.text_transform) 24 | self.val_dataset = RSVQALR(root=self.root, split="val", image_transform=self.image_transform, text_transform=self.text_transform) 25 | self.test_dataset = RSVQALR(root=self.root, split="test", image_transform=self.image_transform, text_transform=self.text_transform) 26 | 27 | 28 | class RSVQAHRDataModule(BaseDataModule): 29 | 30 | def __init__( 31 | self, 32 | root: str = ".data/RSVQA_HR", 33 | image_transform: Compose = Compose([ToTensor()]), 34 | text_transform: Compose = Compose([]), 35 | *args, **kwargs 36 | ): 37 | super().__init__(*args, **kwargs) 38 | self.root = root 39 | self.image_transform = image_transform 40 | self.text_transform = text_transform 41 | 42 | def setup(self, stage: Optional[str] = None): 43 | self.train_dataset = RSVQAHR(root=self.root, split="train", image_transform=self.image_transform, text_transform=self.text_transform) 44 | self.val_dataset = RSVQAHR(root=self.root, split="val", image_transform=self.image_transform, text_transform=self.text_transform) 45 | self.test_dataset = RSVQAHR(root=self.root, split="test", image_transform=self.image_transform, text_transform=self.text_transform) 46 | 47 | 48 | class RSVQAxBENDataModule(BaseDataModule): 49 | 50 | def __init__( 51 | self, 52 | root: str = ".data/rsvqaxben", 53 | image_transform: Compose = Compose([ToTensor()]), 54 | text_transform: Compose = Compose([]), 55 | *args, **kwargs 56 | ): 57 | super().__init__(*args, **kwargs) 58 | self.root = root 59 | self.image_transform = image_transform 60 | self.text_transform = text_transform 61 | 62 | def setup(self, stage: Optional[str] = None): 63 | self.train_dataset = RSVQAxBEN(root=self.root, split="train", image_transform=self.image_transform, text_transform=self.text_transform) 64 | self.val_dataset = RSVQAxBEN(root=self.root, split="val", image_transform=self.image_transform, text_transform=self.text_transform) 65 | self.test_dataset = RSVQAxBEN(root=self.root, split="test", image_transform=self.image_transform, text_transform=self.text_transform) 66 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/s2looking.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.train.datamodules import BaseDataModule 4 | from torchrs.datasets import S2Looking 5 | from torchrs.transforms import Compose, ToTensor 6 | 7 | 8 | class S2LookingDataModule(BaseDataModule): 9 | 10 | def __init__( 11 | self, 12 | root: str = ".data/s2looking", 13 | transform: Compose = Compose([ToTensor()]), 14 | *args, **kwargs 15 | ): 16 | super().__init__(*args, **kwargs) 17 | self.root = root 18 | self.transform = transform 19 | 20 | def setup(self, stage: Optional[str] = None): 21 | self.train_dataset = S2Looking(root=self.root, split="train", transform=self.transform) 22 | self.val_dataset = S2Looking(root=self.root, split="val", transform=self.transform) 23 | self.test_dataset = S2Looking(root=self.root, split="test", transform=self.transform) 24 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/s2mtcp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import S2MTCP 7 | 8 | 9 | class S2MTCPDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/s2mtcp", 14 | transform: Compose = Compose([ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | dataset = S2MTCP(root=self.root, transform=self.transform) 23 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 24 | dataset, val_pct=self.val_split, test_pct=self.test_split 25 | ) 26 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/sat.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import SAT4, SAT6 8 | 9 | 10 | class SAT4DataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/sat/sat4.h5", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | train_dataset = SAT4(root=self.root, split="train", transform=self.transform) 24 | self.train_dataset, self.val_dataset = dataset_split(train_dataset, val_pct=self.val_split) 25 | self.test_dataset = SAT4(root=self.root, split="test", transform=self.transform) 26 | 27 | 28 | class SAT6DataModule(BaseDataModule): 29 | 30 | def __init__( 31 | self, 32 | root: str = ".data/sat/sat6.h5", 33 | transform: T.Compose = T.Compose([T.ToTensor()]), 34 | *args, **kwargs 35 | ): 36 | super().__init__(*args, **kwargs) 37 | self.root = root 38 | self.transform = transform 39 | 40 | def setup(self, stage: Optional[str] = None): 41 | train_dataset = SAT6(root=self.root, split="train", transform=self.transform) 42 | self.train_dataset, self.val_dataset = dataset_split(train_dataset, val_pct=self.val_split) 43 | self.test_dataset = SAT6(root=self.root, split="test", transform=self.transform) 44 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/sydney_captions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import SydneyCaptions 7 | 8 | 9 | class SydneyCaptionsDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/sydney_captions", 14 | transform: T.Compose = T.Compose([T.ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | self.train_dataset = SydneyCaptions(root=self.root, split="train", transform=self.transform) 23 | self.val_dataset = SydneyCaptions(root=self.root, split="val", transform=self.transform) 24 | self.test_dataset = SydneyCaptions(root=self.root, split="test", transform=self.transform) 25 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/tiselac.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import Tiselac 7 | 8 | 9 | class TiselacDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/tiselac", 14 | transform: Compose = Compose([ToTensor(permute_dims=False)]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | train_dataset = Tiselac(root=self.root, split="train", transform=self.transform) 23 | self.train_dataset, self.val_dataset = dataset_split(train_dataset, val_pct=self.val_split) 24 | self.test_dataset = Tiselac(root=self.root, split="test", transform=self.transform) 25 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/ucm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import UCM 8 | 9 | 10 | class UCMDataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/UCMerced_LandUse", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | dataset = UCM(root=self.root, transform=self.transform) 24 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 25 | dataset, val_pct=self.val_split, test_pct=self.test_split 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/ucm_captions.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import UCMCaptions 7 | 8 | 9 | class UCMCaptionsDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/ucm_captions", 14 | transform: T.Compose = T.Compose([T.ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | self.train_dataset = UCMCaptions(root=self.root, split="train", transform=self.transform) 23 | self.val_dataset = UCMCaptions(root=self.root, split="val", transform=self.transform) 24 | self.test_dataset = UCMCaptions(root=self.root, split="test", transform=self.transform) 25 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/whu_rs19.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torchvision.transforms as T 4 | 5 | from torchrs.datasets.utils import dataset_split 6 | from torchrs.train.datamodules import BaseDataModule 7 | from torchrs.datasets import WHURS19 8 | 9 | 10 | class WHURS19DataModule(BaseDataModule): 11 | 12 | def __init__( 13 | self, 14 | root: str = ".data/WHU-RS19", 15 | transform: T.Compose = T.Compose([T.ToTensor()]), 16 | *args, **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.root = root 20 | self.transform = transform 21 | 22 | def setup(self, stage: Optional[str] = None): 23 | dataset = WHURS19(root=self.root, transform=self.transform) 24 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 25 | dataset, val_pct=self.val_split, test_pct=self.test_split 26 | ) 27 | -------------------------------------------------------------------------------- /torchrs/train/datamodules/zuericrop.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torchrs.transforms import Compose, ToTensor 4 | from torchrs.datasets.utils import dataset_split 5 | from torchrs.train.datamodules import BaseDataModule 6 | from torchrs.datasets import ZueriCrop 7 | 8 | 9 | class ZueriCropDataModule(BaseDataModule): 10 | 11 | def __init__( 12 | self, 13 | root: str = ".data/ZueriCrop", 14 | transform: Compose = Compose([ToTensor()]), 15 | *args, **kwargs 16 | ): 17 | super().__init__(*args, **kwargs) 18 | self.root = root 19 | self.transform = transform 20 | 21 | def setup(self, stage: Optional[str] = None): 22 | dataset = ZueriCrop(root=self.root, transform=self.transform) 23 | self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( 24 | dataset, val_pct=self.val_split, test_pct=self.test_split 25 | ) 26 | -------------------------------------------------------------------------------- /torchrs/train/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .rams import RAMSModule 2 | from .fc_cd import FCEFModule, FCSiamConcModule, FCSiamDiffModule 3 | 4 | __all__ = ["RAMSModule", "FCEFModule", "FCSiamConcModule", "FCSiamDiffModule"] 5 | -------------------------------------------------------------------------------- /torchrs/train/modules/fc_cd.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchmetrics 6 | import pytorch_lightning as pl 7 | 8 | from torchrs.models import FCEF, FCSiamDiff, FCSiamConc 9 | 10 | 11 | class BaseFCCDModule(pl.LightningModule): 12 | 13 | def __init__( 14 | self, 15 | channels: int = 3, 16 | t: int = 2, 17 | num_classes: int = 2, 18 | loss_fn: nn.Module = nn.CrossEntropyLoss(), 19 | opt: torch.optim.Optimizer = torch.optim.Adam, 20 | lr: float = 3E-4 21 | ): 22 | super().__init__() 23 | self.loss_fn = loss_fn 24 | self.opt = opt 25 | self.lr = lr 26 | 27 | metrics = torchmetrics.MetricCollection([ 28 | torchmetrics.Accuracy(threshold=0.5, num_classes=num_classes, average="micro", mdmc_average="global"), 29 | torchmetrics.Precision(num_classes=num_classes, threshold=0.5, average="micro", mdmc_average="global"), 30 | torchmetrics.Recall(num_classes=num_classes, threshold=0.5, average="micro", mdmc_average="global"), 31 | torchmetrics.F1(num_classes=num_classes, threshold=0.5, average="micro", mdmc_average="global"), 32 | torchmetrics.IoU(threshold=0.5, num_classes=num_classes), 33 | ]) 34 | self.train_metrics = metrics.clone(prefix='train_') 35 | self.val_metrics = metrics.clone(prefix='val_') 36 | self.test_metrics = metrics.clone(prefix='test_') 37 | 38 | def configure_optimizers(self): 39 | return self.opt(self.parameters(), lr=self.lr) 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | return self.model(x) 43 | 44 | def training_step(self, batch: Dict, batch_idx: int): 45 | x, y = batch 46 | y_pred = self(x) 47 | loss = self.loss_fn(y_pred, y) 48 | metrics = self.train_metrics(y_pred.softmax(dim=1), y) 49 | metrics["train_loss"] = loss 50 | self.log_dict(metrics) 51 | return loss 52 | 53 | def validation_step(self, batch: Dict, batch_idx: int): 54 | x, y = batch 55 | y_pred = self(x) 56 | loss = self.loss_fn(y_pred, y) 57 | metrics = self.val_metrics(y_pred.softmax(dim=1), y) 58 | metrics["val_loss"] = loss 59 | self.log_dict(metrics) 60 | 61 | def test_step(self, batch: Dict, batch_idx: int): 62 | x, y = batch 63 | y_pred = self(x) 64 | loss = self.loss_fn(y_pred, y) 65 | metrics = self.test_metrics(y_pred.softmax(dim=1), y) 66 | metrics["test_loss"] = loss 67 | self.log_dict(metrics) 68 | 69 | 70 | class FCEFModule(BaseFCCDModule): 71 | 72 | def __init__( 73 | self, 74 | channels: int = 3, 75 | t: int = 2, 76 | num_classes: int = 2, 77 | *args, **kwargs 78 | ): 79 | super().__init__(channels, t, num_classes, *args, **kwargs) 80 | self.model = FCEF(channels, t, num_classes) 81 | 82 | 83 | class FCSiamConcModule(BaseFCCDModule): 84 | 85 | def __init__( 86 | self, 87 | channels: int = 3, 88 | t: int = 2, 89 | num_classes: int = 2, 90 | *args, **kwargs 91 | ): 92 | super().__init__(channels, t, num_classes, *args, **kwargs) 93 | self.model = FCSiamConc(channels, t, num_classes) 94 | 95 | 96 | class FCSiamDiffModule(BaseFCCDModule): 97 | 98 | def __init__( 99 | self, 100 | channels: int = 3, 101 | t: int = 2, 102 | num_classes: int = 2, 103 | *args, **kwargs 104 | ): 105 | super().__init__(channels, t, num_classes, *args, **kwargs) 106 | self.model = FCSiamDiff(channels, t, num_classes) 107 | -------------------------------------------------------------------------------- /torchrs/train/modules/rams.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchmetrics 6 | import pytorch_lightning as pl 7 | 8 | from torchrs.models import RAMS 9 | 10 | 11 | class RAMSModule(pl.LightningModule): 12 | 13 | def __init__( 14 | self, 15 | scale_factor: int = 3, 16 | t: int = 9, 17 | c: int = 1, 18 | num_feature_attn_blocks: int = 12, 19 | loss_fn: nn.Module = nn.MSELoss(), 20 | opt: torch.optim.Optimizer = torch.optim.Adam, 21 | lr: float = 3E-4 22 | ): 23 | super(RAMSModule, self).__init__() 24 | self.loss_fn = loss_fn 25 | self.opt = opt 26 | self.lr = lr 27 | self.model = RAMS(scale_factor, t, c, num_feature_attn_blocks) 28 | 29 | metrics = torchmetrics.MetricCollection([ 30 | torchmetrics.MeanSquaredError(), 31 | torchmetrics.MeanAbsoluteError(), 32 | torchmetrics.MeanAbsolutePercentageError(), 33 | torchmetrics.PSNR(), 34 | torchmetrics.SSIM() 35 | ]) 36 | self.train_metrics = metrics.clone(prefix='train_') 37 | self.val_metrics = metrics.clone(prefix='val_') 38 | self.test_metrics = metrics.clone(prefix='test_') 39 | 40 | def configure_optimizers(self): 41 | return self.opt(self.parameters(), lr=self.lr) 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | return self.model(x) 45 | 46 | def training_step(self, batch: Dict, batch_idx: int): 47 | lr, hr = batch["lr"], batch["hr"] 48 | sr = self(lr) 49 | loss = self.loss_fn(sr, hr) 50 | metrics = self.train_metrics(sr.to(torch.float32), hr) 51 | metrics["train_loss"] = loss 52 | self.log_dict(metrics) 53 | return loss 54 | 55 | def validation_step(self, batch: Dict, batch_idx: int): 56 | lr, hr = batch["lr"], batch["hr"] 57 | sr = self(lr) 58 | loss = self.loss_fn(sr, hr) 59 | metrics = self.val_metrics(sr.to(torch.float32), hr) 60 | metrics["val_loss"] = loss 61 | self.log_dict(metrics) 62 | 63 | def test_step(self, batch: Dict, batch_idx: int): 64 | lr, hr = batch["lr"], batch["hr"] 65 | sr = self(lr) 66 | loss = self.loss_fn(sr, hr) 67 | metrics = self.test_metrics(sr.to(torch.float32), hr) 68 | metrics["test_loss"] = loss 69 | self.log_dict(metrics) 70 | -------------------------------------------------------------------------------- /torchrs/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Any, Sequence, Callable, Tuple, List 2 | 3 | import torch 4 | import einops 5 | import numpy as np 6 | import torchvision.transforms as T 7 | 8 | 9 | __all__ = ["Compose", "ToTensor", "ToDtype", "ExtractChips", "MinMaxNormalize", "Clip"] 10 | 11 | 12 | class Compose(T.Compose): 13 | """ Custom Compose which processes a list of inputs """ 14 | def __init__(self, transforms: Sequence[Callable]): 15 | self.transforms = transforms 16 | 17 | def __call__(self, x: Union[Any, Sequence]): 18 | if isinstance(x, Sequence): 19 | for t in self.transforms: 20 | x = [t(i) for i in x] 21 | else: 22 | for t in self.transforms: 23 | x = t(x) 24 | return x 25 | 26 | 27 | class ToTensor(object): 28 | """ Custom ToTensor op which doesn't perform min-max normalization """ 29 | def __init__(self, permute_dims: bool = True): 30 | self.permute_dims = permute_dims 31 | 32 | def __call__(self, x: np.ndarray) -> torch.Tensor: 33 | 34 | if x.dtype == "uint16": 35 | x = x.astype("int32") 36 | 37 | if isinstance(x, np.ndarray): 38 | x = torch.from_numpy(x) 39 | 40 | if x.ndim == 2: 41 | if self.permute_dims: 42 | x = x[:, :, None] 43 | else: 44 | x = x[None, :, :] 45 | 46 | # Convert HWC->CHW 47 | if self.permute_dims: 48 | if x.ndim == 4: 49 | x = x.permute((0, 3, 1, 2)).contiguous() 50 | else: 51 | x = x.permute((2, 0, 1)).contiguous() 52 | 53 | return x 54 | 55 | 56 | class ToDtype(object): 57 | """ Convert input tensor to specified dtype """ 58 | def __init__(self, dtype: torch.dtype): 59 | self.dtype = dtype 60 | 61 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 62 | return x.to(self.dtype) 63 | 64 | 65 | class ExtractChips(object): 66 | """ Convert an tensor or ndarray into patches """ 67 | def __init__(self, shape: Tuple[int, int]): 68 | self.h, self.w = shape 69 | 70 | def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]: 71 | return einops.rearrange(x, "c (h p1) (w p2) -> (h w) c p1 p2", p1=self.h, p2=self.w) 72 | 73 | 74 | class MinMaxNormalize(object): 75 | """ Normalize channels to the range [0, 1] using min/max values """ 76 | def __init__(self, min: List[float], max: List[float]): 77 | self.min = torch.tensor(min)[:, None, None] 78 | self.max = torch.tensor(max)[:, None, None] 79 | self.denominator = (self.max - self.min) 80 | 81 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 82 | return (x - self.min) / self.denominator 83 | 84 | 85 | class Clip(object): 86 | """ Clip channels to the range [min, max] """ 87 | def __init__(self, min: List[float], max: List[float]): 88 | self.min = torch.tensor(min)[:, None, None] 89 | self.max = torch.tensor(max)[:, None, None] 90 | 91 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 92 | x = torch.where(x < self.min, self.min, x) 93 | x = torch.where(x > self.max, self.max, x) 94 | return x 95 | --------------------------------------------------------------------------------