├── .gitignore ├── README.md ├── fresunet.py ├── fully-convolutional-change-detection.ipynb ├── siamunet_conc.py ├── siamunet_diff.py └── unet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fully_convolutional_change_detection 2 | 3 | Fully convolutional network architectures for change detection using remote sensing images. 4 | 5 | [Rodrigo Caye Daudt, Bertrand Le Saux, Alexandre Boulch. (2018, October). Fully convolutional siamese networks for change detection. In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.](https://ieeexplore.ieee.org/abstract/document/8451652) 6 | 7 | [arXiv](https://arxiv.org/abs/1810.08462) -------------------------------------------------------------------------------- /fresunet.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R.C., Le Saux, B., Boulch, A. and Gousseau, Y., 2019. Multitask learning for large-scale semantic change detection. Computer Vision and Image Understanding, 187, p.102783. 4 | 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.modules.padding import ReplicationPad2d 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | "3x3 convolution with padding" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1) 15 | 16 | 17 | class BasicBlock_ss(nn.Module): 18 | 19 | def __init__(self, inplanes, planes = None, subsamp=1): 20 | super(BasicBlock_ss, self).__init__() 21 | if planes == None: 22 | planes = inplanes * subsamp 23 | self.conv1 = conv3x3(inplanes, planes) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.subsamp = subsamp 29 | self.doit = planes != inplanes 30 | if self.doit: 31 | self.couple = nn.Conv2d(inplanes, planes, kernel_size=1) 32 | self.bnc = nn.BatchNorm2d(planes) 33 | 34 | def forward(self, x): 35 | if self.doit: 36 | residual = self.couple(x) 37 | residual = self.bnc(residual) 38 | else: 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | if self.subsamp > 1: 46 | out = F.max_pool2d(out, kernel_size=self.subsamp, stride=self.subsamp) 47 | residual = F.max_pool2d(residual, kernel_size=self.subsamp, stride=self.subsamp) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | 59 | class BasicBlock_us(nn.Module): 60 | 61 | def __init__(self, inplanes, upsamp=1): 62 | super(BasicBlock_us, self).__init__() 63 | planes = int(inplanes / upsamp) # assumes integer result, fix later 64 | self.conv1 = nn.ConvTranspose2d(inplanes, planes, kernel_size=3, padding=1, stride=upsamp, output_padding=1) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.conv2 = conv3x3(planes, planes) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.upsamp = upsamp 70 | self.couple = nn.ConvTranspose2d(inplanes, planes, kernel_size=3, padding=1, stride=upsamp, output_padding=1) 71 | self.bnc = nn.BatchNorm2d(planes) 72 | 73 | def forward(self, x): 74 | residual = self.couple(x) 75 | residual = self.bnc(residual) 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | 85 | out += residual 86 | out = self.relu(out) 87 | 88 | return out 89 | 90 | 91 | class FresUNet(nn.Module): 92 | """FresUNet segmentation network.""" 93 | 94 | def __init__(self, input_nbr, label_nbr): 95 | """Init FresUNet fields.""" 96 | super(FresUNet, self).__init__() 97 | 98 | self.input_nbr = input_nbr 99 | 100 | cur_depth = input_nbr 101 | 102 | base_depth = 8 103 | 104 | # Encoding stage 1 105 | self.encres1_1 = BasicBlock_ss(cur_depth, planes = base_depth) 106 | cur_depth = base_depth 107 | d1 = base_depth 108 | self.encres1_2 = BasicBlock_ss(cur_depth, subsamp=2) 109 | cur_depth *= 2 110 | 111 | # Encoding stage 2 112 | self.encres2_1 = BasicBlock_ss(cur_depth) 113 | d2 = cur_depth 114 | self.encres2_2 = BasicBlock_ss(cur_depth, subsamp=2) 115 | cur_depth *= 2 116 | 117 | # Encoding stage 3 118 | self.encres3_1 = BasicBlock_ss(cur_depth) 119 | d3 = cur_depth 120 | self.encres3_2 = BasicBlock_ss(cur_depth, subsamp=2) 121 | cur_depth *= 2 122 | 123 | # Encoding stage 4 124 | self.encres4_1 = BasicBlock_ss(cur_depth) 125 | d4 = cur_depth 126 | self.encres4_2 = BasicBlock_ss(cur_depth, subsamp=2) 127 | cur_depth *= 2 128 | 129 | # Decoding stage 4 130 | self.decres4_1 = BasicBlock_ss(cur_depth) 131 | self.decres4_2 = BasicBlock_us(cur_depth, upsamp=2) 132 | cur_depth = int(cur_depth/2) 133 | 134 | # Decoding stage 3 135 | self.decres3_1 = BasicBlock_ss(cur_depth + d4, planes = cur_depth) 136 | self.decres3_2 = BasicBlock_us(cur_depth, upsamp=2) 137 | cur_depth = int(cur_depth/2) 138 | 139 | # Decoding stage 2 140 | self.decres2_1 = BasicBlock_ss(cur_depth + d3, planes = cur_depth) 141 | self.decres2_2 = BasicBlock_us(cur_depth, upsamp=2) 142 | cur_depth = int(cur_depth/2) 143 | 144 | # Decoding stage 1 145 | self.decres1_1 = BasicBlock_ss(cur_depth + d2, planes = cur_depth) 146 | self.decres1_2 = BasicBlock_us(cur_depth, upsamp=2) 147 | cur_depth = int(cur_depth/2) 148 | 149 | # Output 150 | self.coupling = nn.Conv2d(cur_depth + d1, label_nbr, kernel_size=1) 151 | self.sm = nn.LogSoftmax(dim=1) 152 | 153 | def forward(self, x1, x2): 154 | 155 | x = torch.cat((x1, x2), 1) 156 | 157 | # pad5 = ReplicationPad2d((0, x53.size(3) - x5d.size(3), 0, x53.size(2) - x5d.size(2))) 158 | 159 | s1_1 = x.size() 160 | x1 = self.encres1_1(x) 161 | x = self.encres1_2(x1) 162 | 163 | s2_1 = x.size() 164 | x2 = self.encres2_1(x) 165 | x = self.encres2_2(x2) 166 | 167 | s3_1 = x.size() 168 | x3 = self.encres3_1(x) 169 | x = self.encres3_2(x3) 170 | 171 | s4_1 = x.size() 172 | x4 = self.encres4_1(x) 173 | x = self.encres4_2(x4) 174 | 175 | x = self.decres4_1(x) 176 | x = self.decres4_2(x) 177 | s4_2 = x.size() 178 | pad4 = ReplicationPad2d((0, s4_1[3] - s4_2[3], 0, s4_1[2] - s4_2[2])) 179 | x = pad4(x) 180 | 181 | # x = self.decres3_1(x) 182 | x = self.decres3_1(torch.cat((x, x4), 1)) 183 | x = self.decres3_2(x) 184 | s3_2 = x.size() 185 | pad3 = ReplicationPad2d((0, s3_1[3] - s3_2[3], 0, s3_1[2] - s3_2[2])) 186 | x = pad3(x) 187 | 188 | x = self.decres2_1(torch.cat((x, x3), 1)) 189 | x = self.decres2_2(x) 190 | s2_2 = x.size() 191 | pad2 = ReplicationPad2d((0, s2_1[3] - s2_2[3], 0, s2_1[2] - s2_2[2])) 192 | x = pad2(x) 193 | 194 | x = self.decres1_1(torch.cat((x, x2), 1)) 195 | x = self.decres1_2(x) 196 | s1_2 = x.size() 197 | pad1 = ReplicationPad2d((0, s1_1[3] - s1_2[3], 0, s1_1[2] - s1_2[2])) 198 | x = pad1(x) 199 | 200 | x = self.coupling(torch.cat((x, x1), 1)) 201 | x = self.sm(x) 202 | 203 | return x -------------------------------------------------------------------------------- /fully-convolutional-change-detection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Fully Convolutional Networks for Change Detection\n", 8 | "\n", 9 | "Example code for training the network presented in the paper:\n", 10 | "\n", 11 | "```\n", 12 | "Daudt, R.C., Le Saux, B. and Boulch, A., 2018, October. Fully convolutional siamese networks for change detection. In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE.\n", 13 | "```\n", 14 | "\n", 15 | "Code uses the OSCD dataset:\n", 16 | "\n", 17 | "```\n", 18 | "Daudt, R.C., Le Saux, B., Boulch, A. and Gousseau, Y., 2018, July. Urban change detection for multispectral earth observation using convolutional neural networks. In IGARSS 2018-2018 IEEE International Geoscience and Remote Sensing Symposium (pp. 2115-2118). IEEE.\n", 19 | "```\n", 20 | "\n", 21 | "\n", 22 | "FresUNet architecture from paper:\n", 23 | "\n", 24 | "```\n", 25 | "Daudt, R.C., Le Saux, B., Boulch, A. and Gousseau, Y., 2019. Multitask learning for large-scale semantic change detection. Computer Vision and Image Understanding, 187, p.102783.\n", 26 | "```\n", 27 | "\n", 28 | "Please consider all relevant papers if you use this code." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# Rodrigo Daudt\n", 38 | "# rcdaudt.github.io\n", 39 | "# rodrigo.daudt@onera.fr" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "# Imports\n", 49 | "\n", 50 | "# PyTorch\n", 51 | "import torch\n", 52 | "import torch.nn as nn\n", 53 | "from torch.utils.data import Dataset, DataLoader\n", 54 | "from torch.autograd import Variable\n", 55 | "import torchvision.transforms as tr\n", 56 | "\n", 57 | "# Models\n", 58 | "from unet import Unet\n", 59 | "from siamunet_conc import SiamUnet_conc\n", 60 | "from siamunet_diff import SiamUnet_diff\n", 61 | "from fresunet import FresUNet\n", 62 | "\n", 63 | "# Other\n", 64 | "import os\n", 65 | "import numpy as np\n", 66 | "import random\n", 67 | "from skimage import io\n", 68 | "from scipy.ndimage import zoom\n", 69 | "import matplotlib.pyplot as plt\n", 70 | "%matplotlib inline\n", 71 | "from tqdm import tqdm as tqdm\n", 72 | "from pandas import read_csv\n", 73 | "from math import floor, ceil, sqrt, exp\n", 74 | "from IPython import display\n", 75 | "import time\n", 76 | "from itertools import chain\n", 77 | "import time\n", 78 | "import warnings\n", 79 | "from pprint import pprint\n", 80 | "\n", 81 | "\n", 82 | "\n", 83 | "print('IMPORTS OK')\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# Global Variables' Definitions\n", 100 | "\n", 101 | "PATH_TO_DATASET = './OSCD/'\n", 102 | "IS_PROTOTYPE = False\n", 103 | "\n", 104 | "FP_MODIFIER = 10 # Tuning parameter, use 1 if unsure\n", 105 | "\n", 106 | "BATCH_SIZE = 32\n", 107 | "PATCH_SIDE = 96\n", 108 | "N_EPOCHS = 50\n", 109 | "\n", 110 | "NORMALISE_IMGS = True\n", 111 | "\n", 112 | "TRAIN_STRIDE = int(PATCH_SIDE/2) - 1\n", 113 | "\n", 114 | "TYPE = 3 # 0-RGB | 1-RGBIr | 2-All bands s.t. resulution <= 20m | 3-All bands\n", 115 | "\n", 116 | "LOAD_TRAINED = False\n", 117 | "\n", 118 | "DATA_AUG = True\n", 119 | "\n", 120 | "\n", 121 | "print('DEFINITIONS OK')" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "# Functions\n", 131 | "\n", 132 | "def adjust_shape(I, s):\n", 133 | " \"\"\"Adjust shape of grayscale image I to s.\"\"\"\n", 134 | " \n", 135 | " # crop if necesary\n", 136 | " I = I[:s[0],:s[1]]\n", 137 | " si = I.shape\n", 138 | " \n", 139 | " # pad if necessary \n", 140 | " p0 = max(0,s[0] - si[0])\n", 141 | " p1 = max(0,s[1] - si[1])\n", 142 | " \n", 143 | " return np.pad(I,((0,p0),(0,p1)),'edge')\n", 144 | " \n", 145 | "\n", 146 | "def read_sentinel_img(path):\n", 147 | " \"\"\"Read cropped Sentinel-2 image: RGB bands.\"\"\"\n", 148 | " im_name = os.listdir(path)[0][:-7]\n", 149 | " r = io.imread(path + im_name + \"B04.tif\")\n", 150 | " g = io.imread(path + im_name + \"B03.tif\")\n", 151 | " b = io.imread(path + im_name + \"B02.tif\")\n", 152 | " \n", 153 | " I = np.stack((r,g,b),axis=2).astype('float')\n", 154 | " \n", 155 | " if NORMALISE_IMGS:\n", 156 | " I = (I - I.mean()) / I.std()\n", 157 | "\n", 158 | " return I\n", 159 | "\n", 160 | "def read_sentinel_img_4(path):\n", 161 | " \"\"\"Read cropped Sentinel-2 image: RGB and NIR bands.\"\"\"\n", 162 | " im_name = os.listdir(path)[0][:-7]\n", 163 | " r = io.imread(path + im_name + \"B04.tif\")\n", 164 | " g = io.imread(path + im_name + \"B03.tif\")\n", 165 | " b = io.imread(path + im_name + \"B02.tif\")\n", 166 | " nir = io.imread(path + im_name + \"B08.tif\")\n", 167 | " \n", 168 | " I = np.stack((r,g,b,nir),axis=2).astype('float')\n", 169 | " \n", 170 | " if NORMALISE_IMGS:\n", 171 | " I = (I - I.mean()) / I.std()\n", 172 | "\n", 173 | " return I\n", 174 | "\n", 175 | "def read_sentinel_img_leq20(path):\n", 176 | " \"\"\"Read cropped Sentinel-2 image: bands with resolution less than or equals to 20m.\"\"\"\n", 177 | " im_name = os.listdir(path)[0][:-7]\n", 178 | " \n", 179 | " r = io.imread(path + im_name + \"B04.tif\")\n", 180 | " s = r.shape\n", 181 | " g = io.imread(path + im_name + \"B03.tif\")\n", 182 | " b = io.imread(path + im_name + \"B02.tif\")\n", 183 | " nir = io.imread(path + im_name + \"B08.tif\")\n", 184 | " \n", 185 | " ir1 = adjust_shape(zoom(io.imread(path + im_name + \"B05.tif\"),2),s)\n", 186 | " ir2 = adjust_shape(zoom(io.imread(path + im_name + \"B06.tif\"),2),s)\n", 187 | " ir3 = adjust_shape(zoom(io.imread(path + im_name + \"B07.tif\"),2),s)\n", 188 | " nir2 = adjust_shape(zoom(io.imread(path + im_name + \"B8A.tif\"),2),s)\n", 189 | " swir2 = adjust_shape(zoom(io.imread(path + im_name + \"B11.tif\"),2),s)\n", 190 | " swir3 = adjust_shape(zoom(io.imread(path + im_name + \"B12.tif\"),2),s)\n", 191 | " \n", 192 | " I = np.stack((r,g,b,nir,ir1,ir2,ir3,nir2,swir2,swir3),axis=2).astype('float')\n", 193 | " \n", 194 | " if NORMALISE_IMGS:\n", 195 | " I = (I - I.mean()) / I.std()\n", 196 | "\n", 197 | " return I\n", 198 | "\n", 199 | "def read_sentinel_img_leq60(path):\n", 200 | " \"\"\"Read cropped Sentinel-2 image: all bands.\"\"\"\n", 201 | " im_name = os.listdir(path)[0][:-7]\n", 202 | " \n", 203 | " r = io.imread(path + im_name + \"B04.tif\")\n", 204 | " s = r.shape\n", 205 | " g = io.imread(path + im_name + \"B03.tif\")\n", 206 | " b = io.imread(path + im_name + \"B02.tif\")\n", 207 | " nir = io.imread(path + im_name + \"B08.tif\")\n", 208 | " \n", 209 | " ir1 = adjust_shape(zoom(io.imread(path + im_name + \"B05.tif\"),2),s)\n", 210 | " ir2 = adjust_shape(zoom(io.imread(path + im_name + \"B06.tif\"),2),s)\n", 211 | " ir3 = adjust_shape(zoom(io.imread(path + im_name + \"B07.tif\"),2),s)\n", 212 | " nir2 = adjust_shape(zoom(io.imread(path + im_name + \"B8A.tif\"),2),s)\n", 213 | " swir2 = adjust_shape(zoom(io.imread(path + im_name + \"B11.tif\"),2),s)\n", 214 | " swir3 = adjust_shape(zoom(io.imread(path + im_name + \"B12.tif\"),2),s)\n", 215 | " \n", 216 | " uv = adjust_shape(zoom(io.imread(path + im_name + \"B01.tif\"),6),s)\n", 217 | " wv = adjust_shape(zoom(io.imread(path + im_name + \"B09.tif\"),6),s)\n", 218 | " swirc = adjust_shape(zoom(io.imread(path + im_name + \"B10.tif\"),6),s)\n", 219 | " \n", 220 | " I = np.stack((r,g,b,nir,ir1,ir2,ir3,nir2,swir2,swir3,uv,wv,swirc),axis=2).astype('float')\n", 221 | " \n", 222 | " if NORMALISE_IMGS:\n", 223 | " I = (I - I.mean()) / I.std()\n", 224 | "\n", 225 | " return I\n", 226 | "\n", 227 | "def read_sentinel_img_trio(path):\n", 228 | " \"\"\"Read cropped Sentinel-2 image pair and change map.\"\"\"\n", 229 | "# read images\n", 230 | " if TYPE == 0:\n", 231 | " I1 = read_sentinel_img(path + '/imgs_1/')\n", 232 | " I2 = read_sentinel_img(path + '/imgs_2/')\n", 233 | " elif TYPE == 1:\n", 234 | " I1 = read_sentinel_img_4(path + '/imgs_1/')\n", 235 | " I2 = read_sentinel_img_4(path + '/imgs_2/')\n", 236 | " elif TYPE == 2:\n", 237 | " I1 = read_sentinel_img_leq20(path + '/imgs_1/')\n", 238 | " I2 = read_sentinel_img_leq20(path + '/imgs_2/')\n", 239 | " elif TYPE == 3:\n", 240 | " I1 = read_sentinel_img_leq60(path + '/imgs_1/')\n", 241 | " I2 = read_sentinel_img_leq60(path + '/imgs_2/')\n", 242 | " \n", 243 | " cm = io.imread(path + '/cm/cm.png', as_gray=True) != 0\n", 244 | " \n", 245 | " # crop if necessary\n", 246 | " s1 = I1.shape\n", 247 | " s2 = I2.shape\n", 248 | " I2 = np.pad(I2,((0, s1[0] - s2[0]), (0, s1[1] - s2[1]), (0,0)),'edge')\n", 249 | " \n", 250 | " \n", 251 | " return I1, I2, cm\n", 252 | "\n", 253 | "\n", 254 | "\n", 255 | "def reshape_for_torch(I):\n", 256 | " \"\"\"Transpose image for PyTorch coordinates.\"\"\"\n", 257 | "# out = np.swapaxes(I,1,2)\n", 258 | "# out = np.swapaxes(out,0,1)\n", 259 | "# out = out[np.newaxis,:]\n", 260 | " out = I.transpose((2, 0, 1))\n", 261 | " return torch.from_numpy(out)\n", 262 | "\n", 263 | "\n", 264 | "\n", 265 | "class ChangeDetectionDataset(Dataset):\n", 266 | " \"\"\"Change Detection dataset class, used for both training and test data.\"\"\"\n", 267 | "\n", 268 | " def __init__(self, path, train = True, patch_side = 96, stride = None, use_all_bands = False, transform=None):\n", 269 | " \"\"\"\n", 270 | " Args:\n", 271 | " csv_file (string): Path to the csv file with annotations.\n", 272 | " root_dir (string): Directory with all the images.\n", 273 | " transform (callable, optional): Optional transform to be applied\n", 274 | " on a sample.\n", 275 | " \"\"\"\n", 276 | " \n", 277 | " # basics\n", 278 | " self.transform = transform\n", 279 | " self.path = path\n", 280 | " self.patch_side = patch_side\n", 281 | " if not stride:\n", 282 | " self.stride = 1\n", 283 | " else:\n", 284 | " self.stride = stride\n", 285 | " \n", 286 | " if train:\n", 287 | " fname = 'train.txt'\n", 288 | " else:\n", 289 | " fname = 'test.txt'\n", 290 | " \n", 291 | "# print(path + fname)\n", 292 | " self.names = read_csv(path + fname).columns\n", 293 | " self.n_imgs = self.names.shape[0]\n", 294 | " \n", 295 | " n_pix = 0\n", 296 | " true_pix = 0\n", 297 | " \n", 298 | " \n", 299 | " # load images\n", 300 | " self.imgs_1 = {}\n", 301 | " self.imgs_2 = {}\n", 302 | " self.change_maps = {}\n", 303 | " self.n_patches_per_image = {}\n", 304 | " self.n_patches = 0\n", 305 | " self.patch_coords = []\n", 306 | " for im_name in tqdm(self.names):\n", 307 | " # load and store each image\n", 308 | " I1, I2, cm = read_sentinel_img_trio(self.path + im_name)\n", 309 | " self.imgs_1[im_name] = reshape_for_torch(I1)\n", 310 | " self.imgs_2[im_name] = reshape_for_torch(I2)\n", 311 | " self.change_maps[im_name] = cm\n", 312 | " \n", 313 | " s = cm.shape\n", 314 | " n_pix += np.prod(s)\n", 315 | " true_pix += cm.sum()\n", 316 | " \n", 317 | " # calculate the number of patches\n", 318 | " s = self.imgs_1[im_name].shape\n", 319 | " n1 = ceil((s[1] - self.patch_side + 1) / self.stride)\n", 320 | " n2 = ceil((s[2] - self.patch_side + 1) / self.stride)\n", 321 | " n_patches_i = n1 * n2\n", 322 | " self.n_patches_per_image[im_name] = n_patches_i\n", 323 | " self.n_patches += n_patches_i\n", 324 | " \n", 325 | " # generate path coordinates\n", 326 | " for i in range(n1):\n", 327 | " for j in range(n2):\n", 328 | " # coordinates in (x1, x2, y1, y2)\n", 329 | " current_patch_coords = (im_name, \n", 330 | " [self.stride*i, self.stride*i + self.patch_side, self.stride*j, self.stride*j + self.patch_side],\n", 331 | " [self.stride*(i + 1), self.stride*(j + 1)])\n", 332 | " self.patch_coords.append(current_patch_coords)\n", 333 | " \n", 334 | " self.weights = [ FP_MODIFIER * 2 * true_pix / n_pix, 2 * (n_pix - true_pix) / n_pix]\n", 335 | " \n", 336 | " \n", 337 | "\n", 338 | " def get_img(self, im_name):\n", 339 | " return self.imgs_1[im_name], self.imgs_2[im_name], self.change_maps[im_name]\n", 340 | "\n", 341 | " def __len__(self):\n", 342 | " return self.n_patches\n", 343 | "\n", 344 | " def __getitem__(self, idx):\n", 345 | " current_patch_coords = self.patch_coords[idx]\n", 346 | " im_name = current_patch_coords[0]\n", 347 | " limits = current_patch_coords[1]\n", 348 | " centre = current_patch_coords[2]\n", 349 | " \n", 350 | " I1 = self.imgs_1[im_name][:, limits[0]:limits[1], limits[2]:limits[3]]\n", 351 | " I2 = self.imgs_2[im_name][:, limits[0]:limits[1], limits[2]:limits[3]]\n", 352 | " \n", 353 | " label = self.change_maps[im_name][limits[0]:limits[1], limits[2]:limits[3]]\n", 354 | " label = torch.from_numpy(1*np.array(label)).float()\n", 355 | " \n", 356 | " sample = {'I1': I1, 'I2': I2, 'label': label}\n", 357 | " \n", 358 | " if self.transform:\n", 359 | " sample = self.transform(sample)\n", 360 | "\n", 361 | " return sample\n", 362 | "\n", 363 | "\n", 364 | "\n", 365 | "class RandomFlip(object):\n", 366 | " \"\"\"Flip randomly the images in a sample.\"\"\"\n", 367 | "\n", 368 | "# def __init__(self):\n", 369 | "# return\n", 370 | "\n", 371 | " def __call__(self, sample):\n", 372 | " I1, I2, label = sample['I1'], sample['I2'], sample['label']\n", 373 | " \n", 374 | " if random.random() > 0.5:\n", 375 | " I1 = I1.numpy()[:,:,::-1].copy()\n", 376 | " I1 = torch.from_numpy(I1)\n", 377 | " I2 = I2.numpy()[:,:,::-1].copy()\n", 378 | " I2 = torch.from_numpy(I2)\n", 379 | " label = label.numpy()[:,::-1].copy()\n", 380 | " label = torch.from_numpy(label)\n", 381 | "\n", 382 | " return {'I1': I1, 'I2': I2, 'label': label}\n", 383 | "\n", 384 | "\n", 385 | "\n", 386 | "class RandomRot(object):\n", 387 | " \"\"\"Rotate randomly the images in a sample.\"\"\"\n", 388 | "\n", 389 | "# def __init__(self):\n", 390 | "# return\n", 391 | "\n", 392 | " def __call__(self, sample):\n", 393 | " I1, I2, label = sample['I1'], sample['I2'], sample['label']\n", 394 | " \n", 395 | " n = random.randint(0, 3)\n", 396 | " if n:\n", 397 | " I1 = sample['I1'].numpy()\n", 398 | " I1 = np.rot90(I1, n, axes=(1, 2)).copy()\n", 399 | " I1 = torch.from_numpy(I1)\n", 400 | " I2 = sample['I2'].numpy()\n", 401 | " I2 = np.rot90(I2, n, axes=(1, 2)).copy()\n", 402 | " I2 = torch.from_numpy(I2)\n", 403 | " label = sample['label'].numpy()\n", 404 | " label = np.rot90(label, n, axes=(0, 1)).copy()\n", 405 | " label = torch.from_numpy(label)\n", 406 | "\n", 407 | " return {'I1': I1, 'I2': I2, 'label': label}\n", 408 | "\n", 409 | "\n", 410 | "\n", 411 | "\n", 412 | "\n", 413 | "print('UTILS OK')\n" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "metadata": {}, 420 | "outputs": [], 421 | "source": [] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [ 429 | "# Dataset\n", 430 | "\n", 431 | "\n", 432 | "if DATA_AUG:\n", 433 | " data_transform = tr.Compose([RandomFlip(), RandomRot()])\n", 434 | "else:\n", 435 | " data_transform = None\n", 436 | "\n", 437 | "\n", 438 | " \n", 439 | "\n", 440 | "train_dataset = ChangeDetectionDataset(PATH_TO_DATASET, train = True, stride = TRAIN_STRIDE, transform=data_transform)\n", 441 | "weights = torch.FloatTensor(train_dataset.weights).cuda()\n", 442 | "print(weights)\n", 443 | "train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)\n", 444 | "\n", 445 | "test_dataset = ChangeDetectionDataset(PATH_TO_DATASET, train = False, stride = TRAIN_STRIDE)\n", 446 | "test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers = 4)\n", 447 | "\n", 448 | "\n", 449 | "print('DATASETS OK')" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": {}, 456 | "outputs": [], 457 | "source": [ 458 | "# print(weights)" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "# 0-RGB | 1-RGBIr | 2-All bands s.t. resulution <= 20m | 3-All bands\n", 468 | "\n", 469 | "if TYPE == 0:\n", 470 | "# net, net_name = Unet(2*3, 2), 'FC-EF'\n", 471 | "# net, net_name = SiamUnet_conc(3, 2), 'FC-Siam-conc'\n", 472 | "# net, net_name = SiamUnet_diff(3, 2), 'FC-Siam-diff'\n", 473 | " net, net_name = FresUNet(2*3, 2), 'FresUNet'\n", 474 | "elif TYPE == 1:\n", 475 | "# net, net_name = Unet(2*4, 2), 'FC-EF'\n", 476 | "# net, net_name = SiamUnet_conc(4, 2), 'FC-Siam-conc'\n", 477 | "# net, net_name = SiamUnet_diff(4, 2), 'FC-Siam-diff'\n", 478 | " net, net_name = FresUNet(2*4, 2), 'FresUNet'\n", 479 | "elif TYPE == 2:\n", 480 | "# net, net_name = Unet(2*10, 2), 'FC-EF'\n", 481 | "# net, net_name = SiamUnet_conc(10, 2), 'FC-Siam-conc'\n", 482 | "# net, net_name = SiamUnet_diff(10, 2), 'FC-Siam-diff'\n", 483 | " net, net_name = FresUNet(2*10, 2), 'FresUNet'\n", 484 | "elif TYPE == 3:\n", 485 | "# net, net_name = Unet(2*13, 2), 'FC-EF'\n", 486 | "# net, net_name = SiamUnet_conc(13, 2), 'FC-Siam-conc'\n", 487 | "# net, net_name = SiamUnet_diff(13, 2), 'FC-Siam-diff'\n", 488 | " net, net_name = FresUNet(2*13, 2), 'FresUNet'\n", 489 | "\n", 490 | "\n", 491 | "net.cuda()\n", 492 | "\n", 493 | "criterion = nn.NLLLoss(weight=weights) # to be used with logsoftmax output\n", 494 | "\n", 495 | "print('NETWORK OK')" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "def count_parameters(model):\n", 505 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 506 | "\n", 507 | "print('Number of trainable parameters:', count_parameters(net))" 508 | ] 509 | }, 510 | { 511 | "cell_type": "code", 512 | "execution_count": null, 513 | "metadata": {}, 514 | "outputs": [], 515 | "source": [ 516 | "# net.load_state_dict(torch.load('net-best_epoch-1_fm-0.7394933126157746.pth.tar'))\n", 517 | "\n", 518 | "def train(n_epochs = N_EPOCHS, save = True):\n", 519 | " t = np.linspace(1, n_epochs, n_epochs)\n", 520 | " \n", 521 | " epoch_train_loss = 0 * t\n", 522 | " epoch_train_accuracy = 0 * t\n", 523 | " epoch_train_change_accuracy = 0 * t\n", 524 | " epoch_train_nochange_accuracy = 0 * t\n", 525 | " epoch_train_precision = 0 * t\n", 526 | " epoch_train_recall = 0 * t\n", 527 | " epoch_train_Fmeasure = 0 * t\n", 528 | " epoch_test_loss = 0 * t\n", 529 | " epoch_test_accuracy = 0 * t\n", 530 | " epoch_test_change_accuracy = 0 * t\n", 531 | " epoch_test_nochange_accuracy = 0 * t\n", 532 | " epoch_test_precision = 0 * t\n", 533 | " epoch_test_recall = 0 * t\n", 534 | " epoch_test_Fmeasure = 0 * t\n", 535 | " \n", 536 | "# mean_acc = 0\n", 537 | "# best_mean_acc = 0\n", 538 | " fm = 0\n", 539 | " best_fm = 0\n", 540 | " \n", 541 | " lss = 1000\n", 542 | " best_lss = 1000\n", 543 | " \n", 544 | " plt.figure(num=1)\n", 545 | " plt.figure(num=2)\n", 546 | " plt.figure(num=3)\n", 547 | " \n", 548 | " \n", 549 | " optimizer = torch.optim.Adam(net.parameters(), weight_decay=1e-4)\n", 550 | "# optimizer = torch.optim.Adam(net.parameters(), lr=0.0005)\n", 551 | " \n", 552 | " \n", 553 | " scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95)\n", 554 | " \n", 555 | " \n", 556 | " for epoch_index in tqdm(range(n_epochs)):\n", 557 | " net.train()\n", 558 | " print('Epoch: ' + str(epoch_index + 1) + ' of ' + str(N_EPOCHS))\n", 559 | "\n", 560 | " tot_count = 0\n", 561 | " tot_loss = 0\n", 562 | " tot_accurate = 0\n", 563 | " class_correct = list(0. for i in range(2))\n", 564 | " class_total = list(0. for i in range(2))\n", 565 | "# for batch_index, batch in enumerate(tqdm(data_loader)):\n", 566 | " for batch in train_loader:\n", 567 | " I1 = Variable(batch['I1'].float().cuda())\n", 568 | " I2 = Variable(batch['I2'].float().cuda())\n", 569 | " label = torch.squeeze(Variable(batch['label'].cuda()))\n", 570 | "\n", 571 | " optimizer.zero_grad()\n", 572 | " output = net(I1, I2)\n", 573 | " loss = criterion(output, label.long())\n", 574 | " loss.backward()\n", 575 | " optimizer.step()\n", 576 | " \n", 577 | " scheduler.step()\n", 578 | "\n", 579 | "\n", 580 | " epoch_train_loss[epoch_index], epoch_train_accuracy[epoch_index], cl_acc, pr_rec = test(train_dataset)\n", 581 | " epoch_train_nochange_accuracy[epoch_index] = cl_acc[0]\n", 582 | " epoch_train_change_accuracy[epoch_index] = cl_acc[1]\n", 583 | " epoch_train_precision[epoch_index] = pr_rec[0]\n", 584 | " epoch_train_recall[epoch_index] = pr_rec[1]\n", 585 | " epoch_train_Fmeasure[epoch_index] = pr_rec[2]\n", 586 | " \n", 587 | "# epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test(test_dataset)\n", 588 | " epoch_test_loss[epoch_index], epoch_test_accuracy[epoch_index], cl_acc, pr_rec = test(test_dataset)\n", 589 | " epoch_test_nochange_accuracy[epoch_index] = cl_acc[0]\n", 590 | " epoch_test_change_accuracy[epoch_index] = cl_acc[1]\n", 591 | " epoch_test_precision[epoch_index] = pr_rec[0]\n", 592 | " epoch_test_recall[epoch_index] = pr_rec[1]\n", 593 | " epoch_test_Fmeasure[epoch_index] = pr_rec[2]\n", 594 | "\n", 595 | " plt.figure(num=1)\n", 596 | " plt.clf()\n", 597 | " l1_1, = plt.plot(t[:epoch_index + 1], epoch_train_loss[:epoch_index + 1], label='Train loss')\n", 598 | " l1_2, = plt.plot(t[:epoch_index + 1], epoch_test_loss[:epoch_index + 1], label='Test loss')\n", 599 | " plt.legend(handles=[l1_1, l1_2])\n", 600 | " plt.grid()\n", 601 | "# plt.gcf().gca().set_ylim(bottom = 0)\n", 602 | " plt.gcf().gca().set_xlim(left = 0)\n", 603 | " plt.title('Loss')\n", 604 | " display.clear_output(wait=True)\n", 605 | " display.display(plt.gcf())\n", 606 | "\n", 607 | " plt.figure(num=2)\n", 608 | " plt.clf()\n", 609 | " l2_1, = plt.plot(t[:epoch_index + 1], epoch_train_accuracy[:epoch_index + 1], label='Train accuracy')\n", 610 | " l2_2, = plt.plot(t[:epoch_index + 1], epoch_test_accuracy[:epoch_index + 1], label='Test accuracy')\n", 611 | " plt.legend(handles=[l2_1, l2_2])\n", 612 | " plt.grid()\n", 613 | " plt.gcf().gca().set_ylim(0, 100)\n", 614 | "# plt.gcf().gca().set_ylim(bottom = 0)\n", 615 | "# plt.gcf().gca().set_xlim(left = 0)\n", 616 | " plt.title('Accuracy')\n", 617 | " display.clear_output(wait=True)\n", 618 | " display.display(plt.gcf())\n", 619 | "\n", 620 | " plt.figure(num=3)\n", 621 | " plt.clf()\n", 622 | " l3_1, = plt.plot(t[:epoch_index + 1], epoch_train_nochange_accuracy[:epoch_index + 1], label='Train accuracy: no change')\n", 623 | " l3_2, = plt.plot(t[:epoch_index + 1], epoch_train_change_accuracy[:epoch_index + 1], label='Train accuracy: change')\n", 624 | " l3_3, = plt.plot(t[:epoch_index + 1], epoch_test_nochange_accuracy[:epoch_index + 1], label='Test accuracy: no change')\n", 625 | " l3_4, = plt.plot(t[:epoch_index + 1], epoch_test_change_accuracy[:epoch_index + 1], label='Test accuracy: change')\n", 626 | " plt.legend(handles=[l3_1, l3_2, l3_3, l3_4])\n", 627 | " plt.grid()\n", 628 | " plt.gcf().gca().set_ylim(0, 100)\n", 629 | "# plt.gcf().gca().set_ylim(bottom = 0)\n", 630 | "# plt.gcf().gca().set_xlim(left = 0)\n", 631 | " plt.title('Accuracy per class')\n", 632 | " display.clear_output(wait=True)\n", 633 | " display.display(plt.gcf())\n", 634 | "\n", 635 | " plt.figure(num=4)\n", 636 | " plt.clf()\n", 637 | " l4_1, = plt.plot(t[:epoch_index + 1], epoch_train_precision[:epoch_index + 1], label='Train precision')\n", 638 | " l4_2, = plt.plot(t[:epoch_index + 1], epoch_train_recall[:epoch_index + 1], label='Train recall')\n", 639 | " l4_3, = plt.plot(t[:epoch_index + 1], epoch_train_Fmeasure[:epoch_index + 1], label='Train Dice/F1')\n", 640 | " l4_4, = plt.plot(t[:epoch_index + 1], epoch_test_precision[:epoch_index + 1], label='Test precision')\n", 641 | " l4_5, = plt.plot(t[:epoch_index + 1], epoch_test_recall[:epoch_index + 1], label='Test recall')\n", 642 | " l4_6, = plt.plot(t[:epoch_index + 1], epoch_test_Fmeasure[:epoch_index + 1], label='Test Dice/F1')\n", 643 | " plt.legend(handles=[l4_1, l4_2, l4_3, l4_4, l4_5, l4_6])\n", 644 | " plt.grid()\n", 645 | " plt.gcf().gca().set_ylim(0, 1)\n", 646 | "# plt.gcf().gca().set_ylim(bottom = 0)\n", 647 | "# plt.gcf().gca().set_xlim(left = 0)\n", 648 | " plt.title('Precision, Recall and F-measure')\n", 649 | " display.clear_output(wait=True)\n", 650 | " display.display(plt.gcf())\n", 651 | " \n", 652 | " \n", 653 | "# mean_acc = (epoch_test_nochange_accuracy[epoch_index] + epoch_test_change_accuracy[epoch_index])/2\n", 654 | "# if mean_acc > best_mean_acc:\n", 655 | "# best_mean_acc = mean_acc\n", 656 | "# save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_acc-' + str(mean_acc) + '.pth.tar'\n", 657 | "# torch.save(net.state_dict(), save_str)\n", 658 | " \n", 659 | " \n", 660 | "# fm = pr_rec[2]\n", 661 | " fm = epoch_train_Fmeasure[epoch_index]\n", 662 | " if fm > best_fm:\n", 663 | " best_fm = fm\n", 664 | " save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_fm-' + str(fm) + '.pth.tar'\n", 665 | " torch.save(net.state_dict(), save_str)\n", 666 | " \n", 667 | " lss = epoch_train_loss[epoch_index]\n", 668 | " if lss < best_lss:\n", 669 | " best_lss = lss\n", 670 | " save_str = 'net-best_epoch-' + str(epoch_index + 1) + '_loss-' + str(lss) + '.pth.tar'\n", 671 | " torch.save(net.state_dict(), save_str)\n", 672 | " \n", 673 | " \n", 674 | "# print('Epoch loss: ' + str(tot_loss/tot_count))\n", 675 | " if save:\n", 676 | " im_format = 'png'\n", 677 | " # im_format = 'eps'\n", 678 | "\n", 679 | " plt.figure(num=1)\n", 680 | " plt.savefig(net_name + '-01-loss.' + im_format)\n", 681 | "\n", 682 | " plt.figure(num=2)\n", 683 | " plt.savefig(net_name + '-02-accuracy.' + im_format)\n", 684 | "\n", 685 | " plt.figure(num=3)\n", 686 | " plt.savefig(net_name + '-03-accuracy-per-class.' + im_format)\n", 687 | "\n", 688 | " plt.figure(num=4)\n", 689 | " plt.savefig(net_name + '-04-prec-rec-fmeas.' + im_format)\n", 690 | " \n", 691 | " out = {'train_loss': epoch_train_loss[-1],\n", 692 | " 'train_accuracy': epoch_train_accuracy[-1],\n", 693 | " 'train_nochange_accuracy': epoch_train_nochange_accuracy[-1],\n", 694 | " 'train_change_accuracy': epoch_train_change_accuracy[-1],\n", 695 | " 'test_loss': epoch_test_loss[-1],\n", 696 | " 'test_accuracy': epoch_test_accuracy[-1],\n", 697 | " 'test_nochange_accuracy': epoch_test_nochange_accuracy[-1],\n", 698 | " 'test_change_accuracy': epoch_test_change_accuracy[-1]}\n", 699 | " \n", 700 | " print('pr_c, rec_c, f_meas, pr_nc, rec_nc')\n", 701 | " print(pr_rec)\n", 702 | " \n", 703 | " return out\n", 704 | "\n", 705 | "L = 1024\n", 706 | "N = 2\n", 707 | "\n", 708 | "def test(dset):\n", 709 | " net.eval()\n", 710 | " tot_loss = 0\n", 711 | " tot_count = 0\n", 712 | " tot_accurate = 0\n", 713 | " \n", 714 | " n = 2\n", 715 | " class_correct = list(0. for i in range(n))\n", 716 | " class_total = list(0. for i in range(n))\n", 717 | " class_accuracy = list(0. for i in range(n))\n", 718 | " \n", 719 | " tp = 0\n", 720 | " tn = 0\n", 721 | " fp = 0\n", 722 | " fn = 0\n", 723 | " \n", 724 | " for img_index in dset.names:\n", 725 | " I1_full, I2_full, cm_full = dset.get_img(img_index)\n", 726 | " \n", 727 | " s = cm_full.shape\n", 728 | " \n", 729 | "\n", 730 | " steps0 = np.arange(0,s[0],ceil(s[0]/N))\n", 731 | " steps1 = np.arange(0,s[1],ceil(s[1]/N))\n", 732 | " for ii in range(N):\n", 733 | " for jj in range(N):\n", 734 | " xmin = steps0[ii]\n", 735 | " if ii == N-1:\n", 736 | " xmax = s[0]\n", 737 | " else:\n", 738 | " xmax = steps0[ii+1]\n", 739 | " ymin = jj\n", 740 | " if jj == N-1:\n", 741 | " ymax = s[1]\n", 742 | " else:\n", 743 | " ymax = steps1[jj+1]\n", 744 | " I1 = I1_full[:, xmin:xmax, ymin:ymax]\n", 745 | " I2 = I2_full[:, xmin:xmax, ymin:ymax]\n", 746 | " cm = cm_full[xmin:xmax, ymin:ymax]\n", 747 | "\n", 748 | " I1 = Variable(torch.unsqueeze(I1, 0).float()).cuda()\n", 749 | " I2 = Variable(torch.unsqueeze(I2, 0).float()).cuda()\n", 750 | " cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm),0).float()).cuda()\n", 751 | "\n", 752 | "\n", 753 | " output = net(I1, I2)\n", 754 | " loss = criterion(output, cm.long())\n", 755 | " # print(loss)\n", 756 | " tot_loss += loss.data * np.prod(cm.size())\n", 757 | " tot_count += np.prod(cm.size())\n", 758 | "\n", 759 | " _, predicted = torch.max(output.data, 1)\n", 760 | "\n", 761 | " c = (predicted.int() == cm.data.int())\n", 762 | " for i in range(c.size(1)):\n", 763 | " for j in range(c.size(2)):\n", 764 | " l = int(cm.data[0, i, j])\n", 765 | " class_correct[l] += c[0, i, j]\n", 766 | " class_total[l] += 1\n", 767 | " \n", 768 | " pr = (predicted.int() > 0).cpu().numpy()\n", 769 | " gt = (cm.data.int() > 0).cpu().numpy()\n", 770 | " \n", 771 | " tp += np.logical_and(pr, gt).sum()\n", 772 | " tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()\n", 773 | " fp += np.logical_and(pr, np.logical_not(gt)).sum()\n", 774 | " fn += np.logical_and(np.logical_not(pr), gt).sum()\n", 775 | " \n", 776 | " net_loss = tot_loss/tot_count\n", 777 | " net_accuracy = 100 * (tp + tn)/tot_count\n", 778 | " \n", 779 | " for i in range(n):\n", 780 | " class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)\n", 781 | "\n", 782 | " prec = tp / (tp + fp)\n", 783 | " rec = tp / (tp + fn)\n", 784 | " f_meas = 2 * prec * rec / (prec + rec)\n", 785 | " prec_nc = tn / (tn + fn)\n", 786 | " rec_nc = tn / (tn + fp)\n", 787 | " \n", 788 | " pr_rec = [prec, rec, f_meas, prec_nc, rec_nc]\n", 789 | " \n", 790 | " return net_loss, net_accuracy, class_accuracy, pr_rec\n", 791 | " \n", 792 | " \n", 793 | "\n", 794 | "\n", 795 | "\n", 796 | " \n" 797 | ] 798 | }, 799 | { 800 | "cell_type": "code", 801 | "execution_count": null, 802 | "metadata": {}, 803 | "outputs": [], 804 | "source": [ 805 | "if LOAD_TRAINED:\n", 806 | " net.load_state_dict(torch.load('net_final.pth.tar'))\n", 807 | " print('LOAD OK')\n", 808 | "else:\n", 809 | " t_start = time.time()\n", 810 | " out_dic = train()\n", 811 | " t_end = time.time()\n", 812 | " print(out_dic)\n", 813 | " print('Elapsed time:')\n", 814 | " print(t_end - t_start)\n", 815 | "\n", 816 | "\n" 817 | ] 818 | }, 819 | { 820 | "cell_type": "code", 821 | "execution_count": null, 822 | "metadata": {}, 823 | "outputs": [], 824 | "source": [ 825 | "if not LOAD_TRAINED:\n", 826 | " torch.save(net.state_dict(), 'net_final.pth.tar')\n", 827 | " print('SAVE OK')" 828 | ] 829 | }, 830 | { 831 | "cell_type": "code", 832 | "execution_count": null, 833 | "metadata": {}, 834 | "outputs": [], 835 | "source": [ 836 | "\n", 837 | "\n", 838 | "def save_test_results(dset):\n", 839 | " for name in tqdm(dset.names):\n", 840 | " with warnings.catch_warnings():\n", 841 | " I1, I2, cm = dset.get_img(name)\n", 842 | " I1 = Variable(torch.unsqueeze(I1, 0).float()).cuda()\n", 843 | " I2 = Variable(torch.unsqueeze(I2, 0).float()).cuda()\n", 844 | " out = net(I1, I2)\n", 845 | " _, predicted = torch.max(out.data, 1)\n", 846 | " I = np.stack((255*cm,255*np.squeeze(predicted.cpu().numpy()),255*cm),2)\n", 847 | " io.imsave(f'{net_name}-{name}.png',I)\n", 848 | "\n", 849 | "\n", 850 | "\n", 851 | "t_start = time.time()\n", 852 | "# save_test_results(train_dataset)\n", 853 | "save_test_results(test_dataset)\n", 854 | "t_end = time.time()\n", 855 | "print('Elapsed time: {}'.format(t_end - t_start))\n" 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": null, 861 | "metadata": {}, 862 | "outputs": [], 863 | "source": [ 864 | "L = 1024\n", 865 | "\n", 866 | "def kappa(tp, tn, fp, fn):\n", 867 | " N = tp + tn + fp + fn\n", 868 | " p0 = (tp + tn) / N\n", 869 | " pe = ((tp+fp)*(tp+fn) + (tn+fp)*(tn+fn)) / (N * N)\n", 870 | " \n", 871 | " return (p0 - pe) / (1 - pe)\n", 872 | "\n", 873 | "def test(dset):\n", 874 | " net.eval()\n", 875 | " tot_loss = 0\n", 876 | " tot_count = 0\n", 877 | " tot_accurate = 0\n", 878 | " \n", 879 | " n = 2\n", 880 | " class_correct = list(0. for i in range(n))\n", 881 | " class_total = list(0. for i in range(n))\n", 882 | " class_accuracy = list(0. for i in range(n))\n", 883 | " \n", 884 | " tp = 0\n", 885 | " tn = 0\n", 886 | " fp = 0\n", 887 | " fn = 0\n", 888 | " \n", 889 | " for img_index in tqdm(dset.names):\n", 890 | " I1_full, I2_full, cm_full = dset.get_img(img_index)\n", 891 | " \n", 892 | " s = cm_full.shape\n", 893 | " \n", 894 | " for ii in range(ceil(s[0]/L)):\n", 895 | " for jj in range(ceil(s[1]/L)):\n", 896 | " xmin = L*ii\n", 897 | " xmax = min(L*(ii+1),s[1])\n", 898 | " ymin = L*jj\n", 899 | " ymax = min(L*(jj+1),s[1])\n", 900 | " I1 = I1_full[:, xmin:xmax, ymin:ymax]\n", 901 | " I2 = I2_full[:, xmin:xmax, ymin:ymax]\n", 902 | " cm = cm_full[xmin:xmax, ymin:ymax]\n", 903 | "\n", 904 | " I1 = Variable(torch.unsqueeze(I1, 0).float()).cuda()\n", 905 | " I2 = Variable(torch.unsqueeze(I2, 0).float()).cuda()\n", 906 | " cm = Variable(torch.unsqueeze(torch.from_numpy(1.0*cm),0).float()).cuda()\n", 907 | "\n", 908 | " output = net(I1, I2)\n", 909 | " \n", 910 | " loss = criterion(output, cm.long())\n", 911 | " tot_loss += loss.data * np.prod(cm.size())\n", 912 | " tot_count += np.prod(cm.size())\n", 913 | "\n", 914 | " _, predicted = torch.max(output.data, 1)\n", 915 | "\n", 916 | " c = (predicted.int() == cm.data.int())\n", 917 | " for i in range(c.size(1)):\n", 918 | " for j in range(c.size(2)):\n", 919 | " l = int(cm.data[0, i, j])\n", 920 | " class_correct[l] += c[0, i, j]\n", 921 | " class_total[l] += 1\n", 922 | " \n", 923 | " pr = (predicted.int() > 0).cpu().numpy()\n", 924 | " gt = (cm.data.int() > 0).cpu().numpy()\n", 925 | " \n", 926 | " tp += np.logical_and(pr, gt).sum()\n", 927 | " tn += np.logical_and(np.logical_not(pr), np.logical_not(gt)).sum()\n", 928 | " fp += np.logical_and(pr, np.logical_not(gt)).sum()\n", 929 | " fn += np.logical_and(np.logical_not(pr), gt).sum()\n", 930 | " \n", 931 | " net_loss = tot_loss/tot_count \n", 932 | " net_loss = float(net_loss.cpu().numpy())\n", 933 | " \n", 934 | " net_accuracy = 100 * (tp + tn)/tot_count\n", 935 | " \n", 936 | " for i in range(n):\n", 937 | " class_accuracy[i] = 100 * class_correct[i] / max(class_total[i],0.00001)\n", 938 | " class_accuracy[i] = float(class_accuracy[i].cpu().numpy())\n", 939 | "\n", 940 | " prec = tp / (tp + fp)\n", 941 | " rec = tp / (tp + fn)\n", 942 | " dice = 2 * prec * rec / (prec + rec)\n", 943 | " prec_nc = tn / (tn + fn)\n", 944 | " rec_nc = tn / (tn + fp)\n", 945 | " \n", 946 | " pr_rec = [prec, rec, dice, prec_nc, rec_nc]\n", 947 | " \n", 948 | " k = kappa(tp, tn, fp, fn)\n", 949 | " \n", 950 | " return {'net_loss': net_loss, \n", 951 | " 'net_accuracy': net_accuracy, \n", 952 | " 'class_accuracy': class_accuracy, \n", 953 | " 'precision': prec, \n", 954 | " 'recall': rec, \n", 955 | " 'dice': dice, \n", 956 | " 'kappa': k}\n", 957 | "\n", 958 | "results = test(test_dataset)\n", 959 | "pprint(results)" 960 | ] 961 | }, 962 | { 963 | "cell_type": "code", 964 | "execution_count": null, 965 | "metadata": {}, 966 | "outputs": [], 967 | "source": [] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "execution_count": null, 972 | "metadata": {}, 973 | "outputs": [], 974 | "source": [] 975 | }, 976 | { 977 | "cell_type": "code", 978 | "execution_count": null, 979 | "metadata": {}, 980 | "outputs": [], 981 | "source": [] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": null, 986 | "metadata": {}, 987 | "outputs": [], 988 | "source": [] 989 | }, 990 | { 991 | "cell_type": "code", 992 | "execution_count": null, 993 | "metadata": {}, 994 | "outputs": [], 995 | "source": [] 996 | } 997 | ], 998 | "metadata": { 999 | "kernelspec": { 1000 | "display_name": "Python 3", 1001 | "language": "python", 1002 | "name": "python3" 1003 | }, 1004 | "language_info": { 1005 | "codemirror_mode": { 1006 | "name": "ipython", 1007 | "version": 3 1008 | }, 1009 | "file_extension": ".py", 1010 | "mimetype": "text/x-python", 1011 | "name": "python", 1012 | "nbconvert_exporter": "python", 1013 | "pygments_lexer": "ipython3", 1014 | "version": "3.7.6" 1015 | } 1016 | }, 1017 | "nbformat": 4, 1018 | "nbformat_minor": 4 1019 | } 1020 | -------------------------------------------------------------------------------- /siamunet_conc.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_conc(nn.Module): 11 | """SiamUnet_conc segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_conc, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | """Forward method.""" 97 | # Stage 1 98 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 99 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 100 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 101 | 102 | 103 | # Stage 2 104 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 105 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 106 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 107 | 108 | # Stage 3 109 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 110 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 111 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 112 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 113 | 114 | # Stage 4 115 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 116 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 117 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 118 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 119 | 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | # Stage 2 128 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 129 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 130 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 131 | 132 | # Stage 3 133 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 134 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 135 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 136 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 137 | 138 | # Stage 4 139 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 140 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 141 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 142 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 143 | 144 | 145 | #################################################### 146 | # Stage 4d 147 | x4d = self.upconv4(x4p) 148 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 149 | x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1) 150 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 151 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 152 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 153 | 154 | # Stage 3d 155 | x3d = self.upconv3(x41d) 156 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 157 | x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1) 158 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 159 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 160 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 161 | 162 | # Stage 2d 163 | x2d = self.upconv2(x31d) 164 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 165 | x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1) 166 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 167 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 168 | 169 | # Stage 1d 170 | x1d = self.upconv1(x21d) 171 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 172 | x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1) 173 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 174 | x11d = self.conv11d(x12d) 175 | 176 | return self.sm(x11d) 177 | 178 | 179 | -------------------------------------------------------------------------------- /siamunet_diff.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class SiamUnet_diff(nn.Module): 11 | """SiamUnet_diff segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(SiamUnet_diff, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 53 | 54 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 55 | self.bn43d = nn.BatchNorm2d(128) 56 | self.do43d = nn.Dropout2d(p=0.2) 57 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 58 | self.bn42d = nn.BatchNorm2d(128) 59 | self.do42d = nn.Dropout2d(p=0.2) 60 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 61 | self.bn41d = nn.BatchNorm2d(64) 62 | self.do41d = nn.Dropout2d(p=0.2) 63 | 64 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 65 | 66 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 67 | self.bn33d = nn.BatchNorm2d(64) 68 | self.do33d = nn.Dropout2d(p=0.2) 69 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 70 | self.bn32d = nn.BatchNorm2d(64) 71 | self.do32d = nn.Dropout2d(p=0.2) 72 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 73 | self.bn31d = nn.BatchNorm2d(32) 74 | self.do31d = nn.Dropout2d(p=0.2) 75 | 76 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 77 | 78 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 79 | self.bn22d = nn.BatchNorm2d(32) 80 | self.do22d = nn.Dropout2d(p=0.2) 81 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 82 | self.bn21d = nn.BatchNorm2d(16) 83 | self.do21d = nn.Dropout2d(p=0.2) 84 | 85 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 86 | 87 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 88 | self.bn12d = nn.BatchNorm2d(16) 89 | self.do12d = nn.Dropout2d(p=0.2) 90 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 91 | 92 | self.sm = nn.LogSoftmax(dim=1) 93 | 94 | def forward(self, x1, x2): 95 | 96 | 97 | """Forward method.""" 98 | # Stage 1 99 | x11 = self.do11(F.relu(self.bn11(self.conv11(x1)))) 100 | x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 101 | x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2) 102 | 103 | 104 | # Stage 2 105 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 106 | x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 107 | x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2) 108 | 109 | # Stage 3 110 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 111 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 112 | x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 113 | x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2) 114 | 115 | # Stage 4 116 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 117 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 118 | x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 119 | x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2) 120 | 121 | #################################################### 122 | # Stage 1 123 | x11 = self.do11(F.relu(self.bn11(self.conv11(x2)))) 124 | x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 125 | x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2) 126 | 127 | 128 | # Stage 2 129 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 130 | x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 131 | x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2) 132 | 133 | # Stage 3 134 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 135 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 136 | x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 137 | x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2) 138 | 139 | # Stage 4 140 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 141 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 142 | x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 143 | x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2) 144 | 145 | 146 | 147 | # Stage 4d 148 | x4d = self.upconv4(x4p) 149 | pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2))) 150 | x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1) 151 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 152 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 153 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 154 | 155 | # Stage 3d 156 | x3d = self.upconv3(x41d) 157 | pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2))) 158 | x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1) 159 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 160 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 161 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 162 | 163 | # Stage 2d 164 | x2d = self.upconv2(x31d) 165 | pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2))) 166 | x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1) 167 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 168 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 169 | 170 | # Stage 1d 171 | x1d = self.upconv1(x21d) 172 | pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2))) 173 | x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1) 174 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 175 | x11d = self.conv11d(x12d) 176 | 177 | return self.sm(x11d) 178 | 179 | 180 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | # Rodrigo Caye Daudt 2 | # https://rcdaudt.github.io/ 3 | # Daudt, R. C., Le Saux, B., & Boulch, A. "Fully convolutional siamese networks for change detection". In 2018 25th IEEE International Conference on Image Processing (ICIP) (pp. 4063-4067). IEEE. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.padding import ReplicationPad2d 9 | 10 | class Unet(nn.Module): 11 | """EF segmentation network.""" 12 | 13 | def __init__(self, input_nbr, label_nbr): 14 | super(Unet, self).__init__() 15 | 16 | self.input_nbr = input_nbr 17 | 18 | self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1) 19 | self.bn11 = nn.BatchNorm2d(16) 20 | self.do11 = nn.Dropout2d(p=0.2) 21 | self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1) 22 | self.bn12 = nn.BatchNorm2d(16) 23 | self.do12 = nn.Dropout2d(p=0.2) 24 | 25 | self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1) 26 | self.bn21 = nn.BatchNorm2d(32) 27 | self.do21 = nn.Dropout2d(p=0.2) 28 | self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1) 29 | self.bn22 = nn.BatchNorm2d(32) 30 | self.do22 = nn.Dropout2d(p=0.2) 31 | 32 | self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 33 | self.bn31 = nn.BatchNorm2d(64) 34 | self.do31 = nn.Dropout2d(p=0.2) 35 | self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 36 | self.bn32 = nn.BatchNorm2d(64) 37 | self.do32 = nn.Dropout2d(p=0.2) 38 | self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 39 | self.bn33 = nn.BatchNorm2d(64) 40 | self.do33 = nn.Dropout2d(p=0.2) 41 | 42 | self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 43 | self.bn41 = nn.BatchNorm2d(128) 44 | self.do41 = nn.Dropout2d(p=0.2) 45 | self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 46 | self.bn42 = nn.BatchNorm2d(128) 47 | self.do42 = nn.Dropout2d(p=0.2) 48 | self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 49 | self.bn43 = nn.BatchNorm2d(128) 50 | self.do43 = nn.Dropout2d(p=0.2) 51 | 52 | 53 | self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1) 54 | 55 | self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1) 56 | self.bn43d = nn.BatchNorm2d(128) 57 | self.do43d = nn.Dropout2d(p=0.2) 58 | self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1) 59 | self.bn42d = nn.BatchNorm2d(128) 60 | self.do42d = nn.Dropout2d(p=0.2) 61 | self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 62 | self.bn41d = nn.BatchNorm2d(64) 63 | self.do41d = nn.Dropout2d(p=0.2) 64 | 65 | self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1) 66 | 67 | self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1) 68 | self.bn33d = nn.BatchNorm2d(64) 69 | self.do33d = nn.Dropout2d(p=0.2) 70 | self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1) 71 | self.bn32d = nn.BatchNorm2d(64) 72 | self.do32d = nn.Dropout2d(p=0.2) 73 | self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 74 | self.bn31d = nn.BatchNorm2d(32) 75 | self.do31d = nn.Dropout2d(p=0.2) 76 | 77 | self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1) 78 | 79 | self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1) 80 | self.bn22d = nn.BatchNorm2d(32) 81 | self.do22d = nn.Dropout2d(p=0.2) 82 | self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 83 | self.bn21d = nn.BatchNorm2d(16) 84 | self.do21d = nn.Dropout2d(p=0.2) 85 | 86 | self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1) 87 | 88 | self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1) 89 | self.bn12d = nn.BatchNorm2d(16) 90 | self.do12d = nn.Dropout2d(p=0.2) 91 | self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1) 92 | 93 | self.sm = nn.LogSoftmax(dim=1) 94 | 95 | def forward(self, x1, x2): 96 | 97 | x = torch.cat((x1, x2), 1) 98 | 99 | """Forward method.""" 100 | # Stage 1 101 | x11 = self.do11(F.relu(self.bn11(self.conv11(x)))) 102 | x12 = self.do12(F.relu(self.bn12(self.conv12(x11)))) 103 | x1p = F.max_pool2d(x12, kernel_size=2, stride=2) 104 | 105 | # Stage 2 106 | x21 = self.do21(F.relu(self.bn21(self.conv21(x1p)))) 107 | x22 = self.do22(F.relu(self.bn22(self.conv22(x21)))) 108 | x2p = F.max_pool2d(x22, kernel_size=2, stride=2) 109 | 110 | # Stage 3 111 | x31 = self.do31(F.relu(self.bn31(self.conv31(x2p)))) 112 | x32 = self.do32(F.relu(self.bn32(self.conv32(x31)))) 113 | x33 = self.do33(F.relu(self.bn33(self.conv33(x32)))) 114 | x3p = F.max_pool2d(x33, kernel_size=2, stride=2) 115 | 116 | # Stage 4 117 | x41 = self.do41(F.relu(self.bn41(self.conv41(x3p)))) 118 | x42 = self.do42(F.relu(self.bn42(self.conv42(x41)))) 119 | x43 = self.do43(F.relu(self.bn43(self.conv43(x42)))) 120 | x4p = F.max_pool2d(x43, kernel_size=2, stride=2) 121 | 122 | 123 | # Stage 4d 124 | x4d = self.upconv4(x4p) 125 | pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2))) 126 | x4d = torch.cat((pad4(x4d), x43), 1) 127 | x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d)))) 128 | x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d)))) 129 | x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d)))) 130 | 131 | # Stage 3d 132 | x3d = self.upconv3(x41d) 133 | pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2))) 134 | x3d = torch.cat((pad3(x3d), x33), 1) 135 | x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d)))) 136 | x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d)))) 137 | x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d)))) 138 | 139 | # Stage 2d 140 | x2d = self.upconv2(x31d) 141 | pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2))) 142 | x2d = torch.cat((pad2(x2d), x22), 1) 143 | x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d)))) 144 | x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d)))) 145 | 146 | # Stage 1d 147 | x1d = self.upconv1(x21d) 148 | pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2))) 149 | x1d = torch.cat((pad1(x1d), x12), 1) 150 | x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d)))) 151 | x11d = self.conv11d(x12d) 152 | 153 | return self.sm(x11d) 154 | 155 | 156 | --------------------------------------------------------------------------------