├── README.md ├── code ├── CRAPL-LICENSE.txt ├── Dockerfile ├── aux │ └── fully_convolutional_change_detection │ │ ├── .gitignore │ │ ├── fresunet.py │ │ ├── fully-convolutional-change-detection.ipynb │ │ ├── siamunet_conc.py │ │ ├── siamunet_diff.py │ │ └── unet.py ├── data_loader.py └── test.py └── previews └── preview.png /README.md: -------------------------------------------------------------------------------- 1 | # Fusing Multi-modal Data for Supervised Change Detection 2 | 3 | ![Paper preview](previews/preview.png) 4 | > 5 | > _Exemplary observations utilized in our multi-modal bi-temporal change detection work. Left: Rows of optical Sentinel-2 data as well as co-registered radar Sentinel-1 data. Each column denotes a different time point, pre- and post change. Right: The associated change map. The Sentinel-2 observations and pixel-wise annotations are by Daudt et al (2018), we curated and provide additional Sentinel-1 SAR data._ 6 | ---- 7 | This repository presents the multi-modal change detection SAR Sentinel-1 data (available [here](https://mediatum.ub.tum.de/1619966), paper [here](https://www.int-arch-photogramm-remote-sens-spatial-inf-sci.net/XLIII-B3-2021/243/2021/isprs-archives-XLIII-B3-2021-243-2021.pdf)) utilized in the work of 8 | > Ebel, Patrick und Saha, Sudipan und Zhu, Xiao Xiang (2021) Fusing Multi-modal Data for Supervised Change Detection. ISPRS. XXIV ISPRS Congress 2021, 04 - 10 July 2021, Nice, France / Virtual. paper [here](https://www.int-arch-photogramm-remote-sens-spatial-inf-sci.net/XLIII-B3-2021/243/2021/isprs-archives-XLIII-B3-2021-243-2021.pdf) 9 | 10 | Please also consider the work of 11 | > Daudt, R. C., Le Saux, B., & Boulch, A. (2018, October). Fully convolutional siamese networks for change detection. In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 12 | 13 | whose original data set we built upon and extended. The original Sentinel-2 observations and pixel-wise annotations are available [here](https://ieee-dataport.org/open-access/oscd-onera-satellite-change-detection). We curated and provide additional Sentinel-1 SAR measurements, co-registered and temporally aligned with the original data in order to provide an opportunity for multi-sensor change detection. You can find the complementary Sentinel-1 SAR data [here](https://mediatum.ub.tum.de/1619966). 14 | 15 | ### Updates: 16 | 17 | 1. Provided code under ``` \code```. Credits: The code builds on Rodrigo's [repository](https://github.com/rcdaudt/fully_convolutional_change_detection) and extends it. 18 | 19 | 2. You may also be interested in our follow-up publications 20 | > S. Saha, P. Ebel and X. X. Zhu, "Self-Supervised Multisensor Change Detection," in IEEE Transactions on Geoscience and Remote Sensing, doi: 10.1109/TGRS.2021.3109957. [url](https://ieeexplore.ieee.org/document/9538396), [code](https://gitlab.lrz.de/ai4eo/cd/tree/main/sarOpticalMultisensorTgrs2021) 21 | 22 | and 23 | 24 | > S. Saha, M. Shahzad, P. Ebel and X. X. Zhu, "Supervised Change Detection Using Pre-Change Optical-SAR and Post-Change SAR Data," in IEEE Journal of Selected Topics in Applied Earth Observations and Remote Sensing, 2022, doi: 10.1109/JSTARS.2022.3206898. [url](https://ieeexplore.ieee.org/document/9901395), [code](https://github.com/sudipansaha/optSarSarCdJstars2022) 25 | 26 | which build on this data set, released with our ISPRS Congress publication. 27 | -------------------------------------------------------------------------------- /code/CRAPL-LICENSE.txt: -------------------------------------------------------------------------------- 1 | THE CRAPL v0 BETA 1 2 | 3 | 4 | 0. Information about the CRAPL 5 | 6 | If you have questions or concerns about the CRAPL, or you need more 7 | information about this license, please contact: 8 | 9 | Matthew Might 10 | http://matt.might.net/ 11 | 12 | 13 | I. Preamble 14 | 15 | Science thrives on openness. 16 | 17 | In modern science, it is often infeasible to replicate claims without 18 | access to the software underlying those claims. 19 | 20 | Let's all be honest: when scientists write code, aesthetics and 21 | software engineering principles take a back seat to having running, 22 | working code before a deadline. 23 | 24 | So, let's release the ugly. And, let's be proud of that. 25 | 26 | 27 | II. Definitions 28 | 29 | 1. "This License" refers to version 0 beta 1 of the Community 30 | Research and Academic Programming License (the CRAPL). 31 | 32 | 2. "The Program" refers to the medley of source code, shell scripts, 33 | executables, objects, libraries and build files supplied to You, 34 | or these files as modified by You. 35 | 36 | [Any appearance of design in the Program is purely coincidental and 37 | should not in any way be mistaken for evidence of thoughtful 38 | software construction.] 39 | 40 | 3. "You" refers to the person or persons brave and daft enough to use 41 | the Program. 42 | 43 | 4. "The Documentation" refers to the Program. 44 | 45 | 5. "The Author" probably refers to the caffeine-addled graduate 46 | student that got the Program to work moments before a submission 47 | deadline. 48 | 49 | 50 | III. Terms 51 | 52 | 1. By reading this sentence, You have agreed to the terms and 53 | conditions of this License. 54 | 55 | 2. If the Program shows any evidence of having been properly tested 56 | or verified, You will disregard this evidence. 57 | 58 | 3. You agree to hold the Author free from shame, embarrassment or 59 | ridicule for any hacks, kludges or leaps of faith found within the 60 | Program. 61 | 62 | 4. You recognize that any request for support for the Program will be 63 | discarded with extreme prejudice. 64 | 65 | 5. The Author reserves all rights to the Program, except for any 66 | rights granted under any additional licenses attached to the 67 | Program. 68 | 69 | 70 | IV. Permissions 71 | 72 | 1. You are permitted to use the Program to validate published 73 | scientific claims. 74 | 75 | 2. You are permitted to use the Program to validate scientific claims 76 | submitted for peer review, under the condition that You keep 77 | modifications to the Program confidential until those claims have 78 | been published. 79 | 80 | 3. You are permitted to use and/or modify the Program for the 81 | validation of novel scientific claims if You make a good-faith 82 | attempt to notify the Author of Your work and Your claims prior to 83 | submission for publication. 84 | 85 | 4. If You publicly release any claims or data that were supported or 86 | generated by the Program or a modification thereof, in whole or in 87 | part, You will release any inputs supplied to the Program and any 88 | modifications You made to the Progam. This License will be in 89 | effect for the modified program. 90 | 91 | 92 | V. Disclaimer of Warranty 93 | 94 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 95 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 96 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT 97 | WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT 98 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 99 | A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND 100 | PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE 101 | DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR 102 | CORRECTION. 103 | 104 | 105 | VI. Limitation of Liability 106 | 107 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 108 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR 109 | CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 110 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES 111 | ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT 112 | NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR 113 | LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM 114 | TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER 115 | PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 116 | -------------------------------------------------------------------------------- /code/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.8.1-cuda11.1-cudnn8-runtime 2 | 3 | # install dependencies 4 | RUN conda install -c conda-forge cupy 5 | RUN conda install -c conda-forge opencv 6 | RUN pip install scipy sklearn rasterio natsort matplotlib scikit-image pandas tqdm natsort 7 | 8 | # add directories and files 9 | RUN mkdir -p ./aux 10 | ADD aux ./aux 11 | ADD data_loader.py ./data_loader.py 12 | ADD test.py ./test.py 13 | -------------------------------------------------------------------------------- /code/aux/fully_convolutional_change_detection/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /code/aux/fully_convolutional_change_detection/fresunet.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R.C., Le Saux, B., Boulch, A. and Gousseau, Y., 2019. Multitask learning for large-scale semantic change detection. Computer Vision and Image Understanding, 187, p.102783. 4 | 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.modules.padding import ReplicationPad2d 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | "3x3 convolution with padding" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) 15 | 16 | 17 | class BasicBlock_ss(nn.Module): 18 | 19 | def __init__(self, inplanes, planes = None, subsamp=1): 20 | super(BasicBlock_ss, self).__init__() 21 | if planes == None: 22 | planes = inplanes * subsamp 23 | self.conv1 = conv3x3(inplanes, planes) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.subsamp = subsamp 29 | self.doit = planes != inplanes 30 | if self.doit: 31 | self.couple = nn.Conv2d(inplanes, planes, kernel_size=1) 32 | self.bnc = nn.BatchNorm2d(planes) 33 | 34 | def forward(self, x): 35 | if self.doit: 36 | residual = self.couple(x) 37 | residual = self.bnc(residual) 38 | else: 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | if self.subsamp > 1: 46 | out = F.max_pool2d(out, kernel_size=self.subsamp, stride=self.subsamp) 47 | residual = F.max_pool2d(residual, kernel_size=self.subsamp, stride=self.subsamp) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | 59 | class BasicBlock_us(nn.Module): 60 | 61 | def __init__(self, inplanes, upsamp=1): 62 | super(BasicBlock_us, self).__init__() 63 | planes = int(inplanes / upsamp) # assumes integer result, fix later 64 | self.conv1 = nn.ConvTranspose2d(inplanes, planes, kernel_size=3, padding=1, stride=upsamp, output_padding=1) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.conv2 = conv3x3(planes, planes) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.upsamp = upsamp 70 | self.couple = nn.ConvTranspose2d(inplanes, planes, kernel_size=3, padding=1, stride=upsamp, output_padding=1) 71 | self.bnc = nn.BatchNorm2d(planes) 72 | 73 | def forward(self, x): 74 | residual = self.couple(x) 75 | residual = self.bnc(residual) 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | 85 | out += residual 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | 91 | class FresUNet(nn.Module): 92 | """FresUNet segmentation network.""" 93 | 94 | def __init__(self, input_nbr, label_nbr): 95 | """Init FresUNet fields.""" 96 | super(FresUNet, self).__init__() 97 | 98 | self.input_nbr = input_nbr 99 | 100 | cur_depth = input_nbr 101 | 102 | base_depth = 8 103 | 104 | # Encoding stage 1 105 | self.encres1_1 = BasicBlock_ss(cur_depth, planes = base_depth) 106 | cur_depth = base_depth 107 | d1 = base_depth 108 | self.encres1_2 = BasicBlock_ss(cur_depth, subsamp=2) 109 | cur_depth *= 2 110 | 111 | # Encoding stage 2 112 | self.encres2_1 = BasicBlock_ss(cur_depth) 113 | d2 = cur_depth 114 | self.encres2_2 = BasicBlock_ss(cur_depth, subsamp=2) 115 | cur_depth *= 2 116 | 117 | # Encoding stage 3 118 | self.encres3_1 = BasicBlock_ss(cur_depth) 119 | d3 = cur_depth 120 | self.encres3_2 = BasicBlock_ss(cur_depth, subsamp=2) 121 | cur_depth *= 2 122 | 123 | # Encoding stage 4 124 | self.encres4_1 = BasicBlock_ss(cur_depth) 125 | d4 = cur_depth 126 | self.encres4_2 = BasicBlock_ss(cur_depth, subsamp=2) 127 | cur_depth *= 2 128 | 129 | # Decoding stage 4 130 | self.decres4_1 = BasicBlock_ss(cur_depth) 131 | self.decres4_2 = BasicBlock_us(cur_depth, upsamp=2) 132 | cur_depth = int(cur_depth/2) 133 | 134 | # Decoding stage 3 135 | self.decres3_1 = BasicBlock_ss(cur_depth + d4, planes = cur_depth) 136 | self.decres3_2 = BasicBlock_us(cur_depth, upsamp=2) 137 | cur_depth = int(cur_depth/2) 138 | 139 | # Decoding stage 2 140 | self.decres2_1 = BasicBlock_ss(cur_depth + d3, planes = cur_depth) 141 | self.decres2_2 = BasicBlock_us(cur_depth, upsamp=2) 142 | cur_depth = int(cur_depth/2) 143 | 144 | # Decoding stage 1 145 | self.decres1_1 = BasicBlock_ss(cur_depth + d2, planes = cur_depth) 146 | self.decres1_2 = BasicBlock_us(cur_depth, upsamp=2) 147 | cur_depth = int(cur_depth/2) 148 | 149 | # Output 150 | self.coupling = nn.Conv2d(cur_depth + d1, label_nbr, kernel_size=1) 151 | self.sm = nn.LogSoftmax(dim=1) 152 | 153 | def forward(self, x1, x2): 154 | 155 | x = torch.cat((x1, x2), 1) 156 | 157 | s1_1 = x.size() 158 | x1 = self.encres1_1(x) 159 | x = self.encres1_2(x1) 160 | 161 | s2_1 = x.size() 162 | x2 = self.encres2_1(x) 163 | x = self.encres2_2(x2) 164 | 165 | s3_1 = x.size() 166 | x3 = self.encres3_1(x) 167 | x = self.encres3_2(x3) 168 | 169 | s4_1 = x.size() 170 | x4 = self.encres4_1(x) 171 | x = self.encres4_2(x4) 172 | 173 | x = self.decres4_1(x) 174 | x = self.decres4_2(x) 175 | s4_2 = x.size() 176 | pad4 = ReplicationPad2d((0, s4_1[3] - s4_2[3], 0, s4_1[2] - s4_2[2])) 177 | x = pad4(x) 178 | 179 | # x = self.decres3_1(x) 180 | x = self.decres3_1(torch.cat((x, x4), 1)) 181 | x = self.decres3_2(x) 182 | s3_2 = x.size() 183 | pad3 = ReplicationPad2d((0, s3_1[3] - s3_2[3], 0, s3_1[2] - s3_2[2])) 184 | x = pad3(x) 185 | 186 | x = self.decres2_1(torch.cat((x, x3), 1)) 187 | x = self.decres2_2(x) 188 | s2_2 = x.size() 189 | pad2 = ReplicationPad2d((0, s2_1[3] - s2_2[3], 0, s2_1[2] - s2_2[2])) 190 | x = pad2(x) 191 | 192 | x = self.decres1_1(torch.cat((x, x2), 1)) 193 | x = self.decres1_2(x) 194 | s1_2 = x.size() 195 | pad1 = ReplicationPad2d((0, s1_1[3] - s1_2[3], 0, s1_1[2] - s1_2[2])) 196 | x = pad1(x) 197 | 198 | x = self.coupling(torch.cat((x, x1), 1)) 199 | x = self.sm(x) 200 | 201 | return x -------------------------------------------------------------------------------- /code/aux/fully_convolutional_change_detection/fully-convolutional-change-detection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Fully Convolutional Networks for Change Detection\n", 8 | "\n", 9 | "Example code for training the network presented in the paper:\n", 10 | "\n", 11 | "```\n", 12 | "Daudt, R.C., Le Saux, B. and Boulch, A., 2018, October. Fully convolutional siamese networks for change detection. In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.\n", 13 | "```\n", 14 | "\n", 15 | "Code uses the OSCD dataset:\n", 16 | "\n", 17 | "```\n", 18 | "Daudt, R.C., Le Saux, B., Boulch, A. and Gousseau, Y., 2018, July. Urban change detection for multispectral earth observation using convolutional neural networks. In IGARSS 2018-2018 IEEE International Geoscience and Remote Sensing Symposium (pp. 2115-2118). IEEE.\n", 19 | "```\n", 20 | "\n", 21 | "\n", 22 | "FresUNet architecture from paper:\n", 23 | "\n", 24 | "```\n", 25 | "Daudt, R.C., Le Saux, B., Boulch, A. and Gousseau, Y., 2019. Multitask learning for large-scale semantic change detection. Computer Vision and Image Understanding, 187, p.102783.\n", 26 | "```\n", 27 | "\n", 28 | "Please consider all relevant papers if you use this code." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# Rodrigo Daudt\n", 38 | "# rcdaudt.github.io\n", 39 | "# rodrigo.daudt@onera.fr" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# Imports\n", 49 | "\n", 50 | "# PyTorch\n", 51 | "import torch\n", 52 | "import torch.nn as nn\n", 53 | "from torch.utils.data import Dataset, DataLoader\n", 54 | "from torch.autograd import Variable\n", 55 | "import torchvision.transforms as tr\n", 56 | "\n", 57 | "# Models\n", 58 | "from unet import Unet\n", 59 | "from siamunet_conc import SiamUnet_conc\n", 60 | "from siamunet_diff import SiamUnet_diff\n", 61 | "from fresunet import FresUNet\n", 62 | "\n", 63 | "# Other\n", 64 | "import os\n", 65 | "import numpy as np\n", 66 | "import random\n", 67 | "from skimage import io\n", 68 | "from scipy.ndimage import zoom\n", 69 | "import matplotlib.pyplot as plt\n", 70 | "%matplotlib inline\n", 71 | "from tqdm import tqdm as tqdm\n", 72 | "from pandas import read_csv\n", 73 | "from math import floor, ceil, sqrt, exp\n", 74 | "from IPython import display\n", 75 | "import time\n", 76 | "from itertools import chain\n", 77 | "import time\n", 78 | "import warnings\n", 79 | "from pprint import pprint\n", 80 | "\n", 81 | "\n", 82 | "\n", 83 | "print('IMPORTS OK')\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# Global Variables' Definitions\n", 100 | "\n", 101 | "PATH_TO_DATASET = './OSCD/'\n", 102 | "IS_PROTOTYPE = False\n", 103 | "\n", 104 | "FP_MODIFIER = 10 # Tuning parameter, use 1 if unsure\n", 105 | "\n", 106 | "BATCH_SIZE = 32\n", 107 | "PATCH_SIDE = 96\n", 108 | "N_EPOCHS = 50\n", 109 | "\n", 110 | "NORMALISE_IMGS = True\n", 111 | "\n", 112 | "TRAIN_STRIDE = int(PATCH_SIDE/2) - 1\n", 113 | "\n", 114 | "TYPE = 3 # 0-RGB | 1-RGBIr | 2-All bands s.t. resulution <= 20m | 3-All bands\n", 115 | "\n", 116 | "LOAD_TRAINED = False\n", 117 | "\n", 118 | "DATA_AUG = True\n", 119 | "\n", 120 | "\n", 121 | "print('DEFINITIONS OK')" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# Functions\n", 131 | "\n", 132 | "def adjust_shape(I, s):\n", 133 | " \"\"\"Adjust shape of grayscale image I to s.\"\"\"\n", 134 | " \n", 135 | " # crop if necesary\n", 136 | " I = I[:s[0],:s[1]]\n", 137 | " si = I.shape\n", 138 | " \n", 139 | " # pad if necessary \n", 140 | " p0 = max(0,s[0] - si[0])\n", 141 | " p1 = max(0,s[1] - si[1])\n", 142 | " \n", 143 | " return np.pad(I,((0,p0),(0,p1)),'edge')\n", 144 | " \n", 145 | "\n", 146 | "def read_sentinel_img(path):\n", 147 | " \"\"\"Read cropped Sentinel-2 image: RGB bands.\"\"\"\n", 148 | " im_name = os.listdir(path)[0][:-7]\n", 149 | " r = io.imread(path + im_name + \"B04.tif\")\n", 150 | " g = io.imread(path + im_name + \"B03.tif\")\n", 151 | " b = io.imread(path + im_name + \"B02.tif\")\n", 152 | " \n", 153 | " I = np.stack((r,g,b),axis=2).astype('float')\n", 154 | " \n", 155 | " if NORMALISE_IMGS:\n", 156 | " I = (I - I.mean()) / I.std()\n", 157 | "\n", 158 | " return I\n", 159 | "\n", 160 | "def read_sentinel_img_4(path):\n", 161 | " \"\"\"Read cropped Sentinel-2 image: RGB and NIR bands.\"\"\"\n", 162 | " im_name = os.listdir(path)[0][:-7]\n", 163 | " r = io.imread(path + im_name + \"B04.tif\")\n", 164 | " g = io.imread(path + im_name + \"B03.tif\")\n", 165 | " b = io.imread(path + im_name + \"B02.tif\")\n", 166 | " nir = io.imread(path + im_name + \"B08.tif\")\n", 167 | " \n", 168 | " I = np.stack((r,g,b,nir),axis=2).astype('float')\n", 169 | " \n", 170 | " if NORMALISE_IMGS:\n", 171 | " I = (I - I.mean()) / I.std()\n", 172 | "\n", 173 | " return I\n", 174 | "\n", 175 | "def read_sentinel_img_leq20(path):\n", 176 | " \"\"\"Read cropped Sentinel-2 image: bands with resolution less than or equals to 20m.\"\"\"\n", 177 | " im_name = os.listdir(path)[0][:-7]\n", 178 | " \n", 179 | " r = io.imread(path + im_name + \"B04.tif\")\n", 180 | " s = r.shape\n", 181 | " g = io.imread(path + im_name + \"B03.tif\")\n", 182 | " b = io.imread(path + im_name + \"B02.tif\")\n", 183 | " nir = io.imread(path + im_name + \"B08.tif\")\n", 184 | " \n", 185 | " ir1 = adjust_shape(zoom(io.imread(path + im_name + \"B05.tif\"),2),s)\n", 186 | " ir2 = adjust_shape(zoom(io.imread(path + im_name + \"B06.tif\"),2),s)\n", 187 | " ir3 = adjust_shape(zoom(io.imread(path + im_name + \"B07.tif\"),2),s)\n", 188 | " nir2 = adjust_shape(zoom(io.imread(path + im_name + \"B8A.tif\"),2),s)\n", 189 | " swir2 = adjust_shape(zoom(io.imread(path + im_name + \"B11.tif\"),2),s)\n", 190 | " swir3 = adjust_shape(zoom(io.imread(path + im_name + \"B12.tif\"),2),s)\n", 191 | " \n", 192 | " I = np.stack((r,g,b,nir,ir1,ir2,ir3,nir2,swir2,swir3),axis=2).astype('float')\n", 193 | " \n", 194 | " if NORMALISE_IMGS:\n", 195 | " I = (I - I.mean()) / I.std()\n", 196 | "\n", 197 | " return I\n", 198 | "\n", 199 | "def read_sentinel_img_leq60(path):\n", 200 | " \"\"\"Read cropped Sentinel-2 image: all bands.\"\"\"\n", 201 | " im_name = os.listdir(path)[0][:-7]\n", 202 | " \n", 203 | " r = io.imread(path + im_name + \"B04.tif\")\n", 204 | " s = r.shape\n", 205 | " g = io.imread(path + im_name + \"B03.tif\")\n", 206 | " b = io.imread(path + im_name + \"B02.tif\")\n", 207 | " nir = io.imread(path + im_name + \"B08.tif\")\n", 208 | " \n", 209 | " ir1 = adjust_shape(zoom(io.imread(path + im_name + \"B05.tif\"),2),s)\n", 210 | " ir2 = adjust_shape(zoom(io.imread(path + im_name + \"B06.tif\"),2),s)\n", 211 | " ir3 = adjust_shape(zoom(io.imread(path + im_name + \"B07.tif\"),2),s)\n", 212 | " nir2 = adjust_shape(zoom(io.imread(path + im_name + \"B8A.tif\"),2),s)\n", 213 | " swir2 = adjust_shape(zoom(io.imread(path + im_name + \"B11.tif\"),2),s)\n", 214 | " swir3 = adjust_shape(zoom(io.imread(path + im_name + \"B12.tif\"),2),s)\n", 215 | " \n", 216 | " uv = adjust_shape(zoom(io.imread(path + im_name + \"B01.tif\"),6),s)\n", 217 | " wv = adjust_shape(zoom(io.imread(path + im_name + \"B09.tif\"),6),s)\n", 218 | " swirc = adjust_shape(zoom(io.imread(path + im_name + \"B10.tif\"),6),s)\n", 219 | " \n", 220 | " I = np.stack((r,g,b,nir,ir1,ir2,ir3,nir2,swir2,swir3,uv,wv,swirc),axis=2).astype('float')\n", 221 | " \n", 222 | " if NORMALISE_IMGS:\n", 223 | " I = (I - I.mean()) / I.std()\n", 224 | "\n", 225 | " return I\n", 226 | "\n", 227 | "def read_sentinel_img_trio(path):\n", 228 | " \"\"\"Read cropped Sentinel-2 image pair and change map.\"\"\"\n", 229 | "# read images\n", 230 | " if TYPE == 0:\n", 231 | " I1 = read_sentinel_img(path + '/imgs_1/')\n", 232 | " I2 = read_sentinel_img(path + '/imgs_2/')\n", 233 | " elif TYPE == 1:\n", 234 | " I1 = read_sentinel_img_4(path + '/imgs_1/')\n", 235 | " I2 = read_sentinel_img_4(path + '/imgs_2/')\n", 236 | " elif TYPE == 2:\n", 237 | " I1 = read_sentinel_img_leq20(path + '/imgs_1/')\n", 238 | " I2 = read_sentinel_img_leq20(path + '/imgs_2/')\n", 239 | " elif TYPE == 3:\n", 240 | " I1 = read_sentinel_img_leq60(path + '/imgs_1/')\n", 241 | " I2 = read_sentinel_img_leq60(path + '/imgs_2/')\n", 242 | " \n", 243 | " cm = io.imread(path + '/cm/cm.png', as_gray=True) != 0\n", 244 | " \n", 245 | " # crop if necessary\n", 246 | " s1 = I1.shape\n", 247 | " s2 = I2.shape\n", 248 | " I2 = np.pad(I2,((0, s1[0] - s2[0]), (0, s1[1] - s2[1]), (0,0)),'edge')\n", 249 | " \n", 250 | " \n", 251 | " return I1, I2, cm\n", 252 | "\n", 253 | "\n", 254 | "\n", 255 | "def reshape_for_torch(I):\n", 256 | " \"\"\"Transpose image for PyTorch coordinates.\"\"\"\n", 257 | "# out = np.swapaxes(I,1,2)\n", 258 | "# out = np.swapaxes(out,0,1)\n", 259 | "# out = out[np.newaxis,:]\n", 260 | " out = I.transpose((2, 0, 1))\n", 261 | " return torch.from_numpy(out)\n", 262 | "\n", 263 | "\n", 264 | "\n", 265 | "class ChangeDetectionDataset(Dataset):\n", 266 | " \"\"\"Change Detection dataset class, used for both training and test data.\"\"\"\n", 267 | "\n", 268 | " def __init__(self, path, train = True, patch_side = 96, stride = None, use_all_bands = False, transform=None):\n", 269 | " \"\"\"\n", 270 | " Args:\n", 271 | " csv_file (string): Path to the csv file with annotations.\n", 272 | " root_dir (string): Directory with all the images.\n", 273 | " transform (callable, optional): Optional transform to be applied\n", 274 | " on a sample.\n", 275 | " \"\"\"\n", 276 | " \n", 277 | " # basics\n", 278 | " self.transform = transform\n", 279 | " self.path = path\n", 280 | " self.patch_side = patch_side\n", 281 | " if not stride:\n", 282 | " self.stride = 1\n", 283 | " else:\n", 284 | " self.stride = stride\n", 285 | " \n", 286 | " if train:\n", 287 | " fname = 'train.txt'\n", 288 | " else:\n", 289 | " fname = 'test.txt'\n", 290 | " \n", 291 | "# print(path + fname)\n", 292 | " self.names = read_csv(path + fname).columns\n", 293 | " self.n_imgs = self.names.shape[0]\n", 294 | " \n", 295 | " n_pix = 0\n", 296 | " true_pix = 0\n", 297 | " \n", 298 | " \n", 299 | " # load images\n", 300 | " self.imgs_1 = {}\n", 301 | " self.imgs_2 = {}\n", 302 | " self.change_maps = {}\n", 303 | " self.n_patches_per_image = {}\n", 304 | " self.n_patches = 0\n", 305 | " self.patch_coords = []\n", 306 | " for im_name in tqdm(self.names):\n", 307 | " # load and store each image\n", 308 | " I1, I2, cm = read_sentinel_img_trio(self.path + im_name)\n", 309 | " self.imgs_1[im_name] = reshape_for_torch(I1)\n", 310 | " self.imgs_2[im_name] = reshape_for_torch(I2)\n", 311 | " self.change_maps[im_name] = cm\n", 312 | " \n", 313 | " s = cm.shape\n", 314 | " n_pix += np.prod(s)\n", 315 | " true_pix += cm.sum()\n", 316 | " \n", 317 | " # calculate the number of patches\n", 318 | " s = self.imgs_1[im_name].shape\n", 319 | " n1 = ceil((s[1] - self.patch_side + 1) / self.stride)\n", 320 | " n2 = ceil((s[2] - self.patch_side + 1) / self.stride)\n", 321 | " n_patches_i = n1 * n2\n", 322 | " self.n_patches_per_image[im_name] = n_patches_i\n", 323 | " self.n_patches += n_patches_i\n", 324 | " \n", 325 | " # generate path coordinates\n", 326 | " for i in range(n1):\n", 327 | " for j in range(n2):\n", 328 | " # coordinates in (x1, x2, y1, y2)\n", 329 | " current_patch_coords = (im_name, \n", 330 | " [self.stride*i, self.stride*i + self.patch_side, self.stride*j, self.stride*j + self.patch_side],\n", 331 | " [self.stride*(i + 1), self.stride*(j + 1)])\n", 332 | " self.patch_coords.append(current_patch_coords)\n", 333 | " \n", 334 | " self.weights = [ FP_MODIFIER * 2 * true_pix / n_pix, 2 * (n_pix - true_pix) / n_pix]\n", 335 | " \n", 336 | " \n", 337 | "\n", 338 | " def get_img(self, im_name):\n", 339 | " return self.imgs_1[im_name], self.imgs_2[im_name], self.change_maps[im_name]\n", 340 | "\n", 341 | " def __len__(self):\n", 342 | " return self.n_patches\n", 343 | "\n", 344 | " def __getitem__(self, idx):\n", 345 | " current_patch_coords = self.patch_coords[idx]\n", 346 | " im_name = current_patch_coords[0]\n", 347 | " limits = current_patch_coords[1]\n", 348 | " centre = current_patch_coords[2]\n", 349 | " \n", 350 | " I1 = self.imgs_1[im_name][:, limits[0]:limits[1], limits[2]:limits[3]]\n", 351 | " I2 = self.imgs_2[im_name][:, limits[0]:limits[1], limits[2]:limits[3]]\n", 352 | " \n", 353 | " label = self.change_maps[im_name][limits[0]:limits[1], limits[2]:limits[3]]\n", 354 | " label = torch.from_numpy(1*np.array(label)).float()\n", 355 | " \n", 356 | " sample = {'I1': I1, 'I2': I2, 'label': label}\n", 357 | " \n", 358 | " if self.transform:\n", 359 | " sample = self.transform(sample)\n", 360 | "\n", 361 | " return sample\n", 362 | "\n", 363 | "\n", 364 | "\n", 365 | "class RandomFlip(object):\n", 366 | " \"\"\"Flip randomly the images in a sample.\"\"\"\n", 367 | "\n", 368 | "# def __init__(self):\n", 369 | "# return\n", 370 | "\n", 371 | " def __call__(self, sample):\n", 372 | " I1, I2, label = sample['I1'], sample['I2'], sample['label']\n", 373 | " \n", 374 | " if random.random() > 0.5:\n", 375 | " I1 = I1.numpy()[:,:,::-1].copy()\n", 376 | " I1 = torch.from_numpy(I1)\n", 377 | " I2 = I2.numpy()[:,:,::-1].copy()\n", 378 | " I2 = torch.from_numpy(I2)\n", 379 | " label = label.numpy()[:,::-1].copy()\n", 380 | " label = torch.from_numpy(label)\n", 381 | "\n", 382 | " return {'I1': I1, 'I2': I2, 'label': label}\n", 383 | "\n", 384 | "\n", 385 | "\n", 386 | "class RandomRot(object):\n", 387 | " \"\"\"Rotate randomly the images in a sample.\"\"\"\n", 388 | "\n", 389 | "# def __init__(self):\n", 390 | "# return\n", 391 | "\n", 392 | " def __call__(self, sample):\n", 393 | " I1, I2, label = sample['I1'], sample['I2'], sample['label']\n", 394 | " \n", 395 | " n = random.randint(0, 3)\n", 396 | " if n:\n", 397 | " I1 = sample['I1'].numpy()\n", 398 | " I1 = np.rot90(I1, n, axes=(1, 2)).copy()\n", 399 | " I1 = torch.from_numpy(I1)\n", 400 | " I2 = sample['I2'].numpy()\n", 401 | " I2 = np.rot90(I2, n, axes=(1, 2)).copy()\n", 402 | " I2 = torch.from_numpy(I2)\n", 403 | " label = sample['label'].numpy()\n", 404 | " label = np.rot90(label, n, axes=(0, 1)).copy()\n", 405 | " label = torch.from_numpy(label)\n", 406 | "\n", 407 | " return {'I1': I1, 'I2': I2, 'label': label}\n", 408 | "\n", 409 | "\n", 410 | "\n", 411 | "\n", 412 | "\n", 413 | "print('UTILS OK')\n" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "# Dataset\n", 430 | "\n", 431 | "\n", 432 | "if DATA_AUG:\n", 433 | " data_transform = tr.Compose([RandomFlip(), RandomRot()])\n", 434 | "else:\n", 435 | " data_transform = None\n", 436 | "\n", 437 | "\n", 438 | " \n", 439 | "\n", 440 | "train_dataset = ChangeDetectionDataset(PATH_TO_DATASET, train = True, stride = TRAIN_STRIDE, transform=data_transform)\n", 441 | "weights = torch.FloatTensor(train_dataset.weights).cuda()\n", 442 | "print(weights)\n", 443 | "train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)\n", 444 | "\n", 445 | "test_dataset = ChangeDetectionDataset(PATH_TO_DATASET, train = False, stride = TRAIN_STRIDE)\n", 446 | "test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)\n", 447 | "\n", 448 | "\n", 449 | "print('DATASETS OK')" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [ 458 | "# print(weights)" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "# 0-RGB | 1-RGBIr | 2-All bands s.t. resulution <= 20m | 3-All bands\n", 468 | "\n", 469 | "if TYPE == 0:\n", 470 | "# net, net_name = Unet(2*3, 2), 'FC-EF'\n", 471 | "# net, net_name = SiamUnet_conc(3, 2), 'FC-Siam-conc'\n", 472 | "# net, net_name = SiamUnet_diff(3, 2), 'FC-Siam-diff'\n", 473 | " net, net_name = FresUNet(2*3, 2), 'FresUNet'\n", 474 | "elif TYPE == 1:\n", 475 | "# net, net_name = Unet(2*4, 2), 'FC-EF'\n", 476 | "# net, net_name = SiamUnet_conc(4, 2), 'FC-Siam-conc'\n", 477 | "# net, net_name = SiamUnet_diff(4, 2), 'FC-Siam-diff'\n", 478 | " net, net_name = FresUNet(2*4, 2), 'FresUNet'\n", 479 | "elif TYPE == 2:\n", 480 | "# net, net_name = Unet(2*10, 2), 'FC-EF'\n", 481 | "# net, net_name = SiamUnet_conc(10, 2), 'FC-Siam-conc'\n", 482 | "# net, net_name = SiamUnet_diff(10, 2), 'FC-Siam-diff'\n", 483 | " net, net_name = FresUNet(2*10, 2), 'FresUNet'\n", 484 | "elif TYPE == 3:\n", 485 | "# net, net_name = Unet(2*13, 2), 'FC-EF'\n", 486 | "# net, net_name = SiamUnet_conc(13, 2), 'FC-Siam-conc'\n", 487 | "# net, net_name = SiamUnet_diff(13, 2), 'FC-Siam-diff'\n", 488 | " net, net_name = FresUNet(2*13, 2), 'FresUNet'\n", 489 | "\n", 490 | "\n", 491 | "net.cuda()\n", 492 | "\n", 493 | "criterion = nn.NLLLoss(weight=weights) # to be used with logsoftmax output\n", 494 | "\n", 495 | "print('NETWORK OK')" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "def count_parameters(model):\n", 505 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 506 | "\n", 507 | "print('Number of trainable parameters:', count_parameters(net))" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [ 516 | "# net.load_state_dict(torch.load('net-best_epoch-1_fm-0.7394933126157746.pth.tar'))\n", 517 | "\n", 518 | "def train(n_epochs = N_EPOCHS, save = True):\n", 519 | " t = np.linspace(1, n_epochs, n_epochs)\n", 520 | " \n", 521 | " epoch_train_loss = 0 * t\n", 522 | " epoch_train_accuracy = 0 * t\n", 523 | " epoch_train_change_accuracy = 0 * t\n", 524 | " epoch_train_nochange_accuracy = 0 * t\n", 525 | " epoch_train_precision = 0 * t\n", 526 | " epoch_train_recall = 0 * t\n", 527 | " epoch_train_Fmeasure = 0 * t\n", 528 | " epoch_test_loss = 0 * t\n", 529 | " epoch_test_accuracy = 0 * t\n", 530 | " epoch_test_change_accuracy = 0 * t\n", 531 | " epoch_test_nochange_accuracy = 0 * t\n", 532 | " epoch_test_precision = 0 * t\n", 533 | " epoch_test_recall = 0 * t\n", 534 | " epoch_test_Fmeasure = 0 * t\n", 535 | " \n", 536 | "# mean_acc = 0\n", 537 | "# best_mean_acc = 0\n", 538 | " fm = 0\n", 539 | " best_fm = 0\n", 540 | " \n", 541 | " lss = 1000\n", 542 | " best_lss = 1000\n", 543 | " \n", 544 | " plt.figure(num=1)\n", 545 | " plt.figure(num=2)\n", 546 | " plt.figure(num=3)\n", 547 | " \n", 548 | " \n", 549 | " optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-4)\n", 550 | "# optimizer = torch.optim.Adam(net.parameters(), lr=0.0005)\n", 551 | " \n", 552 | " \n", 553 | " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95)\n", 554 | " \n", 555 | " \n", 556 | " for epoch_index in tqdm(range(n_epochs)):\n", 557 | " net.train()\n", 558 | " print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(N_EPOCHS))\n", 559 | "\n", 560 | " tot_count = 0\n", 561 | " tot_loss = 0\n", 562 | " tot_accurate = 0\n", 563 | " class_correct = list(0. for i in range(2))\n", 564 | " class_total = list(0. for i in range(2))\n", 565 | "# for batch_index, batch in enumerate(tqdm(data_loader)):\n", 566 | " for batch in train_loader:\n", 567 | " I1 = Variable(batch['I1'].float().cuda())\n", 568 | " I2 = Variable(batch['I2'].float().cuda())\n", 569 | " label = torch.squeeze(Variable(batch['label'].cuda()))\n", 570 | "\n", 571 | " optimizer.zero_grad()\n", 572 | " output = net(I1, I2)\n", 573 | " loss = criterion(output, label.long())\n", 574 | " loss.backward()\n", 575 | " optimizer.step()\n", 576 | " \n", 577 | " scheduler.step()\n", 578 | "\n", 579 | "\n", 580 | " epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index], cl_acc, pr_rec = test(train_dataset)\n", 581 | " epoch_train_nochange_accuracy[epoch_index] = cl_acc[0]\n", 582 | " epoch_train_change_accuracy[epoch_index] = cl_acc[1]\n", 583 | " epoch_train_precision[epoch_index] = pr_rec[0]\n", 584 | " epoch_train_recall[epoch_index] = pr_rec[1]\n", 585 | " epoch_train_Fmeasure[epoch_index] = pr_rec[2]\n", 586 | " \n", 587 | "# epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test(test_dataset)\n", 588 | " epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test(test_dataset)\n", 589 | " epoch_test_nochange_accuracy[epoch_index] = cl_acc[0]\n", 590 | " epoch_test_change_accuracy[epoch_index] = cl_acc[1]\n", 591 | " epoch_test_precision[epoch_index] = pr_rec[0]\n", 592 | " epoch_test_recall[epoch_index] = pr_rec[1]\n", 593 | " epoch_test_Fmeasure[epoch_index] = pr_rec[2]\n", 594 | "\n", 595 | " plt.figure(num=1)\n", 596 | " plt.clf()\n", 597 | " l1_1, = plt.plot(t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1], label='Train loss')\n", 598 | " l1_2, = plt.plot(t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1], label='Test loss')\n", 599 | " plt.legend(handles=[l1_1, l1_2])\n", 600 | " plt.grid()\n", 601 | "# plt.gcf().gca().set_ylim(bottom = 0)\n", 602 | " plt.gcf().gca().set_xlim(left = 0)\n", 603 | " plt.title('Loss')\n", 604 | " display.clear_output(wait=True)\n", 605 | " display.display(plt.gcf())\n", 606 | "\n", 607 | " plt.figure(num=2)\n", 608 | " plt.clf()\n", 609 | " l2_1, = plt.plot(t[:epoch_index + 1], epoch_train_accuracy[:epoch_index + 1], label='Train accuracy')\n", 610 | " l2_2, = plt.plot(t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1], label='Test accuracy')\n", 611 | " plt.legend(handles=[l2_1, l2_2])\n", 612 | " plt.grid()\n", 613 | " plt.gcf().gca().set_ylim(0, 100)\n", 614 | "# plt.gcf().gca().set_ylim(bottom = 0)\n", 615 | "# plt.gcf().gca().set_xlim(left = 0)\n", 616 | " plt.title('Accuracy')\n", 617 | " display.clear_output(wait=True)\n", 618 | " display.display(plt.gcf())\n", 619 | "\n", 620 | " plt.figure(num=3)\n", 621 | " plt.clf()\n", 622 | " l3_1, = plt.plot(t[:epoch_index + 1], epoch_train_nochange_accuracy[:epoch_index + 1], label='Train accuracy: no change')\n", 623 | " l3_2, = plt.plot(t[:epoch_index + 1], epoch_train_change_accuracy[:epoch_index + 1], label='Train accuracy: change')\n", 624 | " l3_3, = plt.plot(t[:epoch_index + 1], epoch_test_nochange_accuracy[:epoch_index + 1], label='Test accuracy: no change')\n", 625 | " l3_4, = plt.plot(t[:epoch_index + 1], epoch_test_change_accuracy[:epoch_index + 1], label='Test accuracy: change')\n", 626 | " plt.legend(handles=[l3_1, l3_2, l3_3, l3_4])\n", 627 | " plt.grid()\n", 628 | " plt.gcf().gca().set_ylim(0, 100)\n", 629 | "# plt.gcf().gca().set_ylim(bottom = 0)\n", 630 | "# plt.gcf().gca().set_xlim(left = 0)\n", 631 | " plt.title('Accuracy per class')\n", 632 | " display.clear_output(wait=True)\n", 633 | " display.display(plt.gcf())\n", 634 | "\n", 635 | " plt.figure(num=4)\n", 636 | " plt.clf()\n", 637 | " l4_1, = plt.plot(t[:epoch_index + 1], epoch_train_precision[:epoch_index + 1], label='Train precision')\n", 638 | " l4_2, = plt.plot(t[:epoch_index + 1], epoch_train_recall[:epoch_index + 1], label='Train recall')\n", 639 | " l4_3, = plt.plot(t[:epoch_index + 1], epoch_train_Fmeasure[:epoch_index + 1], label='Train Dice/F1')\n", 640 | " l4_4, = plt.plot(t[:epoch_index + 1], epoch_test_precision[:epoch_index + 1], label='Test precision')\n", 641 | " l4_5, = plt.plot(t[:epoch_index + 1], epoch_test_recall[:epoch_index + 1], label='Test recall')\n", 642 | " l4_6, = plt.plot(t[:epoch_index + 1], epoch_test_Fmeasure[:epoch_index + 1], label='Test Dice/F1')\n", 643 | " plt.legend(handles=[l4_1, l4_2, l4_3, l4_4, l4_5, l4_6])\n", 644 | " plt.grid()\n", 645 | " plt.gcf().gca().set_ylim(0, 1)\n", 646 | "# plt.gcf().gca().set_ylim(bottom = 0)\n", 647 | "# plt.gcf().gca().set_xlim(left = 0)\n", 648 | " plt.title('Precision, Recall and F-measure')\n", 649 | " display.clear_output(wait=True)\n", 650 | " display.display(plt.gcf())\n", 651 | " \n", 652 | " \n", 653 | "# mean_acc = (epoch_test_nochange_accuracy[epoch_index] + epoch_test_change_accuracy[epoch_index])/2\n", 654 | "# if mean_acc > best_mean_acc:\n", 655 | "# best_mean_acc = mean_acc\n", 656 | "# save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_acc-' + str(mean_acc) + '.pth.tar'\n", 657 | "# torch.save(net.state_dict(), save_str)\n", 658 | " \n", 659 | " \n", 660 | "# fm = pr_rec[2]\n", 661 | " fm = epoch_train_Fmeasure[epoch_index]\n", 662 | " if fm > best_fm:\n", 663 | " best_fm = fm\n", 664 | " save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_fm-' + str(fm) + '.pth.tar'\n", 665 | " torch.save(net.state_dict(), save_str)\n", 666 | " \n", 667 | " lss = epoch_train_loss[epoch_index]\n", 668 | " if lss < best_lss:\n", 669 | " best_lss = lss\n", 670 | " save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_loss-' + str(lss) + '.pth.tar'\n", 671 | " torch.save(net.state_dict(), save_str)\n", 672 | " \n", 673 | " \n", 674 | "# print('Epoch loss: ' + str(tot_loss/tot_count))\n", 675 | " if save:\n", 676 | " im_format = 'png'\n", 677 | " # im_format = 'eps'\n", 678 | "\n", 679 | " plt.figure(num=1)\n", 680 | " plt.savefig(net_name + '-01-loss.' + im_format)\n", 681 | "\n", 682 | " plt.figure(num=2)\n", 683 | " plt.savefig(net_name + '-02-accuracy.' + im_format)\n", 684 | "\n", 685 | " plt.figure(num=3)\n", 686 | " plt.savefig(net_name + '-03-accuracy-per-class.' + im_format)\n", 687 | "\n", 688 | " plt.figure(num=4)\n", 689 | " plt.savefig(net_name + '-04-prec-rec-fmeas.' + im_format)\n", 690 | " \n", 691 | " out = {'train_loss': epoch_train_loss[-1],\n", 692 | " 'train_accuracy': epoch_train_accuracy[-1],\n", 693 | " 'train_nochange_accuracy': epoch_train_nochange_accuracy[-1],\n", 694 | " 'train_change_accuracy': epoch_train_change_accuracy[-1],\n", 695 | " 'test_loss': epoch_test_loss[-1],\n", 696 | " 'test_accuracy': epoch_test_accuracy[-1],\n", 697 | " 'test_nochange_accuracy': epoch_test_nochange_accuracy[-1],\n", 698 | " 'test_change_accuracy': epoch_test_change_accuracy[-1]}\n", 699 | " \n", 700 | " print('pr_c, rec_c, f_meas, pr_nc, rec_nc')\n", 701 | " print(pr_rec)\n", 702 | " \n", 703 | " return out\n", 704 | "\n", 705 | "L = 1024\n", 706 | "N = 2\n", 707 | "\n", 708 | "def test(dset):\n", 709 | " net.eval()\n", 710 | " tot_loss = 0\n", 711 | " tot_count = 0\n", 712 | " tot_accurate = 0\n", 713 | " \n", 714 | " n = 2\n", 715 | " class_correct = list(0. for i in range(n))\n", 716 | " class_total = list(0. for i in range(n))\n", 717 | " class_accuracy = list(0. for i in range(n))\n", 718 | " \n", 719 | " tp = 0\n", 720 | " tn = 0\n", 721 | " fp = 0\n", 722 | " fn = 0\n", 723 | " \n", 724 | " for img_index in dset.names:\n", 725 | " I1_full, I2_full, cm_full = dset.get_img(img_index)\n", 726 | " \n", 727 | " s = cm_full.shape\n", 728 | " \n", 729 | "\n", 730 | " steps0 = np.arange(0,s[0],ceil(s[0]/N))\n", 731 | " steps1 = np.arange(0,s[1],ceil(s[1]/N))\n", 732 | " for ii in range(N):\n", 733 | " for jj in range(N):\n", 734 | " xmin = steps0[ii]\n", 735 | " if ii == N-1:\n", 736 | " xmax = s[0]\n", 737 | " else:\n", 738 | " xmax = steps0[ii+1]\n", 739 | " ymin = jj\n", 740 | " if jj == N-1:\n", 741 | " ymax = s[1]\n", 742 | " else:\n", 743 | " ymax = steps1[jj+1]\n", 744 | " I1 = I1_full[:, xmin:xmax, ymin:ymax]\n", 745 | " I2 = I2_full[:, xmin:xmax, ymin:ymax]\n", 746 | " cm = cm_full[xmin:xmax, ymin:ymax]\n", 747 | "\n", 748 | " I1 = Variable(torch.unsqueeze(I1, 0).float()).cuda()\n", 749 | " I2 = Variable(torch.unsqueeze(I2, 0).float()).cuda()\n", 750 | " cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm),0).float()).cuda()\n", 751 | "\n", 752 | "\n", 753 | " output = net(I1, I2)\n", 754 | " loss = criterion(output, cm.long())\n", 755 | " # print(loss)\n", 756 | " tot_loss += loss.data * np.prod(cm.size())\n", 757 | " tot_count += np.prod(cm.size())\n", 758 | "\n", 759 | " _, predicted = torch.max(output.data, 1)\n", 760 | "\n", 761 | " c = (predicted.int() == cm.data.int())\n", 762 | " for i in range(c.size(1)):\n", 763 | " for j in range(c.size(2)):\n", 764 | " l = int(cm.data[0, i, j])\n", 765 | " class_correct[l] += c[0, i, j]\n", 766 | " class_total[l] += 1\n", 767 | " \n", 768 | " pr = (predicted.int() > 0).cpu().numpy()\n", 769 | " gt = (cm.data.int() > 0).cpu().numpy()\n", 770 | " \n", 771 | " tp += np.logical_and(pr, gt).sum()\n", 772 | " tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()\n", 773 | " fp += np.logical_and(pr, np.logical_not(gt)).sum()\n", 774 | " fn += np.logical_and(np.logical_not(pr), gt).sum()\n", 775 | " \n", 776 | " net_loss = tot_loss/tot_count\n", 777 | " net_accuracy = 100 * (tp + tn)/tot_count\n", 778 | " \n", 779 | " for i in range(n):\n", 780 | " class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)\n", 781 | "\n", 782 | " prec = tp / (tp + fp)\n", 783 | " rec = tp / (tp + fn)\n", 784 | " f_meas = 2 * prec * rec / (prec + rec)\n", 785 | " prec_nc = tn / (tn + fn)\n", 786 | " rec_nc = tn / (tn + fp)\n", 787 | " \n", 788 | " pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]\n", 789 | " \n", 790 | " return net_loss, net_accuracy, class_accuracy, pr_rec\n", 791 | " \n", 792 | " \n", 793 | "\n", 794 | "\n", 795 | "\n", 796 | " \n" 797 | ] 798 | }, 799 | { 800 | "cell_type": "code", 801 | "execution_count": null, 802 | "metadata": {}, 803 | "outputs": [], 804 | "source": [ 805 | "if LOAD_TRAINED:\n", 806 | " net.load_state_dict(torch.load('net_final.pth.tar'))\n", 807 | " print('LOAD OK')\n", 808 | "else:\n", 809 | " t_start = time.time()\n", 810 | " out_dic = train()\n", 811 | " t_end = time.time()\n", 812 | " print(out_dic)\n", 813 | " print('Elapsed time:')\n", 814 | " print(t_end - t_start)\n", 815 | "\n", 816 | "\n" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "execution_count": null, 822 | "metadata": {}, 823 | "outputs": [], 824 | "source": [ 825 | "if not LOAD_TRAINED:\n", 826 | " torch.save(net.state_dict(), 'net_final.pth.tar')\n", 827 | " print('SAVE OK')" 828 | ] 829 | }, 830 | { 831 | "cell_type": "code", 832 | "execution_count": null, 833 | "metadata": {}, 834 | "outputs": [], 835 | "source": [ 836 | "\n", 837 | "\n", 838 | "def save_test_results(dset):\n", 839 | " for name in tqdm(dset.names):\n", 840 | " with warnings.catch_warnings():\n", 841 | " I1, I2, cm = dset.get_img(name)\n", 842 | " I1 = Variable(torch.unsqueeze(I1, 0).float()).cuda()\n", 843 | " I2 = Variable(torch.unsqueeze(I2, 0).float()).cuda()\n", 844 | " out = net(I1, I2)\n", 845 | " _, predicted = torch.max(out.data, 1)\n", 846 | " I = np.stack((255*cm,255*np.squeeze(predicted.cpu().numpy()),255*cm),2)\n", 847 | " io.imsave(f'{net_name}-{name}.png',I)\n", 848 | "\n", 849 | "\n", 850 | "\n", 851 | "t_start = time.time()\n", 852 | "# save_test_results(train_dataset)\n", 853 | "save_test_results(test_dataset)\n", 854 | "t_end = time.time()\n", 855 | "print('Elapsed time: {}'.format(t_end - t_start))\n" 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": null, 861 | "metadata": {}, 862 | "outputs": [], 863 | "source": [ 864 | "L = 1024\n", 865 | "\n", 866 | "def kappa(tp, tn, fp, fn):\n", 867 | " N = tp + tn + fp + fn\n", 868 | " p0 = (tp + tn) / N\n", 869 | " pe = ((tp+fp)*(tp+fn) + (tn+fp)*(tn+fn)) / (N * N)\n", 870 | " \n", 871 | " return (p0 - pe) / (1 - pe)\n", 872 | "\n", 873 | "def test(dset):\n", 874 | " net.eval()\n", 875 | " tot_loss = 0\n", 876 | " tot_count = 0\n", 877 | " tot_accurate = 0\n", 878 | " \n", 879 | " n = 2\n", 880 | " class_correct = list(0. for i in range(n))\n", 881 | " class_total = list(0. for i in range(n))\n", 882 | " class_accuracy = list(0. for i in range(n))\n", 883 | " \n", 884 | " tp = 0\n", 885 | " tn = 0\n", 886 | " fp = 0\n", 887 | " fn = 0\n", 888 | " \n", 889 | " for img_index in tqdm(dset.names):\n", 890 | " I1_full, I2_full, cm_full = dset.get_img(img_index)\n", 891 | " \n", 892 | " s = cm_full.shape\n", 893 | " \n", 894 | " for ii in range(ceil(s[0]/L)):\n", 895 | " for jj in range(ceil(s[1]/L)):\n", 896 | " xmin = L*ii\n", 897 | " xmax = min(L*(ii+1),s[1])\n", 898 | " ymin = L*jj\n", 899 | " ymax = min(L*(jj+1),s[1])\n", 900 | " I1 = I1_full[:, xmin:xmax, ymin:ymax]\n", 901 | " I2 = I2_full[:, xmin:xmax, ymin:ymax]\n", 902 | " cm = cm_full[xmin:xmax, ymin:ymax]\n", 903 | "\n", 904 | " I1 = Variable(torch.unsqueeze(I1, 0).float()).cuda()\n", 905 | " I2 = Variable(torch.unsqueeze(I2, 0).float()).cuda()\n", 906 | " cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm),0).float()).cuda()\n", 907 | "\n", 908 | " output = net(I1, I2)\n", 909 | " \n", 910 | " loss = criterion(output, cm.long())\n", 911 | " tot_loss += loss.data * np.prod(cm.size())\n", 912 | " tot_count += np.prod(cm.size())\n", 913 | "\n", 914 | " _, predicted = torch.max(output.data, 1)\n", 915 | "\n", 916 | " c = (predicted.int() == cm.data.int())\n", 917 | " for i in range(c.size(1)):\n", 918 | " for j in range(c.size(2)):\n", 919 | " l = int(cm.data[0, i, j])\n", 920 | " class_correct[l] += c[0, i, j]\n", 921 | " class_total[l] += 1\n", 922 | " \n", 923 | " pr = (predicted.int() > 0).cpu().numpy()\n", 924 | " gt = (cm.data.int() > 0).cpu().numpy()\n", 925 | " \n", 926 | " tp += np.logical_and(pr, gt).sum()\n", 927 | " tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()\n", 928 | " fp += np.logical_and(pr, np.logical_not(gt)).sum()\n", 929 | " fn += np.logical_and(np.logical_not(pr), gt).sum()\n", 930 | " \n", 931 | " net_loss = tot_loss/tot_count \n", 932 | " net_loss = float(net_loss.cpu().numpy())\n", 933 | " \n", 934 | " net_accuracy = 100 * (tp + tn)/tot_count\n", 935 | " \n", 936 | " for i in range(n):\n", 937 | " class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)\n", 938 | " class_accuracy[i] = float(class_accuracy[i].cpu().numpy())\n", 939 | "\n", 940 | " prec = tp / (tp + fp)\n", 941 | " rec = tp / (tp + fn)\n", 942 | " dice = 2 * prec * rec / (prec + rec)\n", 943 | " prec_nc = tn / (tn + fn)\n", 944 | " rec_nc = tn / (tn + fp)\n", 945 | " \n", 946 | " pr_rec = [prec, rec, dice, prec_nc, rec_nc]\n", 947 | " \n", 948 | " k = kappa(tp, tn, fp, fn)\n", 949 | " \n", 950 | " return {'net_loss': net_loss, \n", 951 | " 'net_accuracy': net_accuracy, \n", 952 | " 'class_accuracy': class_accuracy, \n", 953 | " 'precision': prec, \n", 954 | " 'recall': rec, \n", 955 | " 'dice': dice, \n", 956 | " 'kappa': k}\n", 957 | "\n", 958 | "results = test(test_dataset)\n", 959 | "pprint(results)" 960 | ] 961 | }, 962 | { 963 | "cell_type": "code", 964 | "execution_count": null, 965 | "metadata": {}, 966 | "outputs": [], 967 | "source": [] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "execution_count": null, 972 | "metadata": {}, 973 | "outputs": [], 974 | "source": [] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": null, 979 | "metadata": {}, 980 | "outputs": [], 981 | "source": [] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": null, 986 | "metadata": {}, 987 | "outputs": [], 988 | "source": [] 989 | }, 990 | { 991 | "cell_type": "code", 992 | "execution_count": null, 993 | "metadata": {}, 994 | "outputs": [], 995 | "source": [] 996 | } 997 | ], 998 | "metadata": { 999 | "kernelspec": { 1000 | "display_name": "Python 3", 1001 | "language": "python", 1002 | "name": "python3" 1003 | }, 1004 | "language_info": { 1005 | "codemirror_mode": { 1006 | "name": "ipython", 1007 | "version": 3 1008 | }, 1009 | "file_extension": ".py", 1010 | "mimetype": "text/x-python", 1011 | "name": "python", 1012 | "nbconvert_exporter": "python", 1013 | "pygments_lexer": "ipython3", 1014 | "version": "3.7.6" 1015 | } 1016 | }, 1017 | "nbformat": 4, 1018 | "nbformat_minor": 4 1019 | } 1020 | -------------------------------------------------------------------------------- /code/aux/fully_convolutional_change_detection/siamunet_conc.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | # this simple Siamese U-Net uses 2 branches 11 | # 1. before change: S2; S1, weights shared with 2. 12 | # 2. after change: S2; S1, weights shared with 1. 13 | 14 | class SiamUnet_conc(nn.Module): 15 | """SiamUnet_conc segmentation network.""" 16 | 17 | def __init__(self, input_nbr, label_nbr): 18 | super(SiamUnet_conc, self).__init__() 19 | 20 | self.input_nbr = input_nbr 21 | 22 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 23 | self.bn11 = nn.BatchNorm2d(16) 24 | self.do11 = nn.Dropout2d(p=0.2) 25 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 26 | self.bn12 = nn.BatchNorm2d(16) 27 | self.do12 = nn.Dropout2d(p=0.2) 28 | 29 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 30 | self.bn21 = nn.BatchNorm2d(32) 31 | self.do21 = nn.Dropout2d(p=0.2) 32 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 33 | self.bn22 = nn.BatchNorm2d(32) 34 | self.do22 = nn.Dropout2d(p=0.2) 35 | 36 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 37 | self.bn31 = nn.BatchNorm2d(64) 38 | self.do31 = nn.Dropout2d(p=0.2) 39 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 40 | self.bn32 = nn.BatchNorm2d(64) 41 | self.do32 = nn.Dropout2d(p=0.2) 42 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 43 | self.bn33 = nn.BatchNorm2d(64) 44 | self.do33 = nn.Dropout2d(p=0.2) 45 | 46 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 47 | self.bn41 = nn.BatchNorm2d(128) 48 | self.do41 = nn.Dropout2d(p=0.2) 49 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 50 | self.bn42 = nn.BatchNorm2d(128) 51 | self.do42 = nn.Dropout2d(p=0.2) 52 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 53 | self.bn43 = nn.BatchNorm2d(128) 54 | self.do43 = nn.Dropout2d(p=0.2) 55 | 56 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 57 | 58 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1) 59 | self.bn43d = nn.BatchNorm2d(128) 60 | self.do43d = nn.Dropout2d(p=0.2) 61 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 62 | self.bn42d = nn.BatchNorm2d(128) 63 | self.do42d = nn.Dropout2d(p=0.2) 64 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 65 | self.bn41d = nn.BatchNorm2d(64) 66 | self.do41d = nn.Dropout2d(p=0.2) 67 | 68 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 69 | 70 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1) 71 | self.bn33d = nn.BatchNorm2d(64) 72 | self.do33d = nn.Dropout2d(p=0.2) 73 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 74 | self.bn32d = nn.BatchNorm2d(64) 75 | self.do32d = nn.Dropout2d(p=0.2) 76 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 77 | self.bn31d = nn.BatchNorm2d(32) 78 | self.do31d = nn.Dropout2d(p=0.2) 79 | 80 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 81 | 82 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1) 83 | self.bn22d = nn.BatchNorm2d(32) 84 | self.do22d = nn.Dropout2d(p=0.2) 85 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 86 | self.bn21d = nn.BatchNorm2d(16) 87 | self.do21d = nn.Dropout2d(p=0.2) 88 | 89 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 90 | 91 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1) 92 | self.bn12d = nn.BatchNorm2d(16) 93 | self.do12d = nn.Dropout2d(p=0.2) 94 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 95 | 96 | self.sm = nn.LogSoftmax(dim=1) 97 | 98 | def forward(self, s2_1, s2_2, s1_1=None, s1_2=None): 99 | 100 | if s1_1 is not None and s1_2 is not None: 101 | s2_1 = torch.cat((s2_1, s1_1), 1) 102 | s2_2 = torch.cat((s2_2, s1_2), 1) 103 | 104 | """Forward method.""" 105 | # Stage 1 106 | x11 = self.do11(F.relu(self.bn11(self.conv11(s2_1)))) 107 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 108 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 109 | 110 | 111 | # Stage 2 112 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 113 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 114 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 115 | 116 | # Stage 3 117 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 118 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 119 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 120 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 121 | 122 | # Stage 4 123 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 124 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 125 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 126 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 127 | 128 | 129 | #################################################### 130 | # Stage 1 131 | x11 = self.do11(F.relu(self.bn11(self.conv11(s2_2)))) 132 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 133 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 134 | 135 | # Stage 2 136 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 137 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 138 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 139 | 140 | # Stage 3 141 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 142 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 143 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 144 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 145 | 146 | # Stage 4 147 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 148 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 149 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 150 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 151 | 152 | 153 | #################################################### 154 | # Stage 4d 155 | x4d = self.upconv4(x4p) 156 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 157 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) 158 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 159 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 160 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 161 | 162 | # Stage 3d 163 | x3d = self.upconv3(x41d) 164 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 165 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) 166 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 167 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 168 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 169 | 170 | # Stage 2d 171 | x2d = self.upconv2(x31d) 172 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 173 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) 174 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 175 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 176 | 177 | # Stage 1d 178 | x1d = self.upconv1(x21d) 179 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 180 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) 181 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 182 | x11d = self.conv11d(x12d) 183 | 184 | return self.sm(x11d) 185 | 186 | # this complex Siamese U-Net uses 4 branches 187 | # 1. before change: S2, weights shared with 2. 188 | # 2. after change: S2, weights shared with 1. 189 | # 3. before change: S1, weights shared with 4. 190 | # 4. after change: S1, weights shared with 3. 191 | 192 | class SiamUnet_conc_multi(nn.Module): 193 | """SiamUnet_conc segmentation network.""" 194 | 195 | def __init__(self, input_nbr, label_nbr): 196 | super(SiamUnet_conc_multi, self).__init__() 197 | 198 | self.input_nbr_1, self.input_nbr_2 = input_nbr 199 | 200 | ################################# encoder S2 ################################# 201 | 202 | # 16 channels 203 | self.conv11 = nn.Conv2d(self.input_nbr_1, 16, kernel_size=3, padding=1) 204 | self.bn11 = nn.BatchNorm2d(16) 205 | self.do11 = nn.Dropout2d(p=0.2) 206 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 207 | self.bn12 = nn.BatchNorm2d(16) 208 | self.do12 = nn.Dropout2d(p=0.2) 209 | 210 | # 32 channels 211 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 212 | self.bn21 = nn.BatchNorm2d(32) 213 | self.do21 = nn.Dropout2d(p=0.2) 214 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 215 | self.bn22 = nn.BatchNorm2d(32) 216 | self.do22 = nn.Dropout2d(p=0.2) 217 | 218 | # 64 channels 219 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 220 | self.bn31 = nn.BatchNorm2d(64) 221 | self.do31 = nn.Dropout2d(p=0.2) 222 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 223 | self.bn32 = nn.BatchNorm2d(64) 224 | self.do32 = nn.Dropout2d(p=0.2) 225 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 226 | self.bn33 = nn.BatchNorm2d(64) 227 | self.do33 = nn.Dropout2d(p=0.2) 228 | 229 | # 128 channels 230 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 231 | self.bn41 = nn.BatchNorm2d(128) 232 | self.do41 = nn.Dropout2d(p=0.2) 233 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 234 | self.bn42 = nn.BatchNorm2d(128) 235 | self.do42 = nn.Dropout2d(p=0.2) 236 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 237 | self.bn43 = nn.BatchNorm2d(128) 238 | self.do43 = nn.Dropout2d(p=0.2) 239 | 240 | ################################# encoder S1 ################################# 241 | 242 | # 16 channels 243 | self.conv11_b = nn.Conv2d(self.input_nbr_2, 16, kernel_size=3, padding=1) 244 | self.bn11_b = nn.BatchNorm2d(16) 245 | self.do11_b = nn.Dropout2d(p=0.2) 246 | self.conv12_b = nn.Conv2d(16, 16, kernel_size=3, padding=1) 247 | self.bn12_b = nn.BatchNorm2d(16) 248 | self.do12_b = nn.Dropout2d(p=0.2) 249 | 250 | # 32 channels 251 | self.conv21_b = nn.Conv2d(16, 32, kernel_size=3, padding=1) 252 | self.bn21_b = nn.BatchNorm2d(32) 253 | self.do21_b = nn.Dropout2d(p=0.2) 254 | self.conv22_b = nn.Conv2d(32, 32, kernel_size=3, padding=1) 255 | self.bn22_b = nn.BatchNorm2d(32) 256 | self.do22_b = nn.Dropout2d(p=0.2) 257 | 258 | # 64 channels 259 | self.conv31_b = nn.Conv2d(32, 64, kernel_size=3, padding=1) 260 | self.bn31_b = nn.BatchNorm2d(64) 261 | self.do31_b = nn.Dropout2d(p=0.2) 262 | self.conv32_b = nn.Conv2d(64, 64, kernel_size=3, padding=1) 263 | self.bn32_b = nn.BatchNorm2d(64) 264 | self.do32_b = nn.Dropout2d(p=0.2) 265 | self.conv33_b = nn.Conv2d(64, 64, kernel_size=3, padding=1) 266 | self.bn33_b = nn.BatchNorm2d(64) 267 | self.do33_b = nn.Dropout2d(p=0.2) 268 | 269 | # 128 channels 270 | self.conv41_b = nn.Conv2d(64, 128, kernel_size=3, padding=1) 271 | self.bn41_b = nn.BatchNorm2d(128) 272 | self.do41_b = nn.Dropout2d(p=0.2) 273 | self.conv42_b = nn.Conv2d(128, 128, kernel_size=3, padding=1) 274 | self.bn42_b = nn.BatchNorm2d(128) 275 | self.do42_b = nn.Dropout2d(p=0.2) 276 | self.conv43_b = nn.Conv2d(128, 128, kernel_size=3, padding=1) 277 | self.bn43_b = nn.BatchNorm2d(128) 278 | self.do43_b = nn.Dropout2d(p=0.2) 279 | 280 | ################################# decoder ################################# 281 | 282 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 283 | 284 | self.conv43d = nn.ConvTranspose2d(384+128+128, 128, kernel_size=3, padding=1) # added S1+S2 channels here 285 | self.bn43d = nn.BatchNorm2d(128) 286 | self.do43d = nn.Dropout2d(p=0.2) 287 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 288 | self.bn42d = nn.BatchNorm2d(128) 289 | self.do42d = nn.Dropout2d(p=0.2) 290 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 291 | self.bn41d = nn.BatchNorm2d(64) 292 | self.do41d = nn.Dropout2d(p=0.2) 293 | 294 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 295 | 296 | self.conv33d = nn.ConvTranspose2d(192+64+64, 64, kernel_size=3, padding=1) # added S1+S2 channels here 297 | self.bn33d = nn.BatchNorm2d(64) 298 | self.do33d = nn.Dropout2d(p=0.2) 299 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 300 | self.bn32d = nn.BatchNorm2d(64) 301 | self.do32d = nn.Dropout2d(p=0.2) 302 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 303 | self.bn31d = nn.BatchNorm2d(32) 304 | self.do31d = nn.Dropout2d(p=0.2) 305 | 306 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 307 | 308 | self.conv22d = nn.ConvTranspose2d(96+32+32, 32, kernel_size=3, padding=1) # added S1+S2 channels here 309 | self.bn22d = nn.BatchNorm2d(32) 310 | self.do22d = nn.Dropout2d(p=0.2) 311 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 312 | self.bn21d = nn.BatchNorm2d(16) 313 | self.do21d = nn.Dropout2d(p=0.2) 314 | 315 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 316 | 317 | self.conv12d = nn.ConvTranspose2d(48+16+16, 16, kernel_size=3, padding=1) # added S1+S2 channels here 318 | self.bn12d = nn.BatchNorm2d(16) 319 | self.do12d = nn.Dropout2d(p=0.2) 320 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 321 | 322 | self.sm = nn.LogSoftmax(dim=1) 323 | 324 | def forward(self, s2_1, s2_2, s1_1, s1_2): 325 | 326 | """Forward method.""" 327 | 328 | #################################################### encoder S2 #################################################### 329 | 330 | # siamese processing of input s2_1 331 | # Stage 1 332 | x11 = self.do11(F.relu(self.bn11(self.conv11(s2_1)))) 333 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 334 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 335 | 336 | 337 | # Stage 2 338 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 339 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 340 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 341 | 342 | # Stage 3 343 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 344 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 345 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 346 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 347 | 348 | # Stage 4 349 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 350 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 351 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 352 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 353 | 354 | 355 | #################################################### 356 | 357 | # siamese processing of input s2_2 358 | # Stage 1 359 | x11 = self.do11(F.relu(self.bn11(self.conv11(s2_2)))) 360 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 361 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 362 | 363 | # Stage 2 364 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 365 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 366 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 367 | 368 | # Stage 3 369 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 370 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 371 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 372 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 373 | 374 | # Stage 4 375 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 376 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 377 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 378 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 379 | 380 | #################################################### encoder S1 #################################################### 381 | 382 | # siamese processing of input s1_1 383 | # Stage 1 384 | x11_b = self.do11_b(F.relu(self.bn11_b(self.conv11_b(s1_1)))) 385 | x12_1_b = self.do12_b(F.relu(self.bn12_b(self.conv12_b(x11_b)))) 386 | x1p_b = F.max_pool2d(x12_1_b, kernel_size=2, stride=2) 387 | 388 | 389 | # Stage 2 390 | x21_b = self.do21_b(F.relu(self.bn21_b(self.conv21_b(x1p_b)))) 391 | x22_1_b = self.do22_b(F.relu(self.bn22_b(self.conv22_b(x21_b)))) 392 | x2p_b = F.max_pool2d(x22_1_b, kernel_size=2, stride=2) 393 | 394 | # Stage 3 395 | x31_b = self.do31_b(F.relu(self.bn31_b(self.conv31_b(x2p_b)))) 396 | x32_b = self.do32_b(F.relu(self.bn32_b(self.conv32_b(x31_b)))) 397 | x33_1_b = self.do33_b(F.relu(self.bn33_b(self.conv33_b(x32_b)))) 398 | x3p_b = F.max_pool2d(x33_1_b, kernel_size=2, stride=2) 399 | 400 | # Stage 4 401 | x41_b = self.do41_b(F.relu(self.bn41_b(self.conv41_b(x3p_b)))) 402 | x42_b = self.do42_b(F.relu(self.bn42_b(self.conv42_b(x41_b)))) 403 | x43_1_b = self.do43_b(F.relu(self.bn43_b(self.conv43_b(x42_b)))) 404 | x4p_b = F.max_pool2d(x43_1_b, kernel_size=2, stride=2) 405 | 406 | 407 | #################################################### 408 | 409 | # siamese processing of input s1_2 410 | # Stage 1 411 | x11_b = self.do11_b(F.relu(self.bn11_b(self.conv11_b(s1_2)))) 412 | x12_2_b = self.do12_b(F.relu(self.bn12_b(self.conv12_b(x11_b)))) 413 | x1p_b = F.max_pool2d(x12_2_b, kernel_size=2, stride=2) 414 | 415 | # Stage 2 416 | x21_b = self.do21_b(F.relu(self.bn21_b(self.conv21_b(x1p_b)))) 417 | x22_2_b = self.do22_b(F.relu(self.bn22_b(self.conv22_b(x21_b)))) 418 | x2p_b = F.max_pool2d(x22_2_b, kernel_size=2, stride=2) 419 | 420 | # Stage 3 421 | x31_b = self.do31_b(F.relu(self.bn31_b(self.conv31_b(x2p_b)))) 422 | x32_b = self.do32_b(F.relu(self.bn32_b(self.conv32_b(x31_b)))) 423 | x33_2_b = self.do33_b(F.relu(self.bn33_b(self.conv33_b(x32_b)))) 424 | x3p_b = F.max_pool2d(x33_2_b, kernel_size=2, stride=2) 425 | 426 | # Stage 4 427 | x41_b = self.do41_b(F.relu(self.bn41_b(self.conv41_b(x3p_b)))) 428 | x42_b = self.do42_b(F.relu(self.bn42_b(self.conv42_b(x41_b)))) 429 | x43_2_b = self.do43_b(F.relu(self.bn43_b(self.conv43_b(x42_b)))) 430 | x4p_b = F.max_pool2d(x43_2_b, kernel_size=2, stride=2) 431 | 432 | #################################################### decoder #################################################### 433 | # Stage 4d 434 | x4d = self.upconv4(x4p) 435 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 436 | x4d = torch.cat((pad4(x4d), x43_1, x43_2, x43_1_b, x43_2_b), 1) 437 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 438 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 439 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 440 | 441 | # Stage 3d 442 | x3d = self.upconv3(x41d) 443 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 444 | x3d = torch.cat((pad3(x3d), x33_1, x33_2, x33_1_b, x33_2_b), 1) 445 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 446 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 447 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 448 | 449 | # Stage 2d 450 | x2d = self.upconv2(x31d) 451 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 452 | x2d = torch.cat((pad2(x2d), x22_1, x22_2, x22_1_b, x22_2_b), 1) 453 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 454 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 455 | 456 | # Stage 1d 457 | x1d = self.upconv1(x21d) 458 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 459 | x1d = torch.cat((pad1(x1d), x12_1, x12_2, x12_1_b, x12_2_b), 1) 460 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 461 | x11d = self.conv11d(x12d) 462 | 463 | return self.sm(x11d) -------------------------------------------------------------------------------- /code/aux/fully_convolutional_change_detection/siamunet_diff.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_diff(nn.Module): 11 | """SiamUnet_diff segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_diff, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 102 | 103 | 104 | # Stage 2 105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 108 | 109 | # Stage 3 110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 114 | 115 | # Stage 4 116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | 128 | # Stage 2 129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 132 | 133 | # Stage 3 134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 138 | 139 | # Stage 4 140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 144 | 145 | 146 | 147 | # Stage 4d 148 | x4d = self.upconv4(x4p) 149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 150 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) 151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 154 | 155 | # Stage 3d 156 | x3d = self.upconv3(x41d) 157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 158 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) 159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 162 | 163 | # Stage 2d 164 | x2d = self.upconv2(x31d) 165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 166 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) 167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 169 | 170 | # Stage 1d 171 | x1d = self.upconv1(x21d) 172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 173 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) 174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 175 | x11d = self.conv11d(x12d) 176 | 177 | return self.sm(x11d) 178 | 179 | 180 | -------------------------------------------------------------------------------- /code/aux/fully_convolutional_change_detection/unet.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class Unet(nn.Module): 11 | """EF segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(Unet, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | 53 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 54 | 55 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 56 | self.bn43d = nn.BatchNorm2d(128) 57 | self.do43d = nn.Dropout2d(p=0.2) 58 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 59 | self.bn42d = nn.BatchNorm2d(128) 60 | self.do42d = nn.Dropout2d(p=0.2) 61 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 62 | self.bn41d = nn.BatchNorm2d(64) 63 | self.do41d = nn.Dropout2d(p=0.2) 64 | 65 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 66 | 67 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 68 | self.bn33d = nn.BatchNorm2d(64) 69 | self.do33d = nn.Dropout2d(p=0.2) 70 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 71 | self.bn32d = nn.BatchNorm2d(64) 72 | self.do32d = nn.Dropout2d(p=0.2) 73 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 74 | self.bn31d = nn.BatchNorm2d(32) 75 | self.do31d = nn.Dropout2d(p=0.2) 76 | 77 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 78 | 79 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 80 | self.bn22d = nn.BatchNorm2d(32) 81 | self.do22d = nn.Dropout2d(p=0.2) 82 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 83 | self.bn21d = nn.BatchNorm2d(16) 84 | self.do21d = nn.Dropout2d(p=0.2) 85 | 86 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 87 | 88 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 89 | self.bn12d = nn.BatchNorm2d(16) 90 | self.do12d = nn.Dropout2d(p=0.2) 91 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 92 | 93 | self.sm = nn.LogSoftmax(dim=1) 94 | 95 | def forward(self, s2_1, s2_2, s1_1=None, s1_2=None): 96 | 97 | x = torch.cat((s2_1, s2_2), 1) 98 | if s1_1 is not None and s1_2 is not None: x = torch.cat((x, s1_1, s1_2), 1) 99 | 100 | """Forward method.""" 101 | # Stage 1 102 | x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) 103 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 104 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2) 105 | 106 | # Stage 2 107 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 108 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 109 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2) 110 | 111 | # Stage 3 112 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 113 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 114 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 115 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2) 116 | 117 | # Stage 4 118 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 119 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 120 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 121 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2) 122 | 123 | 124 | # Stage 4d 125 | x4d = self.upconv4(x4p) 126 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) 127 | x4d = torch.cat((pad4(x4d), x43), 1) 128 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 129 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 130 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 131 | 132 | # Stage 3d 133 | x3d = self.upconv3(x41d) 134 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) 135 | x3d = torch.cat((pad3(x3d), x33), 1) 136 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 137 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 138 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 139 | 140 | # Stage 2d 141 | x2d = self.upconv2(x31d) 142 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) 143 | x2d = torch.cat((pad2(x2d), x22), 1) 144 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 145 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 146 | 147 | # Stage 1d 148 | x1d = self.upconv1(x21d) 149 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) 150 | x1d = torch.cat((pad1(x1d), x12), 1) 151 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 152 | x11d = self.conv11d(x12d) 153 | 154 | return self.sm(x11d) 155 | 156 | 157 | -------------------------------------------------------------------------------- /code/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import os 4 | import random 5 | import numpy as np 6 | import rasterio 7 | from tqdm import tqdm 8 | from natsort import natsorted 9 | import torchvision.transforms as tr 10 | 11 | def get_window_offsets(img_dim, psize, stride): 12 | max_dim = (np.array(img_dim) // psize) * psize - psize 13 | 14 | ys = np.arange(0, img_dim[0], stride) 15 | xs = np.arange(0, img_dim[1], stride) 16 | 17 | tlc = np.array(np.meshgrid(ys, xs)).T.reshape(-1, 2) 18 | tlc = tlc[tlc[:, 0] <= max_dim[0]] 19 | tlc = tlc[tlc[:, 1] <= max_dim[1]] 20 | 21 | return tlc.astype(int) 22 | 23 | class multiCD(Dataset): 24 | def __init__(self, root, split="all", fp_modifier=10, transform=None, s2_channel_type=3, run_on_onera_patches=False, use_pre_sliced=True, normalize=False): 25 | 26 | self.splits = {"train": ['abudhabi', 'aguasclaras', 'beihai', 'beirut', 'bercy', 'bordeaux', 'cupertino', 'hongkong', 'mumbai', 'nantes', 'paris', 'pisa', 'rennes', 'saclay_e'], 27 | "test": ['brasilia', 'chongqing', 'dubai', 'lasvegas', 'milano', 'montpellier', 'norcia', 'rio', 'saclay_w', 'valencia']} 28 | self.splits["all"] = self.splits["train"] + self.splits["test"] 29 | 30 | self.split = split 31 | self.names = natsorted(self.splits[self.split]) # get list of names of ROI 32 | self.root_dir = root 33 | self.run_on_onera = run_on_onera_patches 34 | self.normalize = normalize # whether to z-score S1 & S2 or not, only do on whole images (not on patches) 35 | 36 | # settings pertaining to offline or online slicing 37 | self.use_pre_sliced = use_pre_sliced 38 | self.patch_indices = [] 39 | self.patch_size = 96 40 | self.stride = 1 # int(self.patch_size/2) - 1 41 | 42 | # 0-RGB | 1-RGBIr | 2-All bands s.t. resolution <= 20m | 3-All bands | 4-All bands Sentinel-2 & Sentinel-1 | 5-RGB bands Sentinel-2 & Sentinel-1 43 | 44 | if s2_channel_type in [0, 5]: 45 | self.s2_channels = 1 + np.arange(1, 4) 46 | elif s2_channel_type == 1: 47 | self.s2_channels = 1 + np.array([1, 2, 3, 7]) 48 | elif s2_channel_type == 2: 49 | self.s2_channels = 1 + np.array([1, 2, 3, 4, 5, 6, 7, 8, 11, 12]) 50 | elif s2_channel_type in [3, 4]: 51 | self.s2_channels = 1 + np.arange(0, 13) 52 | 53 | # keep track of changed pixels and total pixels in samples (of the change map labels) 54 | self.true_pix = 0 55 | self.n_pix = 0 56 | 57 | if not self.use_pre_sliced: 58 | # load full-scene images into memory 59 | self.full_scenes = dict() 60 | for roi in tqdm(self.names): 61 | # get the full-scene images, note: get_img_np_instead_of_tensor is already doing whole-image preprocessing 62 | self.full_scenes[roi] = self.get_img_np_instead_of_tensor(roi) 63 | self.true_pix += np.count_nonzero(self.full_scenes[roi]['label']) 64 | self.n_pix += self.full_scenes[roi]['label'].size 65 | self.full_scenes[roi]['patches'] = get_window_offsets(self.full_scenes[roi]['time_1']['S2'].shape[1:], self.patch_size, self.stride) 66 | # pre-compute lookup tables of patches, to be used with online slicing 67 | self.patch_indices = np.cumsum([0]+[len(self.full_scenes[roi]['patches']) for roi in self.names]) 68 | self.indices_dict = {} 69 | for idx, index in enumerate(self.patch_indices): self.indices_dict[idx] = (self.names + ['end'])[idx] 70 | 71 | if transform: 72 | data_transform = tr.Compose([RandomFlip(), RandomRot()]) 73 | else: 74 | data_transform = None 75 | 76 | # define patch transform for data augmentation 77 | self.transform = data_transform 78 | self.paths = [] if self.split=='test' or not self.use_pre_sliced else self.get_paths() # get a list of paths 79 | self.n_samples = len(self.paths) if self.use_pre_sliced else self.patch_indices[-1] 80 | 81 | # calculate a weighting of pixels 82 | self.weights = [] if self.split=='test' else [fp_modifier * 2 * self.true_pix / self.n_pix, 2 * (self.n_pix - self.true_pix) / self.n_pix] 83 | 84 | def get_paths(self): 85 | paths = [] 86 | modalities = ["S1", "S2"] 87 | time_points = [1, 2] 88 | s2_patch_dir = 'S2_patches_original' if self.run_on_onera else 'S2_patches' 89 | 90 | for roi in tqdm(self.splits[self.split]): 91 | path = os.path.join(self.root_dir, f"{modalities[0]}_patches", roi, f"imgs_{time_points[0]}") 92 | 93 | # get all files 94 | s1_1 = [os.path.join(path, f) for f in os.listdir(path) if (os.path.isfile(os.path.join(path, f)) and ".tif" in f)] 95 | # sort via file names according to dates 96 | s1_1 = natsorted(s1_1) 97 | 98 | # get paired files and check for proper directory structure 99 | s1_2 = [f.replace('imgs_1','imgs_2') for f in s1_1 if os.path.isfile(f.replace('imgs_1','imgs_2'))] 100 | s2_1 = [f.replace('S1_patches', s2_patch_dir) for f in s1_1 if os.path.isfile(f.replace('S1_patches', s2_patch_dir))] 101 | s2_2 = [f.replace('imgs_1','imgs_2') for f in s2_1 if os.path.isfile(f.replace('imgs_1','imgs_2'))] 102 | label = [f.replace('S1_patches','masks_patches').replace('imgs_1/','') for f in s1_1 if os.path.isfile(f.replace('S1_patches','masks_patches').replace('imgs_1/',''))] 103 | 104 | assert len(s1_1) == len(s1_2) == len(s2_1) == len(s2_2) == len(label) 105 | 106 | for idx in range(len(s1_1)): 107 | sample = {'time_1': {'S1': s1_1[idx], 'S2': s2_1[idx]}, 108 | 'time_2': {'S1': s1_2[idx], 'S2': s2_2[idx]}, 109 | 'label': label[idx]} 110 | paths.append(sample) 111 | 112 | # keep track of number of changed and total pixels 113 | patch_pix = self.read_img(label[idx], [1]) - 1 114 | self.true_pix += np.sum(patch_pix) 115 | self.n_pix += np.prod(patch_pix.shape) 116 | return paths 117 | 118 | def read_img(self, path_IMG, bands): 119 | tif = rasterio.open(path_IMG) 120 | return tif.read(tuple(bands)).astype(np.float32) 121 | 122 | def rescale(self, img, oldMin, oldMax): 123 | oldRange = oldMax - oldMin 124 | img = (img - oldMin) / oldRange 125 | return img 126 | 127 | def process_MS(self, img): 128 | intensity_min, intensity_max = 0, 10000 # define a reasonable range of MS intensities 129 | img = np.clip(img, intensity_min, intensity_max) # intensity clipping to a global unified MS intensity range 130 | img = self.rescale(img, intensity_min, intensity_max) # project to [0,1], preserve global intensities (across patches) 131 | return img 132 | 133 | def process_SAR(self, img): 134 | dB_min, dB_max = -25, 0 # define a reasonable range of SAR dB 135 | img = np.clip(img, dB_min, dB_max) # intensity clipping to a global unified SAR dB range 136 | img = self.rescale(img, dB_min, dB_max) 137 | return img 138 | 139 | def get_img(self, roi_name): 140 | # get path to full images 141 | containing_split = "Train" if roi_name in self.splits["train"] else "Test" 142 | containing_split = f"Onera Satellite Change Detection dataset - {containing_split} Labels" 143 | 144 | s1_1_path = os.path.join(self.root_dir, "S1", roi_name, "imgs_1", "transformed") 145 | s1_2_path = os.path.join(self.root_dir, "S1", roi_name, "imgs_2", "transformed") 146 | s1_1 = torch.from_numpy(self.process_SAR(self.read_img(os.path.join(s1_1_path, os.listdir(s1_1_path)[0]), [1, 2]))) 147 | s1_2 = torch.from_numpy(self.process_SAR(self.read_img(os.path.join(s1_2_path, os.listdir(s1_2_path)[0]), [1, 2]))) 148 | 149 | if self.run_on_onera: 150 | s2_1_path = os.path.join(self.root_dir, "Onera Satellite Change Detection dataset - Images", roi_name, "imgs_1_rect") 151 | s2_2_path = os.path.join(self.root_dir, "Onera Satellite Change Detection dataset - Images", roi_name, "imgs_2_rect") 152 | indices = self.s2_channels - 1 # convert rasterio indices back to np indices 153 | s2_1 = torch.from_numpy(self.process_MS(np.array([self.read_img(os.path.join(s2_1_path, s2_1_file), [1]) for s2_1_file in natsorted(os.listdir(s2_1_path))])[indices, 0, ...])) 154 | s2_2 = torch.from_numpy(self.process_MS(np.array([self.read_img(os.path.join(s2_2_path, s2_2_file), [1]) for s2_2_file in natsorted(os.listdir(s2_2_path))])[indices, 0, ...])) 155 | else: 156 | s2_1_path = os.path.join(self.root_dir, "S2", roi_name, "imgs_1", "transformed") 157 | s2_2_path = os.path.join(self.root_dir, "S2", roi_name, "imgs_2", "transformed") 158 | s2_1 = torch.from_numpy(self.process_MS(self.read_img(os.path.join(s2_1_path, os.listdir(s2_1_path)[0]), self.s2_channels))) 159 | s2_2 = torch.from_numpy(self.process_MS(self.read_img(os.path.join(s2_2_path, os.listdir(s2_2_path)[0]), self.s2_channels))) 160 | if self.normalize: 161 | # z-standardize the whole images (after already doing preprocessing, this may matter for the non-linear SAR transforms) 162 | s2_1 = (s2_1 - s2_1.mean()) / s2_1.std() 163 | s2_2 = (s2_2 - s2_2.mean()) / s2_2.std() 164 | 165 | mask_path = os.path.join(self.root_dir, containing_split, roi_name, "cm", f"{roi_name}-cm.tif") 166 | label = self.read_img(mask_path, [1])[0] - 1 167 | 168 | imgs = {'time_1': {'S1': s1_1, 'S2': s2_1}, 169 | 'time_2': {'S1': s1_2, 'S2': s2_2}, 170 | 'label': label} 171 | return imgs 172 | 173 | def get_img_np_instead_of_tensor(self, roi_name): 174 | # get path to full images 175 | 176 | containing_split = "Train" if roi_name in self.splits["train"] else "Test" 177 | containing_split = f"Onera Satellite Change Detection dataset - {containing_split} Labels" 178 | 179 | s1_1_path = os.path.join(self.root_dir, "S1", roi_name, "imgs_1", "transformed") 180 | s1_2_path = os.path.join(self.root_dir, "S1", roi_name, "imgs_2", "transformed") 181 | s1_1 = self.process_SAR(self.read_img(os.path.join(s1_1_path, os.listdir(s1_1_path)[0]), [1, 2])) 182 | s1_2 = self.process_SAR(self.read_img(os.path.join(s1_2_path, os.listdir(s1_2_path)[0]), [1, 2])) 183 | 184 | if self.run_on_onera: 185 | s2_1_path = os.path.join(self.root_dir, "Onera Satellite Change Detection dataset - Images", roi_name, "imgs_1_rect") 186 | s2_2_path = os.path.join(self.root_dir, "Onera Satellite Change Detection dataset - Images", roi_name, "imgs_2_rect") 187 | indices = self.s2_channels - 1 # convert rasterio indices back to np indices 188 | s2_1 = self.process_MS(np.array([self.read_img(os.path.join(s2_1_path, s2_1_file), [1]) for s2_1_file in natsorted(os.listdir(s2_1_path))])[indices, 0, ...]) 189 | s2_2 = self.process_MS(np.array([self.read_img(os.path.join(s2_2_path, s2_2_file), [1]) for s2_2_file in natsorted(os.listdir(s2_2_path))])[indices, 0, ...]) 190 | else: 191 | s2_1_path = os.path.join(self.root_dir, "S2", roi_name, "imgs_1", "transformed") 192 | s2_2_path = os.path.join(self.root_dir, "S2", roi_name, "imgs_2", "transformed") 193 | s2_1 = self.process_MS(self.read_img(os.path.join(s2_1_path, os.listdir(s2_1_path)[0]), self.s2_channels)) 194 | s2_2 = self.process_MS(self.read_img(os.path.join(s2_2_path, os.listdir(s2_2_path)[0]), self.s2_channels)) 195 | if self.normalize: 196 | # z-standardize the whole images (after already doing preprocessing, this may matter for the non-linear SAR transforms) 197 | s2_1 = (s2_1 - s2_1.mean()) / s2_1.std() 198 | s2_2 = (s2_2 - s2_2.mean()) / s2_2.std() 199 | 200 | mask_path = os.path.join(self.root_dir, containing_split, roi_name, "cm", f"{roi_name}-cm.tif") 201 | label = self.read_img(mask_path, [1])[0] - 1 202 | 203 | imgs = {'time_1': {'S1': s1_1, 'S2': s2_1}, 204 | 'time_2': {'S1': s1_2, 'S2': s2_2}, 205 | 'label': label} 206 | return imgs 207 | 208 | def __getitem__(self, idx): 209 | 210 | if self.use_pre_sliced: 211 | # use patches sliced before (offline) by pre-processing script 212 | s1_1 = self.process_SAR(self.read_img(self.paths[idx]['time_1']['S1'], [1, 2])) 213 | s2_1 = self.process_MS(self.read_img(self.paths[idx]['time_1']['S2'], self.s2_channels)) 214 | s1_2 = self.process_SAR(self.read_img(self.paths[idx]['time_2']['S1'], [1, 2])) 215 | s2_2 = self.process_MS(self.read_img(self.paths[idx]['time_2']['S2'], self.s2_channels)) 216 | label= self.read_img(self.paths[idx]['label'], [1]) - 1 217 | 218 | sample = {'time_1': {'S1': s1_1, 'S2': s2_1}, 219 | 'time_2': {'S1': s1_2, 'S2': s2_2}, 220 | 'label': label, 221 | 'idx': idx, 222 | } 223 | else: 224 | # load full-scene images and slice into patches online 225 | # (mirror-pad) and slice into patches 226 | first_idx = np.where(self.patch_indices>idx)[0][0]-1 227 | self.indices_dict[first_idx] 228 | self.full_scenes[self.indices_dict[first_idx]]['patches'] 229 | # map the queried index to the current ROI's patch slice anchors (top left corner) 230 | patch_idx = self.full_scenes[self.indices_dict[first_idx]]['patches'][idx-self.patch_indices[first_idx]] 231 | # read the actual patch, given the full-scene image and the patch indices 232 | s1_1 = self.full_scenes[list(self.full_scenes.keys())[first_idx]]['time_1']['S1'][:, patch_idx[0]:(patch_idx[0]+self.patch_size), patch_idx[1]:(patch_idx[1]+self.patch_size)] 233 | s1_2 = self.full_scenes[list(self.full_scenes.keys())[first_idx]]['time_2']['S1'][:, patch_idx[0]:(patch_idx[0]+self.patch_size), patch_idx[1]:(patch_idx[1]+self.patch_size)] 234 | s2_1 = self.full_scenes[list(self.full_scenes.keys())[first_idx]]['time_1']['S2'][:, patch_idx[0]:(patch_idx[0]+self.patch_size), patch_idx[1]:(patch_idx[1]+self.patch_size)] 235 | s2_2 = self.full_scenes[list(self.full_scenes.keys())[first_idx]]['time_2']['S2'][:, patch_idx[0]:(patch_idx[0]+self.patch_size), patch_idx[1]:(patch_idx[1]+self.patch_size)] 236 | label= self.full_scenes[list(self.full_scenes.keys())[first_idx]]['label'][patch_idx[0]:(patch_idx[0]+self.patch_size), patch_idx[1]:(patch_idx[1]+self.patch_size)][None] 237 | 238 | # note: no preprocessing done here for online slicing as we already do preprocess when loading the full scene, on a whole-image basis 239 | sample = {'time_1': {'S1': s1_1, 'S2': s2_1}, 240 | 'time_2': {'S1': s1_2, 'S2': s2_2}, 241 | 'label': label, 242 | 'idx': idx, 243 | } 244 | # apply data augmentation 245 | if self.transform: sample = self.transform(sample) 246 | 247 | return sample 248 | 249 | def __len__(self): 250 | # length of generated list 251 | return self.n_samples 252 | 253 | 254 | class RandomFlip(object): 255 | """Flip randomly the images in a sample, right to left side.""" 256 | 257 | def __call__(self, sample): 258 | I1, I2, I1_b, I2_b, label = sample['time_1']['S2'], sample['time_2']['S2'], sample['time_1']['S1'], sample['time_2']['S1'], sample['label'] 259 | 260 | if random.random() > 0.5: 261 | I1 = I1[:, :, ::-1].copy() 262 | I2 = I2[:, :, ::-1].copy() 263 | I1_b = I1_b[:, :, ::-1].copy() 264 | I2_b = I2_b[:, :, ::-1].copy() 265 | label = label[:, :, ::-1].copy() 266 | sample = {'time_1': {'S1': I1_b, 'S2': I1}, 267 | 'time_2': {'S1': I2_b, 'S2': I2}, 268 | 'label': label, 269 | 'idx': sample['idx'], 270 | } 271 | return sample 272 | 273 | 274 | class RandomRot(object): 275 | """Rotate randomly the images in a sample.""" 276 | 277 | def __call__(self, sample): 278 | I1, I2, I1_b, I2_b, label = sample['time_1']['S2'], sample['time_2']['S2'], sample['time_1']['S1'], sample['time_2']['S1'], sample['label'] 279 | 280 | n = random.randint(0, 3) 281 | if n: 282 | I1 = I1 283 | I1 = np.rot90(I1, n, axes=(1, 2)).copy() 284 | I2 = I2 285 | I2 = np.rot90(I2, n, axes=(1, 2)).copy() 286 | I1_b = I1_b 287 | I1_b = np.rot90(I1_b, n, axes=(1, 2)).copy() 288 | I2_b = I2_b 289 | I2_b = np.rot90(I2_b, n, axes=(1, 2)).copy() 290 | label = sample['label'] 291 | label = np.rot90(label, n, axes=(1, 2)).copy() 292 | sample = {'time_1': {'S1': I1_b, 'S2': I1}, 293 | 'time_2': {'S1': I2_b, 'S2': I2}, 294 | 'label': label, 295 | 'idx': sample['idx'], 296 | } 297 | return sample 298 | -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | # PyTorch 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | # Data 7 | from data_loader import multiCD 8 | 9 | # Models 10 | from aux.fully_convolutional_change_detection.unet import Unet 11 | from aux.fully_convolutional_change_detection.siamunet_conc import SiamUnet_conc, SiamUnet_conc_multi 12 | from aux.fully_convolutional_change_detection.siamunet_diff import SiamUnet_diff 13 | from aux.fully_convolutional_change_detection.fresunet import FresUNet 14 | 15 | # Other 16 | import os 17 | import numpy as np 18 | import random 19 | from skimage import io 20 | from scipy.ndimage import zoom 21 | import matplotlib.pyplot as plt 22 | from tqdm import tqdm as tqdm 23 | from pandas import read_csv 24 | from math import floor, ceil, sqrt, exp 25 | from IPython import display 26 | import time 27 | from itertools import chain 28 | import time 29 | import warnings 30 | from pprint import pprint 31 | 32 | # Global Variables' Definitions 33 | PATH_TO_DATASET = "/mnt/data/ONERA_s1_s2/" 34 | FP_MODIFIER = 10 # Tuning parameter, use 1 if unsure 35 | BATCH_SIZE = 32 # number of elements in a batch 36 | NUM_THREADS = 6 # number of parallel threads in data loader 37 | NET = 'SiamUnet_conc' # 'Unet', 'SiamUnet_conc-simple', 'SiamUnet_conc', 'SiamUnet_diff', 'FresUNet' 38 | N_EPOCHS = 50 # number of epochs to train the network 39 | TYPE = 4 # type of input to the network: 0-RGB | 1-RGBIr | 2-All bands s.t. resulution <= 20m | 3-All bands | 4-All bands Sentinel-2 & Sentinel-1 | 5-RGB bands Sentinel-2 & Sentinel-1 40 | LOAD_TRAINED = False # whether to load a pre-trained model or train the network 41 | DATA_AUG = True # whether to apply data augmentation (mirroring, rotating) or not 42 | ONERA_PATCHES = True # whether to train on patches sliced on the original Onera images or not 43 | NORMALISE_IMGS = True # z-standardizing on full-image basis, note: only implemented for online slicing! 44 | PRE_SLICED = False # whether to use pre-sliced patches (processed offline) for training or do online-slicing instead 45 | 46 | L = 1024 47 | N = 2 48 | 49 | # not applicable for our data loader 50 | # PATCH_SIDE = 96 51 | # TRAIN_STRIDE = int(PATCH_SIDE/2) - 1 52 | 53 | augm_str = 'Augm' if DATA_AUG else 'noAugm' 54 | save_path = os.path.join('/mnt/results/multimodal_CD_ISPRS', 'results', f'{NET}-type-{TYPE}-epochs-{N_EPOCHS}-{augm_str}') 55 | 56 | if not os.path.exists(save_path): 57 | os.makedirs(save_path) 58 | os.makedirs(os.path.join(save_path, 'plots')) 59 | os.makedirs(os.path.join(save_path, 'checkpoints')) 60 | 61 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # match IDs of nvidia-smi 62 | os.environ["CUDA_VISIBLE_DEVICES"] = str(0) # set only passed devices visible 63 | 64 | 65 | def count_parameters(model): 66 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 67 | 68 | 69 | def train(net, train_loader, train_dataset, test_dataset, n_epochs=N_EPOCHS, save=True): 70 | t = np.linspace(1, n_epochs, n_epochs) 71 | 72 | epoch_train_loss = 0 * t 73 | epoch_train_accuracy = 0 * t 74 | epoch_train_change_accuracy = 0 * t 75 | epoch_train_nochange_accuracy = 0 * t 76 | epoch_train_precision = 0 * t 77 | epoch_train_recall = 0 * t 78 | epoch_train_Fmeasure = 0 * t 79 | epoch_test_loss = 0 * t 80 | epoch_test_accuracy = 0 * t 81 | epoch_test_change_accuracy = 0 * t 82 | epoch_test_nochange_accuracy = 0 * t 83 | epoch_test_precision = 0 * t 84 | epoch_test_recall = 0 * t 85 | epoch_test_Fmeasure = 0 * t 86 | 87 | # mean_acc = 0 88 | # best_mean_acc = 0 89 | fm = 0 90 | best_fm = 0 91 | 92 | lss = 1000 93 | best_lss = 1000 94 | 95 | plt.figure(num=1) 96 | plt.figure(num=2) 97 | plt.figure(num=3) 98 | 99 | optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-4) 100 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95) 101 | 102 | # train the network for a given number of epochs 103 | for epoch_index in tqdm(range(n_epochs)): 104 | net.train() 105 | print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(N_EPOCHS)) 106 | 107 | tot_count = 0 108 | tot_loss = 0 109 | tot_accurate = 0 110 | class_correct = list(0. for i in range(2)) 111 | class_total = list(0. for i in range(2)) 112 | 113 | # iterate over batch 114 | for batch in train_loader: 115 | # inserted multi-modality here 116 | S2_1 = Variable(batch['time_1']['S2'].float().cuda()) 117 | S2_2 = Variable(batch['time_2']['S2'].float().cuda()) 118 | if TYPE in [4, 5]: 119 | S1_1 = Variable(batch['time_1']['S1'].float().cuda()) 120 | S1_2 = Variable(batch['time_2']['S1'].float().cuda()) 121 | label = torch.squeeze(Variable(batch['label'].cuda())) 122 | 123 | # get predictions, compute losses and optimize network 124 | optimizer.zero_grad() 125 | # outputs of the network are [N x 2 x H x W] 126 | # label is of shape [32, 96, 96] 127 | if TYPE in [4, 5]: 128 | output = net(S2_1, S2_2, S1_1, S1_2) 129 | else: 130 | output = net(S2_1, S2_2) 131 | loss = criterion(output, label.long()) 132 | loss.backward() 133 | optimizer.step() 134 | 135 | # step in lr scheduler 136 | scheduler.step() 137 | 138 | # evaluate network statistics on train split and keep track 139 | epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index], cl_acc, pr_rec = test(train_dataset) 140 | epoch_train_nochange_accuracy[epoch_index] = cl_acc[0] 141 | epoch_train_change_accuracy[epoch_index] = cl_acc[1] 142 | epoch_train_precision[epoch_index] = pr_rec[0] 143 | epoch_train_recall[epoch_index] = pr_rec[1] 144 | epoch_train_Fmeasure[epoch_index] = pr_rec[2] 145 | 146 | # evaluate network statistics on test split and keep track 147 | epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test(test_dataset) 148 | epoch_test_nochange_accuracy[epoch_index] = cl_acc[0] 149 | epoch_test_change_accuracy[epoch_index] = cl_acc[1] 150 | epoch_test_precision[epoch_index] = pr_rec[0] 151 | epoch_test_recall[epoch_index] = pr_rec[1] 152 | epoch_test_Fmeasure[epoch_index] = pr_rec[2] 153 | 154 | print(f'Test F1 in epoch {epoch_index}: {epoch_test_Fmeasure[epoch_index]}') 155 | plt.figure(num=1) 156 | plt.clf() 157 | l1_1, = plt.plot(t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1], label='Train loss') 158 | l1_2, = plt.plot(t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1], label='Test loss') 159 | plt.legend(handles=[l1_1, l1_2]) 160 | plt.grid() 161 | plt.gcf().gca().set_xlim(left=0) 162 | plt.title('Loss') 163 | display.clear_output(wait=True) 164 | display.display(plt.gcf()) 165 | 166 | plt.figure(num=2) 167 | plt.clf() 168 | l2_1, = plt.plot(t[:epoch_index + 1], epoch_train_accuracy[:epoch_index + 1], label='Train accuracy') 169 | l2_2, = plt.plot(t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1], label='Test accuracy') 170 | plt.legend(handles=[l2_1, l2_2]) 171 | plt.grid() 172 | plt.gcf().gca().set_ylim(0, 100) 173 | plt.title('Accuracy') 174 | display.clear_output(wait=True) 175 | display.display(plt.gcf()) 176 | 177 | plt.figure(num=3) 178 | plt.clf() 179 | l3_1, = plt.plot(t[:epoch_index + 1], epoch_train_nochange_accuracy[:epoch_index + 1], 180 | label='Train accuracy: no change') 181 | l3_2, = plt.plot(t[:epoch_index + 1], epoch_train_change_accuracy[:epoch_index + 1], 182 | label='Train accuracy: change') 183 | l3_3, = plt.plot(t[:epoch_index + 1], epoch_test_nochange_accuracy[:epoch_index + 1], 184 | label='Test accuracy: no change') 185 | l3_4, = plt.plot(t[:epoch_index + 1], epoch_test_change_accuracy[:epoch_index + 1], 186 | label='Test accuracy: change') 187 | plt.legend(handles=[l3_1, l3_2, l3_3, l3_4]) 188 | plt.grid() 189 | plt.gcf().gca().set_ylim(0, 100) 190 | plt.title('Accuracy per class') 191 | display.clear_output(wait=True) 192 | display.display(plt.gcf()) 193 | 194 | plt.figure(num=4) 195 | plt.clf() 196 | l4_1, = plt.plot(t[:epoch_index + 1], epoch_train_precision[:epoch_index + 1], label='Train precision') 197 | l4_2, = plt.plot(t[:epoch_index + 1], epoch_train_recall[:epoch_index + 1], label='Train recall') 198 | l4_3, = plt.plot(t[:epoch_index + 1], epoch_train_Fmeasure[:epoch_index + 1], label='Train Dice/F1') 199 | l4_4, = plt.plot(t[:epoch_index + 1], epoch_test_precision[:epoch_index + 1], label='Test precision') 200 | l4_5, = plt.plot(t[:epoch_index + 1], epoch_test_recall[:epoch_index + 1], label='Test recall') 201 | l4_6, = plt.plot(t[:epoch_index + 1], epoch_test_Fmeasure[:epoch_index + 1], label='Test Dice/F1') 202 | plt.legend(handles=[l4_1, l4_2, l4_3, l4_4, l4_5, l4_6]) 203 | plt.grid() 204 | plt.gcf().gca().set_ylim(0, 1) 205 | plt.title('Precision, Recall and F-measure') 206 | display.clear_output(wait=True) 207 | display.display(plt.gcf()) 208 | 209 | fm = epoch_train_Fmeasure[epoch_index] 210 | if fm > best_fm: 211 | best_fm = fm 212 | save_str = os.path.join(save_path, 'checkpoints', 'net-best_epoch-' + str(epoch_index + 1) + '_fm-' + str(fm) + '.pth.tar') 213 | torch.save(net.state_dict(), save_str) 214 | 215 | lss = epoch_train_loss[epoch_index] 216 | if lss < best_lss: 217 | best_lss = lss 218 | save_str = os.path.join(save_path, 'checkpoints', 'net-best_epoch-' + str(epoch_index + 1) + '_loss-' + str(lss) + '.pth.tar') 219 | torch.save(net.state_dict(), save_str) 220 | 221 | if save: 222 | im_format = 'png' 223 | plt.figure(num=1) 224 | plt.savefig(os.path.join(save_path, 'plots', net_name + '-01-loss.' + im_format)) 225 | plt.figure(num=2) 226 | plt.savefig(os.path.join(save_path, 'plots', net_name + '-02-accuracy.' + im_format)) 227 | plt.figure(num=3) 228 | plt.savefig(os.path.join(save_path, 'plots', net_name + '-03-accuracy-per-class.' + im_format)) 229 | plt.figure(num=4) 230 | plt.savefig(os.path.join(save_path, 'plots', net_name + '-04-prec-rec-fmeas.' + im_format)) 231 | 232 | out = {'train_loss': epoch_train_loss[-1], 233 | 'train_accuracy': epoch_train_accuracy[-1], 234 | 'train_nochange_accuracy': epoch_train_nochange_accuracy[-1], 235 | 'train_change_accuracy': epoch_train_change_accuracy[-1], 236 | 'test_loss': epoch_test_loss[-1], 237 | 'test_accuracy': epoch_test_accuracy[-1], 238 | 'test_nochange_accuracy': epoch_test_nochange_accuracy[-1], 239 | 'test_change_accuracy': epoch_test_change_accuracy[-1]} 240 | 241 | print('pr_c, rec_c, f_meas, pr_nc, rec_nc') 242 | print(pr_rec) 243 | 244 | return out 245 | 246 | # run network on full-scene ROI and evaluate performance 247 | def test(dset): 248 | net.eval() 249 | tot_loss = 0 250 | tot_count = 0 251 | 252 | n = 2 253 | class_correct = list(0. for i in range(n)) 254 | class_total = list(0. for i in range(n)) 255 | class_accuracy = list(0. for i in range(n)) 256 | 257 | tp = 0 258 | tn = 0 259 | fp = 0 260 | fn = 0 261 | 262 | # iterate over all ROI, load modalities and 263 | for img_index in dset.names: 264 | print(f"Testing for ROI {img_index}") 265 | # inserted multi-modality here 266 | full_imgs = dset.get_img(img_index) 267 | S2_1_full, S2_2_full, cm_full = full_imgs['time_1']['S2'], full_imgs['time_2']['S2'], full_imgs['label'] 268 | s = cm_full.shape 269 | 270 | if TYPE in [4, 5]: 271 | S1_1_full, S1_2_full = full_imgs['time_1']['S1'], full_imgs['time_2']['S1'] 272 | 273 | steps0 = np.arange(0, s[0], ceil(s[0] / N)) 274 | steps1 = np.arange(0, s[1], ceil(s[1] / N)) 275 | for ii in range(N): 276 | for jj in range(N): 277 | xmin = steps0[ii] 278 | if ii == N - 1: 279 | xmax = s[0] 280 | else: 281 | xmax = steps0[ii + 1] 282 | ymin = jj 283 | if jj == N - 1: 284 | ymax = s[1] 285 | else: 286 | ymax = steps1[jj + 1] 287 | # inserted multi-modality here 288 | S2_1 = S2_1_full[:, xmin:xmax, ymin:ymax] 289 | S2_2 = S2_2_full[:, xmin:xmax, ymin:ymax] 290 | cm = cm_full[xmin:xmax, ymin:ymax] 291 | 292 | S2_1 = Variable(torch.unsqueeze(S2_1, 0).float()).cuda() 293 | S2_2 = Variable(torch.unsqueeze(S2_2, 0).float()).cuda() 294 | cm = Variable(torch.unsqueeze(torch.from_numpy(1.0 * cm), 0).float()).cuda() 295 | 296 | if TYPE in [4, 5]: 297 | S1_1 = S1_1_full[:, xmin:xmax, ymin:ymax] 298 | S1_2 = S1_2_full[:, xmin:xmax, ymin:ymax] 299 | S1_1 = Variable(torch.unsqueeze(S1_1, 0).float()).cuda() 300 | S1_2 = Variable(torch.unsqueeze(S1_2, 0).float()).cuda() 301 | 302 | # predict output via network and compute losses 303 | # outputs of the network are [N x 2 x H x W] 304 | if TYPE in [4, 5]: 305 | output = net(S2_1, S2_2, S1_1, S1_2) 306 | else: 307 | output = net(S2_1, S2_2) 308 | loss = criterion(output, cm.long()) 309 | tot_loss += loss.data * np.prod(cm.size()) 310 | tot_count += np.prod(cm.size()) 311 | 312 | _, predicted = torch.max(output.data, 1) 313 | 314 | # compare predictions with change maps and count correct predictions 315 | c = (predicted.int() == cm.data.int()) 316 | for i in range(c.size(1)): 317 | for j in range(c.size(2)): 318 | l = int(cm.data[0, i, j]) 319 | class_correct[l] += c[0, i, j] 320 | class_total[l] += 1 321 | 322 | pr = (predicted.int() > 0).cpu().numpy() 323 | gt = (cm.data.int() > 0).cpu().numpy() 324 | 325 | # evaluate TP, TN, FP & FN 326 | tp += np.logical_and(pr, gt).sum() 327 | tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum() 328 | fp += np.logical_and(pr, np.logical_not(gt)).sum() 329 | fn += np.logical_and(np.logical_not(pr), gt).sum() 330 | 331 | # compute goodness of predictions 332 | net_loss = tot_loss / tot_count 333 | net_accuracy = 100 * (tp + tn) / tot_count 334 | 335 | for i in range(n): # compute classwise accuracies 336 | class_accuracy[i] = 100 * class_correct[i] / max(class_total[i], 0.00001) 337 | 338 | # get precision, recall etc 339 | prec = tp / (tp + fp) 340 | rec = tp / (tp + fn) 341 | f_meas = 2 * prec * rec / (prec + rec) 342 | prec_nc = tn / (tn + fn) 343 | rec_nc = tn / (tn + fp) 344 | pr_rec = [prec, rec, f_meas, prec_nc, rec_nc] 345 | 346 | return net_loss, net_accuracy, class_accuracy, pr_rec 347 | 348 | # run predictions for a given network and data set 349 | # and save all predictions as png 350 | def save_test_results(net, dset): 351 | for name in tqdm(dset.names): 352 | print(f"Saving prediction on ROI {name}") 353 | with warnings.catch_warnings(): 354 | full_imgs = dset.get_img(name) 355 | # inserted multi-modality here 356 | S2_1, S2_2, cm = full_imgs['time_1']['S2'], full_imgs['time_2']['S2'], full_imgs['label'] 357 | S2_1 = Variable(torch.unsqueeze(S2_1, 0).float()).cuda() 358 | S2_2 = Variable(torch.unsqueeze(S2_2, 0).float()).cuda() 359 | if TYPE in [4, 5]: 360 | S1_1 = full_imgs['time_1']['S1'] 361 | S1_2 = full_imgs['time_2']['S1'] 362 | S1_1 = Variable(torch.unsqueeze(S1_1, 0).float()).cuda() 363 | S1_2 = Variable(torch.unsqueeze(S1_2, 0).float()).cuda() 364 | out = net(S2_1, S2_2, S1_1, S1_2) 365 | else: 366 | out = net(S2_1, S2_2) 367 | _, predicted = torch.max(out.data, 1) 368 | # save plot of difference maps 369 | I = np.stack((255*cm,255*np.squeeze(predicted.cpu().numpy()),255*cm),2) 370 | io.imsave(os.path.join(save_path, 'plots', f'{net_name}-{name}.png'), I) 371 | 372 | 373 | if __name__ == '__main__': 374 | 375 | # initialize data set instances for train and test splits 376 | data_loader_train = multiCD(PATH_TO_DATASET, split="train", transform=DATA_AUG, run_on_onera_patches=ONERA_PATCHES, use_pre_sliced=PRE_SLICED, normalize=NORMALISE_IMGS) 377 | train_loader = torch.utils.data.DataLoader(data_loader_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_THREADS) 378 | data_loader_test = multiCD(PATH_TO_DATASET, split="test", run_on_onera_patches=ONERA_PATCHES, normalize=NORMALISE_IMGS) 379 | # note: test loader is never used, testing is always done on full-scene images (not on the patches) 380 | # test_loader = torch.utils.data.DataLoader(data_loader_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_THREADS) 381 | 382 | # get train split weighting of pixel labels 383 | weights = torch.FloatTensor(data_loader_train.weights).cuda() 384 | print(f"Train data set weighting is: {weights}") 385 | print(f"Total pixel numbers: {data_loader_train.n_pix}") 386 | print(f"Changed pixel numbers: {data_loader_train.true_pix}") 387 | print(f"Change-to-total ratio: {data_loader_train.true_pix / data_loader_train.n_pix}") 388 | 389 | # 0-RGB | 1-RGBIr | 2-All bands s.t. resolution <= 20m | 3-All bands | 4-All bands Sentinel-2 & Sentinel-1 | 5-RGB bands Sentinel-2 & Sentinel-1 390 | 391 | if TYPE == 0: 392 | if NET == 'Unet': net, net_name = Unet(2*3, 2), 'FC-EF' 393 | if NET == 'SiamUnet_conc': net, net_name = SiamUnet_conc(3, 2), 'FC-Siam-conc' 394 | if NET == 'SiamUnet_diff': net, net_name = SiamUnet_diff(3, 2), 'FC-Siam-diff' 395 | if NET == 'FresUNet': net, net_name = FresUNet(2*3, 2), 'FresUNet' 396 | elif TYPE == 1: 397 | if NET == 'Unet': net, net_name = Unet(2*4, 2), 'FC-EF' 398 | if NET == 'SiamUnet_conc': net, net_name = SiamUnet_conc(4, 2), 'FC-Siam-conc' 399 | if NET == 'SiamUnet_diff': net, net_name = SiamUnet_diff(4, 2), 'FC-Siam-diff' 400 | if NET == 'FresUNet': net, net_name = FresUNet(2*4, 2), 'FresUNet' 401 | elif TYPE == 2: 402 | if NET == 'Unet': net, net_name = Unet(2*10, 2), 'FC-EF' 403 | if NET == 'SiamUnet_conc': net, net_name = SiamUnet_conc(10, 2), 'FC-Siam-conc' 404 | if NET == 'SiamUnet_diff': net, net_name = SiamUnet_diff(10, 2), 'FC-Siam-diff' 405 | if NET == 'FresUNet': net, net_name = FresUNet(2*10, 2), 'FresUNet' 406 | elif TYPE == 3: 407 | if NET == 'Unet': net, net_name = Unet(2*13, 2), 'FC-EF' 408 | if NET == 'SiamUnet_conc': net, net_name = SiamUnet_conc(13, 2), 'FC-Siam-conc' 409 | if NET == 'SiamUnet_diff': net, net_name = SiamUnet_diff(13, 2), 'FC-Siam-diff' 410 | if NET == 'FresUNet': net, net_name = FresUNet(2*13, 2), 'FresUNet' 411 | elif TYPE == 4: 412 | if NET == 'Unet': net, net_name = Unet(2*13+2*2, 2), 'FC-EF-multi' # same architecture as the other network 413 | if NET == 'SiamUnet_conc-simple': net, net_name = SiamUnet_conc(13+2, 2), 'FC-Siam-conc-simple' 414 | if NET == 'SiamUnet_conc': net, net_name = SiamUnet_conc_multi((13, 2), 2), 'FC-Siam-conc-complex' 415 | if NET == 'SiamUnet_diff': net, net_name = SiamUnet_diff_multi(13, 2, 2), 'FC-Siam-diff' 416 | if NET == 'FresUNet': net, net_name = FresUNet_multi(2 * 13 + 2*13, 2), 'FresUNet' 417 | elif TYPE == 5: 418 | if NET == 'Unet': net, net_name = Unet(2*3+2*2, 2), 'FC-EF-multi' # same architecture as the other network 419 | if NET == 'SiamUnet_conc-simple': net, net_name = SiamUnet_conc(3+2, 2), 'FC-Siam-conc-simple' 420 | if NET == 'SiamUnet_conc': net, net_name = SiamUnet_conc_multi((3, 2), 2), 'FC-Siam-conc-complex' 421 | if NET == 'SiamUnet_diff': net, net_name = SiamUnet_diff_multi(3, 2, 2), 'FC-Siam-diff' 422 | if NET == 'FresUNet': net, net_name = FresUNet_multi(2 * 3 + 2*3, 2), 'FresUNet' 423 | net.cuda() 424 | 425 | # define loss: logsoftmax output 426 | criterion = nn.NLLLoss(weight=weights) 427 | print('Number of trainable parameters:', count_parameters(net)) 428 | 429 | # either load a pre-trained model or train a model from scratch 430 | if LOAD_TRAINED: 431 | # load e.g. net.load_state_dict(torch.load('net-best_epoch-1_fm-0.7394933126157746.pth.tar')) 432 | net.load_state_dict(torch.load(os.path.join(save_path, 'checkpoints', 'net_final.pth.tar'))) 433 | print('LOAD OK') 434 | else: 435 | t_start = time.time() 436 | # train the network and, at the end of each epoch, 437 | # get its performance on train & test split 438 | out_dic = train(net, train_loader, data_loader_train, data_loader_test) 439 | t_end = time.time() 440 | print(out_dic) 441 | print('Elapsed time:') 442 | print(t_end - t_start) 443 | torch.save(net.state_dict(), os.path.join(save_path, 'checkpoints', 'net_final.pth.tar')) 444 | print('SAVE OK') 445 | 446 | t_start = time.time() 447 | save_test_results(net, data_loader_test) 448 | t_end = time.time() 449 | print('Elapsed time: {}'.format(t_end - t_start)) 450 | -------------------------------------------------------------------------------- /previews/preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatrickTUM/multimodalCD_ISPRS21/0339b43a53f466a92b672bceeb3c183aa783e9a0/previews/preview.png --------------------------------------------------------------------------------