├── .gitignore ├── LICENSE ├── README.md ├── black.toml ├── compress.py ├── cut_test_dataset.py ├── fit_predict.py ├── fit_predict2.py ├── inria ├── augmentations.py ├── dataset.py ├── factory.py ├── losses.py ├── metric.py ├── models │ ├── __init__.py │ ├── can.py │ ├── deeplab.py │ ├── efficient_unet.py │ ├── fpn.py │ ├── hg.py │ ├── hrnet.py │ ├── u2net.py │ └── unet.py ├── optim.py ├── pseudo.py ├── scheduler.py └── visualization.py ├── make_train_tiles.py ├── requirements.txt ├── run_tensorboard.cmd ├── run_tensorboard.sh ├── sample_color.jpg ├── scripts ├── train.cmd ├── train_b6.sh ├── train_hrnet18_4x1080Ti.sh ├── train_hrnet18_p3.8xlarge.sh ├── train_hrnet34_p3.8xlarge.sh ├── train_hrnet34_unet64_aug_hard.sh ├── train_hrnet34_unet64_aug_light.sh ├── train_hrnet34_unet64_aug_medium.sh ├── train_hrnet34_unet64_aug_none.sh ├── train_hrnet48_p3.8xlarge.sh ├── train_mixnet.sh ├── train_resnets.sh ├── train_seresnext50_can.sh └── train_u2net.sh ├── submit.py └── tests ├── mask.png ├── test_data.py └── test_models.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | runs/ 107 | old_runs/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Eugene Khvedchenya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Catalyst-Inria-Segmentation-Example 2 | An example project showing the power of Catalyst for training segmentation model for Inria Sattelite Segmentation Challenge 3 | 4 | # Dependencies 5 | 6 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" git+https://github.com/NVIDIA/apex.git 7 | pip install git+https://github.com/mapillary/inplace_abn -------------------------------------------------------------------------------- /black.toml: -------------------------------------------------------------------------------- 1 | # Example configuration for Black. 2 | 3 | # NOTE: you have to use single-quoted strings in TOML for regular expressions. 4 | # It's the equivalent of r-strings in Python. Multiline strings are treated as 5 | # verbose regular expressions by Black. Use [ ] to denote a significant space 6 | # character. 7 | 8 | [tool.black] 9 | line-length = 119 10 | target-version = ['py35', 'py36', 'py37', 'py38'] 11 | include = '\.pyi?$' 12 | exclude = ''' 13 | /( 14 | \.eggs 15 | | \.git 16 | | \.hg 17 | | \.mypy_cache 18 | | \.tox 19 | | \.venv 20 | | _build 21 | | buck-out 22 | | build 23 | | dist 24 | )/ 25 | ''' 26 | -------------------------------------------------------------------------------- /compress.py: -------------------------------------------------------------------------------- 1 | # Run this script from the terminal as follows: 2 | # python compress.py input_dir output_dir 3 | # replacing input_dir by the directory containing the *.tif images 4 | # to compress, and output_dir by the destination folder 5 | 6 | # Requires GDAL installed in the machine! 7 | 8 | 9 | import os 10 | import subprocess 11 | import sys 12 | 13 | input_dir = sys.argv[1] 14 | output_dir = sys.argv[2] 15 | 16 | for file in os.listdir(input_dir): 17 | if file.endswith(".tif"): 18 | input_file = os.path.join(input_dir, file) 19 | output_file = os.path.join(output_dir, file) 20 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 21 | command = ( 22 | "gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 " 23 | + input_file 24 | + " " 25 | + output_file 26 | ) 27 | print(command) 28 | subprocess.call(command, shell=True) 29 | 30 | print("Done!") 31 | -------------------------------------------------------------------------------- /cut_test_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2, os 4 | from pytorch_toolbelt.inference.tiles import ImageSlicer 5 | from pytorch_toolbelt.utils.fs import id_from_fname, read_image_as_is 6 | import pandas as pd 7 | from tqdm import tqdm 8 | 9 | from inria.dataset import TEST_LOCATIONS 10 | 11 | 12 | def split_image(image_fname, output_dir, tile_size, tile_step, image_margin): 13 | os.makedirs(output_dir, exist_ok=True) 14 | image = read_image_as_is(image_fname) 15 | image_id = id_from_fname(image_fname) 16 | 17 | slicer = ImageSlicer(image.shape, tile_size, tile_step, image_margin) 18 | tiles = slicer.split(image) 19 | 20 | fnames = [] 21 | for i, tile in enumerate(tiles): 22 | output_fname = os.path.join(output_dir, f"{image_id}_tile_{i}.png") 23 | cv2.imwrite(output_fname, tile) 24 | fnames.append(output_fname) 25 | 26 | return fnames 27 | 28 | 29 | def cut_dataset_in_patches(data_dir, tile_size, tile_step, image_margin): 30 | locations = TEST_LOCATIONS 31 | 32 | train_data = [] 33 | 34 | # For validation, we remove the first five images of every location (e.g., austin{1-5}.tif, chicago{1-5}.tif) from the training set. 35 | # That is suggested validation strategy by competition host 36 | for loc in locations: 37 | for i in range(1, 37): 38 | train_data.append(f"{loc}{i}") 39 | 40 | train_imgs = [os.path.join(data_dir, "test", "images", f"{fname}.tif") for fname in train_data] 41 | 42 | images_dir = os.path.join(data_dir, "test_tiles", "images") 43 | 44 | for train_img in tqdm(train_imgs, total=len(train_imgs), desc="test_imgs"): 45 | split_image(train_img, images_dir, tile_size, tile_step, image_margin) 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument( 51 | "-dd", "--data-dir", type=str, required=True, help="Data directory for INRIA sattelite dataset" 52 | ) 53 | args = parser.parse_args() 54 | 55 | cut_dataset_in_patches(args.data_dir, tile_size=(768, 768), tile_step=(512, 512), image_margin=0) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /fit_predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | from datetime import datetime 8 | from functools import partial 9 | from typing import List, Tuple, Dict 10 | 11 | import catalyst 12 | import cv2 13 | import numpy as np 14 | import torch 15 | from catalyst.contrib.nn import OneCycleLRWithWarmup 16 | from catalyst.data import DistributedSamplerWrapper 17 | from catalyst.dl import ( 18 | SupervisedRunner, 19 | CriterionCallback, 20 | OptimizerCallback, 21 | SchedulerCallback, 22 | MetricAggregationCallback, 23 | Callback, 24 | ) 25 | from catalyst.utils import load_checkpoint, unpack_checkpoint 26 | from pytorch_toolbelt.optimization.functional import get_optimizable_parameters 27 | from pytorch_toolbelt.utils import fs 28 | from pytorch_toolbelt.utils.catalyst import ( 29 | ShowPolarBatchesCallback, 30 | HyperParametersCallback, 31 | BestMetricCheckpointCallback, 32 | PixelAccuracyCallback, 33 | report_checkpoint, 34 | clean_checkpoint, 35 | ) 36 | from pytorch_toolbelt.utils.random import set_manual_seed 37 | from pytorch_toolbelt.utils.torch_utils import count_parameters, transfer_weights 38 | from sklearn.utils import compute_sample_weight 39 | from torch import nn 40 | from torch.optim.lr_scheduler import CyclicLR 41 | from torch.utils.data import DataLoader, WeightedRandomSampler, DistributedSampler 42 | 43 | from inria.dataset import ( 44 | read_inria_image, 45 | INPUT_IMAGE_KEY, 46 | OUTPUT_MASK_KEY, 47 | INPUT_MASK_KEY, 48 | get_pseudolabeling_dataset, 49 | get_datasets, 50 | UNLABELED_SAMPLE, 51 | OUTPUT_MASK_8_KEY, 52 | OUTPUT_MASK_4_KEY, 53 | OUTPUT_MASK_16_KEY, 54 | INPUT_IMAGE_ID_KEY, 55 | get_xview2_extra_dataset, 56 | INPUT_MASK_WEIGHT_KEY, 57 | OUTPUT_MASK_2_KEY, 58 | OUTPUT_DSV_MASK_1_KEY, 59 | OUTPUT_DSV_MASK_2_KEY, 60 | OUTPUT_DSV_MASK_3_KEY, 61 | OUTPUT_DSV_MASK_4_KEY, 62 | OUTPUT_DSV_MASK_5_KEY, 63 | OUTPUT_DSV_MASK_6_KEY, 64 | ) 65 | from inria.factory import predict 66 | from inria.losses import get_loss, ResizeTargetToPrediction2d 67 | from inria.metric import JaccardMetricPerImageWithOptimalThreshold 68 | from inria.models import get_model 69 | from inria.optim import get_optimizer 70 | from inria.pseudo import BCEOnlinePseudolabelingCallback2d 71 | from inria.scheduler import get_scheduler 72 | from inria.visualization import draw_inria_predictions 73 | 74 | 75 | def get_criterions( 76 | criterions, 77 | criterions_stride1_dsv1=None, 78 | criterions_stride1_dsv2=None, 79 | criterions_stride1_dsv3=None, 80 | criterions_stride1_dsv4=None, 81 | criterions_stride1_dsv5=None, 82 | criterions_stride1_dsv6=None, 83 | criterions_stride2=None, 84 | criterions_stride4=None, 85 | criterions_stride8=None, 86 | criterions_stride16=None, 87 | ignore_index=None, 88 | ) -> Tuple[List[Callback], Dict]: 89 | criterions_dict = {} 90 | losses = [] 91 | callbacks = [] 92 | 93 | # Create main losses 94 | for loss_name, loss_weight in criterions: 95 | criterion_callback = CriterionCallback( 96 | prefix=f"{OUTPUT_MASK_KEY}/" + loss_name, 97 | input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], 98 | output_key=OUTPUT_MASK_KEY, 99 | criterion_key=f"{OUTPUT_MASK_KEY}/" + loss_name, 100 | multiplier=float(loss_weight), 101 | ) 102 | 103 | criterions_dict[criterion_callback.criterion_key] = get_loss(loss_name, ignore_index=ignore_index) 104 | callbacks.append(criterion_callback) 105 | losses.append(criterion_callback.prefix) 106 | print("Using loss", loss_name, loss_weight) 107 | 108 | for supervision_losses, supervision_output in zip( 109 | [ 110 | criterions_stride1_dsv1, 111 | criterions_stride1_dsv2, 112 | criterions_stride1_dsv3, 113 | criterions_stride1_dsv4, 114 | criterions_stride1_dsv5, 115 | criterions_stride1_dsv6, 116 | ], 117 | [ 118 | OUTPUT_DSV_MASK_1_KEY, 119 | OUTPUT_DSV_MASK_2_KEY, 120 | OUTPUT_DSV_MASK_3_KEY, 121 | OUTPUT_DSV_MASK_4_KEY, 122 | OUTPUT_DSV_MASK_5_KEY, 123 | OUTPUT_DSV_MASK_6_KEY, 124 | ], 125 | ): 126 | if supervision_losses is not None: 127 | for loss_name, loss_weight in supervision_losses: 128 | prefix = f"{supervision_output}/" + loss_name 129 | criterion_callback = CriterionCallback( 130 | prefix=prefix, 131 | input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], 132 | output_key=supervision_output, 133 | criterion_key=prefix, 134 | multiplier=float(loss_weight), 135 | ) 136 | 137 | criterions_dict[criterion_callback.criterion_key] = get_loss(loss_name, ignore_index=ignore_index) 138 | callbacks.append(criterion_callback) 139 | losses.append(criterion_callback.prefix) 140 | print("Using loss", loss_name, loss_weight) 141 | 142 | # Additional supervision losses 143 | for supervision_losses, supervision_output in zip( 144 | [criterions_stride2, criterions_stride4, criterions_stride8, criterions_stride16], 145 | [OUTPUT_MASK_2_KEY, OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY], 146 | ): 147 | if supervision_losses is not None: 148 | for loss_name, loss_weight in supervision_losses: 149 | prefix = f"{supervision_output}/" + loss_name 150 | criterion_callback = CriterionCallback( 151 | prefix=prefix, 152 | input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], 153 | output_key=supervision_output, 154 | criterion_key=prefix, 155 | multiplier=float(loss_weight), 156 | ) 157 | 158 | criterions_dict[criterion_callback.criterion_key] = ResizeTargetToPrediction2d( 159 | get_loss(loss_name, ignore_index=ignore_index) 160 | ) 161 | callbacks.append(criterion_callback) 162 | losses.append(criterion_callback.prefix) 163 | print("Using loss", loss_name, loss_weight) 164 | 165 | callbacks.append(MetricAggregationCallback(prefix="loss", metrics=losses, mode="sum")) 166 | return callbacks, criterions_dict 167 | 168 | 169 | def main(): 170 | parser = argparse.ArgumentParser() 171 | 172 | ########################################################################################### 173 | # Distributed-training related stuff 174 | parser.add_argument("--local_rank", type=int, default=0) 175 | ########################################################################################### 176 | 177 | parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") 178 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 179 | parser.add_argument("-v", "--verbose", action="store_true") 180 | parser.add_argument("--fast", action="store_true") 181 | parser.add_argument( 182 | "-dd", 183 | "--data-dir", 184 | type=str, 185 | help="Data directory for INRIA sattelite dataset", 186 | default=os.environ.get("INRIA_DATA_DIR"), 187 | ) 188 | parser.add_argument( 189 | "-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset" 190 | ) 191 | parser.add_argument("-m", "--model", type=str, default="resnet34_fpncat128", help="") 192 | parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") 193 | parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") 194 | # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') 195 | # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') 196 | # parser.add_argument('-ft', '--fine-tune', action='store_true') 197 | parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") 198 | parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") 199 | parser.add_argument( 200 | "-l2", 201 | "--criterion2", 202 | type=str, 203 | required=False, 204 | action="append", 205 | nargs="+", 206 | help="Criterion for stride 2 mask", 207 | ) 208 | parser.add_argument( 209 | "-l4", 210 | "--criterion4", 211 | type=str, 212 | required=False, 213 | action="append", 214 | nargs="+", 215 | help="Criterion for stride 4 mask", 216 | ) 217 | parser.add_argument( 218 | "-l8", 219 | "--criterion8", 220 | type=str, 221 | required=False, 222 | action="append", 223 | nargs="+", 224 | help="Criterion for stride 8 mask", 225 | ) 226 | parser.add_argument( 227 | "-l16", 228 | "--criterion16", 229 | type=str, 230 | required=False, 231 | action="append", 232 | nargs="+", 233 | help="Criterion for stride 16 mask", 234 | ) 235 | 236 | parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") 237 | parser.add_argument( 238 | "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights" 239 | ) 240 | parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") 241 | parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") 242 | parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") 243 | parser.add_argument("--run-mode", default="fit_predict", type=str, help="") 244 | parser.add_argument("--transfer", default=None, type=str, help="") 245 | parser.add_argument("--fp16", action="store_true") 246 | parser.add_argument("--size", default=512, type=int) 247 | parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") 248 | parser.add_argument("-x", "--experiment", default=None, type=str, help="") 249 | parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") 250 | parser.add_argument("--opl", action="store_true") 251 | parser.add_argument( 252 | "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" 253 | ) 254 | parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") 255 | parser.add_argument("--show", action="store_true") 256 | parser.add_argument("--dsv", action="store_true") 257 | 258 | args = parser.parse_args() 259 | 260 | args.is_master = args.local_rank == 0 261 | args.distributed = False 262 | fp16 = args.fp16 263 | 264 | if "WORLD_SIZE" in os.environ: 265 | args.distributed = int(os.environ["WORLD_SIZE"]) > 1 266 | args.world_size = int(os.environ["WORLD_SIZE"]) 267 | # args.world_size = torch.distributed.get_world_size() 268 | 269 | print("Initializing init_process_group", args.local_rank) 270 | 271 | torch.cuda.set_device(args.local_rank) 272 | torch.distributed.init_process_group(backend="nccl") 273 | print("Initialized init_process_group", args.local_rank) 274 | 275 | is_master = args.is_master | (not args.distributed) 276 | 277 | if args.distributed: 278 | distributed_params = {"rank": args.local_rank, "syncbn": True} 279 | if args.fp16: 280 | distributed_params["amp"] = True 281 | else: 282 | if args.fp16: 283 | distributed_params = {} 284 | distributed_params["amp"] = True 285 | else: 286 | distributed_params = False 287 | 288 | set_manual_seed(args.seed + args.local_rank) 289 | catalyst.utils.set_global_seed(args.seed + args.local_rank) 290 | torch.backends.cudnn.deterministic = True 291 | torch.backends.cudnn.benchmark = False 292 | 293 | data_dir = args.data_dir 294 | if data_dir is None: 295 | raise ValueError("--data-dir must be set") 296 | 297 | num_workers = args.workers 298 | num_epochs = args.epochs 299 | batch_size = args.batch_size 300 | learning_rate = args.learning_rate 301 | model_name = args.model 302 | optimizer_name = args.optimizer 303 | image_size = args.size, args.size 304 | fast = args.fast 305 | augmentations = args.augmentations 306 | train_mode = args.train_mode 307 | scheduler_name = args.scheduler 308 | experiment = args.experiment 309 | dropout = args.dropout 310 | online_pseudolabeling = args.opl 311 | criterions = args.criterion 312 | criterions2 = args.criterion2 313 | criterions4 = args.criterion4 314 | criterions8 = args.criterion8 315 | criterions16 = args.criterion16 316 | 317 | verbose = args.verbose 318 | show = args.show 319 | accumulation_steps = args.accumulation_steps 320 | weight_decay = args.weight_decay 321 | extra_data_xview2 = args.data_dir_xview2 322 | 323 | run_train = num_epochs > 0 324 | need_weight_mask = any(c[0] == "wbce" for c in criterions) 325 | 326 | custom_model_kwargs = {} 327 | if dropout is not None: 328 | custom_model_kwargs["dropout"] = float(dropout) 329 | 330 | if any([criterions2, criterions4, criterions8, criterions16]): 331 | custom_model_kwargs["need_supervision_masks"] = True 332 | print("Enabling supervision masks") 333 | 334 | model: nn.Module = get_model(model_name, **custom_model_kwargs).cuda() 335 | 336 | if args.transfer: 337 | transfer_checkpoint = fs.auto_file(args.transfer) 338 | print("Transfering weights from model checkpoint", transfer_checkpoint) 339 | checkpoint = load_checkpoint(transfer_checkpoint) 340 | pretrained_dict = checkpoint["model_state_dict"] 341 | 342 | transfer_weights(model, pretrained_dict) 343 | 344 | if args.checkpoint: 345 | checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) 346 | unpack_checkpoint(checkpoint, model=model) 347 | 348 | print("Loaded model weights from:", args.checkpoint) 349 | report_checkpoint(checkpoint) 350 | 351 | main_metric = "optimized_jaccard" 352 | 353 | current_time = datetime.now().strftime("%y%m%d_%H_%M") 354 | checkpoint_prefix = f"{current_time}_{args.model}" 355 | 356 | if fp16: 357 | checkpoint_prefix += "_fp16" 358 | 359 | if fast: 360 | checkpoint_prefix += "_fast" 361 | 362 | if online_pseudolabeling: 363 | checkpoint_prefix += "_opl" 364 | 365 | if extra_data_xview2: 366 | checkpoint_prefix += "_with_xview2" 367 | 368 | if experiment is not None: 369 | checkpoint_prefix = experiment 370 | 371 | default_callbacks = [ 372 | PixelAccuracyCallback(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY), 373 | # JaccardMetricPerImage(input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="jaccard"), 374 | JaccardMetricPerImageWithOptimalThreshold( 375 | input_key=INPUT_MASK_KEY, output_key=OUTPUT_MASK_KEY, prefix="optimized_jaccard" 376 | ), 377 | ] 378 | 379 | if is_master: 380 | 381 | default_callbacks += [ 382 | BestMetricCheckpointCallback(target_metric="optimized_jaccard", target_metric_minimize=False), 383 | HyperParametersCallback( 384 | hparam_dict={ 385 | "model": model_name, 386 | "scheduler": scheduler_name, 387 | "optimizer": optimizer_name, 388 | "augmentations": augmentations, 389 | "size": args.size, 390 | "weight_decay": weight_decay, 391 | "epochs": num_epochs, 392 | "dropout": None if dropout is None else float(dropout), 393 | } 394 | ), 395 | ] 396 | 397 | if show: 398 | visualize_inria_predictions = partial( 399 | draw_inria_predictions, 400 | inputs_to_labels=lambda x: x.ge(0.5).squeeze(1), 401 | outputs_to_labels=lambda x: x.float().sigmoid().ge(0.5).squeeze(1), 402 | image_key=INPUT_IMAGE_KEY, 403 | image_id_key=INPUT_IMAGE_ID_KEY, 404 | targets_key=INPUT_MASK_KEY, 405 | outputs_key=OUTPUT_MASK_KEY, 406 | max_images=16, 407 | ) 408 | default_callbacks += [ 409 | ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False), 410 | ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True), 411 | ] 412 | 413 | train_ds, valid_ds, train_sampler = get_datasets( 414 | data_dir=data_dir, 415 | image_size=image_size, 416 | augmentation=augmentations, 417 | train_mode=train_mode, 418 | buildings_only=(train_mode == "tiles"), 419 | fast=fast, 420 | need_weight_mask=need_weight_mask, 421 | ) 422 | 423 | if extra_data_xview2 is not None: 424 | extra_train_ds, _ = get_xview2_extra_dataset( 425 | extra_data_xview2, 426 | image_size=image_size, 427 | augmentation=augmentations, 428 | fast=fast, 429 | need_weight_mask=need_weight_mask, 430 | ) 431 | 432 | weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) 433 | train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) 434 | 435 | train_ds = train_ds + extra_train_ds 436 | print("Using extra data from xView2 with", len(extra_train_ds), "samples") 437 | 438 | if run_train: 439 | loaders = collections.OrderedDict() 440 | callbacks = default_callbacks.copy() 441 | criterions_dict = {} 442 | losses = [] 443 | 444 | ignore_index = None 445 | if online_pseudolabeling: 446 | ignore_index = UNLABELED_SAMPLE 447 | unlabeled_label = get_pseudolabeling_dataset( 448 | data_dir, include_masks=False, augmentation=None, image_size=image_size 449 | ) 450 | 451 | unlabeled_train = get_pseudolabeling_dataset( 452 | data_dir, include_masks=True, augmentation=augmentations, image_size=image_size 453 | ) 454 | 455 | if args.distributed: 456 | label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False) 457 | else: 458 | label_sampler = None 459 | 460 | loaders["infer"] = DataLoader( 461 | unlabeled_label, 462 | batch_size=batch_size // 2, 463 | num_workers=num_workers, 464 | pin_memory=True, 465 | sampler=label_sampler, 466 | drop_last=False, 467 | ) 468 | 469 | if train_sampler is not None: 470 | num_samples = 2 * train_sampler.num_samples 471 | else: 472 | num_samples = 2 * len(train_ds) 473 | weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) 474 | 475 | train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) 476 | train_ds = train_ds + unlabeled_train 477 | 478 | callbacks += [ 479 | BCEOnlinePseudolabelingCallback2d( 480 | unlabeled_train, 481 | pseudolabel_loader="infer", 482 | prob_threshold=0.7, 483 | output_key=OUTPUT_MASK_KEY, 484 | unlabeled_class=UNLABELED_SAMPLE, 485 | label_frequency=5, 486 | ) 487 | ] 488 | 489 | print("Using online pseudolabeling with ", len(unlabeled_label), "samples") 490 | 491 | valid_sampler = None 492 | if args.distributed: 493 | if train_sampler is not None: 494 | train_sampler = DistributedSamplerWrapper( 495 | train_sampler, args.world_size, args.local_rank, shuffle=True 496 | ) 497 | else: 498 | train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True) 499 | valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False) 500 | 501 | loaders["train"] = DataLoader( 502 | train_ds, 503 | batch_size=batch_size, 504 | num_workers=num_workers, 505 | pin_memory=True, 506 | drop_last=True, 507 | shuffle=train_sampler is None, 508 | sampler=train_sampler, 509 | ) 510 | 511 | loaders["valid"] = DataLoader( 512 | valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler 513 | ) 514 | 515 | if model_name in {"U2NETP", "U2NET"}: 516 | dsv_criterions = criterions 517 | else: 518 | dsv_criterions = None 519 | 520 | loss_callbacks, loss_criterions = get_criterions( 521 | criterions=criterions, 522 | criterions_stride1_dsv1=dsv_criterions, 523 | criterions_stride1_dsv2=dsv_criterions, 524 | criterions_stride1_dsv3=dsv_criterions, 525 | criterions_stride1_dsv4=dsv_criterions, 526 | criterions_stride1_dsv5=dsv_criterions, 527 | criterions_stride1_dsv6=dsv_criterions, 528 | criterions_stride2=criterions2, 529 | criterions_stride4=criterions4, 530 | criterions_stride8=criterions8, 531 | criterions_stride16=criterions16, 532 | ) 533 | callbacks += loss_callbacks 534 | 535 | optimizer = get_optimizer( 536 | optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay 537 | ) 538 | scheduler = get_scheduler( 539 | scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"]) 540 | ) 541 | if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): 542 | callbacks += [SchedulerCallback(mode="batch")] 543 | 544 | log_dir = os.path.join("runs", checkpoint_prefix) 545 | 546 | if is_master: 547 | os.makedirs(log_dir, exist_ok=False) 548 | config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") 549 | with open(config_fname, "w") as f: 550 | train_session_args = vars(args) 551 | f.write(json.dumps(train_session_args, indent=2)) 552 | 553 | print("Train session :", checkpoint_prefix) 554 | print(" FP16 mode :", fp16) 555 | print(" Fast mode :", args.fast) 556 | print(" Train mode :", train_mode) 557 | print(" Epochs :", num_epochs) 558 | print(" Workers :", num_workers) 559 | print(" Data dir :", data_dir) 560 | print(" Log dir :", log_dir) 561 | print(" Augmentations :", augmentations) 562 | print(" Train size :", "batches", len(loaders["train"]), "dataset", len(train_ds)) 563 | print(" Valid size :", "batches", len(loaders["valid"]), "dataset", len(valid_ds)) 564 | print("Model :", model_name) 565 | print(" Parameters :", count_parameters(model)) 566 | print(" Image size :", image_size) 567 | print("Optimizer :", optimizer_name) 568 | print(" Learning rate :", learning_rate) 569 | print(" Batch size :", batch_size) 570 | print(" Criterion :", criterions) 571 | print(" Use weight mask:", need_weight_mask) 572 | if args.distributed: 573 | print("Distributed") 574 | print(" World size :", args.world_size) 575 | print(" Local rank :", args.local_rank) 576 | print(" Is master :", args.is_master) 577 | 578 | # model training 579 | runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") 580 | runner.train( 581 | fp16=distributed_params, 582 | model=model, 583 | criterion=loss_criterions, 584 | optimizer=optimizer, 585 | scheduler=scheduler, 586 | callbacks=callbacks, 587 | loaders=loaders, 588 | logdir=os.path.join(log_dir, "main"), 589 | num_epochs=num_epochs, 590 | verbose=verbose, 591 | main_metric=main_metric, 592 | minimize_metric=False, 593 | checkpoint_data={"cmd_args": vars(args)}, 594 | ) 595 | 596 | # Training is finished. Let's run predictions using best checkpoint weights 597 | if is_master: 598 | best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") 599 | 600 | model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") 601 | clean_checkpoint(best_checkpoint, model_checkpoint) 602 | 603 | unpack_checkpoint(torch.load(model_checkpoint), model=model) 604 | 605 | mask = predict( 606 | model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size 607 | ) 608 | mask = ((mask > 0) * 255).astype(np.uint8) 609 | name = os.path.join(log_dir, "sample_color.jpg") 610 | cv2.imwrite(name, mask) 611 | 612 | 613 | if __name__ == "__main__": 614 | main() 615 | -------------------------------------------------------------------------------- /fit_predict2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | from datetime import datetime 8 | from functools import partial 9 | from typing import List, Tuple, Dict 10 | 11 | import catalyst 12 | import cv2 13 | import numpy as np 14 | import torch 15 | from catalyst.contrib.nn import OneCycleLRWithWarmup 16 | from catalyst.data import DistributedSamplerWrapper 17 | from catalyst.dl import ( 18 | SupervisedRunner, 19 | CriterionCallback, 20 | OptimizerCallback, 21 | SchedulerCallback, 22 | MetricAggregationCallback, 23 | Callback, 24 | ) 25 | from catalyst.utils import load_checkpoint, unpack_checkpoint 26 | from pytorch_toolbelt.optimization.functional import get_optimizable_parameters 27 | from pytorch_toolbelt.utils import fs 28 | from pytorch_toolbelt.utils.catalyst import ( 29 | ShowPolarBatchesCallback, 30 | HyperParametersCallback, 31 | BestMetricCheckpointCallback, 32 | PixelAccuracyCallback, 33 | report_checkpoint, 34 | clean_checkpoint, 35 | ) 36 | from pytorch_toolbelt.utils.random import set_manual_seed 37 | from pytorch_toolbelt.utils.torch_utils import count_parameters, transfer_weights 38 | from sklearn.utils import compute_sample_weight 39 | from torch import nn 40 | from torch.optim.lr_scheduler import CyclicLR 41 | from torch.utils.data import DataLoader, WeightedRandomSampler, DistributedSampler 42 | 43 | from inria.dataset import ( 44 | read_inria_image, 45 | INPUT_IMAGE_KEY, 46 | OUTPUT_MASK_KEY, 47 | INPUT_MASK_KEY, 48 | get_pseudolabeling_dataset, 49 | get_datasets, 50 | UNLABELED_SAMPLE, 51 | OUTPUT_MASK_8_KEY, 52 | OUTPUT_MASK_4_KEY, 53 | OUTPUT_MASK_16_KEY, 54 | INPUT_IMAGE_ID_KEY, 55 | get_xview2_extra_dataset, 56 | INPUT_MASK_WEIGHT_KEY, 57 | OUTPUT_MASK_2_KEY, 58 | decode_depth_mask, 59 | depth2mask, 60 | mask_to_ce_target, 61 | ) 62 | from inria.factory import predict 63 | from inria.losses import get_loss, ResizeTargetToPrediction2d 64 | from inria.metric import JaccardMetricPerImage 65 | from inria.models import get_model 66 | from inria.optim import get_optimizer 67 | from inria.pseudo import BCEOnlinePseudolabelingCallback2d 68 | from inria.scheduler import get_scheduler 69 | from inria.visualization import draw_inria_predictions 70 | 71 | 72 | def get_criterions( 73 | criterions, 74 | criterions_stride2=None, 75 | criterions_stride4=None, 76 | criterions_stride8=None, 77 | criterions_stride16=None, 78 | ignore_index=None, 79 | ) -> Tuple[List[Callback], Dict]: 80 | criterions_dict = {} 81 | losses = [] 82 | callbacks = [] 83 | 84 | # Create main losses 85 | for loss_name, loss_weight in criterions: 86 | criterion_callback = CriterionCallback( 87 | prefix=f"{OUTPUT_MASK_KEY}/" + loss_name, 88 | input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], 89 | output_key=OUTPUT_MASK_KEY, 90 | criterion_key=f"{OUTPUT_MASK_KEY}/" + loss_name, 91 | multiplier=float(loss_weight), 92 | ) 93 | 94 | criterions_dict[criterion_callback.criterion_key] = get_loss(loss_name, ignore_index=ignore_index) 95 | callbacks.append(criterion_callback) 96 | losses.append(criterion_callback.prefix) 97 | print("Using loss", loss_name, loss_weight) 98 | 99 | # Additional supervision losses 100 | for supervision_losses, supervision_output in zip( 101 | [criterions_stride2, criterions_stride4, criterions_stride8, criterions_stride16], 102 | [OUTPUT_MASK_2_KEY, OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY], 103 | ): 104 | if supervision_losses is not None: 105 | for loss_name, loss_weight in supervision_losses: 106 | prefix = f"{supervision_output}/" + loss_name 107 | criterion_callback = CriterionCallback( 108 | prefix=prefix, 109 | input_key=INPUT_MASK_KEY if loss_name != "wbce" else [INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY], 110 | output_key=supervision_output, 111 | criterion_key=prefix, 112 | multiplier=float(loss_weight), 113 | ) 114 | 115 | criterions_dict[criterion_callback.criterion_key] = ResizeTargetToPrediction2d( 116 | get_loss(loss_name, ignore_index=ignore_index) 117 | ) 118 | callbacks.append(criterion_callback) 119 | losses.append(criterion_callback.prefix) 120 | print("Using loss", loss_name, loss_weight) 121 | 122 | callbacks.append(MetricAggregationCallback(prefix="loss", metrics=losses, mode="sum")) 123 | return callbacks, criterions_dict 124 | 125 | 126 | def main(): 127 | parser = argparse.ArgumentParser() 128 | 129 | ########################################################################################### 130 | # Distributed-training related stuff 131 | parser.add_argument("--local_rank", type=int, default=0) 132 | ########################################################################################### 133 | 134 | parser.add_argument("-acc", "--accumulation-steps", type=int, default=1, help="Number of batches to process") 135 | parser.add_argument("--seed", type=int, default=42, help="Random seed") 136 | parser.add_argument("-v", "--verbose", action="store_true") 137 | parser.add_argument("--fast", action="store_true") 138 | parser.add_argument( 139 | "-dd", 140 | "--data-dir", 141 | type=str, 142 | help="Data directory for INRIA sattelite dataset", 143 | default=os.environ.get("INRIA_DATA_DIR"), 144 | ) 145 | parser.add_argument( 146 | "-dd-xview2", "--data-dir-xview2", type=str, required=False, help="Data directory for external xView2 dataset" 147 | ) 148 | parser.add_argument("-m", "--model", type=str, default="b6_unet32_s2", help="") 149 | parser.add_argument("-b", "--batch-size", type=int, default=8, help="Batch Size during training, e.g. -b 64") 150 | parser.add_argument("-e", "--epochs", type=int, default=100, help="Epoch to run") 151 | # parser.add_argument('-es', '--early-stopping', type=int, default=None, help='Maximum number of epochs without improvement') 152 | # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs') 153 | # parser.add_argument('-ft', '--fine-tune', action='store_true') 154 | parser.add_argument("-lr", "--learning-rate", type=float, default=1e-3, help="Initial learning rate") 155 | parser.add_argument("-l", "--criterion", type=str, required=True, action="append", nargs="+", help="Criterion") 156 | parser.add_argument( 157 | "-l2", 158 | "--criterion2", 159 | type=str, 160 | required=False, 161 | action="append", 162 | nargs="+", 163 | help="Criterion for stride 2 mask", 164 | ) 165 | parser.add_argument( 166 | "-l4", 167 | "--criterion4", 168 | type=str, 169 | required=False, 170 | action="append", 171 | nargs="+", 172 | help="Criterion for stride 4 mask", 173 | ) 174 | parser.add_argument( 175 | "-l8", 176 | "--criterion8", 177 | type=str, 178 | required=False, 179 | action="append", 180 | nargs="+", 181 | help="Criterion for stride 8 mask", 182 | ) 183 | parser.add_argument( 184 | "-l16", 185 | "--criterion16", 186 | type=str, 187 | required=False, 188 | action="append", 189 | nargs="+", 190 | help="Criterion for stride 16 mask", 191 | ) 192 | 193 | parser.add_argument("-o", "--optimizer", default="RAdam", help="Name of the optimizer") 194 | parser.add_argument( 195 | "-c", "--checkpoint", type=str, default=None, help="Checkpoint filename to use as initial model weights" 196 | ) 197 | parser.add_argument("-w", "--workers", default=8, type=int, help="Num workers") 198 | parser.add_argument("-a", "--augmentations", default="hard", type=str, help="") 199 | parser.add_argument("-tm", "--train-mode", default="random", type=str, help="") 200 | parser.add_argument("--run-mode", default="fit_predict", type=str, help="") 201 | parser.add_argument("--transfer", default=None, type=str, help="") 202 | parser.add_argument("--fp16", action="store_true") 203 | parser.add_argument("--size", default=512, type=int) 204 | parser.add_argument("-s", "--scheduler", default="multistep", type=str, help="") 205 | parser.add_argument("-x", "--experiment", default=None, type=str, help="") 206 | parser.add_argument("-d", "--dropout", default=None, type=float, help="Dropout before head layer") 207 | parser.add_argument("--opl", action="store_true") 208 | parser.add_argument( 209 | "--warmup", default=0, type=int, help="Number of warmup epochs with reduced LR on encoder parameters" 210 | ) 211 | parser.add_argument("-wd", "--weight-decay", default=0, type=float, help="L2 weight decay") 212 | parser.add_argument("--show", action="store_true") 213 | parser.add_argument("--dsv", action="store_true") 214 | 215 | args = parser.parse_args() 216 | 217 | args.is_master = args.local_rank == 0 218 | args.distributed = False 219 | fp16 = args.fp16 220 | 221 | if "WORLD_SIZE" in os.environ: 222 | args.distributed = int(os.environ["WORLD_SIZE"]) > 1 223 | args.world_size = int(os.environ["WORLD_SIZE"]) 224 | # args.world_size = torch.distributed.get_world_size() 225 | 226 | print("Initializing init_process_group", args.local_rank) 227 | 228 | torch.cuda.set_device(args.local_rank) 229 | torch.distributed.init_process_group(backend="nccl") 230 | print("Initialized init_process_group", args.local_rank) 231 | 232 | is_master = args.is_master | (not args.distributed) 233 | 234 | if args.distributed: 235 | distributed_params = {"rank": args.local_rank, "syncbn": True} 236 | if args.fp16: 237 | distributed_params["amp"] = True 238 | else: 239 | if args.fp16: 240 | distributed_params = {} 241 | distributed_params["amp"] = True 242 | else: 243 | distributed_params = False 244 | 245 | set_manual_seed(args.seed + args.local_rank) 246 | catalyst.utils.set_global_seed(args.seed + args.local_rank) 247 | torch.backends.cudnn.deterministic = False 248 | torch.backends.cudnn.benchmark = True 249 | 250 | data_dir = args.data_dir 251 | if data_dir is None: 252 | raise ValueError("--data-dir must be set") 253 | 254 | num_workers = args.workers 255 | num_epochs = args.epochs 256 | batch_size = args.batch_size 257 | learning_rate = args.learning_rate 258 | model_name = args.model 259 | optimizer_name = args.optimizer 260 | image_size = args.size, args.size 261 | fast = args.fast 262 | augmentations = args.augmentations 263 | train_mode = args.train_mode 264 | scheduler_name = args.scheduler 265 | experiment = args.experiment 266 | dropout = args.dropout 267 | online_pseudolabeling = args.opl 268 | criterions = args.criterion 269 | criterions2 = args.criterion2 270 | criterions4 = args.criterion4 271 | criterions8 = args.criterion8 272 | criterions16 = args.criterion16 273 | 274 | verbose = args.verbose 275 | show = args.show 276 | accumulation_steps = args.accumulation_steps 277 | weight_decay = args.weight_decay 278 | extra_data_xview2 = args.data_dir_xview2 279 | 280 | run_train = num_epochs > 0 281 | need_weight_mask = any(c[0] == "wbce" for c in criterions) 282 | 283 | custom_model_kwargs = {"full_size_mask": False} 284 | if dropout is not None: 285 | custom_model_kwargs["dropout"] = float(dropout) 286 | 287 | if any([criterions2, criterions4, criterions8, criterions16]): 288 | custom_model_kwargs["need_supervision_masks"] = True 289 | print("Enabling supervision masks") 290 | 291 | model: nn.Module = get_model(model_name, num_classes=16, **custom_model_kwargs).cuda() 292 | 293 | if args.transfer: 294 | transfer_checkpoint = fs.auto_file(args.transfer) 295 | print("Transfering weights from model checkpoint", transfer_checkpoint) 296 | checkpoint = load_checkpoint(transfer_checkpoint) 297 | pretrained_dict = checkpoint["model_state_dict"] 298 | 299 | transfer_weights(model, pretrained_dict) 300 | 301 | if args.checkpoint: 302 | checkpoint = load_checkpoint(fs.auto_file(args.checkpoint)) 303 | unpack_checkpoint(checkpoint, model=model) 304 | 305 | print("Loaded model weights from:", args.checkpoint) 306 | report_checkpoint(checkpoint) 307 | 308 | main_metric = "jaccard" 309 | 310 | current_time = datetime.now().strftime("%y%m%d_%H_%M") 311 | checkpoint_prefix = f"{current_time}_{args.model}" 312 | 313 | if fp16: 314 | checkpoint_prefix += "_fp16" 315 | 316 | if fast: 317 | checkpoint_prefix += "_fast" 318 | 319 | if online_pseudolabeling: 320 | checkpoint_prefix += "_opl" 321 | 322 | if extra_data_xview2: 323 | checkpoint_prefix += "_with_xview2" 324 | 325 | if experiment is not None: 326 | checkpoint_prefix = experiment 327 | 328 | default_callbacks = [ 329 | JaccardMetricPerImage( 330 | input_key=INPUT_MASK_KEY, 331 | output_key=OUTPUT_MASK_KEY, 332 | prefix="jaccard", 333 | inputs_to_labels=depth2mask, 334 | outputs_to_labels=decode_depth_mask, 335 | ), 336 | ] 337 | 338 | if is_master: 339 | 340 | default_callbacks += [ 341 | BestMetricCheckpointCallback(target_metric="jaccard", target_metric_minimize=False), 342 | HyperParametersCallback( 343 | hparam_dict={ 344 | "model": model_name, 345 | "scheduler": scheduler_name, 346 | "optimizer": optimizer_name, 347 | "augmentations": augmentations, 348 | "size": args.size, 349 | "weight_decay": weight_decay, 350 | "epochs": num_epochs, 351 | "dropout": None if dropout is None else float(dropout), 352 | } 353 | ), 354 | ] 355 | 356 | if show: 357 | visualize_inria_predictions = partial( 358 | draw_inria_predictions, 359 | image_key=INPUT_IMAGE_KEY, 360 | image_id_key=INPUT_IMAGE_ID_KEY, 361 | targets_key=INPUT_MASK_KEY, 362 | outputs_key=OUTPUT_MASK_KEY, 363 | inputs_to_labels=depth2mask, 364 | outputs_to_labels=decode_depth_mask, 365 | max_images=16, 366 | ) 367 | default_callbacks += [ 368 | ShowPolarBatchesCallback(visualize_inria_predictions, metric="accuracy", minimize=False), 369 | ShowPolarBatchesCallback(visualize_inria_predictions, metric="loss", minimize=True), 370 | ] 371 | 372 | train_ds, valid_ds, train_sampler = get_datasets( 373 | data_dir=data_dir, 374 | image_size=image_size, 375 | augmentation=augmentations, 376 | train_mode=train_mode, 377 | buildings_only=(train_mode == "tiles"), 378 | fast=fast, 379 | need_weight_mask=need_weight_mask, 380 | make_mask_target_fn=mask_to_ce_target, 381 | ) 382 | 383 | if extra_data_xview2 is not None: 384 | extra_train_ds, _ = get_xview2_extra_dataset( 385 | extra_data_xview2, 386 | image_size=image_size, 387 | augmentation=augmentations, 388 | fast=fast, 389 | need_weight_mask=need_weight_mask, 390 | ) 391 | 392 | weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(extra_train_ds)) 393 | train_sampler = WeightedRandomSampler(weights, train_sampler.num_samples * 2) 394 | 395 | train_ds = train_ds + extra_train_ds 396 | print("Using extra data from xView2 with", len(extra_train_ds), "samples") 397 | 398 | if run_train: 399 | loaders = collections.OrderedDict() 400 | callbacks = default_callbacks.copy() 401 | criterions_dict = {} 402 | losses = [] 403 | 404 | ignore_index = None 405 | if online_pseudolabeling: 406 | ignore_index = UNLABELED_SAMPLE 407 | unlabeled_label = get_pseudolabeling_dataset( 408 | data_dir, include_masks=False, augmentation=None, image_size=image_size 409 | ) 410 | 411 | unlabeled_train = get_pseudolabeling_dataset( 412 | data_dir, include_masks=True, augmentation=augmentations, image_size=image_size 413 | ) 414 | 415 | if args.distributed: 416 | label_sampler = DistributedSampler(unlabeled_label, args.world_size, args.local_rank, shuffle=False) 417 | else: 418 | label_sampler = None 419 | 420 | loaders["infer"] = DataLoader( 421 | unlabeled_label, 422 | batch_size=batch_size // 2, 423 | num_workers=num_workers, 424 | pin_memory=True, 425 | sampler=label_sampler, 426 | drop_last=False, 427 | ) 428 | 429 | if train_sampler is not None: 430 | num_samples = 2 * train_sampler.num_samples 431 | else: 432 | num_samples = 2 * len(train_ds) 433 | weights = compute_sample_weight("balanced", [0] * len(train_ds) + [1] * len(unlabeled_label)) 434 | 435 | train_sampler = WeightedRandomSampler(weights, num_samples, replacement=True) 436 | train_ds = train_ds + unlabeled_train 437 | 438 | callbacks += [ 439 | BCEOnlinePseudolabelingCallback2d( 440 | unlabeled_train, 441 | pseudolabel_loader="infer", 442 | prob_threshold=0.7, 443 | output_key=OUTPUT_MASK_KEY, 444 | unlabeled_class=UNLABELED_SAMPLE, 445 | label_frequency=5, 446 | ) 447 | ] 448 | 449 | print("Using online pseudolabeling with ", len(unlabeled_label), "samples") 450 | 451 | valid_sampler = None 452 | if args.distributed: 453 | if train_sampler is not None: 454 | train_sampler = DistributedSamplerWrapper( 455 | train_sampler, args.world_size, args.local_rank, shuffle=True 456 | ) 457 | else: 458 | train_sampler = DistributedSampler(train_ds, args.world_size, args.local_rank, shuffle=True) 459 | valid_sampler = DistributedSampler(valid_ds, args.world_size, args.local_rank, shuffle=False) 460 | 461 | loaders["train"] = DataLoader( 462 | train_ds, 463 | batch_size=batch_size, 464 | num_workers=num_workers, 465 | pin_memory=True, 466 | drop_last=True, 467 | shuffle=train_sampler is None, 468 | sampler=train_sampler, 469 | ) 470 | 471 | loaders["valid"] = DataLoader( 472 | valid_ds, batch_size=batch_size, num_workers=num_workers, pin_memory=True, sampler=valid_sampler 473 | ) 474 | 475 | loss_callbacks, loss_criterions = get_criterions( 476 | criterions, criterions2, criterions4, criterions8, criterions16 477 | ) 478 | callbacks += loss_callbacks 479 | 480 | optimizer = get_optimizer( 481 | optimizer_name, get_optimizable_parameters(model), learning_rate, weight_decay=weight_decay 482 | ) 483 | scheduler = get_scheduler( 484 | scheduler_name, optimizer, lr=learning_rate, num_epochs=num_epochs, batches_in_epoch=len(loaders["train"]) 485 | ) 486 | if isinstance(scheduler, (CyclicLR, OneCycleLRWithWarmup)): 487 | callbacks += [SchedulerCallback(mode="batch")] 488 | 489 | log_dir = os.path.join("runs", checkpoint_prefix) 490 | 491 | if is_master: 492 | os.makedirs(log_dir, exist_ok=False) 493 | config_fname = os.path.join(log_dir, f"{checkpoint_prefix}.json") 494 | with open(config_fname, "w") as f: 495 | train_session_args = vars(args) 496 | f.write(json.dumps(train_session_args, indent=2)) 497 | 498 | print("Train session :", checkpoint_prefix) 499 | print(" FP16 mode :", fp16) 500 | print(" Fast mode :", args.fast) 501 | print(" Train mode :", train_mode) 502 | print(" Epochs :", num_epochs) 503 | print(" Workers :", num_workers) 504 | print(" Data dir :", data_dir) 505 | print(" Log dir :", log_dir) 506 | print(" Augmentations :", augmentations) 507 | print(" Train size :", "batches", len(loaders["train"]), "dataset", len(train_ds)) 508 | print(" Valid size :", "batches", len(loaders["valid"]), "dataset", len(valid_ds)) 509 | print("Model :", model_name) 510 | print(" Parameters :", count_parameters(model)) 511 | print(" Image size :", image_size) 512 | print("Optimizer :", optimizer_name) 513 | print(" Learning rate :", learning_rate) 514 | print(" Batch size :", batch_size) 515 | print(" Criterion :", criterions) 516 | print(" Use weight mask:", need_weight_mask) 517 | if args.distributed: 518 | print("Distributed") 519 | print(" World size :", args.world_size) 520 | print(" Local rank :", args.local_rank) 521 | print(" Is master :", args.is_master) 522 | 523 | # model training 524 | runner = SupervisedRunner(input_key=INPUT_IMAGE_KEY, output_key=None, device="cuda") 525 | runner.train( 526 | fp16=distributed_params, 527 | model=model, 528 | criterion=loss_criterions, 529 | optimizer=optimizer, 530 | scheduler=scheduler, 531 | callbacks=callbacks, 532 | loaders=loaders, 533 | logdir=os.path.join(log_dir, "main"), 534 | num_epochs=num_epochs, 535 | verbose=verbose, 536 | main_metric=main_metric, 537 | minimize_metric=False, 538 | checkpoint_data={"cmd_args": vars(args)}, 539 | ) 540 | 541 | # Training is finished. Let's run predictions using best checkpoint weights 542 | if is_master: 543 | best_checkpoint = os.path.join(log_dir, "main", "checkpoints", "best.pth") 544 | 545 | model_checkpoint = os.path.join(log_dir, f"{checkpoint_prefix}.pth") 546 | clean_checkpoint(best_checkpoint, model_checkpoint) 547 | 548 | unpack_checkpoint(torch.load(model_checkpoint), model=model) 549 | 550 | mask = predict( 551 | model, read_inria_image("sample_color.jpg"), image_size=image_size, batch_size=args.batch_size 552 | ) 553 | mask = ((mask > 0) * 255).astype(np.uint8) 554 | name = os.path.join(log_dir, "sample_color.jpg") 555 | cv2.imwrite(name, mask) 556 | 557 | 558 | if __name__ == "__main__": 559 | main() 560 | -------------------------------------------------------------------------------- /inria/augmentations.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import cv2 3 | from typing import Tuple, List 4 | 5 | __all__ = [ 6 | "crop_transform", 7 | "safe_augmentations", 8 | "light_augmentations", 9 | "medium_augmentations", 10 | "hard_augmentations", 11 | "get_augmentations", 12 | ] 13 | 14 | 15 | def crop_transform(image_size: Tuple[int, int], min_scale=0.75, max_scale=1.25, input_size=5000): 16 | return A.OneOrOther( 17 | A.RandomSizedCrop( 18 | (int(image_size[0] * min_scale), int(min(input_size, image_size[0] * max_scale))), 19 | image_size[0], 20 | image_size[1], 21 | ), 22 | A.CropNonEmptyMaskIfExists(image_size[0], image_size[1]), 23 | ) 24 | 25 | 26 | def crop_transform_xview2(image_size: Tuple[int, int], min_scale=0.4, max_scale=0.75, input_size=1024): 27 | return A.OneOrOther( 28 | A.RandomSizedCrop( 29 | (int(image_size[0] * min_scale), int(min(input_size, image_size[0] * max_scale))), 30 | image_size[0], 31 | image_size[1], 32 | ), 33 | A.Compose( 34 | [A.Resize(input_size * 2, input_size * 2), A.CropNonEmptyMaskIfExists(image_size[0], image_size[1])] 35 | ), 36 | ) 37 | 38 | 39 | def safe_augmentations() -> List[A.DualTransform]: 40 | return [ 41 | # D4 Augmentations 42 | A.RandomRotate90(p=1), 43 | A.Transpose(p=0.5), 44 | ] 45 | 46 | 47 | def light_augmentations(mask_dropout=True) -> List[A.DualTransform]: 48 | return [ 49 | # D4 Augmentations 50 | A.RandomRotate90(p=1), 51 | A.Transpose(p=0.5), 52 | A.RandomBrightnessContrast(), 53 | A.ShiftScaleRotate(scale_limit=0.05, rotate_limit=15, border_mode=cv2.BORDER_CONSTANT), 54 | ] 55 | 56 | 57 | def medium_augmentations(mask_dropout=True) -> List[A.DualTransform]: 58 | return [ 59 | A.HorizontalFlip(), 60 | A.ShiftScaleRotate(scale_limit=0.1, rotate_limit=15, border_mode=cv2.BORDER_CONSTANT), 61 | # Add occasion blur/sharpening 62 | A.OneOf([A.GaussianBlur(), A.IAASharpen(), A.NoOp()]), 63 | # Spatial-preserving augmentations: 64 | A.OneOf([A.CoarseDropout(), A.MaskDropout(max_objects=5) if mask_dropout else A.NoOp(), A.NoOp()]), 65 | A.GaussNoise(), 66 | A.OneOf([A.RandomBrightnessContrast(), A.CLAHE(), A.HueSaturationValue(), A.RGBShift(), A.RandomGamma()]), 67 | # Weather effects 68 | A.RandomFog(fog_coef_lower=0.01, fog_coef_upper=0.3, p=0.1), 69 | ] 70 | 71 | 72 | def hard_augmentations(mask_dropout=True) -> List[A.DualTransform]: 73 | return [ 74 | # D4 Augmentations 75 | A.RandomRotate90(p=1), 76 | A.Transpose(p=0.5), 77 | # Spatial augmentations 78 | A.OneOf( 79 | [ 80 | A.ShiftScaleRotate(scale_limit=0.2, rotate_limit=45, border_mode=cv2.BORDER_REFLECT101), 81 | A.ElasticTransform(border_mode=cv2.BORDER_REFLECT101, alpha_affine=5), 82 | ] 83 | ), 84 | # Color augmentations 85 | A.OneOf( 86 | [ 87 | A.RandomBrightnessContrast(brightness_by_max=True), 88 | A.CLAHE(), 89 | A.FancyPCA(), 90 | A.HueSaturationValue(), 91 | A.RGBShift(), 92 | A.RandomGamma(), 93 | ] 94 | ), 95 | # Dropout & Shuffle 96 | A.OneOf( 97 | [ 98 | A.RandomGridShuffle(), 99 | A.CoarseDropout(), 100 | A.MaskDropout(max_objects=2, mask_fill_value=0) if mask_dropout else A.NoOp(), 101 | ] 102 | ), 103 | # Add occasion blur 104 | A.OneOf([A.GaussianBlur(), A.GaussNoise(), A.IAAAdditiveGaussianNoise()]), 105 | # Weather effects 106 | A.RandomFog(fog_coef_lower=0.01, fog_coef_upper=0.3, p=0.1), 107 | ] 108 | 109 | 110 | def get_augmentations(augmentation: str) -> List[A.DualTransform]: 111 | if augmentation == "hard": 112 | aug_transform = hard_augmentations() 113 | elif augmentation == "medium": 114 | aug_transform = medium_augmentations() 115 | elif augmentation == "light": 116 | aug_transform = light_augmentations() 117 | elif augmentation == "safe": 118 | aug_transform = safe_augmentations() 119 | else: 120 | aug_transform = [] 121 | 122 | return aug_transform 123 | -------------------------------------------------------------------------------- /inria/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Callable, Optional, Tuple 3 | 4 | import albumentations as A 5 | import cv2 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from PIL import Image 10 | from pytorch_toolbelt.inference.tiles import ImageSlicer 11 | from pytorch_toolbelt.utils import fs 12 | from pytorch_toolbelt.utils.catalyst import PseudolabelDatasetMixin 13 | from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, tensor_from_mask_image, image_to_tensor 14 | from scipy.ndimage import binary_dilation, binary_erosion 15 | from torch.utils.data import WeightedRandomSampler, Dataset, ConcatDataset 16 | 17 | from .augmentations import * 18 | from .augmentations import crop_transform_xview2, get_augmentations 19 | 20 | INPUT_IMAGE_KEY = "image" 21 | INPUT_IMAGE_ID_KEY = "image_id" 22 | INPUT_MASK_KEY = "true_mask" 23 | INPUT_MASK_WEIGHT_KEY = "true_weights" 24 | OUTPUT_MASK_KEY = "pred_mask" 25 | OUTPUT_OFFSET_KEY = "pred_offset" 26 | INPUT_INDEX_KEY = "index" 27 | 28 | # Smaller masks for deep supervision 29 | def output_mask_name_for_stride(stride: int): 30 | return f"pred_mask_{stride}" 31 | 32 | 33 | OUTPUT_MASK_2_KEY = output_mask_name_for_stride(2) 34 | OUTPUT_MASK_4_KEY = output_mask_name_for_stride(4) 35 | OUTPUT_MASK_8_KEY = output_mask_name_for_stride(8) 36 | OUTPUT_MASK_16_KEY = output_mask_name_for_stride(16) 37 | OUTPUT_MASK_32_KEY = output_mask_name_for_stride(32) 38 | OUTPUT_MASK_64_KEY = output_mask_name_for_stride(64) 39 | 40 | OUTPUT_DSV_MASK_1_KEY = "output_dsv_mask_1" 41 | OUTPUT_DSV_MASK_2_KEY = "output_dsv_mask_2" 42 | OUTPUT_DSV_MASK_3_KEY = "output_dsv_mask_3" 43 | OUTPUT_DSV_MASK_4_KEY = "output_dsv_mask_4" 44 | OUTPUT_DSV_MASK_5_KEY = "output_dsv_mask_5" 45 | OUTPUT_DSV_MASK_6_KEY = "output_dsv_mask_6" 46 | 47 | 48 | OUTPUT_CLASS_KEY = "pred_classes" 49 | 50 | UNLABELED_SAMPLE = 127 51 | 52 | # NOISY SAMPLES 53 | # chicago27 54 | # vienna30 55 | # austin23 56 | # chicago26 57 | 58 | TRAIN_LOCATIONS = ["austin", "chicago", "kitsap", "tyrol-w", "vienna"] 59 | TEST_LOCATIONS = ["bellingham", "bloomington", "innsbruck", "sfo", "tyrol-e"] 60 | 61 | 62 | def read_inria_image(fname): 63 | image = cv2.imread(fname) 64 | if image is None: 65 | raise IOError("Cannot read " + fname) 66 | return image 67 | 68 | 69 | def read_inria_mask(fname): 70 | mask = fs.read_image_as_is(fname) 71 | if mask is None: 72 | raise IOError("Cannot read " + fname) 73 | cv2.threshold(mask, thresh=0, maxval=1, type=cv2.THRESH_BINARY, dst=mask) 74 | return mask 75 | 76 | 77 | def read_inria_mask_with_pseudolabel(fname): 78 | mask = fs.read_image_as_is(fname).astype(np.float32) / 255.0 79 | return mask 80 | 81 | 82 | def read_xview_mask(fname): 83 | mask = np.array(Image.open(fname)) # Read using PIL since it supports palletted image 84 | if len(mask.shape) == 3: 85 | mask = np.squeeze(mask, axis=-1) 86 | return mask 87 | 88 | 89 | def compute_weight_mask(mask: np.ndarray, edge_weight=4) -> np.ndarray: 90 | binary_mask = mask > 0 91 | weight_mask = np.ones(mask.shape[:2]).astype(np.float32) 92 | 93 | if binary_mask.any(): 94 | dilated = binary_dilation(binary_mask, structure=np.ones((5, 5), dtype=np.bool)) 95 | eroded = binary_erosion(binary_mask, structure=np.ones((5, 5), dtype=np.bool)) 96 | 97 | a = dilated & ~binary_mask 98 | b = binary_mask & ~eroded 99 | 100 | weight_mask = (a | b).astype(np.float32) * edge_weight + 1 101 | weight_mask = cv2.GaussianBlur(weight_mask, ksize=(5, 5), sigmaX=5) 102 | return weight_mask 103 | 104 | 105 | def mask2depth(mask: np.ndarray): 106 | """ 107 | Take binary mask image and convert it to mask with stride 2 whether each pixel has 16 classes 108 | :param mask: 109 | :return: 110 | """ 111 | mask = (mask > 0).astype(np.uint8) 112 | mask = mask.reshape((mask.shape[0] // 2, 2, mask.shape[1] // 2, 2)) 113 | mask = np.transpose(mask, (0, 2, 1, 3)) # [R/2, C/2, 2,2] 114 | mask = mask.reshape((mask.shape[0], mask.shape[1], 4)) 115 | mask = np.packbits(mask, axis=-1, bitorder="little") 116 | return mask.squeeze(-1) 117 | 118 | 119 | def depth2mask(mask: np.ndarray): 120 | mask = mask.reshape((mask.shape[0], mask.shape[1], 1)) 121 | mask = np.unpackbits(mask.astype(np.uint8), axis=-1, count=4, bitorder="little") 122 | mask = mask.reshape((mask.shape[0], mask.shape[1], 2, 2)) 123 | mask = np.transpose(mask, (0, 2, 1, 3)) 124 | mask = mask.reshape((mask.shape[0] * 2, mask.shape[2] * 2)) 125 | return mask 126 | 127 | 128 | def decode_depth_mask(mask: np.ndarray): 129 | mask = np.argmax(mask, axis=0) 130 | return depth2mask(mask) 131 | 132 | 133 | def mask_to_bce_target(mask): 134 | return image_to_tensor(mask).float() 135 | 136 | 137 | def mask_to_ce_target(mask): 138 | mask = mask2depth(mask) 139 | return torch.from_numpy(mask).long() 140 | 141 | 142 | class InriaImageMaskDataset(Dataset, PseudolabelDatasetMixin): 143 | def __init__( 144 | self, 145 | image_filenames: List[str], 146 | mask_filenames: Optional[List[str]], 147 | transform: A.Compose, 148 | image_loader=read_inria_image, 149 | mask_loader=read_inria_mask, 150 | need_weight_mask=False, 151 | image_ids=None, 152 | make_mask_target_fn: Callable = mask_to_bce_target, 153 | ): 154 | if mask_filenames is not None and len(image_filenames) != len(mask_filenames): 155 | raise ValueError("Number of images does not corresponds to number of targets") 156 | 157 | self.image_ids = [fs.id_from_fname(fname) for fname in image_filenames] if image_ids is None else image_ids 158 | self.need_weight_mask = need_weight_mask 159 | 160 | self.images = image_filenames 161 | self.masks = mask_filenames 162 | self.get_image = image_loader 163 | self.get_mask = mask_loader 164 | 165 | self.transform = transform 166 | self.make_mask_target_fn = make_mask_target_fn 167 | 168 | def __len__(self): 169 | return len(self.images) 170 | 171 | def set_target(self, index: int, value: np.ndarray): 172 | mask_fname = self.masks[index] 173 | 174 | value = (value * 255).astype(np.uint8) 175 | cv2.imwrite(mask_fname, value) 176 | 177 | def __getitem__(self, index): 178 | image = self.get_image(self.images[index]) 179 | 180 | if self.masks is not None: 181 | mask = self.get_mask(self.masks[index]) 182 | else: 183 | mask = np.ones((image.shape[0], image.shape[1], 1), dtype=np.uint8) * UNLABELED_SAMPLE 184 | 185 | data = self.transform(image=image, mask=mask) 186 | 187 | sample = { 188 | INPUT_IMAGE_KEY: image_to_tensor(data["image"]), 189 | INPUT_IMAGE_ID_KEY: self.image_ids[index], 190 | INPUT_INDEX_KEY: index, 191 | INPUT_MASK_KEY: self.make_mask_target_fn(data["mask"]), 192 | } 193 | 194 | if self.need_weight_mask: 195 | sample[INPUT_MASK_WEIGHT_KEY] = image_to_tensor(compute_weight_mask(data["mask"])).float() 196 | 197 | return sample 198 | 199 | 200 | class _InrialTiledImageMaskDataset(Dataset): 201 | def __init__( 202 | self, 203 | image_fname: str, 204 | mask_fname: str, 205 | image_loader: Callable, 206 | target_loader: Callable, 207 | tile_size, 208 | tile_step, 209 | image_margin=0, 210 | transform=None, 211 | target_shape=None, 212 | need_weight_mask=False, 213 | keep_in_mem=False, 214 | make_mask_target_fn: Callable = mask_to_bce_target, 215 | ): 216 | self.image_fname = image_fname 217 | self.mask_fname = mask_fname 218 | self.image_loader = image_loader 219 | self.mask_loader = target_loader 220 | self.image = None 221 | self.mask = None 222 | self.need_weight_mask = need_weight_mask 223 | 224 | if target_shape is None or keep_in_mem: 225 | image = image_loader(image_fname) 226 | mask = target_loader(mask_fname) 227 | if image.shape[0] != mask.shape[0] or image.shape[1] != mask.shape[1]: 228 | raise ValueError( 229 | f"Image size {image.shape} and mask shape {image.shape} must have equal width and height" 230 | ) 231 | 232 | target_shape = image.shape 233 | 234 | self.slicer = ImageSlicer(target_shape, tile_size, tile_step, image_margin) 235 | 236 | self.transform = transform 237 | self.image_ids = [fs.id_from_fname(image_fname)] * len(self.slicer.crops) 238 | self.crop_coords_str = [f"[{crop[0]};{crop[1]};{crop[2]};{crop[3]};]" for crop in self.slicer.crops] 239 | self.make_mask_target_fn = make_mask_target_fn 240 | 241 | def _get_image(self, index): 242 | image = self.image_loader(self.image_fname) 243 | image = self.slicer.cut_patch(image, index) 244 | return image 245 | 246 | def _get_mask(self, index): 247 | mask = self.mask_loader(self.mask_fname) 248 | mask = self.slicer.cut_patch(mask, index) 249 | return mask 250 | 251 | def __len__(self): 252 | return len(self.slicer.crops) 253 | 254 | def __getitem__(self, index): 255 | image = self._get_image(index) 256 | mask = self._get_mask(index) 257 | data = self.transform(image=image, mask=mask) 258 | 259 | image = data["image"] 260 | mask = data["mask"] 261 | 262 | data = { 263 | INPUT_IMAGE_KEY: image_to_tensor(image), 264 | INPUT_MASK_KEY: self.make_mask_target_fn(mask), 265 | INPUT_IMAGE_ID_KEY: self.image_ids[index], 266 | "crop_coords": self.crop_coords_str[index], 267 | } 268 | 269 | if self.need_weight_mask: 270 | data[INPUT_MASK_WEIGHT_KEY] = tensor_from_mask_image(compute_weight_mask(mask)).float() 271 | 272 | return data 273 | 274 | 275 | class InrialTiledImageMaskDataset(ConcatDataset): 276 | def __init__( 277 | self, 278 | image_filenames: List[str], 279 | target_filenames: List[str], 280 | image_loader=read_inria_image, 281 | target_loader=read_inria_mask, 282 | need_weight_mask=False, 283 | **kwargs, 284 | ): 285 | if len(image_filenames) != len(target_filenames): 286 | raise ValueError("Number of images does not corresponds to number of targets") 287 | 288 | datasets = [] 289 | for image, mask in zip(image_filenames, target_filenames): 290 | dataset = _InrialTiledImageMaskDataset( 291 | image, mask, image_loader, target_loader, need_weight_mask=need_weight_mask, **kwargs 292 | ) 293 | datasets.append(dataset) 294 | super().__init__(datasets) 295 | 296 | 297 | def get_datasets( 298 | data_dir: str, 299 | image_size=(224, 224), 300 | augmentation="hard", 301 | train_mode="random", 302 | sanity_check=False, 303 | fast=False, 304 | buildings_only=True, 305 | need_weight_mask=False, 306 | make_mask_target_fn: Callable = mask_to_bce_target, 307 | ) -> Tuple[Dataset, Dataset, Optional[WeightedRandomSampler]]: 308 | """ 309 | Create train and validation data loaders 310 | :param data_dir: Inria dataset directory 311 | :param fast: Fast training model. Use only one image per location for training and one image per location for validation 312 | :param image_size: Size of image crops during training & validation 313 | :param augmentation: Type of image augmentations to use 314 | :param train_mode: 315 | 'random' - crops tiles from source images randomly. 316 | 'tiles' - crop image in overlapping tiles (guaranteed to process entire dataset) 317 | :return: (train_loader, valid_loader) 318 | """ 319 | 320 | normalize = A.Normalize() 321 | 322 | assert train_mode in {"random", "tiles"} 323 | locations = TRAIN_LOCATIONS 324 | 325 | valid_transform = normalize 326 | train_augmentation = get_augmentations(augmentation) 327 | 328 | if train_mode == "random": 329 | 330 | train_data = [] 331 | valid_data = [] 332 | 333 | # For validation, we remove the first five images of every location (e.g., austin{1-5}.tif, chicago{1-5}.tif) from the training set. 334 | # That is suggested validation strategy by competition host 335 | 336 | if fast: 337 | # Fast training model. Use only one image per location for training and one image per location for validation 338 | for loc in locations: 339 | valid_data.append(f"{loc}1") 340 | train_data.append(f"{loc}6") 341 | else: 342 | for loc in locations: 343 | for i in range(1, 6): 344 | valid_data.append(f"{loc}{i}") 345 | for i in range(6, 37): 346 | train_data.append(f"{loc}{i}") 347 | 348 | train_img = [os.path.join(data_dir, "train", "images", f"{fname}.tif") for fname in train_data] 349 | valid_img = [os.path.join(data_dir, "train", "images", f"{fname}.tif") for fname in valid_data] 350 | 351 | train_mask = [os.path.join(data_dir, "train", "gt", f"{fname}.tif") for fname in train_data] 352 | valid_mask = [os.path.join(data_dir, "train", "gt", f"{fname}.tif") for fname in valid_data] 353 | 354 | train_crop = crop_transform(image_size, input_size=5000) 355 | train_transform = A.Compose([train_crop] + train_augmentation + [normalize]) 356 | 357 | trainset = InriaImageMaskDataset( 358 | train_img, 359 | train_mask, 360 | need_weight_mask=need_weight_mask, 361 | transform=train_transform, 362 | make_mask_target_fn=make_mask_target_fn, 363 | ) 364 | 365 | num_train_samples = int(len(trainset) * (5000 * 5000) / (image_size[0] * image_size[1])) 366 | crops_in_image = (5000 * 5000) / (image_size[0] * image_size[1]) 367 | if fast: 368 | num_train_samples = 128 369 | 370 | train_sampler = WeightedRandomSampler(torch.ones(len(trainset)) * crops_in_image, num_train_samples) 371 | 372 | validset = InrialTiledImageMaskDataset( 373 | valid_img, 374 | valid_mask, 375 | transform=valid_transform, 376 | # For validation we don't want tiles overlap 377 | tile_size=image_size, 378 | tile_step=image_size, 379 | target_shape=(5000, 5000), 380 | need_weight_mask=need_weight_mask, 381 | make_mask_target_fn=make_mask_target_fn, 382 | ) 383 | 384 | elif train_mode == "tiles": 385 | inria_tiles = pd.read_csv(os.path.join(data_dir, "inria_tiles.csv")) 386 | inria_tiles["image"] = inria_tiles["image"].apply(lambda x: os.path.join(data_dir, x)) 387 | inria_tiles["mask"] = inria_tiles["mask"].apply(lambda x: os.path.join(data_dir, x)) 388 | 389 | if buildings_only: 390 | inria_tiles = inria_tiles[inria_tiles["has_buildings"]] 391 | 392 | train_img = inria_tiles[inria_tiles["train"] == 1]["image"].tolist() 393 | train_mask = inria_tiles[inria_tiles["train"] == 1]["mask"].tolist() 394 | train_img_ids = inria_tiles[inria_tiles["train"] == 1]["image_id"].tolist() 395 | 396 | if fast: 397 | train_img = train_img[:128] 398 | train_mask = train_mask[:128] 399 | train_img_ids = train_img_ids[:128] 400 | 401 | train_crop = crop_transform(image_size, input_size=768) 402 | train_transform = A.Compose([train_crop] + train_augmentation + [normalize]) 403 | 404 | trainset = InriaImageMaskDataset( 405 | train_img, 406 | train_mask, 407 | image_ids=train_img_ids, 408 | need_weight_mask=need_weight_mask, 409 | transform=train_transform, 410 | make_mask_target_fn=make_mask_target_fn, 411 | ) 412 | 413 | valid_data = [] 414 | for loc in locations: 415 | for i in range(1, 6): 416 | valid_data.append(f"{loc}{i}") 417 | 418 | valid_img = [os.path.join(data_dir, "train", "images", f"{fname}.tif") for fname in valid_data] 419 | valid_mask = [os.path.join(data_dir, "train", "gt", f"{fname}.tif") for fname in valid_data] 420 | 421 | if fast: 422 | valid_img = valid_img[0:1] 423 | valid_mask = valid_mask[0:1] 424 | 425 | validset = InrialTiledImageMaskDataset( 426 | valid_img, 427 | valid_mask, 428 | transform=valid_transform, 429 | # For validation we don't want tiles overlap 430 | tile_size=image_size, 431 | tile_step=image_size, 432 | target_shape=(5000, 5000), 433 | need_weight_mask=need_weight_mask, 434 | make_mask_target_fn=make_mask_target_fn, 435 | ) 436 | 437 | train_sampler = None 438 | else: 439 | raise ValueError(train_mode) 440 | 441 | if sanity_check: 442 | first_batch = [trainset[i] for i in range(32)] 443 | return first_batch * 50, first_batch, None 444 | 445 | return trainset, validset, train_sampler 446 | 447 | 448 | def get_xview2_extra_dataset( 449 | data_dir: str, image_size=(224, 224), augmentation="hard", need_weight_mask=False, fast=False 450 | ) -> Tuple[Dataset, WeightedRandomSampler]: 451 | """ 452 | Create additional train dataset using xView2 dataset 453 | :param data_dir: xView2 dataset directory 454 | :param fast: Fast training model. Use only one image per location for training and one image per location for validation 455 | :param image_size: Size of image crops during training & validation 456 | :param need_weight_mask: If True, adds 'edge' target mask 457 | :param augmentation: Type of image augmentations to use 458 | 'random' - crops tiles from source images randomly. 459 | 'tiles' - crop image in overlapping tiles (guaranteed to process entire dataset) 460 | :return: (train_loader, valid_loader) 461 | """ 462 | 463 | if augmentation == "hard": 464 | train_transform = hard_augmentations() 465 | elif augmentation == "medium": 466 | train_transform = medium_augmentations() 467 | elif augmentation == "light": 468 | train_transform = light_augmentations() 469 | elif augmentation == "safe": 470 | train_transform = safe_augmentations() 471 | else: 472 | train_transform = [] 473 | 474 | def is_pre_image(fname): 475 | return "_pre_" in fname 476 | 477 | train1_img = list(filter(is_pre_image, fs.find_images_in_dir(os.path.join(data_dir, "train", "images")))) 478 | train1_msk = list(filter(is_pre_image, fs.find_images_in_dir(os.path.join(data_dir, "train", "masks")))) 479 | 480 | train2_img = list(filter(is_pre_image, fs.find_images_in_dir(os.path.join(data_dir, "tier3", "images")))) 481 | train2_msk = list(filter(is_pre_image, fs.find_images_in_dir(os.path.join(data_dir, "tier3", "masks")))) 482 | 483 | if fast: 484 | train1_img = train1_img[:128] 485 | train1_msk = train1_msk[:128] 486 | 487 | train2_img = train2_img[:128] 488 | train2_msk = train2_msk[:128] 489 | 490 | train_transform = A.Compose([crop_transform_xview2(image_size, input_size=1024), train_transform]) 491 | 492 | trainset = InriaImageMaskDataset( 493 | image_filenames=train1_img + train2_img, 494 | mask_filenames=train1_msk + train2_msk, 495 | transform=train_transform, 496 | mask_loader=read_xview_mask, 497 | need_weight_mask=need_weight_mask, 498 | ) 499 | 500 | num_train_samples = int(len(trainset) * (1024 * 1024) / (image_size[0] * image_size[1])) 501 | crops_in_image = (1024 * 1024) / (image_size[0] * image_size[1]) 502 | if fast: 503 | num_train_samples = 128 504 | 505 | train_sampler = WeightedRandomSampler(torch.ones(len(trainset)) * crops_in_image, num_train_samples) 506 | 507 | return trainset, None if fast else train_sampler 508 | 509 | 510 | def get_pseudolabeling_dataset( 511 | data_dir: str, include_masks: bool, image_size=(224, 224), augmentation=None, need_weight_mask=False 512 | ): 513 | images = fs.find_images_in_dir(os.path.join(data_dir, "test_tiles", "images")) 514 | 515 | masks_dir = os.path.join(data_dir, "test_tiles", "masks") 516 | os.makedirs(masks_dir, exist_ok=True) 517 | 518 | masks = [os.path.join(masks_dir, fs.id_from_fname(image_fname) + ".png") for image_fname in images] 519 | 520 | normalize = A.Normalize() 521 | 522 | if augmentation == "hard": 523 | augs = hard_augmentations(mask_dropout=False) 524 | crop = [crop_transform(image_size, input_size=768)] 525 | elif augmentation == "medium": 526 | augs = medium_augmentations(mask_dropout=False) 527 | crop = [crop_transform(image_size, input_size=768)] 528 | elif augmentation == "light": 529 | augs = light_augmentations(mask_dropout=False) 530 | crop = [crop_transform(image_size, input_size=768)] 531 | else: 532 | augs = [] 533 | crop = [] 534 | 535 | transfrom = A.Compose(crop + augs + [normalize]) 536 | return InriaImageMaskDataset( 537 | images, 538 | masks if include_masks else None, 539 | transform=transfrom, 540 | image_loader=read_inria_image, 541 | mask_loader=read_inria_mask_with_pseudolabel, 542 | need_weight_mask=need_weight_mask, 543 | ) 544 | -------------------------------------------------------------------------------- /inria/factory.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.pool import Pool 2 | from typing import List, Dict 3 | 4 | import albumentations as A 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from pytorch_toolbelt.inference.tiles import CudaTileMerger, ImageSlicer, TileMerger 9 | from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, to_numpy 10 | from torch import nn 11 | from torch.utils.data import Dataset, DataLoader 12 | from tqdm import tqdm 13 | 14 | from inria.dataset import OUTPUT_MASK_KEY 15 | 16 | 17 | class InMemoryDataset(Dataset): 18 | def __init__(self, data: List[Dict], transform: A.Compose): 19 | self.data = data 20 | self.transform = transform 21 | 22 | def __len__(self): 23 | return len(self.data) 24 | 25 | def __getitem__(self, item): 26 | return self.transform(**self.data[item]) 27 | 28 | 29 | def _tensor_from_rgb_image(image: np.ndarray, **kwargs): 30 | return tensor_from_rgb_image(image) 31 | 32 | 33 | class PickModelOutput(nn.Module): 34 | def __init__(self, model, key): 35 | super().__init__() 36 | self.model = model 37 | self.target_key = key 38 | 39 | def forward(self, input): 40 | output = self.model(input) 41 | return output[self.target_key] 42 | 43 | 44 | @torch.no_grad() 45 | def predict(model: nn.Module, image: np.ndarray, image_size, normalize=A.Normalize(), batch_size=1) -> np.ndarray: 46 | 47 | tile_step = (image_size[0] // 2, image_size[1] // 2) 48 | 49 | tile_slicer = ImageSlicer(image.shape, image_size, tile_step) 50 | tile_merger = TileMerger(tile_slicer.target_shape, 1, tile_slicer.weight, device="cuda") 51 | patches = tile_slicer.split(image) 52 | 53 | transform = A.Compose([normalize, A.Lambda(image=_tensor_from_rgb_image)]) 54 | 55 | data = list( 56 | {"image": patch, "coords": np.array(coords, dtype=np.int)} 57 | for (patch, coords) in zip(patches, tile_slicer.crops) 58 | ) 59 | for batch in DataLoader( 60 | InMemoryDataset(data, transform), 61 | pin_memory=True, 62 | batch_size=batch_size, 63 | num_workers=4, 64 | shuffle=False, 65 | drop_last=False, 66 | ): 67 | image = batch["image"].cuda(non_blocking=True) 68 | coords = batch["coords"] 69 | output = model(image) 70 | tile_merger.integrate_batch(output, coords) 71 | 72 | mask = tile_merger.merge() 73 | 74 | mask = np.moveaxis(to_numpy(mask), 0, -1) 75 | mask = tile_slicer.crop_to_orignal_size(mask) 76 | 77 | return mask 78 | -------------------------------------------------------------------------------- /inria/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from pytorch_toolbelt.losses import * 6 | from torch import nn 7 | from torch.nn import KLDivLoss 8 | 9 | from inria.dataset import INPUT_MASK_KEY, INPUT_MASK_WEIGHT_KEY 10 | 11 | __all__ = ["get_loss", "WeightedBCEWithLogits", "KLDivLossWithLogits"] 12 | 13 | 14 | class BinaryKLDivLossWithLogits(KLDivLoss): 15 | """ 16 | """ 17 | 18 | def __init__(self, ignore_index=None): 19 | super().__init__() 20 | self.ignore_index = ignore_index 21 | 22 | def forward(self, input, target): 23 | # Resize target to size of input 24 | input_size = input.size()[2:] 25 | target_size = target.size()[2:] 26 | if input_size != target_size: 27 | if self.ignore_index is not None: 28 | raise ValueError("In case ignore_index is not None, input and output tensors must have equal size") 29 | target = F.interpolate(target, size=input_size, mode="bilinear", align_corners=False) 30 | 31 | if self.ignore_index is not None: 32 | mask = target != self.ignore_index 33 | input = input[mask] 34 | target = target[mask] 35 | 36 | if len(target) == 0: 37 | return 0 38 | 39 | input = torch.cat([input, 1 - input], dim=1) 40 | log_p = F.logsigmoid(input) 41 | 42 | target = torch.cat([target, 1 - target], dim=1) 43 | 44 | loss = F.kl_div(log_p, target, reduction="mean") 45 | return loss 46 | 47 | 48 | class ResizePredictionTarget2d(nn.Module): 49 | """ 50 | Wrapper around loss, that rescale model output to target size 51 | """ 52 | 53 | def __init__(self, loss): 54 | super().__init__() 55 | self.loss = loss 56 | 57 | def forward(self, input, target): 58 | input = F.interpolate(input, target.size()[2:], mode="bilinear", align_corners=False) 59 | return self.loss(input, target) 60 | 61 | 62 | class ResizeTargetToPrediction2d(nn.Module): 63 | """ 64 | Wrapper around loss, that rescale target tensor to the size of output of the model. 65 | Note: This will corrupt binary labels and not indended for multiclass case 66 | """ 67 | 68 | def __init__(self, loss): 69 | super().__init__() 70 | self.loss = loss 71 | 72 | def forward(self, input, target): 73 | target = F.interpolate(target, input.size()[2:], mode="bilinear", align_corners=False) 74 | return self.loss(input, target) 75 | 76 | 77 | 78 | class WeightedBCEWithLogits(nn.Module): 79 | def __init__(self, mask_key, weight_key, ignore_index: Optional[int] = -100, reduction="mean"): 80 | super().__init__() 81 | self.ignore_index = ignore_index 82 | self.reduction = reduction 83 | self.weight_key = weight_key 84 | self.mask_key = mask_key 85 | 86 | def forward(self, label_input, target: Dict[str, torch.Tensor]): 87 | targets = target[self.mask_key] 88 | weights = target[self.weight_key] 89 | 90 | if self.ignore_index is not None: 91 | not_ignored_mask = (targets != self.ignore_index).float() 92 | 93 | loss = F.binary_cross_entropy_with_logits(label_input, targets, reduction="none") * weights 94 | 95 | if self.ignore_index is not None: 96 | loss = loss * not_ignored_mask.float() 97 | 98 | if self.reduction == "mean": 99 | loss = loss.mean() 100 | 101 | if self.reduction == "sum": 102 | loss = loss.sum() 103 | 104 | return loss 105 | 106 | 107 | class KLDivLossWithLogits(KLDivLoss): 108 | """ 109 | """ 110 | 111 | def __init__(self): 112 | super().__init__() 113 | 114 | def forward(self, input, target): 115 | 116 | # Resize target to size of input 117 | target = F.interpolate(target, size=input.size()[2:], mode="bilinear", align_corners=False) 118 | 119 | input = torch.cat([input, 1 - input], dim=1) 120 | log_p = F.logsigmoid(input) 121 | 122 | target = torch.cat([target, 1 - target], dim=1) 123 | 124 | loss = F.kl_div(log_p, target, reduction="mean") 125 | return loss 126 | 127 | 128 | def get_loss(loss_name: str, ignore_index=None): 129 | if loss_name.lower() == "kl": 130 | return KLDivLossWithLogits() 131 | 132 | if loss_name.lower() == "bce": 133 | return SoftBCEWithLogitsLoss(ignore_index=ignore_index) 134 | 135 | if loss_name.lower() == "ce": 136 | return nn.CrossEntropyLoss() 137 | 138 | if loss_name.lower() == "wbce": 139 | return WeightedBCEWithLogits( 140 | mask_key=INPUT_MASK_KEY, weight_key=INPUT_MASK_WEIGHT_KEY, ignore_index=ignore_index 141 | ) 142 | 143 | if loss_name.lower() == "soft_bce": 144 | return SoftBCEWithLogitsLoss(smooth_factor=0.1, ignore_index=ignore_index) 145 | 146 | if loss_name.lower() == "focal": 147 | return BinaryFocalLoss(alpha=None, gamma=1.5, ignore_index=ignore_index) 148 | 149 | if loss_name.lower() == "jaccard": 150 | assert ignore_index is None 151 | return JaccardLoss(mode="binary") 152 | 153 | if loss_name.lower() == "lovasz": 154 | assert ignore_index is None 155 | return BinaryLovaszLoss() 156 | 157 | if loss_name.lower() == "log_jaccard": 158 | assert ignore_index is None 159 | return JaccardLoss(mode="binary", log_loss=True) 160 | 161 | if loss_name.lower() == "dice": 162 | assert ignore_index is None 163 | return DiceLoss(mode="binary", log_loss=False) 164 | 165 | if loss_name.lower() == "log_dice": 166 | assert ignore_index is None 167 | return DiceLoss(mode="binary", log_loss=True) 168 | 169 | raise KeyError(loss_name) 170 | -------------------------------------------------------------------------------- /inria/metric.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Callable 3 | 4 | import numpy as np 5 | import torch 6 | from pytorch_toolbelt.utils import to_numpy 7 | from pytorch_toolbelt.utils.catalyst import get_tensorboard_logger 8 | from pytorch_toolbelt.utils.distributed import all_gather 9 | from catalyst.core import Callback, CallbackNode, CallbackOrder, IRunner 10 | from catalyst.dl import registry 11 | from catalyst.utils.distributed import get_rank 12 | 13 | __all__ = ["JaccardMetricPerImage", "JaccardMetricPerImageWithOptimalThreshold"] 14 | 15 | 16 | @registry.Callback 17 | class JaccardMetricPerImage(Callback): 18 | """ 19 | Jaccard metric callback which computes IoU metric per image and is aware that image is tiled. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | inputs_to_labels: Callable, 25 | outputs_to_labels: Callable, 26 | input_key: str = "targets", 27 | output_key: str = "logits", 28 | image_id_key: str = "image_id", 29 | prefix: str = "jaccard", 30 | ): 31 | super().__init__(CallbackOrder.Metric, CallbackNode.All) 32 | """ 33 | :param input_key: input key to use for precision calculation; specifies our `y_true`. 34 | :param output_key: output key to use for precision calculation; specifies our `y_pred`. 35 | """ 36 | self.prefix = prefix 37 | self.output_key = output_key 38 | self.input_key = input_key 39 | self.image_id_key = image_id_key 40 | self.scores_per_image = {} 41 | self.locations = ["austin", "chicago", "kitsap", "tyrol-w", "vienna"] 42 | self.inputs_to_labels = inputs_to_labels 43 | self.outputs_to_labels = outputs_to_labels 44 | 45 | def on_loader_start(self, state): 46 | self.scores_per_image = {} 47 | 48 | def on_batch_end(self, runner: IRunner): 49 | image_ids = runner.input[self.image_id_key] 50 | outputs = to_numpy(runner.output[self.output_key].detach()) 51 | targets = to_numpy(runner.input[self.input_key].detach()) 52 | 53 | for img_id, y_true, y_pred in zip(image_ids, targets, outputs): 54 | if img_id not in self.scores_per_image: 55 | self.scores_per_image[img_id] = {"intersection": 0, "union": 0} 56 | 57 | y_true_labels = self.inputs_to_labels(y_true) 58 | y_pred_labels = self.outputs_to_labels(y_pred) 59 | intersection = (y_true_labels * y_pred_labels).sum() 60 | union = y_true_labels.sum() + y_pred_labels.sum() - intersection 61 | 62 | self.scores_per_image[img_id]["intersection"] += float(intersection) 63 | self.scores_per_image[img_id]["union"] += float(union) 64 | 65 | def on_loader_end(self, runner: IRunner): 66 | # Gather statistics from all nodes 67 | gathered_scores_per_image = all_gather(self.scores_per_image) 68 | all_scores_per_image = defaultdict(lambda: {"intersection": 0.0, "union": 0.0}) 69 | for scores_per_image in gathered_scores_per_image: 70 | for image_id, values in scores_per_image.items(): 71 | all_scores_per_image[image_id]["intersection"] += values["intersection"] 72 | all_scores_per_image[image_id]["union"] += values["union"] 73 | 74 | eps = 1e-7 75 | ious_per_image = [] 76 | ious_per_location = defaultdict(list) 77 | 78 | for image_id, values in all_scores_per_image.items(): 79 | intersection = values["intersection"] 80 | union = values["union"] 81 | metric = intersection / (union + eps) 82 | ious_per_image.append(metric) 83 | 84 | for location in self.locations: 85 | if str.startswith(image_id, location): 86 | ious_per_location[location].append(metric) 87 | 88 | metric = float(np.mean(ious_per_image)) 89 | runner.loader_metrics[self.prefix] = metric 90 | 91 | for location, ious in ious_per_location.items(): 92 | runner.loader_metrics[f"{self.prefix}/{location}"] = float(np.mean(ious)) 93 | 94 | 95 | class JaccardMetricPerImageWithOptimalThreshold(Callback): 96 | """ 97 | Callback that computes an optimal threshold for binarizing logits and theoretical IoU score at given threshold. 98 | """ 99 | 100 | def __init__( 101 | self, 102 | input_key: str = "targets", 103 | output_key: str = "logits", 104 | image_id_key: str = "image_id", 105 | prefix: str = "optimal_threshold", 106 | ): 107 | super().__init__(CallbackOrder.Metric) 108 | """ 109 | :param input_key: input key to use for precision calculation; specifies our `y_true`. 110 | :param output_key: output key to use for precision calculation; specifies our `y_pred`. 111 | """ 112 | self.prefix = prefix 113 | self.output_key = output_key 114 | self.input_key = input_key 115 | self.image_id_key = image_id_key 116 | self.thresholds = torch.arange(0.3, 0.6, 0.025).detach() 117 | self.scores_per_image = {} 118 | 119 | def on_loader_start(self, runner: IRunner): 120 | self.scores_per_image = {} 121 | 122 | @torch.no_grad() 123 | def on_batch_end(self, runner: IRunner): 124 | image_id = runner.input[self.image_id_key] 125 | outputs = runner.output[self.output_key].detach().sigmoid() 126 | targets = runner.input[self.input_key].detach() 127 | 128 | # Flatten images for easy computing IoU 129 | assert outputs.size(1) == 1 130 | assert targets.size(1) == 1 131 | outputs = outputs.view(outputs.size(0), -1, 1) > self.thresholds.to(outputs.dtype).to(outputs.device).view( 132 | 1, 1, -1 133 | ) 134 | targets = targets.view(targets.size(0), -1) == 1 135 | n = len(self.thresholds) 136 | 137 | for i, threshold in enumerate(self.thresholds): 138 | # Binarize outputs 139 | outputs_i = outputs[..., i] 140 | intersection = torch.sum(targets & outputs_i, dim=1) 141 | union = torch.sum(targets | outputs_i, dim=1) 142 | 143 | for img_id, img_intersection, img_union in zip(image_id, intersection, union): 144 | if img_id not in self.scores_per_image: 145 | self.scores_per_image[img_id] = {"intersection": np.zeros(n), "union": np.zeros(n)} 146 | 147 | self.scores_per_image[img_id]["intersection"][i] += float(img_intersection) 148 | self.scores_per_image[img_id]["union"][i] += float(img_union) 149 | 150 | def on_loader_end(self, runner: IRunner): 151 | eps = 1e-7 152 | ious_per_image = [] 153 | 154 | # Gather statistics from all nodes 155 | all_gathered_scores_per_image = all_gather(self.scores_per_image) 156 | 157 | n = len(self.thresholds) 158 | all_scores_per_image = defaultdict(lambda: {"intersection": np.zeros(n), "union": np.zeros(n)}) 159 | for scores_per_image in all_gathered_scores_per_image: 160 | for image_id, values in scores_per_image.items(): 161 | all_scores_per_image[image_id]["intersection"] += values["intersection"] 162 | all_scores_per_image[image_id]["union"] += values["union"] 163 | 164 | for image_id, values in all_scores_per_image.items(): 165 | intersection = values["intersection"] 166 | union = values["union"] 167 | metric = intersection / (union + eps) 168 | ious_per_image.append(metric) 169 | 170 | thresholds = to_numpy(self.thresholds) 171 | iou = np.mean(ious_per_image, axis=0) 172 | assert len(iou) == len(thresholds) 173 | 174 | threshold_index = np.argmax(iou) 175 | iou_at_threshold = iou[threshold_index] 176 | threshold_value = thresholds[threshold_index] 177 | 178 | runner.loader_metrics[self.prefix + "/" + "threshold"] = float(threshold_value) 179 | runner.loader_metrics[self.prefix] = float(iou_at_threshold) 180 | 181 | if get_rank() in {-1, 0}: 182 | logger = get_tensorboard_logger(runner) 183 | logger.add_histogram(self.prefix, iou, global_step=runner.epoch) 184 | -------------------------------------------------------------------------------- /inria/models/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Dict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from . import fpn, unet, deeplab, hrnet, hg, can, efficient_unet, u2net 7 | 8 | __all__ = ["get_model", "model_from_checkpoint"] 9 | 10 | 11 | def get_model(model_name: str, pretrained=True, **kwargs) -> nn.Module: 12 | from catalyst.dl import registry 13 | 14 | model_fn = registry.MODEL.get(model_name) 15 | return model_fn(pretrained=pretrained, **kwargs) 16 | 17 | 18 | def model_from_checkpoint(checkpoint_name: str, strict=True, **kwargs) -> Tuple[nn.Module, Dict]: 19 | checkpoint = torch.load(checkpoint_name, map_location="cpu") 20 | model_state_dict = checkpoint["model_state_dict"] 21 | model_name = checkpoint["checkpoint_data"]["cmd_args"]["model"] 22 | 23 | model = get_model(model_name, pretrained=False, **kwargs) 24 | model.load_state_dict(model_state_dict, strict=strict) 25 | 26 | return model.eval(), checkpoint 27 | -------------------------------------------------------------------------------- /inria/models/can.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Union, List, Dict 3 | 4 | from pytorch_toolbelt.modules import encoders as E 5 | from pytorch_toolbelt.modules import decoders as D, conv1x1 6 | from pytorch_toolbelt.modules.decoders.can import CANDecoder 7 | from torch import nn, Tensor 8 | from torch.nn import functional as F 9 | 10 | from inria.dataset import OUTPUT_MASK_KEY 11 | 12 | __all__ = ["CANSegmentationModel", "seresnext50_can"] 13 | 14 | 15 | class CANSegmentationModel(nn.Module): 16 | def __init__( 17 | self, encoder: E.EncoderModule, features=256, num_classes: int = 1, dropout=0.25, full_size_mask=True 18 | ): 19 | super().__init__() 20 | self.encoder = encoder 21 | 22 | self.decoder = CANDecoder(encoder.channels, out_channels=features) 23 | 24 | self.mask = nn.Sequential( 25 | OrderedDict([("drop", nn.Dropout2d(dropout)), ("conv", conv1x1(features, num_classes))]) 26 | ) 27 | 28 | self.full_size_mask = full_size_mask 29 | 30 | def forward(self, x: Tensor) -> Dict[str, Tensor]: 31 | x_size = x.size() 32 | x = self.encoder(x) 33 | x = self.decoder(x) 34 | 35 | # Decode mask 36 | mask = self.mask(x[0]) 37 | 38 | if self.full_size_mask: 39 | mask = F.interpolate(mask, size=x_size[2:], mode="bilinear", align_corners=False) 40 | 41 | output = {OUTPUT_MASK_KEY: mask} 42 | return output 43 | 44 | 45 | def seresnext50_can(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 46 | encoder = E.Resnet50Encoder(pretrained=pretrained, layers=[1, 2, 3, 4]) 47 | if input_channels != 3: 48 | encoder.change_input_channels(input_channels) 49 | 50 | return CANSegmentationModel(encoder, num_classes=num_classes, features=256, dropout=dropout) 51 | -------------------------------------------------------------------------------- /inria/models/deeplab.py: -------------------------------------------------------------------------------- 1 | from pytorch_toolbelt.modules import ABN 2 | from pytorch_toolbelt.modules import decoders as D 3 | from pytorch_toolbelt.modules import encoders as E 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from ..dataset import OUTPUT_MASK_32_KEY, OUTPUT_MASK_KEY 8 | from catalyst.registry import Model 9 | 10 | __all__ = ["DeeplabV3SegmentationModel", "resnet34_deeplab128", "seresnext101_deeplab256"] 11 | 12 | 13 | class DeeplabV3SegmentationModel(nn.Module): 14 | def __init__( 15 | self, 16 | encoder: E.EncoderModule, 17 | num_classes: int, 18 | dropout=0.25, 19 | abn_block=ABN, 20 | high_level_bottleneck=256, 21 | low_level_bottleneck=32, 22 | full_size_mask=True, 23 | ): 24 | super().__init__() 25 | self.encoder = encoder 26 | 27 | self.decoder = D.DeeplabV3Decoder( 28 | feature_maps=encoder.output_filters, 29 | output_stride=encoder.output_strides[-1], 30 | num_classes=num_classes, 31 | high_level_bottleneck=high_level_bottleneck, 32 | low_level_bottleneck=low_level_bottleneck, 33 | abn_block=abn_block, 34 | dropout=dropout, 35 | ) 36 | 37 | self.full_size_mask = full_size_mask 38 | 39 | def forward(self, x): 40 | enc_features = self.encoder(x) 41 | 42 | # Decode mask 43 | mask, dsv = self.decoder(enc_features) 44 | 45 | if self.full_size_mask: 46 | mask = F.interpolate(mask, size=x.size()[2:], mode="bilinear", align_corners=False) 47 | 48 | output = {OUTPUT_MASK_KEY: mask, OUTPUT_MASK_32_KEY: dsv} 49 | 50 | return output 51 | 52 | 53 | @Model 54 | def resnet34_deeplab128(num_classes=1, dropout=0.0, pretrained=True): 55 | encoder = E.Resnet34Encoder(pretrained=pretrained) 56 | return DeeplabV3SegmentationModel(encoder, num_classes=num_classes, high_level_bottleneck=128, dropout=dropout) 57 | 58 | 59 | @Model 60 | def seresnext101_deeplab256(num_classes=1, dropout=0.0, pretrained=True): 61 | encoder = E.SEResNeXt101Encoder(pretrained=pretrained) 62 | return DeeplabV3SegmentationModel(encoder, num_classes=num_classes, high_level_bottleneck=256, dropout=dropout) 63 | -------------------------------------------------------------------------------- /inria/models/efficient_unet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | from typing import Union, List, Dict 4 | 5 | from pytorch_toolbelt.modules import conv1x1, UnetBlock, ACT_RELU, ABN, ACT_SWISH, Swish, DropBlock2D 6 | from pytorch_toolbelt.modules import encoders as E 7 | from pytorch_toolbelt.modules.decoders import UNetDecoder 8 | from pytorch_toolbelt.modules.encoders import EncoderModule 9 | from .timm_encoders import B4Encoder, B0Encoder, B6Encoder 10 | from torch import nn, Tensor 11 | from torch.nn import functional as F 12 | from timm.models.efficientnet_blocks import InvertedResidual 13 | from ..dataset import OUTPUT_MASK_KEY 14 | from catalyst.registry import Model 15 | 16 | 17 | __all__ = [ 18 | "EfficientUnetBlock", 19 | "EfficientUnetSegmentationModel", 20 | "b4_effunet32_s2", 21 | ] 22 | 23 | 24 | class EfficientUnetBlock(nn.Module): 25 | def __init__(self, in_channels: int, out_channels: int, activation=Swish, drop_path_rate=0.0): 26 | super().__init__() 27 | self.ir = InvertedResidual(in_channels, out_channels, act_layer=activation, se_ratio=0.25, exp_ratio=4) 28 | self.drop = DropBlock2D(drop_path_rate, 2) 29 | self.conv1 = nn.Sequential( 30 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 31 | nn.BatchNorm2d(out_channels), 32 | activation(inplace=True), 33 | ) 34 | self.conv2 = nn.Sequential( 35 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 36 | nn.BatchNorm2d(out_channels), 37 | activation(inplace=True), 38 | ) 39 | 40 | def forward(self, x): 41 | x = self.ir(x) 42 | x = self.drop(x) 43 | x = self.conv1(x) 44 | x = self.conv2(x) 45 | return x 46 | 47 | 48 | class EfficientUNetDecoder(UNetDecoder): 49 | def __init__( 50 | self, 51 | feature_maps: List[int], 52 | decoder_features: List[int], 53 | upsample_block=nn.UpsamplingNearest2d, 54 | activation=Swish, 55 | ): 56 | super().__init__( 57 | feature_maps, 58 | unet_block=partial(EfficientUnetBlock, activation=activation, drop_path_rate=0.2), 59 | decoder_features=decoder_features, 60 | upsample_block=upsample_block, 61 | ) 62 | 63 | 64 | class EfficientUnetSegmentationModel(nn.Module): 65 | def __init__( 66 | self, 67 | encoder: EncoderModule, 68 | unet_channels: Union[int, List[int]], 69 | num_classes: int = 1, 70 | dropout=0.25, 71 | full_size_mask=True, 72 | activation=Swish, 73 | ): 74 | super().__init__() 75 | self.encoder = encoder 76 | 77 | self.decoder = EfficientUNetDecoder( 78 | feature_maps=encoder.channels, decoder_features=unet_channels, activation=activation 79 | ) 80 | 81 | self.mask = nn.Sequential( 82 | OrderedDict([("drop", nn.Dropout2d(dropout)), ("conv", conv1x1(self.decoder.channels[0], num_classes))]) 83 | ) 84 | 85 | self.full_size_mask = full_size_mask 86 | 87 | def forward(self, x: Tensor) -> Dict[str, Tensor]: 88 | x_size = x.size() 89 | enc = self.encoder(x) 90 | dec = self.decoder(enc) 91 | 92 | # Decode mask 93 | mask = self.mask(dec[0]) 94 | 95 | if self.full_size_mask: 96 | mask = F.interpolate(mask, size=x_size[2:], mode="bilinear", align_corners=False) 97 | 98 | output = {OUTPUT_MASK_KEY: mask} 99 | return output 100 | 101 | 102 | @Model 103 | def b4_effunet32_s2(input_channels=3, num_classes=1, dropout=0.2, pretrained=True): 104 | encoder = B4Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 105 | if input_channels != 3: 106 | encoder.change_input_channels(input_channels) 107 | 108 | return EfficientUnetSegmentationModel( 109 | encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], activation=Swish, dropout=dropout 110 | ) 111 | -------------------------------------------------------------------------------- /inria/models/fpn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | 4 | from pytorch_toolbelt.modules import ABN, conv1x1, ACT_RELU, FPNContextBlock, FPNBottleneckBlock, ACT_SWISH, FPNFuse 5 | from pytorch_toolbelt.modules import encoders as E 6 | from pytorch_toolbelt.modules.decoders import FPNSumDecoder, FPNCatDecoder 7 | from pytorch_toolbelt.modules.encoders import EncoderModule 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from .timm_encoders import B4Encoder 12 | from ..dataset import OUTPUT_MASK_KEY 13 | from catalyst.registry import Model 14 | 15 | __all__ = [ 16 | "FPNSumSegmentationModel", 17 | "FPNCatSegmentationModel", 18 | "resnet34_fpncat128", 19 | "resnet152_fpncat256", 20 | "seresnext50_fpncat128", 21 | "seresnext101_fpncat256", 22 | "seresnext101_fpnsum256", 23 | "effnetB4_fpncat128", 24 | ] 25 | 26 | 27 | class FPNSumSegmentationModel(nn.Module): 28 | def __init__( 29 | self, encoder: EncoderModule, num_classes: int, dropout=0.25, full_size_mask=True, fpn_channels=256, 30 | ): 31 | super().__init__() 32 | self.encoder = encoder 33 | 34 | self.decoder = FPNSumDecoder(feature_maps=encoder.output_filters, fpn_channels=fpn_channels,) 35 | self.mask = nn.Sequential( 36 | OrderedDict([("drop", nn.Dropout2d(dropout)), ("conv", conv1x1(fpn_channels, num_classes))]) 37 | ) 38 | 39 | self.full_size_mask = full_size_mask 40 | 41 | def forward(self, x): 42 | x_size = x.size() 43 | x = self.encoder(x) 44 | x = self.decoder(x) 45 | 46 | # Decode mask 47 | mask = self.mask(x[0]) 48 | 49 | if self.full_size_mask: 50 | mask = F.interpolate(mask, size=x_size[2:], mode="bilinear", align_corners=False) 51 | 52 | output = { 53 | OUTPUT_MASK_KEY: mask, 54 | } 55 | 56 | return output 57 | 58 | 59 | class FPNCatSegmentationModel(nn.Module): 60 | def __init__( 61 | self, 62 | encoder: EncoderModule, 63 | num_classes: int, 64 | dropout=0.25, 65 | fpn_channels=256, 66 | abn_block=ABN, 67 | activation=ACT_RELU, 68 | full_size_mask=True, 69 | ): 70 | super().__init__() 71 | self.encoder = encoder 72 | 73 | abn_block = partial(abn_block, activation=activation) 74 | 75 | self.decoder = FPNCatDecoder( 76 | encoder.channels, 77 | context_block=partial(FPNContextBlock, abn_block=abn_block), 78 | bottleneck_block=partial(FPNBottleneckBlock, abn_block=abn_block), 79 | fpn_channels=fpn_channels, 80 | ) 81 | 82 | self.fuse = FPNFuse() 83 | self.mask = nn.Sequential( 84 | OrderedDict([("drop", nn.Dropout2d(dropout)), ("conv", conv1x1(sum(self.decoder.channels), num_classes))]) 85 | ) 86 | self.full_size_mask = full_size_mask 87 | 88 | def forward(self, x): 89 | x_size = x.size() 90 | x = self.encoder(x) 91 | x = self.decoder(x) 92 | x = self.fuse(x) 93 | # Decode mask 94 | mask = self.mask(x) 95 | 96 | if self.full_size_mask: 97 | mask = F.interpolate(mask, size=x_size[2:], mode="bilinear", align_corners=False) 98 | 99 | output = { 100 | OUTPUT_MASK_KEY: mask, 101 | } 102 | 103 | return output 104 | 105 | 106 | @Model 107 | def resnet34_fpncat128(num_classes=5, dropout=0.0, pretrained=True): 108 | encoder = E.Resnet34Encoder(pretrained=pretrained) 109 | return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=128, dropout=dropout) 110 | 111 | 112 | @Model 113 | def seresnext50_fpncat128(num_classes=5, dropout=0.0, pretrained=True): 114 | encoder = E.SEResNeXt50Encoder(pretrained=pretrained) 115 | return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=128, dropout=dropout) 116 | 117 | 118 | @Model 119 | def seresnext101_fpncat256(num_classes=5, dropout=0.0, pretrained=True): 120 | encoder = E.SEResNeXt101Encoder(pretrained=pretrained) 121 | return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=256, dropout=dropout) 122 | 123 | 124 | @Model 125 | def seresnext101_fpnsum256(num_classes=5, dropout=0.0, pretrained=True): 126 | encoder = E.SEResNeXt101Encoder(pretrained=pretrained) 127 | return FPNSumSegmentationModel(encoder, num_classes=num_classes, fpn_channels=256, dropout=dropout) 128 | 129 | 130 | @Model 131 | def resnet152_fpncat256(num_classes=5, dropout=0.0, pretrained=True): 132 | encoder = E.Resnet152Encoder(pretrained=pretrained) 133 | return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=256, dropout=dropout) 134 | 135 | 136 | @Model 137 | def effnetB4_fpncat128(num_classes=5, dropout=0.0, pretrained=True): 138 | encoder = E.EfficientNetB4Encoder(abn_params={"activation": "swish"}, pretrained=pretrained) 139 | return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=128, dropout=dropout) 140 | 141 | 142 | @Model 143 | def b4_fpn_cat(input_channels=3, num_classes=1, dropout=0.2, pretrained=True): 144 | encoder = B4Encoder(pretrained=pretrained, layers=[1, 2, 3, 4]) 145 | if input_channels != 3: 146 | encoder.change_input_channels(input_channels) 147 | 148 | return FPNCatSegmentationModel( 149 | encoder, num_classes=num_classes, fpn_channels=64, activation=ACT_SWISH, dropout=dropout 150 | ) 151 | -------------------------------------------------------------------------------- /inria/models/hg.py: -------------------------------------------------------------------------------- 1 | from pytorch_toolbelt.modules.encoders import EncoderModule, StackedHGEncoder 2 | from pytorch_toolbelt.modules.encoders.hourglass import StackedSupervisedHGEncoder 3 | from torch import nn 4 | from torch.nn import PixelShuffle 5 | 6 | from inria.dataset import OUTPUT_MASK_KEY, OUTPUT_MASK_4_KEY 7 | import torch.nn.functional as F 8 | from catalyst.registry import Model 9 | 10 | 11 | class HGSegmentationDecoderNaked(nn.Module): 12 | def __init__(self, input_channels: int, stride: int, mask_channels: int): 13 | super().__init__() 14 | 15 | self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=1) 16 | self.bn1 = nn.BatchNorm2d(input_channels) 17 | 18 | self.conv2 = nn.Conv2d(input_channels, mask_channels, kernel_size=1) 19 | 20 | def forward(self, x): 21 | x = self.conv1(x) 22 | x = self.bn1(x) 23 | x = F.relu(x, inplace=True) 24 | x = self.conv2(x) 25 | return x 26 | 27 | 28 | class HGSegmentationDecoder(nn.Module): 29 | def __init__(self, input_channels: int, stride: int, mask_channels: int): 30 | super().__init__() 31 | self.expand = nn.Conv2d(input_channels, input_channels, kernel_size=1) 32 | self.up = PixelShuffle(upscale_factor=stride) 33 | 34 | mid_channels = input_channels // (2 ** stride) 35 | self.conv1 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False) 36 | self.bn1 = nn.BatchNorm2d(mid_channels) 37 | self.act1 = nn.ReLU(True) 38 | 39 | self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1, bias=False) 40 | self.bn2 = nn.BatchNorm2d(mid_channels) 41 | self.act2 = nn.ReLU(True) 42 | 43 | self.final = nn.Conv2d(mid_channels, mask_channels, kernel_size=1) 44 | 45 | def forward(self, x): 46 | x = self.expand(x) 47 | x = self.up(x) 48 | 49 | x = self.act1(self.bn1(self.conv1(x))) 50 | x = self.act2(self.bn2(self.conv2(x))) 51 | 52 | x = self.final(x) 53 | return x 54 | 55 | 56 | class HGSegmentationModel(nn.Module): 57 | def __init__(self, encoder: EncoderModule, num_classes: int, full_size_mask=True): 58 | super().__init__() 59 | self.encoder = encoder 60 | self.decoder = HGSegmentationDecoder(encoder.output_filters[-1], encoder.output_strides[-1], num_classes) 61 | self.full_size_mask = full_size_mask 62 | 63 | def forward(self, x): 64 | features = self.encoder(x) 65 | 66 | # Decode mask 67 | mask = self.decoder(features[-1]) 68 | 69 | if self.full_size_mask: 70 | mask = F.interpolate(mask, x.size()[2:], mode="bilinear", align_corners=False) 71 | 72 | output = {OUTPUT_MASK_KEY: mask} 73 | return output 74 | 75 | 76 | class SupervisedHGSegmentationModel(nn.Module): 77 | def __init__(self, encoder: EncoderModule, num_classes: int, full_size_mask=True): 78 | super().__init__() 79 | self.encoder = encoder 80 | self.decoder = HGSegmentationDecoder(encoder.output_filters[-1], encoder.output_strides[-1], num_classes) 81 | self.full_size_mask = full_size_mask 82 | 83 | def forward(self, x): 84 | features, supervision = self.encoder(x) 85 | 86 | # Decode mask 87 | mask = self.decoder(features[-1]) 88 | 89 | if self.full_size_mask: 90 | mask = F.interpolate(mask, x.size()[2:], mode="bilinear", align_corners=False) 91 | 92 | output = {OUTPUT_MASK_KEY: mask} 93 | for i, sup in enumerate(supervision): 94 | output[OUTPUT_MASK_4_KEY + "_after_hg_" + str(i)] = sup 95 | 96 | return output 97 | 98 | 99 | @Model 100 | def hg4(num_classes=1, dropout=0, pretrained=False): 101 | encoder = StackedHGEncoder(stack_level=4) 102 | return HGSegmentationModel(encoder, num_classes=num_classes) 103 | 104 | 105 | @Model 106 | def shg4(num_classes=1, dropout=0, pretrained=False): 107 | encoder = StackedSupervisedHGEncoder(input_channels=3, stack_level=4, supervision_channels=num_classes) 108 | return SupervisedHGSegmentationModel(encoder, num_classes=num_classes) 109 | 110 | 111 | @Model 112 | def hg8(num_classes=1, dropout=0, pretrained=False): 113 | encoder = StackedHGEncoder(stack_level=8) 114 | return HGSegmentationModel(encoder, num_classes=num_classes) 115 | 116 | 117 | @Model 118 | def shg8(num_classes=1, dropout=0, pretrained=False): 119 | encoder = StackedSupervisedHGEncoder(input_channels=3, stack_level=8, supervision_channels=num_classes) 120 | return SupervisedHGSegmentationModel(encoder, num_classes=num_classes) 121 | -------------------------------------------------------------------------------- /inria/models/hrnet.py: -------------------------------------------------------------------------------- 1 | from pytorch_toolbelt.modules import encoders as E 2 | from pytorch_toolbelt.modules.decoders import HRNetSegmentationDecoder 3 | from pytorch_toolbelt.modules.encoders import EncoderModule 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from ..dataset import OUTPUT_MASK_KEY 8 | from catalyst.registry import Model 9 | 10 | __all__ = ["HRNetSegmentationModel", "hrnet18", "hrnet34", "hrnet48"] 11 | 12 | 13 | class HRNetSegmentationModel(nn.Module): 14 | def __init__(self, encoder: EncoderModule, num_classes: int, dropout=0.0, full_size_mask=True): 15 | super().__init__() 16 | self.encoder = encoder 17 | 18 | self.decoder = HRNetSegmentationDecoder( 19 | feature_maps=encoder.output_filters, output_channels=num_classes, dropout=dropout 20 | ) 21 | 22 | self.full_size_mask = full_size_mask 23 | 24 | def forward(self, x): 25 | enc_features = self.encoder(x) 26 | 27 | # Decode mask 28 | mask = self.decoder(enc_features) 29 | 30 | if self.full_size_mask: 31 | mask = F.interpolate(mask, size=x.size()[2:], mode="bilinear", align_corners=False) 32 | 33 | output = {OUTPUT_MASK_KEY: mask} 34 | 35 | return output 36 | 37 | 38 | @Model 39 | def hrnet18(num_classes=1, dropout=0.0, pretrained=True): 40 | encoder = E.HRNetV2Encoder18(pretrained=pretrained) 41 | return HRNetSegmentationModel(encoder, num_classes=num_classes, dropout=dropout) 42 | 43 | 44 | @Model 45 | def hrnet34(num_classes=1, dropout=0.0, pretrained=True): 46 | encoder = E.HRNetV2Encoder34(pretrained=pretrained) 47 | return HRNetSegmentationModel(encoder, num_classes=num_classes, dropout=dropout) 48 | 49 | 50 | @Model 51 | def hrnet48(num_classes=1, dropout=0.0, pretrained=True): 52 | encoder = E.HRNetV2Encoder48(pretrained=pretrained) 53 | return HRNetSegmentationModel(encoder, num_classes=num_classes, dropout=dropout) 54 | -------------------------------------------------------------------------------- /inria/models/u2net.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from catalyst.registry import Model 7 | 8 | from inria.dataset import ( 9 | OUTPUT_MASK_KEY, 10 | OUTPUT_DSV_MASK_1_KEY, 11 | OUTPUT_DSV_MASK_2_KEY, 12 | OUTPUT_DSV_MASK_3_KEY, 13 | OUTPUT_DSV_MASK_4_KEY, 14 | OUTPUT_DSV_MASK_5_KEY, 15 | OUTPUT_DSV_MASK_6_KEY, 16 | ) 17 | 18 | 19 | class REBNCONV(nn.Module): 20 | def __init__(self, in_ch=3, out_ch=3, dirate=1): 21 | super(REBNCONV, self).__init__() 22 | 23 | self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate) 24 | self.bn_s1 = nn.BatchNorm2d(out_ch) 25 | self.relu_s1 = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | 29 | hx = x 30 | xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) 31 | 32 | return xout 33 | 34 | 35 | ## upsample tensor 'src' to have the same spatial size with tensor 'tar' 36 | def _upsample_like(src, tar): 37 | 38 | src = torch.nn.functional.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False) 39 | 40 | return src 41 | 42 | 43 | ### RSU-7 ### 44 | class RSU7(nn.Module): # UNet07DRES(nn.Module): 45 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 46 | super(RSU7, self).__init__() 47 | 48 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 49 | 50 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 51 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 52 | 53 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 54 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 55 | 56 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 57 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 58 | 59 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 60 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 61 | 62 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 63 | self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 64 | 65 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) 66 | 67 | self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) 68 | 69 | self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 70 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 71 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 72 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 73 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 74 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 75 | 76 | def forward(self, x): 77 | 78 | hx = x 79 | hxin = self.rebnconvin(hx) 80 | 81 | hx1 = self.rebnconv1(hxin) 82 | hx = self.pool1(hx1) 83 | 84 | hx2 = self.rebnconv2(hx) 85 | hx = self.pool2(hx2) 86 | 87 | hx3 = self.rebnconv3(hx) 88 | hx = self.pool3(hx3) 89 | 90 | hx4 = self.rebnconv4(hx) 91 | hx = self.pool4(hx4) 92 | 93 | hx5 = self.rebnconv5(hx) 94 | hx = self.pool5(hx5) 95 | 96 | hx6 = self.rebnconv6(hx) 97 | 98 | hx7 = self.rebnconv7(hx6) 99 | 100 | hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) 101 | hx6dup = _upsample_like(hx6d, hx5) 102 | 103 | hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) 104 | hx5dup = _upsample_like(hx5d, hx4) 105 | 106 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 107 | hx4dup = _upsample_like(hx4d, hx3) 108 | 109 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 110 | hx3dup = _upsample_like(hx3d, hx2) 111 | 112 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 113 | hx2dup = _upsample_like(hx2d, hx1) 114 | 115 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 116 | 117 | return hx1d + hxin 118 | 119 | 120 | ### RSU-6 ### 121 | class RSU6(nn.Module): # UNet06DRES(nn.Module): 122 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 123 | super(RSU6, self).__init__() 124 | 125 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 126 | 127 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 128 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 129 | 130 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 131 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 132 | 133 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 134 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 135 | 136 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 137 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 138 | 139 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) 140 | 141 | self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) 142 | 143 | self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 144 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 145 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 146 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 147 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 148 | 149 | def forward(self, x): 150 | 151 | hx = x 152 | 153 | hxin = self.rebnconvin(hx) 154 | 155 | hx1 = self.rebnconv1(hxin) 156 | hx = self.pool1(hx1) 157 | 158 | hx2 = self.rebnconv2(hx) 159 | hx = self.pool2(hx2) 160 | 161 | hx3 = self.rebnconv3(hx) 162 | hx = self.pool3(hx3) 163 | 164 | hx4 = self.rebnconv4(hx) 165 | hx = self.pool4(hx4) 166 | 167 | hx5 = self.rebnconv5(hx) 168 | 169 | hx6 = self.rebnconv6(hx5) 170 | 171 | hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) 172 | hx5dup = _upsample_like(hx5d, hx4) 173 | 174 | hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) 175 | hx4dup = _upsample_like(hx4d, hx3) 176 | 177 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 178 | hx3dup = _upsample_like(hx3d, hx2) 179 | 180 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 181 | hx2dup = _upsample_like(hx2d, hx1) 182 | 183 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 184 | 185 | return hx1d + hxin 186 | 187 | 188 | ### RSU-5 ### 189 | class RSU5(nn.Module): # UNet05DRES(nn.Module): 190 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 191 | super(RSU5, self).__init__() 192 | 193 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 194 | 195 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 196 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 197 | 198 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 199 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 200 | 201 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 202 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 203 | 204 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) 205 | 206 | self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) 207 | 208 | self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 209 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 210 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 211 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 212 | 213 | def forward(self, x): 214 | 215 | hx = x 216 | 217 | hxin = self.rebnconvin(hx) 218 | 219 | hx1 = self.rebnconv1(hxin) 220 | hx = self.pool1(hx1) 221 | 222 | hx2 = self.rebnconv2(hx) 223 | hx = self.pool2(hx2) 224 | 225 | hx3 = self.rebnconv3(hx) 226 | hx = self.pool3(hx3) 227 | 228 | hx4 = self.rebnconv4(hx) 229 | 230 | hx5 = self.rebnconv5(hx4) 231 | 232 | hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) 233 | hx4dup = _upsample_like(hx4d, hx3) 234 | 235 | hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) 236 | hx3dup = _upsample_like(hx3d, hx2) 237 | 238 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 239 | hx2dup = _upsample_like(hx2d, hx1) 240 | 241 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 242 | 243 | return hx1d + hxin 244 | 245 | 246 | ### RSU-4 ### 247 | class RSU4(nn.Module): # UNet04DRES(nn.Module): 248 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 249 | super(RSU4, self).__init__() 250 | 251 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 252 | 253 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 254 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 255 | 256 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) 257 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 258 | 259 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) 260 | 261 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) 262 | 263 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 264 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) 265 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 266 | 267 | def forward(self, x): 268 | 269 | hx = x 270 | 271 | hxin = self.rebnconvin(hx) 272 | 273 | hx1 = self.rebnconv1(hxin) 274 | hx = self.pool1(hx1) 275 | 276 | hx2 = self.rebnconv2(hx) 277 | hx = self.pool2(hx2) 278 | 279 | hx3 = self.rebnconv3(hx) 280 | 281 | hx4 = self.rebnconv4(hx3) 282 | 283 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 284 | hx3dup = _upsample_like(hx3d, hx2) 285 | 286 | hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) 287 | hx2dup = _upsample_like(hx2d, hx1) 288 | 289 | hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) 290 | 291 | return hx1d + hxin 292 | 293 | 294 | ### RSU-4F ### 295 | class RSU4F(nn.Module): # UNet04FRES(nn.Module): 296 | def __init__(self, in_ch=3, mid_ch=12, out_ch=3): 297 | super(RSU4F, self).__init__() 298 | 299 | self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) 300 | 301 | self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) 302 | self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) 303 | self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) 304 | 305 | self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) 306 | 307 | self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) 308 | self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) 309 | self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) 310 | 311 | def forward(self, x): 312 | 313 | hx = x 314 | 315 | hxin = self.rebnconvin(hx) 316 | 317 | hx1 = self.rebnconv1(hxin) 318 | hx2 = self.rebnconv2(hx1) 319 | hx3 = self.rebnconv3(hx2) 320 | 321 | hx4 = self.rebnconv4(hx3) 322 | 323 | hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) 324 | hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) 325 | hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) 326 | 327 | return hx1d + hxin 328 | 329 | 330 | ##### U^2-Net #### 331 | @Model 332 | class U2NET(nn.Module): 333 | def __init__(self, input_channels=3, num_classes=1, pretrained=False): 334 | super(U2NET, self).__init__() 335 | 336 | self.stage1 = RSU7(input_channels, 32, 64) 337 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 338 | 339 | self.stage2 = RSU6(64, 32, 128) 340 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 341 | 342 | self.stage3 = RSU5(128, 64, 256) 343 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 344 | 345 | self.stage4 = RSU4(256, 128, 512) 346 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 347 | 348 | self.stage5 = RSU4F(512, 256, 512) 349 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 350 | 351 | self.stage6 = RSU4F(512, 256, 512) 352 | 353 | # decoder 354 | self.stage5d = RSU4F(1024, 256, 512) 355 | self.stage4d = RSU4(1024, 128, 256) 356 | self.stage3d = RSU5(512, 64, 128) 357 | self.stage2d = RSU6(256, 32, 64) 358 | self.stage1d = RSU7(128, 16, 64) 359 | 360 | self.side1 = nn.Conv2d(64, num_classes, 3, padding=1) 361 | self.side2 = nn.Conv2d(64, num_classes, 3, padding=1) 362 | self.side3 = nn.Conv2d(128, num_classes, 3, padding=1) 363 | self.side4 = nn.Conv2d(256, num_classes, 3, padding=1) 364 | self.side5 = nn.Conv2d(512, num_classes, 3, padding=1) 365 | self.side6 = nn.Conv2d(512, num_classes, 3, padding=1) 366 | 367 | self.outconv = nn.Conv2d(6, num_classes, 1) 368 | # if pretrained: 369 | # state_dict = torch.load(os.path.join(os.path.dirname(__file__), "u2net.pth")) 370 | # self.load_state_dict(state_dict) 371 | 372 | def forward(self, x): 373 | 374 | hx = x 375 | 376 | # stage 1 377 | hx1 = self.stage1(hx) 378 | hx = self.pool12(hx1) 379 | 380 | # stage 2 381 | hx2 = self.stage2(hx) 382 | hx = self.pool23(hx2) 383 | 384 | # stage 3 385 | hx3 = self.stage3(hx) 386 | hx = self.pool34(hx3) 387 | 388 | # stage 4 389 | hx4 = self.stage4(hx) 390 | hx = self.pool45(hx4) 391 | 392 | # stage 5 393 | hx5 = self.stage5(hx) 394 | hx = self.pool56(hx5) 395 | 396 | # stage 6 397 | hx6 = self.stage6(hx) 398 | hx6up = _upsample_like(hx6, hx5) 399 | 400 | # -------------------- decoder -------------------- 401 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 402 | hx5dup = _upsample_like(hx5d, hx4) 403 | 404 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 405 | hx4dup = _upsample_like(hx4d, hx3) 406 | 407 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 408 | hx3dup = _upsample_like(hx3d, hx2) 409 | 410 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 411 | hx2dup = _upsample_like(hx2d, hx1) 412 | 413 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 414 | 415 | # side output 416 | d1 = self.side1(hx1d) 417 | 418 | d2 = self.side2(hx2d) 419 | d2 = _upsample_like(d2, d1) 420 | 421 | d3 = self.side3(hx3d) 422 | d3 = _upsample_like(d3, d1) 423 | 424 | d4 = self.side4(hx4d) 425 | d4 = _upsample_like(d4, d1) 426 | 427 | d5 = self.side5(hx5d) 428 | d5 = _upsample_like(d5, d1) 429 | 430 | d6 = self.side6(hx6) 431 | d6 = _upsample_like(d6, d1) 432 | 433 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 434 | return { 435 | OUTPUT_MASK_KEY: d0, 436 | OUTPUT_DSV_MASK_1_KEY: d1, 437 | OUTPUT_DSV_MASK_2_KEY: d2, 438 | OUTPUT_DSV_MASK_3_KEY: d3, 439 | OUTPUT_DSV_MASK_4_KEY: d4, 440 | OUTPUT_DSV_MASK_5_KEY: d5, 441 | OUTPUT_DSV_MASK_6_KEY: d6, 442 | } 443 | 444 | 445 | ### U^2-Net small ### 446 | @Model 447 | class U2NETP(nn.Module): 448 | def __init__(self, input_channels=3, num_classes=1, pretrained=False): 449 | super(U2NETP, self).__init__() 450 | self.stage1 = RSU7(input_channels, 16, 64) 451 | self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 452 | 453 | self.stage2 = RSU6(64, 16, 64) 454 | self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 455 | 456 | self.stage3 = RSU5(64, 16, 64) 457 | self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 458 | 459 | self.stage4 = RSU4(64, 16, 64) 460 | self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 461 | 462 | self.stage5 = RSU4F(64, 16, 64) 463 | self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 464 | 465 | self.stage6 = RSU4F(64, 16, 64) 466 | 467 | # decoder 468 | self.stage5d = RSU4F(128, 16, 64) 469 | self.stage4d = RSU4(128, 16, 64) 470 | self.stage3d = RSU5(128, 16, 64) 471 | self.stage2d = RSU6(128, 16, 64) 472 | self.stage1d = RSU7(128, 16, 64) 473 | 474 | self.side1 = nn.Conv2d(64, num_classes, 3, padding=1) 475 | self.side2 = nn.Conv2d(64, num_classes, 3, padding=1) 476 | self.side3 = nn.Conv2d(64, num_classes, 3, padding=1) 477 | self.side4 = nn.Conv2d(64, num_classes, 3, padding=1) 478 | self.side5 = nn.Conv2d(64, num_classes, 3, padding=1) 479 | self.side6 = nn.Conv2d(64, num_classes, 3, padding=1) 480 | 481 | self.outconv = nn.Conv2d(6, num_classes, 1) 482 | 483 | # if pretrained: 484 | # state_dict = torch.load(os.path.join(os.path.dirname(__file__), "u2netp.pth")) 485 | # self.load_state_dict(state_dict) 486 | 487 | def forward(self, x): 488 | 489 | hx = x 490 | 491 | # stage 1 492 | hx1 = self.stage1(hx) 493 | hx = self.pool12(hx1) 494 | 495 | # stage 2 496 | hx2 = self.stage2(hx) 497 | hx = self.pool23(hx2) 498 | 499 | # stage 3 500 | hx3 = self.stage3(hx) 501 | hx = self.pool34(hx3) 502 | 503 | # stage 4 504 | hx4 = self.stage4(hx) 505 | hx = self.pool45(hx4) 506 | 507 | # stage 5 508 | hx5 = self.stage5(hx) 509 | hx = self.pool56(hx5) 510 | 511 | # stage 6 512 | hx6 = self.stage6(hx) 513 | hx6up = _upsample_like(hx6, hx5) 514 | 515 | # decoder 516 | hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) 517 | hx5dup = _upsample_like(hx5d, hx4) 518 | 519 | hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) 520 | hx4dup = _upsample_like(hx4d, hx3) 521 | 522 | hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) 523 | hx3dup = _upsample_like(hx3d, hx2) 524 | 525 | hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) 526 | hx2dup = _upsample_like(hx2d, hx1) 527 | 528 | hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) 529 | 530 | # side output 531 | d1 = self.side1(hx1d) 532 | 533 | d2 = self.side2(hx2d) 534 | d2 = _upsample_like(d2, d1) 535 | 536 | d3 = self.side3(hx3d) 537 | d3 = _upsample_like(d3, d1) 538 | 539 | d4 = self.side4(hx4d) 540 | d4 = _upsample_like(d4, d1) 541 | 542 | d5 = self.side5(hx5d) 543 | d5 = _upsample_like(d5, d1) 544 | 545 | d6 = self.side6(hx6) 546 | d6 = _upsample_like(d6, d1) 547 | 548 | d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1)) 549 | 550 | return { 551 | OUTPUT_MASK_KEY: d0, 552 | OUTPUT_DSV_MASK_1_KEY: d1, 553 | OUTPUT_DSV_MASK_2_KEY: d2, 554 | OUTPUT_DSV_MASK_3_KEY: d3, 555 | OUTPUT_DSV_MASK_4_KEY: d4, 556 | OUTPUT_DSV_MASK_5_KEY: d5, 557 | OUTPUT_DSV_MASK_6_KEY: d6, 558 | } 559 | 560 | 561 | __all__ = ["U2NETP", "U2NET"] 562 | -------------------------------------------------------------------------------- /inria/models/unet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from functools import partial 3 | from typing import Union, List, Dict, Type 4 | 5 | from pytorch_toolbelt.modules import conv1x1, UnetBlock, ACT_RELU, ABN, ACT_SWISH, ResidualDeconvolutionUpsample2d 6 | from pytorch_toolbelt.modules import encoders as E 7 | from pytorch_toolbelt.modules.decoders import UNetDecoder 8 | from pytorch_toolbelt.modules.encoders import EncoderModule 9 | from torch import nn, Tensor 10 | from torch.nn import functional as F 11 | 12 | from ..dataset import OUTPUT_MASK_KEY, output_mask_name_for_stride 13 | from catalyst.registry import Model 14 | 15 | __all__ = [ 16 | "UnetSegmentationModel", 17 | "resnet18_unet32", 18 | "resnet34_unet32", 19 | "resnet50_unet32", 20 | "resnet101_unet64", 21 | "resnet152_unet32", 22 | "densenet121_unet32", 23 | "densenet161_unet32", 24 | "densenet169_unet32", 25 | "densenet201_unet32", 26 | "b0_unet32_s2", 27 | "b4_unet32", 28 | "b6_unet32_s2", 29 | "b6_unet32_s2_bi", 30 | "b6_unet32_s2_tc", 31 | "b6_unet32_s2_rdtc", 32 | ] 33 | 34 | 35 | class UnetSegmentationModel(nn.Module): 36 | def __init__( 37 | self, 38 | encoder: EncoderModule, 39 | unet_channels: Union[int, List[int]], 40 | num_classes: int = 1, 41 | dropout=0.25, 42 | full_size_mask=True, 43 | activation=ACT_RELU, 44 | upsample_block: Union[Type[nn.Upsample], Type[ResidualDeconvolutionUpsample2d]] = nn.UpsamplingNearest2d, 45 | need_supervision_masks=False, 46 | last_upsample_block=None, 47 | ): 48 | super().__init__() 49 | self.encoder = encoder 50 | 51 | abn_block = partial(ABN, activation=activation) 52 | self.decoder = UNetDecoder( 53 | feature_maps=encoder.channels, 54 | decoder_features=unet_channels, 55 | unet_block=partial(UnetBlock, abn_block=abn_block), 56 | upsample_block=upsample_block, 57 | ) 58 | 59 | if last_upsample_block is not None: 60 | self.last_upsample_block = last_upsample_block(unet_channels[0]) 61 | self.mask = nn.Sequential( 62 | OrderedDict( 63 | [ 64 | ("drop", nn.Dropout2d(dropout)), 65 | ("conv", conv1x1(self.last_upsample_block.out_channels, num_classes)), 66 | ] 67 | ) 68 | ) 69 | else: 70 | self.last_upsample_block = None 71 | 72 | self.mask = nn.Sequential( 73 | OrderedDict([("drop", nn.Dropout2d(dropout)), ("conv", conv1x1(unet_channels[0], num_classes))]) 74 | ) 75 | 76 | if need_supervision_masks: 77 | self.supervision = nn.ModuleList([conv1x1(channels, num_classes) for channels in self.decoder.channels]) 78 | self.supervision_names = [output_mask_name_for_stride(stride) for stride in self.encoder.strides] 79 | else: 80 | self.supervision = None 81 | self.supervision_names = None 82 | 83 | self.full_size_mask = full_size_mask 84 | 85 | def forward(self, x: Tensor) -> Dict[str, Tensor]: 86 | x_size = x.size() 87 | x = self.encoder(x) 88 | x = self.decoder(x) 89 | 90 | # Decode mask 91 | if self.last_upsample_block is not None: 92 | mask = self.mask(self.last_upsample_block(x[0])) 93 | else: 94 | mask = self.mask(x[0]) 95 | if self.full_size_mask: 96 | mask = F.interpolate(mask, size=x_size[2:], mode="bilinear", align_corners=False) 97 | 98 | output = {OUTPUT_MASK_KEY: mask} 99 | 100 | if self.supervision is not None: 101 | for feature_map, supervision, name in zip(x, self.supervision, self.supervision_names): 102 | output[name] = supervision(feature_map) 103 | 104 | return output 105 | 106 | 107 | @Model 108 | def resnet18_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 109 | encoder = E.Resnet18Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 110 | if input_channels != 3: 111 | encoder.change_input_channels(input_channels) 112 | 113 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 114 | 115 | 116 | @Model 117 | def resnet34_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 118 | encoder = E.Resnet34Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 119 | if input_channels != 3: 120 | encoder.change_input_channels(input_channels) 121 | 122 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 123 | 124 | 125 | @Model 126 | def resnet50_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 127 | encoder = E.Resnet50Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 128 | if input_channels != 3: 129 | encoder.change_input_channels(input_channels) 130 | 131 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 132 | 133 | 134 | @Model 135 | def resnet101_unet64(input_channels=3, num_classes=1, dropout=0.5, pretrained=True): 136 | encoder = E.Resnet101Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 137 | if input_channels != 3: 138 | encoder.change_input_channels(input_channels) 139 | 140 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[64, 128, 256, 512], dropout=dropout) 141 | 142 | 143 | @Model 144 | def resnet152_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 145 | encoder = E.Resnet152Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 146 | if input_channels != 3: 147 | encoder.change_input_channels(input_channels) 148 | 149 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 150 | 151 | 152 | # Densenets 153 | 154 | 155 | @Model 156 | def densenet121_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 157 | encoder = E.DenseNet121Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 158 | if input_channels != 3: 159 | encoder.change_input_channels(input_channels) 160 | 161 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 162 | 163 | 164 | @Model 165 | def densenet161_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 166 | encoder = E.DenseNet161Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 167 | if input_channels != 3: 168 | encoder.change_input_channels(input_channels) 169 | 170 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 171 | 172 | 173 | @Model 174 | def densenet169_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 175 | encoder = E.DenseNet169Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 176 | if input_channels != 3: 177 | encoder.change_input_channels(input_channels) 178 | 179 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 180 | 181 | 182 | @Model 183 | def densenet201_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 184 | encoder = E.DenseNet201Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 185 | if input_channels != 3: 186 | encoder.change_input_channels(input_channels) 187 | 188 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 189 | 190 | 191 | # HRNet 192 | 193 | 194 | @Model 195 | def hrnet18_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 196 | encoder = E.HRNetV2Encoder18(pretrained=pretrained) 197 | if input_channels != 3: 198 | encoder.change_input_channels(input_channels) 199 | 200 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 201 | 202 | 203 | @Model 204 | def hrnet34_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 205 | encoder = E.HRNetV2Encoder34(pretrained=pretrained) 206 | if input_channels != 3: 207 | encoder.change_input_channels(input_channels) 208 | 209 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 210 | 211 | 212 | @Model 213 | def hrnet48_unet32(input_channels=3, num_classes=1, dropout=0.0, pretrained=True): 214 | encoder = E.HRNetV2Encoder48(pretrained=pretrained) 215 | if input_channels != 3: 216 | encoder.change_input_channels(input_channels) 217 | 218 | return UnetSegmentationModel(encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], dropout=dropout) 219 | 220 | 221 | # B0-Unet 222 | @Model 223 | def b0_unet32_s2(input_channels=3, num_classes=1, dropout=0.1, pretrained=True): 224 | encoder = E.B0Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 225 | if input_channels != 3: 226 | encoder.change_input_channels(input_channels) 227 | 228 | return UnetSegmentationModel( 229 | encoder, num_classes=num_classes, unet_channels=[16, 32, 64, 128], activation=ACT_SWISH, dropout=dropout 230 | ) 231 | 232 | 233 | @Model 234 | def b4_unet32(input_channels=3, num_classes=1, dropout=0.2, pretrained=True): 235 | encoder = E.B4Encoder(pretrained=pretrained) 236 | if input_channels != 3: 237 | encoder.change_input_channels(input_channels) 238 | 239 | return UnetSegmentationModel( 240 | encoder, num_classes=num_classes, unet_channels=[32, 64, 128], activation=ACT_SWISH, dropout=dropout 241 | ) 242 | 243 | 244 | @Model 245 | def b4_unet32_s2(input_channels=3, num_classes=1, dropout=0.2, pretrained=True): 246 | encoder = E.B4Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 247 | if input_channels != 3: 248 | encoder.change_input_channels(input_channels) 249 | 250 | return UnetSegmentationModel( 251 | encoder, num_classes=num_classes, unet_channels=[32, 64, 128, 256], activation=ACT_SWISH, dropout=dropout 252 | ) 253 | 254 | 255 | @Model 256 | def b6_unet32_s2(input_channels=3, num_classes=1, dropout=0.2, full_size_mask=True, pretrained=True): 257 | encoder = E.B6Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 258 | if input_channels != 3: 259 | encoder.change_input_channels(input_channels) 260 | 261 | return UnetSegmentationModel( 262 | encoder, 263 | num_classes=num_classes, 264 | unet_channels=[32, 64, 128, 256], 265 | activation=ACT_SWISH, 266 | dropout=dropout, 267 | full_size_mask=full_size_mask, 268 | ) 269 | 270 | 271 | @Model 272 | def b6_unet32_s2_bi(input_channels=3, num_classes=1, dropout=0.2, pretrained=True): 273 | encoder = E.B6Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 274 | if input_channels != 3: 275 | encoder.change_input_channels(input_channels) 276 | 277 | return UnetSegmentationModel( 278 | encoder, 279 | num_classes=num_classes, 280 | unet_channels=[32, 64, 128, 256], 281 | activation=ACT_SWISH, 282 | dropout=dropout, 283 | upsample_block=nn.UpsamplingBilinear2d, 284 | ) 285 | 286 | 287 | @Model 288 | def b6_unet32_s2_tc(input_channels=3, num_classes=1, dropout=0.2, pretrained=True): 289 | encoder = E.B6Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 290 | if input_channels != 3: 291 | encoder.change_input_channels(input_channels) 292 | 293 | from pytorch_toolbelt.modules.upsample import DeconvolutionUpsample2d 294 | 295 | return UnetSegmentationModel( 296 | encoder, 297 | num_classes=num_classes, 298 | unet_channels=[32, 64, 128, 256], 299 | activation=ACT_SWISH, 300 | dropout=dropout, 301 | upsample_block=DeconvolutionUpsample2d, 302 | ) 303 | 304 | 305 | @Model 306 | def b6_unet32_s2_rdtc(input_channels=3, num_classes=1, dropout=0.2, need_supervision_masks=False, pretrained=True): 307 | encoder = E.B6Encoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 308 | if input_channels != 3: 309 | encoder.change_input_channels(input_channels) 310 | 311 | from pytorch_toolbelt.modules.upsample import ResidualDeconvolutionUpsample2d 312 | 313 | return UnetSegmentationModel( 314 | encoder, 315 | num_classes=num_classes, 316 | unet_channels=[32, 64, 128, 256], 317 | activation=ACT_SWISH, 318 | dropout=dropout, 319 | need_supervision_masks=need_supervision_masks, 320 | upsample_block=ResidualDeconvolutionUpsample2d, 321 | last_upsample_block=ResidualDeconvolutionUpsample2d, 322 | ) 323 | 324 | 325 | @Model 326 | def mxxl_unet32_s1(input_channels=3, num_classes=1, dropout=0.5, pretrained=True, need_supervision_masks=False): 327 | encoder = E.MixNetXLEncoder(pretrained=pretrained, layers=[0, 1, 2, 3, 4]) 328 | if input_channels != 3: 329 | encoder.change_input_channels(input_channels) 330 | 331 | return UnetSegmentationModel( 332 | encoder, 333 | num_classes=num_classes, 334 | unet_channels=[32, 64, 128, 256], 335 | activation=ACT_SWISH, 336 | dropout=dropout, 337 | need_supervision_masks=need_supervision_masks, 338 | ) 339 | -------------------------------------------------------------------------------- /inria/optim.py: -------------------------------------------------------------------------------- 1 | from torch.optim.optimizer import Optimizer 2 | 3 | __all__ = ["get_optimizer"] 4 | 5 | 6 | def get_optimizer( 7 | optimizer_name: str, parameters, learning_rate: float, weight_decay=1e-5, eps=1e-5, **kwargs 8 | ) -> Optimizer: 9 | from torch.optim import SGD, Adam, RMSprop, AdamW 10 | from torch_optimizer import RAdam, Lamb, DiffGrad, NovoGrad, Ranger 11 | 12 | if optimizer_name.lower() == "sgd": 13 | return SGD(parameters, learning_rate, momentum=0.9, nesterov=True, weight_decay=weight_decay, **kwargs) 14 | 15 | if optimizer_name.lower() == "adam": 16 | return Adam(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs) # As Jeremy suggests 17 | 18 | if optimizer_name.lower() == "rms": 19 | return RMSprop(parameters, learning_rate, weight_decay=weight_decay, **kwargs) 20 | 21 | if optimizer_name.lower() == "adamw": 22 | return AdamW(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs) 23 | 24 | if optimizer_name.lower() == "radam": 25 | return RAdam(parameters, learning_rate, weight_decay=weight_decay, eps=eps, **kwargs) # As Jeremy suggests 26 | 27 | # Optimizers from torch-optimizer 28 | if optimizer_name.lower() == "ranger": 29 | return Ranger(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) 30 | 31 | if optimizer_name.lower() == "lamb": 32 | return Lamb(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) 33 | 34 | if optimizer_name.lower() == "diffgrad": 35 | return DiffGrad(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) 36 | 37 | if optimizer_name.lower() == "novograd": 38 | return NovoGrad(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) 39 | 40 | # Optimizers from Apex (Fused version is faster on GPU with tensor cores) 41 | if optimizer_name.lower() == "fused_lamb": 42 | from apex.optimizers import FusedLAMB 43 | 44 | return FusedLAMB(parameters, learning_rate, eps=eps, weight_decay=weight_decay, **kwargs) 45 | 46 | if optimizer_name.lower() == "fused_sgd": 47 | from apex.optimizers import FusedSGD 48 | 49 | return FusedSGD(parameters, learning_rate, momentum=0.9, nesterov=True, weight_decay=weight_decay, **kwargs) 50 | 51 | if optimizer_name.lower() == "fused_adam": 52 | from apex.optimizers import FusedAdam 53 | 54 | return FusedAdam(parameters, learning_rate, eps=eps, weight_decay=weight_decay, adam_w_mode=True, **kwargs) 55 | 56 | raise ValueError("Unsupported optimizer name " + optimizer_name) 57 | -------------------------------------------------------------------------------- /inria/pseudo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from catalyst.dl import Callback, CallbackOrder, IRunner 3 | from pytorch_toolbelt.utils.catalyst import PseudolabelDatasetMixin 4 | from pytorch_toolbelt.utils.torch_utils import to_numpy 5 | 6 | __all__ = ["BCEOnlinePseudolabelingCallback2d"] 7 | 8 | 9 | class BCEOnlinePseudolabelingCallback2d(Callback): 10 | """ 11 | Online pseudo-labeling callback for multi-class problem. 12 | 13 | >>> unlabeled_train = get_test_dataset( 14 | >>> data_dir, image_size=image_size, augmentation=augmentations 15 | >>> ) 16 | >>> unlabeled_eval = get_test_dataset( 17 | >>> data_dir, image_size=image_size 18 | >>> ) 19 | >>> 20 | >>> callbacks += [ 21 | >>> MulticlassOnlinePseudolabelingCallback( 22 | >>> unlabeled_train.targets, 23 | >>> pseudolabel_loader="label", 24 | >>> prob_threshold=0.9) 25 | >>> ] 26 | >>> train_ds = train_ds + unlabeled_train 27 | >>> 28 | >>> loaders = collections.OrderedDict() 29 | >>> loaders["train"] = DataLoader(train_ds) 30 | >>> loaders["valid"] = DataLoader(valid_ds) 31 | >>> loaders["label"] = DataLoader(unlabeled_eval, shuffle=False) # ! shuffle=False is important ! 32 | """ 33 | 34 | def __init__( 35 | self, 36 | unlabeled_ds: PseudolabelDatasetMixin, 37 | pseudolabel_loader="infer", 38 | prob_threshold=0.9, 39 | sample_index_key="index", 40 | output_key="logits", 41 | unlabeled_class=-100, 42 | label_smoothing=0.0, 43 | label_frequency=1, 44 | ): 45 | assert 1.0 > prob_threshold > 0.5 46 | 47 | super().__init__(CallbackOrder.External) 48 | self.unlabeled_ds = unlabeled_ds 49 | self.pseudolabel_loader = pseudolabel_loader 50 | self.prob_threshold = prob_threshold 51 | self.sample_index_key = sample_index_key 52 | self.output_key = output_key 53 | self.unlabeled_class = unlabeled_class 54 | self.label_smoothing = label_smoothing 55 | self.last_labeled_epoch = None 56 | self.label_frequency = label_frequency 57 | 58 | # def on_epoch_start(self, state: RunnerState): 59 | # pass 60 | 61 | # def on_loader_start(self, state: RunnerState): 62 | # if state.loader_name == self.pseudolabel_loader: 63 | # self.predictions = [] 64 | 65 | def on_stage_start(self, runner: IRunner): 66 | self.last_labeled_epoch = None 67 | 68 | def on_loader_start(self, runner: IRunner): 69 | if runner.loader_name == self.pseudolabel_loader: 70 | self.should_relabel = self.last_labeled_epoch is None or ( 71 | runner.epoch == self.last_labeled_epoch + self.label_frequency 72 | ) 73 | print("Should relabel", self.should_relabel, runner.epoch) 74 | 75 | def on_loader_end(self, runner: "IRunner"): 76 | if runner.loader_name == self.pseudolabel_loader and self.should_relabel: 77 | self.last_labeled_epoch = runner.epoch 78 | print("Set last_labeled_epoch", runner.epoch) 79 | 80 | def get_probabilities(self, state: IRunner): 81 | probs = state.output[self.output_key].detach().sigmoid() 82 | indexes = state.input[self.sample_index_key] 83 | 84 | return to_numpy(probs), to_numpy(indexes) 85 | 86 | def on_batch_end(self, runner: IRunner): 87 | if runner.loader_name != self.pseudolabel_loader: 88 | return 89 | 90 | if not self.should_relabel: 91 | return 92 | 93 | # Get predictions for batch 94 | probs, indexes = self.get_probabilities(runner) 95 | 96 | for p, sample_index in zip(probs, indexes): 97 | # confident_negatives = p < (1.0 - self.prob_threshold) 98 | # confident_positives = p > self.prob_threshold 99 | # rest = ~confident_negatives & ~confident_positives 100 | # 101 | # p = p.copy() 102 | # p[confident_negatives] = 0 + self.label_smoothing 103 | # p[confident_positives] = 1 - self.label_smoothing 104 | # p[rest] = self.unlabeled_class 105 | p = np.moveaxis(p, 0, -1) 106 | 107 | self.unlabeled_ds.set_target(sample_index, p) 108 | -------------------------------------------------------------------------------- /inria/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | from catalyst.contrib.nn import OneCycleLRWithWarmup 5 | from torch.optim.lr_scheduler import ( 6 | ExponentialLR, 7 | CyclicLR, 8 | MultiStepLR, 9 | CosineAnnealingLR, 10 | CosineAnnealingWarmRestarts, 11 | ) 12 | 13 | 14 | class CosineAnnealingWarmRestartsWithDecay(CosineAnnealingWarmRestarts): 15 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, gamma=0.9): 16 | super().__init__(optimizer, T_0, T_mult, eta_min, last_epoch) 17 | self.gamma = gamma 18 | 19 | def get_lr(self): 20 | if not self._get_lr_called_within_step: 21 | warnings.warn( 22 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 23 | DeprecationWarning, 24 | ) 25 | 26 | return [ 27 | self.eta_min 28 | + (base_lr * self.gamma ** self.last_epoch - self.eta_min) 29 | * (1 + math.cos(math.pi * self.T_cur / self.T_i)) 30 | / 2 31 | for base_lr in self.base_lrs 32 | ] 33 | 34 | 35 | def get_scheduler(scheduler_name: str, optimizer, lr, num_epochs, batches_in_epoch=None): 36 | if scheduler_name is None or scheduler_name.lower() == "none": 37 | return None 38 | 39 | if scheduler_name.lower() == "cos": 40 | return CosineAnnealingLR(optimizer, num_epochs, eta_min=1e-6) 41 | 42 | if scheduler_name.lower() == "cos2": 43 | return CosineAnnealingLR(optimizer, num_epochs, eta_min=float(lr * 0.5)) 44 | 45 | if scheduler_name.lower() == "cosr": 46 | return CosineAnnealingWarmRestarts(optimizer, T_0=max(2, num_epochs // 4), eta_min=1e-6) 47 | 48 | if scheduler_name.lower() == "cosrd": 49 | return CosineAnnealingWarmRestartsWithDecay(optimizer, T_0=max(2, num_epochs // 6), gamma=0.96, eta_min=1e-6) 50 | 51 | if scheduler_name.lower() in {"1cycle", "one_cycle"}: 52 | return OneCycleLRWithWarmup( 53 | optimizer, 54 | lr_range=(lr, 1e-6), 55 | num_steps=batches_in_epoch * num_epochs, 56 | warmup_fraction=0.05, 57 | decay_fraction=0.1, 58 | ) 59 | 60 | if scheduler_name.lower() == "exp": 61 | return ExponentialLR(optimizer, gamma=0.95) 62 | 63 | if scheduler_name.lower() == "clr": 64 | return CyclicLR( 65 | optimizer, 66 | base_lr=1e-6, 67 | max_lr=lr, 68 | step_size_up=batches_in_epoch // 4, 69 | # mode='exp_range', 70 | gamma=0.99, 71 | ) 72 | 73 | if scheduler_name.lower() == "multistep": 74 | return MultiStepLR( 75 | optimizer, milestones=[int(num_epochs * 0.5), int(num_epochs * 0.7), int(num_epochs * 0.9)], gamma=0.3 76 | ) 77 | 78 | if scheduler_name.lower() == "simple": 79 | return MultiStepLR(optimizer, milestones=[int(num_epochs * 0.4), int(num_epochs * 0.7)], gamma=0.1) 80 | 81 | raise KeyError(scheduler_name) 82 | -------------------------------------------------------------------------------- /inria/visualization.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, List, Union 2 | 3 | import cv2 4 | import numpy as np 5 | from pytorch_toolbelt.utils.torch_utils import rgb_image_from_tensor, to_numpy 6 | 7 | from inria.dataset import ( 8 | OUTPUT_OFFSET_KEY, 9 | OUTPUT_MASK_4_KEY, 10 | OUTPUT_MASK_32_KEY, 11 | OUTPUT_MASK_16_KEY, 12 | OUTPUT_MASK_8_KEY, 13 | OUTPUT_MASK_2_KEY, 14 | ) 15 | 16 | 17 | def draw_inria_predictions( 18 | input: dict, 19 | output: dict, 20 | inputs_to_labels:Callable, 21 | outputs_to_labels: Callable, 22 | image_key="features", 23 | image_id_key: Optional[str] = "image_id", 24 | targets_key="targets", 25 | outputs_key="logits", 26 | mean=(0.485, 0.456, 0.406), 27 | std=(0.229, 0.224, 0.225), 28 | max_images=None, 29 | image_format: Union[str, Callable] = "bgr", 30 | ) -> List[np.ndarray]: 31 | """ 32 | Render visualization of model's prediction for binary segmentation problem. 33 | This function draws a color-coded overlay on top of the image, with color codes meaning: 34 | - green: True positives 35 | - red: False-negatives 36 | - yellow: False-positives 37 | 38 | :param input: Input batch (model's input batch) 39 | :param output: Output batch (model predictions) 40 | :param image_key: Key for getting image 41 | :param image_id_key: Key for getting image id/fname 42 | :param targets_key: Key for getting ground-truth mask 43 | :param outputs_key: Key for getting model logits for predicted mask 44 | :param mean: Mean vector user during normalization 45 | :param std: Std vector user during normalization 46 | :param max_images: Maximum number of images to visualize from batch 47 | (If you have huge batch, saving hundreds of images may make TensorBoard slow) 48 | :param targets_threshold: Threshold to convert target values to binary. 49 | Default value 0.5 is safe for both smoothed and hard labels. 50 | :param logits_threshold: Threshold to convert model predictions (raw logits) values to binary. 51 | Default value 0.0 is equivalent to 0.5 after applying sigmoid activation 52 | :param image_format: Source format of the image tensor to conver to RGB representation. 53 | Can be string ("gray", "rgb", "brg") or function `convert(np.ndarray)->nd.ndarray`. 54 | :return: List of images 55 | """ 56 | images = [] 57 | num_samples = len(input[image_key]) 58 | if max_images is not None: 59 | num_samples = min(num_samples, max_images) 60 | 61 | true_masks = to_numpy(inputs_to_labels(input[targets_key])).astype(bool) 62 | pred_masks = to_numpy(outputs_to_labels(output[outputs_key])).astype(bool) 63 | 64 | for i in range(num_samples): 65 | image = rgb_image_from_tensor(input[image_key][i], mean, std) 66 | 67 | if image_format == "bgr": 68 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 69 | elif image_format == "gray": 70 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 71 | elif hasattr(image_format, "__call__"): 72 | image = image_format(image) 73 | 74 | overlay = image.copy() 75 | true_mask = true_masks[i] 76 | pred_mask = pred_masks[i] 77 | 78 | overlay[true_mask & pred_mask] = np.array( 79 | [0, 250, 0], dtype=overlay.dtype 80 | ) # Correct predictions (Hits) painted with green 81 | overlay[true_mask & ~pred_mask] = np.array([250, 0, 0], dtype=overlay.dtype) # Misses painted with red 82 | overlay[~true_mask & pred_mask] = np.array( 83 | [250, 250, 0], dtype=overlay.dtype 84 | ) # False alarm painted with yellow 85 | overlay = cv2.addWeighted(image, 0.5, overlay, 0.5, 0, dtype=cv2.CV_8U) 86 | 87 | if OUTPUT_OFFSET_KEY in output: 88 | offset = to_numpy(output[OUTPUT_OFFSET_KEY][i]) * 32 89 | offset = np.expand_dims(offset, -1) 90 | 91 | x = offset[0, ...].clip(min=0, max=1) * np.array([255, 0, 0]) + (-offset[0, ...]).clip( 92 | min=0, max=1 93 | ) * np.array([0, 0, 255]) 94 | y = offset[1, ...].clip(min=0, max=1) * np.array([255, 0, 255]) + (-offset[1, ...]).clip( 95 | min=0, max=1 96 | ) * np.array([0, 255, 0]) 97 | 98 | offset = (x + y).clip(0, 255).astype(np.uint8) 99 | offset = cv2.resize(offset, (image.shape[1], image.shape[0])) 100 | overlay = np.row_stack([overlay, offset]) 101 | 102 | dsv_inputs = [OUTPUT_MASK_2_KEY, OUTPUT_MASK_4_KEY, OUTPUT_MASK_8_KEY, OUTPUT_MASK_16_KEY, OUTPUT_MASK_32_KEY] 103 | for dsv_input_key in dsv_inputs: 104 | if dsv_input_key in output: 105 | dsv_p = to_numpy(output[dsv_input_key][i].detach().float().sigmoid().squeeze(0)) 106 | dsv_p = cv2.resize((dsv_p * 255).astype(np.uint8), (image.shape[1], image.shape[0])) 107 | dsv_p = cv2.cvtColor(dsv_p, cv2.COLOR_GRAY2RGB) 108 | overlay = np.row_stack([overlay, dsv_p]) 109 | 110 | if image_id_key is not None and image_id_key in input: 111 | image_id = input[image_id_key][i] 112 | cv2.putText(overlay, str(image_id), (10, 15), cv2.FONT_HERSHEY_PLAIN, 1, (250, 250, 250)) 113 | 114 | images.append(overlay) 115 | return images 116 | -------------------------------------------------------------------------------- /make_train_tiles.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | 4 | import cv2, os 5 | from pytorch_toolbelt.inference.tiles import ImageSlicer 6 | from pytorch_toolbelt.utils import fs 7 | from pytorch_toolbelt.utils.fs import id_from_fname, read_image_as_is 8 | import pandas as pd 9 | from tqdm import tqdm 10 | 11 | from inria.dataset import TRAIN_LOCATIONS, read_inria_mask 12 | 13 | 14 | def split_image(image_fname, output_dir, tile_size, tile_step, image_margin): 15 | os.makedirs(output_dir, exist_ok=True) 16 | image = read_image_as_is(image_fname) 17 | image_id = id_from_fname(image_fname) 18 | 19 | slicer = ImageSlicer(image.shape, tile_size, tile_step, image_margin) 20 | tiles = slicer.split(image) 21 | 22 | fnames = [] 23 | for i, tile in enumerate(tiles): 24 | output_fname = os.path.join(output_dir, f"{image_id}_tile_{i}.png") 25 | cv2.imwrite(output_fname, tile) 26 | fnames.append(output_fname) 27 | 28 | return fnames 29 | 30 | 31 | def cut_train_dataset_in_patches(data_dir, tile_size, tile_step, image_margin): 32 | 33 | train_data = [] 34 | valid_data = [] 35 | 36 | # For validation, we remove the first five images of every location (e.g., austin{1-5}.tif, chicago{1-5}.tif) from the training set. 37 | # That is suggested validation strategy by competition host 38 | for loc in TRAIN_LOCATIONS: 39 | for i in range(1, 6): 40 | valid_data.append(f"{loc}{i}") 41 | for i in range(6, 37): 42 | train_data.append(f"{loc}{i}") 43 | 44 | train_imgs = [os.path.join(data_dir, "train", "images", f"{fname}.tif") for fname in train_data] 45 | valid_imgs = [os.path.join(data_dir, "train", "images", f"{fname}.tif") for fname in valid_data] 46 | 47 | train_masks = [os.path.join(data_dir, "train", "gt", f"{fname}.tif") for fname in train_data] 48 | valid_masks = [os.path.join(data_dir, "train", "gt", f"{fname}.tif") for fname in valid_data] 49 | 50 | images_dir = os.path.join(data_dir, "train_tiles", "images") 51 | masks_dir = os.path.join(data_dir, "train_tiles", "gt") 52 | 53 | df = defaultdict(list) 54 | 55 | for train_img in tqdm(train_imgs, total=len(train_imgs), desc="train_imgs"): 56 | img_tiles = split_image(train_img, images_dir, tile_size, tile_step, image_margin) 57 | df["image"].extend(img_tiles) 58 | df["train"].extend([1] * len(img_tiles)) 59 | df["image_id"].extend([fs.id_from_fname(train_img)] * len(img_tiles)) 60 | 61 | for train_msk in tqdm(train_masks, total=len(train_masks), desc="train_masks"): 62 | msk_tiles = split_image(train_msk, masks_dir, tile_size, tile_step, image_margin) 63 | df["mask"].extend(msk_tiles) 64 | df["has_buildings"].extend([read_inria_mask(x).any() for x in msk_tiles]) 65 | 66 | for valid_img in tqdm(valid_imgs, total=len(valid_imgs), desc="valid_imgs"): 67 | img_tiles = split_image(valid_img, images_dir, tile_size, tile_size, image_margin) 68 | df["image"].extend(img_tiles) 69 | df["train"].extend([0] * len(img_tiles)) 70 | df["image_id"].extend([fs.id_from_fname(valid_img)] * len(img_tiles)) 71 | 72 | for valid_msk in tqdm(valid_masks, total=len(valid_masks), desc="valid_masks"): 73 | msk_tiles = split_image(valid_msk, masks_dir, tile_size, tile_size, image_margin) 74 | df["mask"].extend(msk_tiles) 75 | df["has_buildings"].extend([read_inria_mask(x).any() for x in msk_tiles]) 76 | 77 | return pd.DataFrame.from_dict(df) 78 | 79 | 80 | def cut_test_dataset_in_patches(data_dir, tile_size, tile_step, image_margin): 81 | train_imgs = fs.find_images_in_dir(os.path.join(data_dir, "test", "images")) 82 | 83 | images_dir = os.path.join(data_dir, "test_tiles", "images") 84 | 85 | df = defaultdict(list) 86 | 87 | for train_img in tqdm(train_imgs, total=len(train_imgs), desc="test_imgs"): 88 | img_tiles = split_image(train_img, images_dir, tile_size, tile_step, image_margin) 89 | df["image"].extend(img_tiles) 90 | df["image_id"].extend([fs.id_from_fname(train_img)] * len(img_tiles)) 91 | 92 | return pd.DataFrame.from_dict(df) 93 | 94 | 95 | def main(): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument( 98 | "-dd", "--data-dir", type=str, required=True, help="Data directory for INRIA sattelite dataset" 99 | ) 100 | args = parser.parse_args() 101 | 102 | # df = cut_train_dataset_in_patches(args.data_dir, tile_size=(768, 768), tile_step=(512, 512), image_margin=0) 103 | # df.to_csv(os.path.join(args.data_dir, "inria_tiles.csv"), index=False) 104 | 105 | df = cut_test_dataset_in_patches(args.data_dir, tile_size=(768, 768), tile_step=(512, 512), image_margin=0) 106 | df.to_csv(os.path.join(args.data_dir, "test_tiles.csv"), index=False) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch=>1.5.0 2 | catalyst==git+https://github.com/BloodAxe/catalyst 3 | albumentations>=0.5.0 4 | pytorch-toolbelt>=0.4.1 5 | torch-optimizer==0.1.0 -------------------------------------------------------------------------------- /run_tensorboard.cmd: -------------------------------------------------------------------------------- 1 | @call c:\Anaconda3\Scripts\activate.bat tb 2 | set CUDA_VISIBLE_DEVICES= 3 | tensorboard --logdir runs --host 0.0.0.0 --port 5555 --window_title "Inria" -------------------------------------------------------------------------------- /run_tensorboard.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES= 3 | tensorboard --logdir runs --host 0.0.0.0 --port 5555 -------------------------------------------------------------------------------- /sample_color.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BloodAxe/Catalyst-Inria-Segmentation-Example/79cc11716c7936cdbb98786e05a496f63141eff1/sample_color.jpg -------------------------------------------------------------------------------- /scripts/train.cmd: -------------------------------------------------------------------------------- 1 | @call c:\Anaconda3\Scripts\activate.bat pytorch14 2 | REM --fp16 -m b4_unet32 --train-mode tiles --show -b 10 -w 4 --size 512 -s cos -o RAdam -a hard -lr 3e-4 -e 100 --criterion bce 1 --criterion dice 0.1 -v -dd d:\datasets\AerialImageDataset 3 | 4 | python fit_predict.py --fp16 -m b4_unet32_s2 --train-mode tiles --show -b 8 -w 4 --size 512 -s cos -o RAdam -a hard -lr 3e-4 -e 100 --criterion bce 1 --criterion dice 0.1 -v -dd d:\datasets\AerialImageDataset 5 | -------------------------------------------------------------------------------- /scripts/train_b6.sh: -------------------------------------------------------------------------------- 1 | #export INRIA_DATA_DIR="/home/bloodaxe/datasets/AerialImageDataset" 2 | #python -m torch.distributed.launch --nproc_per_node=4 fit_predict.py -w 6 --fp16 -v\ 3 | # -b 6 -m b6_unet32_s2\ 4 | # --train-mode tiles -b 8 --size 512 -s cos -o RAdam -a hard -lr 3e-4 -e 100\ 5 | # --criterion bce 1 --criterion dice 1 6 | 7 | export INRIA_DATA_DIR="/home/bloodaxe/datasets/AerialImageDataset" 8 | python -m torch.distributed.launch --nproc_per_node=4 fit_predict.py -w 6 --fp16 -v\ 9 | -b 6 -m b6_unet32_s2_tc\ 10 | --train-mode tiles --size 512 -s cos -o RAdam -a hard -lr 3e-4 -e 50 --seed 555\ 11 | --criterion bce 1 --criterion dice 1 --transfer /home/bloodaxe/develop/Catalyst-Inria-Segmentation-Example/runs/200829_17_44_b6_unet32_s2_fp16_local_rank_0/main/checkpoints_optimized_jaccard/best.pth 12 | 13 | export INRIA_DATA_DIR="/home/bloodaxe/datasets/AerialImageDataset" 14 | python -m torch.distributed.launch --nproc_per_node=4 fit_predict.py -w 6 --fp16 -v\ 15 | -b 6 -m b6_unet32_s2_rdtc\ 16 | --train-mode tiles --size 512 -s cos -o RAdam -a hard -lr 3e-4 -e 50 --seed 555\ 17 | --criterion bce 1 --criterion dice 1\ 18 | -l2 bce 0.75\ 19 | -l4 bce 0.5\ 20 | -l8 bce 0.25\ 21 | -l16 bce 0.125\ 22 | --transfer /home/bloodaxe/develop/Catalyst-Inria-Segmentation-Example/runs/200830_23_50_b6_unet32_s2_bi_fp16_local_rank_0/200830_23_50_b6_unet32_s2_bi_fp16_local_rank_0.pth -------------------------------------------------------------------------------- /scripts/train_hrnet18_4x1080Ti.sh: -------------------------------------------------------------------------------- 1 | python fit_predict.py -m hrnet18 --fp16 -b 64 -w 16 -dd /home/bloodaxe/data/AerialImageDataset --size 512 -s 1cycle -o SGD -a light -lr 1e-3 -e 100 -d 0.1 --criterion bce 1 -v -------------------------------------------------------------------------------- /scripts/train_hrnet18_p3.8xlarge.sh: -------------------------------------------------------------------------------- 1 | python fit_predict.py -m hrnet18 -b 64 -w 24 -dd /home/ubuntu/data/inria/AerialImageDataset --size 512 --train-mode tiles -s cos -o RAdam -a hard -lr 1e-3 -e 200 -d 0.1 --criterion bce 1 -v -------------------------------------------------------------------------------- /scripts/train_hrnet34_p3.8xlarge.sh: -------------------------------------------------------------------------------- 1 | python fit_predict.py -m hrnet34 -b 48 -w 24 -dd /home/ubuntu/data/inria/AerialImageDataset --size 512 --train-mode tiles -s cos -o RAdam -a medium -lr 1e-3 -e 200 -d 0.1 --criterion bce 1 -v -------------------------------------------------------------------------------- /scripts/train_hrnet34_unet64_aug_hard.sh: -------------------------------------------------------------------------------- 1 | python fit_predict.py\ 2 | -dd "/home/bloodaxe/data/AerialImageDataset"\ 3 | -m hrnet34_unet64\ 4 | -a hard\ 5 | -b 48\ 6 | -o RAdam\ 7 | -w 24\ 8 | --fp16\ 9 | -e 100\ 10 | -s cos\ 11 | -lr 1e-3\ 12 | -wd 1e-6\ 13 | --show\ 14 | --seed 123\ 15 | -l bce 1\ 16 | -v 17 | 18 | -------------------------------------------------------------------------------- /scripts/train_hrnet34_unet64_aug_light.sh: -------------------------------------------------------------------------------- 1 | python fit_predict.py\ 2 | -dd "/home/bloodaxe/data/AerialImageDataset"\ 3 | -m hrnet34_unet64\ 4 | -a light\ 5 | -b 48\ 6 | -o RAdam\ 7 | -w 24\ 8 | --fp16\ 9 | -e 100\ 10 | -s cos\ 11 | -lr 1e-3\ 12 | -wd 1e-6\ 13 | --show\ 14 | --seed 123\ 15 | -l bce 1\ 16 | -v 17 | 18 | -------------------------------------------------------------------------------- /scripts/train_hrnet34_unet64_aug_medium.sh: -------------------------------------------------------------------------------- 1 | python fit_predict.py\ 2 | -dd "/home/bloodaxe/data/AerialImageDataset"\ 3 | -m hrnet34_unet64\ 4 | -a medium\ 5 | -b 48\ 6 | -o RAdam\ 7 | -w 24\ 8 | --fp16\ 9 | -e 100\ 10 | -s cos\ 11 | -lr 1e-3\ 12 | -wd 1e-6\ 13 | --show\ 14 | --seed 123\ 15 | -l bce 1\ 16 | -v 17 | 18 | -------------------------------------------------------------------------------- /scripts/train_hrnet34_unet64_aug_none.sh: -------------------------------------------------------------------------------- 1 | python fit_predict.py\ 2 | -dd "/home/bloodaxe/data/AerialImageDataset"\ 3 | -m hrnet34_unet64\ 4 | -a None\ 5 | -b 48\ 6 | -o RAdam\ 7 | -w 24\ 8 | --fp16\ 9 | -e 100\ 10 | -s cos\ 11 | -lr 1e-3\ 12 | -wd 1e-6\ 13 | --show\ 14 | --seed 123\ 15 | -l bce 1\ 16 | -v 17 | -------------------------------------------------------------------------------- /scripts/train_hrnet48_p3.8xlarge.sh: -------------------------------------------------------------------------------- 1 | python fit_predict.py -m hrnet48 -b 32 -w 24 -dd /home/ubuntu/data/inria/AerialImageDataset --size 512 --train-mode tiles -s cos -o RAdam -a hard -lr 1e-3 -e 200 -d 0.1 --criterion bce 1 -v -------------------------------------------------------------------------------- /scripts/train_mixnet.sh: -------------------------------------------------------------------------------- 1 | export INRIA_DATA_DIR="/home/bloodaxe/datasets/AerialImageDataset" 2 | python -m torch.distributed.launch --nproc_per_node=4 fit_predict.py -w 6 --fp16 -v\ 3 | -b 10 -m mxxl_unet32_s1\ 4 | --train-mode tiles --size 512 -s cos -o RAdam -a hard -lr 3e-4 -e 100 --seed 555\ 5 | --criterion bce 1 --criterion dice 1 6 | -------------------------------------------------------------------------------- /scripts/train_resnets.sh: -------------------------------------------------------------------------------- 1 | export INRIA_DATA_DIR="/home/bloodaxe/datasets/AerialImageDataset" 2 | 3 | python -m torch.distributed.launch --nproc_per_node=4 fit_predict.py --fp16 -w 8\ 4 | -b 16 -m resnet101_unet64 --size 512\ 5 | -o RAdam -lr 3e-4 -wd 1e-5 -a hard\ 6 | -e 50 -s cos2\ 7 | --criterion bce 1\ 8 | --criterion dice 1\ 9 | --train-mode tiles --show --verbose 10 | -------------------------------------------------------------------------------- /scripts/train_seresnext50_can.sh: -------------------------------------------------------------------------------- 1 | export INRIA_DATA_DIR="/home/bloodaxe/datasets/AerialImageDataset" 2 | #python fit_predict.py -b 40 -m seresnext50_can -s cos -lr 1e-4 -a hard -v -e 10 -e 50 -l bce 1 --train-mode tiles --show 3 | python fit_predict.py -b 40 -m seresnext50_can -s cosrd -d 0.25 -lr 4e-5 -a hard -v -e 100 -w 16 -e 50 -l bce 1 --train-mode tiles --show -c /home/bloodaxe/develop/Catalyst-Inria-Segmentation-Example/runs/May01_12_45_seresnext50_can/main/checkpoints/best.pth -------------------------------------------------------------------------------- /scripts/train_u2net.sh: -------------------------------------------------------------------------------- 1 | 2 | #export INRIA_DATA_DIR="/home/bloodaxe/datasets/AerialImageDataset" 3 | #python -m torch.distributed.launch --nproc_per_node=4 fit_predict.py -w 6 --fp16 -v\ 4 | # -b 8 -m U2NET\ 5 | # --train-mode tiles --size 512 -s cos -o RAdam -a hard -lr 3e-4 -e 50 --seed 555\ 6 | # --criterion bce 1 --criterion dice 1 -c /home/bloodaxe/develop/Catalyst-Inria-Segmentation-Example/runs/201107_10_23_U2NET_fp16/main/checkpoints_optimized_jaccard/last.pth 7 | 8 | export INRIA_DATA_DIR="/home/bloodaxe/datasets/AerialImageDataset" 9 | python -m torch.distributed.launch --nproc_per_node=4 fit_predict.py -w 6 --fp16 -v\ 10 | -b 8 -m U2NET\ 11 | --train-mode tiles --size 512 -s cos2 -o RAdam -a hard -lr 5e-5 -e 50 --seed 666\ 12 | --criterion bce 1 --criterion dice 1 -c /home/bloodaxe/develop/Catalyst-Inria-Segmentation-Example/runs/201108_10_29_U2NET_fp16/main/checkpoints_optimized_jaccard/best.pth -------------------------------------------------------------------------------- /submit.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from pytorch_toolbelt.inference.tta import TTAWrapper, d4_image2mask, fliplr_image2mask, MultiscaleTTAWrapper 9 | from pytorch_toolbelt.utils.catalyst import report_checkpoint 10 | from torch import nn 11 | 12 | from tqdm import tqdm 13 | from pytorch_toolbelt.utils.fs import auto_file, find_in_dir 14 | 15 | from inria.dataset import read_inria_image, OUTPUT_MASK_KEY 16 | from inria.factory import predict, PickModelOutput 17 | from inria.models import model_from_checkpoint 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("-m", "--model", type=str, default="unet", help="") 23 | parser.add_argument("-dd", "--data-dir", type=str, default=None, required=True, help="Data dir") 24 | parser.add_argument( 25 | "-c", 26 | "--checkpoint", 27 | type=str, 28 | default=None, 29 | required=True, 30 | help="Checkpoint filename to use as initial model weights", 31 | ) 32 | parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch size for inference") 33 | parser.add_argument("-tta", "--tta", default=None, type=str, help="Type of TTA to use [fliplr, d4]") 34 | args = parser.parse_args() 35 | 36 | data_dir = args.data_dir 37 | checkpoint_file = auto_file(args.checkpoint) 38 | run_dir = os.path.dirname(checkpoint_file) 39 | out_dir = os.path.join(run_dir, "submit") 40 | os.makedirs(out_dir, exist_ok=True) 41 | 42 | model, checkpoint = model_from_checkpoint(checkpoint_file, strict=False) 43 | threshold = checkpoint["epoch_metrics"].get("valid_optimized_jaccard/threshold", 0.5) 44 | print(report_checkpoint(checkpoint)) 45 | print("Using threshold", threshold) 46 | 47 | model = nn.Sequential(PickModelOutput(model, OUTPUT_MASK_KEY), nn.Sigmoid()) 48 | 49 | if args.tta == "fliplr": 50 | model = TTAWrapper(model, fliplr_image2mask) 51 | elif args.tta == "d4": 52 | model = TTAWrapper(model, d4_image2mask) 53 | elif args.tta == "ms-d2": 54 | model = TTAWrapper(model, fliplr_image2mask) 55 | model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128]) 56 | elif args.tta == "ms-d4": 57 | model = TTAWrapper(model, d4_image2mask) 58 | model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128]) 59 | elif args.tta == "ms": 60 | model = MultiscaleTTAWrapper(model, size_offsets=[-128, -64, 64, 128]) 61 | else: 62 | pass 63 | 64 | model = model.cuda() 65 | if torch.cuda.device_count() > 1: 66 | model = nn.DataParallel(model) 67 | 68 | model = model.eval() 69 | 70 | # mask = predict(model, read_inria_image("sample_color.jpg"), image_size=(512, 512), batch_size=args.batch_size * torch.cuda.device_count()) 71 | # mask = ((mask > threshold) * 255).astype(np.uint8) 72 | # name = os.path.join(run_dir, "sample_color.jpg") 73 | # cv2.imwrite(name, mask) 74 | 75 | test_predictions_dir = os.path.join(out_dir, "test_predictions") 76 | test_predictions_dir_compressed = os.path.join(out_dir, "test_predictions_compressed") 77 | 78 | if args.tta is not None: 79 | test_predictions_dir += f"_{args.tta}" 80 | test_predictions_dir_compressed += f"_{args.tta}" 81 | 82 | os.makedirs(test_predictions_dir, exist_ok=True) 83 | os.makedirs(test_predictions_dir_compressed, exist_ok=True) 84 | 85 | test_images = find_in_dir(os.path.join(data_dir, "test", "images")) 86 | for fname in tqdm(test_images, total=len(test_images)): 87 | predicted_mask_fname = os.path.join(test_predictions_dir, os.path.basename(fname)) 88 | 89 | if not os.path.isfile(predicted_mask_fname): 90 | image = read_inria_image(fname) 91 | mask = predict(model, image, image_size=(512, 512), batch_size=args.batch_size * torch.cuda.device_count()) 92 | mask = ((mask > threshold) * 255).astype(np.uint8) 93 | cv2.imwrite(predicted_mask_fname, mask) 94 | 95 | name_compressed = os.path.join(test_predictions_dir_compressed, os.path.basename(fname)) 96 | command = ( 97 | "gdal_translate --config GDAL_PAM_ENABLED NO -co COMPRESS=CCITTFAX4 -co NBITS=1 " 98 | + predicted_mask_fname 99 | + " " 100 | + name_compressed 101 | ) 102 | subprocess.call(command, shell=True) 103 | 104 | 105 | if __name__ == "__main__": 106 | # Give no chance to randomness 107 | torch.manual_seed(0) 108 | np.random.seed(0) 109 | torch.backends.cudnn.deterministic = True 110 | torch.backends.cudnn.benchmark = False 111 | 112 | main() 113 | -------------------------------------------------------------------------------- /tests/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BloodAxe/Catalyst-Inria-Segmentation-Example/79cc11716c7936cdbb98786e05a496f63141eff1/tests/mask.png -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | from inria.dataset import compute_weight_mask, read_inria_mask, read_xview_mask, mask2depth, depth2mask 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | 6 | def test_compute_weight_mask(): 7 | mask = read_xview_mask("mask.png") 8 | 9 | w = compute_weight_mask(mask, edge_weight=4) 10 | 11 | plt.figure(figsize=(10, 10)) 12 | plt.tight_layout() 13 | plt.imshow(w) 14 | plt.axis("off") 15 | plt.show() 16 | 17 | 18 | def test_mask2depth(): 19 | x = np.random.randint(0, 2, (512, 512), dtype=np.uint8) 20 | a = mask2depth(x) 21 | print(np.bincount(a.flatten(),minlength=16)) 22 | y = depth2mask(a) 23 | np.testing.assert_equal(x, y) 24 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pytorch_toolbelt.utils.torch_utils import maybe_cuda, count_parameters 4 | 5 | from inria.models import get_model 6 | from inria.models.efficient_unet import b4_effunet32_s2 7 | from inria.models.u2net import U2NETP, U2NET 8 | from inria.models.unet import b6_unet32_s2_rdtc, b6_unet32_s2_tc 9 | 10 | 11 | @torch.no_grad() 12 | def test_b4_effunet32_s2(): 13 | model = maybe_cuda(b4_effunet32_s2()) 14 | x = maybe_cuda(torch.rand((2, 3, 512, 512))) 15 | output = model(x) 16 | print(count_parameters(model)) 17 | for key, value in output.items(): 18 | print(key, value.size(), value.mean(), value.std()) 19 | 20 | 21 | @torch.no_grad() 22 | def test_b6_unet32_s2_tc(): 23 | model = b4_effunet32_s2() 24 | model = maybe_cuda(model.eval()) 25 | x = maybe_cuda(torch.rand((2, 3, 512, 512))) 26 | output = model(x) 27 | print(count_parameters(model)) 28 | for key, value in output.items(): 29 | print(key, value.size(), value.mean(), value.std()) 30 | 31 | 32 | @torch.no_grad() 33 | def test_b6_unet32_s2_rdtc(): 34 | model = b6_unet32_s2_rdtc(need_supervision_masks=True) 35 | model = maybe_cuda(model.eval()) 36 | x = maybe_cuda(torch.rand((2, 3, 512, 512))) 37 | output = model(x) 38 | print(count_parameters(model)) 39 | for key, value in output.items(): 40 | print(key, value.size(), value.mean(), value.std()) 41 | 42 | 43 | @torch.no_grad() 44 | def test_test_b6_unet32_s2_tc(): 45 | model = b6_unet32_s2_tc() 46 | model = maybe_cuda(model.eval()) 47 | x = maybe_cuda(torch.rand((2, 3, 512, 512))) 48 | output = model(x) 49 | print(count_parameters(model)) 50 | for key, value in output.items(): 51 | print(key, value.size(), value.mean(), value.std()) 52 | 53 | 54 | @torch.no_grad() 55 | def test_U2NETP(): 56 | model = U2NETP() 57 | model = maybe_cuda(model.eval()) 58 | x = maybe_cuda(torch.rand((2, 3, 512, 512))) 59 | output = model(x) 60 | print(count_parameters(model)) 61 | for key, value in output.items(): 62 | print(key, value.size(), value.mean(), value.std()) 63 | 64 | 65 | @torch.no_grad() 66 | def test_U2NET(): 67 | model = U2NET() 68 | model = maybe_cuda(model.eval()) 69 | x = maybe_cuda(torch.rand((2, 3, 512, 512))) 70 | output = model(x) 71 | print(count_parameters(model)) 72 | for key, value in output.items(): 73 | print(key, value.size(), value.mean(), value.std()) 74 | --------------------------------------------------------------------------------