├── .gitignore ├── LICENSE ├── README.md ├── networks ├── __init__.py └── network.py ├── recursive_noise_diffusion_diagram.svg ├── requirements.txt ├── rnd.yml ├── test.py ├── train.py └── utils ├── __init__.py ├── cityscapes_loader.py ├── evaluation.py ├── pascal_voc_loader.py ├── trainer.py ├── uavid_loader.py ├── utils.py └── vaihingen_buildings_loader.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Output 132 | output/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Benedikt Kolbeinsson 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recursive Noise Diffusion 2 | 3 | ![Recursive Noise Diffusion diagram](recursive_noise_diffusion_diagram.svg?raw=true "Recursive Noise Diffusion") 4 | 5 | This repo is the official implementation of 6 | [Multi-Class Segmentation from Aerial Views using Recursive Noise Diffusion](https://arxiv.org/abs/2212.00787) 7 | 8 | The core idea of Recursive Noise Diffusion is the _recursive denoising_ process, as shown in the figure above. 9 | Training with _recursive denoising_ involves progressing through each time step t from T to 1, recursively (as the name suggests), which allows a portion of the predicted error to propagate. 10 | This process is initialised with pure noise. The noise function diffuses the previous predicted segmentation, then the model denoises this diffused segmentation given a conditioning RGB image. The denoised predicted segmentation is compared to the ground truth. Notably, the ground truth segmentation is never used as part of the input to the model. This process is agnostic to the choice of noise function, diffusion model and loss. 11 | 12 | ## Getting Started 13 | 14 | ### Setup 15 | 16 | - Clone this repo: 17 | 18 | ```bash 19 | git clone https://github.com/benediktkol/recursive-noise-diffusion.git 20 | cd recursive-noise-diffusion 21 | ``` 22 | 23 | - Install requirements: 24 | 25 | ```bash 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | - [Optional] Create a conda environment: 30 | 31 | ```bash 32 | conda env create -f rnd.yml 33 | conda activate rnd 34 | ``` 35 | 36 | - Download data 37 | - [Vaihingen Buildings](https://drive.google.com/open?id=1nenpWH4BdplSiHdfXs0oYfiA5qL42plB) 38 | - [UAVid](https://uavid.nl) 39 | 40 | - File structure 41 | 42 | ```bash 43 | data 44 | ├── UAVid 45 | │ ├── uavid_test 46 | │ │ ├── seq21 47 | │ │ │ └── Images 48 | │ │ │ ├── 000000.png 49 | │ │ │ └── ... 50 | │ │ └── ... 51 | │ ├── uavid_train 52 | │ │ ├── seq1 53 | │ │ │ ├── Images 54 | │ │ │ │ ├── 000000.png 55 | │ │ │ │ └── ... 56 | │ │ │ └── Labels 57 | │ │ │ ├── 000000.png 58 | │ │ │ └── ... 59 | │ │ └── ... 60 | │ └── uavid_val 61 | │ ├── seq16 62 | │ │ ├── Images 63 | │ │ │ ├── 000000.png 64 | │ │ │ └── ... 65 | │ │ └── Labels 66 | │ │ ├── 000000.png 67 | │ │ └── ... 68 | │ └── ... 69 | └── Vaihingen_buildings 70 | ├── all_buildings_mask_001.png 71 | ├── ... 72 | ├── building_001.png 73 | ├── ... 74 | ├── building_gt_001.png 75 | ├── ... 76 | ├── building_mask_001.png 77 | └── ... 78 | recursive-noise-diffusion (this repo) 79 | ├── test.py 80 | ├── train.py 81 | └── ... 82 | ``` 83 | 84 | 85 | ### Train 86 | 87 | To train a model use ```train.py```, for example: 88 | 89 | ```bash 90 | python train.py --dataset vaihingen --scale_procedure loop --n_scales 3 --n_timesteps 25 91 | ``` 92 | 93 | ### Evaluation 94 | 95 | To evaluate a model use ```test.py```, for example: 96 | 97 | ```bash 98 | python test.py --load_checkpoint /path/to/checkpoint.pt --dataset vaihingen --scale_procedure loop --n_scales 3 --n_timesteps 25 99 | ``` 100 | 101 | ## Cite 102 | 103 | ``` 104 | @InProceedings{Kolbeinsson_2024_WACV, 105 | author = {Kolbeinsson, Benedikt and Mikolajczyk, Krystian}, 106 | title = {Multi-Class Segmentation From Aerial Views Using Recursive Noise Diffusion}, 107 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 108 | month = {January}, 109 | year = {2024}, 110 | pages = {8439-8449} 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benediktkol/recursive-noise-diffusion/183696b2d26f245a35a019980fca9adc20f0f4ef/networks/__init__.py -------------------------------------------------------------------------------- /networks/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange 6 | from torch import nn, einsum 7 | 8 | 9 | def exists(x): 10 | return x is not None 11 | 12 | class Residual(nn.Module): 13 | def __init__(self, fn): 14 | super().__init__() 15 | self.fn = fn 16 | 17 | def forward(self, x, *args, **kwargs): 18 | return self.fn(x, *args, **kwargs) + x 19 | 20 | def Upsample(dim_in, dim_out): 21 | return nn.ConvTranspose2d(dim_in, dim_out, 4, 2, 1) 22 | 23 | def Downsample(dim_in, dim_out): 24 | return nn.Conv2d(dim_in, dim_out, 4, 2, 1) 25 | 26 | class SinusoidalPositionEmbeddings(nn.Module): 27 | def __init__(self, dim): 28 | super().__init__() 29 | self.dim = dim 30 | 31 | def forward(self, time): 32 | device = time.device 33 | half_dim = self.dim // 2 34 | embeddings = math.log(10000) / (half_dim - 1) 35 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 36 | embeddings = time[:, None] * embeddings[None, :] 37 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 38 | return embeddings 39 | 40 | class Block(nn.Module): 41 | def __init__(self, dim, dim_out, groups = 8): 42 | super().__init__() 43 | self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) 44 | self.norm = nn.GroupNorm(groups, dim_out) 45 | self.act = nn.SiLU() 46 | 47 | def forward(self, x): 48 | x = self.proj(x) 49 | x = self.norm(x) 50 | x = self.act(x) 51 | return x 52 | 53 | 54 | class ResNetBlock(nn.Module): 55 | """https://arxiv.org/abs/1512.03385""" 56 | 57 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): 58 | super().__init__() 59 | self.mlp = ( 60 | nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) 61 | if exists(time_emb_dim) 62 | else None 63 | ) 64 | 65 | self.block1 = Block(dim, dim_out, groups=groups) 66 | self.block2 = Block(dim_out, dim_out, groups=groups) 67 | self.res_conv = nn.Conv2d(dim, dim_out, 1) #if dim != dim_out else nn.Identity() 68 | 69 | def forward(self, x, time_emb=None): 70 | h = self.block1(x) 71 | 72 | if exists(self.mlp) and exists(time_emb): 73 | time_emb = self.mlp(time_emb) 74 | h = rearrange(time_emb, "b c -> b c 1 1") + h 75 | 76 | h = self.block2(h) 77 | return h + self.res_conv(x) 78 | 79 | class Attention(nn.Module): 80 | def __init__(self, dim, heads=4, dim_head=32): 81 | super().__init__() 82 | self.scale = dim_head**-0.5 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x).chunk(3, dim=1) 91 | q, k, v = map( 92 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 93 | ) 94 | q = q * self.scale 95 | 96 | sim = einsum("b h d i, b h d j -> b h i j", q, k) 97 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 98 | attn = sim.softmax(dim=-1) 99 | 100 | out = einsum("b h i j, b h d j -> b h i d", attn, v) 101 | out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) 102 | return self.to_out(out) 103 | 104 | class LinearAttention(nn.Module): 105 | def __init__(self, dim, heads=4, dim_head=32): 106 | super().__init__() 107 | self.scale = dim_head**-0.5 108 | self.heads = heads 109 | hidden_dim = dim_head * heads 110 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 111 | 112 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 113 | nn.GroupNorm(1, dim)) 114 | 115 | def forward(self, x): 116 | b, c, h, w = x.shape 117 | qkv = self.to_qkv(x).chunk(3, dim=1) 118 | q, k, v = map( 119 | lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv 120 | ) 121 | 122 | q = q.softmax(dim=-2) 123 | k = k.softmax(dim=-1) 124 | 125 | q = q * self.scale 126 | context = torch.einsum("b h d n, b h e n -> b h d e", k, v) 127 | 128 | out = torch.einsum("b h d e, b h d n -> b h e n", context, q) 129 | out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) 130 | return self.to_out(out) 131 | 132 | class PreNorm(nn.Module): 133 | def __init__(self, dim, fn): 134 | super().__init__() 135 | self.fn = fn 136 | self.norm = nn.GroupNorm(1, dim) 137 | 138 | def forward(self, x): 139 | x = self.norm(x) 140 | return self.fn(x) 141 | 142 | 143 | class NetworkConfig: 144 | """Configuration for the network.""" 145 | # Default configuration 146 | image_channels=3 147 | n_classes=19 148 | dim=32 149 | dim_mults=(1, 2, 4, 8) 150 | resnet_block_groups=8 151 | 152 | # diffusion parameters 153 | n_timesteps = 10 154 | n_scales = 3 155 | max_patch_size = 512 156 | scale_procedure = "loop" # "linear" or "loop" 157 | 158 | # ensemble parameters 159 | built_in_ensemble = False 160 | 161 | def __init__(self, **kwargs): 162 | for k,v in kwargs.items(): 163 | setattr(self, k, v) 164 | 165 | 166 | class Network(nn.Module): 167 | def __init__( 168 | self, 169 | network_config=NetworkConfig(), 170 | ): 171 | super().__init__() 172 | self.config = network_config 173 | image_channels = self.config.image_channels 174 | n_classes = self.config.n_classes 175 | dim = self.config.dim 176 | dim_mults = self.config.dim_mults 177 | resnet_block_groups = self.config.resnet_block_groups 178 | 179 | # determine dimensions 180 | self.image_channels = image_channels 181 | self.n_classes = n_classes 182 | self.dims = [c * dim for c in dim_mults] 183 | 184 | # time embedding 185 | time_dim = dim * 4 186 | self.time_mlp = nn.Sequential( 187 | SinusoidalPositionEmbeddings(dim), 188 | nn.Linear(dim, time_dim), 189 | nn.GELU(), 190 | nn.Linear(time_dim, time_dim), 191 | ) 192 | 193 | # image initial 194 | self.image_initial = nn.ModuleList([ 195 | ResNetBlock(image_channels, self.dims[0], time_emb_dim=time_dim, groups=resnet_block_groups), 196 | ResNetBlock(self.dims[0], self.dims[0], groups=resnet_block_groups), 197 | ResNetBlock(self.dims[0], self.dims[0], groups=resnet_block_groups) 198 | ]) 199 | 200 | # segmentation initial 201 | self.seg_initial = nn.ModuleList([ 202 | ResNetBlock(n_classes, self.dims[0], time_emb_dim=time_dim, groups=resnet_block_groups), 203 | ResNetBlock(self.dims[0], self.dims[0], groups=resnet_block_groups), 204 | ResNetBlock(self.dims[0], self.dims[0], groups=resnet_block_groups) 205 | ]) 206 | 207 | # layers 208 | self.down = nn.ModuleList([]) 209 | self.up = nn.ModuleList([]) 210 | 211 | # encoder 212 | for i in range(len(dim_mults)-1): # each dblock 213 | dim_in = self.dims[i] 214 | dim_out = self.dims[i+1] 215 | 216 | self.down.append( 217 | nn.ModuleList([ 218 | ResNetBlock(dim_in, dim_in, time_emb_dim=time_dim, groups=resnet_block_groups), 219 | ResNetBlock(dim_in, dim_in, groups=resnet_block_groups), 220 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 221 | Downsample(dim_in, dim_out), 222 | 223 | ]) 224 | ) 225 | 226 | # decoder 227 | for i in range(len(dim_mults)-1): # each ublock 228 | dim_in = self.dims[-i-1] 229 | dim_out = self.dims[-i-2] 230 | if i == 0: 231 | dim_in_plus_concat = dim_in 232 | else: 233 | dim_in_plus_concat = dim_in * 2 234 | 235 | self.up.append( 236 | nn.ModuleList([ 237 | ResNetBlock(dim_in_plus_concat, dim_in, time_emb_dim=time_dim, groups=resnet_block_groups), 238 | ResNetBlock(dim_in, dim_in, groups=resnet_block_groups), 239 | Residual(PreNorm(dim_in, LinearAttention(dim_in))), 240 | Upsample(dim_in, dim_out), 241 | ]) 242 | ) 243 | 244 | # final 245 | self.final = nn.Sequential(ResNetBlock(self.dims[0]*2, self.dims[0], groups=resnet_block_groups), 246 | ResNetBlock(self.dims[0], self.dims[0], groups=resnet_block_groups), 247 | nn.Conv2d(self.dims[0], n_classes, 1)) 248 | 249 | 250 | 251 | def forward(self, seg, img, time): 252 | # time embedding 253 | t = self.time_mlp(time) 254 | 255 | # segmentation initial 256 | resnetblock1, resnetblock2, resnetblock3 = self.seg_initial 257 | seg_emb = resnetblock1(seg, t) 258 | seg_emb = resnetblock2(seg_emb) 259 | seg_emb = resnetblock3(seg_emb) 260 | 261 | # image initial 262 | resnetblock1, resnetblock2, resnetblock3 = self.image_initial 263 | img_emb = resnetblock1(img, t) 264 | img_emb = resnetblock2(img_emb) 265 | img_emb = resnetblock3(img_emb) 266 | 267 | # add embeddings together 268 | x = seg_emb + img_emb 269 | 270 | # skip connections 271 | h = [] 272 | 273 | # downsample 274 | for resnetblock1, resnetblock2, attn, downsample in self.down: 275 | x = resnetblock1(x, t) 276 | x = resnetblock2(x) 277 | x = attn(x) 278 | h.append(x) 279 | x = downsample(x) 280 | 281 | # upsample 282 | for resnetblock1, resnetblock2, attn, upsample in self.up: 283 | x = resnetblock1(x, t) 284 | x = resnetblock2(x) 285 | x = attn(x) 286 | x = upsample(x) 287 | x = torch.cat((x, h.pop()), dim=1) 288 | 289 | return self.final(x) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | imageio==2.28.1 3 | numpy==1.24.3 4 | Pillow==9.5.0 5 | torch==2.0.1 6 | torchmetrics==0.11.4 7 | torchvision==0.15.2 8 | tqdm==4.65.0 9 | -------------------------------------------------------------------------------- /rnd.yml: -------------------------------------------------------------------------------- 1 | name: rnd 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - abseil-cpp=20211102.0=h27087fc_1 11 | - absl-py=1.4.0=pyhd8ed1ab_0 12 | - aiohttp=3.8.1=py310h5764c6d_1 13 | - aiosignal=1.3.1=pyhd8ed1ab_0 14 | - async-timeout=4.0.2=py310h06a4308_0 15 | - attrs=23.1.0=pyh71513ae_1 16 | - blas=1.0=mkl 17 | - blinker=1.6.2=pyhd8ed1ab_0 18 | - brotlipy=0.7.0=py310h7f8727e_1002 19 | - bzip2=1.0.8=h7b6447c_0 20 | - c-ares=1.19.0=h5eee18b_0 21 | - ca-certificates=2023.5.7=hbcca054_0 22 | - cachetools=5.3.0=pyhd8ed1ab_0 23 | - certifi=2023.5.7=pyhd8ed1ab_0 24 | - cffi=1.15.1=py310h5eee18b_3 25 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 26 | - click=8.1.3=unix_pyhd8ed1ab_2 27 | - colorama=0.4.6=pyhd8ed1ab_0 28 | - cryptography=39.0.1=py310h9ce1e76_0 29 | - cuda-cudart=11.8.89=0 30 | - cuda-cupti=11.8.87=0 31 | - cuda-libraries=11.8.0=0 32 | - cuda-nvrtc=11.8.89=0 33 | - cuda-nvtx=11.8.86=0 34 | - cuda-runtime=11.8.0=0 35 | - einops=0.6.1=pyhd8ed1ab_0 36 | - ffmpeg=4.3=hf484d3e_0 37 | - filelock=3.9.0=py310h06a4308_0 38 | - freetype=2.12.1=h4a9f257_0 39 | - frozenlist=1.3.3=py310h5eee18b_0 40 | - giflib=5.2.1=h5eee18b_3 41 | - gmp=6.2.1=h295c915_3 42 | - gmpy2=2.1.2=py310heeb90bb_0 43 | - gnutls=3.6.15=he1e5248_0 44 | - google-auth=2.18.0=pyh1a96a4e_0 45 | - google-auth-oauthlib=1.0.0=pyhd8ed1ab_0 46 | - grpc-cpp=1.48.2=h5bf31a4_0 47 | - grpcio=1.48.2=py310h5bf31a4_0 48 | - idna=3.4=py310h06a4308_0 49 | - imageio=2.28.1=pyh24c5eb1_0 50 | - importlib-metadata=6.6.0=pyha770c72_0 51 | - intel-openmp=2023.1.0=hdb19cb5_46305 52 | - jinja2=3.1.2=py310h06a4308_0 53 | - jpeg=9e=h5eee18b_1 54 | - lame=3.100=h7b6447c_0 55 | - lcms2=2.12=h3be6417_0 56 | - ld_impl_linux-64=2.38=h1181459_1 57 | - lerc=3.0=h295c915_0 58 | - libcublas=11.11.3.6=0 59 | - libcufft=10.9.0.58=0 60 | - libcufile=1.6.1.9=0 61 | - libcurand=10.3.2.106=0 62 | - libcusolver=11.4.1.48=0 63 | - libcusparse=11.7.5.86=0 64 | - libdeflate=1.17=h5eee18b_0 65 | - libffi=3.4.4=h6a678d5_0 66 | - libgcc-ng=11.2.0=h1234567_1 67 | - libgomp=11.2.0=h1234567_1 68 | - libiconv=1.16=h7f8727e_2 69 | - libidn2=2.3.2=h7f8727e_0 70 | - libnpp=11.8.0.86=0 71 | - libnvjpeg=11.9.0.86=0 72 | - libpng=1.6.39=h5eee18b_0 73 | - libprotobuf=3.20.3=he621ea3_0 74 | - libstdcxx-ng=11.2.0=h1234567_1 75 | - libtasn1=4.19.0=h5eee18b_0 76 | - libtiff=4.5.0=h6a678d5_2 77 | - libunistring=0.9.10=h27cfd23_0 78 | - libuuid=1.41.5=h5eee18b_0 79 | - libwebp=1.2.4=h11a3e52_1 80 | - libwebp-base=1.2.4=h5eee18b_1 81 | - lz4-c=1.9.4=h6a678d5_0 82 | - markdown=3.4.3=pyhd8ed1ab_0 83 | - markupsafe=2.1.1=py310h7f8727e_0 84 | - mkl=2023.1.0=h6d00ec8_46342 85 | - mkl-service=2.4.0=py310h5eee18b_1 86 | - mkl_fft=1.3.6=py310h1128e8f_1 87 | - mkl_random=1.2.2=py310h1128e8f_1 88 | - mpc=1.1.0=h10f8cd9_1 89 | - mpfr=4.0.2=hb69a4c5_1 90 | - multidict=6.0.2=py310h5eee18b_0 91 | - ncurses=6.4=h6a678d5_0 92 | - nettle=3.7.3=hbbd107a_1 93 | - networkx=2.8.4=py310h06a4308_1 94 | - numpy=1.24.3=py310h5f9d8c6_1 95 | - numpy-base=1.24.3=py310hb5e798b_1 96 | - oauthlib=3.2.2=pyhd8ed1ab_0 97 | - openh264=2.1.1=h4ff587b_0 98 | - openssl=1.1.1t=h7f8727e_0 99 | - packaging=23.1=pyhd8ed1ab_0 100 | - pillow=9.4.0=py310h6a678d5_0 101 | - pip=23.0.1=py310h06a4308_0 102 | - protobuf=3.20.3=py310h6a678d5_0 103 | - pyasn1=0.4.8=py_0 104 | - pyasn1-modules=0.2.7=py_0 105 | - pycparser=2.21=pyhd3eb1b0_0 106 | - pyjwt=2.7.0=pyhd8ed1ab_0 107 | - pyopenssl=23.0.0=py310h06a4308_0 108 | - pysocks=1.7.1=py310h06a4308_0 109 | - python=3.10.11=h7a1cb2a_2 110 | - python_abi=3.10=2_cp310 111 | - pytorch=2.0.1=py3.10_cuda11.8_cudnn8.7.0_0 112 | - pytorch-cuda=11.8=h7e8668a_5 113 | - pytorch-mutex=1.0=cuda 114 | - pyu2f=0.1.5=pyhd8ed1ab_0 115 | - re2=2022.04.01=h27087fc_0 116 | - readline=8.2=h5eee18b_0 117 | - requests=2.29.0=py310h06a4308_0 118 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 119 | - rsa=4.9=pyhd8ed1ab_0 120 | - setuptools=66.0.0=py310h06a4308_0 121 | - six=1.16.0=pyh6c4a22f_0 122 | - sqlite=3.41.2=h5eee18b_0 123 | - sympy=1.11.1=py310h06a4308_0 124 | - tbb=2021.8.0=hdb19cb5_0 125 | - tensorboard=2.13.0=pyhd8ed1ab_0 126 | - tensorboard-data-server=0.7.0=py310h52d8a92_0 127 | - tk=8.6.12=h1ccaba5_0 128 | - torchaudio=2.0.2=py310_cu118 129 | - torchmetrics=0.11.4=pyhd8ed1ab_0 130 | - torchtriton=2.0.0=py310 131 | - torchvision=0.15.2=py310_cu118 132 | - tqdm=4.65.0=pyhd8ed1ab_1 133 | - typing_extensions=4.5.0=py310h06a4308_0 134 | - tzdata=2023c=h04d1e81_0 135 | - urllib3=1.26.15=py310h06a4308_0 136 | - werkzeug=2.3.4=pyhd8ed1ab_0 137 | - wheel=0.38.4=py310h06a4308_0 138 | - xz=5.4.2=h5eee18b_0 139 | - yarl=1.7.2=py310h5764c6d_2 140 | - zipp=3.15.0=pyhd8ed1ab_0 141 | - zlib=1.2.13=h5eee18b_0 142 | - zstd=1.5.5=hc292b87_0 143 | - pip: 144 | - mpmath==1.2.1 145 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2022, Benedikt Kolbeinsson 3 | 4 | """This script tests a model.""" 5 | 6 | 7 | ################################### Import ################################### 8 | import argparse 9 | import logging 10 | import torch 11 | 12 | from torch.utils.data import DataLoader 13 | 14 | from networks.network import Network, NetworkConfig 15 | from utils.cityscapes_loader import CityscapesLoader 16 | from utils.evaluation import Evaluator 17 | from utils.pascal_voc_loader import PascalVOCLoader 18 | from utils.trainer import TrainerConfig 19 | from utils.utils import set_seed 20 | from utils.uavid_loader import UAVidLoader 21 | from utils.vaihingen_buildings_loader import VaihingenBuildingsLoader 22 | 23 | 24 | #################################### Setup #################################### 25 | 26 | def make_parser(): 27 | """Creat an argument parser""" 28 | 29 | parser = argparse.ArgumentParser(description=__doc__) 30 | 31 | # ------------ Optional arguments ------------ # 32 | # Network 33 | parser.add_argument("--network", "-n", metavar='NET', type=str, action="store", default=TrainerConfig.network, 34 | help="Network architecture", dest="network") 35 | # Hyperparameters 36 | parser.add_argument("--batch_size", "-b", metavar='B', type=int, action="store", default=TrainerConfig.batch_size, 37 | help="Batch size", dest="batch_size") 38 | # Diffusion parameters 39 | parser.add_argument("--n_timesteps", metavar='T', type=int, action="store", default=NetworkConfig.n_timesteps, 40 | help="Number of timesteps", dest="n_timesteps") 41 | parser.add_argument("--n_scales", metavar='L', type=int, action="store", default=NetworkConfig.n_scales, 42 | help="Number of scales", dest="n_scales") 43 | parser.add_argument("--max_patch_size", metavar='P', type=int, action="store", default=NetworkConfig.max_patch_size, 44 | help="Max patch size", dest="max_patch_size") 45 | parser.add_argument("--scale_procedure", metavar='SP', type=str, action="store", default=NetworkConfig.scale_procedure, 46 | help="Scale procedure", dest="scale_procedure") 47 | # Ensemble 48 | parser.add_argument("--ensemble", metavar='E', type=int, action="store", default=1, 49 | help="Number of models to ensemble", dest="ensemble") 50 | # Directories 51 | parser.add_argument("--checkpoint_dir", metavar='CD', type=str, action="store", default=TrainerConfig.checkpoint_dir, 52 | help="Checkpoint directory", dest="checkpoint_dir") 53 | parser.add_argument("--log_dir", metavar='LG', type=str, action="store", default=TrainerConfig.log_dir, 54 | help="Log directory", dest="log_dir") 55 | # Dataset 56 | parser.add_argument("--dataset", metavar='DS', type=str, action="store", default="uavid", 57 | help="Dataset to be used", dest="dataset_selection") 58 | # Checkpoint 59 | parser.add_argument("--load_checkpoint", metavar='FILE', type=str, action="store", default=TrainerConfig.load_checkpoint, 60 | help="Load checkpoint from a .pt file", dest="load_checkpoint") 61 | # Other 62 | parser.add_argument("--seed", "-s", metavar='S', type=int, action="store", default=TrainerConfig.seed, 63 | help="Set random seed for deterministic results", dest="seed") 64 | parser.add_argument("--n_workers", metavar='W', type=int, action="store", default=TrainerConfig.n_workers, 65 | help="Number of workers", dest="n_workers") 66 | parser.add_argument("-v", "--verbose", action="count", default=0, 67 | help="Verbosity (-v, -vv, etc)") 68 | 69 | return parser 70 | 71 | def box_text(text, title=None): 72 | """Add a title and a box around text""" 73 | lines = text.splitlines() 74 | width = max(len(line) for line in lines) + 4 75 | if title: 76 | title = ' ' + title + ' ' 77 | message = '┌{:─^{width}}┐\n'.format(title, width=width) 78 | else: 79 | message = '┌{:─^{width}}┐\n'.format('', width=width) 80 | 81 | for line in lines: 82 | message += '│{:^{width}}│\n'.format(line, width=width) 83 | message += '└{:─^{width}}┘'.format('', width=width) 84 | return message 85 | 86 | def print_all_arguments(): 87 | """Print all arguments""" 88 | message = '' 89 | for key, value in vars(ARGS).items(): 90 | message += '{: >21}: {: <21}\n'.format(str(key), str(value)) 91 | print(box_text(message, 'ARGUMENTS')) 92 | 93 | def setup_logging(): 94 | """Set logging level""" 95 | base_loglevel = logging.WARNING 96 | loglevel = max(base_loglevel - ARGS.verbose * 10, logging.DEBUG) 97 | logging.basicConfig(level=loglevel, 98 | format='%(message)s') 99 | 100 | 101 | 102 | #################################### Code #################################### 103 | 104 | 105 | 106 | 107 | 108 | 109 | #################################### Main #################################### 110 | 111 | def main(): 112 | """Main entry point of the module""" 113 | # logging setup 114 | setup_logging() 115 | 116 | # print arguments 117 | print_all_arguments() 118 | 119 | # make deterministic (optional) 120 | if ARGS.seed is not None: 121 | set_seed(ARGS.seed) 122 | 123 | # define dataset 124 | if ARGS.dataset_selection == "cityscapes": 125 | test_dataset = CityscapesLoader(root='../data/cityscapes/', split='test', is_transform=True) 126 | elif ARGS.dataset_selection == "pascal": 127 | test_dataset = PascalVOCLoader(root='../data/VOC2012/', split='test', is_transform=True) 128 | elif ARGS.dataset_selection == "vaihingen": 129 | # Dataset can be downloaded from https://drive.google.com/open?id=1nenpWH4BdplSiHdfXs0oYfiA5qL42plB 130 | test_dataset = VaihingenBuildingsLoader(root='../data/Vaihingen_buildings/', split='test', is_transform=True) 131 | elif ARGS.dataset_selection == "uavid": 132 | test_dataset = UAVidLoader(root='../data/UAVid/', split='val', is_transform=True) 133 | 134 | # define dataset loader 135 | test_dataloader = DataLoader(test_dataset, batch_size=ARGS.batch_size, shuffle=False, num_workers=ARGS.n_workers) 136 | 137 | # define the model 138 | network_config = NetworkConfig( 139 | n_timesteps=ARGS.n_timesteps, 140 | n_scales=ARGS.n_scales, 141 | max_patch_size=ARGS.max_patch_size, 142 | scale_procedure=ARGS.scale_procedure, 143 | n_classes=test_dataset.n_classes 144 | ) 145 | model = Network(network_config) 146 | 147 | # load checkpoint if specified 148 | checkpoint = None 149 | if ARGS.load_checkpoint is not None: 150 | checkpoint = torch.load(ARGS.load_checkpoint) 151 | model.load_state_dict(checkpoint['model_state_dict']) 152 | 153 | # use GPU if available 154 | device = 'cpu' 155 | if torch.cuda.is_available(): 156 | device = torch.cuda.current_device() 157 | model.to(device) 158 | logging.info("Using device: {}".format(device)) 159 | 160 | # evaluate 161 | evaluator = Evaluator(model, network_config, device, test_data_loader=test_dataloader) 162 | evaluator.test(ensemble=ARGS.ensemble) 163 | 164 | 165 | 166 | 167 | 168 | if __name__ == "__main__": 169 | PARSER = make_parser() 170 | ARGS = PARSER.parse_args() 171 | main() 172 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2022, Benedikt Kolbeinsson 3 | 4 | """This script trains a diffusion model.""" 5 | 6 | 7 | ################################### Import ################################### 8 | import argparse 9 | import logging 10 | import torch 11 | 12 | from torch.utils.data import DataLoader 13 | 14 | from networks.network import Network, NetworkConfig 15 | from utils.cityscapes_loader import CityscapesLoader 16 | from utils.pascal_voc_loader import PascalVOCLoader 17 | from utils.trainer import Trainer, TrainerConfig 18 | from utils.utils import set_seed 19 | from utils.uavid_loader import UAVidLoader 20 | from utils.vaihingen_buildings_loader import VaihingenBuildingsLoader 21 | 22 | 23 | #################################### Setup #################################### 24 | 25 | def make_parser(): 26 | """Creat an argument parser""" 27 | 28 | parser = argparse.ArgumentParser(description=__doc__) 29 | 30 | # ------------ Optional arguments ------------ # 31 | # Network 32 | parser.add_argument("--network", "-n", metavar='NET', type=str, action="store", default=TrainerConfig.network, 33 | help="Network architecture", dest="network") 34 | # Hyperparameters 35 | parser.add_argument("--epochs", "-e", metavar='E', type=int, action="store", default=TrainerConfig.max_epochs, 36 | help="Max number of epochs", dest="epochs") 37 | parser.add_argument("--batch_size", "-b", metavar='B', type=int, action="store", default=TrainerConfig.batch_size, 38 | help="Batch size", dest="batch_size") 39 | parser.add_argument("--learning_rate", "-l", metavar='LR', type=float, action="store", default=TrainerConfig.learning_rate, 40 | help="Learning rate", dest="learning_rate") 41 | parser.add_argument("--momentum", "-m", metavar='M', type=float, action="store", default=TrainerConfig.momentum, 42 | help="Momentum", dest="momentum") 43 | parser.add_argument("--weight_decay", "-w", metavar='WD', type=float, action="store", default=TrainerConfig.weight_decay, 44 | help="Weight decay", dest="weight_decay") 45 | parser.add_argument("--lr_decay", "-d", metavar='D', type=bool, action="store", default=TrainerConfig.lr_decay, 46 | help="Use learning rate decay", dest="lr_decay") 47 | parser.add_argument("--lr_decay_gamma", "-g", metavar='G', type=float, action="store", default=TrainerConfig.lr_decay_gamma, 48 | help="Learning rate decay gamma", dest="lr_decay_gamma") 49 | # Diffusion parameters 50 | parser.add_argument("--n_timesteps", metavar='T', type=int, action="store", default=NetworkConfig.n_timesteps, 51 | help="Number of timesteps", dest="n_timesteps") 52 | parser.add_argument("--n_scales", metavar='L', type=int, action="store", default=NetworkConfig.n_scales, 53 | help="Number of scales", dest="n_scales") 54 | parser.add_argument("--max_patch_size", metavar='P', type=int, action="store", default=NetworkConfig.max_patch_size, 55 | help="Max patch size", dest="max_patch_size") 56 | parser.add_argument("--scale_procedure", metavar='SP', type=str, action="store", default=NetworkConfig.scale_procedure, 57 | help="Scale procedure (loop or linear)", dest="scale_procedure") 58 | # Diffusion other options 59 | parser.add_argument("--train_on_n_scales", metavar='NS', type=int, action="store", default=NetworkConfig.n_scales + 1, 60 | help="Only train first NS scales", dest="train_on_n_scales") 61 | parser.add_argument("--not_recursive", action="store_true", default=False, 62 | help="Do not use recursive diffusion", dest="not_recursive") 63 | # Directories 64 | parser.add_argument("--checkpoint_dir", metavar='CD', type=str, action="store", default=TrainerConfig.checkpoint_dir, 65 | help="Checkpoint directory", dest="checkpoint_dir") 66 | parser.add_argument("--log_dir", metavar='LG', type=str, action="store", default=TrainerConfig.log_dir, 67 | help="Log directory", dest="log_dir") 68 | # Dataset 69 | parser.add_argument("--dataset", metavar='DS', type=str, action="store", default=TrainerConfig.dataset_selection, 70 | help="Dataset to be used", dest="dataset_selection") 71 | # Checkpoint 72 | parser.add_argument("--load_checkpoint", metavar='FILE', type=str, action="store", default=TrainerConfig.load_checkpoint, 73 | help="Load checkpoint from a .pt file", dest="load_checkpoint") 74 | parser.add_argument("--weights_only", action="store_true", default=False, 75 | help="Load weights only", dest="weights_only") 76 | # Other 77 | parser.add_argument("--seed", "-s", metavar='S', type=int, action="store", default=TrainerConfig.seed, 78 | help="Set random seed for deterministic results", dest="seed") 79 | parser.add_argument("--n_workers", metavar='W', type=int, action="store", default=TrainerConfig.n_workers, 80 | help="Number of workers", dest="n_workers") 81 | parser.add_argument("-v", "--verbose", action="count", default=0, 82 | help="Verbosity (-v, -vv, etc)") 83 | 84 | return parser 85 | 86 | def box_text(text, title=None): 87 | """Add a title and a box around text""" 88 | lines = text.splitlines() 89 | width = max(len(line) for line in lines) + 4 90 | if title: 91 | title = ' ' + title + ' ' 92 | message = '┌{:─^{width}}┐\n'.format(title, width=width) 93 | else: 94 | message = '┌{:─^{width}}┐\n'.format('', width=width) 95 | 96 | for line in lines: 97 | message += '│{:^{width}}│\n'.format(line, width=width) 98 | message += '└{:─^{width}}┘'.format('', width=width) 99 | return message 100 | 101 | def print_all_arguments(): 102 | """Print all arguments""" 103 | message = '' 104 | for key, value in vars(ARGS).items(): 105 | message += '{: >21}: {: <21}\n'.format(str(key), str(value)) 106 | print(box_text(message, 'ARGUMENTS')) 107 | 108 | def setup_logging(): 109 | """Set logging level""" 110 | base_loglevel = logging.WARNING 111 | loglevel = max(base_loglevel - ARGS.verbose * 10, logging.DEBUG) 112 | logging.basicConfig(level=loglevel, 113 | format='%(message)s') 114 | 115 | 116 | 117 | #################################### Code #################################### 118 | 119 | 120 | 121 | 122 | 123 | 124 | #################################### Main #################################### 125 | 126 | def main(): 127 | """Main entry point of the module""" 128 | # logging setup 129 | setup_logging() 130 | 131 | # print arguments 132 | print_all_arguments() 133 | 134 | # make deterministic (optional) 135 | if ARGS.seed is not None: 136 | set_seed(ARGS.seed) 137 | 138 | # define dataset 139 | if ARGS.dataset_selection == "cityscapes": 140 | train_dataset = CityscapesLoader(root='../data/cityscapes/', split='train', is_transform=True) 141 | val_dataset = CityscapesLoader(root='../data/cityscapes/', split='val', is_transform=True) 142 | elif ARGS.dataset_selection == "pascal": 143 | train_dataset = PascalVOCLoader(root='../data/VOC2012/', split='train', is_transform=True, img_size=512) 144 | val_dataset = PascalVOCLoader(root='../data/VOC2012/', split='val', is_transform=True, img_size=512) 145 | elif ARGS.dataset_selection == "vaihingen": 146 | train_dataset = VaihingenBuildingsLoader(root='../data/Vaihingen_buildings/', split='train', is_transform=True) 147 | val_dataset = VaihingenBuildingsLoader(root='../data/Vaihingen_buildings/', split='val', is_transform=True) 148 | elif ARGS.dataset_selection == "uavid": 149 | train_dataset = UAVidLoader(root='../data/UAVid/', split='train', is_transform=True) 150 | val_dataset = UAVidLoader(root='../data/UAVid/', split='val', is_transform=True) 151 | 152 | assert ARGS.dataset_selection in ["cityscapes", "pascal", "vaihingen", "uavid"], "Supported datasets are: cityscapes, pascal, vaihingen, uavid" 153 | 154 | # define dataset loader 155 | train_dataloader = DataLoader(train_dataset, batch_size=ARGS.batch_size, shuffle=True, num_workers=ARGS.n_workers) 156 | val_dataloader = DataLoader(val_dataset, batch_size=ARGS.batch_size, shuffle=False, num_workers=ARGS.n_workers) 157 | 158 | # define the model 159 | network_config = NetworkConfig( 160 | n_timesteps=ARGS.n_timesteps, 161 | n_scales=ARGS.n_scales, 162 | max_patch_size=ARGS.max_patch_size, 163 | scale_procedure=ARGS.scale_procedure, 164 | n_classes=train_dataset.n_classes 165 | ) 166 | model = Network(network_config) 167 | 168 | # load checkpoint if specified 169 | checkpoint = None 170 | if ARGS.load_checkpoint is not None: 171 | checkpoint = torch.load(ARGS.load_checkpoint) 172 | model.load_state_dict(checkpoint['model_state_dict']) 173 | 174 | # use GPU if available 175 | device = 'cpu' 176 | if torch.cuda.is_available(): 177 | device = torch.cuda.current_device() 178 | model.cuda() 179 | logging.info("Using device: {}".format(device)) 180 | 181 | # define trainer 182 | trainer_config = TrainerConfig( 183 | max_epochs=ARGS.epochs, batch_size=ARGS.batch_size, 184 | learning_rate=ARGS.learning_rate, momentum=ARGS.momentum, 185 | weight_decay=ARGS.weight_decay, lr_decay=ARGS.lr_decay, 186 | lr_decay_gamma=ARGS.lr_decay_gamma, checkpoint_dir=ARGS.checkpoint_dir, 187 | log_dir=ARGS.log_dir, load_checkpoint=ARGS.load_checkpoint, 188 | n_workers=ARGS.n_workers, network=ARGS.network, 189 | train_on_n_scales=ARGS.train_on_n_scales, not_recursive=ARGS.not_recursive, 190 | dataset_selection=ARGS.dataset_selection, 191 | device=device, checkpoint=checkpoint, weights_only=ARGS.weights_only 192 | ) 193 | trainer = Trainer(model, network_config, trainer_config, train_dataloader, val_dataloader) 194 | 195 | # train model 196 | trainer.train() 197 | 198 | 199 | 200 | 201 | if __name__ == "__main__": 202 | PARSER = make_parser() 203 | ARGS = PARSER.parse_args() 204 | main() 205 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/benediktkol/recursive-noise-diffusion/183696b2d26f245a35a019980fca9adc20f0f4ef/utils/__init__.py -------------------------------------------------------------------------------- /utils/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import os 4 | import torch 5 | from torchvision.io import read_image 6 | import torchvision.transforms as transforms 7 | import torchvision.transforms.functional as TF 8 | 9 | from torch.utils import data 10 | 11 | def decode_segmap(seg, is_one_hot=False): 12 | colors = torch.tensor([ 13 | [128, 64, 128], 14 | [244, 35, 232], 15 | [70, 70, 70], 16 | [102, 102, 156], 17 | [190, 153, 153], 18 | [153, 153, 153], 19 | [250, 170, 30], 20 | [220, 220, 0], 21 | [107, 142, 35], 22 | [152, 251, 152], 23 | [0, 130, 180], 24 | [220, 20, 60], 25 | [255, 0, 0], 26 | [0, 0, 142], 27 | [0, 0, 70], 28 | [0, 60, 100], 29 | [0, 80, 100], 30 | [0, 0, 230], 31 | [119, 11, 32], 32 | [0, 0, 0] 33 | ], dtype=torch.uint8) 34 | if is_one_hot: 35 | seg = torch.argmax(seg, dim=0) 36 | # convert classes to colors 37 | seg_img = torch.empty((seg.shape[0], seg.shape[1], 3), dtype=torch.uint8) 38 | for c in range(20): 39 | seg_img[seg == c, :] = colors[c] 40 | return seg_img.permute(2, 0, 1) 41 | 42 | 43 | class CityscapesLoader(data.Dataset): 44 | """CityscapesLoader 45 | 46 | https://www.cityscapes-dataset.com 47 | 48 | Data is derived from CityScapes, and can be downloaded from here: 49 | https://www.cityscapes-dataset.com/downloads/ 50 | 51 | Many Thanks to @fvisin for the loader repo: 52 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py 53 | """ 54 | def recursive_glob(self, rootdir=".", suffix=""): 55 | """Performs recursive glob with given suffix and rootdir 56 | :param rootdir is the root directory 57 | :param suffix is the suffix to be searched 58 | """ 59 | return [ 60 | os.path.join(looproot, filename) 61 | for looproot, _, filenames in os.walk(rootdir) 62 | for filename in filenames 63 | if filename.endswith(suffix) 64 | ] 65 | 66 | 67 | colors = [ 68 | [128, 64, 128], 69 | [244, 35, 232], 70 | [70, 70, 70], 71 | [102, 102, 156], 72 | [190, 153, 153], 73 | [153, 153, 153], 74 | [250, 170, 30], 75 | [220, 220, 0], 76 | [107, 142, 35], 77 | [152, 251, 152], 78 | [0, 130, 180], 79 | [220, 20, 60], 80 | [255, 0, 0], 81 | [0, 0, 142], 82 | [0, 0, 70], 83 | [0, 60, 100], 84 | [0, 80, 100], 85 | [0, 0, 230], 86 | [119, 11, 32], 87 | [0, 0, 0] 88 | ] 89 | 90 | label_colours = dict(zip(range(19), colors)) 91 | 92 | 93 | def __init__( 94 | self, 95 | root, 96 | split="train", 97 | is_transform=False, 98 | img_size=(1024, 2048), 99 | augmentations=None, 100 | img_norm=True, 101 | test_mode=False, 102 | ): 103 | """__init__ 104 | 105 | :param root: 106 | :param split: 107 | :param is_transform: 108 | :param img_size: 109 | :param augmentations 110 | """ 111 | self.root = root 112 | self.split = split 113 | self.test_mode = test_mode 114 | self.is_transform = is_transform 115 | self.augmentations = augmentations 116 | self.img_norm = img_norm 117 | self.n_classes = 19 118 | self.ignore_index = 255 119 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 120 | self.mean = torch.tensor([0.485, 0.456, 0.406]) 121 | self.std = torch.tensor([0.229, 0.224, 0.225]) 122 | self.files = {} 123 | 124 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split) 125 | self.annotations_base = os.path.join(self.root, "gtFine", self.split) 126 | self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix=".png") 127 | 128 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 129 | self.valid_classes = [ 130 | 7, 131 | 8, 132 | 11, 133 | 12, 134 | 13, 135 | 17, 136 | 19, 137 | 20, 138 | 21, 139 | 22, 140 | 23, 141 | 24, 142 | 25, 143 | 26, 144 | 27, 145 | 28, 146 | 31, 147 | 32, 148 | 33, 149 | self.ignore_index, 150 | ] 151 | self.class_names = [ 152 | "road", 153 | "sidewalk", 154 | "building", 155 | "wall", 156 | "fence", 157 | "pole", 158 | "traffic_light", 159 | "traffic_sign", 160 | "vegetation", 161 | "terrain", 162 | "sky", 163 | "person", 164 | "rider", 165 | "car", 166 | "truck", 167 | "bus", 168 | "train", 169 | "motorcycle", 170 | "bicycle", 171 | "unlabelled", 172 | ] 173 | 174 | self.class_map = dict(zip(self.valid_classes, range(len(self.valid_classes)))) 175 | 176 | if not self.files[split]: 177 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 178 | 179 | logging.info('Found {} {} images'.format(len(self.files[split]), split)) 180 | 181 | def __len__(self): 182 | """__len__""" 183 | return len(self.files[self.split]) 184 | 185 | def __getitem__(self, index): 186 | """__getitem__ 187 | 188 | :param index: 189 | """ 190 | img_path = self.files[self.split][index].rstrip() 191 | lbl_path = os.path.join( 192 | self.annotations_base, 193 | img_path.split(os.sep)[-2], 194 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png", 195 | ) 196 | # Read image and label 197 | img = read_image(img_path) 198 | lbl = read_image(lbl_path).squeeze(0).long() 199 | 200 | lbl = self.encode_segmap(lbl) 201 | 202 | if self.augmentations is not None: 203 | img, lbl = self.augmentations(img, lbl) 204 | 205 | if self.is_transform: 206 | img, lbl = self.transform(img, lbl) 207 | 208 | return img, lbl 209 | 210 | def transform(self, img, lbl): 211 | """transform 212 | 213 | :param img: 214 | :param lbl: 215 | """ 216 | 217 | # # Random crop 218 | # if self.split == "train": 219 | # i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(self.img_size[0], self.img_size[1])) 220 | # img = TF.crop(img, i, j, h, w) 221 | # lbl = TF.crop(lbl, i, j, h, w) 222 | 223 | # Random horizontal flipping 224 | if self.split == "train": 225 | if torch.rand(1).item() < 0.5: 226 | img = TF.hflip(img) 227 | lbl = TF.hflip(lbl) 228 | # Random color jitter 229 | if np.random.random() < 0.25: 230 | img = TF.adjust_brightness(img, 0.5 + np.random.random()) 231 | img = TF.adjust_contrast(img, 0.5 + np.random.random()) 232 | img = TF.adjust_saturation(img, 0.5 + np.random.random()) 233 | img = TF.adjust_hue(img, 0.10 * (np.random.random() - 0.5)) 234 | # Random resized crop 235 | if np.random.random() < 0.25: 236 | i, j, h, w = transforms.RandomResizedCrop.get_params(img, scale=(0.5, 1), ratio=(2, 2), antialias=True) 237 | img = TF.resized_crop(img, i, j, h, w, self.img_size, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True) 238 | lbl = TF.resized_crop(lbl.unsqueeze(0), i, j, h, w, self.img_size, interpolation=transforms.InterpolationMode.NEAREST, antialias=True).squeeze() 239 | 240 | 241 | 242 | # Normalize 243 | if self.img_norm: 244 | img = TF.normalize(img.float(), self.mean, self.std) 245 | 246 | 247 | return img, lbl 248 | 249 | 250 | def encode_segmap(self, mask): 251 | # Put all void classes to zero 252 | for _voidc in self.void_classes: 253 | mask[mask == _voidc] = self.ignore_index 254 | for _validc in self.valid_classes: 255 | mask[mask == _validc] = self.class_map[_validc] 256 | return mask 257 | 258 | 259 | # if __name__ == "__main__": 260 | # import matplotlib.pyplot as plt 261 | 262 | # augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5)]) 263 | 264 | # local_path = "/datasets01/cityscapes/112817/" 265 | # dst = cityscapesLoader(local_path, is_transform=True, augmentations=augmentations) 266 | # bs = 4 267 | # trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0) 268 | # for i, data_samples in enumerate(trainloader): 269 | # imgs, labels = data_samples 270 | # import pdb 271 | 272 | # pdb.set_trace() 273 | # imgs = imgs.numpy()[:, ::-1, :, :] 274 | # imgs = np.transpose(imgs, [0, 2, 3, 1]) 275 | # f, axarr = plt.subplots(bs, 2) 276 | # for j in range(bs): 277 | # axarr[j][0].imshow(imgs[j]) 278 | # axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j])) 279 | # plt.show() 280 | # a = input() 281 | # if a == "ex": 282 | # break 283 | # else: 284 | # plt.close() -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | """Evaluates the performance of a model""" 2 | import logging 3 | import math 4 | import torch 5 | import torch.nn.functional as F 6 | import torchvision 7 | 8 | from torch.utils.tensorboard import SummaryWriter 9 | from torchmetrics import JaccardIndex, F1Score 10 | from tqdm import tqdm 11 | 12 | from utils.cityscapes_loader import decode_segmap as decode_segmap_cityscapes 13 | from utils.utils import diffuse, denoise_scale 14 | from utils.uavid_loader import decode_segmap as decode_segmap_uavid 15 | from utils.vaihingen_buildings_loader import decode_segmap as decode_segmap_vaihingen 16 | 17 | def segmentation_cross_entropy(predicted_segmentation, target_segmentation): 18 | """Returns Cross Entropy Loss""" 19 | weights = torch.tensor([1.79, 1.0, 2.17, 1.17, 3.2, 27.14, 21.56, 190.25], dtype=torch.float32).to(target_segmentation.device) 20 | criterion = torch.nn.CrossEntropyLoss(weight=weights, reduction='sum') 21 | loss = criterion(predicted_segmentation, target_segmentation) 22 | return loss 23 | 24 | def noise_mse(noise_predicted, noise_target): 25 | """Returns MSE Loss""" 26 | criterion = torch.nn.MSELoss(reduction='mean') 27 | loss = criterion(noise_predicted, noise_target) 28 | return loss 29 | 30 | def compute_total_loss(segmentation_cross_entropy): 31 | """Returns total loss""" 32 | total_loss = (1 * segmentation_cross_entropy) 33 | return total_loss 34 | 35 | def write_images_to_tensorboard(writer, epoch, image=None, seg_diffused=None, seg_predicted=None, seg_gt=None, datasplit='validation', dataset_name='cityscapes'): 36 | """Writes images to TensorBoard""" 37 | # decode segmap based on dataset 38 | if dataset_name == 'cityscapes': 39 | decode_segmap = decode_segmap_cityscapes 40 | elif dataset_name == 'uavid': 41 | decode_segmap = decode_segmap_uavid 42 | elif dataset_name == 'vaihingen': 43 | decode_segmap = decode_segmap_vaihingen 44 | else: 45 | raise NotImplementedError('Dataset {} not implemented'.format(dataset_name)) 46 | if image is not None: 47 | image = torchvision.utils.make_grid(image, normalize=True) # normalize to [0,1] and convert to uint8 48 | writer.add_images('{}/image'.format(datasplit), image, epoch, dataformats='CHW') 49 | if seg_diffused is not None: 50 | seg_diffused = decode_segmap(seg_diffused, is_one_hot=True) 51 | writer.add_images('{}/seg_diffused'.format(datasplit), seg_diffused, epoch, dataformats='CHW') 52 | if seg_predicted is not None: 53 | seg_predicted = decode_segmap(seg_predicted, is_one_hot=True) 54 | writer.add_images('{}/seg_predicted'.format(datasplit), seg_predicted, epoch, dataformats='CHW') 55 | if seg_gt is not None: 56 | seg_gt = decode_segmap(seg_gt, is_one_hot=False) 57 | writer.add_images('{}/seg_gt'.format(datasplit), seg_gt, epoch, dataformats='CHW') 58 | 59 | def denoise_loop_scales(model, device, network_config, images): 60 | """Denoises all scales for a single timestep""" 61 | # Calculate scale sizes (smallest first) 62 | scale_sizes = [(images.shape[2] // (2**(network_config.n_scales - i -1)), images.shape[3] // (2**(network_config.n_scales - i -1))) for i in range(network_config.n_scales)] 63 | 64 | # Initialize first prediction (random noise) 65 | seg_previous_scaled = torch.rand(images.shape[0], network_config.n_classes, images.shape[2], images.shape[3]) 66 | 67 | # Initialize built in ensemble 68 | seg_denoised_ensemble = torch.zeros(images.shape[0], network_config.n_classes, images.shape[2], images.shape[3]) 69 | 70 | # Denoise whole segmentation map in steps 71 | for timestep in range(network_config.n_timesteps): # for each step 72 | 73 | for scale in range(network_config.n_scales): # for each scale 74 | # Resize to current scale 75 | images_scaled = F.interpolate(images, size=scale_sizes[scale], mode='bilinear', align_corners=False) 76 | seg_previous_scaled = F.interpolate(seg_previous_scaled.float(), size=scale_sizes[scale], mode='bilinear', align_corners=False).softmax(dim=1) 77 | 78 | # Diffuse 79 | t = torch.tensor([(network_config.n_timesteps - (timestep + scale/network_config.n_scales)) / network_config.n_timesteps]) # time step 80 | seg_diffused = diffuse(seg_previous_scaled, t) 81 | # Denoise 82 | seg_denoised = denoise_scale(model, device, seg_diffused, images_scaled, t, patch_size=network_config.max_patch_size) 83 | 84 | # Update the previous segmentation map 85 | seg_previous_scaled = seg_denoised 86 | 87 | # Add to ensemble 88 | if network_config.built_in_ensemble: 89 | if timestep == 0: 90 | seg_denoised_ensemble = seg_denoised 91 | else: 92 | seg_denoised_ensemble = seg_denoised_ensemble / 2 + seg_denoised / 2 93 | 94 | seg_previous_scaled = seg_denoised_ensemble 95 | 96 | return seg_denoised 97 | 98 | def denoise_linear_scales(model, device, network_config, images): 99 | """Denoises one scale at a each timestep""" 100 | # Calculate scale sizes (smallest first) 101 | scale_sizes = [(images.shape[2] // (2**(network_config.n_scales - i -1)), images.shape[3] // (2**(network_config.n_scales - i -1))) for i in range(network_config.n_scales)] 102 | 103 | # Initialize first prediction (random noise) 104 | seg_previous_scaled = torch.rand(images.shape[0], network_config.n_classes, images.shape[2], images.shape[3]) 105 | 106 | # Denoise whole segmentation map in steps 107 | for timestep in range(network_config.n_timesteps): # for each step 108 | # Get the current scale 109 | timesteps_per_scale = math.ceil(network_config.n_timesteps / network_config.n_scales) 110 | scale = timestep // timesteps_per_scale 111 | 112 | # Resize to current scale 113 | if timestep % timesteps_per_scale == 0: 114 | images_scaled = F.interpolate(images, size=scale_sizes[scale], mode='bilinear', align_corners=False) 115 | seg_previous_scaled = F.interpolate(seg_previous_scaled.float(), size=scale_sizes[scale], mode='bilinear', align_corners=False) 116 | 117 | # Diffuse 118 | t = torch.tensor([(network_config.n_timesteps - (timestep + scale/network_config.n_scales)) / network_config.n_timesteps]) # time step 119 | seg_diffused = diffuse(seg_previous_scaled, t) 120 | # Denoise 121 | seg_denoised = denoise_scale(model, device, seg_diffused, images_scaled, t, patch_size=network_config.max_patch_size) 122 | 123 | # Update the previous segmentation map 124 | seg_previous_scaled = seg_denoised 125 | 126 | return seg_denoised 127 | 128 | def denoise(model, device, network_config, images): 129 | """Denoises the segmentation map""" 130 | if network_config.scale_procedure == 'loop': 131 | seg_denoised = denoise_loop_scales(model, device, network_config, images) 132 | elif network_config.scale_procedure == 'linear': 133 | seg_denoised = denoise_linear_scales(model, device, network_config, images) 134 | 135 | return seg_denoised 136 | 137 | class Evaluator: 138 | """Evaluates the performance of a model""" 139 | def __init__(self, model, network_config, device, dataset_selection=None, test_data_loader=None, validation_data_loader=None, writer=None): 140 | self.model = model 141 | self.network_config = network_config 142 | self.device = device 143 | self.dataset_selection = dataset_selection 144 | self.test_data_loader = test_data_loader 145 | self.validation_data_loader = validation_data_loader 146 | self.writer = writer 147 | 148 | def evaluate(self, data_loader, epoch=1, is_test=True, ensemble=1): # epoch=None 149 | """Evaluates the model on the given dataset""" 150 | model = self.model 151 | network_config = self.network_config 152 | model.eval() 153 | 154 | if self.dataset_selection == 'cityscapes': 155 | ignore_index = 19 156 | n_ignore = 1 157 | else: 158 | ignore_index = None 159 | n_ignore = 0 160 | 161 | jaccard_index = JaccardIndex(task="multiclass", num_classes=data_loader.dataset.n_classes + n_ignore, ignore_index=ignore_index) 162 | jaccard_per_class = JaccardIndex(task="multiclass", num_classes=data_loader.dataset.n_classes + n_ignore, ignore_index=ignore_index, average='none') 163 | f1_score = F1Score(num_classes=data_loader.dataset.n_classes + n_ignore, mdmc_average='samplewise') 164 | 165 | with torch.no_grad(): 166 | pbar_eval = tqdm(enumerate(data_loader), total=len(data_loader), desc='{}'.format('Test' if is_test else 'Validation'), leave=is_test, bar_format='{l_bar}{bar:50}{r_bar}') 167 | for it, samples in pbar_eval: 168 | # Unpack the samples 169 | images, seg_gt = samples 170 | 171 | seg_denoised = denoise(model, self.device, network_config, images) 172 | 173 | # Ensamble 174 | for i in range(ensemble-1): 175 | seg_denoised += denoise(model, self.device, network_config, images) 176 | seg_denoised /= ensemble 177 | 178 | # Compute loss 179 | seg_predicted = seg_denoised.view(seg_denoised.shape[0], seg_denoised.shape[1], -1).argmax(dim=1) 180 | seg_target = seg_gt.view(seg_gt.shape[0], -1) 181 | jaccard_index.update(seg_predicted, seg_target) 182 | jaccard_per_class.update(seg_predicted, seg_target) 183 | f1_score.update(seg_predicted, seg_target) 184 | 185 | # Write images to tensorboard 186 | if self.writer is not None: 187 | if it < 8: 188 | write_images_to_tensorboard(self.writer, epoch, image=images[0], seg_predicted=seg_denoised[0], seg_gt=seg_gt[0], datasplit='validation/{}'.format(it)) 189 | 190 | 191 | # Overall metrics 192 | jaccard_index_total = jaccard_index.compute() 193 | jaccard_per_class_total = jaccard_per_class.compute() 194 | f1_score_total = f1_score.compute() 195 | 196 | # Text report 197 | report = 'Jaccard index: {:.4f} | F1 score: {:.4f}'.format(jaccard_index_total, f1_score_total) 198 | report_per_class = 'Jaccard index per class: {}'.format(jaccard_per_class_total) 199 | if self.writer is None: 200 | logging.log(logging.WARNING, report) 201 | logging.log(logging.WARNING, report_per_class) 202 | else: 203 | logging.info('{} | {} | {}'.format("Test" if is_test else "Validation | Epoch: {}".format(epoch), report, report_per_class)) 204 | 205 | # Write to tensorboard 206 | if self.writer is not None: 207 | self.writer.add_scalar('{}/JaccardIndex'.format('test' if is_test else 'validation'), jaccard_index_total, epoch) 208 | self.writer.add_scalar('{}/F1Score'.format('test' if is_test else 'validation'), f1_score_total, epoch) 209 | 210 | 211 | def validate(self, epoch): 212 | """Evaluates the model on the validation dataset""" 213 | self.evaluate(self.validation_data_loader, epoch, is_test=False) 214 | 215 | def test(self, ensemble=1): 216 | """Evaluates the model on the test dataset""" 217 | self.evaluate(self.test_data_loader, is_test=True, ensemble=ensemble) 218 | 219 | 220 | -------------------------------------------------------------------------------- /utils/pascal_voc_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as pjoin 3 | import collections 4 | import imageio 5 | import torch 6 | import numpy as np 7 | import glob 8 | 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from torch.utils import data 12 | from torchvision import transforms 13 | # from torchvision.transforms import InterpolationMode 14 | 15 | def decode_segmap(seg, is_one_hot=False, plot=False): 16 | """Decode segmentation class labels into a color image 17 | 18 | Args: 19 | seg (np.ndarray): an (M,N) array of integer values denoting 20 | the class label at each spatial location. 21 | plot (bool, optional): whether to show the resulting color image 22 | in a figure. 23 | 24 | Returns: 25 | (np.ndarray, optional): the resulting decoded color image. 26 | """ 27 | colors = torch.tensor([ 28 | [0, 0, 0], 29 | [128, 0, 0], 30 | [0, 128, 0], 31 | [128, 128, 0], 32 | [0, 0, 128], 33 | [128, 0, 128], 34 | [0, 128, 128], 35 | [128, 128, 128], 36 | [64, 0, 0], 37 | [192, 0, 0], 38 | [64, 128, 0], 39 | [192, 128, 0], 40 | [64, 0, 128], 41 | [192, 0, 128], 42 | [64, 128, 128], 43 | [192, 128, 128], 44 | [0, 64, 0], 45 | [128, 64, 0], 46 | [0, 192, 0], 47 | [128, 192, 0], 48 | [0, 64, 128], 49 | [224, 224, 192], 50 | ], dtype=torch.uint8) 51 | if is_one_hot: 52 | seg = torch.argmax(seg, dim=0) 53 | # convert classes to colors 54 | seg = seg.type(dtype=torch.uint8) 55 | seg_img = torch.zeros((seg.shape[0], seg.shape[1], 3), dtype=torch.uint8) 56 | for c in range(22): 57 | seg_img[seg == c, :] = colors[c] 58 | return seg_img.permute(2, 0, 1) 59 | 60 | class PascalVOCLoader(data.Dataset): 61 | """Data loader for the Pascal VOC semantic segmentation dataset. 62 | 63 | Annotations from both the original VOC data (which consist of RGB images 64 | in which colours map to specific classes) and the SBD (Berkely) dataset 65 | (where annotations are stored as .mat files) are converted into a common 66 | `label_mask` format. Under this format, each mask is an (M,N) array of 67 | integer values from 0 to 21, where 0 represents the background class. 68 | 69 | The label masks are stored in a new folder, called `pre_encoded`, which 70 | is added as a subdirectory of the `SegmentationClass` folder in the 71 | original Pascal VOC data layout. 72 | 73 | A total of five data splits are provided for working with the VOC data: 74 | train: The original VOC 2012 training data - 1464 images 75 | val: The original VOC 2012 validation data - 1449 images 76 | trainval: The combination of `train` and `val` - 2913 images 77 | train_aug: The unique images present in both the train split and 78 | training images from SBD: - 8829 images (the unique members 79 | of the result of combining lists of length 1464 and 8498) 80 | train_aug_val: The original VOC 2012 validation data minus the images 81 | present in `train_aug` (This is done with the same logic as 82 | the validation set used in FCN PAMI paper, but with VOC 2012 83 | rather than VOC 2011) - 904 images 84 | """ 85 | 86 | def __init__( 87 | self, 88 | root, 89 | sbd_path=None, 90 | split="train_aug", 91 | is_transform=False, 92 | img_size=512, 93 | augmentations=None, 94 | img_norm=True, 95 | test_mode=False, 96 | ): 97 | self.root = root 98 | self.sbd_path = sbd_path 99 | self.split = split 100 | self.is_transform = is_transform 101 | self.augmentations = augmentations 102 | self.img_norm = img_norm 103 | self.test_mode = test_mode 104 | self.n_classes = 21 105 | self.mean = np.array([104.00699, 116.66877, 122.67892]) 106 | self.files = collections.defaultdict(list) 107 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 108 | 109 | if not self.test_mode: 110 | for split in ["train", "val", "trainval"]: 111 | path = pjoin(self.root, "ImageSets/Segmentation", split + ".txt") 112 | file_list = tuple(open(path, "r")) 113 | file_list = [id_.rstrip() for id_ in file_list] 114 | self.files[split] = file_list 115 | self.setup_annotations() 116 | 117 | self.tf = transforms.Compose( 118 | [ 119 | transforms.ToTensor(), 120 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 121 | ] 122 | ) 123 | 124 | def __len__(self): 125 | return len(self.files[self.split]) 126 | 127 | def __getitem__(self, index): 128 | im_name = self.files[self.split][index] 129 | im_path = pjoin(self.root, "JPEGImages", im_name + ".jpg") 130 | lbl_path = pjoin(self.root, "SegmentationClass/pre_encoded", im_name + ".png") 131 | im = Image.open(im_path) 132 | lbl = Image.open(lbl_path) 133 | if self.augmentations is not None: 134 | im, lbl = self.augmentations(im, lbl) 135 | if self.is_transform: 136 | im, lbl = self.transform(im, lbl) 137 | return im, lbl 138 | 139 | def transform(self, img, lbl): 140 | if self.img_size == ("same", "same"): 141 | pass 142 | else: 143 | img = img.resize((self.img_size[0], self.img_size[1])) # uint8 with RGB mode 144 | lbl = lbl.resize((self.img_size[0], self.img_size[1]), Image.NEAREST) 145 | img = self.tf(img) 146 | lbl = torch.from_numpy(np.array(lbl)).long() 147 | lbl[lbl == 255] = 0 148 | return img, lbl 149 | 150 | def get_pascal_labels(self): 151 | """Load the mapping that associates pascal classes with label colors 152 | 153 | Returns: 154 | np.ndarray with dimensions (21, 3) 155 | """ 156 | return np.asarray( 157 | [ 158 | [0, 0, 0], 159 | [128, 0, 0], 160 | [0, 128, 0], 161 | [128, 128, 0], 162 | [0, 0, 128], 163 | [128, 0, 128], 164 | [0, 128, 128], 165 | [128, 128, 128], 166 | [64, 0, 0], 167 | [192, 0, 0], 168 | [64, 128, 0], 169 | [192, 128, 0], 170 | [64, 0, 128], 171 | [192, 0, 128], 172 | [64, 128, 128], 173 | [192, 128, 128], 174 | [0, 64, 0], 175 | [128, 64, 0], 176 | [0, 192, 0], 177 | [128, 192, 0], 178 | [0, 64, 128], 179 | [224, 224, 192], 180 | ] 181 | ) 182 | 183 | def encode_segmap(self, mask): 184 | """Encode segmentation label images as pascal classes 185 | 186 | Args: 187 | mask (np.ndarray): raw segmentation label image of dimension 188 | (M, N, 3), in which the Pascal classes are encoded as colours. 189 | 190 | Returns: 191 | (np.ndarray): class map with dimensions (M,N), where the value at 192 | a given location is the integer denoting the class index. 193 | """ 194 | mask = mask.astype(int) 195 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 196 | for ii, label in enumerate(self.get_pascal_labels()): 197 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 198 | label_mask = label_mask.astype(int) 199 | return label_mask 200 | 201 | def decode_segmap(self, label_mask, plot=False): 202 | """Decode segmentation class labels into a color image 203 | 204 | Args: 205 | label_mask (np.ndarray): an (M,N) array of integer values denoting 206 | the class label at each spatial location. 207 | plot (bool, optional): whether to show the resulting color image 208 | in a figure. 209 | 210 | Returns: 211 | (np.ndarray, optional): the resulting decoded color image. 212 | """ 213 | label_colours = self.get_pascal_labels() 214 | r = label_mask.copy() 215 | g = label_mask.copy() 216 | b = label_mask.copy() 217 | for ll in range(0, self.n_classes): 218 | r[label_mask == ll] = label_colours[ll, 0] 219 | g[label_mask == ll] = label_colours[ll, 1] 220 | b[label_mask == ll] = label_colours[ll, 2] 221 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 222 | rgb[:, :, 0] = r / 255.0 223 | rgb[:, :, 1] = g / 255.0 224 | rgb[:, :, 2] = b / 255.0 225 | 226 | return rgb 227 | 228 | def setup_annotations(self): 229 | """Sets up Berkley annotations by adding image indices to the 230 | `train_aug` split and pre-encode all segmentation labels into the 231 | common label_mask format (if this has not already been done). This 232 | function also defines the `train_aug` and `train_aug_val` data splits 233 | according to the description in the class docstring 234 | """ 235 | # sbd_path = self.sbd_path 236 | target_path = pjoin(self.root, "SegmentationClass/pre_encoded") 237 | if not os.path.exists(target_path): 238 | os.makedirs(target_path) 239 | # path = pjoin(sbd_path, "dataset/train.txt") 240 | # sbd_train_list = tuple(open(path, "r")) 241 | # sbd_train_list = [id_.rstrip() for id_ in sbd_train_list] 242 | # train_aug = self.files["train"] + sbd_train_list 243 | train_aug = self.files["train"] 244 | 245 | # keep unique elements (stable) 246 | train_aug = [train_aug[i] for i in sorted(np.unique(train_aug, return_index=True)[1])] 247 | self.files["train_aug"] = train_aug 248 | set_diff = set(self.files["val"]) - set(train_aug) # remove overlap 249 | self.files["train_aug_val"] = list(set_diff) 250 | 251 | pre_encoded = glob.glob(pjoin(target_path, "*.png")) 252 | expected = np.unique(self.files["train_aug"] + self.files["val"]).size 253 | 254 | if len(pre_encoded) != expected: 255 | print("Pre-encoding segmentation masks...") 256 | # for ii in tqdm(sbd_train_list): 257 | # lbl_path = pjoin(sbd_path, "dataset/cls", ii + ".mat") 258 | # data = io.loadmat(lbl_path) 259 | # lbl = data["GTcls"][0]["Segmentation"][0].astype(np.int32) 260 | # lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min()) 261 | # m.imsave(pjoin(target_path, ii + ".png"), lbl) 262 | 263 | for ii in tqdm(self.files["trainval"]): 264 | fname = ii + ".png" 265 | lbl_path = pjoin(self.root, "SegmentationClass", fname) 266 | lbl = imageio.imread(lbl_path, pilmode='RGB') 267 | lbl = np.array(lbl, dtype=np.uint8) 268 | lbl = self.encode_segmap(lbl) 269 | # lbl = m.toimage(lbl, high=lbl.max(), low=lbl.min()) 270 | lbl = np.array(lbl, dtype=np.uint8) 271 | imageio.imsave(pjoin(target_path, fname), lbl) 272 | 273 | # assert expected == 9733, "unexpected dataset sizes" 274 | 275 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import time 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | 9 | from pathlib import Path 10 | from torch.cuda.amp import GradScaler 11 | from torch.optim.lr_scheduler import ExponentialLR 12 | from torch.utils.tensorboard import SummaryWriter 13 | from tqdm import tqdm 14 | from tqdm.contrib.logging import logging_redirect_tqdm 15 | 16 | from utils.evaluation import Evaluator, segmentation_cross_entropy, noise_mse, write_images_to_tensorboard 17 | from utils.utils import diffuse, get_patch_indices, dynamic_range 18 | 19 | 20 | class TrainerConfig: 21 | """ 22 | Config settings (hyperparameters) for training. 23 | """ 24 | # optimization parameters 25 | max_epochs = 100 26 | batch_size = 2 27 | learning_rate = 1e-5 28 | momentum = None 29 | weight_decay = 0.001 30 | grad_norm_clip = 0.95 31 | 32 | # learning rate decay params 33 | lr_decay = True 34 | lr_decay_gamma = 0.98 35 | 36 | # network 37 | network = 'unet' 38 | 39 | # diffusion other settings 40 | train_on_n_scales = None 41 | not_recursive = False 42 | 43 | # checkpoint settings 44 | checkpoint_dir = 'output/checkpoints/' 45 | log_dir = 'output/logs/' 46 | load_checkpoint = None 47 | checkpoint = None 48 | weights_only = False 49 | 50 | # data 51 | dataset_selection = 'uavid' 52 | 53 | # other 54 | eval_every = 2 55 | save_every = 2 56 | seed = 0 57 | n_workers = 8 58 | 59 | def __init__(self, **kwargs): 60 | for k,v in kwargs.items(): 61 | setattr(self, k, v) 62 | 63 | def save_config_file(self, filename): 64 | Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True) 65 | logging.info("Saving TrainerConfig file: {}".format(filename)) 66 | with open(filename, 'w') as f: 67 | for k,v in vars(self).items(): 68 | f.write("{}={}\n".format(k,v)) 69 | 70 | class Trainer: 71 | 72 | def __init__(self, model, network_config, config, train_data_loader, validation_data_loader=None): 73 | self.model = model 74 | self.network_config = network_config 75 | self.config = config 76 | self.train_data_loader = train_data_loader 77 | self.validation_data_loader = validation_data_loader 78 | self.device = config.device 79 | 80 | def create_run_name(self): 81 | """Creates a unique run name based on current time and network""" 82 | self.run_name = '{}_{}'.format(time.strftime("%Y%m%d-%H%M"), self.config.network) 83 | 84 | def save_checkpoint(self, model, optimizer, scheduler, epoch, id=None): 85 | """Saves a model checkpoint""" 86 | if id is None: 87 | id = "e{}".format(epoch) 88 | path = os.path.normpath(self.config.checkpoint_dir + "{}/{}_{}.pt".format(self.run_name, self.run_name, id)) # path/time_network/time_network_epoch.pt 89 | Path(os.path.dirname(path)).mkdir(parents=True, exist_ok=True) 90 | logging.info("Saving checkpoint: {}".format(path)) 91 | torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict()}, path) 92 | 93 | def get_optimizer(self): 94 | """Defines the optimizer""" 95 | # optimizer = optim.SGD(self.model.parameters(), lr=self.config.learning_rate, momentum=self.config.momentum, weight_decay=self.config.weight_decay) 96 | optimizer = optim.AdamW(self.model.parameters(), lr=self.config.learning_rate, betas=(0.9, 0.999), weight_decay=self.config.weight_decay) 97 | if (self.config.checkpoint is not None) and (self.config.weights_only is False): 98 | optimizer.load_state_dict(self.config.checkpoint['optimizer_state_dict']) 99 | return optimizer 100 | 101 | def get_scheduler(self, optimizer): 102 | """Defines the learning rate scheduler""" 103 | scheduler = ExponentialLR(optimizer, gamma=self.config.lr_decay_gamma) 104 | if (self.config.checkpoint is not None) and (self.config.weights_only is False): 105 | scheduler.load_state_dict(self.config.checkpoint['scheduler_state_dict']) 106 | return scheduler 107 | 108 | def denoise_loop_scales(self, model, network_config, config, images, seg_gt_one_hot, optimizer, scaler): 109 | """Denoises all scales for a single timestep""" 110 | # Calculate scale sizes (smallest first) 111 | scale_sizes = [(images.shape[2] // (2**(network_config.n_scales - i -1)), images.shape[3] // (2**(network_config.n_scales - i -1))) for i in range(network_config.n_scales)] 112 | 113 | # Initialize first prediction (random noise) 114 | seg_previous_scaled = torch.rand(images.shape[0], network_config.n_classes, images.shape[2], images.shape[3]) 115 | 116 | # Denoise whole segmentation map in steps 117 | for timestep in range(network_config.n_timesteps): # for each step 118 | loss_per_scale = torch.zeros(network_config.n_scales) 119 | 120 | for scale in range(network_config.n_scales): # for each scale 121 | # break if we don't want to train on all scales 122 | if scale > config.train_on_n_scales - 1: 123 | break 124 | # Resize to current scale 125 | images_scaled = F.interpolate(images, size=scale_sizes[scale], mode='bilinear', align_corners=False) 126 | seg_gt_scaled = F.interpolate(seg_gt_one_hot, size=scale_sizes[scale], mode='bilinear', align_corners=False) 127 | seg_previous_scaled = F.interpolate(seg_previous_scaled, size=scale_sizes[scale], mode='bilinear', align_corners=False) 128 | 129 | patch_indices = get_patch_indices(scale_sizes[scale], network_config.max_patch_size, overlap=False) 130 | 131 | # Create a new tensor to store the denoised segmentation map 132 | seg_denoised = torch.zeros(seg_previous_scaled.shape) 133 | # Create a tensor to store the number of times a pixel has been denoised 134 | n_denoised = torch.zeros(seg_previous_scaled.shape) 135 | 136 | for x, y, patch_size in patch_indices: # for each patch 137 | # Get the patch 138 | img_patch = images_scaled[:, :, x:x+patch_size, y:y+patch_size].detach().cuda(non_blocking=True) 139 | seg_gt_patch = seg_gt_scaled[:, :, x:x+patch_size, y:y+patch_size].detach().cuda(non_blocking=True) 140 | seg_patch_previous = seg_previous_scaled[:, :, x:x+patch_size, y:y+patch_size].detach().cuda(non_blocking=True).softmax(dim=1) 141 | if config.not_recursive: 142 | if timestep + scale > 0: 143 | seg_patch_previous = seg_gt_patch 144 | 145 | # Diffuse 146 | t = torch.tensor([(network_config.n_timesteps - (timestep + scale/network_config.n_scales)) / network_config.n_timesteps]).cuda(non_blocking=True) # time step 147 | seg_patch_diffused = diffuse(seg_patch_previous, t).detach() # diffuse segmentation map 148 | noise_gt = seg_patch_diffused - seg_gt_patch # The noise added in the diffusion process + the error from the previous step 149 | 150 | # Zero the parameter gradients 151 | optimizer.zero_grad() 152 | 153 | # Runs the forward pass with autocasting 154 | with torch.cuda.amp.autocast(): 155 | # Forward pass 156 | noise_predicted = model(seg_patch_diffused, img_patch, t) # predict the noise 157 | seg_patch_denoised = seg_patch_diffused - noise_predicted # denoise the patch 158 | 159 | # Compute loss 160 | losses = {} 161 | noise_mse_loss = noise_mse(noise_predicted, noise_gt) 162 | losses['noise_mse'] = noise_mse_loss 163 | # seg_cross_entropy_loss = segmentation_cross_entropy(seg_patch_denoised, seg_gt_patch.argmax(dim=1)) 164 | # losses['seg_cross_entropy'] = seg_cross_entropy_loss 165 | total_loss = noise_mse_loss 166 | 167 | # Backward pass 168 | # total_loss.backward() 169 | scaler.scale(total_loss).backward() 170 | 171 | # Clip the gradients 172 | # torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 173 | 174 | # Update the parameters 175 | # optimizer.step() 176 | scaler.step(optimizer) 177 | 178 | # Update the scale for the next iteration. 179 | scaler.update() 180 | 181 | # Add the denoised patch to the segmentation map 182 | seg_patch_denoised = seg_patch_denoised.detach().cpu() # detach from the graph 183 | seg_denoised[:, :, x:x+patch_size, y:y+patch_size] += seg_patch_denoised 184 | n_denoised[:, :, x:x+patch_size, y:y+patch_size] += 1 185 | 186 | 187 | # Average the denoised patches 188 | seg_denoised = seg_denoised / n_denoised 189 | 190 | # # Adjust range 191 | # seg_denoised = dynamic_range(seg_denoised) 192 | 193 | # Update the previous segmentation map 194 | seg_previous_scaled = seg_denoised 195 | 196 | return seg_denoised, losses 197 | 198 | def denoise_linear_scales(self, model, network_config, config, images, seg_gt_one_hot, optimizer, scaler): 199 | """Denoises one scale at a each timestep""" 200 | # Calculate scale sizes (smallest first) 201 | scale_sizes = [(images.shape[2] // (2**(network_config.n_scales - i -1)), images.shape[3] // (2**(network_config.n_scales - i -1))) for i in range(network_config.n_scales)] 202 | 203 | # Initialize first prediction (random noise) 204 | seg_previous_scaled = torch.rand(images.shape[0], network_config.n_classes, images.shape[2], images.shape[3]) 205 | 206 | # Denoise whole segmentation map in steps 207 | for timestep in range(network_config.n_timesteps): # for each step 208 | # Get the current scale 209 | timesteps_per_scale = math.ceil(network_config.n_timesteps / network_config.n_scales) 210 | scale = timestep // timesteps_per_scale 211 | 212 | # Resize to current scale 213 | if timestep % timesteps_per_scale == 0: 214 | images_scaled = F.interpolate(images, size=scale_sizes[scale], mode='bilinear', align_corners=False) 215 | seg_gt_scaled = F.interpolate(seg_gt_one_hot, size=scale_sizes[scale], mode='nearest') 216 | seg_previous_scaled = F.interpolate(seg_previous_scaled.float(), size=scale_sizes[scale], mode='bilinear', align_corners=False) 217 | 218 | patch_indices = get_patch_indices(scale_sizes[scale], network_config.max_patch_size, overlap=False) 219 | 220 | # Create a new tensor to store the denoised segmentation map 221 | seg_denoised = torch.zeros(seg_previous_scaled.shape) 222 | # Create a tensor to store the number of times a pixel has been denoised 223 | n_denoised = torch.zeros(seg_previous_scaled.shape) 224 | 225 | for x, y, patch_size in patch_indices: # for each patch 226 | # Get the patch 227 | img_patch = images_scaled[:, :, x:x+patch_size, y:y+patch_size].detach().to(self.device).contiguous() 228 | seg_gt_patch = seg_gt_scaled[:, :, x:x+patch_size, y:y+patch_size].detach().to(self.device).contiguous() 229 | seg_patch_previous = seg_previous_scaled[:, :, x:x+patch_size, y:y+patch_size].detach().to(self.device).contiguous() 230 | 231 | # Zero the parameter gradients 232 | optimizer.zero_grad() 233 | 234 | # Diffuse 235 | t = torch.tensor([(network_config.n_timesteps - timestep) / network_config.n_timesteps]).to(self.device) # time step 236 | seg_patch_diffused = diffuse(seg_patch_previous, t).detach() # diffuse segmentation map 237 | noise_gt = seg_patch_diffused - seg_gt_patch # The noise added in the diffusion process + the error from the previous step 238 | 239 | # Forward pass 240 | noise_predicted = model(seg_patch_diffused, img_patch, t) # predict the noise 241 | seg_patch_denoised = seg_patch_diffused - noise_predicted # denoise the patch 242 | 243 | # Compute loss 244 | losses = {} 245 | noise_mse_loss = noise_mse(noise_predicted, noise_gt) 246 | losses['noise_mse'] = noise_mse_loss 247 | seg_cross_entropy_loss = segmentation_cross_entropy(seg_patch_denoised, seg_gt_patch.argmax(dim=1)) 248 | losses['seg_cross_entropy'] = seg_cross_entropy_loss 249 | total_loss = noise_mse_loss + seg_cross_entropy_loss 250 | 251 | # Backward pass 252 | total_loss.backward() 253 | 254 | # Clip the gradients 255 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 256 | 257 | # Update the parameters 258 | optimizer.step() 259 | 260 | # Add the denoised patch to the segmentation map 261 | seg_patch_denoised = seg_patch_denoised.detach().cpu() # detach from the graph 262 | seg_denoised[:, :, x:x+patch_size, y:y+patch_size] += seg_patch_denoised 263 | n_denoised[:, :, x:x+patch_size, y:y+patch_size] += 1 264 | 265 | # Average the denoised patches 266 | seg_denoised = seg_denoised / n_denoised 267 | 268 | # Update the previous segmentation map 269 | seg_previous_scaled = seg_denoised 270 | 271 | return seg_denoised, losses 272 | 273 | def denoise_and_backprop(self, model, network_config, config, images, seg_gt_one_hot, optimizer, scaler): 274 | """Denoises and backpropagates the error""" 275 | if network_config.scale_procedure == 'loop': 276 | seg_denoised, losses = self.denoise_loop_scales(model, network_config, config, images, seg_gt_one_hot, optimizer, scaler) 277 | elif network_config.scale_procedure == 'linear': 278 | seg_denoised, losses = self.denoise_linear_scales(model, network_config, config, images, seg_gt_one_hot, optimizer, scaler) 279 | 280 | return seg_denoised, losses 281 | 282 | def train(self): 283 | """Trains the model""" 284 | self.create_run_name() 285 | model = self.model 286 | network_config = self.network_config 287 | config = self.config 288 | optimizer = self.get_optimizer() 289 | scaler = GradScaler() 290 | scheduler = self.get_scheduler(optimizer) 291 | writer = SummaryWriter(log_dir=(config.log_dir + self.run_name)) 292 | evaluator = Evaluator(model, network_config, self.device, dataset_selection=config.dataset_selection, validation_data_loader=self.validation_data_loader, writer=writer) 293 | 294 | config.save_config_file(os.path.normpath(config.checkpoint_dir + "{}/{}_config.txt".format(self.run_name, self.run_name))) 295 | 296 | def run_epoch(): 297 | model.train() 298 | 299 | pbar_epoch = tqdm(enumerate(self.train_data_loader), total=len(self.train_data_loader), desc='Epoch {}/{}'.format(epoch+1, config.max_epochs), leave=False, bar_format='{l_bar}{bar:50}{r_bar}') 300 | for it, samples in pbar_epoch: 301 | # Unpack the samples 302 | images, seg_gt = samples 303 | seg_gt_one_hot = F.one_hot(seg_gt, num_classes=network_config.n_classes+1).permute(0,3,1,2)[:,:-1,:,:].float() # make one hot (if remove void class [:,:-1,:,:]) 304 | 305 | # Denoise and backpropagate 306 | seg_denoised, losses = self.denoise_and_backprop(model, network_config, config, images, seg_gt_one_hot, optimizer, scaler) 307 | 308 | # Write to tensorboard 309 | it_total = it + epoch*len(self.train_data_loader) 310 | if it_total % 10 == 0 and it_total > 0: 311 | for loss_name, loss in losses.items(): 312 | writer.add_scalar('train/{}'.format(loss_name), loss, it_total) 313 | 314 | # Write images to tensorboard 315 | if it % 200 == 0: 316 | write_images_to_tensorboard(writer, it_total, image=images[0], seg_predicted=seg_denoised[0], seg_gt=seg_gt[0], datasplit='train', dataset_name=config.dataset_selection) 317 | 318 | scheduler.step() 319 | 320 | 321 | with logging_redirect_tqdm(): 322 | pbar_total = tqdm(range(config.max_epochs), desc='Total', bar_format='{l_bar}{bar:50}{r_bar}') 323 | for epoch in pbar_total: 324 | # Run an epoch 325 | run_epoch() 326 | 327 | # Save checkpoint 328 | if (epoch+1) % config.save_every == 0: 329 | self.save_checkpoint(model, optimizer, scheduler, epoch+1) 330 | 331 | # Evaluate 332 | if self.validation_data_loader is not None: 333 | if (epoch+1) % config.eval_every == 0: 334 | evaluator.validate(epoch+1) 335 | 336 | writer.flush() 337 | writer.close() -------------------------------------------------------------------------------- /utils/uavid_loader.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import torch 5 | from torchvision.io import read_image 6 | import torchvision.transforms as T 7 | import torchvision.transforms.functional as TF 8 | 9 | from torch.utils import data 10 | 11 | def decode_segmap(seg, is_one_hot=False): 12 | colors = torch.tensor([ 13 | [0,0,0], # Background / clutter 14 | [128,0,0], # Building 15 | [128,64,128], # Road 16 | [0,128,0], # Tree 17 | [128,128,0], # Low vegetation 18 | [64,0,128], # Moving car 19 | [192,0,192], # Static car 20 | [64,64,0] # Human 21 | ], dtype=torch.uint8) 22 | if is_one_hot: 23 | seg = torch.argmax(seg, dim=0) 24 | # convert classes to colors 25 | seg_img = torch.empty((seg.shape[0], seg.shape[1], 3), dtype=torch.uint8) 26 | for c in range(colors.shape[0]): 27 | seg_img[seg == c, :] = colors[c] 28 | return seg_img.permute(2, 0, 1) 29 | 30 | 31 | class UAVidLoader(data.Dataset): 32 | """UAVid dataloader""" 33 | 34 | def encode_segmap(self, segcolors): 35 | """RGB colors to class labels""" 36 | colors = torch.tensor([ 37 | [0,0,0], # Background / clutter 38 | [128,0,0], # Building 39 | [128,64,128], # Road 40 | [0,128,0], # Tree 41 | [128,128,0], # Low vegetation 42 | [64,0,128], # Moving car 43 | [192,0,192], # Static car 44 | [64,64,0] # Human 45 | ], dtype=torch.uint8) 46 | segcolors = segcolors.permute(1, 2, 0) 47 | label_map = torch.zeros((segcolors.shape[0], segcolors.shape[1]), dtype=torch.long) 48 | for i, color in enumerate(colors): 49 | label_map[(segcolors == color).all(dim=2)] = i 50 | return label_map 51 | 52 | 53 | 54 | def __init__( 55 | self, 56 | root, 57 | split="train", 58 | is_transform=False, 59 | img_size=(1024,2048), 60 | augmentations=None, 61 | img_norm=True 62 | ): 63 | self.root = root 64 | self.split = split 65 | self.is_transform = is_transform 66 | self.n_classes = 8 67 | self.augmentations = augmentations 68 | self.img_size = [img_size[0], img_size[1]] if isinstance(img_size, tuple) else img_size 69 | self.img_norm = img_norm 70 | self.mean = torch.tensor([0.485, 0.456, 0.406]) 71 | self.std = torch.tensor([0.229, 0.224, 0.225]) 72 | self.images = {} 73 | self.labels = {} 74 | 75 | self.setup() 76 | 77 | def setup(self): 78 | image_list = [] 79 | label_list = [] 80 | for seq in os.listdir(os.path.join(self.root, "uavid_{}".format(self.split))): 81 | for i in range(10): 82 | image_list.append(os.path.join(self.root, "uavid_{}".format(self.split), seq, "Images", "{:06d}.png".format(i*100))) 83 | label_list.append(os.path.join(self.root, "uavid_{}".format(self.split), seq, "Labels", "{:06d}.png".format(i*100))) 84 | 85 | self.images[self.split] = image_list 86 | self.labels[self.split] = label_list 87 | 88 | 89 | def __len__(self): 90 | return len(self.images[self.split]) 91 | 92 | def __getitem__(self, index): 93 | img_path = self.images[self.split][index] 94 | lbl_path = self.labels[self.split][index] 95 | # Read image and label 96 | img = read_image(img_path) 97 | lbl = read_image(lbl_path).squeeze(0).long() 98 | 99 | # Resize 100 | img = TF.resize(img, (1024,2048), antialias=True) 101 | lbl = TF.resize(lbl, (1024,2048), interpolation=TF.InterpolationMode.NEAREST, antialias=True) 102 | 103 | # RandomCrop 104 | i, j, h, w = T.RandomCrop.get_params(img, output_size=(self.img_size[0], self.img_size[1])) 105 | img = TF.crop(img, i, j, h, w) 106 | lbl = TF.crop(lbl, i, j, h, w) 107 | 108 | 109 | # Encode labels 110 | lbl = self.encode_segmap(lbl) 111 | 112 | # Augmentations 113 | if self.split == "train": 114 | # Random flips 115 | if np.random.random() < 0.5: 116 | img = TF.hflip(img) 117 | lbl = TF.hflip(lbl) 118 | # Random color jitter 119 | if np.random.random() < 0.25: 120 | img = TF.adjust_brightness(img, 0.75 + np.random.random() * 0.5) 121 | img = TF.adjust_contrast(img, 0.75 + np.random.random() * 0.5) 122 | img = TF.adjust_saturation(img, 0.75 + np.random.random() * 0.5) 123 | img = TF.adjust_hue(img, 0.10 * (np.random.random() - 0.5)) 124 | 125 | # Normalize 126 | if self.img_norm: 127 | img = TF.normalize(img.float(), self.mean, self.std) 128 | 129 | return img, lbl 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Assortment of useful methods 3 | """ 4 | 5 | import logging 6 | import numpy as np 7 | import random 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | def set_seed(seed): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | logging.info('This run is deterministic with seed {}'.format(seed)) 17 | 18 | def diffuse(seg_start, t, noise_type='normal_add', schedule='linear'): 19 | """Diffuse a segmentation map by adding noise based on the diffusion schedule""" 20 | # Define the diffusion schedule 21 | if schedule == 'linear': 22 | noise_factor = t 23 | elif schedule == 'square': 24 | noise_factor = t ** 2 25 | 26 | noise_factor = noise_factor.view(-1, 1, 1, 1) 27 | 28 | # Create noise 29 | if noise_type == 'normal_add': 30 | noise = torch.randn(seg_start.shape, device=seg_start.device) 31 | diffused = seg_start + noise * noise_factor 32 | elif noise_type == 'normal_average': 33 | noise = torch.randn(seg_start.shape, device=seg_start.device) 34 | diffused = (1 - noise_factor) * (seg_start - 0.5) + 0.5 + noise_factor * noise 35 | elif noise_type == 'uniform': 36 | noise = torch.rand(seg_start.shape, device=seg_start.device) 37 | diffused = seg_start + noise * noise_factor 38 | elif noise_type == 'binary': 39 | noise = F.one_hot(torch.randint(0, seg_start.shape[1], (seg_start.shape[0], seg_start.shape[2], seg_start.shape[3]), device=seg_start.device), num_classes=seg_start.shape[1]).permute(0,3,1,2).float() 40 | noise_factor = noise_factor.expand(-1, 1, seg_start.shape[2], seg_start.shape[3]) 41 | noise_factor_bernoulli = torch.bernoulli(noise_factor) 42 | diffused = seg_start * (1 - noise_factor_bernoulli) + noise * noise_factor_bernoulli 43 | elif noise_type == 'none': 44 | diffused = seg_start 45 | 46 | # Clip the values 47 | diffused = torch.clamp(diffused, 0, 1) 48 | 49 | return diffused 50 | 51 | def get_patch_indices(img_size, patch_size, overlap=True): 52 | """ 53 | Get the indices of the patches in an image 54 | """ 55 | # Adjust patch size if necessary 56 | # if min(img_size[0], img_size[1]) < patch_size: 57 | # patch_size = min(img_size[0], img_size[1]) 58 | 59 | # Set stride 60 | if overlap: 61 | stride = patch_size // 2 62 | else: 63 | stride = patch_size 64 | 65 | # Get the number of patches 66 | n_patches_h = max((img_size[0] - patch_size) // stride + 1, 1) 67 | n_patches_w = max((img_size[1] - patch_size) // stride + 1, 1) 68 | 69 | # Get the indices of the patches 70 | patch_indices = [] 71 | for i in range(n_patches_h): 72 | for j in range(n_patches_w): 73 | x = i * stride 74 | y = j * stride 75 | patch_indices.append((x, y, patch_size)) 76 | 77 | return patch_indices 78 | 79 | def dynamic_range(x, mode='argmax'): 80 | """ 81 | Adjust the dynamic range of a tensor 82 | """ 83 | if mode == 'softmax': 84 | x = torch.softmax(x, dim=1) 85 | elif mode == 'argmax': 86 | v,_ = x.topk(1, dim=1) 87 | x = x/v 88 | x = x.trunc() 89 | elif mode == 'sigmoid': 90 | x = torch.sigmoid(x) 91 | elif mode == 'clamp': 92 | x = torch.clamp(x, 0, 1) 93 | elif mode == 'dynamic': 94 | s = max(x.max(), 1) 95 | x = torch.clamp(x, min=-s, max=s) 96 | x = x / s 97 | return x 98 | 99 | def denoise_scale(model, device, seg_diffused, images, t, patch_size=512, overlap=False, use_dynamic_range=False): 100 | """ 101 | Denoise a segmentation map using a model 102 | """ 103 | # Get the indices of the patches 104 | img_size = seg_diffused.shape[2:] 105 | patch_indices = get_patch_indices(img_size, patch_size, overlap) 106 | 107 | # Create a new tensor to store the denoised segmentation map 108 | seg_denoised = torch.zeros(seg_diffused.shape) 109 | # Create a tensor to store the number of times a pixel has been denoised 110 | n_denoised = torch.zeros(seg_diffused.shape) 111 | 112 | # Denoise each patch 113 | for x, y, patch_size in patch_indices: 114 | # Get the patch 115 | img_patch = images[:, :, x:x+patch_size, y:y+patch_size].detach().to(device).contiguous() 116 | seg_patch_diffused = seg_diffused[:, :, x:x+patch_size, y:y+patch_size].detach().to(device).contiguous() 117 | 118 | # Denoise the patch 119 | noise_predicted = model(seg_patch_diffused, img_patch, t.to(device)) # predict the noise 120 | seg_patch_denoised = seg_patch_diffused - noise_predicted # denoise the patch 121 | 122 | # Add the denoised patch to the segmentation map 123 | seg_denoised[:, :, x:x+patch_size, y:y+patch_size] += seg_patch_denoised.cpu() 124 | n_denoised[:, :, x:x+patch_size, y:y+patch_size] += 1 125 | 126 | # Average the denoised segmentation map 127 | seg_denoised /= n_denoised 128 | 129 | # Adjust range 130 | if use_dynamic_range: 131 | seg_denoised = dynamic_range(seg_denoised) 132 | 133 | return seg_denoised 134 | -------------------------------------------------------------------------------- /utils/vaihingen_buildings_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from torchvision.io import read_image 5 | import torchvision.transforms.functional as TF 6 | 7 | from torch.utils import data 8 | 9 | def decode_segmap(seg, is_one_hot=False): 10 | colors = torch.tensor([ 11 | [0, 0, 0], 12 | [255, 255, 255], 13 | ], dtype=torch.uint8) 14 | if is_one_hot: 15 | seg = torch.argmax(seg, dim=0) 16 | # convert classes to colors 17 | seg_img = torch.empty((seg.shape[0], seg.shape[1], 3), dtype=torch.uint8) 18 | for c in range(colors.shape[0]): 19 | seg_img[seg == c, :] = colors[c] 20 | return seg_img.permute(2, 0, 1) 21 | 22 | class VaihingenBuildingsLoader(data.Dataset): 23 | """Vaihingen Buildings dataloader""" 24 | 25 | def __init__( 26 | self, 27 | root, 28 | split="train", 29 | is_transform=True, 30 | img_size=512, 31 | augmentations=None, 32 | img_norm=True, 33 | ): 34 | self.root = root 35 | if split == "val": # There is no separate validation data 36 | split = "test" 37 | self.split = split 38 | self.is_transform = is_transform 39 | self.n_classes = 2 40 | self.augmentations = augmentations 41 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 42 | self.img_norm = img_norm 43 | self.mean = np.array([0.485, 0.456, 0.406]) 44 | self.std = np.array([0.229, 0.224, 0.225]) 45 | self.images = {} 46 | self.labels = {} 47 | 48 | self.setup() 49 | 50 | def setup(self): 51 | n_train = 100 52 | image_list = [] 53 | label_list = [] 54 | for i in range(168): 55 | image_list.append(os.path.join(self.root, "building_{:03d}.png".format(i+1))) 56 | label_list.append(os.path.join(self.root, "building_mask_{:03d}.png".format(i+1))) 57 | self.images["train"] = image_list[:n_train] 58 | self.labels["train"] = label_list[:n_train] 59 | self.images["test"] = image_list[n_train:] 60 | self.labels["test"] = label_list[n_train:] 61 | 62 | def __len__(self): 63 | return len(self.images[self.split]) 64 | 65 | def __getitem__(self, index): 66 | img_path = self.images[self.split][index] 67 | lbl_path = self.labels[self.split][index] 68 | # Read image and label 69 | img = read_image(img_path) 70 | lbl = read_image(lbl_path).long() 71 | 72 | # Resize 73 | img = TF.resize(img, self.img_size, antialias=True) 74 | lbl = TF.resize(lbl, self.img_size, interpolation=TF.InterpolationMode.NEAREST, antialias=True).squeeze(0).long() 75 | 76 | if self.split == "train": 77 | # Random flips 78 | if np.random.random() < 0.5: 79 | img = TF.vflip(img) 80 | lbl = TF.vflip(lbl) 81 | if np.random.random() < 0.5: 82 | img = TF.hflip(img) 83 | lbl = TF.hflip(lbl) 84 | # Random rotations 85 | if np.random.random() < 0.5: 86 | lbl = lbl.unsqueeze(0) 87 | angle = np.random.randint(-180, 180) 88 | img = TF.rotate(img, angle) 89 | lbl = TF.rotate(lbl, angle) 90 | lbl = lbl.squeeze(0) 91 | # Random color jitter 92 | if np.random.random() < 0.5: 93 | img = TF.adjust_contrast(img, 0.75 + np.random.random() * 0.5) 94 | img = TF.adjust_saturation(img, 0.75 + np.random.random() * 0.5) 95 | img = TF.adjust_hue(img, np.random.random() * 0.05) 96 | 97 | # Normalize 98 | if self.img_norm: 99 | img = TF.normalize(img.float(), self.mean, self.std) 100 | 101 | return img, lbl 102 | 103 | 104 | 105 | 106 | --------------------------------------------------------------------------------