├── .gitignore
├── LICENSE
├── README.md
├── assets
├── levircd-results.png
└── whucd-results.png
├── pyproject.toml
├── requirements.txt
├── scripts
├── preprocess_levircd.ipynb
└── preprocess_whucd.ipynb
├── src
├── __init__.py
├── change_detection.py
├── datasets
│ ├── __init__.py
│ ├── levircd.py
│ └── whucd.py
└── models
│ ├── __init__.py
│ ├── bit
│ ├── __init__.py
│ ├── help_funcs.py
│ ├── networks.py
│ └── resnet.py
│ ├── changeformer
│ ├── ChangeFormer.py
│ ├── ChangeFormerBaseNetworks.py
│ └── __init__.py
│ └── tiny_cd
│ ├── __init__.py
│ ├── change_classifier.py
│ └── layers.py
├── test_levircd.py
├── test_whucd.py
├── train_levircd.py
└── train_whucd.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | logs*
3 | *.csv
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # poetry
102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106 | #poetry.lock
107 |
108 | # pdm
109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110 | #pdm.lock
111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112 | # in version control.
113 | # https://pdm.fming.dev/#use-with-ide
114 | .pdm.toml
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Isaac Corley
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
A Change Detection Reality Check
3 |
4 | [**Isaac Corley**](https://isaacc.dev/)
1 · [**Caleb Robinson**](https://www.microsoft.com/en-us/research/people/davrob/)
2 · [**Anthony Ortiz**](https://www.microsoft.com/en-us/research/people/anort/)
2
5 |
6 |
1University of Texas at San Antonio
2Microsoft AI for Good Research Lab
7 |
8 |

9 |

10 |
11 |
12 | Code and experiments for the paper, ["A Change Detection Reality Check", Isaac Corley, Caleb Robinson, Anthony Ortiz](https://arxiv.org/abs/2402.06994) presented at the [ICLR 2024 Machine Learning for Remote Sensing (ML4RS) Workshop](https://ml-for-rs.github.io/iclr2024/)
13 |
14 | ### Summary
15 |
16 | Remote sensing image literature from the past several years has exploded with proposed deep learning architectures that claim to be the latest state-of-the-art on standard change detection benchmark datasets. However, has the field truly made significant progress? In this paper we perform experiments which conclude a simple U-Net segmentation baseline without training tricks or complicated architectural changes is still a top performer for the task of change detection.
17 |
18 | ### Results
19 |
20 | We find that U-Net is still a top performer on the LEVIR-CD and WHU-CD benchmark datasets. See below tables for comparisons with SOTA methods.
21 |
22 |
23 | 
24 | Table 1. Comparison of state-of-the-art and change detection architectures to a U-Net baseline on the LEVIR-CD dataset. We report the test set precision, recall, and F1 metrics of the positive change class. For the baseline experiments we perform 10 runs while varying random the seed and report metrics from the highest performing run. All other metrics are taken from their respective papers. The top performing methods are highlighted in bold. Gray rows indicate our baseline U-Net and siamese encoder variants.
25 |
26 |
27 |
28 | 
29 | Table 2. Experimental results on the WHU-CD dataset. We retrain several state-of-the-art methods using the original dataset’s train/test splits instead of the commonly used randomly split preprocessed version created in (Bandara & Patel (2022a)). We find that these state-of-the-art methods are outperformed by a U-Net baseline. We report the test set precision, recall, F1, and IoU metrics of the positive change class. For each run we select the model checkpoint with the lowest validation set loss. We provide metrics averaged over 10 runs with varying random seed as well as the best seed. Gray rows indicate our baseline U-Net and siamese encoder variants.
30 |
31 |
32 | ### Model Checkpoints
33 |
34 | **Model Checkpoints uploaded to HuggingFace [here](https://huggingface.co/isaaccorley/a-change-detection-reality-check)!
35 |
36 |
37 | #### LEVIR-CD
38 |
39 | | **Model** | **Backbone** | **Precision** | **Recall** | **F1** | **IoU** | **Checkpoint** |
40 | |:--------------: |:---------------: |:---------: |:------: |:------: |:------: |:----------: |
41 | | U-Net | ResNet-50 | 0.9197 | 0.8795 | 0.8991 | 0.8167 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/levir-cd/unet_resnet50.ckpt) |
42 | | U-Net | EfficientNet-B4 | 0.9269 | 0.8588 | 0.8915 | 0.8044 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/levir-cd/unet_efficientnetb4.ckpt) |
43 | | U-Net SiamConc | ResNet-50 | 0.9287 | 0.8749 | 0.9010 | 0.8199 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/levir-cd/unet_siamconc_resnet50.ckpt) |
44 | | U-Net SiamDiff | ResNet-50 | 0.9321 | 0.8730 | 0.9015 | 0.8207 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/levir-cd/unet_siamdiff_resnet50.ckpt) |
45 |
46 | #### WHU-CD (using official train/test splits)
47 |
48 | | **Model** | **Backbone** | **Precision** | **Recall** | **F1** | **IoU** | **Checkpoint** |
49 | |:--------------: |:---------: |:---------: |:------: |:------: |:------: |:----------: |
50 | | U-Net SiamConc | ResNet-50 | 0.8369 | 0.8130 | 0.8217 | 0.7054 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/whu-cd/unet_siamconc_resnet50.ckpt) |
51 | | U-Net SiamDiff | ResNet-50 | 0.8856 | 0.7741 | 0.8248 | 0.7086 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/whu-cd/unet_siamdiff_resnet50.ckpt) |
52 | | U-Net | ResNet-50 | 0.8865 | 0.7663 | 0.8200 | 0.7020 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/whu-cd/unet_resnet50.ckpt) |
53 |
54 | ### Reproducing Results
55 |
56 | Download the [LEVIR-CD](https://chenhao.in/LEVIR/) and [WHU-CD](http://gpcv.whu.edu.cn/data/building_dataset.html) datasets and then use the following notebooks to chip the datasets into non-overlapping 256x256 patches.
57 |
58 | ```bash
59 | scripts/preprocess_levircd.ipynb
60 | scripts/preprocess_whucd.ipynb
61 | ```
62 |
63 | To train UNet on both datasets over 10 random seeds run
64 |
65 | ```bash
66 | python train_levircd.py --train-root /path/to/preprocessed-dataset/ --model unet --backbone resnet50 --num_seeds 10
67 | python train_whucd.py --train-root /path/to/preprocessed-dataset/ --model unet --backbone resnet50 --num_seeds 10
68 | ```
69 |
70 | To evaluate a set of checkpoints and save results to a .csv file run:
71 |
72 | ```bash
73 | python test_levircd.py --root /path/to/preprocessed-dataset/ --ckpt-root lightning_logs/ --output-filename metrics.csv
74 | python test_whucd.py --root /path/to/preprocessed-dataset/ --ckpt-root lightning_logs/ --output-filename metrics.csv
75 | ```
76 |
77 | ### Citation
78 |
79 | If this work inspired your change detection research, please consider citing our paper:
80 |
81 | ```
82 | @article{corley2024change,
83 | title={A Change Detection Reality Check},
84 | author={Corley, Isaac and Robinson, Caleb and Ortiz, Anthony},
85 | journal={arXiv preprint arXiv:2402.06994},
86 | year={2024}
87 | }
88 | ```
89 |
--------------------------------------------------------------------------------
/assets/levircd-results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/assets/levircd-results.png
--------------------------------------------------------------------------------
/assets/whucd-results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/assets/whucd-results.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | target-version = "py311"
3 | line-length = 120
4 | src = ["src", "notebooks"]
5 | force-exclude = true
6 | fix = true
7 | show-fixes = true
8 |
9 | [tool.ruff.format]
10 | skip-magic-trailing-comma = true
11 |
12 | [tool.ruff.lint]
13 | extend-select = ["B", "Q", "I", "UP"]
14 | ignore = [
15 | "E203",
16 | "E402",
17 | "F821",
18 | "F405",
19 | "F403",
20 | "E731",
21 | "B006",
22 | "B008",
23 | "B904",
24 | "E741",
25 | "F401",
26 | ]
27 |
28 | [tool.ruff.lint.pylint]
29 | max-returns = 5
30 | max-args = 25
31 |
32 | [tool.ruff.lint.isort]
33 | split-on-trailing-comma = false
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchgeo[all]==0.6.0
2 | image_bbox_slicer
3 | einops
--------------------------------------------------------------------------------
/scripts/preprocess_levircd.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Chipping directory A\n"
13 | ]
14 | },
15 | {
16 | "name": "stderr",
17 | "output_type": "stream",
18 | "text": [
19 | "/home/ubuntu/miniconda3/envs/torchgeo/lib/python3.11/site-packages/image_bbox_slicer/helpers.py:113: UserWarning: Destination ../data/train-chipped/A directory does not exist so creating it now\n",
20 | " warnings.warn(\n"
21 | ]
22 | },
23 | {
24 | "name": "stdout",
25 | "output_type": "stream",
26 | "text": [
27 | "Obtained 7120 image slices!\n",
28 | "Chipping directory B\n"
29 | ]
30 | },
31 | {
32 | "name": "stderr",
33 | "output_type": "stream",
34 | "text": [
35 | "/home/ubuntu/miniconda3/envs/torchgeo/lib/python3.11/site-packages/image_bbox_slicer/helpers.py:113: UserWarning: Destination ../data/train-chipped/B directory does not exist so creating it now\n",
36 | " warnings.warn(\n"
37 | ]
38 | },
39 | {
40 | "name": "stdout",
41 | "output_type": "stream",
42 | "text": [
43 | "Obtained 7120 image slices!\n",
44 | "Chipping directory label\n"
45 | ]
46 | },
47 | {
48 | "name": "stderr",
49 | "output_type": "stream",
50 | "text": [
51 | "/home/ubuntu/miniconda3/envs/torchgeo/lib/python3.11/site-packages/image_bbox_slicer/helpers.py:113: UserWarning: Destination ../data/train-chipped/label directory does not exist so creating it now\n",
52 | " warnings.warn(\n"
53 | ]
54 | },
55 | {
56 | "name": "stdout",
57 | "output_type": "stream",
58 | "text": [
59 | "Obtained 7120 image slices!\n"
60 | ]
61 | }
62 | ],
63 | "source": [
64 | "import os\n",
65 | "from pathlib import Path\n",
66 | "\n",
67 | "from image_bbox_slicer import Slicer\n",
68 | "from tqdm import tqdm\n",
69 | "\n",
70 | "directories = [\"A\", \"B\", \"label\"]\n",
71 | "root = \"../data/train\"\n",
72 | "output = \"../data/train-chipped/\"\n",
73 | "\n",
74 | "for directory in directories:\n",
75 | " print(f\"Chipping directory {directory}\")\n",
76 | " slicer = Slicer()\n",
77 | " src = os.path.join(root, directory)\n",
78 | " dst = os.path.join(output, directory)\n",
79 | " slicer.config_image_dirs(img_src=src, img_dst=dst)\n",
80 | " slicer.slice_images_by_size(tile_size=(256, 256), tile_overlap=0)"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": 3,
86 | "metadata": {},
87 | "outputs": [
88 | {
89 | "data": {
90 | "text/plain": [
91 | "['A', 'B', 'label']"
92 | ]
93 | },
94 | "execution_count": 3,
95 | "metadata": {},
96 | "output_type": "execute_result"
97 | }
98 | ],
99 | "source": [
100 | "directories"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 4,
106 | "metadata": {},
107 | "outputs": [
108 | {
109 | "name": "stderr",
110 | "output_type": "stream",
111 | "text": [
112 | " 0%| | 0/7120 [00:00, ?it/s]"
113 | ]
114 | },
115 | {
116 | "name": "stderr",
117 | "output_type": "stream",
118 | "text": [
119 | "100%|██████████| 7120/7120 [00:00<00:00, 24150.23it/s]\n",
120 | "100%|██████████| 7120/7120 [00:00<00:00, 21067.20it/s]\n",
121 | "100%|██████████| 7120/7120 [00:00<00:00, 18317.66it/s]\n"
122 | ]
123 | }
124 | ],
125 | "source": [
126 | "root = Path(\"../data/train-chipped/\")\n",
127 | "\n",
128 | "for directory in directories:\n",
129 | " path = root / directory\n",
130 | " files = list(path.glob(\"*.png\"))\n",
131 | " for f in tqdm(files):\n",
132 | " dst = path / f\"train_{f.name}\"\n",
133 | " _ = f.rename(dst)"
134 | ]
135 | }
136 | ],
137 | "metadata": {
138 | "kernelspec": {
139 | "display_name": "torchgeo",
140 | "language": "python",
141 | "name": "python3"
142 | },
143 | "language_info": {
144 | "codemirror_mode": {
145 | "name": "ipython",
146 | "version": 3
147 | },
148 | "file_extension": ".py",
149 | "mimetype": "text/x-python",
150 | "name": "python",
151 | "nbconvert_exporter": "python",
152 | "pygments_lexer": "ipython3",
153 | "version": "3.11.11"
154 | }
155 | },
156 | "nbformat": 4,
157 | "nbformat_minor": 2
158 | }
159 |
--------------------------------------------------------------------------------
/scripts/preprocess_whucd.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 3,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "Chipping directory 2012\n",
13 | "/workspace/storage/data/whucd/Building change detection dataset_add/1. The two-period image data/2012/whole_image/train/image /workspace/storage/data/whucd-chipped/2012/train\n",
14 | "Obtained 4838 image slices!\n",
15 | "/workspace/storage/data/whucd/Building change detection dataset_add/1. The two-period image data/2012/whole_image/test/image /workspace/storage/data/whucd-chipped/2012/test\n",
16 | "Obtained 2596 image slices!\n",
17 | "Chipping directory 2016\n",
18 | "/workspace/storage/data/whucd/Building change detection dataset_add/1. The two-period image data/2016/whole_image/train/image /workspace/storage/data/whucd-chipped/2016/train\n",
19 | "Obtained 4838 image slices!\n",
20 | "/workspace/storage/data/whucd/Building change detection dataset_add/1. The two-period image data/2016/whole_image/test/image /workspace/storage/data/whucd-chipped/2016/test\n",
21 | "Obtained 2596 image slices!\n",
22 | "Chipping directory change_label\n",
23 | "/workspace/storage/data/whucd/Building change detection dataset_add/1. The two-period image data/change_label/train /workspace/storage/data/whucd-chipped/change_label/train\n",
24 | "Obtained 4838 image slices!\n",
25 | "/workspace/storage/data/whucd/Building change detection dataset_add/1. The two-period image data/change_label/test /workspace/storage/data/whucd-chipped/change_label/test\n",
26 | "Obtained 2596 image slices!\n"
27 | ]
28 | }
29 | ],
30 | "source": [
31 | "import os\n",
32 | "\n",
33 | "import image_bbox_slicer.helpers\n",
34 | "from image_bbox_slicer import Slicer\n",
35 | "from PIL import Image\n",
36 | "\n",
37 | "image_bbox_slicer.helpers.IMG_FORMAT_LIST.append(\"tif\")\n",
38 | "Image.MAX_IMAGE_PIXELS = None\n",
39 | "\n",
40 | "files = {\n",
41 | " \"2012\": {\n",
42 | " \"train\": os.path.join(\"2012\", \"whole_image\", \"train\", \"image\", \"2012_train.tif\"),\n",
43 | " \"test\": os.path.join(\"2012\", \"whole_image\", \"test\", \"image\", \"2012_test.tif\"),\n",
44 | " },\n",
45 | " \"2016\": {\n",
46 | " \"train\": os.path.join(\"2016\", \"whole_image\", \"train\", \"image\", \"2016_train.tif\"),\n",
47 | " \"test\": os.path.join(\"2016\", \"whole_image\", \"test\", \"image\", \"2016_test.tif\"),\n",
48 | " },\n",
49 | " \"change_label\": {\n",
50 | " \"train\": os.path.join(\"change_label\", \"train\", \"change_label.tif\"),\n",
51 | " \"test\": os.path.join(\"change_label\", \"test\", \"change_label.tif\"),\n",
52 | " },\n",
53 | "}\n",
54 | "patch_size = (256, 256)\n",
55 | "root = \"/workspace/storage/data/whucd/Building change detection dataset_add/1. The two-period image data/\"\n",
56 | "output = \"/workspace/storage/data/whucd-chipped/\"\n",
57 | "\n",
58 | "for directory in files:\n",
59 | " print(f\"Chipping directory {directory}\")\n",
60 | " for split in files[directory]:\n",
61 | " src = os.path.join(root, os.path.dirname(files[directory][split]))\n",
62 | " dst = os.path.join(output, directory, split)\n",
63 | " print(src, dst)\n",
64 | " slicer = Slicer()\n",
65 | " slicer.config_image_dirs(img_src=src, img_dst=dst)\n",
66 | " slicer.slice_images_by_size(tile_size=patch_size, tile_overlap=0)"
67 | ]
68 | }
69 | ],
70 | "metadata": {
71 | "kernelspec": {
72 | "display_name": "torchenv",
73 | "language": "python",
74 | "name": "python3"
75 | },
76 | "language_info": {
77 | "codemirror_mode": {
78 | "name": "ipython",
79 | "version": 3
80 | },
81 | "file_extension": ".py",
82 | "mimetype": "text/x-python",
83 | "name": "python",
84 | "nbconvert_exporter": "python",
85 | "pygments_lexer": "ipython3",
86 | "version": "3.11.5"
87 | }
88 | },
89 | "nbformat": 4,
90 | "nbformat_minor": 2
91 | }
92 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/__init__.py
--------------------------------------------------------------------------------
/src/change_detection.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation. All rights reserved.
2 | # Licensed under the MIT License.
3 |
4 | """Trainers for change detection."""
5 |
6 | import os
7 | import warnings
8 | from typing import Any
9 |
10 | import matplotlib.pyplot as plt
11 | import segmentation_models_pytorch as smp
12 | import torch
13 | import torch.nn as nn
14 | from matplotlib.figure import Figure
15 | from torch import Tensor
16 | from torchgeo.datasets import RGBBandsMissingError, unbind_samples
17 | from torchgeo.models import FCSiamConc, FCSiamDiff, get_weight
18 | from torchgeo.trainers import utils
19 | from torchgeo.trainers.base import BaseTask
20 | from torchmetrics import MetricCollection
21 | from torchmetrics.classification import Accuracy, FBetaScore, JaccardIndex, Precision, Recall
22 | from torchmetrics.wrappers import ClasswiseWrapper
23 | from torchvision.models._api import WeightsEnum
24 |
25 | from .models import BIT, ChangeFormerV6, TinyCD
26 |
27 |
28 | class ChangeDetectionTask(BaseTask):
29 | """Change Detection."""
30 |
31 | def __init__(
32 | self,
33 | model: str = "unet",
34 | backbone: str = "resnet50",
35 | weights: WeightsEnum | str | bool | None = None,
36 | in_channels: int = 3,
37 | num_classes: int = 2,
38 | class_weights: Tensor | None = None,
39 | labels: list[str] | None = None,
40 | loss: str = "ce",
41 | ignore_index: int | None = None,
42 | lr: float = 1e-3,
43 | patience: int = 10,
44 | freeze_backbone: bool = False,
45 | freeze_decoder: bool = False,
46 | ) -> None:
47 | if ignore_index is not None and loss == "jaccard":
48 | warnings.warn("ignore_index has no effect on training when loss='jaccard'", UserWarning, stacklevel=2)
49 |
50 | self.weights = weights
51 | super().__init__(ignore="weights")
52 |
53 | def configure_losses(self) -> None:
54 | """Initialize the loss criterion.
55 | Raises:
56 | ValueError: If *loss* is invalid.
57 | """
58 | loss: str = self.hparams["loss"]
59 | ignore_index = self.hparams["ignore_index"]
60 | if loss == "ce":
61 | ignore_value = -1000 if ignore_index is None else ignore_index
62 | self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=self.hparams["class_weights"])
63 | elif loss == "jaccard":
64 | self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=self.hparams["num_classes"])
65 | elif loss == "focal":
66 | self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
67 | else:
68 | raise ValueError(f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard' or 'focal' loss.")
69 |
70 | def configure_metrics(self) -> None:
71 | """Initialize the performance metrics."""
72 | num_classes: int = self.hparams["num_classes"]
73 | ignore_index: int | None = self.hparams["ignore_index"]
74 | labels: list[str] | None = self.hparams["labels"]
75 |
76 | self.train_metrics = MetricCollection(
77 | {
78 | "OverallAccuracy": Accuracy(
79 | task="multiclass", num_classes=num_classes, average="micro", multidim_average="global"
80 | ),
81 | "OverallF1Score": FBetaScore(
82 | task="multiclass", num_classes=num_classes, beta=1.0, average="micro", multidim_average="global"
83 | ),
84 | "OverallIoU": JaccardIndex(
85 | task="multiclass", num_classes=num_classes, ignore_index=ignore_index, average="micro"
86 | ),
87 | "AverageAccuracy": Accuracy(
88 | task="multiclass", num_classes=num_classes, average="macro", multidim_average="global"
89 | ),
90 | "AverageF1Score": FBetaScore(
91 | task="multiclass", num_classes=num_classes, beta=1.0, average="macro", multidim_average="global"
92 | ),
93 | "AverageIoU": JaccardIndex(
94 | task="multiclass", num_classes=num_classes, ignore_index=ignore_index, average="macro"
95 | ),
96 | "Accuracy": ClasswiseWrapper(
97 | Accuracy(task="multiclass", num_classes=num_classes, average="none", multidim_average="global"),
98 | labels=labels,
99 | ),
100 | "Precision": ClasswiseWrapper(
101 | Precision(task="multiclass", num_classes=num_classes, average="none", multidim_average="global"),
102 | labels=labels,
103 | ),
104 | "Recall": ClasswiseWrapper(
105 | Recall(task="multiclass", num_classes=num_classes, average="none", multidim_average="global"),
106 | labels=labels,
107 | ),
108 | "F1Score": ClasswiseWrapper(
109 | FBetaScore(
110 | task="multiclass", num_classes=num_classes, beta=1.0, average="none", multidim_average="global"
111 | ),
112 | labels=labels,
113 | ),
114 | "IoU": ClasswiseWrapper(
115 | JaccardIndex(task="multiclass", num_classes=num_classes, average="none"), labels=labels
116 | ),
117 | },
118 | prefix="train_",
119 | )
120 | self.val_metrics = self.train_metrics.clone(prefix="val_")
121 | self.test_metrics = self.train_metrics.clone(prefix="test_")
122 |
123 | def configure_optimizers(self):
124 | def lambda_rule(epoch):
125 | lr_l = 1.0 - epoch / float(self.trainer.max_epochs + 1)
126 | return lr_l
127 |
128 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams["lr"], momentum=0.9, weight_decay=5e-4)
129 | scheduler = scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
130 |
131 | return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}}
132 |
133 | def configure_models(self) -> None:
134 | """Initialize the model.
135 | Raises:
136 | ValueError: If *model* is invalid.
137 | """
138 | model: str = self.hparams["model"]
139 | backbone: str = self.hparams["backbone"]
140 | weights = self.weights
141 | in_channels: int = self.hparams["in_channels"]
142 | num_classes: int = self.hparams["num_classes"]
143 |
144 | if model == "unet":
145 | self.model = smp.Unet(
146 | encoder_name=backbone,
147 | encoder_weights="imagenet" if weights is True else None,
148 | in_channels=in_channels * 2, # images are concatenated
149 | classes=num_classes,
150 | )
151 | elif model == "fcsiamdiff":
152 | self.model = FCSiamDiff(
153 | encoder_name=backbone,
154 | in_channels=in_channels,
155 | classes=num_classes,
156 | encoder_weights="imagenet" if weights is True else None,
157 | )
158 | elif model == "fcsiamconc":
159 | self.model = FCSiamConc(
160 | encoder_name=backbone,
161 | in_channels=in_channels,
162 | classes=num_classes,
163 | encoder_weights="imagenet" if weights is True else None,
164 | )
165 | elif model == "bit":
166 | self.model = BIT(arch="base_transformer_pos_s4_dd8")
167 | elif model == "changeformer":
168 | self.model = ChangeFormerV6(
169 | input_nc=in_channels, output_nc=num_classes, decoder_softmax=False, embed_dim=256
170 | )
171 | elif model == "tinycd":
172 | self.model = TinyCD(
173 | bkbn_name="efficientnet_b4", pretrained=True, output_layer_bkbn="3", freeze_backbone=False
174 | )
175 | else:
176 | raise ValueError(f"Model type '{model}' is not valid.")
177 |
178 | if weights and weights is not True:
179 | if isinstance(weights, WeightsEnum):
180 | state_dict = weights.get_state_dict(progress=True)
181 | elif os.path.exists(weights):
182 | _, state_dict = utils.extract_backbone(weights)
183 | else:
184 | state_dict = get_weight(weights).get_state_dict(progress=True)
185 | self.model.encoder.load_state_dict(state_dict)
186 |
187 | # Freeze backbone
188 | if self.hparams["freeze_backbone"] and model in ["unet"]:
189 | for param in self.model.encoder.parameters():
190 | param.requires_grad = False
191 |
192 | # Freeze decoder
193 | if self.hparams["freeze_decoder"] and model in ["unet"]:
194 | for param in self.model.decoder.parameters():
195 | param.requires_grad = False
196 |
197 | def training_step(self, batch: Any, batch_idx: int) -> Tensor:
198 | image1, image2, y = batch["image1"], batch["image2"], batch["mask"]
199 |
200 | model: str = self.hparams["model"]
201 | if model == "unet":
202 | x = torch.cat([image1, image2], dim=1)
203 | elif model in ["fcsiamdiff", "fcsiamconc"]:
204 | x = torch.stack((image1, image2), dim=1)
205 |
206 | if model in ["bit", "changeformer", "tinycd"]:
207 | y_hat = self(image1, image2)
208 | else:
209 | y_hat = self(x)
210 |
211 | loss: Tensor = self.criterion(y_hat, y)
212 |
213 | self.log("train_loss", loss)
214 |
215 | y_hat = torch.softmax(y_hat, dim=1)
216 | y_hat_hard = y_hat.argmax(dim=1)
217 |
218 | self.train_metrics(y_hat_hard, y)
219 | self.log_dict({f"{k}": v for k, v in self.train_metrics.compute().items()})
220 |
221 | return loss
222 |
223 | def validation_step(self, batch: Any, batch_idx: int) -> None:
224 | image1, image2, y = batch["image1"], batch["image2"], batch["mask"]
225 |
226 | model: str = self.hparams["model"]
227 | if model == "unet":
228 | x = torch.cat([image1, image2], dim=1)
229 | elif model in ["fcsiamdiff", "fcsiamconc"]:
230 | x = torch.stack((image1, image2), dim=1)
231 |
232 | if model in ["bit", "changeformer", "tinycd"]:
233 | y_hat = self(image1, image2)
234 | else:
235 | y_hat = self(x)
236 |
237 | loss: Tensor = self.criterion(y_hat, y)
238 |
239 | self.log("val_loss", loss, on_epoch=True)
240 |
241 | y_hat = torch.softmax(y_hat, dim=1)
242 | y_hat_hard = y_hat.argmax(dim=1)
243 |
244 | self.val_metrics(y_hat_hard, y)
245 | self.log_dict({f"{k}": v for k, v in self.val_metrics.compute().items()}, on_epoch=True)
246 |
247 | if (
248 | batch_idx < 10
249 | and hasattr(self.trainer, "datamodule")
250 | and hasattr(self.trainer.datamodule, "plot")
251 | and self.logger
252 | and hasattr(self.logger, "experiment")
253 | and hasattr(self.logger.experiment, "add_figure")
254 | ):
255 | datamodule = self.trainer.datamodule
256 | batch["prediction"] = y_hat_hard
257 | for key in ["image1", "image2", "mask", "prediction"]:
258 | batch[key] = batch[key].cpu()
259 | sample = unbind_samples(batch)[0]
260 |
261 | fig: Figure | None = None
262 | try:
263 | fig = datamodule.plot(sample)
264 | except RGBBandsMissingError:
265 | pass
266 |
267 | if fig:
268 | summary_writer = self.logger.experiment
269 | summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step)
270 | plt.close()
271 |
272 | def test_step(self, batch: Any, batch_idx: int) -> None:
273 | image1, image2, y = batch["image1"], batch["image2"], batch["mask"]
274 |
275 | model: str = self.hparams["model"]
276 | if model == "unet":
277 | x = torch.cat([image1, image2], dim=1)
278 | elif model in ["fcsiamdiff", "fcsiamconc"]:
279 | x = torch.stack((image1, image2), dim=1)
280 |
281 | if model in ["bit", "changeformer", "tinycd"]:
282 | y_hat = self(image1, image2)
283 | else:
284 | y_hat = self(x)
285 |
286 | loss: Tensor = self.criterion(y_hat, y)
287 |
288 | self.log("test_loss", loss, on_epoch=True)
289 |
290 | y_hat = torch.softmax(y_hat, dim=1)
291 | y_hat_hard = y_hat.argmax(dim=1)
292 |
293 | self.test_metrics(y_hat_hard, y)
294 | self.log_dict({f"{k}": v for k, v in self.test_metrics.compute().items()}, on_epoch=True)
295 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/datasets/__init__.py
--------------------------------------------------------------------------------
/src/datasets/levircd.py:
--------------------------------------------------------------------------------
1 | import kornia.augmentation as K
2 | import torchgeo.datamodules
3 | import torchgeo.datasets
4 | from torchgeo.transforms import AugmentationSequential
5 | from torchgeo.transforms.transforms import _ExtractPatches
6 |
7 |
8 | class LEVIRCDDataModule(torchgeo.datamodules.LEVIRCDDataModule):
9 | train_root = ""
10 |
11 | def __init__(self, *args, **kwargs):
12 | super().__init__(*args, **kwargs)
13 | self.train_aug = AugmentationSequential(
14 | K.RandomHorizontalFlip(p=0.5),
15 | K.RandomVerticalFlip(p=0.5),
16 | K.RandomResizedCrop(size=self.patch_size, scale=(0.8, 1.0), ratio=(1, 1), p=1.0),
17 | K.Normalize(mean=0.0, std=255.0),
18 | K.Normalize(mean=0.5, std=0.5),
19 | data_keys=["image1", "image2", "mask"],
20 | )
21 | self.val_aug = AugmentationSequential(
22 | K.Normalize(mean=0.0, std=255.0),
23 | K.Normalize(mean=0.5, std=0.5),
24 | _ExtractPatches(window_size=self.patch_size),
25 | data_keys=["image1", "image2", "mask"],
26 | same_on_batch=True,
27 | )
28 | self.test_aug = AugmentationSequential(
29 | K.Normalize(mean=0.0, std=255.0),
30 | K.Normalize(mean=0.5, std=0.5),
31 | _ExtractPatches(window_size=self.patch_size),
32 | data_keys=["image1", "image2", "mask"],
33 | same_on_batch=True,
34 | )
35 |
36 | def setup(self, stage: str) -> None:
37 | if stage in ["fit"]:
38 | self.train_dataset = torchgeo.datasets.LEVIRCD(root=self.train_root, split="train")
39 | if stage in ["fit", "validate"]:
40 | self.val_dataset = torchgeo.datasets.LEVIRCD(split="val", **self.kwargs)
41 | if stage == "test":
42 | self.test_dataset = torchgeo.datasets.LEVIRCD(split="test", **self.kwargs)
43 |
--------------------------------------------------------------------------------
/src/datasets/whucd.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | from collections.abc import Callable
4 | from typing import Optional
5 |
6 | import kornia.augmentation as K
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import torch
10 | import torchgeo
11 | from matplotlib.figure import Figure
12 | from PIL import Image
13 | from torch import Tensor
14 | from torchgeo.datamodules.geo import NonGeoDataModule
15 | from torchgeo.datamodules.utils import dataset_split
16 | from torchgeo.datasets.utils import percentile_normalization
17 | from torchgeo.transforms import AugmentationSequential
18 | from torchgeo.transforms.transforms import _ExtractPatches
19 |
20 |
21 | class WHUCD(torch.utils.data.Dataset):
22 | splits = ["train", "test"]
23 |
24 | def __init__(
25 | self,
26 | root: str = "data",
27 | split: str = "train",
28 | transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
29 | ) -> None:
30 | assert split in self.splits
31 |
32 | self.root = root
33 | self.split = split
34 | self.transforms = transforms
35 | self.files = self._load_files(self.root, self.split)
36 |
37 | def __getitem__(self, index: int) -> dict[str, Tensor]:
38 | files = self.files[index]
39 | image1 = self._load_image(files["image1"])
40 | image2 = self._load_image(files["image2"])
41 | mask = self._load_target(files["mask"])
42 | sample = {"image1": image1, "image2": image2, "mask": mask}
43 |
44 | if self.transforms is not None:
45 | sample = self.transforms(sample)
46 |
47 | return sample
48 |
49 | def __len__(self) -> int:
50 | return len(self.files)
51 |
52 | def _load_image(self, path: str) -> Tensor:
53 | filename = os.path.join(path)
54 | with Image.open(filename) as img:
55 | array = np.array(img.convert("RGB"))
56 | tensor = torch.from_numpy(array)
57 | tensor = tensor.float()
58 | tensor = tensor.permute((2, 0, 1))
59 | return tensor
60 |
61 | def _load_target(self, path: str) -> Tensor:
62 | filename = os.path.join(path)
63 | with Image.open(filename) as img:
64 | array = np.array(img.convert("L"))
65 | tensor = torch.from_numpy(array)
66 | tensor = torch.clamp(tensor, min=0, max=1)
67 | tensor = tensor.to(torch.long)
68 | return tensor
69 |
70 | def plot(self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None) -> Figure:
71 | ncols = 3
72 |
73 | image1 = sample["image1"].permute(1, 2, 0).numpy()
74 | image1 = percentile_normalization(image1, lower=0, upper=98, axis=(0, 1))
75 |
76 | image2 = sample["image2"].permute(1, 2, 0).numpy()
77 | image2 = percentile_normalization(image2, lower=0, upper=98, axis=(0, 1))
78 |
79 | if "prediction" in sample:
80 | ncols += 1
81 |
82 | fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5))
83 |
84 | axs[0].imshow(image1)
85 | axs[0].axis("off")
86 | axs[1].imshow(image2)
87 | axs[1].axis("off")
88 | axs[2].imshow(sample["mask"], cmap="gray", interpolation="none")
89 | axs[2].axis("off")
90 |
91 | if "prediction" in sample:
92 | axs[3].imshow(sample["prediction"], cmap="gray", interpolation="none")
93 | axs[3].axis("off")
94 | if show_titles:
95 | axs[3].set_title("Prediction")
96 |
97 | if show_titles:
98 | axs[0].set_title("Image 1")
99 | axs[1].set_title("Image 2")
100 | axs[2].set_title("Mask")
101 |
102 | if suptitle is not None:
103 | plt.suptitle(suptitle)
104 |
105 | return fig
106 |
107 | def _load_files(self, root: str, split: str) -> list[dict[str, str]]:
108 | images1 = sorted(glob.glob(os.path.join(root, "2012", split, "*.tif")))
109 | images2 = sorted(glob.glob(os.path.join(root, "2016", split, "*.tif")))
110 | masks = sorted(glob.glob(os.path.join(root, "change_label", split, "*.tif")))
111 |
112 | files = []
113 | for image1, image2, mask in zip(images1, images2, masks, strict=False):
114 | files.append(dict(image1=image1, image2=image2, mask=mask))
115 | return files
116 |
117 |
118 | class WHUCDDataModule(NonGeoDataModule):
119 | def __init__(self, patch_size: int = 256, val_split_pct: float = 0.1, *args, **kwargs):
120 | super().__init__(WHUCD, *args, **kwargs)
121 |
122 | self.patch_size = (patch_size, patch_size)
123 | self.val_split_pct = val_split_pct
124 |
125 | self.train_aug = AugmentationSequential(
126 | K.RandomHorizontalFlip(p=0.5),
127 | K.RandomVerticalFlip(p=0.5),
128 | K.RandomResizedCrop(size=self.patch_size, scale=(0.8, 1.0), ratio=(1, 1), p=1.0),
129 | K.Normalize(mean=0.0, std=255.0),
130 | K.Normalize(mean=0.5, std=0.5),
131 | data_keys=["image1", "image2", "mask"],
132 | )
133 | self.val_aug = AugmentationSequential(
134 | K.Normalize(mean=0.0, std=255.0),
135 | K.Normalize(mean=0.5, std=0.5),
136 | _ExtractPatches(window_size=self.patch_size),
137 | data_keys=["image1", "image2", "mask"],
138 | same_on_batch=True,
139 | )
140 | self.test_aug = AugmentationSequential(
141 | K.Normalize(mean=0.0, std=255.0),
142 | K.Normalize(mean=0.5, std=0.5),
143 | _ExtractPatches(window_size=self.patch_size),
144 | data_keys=["image1", "image2", "mask"],
145 | same_on_batch=True,
146 | )
147 |
148 | def setup(self, stage: str) -> None:
149 | if stage in ["fit", "validate"]:
150 | self.dataset = WHUCD(split="train", **self.kwargs)
151 | self.train_dataset, self.val_dataset = dataset_split(self.dataset, val_pct=self.val_split_pct)
152 | if stage == "test":
153 | self.test_dataset = WHUCD(split="test", **self.kwargs)
154 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .bit.networks import define_G as BIT
2 | from .changeformer.ChangeFormer import ChangeFormerV6
3 | from .tiny_cd.change_classifier import TinyCD
4 |
--------------------------------------------------------------------------------
/src/models/bit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/models/bit/__init__.py
--------------------------------------------------------------------------------
/src/models/bit/help_funcs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange
4 | from torch import nn
5 |
6 |
7 | class TwoLayerConv2d(nn.Sequential):
8 | def __init__(self, in_channels, out_channels, kernel_size=3):
9 | super().__init__(
10 | nn.Conv2d(
11 | in_channels, in_channels, kernel_size=kernel_size, padding=kernel_size // 2, stride=1, bias=False
12 | ),
13 | nn.BatchNorm2d(in_channels),
14 | nn.ReLU(),
15 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, stride=1),
16 | )
17 |
18 |
19 | class Residual(nn.Module):
20 | def __init__(self, fn):
21 | super().__init__()
22 | self.fn = fn
23 |
24 | def forward(self, x, **kwargs):
25 | return self.fn(x, **kwargs) + x
26 |
27 |
28 | class Residual2(nn.Module):
29 | def __init__(self, fn):
30 | super().__init__()
31 | self.fn = fn
32 |
33 | def forward(self, x, x2, **kwargs):
34 | return self.fn(x, x2, **kwargs) + x
35 |
36 |
37 | class PreNorm(nn.Module):
38 | def __init__(self, dim, fn):
39 | super().__init__()
40 | self.norm = nn.LayerNorm(dim)
41 | self.fn = fn
42 |
43 | def forward(self, x, **kwargs):
44 | return self.fn(self.norm(x), **kwargs)
45 |
46 |
47 | class PreNorm2(nn.Module):
48 | def __init__(self, dim, fn):
49 | super().__init__()
50 | self.norm = nn.LayerNorm(dim)
51 | self.fn = fn
52 |
53 | def forward(self, x, x2, **kwargs):
54 | return self.fn(self.norm(x), self.norm(x2), **kwargs)
55 |
56 |
57 | class FeedForward(nn.Module):
58 | def __init__(self, dim, hidden_dim, dropout=0.0):
59 | super().__init__()
60 | self.net = nn.Sequential(
61 | nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout)
62 | )
63 |
64 | def forward(self, x):
65 | return self.net(x)
66 |
67 |
68 | class Cross_Attention(nn.Module):
69 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, softmax=True):
70 | super().__init__()
71 | inner_dim = dim_head * heads
72 | self.heads = heads
73 | self.scale = dim**-0.5
74 |
75 | self.softmax = softmax
76 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
77 | self.to_k = nn.Linear(dim, inner_dim, bias=False)
78 | self.to_v = nn.Linear(dim, inner_dim, bias=False)
79 |
80 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
81 |
82 | def forward(self, x, m, mask=None):
83 | _b, _n, _, h = *x.shape, self.heads
84 | q = self.to_q(x)
85 | k = self.to_k(m)
86 | v = self.to_v(m)
87 |
88 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), [q, k, v])
89 |
90 | dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
91 | mask_value = -torch.finfo(dots.dtype).max
92 |
93 | if mask is not None:
94 | mask = F.pad(mask.flatten(1), (1, 0), value=True)
95 | assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions"
96 | mask = mask[:, None, :] * mask[:, :, None]
97 | dots.masked_fill_(~mask, mask_value)
98 | del mask
99 |
100 | if self.softmax:
101 | attn = dots.softmax(dim=-1)
102 | else:
103 | attn = dots
104 | # attn = dots
105 | # vis_tmp(dots)
106 |
107 | out = torch.einsum("bhij,bhjd->bhid", attn, v)
108 | out = rearrange(out, "b h n d -> b n (h d)")
109 | out = self.to_out(out)
110 | # vis_tmp2(out)
111 |
112 | return out
113 |
114 |
115 | class Attention(nn.Module):
116 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
117 | super().__init__()
118 | inner_dim = dim_head * heads
119 | self.heads = heads
120 | self.scale = dim**-0.5
121 |
122 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
123 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
124 |
125 | def forward(self, x, mask=None):
126 | _b, _n, _, h = *x.shape, self.heads
127 | qkv = self.to_qkv(x).chunk(3, dim=-1)
128 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
129 |
130 | dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
131 | mask_value = -torch.finfo(dots.dtype).max
132 |
133 | if mask is not None:
134 | mask = F.pad(mask.flatten(1), (1, 0), value=True)
135 | assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions"
136 | mask = mask[:, None, :] * mask[:, :, None]
137 | dots.masked_fill_(~mask, mask_value)
138 | del mask
139 |
140 | attn = dots.softmax(dim=-1)
141 |
142 | out = torch.einsum("bhij,bhjd->bhid", attn, v)
143 | out = rearrange(out, "b h n d -> b n (h d)")
144 | out = self.to_out(out)
145 | return out
146 |
147 |
148 | class Transformer(nn.Module):
149 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
150 | super().__init__()
151 | self.layers = nn.ModuleList([])
152 | for _ in range(depth):
153 | self.layers.append(
154 | nn.ModuleList(
155 | [
156 | Residual(PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))),
157 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))),
158 | ]
159 | )
160 | )
161 |
162 | def forward(self, x, mask=None):
163 | for attn, ff in self.layers:
164 | x = attn(x, mask=mask)
165 | x = ff(x)
166 | return x
167 |
168 |
169 | class TransformerDecoder(nn.Module):
170 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True):
171 | super().__init__()
172 | self.layers = nn.ModuleList([])
173 | for _ in range(depth):
174 | self.layers.append(
175 | nn.ModuleList(
176 | [
177 | Residual2(
178 | PreNorm2(
179 | dim,
180 | Cross_Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, softmax=softmax),
181 | )
182 | ),
183 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))),
184 | ]
185 | )
186 | )
187 |
188 | def forward(self, x, m, mask=None):
189 | """target(query), memory"""
190 | for attn, ff in self.layers:
191 | x = attn(x, m, mask=mask)
192 | x = ff(x)
193 | return x
194 |
--------------------------------------------------------------------------------
/src/models/bit/networks.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 | from torch.nn import init
8 | from torch.optim import lr_scheduler
9 |
10 | from .help_funcs import Transformer, TransformerDecoder, TwoLayerConv2d
11 | from .resnet import resnet18, resnet34, resnet50
12 |
13 | ###############################################################################
14 | # Helper Functions
15 | ###############################################################################
16 |
17 |
18 | def get_scheduler(optimizer, args):
19 | """Return a learning rate scheduler
20 |
21 | Parameters:
22 | optimizer -- the optimizer of the network
23 | args (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
24 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
25 |
26 | For 'linear', we keep the same learning rate for the first epochs
27 | and linearly decay the rate to zero over the next epochs.
28 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
29 | See https://pytorch.org/docs/stable/optim.html for more details.
30 | """
31 | if args.lr_policy == "linear":
32 |
33 | def lambda_rule(epoch):
34 | lr_l = 1.0 - epoch / float(args.max_epochs + 1)
35 | return lr_l
36 |
37 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
38 | elif args.lr_policy == "step":
39 | step_size = args.max_epochs // 3
40 | # args.lr_decay_iters
41 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
42 | else:
43 | return NotImplementedError("learning rate policy [%s] is not implemented", args.lr_policy)
44 | return scheduler
45 |
46 |
47 | class Identity(nn.Module):
48 | def forward(self, x):
49 | return x
50 |
51 |
52 | def get_norm_layer(norm_type="instance"):
53 | """Return a normalization layer
54 |
55 | Parameters:
56 | norm_type (str) -- the name of the normalization layer: batch | instance | none
57 |
58 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
59 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
60 | """
61 | if norm_type == "batch":
62 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
63 | elif norm_type == "instance":
64 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
65 | elif norm_type == "none":
66 | norm_layer = lambda x: Identity()
67 | else:
68 | raise NotImplementedError(f"normalization layer [{norm_type}] is not found")
69 | return norm_layer
70 |
71 |
72 | def init_weights(net, init_type="normal", init_gain=0.02):
73 | """Initialize network weights.
74 |
75 | Parameters:
76 | net (network) -- network to be initialized
77 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
78 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
79 |
80 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
81 | work better for some applications. Feel free to try yourself.
82 | """
83 |
84 | def init_func(m): # define the initialization function
85 | classname = m.__class__.__name__
86 | if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
87 | if init_type == "normal":
88 | init.normal_(m.weight.data, 0.0, init_gain)
89 | elif init_type == "xavier":
90 | init.xavier_normal_(m.weight.data, gain=init_gain)
91 | elif init_type == "kaiming":
92 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
93 | elif init_type == "orthogonal":
94 | init.orthogonal_(m.weight.data, gain=init_gain)
95 | else:
96 | raise NotImplementedError(f"initialization method [{init_type}] is not implemented")
97 | if hasattr(m, "bias") and m.bias is not None:
98 | init.constant_(m.bias.data, 0.0)
99 | elif (
100 | classname.find("BatchNorm2d") != -1
101 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
102 | init.normal_(m.weight.data, 1.0, init_gain)
103 | init.constant_(m.bias.data, 0.0)
104 |
105 | print(f"initialize network with {init_type}")
106 | net.apply(init_func) # apply the initialization function
107 |
108 |
109 | def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]):
110 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
111 | Parameters:
112 | net (network) -- the network to be initialized
113 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
114 | gain (float) -- scaling factor for normal, xavier and orthogonal.
115 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
116 |
117 | Return an initialized network.
118 | """
119 | """
120 | if len(gpu_ids) > 0:
121 | assert torch.cuda.is_available()
122 | net.to(gpu_ids[0])
123 | if len(gpu_ids) > 1:
124 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
125 | """
126 | init_weights(net, init_type, init_gain=init_gain)
127 | return net
128 |
129 |
130 | def define_G(arch, init_type="normal", init_gain=0.02, gpu_ids=[]):
131 | if arch == "base_resnet18":
132 | net = ResNet(input_nc=3, output_nc=2, output_sigmoid=False)
133 |
134 | elif arch == "base_transformer_pos_s4":
135 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, with_pos="learned")
136 |
137 | elif arch == "base_transformer_pos_s4_dd8":
138 | net = BASE_Transformer(
139 | input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, with_pos="learned", enc_depth=1, dec_depth=8
140 | )
141 |
142 | elif arch == "base_transformer_pos_s4_dd8_dedim8":
143 | net = BASE_Transformer(
144 | input_nc=3,
145 | output_nc=2,
146 | token_len=4,
147 | resnet_stages_num=4,
148 | with_pos="learned",
149 | enc_depth=1,
150 | dec_depth=8,
151 | decoder_dim_head=8,
152 | )
153 |
154 | else:
155 | raise NotImplementedError(f"Generator model name [{arch}] is not recognized")
156 | return init_net(net, init_type, init_gain, gpu_ids)
157 |
158 |
159 | ###############################################################################
160 | # main Functions
161 | ###############################################################################
162 |
163 |
164 | class ResNet(torch.nn.Module):
165 | def __init__(
166 | self, input_nc, output_nc, resnet_stages_num=5, backbone="resnet18", output_sigmoid=False, if_upsample_2x=True
167 | ):
168 | """
169 | In the constructor we instantiate two nn.Linear modules and assign them as
170 | member variables.
171 | """
172 | super().__init__()
173 | expand = 1
174 | if backbone == "resnet18":
175 | self.resnet = resnet18(pretrained=True, replace_stride_with_dilation=[False, True, True])
176 | elif backbone == "resnet34":
177 | self.resnet = resnet34(pretrained=True, replace_stride_with_dilation=[False, True, True])
178 | elif backbone == "resnet50":
179 | self.resnet = resnet50(pretrained=True, replace_stride_with_dilation=[False, True, True])
180 | expand = 4
181 | else:
182 | raise NotImplementedError
183 | self.relu = nn.ReLU()
184 | self.upsamplex2 = nn.Upsample(scale_factor=2)
185 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode="bilinear")
186 |
187 | self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc)
188 |
189 | self.resnet_stages_num = resnet_stages_num
190 |
191 | self.if_upsample_2x = if_upsample_2x
192 | if self.resnet_stages_num == 5:
193 | layers = 512 * expand
194 | elif self.resnet_stages_num == 4:
195 | layers = 256 * expand
196 | elif self.resnet_stages_num == 3:
197 | layers = 128 * expand
198 | else:
199 | raise NotImplementedError
200 | self.conv_pred = nn.Conv2d(layers, 32, kernel_size=3, padding=1)
201 |
202 | self.output_sigmoid = output_sigmoid
203 | self.sigmoid = nn.Sigmoid()
204 |
205 | def forward(self, x1, x2):
206 | x1 = self.forward_single(x1)
207 | x2 = self.forward_single(x2)
208 | x = torch.abs(x1 - x2)
209 | if not self.if_upsample_2x:
210 | x = self.upsamplex2(x)
211 | x = self.upsamplex4(x)
212 | x = self.classifier(x)
213 |
214 | if self.output_sigmoid:
215 | x = self.sigmoid(x)
216 | return x
217 |
218 | def forward_single(self, x):
219 | # resnet layers
220 | x = self.resnet.conv1(x)
221 | x = self.resnet.bn1(x)
222 | x = self.resnet.relu(x)
223 | x = self.resnet.maxpool(x)
224 |
225 | x_4 = self.resnet.layer1(x) # 1/4, in=64, out=64
226 | x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128
227 |
228 | if self.resnet_stages_num > 3:
229 | x_8 = self.resnet.layer3(x_8) # 1/8, in=128, out=256
230 |
231 | if self.resnet_stages_num == 5:
232 | x_8 = self.resnet.layer4(x_8) # 1/32, in=256, out=512
233 | elif self.resnet_stages_num > 5:
234 | raise NotImplementedError
235 |
236 | if self.if_upsample_2x:
237 | x = self.upsamplex2(x_8)
238 | else:
239 | x = x_8
240 | # output layers
241 | x = self.conv_pred(x)
242 | return x
243 |
244 |
245 | class BASE_Transformer(ResNet):
246 | """
247 | Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN
248 | """
249 |
250 | def __init__(
251 | self,
252 | input_nc,
253 | output_nc,
254 | with_pos,
255 | resnet_stages_num=5,
256 | token_len=4,
257 | token_trans=True,
258 | enc_depth=1,
259 | dec_depth=1,
260 | dim_head=64,
261 | decoder_dim_head=64,
262 | tokenizer=True,
263 | if_upsample_2x=True,
264 | pool_mode="max",
265 | pool_size=2,
266 | backbone="resnet18",
267 | decoder_softmax=True,
268 | with_decoder_pos=None,
269 | with_decoder=True,
270 | ):
271 | super().__init__(
272 | input_nc, output_nc, backbone=backbone, resnet_stages_num=resnet_stages_num, if_upsample_2x=if_upsample_2x
273 | )
274 | self.token_len = token_len
275 | self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1, padding=0, bias=False)
276 | self.tokenizer = tokenizer
277 | if not self.tokenizer:
278 | # if not use tokenzier,then downsample the feature map into a certain size
279 | self.pooling_size = pool_size
280 | self.pool_mode = pool_mode
281 | self.token_len = self.pooling_size * self.pooling_size
282 |
283 | self.token_trans = token_trans
284 | self.with_decoder = with_decoder
285 | dim = 32
286 | mlp_dim = 2 * dim
287 |
288 | self.with_pos = with_pos
289 | if with_pos == "learned":
290 | self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len * 2, 32))
291 | decoder_pos_size = 256 // 4
292 | self.with_decoder_pos = with_decoder_pos
293 | if self.with_decoder_pos == "learned":
294 | self.pos_embedding_decoder = nn.Parameter(torch.randn(1, 32, decoder_pos_size, decoder_pos_size))
295 | self.enc_depth = enc_depth
296 | self.dec_depth = dec_depth
297 | self.dim_head = dim_head
298 | self.decoder_dim_head = decoder_dim_head
299 | self.transformer = Transformer(
300 | dim=dim, depth=self.enc_depth, heads=8, dim_head=self.dim_head, mlp_dim=mlp_dim, dropout=0
301 | )
302 | self.transformer_decoder = TransformerDecoder(
303 | dim=dim,
304 | depth=self.dec_depth,
305 | heads=8,
306 | dim_head=self.decoder_dim_head,
307 | mlp_dim=mlp_dim,
308 | dropout=0,
309 | softmax=decoder_softmax,
310 | )
311 |
312 | def _forward_semantic_tokens(self, x):
313 | b, c, h, w = x.shape
314 | spatial_attention = self.conv_a(x)
315 | spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous()
316 | spatial_attention = torch.softmax(spatial_attention, dim=-1)
317 | x = x.view([b, c, -1]).contiguous()
318 | tokens = torch.einsum("bln,bcn->blc", spatial_attention, x)
319 |
320 | return tokens
321 |
322 | def _forward_reshape_tokens(self, x):
323 | # b,c,h,w = x.shape
324 | if self.pool_mode == "max":
325 | x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size])
326 | elif self.pool_mode == "ave":
327 | x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size])
328 | else:
329 | x = x
330 | tokens = rearrange(x, "b c h w -> b (h w) c")
331 | return tokens
332 |
333 | def _forward_transformer(self, x):
334 | if self.with_pos:
335 | x += self.pos_embedding
336 | x = self.transformer(x)
337 | return x
338 |
339 | def _forward_transformer_decoder(self, x, m):
340 | b, c, h, w = x.shape
341 | if self.with_decoder_pos == "fix":
342 | x = x + self.pos_embedding_decoder
343 | elif self.with_decoder_pos == "learned":
344 | x = x + self.pos_embedding_decoder
345 | x = rearrange(x, "b c h w -> b (h w) c")
346 | x = self.transformer_decoder(x, m)
347 | x = rearrange(x, "b (h w) c -> b c h w", h=h)
348 | return x
349 |
350 | def _forward_simple_decoder(self, x, m):
351 | b, c, h, w = x.shape
352 | b, l, c = m.shape
353 | m = m.expand([h, w, b, l, c])
354 | m = rearrange(m, "h w b l c -> l b c h w")
355 | m = m.sum(0)
356 | x = x + m
357 | return x
358 |
359 | def forward(self, x1, x2):
360 | # forward backbone resnet
361 | x1 = self.forward_single(x1)
362 | x2 = self.forward_single(x2)
363 |
364 | # forward tokenzier
365 | if self.tokenizer:
366 | token1 = self._forward_semantic_tokens(x1)
367 | token2 = self._forward_semantic_tokens(x2)
368 | else:
369 | token1 = self._forward_reshape_tokens(x1)
370 | token2 = self._forward_reshape_tokens(x2)
371 | # forward transformer encoder
372 | if self.token_trans:
373 | self.tokens_ = torch.cat([token1, token2], dim=1)
374 | self.tokens = self._forward_transformer(self.tokens_)
375 | token1, token2 = self.tokens.chunk(2, dim=1)
376 | # forward transformer decoder
377 | if self.with_decoder:
378 | x1 = self._forward_transformer_decoder(x1, token1)
379 | x2 = self._forward_transformer_decoder(x2, token2)
380 | else:
381 | x1 = self._forward_simple_decoder(x1, token1)
382 | x2 = self._forward_simple_decoder(x2, token2)
383 | # feature differencing
384 | x = torch.abs(x1 - x2)
385 | if not self.if_upsample_2x:
386 | x = self.upsamplex2(x)
387 | x = self.upsamplex4(x)
388 | # forward small cnn
389 | x = self.classifier(x)
390 | if self.output_sigmoid:
391 | x = self.sigmoid(x)
392 | return x
393 |
--------------------------------------------------------------------------------
/src/models/bit/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.hub import load_state_dict_from_url
4 |
5 | __all__ = [
6 | "ResNet",
7 | "resnet18",
8 | "resnet34",
9 | "resnet50",
10 | "resnet101",
11 | "resnet152",
12 | "resnext50_32x4d",
13 | "resnext101_32x8d",
14 | "wide_resnet50_2",
15 | "wide_resnet101_2",
16 | ]
17 |
18 |
19 | model_urls = {
20 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
21 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
22 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
23 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
24 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
25 | "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
26 | "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
27 | "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
28 | "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
29 | }
30 |
31 |
32 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
33 | """3x3 convolution with padding"""
34 | return nn.Conv2d(
35 | in_planes,
36 | out_planes,
37 | kernel_size=3,
38 | stride=stride,
39 | padding=dilation,
40 | groups=groups,
41 | bias=False,
42 | dilation=dilation,
43 | )
44 |
45 |
46 | def conv1x1(in_planes, out_planes, stride=1):
47 | """1x1 convolution"""
48 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
49 |
50 |
51 | class BasicBlock(nn.Module):
52 | expansion = 1
53 |
54 | def __init__(
55 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None
56 | ):
57 | super().__init__()
58 | if norm_layer is None:
59 | norm_layer = nn.BatchNorm2d
60 | if groups != 1 or base_width != 64:
61 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
62 | if dilation > 1:
63 | dilation = 1
64 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
65 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
66 | self.conv1 = conv3x3(inplanes, planes, stride)
67 | self.bn1 = norm_layer(planes)
68 | self.relu = nn.ReLU(inplace=True)
69 | self.conv2 = conv3x3(planes, planes)
70 | self.bn2 = norm_layer(planes)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | identity = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 |
84 | if self.downsample is not None:
85 | identity = self.downsample(x)
86 |
87 | out += identity
88 | out = self.relu(out)
89 |
90 | return out
91 |
92 |
93 | class Bottleneck(nn.Module):
94 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
95 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
96 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
97 | # This variant is also known as ResNet V1.5 and improves accuracy according to
98 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
99 |
100 | expansion = 4
101 |
102 | def __init__(
103 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None
104 | ):
105 | super().__init__()
106 | if norm_layer is None:
107 | norm_layer = nn.BatchNorm2d
108 | width = int(planes * (base_width / 64.0)) * groups
109 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
110 | self.conv1 = conv1x1(inplanes, width)
111 | self.bn1 = norm_layer(width)
112 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
113 | self.bn2 = norm_layer(width)
114 | self.conv3 = conv1x1(width, planes * self.expansion)
115 | self.bn3 = norm_layer(planes * self.expansion)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.downsample = downsample
118 | self.stride = stride
119 |
120 | def forward(self, x):
121 | identity = x
122 |
123 | out = self.conv1(x)
124 | out = self.bn1(out)
125 | out = self.relu(out)
126 |
127 | out = self.conv2(out)
128 | out = self.bn2(out)
129 | out = self.relu(out)
130 |
131 | out = self.conv3(out)
132 | out = self.bn3(out)
133 |
134 | if self.downsample is not None:
135 | identity = self.downsample(x)
136 |
137 | out += identity
138 | out = self.relu(out)
139 |
140 | return out
141 |
142 |
143 | class ResNet(nn.Module):
144 | def __init__(
145 | self,
146 | block,
147 | layers,
148 | num_classes=1000,
149 | zero_init_residual=False,
150 | groups=1,
151 | width_per_group=64,
152 | replace_stride_with_dilation=None,
153 | norm_layer=None,
154 | strides=None,
155 | ):
156 | super().__init__()
157 | if norm_layer is None:
158 | norm_layer = nn.BatchNorm2d
159 | self._norm_layer = norm_layer
160 |
161 | self.strides = strides
162 | if self.strides is None:
163 | self.strides = [2, 2, 2, 2, 2]
164 |
165 | self.inplanes = 64
166 | self.dilation = 1
167 | if replace_stride_with_dilation is None:
168 | # each element in the tuple indicates if we should replace
169 | # the 2x2 stride with a dilated convolution instead
170 | replace_stride_with_dilation = [False, False, False]
171 | if len(replace_stride_with_dilation) != 3:
172 | raise ValueError(
173 | f"replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}"
174 | )
175 | self.groups = groups
176 | self.base_width = width_per_group
177 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3, bias=False)
178 | self.bn1 = norm_layer(self.inplanes)
179 | self.relu = nn.ReLU(inplace=True)
180 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1)
181 | self.layer1 = self._make_layer(block, 64, layers[0])
182 | self.layer2 = self._make_layer(
183 | block, 128, layers[1], stride=self.strides[2], dilate=replace_stride_with_dilation[0]
184 | )
185 | self.layer3 = self._make_layer(
186 | block, 256, layers[2], stride=self.strides[3], dilate=replace_stride_with_dilation[1]
187 | )
188 | self.layer4 = self._make_layer(
189 | block, 512, layers[3], stride=self.strides[4], dilate=replace_stride_with_dilation[2]
190 | )
191 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
192 | self.fc = nn.Linear(512 * block.expansion, num_classes)
193 |
194 | for m in self.modules():
195 | if isinstance(m, nn.Conv2d):
196 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
197 | elif isinstance(m, nn.BatchNorm2d | nn.GroupNorm):
198 | nn.init.constant_(m.weight, 1)
199 | nn.init.constant_(m.bias, 0)
200 |
201 | # Zero-initialize the last BN in each residual branch,
202 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
203 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
204 | if zero_init_residual:
205 | for m in self.modules():
206 | if isinstance(m, Bottleneck):
207 | nn.init.constant_(m.bn3.weight, 0)
208 | elif isinstance(m, BasicBlock):
209 | nn.init.constant_(m.bn2.weight, 0)
210 |
211 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
212 | norm_layer = self._norm_layer
213 | downsample = None
214 | previous_dilation = self.dilation
215 | if dilate:
216 | self.dilation *= stride
217 | stride = 1
218 | if stride != 1 or self.inplanes != planes * block.expansion:
219 | downsample = nn.Sequential(
220 | conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion)
221 | )
222 |
223 | layers = []
224 | layers.append(
225 | block(
226 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
227 | )
228 | )
229 | self.inplanes = planes * block.expansion
230 | for _ in range(1, blocks):
231 | layers.append(
232 | block(
233 | self.inplanes,
234 | planes,
235 | groups=self.groups,
236 | base_width=self.base_width,
237 | dilation=self.dilation,
238 | norm_layer=norm_layer,
239 | )
240 | )
241 |
242 | return nn.Sequential(*layers)
243 |
244 | def _forward_impl(self, x):
245 | # See note [TorchScript super()]
246 | x = self.conv1(x)
247 | x = self.bn1(x)
248 | x = self.relu(x)
249 | x = self.maxpool(x)
250 |
251 | x = self.layer1(x)
252 | x = self.layer2(x)
253 | x = self.layer3(x)
254 | x = self.layer4(x)
255 |
256 | x = self.avgpool(x)
257 | x = torch.flatten(x, 1)
258 | x = self.fc(x)
259 |
260 | return x
261 |
262 | def forward(self, x):
263 | return self._forward_impl(x)
264 |
265 |
266 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
267 | model = ResNet(block, layers, **kwargs)
268 | if pretrained:
269 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
270 | model.load_state_dict(state_dict)
271 | return model
272 |
273 |
274 | def resnet18(pretrained=False, progress=True, **kwargs):
275 | r"""ResNet-18 model from
276 | `"Deep Residual Learning for Image Recognition" `_
277 |
278 | Args:
279 | pretrained (bool): If True, returns a model pre-trained on ImageNet
280 | progress (bool): If True, displays a progress bar of the download to stderr
281 | """
282 | return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
283 |
284 |
285 | def resnet34(pretrained=False, progress=True, **kwargs):
286 | r"""ResNet-34 model from
287 | `"Deep Residual Learning for Image Recognition" `_
288 |
289 | Args:
290 | pretrained (bool): If True, returns a model pre-trained on ImageNet
291 | progress (bool): If True, displays a progress bar of the download to stderr
292 | """
293 | return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
294 |
295 |
296 | def resnet50(pretrained=False, progress=True, **kwargs):
297 | r"""ResNet-50 model from
298 | `"Deep Residual Learning for Image Recognition" `_
299 |
300 | Args:
301 | pretrained (bool): If True, returns a model pre-trained on ImageNet
302 | progress (bool): If True, displays a progress bar of the download to stderr
303 | """
304 | return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
305 |
306 |
307 | def resnet101(pretrained=False, progress=True, **kwargs):
308 | r"""ResNet-101 model from
309 | `"Deep Residual Learning for Image Recognition" `_
310 |
311 | Args:
312 | pretrained (bool): If True, returns a model pre-trained on ImageNet
313 | progress (bool): If True, displays a progress bar of the download to stderr
314 | """
315 | return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
316 |
317 |
318 | def resnet152(pretrained=False, progress=True, **kwargs):
319 | r"""ResNet-152 model from
320 | `"Deep Residual Learning for Image Recognition" `_
321 |
322 | Args:
323 | pretrained (bool): If True, returns a model pre-trained on ImageNet
324 | progress (bool): If True, displays a progress bar of the download to stderr
325 | """
326 | return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
327 |
328 |
329 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
330 | r"""ResNeXt-50 32x4d model from
331 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
332 |
333 | Args:
334 | pretrained (bool): If True, returns a model pre-trained on ImageNet
335 | progress (bool): If True, displays a progress bar of the download to stderr
336 | """
337 | kwargs["groups"] = 32
338 | kwargs["width_per_group"] = 4
339 | return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
340 |
341 |
342 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
343 | r"""ResNeXt-101 32x8d model from
344 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
345 |
346 | Args:
347 | pretrained (bool): If True, returns a model pre-trained on ImageNet
348 | progress (bool): If True, displays a progress bar of the download to stderr
349 | """
350 | kwargs["groups"] = 32
351 | kwargs["width_per_group"] = 8
352 | return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
353 |
354 |
355 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
356 | r"""Wide ResNet-50-2 model from
357 | `"Wide Residual Networks" `_
358 |
359 | The model is the same as ResNet except for the bottleneck number of channels
360 | which is twice larger in every block. The number of channels in outer 1x1
361 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
362 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
363 |
364 | Args:
365 | pretrained (bool): If True, returns a model pre-trained on ImageNet
366 | progress (bool): If True, displays a progress bar of the download to stderr
367 | """
368 | kwargs["width_per_group"] = 64 * 2
369 | return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
370 |
371 |
372 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
373 | r"""Wide ResNet-101-2 model from
374 | `"Wide Residual Networks" `_
375 |
376 | The model is the same as ResNet except for the bottleneck number of channels
377 | which is twice larger in every block. The number of channels in outer 1x1
378 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
379 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
380 |
381 | Args:
382 | pretrained (bool): If True, returns a model pre-trained on ImageNet
383 | progress (bool): If True, displays a progress bar of the download to stderr
384 | """
385 | kwargs["width_per_group"] = 64 * 2
386 | return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
387 |
--------------------------------------------------------------------------------
/src/models/changeformer/ChangeFormer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from functools import partial
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional
8 | import torch.nn.functional as F
9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
10 |
11 | from .ChangeFormerBaseNetworks import ConvLayer, ResidualBlock, UpsampleConvLayer
12 |
13 |
14 | class EncoderTransformer(nn.Module):
15 | def __init__(
16 | self,
17 | img_size=256,
18 | patch_size=16,
19 | in_chans=3,
20 | num_classes=2,
21 | embed_dims=[64, 128, 256, 512],
22 | num_heads=[1, 2, 4, 8],
23 | mlp_ratios=[4, 4, 4, 4],
24 | qkv_bias=False,
25 | qk_scale=None,
26 | drop_rate=0.0,
27 | attn_drop_rate=0.0,
28 | drop_path_rate=0.0,
29 | norm_layer=nn.LayerNorm,
30 | depths=[3, 4, 6, 3],
31 | sr_ratios=[8, 4, 2, 1],
32 | ):
33 | super().__init__()
34 | self.num_classes = num_classes
35 | self.depths = depths
36 |
37 | # patch embedding definitions
38 | self.patch_embed1 = OverlapPatchEmbed(
39 | img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0]
40 | )
41 | self.patch_embed2 = OverlapPatchEmbed(
42 | img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]
43 | )
44 | self.patch_embed3 = OverlapPatchEmbed(
45 | img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]
46 | )
47 | self.patch_embed4 = OverlapPatchEmbed(
48 | img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]
49 | )
50 |
51 | # main encoder
52 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
53 | cur = 0
54 | self.block1 = nn.ModuleList(
55 | [
56 | Block(
57 | dim=embed_dims[0],
58 | num_heads=num_heads[0],
59 | mlp_ratio=mlp_ratios[0],
60 | qkv_bias=qkv_bias,
61 | qk_scale=qk_scale,
62 | drop=drop_rate,
63 | attn_drop=attn_drop_rate,
64 | drop_path=dpr[cur + i],
65 | norm_layer=norm_layer,
66 | sr_ratio=sr_ratios[0],
67 | )
68 | for i in range(depths[0])
69 | ]
70 | )
71 | self.norm1 = norm_layer(embed_dims[0])
72 | # intra-patch encoder
73 | self.patch_block1 = nn.ModuleList(
74 | [
75 | Block(
76 | dim=embed_dims[1],
77 | num_heads=num_heads[0],
78 | mlp_ratio=mlp_ratios[0],
79 | qkv_bias=qkv_bias,
80 | qk_scale=qk_scale,
81 | drop=drop_rate,
82 | attn_drop=attn_drop_rate,
83 | drop_path=dpr[cur + i],
84 | norm_layer=norm_layer,
85 | sr_ratio=sr_ratios[0],
86 | )
87 | for i in range(1)
88 | ]
89 | )
90 | self.pnorm1 = norm_layer(embed_dims[1])
91 | # main encoder
92 | cur += depths[0]
93 | self.block2 = nn.ModuleList(
94 | [
95 | Block(
96 | dim=embed_dims[1],
97 | num_heads=num_heads[1],
98 | mlp_ratio=mlp_ratios[1],
99 | qkv_bias=qkv_bias,
100 | qk_scale=qk_scale,
101 | drop=drop_rate,
102 | attn_drop=attn_drop_rate,
103 | drop_path=dpr[cur + i],
104 | norm_layer=norm_layer,
105 | sr_ratio=sr_ratios[1],
106 | )
107 | for i in range(depths[1])
108 | ]
109 | )
110 | self.norm2 = norm_layer(embed_dims[1])
111 | # intra-patch encoder
112 | self.patch_block2 = nn.ModuleList(
113 | [
114 | Block(
115 | dim=embed_dims[2],
116 | num_heads=num_heads[1],
117 | mlp_ratio=mlp_ratios[1],
118 | qkv_bias=qkv_bias,
119 | qk_scale=qk_scale,
120 | drop=drop_rate,
121 | attn_drop=attn_drop_rate,
122 | drop_path=dpr[cur + i],
123 | norm_layer=norm_layer,
124 | sr_ratio=sr_ratios[1],
125 | )
126 | for i in range(1)
127 | ]
128 | )
129 | self.pnorm2 = norm_layer(embed_dims[2])
130 | # main encoder
131 | cur += depths[1]
132 | self.block3 = nn.ModuleList(
133 | [
134 | Block(
135 | dim=embed_dims[2],
136 | num_heads=num_heads[2],
137 | mlp_ratio=mlp_ratios[2],
138 | qkv_bias=qkv_bias,
139 | qk_scale=qk_scale,
140 | drop=drop_rate,
141 | attn_drop=attn_drop_rate,
142 | drop_path=dpr[cur + i],
143 | norm_layer=norm_layer,
144 | sr_ratio=sr_ratios[2],
145 | )
146 | for i in range(depths[2])
147 | ]
148 | )
149 | self.norm3 = norm_layer(embed_dims[2])
150 | # intra-patch encoder
151 | self.patch_block3 = nn.ModuleList(
152 | [
153 | Block(
154 | dim=embed_dims[3],
155 | num_heads=num_heads[1],
156 | mlp_ratio=mlp_ratios[2],
157 | qkv_bias=qkv_bias,
158 | qk_scale=qk_scale,
159 | drop=drop_rate,
160 | attn_drop=attn_drop_rate,
161 | drop_path=dpr[cur + i],
162 | norm_layer=norm_layer,
163 | sr_ratio=sr_ratios[2],
164 | )
165 | for i in range(1)
166 | ]
167 | )
168 | self.pnorm3 = norm_layer(embed_dims[3])
169 | # main encoder
170 | cur += depths[2]
171 | self.block4 = nn.ModuleList(
172 | [
173 | Block(
174 | dim=embed_dims[3],
175 | num_heads=num_heads[3],
176 | mlp_ratio=mlp_ratios[3],
177 | qkv_bias=qkv_bias,
178 | qk_scale=qk_scale,
179 | drop=drop_rate,
180 | attn_drop=attn_drop_rate,
181 | drop_path=dpr[cur + i],
182 | norm_layer=norm_layer,
183 | sr_ratio=sr_ratios[3],
184 | )
185 | for i in range(depths[3])
186 | ]
187 | )
188 | self.norm4 = norm_layer(embed_dims[3])
189 |
190 | self.apply(self._init_weights)
191 |
192 | def _init_weights(self, m):
193 | if isinstance(m, nn.Linear):
194 | trunc_normal_(m.weight, std=0.02)
195 | if isinstance(m, nn.Linear) and m.bias is not None:
196 | nn.init.constant_(m.bias, 0)
197 | elif isinstance(m, nn.LayerNorm):
198 | nn.init.constant_(m.bias, 0)
199 | nn.init.constant_(m.weight, 1.0)
200 | elif isinstance(m, nn.Conv2d):
201 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
202 | fan_out //= m.groups
203 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
204 | if m.bias is not None:
205 | m.bias.data.zero_()
206 |
207 | def reset_drop_path(self, drop_path_rate):
208 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
209 | cur = 0
210 | for i in range(self.depths[0]):
211 | self.block1[i].drop_path.drop_prob = dpr[cur + i]
212 |
213 | cur += self.depths[0]
214 | for i in range(self.depths[1]):
215 | self.block2[i].drop_path.drop_prob = dpr[cur + i]
216 |
217 | cur += self.depths[1]
218 | for i in range(self.depths[2]):
219 | self.block3[i].drop_path.drop_prob = dpr[cur + i]
220 |
221 | cur += self.depths[2]
222 | for i in range(self.depths[3]):
223 | self.block4[i].drop_path.drop_prob = dpr[cur + i]
224 |
225 | def forward_features(self, x):
226 | B = x.shape[0]
227 | outs = []
228 | embed_dims = [64, 128, 320, 512]
229 | # stage 1
230 | x1, H1, W1 = self.patch_embed1(x)
231 |
232 | for _i, blk in enumerate(self.block1):
233 | x1 = blk(x1, H1, W1)
234 | x1 = self.norm1(x1)
235 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
236 |
237 | outs.append(x1)
238 |
239 | # stage 2
240 | x1, H1, W1 = self.patch_embed2(x1)
241 | x1 = x1.permute(0, 2, 1).reshape(B, embed_dims[1], H1, W1)
242 |
243 | x1 = x1.view(x1.shape[0], x1.shape[1], -1).permute(0, 2, 1)
244 |
245 | for _i, blk in enumerate(self.block2):
246 | x1 = blk(x1, H1, W1)
247 | x1 = self.norm2(x1)
248 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
249 | outs.append(x1)
250 |
251 | # stage 3
252 | x1, H1, W1 = self.patch_embed3(x1)
253 | x1 = x1.permute(0, 2, 1).reshape(B, embed_dims[2], H1, W1)
254 |
255 | x1 = x1.view(x1.shape[0], x1.shape[1], -1).permute(0, 2, 1)
256 |
257 | for _i, blk in enumerate(self.block3):
258 | x1 = blk(x1, H1, W1)
259 | x1 = self.norm3(x1)
260 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
261 | outs.append(x1)
262 |
263 | # stage 4
264 | x1, H1, W1 = self.patch_embed4(x1)
265 | x1 = x1.permute(0, 2, 1).reshape(B, embed_dims[3], H1, W1) # +x2
266 |
267 | x1 = x1.view(x1.shape[0], x1.shape[1], -1).permute(0, 2, 1)
268 |
269 | for _i, blk in enumerate(self.block4):
270 | x1 = blk(x1, H1, W1)
271 | x1 = self.norm4(x1)
272 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
273 | outs.append(x1)
274 |
275 | return outs
276 |
277 | def forward(self, x):
278 | x = self.forward_features(x)
279 |
280 | return x
281 |
282 |
283 | class OverlapPatchEmbed(nn.Module):
284 | """Image to Patch Embedding"""
285 |
286 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
287 | super().__init__()
288 | img_size = to_2tuple(img_size)
289 | patch_size = to_2tuple(patch_size)
290 |
291 | self.img_size = img_size
292 | self.patch_size = patch_size
293 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
294 | self.num_patches = self.H * self.W
295 | self.proj = nn.Conv2d(
296 | in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)
297 | )
298 | self.norm = nn.LayerNorm(embed_dim)
299 |
300 | self.apply(self._init_weights)
301 |
302 | def _init_weights(self, m):
303 | if isinstance(m, nn.Linear):
304 | trunc_normal_(m.weight, std=0.02)
305 | if isinstance(m, nn.Linear) and m.bias is not None:
306 | nn.init.constant_(m.bias, 0)
307 | elif isinstance(m, nn.LayerNorm):
308 | nn.init.constant_(m.bias, 0)
309 | nn.init.constant_(m.weight, 1.0)
310 | elif isinstance(m, nn.Conv2d):
311 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
312 | fan_out //= m.groups
313 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
314 | if m.bias is not None:
315 | m.bias.data.zero_()
316 |
317 | def forward(self, x):
318 | # pdb.set_trace()
319 | x = self.proj(x)
320 | _, _, H, W = x.shape
321 | x = x.flatten(2).transpose(1, 2)
322 | x = self.norm(x)
323 |
324 | return x, H, W
325 |
326 |
327 | def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=True):
328 | if warning:
329 | if size is not None and align_corners:
330 | input_h, input_w = tuple(int(x) for x in input.shape[2:])
331 | output_h, output_w = tuple(int(x) for x in size)
332 | if output_h > input_h or output_w > output_h:
333 | if (
334 | (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
335 | and (output_h - 1) % (input_h - 1)
336 | and (output_w - 1) % (input_w - 1)
337 | ):
338 | warnings.warn(
339 | f"When align_corners={align_corners}, "
340 | "the output would more aligned if "
341 | f"input size {(input_h, input_w)} is `x+1` and "
342 | f"out size {(output_h, output_w)} is `nx+1`",
343 | stacklevel=2,
344 | )
345 | return F.interpolate(input, size, scale_factor, mode, align_corners)
346 |
347 |
348 | class Mlp(nn.Module):
349 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
350 | super().__init__()
351 | out_features = out_features or in_features
352 | hidden_features = hidden_features or in_features
353 | self.fc1 = nn.Linear(in_features, hidden_features)
354 | self.dwconv = DWConv(hidden_features)
355 | self.act = act_layer()
356 | self.fc2 = nn.Linear(hidden_features, out_features)
357 | self.drop = nn.Dropout(drop)
358 |
359 | self.apply(self._init_weights)
360 |
361 | def _init_weights(self, m):
362 | if isinstance(m, nn.Linear):
363 | trunc_normal_(m.weight, std=0.02)
364 | if isinstance(m, nn.Linear) and m.bias is not None:
365 | nn.init.constant_(m.bias, 0)
366 | elif isinstance(m, nn.LayerNorm):
367 | nn.init.constant_(m.bias, 0)
368 | nn.init.constant_(m.weight, 1.0)
369 | elif isinstance(m, nn.Conv2d):
370 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
371 | fan_out //= m.groups
372 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
373 | if m.bias is not None:
374 | m.bias.data.zero_()
375 |
376 | def forward(self, x, H, W):
377 | x = self.fc1(x)
378 | x = self.dwconv(x, H, W)
379 | x = self.act(x)
380 | x = self.drop(x)
381 | x = self.fc2(x)
382 | x = self.drop(x)
383 | return x
384 |
385 |
386 | class Attention(nn.Module):
387 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1):
388 | super().__init__()
389 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
390 |
391 | self.dim = dim
392 | self.num_heads = num_heads
393 | head_dim = dim // num_heads
394 | self.scale = qk_scale or head_dim**-0.5
395 |
396 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
397 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
398 | self.attn_drop = nn.Dropout(attn_drop)
399 | self.proj = nn.Linear(dim, dim)
400 | self.proj_drop = nn.Dropout(proj_drop)
401 |
402 | self.sr_ratio = sr_ratio
403 | if sr_ratio > 1:
404 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
405 | self.norm = nn.LayerNorm(dim)
406 |
407 | self.apply(self._init_weights)
408 |
409 | def _init_weights(self, m):
410 | if isinstance(m, nn.Linear):
411 | trunc_normal_(m.weight, std=0.02)
412 | if isinstance(m, nn.Linear) and m.bias is not None:
413 | nn.init.constant_(m.bias, 0)
414 | elif isinstance(m, nn.LayerNorm):
415 | nn.init.constant_(m.bias, 0)
416 | nn.init.constant_(m.weight, 1.0)
417 | elif isinstance(m, nn.Conv2d):
418 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
419 | fan_out //= m.groups
420 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
421 | if m.bias is not None:
422 | m.bias.data.zero_()
423 |
424 | def forward(self, x, H, W):
425 | B, N, C = x.shape
426 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
427 |
428 | if self.sr_ratio > 1:
429 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
430 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
431 | x_ = self.norm(x_)
432 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
433 | else:
434 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
435 | k, v = kv[0], kv[1]
436 |
437 | attn = (q @ k.transpose(-2, -1)) * self.scale
438 | attn = attn.softmax(dim=-1)
439 | attn = self.attn_drop(attn)
440 |
441 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
442 | x = self.proj(x)
443 | x = self.proj_drop(x)
444 |
445 | return x
446 |
447 |
448 | class Attention_dec(nn.Module):
449 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1):
450 | super().__init__()
451 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
452 |
453 | self.dim = dim
454 | self.num_heads = num_heads
455 | head_dim = dim // num_heads
456 | self.scale = qk_scale or head_dim**-0.5
457 |
458 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
459 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
460 | self.attn_drop = nn.Dropout(attn_drop)
461 | self.proj = nn.Linear(dim, dim)
462 | self.proj_drop = nn.Dropout(proj_drop)
463 |
464 | self.task_query = nn.Parameter(torch.randn(1, 48, dim))
465 | self.sr_ratio = sr_ratio
466 | if sr_ratio > 1:
467 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
468 | self.norm = nn.LayerNorm(dim)
469 |
470 | self.apply(self._init_weights)
471 |
472 | def _init_weights(self, m):
473 | if isinstance(m, nn.Linear):
474 | trunc_normal_(m.weight, std=0.02)
475 | if isinstance(m, nn.Linear) and m.bias is not None:
476 | nn.init.constant_(m.bias, 0)
477 | elif isinstance(m, nn.LayerNorm):
478 | nn.init.constant_(m.bias, 0)
479 | nn.init.constant_(m.weight, 1.0)
480 | elif isinstance(m, nn.Conv2d):
481 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
482 | fan_out //= m.groups
483 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
484 | if m.bias is not None:
485 | m.bias.data.zero_()
486 |
487 | def forward(self, x, H, W):
488 | B, N, C = x.shape
489 | task_q = self.task_query
490 |
491 | # This is because we fix the task parameters to be of a certain dimension, so with varying batch size, we just stack up the same queries to operate on the entire batch
492 | if B > 1:
493 | task_q = task_q.unsqueeze(0).repeat(B, 1, 1, 1)
494 | task_q = task_q.squeeze(1)
495 |
496 | q = self.q(task_q).reshape(B, task_q.shape[1], self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
497 |
498 | if self.sr_ratio > 1:
499 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
500 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
501 | x_ = self.norm(x_)
502 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
503 | else:
504 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
505 | k, v = kv[0], kv[1]
506 | q = torch.nn.functional.interpolate(q, size=(v.shape[2], v.shape[3]))
507 | attn = (q @ k.transpose(-2, -1)) * self.scale
508 | attn = attn.softmax(dim=-1)
509 | attn = self.attn_drop(attn)
510 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
511 | x = self.proj(x)
512 | x = self.proj_drop(x)
513 |
514 | return x
515 |
516 |
517 | class Block_dec(nn.Module):
518 | def __init__(
519 | self,
520 | dim,
521 | num_heads,
522 | mlp_ratio=4.0,
523 | qkv_bias=False,
524 | qk_scale=None,
525 | drop=0.0,
526 | attn_drop=0.0,
527 | drop_path=0.0,
528 | act_layer=nn.GELU,
529 | norm_layer=nn.LayerNorm,
530 | sr_ratio=1,
531 | ):
532 | super().__init__()
533 | self.norm1 = norm_layer(dim)
534 | self.attn = Attention_dec(
535 | dim,
536 | num_heads=num_heads,
537 | qkv_bias=qkv_bias,
538 | qk_scale=qk_scale,
539 | attn_drop=attn_drop,
540 | proj_drop=drop,
541 | sr_ratio=sr_ratio,
542 | )
543 |
544 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
545 | self.norm2 = norm_layer(dim)
546 | mlp_hidden_dim = int(dim * mlp_ratio)
547 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
548 |
549 | self.apply(self._init_weights)
550 |
551 | def _init_weights(self, m):
552 | if isinstance(m, nn.Linear):
553 | trunc_normal_(m.weight, std=0.02)
554 | if isinstance(m, nn.Linear) and m.bias is not None:
555 | nn.init.constant_(m.bias, 0)
556 | elif isinstance(m, nn.LayerNorm):
557 | nn.init.constant_(m.bias, 0)
558 | nn.init.constant_(m.weight, 1.0)
559 | elif isinstance(m, nn.Conv2d):
560 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
561 | fan_out //= m.groups
562 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
563 | if m.bias is not None:
564 | m.bias.data.zero_()
565 |
566 | def forward(self, x, H, W):
567 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
568 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
569 |
570 | return x
571 |
572 |
573 | class Block(nn.Module):
574 | def __init__(
575 | self,
576 | dim,
577 | num_heads,
578 | mlp_ratio=4.0,
579 | qkv_bias=False,
580 | qk_scale=None,
581 | drop=0.0,
582 | attn_drop=0.0,
583 | drop_path=0.0,
584 | act_layer=nn.GELU,
585 | norm_layer=nn.LayerNorm,
586 | sr_ratio=1,
587 | ):
588 | super().__init__()
589 | self.norm1 = norm_layer(dim)
590 | self.attn = Attention(
591 | dim,
592 | num_heads=num_heads,
593 | qkv_bias=qkv_bias,
594 | qk_scale=qk_scale,
595 | attn_drop=attn_drop,
596 | proj_drop=drop,
597 | sr_ratio=sr_ratio,
598 | )
599 |
600 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
601 | self.norm2 = norm_layer(dim)
602 | mlp_hidden_dim = int(dim * mlp_ratio)
603 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
604 |
605 | self.apply(self._init_weights)
606 |
607 | def _init_weights(self, m):
608 | if isinstance(m, nn.Linear):
609 | trunc_normal_(m.weight, std=0.02)
610 | if isinstance(m, nn.Linear) and m.bias is not None:
611 | nn.init.constant_(m.bias, 0)
612 | elif isinstance(m, nn.LayerNorm):
613 | nn.init.constant_(m.bias, 0)
614 | nn.init.constant_(m.weight, 1.0)
615 | elif isinstance(m, nn.Conv2d):
616 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
617 | fan_out //= m.groups
618 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
619 | if m.bias is not None:
620 | m.bias.data.zero_()
621 |
622 | def forward(self, x, H, W):
623 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
624 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
625 | return x
626 |
627 |
628 | class DWConv(nn.Module):
629 | def __init__(self, dim=768):
630 | super().__init__()
631 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
632 |
633 | def forward(self, x, H, W):
634 | B, N, C = x.shape
635 | x = x.transpose(1, 2).view(B, C, H, W)
636 | x = self.dwconv(x)
637 | x = x.flatten(2).transpose(1, 2)
638 |
639 | return x
640 |
641 |
642 | class Tenc(EncoderTransformer):
643 | def __init__(self, **kwargs):
644 | super().__init__(
645 | patch_size=16,
646 | embed_dims=[64, 128, 320, 512],
647 | num_heads=[1, 2, 4, 8],
648 | mlp_ratios=[4, 4, 4, 4],
649 | qkv_bias=True,
650 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
651 | depths=[3, 4, 6, 3],
652 | sr_ratios=[8, 4, 2, 1],
653 | drop_rate=0.0,
654 | drop_path_rate=0.1,
655 | )
656 |
657 |
658 | class convprojection(nn.Module):
659 | def __init__(self, path=None, **kwargs):
660 | super().__init__()
661 |
662 | self.convd32x = UpsampleConvLayer(512, 512, kernel_size=4, stride=2)
663 | self.convd16x = UpsampleConvLayer(512, 320, kernel_size=4, stride=2)
664 | self.dense_4 = nn.Sequential(ResidualBlock(320))
665 | self.convd8x = UpsampleConvLayer(320, 128, kernel_size=4, stride=2)
666 | self.dense_3 = nn.Sequential(ResidualBlock(128))
667 | self.convd4x = UpsampleConvLayer(128, 64, kernel_size=4, stride=2)
668 | self.dense_2 = nn.Sequential(ResidualBlock(64))
669 | self.convd2x = UpsampleConvLayer(64, 16, kernel_size=4, stride=2)
670 | self.dense_1 = nn.Sequential(ResidualBlock(16))
671 | self.convd1x = UpsampleConvLayer(16, 8, kernel_size=4, stride=2)
672 | self.conv_output = ConvLayer(8, 2, kernel_size=3, stride=1, padding=1)
673 |
674 | self.active = nn.Tanh()
675 |
676 | def forward(self, x1, x2):
677 | res32x = self.convd32x(x2[0])
678 |
679 | if x1[3].shape[3] != res32x.shape[3] and x1[3].shape[2] != res32x.shape[2]:
680 | p2d = (0, -1, 0, -1)
681 | res32x = F.pad(res32x, p2d, "constant", 0)
682 |
683 | elif x1[3].shape[3] != res32x.shape[3] and x1[3].shape[2] == res32x.shape[2]:
684 | p2d = (0, -1, 0, 0)
685 | res32x = F.pad(res32x, p2d, "constant", 0)
686 | elif x1[3].shape[3] == res32x.shape[3] and x1[3].shape[2] != res32x.shape[2]:
687 | p2d = (0, 0, 0, -1)
688 | res32x = F.pad(res32x, p2d, "constant", 0)
689 |
690 | res16x = res32x + x1[3]
691 | res16x = self.convd16x(res16x)
692 |
693 | if x1[2].shape[3] != res16x.shape[3] and x1[2].shape[2] != res16x.shape[2]:
694 | p2d = (0, -1, 0, -1)
695 | res16x = F.pad(res16x, p2d, "constant", 0)
696 | elif x1[2].shape[3] != res16x.shape[3] and x1[2].shape[2] == res16x.shape[2]:
697 | p2d = (0, -1, 0, 0)
698 | res16x = F.pad(res16x, p2d, "constant", 0)
699 | elif x1[2].shape[3] == res16x.shape[3] and x1[2].shape[2] != res16x.shape[2]:
700 | p2d = (0, 0, 0, -1)
701 | res16x = F.pad(res16x, p2d, "constant", 0)
702 |
703 | res8x = self.dense_4(res16x) + x1[2]
704 | res8x = self.convd8x(res8x)
705 | res4x = self.dense_3(res8x) + x1[1]
706 | res4x = self.convd4x(res4x)
707 | res2x = self.dense_2(res4x) + x1[0]
708 | res2x = self.convd2x(res2x)
709 | x = res2x
710 | x = self.dense_1(x)
711 | x = self.convd1x(x)
712 |
713 | return x
714 |
715 |
716 | class convprojection_base(nn.Module):
717 | def __init__(self, path=None, **kwargs):
718 | super().__init__()
719 |
720 | # self.convd32x = UpsampleConvLayer(512, 512, kernel_size=4, stride=2)
721 | self.convd16x = UpsampleConvLayer(512, 320, kernel_size=4, stride=2)
722 | self.dense_4 = nn.Sequential(ResidualBlock(320))
723 | self.convd8x = UpsampleConvLayer(320, 128, kernel_size=4, stride=2)
724 | self.dense_3 = nn.Sequential(ResidualBlock(128))
725 | self.convd4x = UpsampleConvLayer(128, 64, kernel_size=4, stride=2)
726 | self.dense_2 = nn.Sequential(ResidualBlock(64))
727 | self.convd2x = UpsampleConvLayer(64, 16, kernel_size=4, stride=2)
728 | self.dense_1 = nn.Sequential(ResidualBlock(16))
729 | self.convd1x = UpsampleConvLayer(16, 8, kernel_size=4, stride=2)
730 |
731 | def forward(self, x1):
732 | # if x1[3].shape[3] != res32x.shape[3] and x1[3].shape[2] != res32x.shape[2]:
733 | # p2d = (0,-1,0,-1)
734 | # res32x = F.pad(res32x,p2d,"constant",0)
735 |
736 | # elif x1[3].shape[3] != res32x.shape[3] and x1[3].shape[2] == res32x.shape[2]:
737 | # p2d = (0,-1,0,0)
738 | # res32x = F.pad(res32x,p2d,"constant",0)
739 | # elif x1[3].shape[3] == res32x.shape[3] and x1[3].shape[2] != res32x.shape[2]:
740 | # p2d = (0,0,0,-1)
741 | # res32x = F.pad(res32x,p2d,"constant",0)
742 |
743 | # res16x = res32x + x1[3]
744 | res16x = self.convd16x(x1[3])
745 |
746 | if x1[2].shape[3] != res16x.shape[3] and x1[2].shape[2] != res16x.shape[2]:
747 | p2d = (0, -1, 0, -1)
748 | res16x = F.pad(res16x, p2d, "constant", 0)
749 | elif x1[2].shape[3] != res16x.shape[3] and x1[2].shape[2] == res16x.shape[2]:
750 | p2d = (0, -1, 0, 0)
751 | res16x = F.pad(res16x, p2d, "constant", 0)
752 | elif x1[2].shape[3] == res16x.shape[3] and x1[2].shape[2] != res16x.shape[2]:
753 | p2d = (0, 0, 0, -1)
754 | res16x = F.pad(res16x, p2d, "constant", 0)
755 |
756 | res8x = self.dense_4(res16x) + x1[2]
757 | res8x = self.convd8x(res8x)
758 | res4x = self.dense_3(res8x) + x1[1]
759 | res4x = self.convd4x(res4x)
760 | res2x = self.dense_2(res4x) + x1[0]
761 | res2x = self.convd2x(res2x)
762 | x = res2x
763 | x = self.dense_1(x)
764 | x = self.convd1x(x)
765 | return x
766 |
767 |
768 | ### This is the basic ChangeFormer module
769 | class ChangeFormerV1(nn.Module):
770 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False):
771 | super().__init__()
772 |
773 | self.Tenc = Tenc()
774 |
775 | self.convproj = convprojection_base()
776 |
777 | self.change_probability = ConvLayer(8, output_nc, kernel_size=3, stride=1, padding=1)
778 |
779 | self.output_softmax = decoder_softmax
780 | self.active = torch.nn.Softmax(dim=1)
781 |
782 | def forward(self, x1, x2):
783 | fx1 = self.Tenc(x1)
784 | fx2 = self.Tenc(x2)
785 |
786 | DI = []
787 | for i in range(0, 4):
788 | DI.append(torch.abs(fx1[i] - fx2[i]))
789 |
790 | cp = self.convproj(DI)
791 |
792 | cp = self.change_probability(cp)
793 |
794 | if self.output_softmax:
795 | cp = self.active(cp)
796 |
797 | return cp
798 |
799 |
800 | # Transformer Decoder
801 | class MLP(nn.Module):
802 | """
803 | Linear Embedding
804 | """
805 |
806 | def __init__(self, input_dim=2048, embed_dim=768):
807 | super().__init__()
808 | self.proj = nn.Linear(input_dim, embed_dim)
809 |
810 | def forward(self, x):
811 | x = x.flatten(2).transpose(1, 2)
812 | x = self.proj(x)
813 | return x
814 |
815 |
816 | class TDec(nn.Module):
817 | """
818 | Transformer Decoder
819 | """
820 |
821 | def __init__(
822 | self,
823 | input_transform="multiple_select",
824 | in_index=[0, 1, 2, 3],
825 | align_corners=True,
826 | in_channels=[64, 128, 256, 512],
827 | embedding_dim=256,
828 | output_nc=2,
829 | decoder_softmax=False,
830 | feature_strides=[4, 8, 16, 32],
831 | ):
832 | super().__init__()
833 | assert len(feature_strides) == len(in_channels)
834 | assert min(feature_strides) == feature_strides[0]
835 | self.feature_strides = feature_strides
836 |
837 | # input transforms
838 | self.input_transform = input_transform
839 | self.in_index = in_index
840 | self.align_corners = align_corners
841 |
842 | # MLP
843 | self.in_channels = in_channels
844 | self.embedding_dim = embedding_dim
845 |
846 | # Final prediction
847 | self.output_nc = output_nc
848 |
849 | (c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels) = self.in_channels
850 |
851 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
852 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
853 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
854 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
855 |
856 | self.linear_fuse = nn.Conv2d(in_channels=self.embedding_dim * 4, out_channels=self.embedding_dim, kernel_size=1)
857 |
858 | # self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
859 | self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
860 | self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim))
861 | self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
862 | self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim))
863 |
864 | # Final prediction
865 | self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1)
866 | self.output_softmax = decoder_softmax
867 | self.active = nn.Softmax(dim=1)
868 |
869 | def _transform_inputs(self, inputs):
870 | """Transform inputs for decoder.
871 | Args:
872 | inputs (list[Tensor]): List of multi-level img features.
873 | Returns:
874 | Tensor: The transformed inputs
875 | """
876 |
877 | if self.input_transform == "resize_concat":
878 | inputs = [inputs[i] for i in self.in_index]
879 | upsampled_inputs = [
880 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
881 | for x in inputs
882 | ]
883 | inputs = torch.cat(upsampled_inputs, dim=1)
884 | elif self.input_transform == "multiple_select":
885 | inputs = [inputs[i] for i in self.in_index]
886 | else:
887 | inputs = inputs[self.in_index]
888 |
889 | return inputs
890 |
891 | def forward(self, inputs):
892 | x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
893 | c1, c2, c3, c4 = x
894 |
895 | ############## MLP decoder on C1-C4 ###########
896 | n, _, h, w = c4.shape
897 |
898 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
899 | _c4 = resize(_c4, size=c1.size()[2:], mode="bilinear", align_corners=False)
900 |
901 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
902 | _c3 = resize(_c3, size=c1.size()[2:], mode="bilinear", align_corners=False)
903 |
904 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
905 | _c2 = resize(_c2, size=c1.size()[2:], mode="bilinear", align_corners=False)
906 |
907 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
908 |
909 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
910 |
911 | x = self.convd2x(_c)
912 | x = self.dense_2x(x)
913 | x = self.convd1x(x)
914 | x = self.dense_1x(x)
915 |
916 | cp = self.change_probability(x)
917 | if self.output_softmax:
918 | cp = self.active(cp)
919 |
920 | return cp
921 |
922 |
923 | class TDecV2(nn.Module):
924 | """
925 | Transformer Decoder
926 | """
927 |
928 | def __init__(
929 | self,
930 | input_transform="multiple_select",
931 | in_index=[0, 1, 2, 3],
932 | align_corners=True,
933 | in_channels=[64, 128, 256, 512],
934 | embedding_dim=256,
935 | output_nc=2,
936 | decoder_softmax=False,
937 | feature_strides=[4, 8, 16, 32],
938 | ):
939 | super().__init__()
940 | assert len(feature_strides) == len(in_channels)
941 | assert min(feature_strides) == feature_strides[0]
942 | self.feature_strides = feature_strides
943 |
944 | # input transforms
945 | self.input_transform = input_transform
946 | self.in_index = in_index
947 | self.align_corners = align_corners
948 |
949 | # MLP
950 | self.in_channels = in_channels
951 | self.embedding_dim = embedding_dim
952 |
953 | # Final prediction
954 | self.output_nc = output_nc
955 |
956 | (c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels) = self.in_channels
957 |
958 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
959 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
960 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
961 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
962 |
963 | self.linear_fuse = nn.Conv2d(in_channels=self.embedding_dim * 4, out_channels=self.embedding_dim, kernel_size=1)
964 |
965 | # self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
966 | # self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
967 | # self.dense_2x = nn.Sequential( ResidualBlock(self.embedding_dim))
968 | # self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
969 | # self.dense_1x = nn.Sequential( ResidualBlock(self.embedding_dim))
970 |
971 | # Pixel Shiffle
972 | self.pix_shuffle_conv = nn.Conv2d(
973 | in_channels=self.embedding_dim, out_channels=16 * output_nc, kernel_size=3, stride=1, padding=1
974 | )
975 | self.relu = nn.ReLU()
976 | self.pix_shuffle = nn.PixelShuffle(4)
977 |
978 | # Final prediction
979 | # self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1)
980 |
981 | # Final activation
982 | self.output_softmax = decoder_softmax
983 | self.active = nn.Softmax(dim=1)
984 |
985 | def _transform_inputs(self, inputs):
986 | """Transform inputs for decoder.
987 | Args:
988 | inputs (list[Tensor]): List of multi-level img features.
989 | Returns:
990 | Tensor: The transformed inputs
991 | """
992 |
993 | if self.input_transform == "resize_concat":
994 | inputs = [inputs[i] for i in self.in_index]
995 | upsampled_inputs = [
996 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
997 | for x in inputs
998 | ]
999 | inputs = torch.cat(upsampled_inputs, dim=1)
1000 | elif self.input_transform == "multiple_select":
1001 | inputs = [inputs[i] for i in self.in_index]
1002 | else:
1003 | inputs = inputs[self.in_index]
1004 |
1005 | return inputs
1006 |
1007 | def forward(self, inputs1, inputs2):
1008 | x_1 = self._transform_inputs(inputs1) # len=4, 1/4,1/8,1/16,1/32
1009 | x_2 = self._transform_inputs(inputs2) # len=4, 1/4,1/8,1/16,1/32
1010 |
1011 | c1_1, c2_1, c3_1, c4_1 = x_1
1012 | c1_2, c2_2, c3_2, c4_2 = x_2
1013 |
1014 | ############## MLP decoder on C1-C4 ###########
1015 | n, _, h, w = c4_1.shape
1016 |
1017 | _c4_1 = self.linear_c4(c4_1).permute(0, 2, 1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3])
1018 | _c4_1 = resize(_c4_1, size=c1_1.size()[2:], mode="bilinear", align_corners=False)
1019 | _c4_2 = self.linear_c4(c4_2).permute(0, 2, 1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3])
1020 | _c4_2 = resize(_c4_2, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1021 |
1022 | _c3_1 = self.linear_c3(c3_1).permute(0, 2, 1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3])
1023 | _c3_1 = resize(_c3_1, size=c1_1.size()[2:], mode="bilinear", align_corners=False)
1024 | _c3_2 = self.linear_c3(c3_2).permute(0, 2, 1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3])
1025 | _c3_2 = resize(_c3_2, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1026 |
1027 | _c2_1 = self.linear_c2(c2_1).permute(0, 2, 1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3])
1028 | _c2_1 = resize(_c2_1, size=c1_1.size()[2:], mode="bilinear", align_corners=False)
1029 | _c2_2 = self.linear_c2(c2_2).permute(0, 2, 1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3])
1030 | _c2_2 = resize(_c2_2, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1031 |
1032 | _c1_1 = self.linear_c1(c1_1).permute(0, 2, 1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3])
1033 | _c1_2 = self.linear_c1(c1_2).permute(0, 2, 1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3])
1034 |
1035 | _c = self.linear_fuse(
1036 | torch.cat(
1037 | [
1038 | torch.abs(_c4_1 - _c4_2),
1039 | torch.abs(_c3_1 - _c3_2),
1040 | torch.abs(_c2_1 - _c2_2),
1041 | torch.abs(_c1_1 - _c1_2),
1042 | ],
1043 | dim=1,
1044 | )
1045 | )
1046 |
1047 | # x = self.dense_2x(x)
1048 | # x = self.convd1x(x)
1049 | # x = self.dense_1x(x)
1050 |
1051 | # cp = self.change_probability(x)
1052 |
1053 | # cp = F.interpolate(_c, scale_factor=4, mode="nearest")
1054 | x = self.relu(self.pix_shuffle_conv(_c))
1055 | cp = self.pix_shuffle(x)
1056 |
1057 | if self.output_softmax:
1058 | cp = self.active(cp)
1059 |
1060 | return cp
1061 |
1062 |
1063 | # ChangeFormerV2:
1064 | # Transformer encoder to extract features
1065 | # Feature differencing and pass it through Transformer decoder
1066 | class ChangeFormerV2(nn.Module):
1067 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False):
1068 | super().__init__()
1069 | # Transformer Encoder
1070 | self.Tenc = Tenc()
1071 |
1072 | # Transformer Decoder
1073 | self.TDec = TDec(
1074 | input_transform="multiple_select",
1075 | in_index=[0, 1, 2, 3],
1076 | align_corners=True,
1077 | in_channels=[64, 128, 320, 512],
1078 | embedding_dim=32,
1079 | output_nc=output_nc,
1080 | decoder_softmax=decoder_softmax,
1081 | feature_strides=[4, 8, 16, 32],
1082 | )
1083 | # Final activation
1084 | self.decoder_softmax = decoder_softmax
1085 | self.output_activation = torch.nn.Softmax(dim=1)
1086 |
1087 | def forward(self, x1, x2):
1088 | fx1 = self.Tenc(x1)
1089 | fx2 = self.Tenc(x2)
1090 |
1091 | DI = []
1092 | for i in range(0, 4):
1093 | DI.append(torch.abs(fx1[i] - fx2[i]))
1094 |
1095 | cp = self.TDec(DI)
1096 |
1097 | if self.decoder_softmax:
1098 | cp = self.output_activation(cp)
1099 |
1100 | return cp
1101 |
1102 |
1103 | # ChangeFormerV3:
1104 | # Feature differencing and pass it through Transformer decoder
1105 | class ChangeFormerV3(nn.Module):
1106 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False):
1107 | super().__init__()
1108 | # Transformer Encoder
1109 | self.Tenc = Tenc(
1110 | patch_size=16,
1111 | embed_dims=[64, 128, 320, 512],
1112 | num_heads=[1, 2, 4, 8],
1113 | mlp_ratios=[4, 4, 4, 4],
1114 | qkv_bias=True,
1115 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
1116 | depths=[3, 4, 6, 3],
1117 | sr_ratios=[8, 4, 2, 1],
1118 | drop_rate=0.0,
1119 | drop_path_rate=0.1,
1120 | )
1121 |
1122 | # Transformer Decoder
1123 | self.TDec = TDecV2(
1124 | input_transform="multiple_select",
1125 | in_index=[0, 1, 2, 3],
1126 | align_corners=True,
1127 | in_channels=[64, 128, 320, 512],
1128 | embedding_dim=64,
1129 | output_nc=output_nc,
1130 | decoder_softmax=decoder_softmax,
1131 | feature_strides=[4, 8, 16, 32],
1132 | )
1133 |
1134 | def forward(self, x1, x2):
1135 | fx1 = self.Tenc(x1)
1136 | fx2 = self.Tenc(x2)
1137 |
1138 | cp = self.TDec(fx1, fx2)
1139 |
1140 | return cp
1141 |
1142 |
1143 | # Transormer Ecoder with x2, x4, x8, x16 scales
1144 | class EncoderTransformer_x2(nn.Module):
1145 | def __init__(
1146 | self,
1147 | img_size=256,
1148 | patch_size=3,
1149 | in_chans=3,
1150 | num_classes=2,
1151 | embed_dims=[32, 64, 128, 256, 512],
1152 | num_heads=[2, 2, 4, 8, 16],
1153 | mlp_ratios=[4, 4, 4, 4, 4],
1154 | qkv_bias=False,
1155 | qk_scale=None,
1156 | drop_rate=0.0,
1157 | attn_drop_rate=0.0,
1158 | drop_path_rate=0.0,
1159 | norm_layer=nn.LayerNorm,
1160 | depths=[3, 3, 6, 18, 3],
1161 | sr_ratios=[8, 4, 2, 1, 1],
1162 | ):
1163 | super().__init__()
1164 | self.num_classes = num_classes
1165 | self.depths = depths
1166 | self.embed_dims = embed_dims
1167 |
1168 | # patch embedding definitions
1169 | self.patch_embed1 = OverlapPatchEmbed(
1170 | img_size=img_size, patch_size=7, stride=2, in_chans=in_chans, embed_dim=embed_dims[0]
1171 | )
1172 | self.patch_embed2 = OverlapPatchEmbed(
1173 | img_size=img_size // 2, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]
1174 | )
1175 | self.patch_embed3 = OverlapPatchEmbed(
1176 | img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]
1177 | )
1178 | self.patch_embed4 = OverlapPatchEmbed(
1179 | img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]
1180 | )
1181 | self.patch_embed5 = OverlapPatchEmbed(
1182 | img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[3], embed_dim=embed_dims[4]
1183 | )
1184 |
1185 | # Stage-1 (x1/2 scale)
1186 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1187 | cur = 0
1188 | self.block1 = nn.ModuleList(
1189 | [
1190 | Block(
1191 | dim=embed_dims[0],
1192 | num_heads=num_heads[0],
1193 | mlp_ratio=mlp_ratios[0],
1194 | qkv_bias=qkv_bias,
1195 | qk_scale=qk_scale,
1196 | drop=drop_rate,
1197 | attn_drop=attn_drop_rate,
1198 | drop_path=dpr[cur + i],
1199 | norm_layer=norm_layer,
1200 | sr_ratio=sr_ratios[0],
1201 | )
1202 | for i in range(depths[0])
1203 | ]
1204 | )
1205 | self.norm1 = norm_layer(embed_dims[0])
1206 |
1207 | # Stage-2 (x1/4 scale)
1208 | cur += depths[0]
1209 | self.block2 = nn.ModuleList(
1210 | [
1211 | Block(
1212 | dim=embed_dims[1],
1213 | num_heads=num_heads[1],
1214 | mlp_ratio=mlp_ratios[1],
1215 | qkv_bias=qkv_bias,
1216 | qk_scale=qk_scale,
1217 | drop=drop_rate,
1218 | attn_drop=attn_drop_rate,
1219 | drop_path=dpr[cur + i],
1220 | norm_layer=norm_layer,
1221 | sr_ratio=sr_ratios[1],
1222 | )
1223 | for i in range(depths[1])
1224 | ]
1225 | )
1226 | self.norm2 = norm_layer(embed_dims[1])
1227 |
1228 | # Stage-3 (x1/8 scale)
1229 | cur += depths[1]
1230 | self.block3 = nn.ModuleList(
1231 | [
1232 | Block(
1233 | dim=embed_dims[2],
1234 | num_heads=num_heads[2],
1235 | mlp_ratio=mlp_ratios[2],
1236 | qkv_bias=qkv_bias,
1237 | qk_scale=qk_scale,
1238 | drop=drop_rate,
1239 | attn_drop=attn_drop_rate,
1240 | drop_path=dpr[cur + i],
1241 | norm_layer=norm_layer,
1242 | sr_ratio=sr_ratios[2],
1243 | )
1244 | for i in range(depths[2])
1245 | ]
1246 | )
1247 | self.norm3 = norm_layer(embed_dims[2])
1248 |
1249 | # Stage-4 (x1/16 scale)
1250 | cur += depths[2]
1251 | self.block4 = nn.ModuleList(
1252 | [
1253 | Block(
1254 | dim=embed_dims[3],
1255 | num_heads=num_heads[3],
1256 | mlp_ratio=mlp_ratios[3],
1257 | qkv_bias=qkv_bias,
1258 | qk_scale=qk_scale,
1259 | drop=drop_rate,
1260 | attn_drop=attn_drop_rate,
1261 | drop_path=dpr[cur + i],
1262 | norm_layer=norm_layer,
1263 | sr_ratio=sr_ratios[3],
1264 | )
1265 | for i in range(depths[3])
1266 | ]
1267 | )
1268 | self.norm4 = norm_layer(embed_dims[3])
1269 |
1270 | # Stage-5 (x1/32 scale)
1271 | cur += depths[3]
1272 | self.block5 = nn.ModuleList(
1273 | [
1274 | Block(
1275 | dim=embed_dims[4],
1276 | num_heads=num_heads[4],
1277 | mlp_ratio=mlp_ratios[4],
1278 | qkv_bias=qkv_bias,
1279 | qk_scale=qk_scale,
1280 | drop=drop_rate,
1281 | attn_drop=attn_drop_rate,
1282 | drop_path=dpr[cur + i],
1283 | norm_layer=norm_layer,
1284 | sr_ratio=sr_ratios[4],
1285 | )
1286 | for i in range(depths[4])
1287 | ]
1288 | )
1289 | self.norm5 = norm_layer(embed_dims[4])
1290 |
1291 | self.apply(self._init_weights)
1292 |
1293 | def _init_weights(self, m):
1294 | if isinstance(m, nn.Linear):
1295 | trunc_normal_(m.weight, std=0.02)
1296 | if isinstance(m, nn.Linear) and m.bias is not None:
1297 | nn.init.constant_(m.bias, 0)
1298 | elif isinstance(m, nn.LayerNorm):
1299 | nn.init.constant_(m.bias, 0)
1300 | nn.init.constant_(m.weight, 1.0)
1301 | elif isinstance(m, nn.Conv2d):
1302 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
1303 | fan_out //= m.groups
1304 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
1305 | if m.bias is not None:
1306 | m.bias.data.zero_()
1307 |
1308 | def reset_drop_path(self, drop_path_rate):
1309 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
1310 | cur = 0
1311 | for i in range(self.depths[0]):
1312 | self.block1[i].drop_path.drop_prob = dpr[cur + i]
1313 |
1314 | cur += self.depths[0]
1315 | for i in range(self.depths[1]):
1316 | self.block2[i].drop_path.drop_prob = dpr[cur + i]
1317 |
1318 | cur += self.depths[1]
1319 | for i in range(self.depths[2]):
1320 | self.block3[i].drop_path.drop_prob = dpr[cur + i]
1321 |
1322 | cur += self.depths[2]
1323 | for i in range(self.depths[3]):
1324 | self.block4[i].drop_path.drop_prob = dpr[cur + i]
1325 |
1326 | def forward_features(self, x):
1327 | B = x.shape[0]
1328 | outs = []
1329 |
1330 | # stage 1
1331 | x1, H1, W1 = self.patch_embed1(x)
1332 | for _i, blk in enumerate(self.block1):
1333 | x1 = blk(x1, H1, W1)
1334 | x1 = self.norm1(x1)
1335 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1336 | outs.append(x1)
1337 |
1338 | # stage 2
1339 | x1, H1, W1 = self.patch_embed2(x1)
1340 | for _i, blk in enumerate(self.block2):
1341 | x1 = blk(x1, H1, W1)
1342 | x1 = self.norm2(x1)
1343 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1344 | outs.append(x1)
1345 |
1346 | # stage 3
1347 | x1, H1, W1 = self.patch_embed3(x1)
1348 | for _i, blk in enumerate(self.block3):
1349 | x1 = blk(x1, H1, W1)
1350 | x1 = self.norm3(x1)
1351 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1352 | outs.append(x1)
1353 |
1354 | # stage 4
1355 | x1, H1, W1 = self.patch_embed4(x1)
1356 | for _i, blk in enumerate(self.block4):
1357 | x1 = blk(x1, H1, W1)
1358 | x1 = self.norm4(x1)
1359 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1360 | outs.append(x1)
1361 |
1362 | # stage 5
1363 | x1, H1, W1 = self.patch_embed5(x1)
1364 | for _i, blk in enumerate(self.block5):
1365 | x1 = blk(x1, H1, W1)
1366 | x1 = self.norm5(x1)
1367 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1368 | outs.append(x1)
1369 |
1370 | return outs
1371 |
1372 | def forward(self, x):
1373 | x = self.forward_features(x)
1374 | return x
1375 |
1376 |
1377 | # Difference module
1378 | def conv_diff(in_channels, out_channels):
1379 | return nn.Sequential(
1380 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
1381 | nn.ReLU(),
1382 | nn.BatchNorm2d(out_channels),
1383 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
1384 | nn.ReLU(),
1385 | )
1386 |
1387 |
1388 | # Intermediate prediction module
1389 | def make_prediction(in_channels, out_channels):
1390 | return nn.Sequential(
1391 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
1392 | nn.ReLU(),
1393 | nn.BatchNorm2d(out_channels),
1394 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
1395 | )
1396 |
1397 |
1398 | class DecoderTransformer_x2(nn.Module):
1399 | """
1400 | Transformer Decoder
1401 | """
1402 |
1403 | def __init__(
1404 | self,
1405 | input_transform="multiple_select",
1406 | in_index=[0, 1, 2, 3, 4],
1407 | align_corners=True,
1408 | in_channels=[32, 64, 128, 256, 512],
1409 | embedding_dim=64,
1410 | output_nc=2,
1411 | decoder_softmax=False,
1412 | feature_strides=[2, 4, 8, 16, 32],
1413 | ):
1414 | super().__init__()
1415 | assert len(feature_strides) == len(in_channels)
1416 | assert min(feature_strides) == feature_strides[0]
1417 | self.feature_strides = feature_strides
1418 |
1419 | # input transforms
1420 | self.input_transform = input_transform
1421 | self.in_index = in_index
1422 | self.align_corners = align_corners
1423 |
1424 | # MLP
1425 | self.in_channels = in_channels
1426 | self.embedding_dim = embedding_dim
1427 |
1428 | # Final prediction
1429 | self.output_nc = output_nc
1430 |
1431 | (c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels, c5_in_channels) = self.in_channels
1432 |
1433 | self.linear_c5 = MLP(input_dim=c5_in_channels, embed_dim=self.embedding_dim)
1434 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
1435 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
1436 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
1437 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
1438 |
1439 | # Convolutional Difference Modules
1440 | self.diff_c5 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
1441 | self.diff_c4 = conv_diff(in_channels=3 * self.embedding_dim, out_channels=self.embedding_dim)
1442 | self.diff_c3 = conv_diff(in_channels=3 * self.embedding_dim, out_channels=self.embedding_dim)
1443 | self.diff_c2 = conv_diff(in_channels=3 * self.embedding_dim, out_channels=self.embedding_dim)
1444 | self.diff_c1 = conv_diff(in_channels=3 * self.embedding_dim, out_channels=self.embedding_dim)
1445 |
1446 | # Taking outputs from middle of the encoder
1447 | self.make_pred_c5 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1448 | self.make_pred_c4 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1449 | self.make_pred_c3 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1450 | self.make_pred_c2 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1451 | self.make_pred_c1 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1452 |
1453 | self.linear_fuse = nn.Conv2d(
1454 | in_channels=self.embedding_dim * len(in_channels), out_channels=self.embedding_dim, kernel_size=1
1455 | )
1456 |
1457 | # self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
1458 | self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
1459 | self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim))
1460 | self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
1461 | self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim))
1462 |
1463 | # Final prediction
1464 | self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1)
1465 |
1466 | # Final activation
1467 | self.output_softmax = decoder_softmax
1468 | self.active = nn.Sigmoid()
1469 |
1470 | def _transform_inputs(self, inputs):
1471 | """Transform inputs for decoder.
1472 | Args:
1473 | inputs (list[Tensor]): List of multi-level img features.
1474 | Returns:
1475 | Tensor: The transformed inputs
1476 | """
1477 |
1478 | if self.input_transform == "resize_concat":
1479 | inputs = [inputs[i] for i in self.in_index]
1480 | upsampled_inputs = [
1481 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
1482 | for x in inputs
1483 | ]
1484 | inputs = torch.cat(upsampled_inputs, dim=1)
1485 | elif self.input_transform == "multiple_select":
1486 | inputs = [inputs[i] for i in self.in_index]
1487 | else:
1488 | inputs = inputs[self.in_index]
1489 |
1490 | return inputs
1491 |
1492 | def forward(self, inputs1, inputs2):
1493 | x_1 = self._transform_inputs(inputs1) # len=4, 1/2,1/4,1/8,1/16,1/32
1494 | x_2 = self._transform_inputs(inputs2) # len=4, 1/2,1/4,1/8,1/16,1/32
1495 |
1496 | c1_1, c2_1, c3_1, c4_1, c5_1 = x_1
1497 | c1_2, c2_2, c3_2, c4_2, c5_2 = x_2
1498 |
1499 | ############## MLP decoder on C1-C4 ###########
1500 | n, _, h, w = c5_1.shape
1501 |
1502 | outputs = [] # Multi-scale outputs adding here
1503 |
1504 | _c5_1 = self.linear_c5(c5_1).permute(0, 2, 1).reshape(n, -1, c5_1.shape[2], c5_1.shape[3])
1505 | _c5_2 = self.linear_c5(c5_2).permute(0, 2, 1).reshape(n, -1, c5_2.shape[2], c5_2.shape[3])
1506 | _c5 = self.diff_c5(torch.cat((_c5_1, _c5_2), dim=1)) # Difference of features at x1/32 scale
1507 | p_c5 = self.make_pred_c5(_c5) # Predicted change map at x1/32 scale
1508 | outputs.append(p_c5) # x1/32 scale
1509 | _c5_up = resize(_c5, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1510 |
1511 | _c4_1 = self.linear_c4(c4_1).permute(0, 2, 1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3])
1512 | _c4_2 = self.linear_c4(c4_2).permute(0, 2, 1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3])
1513 | _c4 = self.diff_c4(
1514 | torch.cat((F.interpolate(_c5, scale_factor=2, mode="bilinear"), _c4_1, _c4_2), dim=1)
1515 | ) # Difference of features at x1/16 scale
1516 | p_c4 = self.make_pred_c4(_c4) # Predicted change map at x1/16 scale
1517 | outputs.append(p_c4) # x1/16 scale
1518 | _c4_up = resize(_c4, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1519 |
1520 | _c3_1 = self.linear_c3(c3_1).permute(0, 2, 1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3])
1521 | _c3_2 = self.linear_c3(c3_2).permute(0, 2, 1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3])
1522 | _c3 = self.diff_c3(
1523 | torch.cat((F.interpolate(_c4, scale_factor=2, mode="bilinear"), _c3_1, _c3_2), dim=1)
1524 | ) # Difference of features at x1/8 scale
1525 | p_c3 = self.make_pred_c3(_c3) # Predicted change map at x1/8 scale
1526 | outputs.append(p_c3) # x1/8 scale
1527 | _c3_up = resize(_c3, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1528 |
1529 | _c2_1 = self.linear_c2(c2_1).permute(0, 2, 1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3])
1530 | _c2_2 = self.linear_c2(c2_2).permute(0, 2, 1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3])
1531 | _c2 = self.diff_c2(
1532 | torch.cat((F.interpolate(_c3, scale_factor=2, mode="bilinear"), _c2_1, _c2_2), dim=1)
1533 | ) # Difference of features at x1/4 scale
1534 | p_c2 = self.make_pred_c2(_c2) # Predicted change map at x1/4 scale
1535 | outputs.append(p_c2) # x1/4 scale
1536 | _c2_up = resize(_c2, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1537 |
1538 | _c1_1 = self.linear_c1(c1_1).permute(0, 2, 1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3])
1539 | _c1_2 = self.linear_c1(c1_2).permute(0, 2, 1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3])
1540 | _c1 = self.diff_c1(
1541 | torch.cat((F.interpolate(_c2, scale_factor=2, mode="bilinear"), _c1_1, _c1_2), dim=1)
1542 | ) # Difference of features at x1/2 scale
1543 | p_c1 = self.make_pred_c1(_c1) # Predicted change map at x1/2 scale
1544 | outputs.append(p_c1) # x1/2 scale
1545 |
1546 | _c = self.linear_fuse(torch.cat((_c5_up, _c4_up, _c3_up, _c2_up, _c1), dim=1))
1547 |
1548 | x = self.convd2x(_c)
1549 | x = self.dense_2x(x)
1550 | cp = self.change_probability(x)
1551 | outputs.append(cp)
1552 |
1553 | if self.output_softmax:
1554 | temp = outputs
1555 | outputs = []
1556 | for pred in temp:
1557 | outputs.append(self.active(pred))
1558 |
1559 | return outputs
1560 |
1561 |
1562 | # ChangeFormerV4:
1563 | class ChangeFormerV4(nn.Module):
1564 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False):
1565 | super().__init__()
1566 | # Transformer Encoder
1567 | self.embed_dims = [32, 64, 128, 320, 512]
1568 | self.depths = [3, 3, 4, 12, 3] # [3, 3, 6, 18, 3]
1569 | self.embedding_dim = 256
1570 |
1571 | self.Tenc_x2 = EncoderTransformer_x2(
1572 | img_size=256,
1573 | patch_size=3,
1574 | in_chans=input_nc,
1575 | num_classes=output_nc,
1576 | embed_dims=self.embed_dims,
1577 | num_heads=[2, 2, 4, 8, 16],
1578 | mlp_ratios=[2, 2, 2, 2, 2],
1579 | qkv_bias=False,
1580 | qk_scale=None,
1581 | drop_rate=0.0,
1582 | attn_drop_rate=0.0,
1583 | drop_path_rate=0.0,
1584 | norm_layer=nn.LayerNorm,
1585 | depths=self.depths,
1586 | sr_ratios=[8, 4, 2, 1, 1],
1587 | )
1588 |
1589 | # Transformer Decoder
1590 | self.TDec_x2 = DecoderTransformer_x2(
1591 | input_transform="multiple_select",
1592 | in_index=[0, 1, 2, 3, 4],
1593 | align_corners=True,
1594 | in_channels=self.embed_dims,
1595 | embedding_dim=256,
1596 | output_nc=output_nc,
1597 | decoder_softmax=decoder_softmax,
1598 | feature_strides=[2, 4, 8, 16, 32],
1599 | )
1600 |
1601 | def forward(self, x1, x2):
1602 | [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)]
1603 |
1604 | cp = self.TDec_x2(fx1, fx2)
1605 |
1606 | # # Save to mat
1607 | # save_to_mat(x1, x2, fx1, fx2, cp, "ChangeFormerV4")
1608 |
1609 | # exit()
1610 | return cp
1611 |
1612 |
1613 | # Transormer Ecoder with x2, x4, x8, x16 scales
1614 | class EncoderTransformer_v3(nn.Module):
1615 | def __init__(
1616 | self,
1617 | img_size=256,
1618 | patch_size=3,
1619 | in_chans=3,
1620 | num_classes=2,
1621 | embed_dims=[32, 64, 128, 256],
1622 | num_heads=[2, 2, 4, 8],
1623 | mlp_ratios=[4, 4, 4, 4],
1624 | qkv_bias=True,
1625 | qk_scale=None,
1626 | drop_rate=0.0,
1627 | attn_drop_rate=0.0,
1628 | drop_path_rate=0.0,
1629 | norm_layer=nn.LayerNorm,
1630 | depths=[3, 3, 6, 18],
1631 | sr_ratios=[8, 4, 2, 1],
1632 | ):
1633 | super().__init__()
1634 | self.num_classes = num_classes
1635 | self.depths = depths
1636 | self.embed_dims = embed_dims
1637 |
1638 | # patch embedding definitions
1639 | self.patch_embed1 = OverlapPatchEmbed(
1640 | img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0]
1641 | )
1642 | self.patch_embed2 = OverlapPatchEmbed(
1643 | img_size=img_size // 4, patch_size=patch_size, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]
1644 | )
1645 | self.patch_embed3 = OverlapPatchEmbed(
1646 | img_size=img_size // 8, patch_size=patch_size, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]
1647 | )
1648 | self.patch_embed4 = OverlapPatchEmbed(
1649 | img_size=img_size // 16, patch_size=patch_size, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]
1650 | )
1651 |
1652 | # Stage-1 (x1/4 scale)
1653 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
1654 | cur = 0
1655 | self.block1 = nn.ModuleList(
1656 | [
1657 | Block(
1658 | dim=embed_dims[0],
1659 | num_heads=num_heads[0],
1660 | mlp_ratio=mlp_ratios[0],
1661 | qkv_bias=qkv_bias,
1662 | qk_scale=qk_scale,
1663 | drop=drop_rate,
1664 | attn_drop=attn_drop_rate,
1665 | drop_path=dpr[cur + i],
1666 | norm_layer=norm_layer,
1667 | sr_ratio=sr_ratios[0],
1668 | )
1669 | for i in range(depths[0])
1670 | ]
1671 | )
1672 | self.norm1 = norm_layer(embed_dims[0])
1673 |
1674 | # Stage-2 (x1/8 scale)
1675 | cur += depths[0]
1676 | self.block2 = nn.ModuleList(
1677 | [
1678 | Block(
1679 | dim=embed_dims[1],
1680 | num_heads=num_heads[1],
1681 | mlp_ratio=mlp_ratios[1],
1682 | qkv_bias=qkv_bias,
1683 | qk_scale=qk_scale,
1684 | drop=drop_rate,
1685 | attn_drop=attn_drop_rate,
1686 | drop_path=dpr[cur + i],
1687 | norm_layer=norm_layer,
1688 | sr_ratio=sr_ratios[1],
1689 | )
1690 | for i in range(depths[1])
1691 | ]
1692 | )
1693 | self.norm2 = norm_layer(embed_dims[1])
1694 |
1695 | # Stage-3 (x1/16 scale)
1696 | cur += depths[1]
1697 | self.block3 = nn.ModuleList(
1698 | [
1699 | Block(
1700 | dim=embed_dims[2],
1701 | num_heads=num_heads[2],
1702 | mlp_ratio=mlp_ratios[2],
1703 | qkv_bias=qkv_bias,
1704 | qk_scale=qk_scale,
1705 | drop=drop_rate,
1706 | attn_drop=attn_drop_rate,
1707 | drop_path=dpr[cur + i],
1708 | norm_layer=norm_layer,
1709 | sr_ratio=sr_ratios[2],
1710 | )
1711 | for i in range(depths[2])
1712 | ]
1713 | )
1714 | self.norm3 = norm_layer(embed_dims[2])
1715 |
1716 | # Stage-4 (x1/32 scale)
1717 | cur += depths[2]
1718 | self.block4 = nn.ModuleList(
1719 | [
1720 | Block(
1721 | dim=embed_dims[3],
1722 | num_heads=num_heads[3],
1723 | mlp_ratio=mlp_ratios[3],
1724 | qkv_bias=qkv_bias,
1725 | qk_scale=qk_scale,
1726 | drop=drop_rate,
1727 | attn_drop=attn_drop_rate,
1728 | drop_path=dpr[cur + i],
1729 | norm_layer=norm_layer,
1730 | sr_ratio=sr_ratios[3],
1731 | )
1732 | for i in range(depths[3])
1733 | ]
1734 | )
1735 | self.norm4 = norm_layer(embed_dims[3])
1736 |
1737 | self.apply(self._init_weights)
1738 |
1739 | def _init_weights(self, m):
1740 | if isinstance(m, nn.Linear):
1741 | trunc_normal_(m.weight, std=0.02)
1742 | if isinstance(m, nn.Linear) and m.bias is not None:
1743 | nn.init.constant_(m.bias, 0)
1744 | elif isinstance(m, nn.LayerNorm):
1745 | nn.init.constant_(m.bias, 0)
1746 | nn.init.constant_(m.weight, 1.0)
1747 | elif isinstance(m, nn.Conv2d):
1748 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
1749 | fan_out //= m.groups
1750 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
1751 | if m.bias is not None:
1752 | m.bias.data.zero_()
1753 |
1754 | def reset_drop_path(self, drop_path_rate):
1755 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
1756 | cur = 0
1757 | for i in range(self.depths[0]):
1758 | self.block1[i].drop_path.drop_prob = dpr[cur + i]
1759 |
1760 | cur += self.depths[0]
1761 | for i in range(self.depths[1]):
1762 | self.block2[i].drop_path.drop_prob = dpr[cur + i]
1763 |
1764 | cur += self.depths[1]
1765 | for i in range(self.depths[2]):
1766 | self.block3[i].drop_path.drop_prob = dpr[cur + i]
1767 |
1768 | cur += self.depths[2]
1769 | for i in range(self.depths[3]):
1770 | self.block4[i].drop_path.drop_prob = dpr[cur + i]
1771 |
1772 | def forward_features(self, x):
1773 | B = x.shape[0]
1774 | outs = []
1775 |
1776 | # stage 1
1777 | x1, H1, W1 = self.patch_embed1(x)
1778 | for _i, blk in enumerate(self.block1):
1779 | x1 = blk(x1, H1, W1)
1780 | x1 = self.norm1(x1)
1781 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1782 | outs.append(x1)
1783 |
1784 | # stage 2
1785 | x1, H1, W1 = self.patch_embed2(x1)
1786 | for _i, blk in enumerate(self.block2):
1787 | x1 = blk(x1, H1, W1)
1788 | x1 = self.norm2(x1)
1789 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1790 | outs.append(x1)
1791 |
1792 | # stage 3
1793 | x1, H1, W1 = self.patch_embed3(x1)
1794 | for _i, blk in enumerate(self.block3):
1795 | x1 = blk(x1, H1, W1)
1796 | x1 = self.norm3(x1)
1797 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1798 | outs.append(x1)
1799 |
1800 | # stage 4
1801 | x1, H1, W1 = self.patch_embed4(x1)
1802 | for _i, blk in enumerate(self.block4):
1803 | x1 = blk(x1, H1, W1)
1804 | x1 = self.norm4(x1)
1805 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
1806 | outs.append(x1)
1807 | return outs
1808 |
1809 | def forward(self, x):
1810 | x = self.forward_features(x)
1811 | return x
1812 |
1813 |
1814 | class DecoderTransformer_v3(nn.Module):
1815 | """
1816 | Transformer Decoder
1817 | """
1818 |
1819 | def __init__(
1820 | self,
1821 | input_transform="multiple_select",
1822 | in_index=[0, 1, 2, 3],
1823 | align_corners=True,
1824 | in_channels=[32, 64, 128, 256],
1825 | embedding_dim=64,
1826 | output_nc=2,
1827 | decoder_softmax=False,
1828 | feature_strides=[2, 4, 8, 16],
1829 | ):
1830 | super().__init__()
1831 | # assert
1832 | assert len(feature_strides) == len(in_channels)
1833 | assert min(feature_strides) == feature_strides[0]
1834 |
1835 | # settings
1836 | self.feature_strides = feature_strides
1837 | self.input_transform = input_transform
1838 | self.in_index = in_index
1839 | self.align_corners = align_corners
1840 | self.in_channels = in_channels
1841 | self.embedding_dim = embedding_dim
1842 | self.output_nc = output_nc
1843 | (c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels) = self.in_channels
1844 |
1845 | # MLP decoder heads
1846 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
1847 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
1848 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
1849 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
1850 |
1851 | # convolutional Difference Modules
1852 | self.diff_c4 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
1853 | self.diff_c3 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
1854 | self.diff_c2 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
1855 | self.diff_c1 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim)
1856 |
1857 | # taking outputs from middle of the encoder
1858 | self.make_pred_c4 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1859 | self.make_pred_c3 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1860 | self.make_pred_c2 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1861 | self.make_pred_c1 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc)
1862 |
1863 | # Final linear fusion layer
1864 | self.linear_fuse = nn.Sequential(
1865 | nn.Conv2d(
1866 | in_channels=self.embedding_dim * len(in_channels), out_channels=self.embedding_dim, kernel_size=1
1867 | ),
1868 | nn.BatchNorm2d(self.embedding_dim),
1869 | )
1870 |
1871 | # Final predction head
1872 | self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
1873 | self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim))
1874 | self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2)
1875 | self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim))
1876 | self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1)
1877 |
1878 | # Final activation
1879 | self.output_softmax = decoder_softmax
1880 | self.active = nn.Sigmoid()
1881 |
1882 | def _transform_inputs(self, inputs):
1883 | """Transform inputs for decoder.
1884 | Args:
1885 | inputs (list[Tensor]): List of multi-level img features.
1886 | Returns:
1887 | Tensor: The transformed inputs
1888 | """
1889 |
1890 | if self.input_transform == "resize_concat":
1891 | inputs = [inputs[i] for i in self.in_index]
1892 | upsampled_inputs = [
1893 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners)
1894 | for x in inputs
1895 | ]
1896 | inputs = torch.cat(upsampled_inputs, dim=1)
1897 | elif self.input_transform == "multiple_select":
1898 | inputs = [inputs[i] for i in self.in_index]
1899 | else:
1900 | inputs = inputs[self.in_index]
1901 |
1902 | return inputs
1903 |
1904 | def forward(self, inputs1, inputs2):
1905 | # Transforming encoder features (select layers)
1906 | x_1 = self._transform_inputs(inputs1) # len=4, 1/2, 1/4, 1/8, 1/16
1907 | x_2 = self._transform_inputs(inputs2) # len=4, 1/2, 1/4, 1/8, 1/16
1908 |
1909 | # img1 and img2 features
1910 | c1_1, c2_1, c3_1, c4_1 = x_1
1911 | c1_2, c2_2, c3_2, c4_2 = x_2
1912 |
1913 | ############## MLP decoder on C1-C4 ###########
1914 | n, _, h, w = c4_1.shape
1915 |
1916 | outputs = []
1917 | # Stage 4: x1/32 scale
1918 | _c4_1 = self.linear_c4(c4_1).permute(0, 2, 1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3])
1919 | _c4_2 = self.linear_c4(c4_2).permute(0, 2, 1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3])
1920 | _c4 = self.diff_c4(torch.cat((_c4_1, _c4_2), dim=1))
1921 | p_c4 = self.make_pred_c4(_c4)
1922 | outputs.append(p_c4)
1923 | _c4_up = resize(_c4, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1924 |
1925 | # Stage 3: x1/16 scale
1926 | _c3_1 = self.linear_c3(c3_1).permute(0, 2, 1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3])
1927 | _c3_2 = self.linear_c3(c3_2).permute(0, 2, 1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3])
1928 | _c3 = self.diff_c3(torch.cat((_c3_1, _c3_2), dim=1)) + F.interpolate(_c4, scale_factor=2, mode="bilinear")
1929 | p_c3 = self.make_pred_c3(_c3)
1930 | outputs.append(p_c3)
1931 | _c3_up = resize(_c3, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1932 |
1933 | # Stage 2: x1/8 scale
1934 | _c2_1 = self.linear_c2(c2_1).permute(0, 2, 1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3])
1935 | _c2_2 = self.linear_c2(c2_2).permute(0, 2, 1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3])
1936 | _c2 = self.diff_c2(torch.cat((_c2_1, _c2_2), dim=1)) + F.interpolate(_c3, scale_factor=2, mode="bilinear")
1937 | p_c2 = self.make_pred_c2(_c2)
1938 | outputs.append(p_c2)
1939 | _c2_up = resize(_c2, size=c1_2.size()[2:], mode="bilinear", align_corners=False)
1940 |
1941 | # Stage 1: x1/4 scale
1942 | _c1_1 = self.linear_c1(c1_1).permute(0, 2, 1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3])
1943 | _c1_2 = self.linear_c1(c1_2).permute(0, 2, 1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3])
1944 | _c1 = self.diff_c1(torch.cat((_c1_1, _c1_2), dim=1)) + F.interpolate(_c2, scale_factor=2, mode="bilinear")
1945 | p_c1 = self.make_pred_c1(_c1)
1946 | outputs.append(p_c1)
1947 |
1948 | # Linear Fusion of difference image from all scales
1949 | _c = self.linear_fuse(torch.cat((_c4_up, _c3_up, _c2_up, _c1), dim=1))
1950 |
1951 | # #Dropout
1952 | # if dropout_ratio > 0:
1953 | # self.dropout = nn.Dropout2d(dropout_ratio)
1954 | # else:
1955 | # self.dropout = None
1956 |
1957 | # Upsampling x2 (x1/2 scale)
1958 | x = self.convd2x(_c)
1959 | # Residual block
1960 | x = self.dense_2x(x)
1961 | # Upsampling x2 (x1 scale)
1962 | x = self.convd1x(x)
1963 | # Residual block
1964 | x = self.dense_1x(x)
1965 |
1966 | # Final prediction
1967 | cp = self.change_probability(x)
1968 |
1969 | outputs.append(cp)
1970 |
1971 | if self.output_softmax:
1972 | temp = outputs
1973 | outputs = []
1974 | for pred in temp:
1975 | outputs.append(self.active(pred))
1976 |
1977 | return outputs
1978 |
1979 |
1980 | # ChangeFormerV5:
1981 | class ChangeFormerV5(nn.Module):
1982 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False, embed_dim=256):
1983 | super().__init__()
1984 | # Transformer Encoder
1985 | self.embed_dims = [64, 128, 320, 512]
1986 | self.depths = [3, 6, 16, 3] # [3, 3, 6, 18, 3]
1987 | self.embedding_dim = embed_dim
1988 | self.drop_rate = 0.0
1989 | self.attn_drop = 0.0
1990 | self.drop_path_rate = 0.1
1991 |
1992 | self.Tenc_x2 = EncoderTransformer_v3(
1993 | img_size=256,
1994 | patch_size=4,
1995 | in_chans=input_nc,
1996 | num_classes=output_nc,
1997 | embed_dims=self.embed_dims,
1998 | num_heads=[1, 2, 5, 8],
1999 | mlp_ratios=[4, 4, 4, 4],
2000 | qkv_bias=True,
2001 | qk_scale=None,
2002 | drop_rate=self.drop_rate,
2003 | attn_drop_rate=self.attn_drop,
2004 | drop_path_rate=self.drop_path_rate,
2005 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
2006 | depths=self.depths,
2007 | sr_ratios=[8, 4, 2, 1],
2008 | )
2009 |
2010 | # Transformer Decoder
2011 | self.TDec_x2 = DecoderTransformer_v3(
2012 | input_transform="multiple_select",
2013 | in_index=[0, 1, 2, 3],
2014 | align_corners=False,
2015 | in_channels=self.embed_dims,
2016 | embedding_dim=self.embedding_dim,
2017 | output_nc=output_nc,
2018 | decoder_softmax=decoder_softmax,
2019 | feature_strides=[2, 4, 8, 16],
2020 | )
2021 |
2022 | def forward(self, x1, x2):
2023 | [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)]
2024 |
2025 | cp = self.TDec_x2(fx1, fx2)
2026 |
2027 | # # Save to mat
2028 | # save_to_mat(x1, x2, fx1, fx2, cp, "ChangeFormerV4")
2029 |
2030 | # exit()
2031 | return cp
2032 |
2033 |
2034 | # ChangeFormerV6:
2035 | class ChangeFormerV6(nn.Module):
2036 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False, embed_dim=256):
2037 | super().__init__()
2038 | # Transformer Encoder
2039 | self.embed_dims = [64, 128, 320, 512]
2040 | self.depths = [3, 3, 4, 3] # [3, 3, 6, 18, 3]
2041 | self.embedding_dim = embed_dim
2042 | self.drop_rate = 0.1
2043 | self.attn_drop = 0.1
2044 | self.drop_path_rate = 0.1
2045 |
2046 | self.Tenc_x2 = EncoderTransformer_v3(
2047 | img_size=256,
2048 | patch_size=7,
2049 | in_chans=input_nc,
2050 | num_classes=output_nc,
2051 | embed_dims=self.embed_dims,
2052 | num_heads=[1, 2, 4, 8],
2053 | mlp_ratios=[4, 4, 4, 4],
2054 | qkv_bias=True,
2055 | qk_scale=None,
2056 | drop_rate=self.drop_rate,
2057 | attn_drop_rate=self.attn_drop,
2058 | drop_path_rate=self.drop_path_rate,
2059 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
2060 | depths=self.depths,
2061 | sr_ratios=[8, 4, 2, 1],
2062 | )
2063 |
2064 | # Transformer Decoder
2065 | self.TDec_x2 = DecoderTransformer_v3(
2066 | input_transform="multiple_select",
2067 | in_index=[0, 1, 2, 3],
2068 | align_corners=False,
2069 | in_channels=self.embed_dims,
2070 | embedding_dim=self.embedding_dim,
2071 | output_nc=output_nc,
2072 | decoder_softmax=decoder_softmax,
2073 | feature_strides=[2, 4, 8, 16],
2074 | )
2075 |
2076 | def forward(self, x1, x2):
2077 | [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)]
2078 |
2079 | cp = self.TDec_x2(fx1, fx2)[-1]
2080 |
2081 | # # Save to mat
2082 | # save_to_mat(x1, x2, fx1, fx2, cp, "ChangeFormerV4")
2083 |
2084 | # exit()
2085 | return cp
2086 |
--------------------------------------------------------------------------------
/src/models/changeformer/ChangeFormerBaseNetworks.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 |
7 |
8 | class ConvBlock(torch.nn.Module):
9 | def __init__(
10 | self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation="prelu", norm=None
11 | ):
12 | super().__init__()
13 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
14 |
15 | self.norm = norm
16 | if self.norm == "batch":
17 | self.bn = torch.nn.BatchNorm2d(output_size)
18 | elif self.norm == "instance":
19 | self.bn = torch.nn.InstanceNorm2d(output_size)
20 |
21 | self.activation = activation
22 | if self.activation == "relu":
23 | self.act = torch.nn.ReLU(True)
24 | elif self.activation == "prelu":
25 | self.act = torch.nn.PReLU()
26 | elif self.activation == "lrelu":
27 | self.act = torch.nn.LeakyReLU(0.2, True)
28 | elif self.activation == "tanh":
29 | self.act = torch.nn.Tanh()
30 | elif self.activation == "sigmoid":
31 | self.act = torch.nn.Sigmoid()
32 |
33 | def forward(self, x):
34 | if self.norm is not None:
35 | out = self.bn(self.conv(x))
36 | else:
37 | out = self.conv(x)
38 |
39 | if self.activation != "no":
40 | return self.act(out)
41 | else:
42 | return out
43 |
44 |
45 | class DeconvBlock(torch.nn.Module):
46 | def __init__(
47 | self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation="prelu", norm=None
48 | ):
49 | super().__init__()
50 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias)
51 |
52 | self.norm = norm
53 | if self.norm == "batch":
54 | self.bn = torch.nn.BatchNorm2d(output_size)
55 | elif self.norm == "instance":
56 | self.bn = torch.nn.InstanceNorm2d(output_size)
57 |
58 | self.activation = activation
59 | if self.activation == "relu":
60 | self.act = torch.nn.ReLU(True)
61 | elif self.activation == "prelu":
62 | self.act = torch.nn.PReLU()
63 | elif self.activation == "lrelu":
64 | self.act = torch.nn.LeakyReLU(0.2, True)
65 | elif self.activation == "tanh":
66 | self.act = torch.nn.Tanh()
67 | elif self.activation == "sigmoid":
68 | self.act = torch.nn.Sigmoid()
69 |
70 | def forward(self, x):
71 | if self.norm is not None:
72 | out = self.bn(self.deconv(x))
73 | else:
74 | out = self.deconv(x)
75 |
76 | if self.activation is not None:
77 | return self.act(out)
78 | else:
79 | return out
80 |
81 |
82 | class ConvLayer(nn.Module):
83 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
84 | super().__init__()
85 | # reflection_padding = kernel_size // 2
86 | # self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
87 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
88 |
89 | def forward(self, x):
90 | # out = self.reflection_pad(x)
91 | out = self.conv2d(x)
92 | return out
93 |
94 |
95 | class UpsampleConvLayer(torch.nn.Module):
96 | def __init__(self, in_channels, out_channels, kernel_size, stride):
97 | super().__init__()
98 | self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1)
99 |
100 | def forward(self, x):
101 | out = self.conv2d(x)
102 | return out
103 |
104 |
105 | class ResidualBlock(torch.nn.Module):
106 | def __init__(self, channels):
107 | super().__init__()
108 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
109 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1)
110 | self.relu = nn.ReLU()
111 |
112 | def forward(self, x):
113 | residual = x
114 | out = self.relu(self.conv1(x))
115 | out = self.conv2(out) * 0.1
116 | out = torch.add(out, residual)
117 | return out
118 |
119 |
120 | def init_linear(linear):
121 | init.xavier_normal(linear.weight)
122 | linear.bias.data.zero_()
123 |
124 |
125 | def init_conv(conv, glu=True):
126 | init.kaiming_normal(conv.weight)
127 | if conv.bias is not None:
128 | conv.bias.data.zero_()
129 |
130 |
131 | class EqualLR:
132 | def __init__(self, name):
133 | self.name = name
134 |
135 | def compute_weight(self, module):
136 | weight = getattr(module, self.name + "_orig")
137 | fan_in = weight.data.size(1) * weight.data[0][0].numel()
138 |
139 | return weight * sqrt(2 / fan_in)
140 |
141 | @staticmethod
142 | def apply(module, name):
143 | fn = EqualLR(name)
144 |
145 | weight = getattr(module, name)
146 | del module._parameters[name]
147 | module.register_parameter(name + "_orig", nn.Parameter(weight.data))
148 | module.register_forward_pre_hook(fn)
149 |
150 | return fn
151 |
152 | def __call__(self, module, input):
153 | weight = self.compute_weight(module)
154 | setattr(module, self.name, weight)
155 |
156 |
157 | def equal_lr(module, name="weight"):
158 | EqualLR.apply(module, name)
159 |
160 | return module
161 |
--------------------------------------------------------------------------------
/src/models/changeformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/models/changeformer/__init__.py
--------------------------------------------------------------------------------
/src/models/tiny_cd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/models/tiny_cd/__init__.py
--------------------------------------------------------------------------------
/src/models/tiny_cd/change_classifier.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torch import Tensor
3 | from torch.nn import Identity, Module, ModuleList
4 |
5 | from .layers import MixingBlock, MixingMaskAttentionBlock, PixelwiseLinear, UpMask
6 |
7 |
8 | class TinyCD(Module):
9 | def __init__(self, bkbn_name="efficientnet_b4", pretrained=True, output_layer_bkbn="3", freeze_backbone=False):
10 | super().__init__()
11 |
12 | # Load the pretrained backbone according to parameters:
13 | self._backbone = _get_backbone(bkbn_name, pretrained, output_layer_bkbn, freeze_backbone)
14 |
15 | # Initialize mixing blocks:
16 | self._first_mix = MixingMaskAttentionBlock(6, 3, [3, 10, 5], [10, 5, 1])
17 | self._mixing_mask = ModuleList(
18 | [
19 | MixingMaskAttentionBlock(48, 24, [24, 12, 6], [12, 6, 1]),
20 | MixingMaskAttentionBlock(64, 32, [32, 16, 8], [16, 8, 1]),
21 | MixingBlock(112, 56),
22 | ]
23 | )
24 |
25 | # Initialize Upsampling blocks:
26 | self._up = ModuleList([UpMask(2, 56, 64), UpMask(2, 64, 64), UpMask(2, 64, 32)])
27 |
28 | # Final classification layer:
29 | self._classify = PixelwiseLinear([32, 16, 8], [16, 8, 2], Identity())
30 |
31 | def forward(self, ref: Tensor, test: Tensor) -> Tensor:
32 | features = self._encode(ref, test)
33 | latents = self._decode(features)
34 | return self._classify(latents)
35 |
36 | def _encode(self, ref, test) -> list[Tensor]:
37 | features = [self._first_mix(ref, test)]
38 | for num, layer in enumerate(self._backbone):
39 | ref, test = layer(ref), layer(test)
40 | if num != 0:
41 | features.append(self._mixing_mask[num - 1](ref, test))
42 | return features
43 |
44 | def _decode(self, features) -> Tensor:
45 | upping = features[-1]
46 | for i, j in enumerate(range(-2, -5, -1)):
47 | upping = self._up[i](upping, features[j])
48 | return upping
49 |
50 |
51 | def _get_backbone(bkbn_name, pretrained, output_layer_bkbn, freeze_backbone) -> ModuleList:
52 | # The whole model:
53 | entire_model = getattr(torchvision.models, bkbn_name)(pretrained=pretrained).features
54 |
55 | # Slicing it:
56 | derived_model = ModuleList([])
57 | for name, layer in entire_model.named_children():
58 | derived_model.append(layer)
59 | if name == output_layer_bkbn:
60 | break
61 |
62 | # Freezing the backbone weights:
63 | if freeze_backbone:
64 | for param in derived_model.parameters():
65 | param.requires_grad = False
66 | return derived_model
67 |
--------------------------------------------------------------------------------
/src/models/tiny_cd/layers.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor, reshape, stack
2 | from torch.nn import Conv2d, InstanceNorm2d, Module, PReLU, Sequential, Upsample
3 |
4 |
5 | class PixelwiseLinear(Module):
6 | def __init__(self, fin: list[int], fout: list[int], last_activation: Module = None) -> None:
7 | assert len(fout) == len(fin)
8 | super().__init__()
9 |
10 | n = len(fin)
11 | self._linears = Sequential(
12 | *[
13 | Sequential(
14 | Conv2d(fin[i], fout[i], kernel_size=1, bias=True),
15 | PReLU() if i < n - 1 or last_activation is None else last_activation,
16 | )
17 | for i in range(n)
18 | ]
19 | )
20 |
21 | def forward(self, x: Tensor) -> Tensor:
22 | # Processing the tensor:
23 | return self._linears(x)
24 |
25 |
26 | class MixingBlock(Module):
27 | def __init__(self, ch_in: int, ch_out: int):
28 | super().__init__()
29 | self._convmix = Sequential(Conv2d(ch_in, ch_out, 3, groups=ch_out, padding=1), PReLU(), InstanceNorm2d(ch_out))
30 |
31 | def forward(self, x: Tensor, y: Tensor) -> Tensor:
32 | # Packing the tensors and interleaving the channels:
33 | mixed = stack((x, y), dim=2)
34 | mixed = reshape(mixed, (x.shape[0], -1, x.shape[2], x.shape[3]))
35 |
36 | # Mixing:
37 | return self._convmix(mixed)
38 |
39 |
40 | class MixingMaskAttentionBlock(Module):
41 | """use the grouped convolution to make a sort of attention"""
42 |
43 | def __init__(self, ch_in: int, ch_out: int, fin: list[int], fout: list[int], generate_masked: bool = False):
44 | super().__init__()
45 | self._mixing = MixingBlock(ch_in, ch_out)
46 | self._linear = PixelwiseLinear(fin, fout)
47 | self._final_normalization = InstanceNorm2d(ch_out) if generate_masked else None
48 | self._mixing_out = MixingBlock(ch_in, ch_out) if generate_masked else None
49 |
50 | def forward(self, x: Tensor, y: Tensor) -> Tensor:
51 | z_mix = self._mixing(x, y)
52 | z = self._linear(z_mix)
53 | z_mix_out = 0 if self._mixing_out is None else self._mixing_out(x, y)
54 |
55 | return z if self._final_normalization is None else self._final_normalization(z_mix_out * z)
56 |
57 |
58 | class UpMask(Module):
59 | def __init__(self, scale_factor: float, nin: int, nout: int):
60 | super().__init__()
61 | self._upsample = Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=True)
62 | self._convolution = Sequential(
63 | Conv2d(nin, nin, 3, 1, groups=nin, padding=1),
64 | PReLU(),
65 | InstanceNorm2d(nin),
66 | Conv2d(nin, nout, kernel_size=1, stride=1),
67 | PReLU(),
68 | InstanceNorm2d(nout),
69 | )
70 |
71 | def forward(self, x: Tensor, y: Tensor | None = None) -> Tensor:
72 | x = self._upsample(x)
73 | if y is not None:
74 | x = x * y
75 | return self._convolution(x)
76 |
--------------------------------------------------------------------------------
/test_levircd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 |
5 | import lightning
6 | import pandas as pd
7 | from src.change_detection import ChangeDetectionTask
8 | from src.datasets.levircd import LEVIRCDDataModule
9 | from tqdm import tqdm
10 |
11 |
12 | def main(args):
13 | lightning.seed_everything(0)
14 | checkpoints = glob.glob(f"{args.ckpt_root}/**/checkpoints/epoch*.ckpt")
15 | runs = [ckpt.split(os.sep)[-3] for ckpt in checkpoints]
16 |
17 | metrics = {}
18 | for run, ckpt in tqdm(zip(runs, checkpoints, strict=False), total=len(runs)):
19 | datamodule = LEVIRCDDataModule(
20 | root=args.root, batch_size=args.batch_size, patch_size=256, num_workers=args.workers
21 | )
22 | module = ChangeDetectionTask.load_from_checkpoint(ckpt, map_location="cpu")
23 | trainer = lightning.Trainer(
24 | accelerator=args.accelerator, devices=[args.device], logger=False, precision="16-mixed"
25 | )
26 | metrics[run] = trainer.test(model=module, datamodule=datamodule)[0]
27 | metrics[run]["model"] = module.hparams.model
28 |
29 | metrics = pd.DataFrame.from_dict(metrics, orient="index")
30 | metrics.to_csv(args.output_filename)
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--root", type=str, default="./data/levircd")
36 | parser.add_argument("--ckpt-root", type=str, default="lightning_logs")
37 | parser.add_argument("--batch-size", type=int, default=8)
38 | parser.add_argument("--workers", type=int, default=16)
39 | parser.add_argument("--accelerator", type=str, default="gpu")
40 | parser.add_argument("--device", type=int, default=0)
41 | parser.add_argument("--output-filename", type=str, default="metrics.csv")
42 | args = parser.parse_args()
43 | main(args)
44 |
--------------------------------------------------------------------------------
/test_whucd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 |
5 | import lightning
6 | import pandas as pd
7 | from src.change_detection import ChangeDetectionTask
8 | from src.datasets.whucd import WHUCDDataModule
9 | from tqdm import tqdm
10 |
11 |
12 | def main(args):
13 | lightning.seed_everything(0)
14 | checkpoints = glob.glob(f"{args.ckpt_root}/**/checkpoints/epoch*.ckpt")
15 | runs = [ckpt.split(os.sep)[-3] for ckpt in checkpoints]
16 |
17 | metrics = {}
18 | for run, ckpt in tqdm(zip(runs, checkpoints, strict=False), total=len(runs)):
19 | datamodule = WHUCDDataModule(
20 | root=args.root, batch_size=args.batch_size, patch_size=256, num_workers=args.workers
21 | )
22 | module = ChangeDetectionTask.load_from_checkpoint(ckpt, map_location="cpu")
23 | trainer = lightning.Trainer(
24 | accelerator=args.accelerator, devices=[args.device], logger=False, precision="16-mixed"
25 | )
26 | metrics[run] = trainer.test(model=module, datamodule=datamodule)[0]
27 | metrics[run]["model"] = module.hparams.model
28 |
29 | metrics = pd.DataFrame.from_dict(metrics, orient="index")
30 | metrics.to_csv(args.output_filename)
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--root", type=str, default="/workspace/storage/data/whucd-chipped")
36 | parser.add_argument("--ckpt-root", type=str, default="lightning_logs")
37 | parser.add_argument("--batch-size", type=int, default=8)
38 | parser.add_argument("--workers", type=int, default=16)
39 | parser.add_argument("--accelerator", type=str, default="gpu")
40 | parser.add_argument("--device", type=int, default=0)
41 | parser.add_argument("--output-filename", type=str, default="metrics.csv")
42 | args = parser.parse_args()
43 | main(args)
44 |
--------------------------------------------------------------------------------
/train_levircd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import lightning
4 | from lightning.pytorch.callbacks import ModelCheckpoint
5 | from src.change_detection import ChangeDetectionTask
6 | from src.datasets.levircd import LEVIRCDDataModule
7 |
8 |
9 | def main(args):
10 | for seed in range(args.num_seeds):
11 | lightning.seed_everything(seed)
12 | datamodule = LEVIRCDDataModule(
13 | root=args.root, batch_size=args.batch_size, patch_size=256, num_workers=args.workers
14 | )
15 | datamodule.train_root = args.train_root
16 | module = ChangeDetectionTask(
17 | model=args.model, backbone=args.backbone, weights=True, in_channels=3, num_classes=2, loss="ce", lr=args.lr
18 | )
19 |
20 | callbacks = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=1)
21 | trainer = lightning.Trainer(
22 | accelerator=args.accelerator,
23 | devices=[args.device],
24 | logger=True,
25 | precision="16-mixed",
26 | max_epochs=args.epochs,
27 | log_every_n_steps=10,
28 | default_root_dir=f"logs-levircd-{args.model}",
29 | callbacks=[callbacks],
30 | )
31 | trainer.fit(model=module, datamodule=datamodule)
32 | trainer.test(datamodule=datamodule, ckpt_path="best")
33 |
34 |
35 | if __name__ == "__main__":
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("--root", type=str, default="./data/levircd")
38 | parser.add_argument("--train-root", type=str, default="./data/levircd-train-chipped")
39 | parser.add_argument(
40 | "--model",
41 | type=str,
42 | default="unet",
43 | choices=["unet", "fcsiamconc", "fcsiamdiff", "changeformer", "tinycd", "bit"],
44 | )
45 | parser.add_argument(
46 | "--backbone", type=str, default="resnet50", help="only works with unet, fcsiamdiff, or fcsiamconc"
47 | )
48 | parser.add_argument("--epochs", type=int, default=200)
49 | parser.add_argument("--batch-size", type=int, default=8)
50 | parser.add_argument("--workers", type=int, default=8)
51 | parser.add_argument("--lr", type=float, default=0.01)
52 | parser.add_argument("--accelerator", type=str, default="gpu")
53 | parser.add_argument("--device", type=int, default=0)
54 | parser.add_argument("--num_seeds", type=int, default=10)
55 | args = parser.parse_args()
56 | main(args)
57 |
--------------------------------------------------------------------------------
/train_whucd.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import lightning
4 | from lightning.pytorch.callbacks import ModelCheckpoint
5 | from src.change_detection import ChangeDetectionTask
6 | from src.datasets.whucd import WHUCDDataModule
7 |
8 |
9 | def main(args):
10 | for seed in range(args.num_seeds):
11 | lightning.seed_everything(seed)
12 | datamodule = WHUCDDataModule(
13 | val_split_pct=0.1, root=args.root, batch_size=args.batch_size, patch_size=256, num_workers=args.workers
14 | )
15 | module = ChangeDetectionTask(
16 | model=args.model, backbone=args.backbone, weights=True, in_channels=3, num_classes=2, loss="ce", lr=args.lr
17 | )
18 |
19 | callbacks = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=1)
20 | trainer = lightning.Trainer(
21 | accelerator=args.accelerator,
22 | devices=[args.device],
23 | logger=True,
24 | precision="16-mixed",
25 | max_epochs=args.epochs,
26 | log_every_n_steps=10,
27 | default_root_dir=f"logs-whucd-{args.model}",
28 | callbacks=[callbacks],
29 | )
30 | trainer.fit(model=module, datamodule=datamodule)
31 | trainer.test(datamodule=datamodule, ckpt_path="best")
32 |
33 |
34 | if __name__ == "__main__":
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument("--root", type=str, default="/workspace/storage/data/whucd-chipped")
37 | parser.add_argument(
38 | "--model",
39 | type=str,
40 | default="unet",
41 | choices=["unet", "fcsiamconc", "fcsiamdiff", "changeformer", "tinycd", "bit"],
42 | )
43 | parser.add_argument(
44 | "--backbone", type=str, default="resnet50", help="only works with unet, fcsiamdiff, or fcsiamconc"
45 | )
46 | parser.add_argument("--epochs", type=int, default=200)
47 | parser.add_argument("--batch-size", type=int, default=8)
48 | parser.add_argument("--workers", type=int, default=8)
49 | parser.add_argument("--lr", type=float, default=0.01)
50 | parser.add_argument("--accelerator", type=str, default="gpu")
51 | parser.add_argument("--device", type=int, default=0)
52 | parser.add_argument("--num_seeds", type=int, default=10)
53 | args = parser.parse_args()
54 | main(args)
55 |
--------------------------------------------------------------------------------