├── .gitignore ├── LICENSE ├── README.md ├── assets ├── levircd-results.png └── whucd-results.png ├── pyproject.toml ├── requirements.txt ├── scripts ├── preprocess_levircd.ipynb └── preprocess_whucd.ipynb ├── src ├── __init__.py ├── change_detection.py ├── datasets │ ├── __init__.py │ ├── levircd.py │ └── whucd.py └── models │ ├── __init__.py │ ├── bit │ ├── __init__.py │ ├── help_funcs.py │ ├── networks.py │ └── resnet.py │ ├── changeformer │ ├── ChangeFormer.py │ ├── ChangeFormerBaseNetworks.py │ └── __init__.py │ └── tiny_cd │ ├── __init__.py │ ├── change_classifier.py │ └── layers.py ├── test_levircd.py ├── test_whucd.py ├── train_levircd.py └── train_whucd.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | logs* 3 | *.csv 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Isaac Corley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

A Change Detection Reality Check

3 | 4 | [**Isaac Corley**](https://isaacc.dev/)1 · [**Caleb Robinson**](https://www.microsoft.com/en-us/research/people/davrob/)2 · [**Anthony Ortiz**](https://www.microsoft.com/en-us/research/people/anort/)2 5 | 6 | 1University of Texas at San Antonio    2Microsoft AI for Good Research Lab 7 | 8 | Paper PDF 9 | 10 |
11 | 12 | Code and experiments for the paper, ["A Change Detection Reality Check", Isaac Corley, Caleb Robinson, Anthony Ortiz](https://arxiv.org/abs/2402.06994) presented at the [ICLR 2024 Machine Learning for Remote Sensing (ML4RS) Workshop](https://ml-for-rs.github.io/iclr2024/) 13 | 14 | ### Summary 15 | 16 | Remote sensing image literature from the past several years has exploded with proposed deep learning architectures that claim to be the latest state-of-the-art on standard change detection benchmark datasets. However, has the field truly made significant progress? In this paper we perform experiments which conclude a simple U-Net segmentation baseline without training tricks or complicated architectural changes is still a top performer for the task of change detection. 17 | 18 | ### Results 19 | 20 | We find that U-Net is still a top performer on the LEVIR-CD and WHU-CD benchmark datasets. See below tables for comparisons with SOTA methods. 21 | 22 |

23 |
24 | Table 1. Comparison of state-of-the-art and change detection architectures to a U-Net baseline on the LEVIR-CD dataset. We report the test set precision, recall, and F1 metrics of the positive change class. For the baseline experiments we perform 10 runs while varying random the seed and report metrics from the highest performing run. All other metrics are taken from their respective papers. The top performing methods are highlighted in bold. Gray rows indicate our baseline U-Net and siamese encoder variants. 25 |

26 | 27 |

28 |
29 | Table 2. Experimental results on the WHU-CD dataset. We retrain several state-of-the-art methods using the original dataset’s train/test splits instead of the commonly used randomly split preprocessed version created in (Bandara & Patel (2022a)). We find that these state-of-the-art methods are outperformed by a U-Net baseline. We report the test set precision, recall, F1, and IoU metrics of the positive change class. For each run we select the model checkpoint with the lowest validation set loss. We provide metrics averaged over 10 runs with varying random seed as well as the best seed. Gray rows indicate our baseline U-Net and siamese encoder variants. 30 |

31 | 32 | ### Model Checkpoints 33 | 34 | **Model Checkpoints uploaded to HuggingFace [here](https://huggingface.co/isaaccorley/a-change-detection-reality-check)! 35 | 36 | 37 | #### LEVIR-CD 38 | 39 | | **Model** | **Backbone** | **Precision** | **Recall** | **F1** | **IoU** | **Checkpoint** | 40 | |:--------------: |:---------------: |:---------: |:------: |:------: |:------: |:----------: | 41 | | U-Net | ResNet-50 | 0.9197 | 0.8795 | 0.8991 | 0.8167 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/levir-cd/unet_resnet50.ckpt) | 42 | | U-Net | EfficientNet-B4 | 0.9269 | 0.8588 | 0.8915 | 0.8044 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/levir-cd/unet_efficientnetb4.ckpt) | 43 | | U-Net SiamConc | ResNet-50 | 0.9287 | 0.8749 | 0.9010 | 0.8199 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/levir-cd/unet_siamconc_resnet50.ckpt) | 44 | | U-Net SiamDiff | ResNet-50 | 0.9321 | 0.8730 | 0.9015 | 0.8207 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/levir-cd/unet_siamdiff_resnet50.ckpt) | 45 | 46 | #### WHU-CD (using official train/test splits) 47 | 48 | | **Model** | **Backbone** | **Precision** | **Recall** | **F1** | **IoU** | **Checkpoint** | 49 | |:--------------: |:---------: |:---------: |:------: |:------: |:------: |:----------: | 50 | | U-Net SiamConc | ResNet-50 | 0.8369 | 0.8130 | 0.8217 | 0.7054 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/whu-cd/unet_siamconc_resnet50.ckpt) | 51 | | U-Net SiamDiff | ResNet-50 | 0.8856 | 0.7741 | 0.8248 | 0.7086 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/whu-cd/unet_siamdiff_resnet50.ckpt) | 52 | | U-Net | ResNet-50 | 0.8865 | 0.7663 | 0.8200 | 0.7020 | [Checkpoint](https://huggingface.co/isaaccorley/a-change-detection-reality-check/resolve/main/whu-cd/unet_resnet50.ckpt) | 53 | 54 | ### Reproducing Results 55 | 56 | Download the [LEVIR-CD](https://chenhao.in/LEVIR/) and [WHU-CD](http://gpcv.whu.edu.cn/data/building_dataset.html) datasets and then use the following notebooks to chip the datasets into non-overlapping 256x256 patches. 57 | 58 | ```bash 59 | scripts/preprocess_levircd.ipynb 60 | scripts/preprocess_whucd.ipynb 61 | ``` 62 | 63 | To train UNet on both datasets over 10 random seeds run 64 | 65 | ```bash 66 | python train_levircd.py --train-root /path/to/preprocessed-dataset/ --model unet --backbone resnet50 --num_seeds 10 67 | python train_whucd.py --train-root /path/to/preprocessed-dataset/ --model unet --backbone resnet50 --num_seeds 10 68 | ``` 69 | 70 | To evaluate a set of checkpoints and save results to a .csv file run: 71 | 72 | ```bash 73 | python test_levircd.py --root /path/to/preprocessed-dataset/ --ckpt-root lightning_logs/ --output-filename metrics.csv 74 | python test_whucd.py --root /path/to/preprocessed-dataset/ --ckpt-root lightning_logs/ --output-filename metrics.csv 75 | ``` 76 | 77 | ### Citation 78 | 79 | If this work inspired your change detection research, please consider citing our paper: 80 | 81 | ``` 82 | @article{corley2024change, 83 | title={A Change Detection Reality Check}, 84 | author={Corley, Isaac and Robinson, Caleb and Ortiz, Anthony}, 85 | journal={arXiv preprint arXiv:2402.06994}, 86 | year={2024} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /assets/levircd-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/assets/levircd-results.png -------------------------------------------------------------------------------- /assets/whucd-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/assets/whucd-results.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | target-version = "py311" 3 | line-length = 120 4 | src = ["src", "notebooks"] 5 | force-exclude = true 6 | fix = true 7 | show-fixes = true 8 | 9 | [tool.ruff.format] 10 | skip-magic-trailing-comma = true 11 | 12 | [tool.ruff.lint] 13 | extend-select = ["B", "Q", "I", "UP"] 14 | ignore = [ 15 | "E203", 16 | "E402", 17 | "F821", 18 | "F405", 19 | "F403", 20 | "E731", 21 | "B006", 22 | "B008", 23 | "B904", 24 | "E741", 25 | "F401", 26 | ] 27 | 28 | [tool.ruff.lint.pylint] 29 | max-returns = 5 30 | max-args = 25 31 | 32 | [tool.ruff.lint.isort] 33 | split-on-trailing-comma = false -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchgeo[all]==0.6.0 2 | image_bbox_slicer 3 | einops -------------------------------------------------------------------------------- /scripts/preprocess_levircd.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "Chipping directory A\n" 13 | ] 14 | }, 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "/home/ubuntu/miniconda3/envs/torchgeo/lib/python3.11/site-packages/image_bbox_slicer/helpers.py:113: UserWarning: Destination ../data/train-chipped/A directory does not exist so creating it now\n", 20 | " warnings.warn(\n" 21 | ] 22 | }, 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Obtained 7120 image slices!\n", 28 | "Chipping directory B\n" 29 | ] 30 | }, 31 | { 32 | "name": "stderr", 33 | "output_type": "stream", 34 | "text": [ 35 | "/home/ubuntu/miniconda3/envs/torchgeo/lib/python3.11/site-packages/image_bbox_slicer/helpers.py:113: UserWarning: Destination ../data/train-chipped/B directory does not exist so creating it now\n", 36 | " warnings.warn(\n" 37 | ] 38 | }, 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "Obtained 7120 image slices!\n", 44 | "Chipping directory label\n" 45 | ] 46 | }, 47 | { 48 | "name": "stderr", 49 | "output_type": "stream", 50 | "text": [ 51 | "/home/ubuntu/miniconda3/envs/torchgeo/lib/python3.11/site-packages/image_bbox_slicer/helpers.py:113: UserWarning: Destination ../data/train-chipped/label directory does not exist so creating it now\n", 52 | " warnings.warn(\n" 53 | ] 54 | }, 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "Obtained 7120 image slices!\n" 60 | ] 61 | } 62 | ], 63 | "source": [ 64 | "import os\n", 65 | "from pathlib import Path\n", 66 | "\n", 67 | "from image_bbox_slicer import Slicer\n", 68 | "from tqdm import tqdm\n", 69 | "\n", 70 | "directories = [\"A\", \"B\", \"label\"]\n", 71 | "root = \"../data/train\"\n", 72 | "output = \"../data/train-chipped/\"\n", 73 | "\n", 74 | "for directory in directories:\n", 75 | " print(f\"Chipping directory {directory}\")\n", 76 | " slicer = Slicer()\n", 77 | " src = os.path.join(root, directory)\n", 78 | " dst = os.path.join(output, directory)\n", 79 | " slicer.config_image_dirs(img_src=src, img_dst=dst)\n", 80 | " slicer.slice_images_by_size(tile_size=(256, 256), tile_overlap=0)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 3, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "data": { 90 | "text/plain": [ 91 | "['A', 'B', 'label']" 92 | ] 93 | }, 94 | "execution_count": 3, 95 | "metadata": {}, 96 | "output_type": "execute_result" 97 | } 98 | ], 99 | "source": [ 100 | "directories" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "metadata": {}, 107 | "outputs": [ 108 | { 109 | "name": "stderr", 110 | "output_type": "stream", 111 | "text": [ 112 | " 0%| | 0/7120 [00:00 None: 47 | if ignore_index is not None and loss == "jaccard": 48 | warnings.warn("ignore_index has no effect on training when loss='jaccard'", UserWarning, stacklevel=2) 49 | 50 | self.weights = weights 51 | super().__init__(ignore="weights") 52 | 53 | def configure_losses(self) -> None: 54 | """Initialize the loss criterion. 55 | Raises: 56 | ValueError: If *loss* is invalid. 57 | """ 58 | loss: str = self.hparams["loss"] 59 | ignore_index = self.hparams["ignore_index"] 60 | if loss == "ce": 61 | ignore_value = -1000 if ignore_index is None else ignore_index 62 | self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_value, weight=self.hparams["class_weights"]) 63 | elif loss == "jaccard": 64 | self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=self.hparams["num_classes"]) 65 | elif loss == "focal": 66 | self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True) 67 | else: 68 | raise ValueError(f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard' or 'focal' loss.") 69 | 70 | def configure_metrics(self) -> None: 71 | """Initialize the performance metrics.""" 72 | num_classes: int = self.hparams["num_classes"] 73 | ignore_index: int | None = self.hparams["ignore_index"] 74 | labels: list[str] | None = self.hparams["labels"] 75 | 76 | self.train_metrics = MetricCollection( 77 | { 78 | "OverallAccuracy": Accuracy( 79 | task="multiclass", num_classes=num_classes, average="micro", multidim_average="global" 80 | ), 81 | "OverallF1Score": FBetaScore( 82 | task="multiclass", num_classes=num_classes, beta=1.0, average="micro", multidim_average="global" 83 | ), 84 | "OverallIoU": JaccardIndex( 85 | task="multiclass", num_classes=num_classes, ignore_index=ignore_index, average="micro" 86 | ), 87 | "AverageAccuracy": Accuracy( 88 | task="multiclass", num_classes=num_classes, average="macro", multidim_average="global" 89 | ), 90 | "AverageF1Score": FBetaScore( 91 | task="multiclass", num_classes=num_classes, beta=1.0, average="macro", multidim_average="global" 92 | ), 93 | "AverageIoU": JaccardIndex( 94 | task="multiclass", num_classes=num_classes, ignore_index=ignore_index, average="macro" 95 | ), 96 | "Accuracy": ClasswiseWrapper( 97 | Accuracy(task="multiclass", num_classes=num_classes, average="none", multidim_average="global"), 98 | labels=labels, 99 | ), 100 | "Precision": ClasswiseWrapper( 101 | Precision(task="multiclass", num_classes=num_classes, average="none", multidim_average="global"), 102 | labels=labels, 103 | ), 104 | "Recall": ClasswiseWrapper( 105 | Recall(task="multiclass", num_classes=num_classes, average="none", multidim_average="global"), 106 | labels=labels, 107 | ), 108 | "F1Score": ClasswiseWrapper( 109 | FBetaScore( 110 | task="multiclass", num_classes=num_classes, beta=1.0, average="none", multidim_average="global" 111 | ), 112 | labels=labels, 113 | ), 114 | "IoU": ClasswiseWrapper( 115 | JaccardIndex(task="multiclass", num_classes=num_classes, average="none"), labels=labels 116 | ), 117 | }, 118 | prefix="train_", 119 | ) 120 | self.val_metrics = self.train_metrics.clone(prefix="val_") 121 | self.test_metrics = self.train_metrics.clone(prefix="test_") 122 | 123 | def configure_optimizers(self): 124 | def lambda_rule(epoch): 125 | lr_l = 1.0 - epoch / float(self.trainer.max_epochs + 1) 126 | return lr_l 127 | 128 | optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams["lr"], momentum=0.9, weight_decay=5e-4) 129 | scheduler = scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 130 | 131 | return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"}} 132 | 133 | def configure_models(self) -> None: 134 | """Initialize the model. 135 | Raises: 136 | ValueError: If *model* is invalid. 137 | """ 138 | model: str = self.hparams["model"] 139 | backbone: str = self.hparams["backbone"] 140 | weights = self.weights 141 | in_channels: int = self.hparams["in_channels"] 142 | num_classes: int = self.hparams["num_classes"] 143 | 144 | if model == "unet": 145 | self.model = smp.Unet( 146 | encoder_name=backbone, 147 | encoder_weights="imagenet" if weights is True else None, 148 | in_channels=in_channels * 2, # images are concatenated 149 | classes=num_classes, 150 | ) 151 | elif model == "fcsiamdiff": 152 | self.model = FCSiamDiff( 153 | encoder_name=backbone, 154 | in_channels=in_channels, 155 | classes=num_classes, 156 | encoder_weights="imagenet" if weights is True else None, 157 | ) 158 | elif model == "fcsiamconc": 159 | self.model = FCSiamConc( 160 | encoder_name=backbone, 161 | in_channels=in_channels, 162 | classes=num_classes, 163 | encoder_weights="imagenet" if weights is True else None, 164 | ) 165 | elif model == "bit": 166 | self.model = BIT(arch="base_transformer_pos_s4_dd8") 167 | elif model == "changeformer": 168 | self.model = ChangeFormerV6( 169 | input_nc=in_channels, output_nc=num_classes, decoder_softmax=False, embed_dim=256 170 | ) 171 | elif model == "tinycd": 172 | self.model = TinyCD( 173 | bkbn_name="efficientnet_b4", pretrained=True, output_layer_bkbn="3", freeze_backbone=False 174 | ) 175 | else: 176 | raise ValueError(f"Model type '{model}' is not valid.") 177 | 178 | if weights and weights is not True: 179 | if isinstance(weights, WeightsEnum): 180 | state_dict = weights.get_state_dict(progress=True) 181 | elif os.path.exists(weights): 182 | _, state_dict = utils.extract_backbone(weights) 183 | else: 184 | state_dict = get_weight(weights).get_state_dict(progress=True) 185 | self.model.encoder.load_state_dict(state_dict) 186 | 187 | # Freeze backbone 188 | if self.hparams["freeze_backbone"] and model in ["unet"]: 189 | for param in self.model.encoder.parameters(): 190 | param.requires_grad = False 191 | 192 | # Freeze decoder 193 | if self.hparams["freeze_decoder"] and model in ["unet"]: 194 | for param in self.model.decoder.parameters(): 195 | param.requires_grad = False 196 | 197 | def training_step(self, batch: Any, batch_idx: int) -> Tensor: 198 | image1, image2, y = batch["image1"], batch["image2"], batch["mask"] 199 | 200 | model: str = self.hparams["model"] 201 | if model == "unet": 202 | x = torch.cat([image1, image2], dim=1) 203 | elif model in ["fcsiamdiff", "fcsiamconc"]: 204 | x = torch.stack((image1, image2), dim=1) 205 | 206 | if model in ["bit", "changeformer", "tinycd"]: 207 | y_hat = self(image1, image2) 208 | else: 209 | y_hat = self(x) 210 | 211 | loss: Tensor = self.criterion(y_hat, y) 212 | 213 | self.log("train_loss", loss) 214 | 215 | y_hat = torch.softmax(y_hat, dim=1) 216 | y_hat_hard = y_hat.argmax(dim=1) 217 | 218 | self.train_metrics(y_hat_hard, y) 219 | self.log_dict({f"{k}": v for k, v in self.train_metrics.compute().items()}) 220 | 221 | return loss 222 | 223 | def validation_step(self, batch: Any, batch_idx: int) -> None: 224 | image1, image2, y = batch["image1"], batch["image2"], batch["mask"] 225 | 226 | model: str = self.hparams["model"] 227 | if model == "unet": 228 | x = torch.cat([image1, image2], dim=1) 229 | elif model in ["fcsiamdiff", "fcsiamconc"]: 230 | x = torch.stack((image1, image2), dim=1) 231 | 232 | if model in ["bit", "changeformer", "tinycd"]: 233 | y_hat = self(image1, image2) 234 | else: 235 | y_hat = self(x) 236 | 237 | loss: Tensor = self.criterion(y_hat, y) 238 | 239 | self.log("val_loss", loss, on_epoch=True) 240 | 241 | y_hat = torch.softmax(y_hat, dim=1) 242 | y_hat_hard = y_hat.argmax(dim=1) 243 | 244 | self.val_metrics(y_hat_hard, y) 245 | self.log_dict({f"{k}": v for k, v in self.val_metrics.compute().items()}, on_epoch=True) 246 | 247 | if ( 248 | batch_idx < 10 249 | and hasattr(self.trainer, "datamodule") 250 | and hasattr(self.trainer.datamodule, "plot") 251 | and self.logger 252 | and hasattr(self.logger, "experiment") 253 | and hasattr(self.logger.experiment, "add_figure") 254 | ): 255 | datamodule = self.trainer.datamodule 256 | batch["prediction"] = y_hat_hard 257 | for key in ["image1", "image2", "mask", "prediction"]: 258 | batch[key] = batch[key].cpu() 259 | sample = unbind_samples(batch)[0] 260 | 261 | fig: Figure | None = None 262 | try: 263 | fig = datamodule.plot(sample) 264 | except RGBBandsMissingError: 265 | pass 266 | 267 | if fig: 268 | summary_writer = self.logger.experiment 269 | summary_writer.add_figure(f"image/{batch_idx}", fig, global_step=self.global_step) 270 | plt.close() 271 | 272 | def test_step(self, batch: Any, batch_idx: int) -> None: 273 | image1, image2, y = batch["image1"], batch["image2"], batch["mask"] 274 | 275 | model: str = self.hparams["model"] 276 | if model == "unet": 277 | x = torch.cat([image1, image2], dim=1) 278 | elif model in ["fcsiamdiff", "fcsiamconc"]: 279 | x = torch.stack((image1, image2), dim=1) 280 | 281 | if model in ["bit", "changeformer", "tinycd"]: 282 | y_hat = self(image1, image2) 283 | else: 284 | y_hat = self(x) 285 | 286 | loss: Tensor = self.criterion(y_hat, y) 287 | 288 | self.log("test_loss", loss, on_epoch=True) 289 | 290 | y_hat = torch.softmax(y_hat, dim=1) 291 | y_hat_hard = y_hat.argmax(dim=1) 292 | 293 | self.test_metrics(y_hat_hard, y) 294 | self.log_dict({f"{k}": v for k, v in self.test_metrics.compute().items()}, on_epoch=True) 295 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/datasets/__init__.py -------------------------------------------------------------------------------- /src/datasets/levircd.py: -------------------------------------------------------------------------------- 1 | import kornia.augmentation as K 2 | import torchgeo.datamodules 3 | import torchgeo.datasets 4 | from torchgeo.transforms import AugmentationSequential 5 | from torchgeo.transforms.transforms import _ExtractPatches 6 | 7 | 8 | class LEVIRCDDataModule(torchgeo.datamodules.LEVIRCDDataModule): 9 | train_root = "" 10 | 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.train_aug = AugmentationSequential( 14 | K.RandomHorizontalFlip(p=0.5), 15 | K.RandomVerticalFlip(p=0.5), 16 | K.RandomResizedCrop(size=self.patch_size, scale=(0.8, 1.0), ratio=(1, 1), p=1.0), 17 | K.Normalize(mean=0.0, std=255.0), 18 | K.Normalize(mean=0.5, std=0.5), 19 | data_keys=["image1", "image2", "mask"], 20 | ) 21 | self.val_aug = AugmentationSequential( 22 | K.Normalize(mean=0.0, std=255.0), 23 | K.Normalize(mean=0.5, std=0.5), 24 | _ExtractPatches(window_size=self.patch_size), 25 | data_keys=["image1", "image2", "mask"], 26 | same_on_batch=True, 27 | ) 28 | self.test_aug = AugmentationSequential( 29 | K.Normalize(mean=0.0, std=255.0), 30 | K.Normalize(mean=0.5, std=0.5), 31 | _ExtractPatches(window_size=self.patch_size), 32 | data_keys=["image1", "image2", "mask"], 33 | same_on_batch=True, 34 | ) 35 | 36 | def setup(self, stage: str) -> None: 37 | if stage in ["fit"]: 38 | self.train_dataset = torchgeo.datasets.LEVIRCD(root=self.train_root, split="train") 39 | if stage in ["fit", "validate"]: 40 | self.val_dataset = torchgeo.datasets.LEVIRCD(split="val", **self.kwargs) 41 | if stage == "test": 42 | self.test_dataset = torchgeo.datasets.LEVIRCD(split="test", **self.kwargs) 43 | -------------------------------------------------------------------------------- /src/datasets/whucd.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from collections.abc import Callable 4 | from typing import Optional 5 | 6 | import kornia.augmentation as K 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | import torchgeo 11 | from matplotlib.figure import Figure 12 | from PIL import Image 13 | from torch import Tensor 14 | from torchgeo.datamodules.geo import NonGeoDataModule 15 | from torchgeo.datamodules.utils import dataset_split 16 | from torchgeo.datasets.utils import percentile_normalization 17 | from torchgeo.transforms import AugmentationSequential 18 | from torchgeo.transforms.transforms import _ExtractPatches 19 | 20 | 21 | class WHUCD(torch.utils.data.Dataset): 22 | splits = ["train", "test"] 23 | 24 | def __init__( 25 | self, 26 | root: str = "data", 27 | split: str = "train", 28 | transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, 29 | ) -> None: 30 | assert split in self.splits 31 | 32 | self.root = root 33 | self.split = split 34 | self.transforms = transforms 35 | self.files = self._load_files(self.root, self.split) 36 | 37 | def __getitem__(self, index: int) -> dict[str, Tensor]: 38 | files = self.files[index] 39 | image1 = self._load_image(files["image1"]) 40 | image2 = self._load_image(files["image2"]) 41 | mask = self._load_target(files["mask"]) 42 | sample = {"image1": image1, "image2": image2, "mask": mask} 43 | 44 | if self.transforms is not None: 45 | sample = self.transforms(sample) 46 | 47 | return sample 48 | 49 | def __len__(self) -> int: 50 | return len(self.files) 51 | 52 | def _load_image(self, path: str) -> Tensor: 53 | filename = os.path.join(path) 54 | with Image.open(filename) as img: 55 | array = np.array(img.convert("RGB")) 56 | tensor = torch.from_numpy(array) 57 | tensor = tensor.float() 58 | tensor = tensor.permute((2, 0, 1)) 59 | return tensor 60 | 61 | def _load_target(self, path: str) -> Tensor: 62 | filename = os.path.join(path) 63 | with Image.open(filename) as img: 64 | array = np.array(img.convert("L")) 65 | tensor = torch.from_numpy(array) 66 | tensor = torch.clamp(tensor, min=0, max=1) 67 | tensor = tensor.to(torch.long) 68 | return tensor 69 | 70 | def plot(self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None) -> Figure: 71 | ncols = 3 72 | 73 | image1 = sample["image1"].permute(1, 2, 0).numpy() 74 | image1 = percentile_normalization(image1, lower=0, upper=98, axis=(0, 1)) 75 | 76 | image2 = sample["image2"].permute(1, 2, 0).numpy() 77 | image2 = percentile_normalization(image2, lower=0, upper=98, axis=(0, 1)) 78 | 79 | if "prediction" in sample: 80 | ncols += 1 81 | 82 | fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(10, ncols * 5)) 83 | 84 | axs[0].imshow(image1) 85 | axs[0].axis("off") 86 | axs[1].imshow(image2) 87 | axs[1].axis("off") 88 | axs[2].imshow(sample["mask"], cmap="gray", interpolation="none") 89 | axs[2].axis("off") 90 | 91 | if "prediction" in sample: 92 | axs[3].imshow(sample["prediction"], cmap="gray", interpolation="none") 93 | axs[3].axis("off") 94 | if show_titles: 95 | axs[3].set_title("Prediction") 96 | 97 | if show_titles: 98 | axs[0].set_title("Image 1") 99 | axs[1].set_title("Image 2") 100 | axs[2].set_title("Mask") 101 | 102 | if suptitle is not None: 103 | plt.suptitle(suptitle) 104 | 105 | return fig 106 | 107 | def _load_files(self, root: str, split: str) -> list[dict[str, str]]: 108 | images1 = sorted(glob.glob(os.path.join(root, "2012", split, "*.tif"))) 109 | images2 = sorted(glob.glob(os.path.join(root, "2016", split, "*.tif"))) 110 | masks = sorted(glob.glob(os.path.join(root, "change_label", split, "*.tif"))) 111 | 112 | files = [] 113 | for image1, image2, mask in zip(images1, images2, masks, strict=False): 114 | files.append(dict(image1=image1, image2=image2, mask=mask)) 115 | return files 116 | 117 | 118 | class WHUCDDataModule(NonGeoDataModule): 119 | def __init__(self, patch_size: int = 256, val_split_pct: float = 0.1, *args, **kwargs): 120 | super().__init__(WHUCD, *args, **kwargs) 121 | 122 | self.patch_size = (patch_size, patch_size) 123 | self.val_split_pct = val_split_pct 124 | 125 | self.train_aug = AugmentationSequential( 126 | K.RandomHorizontalFlip(p=0.5), 127 | K.RandomVerticalFlip(p=0.5), 128 | K.RandomResizedCrop(size=self.patch_size, scale=(0.8, 1.0), ratio=(1, 1), p=1.0), 129 | K.Normalize(mean=0.0, std=255.0), 130 | K.Normalize(mean=0.5, std=0.5), 131 | data_keys=["image1", "image2", "mask"], 132 | ) 133 | self.val_aug = AugmentationSequential( 134 | K.Normalize(mean=0.0, std=255.0), 135 | K.Normalize(mean=0.5, std=0.5), 136 | _ExtractPatches(window_size=self.patch_size), 137 | data_keys=["image1", "image2", "mask"], 138 | same_on_batch=True, 139 | ) 140 | self.test_aug = AugmentationSequential( 141 | K.Normalize(mean=0.0, std=255.0), 142 | K.Normalize(mean=0.5, std=0.5), 143 | _ExtractPatches(window_size=self.patch_size), 144 | data_keys=["image1", "image2", "mask"], 145 | same_on_batch=True, 146 | ) 147 | 148 | def setup(self, stage: str) -> None: 149 | if stage in ["fit", "validate"]: 150 | self.dataset = WHUCD(split="train", **self.kwargs) 151 | self.train_dataset, self.val_dataset = dataset_split(self.dataset, val_pct=self.val_split_pct) 152 | if stage == "test": 153 | self.test_dataset = WHUCD(split="test", **self.kwargs) 154 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .bit.networks import define_G as BIT 2 | from .changeformer.ChangeFormer import ChangeFormerV6 3 | from .tiny_cd.change_classifier import TinyCD 4 | -------------------------------------------------------------------------------- /src/models/bit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/models/bit/__init__.py -------------------------------------------------------------------------------- /src/models/bit/help_funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | 7 | class TwoLayerConv2d(nn.Sequential): 8 | def __init__(self, in_channels, out_channels, kernel_size=3): 9 | super().__init__( 10 | nn.Conv2d( 11 | in_channels, in_channels, kernel_size=kernel_size, padding=kernel_size // 2, stride=1, bias=False 12 | ), 13 | nn.BatchNorm2d(in_channels), 14 | nn.ReLU(), 15 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, stride=1), 16 | ) 17 | 18 | 19 | class Residual(nn.Module): 20 | def __init__(self, fn): 21 | super().__init__() 22 | self.fn = fn 23 | 24 | def forward(self, x, **kwargs): 25 | return self.fn(x, **kwargs) + x 26 | 27 | 28 | class Residual2(nn.Module): 29 | def __init__(self, fn): 30 | super().__init__() 31 | self.fn = fn 32 | 33 | def forward(self, x, x2, **kwargs): 34 | return self.fn(x, x2, **kwargs) + x 35 | 36 | 37 | class PreNorm(nn.Module): 38 | def __init__(self, dim, fn): 39 | super().__init__() 40 | self.norm = nn.LayerNorm(dim) 41 | self.fn = fn 42 | 43 | def forward(self, x, **kwargs): 44 | return self.fn(self.norm(x), **kwargs) 45 | 46 | 47 | class PreNorm2(nn.Module): 48 | def __init__(self, dim, fn): 49 | super().__init__() 50 | self.norm = nn.LayerNorm(dim) 51 | self.fn = fn 52 | 53 | def forward(self, x, x2, **kwargs): 54 | return self.fn(self.norm(x), self.norm(x2), **kwargs) 55 | 56 | 57 | class FeedForward(nn.Module): 58 | def __init__(self, dim, hidden_dim, dropout=0.0): 59 | super().__init__() 60 | self.net = nn.Sequential( 61 | nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) 62 | ) 63 | 64 | def forward(self, x): 65 | return self.net(x) 66 | 67 | 68 | class Cross_Attention(nn.Module): 69 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, softmax=True): 70 | super().__init__() 71 | inner_dim = dim_head * heads 72 | self.heads = heads 73 | self.scale = dim**-0.5 74 | 75 | self.softmax = softmax 76 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 77 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 78 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 79 | 80 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 81 | 82 | def forward(self, x, m, mask=None): 83 | _b, _n, _, h = *x.shape, self.heads 84 | q = self.to_q(x) 85 | k = self.to_k(m) 86 | v = self.to_v(m) 87 | 88 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), [q, k, v]) 89 | 90 | dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale 91 | mask_value = -torch.finfo(dots.dtype).max 92 | 93 | if mask is not None: 94 | mask = F.pad(mask.flatten(1), (1, 0), value=True) 95 | assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions" 96 | mask = mask[:, None, :] * mask[:, :, None] 97 | dots.masked_fill_(~mask, mask_value) 98 | del mask 99 | 100 | if self.softmax: 101 | attn = dots.softmax(dim=-1) 102 | else: 103 | attn = dots 104 | # attn = dots 105 | # vis_tmp(dots) 106 | 107 | out = torch.einsum("bhij,bhjd->bhid", attn, v) 108 | out = rearrange(out, "b h n d -> b n (h d)") 109 | out = self.to_out(out) 110 | # vis_tmp2(out) 111 | 112 | return out 113 | 114 | 115 | class Attention(nn.Module): 116 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 117 | super().__init__() 118 | inner_dim = dim_head * heads 119 | self.heads = heads 120 | self.scale = dim**-0.5 121 | 122 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 123 | self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 124 | 125 | def forward(self, x, mask=None): 126 | _b, _n, _, h = *x.shape, self.heads 127 | qkv = self.to_qkv(x).chunk(3, dim=-1) 128 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) 129 | 130 | dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale 131 | mask_value = -torch.finfo(dots.dtype).max 132 | 133 | if mask is not None: 134 | mask = F.pad(mask.flatten(1), (1, 0), value=True) 135 | assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions" 136 | mask = mask[:, None, :] * mask[:, :, None] 137 | dots.masked_fill_(~mask, mask_value) 138 | del mask 139 | 140 | attn = dots.softmax(dim=-1) 141 | 142 | out = torch.einsum("bhij,bhjd->bhid", attn, v) 143 | out = rearrange(out, "b h n d -> b n (h d)") 144 | out = self.to_out(out) 145 | return out 146 | 147 | 148 | class Transformer(nn.Module): 149 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout): 150 | super().__init__() 151 | self.layers = nn.ModuleList([]) 152 | for _ in range(depth): 153 | self.layers.append( 154 | nn.ModuleList( 155 | [ 156 | Residual(PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout))), 157 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))), 158 | ] 159 | ) 160 | ) 161 | 162 | def forward(self, x, mask=None): 163 | for attn, ff in self.layers: 164 | x = attn(x, mask=mask) 165 | x = ff(x) 166 | return x 167 | 168 | 169 | class TransformerDecoder(nn.Module): 170 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout, softmax=True): 171 | super().__init__() 172 | self.layers = nn.ModuleList([]) 173 | for _ in range(depth): 174 | self.layers.append( 175 | nn.ModuleList( 176 | [ 177 | Residual2( 178 | PreNorm2( 179 | dim, 180 | Cross_Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout, softmax=softmax), 181 | ) 182 | ), 183 | Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))), 184 | ] 185 | ) 186 | ) 187 | 188 | def forward(self, x, m, mask=None): 189 | """target(query), memory""" 190 | for attn, ff in self.layers: 191 | x = attn(x, m, mask=mask) 192 | x = ff(x) 193 | return x 194 | -------------------------------------------------------------------------------- /src/models/bit/networks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | from torch.nn import init 8 | from torch.optim import lr_scheduler 9 | 10 | from .help_funcs import Transformer, TransformerDecoder, TwoLayerConv2d 11 | from .resnet import resnet18, resnet34, resnet50 12 | 13 | ############################################################################### 14 | # Helper Functions 15 | ############################################################################### 16 | 17 | 18 | def get_scheduler(optimizer, args): 19 | """Return a learning rate scheduler 20 | 21 | Parameters: 22 | optimizer -- the optimizer of the network 23 | args (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 24 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 25 | 26 | For 'linear', we keep the same learning rate for the first epochs 27 | and linearly decay the rate to zero over the next epochs. 28 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 29 | See https://pytorch.org/docs/stable/optim.html for more details. 30 | """ 31 | if args.lr_policy == "linear": 32 | 33 | def lambda_rule(epoch): 34 | lr_l = 1.0 - epoch / float(args.max_epochs + 1) 35 | return lr_l 36 | 37 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 38 | elif args.lr_policy == "step": 39 | step_size = args.max_epochs // 3 40 | # args.lr_decay_iters 41 | scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1) 42 | else: 43 | return NotImplementedError("learning rate policy [%s] is not implemented", args.lr_policy) 44 | return scheduler 45 | 46 | 47 | class Identity(nn.Module): 48 | def forward(self, x): 49 | return x 50 | 51 | 52 | def get_norm_layer(norm_type="instance"): 53 | """Return a normalization layer 54 | 55 | Parameters: 56 | norm_type (str) -- the name of the normalization layer: batch | instance | none 57 | 58 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 59 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 60 | """ 61 | if norm_type == "batch": 62 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 63 | elif norm_type == "instance": 64 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 65 | elif norm_type == "none": 66 | norm_layer = lambda x: Identity() 67 | else: 68 | raise NotImplementedError(f"normalization layer [{norm_type}] is not found") 69 | return norm_layer 70 | 71 | 72 | def init_weights(net, init_type="normal", init_gain=0.02): 73 | """Initialize network weights. 74 | 75 | Parameters: 76 | net (network) -- network to be initialized 77 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 78 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 79 | 80 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 81 | work better for some applications. Feel free to try yourself. 82 | """ 83 | 84 | def init_func(m): # define the initialization function 85 | classname = m.__class__.__name__ 86 | if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): 87 | if init_type == "normal": 88 | init.normal_(m.weight.data, 0.0, init_gain) 89 | elif init_type == "xavier": 90 | init.xavier_normal_(m.weight.data, gain=init_gain) 91 | elif init_type == "kaiming": 92 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 93 | elif init_type == "orthogonal": 94 | init.orthogonal_(m.weight.data, gain=init_gain) 95 | else: 96 | raise NotImplementedError(f"initialization method [{init_type}] is not implemented") 97 | if hasattr(m, "bias") and m.bias is not None: 98 | init.constant_(m.bias.data, 0.0) 99 | elif ( 100 | classname.find("BatchNorm2d") != -1 101 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 102 | init.normal_(m.weight.data, 1.0, init_gain) 103 | init.constant_(m.bias.data, 0.0) 104 | 105 | print(f"initialize network with {init_type}") 106 | net.apply(init_func) # apply the initialization function 107 | 108 | 109 | def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]): 110 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 111 | Parameters: 112 | net (network) -- the network to be initialized 113 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 114 | gain (float) -- scaling factor for normal, xavier and orthogonal. 115 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 116 | 117 | Return an initialized network. 118 | """ 119 | """ 120 | if len(gpu_ids) > 0: 121 | assert torch.cuda.is_available() 122 | net.to(gpu_ids[0]) 123 | if len(gpu_ids) > 1: 124 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 125 | """ 126 | init_weights(net, init_type, init_gain=init_gain) 127 | return net 128 | 129 | 130 | def define_G(arch, init_type="normal", init_gain=0.02, gpu_ids=[]): 131 | if arch == "base_resnet18": 132 | net = ResNet(input_nc=3, output_nc=2, output_sigmoid=False) 133 | 134 | elif arch == "base_transformer_pos_s4": 135 | net = BASE_Transformer(input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, with_pos="learned") 136 | 137 | elif arch == "base_transformer_pos_s4_dd8": 138 | net = BASE_Transformer( 139 | input_nc=3, output_nc=2, token_len=4, resnet_stages_num=4, with_pos="learned", enc_depth=1, dec_depth=8 140 | ) 141 | 142 | elif arch == "base_transformer_pos_s4_dd8_dedim8": 143 | net = BASE_Transformer( 144 | input_nc=3, 145 | output_nc=2, 146 | token_len=4, 147 | resnet_stages_num=4, 148 | with_pos="learned", 149 | enc_depth=1, 150 | dec_depth=8, 151 | decoder_dim_head=8, 152 | ) 153 | 154 | else: 155 | raise NotImplementedError(f"Generator model name [{arch}] is not recognized") 156 | return init_net(net, init_type, init_gain, gpu_ids) 157 | 158 | 159 | ############################################################################### 160 | # main Functions 161 | ############################################################################### 162 | 163 | 164 | class ResNet(torch.nn.Module): 165 | def __init__( 166 | self, input_nc, output_nc, resnet_stages_num=5, backbone="resnet18", output_sigmoid=False, if_upsample_2x=True 167 | ): 168 | """ 169 | In the constructor we instantiate two nn.Linear modules and assign them as 170 | member variables. 171 | """ 172 | super().__init__() 173 | expand = 1 174 | if backbone == "resnet18": 175 | self.resnet = resnet18(pretrained=True, replace_stride_with_dilation=[False, True, True]) 176 | elif backbone == "resnet34": 177 | self.resnet = resnet34(pretrained=True, replace_stride_with_dilation=[False, True, True]) 178 | elif backbone == "resnet50": 179 | self.resnet = resnet50(pretrained=True, replace_stride_with_dilation=[False, True, True]) 180 | expand = 4 181 | else: 182 | raise NotImplementedError 183 | self.relu = nn.ReLU() 184 | self.upsamplex2 = nn.Upsample(scale_factor=2) 185 | self.upsamplex4 = nn.Upsample(scale_factor=4, mode="bilinear") 186 | 187 | self.classifier = TwoLayerConv2d(in_channels=32, out_channels=output_nc) 188 | 189 | self.resnet_stages_num = resnet_stages_num 190 | 191 | self.if_upsample_2x = if_upsample_2x 192 | if self.resnet_stages_num == 5: 193 | layers = 512 * expand 194 | elif self.resnet_stages_num == 4: 195 | layers = 256 * expand 196 | elif self.resnet_stages_num == 3: 197 | layers = 128 * expand 198 | else: 199 | raise NotImplementedError 200 | self.conv_pred = nn.Conv2d(layers, 32, kernel_size=3, padding=1) 201 | 202 | self.output_sigmoid = output_sigmoid 203 | self.sigmoid = nn.Sigmoid() 204 | 205 | def forward(self, x1, x2): 206 | x1 = self.forward_single(x1) 207 | x2 = self.forward_single(x2) 208 | x = torch.abs(x1 - x2) 209 | if not self.if_upsample_2x: 210 | x = self.upsamplex2(x) 211 | x = self.upsamplex4(x) 212 | x = self.classifier(x) 213 | 214 | if self.output_sigmoid: 215 | x = self.sigmoid(x) 216 | return x 217 | 218 | def forward_single(self, x): 219 | # resnet layers 220 | x = self.resnet.conv1(x) 221 | x = self.resnet.bn1(x) 222 | x = self.resnet.relu(x) 223 | x = self.resnet.maxpool(x) 224 | 225 | x_4 = self.resnet.layer1(x) # 1/4, in=64, out=64 226 | x_8 = self.resnet.layer2(x_4) # 1/8, in=64, out=128 227 | 228 | if self.resnet_stages_num > 3: 229 | x_8 = self.resnet.layer3(x_8) # 1/8, in=128, out=256 230 | 231 | if self.resnet_stages_num == 5: 232 | x_8 = self.resnet.layer4(x_8) # 1/32, in=256, out=512 233 | elif self.resnet_stages_num > 5: 234 | raise NotImplementedError 235 | 236 | if self.if_upsample_2x: 237 | x = self.upsamplex2(x_8) 238 | else: 239 | x = x_8 240 | # output layers 241 | x = self.conv_pred(x) 242 | return x 243 | 244 | 245 | class BASE_Transformer(ResNet): 246 | """ 247 | Resnet of 8 downsampling + BIT + bitemporal feature Differencing + a small CNN 248 | """ 249 | 250 | def __init__( 251 | self, 252 | input_nc, 253 | output_nc, 254 | with_pos, 255 | resnet_stages_num=5, 256 | token_len=4, 257 | token_trans=True, 258 | enc_depth=1, 259 | dec_depth=1, 260 | dim_head=64, 261 | decoder_dim_head=64, 262 | tokenizer=True, 263 | if_upsample_2x=True, 264 | pool_mode="max", 265 | pool_size=2, 266 | backbone="resnet18", 267 | decoder_softmax=True, 268 | with_decoder_pos=None, 269 | with_decoder=True, 270 | ): 271 | super().__init__( 272 | input_nc, output_nc, backbone=backbone, resnet_stages_num=resnet_stages_num, if_upsample_2x=if_upsample_2x 273 | ) 274 | self.token_len = token_len 275 | self.conv_a = nn.Conv2d(32, self.token_len, kernel_size=1, padding=0, bias=False) 276 | self.tokenizer = tokenizer 277 | if not self.tokenizer: 278 | # if not use tokenzier,then downsample the feature map into a certain size 279 | self.pooling_size = pool_size 280 | self.pool_mode = pool_mode 281 | self.token_len = self.pooling_size * self.pooling_size 282 | 283 | self.token_trans = token_trans 284 | self.with_decoder = with_decoder 285 | dim = 32 286 | mlp_dim = 2 * dim 287 | 288 | self.with_pos = with_pos 289 | if with_pos == "learned": 290 | self.pos_embedding = nn.Parameter(torch.randn(1, self.token_len * 2, 32)) 291 | decoder_pos_size = 256 // 4 292 | self.with_decoder_pos = with_decoder_pos 293 | if self.with_decoder_pos == "learned": 294 | self.pos_embedding_decoder = nn.Parameter(torch.randn(1, 32, decoder_pos_size, decoder_pos_size)) 295 | self.enc_depth = enc_depth 296 | self.dec_depth = dec_depth 297 | self.dim_head = dim_head 298 | self.decoder_dim_head = decoder_dim_head 299 | self.transformer = Transformer( 300 | dim=dim, depth=self.enc_depth, heads=8, dim_head=self.dim_head, mlp_dim=mlp_dim, dropout=0 301 | ) 302 | self.transformer_decoder = TransformerDecoder( 303 | dim=dim, 304 | depth=self.dec_depth, 305 | heads=8, 306 | dim_head=self.decoder_dim_head, 307 | mlp_dim=mlp_dim, 308 | dropout=0, 309 | softmax=decoder_softmax, 310 | ) 311 | 312 | def _forward_semantic_tokens(self, x): 313 | b, c, h, w = x.shape 314 | spatial_attention = self.conv_a(x) 315 | spatial_attention = spatial_attention.view([b, self.token_len, -1]).contiguous() 316 | spatial_attention = torch.softmax(spatial_attention, dim=-1) 317 | x = x.view([b, c, -1]).contiguous() 318 | tokens = torch.einsum("bln,bcn->blc", spatial_attention, x) 319 | 320 | return tokens 321 | 322 | def _forward_reshape_tokens(self, x): 323 | # b,c,h,w = x.shape 324 | if self.pool_mode == "max": 325 | x = F.adaptive_max_pool2d(x, [self.pooling_size, self.pooling_size]) 326 | elif self.pool_mode == "ave": 327 | x = F.adaptive_avg_pool2d(x, [self.pooling_size, self.pooling_size]) 328 | else: 329 | x = x 330 | tokens = rearrange(x, "b c h w -> b (h w) c") 331 | return tokens 332 | 333 | def _forward_transformer(self, x): 334 | if self.with_pos: 335 | x += self.pos_embedding 336 | x = self.transformer(x) 337 | return x 338 | 339 | def _forward_transformer_decoder(self, x, m): 340 | b, c, h, w = x.shape 341 | if self.with_decoder_pos == "fix": 342 | x = x + self.pos_embedding_decoder 343 | elif self.with_decoder_pos == "learned": 344 | x = x + self.pos_embedding_decoder 345 | x = rearrange(x, "b c h w -> b (h w) c") 346 | x = self.transformer_decoder(x, m) 347 | x = rearrange(x, "b (h w) c -> b c h w", h=h) 348 | return x 349 | 350 | def _forward_simple_decoder(self, x, m): 351 | b, c, h, w = x.shape 352 | b, l, c = m.shape 353 | m = m.expand([h, w, b, l, c]) 354 | m = rearrange(m, "h w b l c -> l b c h w") 355 | m = m.sum(0) 356 | x = x + m 357 | return x 358 | 359 | def forward(self, x1, x2): 360 | # forward backbone resnet 361 | x1 = self.forward_single(x1) 362 | x2 = self.forward_single(x2) 363 | 364 | # forward tokenzier 365 | if self.tokenizer: 366 | token1 = self._forward_semantic_tokens(x1) 367 | token2 = self._forward_semantic_tokens(x2) 368 | else: 369 | token1 = self._forward_reshape_tokens(x1) 370 | token2 = self._forward_reshape_tokens(x2) 371 | # forward transformer encoder 372 | if self.token_trans: 373 | self.tokens_ = torch.cat([token1, token2], dim=1) 374 | self.tokens = self._forward_transformer(self.tokens_) 375 | token1, token2 = self.tokens.chunk(2, dim=1) 376 | # forward transformer decoder 377 | if self.with_decoder: 378 | x1 = self._forward_transformer_decoder(x1, token1) 379 | x2 = self._forward_transformer_decoder(x2, token2) 380 | else: 381 | x1 = self._forward_simple_decoder(x1, token1) 382 | x2 = self._forward_simple_decoder(x2, token2) 383 | # feature differencing 384 | x = torch.abs(x1 - x2) 385 | if not self.if_upsample_2x: 386 | x = self.upsamplex2(x) 387 | x = self.upsamplex4(x) 388 | # forward small cnn 389 | x = self.classifier(x) 390 | if self.output_sigmoid: 391 | x = self.sigmoid(x) 392 | return x 393 | -------------------------------------------------------------------------------- /src/models/bit/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | __all__ = [ 6 | "ResNet", 7 | "resnet18", 8 | "resnet34", 9 | "resnet50", 10 | "resnet101", 11 | "resnet152", 12 | "resnext50_32x4d", 13 | "resnext101_32x8d", 14 | "wide_resnet50_2", 15 | "wide_resnet101_2", 16 | ] 17 | 18 | 19 | model_urls = { 20 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 21 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 22 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 23 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 24 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 25 | "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", 26 | "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", 27 | "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", 28 | "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", 29 | } 30 | 31 | 32 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d( 35 | in_planes, 36 | out_planes, 37 | kernel_size=3, 38 | stride=stride, 39 | padding=dilation, 40 | groups=groups, 41 | bias=False, 42 | dilation=dilation, 43 | ) 44 | 45 | 46 | def conv1x1(in_planes, out_planes, stride=1): 47 | """1x1 convolution""" 48 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 49 | 50 | 51 | class BasicBlock(nn.Module): 52 | expansion = 1 53 | 54 | def __init__( 55 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None 56 | ): 57 | super().__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | if groups != 1 or base_width != 64: 61 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 62 | if dilation > 1: 63 | dilation = 1 64 | # raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 65 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 66 | self.conv1 = conv3x3(inplanes, planes, stride) 67 | self.bn1 = norm_layer(planes) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.conv2 = conv3x3(planes, planes) 70 | self.bn2 = norm_layer(planes) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | identity = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | 84 | if self.downsample is not None: 85 | identity = self.downsample(x) 86 | 87 | out += identity 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class Bottleneck(nn.Module): 94 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 95 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 96 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 97 | # This variant is also known as ResNet V1.5 and improves accuracy according to 98 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 99 | 100 | expansion = 4 101 | 102 | def __init__( 103 | self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None 104 | ): 105 | super().__init__() 106 | if norm_layer is None: 107 | norm_layer = nn.BatchNorm2d 108 | width = int(planes * (base_width / 64.0)) * groups 109 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 110 | self.conv1 = conv1x1(inplanes, width) 111 | self.bn1 = norm_layer(width) 112 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 113 | self.bn2 = norm_layer(width) 114 | self.conv3 = conv1x1(width, planes * self.expansion) 115 | self.bn3 = norm_layer(planes * self.expansion) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.downsample = downsample 118 | self.stride = stride 119 | 120 | def forward(self, x): 121 | identity = x 122 | 123 | out = self.conv1(x) 124 | out = self.bn1(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv2(out) 128 | out = self.bn2(out) 129 | out = self.relu(out) 130 | 131 | out = self.conv3(out) 132 | out = self.bn3(out) 133 | 134 | if self.downsample is not None: 135 | identity = self.downsample(x) 136 | 137 | out += identity 138 | out = self.relu(out) 139 | 140 | return out 141 | 142 | 143 | class ResNet(nn.Module): 144 | def __init__( 145 | self, 146 | block, 147 | layers, 148 | num_classes=1000, 149 | zero_init_residual=False, 150 | groups=1, 151 | width_per_group=64, 152 | replace_stride_with_dilation=None, 153 | norm_layer=None, 154 | strides=None, 155 | ): 156 | super().__init__() 157 | if norm_layer is None: 158 | norm_layer = nn.BatchNorm2d 159 | self._norm_layer = norm_layer 160 | 161 | self.strides = strides 162 | if self.strides is None: 163 | self.strides = [2, 2, 2, 2, 2] 164 | 165 | self.inplanes = 64 166 | self.dilation = 1 167 | if replace_stride_with_dilation is None: 168 | # each element in the tuple indicates if we should replace 169 | # the 2x2 stride with a dilated convolution instead 170 | replace_stride_with_dilation = [False, False, False] 171 | if len(replace_stride_with_dilation) != 3: 172 | raise ValueError( 173 | f"replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}" 174 | ) 175 | self.groups = groups 176 | self.base_width = width_per_group 177 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3, bias=False) 178 | self.bn1 = norm_layer(self.inplanes) 179 | self.relu = nn.ReLU(inplace=True) 180 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1) 181 | self.layer1 = self._make_layer(block, 64, layers[0]) 182 | self.layer2 = self._make_layer( 183 | block, 128, layers[1], stride=self.strides[2], dilate=replace_stride_with_dilation[0] 184 | ) 185 | self.layer3 = self._make_layer( 186 | block, 256, layers[2], stride=self.strides[3], dilate=replace_stride_with_dilation[1] 187 | ) 188 | self.layer4 = self._make_layer( 189 | block, 512, layers[3], stride=self.strides[4], dilate=replace_stride_with_dilation[2] 190 | ) 191 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 192 | self.fc = nn.Linear(512 * block.expansion, num_classes) 193 | 194 | for m in self.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 197 | elif isinstance(m, nn.BatchNorm2d | nn.GroupNorm): 198 | nn.init.constant_(m.weight, 1) 199 | nn.init.constant_(m.bias, 0) 200 | 201 | # Zero-initialize the last BN in each residual branch, 202 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 203 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 204 | if zero_init_residual: 205 | for m in self.modules(): 206 | if isinstance(m, Bottleneck): 207 | nn.init.constant_(m.bn3.weight, 0) 208 | elif isinstance(m, BasicBlock): 209 | nn.init.constant_(m.bn2.weight, 0) 210 | 211 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 212 | norm_layer = self._norm_layer 213 | downsample = None 214 | previous_dilation = self.dilation 215 | if dilate: 216 | self.dilation *= stride 217 | stride = 1 218 | if stride != 1 or self.inplanes != planes * block.expansion: 219 | downsample = nn.Sequential( 220 | conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion) 221 | ) 222 | 223 | layers = [] 224 | layers.append( 225 | block( 226 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 227 | ) 228 | ) 229 | self.inplanes = planes * block.expansion 230 | for _ in range(1, blocks): 231 | layers.append( 232 | block( 233 | self.inplanes, 234 | planes, 235 | groups=self.groups, 236 | base_width=self.base_width, 237 | dilation=self.dilation, 238 | norm_layer=norm_layer, 239 | ) 240 | ) 241 | 242 | return nn.Sequential(*layers) 243 | 244 | def _forward_impl(self, x): 245 | # See note [TorchScript super()] 246 | x = self.conv1(x) 247 | x = self.bn1(x) 248 | x = self.relu(x) 249 | x = self.maxpool(x) 250 | 251 | x = self.layer1(x) 252 | x = self.layer2(x) 253 | x = self.layer3(x) 254 | x = self.layer4(x) 255 | 256 | x = self.avgpool(x) 257 | x = torch.flatten(x, 1) 258 | x = self.fc(x) 259 | 260 | return x 261 | 262 | def forward(self, x): 263 | return self._forward_impl(x) 264 | 265 | 266 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 267 | model = ResNet(block, layers, **kwargs) 268 | if pretrained: 269 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 270 | model.load_state_dict(state_dict) 271 | return model 272 | 273 | 274 | def resnet18(pretrained=False, progress=True, **kwargs): 275 | r"""ResNet-18 model from 276 | `"Deep Residual Learning for Image Recognition" `_ 277 | 278 | Args: 279 | pretrained (bool): If True, returns a model pre-trained on ImageNet 280 | progress (bool): If True, displays a progress bar of the download to stderr 281 | """ 282 | return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) 283 | 284 | 285 | def resnet34(pretrained=False, progress=True, **kwargs): 286 | r"""ResNet-34 model from 287 | `"Deep Residual Learning for Image Recognition" `_ 288 | 289 | Args: 290 | pretrained (bool): If True, returns a model pre-trained on ImageNet 291 | progress (bool): If True, displays a progress bar of the download to stderr 292 | """ 293 | return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) 294 | 295 | 296 | def resnet50(pretrained=False, progress=True, **kwargs): 297 | r"""ResNet-50 model from 298 | `"Deep Residual Learning for Image Recognition" `_ 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 305 | 306 | 307 | def resnet101(pretrained=False, progress=True, **kwargs): 308 | r"""ResNet-101 model from 309 | `"Deep Residual Learning for Image Recognition" `_ 310 | 311 | Args: 312 | pretrained (bool): If True, returns a model pre-trained on ImageNet 313 | progress (bool): If True, displays a progress bar of the download to stderr 314 | """ 315 | return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 316 | 317 | 318 | def resnet152(pretrained=False, progress=True, **kwargs): 319 | r"""ResNet-152 model from 320 | `"Deep Residual Learning for Image Recognition" `_ 321 | 322 | Args: 323 | pretrained (bool): If True, returns a model pre-trained on ImageNet 324 | progress (bool): If True, displays a progress bar of the download to stderr 325 | """ 326 | return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) 327 | 328 | 329 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 330 | r"""ResNeXt-50 32x4d model from 331 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 332 | 333 | Args: 334 | pretrained (bool): If True, returns a model pre-trained on ImageNet 335 | progress (bool): If True, displays a progress bar of the download to stderr 336 | """ 337 | kwargs["groups"] = 32 338 | kwargs["width_per_group"] = 4 339 | return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 340 | 341 | 342 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 343 | r"""ResNeXt-101 32x8d model from 344 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 345 | 346 | Args: 347 | pretrained (bool): If True, returns a model pre-trained on ImageNet 348 | progress (bool): If True, displays a progress bar of the download to stderr 349 | """ 350 | kwargs["groups"] = 32 351 | kwargs["width_per_group"] = 8 352 | return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 353 | 354 | 355 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 356 | r"""Wide ResNet-50-2 model from 357 | `"Wide Residual Networks" `_ 358 | 359 | The model is the same as ResNet except for the bottleneck number of channels 360 | which is twice larger in every block. The number of channels in outer 1x1 361 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 362 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 363 | 364 | Args: 365 | pretrained (bool): If True, returns a model pre-trained on ImageNet 366 | progress (bool): If True, displays a progress bar of the download to stderr 367 | """ 368 | kwargs["width_per_group"] = 64 * 2 369 | return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) 370 | 371 | 372 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 373 | r"""Wide ResNet-101-2 model from 374 | `"Wide Residual Networks" `_ 375 | 376 | The model is the same as ResNet except for the bottleneck number of channels 377 | which is twice larger in every block. The number of channels in outer 1x1 378 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 379 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 380 | 381 | Args: 382 | pretrained (bool): If True, returns a model pre-trained on ImageNet 383 | progress (bool): If True, displays a progress bar of the download to stderr 384 | """ 385 | kwargs["width_per_group"] = 64 * 2 386 | return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) 387 | -------------------------------------------------------------------------------- /src/models/changeformer/ChangeFormer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional 8 | import torch.nn.functional as F 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | 11 | from .ChangeFormerBaseNetworks import ConvLayer, ResidualBlock, UpsampleConvLayer 12 | 13 | 14 | class EncoderTransformer(nn.Module): 15 | def __init__( 16 | self, 17 | img_size=256, 18 | patch_size=16, 19 | in_chans=3, 20 | num_classes=2, 21 | embed_dims=[64, 128, 256, 512], 22 | num_heads=[1, 2, 4, 8], 23 | mlp_ratios=[4, 4, 4, 4], 24 | qkv_bias=False, 25 | qk_scale=None, 26 | drop_rate=0.0, 27 | attn_drop_rate=0.0, 28 | drop_path_rate=0.0, 29 | norm_layer=nn.LayerNorm, 30 | depths=[3, 4, 6, 3], 31 | sr_ratios=[8, 4, 2, 1], 32 | ): 33 | super().__init__() 34 | self.num_classes = num_classes 35 | self.depths = depths 36 | 37 | # patch embedding definitions 38 | self.patch_embed1 = OverlapPatchEmbed( 39 | img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0] 40 | ) 41 | self.patch_embed2 = OverlapPatchEmbed( 42 | img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1] 43 | ) 44 | self.patch_embed3 = OverlapPatchEmbed( 45 | img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2] 46 | ) 47 | self.patch_embed4 = OverlapPatchEmbed( 48 | img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3] 49 | ) 50 | 51 | # main encoder 52 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 53 | cur = 0 54 | self.block1 = nn.ModuleList( 55 | [ 56 | Block( 57 | dim=embed_dims[0], 58 | num_heads=num_heads[0], 59 | mlp_ratio=mlp_ratios[0], 60 | qkv_bias=qkv_bias, 61 | qk_scale=qk_scale, 62 | drop=drop_rate, 63 | attn_drop=attn_drop_rate, 64 | drop_path=dpr[cur + i], 65 | norm_layer=norm_layer, 66 | sr_ratio=sr_ratios[0], 67 | ) 68 | for i in range(depths[0]) 69 | ] 70 | ) 71 | self.norm1 = norm_layer(embed_dims[0]) 72 | # intra-patch encoder 73 | self.patch_block1 = nn.ModuleList( 74 | [ 75 | Block( 76 | dim=embed_dims[1], 77 | num_heads=num_heads[0], 78 | mlp_ratio=mlp_ratios[0], 79 | qkv_bias=qkv_bias, 80 | qk_scale=qk_scale, 81 | drop=drop_rate, 82 | attn_drop=attn_drop_rate, 83 | drop_path=dpr[cur + i], 84 | norm_layer=norm_layer, 85 | sr_ratio=sr_ratios[0], 86 | ) 87 | for i in range(1) 88 | ] 89 | ) 90 | self.pnorm1 = norm_layer(embed_dims[1]) 91 | # main encoder 92 | cur += depths[0] 93 | self.block2 = nn.ModuleList( 94 | [ 95 | Block( 96 | dim=embed_dims[1], 97 | num_heads=num_heads[1], 98 | mlp_ratio=mlp_ratios[1], 99 | qkv_bias=qkv_bias, 100 | qk_scale=qk_scale, 101 | drop=drop_rate, 102 | attn_drop=attn_drop_rate, 103 | drop_path=dpr[cur + i], 104 | norm_layer=norm_layer, 105 | sr_ratio=sr_ratios[1], 106 | ) 107 | for i in range(depths[1]) 108 | ] 109 | ) 110 | self.norm2 = norm_layer(embed_dims[1]) 111 | # intra-patch encoder 112 | self.patch_block2 = nn.ModuleList( 113 | [ 114 | Block( 115 | dim=embed_dims[2], 116 | num_heads=num_heads[1], 117 | mlp_ratio=mlp_ratios[1], 118 | qkv_bias=qkv_bias, 119 | qk_scale=qk_scale, 120 | drop=drop_rate, 121 | attn_drop=attn_drop_rate, 122 | drop_path=dpr[cur + i], 123 | norm_layer=norm_layer, 124 | sr_ratio=sr_ratios[1], 125 | ) 126 | for i in range(1) 127 | ] 128 | ) 129 | self.pnorm2 = norm_layer(embed_dims[2]) 130 | # main encoder 131 | cur += depths[1] 132 | self.block3 = nn.ModuleList( 133 | [ 134 | Block( 135 | dim=embed_dims[2], 136 | num_heads=num_heads[2], 137 | mlp_ratio=mlp_ratios[2], 138 | qkv_bias=qkv_bias, 139 | qk_scale=qk_scale, 140 | drop=drop_rate, 141 | attn_drop=attn_drop_rate, 142 | drop_path=dpr[cur + i], 143 | norm_layer=norm_layer, 144 | sr_ratio=sr_ratios[2], 145 | ) 146 | for i in range(depths[2]) 147 | ] 148 | ) 149 | self.norm3 = norm_layer(embed_dims[2]) 150 | # intra-patch encoder 151 | self.patch_block3 = nn.ModuleList( 152 | [ 153 | Block( 154 | dim=embed_dims[3], 155 | num_heads=num_heads[1], 156 | mlp_ratio=mlp_ratios[2], 157 | qkv_bias=qkv_bias, 158 | qk_scale=qk_scale, 159 | drop=drop_rate, 160 | attn_drop=attn_drop_rate, 161 | drop_path=dpr[cur + i], 162 | norm_layer=norm_layer, 163 | sr_ratio=sr_ratios[2], 164 | ) 165 | for i in range(1) 166 | ] 167 | ) 168 | self.pnorm3 = norm_layer(embed_dims[3]) 169 | # main encoder 170 | cur += depths[2] 171 | self.block4 = nn.ModuleList( 172 | [ 173 | Block( 174 | dim=embed_dims[3], 175 | num_heads=num_heads[3], 176 | mlp_ratio=mlp_ratios[3], 177 | qkv_bias=qkv_bias, 178 | qk_scale=qk_scale, 179 | drop=drop_rate, 180 | attn_drop=attn_drop_rate, 181 | drop_path=dpr[cur + i], 182 | norm_layer=norm_layer, 183 | sr_ratio=sr_ratios[3], 184 | ) 185 | for i in range(depths[3]) 186 | ] 187 | ) 188 | self.norm4 = norm_layer(embed_dims[3]) 189 | 190 | self.apply(self._init_weights) 191 | 192 | def _init_weights(self, m): 193 | if isinstance(m, nn.Linear): 194 | trunc_normal_(m.weight, std=0.02) 195 | if isinstance(m, nn.Linear) and m.bias is not None: 196 | nn.init.constant_(m.bias, 0) 197 | elif isinstance(m, nn.LayerNorm): 198 | nn.init.constant_(m.bias, 0) 199 | nn.init.constant_(m.weight, 1.0) 200 | elif isinstance(m, nn.Conv2d): 201 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 202 | fan_out //= m.groups 203 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 204 | if m.bias is not None: 205 | m.bias.data.zero_() 206 | 207 | def reset_drop_path(self, drop_path_rate): 208 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 209 | cur = 0 210 | for i in range(self.depths[0]): 211 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 212 | 213 | cur += self.depths[0] 214 | for i in range(self.depths[1]): 215 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 216 | 217 | cur += self.depths[1] 218 | for i in range(self.depths[2]): 219 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 220 | 221 | cur += self.depths[2] 222 | for i in range(self.depths[3]): 223 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 224 | 225 | def forward_features(self, x): 226 | B = x.shape[0] 227 | outs = [] 228 | embed_dims = [64, 128, 320, 512] 229 | # stage 1 230 | x1, H1, W1 = self.patch_embed1(x) 231 | 232 | for _i, blk in enumerate(self.block1): 233 | x1 = blk(x1, H1, W1) 234 | x1 = self.norm1(x1) 235 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 236 | 237 | outs.append(x1) 238 | 239 | # stage 2 240 | x1, H1, W1 = self.patch_embed2(x1) 241 | x1 = x1.permute(0, 2, 1).reshape(B, embed_dims[1], H1, W1) 242 | 243 | x1 = x1.view(x1.shape[0], x1.shape[1], -1).permute(0, 2, 1) 244 | 245 | for _i, blk in enumerate(self.block2): 246 | x1 = blk(x1, H1, W1) 247 | x1 = self.norm2(x1) 248 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 249 | outs.append(x1) 250 | 251 | # stage 3 252 | x1, H1, W1 = self.patch_embed3(x1) 253 | x1 = x1.permute(0, 2, 1).reshape(B, embed_dims[2], H1, W1) 254 | 255 | x1 = x1.view(x1.shape[0], x1.shape[1], -1).permute(0, 2, 1) 256 | 257 | for _i, blk in enumerate(self.block3): 258 | x1 = blk(x1, H1, W1) 259 | x1 = self.norm3(x1) 260 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 261 | outs.append(x1) 262 | 263 | # stage 4 264 | x1, H1, W1 = self.patch_embed4(x1) 265 | x1 = x1.permute(0, 2, 1).reshape(B, embed_dims[3], H1, W1) # +x2 266 | 267 | x1 = x1.view(x1.shape[0], x1.shape[1], -1).permute(0, 2, 1) 268 | 269 | for _i, blk in enumerate(self.block4): 270 | x1 = blk(x1, H1, W1) 271 | x1 = self.norm4(x1) 272 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 273 | outs.append(x1) 274 | 275 | return outs 276 | 277 | def forward(self, x): 278 | x = self.forward_features(x) 279 | 280 | return x 281 | 282 | 283 | class OverlapPatchEmbed(nn.Module): 284 | """Image to Patch Embedding""" 285 | 286 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 287 | super().__init__() 288 | img_size = to_2tuple(img_size) 289 | patch_size = to_2tuple(patch_size) 290 | 291 | self.img_size = img_size 292 | self.patch_size = patch_size 293 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 294 | self.num_patches = self.H * self.W 295 | self.proj = nn.Conv2d( 296 | in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2) 297 | ) 298 | self.norm = nn.LayerNorm(embed_dim) 299 | 300 | self.apply(self._init_weights) 301 | 302 | def _init_weights(self, m): 303 | if isinstance(m, nn.Linear): 304 | trunc_normal_(m.weight, std=0.02) 305 | if isinstance(m, nn.Linear) and m.bias is not None: 306 | nn.init.constant_(m.bias, 0) 307 | elif isinstance(m, nn.LayerNorm): 308 | nn.init.constant_(m.bias, 0) 309 | nn.init.constant_(m.weight, 1.0) 310 | elif isinstance(m, nn.Conv2d): 311 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 312 | fan_out //= m.groups 313 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 314 | if m.bias is not None: 315 | m.bias.data.zero_() 316 | 317 | def forward(self, x): 318 | # pdb.set_trace() 319 | x = self.proj(x) 320 | _, _, H, W = x.shape 321 | x = x.flatten(2).transpose(1, 2) 322 | x = self.norm(x) 323 | 324 | return x, H, W 325 | 326 | 327 | def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=True): 328 | if warning: 329 | if size is not None and align_corners: 330 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 331 | output_h, output_w = tuple(int(x) for x in size) 332 | if output_h > input_h or output_w > output_h: 333 | if ( 334 | (output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1) 335 | and (output_h - 1) % (input_h - 1) 336 | and (output_w - 1) % (input_w - 1) 337 | ): 338 | warnings.warn( 339 | f"When align_corners={align_corners}, " 340 | "the output would more aligned if " 341 | f"input size {(input_h, input_w)} is `x+1` and " 342 | f"out size {(output_h, output_w)} is `nx+1`", 343 | stacklevel=2, 344 | ) 345 | return F.interpolate(input, size, scale_factor, mode, align_corners) 346 | 347 | 348 | class Mlp(nn.Module): 349 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0): 350 | super().__init__() 351 | out_features = out_features or in_features 352 | hidden_features = hidden_features or in_features 353 | self.fc1 = nn.Linear(in_features, hidden_features) 354 | self.dwconv = DWConv(hidden_features) 355 | self.act = act_layer() 356 | self.fc2 = nn.Linear(hidden_features, out_features) 357 | self.drop = nn.Dropout(drop) 358 | 359 | self.apply(self._init_weights) 360 | 361 | def _init_weights(self, m): 362 | if isinstance(m, nn.Linear): 363 | trunc_normal_(m.weight, std=0.02) 364 | if isinstance(m, nn.Linear) and m.bias is not None: 365 | nn.init.constant_(m.bias, 0) 366 | elif isinstance(m, nn.LayerNorm): 367 | nn.init.constant_(m.bias, 0) 368 | nn.init.constant_(m.weight, 1.0) 369 | elif isinstance(m, nn.Conv2d): 370 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 371 | fan_out //= m.groups 372 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 373 | if m.bias is not None: 374 | m.bias.data.zero_() 375 | 376 | def forward(self, x, H, W): 377 | x = self.fc1(x) 378 | x = self.dwconv(x, H, W) 379 | x = self.act(x) 380 | x = self.drop(x) 381 | x = self.fc2(x) 382 | x = self.drop(x) 383 | return x 384 | 385 | 386 | class Attention(nn.Module): 387 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): 388 | super().__init__() 389 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 390 | 391 | self.dim = dim 392 | self.num_heads = num_heads 393 | head_dim = dim // num_heads 394 | self.scale = qk_scale or head_dim**-0.5 395 | 396 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 397 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 398 | self.attn_drop = nn.Dropout(attn_drop) 399 | self.proj = nn.Linear(dim, dim) 400 | self.proj_drop = nn.Dropout(proj_drop) 401 | 402 | self.sr_ratio = sr_ratio 403 | if sr_ratio > 1: 404 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 405 | self.norm = nn.LayerNorm(dim) 406 | 407 | self.apply(self._init_weights) 408 | 409 | def _init_weights(self, m): 410 | if isinstance(m, nn.Linear): 411 | trunc_normal_(m.weight, std=0.02) 412 | if isinstance(m, nn.Linear) and m.bias is not None: 413 | nn.init.constant_(m.bias, 0) 414 | elif isinstance(m, nn.LayerNorm): 415 | nn.init.constant_(m.bias, 0) 416 | nn.init.constant_(m.weight, 1.0) 417 | elif isinstance(m, nn.Conv2d): 418 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 419 | fan_out //= m.groups 420 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 421 | if m.bias is not None: 422 | m.bias.data.zero_() 423 | 424 | def forward(self, x, H, W): 425 | B, N, C = x.shape 426 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 427 | 428 | if self.sr_ratio > 1: 429 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 430 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 431 | x_ = self.norm(x_) 432 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 433 | else: 434 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 435 | k, v = kv[0], kv[1] 436 | 437 | attn = (q @ k.transpose(-2, -1)) * self.scale 438 | attn = attn.softmax(dim=-1) 439 | attn = self.attn_drop(attn) 440 | 441 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 442 | x = self.proj(x) 443 | x = self.proj_drop(x) 444 | 445 | return x 446 | 447 | 448 | class Attention_dec(nn.Module): 449 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1): 450 | super().__init__() 451 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 452 | 453 | self.dim = dim 454 | self.num_heads = num_heads 455 | head_dim = dim // num_heads 456 | self.scale = qk_scale or head_dim**-0.5 457 | 458 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 459 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 460 | self.attn_drop = nn.Dropout(attn_drop) 461 | self.proj = nn.Linear(dim, dim) 462 | self.proj_drop = nn.Dropout(proj_drop) 463 | 464 | self.task_query = nn.Parameter(torch.randn(1, 48, dim)) 465 | self.sr_ratio = sr_ratio 466 | if sr_ratio > 1: 467 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 468 | self.norm = nn.LayerNorm(dim) 469 | 470 | self.apply(self._init_weights) 471 | 472 | def _init_weights(self, m): 473 | if isinstance(m, nn.Linear): 474 | trunc_normal_(m.weight, std=0.02) 475 | if isinstance(m, nn.Linear) and m.bias is not None: 476 | nn.init.constant_(m.bias, 0) 477 | elif isinstance(m, nn.LayerNorm): 478 | nn.init.constant_(m.bias, 0) 479 | nn.init.constant_(m.weight, 1.0) 480 | elif isinstance(m, nn.Conv2d): 481 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 482 | fan_out //= m.groups 483 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 484 | if m.bias is not None: 485 | m.bias.data.zero_() 486 | 487 | def forward(self, x, H, W): 488 | B, N, C = x.shape 489 | task_q = self.task_query 490 | 491 | # This is because we fix the task parameters to be of a certain dimension, so with varying batch size, we just stack up the same queries to operate on the entire batch 492 | if B > 1: 493 | task_q = task_q.unsqueeze(0).repeat(B, 1, 1, 1) 494 | task_q = task_q.squeeze(1) 495 | 496 | q = self.q(task_q).reshape(B, task_q.shape[1], self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 497 | 498 | if self.sr_ratio > 1: 499 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 500 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 501 | x_ = self.norm(x_) 502 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 503 | else: 504 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 505 | k, v = kv[0], kv[1] 506 | q = torch.nn.functional.interpolate(q, size=(v.shape[2], v.shape[3])) 507 | attn = (q @ k.transpose(-2, -1)) * self.scale 508 | attn = attn.softmax(dim=-1) 509 | attn = self.attn_drop(attn) 510 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 511 | x = self.proj(x) 512 | x = self.proj_drop(x) 513 | 514 | return x 515 | 516 | 517 | class Block_dec(nn.Module): 518 | def __init__( 519 | self, 520 | dim, 521 | num_heads, 522 | mlp_ratio=4.0, 523 | qkv_bias=False, 524 | qk_scale=None, 525 | drop=0.0, 526 | attn_drop=0.0, 527 | drop_path=0.0, 528 | act_layer=nn.GELU, 529 | norm_layer=nn.LayerNorm, 530 | sr_ratio=1, 531 | ): 532 | super().__init__() 533 | self.norm1 = norm_layer(dim) 534 | self.attn = Attention_dec( 535 | dim, 536 | num_heads=num_heads, 537 | qkv_bias=qkv_bias, 538 | qk_scale=qk_scale, 539 | attn_drop=attn_drop, 540 | proj_drop=drop, 541 | sr_ratio=sr_ratio, 542 | ) 543 | 544 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 545 | self.norm2 = norm_layer(dim) 546 | mlp_hidden_dim = int(dim * mlp_ratio) 547 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 548 | 549 | self.apply(self._init_weights) 550 | 551 | def _init_weights(self, m): 552 | if isinstance(m, nn.Linear): 553 | trunc_normal_(m.weight, std=0.02) 554 | if isinstance(m, nn.Linear) and m.bias is not None: 555 | nn.init.constant_(m.bias, 0) 556 | elif isinstance(m, nn.LayerNorm): 557 | nn.init.constant_(m.bias, 0) 558 | nn.init.constant_(m.weight, 1.0) 559 | elif isinstance(m, nn.Conv2d): 560 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 561 | fan_out //= m.groups 562 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 563 | if m.bias is not None: 564 | m.bias.data.zero_() 565 | 566 | def forward(self, x, H, W): 567 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 568 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 569 | 570 | return x 571 | 572 | 573 | class Block(nn.Module): 574 | def __init__( 575 | self, 576 | dim, 577 | num_heads, 578 | mlp_ratio=4.0, 579 | qkv_bias=False, 580 | qk_scale=None, 581 | drop=0.0, 582 | attn_drop=0.0, 583 | drop_path=0.0, 584 | act_layer=nn.GELU, 585 | norm_layer=nn.LayerNorm, 586 | sr_ratio=1, 587 | ): 588 | super().__init__() 589 | self.norm1 = norm_layer(dim) 590 | self.attn = Attention( 591 | dim, 592 | num_heads=num_heads, 593 | qkv_bias=qkv_bias, 594 | qk_scale=qk_scale, 595 | attn_drop=attn_drop, 596 | proj_drop=drop, 597 | sr_ratio=sr_ratio, 598 | ) 599 | 600 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 601 | self.norm2 = norm_layer(dim) 602 | mlp_hidden_dim = int(dim * mlp_ratio) 603 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 604 | 605 | self.apply(self._init_weights) 606 | 607 | def _init_weights(self, m): 608 | if isinstance(m, nn.Linear): 609 | trunc_normal_(m.weight, std=0.02) 610 | if isinstance(m, nn.Linear) and m.bias is not None: 611 | nn.init.constant_(m.bias, 0) 612 | elif isinstance(m, nn.LayerNorm): 613 | nn.init.constant_(m.bias, 0) 614 | nn.init.constant_(m.weight, 1.0) 615 | elif isinstance(m, nn.Conv2d): 616 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 617 | fan_out //= m.groups 618 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 619 | if m.bias is not None: 620 | m.bias.data.zero_() 621 | 622 | def forward(self, x, H, W): 623 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 624 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 625 | return x 626 | 627 | 628 | class DWConv(nn.Module): 629 | def __init__(self, dim=768): 630 | super().__init__() 631 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 632 | 633 | def forward(self, x, H, W): 634 | B, N, C = x.shape 635 | x = x.transpose(1, 2).view(B, C, H, W) 636 | x = self.dwconv(x) 637 | x = x.flatten(2).transpose(1, 2) 638 | 639 | return x 640 | 641 | 642 | class Tenc(EncoderTransformer): 643 | def __init__(self, **kwargs): 644 | super().__init__( 645 | patch_size=16, 646 | embed_dims=[64, 128, 320, 512], 647 | num_heads=[1, 2, 4, 8], 648 | mlp_ratios=[4, 4, 4, 4], 649 | qkv_bias=True, 650 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 651 | depths=[3, 4, 6, 3], 652 | sr_ratios=[8, 4, 2, 1], 653 | drop_rate=0.0, 654 | drop_path_rate=0.1, 655 | ) 656 | 657 | 658 | class convprojection(nn.Module): 659 | def __init__(self, path=None, **kwargs): 660 | super().__init__() 661 | 662 | self.convd32x = UpsampleConvLayer(512, 512, kernel_size=4, stride=2) 663 | self.convd16x = UpsampleConvLayer(512, 320, kernel_size=4, stride=2) 664 | self.dense_4 = nn.Sequential(ResidualBlock(320)) 665 | self.convd8x = UpsampleConvLayer(320, 128, kernel_size=4, stride=2) 666 | self.dense_3 = nn.Sequential(ResidualBlock(128)) 667 | self.convd4x = UpsampleConvLayer(128, 64, kernel_size=4, stride=2) 668 | self.dense_2 = nn.Sequential(ResidualBlock(64)) 669 | self.convd2x = UpsampleConvLayer(64, 16, kernel_size=4, stride=2) 670 | self.dense_1 = nn.Sequential(ResidualBlock(16)) 671 | self.convd1x = UpsampleConvLayer(16, 8, kernel_size=4, stride=2) 672 | self.conv_output = ConvLayer(8, 2, kernel_size=3, stride=1, padding=1) 673 | 674 | self.active = nn.Tanh() 675 | 676 | def forward(self, x1, x2): 677 | res32x = self.convd32x(x2[0]) 678 | 679 | if x1[3].shape[3] != res32x.shape[3] and x1[3].shape[2] != res32x.shape[2]: 680 | p2d = (0, -1, 0, -1) 681 | res32x = F.pad(res32x, p2d, "constant", 0) 682 | 683 | elif x1[3].shape[3] != res32x.shape[3] and x1[3].shape[2] == res32x.shape[2]: 684 | p2d = (0, -1, 0, 0) 685 | res32x = F.pad(res32x, p2d, "constant", 0) 686 | elif x1[3].shape[3] == res32x.shape[3] and x1[3].shape[2] != res32x.shape[2]: 687 | p2d = (0, 0, 0, -1) 688 | res32x = F.pad(res32x, p2d, "constant", 0) 689 | 690 | res16x = res32x + x1[3] 691 | res16x = self.convd16x(res16x) 692 | 693 | if x1[2].shape[3] != res16x.shape[3] and x1[2].shape[2] != res16x.shape[2]: 694 | p2d = (0, -1, 0, -1) 695 | res16x = F.pad(res16x, p2d, "constant", 0) 696 | elif x1[2].shape[3] != res16x.shape[3] and x1[2].shape[2] == res16x.shape[2]: 697 | p2d = (0, -1, 0, 0) 698 | res16x = F.pad(res16x, p2d, "constant", 0) 699 | elif x1[2].shape[3] == res16x.shape[3] and x1[2].shape[2] != res16x.shape[2]: 700 | p2d = (0, 0, 0, -1) 701 | res16x = F.pad(res16x, p2d, "constant", 0) 702 | 703 | res8x = self.dense_4(res16x) + x1[2] 704 | res8x = self.convd8x(res8x) 705 | res4x = self.dense_3(res8x) + x1[1] 706 | res4x = self.convd4x(res4x) 707 | res2x = self.dense_2(res4x) + x1[0] 708 | res2x = self.convd2x(res2x) 709 | x = res2x 710 | x = self.dense_1(x) 711 | x = self.convd1x(x) 712 | 713 | return x 714 | 715 | 716 | class convprojection_base(nn.Module): 717 | def __init__(self, path=None, **kwargs): 718 | super().__init__() 719 | 720 | # self.convd32x = UpsampleConvLayer(512, 512, kernel_size=4, stride=2) 721 | self.convd16x = UpsampleConvLayer(512, 320, kernel_size=4, stride=2) 722 | self.dense_4 = nn.Sequential(ResidualBlock(320)) 723 | self.convd8x = UpsampleConvLayer(320, 128, kernel_size=4, stride=2) 724 | self.dense_3 = nn.Sequential(ResidualBlock(128)) 725 | self.convd4x = UpsampleConvLayer(128, 64, kernel_size=4, stride=2) 726 | self.dense_2 = nn.Sequential(ResidualBlock(64)) 727 | self.convd2x = UpsampleConvLayer(64, 16, kernel_size=4, stride=2) 728 | self.dense_1 = nn.Sequential(ResidualBlock(16)) 729 | self.convd1x = UpsampleConvLayer(16, 8, kernel_size=4, stride=2) 730 | 731 | def forward(self, x1): 732 | # if x1[3].shape[3] != res32x.shape[3] and x1[3].shape[2] != res32x.shape[2]: 733 | # p2d = (0,-1,0,-1) 734 | # res32x = F.pad(res32x,p2d,"constant",0) 735 | 736 | # elif x1[3].shape[3] != res32x.shape[3] and x1[3].shape[2] == res32x.shape[2]: 737 | # p2d = (0,-1,0,0) 738 | # res32x = F.pad(res32x,p2d,"constant",0) 739 | # elif x1[3].shape[3] == res32x.shape[3] and x1[3].shape[2] != res32x.shape[2]: 740 | # p2d = (0,0,0,-1) 741 | # res32x = F.pad(res32x,p2d,"constant",0) 742 | 743 | # res16x = res32x + x1[3] 744 | res16x = self.convd16x(x1[3]) 745 | 746 | if x1[2].shape[3] != res16x.shape[3] and x1[2].shape[2] != res16x.shape[2]: 747 | p2d = (0, -1, 0, -1) 748 | res16x = F.pad(res16x, p2d, "constant", 0) 749 | elif x1[2].shape[3] != res16x.shape[3] and x1[2].shape[2] == res16x.shape[2]: 750 | p2d = (0, -1, 0, 0) 751 | res16x = F.pad(res16x, p2d, "constant", 0) 752 | elif x1[2].shape[3] == res16x.shape[3] and x1[2].shape[2] != res16x.shape[2]: 753 | p2d = (0, 0, 0, -1) 754 | res16x = F.pad(res16x, p2d, "constant", 0) 755 | 756 | res8x = self.dense_4(res16x) + x1[2] 757 | res8x = self.convd8x(res8x) 758 | res4x = self.dense_3(res8x) + x1[1] 759 | res4x = self.convd4x(res4x) 760 | res2x = self.dense_2(res4x) + x1[0] 761 | res2x = self.convd2x(res2x) 762 | x = res2x 763 | x = self.dense_1(x) 764 | x = self.convd1x(x) 765 | return x 766 | 767 | 768 | ### This is the basic ChangeFormer module 769 | class ChangeFormerV1(nn.Module): 770 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False): 771 | super().__init__() 772 | 773 | self.Tenc = Tenc() 774 | 775 | self.convproj = convprojection_base() 776 | 777 | self.change_probability = ConvLayer(8, output_nc, kernel_size=3, stride=1, padding=1) 778 | 779 | self.output_softmax = decoder_softmax 780 | self.active = torch.nn.Softmax(dim=1) 781 | 782 | def forward(self, x1, x2): 783 | fx1 = self.Tenc(x1) 784 | fx2 = self.Tenc(x2) 785 | 786 | DI = [] 787 | for i in range(0, 4): 788 | DI.append(torch.abs(fx1[i] - fx2[i])) 789 | 790 | cp = self.convproj(DI) 791 | 792 | cp = self.change_probability(cp) 793 | 794 | if self.output_softmax: 795 | cp = self.active(cp) 796 | 797 | return cp 798 | 799 | 800 | # Transformer Decoder 801 | class MLP(nn.Module): 802 | """ 803 | Linear Embedding 804 | """ 805 | 806 | def __init__(self, input_dim=2048, embed_dim=768): 807 | super().__init__() 808 | self.proj = nn.Linear(input_dim, embed_dim) 809 | 810 | def forward(self, x): 811 | x = x.flatten(2).transpose(1, 2) 812 | x = self.proj(x) 813 | return x 814 | 815 | 816 | class TDec(nn.Module): 817 | """ 818 | Transformer Decoder 819 | """ 820 | 821 | def __init__( 822 | self, 823 | input_transform="multiple_select", 824 | in_index=[0, 1, 2, 3], 825 | align_corners=True, 826 | in_channels=[64, 128, 256, 512], 827 | embedding_dim=256, 828 | output_nc=2, 829 | decoder_softmax=False, 830 | feature_strides=[4, 8, 16, 32], 831 | ): 832 | super().__init__() 833 | assert len(feature_strides) == len(in_channels) 834 | assert min(feature_strides) == feature_strides[0] 835 | self.feature_strides = feature_strides 836 | 837 | # input transforms 838 | self.input_transform = input_transform 839 | self.in_index = in_index 840 | self.align_corners = align_corners 841 | 842 | # MLP 843 | self.in_channels = in_channels 844 | self.embedding_dim = embedding_dim 845 | 846 | # Final prediction 847 | self.output_nc = output_nc 848 | 849 | (c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels) = self.in_channels 850 | 851 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) 852 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) 853 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) 854 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) 855 | 856 | self.linear_fuse = nn.Conv2d(in_channels=self.embedding_dim * 4, out_channels=self.embedding_dim, kernel_size=1) 857 | 858 | # self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 859 | self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) 860 | self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim)) 861 | self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) 862 | self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim)) 863 | 864 | # Final prediction 865 | self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1) 866 | self.output_softmax = decoder_softmax 867 | self.active = nn.Softmax(dim=1) 868 | 869 | def _transform_inputs(self, inputs): 870 | """Transform inputs for decoder. 871 | Args: 872 | inputs (list[Tensor]): List of multi-level img features. 873 | Returns: 874 | Tensor: The transformed inputs 875 | """ 876 | 877 | if self.input_transform == "resize_concat": 878 | inputs = [inputs[i] for i in self.in_index] 879 | upsampled_inputs = [ 880 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) 881 | for x in inputs 882 | ] 883 | inputs = torch.cat(upsampled_inputs, dim=1) 884 | elif self.input_transform == "multiple_select": 885 | inputs = [inputs[i] for i in self.in_index] 886 | else: 887 | inputs = inputs[self.in_index] 888 | 889 | return inputs 890 | 891 | def forward(self, inputs): 892 | x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32 893 | c1, c2, c3, c4 = x 894 | 895 | ############## MLP decoder on C1-C4 ########### 896 | n, _, h, w = c4.shape 897 | 898 | _c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3]) 899 | _c4 = resize(_c4, size=c1.size()[2:], mode="bilinear", align_corners=False) 900 | 901 | _c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3]) 902 | _c3 = resize(_c3, size=c1.size()[2:], mode="bilinear", align_corners=False) 903 | 904 | _c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3]) 905 | _c2 = resize(_c2, size=c1.size()[2:], mode="bilinear", align_corners=False) 906 | 907 | _c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3]) 908 | 909 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 910 | 911 | x = self.convd2x(_c) 912 | x = self.dense_2x(x) 913 | x = self.convd1x(x) 914 | x = self.dense_1x(x) 915 | 916 | cp = self.change_probability(x) 917 | if self.output_softmax: 918 | cp = self.active(cp) 919 | 920 | return cp 921 | 922 | 923 | class TDecV2(nn.Module): 924 | """ 925 | Transformer Decoder 926 | """ 927 | 928 | def __init__( 929 | self, 930 | input_transform="multiple_select", 931 | in_index=[0, 1, 2, 3], 932 | align_corners=True, 933 | in_channels=[64, 128, 256, 512], 934 | embedding_dim=256, 935 | output_nc=2, 936 | decoder_softmax=False, 937 | feature_strides=[4, 8, 16, 32], 938 | ): 939 | super().__init__() 940 | assert len(feature_strides) == len(in_channels) 941 | assert min(feature_strides) == feature_strides[0] 942 | self.feature_strides = feature_strides 943 | 944 | # input transforms 945 | self.input_transform = input_transform 946 | self.in_index = in_index 947 | self.align_corners = align_corners 948 | 949 | # MLP 950 | self.in_channels = in_channels 951 | self.embedding_dim = embedding_dim 952 | 953 | # Final prediction 954 | self.output_nc = output_nc 955 | 956 | (c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels) = self.in_channels 957 | 958 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) 959 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) 960 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) 961 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) 962 | 963 | self.linear_fuse = nn.Conv2d(in_channels=self.embedding_dim * 4, out_channels=self.embedding_dim, kernel_size=1) 964 | 965 | # self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 966 | # self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) 967 | # self.dense_2x = nn.Sequential( ResidualBlock(self.embedding_dim)) 968 | # self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) 969 | # self.dense_1x = nn.Sequential( ResidualBlock(self.embedding_dim)) 970 | 971 | # Pixel Shiffle 972 | self.pix_shuffle_conv = nn.Conv2d( 973 | in_channels=self.embedding_dim, out_channels=16 * output_nc, kernel_size=3, stride=1, padding=1 974 | ) 975 | self.relu = nn.ReLU() 976 | self.pix_shuffle = nn.PixelShuffle(4) 977 | 978 | # Final prediction 979 | # self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1) 980 | 981 | # Final activation 982 | self.output_softmax = decoder_softmax 983 | self.active = nn.Softmax(dim=1) 984 | 985 | def _transform_inputs(self, inputs): 986 | """Transform inputs for decoder. 987 | Args: 988 | inputs (list[Tensor]): List of multi-level img features. 989 | Returns: 990 | Tensor: The transformed inputs 991 | """ 992 | 993 | if self.input_transform == "resize_concat": 994 | inputs = [inputs[i] for i in self.in_index] 995 | upsampled_inputs = [ 996 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) 997 | for x in inputs 998 | ] 999 | inputs = torch.cat(upsampled_inputs, dim=1) 1000 | elif self.input_transform == "multiple_select": 1001 | inputs = [inputs[i] for i in self.in_index] 1002 | else: 1003 | inputs = inputs[self.in_index] 1004 | 1005 | return inputs 1006 | 1007 | def forward(self, inputs1, inputs2): 1008 | x_1 = self._transform_inputs(inputs1) # len=4, 1/4,1/8,1/16,1/32 1009 | x_2 = self._transform_inputs(inputs2) # len=4, 1/4,1/8,1/16,1/32 1010 | 1011 | c1_1, c2_1, c3_1, c4_1 = x_1 1012 | c1_2, c2_2, c3_2, c4_2 = x_2 1013 | 1014 | ############## MLP decoder on C1-C4 ########### 1015 | n, _, h, w = c4_1.shape 1016 | 1017 | _c4_1 = self.linear_c4(c4_1).permute(0, 2, 1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3]) 1018 | _c4_1 = resize(_c4_1, size=c1_1.size()[2:], mode="bilinear", align_corners=False) 1019 | _c4_2 = self.linear_c4(c4_2).permute(0, 2, 1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3]) 1020 | _c4_2 = resize(_c4_2, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1021 | 1022 | _c3_1 = self.linear_c3(c3_1).permute(0, 2, 1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3]) 1023 | _c3_1 = resize(_c3_1, size=c1_1.size()[2:], mode="bilinear", align_corners=False) 1024 | _c3_2 = self.linear_c3(c3_2).permute(0, 2, 1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3]) 1025 | _c3_2 = resize(_c3_2, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1026 | 1027 | _c2_1 = self.linear_c2(c2_1).permute(0, 2, 1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3]) 1028 | _c2_1 = resize(_c2_1, size=c1_1.size()[2:], mode="bilinear", align_corners=False) 1029 | _c2_2 = self.linear_c2(c2_2).permute(0, 2, 1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3]) 1030 | _c2_2 = resize(_c2_2, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1031 | 1032 | _c1_1 = self.linear_c1(c1_1).permute(0, 2, 1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3]) 1033 | _c1_2 = self.linear_c1(c1_2).permute(0, 2, 1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3]) 1034 | 1035 | _c = self.linear_fuse( 1036 | torch.cat( 1037 | [ 1038 | torch.abs(_c4_1 - _c4_2), 1039 | torch.abs(_c3_1 - _c3_2), 1040 | torch.abs(_c2_1 - _c2_2), 1041 | torch.abs(_c1_1 - _c1_2), 1042 | ], 1043 | dim=1, 1044 | ) 1045 | ) 1046 | 1047 | # x = self.dense_2x(x) 1048 | # x = self.convd1x(x) 1049 | # x = self.dense_1x(x) 1050 | 1051 | # cp = self.change_probability(x) 1052 | 1053 | # cp = F.interpolate(_c, scale_factor=4, mode="nearest") 1054 | x = self.relu(self.pix_shuffle_conv(_c)) 1055 | cp = self.pix_shuffle(x) 1056 | 1057 | if self.output_softmax: 1058 | cp = self.active(cp) 1059 | 1060 | return cp 1061 | 1062 | 1063 | # ChangeFormerV2: 1064 | # Transformer encoder to extract features 1065 | # Feature differencing and pass it through Transformer decoder 1066 | class ChangeFormerV2(nn.Module): 1067 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False): 1068 | super().__init__() 1069 | # Transformer Encoder 1070 | self.Tenc = Tenc() 1071 | 1072 | # Transformer Decoder 1073 | self.TDec = TDec( 1074 | input_transform="multiple_select", 1075 | in_index=[0, 1, 2, 3], 1076 | align_corners=True, 1077 | in_channels=[64, 128, 320, 512], 1078 | embedding_dim=32, 1079 | output_nc=output_nc, 1080 | decoder_softmax=decoder_softmax, 1081 | feature_strides=[4, 8, 16, 32], 1082 | ) 1083 | # Final activation 1084 | self.decoder_softmax = decoder_softmax 1085 | self.output_activation = torch.nn.Softmax(dim=1) 1086 | 1087 | def forward(self, x1, x2): 1088 | fx1 = self.Tenc(x1) 1089 | fx2 = self.Tenc(x2) 1090 | 1091 | DI = [] 1092 | for i in range(0, 4): 1093 | DI.append(torch.abs(fx1[i] - fx2[i])) 1094 | 1095 | cp = self.TDec(DI) 1096 | 1097 | if self.decoder_softmax: 1098 | cp = self.output_activation(cp) 1099 | 1100 | return cp 1101 | 1102 | 1103 | # ChangeFormerV3: 1104 | # Feature differencing and pass it through Transformer decoder 1105 | class ChangeFormerV3(nn.Module): 1106 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False): 1107 | super().__init__() 1108 | # Transformer Encoder 1109 | self.Tenc = Tenc( 1110 | patch_size=16, 1111 | embed_dims=[64, 128, 320, 512], 1112 | num_heads=[1, 2, 4, 8], 1113 | mlp_ratios=[4, 4, 4, 4], 1114 | qkv_bias=True, 1115 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 1116 | depths=[3, 4, 6, 3], 1117 | sr_ratios=[8, 4, 2, 1], 1118 | drop_rate=0.0, 1119 | drop_path_rate=0.1, 1120 | ) 1121 | 1122 | # Transformer Decoder 1123 | self.TDec = TDecV2( 1124 | input_transform="multiple_select", 1125 | in_index=[0, 1, 2, 3], 1126 | align_corners=True, 1127 | in_channels=[64, 128, 320, 512], 1128 | embedding_dim=64, 1129 | output_nc=output_nc, 1130 | decoder_softmax=decoder_softmax, 1131 | feature_strides=[4, 8, 16, 32], 1132 | ) 1133 | 1134 | def forward(self, x1, x2): 1135 | fx1 = self.Tenc(x1) 1136 | fx2 = self.Tenc(x2) 1137 | 1138 | cp = self.TDec(fx1, fx2) 1139 | 1140 | return cp 1141 | 1142 | 1143 | # Transormer Ecoder with x2, x4, x8, x16 scales 1144 | class EncoderTransformer_x2(nn.Module): 1145 | def __init__( 1146 | self, 1147 | img_size=256, 1148 | patch_size=3, 1149 | in_chans=3, 1150 | num_classes=2, 1151 | embed_dims=[32, 64, 128, 256, 512], 1152 | num_heads=[2, 2, 4, 8, 16], 1153 | mlp_ratios=[4, 4, 4, 4, 4], 1154 | qkv_bias=False, 1155 | qk_scale=None, 1156 | drop_rate=0.0, 1157 | attn_drop_rate=0.0, 1158 | drop_path_rate=0.0, 1159 | norm_layer=nn.LayerNorm, 1160 | depths=[3, 3, 6, 18, 3], 1161 | sr_ratios=[8, 4, 2, 1, 1], 1162 | ): 1163 | super().__init__() 1164 | self.num_classes = num_classes 1165 | self.depths = depths 1166 | self.embed_dims = embed_dims 1167 | 1168 | # patch embedding definitions 1169 | self.patch_embed1 = OverlapPatchEmbed( 1170 | img_size=img_size, patch_size=7, stride=2, in_chans=in_chans, embed_dim=embed_dims[0] 1171 | ) 1172 | self.patch_embed2 = OverlapPatchEmbed( 1173 | img_size=img_size // 2, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1] 1174 | ) 1175 | self.patch_embed3 = OverlapPatchEmbed( 1176 | img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2] 1177 | ) 1178 | self.patch_embed4 = OverlapPatchEmbed( 1179 | img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3] 1180 | ) 1181 | self.patch_embed5 = OverlapPatchEmbed( 1182 | img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[3], embed_dim=embed_dims[4] 1183 | ) 1184 | 1185 | # Stage-1 (x1/2 scale) 1186 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 1187 | cur = 0 1188 | self.block1 = nn.ModuleList( 1189 | [ 1190 | Block( 1191 | dim=embed_dims[0], 1192 | num_heads=num_heads[0], 1193 | mlp_ratio=mlp_ratios[0], 1194 | qkv_bias=qkv_bias, 1195 | qk_scale=qk_scale, 1196 | drop=drop_rate, 1197 | attn_drop=attn_drop_rate, 1198 | drop_path=dpr[cur + i], 1199 | norm_layer=norm_layer, 1200 | sr_ratio=sr_ratios[0], 1201 | ) 1202 | for i in range(depths[0]) 1203 | ] 1204 | ) 1205 | self.norm1 = norm_layer(embed_dims[0]) 1206 | 1207 | # Stage-2 (x1/4 scale) 1208 | cur += depths[0] 1209 | self.block2 = nn.ModuleList( 1210 | [ 1211 | Block( 1212 | dim=embed_dims[1], 1213 | num_heads=num_heads[1], 1214 | mlp_ratio=mlp_ratios[1], 1215 | qkv_bias=qkv_bias, 1216 | qk_scale=qk_scale, 1217 | drop=drop_rate, 1218 | attn_drop=attn_drop_rate, 1219 | drop_path=dpr[cur + i], 1220 | norm_layer=norm_layer, 1221 | sr_ratio=sr_ratios[1], 1222 | ) 1223 | for i in range(depths[1]) 1224 | ] 1225 | ) 1226 | self.norm2 = norm_layer(embed_dims[1]) 1227 | 1228 | # Stage-3 (x1/8 scale) 1229 | cur += depths[1] 1230 | self.block3 = nn.ModuleList( 1231 | [ 1232 | Block( 1233 | dim=embed_dims[2], 1234 | num_heads=num_heads[2], 1235 | mlp_ratio=mlp_ratios[2], 1236 | qkv_bias=qkv_bias, 1237 | qk_scale=qk_scale, 1238 | drop=drop_rate, 1239 | attn_drop=attn_drop_rate, 1240 | drop_path=dpr[cur + i], 1241 | norm_layer=norm_layer, 1242 | sr_ratio=sr_ratios[2], 1243 | ) 1244 | for i in range(depths[2]) 1245 | ] 1246 | ) 1247 | self.norm3 = norm_layer(embed_dims[2]) 1248 | 1249 | # Stage-4 (x1/16 scale) 1250 | cur += depths[2] 1251 | self.block4 = nn.ModuleList( 1252 | [ 1253 | Block( 1254 | dim=embed_dims[3], 1255 | num_heads=num_heads[3], 1256 | mlp_ratio=mlp_ratios[3], 1257 | qkv_bias=qkv_bias, 1258 | qk_scale=qk_scale, 1259 | drop=drop_rate, 1260 | attn_drop=attn_drop_rate, 1261 | drop_path=dpr[cur + i], 1262 | norm_layer=norm_layer, 1263 | sr_ratio=sr_ratios[3], 1264 | ) 1265 | for i in range(depths[3]) 1266 | ] 1267 | ) 1268 | self.norm4 = norm_layer(embed_dims[3]) 1269 | 1270 | # Stage-5 (x1/32 scale) 1271 | cur += depths[3] 1272 | self.block5 = nn.ModuleList( 1273 | [ 1274 | Block( 1275 | dim=embed_dims[4], 1276 | num_heads=num_heads[4], 1277 | mlp_ratio=mlp_ratios[4], 1278 | qkv_bias=qkv_bias, 1279 | qk_scale=qk_scale, 1280 | drop=drop_rate, 1281 | attn_drop=attn_drop_rate, 1282 | drop_path=dpr[cur + i], 1283 | norm_layer=norm_layer, 1284 | sr_ratio=sr_ratios[4], 1285 | ) 1286 | for i in range(depths[4]) 1287 | ] 1288 | ) 1289 | self.norm5 = norm_layer(embed_dims[4]) 1290 | 1291 | self.apply(self._init_weights) 1292 | 1293 | def _init_weights(self, m): 1294 | if isinstance(m, nn.Linear): 1295 | trunc_normal_(m.weight, std=0.02) 1296 | if isinstance(m, nn.Linear) and m.bias is not None: 1297 | nn.init.constant_(m.bias, 0) 1298 | elif isinstance(m, nn.LayerNorm): 1299 | nn.init.constant_(m.bias, 0) 1300 | nn.init.constant_(m.weight, 1.0) 1301 | elif isinstance(m, nn.Conv2d): 1302 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 1303 | fan_out //= m.groups 1304 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 1305 | if m.bias is not None: 1306 | m.bias.data.zero_() 1307 | 1308 | def reset_drop_path(self, drop_path_rate): 1309 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 1310 | cur = 0 1311 | for i in range(self.depths[0]): 1312 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 1313 | 1314 | cur += self.depths[0] 1315 | for i in range(self.depths[1]): 1316 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 1317 | 1318 | cur += self.depths[1] 1319 | for i in range(self.depths[2]): 1320 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 1321 | 1322 | cur += self.depths[2] 1323 | for i in range(self.depths[3]): 1324 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 1325 | 1326 | def forward_features(self, x): 1327 | B = x.shape[0] 1328 | outs = [] 1329 | 1330 | # stage 1 1331 | x1, H1, W1 = self.patch_embed1(x) 1332 | for _i, blk in enumerate(self.block1): 1333 | x1 = blk(x1, H1, W1) 1334 | x1 = self.norm1(x1) 1335 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1336 | outs.append(x1) 1337 | 1338 | # stage 2 1339 | x1, H1, W1 = self.patch_embed2(x1) 1340 | for _i, blk in enumerate(self.block2): 1341 | x1 = blk(x1, H1, W1) 1342 | x1 = self.norm2(x1) 1343 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1344 | outs.append(x1) 1345 | 1346 | # stage 3 1347 | x1, H1, W1 = self.patch_embed3(x1) 1348 | for _i, blk in enumerate(self.block3): 1349 | x1 = blk(x1, H1, W1) 1350 | x1 = self.norm3(x1) 1351 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1352 | outs.append(x1) 1353 | 1354 | # stage 4 1355 | x1, H1, W1 = self.patch_embed4(x1) 1356 | for _i, blk in enumerate(self.block4): 1357 | x1 = blk(x1, H1, W1) 1358 | x1 = self.norm4(x1) 1359 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1360 | outs.append(x1) 1361 | 1362 | # stage 5 1363 | x1, H1, W1 = self.patch_embed5(x1) 1364 | for _i, blk in enumerate(self.block5): 1365 | x1 = blk(x1, H1, W1) 1366 | x1 = self.norm5(x1) 1367 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1368 | outs.append(x1) 1369 | 1370 | return outs 1371 | 1372 | def forward(self, x): 1373 | x = self.forward_features(x) 1374 | return x 1375 | 1376 | 1377 | # Difference module 1378 | def conv_diff(in_channels, out_channels): 1379 | return nn.Sequential( 1380 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 1381 | nn.ReLU(), 1382 | nn.BatchNorm2d(out_channels), 1383 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 1384 | nn.ReLU(), 1385 | ) 1386 | 1387 | 1388 | # Intermediate prediction module 1389 | def make_prediction(in_channels, out_channels): 1390 | return nn.Sequential( 1391 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 1392 | nn.ReLU(), 1393 | nn.BatchNorm2d(out_channels), 1394 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 1395 | ) 1396 | 1397 | 1398 | class DecoderTransformer_x2(nn.Module): 1399 | """ 1400 | Transformer Decoder 1401 | """ 1402 | 1403 | def __init__( 1404 | self, 1405 | input_transform="multiple_select", 1406 | in_index=[0, 1, 2, 3, 4], 1407 | align_corners=True, 1408 | in_channels=[32, 64, 128, 256, 512], 1409 | embedding_dim=64, 1410 | output_nc=2, 1411 | decoder_softmax=False, 1412 | feature_strides=[2, 4, 8, 16, 32], 1413 | ): 1414 | super().__init__() 1415 | assert len(feature_strides) == len(in_channels) 1416 | assert min(feature_strides) == feature_strides[0] 1417 | self.feature_strides = feature_strides 1418 | 1419 | # input transforms 1420 | self.input_transform = input_transform 1421 | self.in_index = in_index 1422 | self.align_corners = align_corners 1423 | 1424 | # MLP 1425 | self.in_channels = in_channels 1426 | self.embedding_dim = embedding_dim 1427 | 1428 | # Final prediction 1429 | self.output_nc = output_nc 1430 | 1431 | (c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels, c5_in_channels) = self.in_channels 1432 | 1433 | self.linear_c5 = MLP(input_dim=c5_in_channels, embed_dim=self.embedding_dim) 1434 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) 1435 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) 1436 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) 1437 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) 1438 | 1439 | # Convolutional Difference Modules 1440 | self.diff_c5 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) 1441 | self.diff_c4 = conv_diff(in_channels=3 * self.embedding_dim, out_channels=self.embedding_dim) 1442 | self.diff_c3 = conv_diff(in_channels=3 * self.embedding_dim, out_channels=self.embedding_dim) 1443 | self.diff_c2 = conv_diff(in_channels=3 * self.embedding_dim, out_channels=self.embedding_dim) 1444 | self.diff_c1 = conv_diff(in_channels=3 * self.embedding_dim, out_channels=self.embedding_dim) 1445 | 1446 | # Taking outputs from middle of the encoder 1447 | self.make_pred_c5 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1448 | self.make_pred_c4 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1449 | self.make_pred_c3 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1450 | self.make_pred_c2 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1451 | self.make_pred_c1 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1452 | 1453 | self.linear_fuse = nn.Conv2d( 1454 | in_channels=self.embedding_dim * len(in_channels), out_channels=self.embedding_dim, kernel_size=1 1455 | ) 1456 | 1457 | # self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 1458 | self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) 1459 | self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim)) 1460 | self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) 1461 | self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim)) 1462 | 1463 | # Final prediction 1464 | self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1) 1465 | 1466 | # Final activation 1467 | self.output_softmax = decoder_softmax 1468 | self.active = nn.Sigmoid() 1469 | 1470 | def _transform_inputs(self, inputs): 1471 | """Transform inputs for decoder. 1472 | Args: 1473 | inputs (list[Tensor]): List of multi-level img features. 1474 | Returns: 1475 | Tensor: The transformed inputs 1476 | """ 1477 | 1478 | if self.input_transform == "resize_concat": 1479 | inputs = [inputs[i] for i in self.in_index] 1480 | upsampled_inputs = [ 1481 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) 1482 | for x in inputs 1483 | ] 1484 | inputs = torch.cat(upsampled_inputs, dim=1) 1485 | elif self.input_transform == "multiple_select": 1486 | inputs = [inputs[i] for i in self.in_index] 1487 | else: 1488 | inputs = inputs[self.in_index] 1489 | 1490 | return inputs 1491 | 1492 | def forward(self, inputs1, inputs2): 1493 | x_1 = self._transform_inputs(inputs1) # len=4, 1/2,1/4,1/8,1/16,1/32 1494 | x_2 = self._transform_inputs(inputs2) # len=4, 1/2,1/4,1/8,1/16,1/32 1495 | 1496 | c1_1, c2_1, c3_1, c4_1, c5_1 = x_1 1497 | c1_2, c2_2, c3_2, c4_2, c5_2 = x_2 1498 | 1499 | ############## MLP decoder on C1-C4 ########### 1500 | n, _, h, w = c5_1.shape 1501 | 1502 | outputs = [] # Multi-scale outputs adding here 1503 | 1504 | _c5_1 = self.linear_c5(c5_1).permute(0, 2, 1).reshape(n, -1, c5_1.shape[2], c5_1.shape[3]) 1505 | _c5_2 = self.linear_c5(c5_2).permute(0, 2, 1).reshape(n, -1, c5_2.shape[2], c5_2.shape[3]) 1506 | _c5 = self.diff_c5(torch.cat((_c5_1, _c5_2), dim=1)) # Difference of features at x1/32 scale 1507 | p_c5 = self.make_pred_c5(_c5) # Predicted change map at x1/32 scale 1508 | outputs.append(p_c5) # x1/32 scale 1509 | _c5_up = resize(_c5, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1510 | 1511 | _c4_1 = self.linear_c4(c4_1).permute(0, 2, 1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3]) 1512 | _c4_2 = self.linear_c4(c4_2).permute(0, 2, 1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3]) 1513 | _c4 = self.diff_c4( 1514 | torch.cat((F.interpolate(_c5, scale_factor=2, mode="bilinear"), _c4_1, _c4_2), dim=1) 1515 | ) # Difference of features at x1/16 scale 1516 | p_c4 = self.make_pred_c4(_c4) # Predicted change map at x1/16 scale 1517 | outputs.append(p_c4) # x1/16 scale 1518 | _c4_up = resize(_c4, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1519 | 1520 | _c3_1 = self.linear_c3(c3_1).permute(0, 2, 1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3]) 1521 | _c3_2 = self.linear_c3(c3_2).permute(0, 2, 1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3]) 1522 | _c3 = self.diff_c3( 1523 | torch.cat((F.interpolate(_c4, scale_factor=2, mode="bilinear"), _c3_1, _c3_2), dim=1) 1524 | ) # Difference of features at x1/8 scale 1525 | p_c3 = self.make_pred_c3(_c3) # Predicted change map at x1/8 scale 1526 | outputs.append(p_c3) # x1/8 scale 1527 | _c3_up = resize(_c3, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1528 | 1529 | _c2_1 = self.linear_c2(c2_1).permute(0, 2, 1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3]) 1530 | _c2_2 = self.linear_c2(c2_2).permute(0, 2, 1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3]) 1531 | _c2 = self.diff_c2( 1532 | torch.cat((F.interpolate(_c3, scale_factor=2, mode="bilinear"), _c2_1, _c2_2), dim=1) 1533 | ) # Difference of features at x1/4 scale 1534 | p_c2 = self.make_pred_c2(_c2) # Predicted change map at x1/4 scale 1535 | outputs.append(p_c2) # x1/4 scale 1536 | _c2_up = resize(_c2, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1537 | 1538 | _c1_1 = self.linear_c1(c1_1).permute(0, 2, 1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3]) 1539 | _c1_2 = self.linear_c1(c1_2).permute(0, 2, 1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3]) 1540 | _c1 = self.diff_c1( 1541 | torch.cat((F.interpolate(_c2, scale_factor=2, mode="bilinear"), _c1_1, _c1_2), dim=1) 1542 | ) # Difference of features at x1/2 scale 1543 | p_c1 = self.make_pred_c1(_c1) # Predicted change map at x1/2 scale 1544 | outputs.append(p_c1) # x1/2 scale 1545 | 1546 | _c = self.linear_fuse(torch.cat((_c5_up, _c4_up, _c3_up, _c2_up, _c1), dim=1)) 1547 | 1548 | x = self.convd2x(_c) 1549 | x = self.dense_2x(x) 1550 | cp = self.change_probability(x) 1551 | outputs.append(cp) 1552 | 1553 | if self.output_softmax: 1554 | temp = outputs 1555 | outputs = [] 1556 | for pred in temp: 1557 | outputs.append(self.active(pred)) 1558 | 1559 | return outputs 1560 | 1561 | 1562 | # ChangeFormerV4: 1563 | class ChangeFormerV4(nn.Module): 1564 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False): 1565 | super().__init__() 1566 | # Transformer Encoder 1567 | self.embed_dims = [32, 64, 128, 320, 512] 1568 | self.depths = [3, 3, 4, 12, 3] # [3, 3, 6, 18, 3] 1569 | self.embedding_dim = 256 1570 | 1571 | self.Tenc_x2 = EncoderTransformer_x2( 1572 | img_size=256, 1573 | patch_size=3, 1574 | in_chans=input_nc, 1575 | num_classes=output_nc, 1576 | embed_dims=self.embed_dims, 1577 | num_heads=[2, 2, 4, 8, 16], 1578 | mlp_ratios=[2, 2, 2, 2, 2], 1579 | qkv_bias=False, 1580 | qk_scale=None, 1581 | drop_rate=0.0, 1582 | attn_drop_rate=0.0, 1583 | drop_path_rate=0.0, 1584 | norm_layer=nn.LayerNorm, 1585 | depths=self.depths, 1586 | sr_ratios=[8, 4, 2, 1, 1], 1587 | ) 1588 | 1589 | # Transformer Decoder 1590 | self.TDec_x2 = DecoderTransformer_x2( 1591 | input_transform="multiple_select", 1592 | in_index=[0, 1, 2, 3, 4], 1593 | align_corners=True, 1594 | in_channels=self.embed_dims, 1595 | embedding_dim=256, 1596 | output_nc=output_nc, 1597 | decoder_softmax=decoder_softmax, 1598 | feature_strides=[2, 4, 8, 16, 32], 1599 | ) 1600 | 1601 | def forward(self, x1, x2): 1602 | [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)] 1603 | 1604 | cp = self.TDec_x2(fx1, fx2) 1605 | 1606 | # # Save to mat 1607 | # save_to_mat(x1, x2, fx1, fx2, cp, "ChangeFormerV4") 1608 | 1609 | # exit() 1610 | return cp 1611 | 1612 | 1613 | # Transormer Ecoder with x2, x4, x8, x16 scales 1614 | class EncoderTransformer_v3(nn.Module): 1615 | def __init__( 1616 | self, 1617 | img_size=256, 1618 | patch_size=3, 1619 | in_chans=3, 1620 | num_classes=2, 1621 | embed_dims=[32, 64, 128, 256], 1622 | num_heads=[2, 2, 4, 8], 1623 | mlp_ratios=[4, 4, 4, 4], 1624 | qkv_bias=True, 1625 | qk_scale=None, 1626 | drop_rate=0.0, 1627 | attn_drop_rate=0.0, 1628 | drop_path_rate=0.0, 1629 | norm_layer=nn.LayerNorm, 1630 | depths=[3, 3, 6, 18], 1631 | sr_ratios=[8, 4, 2, 1], 1632 | ): 1633 | super().__init__() 1634 | self.num_classes = num_classes 1635 | self.depths = depths 1636 | self.embed_dims = embed_dims 1637 | 1638 | # patch embedding definitions 1639 | self.patch_embed1 = OverlapPatchEmbed( 1640 | img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, embed_dim=embed_dims[0] 1641 | ) 1642 | self.patch_embed2 = OverlapPatchEmbed( 1643 | img_size=img_size // 4, patch_size=patch_size, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1] 1644 | ) 1645 | self.patch_embed3 = OverlapPatchEmbed( 1646 | img_size=img_size // 8, patch_size=patch_size, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2] 1647 | ) 1648 | self.patch_embed4 = OverlapPatchEmbed( 1649 | img_size=img_size // 16, patch_size=patch_size, stride=2, in_chans=embed_dims[2], embed_dim=embed_dims[3] 1650 | ) 1651 | 1652 | # Stage-1 (x1/4 scale) 1653 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 1654 | cur = 0 1655 | self.block1 = nn.ModuleList( 1656 | [ 1657 | Block( 1658 | dim=embed_dims[0], 1659 | num_heads=num_heads[0], 1660 | mlp_ratio=mlp_ratios[0], 1661 | qkv_bias=qkv_bias, 1662 | qk_scale=qk_scale, 1663 | drop=drop_rate, 1664 | attn_drop=attn_drop_rate, 1665 | drop_path=dpr[cur + i], 1666 | norm_layer=norm_layer, 1667 | sr_ratio=sr_ratios[0], 1668 | ) 1669 | for i in range(depths[0]) 1670 | ] 1671 | ) 1672 | self.norm1 = norm_layer(embed_dims[0]) 1673 | 1674 | # Stage-2 (x1/8 scale) 1675 | cur += depths[0] 1676 | self.block2 = nn.ModuleList( 1677 | [ 1678 | Block( 1679 | dim=embed_dims[1], 1680 | num_heads=num_heads[1], 1681 | mlp_ratio=mlp_ratios[1], 1682 | qkv_bias=qkv_bias, 1683 | qk_scale=qk_scale, 1684 | drop=drop_rate, 1685 | attn_drop=attn_drop_rate, 1686 | drop_path=dpr[cur + i], 1687 | norm_layer=norm_layer, 1688 | sr_ratio=sr_ratios[1], 1689 | ) 1690 | for i in range(depths[1]) 1691 | ] 1692 | ) 1693 | self.norm2 = norm_layer(embed_dims[1]) 1694 | 1695 | # Stage-3 (x1/16 scale) 1696 | cur += depths[1] 1697 | self.block3 = nn.ModuleList( 1698 | [ 1699 | Block( 1700 | dim=embed_dims[2], 1701 | num_heads=num_heads[2], 1702 | mlp_ratio=mlp_ratios[2], 1703 | qkv_bias=qkv_bias, 1704 | qk_scale=qk_scale, 1705 | drop=drop_rate, 1706 | attn_drop=attn_drop_rate, 1707 | drop_path=dpr[cur + i], 1708 | norm_layer=norm_layer, 1709 | sr_ratio=sr_ratios[2], 1710 | ) 1711 | for i in range(depths[2]) 1712 | ] 1713 | ) 1714 | self.norm3 = norm_layer(embed_dims[2]) 1715 | 1716 | # Stage-4 (x1/32 scale) 1717 | cur += depths[2] 1718 | self.block4 = nn.ModuleList( 1719 | [ 1720 | Block( 1721 | dim=embed_dims[3], 1722 | num_heads=num_heads[3], 1723 | mlp_ratio=mlp_ratios[3], 1724 | qkv_bias=qkv_bias, 1725 | qk_scale=qk_scale, 1726 | drop=drop_rate, 1727 | attn_drop=attn_drop_rate, 1728 | drop_path=dpr[cur + i], 1729 | norm_layer=norm_layer, 1730 | sr_ratio=sr_ratios[3], 1731 | ) 1732 | for i in range(depths[3]) 1733 | ] 1734 | ) 1735 | self.norm4 = norm_layer(embed_dims[3]) 1736 | 1737 | self.apply(self._init_weights) 1738 | 1739 | def _init_weights(self, m): 1740 | if isinstance(m, nn.Linear): 1741 | trunc_normal_(m.weight, std=0.02) 1742 | if isinstance(m, nn.Linear) and m.bias is not None: 1743 | nn.init.constant_(m.bias, 0) 1744 | elif isinstance(m, nn.LayerNorm): 1745 | nn.init.constant_(m.bias, 0) 1746 | nn.init.constant_(m.weight, 1.0) 1747 | elif isinstance(m, nn.Conv2d): 1748 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 1749 | fan_out //= m.groups 1750 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 1751 | if m.bias is not None: 1752 | m.bias.data.zero_() 1753 | 1754 | def reset_drop_path(self, drop_path_rate): 1755 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 1756 | cur = 0 1757 | for i in range(self.depths[0]): 1758 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 1759 | 1760 | cur += self.depths[0] 1761 | for i in range(self.depths[1]): 1762 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 1763 | 1764 | cur += self.depths[1] 1765 | for i in range(self.depths[2]): 1766 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 1767 | 1768 | cur += self.depths[2] 1769 | for i in range(self.depths[3]): 1770 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 1771 | 1772 | def forward_features(self, x): 1773 | B = x.shape[0] 1774 | outs = [] 1775 | 1776 | # stage 1 1777 | x1, H1, W1 = self.patch_embed1(x) 1778 | for _i, blk in enumerate(self.block1): 1779 | x1 = blk(x1, H1, W1) 1780 | x1 = self.norm1(x1) 1781 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1782 | outs.append(x1) 1783 | 1784 | # stage 2 1785 | x1, H1, W1 = self.patch_embed2(x1) 1786 | for _i, blk in enumerate(self.block2): 1787 | x1 = blk(x1, H1, W1) 1788 | x1 = self.norm2(x1) 1789 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1790 | outs.append(x1) 1791 | 1792 | # stage 3 1793 | x1, H1, W1 = self.patch_embed3(x1) 1794 | for _i, blk in enumerate(self.block3): 1795 | x1 = blk(x1, H1, W1) 1796 | x1 = self.norm3(x1) 1797 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1798 | outs.append(x1) 1799 | 1800 | # stage 4 1801 | x1, H1, W1 = self.patch_embed4(x1) 1802 | for _i, blk in enumerate(self.block4): 1803 | x1 = blk(x1, H1, W1) 1804 | x1 = self.norm4(x1) 1805 | x1 = x1.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() 1806 | outs.append(x1) 1807 | return outs 1808 | 1809 | def forward(self, x): 1810 | x = self.forward_features(x) 1811 | return x 1812 | 1813 | 1814 | class DecoderTransformer_v3(nn.Module): 1815 | """ 1816 | Transformer Decoder 1817 | """ 1818 | 1819 | def __init__( 1820 | self, 1821 | input_transform="multiple_select", 1822 | in_index=[0, 1, 2, 3], 1823 | align_corners=True, 1824 | in_channels=[32, 64, 128, 256], 1825 | embedding_dim=64, 1826 | output_nc=2, 1827 | decoder_softmax=False, 1828 | feature_strides=[2, 4, 8, 16], 1829 | ): 1830 | super().__init__() 1831 | # assert 1832 | assert len(feature_strides) == len(in_channels) 1833 | assert min(feature_strides) == feature_strides[0] 1834 | 1835 | # settings 1836 | self.feature_strides = feature_strides 1837 | self.input_transform = input_transform 1838 | self.in_index = in_index 1839 | self.align_corners = align_corners 1840 | self.in_channels = in_channels 1841 | self.embedding_dim = embedding_dim 1842 | self.output_nc = output_nc 1843 | (c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels) = self.in_channels 1844 | 1845 | # MLP decoder heads 1846 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim) 1847 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim) 1848 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim) 1849 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim) 1850 | 1851 | # convolutional Difference Modules 1852 | self.diff_c4 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) 1853 | self.diff_c3 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) 1854 | self.diff_c2 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) 1855 | self.diff_c1 = conv_diff(in_channels=2 * self.embedding_dim, out_channels=self.embedding_dim) 1856 | 1857 | # taking outputs from middle of the encoder 1858 | self.make_pred_c4 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1859 | self.make_pred_c3 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1860 | self.make_pred_c2 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1861 | self.make_pred_c1 = make_prediction(in_channels=self.embedding_dim, out_channels=self.output_nc) 1862 | 1863 | # Final linear fusion layer 1864 | self.linear_fuse = nn.Sequential( 1865 | nn.Conv2d( 1866 | in_channels=self.embedding_dim * len(in_channels), out_channels=self.embedding_dim, kernel_size=1 1867 | ), 1868 | nn.BatchNorm2d(self.embedding_dim), 1869 | ) 1870 | 1871 | # Final predction head 1872 | self.convd2x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) 1873 | self.dense_2x = nn.Sequential(ResidualBlock(self.embedding_dim)) 1874 | self.convd1x = UpsampleConvLayer(self.embedding_dim, self.embedding_dim, kernel_size=4, stride=2) 1875 | self.dense_1x = nn.Sequential(ResidualBlock(self.embedding_dim)) 1876 | self.change_probability = ConvLayer(self.embedding_dim, self.output_nc, kernel_size=3, stride=1, padding=1) 1877 | 1878 | # Final activation 1879 | self.output_softmax = decoder_softmax 1880 | self.active = nn.Sigmoid() 1881 | 1882 | def _transform_inputs(self, inputs): 1883 | """Transform inputs for decoder. 1884 | Args: 1885 | inputs (list[Tensor]): List of multi-level img features. 1886 | Returns: 1887 | Tensor: The transformed inputs 1888 | """ 1889 | 1890 | if self.input_transform == "resize_concat": 1891 | inputs = [inputs[i] for i in self.in_index] 1892 | upsampled_inputs = [ 1893 | resize(input=x, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners) 1894 | for x in inputs 1895 | ] 1896 | inputs = torch.cat(upsampled_inputs, dim=1) 1897 | elif self.input_transform == "multiple_select": 1898 | inputs = [inputs[i] for i in self.in_index] 1899 | else: 1900 | inputs = inputs[self.in_index] 1901 | 1902 | return inputs 1903 | 1904 | def forward(self, inputs1, inputs2): 1905 | # Transforming encoder features (select layers) 1906 | x_1 = self._transform_inputs(inputs1) # len=4, 1/2, 1/4, 1/8, 1/16 1907 | x_2 = self._transform_inputs(inputs2) # len=4, 1/2, 1/4, 1/8, 1/16 1908 | 1909 | # img1 and img2 features 1910 | c1_1, c2_1, c3_1, c4_1 = x_1 1911 | c1_2, c2_2, c3_2, c4_2 = x_2 1912 | 1913 | ############## MLP decoder on C1-C4 ########### 1914 | n, _, h, w = c4_1.shape 1915 | 1916 | outputs = [] 1917 | # Stage 4: x1/32 scale 1918 | _c4_1 = self.linear_c4(c4_1).permute(0, 2, 1).reshape(n, -1, c4_1.shape[2], c4_1.shape[3]) 1919 | _c4_2 = self.linear_c4(c4_2).permute(0, 2, 1).reshape(n, -1, c4_2.shape[2], c4_2.shape[3]) 1920 | _c4 = self.diff_c4(torch.cat((_c4_1, _c4_2), dim=1)) 1921 | p_c4 = self.make_pred_c4(_c4) 1922 | outputs.append(p_c4) 1923 | _c4_up = resize(_c4, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1924 | 1925 | # Stage 3: x1/16 scale 1926 | _c3_1 = self.linear_c3(c3_1).permute(0, 2, 1).reshape(n, -1, c3_1.shape[2], c3_1.shape[3]) 1927 | _c3_2 = self.linear_c3(c3_2).permute(0, 2, 1).reshape(n, -1, c3_2.shape[2], c3_2.shape[3]) 1928 | _c3 = self.diff_c3(torch.cat((_c3_1, _c3_2), dim=1)) + F.interpolate(_c4, scale_factor=2, mode="bilinear") 1929 | p_c3 = self.make_pred_c3(_c3) 1930 | outputs.append(p_c3) 1931 | _c3_up = resize(_c3, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1932 | 1933 | # Stage 2: x1/8 scale 1934 | _c2_1 = self.linear_c2(c2_1).permute(0, 2, 1).reshape(n, -1, c2_1.shape[2], c2_1.shape[3]) 1935 | _c2_2 = self.linear_c2(c2_2).permute(0, 2, 1).reshape(n, -1, c2_2.shape[2], c2_2.shape[3]) 1936 | _c2 = self.diff_c2(torch.cat((_c2_1, _c2_2), dim=1)) + F.interpolate(_c3, scale_factor=2, mode="bilinear") 1937 | p_c2 = self.make_pred_c2(_c2) 1938 | outputs.append(p_c2) 1939 | _c2_up = resize(_c2, size=c1_2.size()[2:], mode="bilinear", align_corners=False) 1940 | 1941 | # Stage 1: x1/4 scale 1942 | _c1_1 = self.linear_c1(c1_1).permute(0, 2, 1).reshape(n, -1, c1_1.shape[2], c1_1.shape[3]) 1943 | _c1_2 = self.linear_c1(c1_2).permute(0, 2, 1).reshape(n, -1, c1_2.shape[2], c1_2.shape[3]) 1944 | _c1 = self.diff_c1(torch.cat((_c1_1, _c1_2), dim=1)) + F.interpolate(_c2, scale_factor=2, mode="bilinear") 1945 | p_c1 = self.make_pred_c1(_c1) 1946 | outputs.append(p_c1) 1947 | 1948 | # Linear Fusion of difference image from all scales 1949 | _c = self.linear_fuse(torch.cat((_c4_up, _c3_up, _c2_up, _c1), dim=1)) 1950 | 1951 | # #Dropout 1952 | # if dropout_ratio > 0: 1953 | # self.dropout = nn.Dropout2d(dropout_ratio) 1954 | # else: 1955 | # self.dropout = None 1956 | 1957 | # Upsampling x2 (x1/2 scale) 1958 | x = self.convd2x(_c) 1959 | # Residual block 1960 | x = self.dense_2x(x) 1961 | # Upsampling x2 (x1 scale) 1962 | x = self.convd1x(x) 1963 | # Residual block 1964 | x = self.dense_1x(x) 1965 | 1966 | # Final prediction 1967 | cp = self.change_probability(x) 1968 | 1969 | outputs.append(cp) 1970 | 1971 | if self.output_softmax: 1972 | temp = outputs 1973 | outputs = [] 1974 | for pred in temp: 1975 | outputs.append(self.active(pred)) 1976 | 1977 | return outputs 1978 | 1979 | 1980 | # ChangeFormerV5: 1981 | class ChangeFormerV5(nn.Module): 1982 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False, embed_dim=256): 1983 | super().__init__() 1984 | # Transformer Encoder 1985 | self.embed_dims = [64, 128, 320, 512] 1986 | self.depths = [3, 6, 16, 3] # [3, 3, 6, 18, 3] 1987 | self.embedding_dim = embed_dim 1988 | self.drop_rate = 0.0 1989 | self.attn_drop = 0.0 1990 | self.drop_path_rate = 0.1 1991 | 1992 | self.Tenc_x2 = EncoderTransformer_v3( 1993 | img_size=256, 1994 | patch_size=4, 1995 | in_chans=input_nc, 1996 | num_classes=output_nc, 1997 | embed_dims=self.embed_dims, 1998 | num_heads=[1, 2, 5, 8], 1999 | mlp_ratios=[4, 4, 4, 4], 2000 | qkv_bias=True, 2001 | qk_scale=None, 2002 | drop_rate=self.drop_rate, 2003 | attn_drop_rate=self.attn_drop, 2004 | drop_path_rate=self.drop_path_rate, 2005 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 2006 | depths=self.depths, 2007 | sr_ratios=[8, 4, 2, 1], 2008 | ) 2009 | 2010 | # Transformer Decoder 2011 | self.TDec_x2 = DecoderTransformer_v3( 2012 | input_transform="multiple_select", 2013 | in_index=[0, 1, 2, 3], 2014 | align_corners=False, 2015 | in_channels=self.embed_dims, 2016 | embedding_dim=self.embedding_dim, 2017 | output_nc=output_nc, 2018 | decoder_softmax=decoder_softmax, 2019 | feature_strides=[2, 4, 8, 16], 2020 | ) 2021 | 2022 | def forward(self, x1, x2): 2023 | [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)] 2024 | 2025 | cp = self.TDec_x2(fx1, fx2) 2026 | 2027 | # # Save to mat 2028 | # save_to_mat(x1, x2, fx1, fx2, cp, "ChangeFormerV4") 2029 | 2030 | # exit() 2031 | return cp 2032 | 2033 | 2034 | # ChangeFormerV6: 2035 | class ChangeFormerV6(nn.Module): 2036 | def __init__(self, input_nc=3, output_nc=2, decoder_softmax=False, embed_dim=256): 2037 | super().__init__() 2038 | # Transformer Encoder 2039 | self.embed_dims = [64, 128, 320, 512] 2040 | self.depths = [3, 3, 4, 3] # [3, 3, 6, 18, 3] 2041 | self.embedding_dim = embed_dim 2042 | self.drop_rate = 0.1 2043 | self.attn_drop = 0.1 2044 | self.drop_path_rate = 0.1 2045 | 2046 | self.Tenc_x2 = EncoderTransformer_v3( 2047 | img_size=256, 2048 | patch_size=7, 2049 | in_chans=input_nc, 2050 | num_classes=output_nc, 2051 | embed_dims=self.embed_dims, 2052 | num_heads=[1, 2, 4, 8], 2053 | mlp_ratios=[4, 4, 4, 4], 2054 | qkv_bias=True, 2055 | qk_scale=None, 2056 | drop_rate=self.drop_rate, 2057 | attn_drop_rate=self.attn_drop, 2058 | drop_path_rate=self.drop_path_rate, 2059 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 2060 | depths=self.depths, 2061 | sr_ratios=[8, 4, 2, 1], 2062 | ) 2063 | 2064 | # Transformer Decoder 2065 | self.TDec_x2 = DecoderTransformer_v3( 2066 | input_transform="multiple_select", 2067 | in_index=[0, 1, 2, 3], 2068 | align_corners=False, 2069 | in_channels=self.embed_dims, 2070 | embedding_dim=self.embedding_dim, 2071 | output_nc=output_nc, 2072 | decoder_softmax=decoder_softmax, 2073 | feature_strides=[2, 4, 8, 16], 2074 | ) 2075 | 2076 | def forward(self, x1, x2): 2077 | [fx1, fx2] = [self.Tenc_x2(x1), self.Tenc_x2(x2)] 2078 | 2079 | cp = self.TDec_x2(fx1, fx2)[-1] 2080 | 2081 | # # Save to mat 2082 | # save_to_mat(x1, x2, fx1, fx2, cp, "ChangeFormerV4") 2083 | 2084 | # exit() 2085 | return cp 2086 | -------------------------------------------------------------------------------- /src/models/changeformer/ChangeFormerBaseNetworks.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | 7 | 8 | class ConvBlock(torch.nn.Module): 9 | def __init__( 10 | self, input_size, output_size, kernel_size=3, stride=1, padding=1, bias=True, activation="prelu", norm=None 11 | ): 12 | super().__init__() 13 | self.conv = torch.nn.Conv2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 14 | 15 | self.norm = norm 16 | if self.norm == "batch": 17 | self.bn = torch.nn.BatchNorm2d(output_size) 18 | elif self.norm == "instance": 19 | self.bn = torch.nn.InstanceNorm2d(output_size) 20 | 21 | self.activation = activation 22 | if self.activation == "relu": 23 | self.act = torch.nn.ReLU(True) 24 | elif self.activation == "prelu": 25 | self.act = torch.nn.PReLU() 26 | elif self.activation == "lrelu": 27 | self.act = torch.nn.LeakyReLU(0.2, True) 28 | elif self.activation == "tanh": 29 | self.act = torch.nn.Tanh() 30 | elif self.activation == "sigmoid": 31 | self.act = torch.nn.Sigmoid() 32 | 33 | def forward(self, x): 34 | if self.norm is not None: 35 | out = self.bn(self.conv(x)) 36 | else: 37 | out = self.conv(x) 38 | 39 | if self.activation != "no": 40 | return self.act(out) 41 | else: 42 | return out 43 | 44 | 45 | class DeconvBlock(torch.nn.Module): 46 | def __init__( 47 | self, input_size, output_size, kernel_size=4, stride=2, padding=1, bias=True, activation="prelu", norm=None 48 | ): 49 | super().__init__() 50 | self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, bias=bias) 51 | 52 | self.norm = norm 53 | if self.norm == "batch": 54 | self.bn = torch.nn.BatchNorm2d(output_size) 55 | elif self.norm == "instance": 56 | self.bn = torch.nn.InstanceNorm2d(output_size) 57 | 58 | self.activation = activation 59 | if self.activation == "relu": 60 | self.act = torch.nn.ReLU(True) 61 | elif self.activation == "prelu": 62 | self.act = torch.nn.PReLU() 63 | elif self.activation == "lrelu": 64 | self.act = torch.nn.LeakyReLU(0.2, True) 65 | elif self.activation == "tanh": 66 | self.act = torch.nn.Tanh() 67 | elif self.activation == "sigmoid": 68 | self.act = torch.nn.Sigmoid() 69 | 70 | def forward(self, x): 71 | if self.norm is not None: 72 | out = self.bn(self.deconv(x)) 73 | else: 74 | out = self.deconv(x) 75 | 76 | if self.activation is not None: 77 | return self.act(out) 78 | else: 79 | return out 80 | 81 | 82 | class ConvLayer(nn.Module): 83 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 84 | super().__init__() 85 | # reflection_padding = kernel_size // 2 86 | # self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 87 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) 88 | 89 | def forward(self, x): 90 | # out = self.reflection_pad(x) 91 | out = self.conv2d(x) 92 | return out 93 | 94 | 95 | class UpsampleConvLayer(torch.nn.Module): 96 | def __init__(self, in_channels, out_channels, kernel_size, stride): 97 | super().__init__() 98 | self.conv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=1) 99 | 100 | def forward(self, x): 101 | out = self.conv2d(x) 102 | return out 103 | 104 | 105 | class ResidualBlock(torch.nn.Module): 106 | def __init__(self, channels): 107 | super().__init__() 108 | self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1) 109 | self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1, padding=1) 110 | self.relu = nn.ReLU() 111 | 112 | def forward(self, x): 113 | residual = x 114 | out = self.relu(self.conv1(x)) 115 | out = self.conv2(out) * 0.1 116 | out = torch.add(out, residual) 117 | return out 118 | 119 | 120 | def init_linear(linear): 121 | init.xavier_normal(linear.weight) 122 | linear.bias.data.zero_() 123 | 124 | 125 | def init_conv(conv, glu=True): 126 | init.kaiming_normal(conv.weight) 127 | if conv.bias is not None: 128 | conv.bias.data.zero_() 129 | 130 | 131 | class EqualLR: 132 | def __init__(self, name): 133 | self.name = name 134 | 135 | def compute_weight(self, module): 136 | weight = getattr(module, self.name + "_orig") 137 | fan_in = weight.data.size(1) * weight.data[0][0].numel() 138 | 139 | return weight * sqrt(2 / fan_in) 140 | 141 | @staticmethod 142 | def apply(module, name): 143 | fn = EqualLR(name) 144 | 145 | weight = getattr(module, name) 146 | del module._parameters[name] 147 | module.register_parameter(name + "_orig", nn.Parameter(weight.data)) 148 | module.register_forward_pre_hook(fn) 149 | 150 | return fn 151 | 152 | def __call__(self, module, input): 153 | weight = self.compute_weight(module) 154 | setattr(module, self.name, weight) 155 | 156 | 157 | def equal_lr(module, name="weight"): 158 | EqualLR.apply(module, name) 159 | 160 | return module 161 | -------------------------------------------------------------------------------- /src/models/changeformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/models/changeformer/__init__.py -------------------------------------------------------------------------------- /src/models/tiny_cd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/a-change-detection-reality-check/2be028cfce014852016a912b59b4a16d91352f61/src/models/tiny_cd/__init__.py -------------------------------------------------------------------------------- /src/models/tiny_cd/change_classifier.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import Tensor 3 | from torch.nn import Identity, Module, ModuleList 4 | 5 | from .layers import MixingBlock, MixingMaskAttentionBlock, PixelwiseLinear, UpMask 6 | 7 | 8 | class TinyCD(Module): 9 | def __init__(self, bkbn_name="efficientnet_b4", pretrained=True, output_layer_bkbn="3", freeze_backbone=False): 10 | super().__init__() 11 | 12 | # Load the pretrained backbone according to parameters: 13 | self._backbone = _get_backbone(bkbn_name, pretrained, output_layer_bkbn, freeze_backbone) 14 | 15 | # Initialize mixing blocks: 16 | self._first_mix = MixingMaskAttentionBlock(6, 3, [3, 10, 5], [10, 5, 1]) 17 | self._mixing_mask = ModuleList( 18 | [ 19 | MixingMaskAttentionBlock(48, 24, [24, 12, 6], [12, 6, 1]), 20 | MixingMaskAttentionBlock(64, 32, [32, 16, 8], [16, 8, 1]), 21 | MixingBlock(112, 56), 22 | ] 23 | ) 24 | 25 | # Initialize Upsampling blocks: 26 | self._up = ModuleList([UpMask(2, 56, 64), UpMask(2, 64, 64), UpMask(2, 64, 32)]) 27 | 28 | # Final classification layer: 29 | self._classify = PixelwiseLinear([32, 16, 8], [16, 8, 2], Identity()) 30 | 31 | def forward(self, ref: Tensor, test: Tensor) -> Tensor: 32 | features = self._encode(ref, test) 33 | latents = self._decode(features) 34 | return self._classify(latents) 35 | 36 | def _encode(self, ref, test) -> list[Tensor]: 37 | features = [self._first_mix(ref, test)] 38 | for num, layer in enumerate(self._backbone): 39 | ref, test = layer(ref), layer(test) 40 | if num != 0: 41 | features.append(self._mixing_mask[num - 1](ref, test)) 42 | return features 43 | 44 | def _decode(self, features) -> Tensor: 45 | upping = features[-1] 46 | for i, j in enumerate(range(-2, -5, -1)): 47 | upping = self._up[i](upping, features[j]) 48 | return upping 49 | 50 | 51 | def _get_backbone(bkbn_name, pretrained, output_layer_bkbn, freeze_backbone) -> ModuleList: 52 | # The whole model: 53 | entire_model = getattr(torchvision.models, bkbn_name)(pretrained=pretrained).features 54 | 55 | # Slicing it: 56 | derived_model = ModuleList([]) 57 | for name, layer in entire_model.named_children(): 58 | derived_model.append(layer) 59 | if name == output_layer_bkbn: 60 | break 61 | 62 | # Freezing the backbone weights: 63 | if freeze_backbone: 64 | for param in derived_model.parameters(): 65 | param.requires_grad = False 66 | return derived_model 67 | -------------------------------------------------------------------------------- /src/models/tiny_cd/layers.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor, reshape, stack 2 | from torch.nn import Conv2d, InstanceNorm2d, Module, PReLU, Sequential, Upsample 3 | 4 | 5 | class PixelwiseLinear(Module): 6 | def __init__(self, fin: list[int], fout: list[int], last_activation: Module = None) -> None: 7 | assert len(fout) == len(fin) 8 | super().__init__() 9 | 10 | n = len(fin) 11 | self._linears = Sequential( 12 | *[ 13 | Sequential( 14 | Conv2d(fin[i], fout[i], kernel_size=1, bias=True), 15 | PReLU() if i < n - 1 or last_activation is None else last_activation, 16 | ) 17 | for i in range(n) 18 | ] 19 | ) 20 | 21 | def forward(self, x: Tensor) -> Tensor: 22 | # Processing the tensor: 23 | return self._linears(x) 24 | 25 | 26 | class MixingBlock(Module): 27 | def __init__(self, ch_in: int, ch_out: int): 28 | super().__init__() 29 | self._convmix = Sequential(Conv2d(ch_in, ch_out, 3, groups=ch_out, padding=1), PReLU(), InstanceNorm2d(ch_out)) 30 | 31 | def forward(self, x: Tensor, y: Tensor) -> Tensor: 32 | # Packing the tensors and interleaving the channels: 33 | mixed = stack((x, y), dim=2) 34 | mixed = reshape(mixed, (x.shape[0], -1, x.shape[2], x.shape[3])) 35 | 36 | # Mixing: 37 | return self._convmix(mixed) 38 | 39 | 40 | class MixingMaskAttentionBlock(Module): 41 | """use the grouped convolution to make a sort of attention""" 42 | 43 | def __init__(self, ch_in: int, ch_out: int, fin: list[int], fout: list[int], generate_masked: bool = False): 44 | super().__init__() 45 | self._mixing = MixingBlock(ch_in, ch_out) 46 | self._linear = PixelwiseLinear(fin, fout) 47 | self._final_normalization = InstanceNorm2d(ch_out) if generate_masked else None 48 | self._mixing_out = MixingBlock(ch_in, ch_out) if generate_masked else None 49 | 50 | def forward(self, x: Tensor, y: Tensor) -> Tensor: 51 | z_mix = self._mixing(x, y) 52 | z = self._linear(z_mix) 53 | z_mix_out = 0 if self._mixing_out is None else self._mixing_out(x, y) 54 | 55 | return z if self._final_normalization is None else self._final_normalization(z_mix_out * z) 56 | 57 | 58 | class UpMask(Module): 59 | def __init__(self, scale_factor: float, nin: int, nout: int): 60 | super().__init__() 61 | self._upsample = Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=True) 62 | self._convolution = Sequential( 63 | Conv2d(nin, nin, 3, 1, groups=nin, padding=1), 64 | PReLU(), 65 | InstanceNorm2d(nin), 66 | Conv2d(nin, nout, kernel_size=1, stride=1), 67 | PReLU(), 68 | InstanceNorm2d(nout), 69 | ) 70 | 71 | def forward(self, x: Tensor, y: Tensor | None = None) -> Tensor: 72 | x = self._upsample(x) 73 | if y is not None: 74 | x = x * y 75 | return self._convolution(x) 76 | -------------------------------------------------------------------------------- /test_levircd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import lightning 6 | import pandas as pd 7 | from src.change_detection import ChangeDetectionTask 8 | from src.datasets.levircd import LEVIRCDDataModule 9 | from tqdm import tqdm 10 | 11 | 12 | def main(args): 13 | lightning.seed_everything(0) 14 | checkpoints = glob.glob(f"{args.ckpt_root}/**/checkpoints/epoch*.ckpt") 15 | runs = [ckpt.split(os.sep)[-3] for ckpt in checkpoints] 16 | 17 | metrics = {} 18 | for run, ckpt in tqdm(zip(runs, checkpoints, strict=False), total=len(runs)): 19 | datamodule = LEVIRCDDataModule( 20 | root=args.root, batch_size=args.batch_size, patch_size=256, num_workers=args.workers 21 | ) 22 | module = ChangeDetectionTask.load_from_checkpoint(ckpt, map_location="cpu") 23 | trainer = lightning.Trainer( 24 | accelerator=args.accelerator, devices=[args.device], logger=False, precision="16-mixed" 25 | ) 26 | metrics[run] = trainer.test(model=module, datamodule=datamodule)[0] 27 | metrics[run]["model"] = module.hparams.model 28 | 29 | metrics = pd.DataFrame.from_dict(metrics, orient="index") 30 | metrics.to_csv(args.output_filename) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--root", type=str, default="./data/levircd") 36 | parser.add_argument("--ckpt-root", type=str, default="lightning_logs") 37 | parser.add_argument("--batch-size", type=int, default=8) 38 | parser.add_argument("--workers", type=int, default=16) 39 | parser.add_argument("--accelerator", type=str, default="gpu") 40 | parser.add_argument("--device", type=int, default=0) 41 | parser.add_argument("--output-filename", type=str, default="metrics.csv") 42 | args = parser.parse_args() 43 | main(args) 44 | -------------------------------------------------------------------------------- /test_whucd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | import lightning 6 | import pandas as pd 7 | from src.change_detection import ChangeDetectionTask 8 | from src.datasets.whucd import WHUCDDataModule 9 | from tqdm import tqdm 10 | 11 | 12 | def main(args): 13 | lightning.seed_everything(0) 14 | checkpoints = glob.glob(f"{args.ckpt_root}/**/checkpoints/epoch*.ckpt") 15 | runs = [ckpt.split(os.sep)[-3] for ckpt in checkpoints] 16 | 17 | metrics = {} 18 | for run, ckpt in tqdm(zip(runs, checkpoints, strict=False), total=len(runs)): 19 | datamodule = WHUCDDataModule( 20 | root=args.root, batch_size=args.batch_size, patch_size=256, num_workers=args.workers 21 | ) 22 | module = ChangeDetectionTask.load_from_checkpoint(ckpt, map_location="cpu") 23 | trainer = lightning.Trainer( 24 | accelerator=args.accelerator, devices=[args.device], logger=False, precision="16-mixed" 25 | ) 26 | metrics[run] = trainer.test(model=module, datamodule=datamodule)[0] 27 | metrics[run]["model"] = module.hparams.model 28 | 29 | metrics = pd.DataFrame.from_dict(metrics, orient="index") 30 | metrics.to_csv(args.output_filename) 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--root", type=str, default="/workspace/storage/data/whucd-chipped") 36 | parser.add_argument("--ckpt-root", type=str, default="lightning_logs") 37 | parser.add_argument("--batch-size", type=int, default=8) 38 | parser.add_argument("--workers", type=int, default=16) 39 | parser.add_argument("--accelerator", type=str, default="gpu") 40 | parser.add_argument("--device", type=int, default=0) 41 | parser.add_argument("--output-filename", type=str, default="metrics.csv") 42 | args = parser.parse_args() 43 | main(args) 44 | -------------------------------------------------------------------------------- /train_levircd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import lightning 4 | from lightning.pytorch.callbacks import ModelCheckpoint 5 | from src.change_detection import ChangeDetectionTask 6 | from src.datasets.levircd import LEVIRCDDataModule 7 | 8 | 9 | def main(args): 10 | for seed in range(args.num_seeds): 11 | lightning.seed_everything(seed) 12 | datamodule = LEVIRCDDataModule( 13 | root=args.root, batch_size=args.batch_size, patch_size=256, num_workers=args.workers 14 | ) 15 | datamodule.train_root = args.train_root 16 | module = ChangeDetectionTask( 17 | model=args.model, backbone=args.backbone, weights=True, in_channels=3, num_classes=2, loss="ce", lr=args.lr 18 | ) 19 | 20 | callbacks = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=1) 21 | trainer = lightning.Trainer( 22 | accelerator=args.accelerator, 23 | devices=[args.device], 24 | logger=True, 25 | precision="16-mixed", 26 | max_epochs=args.epochs, 27 | log_every_n_steps=10, 28 | default_root_dir=f"logs-levircd-{args.model}", 29 | callbacks=[callbacks], 30 | ) 31 | trainer.fit(model=module, datamodule=datamodule) 32 | trainer.test(datamodule=datamodule, ckpt_path="best") 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--root", type=str, default="./data/levircd") 38 | parser.add_argument("--train-root", type=str, default="./data/levircd-train-chipped") 39 | parser.add_argument( 40 | "--model", 41 | type=str, 42 | default="unet", 43 | choices=["unet", "fcsiamconc", "fcsiamdiff", "changeformer", "tinycd", "bit"], 44 | ) 45 | parser.add_argument( 46 | "--backbone", type=str, default="resnet50", help="only works with unet, fcsiamdiff, or fcsiamconc" 47 | ) 48 | parser.add_argument("--epochs", type=int, default=200) 49 | parser.add_argument("--batch-size", type=int, default=8) 50 | parser.add_argument("--workers", type=int, default=8) 51 | parser.add_argument("--lr", type=float, default=0.01) 52 | parser.add_argument("--accelerator", type=str, default="gpu") 53 | parser.add_argument("--device", type=int, default=0) 54 | parser.add_argument("--num_seeds", type=int, default=10) 55 | args = parser.parse_args() 56 | main(args) 57 | -------------------------------------------------------------------------------- /train_whucd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import lightning 4 | from lightning.pytorch.callbacks import ModelCheckpoint 5 | from src.change_detection import ChangeDetectionTask 6 | from src.datasets.whucd import WHUCDDataModule 7 | 8 | 9 | def main(args): 10 | for seed in range(args.num_seeds): 11 | lightning.seed_everything(seed) 12 | datamodule = WHUCDDataModule( 13 | val_split_pct=0.1, root=args.root, batch_size=args.batch_size, patch_size=256, num_workers=args.workers 14 | ) 15 | module = ChangeDetectionTask( 16 | model=args.model, backbone=args.backbone, weights=True, in_channels=3, num_classes=2, loss="ce", lr=args.lr 17 | ) 18 | 19 | callbacks = ModelCheckpoint(monitor="val_loss", save_last=True, save_top_k=1) 20 | trainer = lightning.Trainer( 21 | accelerator=args.accelerator, 22 | devices=[args.device], 23 | logger=True, 24 | precision="16-mixed", 25 | max_epochs=args.epochs, 26 | log_every_n_steps=10, 27 | default_root_dir=f"logs-whucd-{args.model}", 28 | callbacks=[callbacks], 29 | ) 30 | trainer.fit(model=module, datamodule=datamodule) 31 | trainer.test(datamodule=datamodule, ckpt_path="best") 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--root", type=str, default="/workspace/storage/data/whucd-chipped") 37 | parser.add_argument( 38 | "--model", 39 | type=str, 40 | default="unet", 41 | choices=["unet", "fcsiamconc", "fcsiamdiff", "changeformer", "tinycd", "bit"], 42 | ) 43 | parser.add_argument( 44 | "--backbone", type=str, default="resnet50", help="only works with unet, fcsiamdiff, or fcsiamconc" 45 | ) 46 | parser.add_argument("--epochs", type=int, default=200) 47 | parser.add_argument("--batch-size", type=int, default=8) 48 | parser.add_argument("--workers", type=int, default=8) 49 | parser.add_argument("--lr", type=float, default=0.01) 50 | parser.add_argument("--accelerator", type=str, default="gpu") 51 | parser.add_argument("--device", type=int, default=0) 52 | parser.add_argument("--num_seeds", type=int, default=10) 53 | args = parser.parse_args() 54 | main(args) 55 | --------------------------------------------------------------------------------