├── .gitignore ├── LICENSE ├── README.md ├── base ├── __init__.py ├── base_dataloader.py ├── base_dataset.py ├── base_model.py └── base_trainer.py ├── config.json ├── dataloaders ├── __init__.py ├── ade20k.py ├── cityscapes.py ├── coco.py ├── deepscene.py ├── labels │ ├── ade20k.txt │ ├── cityscapes.txt │ ├── coco.txt │ ├── cocostuff_hierarchy.json │ ├── cocostuff_labels.txt │ ├── deepscene.txt │ ├── voc.txt │ ├── voc_context_classes-400.txt │ └── voc_context_classes-59.txt └── voc.py ├── images ├── colour_scheme.png ├── learning_rates.png ├── tb1.png └── tb2.png ├── inference.py ├── models ├── deeplabv3_plus_xception.py ├── __init__.py ├── deeplabv3_plus.py ├── duc_hdc.py ├── enet.py ├── fcn.py ├── gcn.py ├── pspnet.py ├── resnet.py ├── segnet.py ├── unet.py └── upernet.py ├── requirements.txt ├── train.py ├── trainer.py ├── tutorial.ipynb └── utils ├── __init__.py ├── helpers.py ├── logger.py ├── losses.py ├── lovasz_losses.py ├── lr_scheduler.py ├── metrics.py ├── palette.py ├── sync_batchnorm ├── __init__.py ├── batchnorm.py ├── batchnorm_reimpl.py ├── comm.py ├── replicate.py └── unittest.py ├── torchsummary.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # input data, saved log, checkpoints and pretrained models 104 | data/ 105 | input/ 106 | saved/ 107 | outputs/ 108 | datasets/ 109 | pretrained/ 110 | trained_models/ 111 | 112 | # editor, os cache directory 113 | .vscode/ 114 | .idea/ 115 | __MACOSX/ 116 | 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yassine 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataloader import * 2 | from .base_dataset import * 3 | from .base_model import * 4 | from .base_trainer import * 5 | 6 | 7 | -------------------------------------------------------------------------------- /base/base_dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | 7 | class BaseDataLoader(DataLoader): 8 | def __init__(self, dataset, batch_size, shuffle, num_workers, val_split = 0.0): 9 | self.shuffle = shuffle 10 | self.dataset = dataset 11 | self.nbr_examples = len(dataset) 12 | if val_split: self.train_sampler, self.val_sampler = self._split_sampler(val_split) 13 | else: self.train_sampler, self.val_sampler = None, None 14 | 15 | self.init_kwargs = { 16 | 'dataset': self.dataset, 17 | 'batch_size': batch_size, 18 | 'shuffle': self.shuffle, 19 | 'num_workers': num_workers, 20 | 'pin_memory': True 21 | } 22 | super(BaseDataLoader, self).__init__(sampler=self.train_sampler, **self.init_kwargs) 23 | 24 | def _split_sampler(self, split): 25 | if split == 0.0: 26 | return None, None 27 | 28 | self.shuffle = False 29 | 30 | split_indx = int(self.nbr_examples * split) 31 | np.random.seed(0) 32 | 33 | indxs = np.arange(self.nbr_examples) 34 | np.random.shuffle(indxs) 35 | train_indxs = indxs[split_indx:] 36 | val_indxs = indxs[:split_indx] 37 | self.nbr_examples = len(train_indxs) 38 | 39 | train_sampler = SubsetRandomSampler(train_indxs) 40 | val_sampler = SubsetRandomSampler(val_indxs) 41 | return train_sampler, val_sampler 42 | 43 | def get_val_loader(self): 44 | if self.val_sampler is None: 45 | return None 46 | #self.init_kwargs['batch_size'] = 1 47 | return DataLoader(sampler=self.val_sampler, **self.init_kwargs) 48 | 49 | class DataPrefetcher(object): 50 | def __init__(self, loader, device, stop_after=None): 51 | self.loader = loader 52 | self.dataset = loader.dataset 53 | self.stream = torch.cuda.Stream() 54 | self.stop_after = stop_after 55 | self.next_input = None 56 | self.next_target = None 57 | self.device = device 58 | 59 | def __len__(self): 60 | return len(self.loader) 61 | 62 | def preload(self): 63 | try: 64 | self.next_input, self.next_target = next(self.loaditer) 65 | except StopIteration: 66 | self.next_input = None 67 | self.next_target = None 68 | return 69 | with torch.cuda.stream(self.stream): 70 | self.next_input = self.next_input.cuda(device=self.device, non_blocking=True) 71 | self.next_target = self.next_target.cuda(device=self.device, non_blocking=True) 72 | 73 | def __iter__(self): 74 | count = 0 75 | self.loaditer = iter(self.loader) 76 | self.preload() 77 | while self.next_input is not None: 78 | torch.cuda.current_stream().wait_stream(self.stream) 79 | input = self.next_input 80 | target = self.next_target 81 | self.preload() 82 | count += 1 83 | yield input, target 84 | if type(self.stop_after) is int and (count > self.stop_after): 85 | break -------------------------------------------------------------------------------- /base/base_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import torch 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | from torchvision import transforms 8 | from scipy import ndimage 9 | 10 | class BaseDataSet(Dataset): 11 | def __init__(self, root, split, mean, std, base_size=None, augment=True, val=False, 12 | crop_size=321, scale=True, flip=True, rotate=False, blur=False, return_id=False): 13 | self.root = root 14 | self.split = split 15 | self.mean = mean 16 | self.std = std 17 | self.augment = augment 18 | self.crop_size = crop_size 19 | if self.augment: 20 | self.base_size = base_size 21 | self.scale = scale 22 | self.flip = flip 23 | self.rotate = rotate 24 | self.blur = blur 25 | self.val = val 26 | self.files = [] 27 | self._set_files() 28 | self.to_tensor = transforms.ToTensor() 29 | self.normalize = transforms.Normalize(mean, std) 30 | self.return_id = return_id 31 | 32 | cv2.setNumThreads(0) 33 | 34 | def _set_files(self): 35 | raise NotImplementedError 36 | 37 | def _load_data(self, index): 38 | raise NotImplementedError 39 | 40 | def _val_augmentation(self, image, label): 41 | if self.crop_size: 42 | h, w = label.shape 43 | # Scale the smaller side to crop size 44 | if h < w: 45 | h, w = (self.crop_size, int(self.crop_size * w / h)) 46 | else: 47 | h, w = (int(self.crop_size * h / w), self.crop_size) 48 | 49 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR) 50 | label = Image.fromarray(label).resize((w, h), resample=Image.NEAREST) 51 | label = np.asarray(label, dtype=np.int32) 52 | 53 | # Center Crop 54 | h, w = label.shape 55 | start_h = (h - self.crop_size )// 2 56 | start_w = (w - self.crop_size )// 2 57 | end_h = start_h + self.crop_size 58 | end_w = start_w + self.crop_size 59 | image = image[start_h:end_h, start_w:end_w] 60 | label = label[start_h:end_h, start_w:end_w] 61 | return image, label 62 | 63 | def _augmentation(self, image, label): 64 | h, w, _ = image.shape 65 | # Scaling, we set the bigger to base size, and the smaller 66 | # one is rescaled to maintain the same ratio, if we don't have any obj in the image, re-do the processing 67 | if self.base_size: 68 | if self.scale: 69 | longside = random.randint(int(self.base_size*0.5), int(self.base_size*2.0)) 70 | else: 71 | longside = self.base_size 72 | h, w = (longside, int(1.0 * longside * w / h + 0.5)) if h > w else (int(1.0 * longside * h / w + 0.5), longside) 73 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR) 74 | label = cv2.resize(label, (w, h), interpolation=cv2.INTER_NEAREST) 75 | 76 | h, w, _ = image.shape 77 | # Rotate the image with an angle between -10 and 10 78 | if self.rotate: 79 | angle = random.randint(-10, 10) 80 | center = (w / 2, h / 2) 81 | rot_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) 82 | image = cv2.warpAffine(image, rot_matrix, (w, h), flags=cv2.INTER_LINEAR)#, borderMode=cv2.BORDER_REFLECT) 83 | label = cv2.warpAffine(label, rot_matrix, (w, h), flags=cv2.INTER_NEAREST)#, borderMode=cv2.BORDER_REFLECT) 84 | 85 | # Padding to return the correct crop size 86 | if self.crop_size: 87 | pad_h = max(self.crop_size - h, 0) 88 | pad_w = max(self.crop_size - w, 0) 89 | pad_kwargs = { 90 | "top": 0, 91 | "bottom": pad_h, 92 | "left": 0, 93 | "right": pad_w, 94 | "borderType": cv2.BORDER_CONSTANT,} 95 | if pad_h > 0 or pad_w > 0: 96 | image = cv2.copyMakeBorder(image, value=0, **pad_kwargs) 97 | label = cv2.copyMakeBorder(label, value=0, **pad_kwargs) 98 | 99 | # Cropping 100 | h, w, _ = image.shape 101 | start_h = random.randint(0, h - self.crop_size) 102 | start_w = random.randint(0, w - self.crop_size) 103 | end_h = start_h + self.crop_size 104 | end_w = start_w + self.crop_size 105 | image = image[start_h:end_h, start_w:end_w] 106 | label = label[start_h:end_h, start_w:end_w] 107 | 108 | # Random H flip 109 | if self.flip: 110 | if random.random() > 0.5: 111 | image = np.fliplr(image).copy() 112 | label = np.fliplr(label).copy() 113 | 114 | # Gaussian Blud (sigma between 0 and 1.5) 115 | if self.blur: 116 | sigma = random.random() 117 | ksize = int(3.3 * sigma) 118 | ksize = ksize + 1 if ksize % 2 == 0 else ksize 119 | image = cv2.GaussianBlur(image, (ksize, ksize), sigmaX=sigma, sigmaY=sigma, borderType=cv2.BORDER_REFLECT_101) 120 | return image, label 121 | 122 | def __len__(self): 123 | return len(self.files) 124 | 125 | def __getitem__(self, index): 126 | image, label, image_id = self._load_data(index) 127 | if self.val: 128 | image, label = self._val_augmentation(image, label) 129 | elif self.augment: 130 | image, label = self._augmentation(image, label) 131 | 132 | label = torch.from_numpy(np.array(label, dtype=np.int32)).long() 133 | image = Image.fromarray(np.uint8(image)) 134 | if self.return_id: 135 | return self.normalize(self.to_tensor(image)), label, image_id 136 | return self.normalize(self.to_tensor(image)), label 137 | 138 | def __repr__(self): 139 | fmt_str = "Dataset: " + self.__class__.__name__ + "\n" 140 | fmt_str += " # data: {}\n".format(self.__len__()) 141 | fmt_str += " Split: {}\n".format(self.split) 142 | fmt_str += " Root: {}".format(self.root) 143 | return fmt_str 144 | 145 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import numpy as np 4 | from utils.torchsummary import summary 5 | 6 | class BaseModel(nn.Module): 7 | def __init__(self): 8 | super(BaseModel, self).__init__() 9 | self.logger = logging.getLogger(self.__class__.__name__) 10 | 11 | def forward(self): 12 | raise NotImplementedError 13 | 14 | def summary(self): 15 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 16 | nbr_params = sum([np.prod(p.size()) for p in model_parameters]) 17 | self.logger.info(f'Nbr of trainable parameters: {nbr_params}') 18 | 19 | def __str__(self): 20 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 21 | nbr_params = sum([np.prod(p.size()) for p in model_parameters]) 22 | return super(BaseModel, self).__str__() + f'\nNbr of trainable parameters: {nbr_params}' 23 | #return summary(self, input_shape=(2, 3, 224, 224)) 24 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import json 4 | import math 5 | import torch 6 | import datetime 7 | from torch.utils import tensorboard 8 | from utils import helpers 9 | from utils import logger 10 | import utils.lr_scheduler 11 | from utils.sync_batchnorm import convert_model 12 | from utils.sync_batchnorm import DataParallelWithCallback 13 | 14 | def get_instance(module, name, config, *args): 15 | # GET THE CORRESPONDING CLASS / FCT 16 | return getattr(module, config[name]['type'])(*args, **config[name]['args']) 17 | 18 | class BaseTrainer: 19 | def __init__(self, model, loss, resume, config, train_loader, val_loader=None, train_logger=None): 20 | self.model = model 21 | self.loss = loss 22 | self.config = config 23 | self.train_loader = train_loader 24 | self.val_loader = val_loader 25 | self.train_logger = train_logger 26 | self.logger = logging.getLogger(self.__class__.__name__) 27 | self.do_validation = self.config['trainer']['val'] 28 | self.start_epoch = 1 29 | self.improved = False 30 | 31 | # SETTING THE DEVICE 32 | self.device, availble_gpus = self._get_available_devices(self.config['n_gpu']) 33 | if config["use_synch_bn"]: 34 | self.model = convert_model(self.model) 35 | self.model = DataParallelWithCallback(self.model, device_ids=availble_gpus) 36 | else: 37 | self.model = torch.nn.DataParallel(self.model, device_ids=availble_gpus) 38 | self.model.to(self.device) 39 | 40 | # CONFIGS 41 | cfg_trainer = self.config['trainer'] 42 | self.epochs = cfg_trainer['epochs'] 43 | self.save_period = cfg_trainer['save_period'] 44 | 45 | # OPTIMIZER 46 | if self.config['optimizer']['differential_lr']: 47 | if isinstance(self.model, torch.nn.DataParallel): 48 | trainable_params = [{'params': filter(lambda p:p.requires_grad, self.model.module.get_decoder_params())}, 49 | {'params': filter(lambda p:p.requires_grad, self.model.module.get_backbone_params()), 50 | 'lr': config['optimizer']['args']['lr'] / 10}] 51 | else: 52 | trainable_params = [{'params': filter(lambda p:p.requires_grad, self.model.get_decoder_params())}, 53 | {'params': filter(lambda p:p.requires_grad, self.model.get_backbone_params()), 54 | 'lr': config['optimizer']['args']['lr'] / 10}] 55 | else: 56 | trainable_params = filter(lambda p:p.requires_grad, self.model.parameters()) 57 | self.optimizer = get_instance(torch.optim, 'optimizer', config, trainable_params) 58 | self.lr_scheduler = getattr(utils.lr_scheduler, config['lr_scheduler']['type'])(self.optimizer, self.epochs, len(train_loader)) 59 | 60 | # MONITORING 61 | self.monitor = cfg_trainer.get('monitor', 'off') 62 | if self.monitor == 'off': 63 | self.mnt_mode = 'off' 64 | self.mnt_best = 0 65 | else: 66 | self.mnt_mode, self.mnt_metric = self.monitor.split() 67 | assert self.mnt_mode in ['min', 'max'] 68 | self.mnt_best = -math.inf if self.mnt_mode == 'max' else math.inf 69 | self.early_stoping = cfg_trainer.get('early_stop', math.inf) 70 | 71 | # CHECKPOINTS & TENSOBOARD 72 | start_time = datetime.datetime.now().strftime('%m-%d_%H-%M') 73 | self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], self.config['name'], start_time) 74 | helpers.dir_exists(self.checkpoint_dir) 75 | config_save_path = os.path.join(self.checkpoint_dir, 'config.json') 76 | with open(config_save_path, 'w') as handle: 77 | json.dump(self.config, handle, indent=4, sort_keys=True) 78 | 79 | writer_dir = os.path.join(cfg_trainer['log_dir'], self.config['name'], start_time) 80 | self.writer = tensorboard.SummaryWriter(writer_dir) 81 | 82 | if resume: self._resume_checkpoint(resume) 83 | 84 | def _get_available_devices(self, n_gpu): 85 | sys_gpu = torch.cuda.device_count() 86 | if sys_gpu == 0: 87 | self.logger.warning('No GPUs detected, using the CPU') 88 | n_gpu = 0 89 | elif n_gpu > sys_gpu: 90 | self.logger.warning(f'Nbr of GPU requested is {n_gpu} but only {sys_gpu} are available') 91 | n_gpu = sys_gpu 92 | 93 | device = torch.device('cuda:0' if n_gpu > 0 else 'cpu') 94 | self.logger.info(f'Detected GPUs: {sys_gpu} Requested: {n_gpu}') 95 | available_gpus = list(range(n_gpu)) 96 | return device, available_gpus 97 | 98 | def train(self): 99 | for epoch in range(self.start_epoch, self.epochs+1): 100 | # RUN TRAIN (AND VAL) 101 | results = self._train_epoch(epoch) 102 | if self.do_validation and epoch % self.config['trainer']['val_per_epochs'] == 0: 103 | results = self._valid_epoch(epoch) 104 | 105 | # LOGGING INFO 106 | self.logger.info(f'\n ## Info for epoch {epoch} ## ') 107 | for k, v in results.items(): 108 | self.logger.info(f' {str(k):15s}: {v}') 109 | 110 | if self.train_logger is not None: 111 | log = {'epoch' : epoch, **results} 112 | self.train_logger.add_entry(log) 113 | 114 | # CHECKING IF THIS IS THE BEST MODEL (ONLY FOR VAL) 115 | if self.mnt_mode != 'off' and epoch % self.config['trainer']['val_per_epochs'] == 0: 116 | try: 117 | if self.mnt_mode == 'min': self.improved = (log[self.mnt_metric] < self.mnt_best) 118 | else: self.improved = (log[self.mnt_metric] > self.mnt_best) 119 | except KeyError: 120 | self.logger.warning(f'The metrics being tracked ({self.mnt_metric}) has not been calculated. Training stops.') 121 | break 122 | 123 | if self.improved: 124 | self.mnt_best = log[self.mnt_metric] 125 | self.not_improved_count = 0 126 | else: 127 | self.not_improved_count += 1 128 | 129 | if self.not_improved_count > self.early_stoping: 130 | self.logger.info(f'\nPerformance didn\'t improve for {self.early_stoping} epochs') 131 | self.logger.warning('Training Stoped') 132 | break 133 | 134 | # SAVE CHECKPOINT 135 | if epoch % self.save_period == 0: 136 | self._save_checkpoint(epoch, save_best=self.improved) 137 | 138 | def _save_checkpoint(self, epoch, save_best=False): 139 | state = { 140 | 'arch': type(self.model).__name__, 141 | 'epoch': epoch, 142 | 'state_dict': self.model.state_dict(), 143 | 'optimizer': self.optimizer.state_dict(), 144 | 'monitor_best': self.mnt_best, 145 | 'config': self.config 146 | } 147 | filename = os.path.join(self.checkpoint_dir, f'checkpoint-epoch{epoch}.pth') 148 | self.logger.info(f'\nSaving a checkpoint: {filename} ...') 149 | torch.save(state, filename) 150 | 151 | if save_best: 152 | filename = os.path.join(self.checkpoint_dir, f'best_model.pth') 153 | torch.save(state, filename) 154 | self.logger.info("Saving current best: best_model.pth") 155 | 156 | def _resume_checkpoint(self, resume_path): 157 | self.logger.info(f'Loading checkpoint : {resume_path}') 158 | checkpoint = torch.load(resume_path) 159 | 160 | # Load last run info, the model params, the optimizer and the loggers 161 | self.start_epoch = checkpoint['epoch'] + 1 162 | self.mnt_best = checkpoint['monitor_best'] 163 | self.not_improved_count = 0 164 | 165 | if checkpoint['config']['arch'] != self.config['arch']: 166 | self.logger.warning({'Warning! Current model is not the same as the one in the checkpoint'}) 167 | self.model.load_state_dict(checkpoint['state_dict']) 168 | 169 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 170 | self.logger.warning({'Warning! Current optimizer is not the same as the one in the checkpoint'}) 171 | self.optimizer.load_state_dict(checkpoint['optimizer']) 172 | 173 | self.logger.info(f'Checkpoint <{resume_path}> (epoch {self.start_epoch}) was loaded') 174 | 175 | def _train_epoch(self, epoch): 176 | raise NotImplementedError 177 | 178 | def _valid_epoch(self, epoch): 179 | raise NotImplementedError 180 | 181 | def _eval_metrics(self, output, target): 182 | raise NotImplementedError 183 | 184 | 185 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "PSPNet", 3 | "n_gpu": 1, 4 | "use_synch_bn": false, 5 | 6 | "arch": { 7 | "type": "PSPNet", 8 | "args": { 9 | "backbone": "resnet50", 10 | "freeze_bn": false, 11 | "freeze_backbone": false 12 | } 13 | }, 14 | 15 | "train_loader": { 16 | "type": "CityScapes", 17 | "args":{ 18 | "data_dir": "./data/cityscapes", 19 | "batch_size": 8, 20 | "base_size": 400, 21 | "crop_size": 380, 22 | "augment": true, 23 | "shuffle": true, 24 | "scale": true, 25 | "flip": true, 26 | "rotate": true, 27 | "blur": false, 28 | "split": "train", 29 | "num_workers": 8 30 | } 31 | }, 32 | 33 | "val_loader": { 34 | "type": "CityScapes", 35 | "args":{ 36 | "data_dir": "./data/cityscapes", 37 | "batch_size": 8, 38 | "crop_size": 480, 39 | "val": true, 40 | "split": "val", 41 | "num_workers": 4 42 | } 43 | }, 44 | 45 | "optimizer": { 46 | "type": "SGD", 47 | "differential_lr": true, 48 | "args":{ 49 | "lr": 0.01, 50 | "weight_decay": 1e-4, 51 | "momentum": 0.9 52 | } 53 | }, 54 | 55 | "loss": "CrossEntropyLoss2d", 56 | "ignore_index": 255, 57 | "lr_scheduler": { 58 | "type": "Poly", 59 | "args": {} 60 | }, 61 | 62 | "trainer": { 63 | "epochs": 80, 64 | "save_dir": "saved/", 65 | "save_period": 10, 66 | 67 | "monitor": "max Mean_IoU", 68 | "early_stop": 10, 69 | 70 | "tensorboard": true, 71 | "log_dir": "saved/runs", 72 | "log_per_iter": 20, 73 | 74 | "val": true, 75 | "val_per_epochs": 5 76 | } 77 | } -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco import COCO 2 | from .voc import VOC 3 | from .ade20k import ADE20K 4 | from .cityscapes import CityScapes 5 | from .deepscene import DeepScene -------------------------------------------------------------------------------- /dataloaders/ade20k.py: -------------------------------------------------------------------------------- 1 | from base import BaseDataSet, BaseDataLoader 2 | from utils import palette 3 | import numpy as np 4 | import os 5 | import torch 6 | import cv2 7 | from PIL import Image 8 | from glob import glob 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | 12 | 13 | class ADE20KDataset(BaseDataSet): 14 | """ 15 | ADE20K dataset 16 | http://groups.csail.mit.edu/vision/datasets/ADE20K/ 17 | """ 18 | def __init__(self, **kwargs): 19 | self.num_classes = 150 20 | self.palette = palette.ADE20K_palette 21 | super(ADE20KDataset, self).__init__(**kwargs) 22 | 23 | def _set_files(self): 24 | if self.split in ["training", "validation"]: 25 | self.image_dir = os.path.join(self.root, 'images', self.split) 26 | self.label_dir = os.path.join(self.root, 'annotations', self.split) 27 | self.files = [os.path.basename(path).split('.')[0] for path in glob(self.image_dir + '/*.jpg')] 28 | else: raise ValueError(f"Invalid split name {self.split}") 29 | 30 | def _load_data(self, index): 31 | image_id = self.files[index] 32 | image_path = os.path.join(self.image_dir, image_id + '.jpg') 33 | label_path = os.path.join(self.label_dir, image_id + '.png') 34 | image = np.asarray(Image.open(image_path).convert('RGB'), dtype=np.float32) 35 | label = np.asarray(Image.open(label_path), dtype=np.int32) - 1 # from -1 to 149 36 | return image, label, image_id 37 | 38 | class ADE20K(BaseDataLoader): 39 | def __init__(self, data_dir, batch_size, split, crop_size=None, base_size=None, scale=True, num_workers=1, val=False, 40 | shuffle=False, flip=False, rotate=False, blur= False, augment=False, val_split= None, return_id=False): 41 | 42 | self.MEAN = [0.48897059, 0.46548275, 0.4294] 43 | self.STD = [0.22861765, 0.22948039, 0.24054667] 44 | 45 | kwargs = { 46 | 'root': data_dir, 47 | 'split': split, 48 | 'mean': self.MEAN, 49 | 'std': self.STD, 50 | 'augment': augment, 51 | 'crop_size': crop_size, 52 | 'base_size': base_size, 53 | 'scale': scale, 54 | 'flip': flip, 55 | 'blur': blur, 56 | 'rotate': rotate, 57 | 'return_id': return_id, 58 | 'val': val 59 | } 60 | 61 | self.dataset = ADE20KDataset(**kwargs) 62 | super(ADE20K, self).__init__(self.dataset, batch_size, shuffle, num_workers, val_split) 63 | -------------------------------------------------------------------------------- /dataloaders/cityscapes.py: -------------------------------------------------------------------------------- 1 | from base import BaseDataSet, BaseDataLoader 2 | from utils import palette 3 | from glob import glob 4 | import numpy as np 5 | import os 6 | import cv2 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | 12 | ignore_label = 255 13 | ID_TO_TRAINID = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label, 14 | 3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label, 15 | 7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4, 16 | 14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5, 17 | 18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 18 | 28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18} 19 | 20 | class CityScapesDataset(BaseDataSet): 21 | def __init__(self, mode='fine', **kwargs): 22 | self.num_classes = 19 23 | self.mode = mode 24 | self.palette = palette.CityScpates_palette 25 | self.id_to_trainId = ID_TO_TRAINID 26 | super(CityScapesDataset, self).__init__(**kwargs) 27 | 28 | def _set_files(self): 29 | assert (self.mode == 'fine' and self.split in ['train', 'val']) or \ 30 | (self.mode == 'coarse' and self.split in ['train', 'train_extra', 'val']) 31 | 32 | SUFIX = '_gtFine_labelIds.png' 33 | if self.mode == 'coarse': 34 | img_dir_name = 'leftImg8bit_trainextra' if self.split == 'train_extra' else 'leftImg8bit_trainvaltest' 35 | label_path = os.path.join(self.root, 'gtCoarse', 'gtCoarse', self.split) 36 | else: 37 | img_dir_name = 'leftImg8bit_trainvaltest' 38 | label_path = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split) 39 | image_path = os.path.join(self.root, img_dir_name, 'leftImg8bit', self.split) 40 | assert os.listdir(image_path) == os.listdir(label_path) 41 | 42 | image_paths, label_paths = [], [] 43 | for city in os.listdir(image_path): 44 | image_paths.extend(sorted(glob(os.path.join(image_path, city, '*.png')))) 45 | label_paths.extend(sorted(glob(os.path.join(label_path, city, f'*{SUFIX}')))) 46 | self.files = list(zip(image_paths, label_paths)) 47 | 48 | def _load_data(self, index): 49 | image_path, label_path = self.files[index] 50 | image_id = os.path.splitext(os.path.basename(image_path))[0] 51 | image = np.asarray(Image.open(image_path).convert('RGB'), dtype=np.float32) 52 | label = np.asarray(Image.open(label_path), dtype=np.int32) 53 | for k, v in self.id_to_trainId.items(): 54 | label[label == k] = v 55 | return image, label, image_id 56 | 57 | 58 | 59 | class CityScapes(BaseDataLoader): 60 | def __init__(self, data_dir, batch_size, split, crop_size=None, base_size=None, scale=True, num_workers=1, mode='fine', val=False, 61 | shuffle=False, flip=False, rotate=False, blur= False, augment=False, val_split= None, return_id=False): 62 | 63 | self.MEAN = [0.28689529, 0.32513294, 0.28389176] 64 | self.STD = [0.17613647, 0.18099176, 0.17772235] 65 | 66 | kwargs = { 67 | 'root': data_dir, 68 | 'split': split, 69 | 'mean': self.MEAN, 70 | 'std': self.STD, 71 | 'augment': augment, 72 | 'crop_size': crop_size, 73 | 'base_size': base_size, 74 | 'scale': scale, 75 | 'flip': flip, 76 | 'blur': blur, 77 | 'rotate': rotate, 78 | 'return_id': return_id, 79 | 'val': val 80 | } 81 | 82 | self.dataset = CityScapesDataset(mode=mode, **kwargs) 83 | super(CityScapes, self).__init__(self.dataset, batch_size, shuffle, num_workers, val_split) 84 | 85 | 86 | -------------------------------------------------------------------------------- /dataloaders/coco.py: -------------------------------------------------------------------------------- 1 | # Originally written by Kazuto Nakashima 2 | # https://github.com/kazuto1011/deeplab-pytorch 3 | 4 | from base import BaseDataSet, BaseDataLoader 5 | from PIL import Image 6 | from glob import glob 7 | import numpy as np 8 | import scipy.io as sio 9 | from utils import palette 10 | import torch 11 | import os 12 | import cv2 13 | 14 | class CocoStuff10k(BaseDataSet): 15 | def __init__(self, warp_image = True, **kwargs): 16 | self.warp_image = warp_image 17 | self.num_classes = 182 18 | self.palette = palette.COCO_palette 19 | super(CocoStuff10k, self).__init__(**kwargs) 20 | 21 | def _set_files(self): 22 | if self.split in ['train', 'test', 'all']: 23 | file_list = os.path.join(self.root, 'imageLists', self.split + '.txt') 24 | self.files = [name.rstrip() for name in tuple(open(file_list, "r"))] 25 | else: raise ValueError(f"Invalid split name {self.split} choose one of [train, test, all]") 26 | 27 | def _load_data(self, index): 28 | image_id = self.files[index] 29 | image_path = os.path.join(self.root, 'images', image_id + '.jpg') 30 | label_path = os.path.join(self.root, 'annotations', image_id + '.mat') 31 | image = np.asarray(Image.open(image_path), dtype=np.float32) 32 | label = sio.loadmat(label_path)['S'] 33 | label -= 1 # unlabeled (0 -> -1) 34 | label[label == -1] = 255 35 | if self.warp_image: 36 | image = cv2.resize(image, (513, 513), interpolation=cv2.INTER_LINEAR) 37 | label = np.asarray(Image.fromarray(label).resize((513, 513), resample=Image.NEAREST)) 38 | return image, label, image_id 39 | 40 | class CocoStuff164k(BaseDataSet): 41 | def __init__(self, **kwargs): 42 | self.num_classes = 182 43 | self.palette = palette.COCO_palette 44 | super(CocoStuff164k, self).__init__(**kwargs) 45 | 46 | def _set_files(self): 47 | if self.split in ['train2017', 'val2017']: 48 | file_list = sorted(glob(os.path.join(self.root, 'images', self.split + '/*.jpg'))) 49 | self.files = [os.path.basename(f).split('.')[0] for f in file_list] 50 | else: raise ValueError(f"Invalid split name {self.split}, either train2017 or val2017") 51 | 52 | def _load_data(self, index): 53 | image_id = self.files[index] 54 | image_path = os.path.join(self.root, 'images', self.split, image_id + '.jpg') 55 | label_path = os.path.join(self.root, 'annotations', self.split, image_id + '.png') 56 | image = np.asarray(Image.open(image_path).convert('RGB'), dtype=np.float32) 57 | label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) 58 | return image, label, image_id 59 | 60 | def get_parent_class(value, dictionary): 61 | for k, v in dictionary.items(): 62 | if isinstance(v, list): 63 | if value in v: 64 | yield k 65 | elif isinstance(v, dict): 66 | if value in list(v.keys()): 67 | yield k 68 | else: 69 | for res in get_parent_class(value, v): 70 | yield res 71 | 72 | class COCO(BaseDataLoader): 73 | def __init__(self, data_dir, batch_size, split, crop_size=None, base_size=None, scale=True, num_workers=1, partition = 'CocoStuff164k', 74 | shuffle=False, flip=False, rotate=False, blur= False, augment=False, val_split= None, return_id=False, val=False): 75 | 76 | self.MEAN = [0.43931922, 0.41310471, 0.37480941] 77 | self.STD = [0.24272706, 0.23649098, 0.23429529] 78 | 79 | kwargs = { 80 | 'root': data_dir, 81 | 'split': split, 82 | 'mean': self.MEAN, 83 | 'std': self.STD, 84 | 'augment': augment, 85 | 'crop_size': crop_size, 86 | 'base_size': base_size, 87 | 'scale': scale, 88 | 'flip': flip, 89 | 'blur': blur, 90 | 'rotate': rotate, 91 | 'return_id': return_id, 92 | 'val': val 93 | } 94 | 95 | if partition == 'CocoStuff10k': self.dataset = CocoStuff10k(**kwargs) 96 | elif partition == 'CocoStuff164k': self.dataset = CocoStuff164k(**kwargs) 97 | else: raise ValueError(f"Please choose either CocoStuff10k / CocoStuff164k") 98 | 99 | super(COCO, self).__init__(self.dataset, batch_size, shuffle, num_workers, val_split) 100 | 101 | -------------------------------------------------------------------------------- /dataloaders/deepscene.py: -------------------------------------------------------------------------------- 1 | # Originally written by Dustin Franklin, adapted by Markus Schiffer 2 | # https://github.com/dusty-nv/pytorch-segmentation/blob/master/datasets/deepscene.py 3 | 4 | from base import BaseDataSet, BaseDataLoader 5 | from utils import palette 6 | import numpy as np 7 | import os 8 | import re 9 | from PIL import Image 10 | 11 | 12 | class DeepSceneDataset(BaseDataSet): 13 | """ 14 | DeepScene Freibrug Forest dataset 15 | http://deepscene.cs.uni-freiburg.de/ 16 | """ 17 | def __init__(self, **kwargs): 18 | self.num_classes = 7 19 | self.palette = palette.DeepScene_palette 20 | 21 | self.mask_mapping = {} 22 | 23 | for i in range(0, len(self.palette), 3): 24 | self.mask_mapping[tuple(self.palette[i:i+3])] = i // 3 25 | 26 | self.images = [] 27 | self.targets = [] 28 | 29 | super(DeepSceneDataset, self).__init__(**kwargs) 30 | 31 | def gather_images(self, images_path, labels_path): 32 | def sorted_alphanumeric(data): 33 | convert = lambda text: int(text) if text.isdigit() else text.lower() 34 | alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 35 | return sorted(data, key=alphanum_key) 36 | 37 | image_files = sorted_alphanumeric(os.listdir(images_path)) 38 | label_files = sorted_alphanumeric(os.listdir(labels_path)) 39 | 40 | if len(image_files) != len(label_files): 41 | print('warning: images path has a different number of files than labels path') 42 | print(' ({:d} files) - {:s}'.format(len(image_files), images_path)) 43 | print(' ({:d} files) - {:s}'.format(len(label_files), labels_path)) 44 | 45 | for n in range(len(image_files)): 46 | image_files[n] = os.path.join(images_path, image_files[n]) 47 | label_files[n] = os.path.join(labels_path, label_files[n]) 48 | 49 | return image_files, label_files 50 | 51 | def _set_files(self): 52 | if self.split in ["training", "validation"]: 53 | 54 | if self.split == "training": 55 | train_images, train_targets = self.gather_images(os.path.join(self.root, 'train/rgb'), 56 | os.path.join(self.root, 'train/GT_color')) 57 | 58 | self.images.extend(train_images) 59 | self.targets.extend(train_targets) 60 | 61 | elif self.split == "validation": 62 | val_images, val_targets = self.gather_images(os.path.join(self.root, 'test/rgb'), 63 | os.path.join(self.root, 'test/GT_color')) 64 | 65 | self.images.extend(val_images) 66 | self.targets.extend(val_targets) 67 | 68 | self.files = self.images 69 | 70 | else: raise ValueError(f"Invalid split name {self.split}") 71 | 72 | def _load_data(self, index): 73 | image_id = self.images[index] 74 | image = np.asarray(Image.open(self.images[index]).convert("RGB"), dtype=np.float32) 75 | target_rgb = np.asarray(Image.open(self.targets[index]).convert("RGB"), dtype=np.float32) 76 | target = np.zeros(target_rgb.shape[:2], dtype=np.int32) 77 | for cls in self.mask_mapping: 78 | target[(target_rgb == cls).all(axis=2)] = self.mask_mapping[cls] 79 | return image, target, image_id 80 | 81 | class DeepScene(BaseDataLoader): 82 | def __init__(self, data_dir, batch_size, split, crop_size=None, base_size=None, scale=True, num_workers=1, val=False, 83 | shuffle=False, flip=False, rotate=False, blur= False, augment=False, val_split= None, return_id=False): 84 | 85 | self.MEAN = [0.485, 0.456, 0.406] 86 | self.STD = [0.229, 0.224, 0.225] 87 | 88 | kwargs = { 89 | 'root': data_dir, 90 | 'split': split, 91 | 'mean': self.MEAN, 92 | 'std': self.STD, 93 | 'augment': augment, 94 | 'crop_size': crop_size, 95 | 'base_size': base_size, 96 | 'scale': scale, 97 | 'flip': flip, 98 | 'blur': blur, 99 | 'rotate': rotate, 100 | 'return_id': return_id, 101 | 'val': val 102 | } 103 | 104 | self.dataset = DeepSceneDataset(**kwargs) 105 | super().__init__(self.dataset, batch_size, shuffle, num_workers, val_split) 106 | -------------------------------------------------------------------------------- /dataloaders/labels/ade20k.txt: -------------------------------------------------------------------------------- 1 | Idx Ratio Train Val Name 2 | 1 0.1576 11664 1172 wall 3 | 2 0.1072 6046 612 building, edifice 4 | 3 0.0878 8265 796 sky 5 | 4 0.0621 9336 917 floor, flooring 6 | 5 0.0480 6678 641 tree 7 | 6 0.0450 6604 643 ceiling 8 | 7 0.0398 4023 408 road, route 9 | 8 0.0231 1906 199 bed 10 | 9 0.0198 4688 460 windowpane, window 11 | 10 0.0183 2423 225 grass 12 | 11 0.0181 2874 294 cabinet 13 | 12 0.0166 3068 310 sidewalk, pavement 14 | 13 0.0160 5075 526 person, individual, someone, somebody, mortal, soul 15 | 14 0.0151 1804 190 earth, ground 16 | 15 0.0118 6666 796 door, double door 17 | 16 0.0110 4269 411 table 18 | 17 0.0109 1691 160 mountain, mount 19 | 18 0.0104 3999 441 plant, flora, plant life 20 | 19 0.0104 2149 217 curtain, drape, drapery, mantle, pall 21 | 20 0.0103 3261 318 chair 22 | 21 0.0098 3164 306 car, auto, automobile, machine, motorcar 23 | 22 0.0074 709 75 water 24 | 23 0.0067 3296 315 painting, picture 25 | 24 0.0065 1191 106 sofa, couch, lounge 26 | 25 0.0061 1516 162 shelf 27 | 26 0.0060 667 69 house 28 | 27 0.0053 651 57 sea 29 | 28 0.0052 1847 224 mirror 30 | 29 0.0046 1158 128 rug, carpet, carpeting 31 | 30 0.0044 480 44 field 32 | 31 0.0044 1172 98 armchair 33 | 32 0.0044 1292 184 seat 34 | 33 0.0033 1386 138 fence, fencing 35 | 34 0.0031 698 61 desk 36 | 35 0.0030 781 73 rock, stone 37 | 36 0.0027 380 43 wardrobe, closet, press 38 | 37 0.0026 3089 302 lamp 39 | 38 0.0024 404 37 bathtub, bathing tub, bath, tub 40 | 39 0.0024 804 99 railing, rail 41 | 40 0.0023 1453 153 cushion 42 | 41 0.0023 411 37 base, pedestal, stand 43 | 42 0.0022 1440 162 box 44 | 43 0.0022 800 77 column, pillar 45 | 44 0.0020 2650 298 signboard, sign 46 | 45 0.0019 549 46 chest of drawers, chest, bureau, dresser 47 | 46 0.0019 367 36 counter 48 | 47 0.0018 311 30 sand 49 | 48 0.0018 1181 122 sink 50 | 49 0.0018 287 23 skyscraper 51 | 50 0.0018 468 38 fireplace, hearth, open fireplace 52 | 51 0.0018 402 43 refrigerator, icebox 53 | 52 0.0018 130 12 grandstand, covered stand 54 | 53 0.0018 561 64 path 55 | 54 0.0017 880 102 stairs, steps 56 | 55 0.0017 86 12 runway 57 | 56 0.0017 172 11 case, display case, showcase, vitrine 58 | 57 0.0017 198 18 pool table, billiard table, snooker table 59 | 58 0.0017 930 109 pillow 60 | 59 0.0015 139 18 screen door, screen 61 | 60 0.0015 564 52 stairway, staircase 62 | 61 0.0015 320 26 river 63 | 62 0.0015 261 29 bridge, span 64 | 63 0.0014 275 22 bookcase 65 | 64 0.0014 335 60 blind, screen 66 | 65 0.0014 792 75 coffee table, cocktail table 67 | 66 0.0014 395 49 toilet, can, commode, crapper, pot, potty, stool, throne 68 | 67 0.0014 1309 138 flower 69 | 68 0.0013 1112 113 book 70 | 69 0.0013 266 27 hill 71 | 70 0.0013 659 66 bench 72 | 71 0.0012 331 31 countertop 73 | 72 0.0012 531 56 stove, kitchen stove, range, kitchen range, cooking stove 74 | 73 0.0012 369 36 palm, palm tree 75 | 74 0.0012 144 9 kitchen island 76 | 75 0.0011 265 29 computer, computing machine, computing device, data processor, electronic computer, information processing system 77 | 76 0.0010 324 33 swivel chair 78 | 77 0.0009 304 27 boat 79 | 78 0.0009 170 20 bar 80 | 79 0.0009 68 6 arcade machine 81 | 80 0.0009 65 8 hovel, hut, hutch, shack, shanty 82 | 81 0.0009 248 25 bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle 83 | 82 0.0008 492 49 towel 84 | 83 0.0008 2510 269 light, light source 85 | 84 0.0008 440 39 truck, motortruck 86 | 85 0.0008 147 18 tower 87 | 86 0.0008 583 56 chandelier, pendant, pendent 88 | 87 0.0007 533 61 awning, sunshade, sunblind 89 | 88 0.0007 1989 239 streetlight, street lamp 90 | 89 0.0007 71 5 booth, cubicle, stall, kiosk 91 | 90 0.0007 618 53 television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box 92 | 91 0.0007 135 12 airplane, aeroplane, plane 93 | 92 0.0007 83 5 dirt track 94 | 93 0.0007 178 17 apparel, wearing apparel, dress, clothes 95 | 94 0.0006 1003 104 pole 96 | 95 0.0006 182 12 land, ground, soil 97 | 96 0.0006 452 50 bannister, banister, balustrade, balusters, handrail 98 | 97 0.0006 42 6 escalator, moving staircase, moving stairway 99 | 98 0.0006 307 31 ottoman, pouf, pouffe, puff, hassock 100 | 99 0.0006 965 114 bottle 101 | 100 0.0006 117 13 buffet, counter, sideboard 102 | 101 0.0006 354 35 poster, posting, placard, notice, bill, card 103 | 102 0.0006 108 9 stage 104 | 103 0.0006 557 55 van 105 | 104 0.0006 52 4 ship 106 | 105 0.0005 99 5 fountain 107 | 106 0.0005 57 4 conveyer belt, conveyor belt, conveyer, conveyor, transporter 108 | 107 0.0005 292 31 canopy 109 | 108 0.0005 77 9 washer, automatic washer, washing machine 110 | 109 0.0005 340 38 plaything, toy 111 | 110 0.0005 66 3 swimming pool, swimming bath, natatorium 112 | 111 0.0005 465 49 stool 113 | 112 0.0005 50 4 barrel, cask 114 | 113 0.0005 622 75 basket, handbasket 115 | 114 0.0005 80 9 waterfall, falls 116 | 115 0.0005 59 3 tent, collapsible shelter 117 | 116 0.0005 531 72 bag 118 | 117 0.0005 282 30 minibike, motorbike 119 | 118 0.0005 73 7 cradle 120 | 119 0.0005 435 44 oven 121 | 120 0.0005 136 25 ball 122 | 121 0.0005 116 24 food, solid food 123 | 122 0.0004 266 31 step, stair 124 | 123 0.0004 58 12 tank, storage tank 125 | 124 0.0004 418 83 trade name, brand name, brand, marque 126 | 125 0.0004 319 43 microwave, microwave oven 127 | 126 0.0004 1193 139 pot, flowerpot 128 | 127 0.0004 97 23 animal, animate being, beast, brute, creature, fauna 129 | 128 0.0004 347 36 bicycle, bike, wheel, cycle 130 | 129 0.0004 52 5 lake 131 | 130 0.0004 246 22 dishwasher, dish washer, dishwashing machine 132 | 131 0.0004 108 13 screen, silver screen, projection screen 133 | 132 0.0004 201 30 blanket, cover 134 | 133 0.0004 285 21 sculpture 135 | 134 0.0004 268 27 hood, exhaust hood 136 | 135 0.0003 1020 108 sconce 137 | 136 0.0003 1282 122 vase 138 | 137 0.0003 528 65 traffic light, traffic signal, stoplight 139 | 138 0.0003 453 57 tray 140 | 139 0.0003 671 100 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 141 | 140 0.0003 397 44 fan 142 | 141 0.0003 92 8 pier, wharf, wharfage, dock 143 | 142 0.0003 228 18 crt screen 144 | 143 0.0003 570 59 plate 145 | 144 0.0003 217 22 monitor, monitoring device 146 | 145 0.0003 206 19 bulletin board, notice board 147 | 146 0.0003 130 14 shower 148 | 147 0.0003 178 28 radiator 149 | 148 0.0002 504 57 glass, drinking glass 150 | 149 0.0002 775 96 clock 151 | 150 0.0002 421 56 flag 152 | -------------------------------------------------------------------------------- /dataloaders/labels/cityscapes.txt: -------------------------------------------------------------------------------- 1 | name id trainId category catId hasInstances ignoreInEval color 2 | -------------------------------------------------------------------------------------------------------------------- 3 | unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 4 | ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 5 | rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 6 | out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 7 | static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 8 | dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 9 | ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 10 | road' , 7 , 0 , 'ground' , 1 , False , False , (128, 64,128) ), 11 | sidewalk' , 8 , 1 , 'ground' , 1 , False , False , (244, 35,232) ), 12 | parking' , 9 , 255 , 'ground' , 1 , False , True , (250,170,160) ), 13 | rail track' , 10 , 255 , 'ground' , 1 , False , True , (230,150,140) ), 14 | building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 15 | wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 16 | fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 17 | guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 18 | bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 19 | tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 20 | pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 21 | polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 22 | traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 23 | traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 24 | vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 25 | terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 26 | sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 27 | person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 28 | rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 29 | car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 30 | truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 31 | bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 32 | caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 33 | trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 34 | train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 35 | motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 36 | bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 37 | license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), -------------------------------------------------------------------------------- /dataloaders/labels/coco.txt: -------------------------------------------------------------------------------- 1 | 0 background 2 | 1 person 3 | 2 bicycle 4 | 3 car 5 | 4 motorcycle 6 | 5 airplane 7 | 6 bus 8 | 7 train 9 | 8 truck 10 | 9 boat 11 | 10 traffic light 12 | 11 fire hydrant 13 | 12 street sign 14 | 13 stop sign 15 | 14 parking meter 16 | 15 bench 17 | 16 bird 18 | 17 cat 19 | 18 dog 20 | 19 horse 21 | 20 sheep 22 | 21 cow 23 | 22 elephant 24 | 23 bear 25 | 24 zebra 26 | 25 giraffe 27 | 26 hat 28 | 27 backpack 29 | 28 umbrella 30 | 29 shoe 31 | 30 eye glasses 32 | 31 handbag 33 | 32 tie 34 | 33 suitcase 35 | 34 frisbee 36 | 35 skis 37 | 36 snowboard 38 | 37 sports ball 39 | 38 kite 40 | 39 baseball bat 41 | 40 baseball glove 42 | 41 skateboard 43 | 42 surfboard 44 | 43 tennis racket 45 | 44 bottle 46 | 45 plate 47 | 46 wine glass 48 | 47 cup 49 | 48 fork 50 | 49 knife 51 | 50 spoon 52 | 51 bowl 53 | 52 banana 54 | 53 apple 55 | 54 sandwich 56 | 55 orange 57 | 56 broccoli 58 | 57 carrot 59 | 58 hot dog 60 | 59 pizza 61 | 60 donut 62 | 61 cake 63 | 62 chair 64 | 63 couch 65 | 64 potted plant 66 | 65 bed 67 | 66 mirror 68 | 67 dining table 69 | 68 window 70 | 69 desk 71 | 70 toilet 72 | 71 door 73 | 72 tv 74 | 73 laptop 75 | 74 mouse 76 | 75 remote 77 | 76 keyboard 78 | 77 cell phone 79 | 78 microwave 80 | 79 oven 81 | 80 toaster 82 | 81 sink 83 | 82 refrigerator 84 | 83 blender 85 | 84 book 86 | 85 clock 87 | 86 vase 88 | 87 scissors 89 | 88 teddy bear 90 | 89 hair drier 91 | 90 toothbrush -------------------------------------------------------------------------------- /dataloaders/labels/cocostuff_hierarchy.json: -------------------------------------------------------------------------------- 1 | { 2 | "things": { 3 | "indoor-super-things": { 4 | "appliance-things": [ 5 | "microwave", 6 | "oven", 7 | "toaster", 8 | "sink", 9 | "refrigerator", 10 | "blender" 11 | ], 12 | "electronic-things": [ 13 | "tv", 14 | "laptop", 15 | "mouse", 16 | "remote", 17 | "keyboard", 18 | "cell phone" 19 | ], 20 | "food-things": [ 21 | "banana", 22 | "apple", 23 | "sandwich", 24 | "orange", 25 | "broccoli", 26 | "carrot", 27 | "hot dog", 28 | "pizza", 29 | "donut", 30 | "cake" 31 | ], 32 | "furniture-things": [ 33 | "chair", 34 | "couch", 35 | "potted plant", 36 | "bed", 37 | "mirror", 38 | "dining table", 39 | "window", 40 | "desk", 41 | "toilet", 42 | "door" 43 | ], 44 | "indoor-things": [ 45 | "book", 46 | "clock", 47 | "vase", 48 | "scissors", 49 | "teddy bear", 50 | "hair drier", 51 | "toothbrush", 52 | "hair brush" 53 | ], 54 | "kitchen-things": [ 55 | "bottle", 56 | "plate", 57 | "wine glass", 58 | "cup", 59 | "fork", 60 | "knife", 61 | "spoon", 62 | "bowl" 63 | ] 64 | }, 65 | "outdoor-super-things": { 66 | "accessory-things": [ 67 | "hat", 68 | "backpack", 69 | "umbrella", 70 | "shoe", 71 | "eye glasses", 72 | "handbag", 73 | "tie", 74 | "suitcase" 75 | ], 76 | "animal-things": [ 77 | "bird", 78 | "cat", 79 | "dog", 80 | "horse", 81 | "sheep", 82 | "cow", 83 | "elephant", 84 | "bear", 85 | "zebra", 86 | "giraffe" 87 | ], 88 | "outdoor-things": [ 89 | "traffic light", 90 | "fire hydrant", 91 | "street sign", 92 | "stop sign", 93 | "parking meter", 94 | "bench" 95 | ], 96 | "person-things": [ 97 | "person" 98 | ], 99 | "sports-things": [ 100 | "frisbee", 101 | "skis", 102 | "snowboard", 103 | "sports ball", 104 | "kite", 105 | "baseball bat", 106 | "baseball glove", 107 | "skateboard", 108 | "surfboard", 109 | "tennis racket" 110 | ], 111 | "vehicle-things": [ 112 | "bicycle", 113 | "car", 114 | "motorcycle", 115 | "airplane", 116 | "bus", 117 | "train", 118 | "truck", 119 | "boat" 120 | ] 121 | } 122 | }, 123 | "stuff": { 124 | "indoor-super-stuff": { 125 | "ceiling-stuff": [ 126 | "ceiling-tile", 127 | "ceiling-other" 128 | ], 129 | "floor-stuff": [ 130 | "carpet", 131 | "floor-tile", 132 | "floor-wood", 133 | "floor-marble", 134 | "floor-stone", 135 | "floor-other" 136 | ], 137 | "food-stuff": [ 138 | "fruit", 139 | "salad", 140 | "vegetable", 141 | "food-other" 142 | ], 143 | "furniture-stuff": [ 144 | "door-stuff", 145 | "desk-stuff", 146 | "table", 147 | "shelf", 148 | "cabinet", 149 | "cupboard", 150 | "mirror-stuff", 151 | "counter", 152 | "light", 153 | "stairs", 154 | "furniture-other" 155 | ], 156 | "rawmaterial-stuff": [ 157 | "cardboard", 158 | "paper", 159 | "plastic", 160 | "metal" 161 | ], 162 | "textile-stuff": [ 163 | "rug", 164 | "mat", 165 | "towel", 166 | "napkin", 167 | "clothes", 168 | "cloth", 169 | "curtain", 170 | "blanket", 171 | "pillow", 172 | "banner", 173 | "textile-other" 174 | ], 175 | "wall-stuff": [ 176 | "wall-tile", 177 | "wall-panel", 178 | "wall-wood", 179 | "wall-brick", 180 | "wall-stone", 181 | "wall-concrete", 182 | "wall-other" 183 | ], 184 | "window-stuff": [ 185 | "window-blind", 186 | "window-other" 187 | ] 188 | }, 189 | "outdoor-super-stuff": { 190 | "building-stuff": [ 191 | "house", 192 | "skyscraper", 193 | "bridge", 194 | "tent", 195 | "roof", 196 | "building-other" 197 | ], 198 | "ground-stuff": [ 199 | "sand", 200 | "snow", 201 | "dirt", 202 | "mud", 203 | "gravel", 204 | "road", 205 | "pavement", 206 | "railroad", 207 | "platform", 208 | "playingfield", 209 | "ground-other" 210 | ], 211 | "plant-stuff": [ 212 | "grass", 213 | "tree", 214 | "bush", 215 | "leaves", 216 | "flower", 217 | "branch", 218 | "moss", 219 | "straw", 220 | "plant-other" 221 | ], 222 | "sky-stuff": [ 223 | "clouds", 224 | "sky-other" 225 | ], 226 | "solid-stuff": [ 227 | "wood", 228 | "rock", 229 | "stone", 230 | "mountain", 231 | "hill", 232 | "solid-other" 233 | ], 234 | "structural-stuff": [ 235 | "fence", 236 | "cage", 237 | "net", 238 | "railing", 239 | "structural-other" 240 | ], 241 | "water-stuff": [ 242 | "fog", 243 | "river", 244 | "sea", 245 | "waterdrops", 246 | "water-other" 247 | ] 248 | } 249 | } 250 | } -------------------------------------------------------------------------------- /dataloaders/labels/cocostuff_labels.txt: -------------------------------------------------------------------------------- 1 | 0 unlabeled 2 | 1 person 3 | 2 bicycle 4 | 3 car 5 | 4 motorcycle 6 | 5 airplane 7 | 6 bus 8 | 7 train 9 | 8 truck 10 | 9 boat 11 | 10 traffic light 12 | 11 fire hydrant 13 | 12 street sign 14 | 13 stop sign 15 | 14 parking meter 16 | 15 bench 17 | 16 bird 18 | 17 cat 19 | 18 dog 20 | 19 horse 21 | 20 sheep 22 | 21 cow 23 | 22 elephant 24 | 23 bear 25 | 24 zebra 26 | 25 giraffe 27 | 26 hat 28 | 27 backpack 29 | 28 umbrella 30 | 29 shoe 31 | 30 eye glasses 32 | 31 handbag 33 | 32 tie 34 | 33 suitcase 35 | 34 frisbee 36 | 35 skis 37 | 36 snowboard 38 | 37 sports ball 39 | 38 kite 40 | 39 baseball bat 41 | 40 baseball glove 42 | 41 skateboard 43 | 42 surfboard 44 | 43 tennis racket 45 | 44 bottle 46 | 45 plate 47 | 46 wine glass 48 | 47 cup 49 | 48 fork 50 | 49 knife 51 | 50 spoon 52 | 51 bowl 53 | 52 banana 54 | 53 apple 55 | 54 sandwich 56 | 55 orange 57 | 56 broccoli 58 | 57 carrot 59 | 58 hot dog 60 | 59 pizza 61 | 60 donut 62 | 61 cake 63 | 62 chair 64 | 63 couch 65 | 64 potted plant 66 | 65 bed 67 | 66 mirror 68 | 67 dining table 69 | 68 window 70 | 69 desk 71 | 70 toilet 72 | 71 door 73 | 72 tv 74 | 73 laptop 75 | 74 mouse 76 | 75 remote 77 | 76 keyboard 78 | 77 cell phone 79 | 78 microwave 80 | 79 oven 81 | 80 toaster 82 | 81 sink 83 | 82 refrigerator 84 | 83 blender 85 | 84 book 86 | 85 clock 87 | 86 vase 88 | 87 scissors 89 | 88 teddy bear 90 | 89 hair drier 91 | 90 toothbrush 92 | 91 hair brush 93 | 92 banner 94 | 93 blanket 95 | 94 branch 96 | 95 bridge 97 | 96 building-other 98 | 97 bush 99 | 98 cabinet 100 | 99 cage 101 | 100 cardboard 102 | 101 carpet 103 | 102 ceiling-other 104 | 103 ceiling-tile 105 | 104 cloth 106 | 105 clothes 107 | 106 clouds 108 | 107 counter 109 | 108 cupboard 110 | 109 curtain 111 | 110 desk-stuff 112 | 111 dirt 113 | 112 door-stuff 114 | 113 fence 115 | 114 floor-marble 116 | 115 floor-other 117 | 116 floor-stone 118 | 117 floor-tile 119 | 118 floor-wood 120 | 119 flower 121 | 120 fog 122 | 121 food-other 123 | 122 fruit 124 | 123 furniture-other 125 | 124 grass 126 | 125 gravel 127 | 126 ground-other 128 | 127 hill 129 | 128 house 130 | 129 leaves 131 | 130 light 132 | 131 mat 133 | 132 metal 134 | 133 mirror-stuff 135 | 134 moss 136 | 135 mountain 137 | 136 mud 138 | 137 napkin 139 | 138 net 140 | 139 paper 141 | 140 pavement 142 | 141 pillow 143 | 142 plant-other 144 | 143 plastic 145 | 144 platform 146 | 145 playingfield 147 | 146 railing 148 | 147 railroad 149 | 148 river 150 | 149 road 151 | 150 rock 152 | 151 roof 153 | 152 rug 154 | 153 salad 155 | 154 sand 156 | 155 sea 157 | 156 shelf 158 | 157 sky-other 159 | 158 skyscraper 160 | 159 snow 161 | 160 solid-other 162 | 161 stairs 163 | 162 stone 164 | 163 straw 165 | 164 structural-other 166 | 165 table 167 | 166 tent 168 | 167 textile-other 169 | 168 towel 170 | 169 tree 171 | 170 vegetable 172 | 171 wall-brick 173 | 172 wall-concrete 174 | 173 wall-other 175 | 174 wall-panel 176 | 175 wall-stone 177 | 176 wall-tile 178 | 177 wall-wood 179 | 178 water-other 180 | 179 waterdrops 181 | 180 window-blind 182 | 181 window-other 183 | 182 wood 184 | -------------------------------------------------------------------------------- /dataloaders/labels/deepscene.txt: -------------------------------------------------------------------------------- 1 | 0 Void 2 | 1 Road 3 | 2 Grass 4 | 3 Vegetation 5 | 4 Tree 6 | 5 Sky 7 | 6 Obstacle -------------------------------------------------------------------------------- /dataloaders/labels/voc.txt: -------------------------------------------------------------------------------- 1 | 0 __background__ 2 | 1 aeroplane 3 | 2 bicycle 4 | 3 bird 5 | 4 boat 6 | 5 bottle 7 | 6 bus 8 | 7 car 9 | 8 cat 10 | 9 chair 11 | 10 cow 12 | 11 diningtable 13 | 12 dog 14 | 13 horse 15 | 14 motorbike 16 | 15 person 17 | 16 pottedplant 18 | 17 sheep 19 | 18 sofa 20 | 19 train 21 | 20 tvmonitor -------------------------------------------------------------------------------- /dataloaders/labels/voc_context_classes-400.txt: -------------------------------------------------------------------------------- 1 | 1: accordion 2 | 2: aeroplane 3 | 3: air conditioner 4 | 4: antenna 5 | 5: artillery 6 | 6: ashtray 7 | 7: atrium 8 | 8: baby carriage 9 | 9: bag 10 | 10: ball 11 | 11: balloon 12 | 12: bamboo weaving 13 | 13: barrel 14 | 14: baseball bat 15 | 15: basket 16 | 16: basketball backboard 17 | 17: bathtub 18 | 18: bed 19 | 19: bedclothes 20 | 20: beer 21 | 21: bell 22 | 22: bench 23 | 23: bicycle 24 | 24: binoculars 25 | 25: bird 26 | 26: bird cage 27 | 27: bird feeder 28 | 28: bird nest 29 | 29: blackboard 30 | 30: board 31 | 31: boat 32 | 32: bone 33 | 33: book 34 | 34: bottle 35 | 35: bottle opener 36 | 36: bowl 37 | 37: box 38 | 38: bracelet 39 | 39: brick 40 | 40: bridge 41 | 41: broom 42 | 42: brush 43 | 43: bucket 44 | 44: building 45 | 45: bus 46 | 46: cabinet 47 | 47: cabinet door 48 | 48: cage 49 | 49: cake 50 | 50: calculator 51 | 51: calendar 52 | 52: camel 53 | 53: camera 54 | 54: camera lens 55 | 55: can 56 | 56: candle 57 | 57: candle holder 58 | 58: cap 59 | 59: car 60 | 60: card 61 | 61: cart 62 | 62: case 63 | 63: casette recorder 64 | 64: cash register 65 | 65: cat 66 | 66: cd 67 | 67: cd player 68 | 68: ceiling 69 | 69: cell phone 70 | 70: cello 71 | 71: chain 72 | 72: chair 73 | 73: chessboard 74 | 74: chicken 75 | 75: chopstick 76 | 76: clip 77 | 77: clippers 78 | 78: clock 79 | 79: closet 80 | 80: cloth 81 | 81: clothes tree 82 | 82: coffee 83 | 83: coffee machine 84 | 84: comb 85 | 85: computer 86 | 86: concrete 87 | 87: cone 88 | 88: container 89 | 89: control booth 90 | 90: controller 91 | 91: cooker 92 | 92: copying machine 93 | 93: coral 94 | 94: cork 95 | 95: corkscrew 96 | 96: counter 97 | 97: court 98 | 98: cow 99 | 99: crabstick 100 | 100: crane 101 | 101: crate 102 | 102: cross 103 | 103: crutch 104 | 104: cup 105 | 105: curtain 106 | 106: cushion 107 | 107: cutting board 108 | 108: dais 109 | 109: disc 110 | 110: disc case 111 | 111: dishwasher 112 | 112: dock 113 | 113: dog 114 | 114: dolphin 115 | 115: door 116 | 116: drainer 117 | 117: dray 118 | 118: drink dispenser 119 | 119: drinking machine 120 | 120: drop 121 | 121: drug 122 | 122: drum 123 | 123: drum kit 124 | 124: duck 125 | 125: dumbbell 126 | 126: earphone 127 | 127: earrings 128 | 128: egg 129 | 129: electric fan 130 | 130: electric iron 131 | 131: electric pot 132 | 132: electric saw 133 | 133: electronic keyboard 134 | 134: engine 135 | 135: envelope 136 | 136: equipment 137 | 137: escalator 138 | 138: exhibition booth 139 | 139: extinguisher 140 | 140: eyeglass 141 | 141: fan 142 | 142: faucet 143 | 143: fax machine 144 | 144: fence 145 | 145: ferris wheel 146 | 146: fire extinguisher 147 | 147: fire hydrant 148 | 148: fire place 149 | 149: fish 150 | 150: fish tank 151 | 151: fishbowl 152 | 152: fishing net 153 | 153: fishing pole 154 | 154: flag 155 | 155: flagstaff 156 | 156: flame 157 | 157: flashlight 158 | 158: floor 159 | 159: flower 160 | 160: fly 161 | 161: foam 162 | 162: food 163 | 163: footbridge 164 | 164: forceps 165 | 165: fork 166 | 166: forklift 167 | 167: fountain 168 | 168: fox 169 | 169: frame 170 | 170: fridge 171 | 171: frog 172 | 172: fruit 173 | 173: funnel 174 | 174: furnace 175 | 175: game controller 176 | 176: game machine 177 | 177: gas cylinder 178 | 178: gas hood 179 | 179: gas stove 180 | 180: gift box 181 | 181: glass 182 | 182: glass marble 183 | 183: globe 184 | 184: glove 185 | 185: goal 186 | 186: grandstand 187 | 187: grass 188 | 188: gravestone 189 | 189: ground 190 | 190: guardrail 191 | 191: guitar 192 | 192: gun 193 | 193: hammer 194 | 194: hand cart 195 | 195: handle 196 | 196: handrail 197 | 197: hanger 198 | 198: hard disk drive 199 | 199: hat 200 | 200: hay 201 | 201: headphone 202 | 202: heater 203 | 203: helicopter 204 | 204: helmet 205 | 205: holder 206 | 206: hook 207 | 207: horse 208 | 208: horse-drawn carriage 209 | 209: hot-air balloon 210 | 210: hydrovalve 211 | 211: ice 212 | 212: inflator pump 213 | 213: ipod 214 | 214: iron 215 | 215: ironing board 216 | 216: jar 217 | 217: kart 218 | 218: kettle 219 | 219: key 220 | 220: keyboard 221 | 221: kitchen range 222 | 222: kite 223 | 223: knife 224 | 224: knife block 225 | 225: ladder 226 | 226: ladder truck 227 | 227: ladle 228 | 228: laptop 229 | 229: leaves 230 | 230: lid 231 | 231: life buoy 232 | 232: light 233 | 233: light bulb 234 | 234: lighter 235 | 235: line 236 | 236: lion 237 | 237: lobster 238 | 238: lock 239 | 239: machine 240 | 240: mailbox 241 | 241: mannequin 242 | 242: map 243 | 243: mask 244 | 244: mat 245 | 245: match book 246 | 246: mattress 247 | 247: menu 248 | 248: metal 249 | 249: meter box 250 | 250: microphone 251 | 251: microwave 252 | 252: mirror 253 | 253: missile 254 | 254: model 255 | 255: money 256 | 256: monkey 257 | 257: mop 258 | 258: motorbike 259 | 259: mountain 260 | 260: mouse 261 | 261: mouse pad 262 | 262: musical instrument 263 | 263: napkin 264 | 264: net 265 | 265: newspaper 266 | 266: oar 267 | 267: ornament 268 | 268: outlet 269 | 269: oven 270 | 270: oxygen bottle 271 | 271: pack 272 | 272: pan 273 | 273: paper 274 | 274: paper box 275 | 275: paper cutter 276 | 276: parachute 277 | 277: parasol 278 | 278: parterre 279 | 279: patio 280 | 280: pelage 281 | 281: pen 282 | 282: pen container 283 | 283: pencil 284 | 284: person 285 | 285: photo 286 | 286: piano 287 | 287: picture 288 | 288: pig 289 | 289: pillar 290 | 290: pillow 291 | 291: pipe 292 | 292: pitcher 293 | 293: plant 294 | 294: plastic 295 | 295: plate 296 | 296: platform 297 | 297: player 298 | 298: playground 299 | 299: pliers 300 | 300: plume 301 | 301: poker 302 | 302: poker chip 303 | 303: pole 304 | 304: pool table 305 | 305: postcard 306 | 306: poster 307 | 307: pot 308 | 308: pottedplant 309 | 309: printer 310 | 310: projector 311 | 311: pumpkin 312 | 312: rabbit 313 | 313: racket 314 | 314: radiator 315 | 315: radio 316 | 316: rail 317 | 317: rake 318 | 318: ramp 319 | 319: range hood 320 | 320: receiver 321 | 321: recorder 322 | 322: recreational machines 323 | 323: remote control 324 | 324: road 325 | 325: robot 326 | 326: rock 327 | 327: rocket 328 | 328: rocking horse 329 | 329: rope 330 | 330: rug 331 | 331: ruler 332 | 332: runway 333 | 333: saddle 334 | 334: sand 335 | 335: saw 336 | 336: scale 337 | 337: scanner 338 | 338: scissors 339 | 339: scoop 340 | 340: screen 341 | 341: screwdriver 342 | 342: sculpture 343 | 343: scythe 344 | 344: sewer 345 | 345: sewing machine 346 | 346: shed 347 | 347: sheep 348 | 348: shell 349 | 349: shelves 350 | 350: shoe 351 | 351: shopping cart 352 | 352: shovel 353 | 353: sidecar 354 | 354: sidewalk 355 | 355: sign 356 | 356: signal light 357 | 357: sink 358 | 358: skateboard 359 | 359: ski 360 | 360: sky 361 | 361: sled 362 | 362: slippers 363 | 363: smoke 364 | 364: snail 365 | 365: snake 366 | 366: snow 367 | 367: snowmobiles 368 | 368: sofa 369 | 369: spanner 370 | 370: spatula 371 | 371: speaker 372 | 372: speed bump 373 | 373: spice container 374 | 374: spoon 375 | 375: sprayer 376 | 376: squirrel 377 | 377: stage 378 | 378: stair 379 | 379: stapler 380 | 380: stick 381 | 381: sticky note 382 | 382: stone 383 | 383: stool 384 | 384: stove 385 | 385: straw 386 | 386: stretcher 387 | 387: sun 388 | 388: sunglass 389 | 389: sunshade 390 | 390: surveillance camera 391 | 391: swan 392 | 392: sweeper 393 | 393: swim ring 394 | 394: swimming pool 395 | 395: swing 396 | 396: switch 397 | 397: table 398 | 398: tableware 399 | 399: tank 400 | 400: tap 401 | 401: tape 402 | 402: tarp 403 | 403: telephone 404 | 404: telephone booth 405 | 405: tent 406 | 406: tire 407 | 407: toaster 408 | 408: toilet 409 | 409: tong 410 | 410: tool 411 | 411: toothbrush 412 | 412: towel 413 | 413: toy 414 | 414: toy car 415 | 415: track 416 | 416: train 417 | 417: trampoline 418 | 418: trash bin 419 | 419: tray 420 | 420: tree 421 | 421: tricycle 422 | 422: tripod 423 | 423: trophy 424 | 424: truck 425 | 425: tube 426 | 426: turtle 427 | 427: tvmonitor 428 | 428: tweezers 429 | 429: typewriter 430 | 430: umbrella 431 | 431: unknown 432 | 432: vacuum cleaner 433 | 433: vending machine 434 | 434: video camera 435 | 435: video game console 436 | 436: video player 437 | 437: video tape 438 | 438: violin 439 | 439: wakeboard 440 | 440: wall 441 | 441: wallet 442 | 442: wardrobe 443 | 443: washing machine 444 | 444: watch 445 | 445: water 446 | 446: water dispenser 447 | 447: water pipe 448 | 448: water skate board 449 | 449: watermelon 450 | 450: whale 451 | 451: wharf 452 | 452: wheel 453 | 453: wheelchair 454 | 454: window 455 | 455: window blinds 456 | 456: wineglass 457 | 457: wire 458 | 458: wood 459 | 459: wool -------------------------------------------------------------------------------- /dataloaders/labels/voc_context_classes-59.txt: -------------------------------------------------------------------------------- 1 | 0: background 2 | 1: aeroplane 3 | 2: bicycle 4 | 3: bird 5 | 4: boat 6 | 5: bottle 7 | 6: bus 8 | 7: car 9 | 8: cat 10 | 9: chair 11 | 10: cow 12 | 11: diningtable 13 | 12: dog 14 | 13: horse 15 | 14: motorbike 16 | 15: person 17 | 16: pottedplant 18 | 17: sheep 19 | 18: sofa 20 | 19: train 21 | 20: tvmonitor 22 | 21: bag 23 | 22: bed 24 | 23: bench 25 | 24: book 26 | 25: building 27 | 26: cabinet 28 | 27: ceiling 29 | 28: clothes 30 | 29: computer 31 | 30: cup 32 | 31: door 33 | 32: fence 34 | 33: floor 35 | 34: flower 36 | 35: food 37 | 36: grass 38 | 37: ground 39 | 38: keyboard 40 | 39: light 41 | 40: mountain 42 | 41: mouse 43 | 42: curtain 44 | 43: platform 45 | 44: sign 46 | 45: plate 47 | 46: road 48 | 47: rock 49 | 48: shelves 50 | 49: sidewalk 51 | 50: sky 52 | 51: snow 53 | 52: bedcloth 54 | 53: track 55 | 54: tree 56 | 55: truck 57 | 56: wall 58 | 57: water 59 | 58: window 60 | 59: wood -------------------------------------------------------------------------------- /dataloaders/voc.py: -------------------------------------------------------------------------------- 1 | # Originally written by Kazuto Nakashima 2 | # https://github.com/kazuto1011/deeplab-pytorch 3 | 4 | from base import BaseDataSet, BaseDataLoader 5 | from utils import palette 6 | import numpy as np 7 | import os 8 | import scipy 9 | import torch 10 | from PIL import Image 11 | import cv2 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms 14 | 15 | class VOCDataset(BaseDataSet): 16 | """ 17 | Pascal Voc dataset 18 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar 19 | """ 20 | def __init__(self, **kwargs): 21 | self.num_classes = 21 22 | self.palette = palette.get_voc_palette(self.num_classes) 23 | super(VOCDataset, self).__init__(**kwargs) 24 | 25 | def _set_files(self): 26 | self.root = os.path.join(self.root, 'VOCdevkit/VOC2012') 27 | self.image_dir = os.path.join(self.root, 'JPEGImages') 28 | self.label_dir = os.path.join(self.root, 'SegmentationClass') 29 | 30 | file_list = os.path.join(self.root, "ImageSets/Segmentation", self.split + ".txt") 31 | self.files = [line.rstrip() for line in tuple(open(file_list, "r"))] 32 | 33 | def _load_data(self, index): 34 | image_id = self.files[index] 35 | image_path = os.path.join(self.image_dir, image_id + '.jpg') 36 | label_path = os.path.join(self.label_dir, image_id + '.png') 37 | image = np.asarray(Image.open(image_path), dtype=np.float32) 38 | label = np.asarray(Image.open(label_path), dtype=np.int32) 39 | image_id = self.files[index].split("/")[-1].split(".")[0] 40 | return image, label, image_id 41 | 42 | class VOCAugDataset(BaseDataSet): 43 | """ 44 | Contrains both SBD and VOC 2012 dataset 45 | Annotations : https://github.com/DrSleep/tensorflow-deeplab-resnet#evaluation 46 | Image Sets: https://ucla.app.box.com/s/rd9z2xvwsfpksi7mi08i2xqrj7ab4keb/file/55053033642 47 | """ 48 | def __init__(self, **kwargs): 49 | self.num_classes = 21 50 | self.palette = palette.get_voc_palette(self.num_classes) 51 | super(VOCAugDataset, self).__init__(**kwargs) 52 | 53 | def _set_files(self): 54 | self.root = os.path.join(self.root, 'VOCdevkit/VOC2012') 55 | 56 | file_list = os.path.join(self.root, "ImageSets/Segmentation", self.split + ".txt") 57 | file_list = [line.rstrip().split(' ') for line in tuple(open(file_list, "r"))] 58 | self.files, self.labels = list(zip(*file_list)) 59 | 60 | def _load_data(self, index): 61 | image_path = os.path.join(self.root, self.files[index][1:]) 62 | label_path = os.path.join(self.root, self.labels[index][1:]) 63 | image = np.asarray(Image.open(image_path), dtype=np.float32) 64 | label = np.asarray(Image.open(label_path), dtype=np.int32) 65 | image_id = self.files[index].split("/")[-1].split(".")[0] 66 | return image, label, image_id 67 | 68 | 69 | class VOC(BaseDataLoader): 70 | def __init__(self, data_dir, batch_size, split, crop_size=None, base_size=None, scale=True, num_workers=1, val=False, 71 | shuffle=False, flip=False, rotate=False, blur= False, augment=False, val_split= None, return_id=False): 72 | 73 | self.MEAN = [0.45734706, 0.43338275, 0.40058118] 74 | self.STD = [0.23965294, 0.23532275, 0.2398498] 75 | 76 | kwargs = { 77 | 'root': data_dir, 78 | 'split': split, 79 | 'mean': self.MEAN, 80 | 'std': self.STD, 81 | 'augment': augment, 82 | 'crop_size': crop_size, 83 | 'base_size': base_size, 84 | 'scale': scale, 85 | 'flip': flip, 86 | 'blur': blur, 87 | 'rotate': rotate, 88 | 'return_id': return_id, 89 | 'val': val 90 | } 91 | 92 | if split in ["train_aug", "trainval_aug", "val_aug", "test_aug"]: 93 | self.dataset = VOCAugDataset(**kwargs) 94 | elif split in ["train", "trainval", "val", "test"]: 95 | self.dataset = VOCDataset(**kwargs) 96 | else: raise ValueError(f"Invalid split name {split}") 97 | super(VOC, self).__init__(self.dataset, batch_size, shuffle, num_workers, val_split) 98 | 99 | -------------------------------------------------------------------------------- /images/colour_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yassouali/pytorch-segmentation/33434fa987bbf348c961473c15f044b85e32ca7f/images/colour_scheme.png -------------------------------------------------------------------------------- /images/learning_rates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yassouali/pytorch-segmentation/33434fa987bbf348c961473c15f044b85e32ca7f/images/learning_rates.png -------------------------------------------------------------------------------- /images/tb1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yassouali/pytorch-segmentation/33434fa987bbf348c961473c15f044b85e32ca7f/images/tb1.png -------------------------------------------------------------------------------- /images/tb2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yassouali/pytorch-segmentation/33434fa987bbf348c961473c15f044b85e32ca7f/images/tb2.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy 3 | import os 4 | import numpy as np 5 | import json 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torchvision import transforms 10 | from scipy import ndimage 11 | from tqdm import tqdm 12 | from math import ceil 13 | from glob import glob 14 | from PIL import Image 15 | import dataloaders 16 | import models 17 | from utils.helpers import colorize_mask 18 | from collections import OrderedDict 19 | 20 | def pad_image(img, target_size): 21 | rows_to_pad = max(target_size[0] - img.shape[2], 0) 22 | cols_to_pad = max(target_size[1] - img.shape[3], 0) 23 | padded_img = F.pad(img, (0, cols_to_pad, 0, rows_to_pad), "constant", 0) 24 | return padded_img 25 | 26 | def sliding_predict(model, image, num_classes, flip=True): 27 | image_size = image.shape 28 | tile_size = (int(image_size[2]//2.5), int(image_size[3]//2.5)) 29 | overlap = 1/3 30 | 31 | stride = ceil(tile_size[0] * (1 - overlap)) 32 | 33 | num_rows = int(ceil((image_size[2] - tile_size[0]) / stride) + 1) 34 | num_cols = int(ceil((image_size[3] - tile_size[1]) / stride) + 1) 35 | total_predictions = np.zeros((num_classes, image_size[2], image_size[3])) 36 | count_predictions = np.zeros((image_size[2], image_size[3])) 37 | tile_counter = 0 38 | 39 | for row in range(num_rows): 40 | for col in range(num_cols): 41 | x_min, y_min = int(col * stride), int(row * stride) 42 | x_max = min(x_min + tile_size[1], image_size[3]) 43 | y_max = min(y_min + tile_size[0], image_size[2]) 44 | 45 | img = image[:, :, y_min:y_max, x_min:x_max] 46 | padded_img = pad_image(img, tile_size) 47 | tile_counter += 1 48 | padded_prediction = model(padded_img) 49 | if flip: 50 | fliped_img = padded_img.flip(-1) 51 | fliped_predictions = model(padded_img.flip(-1)) 52 | padded_prediction = 0.5 * (fliped_predictions.flip(-1) + padded_prediction) 53 | predictions = padded_prediction[:, :, :img.shape[2], :img.shape[3]] 54 | count_predictions[y_min:y_max, x_min:x_max] += 1 55 | total_predictions[:, y_min:y_max, x_min:x_max] += predictions.data.cpu().numpy().squeeze(0) 56 | 57 | total_predictions /= count_predictions 58 | return total_predictions 59 | 60 | 61 | def multi_scale_predict(model, image, scales, num_classes, device, flip=False): 62 | input_size = (image.size(2), image.size(3)) 63 | upsample = nn.Upsample(size=input_size, mode='bilinear', align_corners=True) 64 | total_predictions = np.zeros((num_classes, image.size(2), image.size(3))) 65 | 66 | image = image.data.data.cpu().numpy() 67 | for scale in scales: 68 | scaled_img = ndimage.zoom(image, (1.0, 1.0, float(scale), float(scale)), order=1, prefilter=False) 69 | scaled_img = torch.from_numpy(scaled_img).to(device) 70 | scaled_prediction = upsample(model(scaled_img).cpu()) 71 | 72 | if flip: 73 | fliped_img = scaled_img.flip(-1).to(device) 74 | fliped_predictions = upsample(model(fliped_img).cpu()) 75 | scaled_prediction = 0.5 * (fliped_predictions.flip(-1) + scaled_prediction) 76 | total_predictions += scaled_prediction.data.cpu().numpy().squeeze(0) 77 | 78 | total_predictions /= len(scales) 79 | return total_predictions 80 | 81 | 82 | def save_images(image, mask, output_path, image_file, palette): 83 | # Saves the image, the model output and the results after the post processing 84 | w, h = image.size 85 | image_file = os.path.basename(image_file).split('.')[0] 86 | colorized_mask = colorize_mask(mask, palette) 87 | colorized_mask.save(os.path.join(output_path, image_file+'.png')) 88 | # output_im = Image.new('RGB', (w*2, h)) 89 | # output_im.paste(image, (0,0)) 90 | # output_im.paste(colorized_mask, (w,0)) 91 | # output_im.save(os.path.join(output_path, image_file+'_colorized.png')) 92 | # mask_img = Image.fromarray(mask, 'L') 93 | # mask_img.save(os.path.join(output_path, image_file+'.png')) 94 | 95 | def main(): 96 | args = parse_arguments() 97 | config = json.load(open(args.config)) 98 | 99 | # Dataset used for training the model 100 | dataset_type = config['train_loader']['type'] 101 | assert dataset_type in ['VOC', 'COCO', 'CityScapes', 'ADE20K', 'DeepScene'] 102 | if dataset_type == 'CityScapes': 103 | scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25] 104 | else: 105 | scales = [0.75, 1.0, 1.25, 1.5, 1.75, 2.0] 106 | loader = getattr(dataloaders, config['train_loader']['type'])(**config['train_loader']['args']) 107 | to_tensor = transforms.ToTensor() 108 | normalize = transforms.Normalize(loader.MEAN, loader.STD) 109 | num_classes = loader.dataset.num_classes 110 | palette = loader.dataset.palette 111 | 112 | # Model 113 | model = getattr(models, config['arch']['type'])(num_classes, **config['arch']['args']) 114 | availble_gpus = list(range(torch.cuda.device_count())) 115 | device = torch.device('cuda:0' if len(availble_gpus) > 0 else 'cpu') 116 | 117 | # Load checkpoint 118 | checkpoint = torch.load(args.model, map_location=device) 119 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys(): 120 | checkpoint = checkpoint['state_dict'] 121 | # If during training, we used data parallel 122 | if 'module' in list(checkpoint.keys())[0] and not isinstance(model, torch.nn.DataParallel): 123 | # for gpu inference, use data parallel 124 | if "cuda" in device.type: 125 | model = torch.nn.DataParallel(model) 126 | else: 127 | # for cpu inference, remove module 128 | new_state_dict = OrderedDict() 129 | for k, v in checkpoint.items(): 130 | name = k[7:] 131 | new_state_dict[name] = v 132 | checkpoint = new_state_dict 133 | # load 134 | model.load_state_dict(checkpoint) 135 | model.to(device) 136 | model.eval() 137 | 138 | if not os.path.exists('outputs'): 139 | os.makedirs('outputs') 140 | 141 | image_files = sorted(glob(os.path.join(args.images, f'*.{args.extension}'))) 142 | with torch.no_grad(): 143 | tbar = tqdm(image_files, ncols=100) 144 | for img_file in tbar: 145 | image = Image.open(img_file).convert('RGB') 146 | input = normalize(to_tensor(image)).unsqueeze(0) 147 | 148 | if args.mode == 'multiscale': 149 | prediction = multi_scale_predict(model, input, scales, num_classes, device) 150 | elif args.mode == 'sliding': 151 | prediction = sliding_predict(model, input, num_classes) 152 | else: 153 | prediction = model(input.to(device)) 154 | prediction = prediction.squeeze(0).cpu().numpy() 155 | prediction = F.softmax(torch.from_numpy(prediction), dim=0).argmax(0).cpu().numpy() 156 | save_images(image, prediction, args.output, img_file, palette) 157 | 158 | def parse_arguments(): 159 | parser = argparse.ArgumentParser(description='Inference') 160 | parser.add_argument('-c', '--config', default='VOC',type=str, 161 | help='The config used to train the model') 162 | parser.add_argument('-mo', '--mode', default='multiscale', type=str, 163 | help='Mode used for prediction: either [multiscale, sliding]') 164 | parser.add_argument('-m', '--model', default='model_weights.pth', type=str, 165 | help='Path to the .pth model checkpoint to be used in the prediction') 166 | parser.add_argument('-i', '--images', default=None, type=str, 167 | help='Path to the images to be segmented') 168 | parser.add_argument('-o', '--output', default='outputs', type=str, 169 | help='Output Path') 170 | parser.add_argument('-e', '--extension', default='jpg', type=str, 171 | help='The extension of the images to be segmented') 172 | args = parser.parse_args() 173 | return args 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .fcn import FCN8 2 | from .unet import UNet, UNetResnet 3 | from .segnet import SegNet 4 | from .segnet import SegResNet 5 | from .enet import ENet 6 | from .gcn import GCN 7 | from .deeplabv3_plus import DeepLab 8 | from .duc_hdc import DeepLab_DUC_HDC 9 | from .upernet import UperNet 10 | from .pspnet import PSPNet 11 | from .pspnet import PSPDenseNet -------------------------------------------------------------------------------- /models/duc_hdc.py: -------------------------------------------------------------------------------- 1 | from base import BaseModel 2 | import torch 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision import models 7 | import torch.utils.model_zoo as model_zoo 8 | from utils.helpers import initialize_weights 9 | from itertools import chain 10 | 11 | ''' 12 | -> Dense upsampling convolution block 13 | ''' 14 | 15 | class DUC(nn.Module): 16 | def __init__(self, in_channels, out_channles, upscale): 17 | super(DUC, self).__init__() 18 | out_channles = out_channles * (upscale ** 2) 19 | self.conv = nn.Conv2d(in_channels, out_channles, 1, bias=False) 20 | self.bn = nn.BatchNorm2d(out_channles) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.pixl_shf = nn.PixelShuffle(upscale_factor=upscale) 23 | 24 | initialize_weights(self) 25 | kernel = self.icnr(self.conv.weight, scale=upscale) 26 | self.conv.weight.data.copy_(kernel) 27 | 28 | def forward(self, x): 29 | x = self.relu(self.bn(self.conv(x))) 30 | x = self.pixl_shf(x) 31 | return x 32 | 33 | def icnr(self, x, scale=2, init=nn.init.kaiming_normal): 34 | ''' 35 | Even with pixel shuffle we still have check board artifacts, 36 | the solution is to initialize the d**2 feature maps with the same 37 | radom weights: https://arxiv.org/pdf/1707.02937.pdf 38 | ''' 39 | new_shape = [int(x.shape[0] / (scale ** 2))] + list(x.shape[1:]) 40 | subkernel = torch.zeros(new_shape) 41 | subkernel = init(subkernel) 42 | subkernel = subkernel.transpose(0, 1) 43 | subkernel = subkernel.contiguous().view(subkernel.shape[0], 44 | subkernel.shape[1], -1) 45 | kernel = subkernel.repeat(1, 1, scale ** 2) 46 | transposed_shape = [x.shape[1]] + [x.shape[0]] + list(x.shape[2:]) 47 | kernel = kernel.contiguous().view(transposed_shape) 48 | kernel = kernel.transpose(0, 1) 49 | return kernel 50 | 51 | ''' 52 | -> ResNet BackBone 53 | ''' 54 | 55 | class ResNet_HDC_DUC(nn.Module): 56 | def __init__(self, in_channels, output_stride, pretrained=True, dilation_bigger=False): 57 | super(ResNet_HDC_DUC, self).__init__() 58 | 59 | model = models.resnet101(pretrained=pretrained) 60 | if not pretrained or in_channels != 3: 61 | self.layer0 = nn.Sequential( 62 | nn.Conv2d(in_channels, 64, 7, stride=2, padding=3, bias=False), 63 | nn.BatchNorm2d(64), 64 | nn.ReLU(inplace=True), 65 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | ) 67 | initialize_weights(self.layer0) 68 | else: 69 | self.layer0 = nn.Sequential(*list(model.children())[:4]) 70 | 71 | self.layer1 = model.layer1 72 | self.layer2 = model.layer2 73 | self.layer3 = model.layer3 74 | self.layer4 = model.layer4 75 | 76 | if output_stride == 4: list(self.layer0.children())[0].stride = (1, 1) 77 | 78 | d_res4b = [] 79 | if dilation_bigger: 80 | d_res4b.extend([1, 2, 5, 9]*5 + [1, 2, 5]) 81 | d_res5b = [5, 9, 17] 82 | else: 83 | # Dialtion-RF 84 | d_res4b.extend([1, 2, 3]*7 + [2, 2]) 85 | d_res5b = [3, 4, 5] 86 | 87 | l_index = 0 88 | for n, m in self.layer3.named_modules(): 89 | if 'conv2' in n: 90 | d = d_res4b[l_index] 91 | m.dilation, m.padding, m.stride = (d, d), (d, d), (1, 1) 92 | l_index += 1 93 | elif 'downsample.0' in n: 94 | m.stride = (1, 1) 95 | 96 | l_index = 0 97 | for n, m in self.layer4.named_modules(): 98 | if 'conv2' in n: 99 | d = d_res5b[l_index] 100 | m.dilation, m.padding, m.stride = (d, d), (d, d), (1, 1) 101 | l_index += 1 102 | elif 'downsample.0' in n: 103 | m.stride = (1, 1) 104 | 105 | def forward(self, x): 106 | x = self.layer0(x) 107 | x = self.layer1(x) 108 | low_level_features = x 109 | x = self.layer2(x) 110 | x = self.layer3(x) 111 | x = self.layer4(x) 112 | 113 | return x, low_level_features 114 | 115 | ''' 116 | -> The Atrous Spatial Pyramid Pooling 117 | ''' 118 | 119 | def assp_branch(in_channels, out_channles, kernel_size, dilation): 120 | padding = 0 if kernel_size == 1 else dilation 121 | return nn.Sequential( 122 | nn.Conv2d(in_channels, out_channles, kernel_size, padding=padding, dilation=dilation, bias=False), 123 | nn.BatchNorm2d(out_channles), 124 | nn.ReLU(inplace=True)) 125 | 126 | class ASSP(nn.Module): 127 | def __init__(self, in_channels, output_stride, assp_channels=6): 128 | super(ASSP, self).__init__() 129 | 130 | assert output_stride in [4, 8], 'Only output strides of 8 or 16 are suported' 131 | assert assp_channels in [4, 6], 'Number of suported ASSP branches are 4 or 6' 132 | dilations = [1, 6, 12, 18, 24, 36] 133 | dilations = dilations[:assp_channels] 134 | self.assp_channels = assp_channels 135 | 136 | self.aspp1 = assp_branch(in_channels, 256, 1, dilation=dilations[0]) 137 | self.aspp2 = assp_branch(in_channels, 256, 3, dilation=dilations[1]) 138 | self.aspp3 = assp_branch(in_channels, 256, 3, dilation=dilations[2]) 139 | self.aspp4 = assp_branch(in_channels, 256, 3, dilation=dilations[3]) 140 | if self.assp_channels == 6: 141 | self.aspp5 = assp_branch(in_channels, 256, 3, dilation=dilations[4]) 142 | self.aspp6 = assp_branch(in_channels, 256, 3, dilation=dilations[5]) 143 | 144 | self.avg_pool = nn.Sequential( 145 | nn.AdaptiveAvgPool2d((1, 1)), 146 | nn.Conv2d(in_channels, 256, 1, bias=False), 147 | nn.BatchNorm2d(256), 148 | nn.ReLU(inplace=True)) 149 | 150 | self.conv1 = nn.Conv2d(256*(self.assp_channels + 1), 256, 1, bias=False) 151 | self.bn1 = nn.BatchNorm2d(256) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.dropout = nn.Dropout(0.5) 154 | 155 | initialize_weights(self) 156 | 157 | def forward(self, x): 158 | x1 = self.aspp1(x) 159 | x2 = self.aspp2(x) 160 | x3 = self.aspp3(x) 161 | x4 = self.aspp4(x) 162 | if self.assp_channels == 6: 163 | x5 = self.aspp5(x) 164 | x6 = self.aspp6(x) 165 | x_avg_pool = F.interpolate(self.avg_pool(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True) 166 | 167 | if self.assp_channels == 6: 168 | x = self.conv1(torch.cat((x1, x2, x3, x4, x5, x6, x_avg_pool), dim=1)) 169 | else: 170 | x = self.conv1(torch.cat((x1, x2, x3, x4, x_avg_pool), dim=1)) 171 | x = self.bn1(x) 172 | x = self.dropout(self.relu(x)) 173 | 174 | return x 175 | 176 | ''' 177 | -> Decoder 178 | ''' 179 | 180 | class Decoder(nn.Module): 181 | def __init__(self, low_level_channels, num_classes): 182 | super(Decoder, self).__init__() 183 | self.conv1 = nn.Conv2d(low_level_channels, 48, 1, bias=False) 184 | self.bn1 = nn.BatchNorm2d(48) 185 | self.relu = nn.ReLU(inplace=True) 186 | self.DUC = DUC(256, 256, upscale=2) 187 | 188 | self.output = nn.Sequential( 189 | nn.Conv2d(48+256, 256, 3, stride=1, padding=1, bias=False), 190 | nn.BatchNorm2d(256), 191 | nn.ReLU(inplace=True), 192 | nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False), 193 | nn.BatchNorm2d(256), 194 | nn.ReLU(inplace=True), 195 | nn.Dropout(0.1), 196 | nn.Conv2d(256, num_classes, 1, stride=1), 197 | ) 198 | initialize_weights(self) 199 | 200 | def forward(self, x, low_level_features): 201 | low_level_features = self.conv1(low_level_features) 202 | low_level_features = self.relu(self.bn1(low_level_features)) 203 | x = self.DUC(x) 204 | if x.size() != low_level_features.size(): 205 | # One pixel added with a conv with stride 2 when the input size in odd 206 | x = x[:, :, :low_level_features.size(2), :low_level_features.size(3)] 207 | x = self.output(torch.cat((low_level_features, x), dim=1)) 208 | return x 209 | 210 | ''' 211 | -> Deeplab V3 + with DUC & HDC 212 | ''' 213 | 214 | class DeepLab_DUC_HDC(BaseModel): 215 | def __init__(self, num_classes, in_channels=3, pretrained=True, output_stride=8, freeze_bn=False, **_): 216 | super(DeepLab_DUC_HDC, self).__init__() 217 | 218 | self.backbone = ResNet_HDC_DUC(in_channels=in_channels, output_stride=output_stride, pretrained=pretrained) 219 | low_level_channels = 256 220 | 221 | self.ASSP = ASSP(in_channels=2048, output_stride=output_stride) 222 | self.decoder = Decoder(low_level_channels, num_classes) 223 | self.DUC_out = DUC(num_classes, num_classes, 4) 224 | if freeze_bn: self.freeze_bn() 225 | if freeze_backbone: 226 | set_trainable([self.backbone], False) 227 | 228 | def forward(self, x): 229 | H, W = x.size(2), x.size(3) 230 | x, low_level_features = self.backbone(x) 231 | x = self.ASSP(x) 232 | x = self.decoder(x, low_level_features) 233 | x = self.DUC_out(x) 234 | return x 235 | 236 | def get_backbone_params(self): 237 | return self.backbone.parameters() 238 | 239 | def get_decoder_params(self): 240 | return chain(self.ASSP.parameters(), self.decoder.parameters(), self.DUC_out.parameters()) 241 | 242 | def freeze_bn(self): 243 | for module in self.modules(): 244 | if isinstance(module, nn.BatchNorm2d): module.eval() 245 | 246 | -------------------------------------------------------------------------------- /models/enet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from base import BaseModel 5 | from utils.helpers import initialize_weights 6 | from itertools import chain 7 | 8 | class InitalBlock(nn.Module): 9 | def __init__(self, in_channels, use_prelu=True): 10 | super(InitalBlock, self).__init__() 11 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) 12 | self.conv = nn.Conv2d(in_channels, 16 - in_channels, 3, padding=1, stride=2) 13 | self.bn = nn.BatchNorm2d(16) 14 | self.prelu = nn.PReLU(16) if use_prelu else nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | x = torch.cat((self.pool(x), self.conv(x)), dim=1) 18 | x = self.bn(x) 19 | x = self.prelu(x) 20 | return x 21 | 22 | class BottleNeck(nn.Module): 23 | def __init__(self, in_channels, out_channels=None, activation=None, dilation=1, downsample=False, proj_ratio=4, 24 | upsample=False, asymetric=False, regularize=True, p_drop=None, use_prelu=True): 25 | super(BottleNeck, self).__init__() 26 | 27 | self.pad = 0 28 | self.upsample = upsample 29 | self.downsample = downsample 30 | if out_channels is None: out_channels = in_channels 31 | else: self.pad = out_channels - in_channels 32 | 33 | if regularize: assert p_drop is not None 34 | if downsample: assert not upsample 35 | elif upsample: assert not downsample 36 | inter_channels = in_channels//proj_ratio 37 | 38 | # Main 39 | if upsample: 40 | self.spatil_conv = nn.Conv2d(in_channels, out_channels, 1, bias=False) 41 | self.bn_up = nn.BatchNorm2d(out_channels) 42 | self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) 43 | elif downsample: 44 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 45 | 46 | # Bottleneck 47 | if downsample: 48 | self.conv1 = nn.Conv2d(in_channels, inter_channels, 2, stride=2, bias=False) 49 | else: 50 | self.conv1 = nn.Conv2d(in_channels, inter_channels, 1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(inter_channels) 52 | self.prelu1 = nn.PReLU() if use_prelu else nn.ReLU(inplace=True) 53 | 54 | if asymetric: 55 | self.conv2 = nn.Sequential( 56 | nn.Conv2d(inter_channels, inter_channels, kernel_size=(1,5), padding=(0,2)), 57 | nn.BatchNorm2d(inter_channels), 58 | nn.PReLU() if use_prelu else nn.ReLU(inplace=True), 59 | nn.Conv2d(inter_channels, inter_channels, kernel_size=(5,1), padding=(2,0)), 60 | ) 61 | elif upsample: 62 | self.conv2 = nn.ConvTranspose2d(inter_channels, inter_channels, kernel_size=3, padding=1, 63 | output_padding=1, stride=2, bias=False) 64 | else: 65 | self.conv2 = nn.Conv2d(inter_channels, inter_channels, 3, padding=dilation, dilation=dilation, bias=False) 66 | self.bn2 = nn.BatchNorm2d(inter_channels) 67 | self.prelu2 = nn.PReLU() if use_prelu else nn.ReLU(inplace=True) 68 | 69 | self.conv3 = nn.Conv2d(inter_channels, out_channels, 1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(out_channels) 71 | self.prelu3 = nn.PReLU() if use_prelu else nn.ReLU(inplace=True) 72 | 73 | self.regularizer = nn.Dropout2d(p_drop) if regularize else None 74 | self.prelu_out = nn.PReLU() if use_prelu else nn.ReLU(inplace=True) 75 | 76 | def forward(self, x, indices=None, output_size=None): 77 | # Main branch 78 | identity = x 79 | if self.upsample: 80 | assert (indices is not None) and (output_size is not None) 81 | identity = self.bn_up(self.spatil_conv(identity)) 82 | if identity.size() != indices.size(): 83 | pad = (indices.size(3) - identity.size(3), 0, indices.size(2) - identity.size(2), 0) 84 | identity = F.pad(identity, pad, "constant", 0) 85 | identity = self.unpool(identity, indices=indices)#, output_size=output_size) 86 | elif self.downsample: 87 | identity, idx = self.pool(identity) 88 | 89 | ''' 90 | if self.pad > 0: 91 | if self.pad % 2 == 0 : pad = (0, 0, 0, 0, self.pad//2, self.pad//2) 92 | else: pad = (0, 0, 0, 0, self.pad//2, self.pad//2+1) 93 | identity = F.pad(identity, pad, "constant", 0) 94 | ''' 95 | 96 | if self.pad > 0: 97 | extras = torch.zeros((identity.size(0), self.pad, identity.size(2), identity.size(3))) 98 | if torch.cuda.is_available(): extras = extras.cuda(0) 99 | identity = torch.cat((identity, extras), dim = 1) 100 | 101 | # Bottleneck 102 | x = self.conv1(x) 103 | x = self.bn1(x) 104 | x = self.prelu1(x) 105 | x = self.conv2(x) 106 | x = self.bn2(x) 107 | x = self.prelu2(x) 108 | x = self.conv3(x) 109 | x = self.bn3(x) 110 | x = self.prelu3(x) 111 | if self.regularizer is not None: 112 | x = self.regularizer(x) 113 | 114 | # When the input dim is odd, we might have a mismatch of one pixel 115 | if identity.size() != x.size(): 116 | pad = (identity.size(3) - x.size(3), 0, identity.size(2) - x.size(2), 0) 117 | x = F.pad(x, pad, "constant", 0) 118 | 119 | x += identity 120 | x = self.prelu_out(x) 121 | 122 | if self.downsample: 123 | return x, idx 124 | return x 125 | 126 | class ENet(BaseModel): 127 | def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_): 128 | super(ENet, self).__init__() 129 | self.initial = InitalBlock(in_channels) 130 | 131 | # Stage 1 132 | self.bottleneck10 = BottleNeck(16, 64, downsample=True, p_drop=0.01) 133 | self.bottleneck11 = BottleNeck(64, p_drop=0.01) 134 | self.bottleneck12 = BottleNeck(64, p_drop=0.01) 135 | self.bottleneck13 = BottleNeck(64, p_drop=0.01) 136 | self.bottleneck14 = BottleNeck(64, p_drop=0.01) 137 | 138 | # Stage 2 139 | self.bottleneck20 = BottleNeck(64, 128, downsample=True, p_drop=0.1) 140 | self.bottleneck21 = BottleNeck(128, p_drop=0.1) 141 | self.bottleneck22 = BottleNeck(128, dilation=2, p_drop=0.1) 142 | self.bottleneck23 = BottleNeck(128, asymetric=True, p_drop=0.1) 143 | self.bottleneck24 = BottleNeck(128, dilation=4, p_drop=0.1) 144 | self.bottleneck25 = BottleNeck(128, p_drop=0.1) 145 | self.bottleneck26 = BottleNeck(128, dilation=8, p_drop=0.1) 146 | self.bottleneck27 = BottleNeck(128, asymetric=True, p_drop=0.1) 147 | self.bottleneck28 = BottleNeck(128, dilation=16, p_drop=0.1) 148 | 149 | # Stage 3 150 | self.bottleneck31 = BottleNeck(128, p_drop=0.1) 151 | self.bottleneck32 = BottleNeck(128, dilation=2, p_drop=0.1) 152 | self.bottleneck33 = BottleNeck(128, asymetric=True, p_drop=0.1) 153 | self.bottleneck34 = BottleNeck(128, dilation=4, p_drop=0.1) 154 | self.bottleneck35 = BottleNeck(128, p_drop=0.1) 155 | self.bottleneck36 = BottleNeck(128, dilation=8, p_drop=0.1) 156 | self.bottleneck37 = BottleNeck(128, asymetric=True, p_drop=0.1) 157 | self.bottleneck38 = BottleNeck(128, dilation=16, p_drop=0.1) 158 | 159 | # Stage 4 160 | self.bottleneck40 = BottleNeck(128, 64, upsample=True, p_drop=0.1, use_prelu=False) 161 | self.bottleneck41 = BottleNeck(64, p_drop=0.1, use_prelu=False) 162 | self.bottleneck42 = BottleNeck(64, p_drop=0.1, use_prelu=False) 163 | 164 | # Stage 5 165 | self.bottleneck50 = BottleNeck(64, 16, upsample=True, p_drop=0.1, use_prelu=False) 166 | self.bottleneck51 = BottleNeck(16, p_drop=0.1, use_prelu=False) 167 | 168 | # Stage 6 169 | self.fullconv = nn.ConvTranspose2d(16, num_classes, kernel_size=3, padding=1, 170 | output_padding=1, stride=2, bias=False) 171 | initialize_weights(self) 172 | if freeze_bn: self.freeze_bn() 173 | 174 | def forward(self, x): 175 | x = self.initial(x) 176 | 177 | # Stage 1 178 | sz1 = x.size() 179 | x, indices1 = self.bottleneck10(x) 180 | x = self.bottleneck11(x) 181 | x = self.bottleneck12(x) 182 | x = self.bottleneck13(x) 183 | x = self.bottleneck14(x) 184 | 185 | # Stage 2 186 | sz2 = x.size() 187 | x, indices2 = self.bottleneck20(x) 188 | x = self.bottleneck21(x) 189 | x = self.bottleneck22(x) 190 | x = self.bottleneck23(x) 191 | x = self.bottleneck24(x) 192 | x = self.bottleneck25(x) 193 | x = self.bottleneck26(x) 194 | x = self.bottleneck27(x) 195 | x = self.bottleneck28(x) 196 | 197 | # Stage 3 198 | x = self.bottleneck31(x) 199 | x = self.bottleneck32(x) 200 | x = self.bottleneck33(x) 201 | x = self.bottleneck34(x) 202 | x = self.bottleneck35(x) 203 | x = self.bottleneck36(x) 204 | x = self.bottleneck37(x) 205 | x = self.bottleneck38(x) 206 | 207 | # Stage 4 208 | x = self.bottleneck40(x, indices=indices2, output_size=sz2) 209 | x = self.bottleneck41(x) 210 | x = self.bottleneck42(x) 211 | 212 | # Stage 5 213 | x = self.bottleneck50(x, indices=indices1, output_size=sz1) 214 | x = self.bottleneck51(x) 215 | 216 | # Stage 6 217 | x = self.fullconv(x) 218 | return x 219 | 220 | def get_backbone_params(self): 221 | # There is no backbone for unet, all the parameters are trained from scratch 222 | return [] 223 | 224 | def get_decoder_params(self): 225 | return self.parameters() 226 | 227 | def freeze_bn(self): 228 | for module in self.modules(): 229 | if isinstance(module, nn.BatchNorm2d): module.eval() 230 | 231 | -------------------------------------------------------------------------------- /models/fcn.py: -------------------------------------------------------------------------------- 1 | from base import BaseModel 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from utils.helpers import get_upsampling_weight 6 | import torch 7 | from itertools import chain 8 | 9 | class FCN8(BaseModel): 10 | def __init__(self, num_classes, pretrained=True, freeze_bn=False, **_): 11 | super(FCN8, self).__init__() 12 | vgg = models.vgg16(pretrained) 13 | features = list(vgg.features.children()) 14 | classifier = list(vgg.classifier.children()) 15 | 16 | # Pad the input to enable small inputs and allow matching feature maps 17 | features[0].padding = (100, 100) 18 | 19 | # Enbale ceil in max pool, to avoid different sizes when upsampling 20 | for layer in features: 21 | if 'MaxPool' in layer.__class__.__name__: 22 | layer.ceil_mode = True 23 | 24 | # Extract pool3, pool4 and pool5 from the VGG net 25 | self.pool3 = nn.Sequential(*features[:17]) 26 | self.pool4 = nn.Sequential(*features[17:24]) 27 | self.pool5 = nn.Sequential(*features[24:]) 28 | 29 | # Adjust the depth of pool3 and pool4 to num_classes 30 | self.adj_pool3 = nn.Conv2d(256, num_classes, kernel_size=1) 31 | self.adj_pool4 = nn.Conv2d(512, num_classes, kernel_size=1) 32 | 33 | # Replace the FC layer of VGG with conv layers 34 | conv6 = nn.Conv2d(512, 4096, kernel_size=7) 35 | conv7 = nn.Conv2d(4096, 4096, kernel_size=1) 36 | output = nn.Conv2d(4096, num_classes, kernel_size=1) 37 | 38 | # Copy the weights from VGG's FC pretrained layers 39 | conv6.weight.data.copy_(classifier[0].weight.data.view( 40 | conv6.weight.data.size())) 41 | conv6.bias.data.copy_(classifier[0].bias.data) 42 | 43 | conv7.weight.data.copy_(classifier[3].weight.data.view( 44 | conv7.weight.data.size())) 45 | conv7.bias.data.copy_(classifier[3].bias.data) 46 | 47 | # Get the outputs 48 | self.output = nn.Sequential(conv6, nn.ReLU(inplace=True), nn.Dropout(), 49 | conv7, nn.ReLU(inplace=True), nn.Dropout(), 50 | output) 51 | 52 | # We'll need three upsampling layers, upsampling (x2 +2) the ouputs 53 | # upsampling (x2 +2) addition of pool4 and upsampled output 54 | # upsampling (x8 +8) the final value (pool3 + added output and pool4) 55 | self.up_output = nn.ConvTranspose2d(num_classes, num_classes, 56 | kernel_size=4, stride=2, bias=False) 57 | self.up_pool4_out = nn.ConvTranspose2d(num_classes, num_classes, 58 | kernel_size=4, stride=2, bias=False) 59 | self.up_final = nn.ConvTranspose2d(num_classes, num_classes, 60 | kernel_size=16, stride=8, bias=False) 61 | 62 | # We'll use guassian kernels for the upsampling weights 63 | self.up_output.weight.data.copy_( 64 | get_upsampling_weight(num_classes, num_classes, 4)) 65 | self.up_pool4_out.weight.data.copy_( 66 | get_upsampling_weight(num_classes, num_classes, 4)) 67 | self.up_final.weight.data.copy_( 68 | get_upsampling_weight(num_classes, num_classes, 16)) 69 | 70 | # We'll freeze the wights, this is a fixed upsampling and not deconv 71 | for m in self.modules(): 72 | if isinstance(m, nn.ConvTranspose2d): 73 | m.weight.requires_grad = False 74 | if freeze_bn: self.freeze_bn() 75 | if freeze_backbone: 76 | set_trainable([self.pool3, self.pool4, self.pool5], False) 77 | 78 | def forward(self, x): 79 | imh_H, img_W = x.size()[2], x.size()[3] 80 | 81 | # Forward the image 82 | pool3 = self.pool3(x) 83 | pool4 = self.pool4(pool3) 84 | pool5 = self.pool5(pool4) 85 | 86 | # Get the outputs and upsmaple them 87 | output = self.output(pool5) 88 | up_output = self.up_output(output) 89 | 90 | # Adjust pool4 and add the uped-outputs to pool4 91 | adjstd_pool4 = self.adj_pool4(0.01 * pool4) 92 | add_out_pool4 = self.up_pool4_out(adjstd_pool4[:, :, 5: (5 + up_output.size()[2]), 93 | 5: (5 + up_output.size()[3])] 94 | + up_output) 95 | 96 | # Adjust pool3 and add it to the uped last addition 97 | adjstd_pool3 = self.adj_pool3(0.0001 * pool3) 98 | final_value = self.up_final(adjstd_pool3[:, :, 9: (9 + add_out_pool4.size()[2]), 9: (9 + add_out_pool4.size()[3])] 99 | + add_out_pool4) 100 | 101 | # Remove the corresponding padded regions to the input img size 102 | final_value = final_value[:, :, 31: (31 + imh_H), 31: (31 + img_W)].contiguous() 103 | return final_value 104 | 105 | def get_backbone_params(self): 106 | return chain(self.pool3.parameters(), self.pool4.parameters(), self.pool5.parameters(), self.output.parameters()) 107 | 108 | def get_decoder_params(self): 109 | return chain(self.up_output.parameters(), self.adj_pool4.parameters(), self.up_pool4_out.parameters(), 110 | self.adj_pool3.parameters(), self.up_final.parameters()) 111 | 112 | def freeze_bn(self): 113 | for module in self.modules(): 114 | if isinstance(module, nn.BatchNorm2d): module.eval() 115 | 116 | -------------------------------------------------------------------------------- /models/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | from base import BaseModel 6 | from utils.helpers import initialize_weights 7 | from itertools import chain 8 | 9 | ''' 10 | -> BackBone Resnet_GCN 11 | ''' 12 | 13 | class Block_Resnet_GCN(nn.Module): 14 | def __init__(self, kernel_size, in_channels, out_channels, stride=1): 15 | super(Block_Resnet_GCN, self).__init__() 16 | self.conv11 = nn.Conv2d(in_channels, out_channels, bias=False, stride=stride, 17 | kernel_size=(kernel_size, 1), padding=(kernel_size//2, 0), ) 18 | self.bn11 = nn.BatchNorm2d(out_channels) 19 | self.relu11 = nn.ReLU(inplace=True) 20 | self.conv12 = nn.Conv2d(out_channels, out_channels, bias=False, stride=stride, 21 | kernel_size=(1, kernel_size), padding=(0, kernel_size//2)) 22 | self.bn12 = nn.BatchNorm2d(out_channels) 23 | self.relu12 = nn.ReLU(inplace=True) 24 | 25 | self.conv21 = nn.Conv2d(in_channels, out_channels, bias=False, stride=stride, 26 | kernel_size=(1, kernel_size), padding=(0, kernel_size//2)) 27 | self.bn21 = nn.BatchNorm2d(out_channels) 28 | self.relu21 = nn.ReLU(inplace=True) 29 | self.conv22 = nn.Conv2d(out_channels, out_channels, bias=False, stride=stride, 30 | kernel_size=(kernel_size, 1), padding=(kernel_size//2, 0)) 31 | self.bn22 = nn.BatchNorm2d(out_channels) 32 | self.relu22 = nn.ReLU(inplace=True) 33 | 34 | 35 | def forward(self, x): 36 | x1 = self.conv11(x) 37 | x1 = self.bn11(x1) 38 | x1 = self.relu11(x1) 39 | x1 = self.conv12(x1) 40 | x1 = self.bn12(x1) 41 | x1 = self.relu12(x1) 42 | 43 | x2 = self.conv21(x) 44 | x2 = self.bn21(x2) 45 | x2 = self.relu21(x2) 46 | x2 = self.conv22(x2) 47 | x2 = self.bn22(x2) 48 | x2 = self.relu22(x2) 49 | 50 | x = x1 + x2 51 | return x 52 | 53 | class BottleneckGCN(nn.Module): 54 | def __init__(self, in_channels, out_channels, kernel_size, out_channels_gcn, stride=1): 55 | super(BottleneckGCN, self).__init__() 56 | if in_channels != out_channels or stride != 1: 57 | self.downsample = nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride), 59 | nn.BatchNorm2d(out_channels)) 60 | else: self.downsample = None 61 | 62 | self.gcn = Block_Resnet_GCN(kernel_size, in_channels, out_channels_gcn) 63 | self.conv1x1 = nn.Conv2d(out_channels_gcn, out_channels, 1, stride=stride, bias=False) 64 | self.bn1x1 = nn.BatchNorm2d(out_channels) 65 | 66 | def forward(self, x): 67 | identity = x 68 | if self.downsample is not None: 69 | identity = self.downsample(identity) 70 | 71 | x = self.gcn(x) 72 | x = self.conv1x1(x) 73 | x = self.bn1x1(x) 74 | 75 | x += identity 76 | return x 77 | 78 | class ResnetGCN(nn.Module): 79 | def __init__(self, in_channels, backbone, out_channels_gcn=(85, 128), kernel_sizes=(5, 7)): 80 | super(ResnetGCN, self).__init__() 81 | resnet = getattr(torchvision.models, backbone)(pretrained=False) 82 | 83 | if in_channels == 3: conv1 = resnet.conv1 84 | else: conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 85 | self.initial = nn.Sequential( 86 | conv1, 87 | resnet.bn1, 88 | resnet.relu, 89 | resnet.maxpool) 90 | 91 | self.layer1 = resnet.layer1 92 | self.layer2 = resnet.layer2 93 | self.layer3 = nn.Sequential( 94 | BottleneckGCN(512, 1024, kernel_sizes[0], out_channels_gcn[0], stride=2), 95 | *[BottleneckGCN(1024, 1024, kernel_sizes[0], out_channels_gcn[0])]*5) 96 | self.layer4 = nn.Sequential( 97 | BottleneckGCN(1024, 2048, kernel_sizes[1], out_channels_gcn[1], stride=2), 98 | *[BottleneckGCN(1024, 1024, kernel_sizes[1], out_channels_gcn[1])]*5) 99 | initialize_weights(self) 100 | 101 | def forward(self, x): 102 | x = self.initial(x) 103 | conv1_sz = (x.size(2), x.size(3)) 104 | x1 = self.layer1(x) 105 | x2 = self.layer2(x1) 106 | x3 = self.layer3(x2) 107 | x4 = self.layer4(x3) 108 | return x1, x2, x3, x4, conv1_sz 109 | 110 | ''' 111 | -> BackBone Resnet 112 | ''' 113 | 114 | class Resnet(nn.Module): 115 | def __init__(self, in_channels, backbone, out_channels_gcn=(85, 128), 116 | pretrained=True, kernel_sizes=(5, 7)): 117 | super(Resnet, self).__init__() 118 | resnet = getattr(torchvision.models, backbone)(pretrained) 119 | 120 | if in_channels == 3: conv1 = resnet.conv1 121 | else: conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 122 | self.initial = nn.Sequential( 123 | conv1, 124 | resnet.bn1, 125 | resnet.relu, 126 | resnet.maxpool) 127 | 128 | self.layer1 = resnet.layer1 129 | self.layer2 = resnet.layer2 130 | self.layer3 = resnet.layer3 131 | self.layer4 = resnet.layer4 132 | if not pretrained: initialize_weights(self) 133 | 134 | def forward(self, x): 135 | x = self.initial(x) 136 | conv1_sz = (x.size(2), x.size(3)) 137 | x1 = self.layer1(x) 138 | x2 = self.layer2(x1) 139 | x3 = self.layer3(x2) 140 | x4 = self.layer4(x3) 141 | return x1, x2, x3, x4, conv1_sz 142 | 143 | ''' 144 | -> Global Convolutionnal Network 145 | ''' 146 | 147 | class GCN_Block(nn.Module): 148 | def __init__(self, kernel_size, in_channels, out_channels): 149 | super(GCN_Block, self).__init__() 150 | 151 | assert kernel_size % 2 == 1, 'Kernel size must be odd' 152 | self.conv11 = nn.Conv2d(in_channels, out_channels, 153 | kernel_size=(kernel_size, 1), padding=(kernel_size//2, 0)) 154 | self.conv12 = nn.Conv2d(out_channels, out_channels, 155 | kernel_size=(1, kernel_size), padding=(0, kernel_size//2)) 156 | 157 | self.conv21 = nn.Conv2d(in_channels, out_channels, 158 | kernel_size=(1, kernel_size), padding=(0, kernel_size//2)) 159 | self.conv22 = nn.Conv2d(out_channels, out_channels, 160 | kernel_size=(kernel_size, 1), padding=(kernel_size//2, 0)) 161 | initialize_weights(self) 162 | 163 | def forward(self, x): 164 | x1 = self.conv11(x) 165 | x1 = self.conv12(x1) 166 | x2 = self.conv21(x) 167 | x2 = self.conv22(x2) 168 | 169 | x = x1 + x2 170 | return x 171 | 172 | class BR_Block(nn.Module): 173 | def __init__(self, num_channels): 174 | super(BR_Block, self).__init__() 175 | self.bn1 = nn.BatchNorm2d(num_channels) 176 | self.relu1 = nn.ReLU(inplace=True) 177 | self.conv1 = nn.Conv2d(num_channels, num_channels, 3, padding=1) 178 | self.bn2 = nn.BatchNorm2d(num_channels) 179 | self.relu2 = nn.ReLU(inplace=True) 180 | self.conv2 = nn.Conv2d(num_channels, num_channels, 3, padding=1) 181 | initialize_weights(self) 182 | 183 | def forward(self, x): 184 | identity = x 185 | # x = self.conv1(self.relu1(self.bn1(x))) 186 | # x = self.conv2(self.relu2(self.bn2(x))) 187 | x = self.conv2(self.relu2(self.conv1(x))) 188 | x += identity 189 | return x 190 | 191 | class GCN(BaseModel): 192 | def __init__(self, num_classes, in_channels=3, pretrained=True, use_resnet_gcn=False, backbone='resnet50', use_deconv=False, 193 | num_filters=11, freeze_bn=False, **_): 194 | super(GCN, self).__init__() 195 | self.use_deconv = use_deconv 196 | if use_resnet_gcn: 197 | self.backbone = ResnetGCN(in_channels, backbone=backbone) 198 | else: 199 | self.backbone = Resnet(in_channels, pretrained=pretrained, backbone=backbone) 200 | 201 | if (backbone == 'resnet34' or backbone == 'resnet18'): resnet_channels = [64, 128, 256, 512] 202 | else: resnet_channels = [256, 512, 1024, 2048] 203 | 204 | self.gcn1 = GCN_Block(num_filters, resnet_channels[0], num_classes) 205 | self.br1 = BR_Block(num_classes) 206 | self.gcn2 = GCN_Block(num_filters, resnet_channels[1], num_classes) 207 | self.br2 = BR_Block(num_classes) 208 | self.gcn3 = GCN_Block(num_filters, resnet_channels[2], num_classes) 209 | self.br3 = BR_Block(num_classes) 210 | self.gcn4 = GCN_Block(num_filters, resnet_channels[3], num_classes) 211 | self.br4 = BR_Block(num_classes) 212 | 213 | self.br5 = BR_Block(num_classes) 214 | self.br6 = BR_Block(num_classes) 215 | self.br7 = BR_Block(num_classes) 216 | self.br8 = BR_Block(num_classes) 217 | self.br9 = BR_Block(num_classes) 218 | 219 | if self.use_deconv: 220 | self.decon1 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=3, padding=1, 221 | output_padding=1, stride=2, bias=False) 222 | self.decon2 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=3, padding=1, 223 | output_padding=1, stride=2, bias=False) 224 | self.decon3 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=3, padding=1, 225 | output_padding=1, stride=2, bias=False) 226 | self.decon4 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=3, padding=1, 227 | output_padding=1, stride=2, bias=False) 228 | self.decon5 = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=3, padding=1, 229 | output_padding=1, stride=2, bias=False) 230 | self.final_conv = nn.Conv2d(num_classes, num_classes, kernel_size=1) 231 | if freeze_bn: self.freeze_bn() 232 | if freeze_backbone: 233 | set_trainable([self.backbone], False) 234 | 235 | def forward(self, x): 236 | x1, x2, x3, x4, conv1_sz = self.backbone(x) 237 | 238 | x1 = self.br1(self.gcn1(x1)) 239 | x2 = self.br2(self.gcn2(x2)) 240 | x3 = self.br3(self.gcn3(x3)) 241 | x4 = self.br4(self.gcn4(x4)) 242 | 243 | if self.use_deconv: 244 | # Padding because when using deconv, if the size is odd, we'll have an alignment error 245 | x4 = self.decon4(x4) 246 | if x4.size() != x3.size(): x4 = self._pad(x4, x3) 247 | x3 = self.decon3(self.br5(x3 + x4)) 248 | if x3.size() != x2.size(): x3 = self._pad(x3, x2) 249 | x2 = self.decon2(self.br6(x2 + x3)) 250 | x1 = self.decon1(self.br7(x1 + x2)) 251 | 252 | x = self.br9(self.decon5(self.br8(x1))) 253 | else: 254 | x4 = F.interpolate(x4, size=x3.size()[2:], mode='bilinear', align_corners=True) 255 | x3 = F.interpolate(self.br5(x3 + x4), size=x2.size()[2:], mode='bilinear', align_corners=True) 256 | x2 = F.interpolate(self.br6(x2 + x3), size=x1.size()[2:], mode='bilinear', align_corners=True) 257 | x1 = F.interpolate(self.br7(x1 + x2), size=conv1_sz, mode='bilinear', align_corners=True) 258 | 259 | x = self.br9(F.interpolate(self.br8(x1), size=x.size()[2:], mode='bilinear', align_corners=True)) 260 | return self.final_conv(x) 261 | 262 | def _pad(self, x_topad, x): 263 | pad = (x.size(3) - x_topad.size(3), 0, x.size(2) - x_topad.size(2), 0) 264 | x_topad = F.pad(x_topad, pad, "constant", 0) 265 | return x_topad 266 | 267 | def get_backbone_params(self): 268 | return self.backbone.parameters() 269 | 270 | def get_decoder_params(self): 271 | return [p for n, p in self.named_parameters() if 'backbone' not in n] 272 | 273 | def freeze_bn(self): 274 | for module in self.modules(): 275 | if isinstance(module, nn.BatchNorm2d): module.eval() 276 | 277 | -------------------------------------------------------------------------------- /models/pspnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from models import resnet 6 | from torchvision import models 7 | from base import BaseModel 8 | from utils.helpers import initialize_weights, set_trainable 9 | from itertools import chain 10 | 11 | class _PSPModule(nn.Module): 12 | def __init__(self, in_channels, bin_sizes, norm_layer): 13 | super(_PSPModule, self).__init__() 14 | out_channels = in_channels // len(bin_sizes) 15 | self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s, norm_layer) 16 | for b_s in bin_sizes]) 17 | self.bottleneck = nn.Sequential( 18 | nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), out_channels, 19 | kernel_size=3, padding=1, bias=False), 20 | norm_layer(out_channels), 21 | nn.ReLU(inplace=True), 22 | nn.Dropout2d(0.1) 23 | ) 24 | 25 | def _make_stages(self, in_channels, out_channels, bin_sz, norm_layer): 26 | prior = nn.AdaptiveAvgPool2d(output_size=bin_sz) 27 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 28 | bn = norm_layer(out_channels) 29 | relu = nn.ReLU(inplace=True) 30 | return nn.Sequential(prior, conv, bn, relu) 31 | 32 | def forward(self, features): 33 | h, w = features.size()[2], features.size()[3] 34 | pyramids = [features] 35 | pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', 36 | align_corners=True) for stage in self.stages]) 37 | output = self.bottleneck(torch.cat(pyramids, dim=1)) 38 | return output 39 | 40 | 41 | class PSPNet(BaseModel): 42 | def __init__(self, num_classes, in_channels=3, backbone='resnet152', pretrained=True, use_aux=True, freeze_bn=False, freeze_backbone=False): 43 | super(PSPNet, self).__init__() 44 | norm_layer = nn.BatchNorm2d 45 | model = getattr(resnet, backbone)(pretrained, norm_layer=norm_layer) 46 | m_out_sz = model.fc.in_features 47 | self.use_aux = use_aux 48 | 49 | self.initial = nn.Sequential(*list(model.children())[:4]) 50 | if in_channels != 3: 51 | self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 52 | self.initial = nn.Sequential(*self.initial) 53 | 54 | self.layer1 = model.layer1 55 | self.layer2 = model.layer2 56 | self.layer3 = model.layer3 57 | self.layer4 = model.layer4 58 | 59 | self.master_branch = nn.Sequential( 60 | _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=norm_layer), 61 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1) 62 | ) 63 | 64 | self.auxiliary_branch = nn.Sequential( 65 | nn.Conv2d(m_out_sz//2, m_out_sz//4, kernel_size=3, padding=1, bias=False), 66 | norm_layer(m_out_sz//4), 67 | nn.ReLU(inplace=True), 68 | nn.Dropout2d(0.1), 69 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1) 70 | ) 71 | 72 | initialize_weights(self.master_branch, self.auxiliary_branch) 73 | if freeze_bn: self.freeze_bn() 74 | if freeze_backbone: 75 | set_trainable([self.initial, self.layer1, self.layer2, self.layer3, self.layer4], False) 76 | 77 | def forward(self, x): 78 | input_size = (x.size()[2], x.size()[3]) 79 | x = self.initial(x) 80 | x = self.layer1(x) 81 | x = self.layer2(x) 82 | x_aux = self.layer3(x) 83 | x = self.layer4(x_aux) 84 | 85 | output = self.master_branch(x) 86 | output = F.interpolate(output, size=input_size, mode='bilinear') 87 | output = output[:, :, :input_size[0], :input_size[1]] 88 | 89 | if self.training and self.use_aux: 90 | aux = self.auxiliary_branch(x_aux) 91 | aux = F.interpolate(aux, size=input_size, mode='bilinear') 92 | aux = aux[:, :, :input_size[0], :input_size[1]] 93 | return output, aux 94 | return output 95 | 96 | def get_backbone_params(self): 97 | return chain(self.initial.parameters(), self.layer1.parameters(), self.layer2.parameters(), 98 | self.layer3.parameters(), self.layer4.parameters()) 99 | 100 | def get_decoder_params(self): 101 | return chain(self.master_branch.parameters(), self.auxiliary_branch.parameters()) 102 | 103 | def freeze_bn(self): 104 | for module in self.modules(): 105 | if isinstance(module, nn.BatchNorm2d): module.eval() 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | ## PSP with dense net as the backbone 116 | 117 | class PSPDenseNet(BaseModel): 118 | def __init__(self, num_classes, in_channels=3, backbone='densenet201', pretrained=True, use_aux=True, freeze_bn=False, **_): 119 | super(PSPDenseNet, self).__init__() 120 | self.use_aux = use_aux 121 | model = getattr(models, backbone)(pretrained) 122 | m_out_sz = model.classifier.in_features 123 | aux_out_sz = model.features.transition3.conv.out_channels 124 | 125 | if not pretrained or in_channels != 3: 126 | # If we're training from scratch, better to use 3x3 convs 127 | block0 = [nn.Conv2d(in_channels, 64, 3, stride=2, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)] 128 | block0.extend( 129 | [nn.Conv2d(64, 64, 3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)] * 2 130 | ) 131 | self.block0 = nn.Sequential( 132 | *block0, 133 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | ) 135 | initialize_weights(self.block0) 136 | else: 137 | self.block0 = nn.Sequential(*list(model.features.children())[:4]) 138 | 139 | self.block1 = model.features.denseblock1 140 | self.block2 = model.features.denseblock2 141 | self.block3 = model.features.denseblock3 142 | self.block4 = model.features.denseblock4 143 | 144 | self.transition1 = model.features.transition1 145 | # No pooling 146 | self.transition2 = nn.Sequential( 147 | *list(model.features.transition2.children())[:-1]) 148 | self.transition3 = nn.Sequential( 149 | *list(model.features.transition3.children())[:-1]) 150 | 151 | for n, m in self.block3.named_modules(): 152 | if 'conv2' in n: 153 | m.dilation, m.padding = (2,2), (2,2) 154 | for n, m in self.block4.named_modules(): 155 | if 'conv2' in n: 156 | m.dilation, m.padding = (4,4), (4,4) 157 | 158 | self.master_branch = nn.Sequential( 159 | _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=nn.BatchNorm2d), 160 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1) 161 | ) 162 | 163 | self.auxiliary_branch = nn.Sequential( 164 | nn.Conv2d(aux_out_sz, m_out_sz//4, kernel_size=3, padding=1, bias=False), 165 | nn.BatchNorm2d(m_out_sz//4), 166 | nn.ReLU(inplace=True), 167 | nn.Dropout2d(0.1), 168 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1) 169 | ) 170 | 171 | initialize_weights(self.master_branch, self.auxiliary_branch) 172 | if freeze_bn: self.freeze_bn() 173 | 174 | def forward(self, x): 175 | input_size = (x.size()[2], x.size()[3]) 176 | 177 | x = self.block0(x) 178 | x = self.block1(x) 179 | x = self.transition1(x) 180 | x = self.block2(x) 181 | x = self.transition2(x) 182 | x = self.block3(x) 183 | x_aux = self.transition3(x) 184 | x = self.block4(x_aux) 185 | 186 | output = self.master_branch(x) 187 | output = F.interpolate(output, size=input_size, mode='bilinear') 188 | 189 | if self.training and self.use_aux: 190 | aux = self.auxiliary_branch(x_aux) 191 | aux = F.interpolate(aux, size=input_size, mode='bilinear') 192 | return output, aux 193 | return output 194 | 195 | def get_backbone_params(self): 196 | return chain(self.block0.parameters(), self.block1.parameters(), self.block2.parameters(), 197 | self.block3.parameters(), self.transition1.parameters(), self.transition2.parameters(), 198 | self.transition3.parameters()) 199 | 200 | def get_decoder_params(self): 201 | return chain(self.master_branch.parameters(), self.auxiliary_branch.parameters()) 202 | 203 | def freeze_bn(self): 204 | for module in self.modules(): 205 | if isinstance(module, nn.BatchNorm2d): module.eval() 206 | -------------------------------------------------------------------------------- /models/segnet.py: -------------------------------------------------------------------------------- 1 | from base import BaseModel 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import models 6 | from utils.helpers import set_trainable 7 | from itertools import chain 8 | from math import ceil 9 | 10 | import copy 11 | 12 | 13 | class SegNet(BaseModel): 14 | def __init__(self, num_classes, in_channels=3, pretrained=True, freeze_bn=False, freeze_backbone=False, **_): 15 | super(SegNet, self).__init__() 16 | vgg_bn = models.vgg16_bn(weights='VGG16_BN_Weights.IMAGENET1K_V1') 17 | encoder = list(vgg_bn.features.children()) 18 | 19 | # Adjust the input size 20 | if in_channels != 3: 21 | encoder[0] = nn.Conv2d( 22 | in_channels, 64, kernel_size=3, stride=1, padding=1) 23 | 24 | # Encoder, VGG without any maxpooling 25 | self.stage1_encoder = nn.Sequential(*encoder[:6]) 26 | self.stage2_encoder = nn.Sequential(*encoder[7:13]) 27 | self.stage3_encoder = nn.Sequential(*encoder[14:23]) 28 | self.stage4_encoder = nn.Sequential(*encoder[24:33]) 29 | self.stage5_encoder = nn.Sequential(*encoder[34:-1]) 30 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 31 | 32 | # Decoder, same as the encoder but reversed, maxpool will not be used 33 | # 2023.11.2 Serious error!!! Shallow copy of the list, where elements within the list point to the same object, and many modules (especially the BN layer) may be duplicated 34 | # decoder = encoder 35 | 36 | # List deep copy 37 | decoder = copy.deepcopy(encoder) 38 | 39 | decoder = [i for i in list(reversed(decoder)) 40 | if not isinstance(i, nn.MaxPool2d)] 41 | # Replace the last conv layer 42 | decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 43 | # When reversing, we also reversed conv->batchN->relu, correct it 44 | decoder = [item for i in range(0, len(decoder), 3) 45 | for item in decoder[i:i+3][::-1]] 46 | # Replace some conv layers & batchN after them 47 | for i, module in enumerate(decoder): 48 | if isinstance(module, nn.Conv2d): 49 | if module.in_channels != module.out_channels: 50 | decoder[i+1] = nn.BatchNorm2d(module.in_channels) 51 | decoder[i] = nn.Conv2d( 52 | module.out_channels, module.in_channels, kernel_size=3, stride=1, padding=1) 53 | 54 | self.stage1_decoder = nn.Sequential(*decoder[0:9]) 55 | self.stage2_decoder = nn.Sequential(*decoder[9:18]) 56 | self.stage3_decoder = nn.Sequential(*decoder[18:27]) 57 | self.stage4_decoder = nn.Sequential(*decoder[27:33]) 58 | self.stage5_decoder = nn.Sequential(*decoder[33:], 59 | nn.Conv2d( 60 | 64, num_classes, kernel_size=3, stride=1, padding=1) 61 | ) 62 | self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) 63 | 64 | self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder, 65 | self.stage4_decoder, self.stage5_decoder) 66 | if freeze_bn: 67 | self.freeze_bn() 68 | if freeze_backbone: 69 | set_trainable([self.stage1_encoder, self.stage2_encoder, 70 | self.stage3_encoder, self.stage4_encoder, self.stage5_encoder], False) 71 | 72 | def _initialize_weights(self, *stages): 73 | for modules in stages: 74 | for module in modules.modules(): 75 | if isinstance(module, nn.Conv2d): 76 | nn.init.kaiming_normal_(module.weight) 77 | if module.bias is not None: 78 | module.bias.data.zero_() 79 | elif isinstance(module, nn.BatchNorm2d): 80 | module.weight.data.fill_(1) 81 | module.bias.data.zero_() 82 | 83 | def forward(self, x): 84 | # Encoder 85 | x = self.stage1_encoder(x) 86 | x1_size = x.size() 87 | x, indices1 = self.pool(x) 88 | 89 | x = self.stage2_encoder(x) 90 | x2_size = x.size() 91 | x, indices2 = self.pool(x) 92 | 93 | x = self.stage3_encoder(x) 94 | x3_size = x.size() 95 | x, indices3 = self.pool(x) 96 | 97 | x = self.stage4_encoder(x) 98 | x4_size = x.size() 99 | x, indices4 = self.pool(x) 100 | 101 | x = self.stage5_encoder(x) 102 | x5_size = x.size() 103 | x, indices5 = self.pool(x) 104 | 105 | # Decoder 106 | x = self.unpool(x, indices=indices5, output_size=x5_size) 107 | x = self.stage1_decoder(x) 108 | 109 | x = self.unpool(x, indices=indices4, output_size=x4_size) 110 | x = self.stage2_decoder(x) 111 | 112 | x = self.unpool(x, indices=indices3, output_size=x3_size) 113 | x = self.stage3_decoder(x) 114 | 115 | x = self.unpool(x, indices=indices2, output_size=x2_size) 116 | x = self.stage4_decoder(x) 117 | 118 | x = self.unpool(x, indices=indices1, output_size=x1_size) 119 | x = self.stage5_decoder(x) 120 | 121 | return x 122 | 123 | def get_backbone_params(self): 124 | return [] 125 | 126 | def get_decoder_params(self): 127 | return self.parameters() 128 | 129 | def freeze_bn(self): 130 | for module in self.modules(): 131 | if isinstance(module, nn.BatchNorm2d): 132 | module.eval() 133 | 134 | 135 | class DecoderBottleneck(nn.Module): 136 | def __init__(self, inchannels): 137 | super(DecoderBottleneck, self).__init__() 138 | self.conv1 = nn.Conv2d(inchannels, inchannels//4, 139 | kernel_size=1, bias=False) 140 | self.bn1 = nn.BatchNorm2d(inchannels//4) 141 | self.conv2 = nn.ConvTranspose2d( 142 | inchannels//4, inchannels//4, kernel_size=2, stride=2, bias=False) 143 | self.bn2 = nn.BatchNorm2d(inchannels//4) 144 | self.conv3 = nn.Conv2d(inchannels//4, inchannels//2, 1, bias=False) 145 | self.bn3 = nn.BatchNorm2d(inchannels//2) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.downsample = nn.Sequential( 148 | nn.ConvTranspose2d(inchannels, inchannels//2, 149 | kernel_size=2, stride=2, bias=False), 150 | nn.BatchNorm2d(inchannels//2)) 151 | 152 | def forward(self, x): 153 | out = self.conv1(x) 154 | out = self.bn1(out) 155 | out = self.relu(out) 156 | out = self.conv2(out) 157 | out = self.bn2(out) 158 | out = self.relu(out) 159 | out = self.conv3(out) 160 | out = self.bn3(out) 161 | 162 | identity = self.downsample(x) 163 | out += identity 164 | out = self.relu(out) 165 | return out 166 | 167 | 168 | class LastBottleneck(nn.Module): 169 | def __init__(self, inchannels): 170 | super(LastBottleneck, self).__init__() 171 | self.conv1 = nn.Conv2d(inchannels, inchannels//4, 172 | kernel_size=1, bias=False) 173 | self.bn1 = nn.BatchNorm2d(inchannels//4) 174 | self.conv2 = nn.Conv2d(inchannels//4, inchannels // 175 | 4, kernel_size=3, padding=1, bias=False) 176 | self.bn2 = nn.BatchNorm2d(inchannels//4) 177 | self.conv3 = nn.Conv2d(inchannels//4, inchannels//4, 1, bias=False) 178 | self.bn3 = nn.BatchNorm2d(inchannels//4) 179 | self.relu = nn.ReLU(inplace=True) 180 | self.downsample = nn.Sequential( 181 | nn.Conv2d(inchannels, inchannels//4, kernel_size=1, bias=False), 182 | nn.BatchNorm2d(inchannels//4)) 183 | 184 | def forward(self, x): 185 | out = self.conv1(x) 186 | out = self.bn1(out) 187 | out = self.relu(out) 188 | out = self.conv2(out) 189 | out = self.bn2(out) 190 | out = self.relu(out) 191 | out = self.conv3(out) 192 | out = self.bn3(out) 193 | 194 | identity = self.downsample(x) 195 | out += identity 196 | out = self.relu(out) 197 | return out 198 | 199 | 200 | class SegResNet(BaseModel): 201 | def __init__(self, num_classes, in_channels=3, pretrained=True, freeze_bn=False, freeze_backbone=False, **_): 202 | super(SegResNet, self).__init__() 203 | resnet50 = models.resnet50(pretrained=pretrained) 204 | encoder = list(resnet50.children()) 205 | if in_channels != 3: 206 | encoder[0] = nn.Conv2d( 207 | in_channels, 64, kernel_size=3, stride=1, padding=1) 208 | encoder[3].return_indices = True 209 | 210 | # Encoder 211 | self.first_conv = nn.Sequential(*encoder[:4]) 212 | resnet50_blocks = list(resnet50.children())[4:-2] 213 | self.encoder = nn.Sequential(*resnet50_blocks) 214 | 215 | # Decoder 216 | resnet50_untrained = models.resnet50(pretrained=False) 217 | resnet50_blocks = list(resnet50_untrained.children())[4:-2][::-1] 218 | decoder = [] 219 | channels = (2048, 1024, 512) 220 | for i, block in enumerate(resnet50_blocks[:-1]): 221 | new_block = list(block.children())[::-1][:-1] 222 | decoder.append(nn.Sequential( 223 | *new_block, DecoderBottleneck(channels[i]))) 224 | new_block = list(resnet50_blocks[-1].children())[::-1][:-1] 225 | decoder.append(nn.Sequential(*new_block, LastBottleneck(256))) 226 | 227 | self.decoder = nn.Sequential(*decoder) 228 | self.last_conv = nn.Sequential( 229 | nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False), 230 | nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1) 231 | ) 232 | if freeze_bn: 233 | self.freeze_bn() 234 | if freeze_backbone: 235 | set_trainable([self.first_conv, self.encoder], False) 236 | 237 | def forward(self, x): 238 | inputsize = x.size() 239 | 240 | # Encoder 241 | x, indices = self.first_conv(x) 242 | x = self.encoder(x) 243 | 244 | # Decoder 245 | x = self.decoder(x) 246 | h_diff = ceil((x.size()[2] - indices.size()[2]) / 2) 247 | w_diff = ceil((x.size()[3] - indices.size()[3]) / 2) 248 | if indices.size()[2] % 2 == 1: 249 | x = x[:, :, h_diff:x.size()[2]-(h_diff-1), 250 | w_diff: x.size()[3]-(w_diff-1)] 251 | else: 252 | x = x[:, :, h_diff:x.size()[2]-h_diff, w_diff: x.size()[3]-w_diff] 253 | 254 | x = F.max_unpool2d(x, indices, kernel_size=2, stride=2) 255 | x = self.last_conv(x) 256 | 257 | if inputsize != x.size(): 258 | h_diff = (x.size()[2] - inputsize[2]) // 2 259 | w_diff = (x.size()[3] - inputsize[3]) // 2 260 | x = x[:, :, h_diff:x.size()[2]-h_diff, w_diff: x.size()[3]-w_diff] 261 | if h_diff % 2 != 0: 262 | x = x[:, :, :-1, :] 263 | if w_diff % 2 != 0: 264 | x = x[:, :, :, :-1] 265 | 266 | return x 267 | 268 | def get_backbone_params(self): 269 | return chain(self.first_conv.parameters(), self.encoder.parameters()) 270 | 271 | def get_decoder_params(self): 272 | return chain(self.decoder.parameters(), self.last_conv.parameters()) 273 | 274 | def freeze_bn(self): 275 | for module in self.modules(): 276 | if isinstance(module, nn.BatchNorm2d): 277 | module.eval() 278 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | from base import BaseModel 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from itertools import chain 6 | from base import BaseModel 7 | from utils.helpers import initialize_weights, set_trainable 8 | from itertools import chain 9 | from models import resnet 10 | 11 | 12 | def x2conv(in_channels, out_channels, inner_channels=None): 13 | inner_channels = out_channels // 2 if inner_channels is None else inner_channels 14 | down_conv = nn.Sequential( 15 | nn.Conv2d(in_channels, inner_channels, kernel_size=3, padding=1, bias=False), 16 | nn.BatchNorm2d(inner_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(inner_channels, out_channels, kernel_size=3, padding=1, bias=False), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True)) 21 | return down_conv 22 | 23 | class encoder(nn.Module): 24 | def __init__(self, in_channels, out_channels): 25 | super(encoder, self).__init__() 26 | self.down_conv = x2conv(in_channels, out_channels) 27 | self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True) 28 | 29 | def forward(self, x): 30 | x = self.down_conv(x) 31 | x = self.pool(x) 32 | return x 33 | 34 | class decoder(nn.Module): 35 | def __init__(self, in_channels, out_channels): 36 | super(decoder, self).__init__() 37 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 38 | self.up_conv = x2conv(in_channels, out_channels) 39 | 40 | def forward(self, x_copy, x, interpolate=True): 41 | x = self.up(x) 42 | 43 | if (x.size(2) != x_copy.size(2)) or (x.size(3) != x_copy.size(3)): 44 | if interpolate: 45 | # Iterpolating instead of padding 46 | x = F.interpolate(x, size=(x_copy.size(2), x_copy.size(3)), 47 | mode="bilinear", align_corners=True) 48 | else: 49 | # Padding in case the incomping volumes are of different sizes 50 | diffY = x_copy.size()[2] - x.size()[2] 51 | diffX = x_copy.size()[3] - x.size()[3] 52 | x = F.pad(x, (diffX // 2, diffX - diffX // 2, 53 | diffY // 2, diffY - diffY // 2)) 54 | 55 | # Concatenate 56 | x = torch.cat([x_copy, x], dim=1) 57 | x = self.up_conv(x) 58 | return x 59 | 60 | 61 | class UNet(BaseModel): 62 | def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_): 63 | super(UNet, self).__init__() 64 | 65 | self.start_conv = x2conv(in_channels, 64) 66 | self.down1 = encoder(64, 128) 67 | self.down2 = encoder(128, 256) 68 | self.down3 = encoder(256, 512) 69 | self.down4 = encoder(512, 1024) 70 | 71 | self.middle_conv = x2conv(1024, 1024) 72 | 73 | self.up1 = decoder(1024, 512) 74 | self.up2 = decoder(512, 256) 75 | self.up3 = decoder(256, 128) 76 | self.up4 = decoder(128, 64) 77 | self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1) 78 | self._initialize_weights() 79 | 80 | if freeze_bn: 81 | self.freeze_bn() 82 | 83 | def _initialize_weights(self): 84 | for module in self.modules(): 85 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 86 | nn.init.kaiming_normal_(module.weight) 87 | if module.bias is not None: 88 | module.bias.data.zero_() 89 | elif isinstance(module, nn.BatchNorm2d): 90 | module.weight.data.fill_(1) 91 | module.bias.data.zero_() 92 | 93 | def forward(self, x): 94 | x1 = self.start_conv(x) 95 | x2 = self.down1(x1) 96 | x3 = self.down2(x2) 97 | x4 = self.down3(x3) 98 | x = self.middle_conv(self.down4(x4)) 99 | 100 | x = self.up1(x4, x) 101 | x = self.up2(x3, x) 102 | x = self.up3(x2, x) 103 | x = self.up4(x1, x) 104 | 105 | x = self.final_conv(x) 106 | return x 107 | 108 | def get_backbone_params(self): 109 | # There is no backbone for unet, all the parameters are trained from scratch 110 | return [] 111 | 112 | def get_decoder_params(self): 113 | return self.parameters() 114 | 115 | def freeze_bn(self): 116 | for module in self.modules(): 117 | if isinstance(module, nn.BatchNorm2d): module.eval() 118 | 119 | 120 | 121 | 122 | """ 123 | -> Unet with a resnet backbone 124 | """ 125 | 126 | class UNetResnet(BaseModel): 127 | def __init__(self, num_classes, in_channels=3, backbone='resnet50', pretrained=True, freeze_bn=False, freeze_backbone=False, **_): 128 | super(UNetResnet, self).__init__() 129 | model = getattr(resnet, backbone)(pretrained, norm_layer=nn.BatchNorm2d) 130 | 131 | self.initial = list(model.children())[:4] 132 | if in_channels != 3: 133 | self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 134 | self.initial = nn.Sequential(*self.initial) 135 | 136 | # encoder 137 | self.layer1 = model.layer1 138 | self.layer2 = model.layer2 139 | self.layer3 = model.layer3 140 | self.layer4 = model.layer4 141 | 142 | # decoder 143 | self.conv1 = nn.Conv2d(2048, 192, kernel_size=3, stride=1, padding=1) 144 | self.upconv1 = nn.ConvTranspose2d(192, 128, 4, 2, 1, bias=False) 145 | 146 | self.conv2 = nn.Conv2d(1152, 128, kernel_size=3, stride=1, padding=1) 147 | self.upconv2 = nn.ConvTranspose2d(128, 96, 4, 2, 1, bias=False) 148 | 149 | self.conv3 = nn.Conv2d(608, 96, kernel_size=3, stride=1, padding=1) 150 | self.upconv3 = nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False) 151 | 152 | self.conv4 = nn.Conv2d(320, 64, kernel_size=3, stride=1, padding=1) 153 | self.upconv4 = nn.ConvTranspose2d(64, 48, 4, 2, 1, bias=False) 154 | 155 | self.conv5 = nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1) 156 | self.upconv5 = nn.ConvTranspose2d(48, 32, 4, 2, 1, bias=False) 157 | 158 | self.conv6 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 159 | self.conv7 = nn.Conv2d(32, num_classes, kernel_size=1, bias=False) 160 | 161 | initialize_weights(self) 162 | 163 | if freeze_bn: 164 | self.freeze_bn() 165 | if freeze_backbone: 166 | set_trainable([self.initial, self.layer1, self.layer2, self.layer3, self.layer4], False) 167 | 168 | def forward(self, x): 169 | H, W = x.size(2), x.size(3) 170 | x1 = self.layer1(self.initial(x)) 171 | x2 = self.layer2(x1) 172 | x3 = self.layer3(x2) 173 | x4 = self.layer4(x3) 174 | 175 | x = self.upconv1(self.conv1(x4)) 176 | x = F.interpolate(x, size=(x3.size(2), x3.size(3)), mode="bilinear", align_corners=True) 177 | x = torch.cat([x, x3], dim=1) 178 | x = self.upconv2(self.conv2(x)) 179 | 180 | x = F.interpolate(x, size=(x2.size(2), x2.size(3)), mode="bilinear", align_corners=True) 181 | x = torch.cat([x, x2], dim=1) 182 | x = self.upconv3(self.conv3(x)) 183 | 184 | x = F.interpolate(x, size=(x1.size(2), x1.size(3)), mode="bilinear", align_corners=True) 185 | x = torch.cat([x, x1], dim=1) 186 | 187 | x = self.upconv4(self.conv4(x)) 188 | 189 | x = self.upconv5(self.conv5(x)) 190 | 191 | # if the input is not divisible by the output stride 192 | if x.size(2) != H or x.size(3) != W: 193 | x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=True) 194 | 195 | x = self.conv7(self.conv6(x)) 196 | return x 197 | 198 | def get_backbone_params(self): 199 | return chain(self.initial.parameters(), self.layer1.parameters(), self.layer2.parameters(), 200 | self.layer3.parameters(), self.layer4.parameters()) 201 | 202 | def get_decoder_params(self): 203 | return chain(self.conv1.parameters(), self.upconv1.parameters(), self.conv2.parameters(), self.upconv2.parameters(), 204 | self.conv3.parameters(), self.upconv3.parameters(), self.conv4.parameters(), self.upconv4.parameters(), 205 | self.conv5.parameters(), self.upconv5.parameters(), self.conv6.parameters(), self.conv7.parameters()) 206 | 207 | def freeze_bn(self): 208 | for module in self.modules(): 209 | if isinstance(module, nn.BatchNorm2d): module.eval() 210 | -------------------------------------------------------------------------------- /models/upernet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from base import BaseModel 6 | from utils.helpers import initialize_weights 7 | from itertools import chain 8 | 9 | class PSPModule(nn.Module): 10 | # In the original inmplementation they use precise RoI pooling 11 | # Instead of using adaptative average pooling 12 | def __init__(self, in_channels, bin_sizes=[1, 2, 4, 6]): 13 | super(PSPModule, self).__init__() 14 | out_channels = in_channels // len(bin_sizes) 15 | self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) 16 | for b_s in bin_sizes]) 17 | self.bottleneck = nn.Sequential( 18 | nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), in_channels, 19 | kernel_size=3, padding=1, bias=False), 20 | nn.BatchNorm2d(in_channels), 21 | nn.ReLU(inplace=True), 22 | nn.Dropout2d(0.1) 23 | ) 24 | 25 | def _make_stages(self, in_channels, out_channels, bin_sz): 26 | prior = nn.AdaptiveAvgPool2d(output_size=bin_sz) 27 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 28 | bn = nn.BatchNorm2d(out_channels) 29 | relu = nn.ReLU(inplace=True) 30 | return nn.Sequential(prior, conv, bn, relu) 31 | 32 | def forward(self, features): 33 | h, w = features.size()[2], features.size()[3] 34 | pyramids = [features] 35 | pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', 36 | align_corners=True) for stage in self.stages]) 37 | output = self.bottleneck(torch.cat(pyramids, dim=1)) 38 | return output 39 | 40 | class ResNet(nn.Module): 41 | def __init__(self, in_channels=3, output_stride=16, backbone='resnet101', pretrained=True): 42 | super(ResNet, self).__init__() 43 | model = getattr(models, backbone)(pretrained) 44 | if not pretrained or in_channels != 3: 45 | self.initial = nn.Sequential( 46 | nn.Conv2d(in_channels, 64, 7, stride=2, padding=3, bias=False), 47 | nn.BatchNorm2d(64), 48 | nn.ReLU(inplace=True), 49 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 50 | ) 51 | initialize_weights(self.initial) 52 | else: 53 | self.initial = nn.Sequential(*list(model.children())[:4]) 54 | 55 | self.layer1 = model.layer1 56 | self.layer2 = model.layer2 57 | self.layer3 = model.layer3 58 | self.layer4 = model.layer4 59 | 60 | if output_stride == 16: s3, s4, d3, d4 = (2, 1, 1, 2) 61 | elif output_stride == 8: s3, s4, d3, d4 = (1, 1, 2, 4) 62 | 63 | if output_stride == 8: 64 | for n, m in self.layer3.named_modules(): 65 | if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'): 66 | m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3) 67 | elif 'conv2' in n: 68 | m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3) 69 | elif 'downsample.0' in n: 70 | m.stride = (s3, s3) 71 | 72 | for n, m in self.layer4.named_modules(): 73 | if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'): 74 | m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4) 75 | elif 'conv2' in n: 76 | m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4) 77 | elif 'downsample.0' in n: 78 | m.stride = (s4, s4) 79 | 80 | def forward(self, x): 81 | x = self.initial(x) 82 | x1 = self.layer1(x) 83 | x2 = self.layer2(x1) 84 | x3 = self.layer3(x2) 85 | x4 = self.layer4(x3) 86 | 87 | return [x1, x2, x3, x4] 88 | 89 | def up_and_add(x, y): 90 | return F.interpolate(x, size=(y.size(2), y.size(3)), mode='bilinear', align_corners=True) + y 91 | 92 | class FPN_fuse(nn.Module): 93 | def __init__(self, feature_channels=[256, 512, 1024, 2048], fpn_out=256): 94 | super(FPN_fuse, self).__init__() 95 | assert feature_channels[0] == fpn_out 96 | self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1) 97 | for ft_size in feature_channels[1:]]) 98 | self.smooth_conv = nn.ModuleList([nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)] 99 | * (len(feature_channels)-1)) 100 | self.conv_fusion = nn.Sequential( 101 | nn.Conv2d(len(feature_channels)*fpn_out, fpn_out, kernel_size=3, padding=1, bias=False), 102 | nn.BatchNorm2d(fpn_out), 103 | nn.ReLU(inplace=True) 104 | ) 105 | 106 | def forward(self, features): 107 | 108 | features[1:] = [conv1x1(feature) for feature, conv1x1 in zip(features[1:], self.conv1x1)] 109 | P = [up_and_add(features[i], features[i-1]) for i in reversed(range(1, len(features)))] 110 | P = [smooth_conv(x) for smooth_conv, x in zip(self.smooth_conv, P)] 111 | P = list(reversed(P)) 112 | P.append(features[-1]) #P = [P1, P2, P3, P4] 113 | H, W = P[0].size(2), P[0].size(3) 114 | P[1:] = [F.interpolate(feature, size=(H, W), mode='bilinear', align_corners=True) for feature in P[1:]] 115 | 116 | x = self.conv_fusion(torch.cat((P), dim=1)) 117 | return x 118 | 119 | class UperNet(BaseModel): 120 | # Implementing only the object path 121 | def __init__(self, num_classes, in_channels=3, backbone='resnet101', pretrained=True, use_aux=True, fpn_out=256, freeze_bn=False, **_): 122 | super(UperNet, self).__init__() 123 | 124 | if backbone == 'resnet34' or backbone == 'resnet18': 125 | feature_channels = [64, 128, 256, 512] 126 | else: 127 | feature_channels = [256, 512, 1024, 2048] 128 | self.backbone = ResNet(in_channels, backbone=backbone, pretrained=pretrained) 129 | self.PPN = PSPModule(feature_channels[-1]) 130 | self.FPN = FPN_fuse(feature_channels, fpn_out=fpn_out) 131 | self.head = nn.Conv2d(fpn_out, num_classes, kernel_size=3, padding=1) 132 | if freeze_bn: self.freeze_bn() 133 | if freeze_backbone: 134 | set_trainable([self.backbone], False) 135 | 136 | def forward(self, x): 137 | input_size = (x.size()[2], x.size()[3]) 138 | 139 | features = self.backbone(x) 140 | features[-1] = self.PPN(features[-1]) 141 | x = self.head(self.FPN(features)) 142 | 143 | x = F.interpolate(x, size=input_size, mode='bilinear') 144 | return x 145 | 146 | def get_backbone_params(self): 147 | return self.backbone.parameters() 148 | 149 | def get_decoder_params(self): 150 | return chain(self.PPN.parameters(), self.FPN.parameters(), self.head.parameters()) 151 | 152 | def freeze_bn(self): 153 | for module in self.modules(): 154 | if isinstance(module, nn.BatchNorm2d): module.eval() 155 | 156 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | torchvision==0.3.0 3 | tqdm==4.32.2 4 | tensorboard==1.14.0 5 | Pillow==6.2.0 6 | opencv-python==4.1.0.25 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import torch 5 | import dataloaders 6 | import models 7 | import inspect 8 | import math 9 | from utils import losses 10 | from utils import Logger 11 | from utils.torchsummary import summary 12 | from trainer import Trainer 13 | 14 | def get_instance(module, name, config, *args): 15 | # GET THE CORRESPONDING CLASS / FCT 16 | return getattr(module, config[name]['type'])(*args, **config[name]['args']) 17 | 18 | def main(config, resume): 19 | train_logger = Logger() 20 | 21 | # DATA LOADERS 22 | train_loader = get_instance(dataloaders, 'train_loader', config) 23 | val_loader = get_instance(dataloaders, 'val_loader', config) 24 | 25 | # MODEL 26 | model = get_instance(models, 'arch', config, train_loader.dataset.num_classes) 27 | print(f'\n{model}\n') 28 | 29 | # LOSS 30 | loss = getattr(losses, config['loss'])(ignore_index = config['ignore_index']) 31 | 32 | # TRAINING 33 | trainer = Trainer( 34 | model=model, 35 | loss=loss, 36 | resume=resume, 37 | config=config, 38 | train_loader=train_loader, 39 | val_loader=val_loader, 40 | train_logger=train_logger) 41 | 42 | trainer.train() 43 | 44 | if __name__=='__main__': 45 | # PARSE THE ARGS 46 | parser = argparse.ArgumentParser(description='PyTorch Training') 47 | parser.add_argument('-c', '--config', default='config.json',type=str, 48 | help='Path to the config file (default: config.json)') 49 | parser.add_argument('-r', '--resume', default=None, type=str, 50 | help='Path to the .pth model checkpoint to resume training') 51 | parser.add_argument('-d', '--device', default=None, type=str, 52 | help='indices of GPUs to enable (default: all)') 53 | args = parser.parse_args() 54 | 55 | config = json.load(open(args.config)) 56 | if args.resume: 57 | config = torch.load(args.resume)['config'] 58 | if args.device: 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 60 | 61 | main(config, args.resume) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import numpy as np 4 | from torchvision.utils import make_grid 5 | from torchvision import transforms 6 | from utils import transforms as local_transforms 7 | from base import BaseTrainer, DataPrefetcher 8 | from utils.helpers import colorize_mask 9 | from utils.metrics import eval_metrics, AverageMeter 10 | from tqdm import tqdm 11 | 12 | class Trainer(BaseTrainer): 13 | def __init__(self, model, loss, resume, config, train_loader, val_loader=None, train_logger=None, prefetch=True): 14 | super(Trainer, self).__init__(model, loss, resume, config, train_loader, val_loader, train_logger) 15 | 16 | self.wrt_mode, self.wrt_step = 'train_', 0 17 | self.log_step = config['trainer'].get('log_per_iter', int(np.sqrt(self.train_loader.batch_size))) 18 | if config['trainer']['log_per_iter']: self.log_step = int(self.log_step / self.train_loader.batch_size) + 1 19 | 20 | self.num_classes = self.train_loader.dataset.num_classes 21 | 22 | # TRANSORMS FOR VISUALIZATION 23 | self.restore_transform = transforms.Compose([ 24 | local_transforms.DeNormalize(self.train_loader.MEAN, self.train_loader.STD), 25 | transforms.ToPILImage()]) 26 | self.viz_transform = transforms.Compose([ 27 | transforms.Resize((400, 400)), 28 | transforms.ToTensor()]) 29 | 30 | if self.device == torch.device('cpu'): prefetch = False 31 | if prefetch: 32 | self.train_loader = DataPrefetcher(train_loader, device=self.device) 33 | self.val_loader = DataPrefetcher(val_loader, device=self.device) 34 | 35 | torch.backends.cudnn.benchmark = True 36 | 37 | def _train_epoch(self, epoch): 38 | self.logger.info('\n') 39 | 40 | self.model.train() 41 | if self.config['arch']['args']['freeze_bn']: 42 | if isinstance(self.model, torch.nn.DataParallel): self.model.module.freeze_bn() 43 | else: self.model.freeze_bn() 44 | self.wrt_mode = 'train' 45 | 46 | tic = time.time() 47 | self._reset_metrics() 48 | tbar = tqdm(self.train_loader, ncols=130) 49 | for batch_idx, (data, target) in enumerate(tbar): 50 | self.data_time.update(time.time() - tic) 51 | #data, target = data.to(self.device), target.to(self.device) 52 | self.lr_scheduler.step(epoch=epoch-1) 53 | 54 | # LOSS & OPTIMIZE 55 | self.optimizer.zero_grad() 56 | output = self.model(data) 57 | if self.config['arch']['type'][:3] == 'PSP': 58 | assert output[0].size()[2:] == target.size()[1:] 59 | assert output[0].size()[1] == self.num_classes 60 | loss = self.loss(output[0], target) 61 | loss += self.loss(output[1], target) * 0.4 62 | output = output[0] 63 | else: 64 | assert output.size()[2:] == target.size()[1:] 65 | assert output.size()[1] == self.num_classes 66 | loss = self.loss(output, target) 67 | 68 | if isinstance(self.loss, torch.nn.DataParallel): 69 | loss = loss.mean() 70 | loss.backward() 71 | self.optimizer.step() 72 | self.total_loss.update(loss.item()) 73 | 74 | # measure elapsed time 75 | self.batch_time.update(time.time() - tic) 76 | tic = time.time() 77 | 78 | # LOGGING & TENSORBOARD 79 | if batch_idx % self.log_step == 0: 80 | self.wrt_step = (epoch - 1) * len(self.train_loader) + batch_idx 81 | self.writer.add_scalar(f'{self.wrt_mode}/loss', loss.item(), self.wrt_step) 82 | 83 | # FOR EVAL 84 | seg_metrics = eval_metrics(output, target, self.num_classes) 85 | self._update_seg_metrics(*seg_metrics) 86 | pixAcc, mIoU, _ = self._get_seg_metrics().values() 87 | 88 | # PRINT INFO 89 | tbar.set_description('TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.2f} | B {:.2f} D {:.2f} |'.format( 90 | epoch, self.total_loss.average, 91 | pixAcc, mIoU, 92 | self.batch_time.average, self.data_time.average)) 93 | 94 | # METRICS TO TENSORBOARD 95 | seg_metrics = self._get_seg_metrics() 96 | for k, v in list(seg_metrics.items())[:-1]: 97 | self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step) 98 | for i, opt_group in enumerate(self.optimizer.param_groups): 99 | self.writer.add_scalar(f'{self.wrt_mode}/Learning_rate_{i}', opt_group['lr'], self.wrt_step) 100 | #self.writer.add_scalar(f'{self.wrt_mode}/Momentum_{k}', opt_group['momentum'], self.wrt_step) 101 | 102 | # RETURN LOSS & METRICS 103 | log = {'loss': self.total_loss.average, 104 | **seg_metrics} 105 | 106 | #if self.lr_scheduler is not None: self.lr_scheduler.step() 107 | return log 108 | 109 | def _valid_epoch(self, epoch): 110 | if self.val_loader is None: 111 | self.logger.warning('Not data loader was passed for the validation step, No validation is performed !') 112 | return {} 113 | self.logger.info('\n###### EVALUATION ######') 114 | 115 | self.model.eval() 116 | self.wrt_mode = 'val' 117 | 118 | self._reset_metrics() 119 | tbar = tqdm(self.val_loader, ncols=130) 120 | with torch.no_grad(): 121 | val_visual = [] 122 | for batch_idx, (data, target) in enumerate(tbar): 123 | #data, target = data.to(self.device), target.to(self.device) 124 | # LOSS 125 | output = self.model(data) 126 | loss = self.loss(output, target) 127 | if isinstance(self.loss, torch.nn.DataParallel): 128 | loss = loss.mean() 129 | self.total_loss.update(loss.item()) 130 | 131 | seg_metrics = eval_metrics(output, target, self.num_classes) 132 | self._update_seg_metrics(*seg_metrics) 133 | 134 | # LIST OF IMAGE TO VIZ (15 images) 135 | if len(val_visual) < 15: 136 | target_np = target.data.cpu().numpy() 137 | output_np = output.data.max(1)[1].cpu().numpy() 138 | val_visual.append([data[0].data.cpu(), target_np[0], output_np[0]]) 139 | 140 | # PRINT INFO 141 | pixAcc, mIoU, _ = self._get_seg_metrics().values() 142 | tbar.set_description('EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'.format( epoch, 143 | self.total_loss.average, 144 | pixAcc, mIoU)) 145 | 146 | # WRTING & VISUALIZING THE MASKS 147 | val_img = [] 148 | palette = self.train_loader.dataset.palette 149 | for d, t, o in val_visual: 150 | d = self.restore_transform(d) 151 | t, o = colorize_mask(t, palette), colorize_mask(o, palette) 152 | d, t, o = d.convert('RGB'), t.convert('RGB'), o.convert('RGB') 153 | [d, t, o] = [self.viz_transform(x) for x in [d, t, o]] 154 | val_img.extend([d, t, o]) 155 | val_img = torch.stack(val_img, 0) 156 | val_img = make_grid(val_img.cpu(), nrow=3, padding=5) 157 | self.writer.add_image(f'{self.wrt_mode}/inputs_targets_predictions', val_img, self.wrt_step) 158 | 159 | # METRICS TO TENSORBOARD 160 | self.wrt_step = (epoch) * len(self.val_loader) 161 | self.writer.add_scalar(f'{self.wrt_mode}/loss', self.total_loss.average, self.wrt_step) 162 | seg_metrics = self._get_seg_metrics() 163 | for k, v in list(seg_metrics.items())[:-1]: 164 | self.writer.add_scalar(f'{self.wrt_mode}/{k}', v, self.wrt_step) 165 | 166 | log = { 167 | 'val_loss': self.total_loss.average, 168 | **seg_metrics 169 | } 170 | 171 | return log 172 | 173 | def _reset_metrics(self): 174 | self.batch_time = AverageMeter() 175 | self.data_time = AverageMeter() 176 | self.total_loss = AverageMeter() 177 | self.total_inter, self.total_union = 0, 0 178 | self.total_correct, self.total_label = 0, 0 179 | 180 | def _update_seg_metrics(self, correct, labeled, inter, union): 181 | self.total_correct += correct 182 | self.total_label += labeled 183 | self.total_inter += inter 184 | self.total_union += union 185 | 186 | def _get_seg_metrics(self): 187 | pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label) 188 | IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union) 189 | mIoU = IoU.mean() 190 | return { 191 | "Pixel_Accuracy": np.round(pixAcc, 3), 192 | "Mean_IoU": np.round(mIoU, 3), 193 | "Class_IoU": dict(zip(range(self.num_classes), np.round(IoU, 3))) 194 | } -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import math 6 | import PIL 7 | 8 | def dir_exists(path): 9 | if not os.path.exists(path): 10 | os.makedirs(path) 11 | 12 | def initialize_weights(*models): 13 | for model in models: 14 | for m in model.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') 17 | elif isinstance(m, nn.BatchNorm2d): 18 | m.weight.data.fill_(1.) 19 | m.bias.data.fill_(1e-4) 20 | elif isinstance(m, nn.Linear): 21 | m.weight.data.normal_(0.0, 0.0001) 22 | m.bias.data.zero_() 23 | 24 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 25 | factor = (kernel_size + 1) // 2 26 | if kernel_size % 2 == 1: 27 | center = factor - 1 28 | else: 29 | center = factor - 0.5 30 | og = np.ogrid[:kernel_size, :kernel_size] 31 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 32 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) 33 | weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt 34 | return torch.from_numpy(weight).float() 35 | 36 | def colorize_mask(mask, palette): 37 | zero_pad = 256 * 3 - len(palette) 38 | for i in range(zero_pad): 39 | palette.append(0) 40 | new_mask = PIL.Image.fromarray(mask.astype(np.uint8)).convert('P') 41 | new_mask.putpalette(palette) 42 | return new_mask 43 | 44 | def set_trainable_attr(m,b): 45 | m.trainable = b 46 | for p in m.parameters(): p.requires_grad = b 47 | 48 | def apply_leaf(m, f): 49 | c = m if isinstance(m, (list, tuple)) else list(m.children()) 50 | if isinstance(m, nn.Module): 51 | f(m) 52 | if len(c)>0: 53 | for l in c: 54 | apply_leaf(l,f) 55 | 56 | def set_trainable(l, b): 57 | apply_leaf(l, lambda m: set_trainable_attr(m,b)) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | logging.basicConfig(level=logging.INFO, format='') 5 | 6 | class Logger: 7 | def __init__(self): 8 | self.entries = {} 9 | 10 | def add_entry(self, entry): 11 | self.entries[len(self.entries) + 1] = entry 12 | 13 | def __str__(self): 14 | return json.dumps(self.entries, sort_keys=True, indent=4) 15 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from sklearn.utils import class_weight 6 | from utils.lovasz_losses import lovasz_softmax 7 | 8 | def make_one_hot(labels, classes): 9 | one_hot = torch.FloatTensor(labels.size()[0], classes, labels.size()[2], labels.size()[3]).zero_().to(labels.device) 10 | target = one_hot.scatter_(1, labels.data, 1) 11 | return target 12 | 13 | def get_weights(target): 14 | t_np = target.view(-1).data.cpu().numpy() 15 | 16 | classes, counts = np.unique(t_np, return_counts=True) 17 | cls_w = np.median(counts) / counts 18 | #cls_w = class_weight.compute_class_weight('balanced', classes, t_np) 19 | 20 | weights = np.ones(7) 21 | weights[classes] = cls_w 22 | return torch.from_numpy(weights).float().cuda() 23 | 24 | class CrossEntropyLoss2d(nn.Module): 25 | def __init__(self, weight=None, ignore_index=255, reduction='mean'): 26 | super(CrossEntropyLoss2d, self).__init__() 27 | self.CE = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction=reduction) 28 | 29 | def forward(self, output, target): 30 | loss = self.CE(output, target) 31 | return loss 32 | 33 | class DiceLoss(nn.Module): 34 | def __init__(self, smooth=1., ignore_index=255): 35 | super(DiceLoss, self).__init__() 36 | self.ignore_index = ignore_index 37 | self.smooth = smooth 38 | 39 | def forward(self, output, target): 40 | if self.ignore_index not in range(target.min(), target.max()): 41 | if (target == self.ignore_index).sum() > 0: 42 | target[target == self.ignore_index] = target.min() 43 | target = make_one_hot(target.unsqueeze(dim=1), classes=output.size()[1]) 44 | output = F.softmax(output, dim=1) 45 | output_flat = output.contiguous().view(-1) 46 | target_flat = target.contiguous().view(-1) 47 | intersection = (output_flat * target_flat).sum() 48 | loss = 1 - ((2. * intersection + self.smooth) / 49 | (output_flat.sum() + target_flat.sum() + self.smooth)) 50 | return loss 51 | 52 | class FocalLoss(nn.Module): 53 | def __init__(self, gamma=2, alpha=None, ignore_index=255, size_average=True): 54 | super(FocalLoss, self).__init__() 55 | self.gamma = gamma 56 | self.size_average = size_average 57 | self.CE_loss = nn.CrossEntropyLoss(reduce=False, ignore_index=ignore_index, weight=alpha) 58 | 59 | def forward(self, output, target): 60 | logpt = self.CE_loss(output, target) 61 | pt = torch.exp(-logpt) 62 | loss = ((1-pt)**self.gamma) * logpt 63 | if self.size_average: 64 | return loss.mean() 65 | return loss.sum() 66 | 67 | class CE_DiceLoss(nn.Module): 68 | def __init__(self, smooth=1, reduction='mean', ignore_index=255, weight=None): 69 | super(CE_DiceLoss, self).__init__() 70 | self.smooth = smooth 71 | self.dice = DiceLoss() 72 | self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction, ignore_index=ignore_index) 73 | 74 | def forward(self, output, target): 75 | CE_loss = self.cross_entropy(output, target) 76 | dice_loss = self.dice(output, target) 77 | return CE_loss + dice_loss 78 | 79 | class LovaszSoftmax(nn.Module): 80 | def __init__(self, classes='present', per_image=False, ignore_index=255): 81 | super(LovaszSoftmax, self).__init__() 82 | self.smooth = classes 83 | self.per_image = per_image 84 | self.ignore_index = ignore_index 85 | 86 | def forward(self, output, target): 87 | logits = F.softmax(output, dim=1) 88 | loss = lovasz_softmax(logits, target, ignore=self.ignore_index) 89 | return loss 90 | -------------------------------------------------------------------------------- /utils/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py 5 | """ 6 | 7 | from __future__ import print_function, division 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | import numpy as np 13 | try: 14 | from itertools import ifilterfalse 15 | except ImportError: # py3k 16 | from itertools import filterfalse as ifilterfalse 17 | 18 | 19 | def lovasz_grad(gt_sorted): 20 | """ 21 | Computes gradient of the Lovasz extension w.r.t sorted errors 22 | See Alg. 1 in paper 23 | """ 24 | p = len(gt_sorted) 25 | gts = gt_sorted.sum() 26 | intersection = gts - gt_sorted.float().cumsum(0) 27 | union = gts + (1 - gt_sorted).float().cumsum(0) 28 | jaccard = 1. - intersection / union 29 | if p > 1: # cover 1-pixel case 30 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 31 | return jaccard 32 | 33 | 34 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 35 | """ 36 | IoU for foreground class 37 | binary: 1 foreground, 0 background 38 | """ 39 | if not per_image: 40 | preds, labels = (preds,), (labels,) 41 | ious = [] 42 | for pred, label in zip(preds, labels): 43 | intersection = ((label == 1) & (pred == 1)).sum() 44 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 45 | if not union: 46 | iou = EMPTY 47 | else: 48 | iou = float(intersection) / float(union) 49 | ious.append(iou) 50 | iou = mean(ious) # mean accross images if per_image 51 | return 100 * iou 52 | 53 | 54 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 55 | """ 56 | Array of IoU for each (non ignored) class 57 | """ 58 | if not per_image: 59 | preds, labels = (preds,), (labels,) 60 | ious = [] 61 | for pred, label in zip(preds, labels): 62 | iou = [] 63 | for i in range(C): 64 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 65 | intersection = ((label == i) & (pred == i)).sum() 66 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 67 | if not union: 68 | iou.append(EMPTY) 69 | else: 70 | iou.append(float(intersection) / float(union)) 71 | ious.append(iou) 72 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 73 | return 100 * np.array(ious) 74 | 75 | 76 | # --------------------------- BINARY LOSSES --------------------------- 77 | 78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 79 | """ 80 | Binary Lovasz hinge loss 81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 83 | per_image: compute the loss per image instead of per batch 84 | ignore: void class id 85 | """ 86 | if per_image: 87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 88 | for log, lab in zip(logits, labels)) 89 | else: 90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 91 | return loss 92 | 93 | 94 | def lovasz_hinge_flat(logits, labels): 95 | """ 96 | Binary Lovasz hinge loss 97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 98 | labels: [P] Tensor, binary ground truth labels (0 or 1) 99 | ignore: label to ignore 100 | """ 101 | if len(labels) == 0: 102 | # only void pixels, the gradients should be 0 103 | return logits.sum() * 0. 104 | signs = 2. * labels.float() - 1. 105 | errors = (1. - logits * Variable(signs)) 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | gt_sorted = labels[perm] 109 | grad = lovasz_grad(gt_sorted) 110 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 111 | return loss 112 | 113 | 114 | def flatten_binary_scores(scores, labels, ignore=None): 115 | """ 116 | Flattens predictions in the batch (binary case) 117 | Remove labels equal to 'ignore' 118 | """ 119 | scores = scores.view(-1) 120 | labels = labels.view(-1) 121 | if ignore is None: 122 | return scores, labels 123 | valid = (labels != ignore) 124 | vscores = scores[valid] 125 | vlabels = labels[valid] 126 | return vscores, vlabels 127 | 128 | 129 | class StableBCELoss(torch.nn.modules.Module): 130 | def __init__(self): 131 | super(StableBCELoss, self).__init__() 132 | def forward(self, input, target): 133 | neg_abs = - input.abs() 134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 135 | return loss.mean() 136 | 137 | 138 | def binary_xloss(logits, labels, ignore=None): 139 | """ 140 | Binary Cross entropy loss 141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 143 | ignore: void class id 144 | """ 145 | logits, labels = flatten_binary_scores(logits, labels, ignore) 146 | loss = StableBCELoss()(logits, Variable(labels.float())) 147 | return loss 148 | 149 | 150 | # --------------------------- MULTICLASS LOSSES --------------------------- 151 | 152 | 153 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 154 | """ 155 | Multi-class Lovasz-Softmax loss 156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 157 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 158 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 159 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 160 | per_image: compute the loss per image instead of per batch 161 | ignore: void class labels 162 | """ 163 | if per_image: 164 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 165 | for prob, lab in zip(probas, labels)) 166 | else: 167 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 168 | return loss 169 | 170 | 171 | def lovasz_softmax_flat(probas, labels, classes='present'): 172 | """ 173 | Multi-class Lovasz-Softmax loss 174 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 175 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 176 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 177 | """ 178 | if probas.numel() == 0: 179 | # only void pixels, the gradients should be 0 180 | return probas * 0. 181 | C = probas.size(1) 182 | losses = [] 183 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 184 | for c in class_to_sum: 185 | fg = (labels == c).float() # foreground for class c 186 | if (classes is 'present' and fg.sum() == 0): 187 | continue 188 | if C == 1: 189 | if len(classes) > 1: 190 | raise ValueError('Sigmoid output possible only with 1 class') 191 | class_pred = probas[:, 0] 192 | else: 193 | class_pred = probas[:, c] 194 | errors = (Variable(fg) - class_pred).abs() 195 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 196 | perm = perm.data 197 | fg_sorted = fg[perm] 198 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 199 | return mean(losses) 200 | 201 | 202 | def flatten_probas(probas, labels, ignore=None): 203 | """ 204 | Flattens predictions in the batch 205 | """ 206 | if probas.dim() == 3: 207 | # assumes output of a sigmoid layer 208 | B, H, W = probas.size() 209 | probas = probas.view(B, 1, H, W) 210 | B, C, H, W = probas.size() 211 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 212 | labels = labels.view(-1) 213 | if ignore is None: 214 | return probas, labels 215 | valid = (labels != ignore) 216 | vprobas = probas[valid.nonzero().squeeze()] 217 | vlabels = labels[valid] 218 | return vprobas, vlabels 219 | 220 | def xloss(logits, labels, ignore=None): 221 | """ 222 | Cross entropy loss 223 | """ 224 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 225 | 226 | 227 | # --------------------------- HELPER FUNCTIONS --------------------------- 228 | def isnan(x): 229 | return x != x 230 | 231 | 232 | def mean(l, ignore_nan=False, empty=0): 233 | """ 234 | nanmean compatible with generators. 235 | """ 236 | l = iter(l) 237 | if ignore_nan: 238 | l = ifilterfalse(isnan, l) 239 | try: 240 | n = 1 241 | acc = next(l) 242 | except StopIteration: 243 | if empty == 'raise': 244 | raise ValueError('Empty mean') 245 | return empty 246 | for n, v in enumerate(l, 2): 247 | acc += v 248 | if n == 1: 249 | return acc 250 | return acc / n 251 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | class Poly(_LRScheduler): 5 | def __init__(self, optimizer, num_epochs, iters_per_epoch=0, warmup_epochs=0, last_epoch=-1): 6 | self.iters_per_epoch = iters_per_epoch 7 | self.cur_iter = 0 8 | self.N = num_epochs * iters_per_epoch 9 | self.warmup_iters = warmup_epochs * iters_per_epoch 10 | super(Poly, self).__init__(optimizer, last_epoch) 11 | 12 | def get_lr(self): 13 | T = self.last_epoch * self.iters_per_epoch + self.cur_iter 14 | factor = pow((1 - 1.0 * T / self.N), 0.9) 15 | if self.warmup_iters > 0 and T < self.warmup_iters: 16 | factor = 1.0 * T / self.warmup_iters 17 | 18 | self.cur_iter %= self.iters_per_epoch 19 | self.cur_iter += 1 20 | return [base_lr * factor for base_lr in self.base_lrs] 21 | 22 | 23 | class OneCycle(_LRScheduler): 24 | def __init__(self, optimizer, num_epochs, iters_per_epoch=0, last_epoch=-1, 25 | momentums = (0.85, 0.95), div_factor = 25, phase1=0.3): 26 | self.iters_per_epoch = iters_per_epoch 27 | self.cur_iter = 0 28 | self.N = num_epochs * iters_per_epoch 29 | self.phase1_iters = int(self.N * phase1) 30 | self.phase2_iters = (self.N - self.phase1_iters) 31 | self.momentums = momentums 32 | self.mom_diff = momentums[1] - momentums[0] 33 | 34 | self.low_lrs = [opt_grp['lr']/div_factor for opt_grp in optimizer.param_groups] 35 | self.final_lrs = [opt_grp['lr']/(div_factor * 1e4) for opt_grp in optimizer.param_groups] 36 | super(OneCycle, self).__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self): 39 | T = self.last_epoch * self.iters_per_epoch + self.cur_iter 40 | self.cur_iter %= self.iters_per_epoch 41 | self.cur_iter += 1 42 | 43 | # Going from base_lr / 25 -> base_lr 44 | if T <= self.phase1_iters: 45 | cos_anneling = (1 + math.cos(math.pi * T / self.phase1_iters)) / 2 46 | for i in range(len(self.optimizer.param_groups)): 47 | self.optimizer.param_groups[i]['momentum'] = self.momentums[0] + self.mom_diff * cos_anneling 48 | 49 | return [base_lr - (base_lr - low_lr) * cos_anneling 50 | for base_lr, low_lr in zip(self.base_lrs, self.low_lrs)] 51 | 52 | # Going from base_lr -> base_lr / (25e4) 53 | T -= self.phase1_iters 54 | cos_anneling = (1 + math.cos(math.pi * T / self.phase2_iters)) / 2 55 | 56 | for i in range(len(self.optimizer.param_groups)): 57 | self.optimizer.param_groups[i]['momentum'] = self.momentums[1] - self.mom_diff * cos_anneling 58 | return [final_lr + (base_lr - final_lr) * cos_anneling 59 | for base_lr, final_lr in zip(self.base_lrs, self.final_lrs)] 60 | 61 | 62 | if __name__ == "__main__": 63 | import torchvision 64 | import torch 65 | import matplotlib.pylab as plt 66 | 67 | resnet = torchvision.models.resnet34() 68 | params = { 69 | "lr": 0.01, 70 | "weight_decay": 0.001, 71 | "momentum": 0.9 72 | } 73 | optimizer = torch.optim.SGD(params=resnet.parameters(), **params) 74 | 75 | epochs = 2 76 | iters_per_epoch = 100 77 | lrs = [] 78 | mementums = [] 79 | lr_scheduler = OneCycle(optimizer, epochs, iters_per_epoch) 80 | #lr_scheduler = Poly(optimizer, epochs, iters_per_epoch) 81 | 82 | for epoch in range(epochs): 83 | for i in range(iters_per_epoch): 84 | lr_scheduler.step(epoch=epoch) 85 | lr_scheduler(optimizer, i, epoch) 86 | lrs.append(optimizer.param_groups[0]['lr']) 87 | mementums.append(optimizer.param_groups[0]['momentum']) 88 | 89 | plt.ylabel("learning rate") 90 | plt.xlabel("iteration") 91 | plt.plot(lrs) 92 | plt.show() 93 | 94 | plt.ylabel("momentum") 95 | plt.xlabel("iteration") 96 | plt.plot(mementums) 97 | plt.show() 98 | 99 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value""" 8 | def __init__(self): 9 | self.initialized = False 10 | self.val = None 11 | self.avg = None 12 | self.sum = None 13 | self.count = None 14 | 15 | def initialize(self, val, weight): 16 | self.val = val 17 | self.avg = val 18 | self.sum = np.multiply(val, weight) 19 | self.count = weight 20 | self.initialized = True 21 | 22 | def update(self, val, weight=1): 23 | if not self.initialized: 24 | self.initialize(val, weight) 25 | else: 26 | self.add(val, weight) 27 | 28 | def add(self, val, weight): 29 | self.val = val 30 | self.sum = np.add(self.sum, np.multiply(val, weight)) 31 | self.count = self.count + weight 32 | self.avg = self.sum / self.count 33 | 34 | @property 35 | def value(self): 36 | return self.val 37 | 38 | @property 39 | def average(self): 40 | return np.round(self.avg, 5) 41 | 42 | def batch_pix_accuracy(predict, target, labeled): 43 | pixel_labeled = labeled.sum() 44 | pixel_correct = ((predict == target) * labeled).sum() 45 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 46 | return pixel_correct.cpu().numpy(), pixel_labeled.cpu().numpy() 47 | 48 | def batch_intersection_union(predict, target, num_class, labeled): 49 | predict = predict * labeled.long() 50 | intersection = predict * (predict == target).long() 51 | 52 | area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1) 53 | area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1) 54 | area_lab = torch.histc(target.float(), bins=num_class, max=num_class, min=1) 55 | area_union = area_pred + area_lab - area_inter 56 | assert (area_inter <= area_union).all(), "Intersection area should be smaller than Union area" 57 | return area_inter.cpu().numpy(), area_union.cpu().numpy() 58 | 59 | def eval_metrics(output, target, num_class): 60 | _, predict = torch.max(output.data, 1) 61 | predict = predict + 1 62 | target = target + 1 63 | 64 | labeled = (target > 0) * (target <= num_class) 65 | correct, num_labeled = batch_pix_accuracy(predict, target, labeled) 66 | inter, union = batch_intersection_union(predict, target, num_class, labeled) 67 | return [np.round(correct, 5), np.round(num_labeled, 5), np.round(inter, 5), np.round(union, 5)] 68 | -------------------------------------------------------------------------------- /utils/palette.py: -------------------------------------------------------------------------------- 1 | 2 | def get_voc_palette(num_classes): 3 | n = num_classes 4 | palette = [0]*(n*3) 5 | for j in range(0,n): 6 | lab = j 7 | palette[j*3+0] = 0 8 | palette[j*3+1] = 0 9 | palette[j*3+2] = 0 10 | i = 0 11 | while (lab > 0): 12 | palette[j*3+0] |= (((lab >> 0) & 1) << (7-i)) 13 | palette[j*3+1] |= (((lab >> 1) & 1) << (7-i)) 14 | palette[j*3+2] |= (((lab >> 2) & 1) << (7-i)) 15 | i = i + 1 16 | lab >>= 3 17 | return palette 18 | 19 | ADE20K_palette = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200, 20 | 3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224, 21 | 5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143, 22 | 255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255, 23 | 6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9, 24 | 92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41, 25 | 10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8, 26 | 0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0, 27 | 163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224 28 | ,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200, 29 | 200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255, 30 | 163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0, 31 | 255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245, 32 | 255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255, 33 | 255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0, 34 | 122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163, 35 | 255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184, 36 | 255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163, 37 | 0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255, 38 | 0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255, 39 | 20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0, 40 | 255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255, 41 | 255,214,0,25,194,194,102,255,0,92,0,255] 42 | 43 | CityScpates_palette = [128,64,128,244,35,232,70,70,70,102,102,156,190,153,153,153,153,153, 44 | 250,170,30,220,220,0,107,142,35,152,251,152,70,130,180,220,20,60,255,0,0,0,0,142, 45 | 0,0,70,0,60,100,0,80,100,0,0,230,119,11,32,128,192,0,0,64,128,128,64,128,0,192, 46 | 128,128,192,128,64,64,0,192,64,0,64,192,0,192,192,0,64,64,128,192,64,128,64,192, 47 | 128,192,192,128,0,0,64,128,0,64,0,128,64,128,128,64,0,0,192,128,0,192,0,128,192, 48 | 128,128,192,64,0,64,192,0,64,64,128,64,192,128,64,64,0,192,192,0,192,64,128,192, 49 | 192,128,192,0,64,64,128,64,64,0,192,64,128,192,64,0,64,192,128,64,192,0,192,192, 50 | 128,192,192,64,64,64,192,64,64,64,192,64,192,192,64,64,64,192,192,64,192,64,192, 51 | 192,192,192,192,32,0,0,160,0,0,32,128,0,160,128,0,32,0,128,160,0,128,32,128,128, 52 | 160,128,128,96,0,0,224,0,0,96,128,0,224,128,0,96,0,128,224,0,128,96,128,128,224, 53 | 128,128,32,64,0,160,64,0,32,192,0,160,192,0,32,64,128,160,64,128,32,192,128,160, 54 | 192,128,96,64,0,224,64,0,96,192,0,224,192,0,96,64,128,224,64,128,96,192,128,224, 55 | 192,128,32,0,64,160,0,64,32,128,64,160,128,64,32,0,192,160,0,192,32,128,192,160, 56 | 128,192,96,0,64,224,0,64,96,128,64,224,128,64,96,0,192,224,0,192,96,128,192,224, 57 | 128,192,32,64,64,160,64,64,32,192,64,160,192,64,32,64,192,160,64,192,32,192,192, 58 | 160,192,192,96,64,64,224,64,64,96,192,64,224,192,64,96,64,192,224,64,192,96,192, 59 | 192,224,192,192,0,32,0,128,32,0,0,160,0,128,160,0,0,32,128,128,32,128,0,160,128, 60 | 128,160,128,64,32,0,192,32,0,64,160,0,192,160,0,64,32,128,192,32,128,64,160,128, 61 | 192,160,128,0,96,0,128,96,0,0,224,0,128,224,0,0,96,128,128,96,128,0,224,128,128, 62 | 224,128,64,96,0,192,96,0,64,224,0,192,224,0,64,96,128,192,96,128,64,224,128,192, 63 | 224,128,0,32,64,128,32,64,0,160,64,128,160,64,0,32,192,128,32,192,0,160,192,128, 64 | 160,192,64,32,64,192,32,64,64,160,64,192,160,64,64,32,192,192,32,192,64,160,192, 65 | 192,160,192,0,96,64,128,96,64,0,224,64,128,224,64,0,96,192,128,96,192,0,224,192, 66 | 128,224,192,64,96,64,192,96,64,64,224,64,192,224,64,64,96,192,192,96,192,64,224, 67 | 192,192,224,192,32,32,0,160,32,0,32,160,0,160,160,0,32,32,128,160,32,128,32,160, 68 | 128,160,160,128,96,32,0,224,32,0,96,160,0,224,160,0,96,32,128,224,32,128,96,160, 69 | 128,224,160,128,32,96,0,160,96,0,32,224,0,160,224,0,32,96,128,160,96,128,32,224, 70 | 128,160,224,128,96,96,0,224,96,0,96,224,0,224,224,0,96,96,128,224,96,128,96,224, 71 | 128,224,224,128,32,32,64,160,32,64,32,160,64,160,160,64,32,32,192,160,32,192,32, 72 | 160,192,160,160,192,96,32,64,224,32,64,96,160,64,224,160,64,96,32,192,224,32,192, 73 | 96,160,192,224,160,192,32,96,64,160,96,64,32,224,64,160,224,64,32,96,192,160,96, 74 | 192,32,224,192,160,224,192,96,96,64,224,96,64,96,224,64,224,224,64,96,96,192,224, 75 | 96,192,96,224,192,0,0,0] 76 | 77 | 78 | COCO_palette = [31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 79 | 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 80 | 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 81 | 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 82 | 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 83 | 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 84 | 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 85 | 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 86 | 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 87 | 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 88 | 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 89 | 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127 90 | , 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 91 | 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 92 | 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 93 | 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 94 | 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 95 | 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 96 | 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 97 | 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 98 | 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 99 | 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 100 | 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 101 | 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14, 44, 160, 44, 214, 39, 102 | 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 127, 188, 189, 34, 23, 190, 207, 31, 119, 103 | 180, 255, 127, 14, 44, 160, 44, 214, 39, 40, 148, 103, 189, 140, 86, 75, 227, 119, 194, 127, 127, 104 | 127, 188, 189, 34, 23, 190, 207, 31, 119, 180, 255, 127, 14] 105 | 106 | DeepScene_palette = [255, 0, 0, 170, 170, 170, 0, 255, 0, 102, 102, 51, 0, 60, 0, 0, 120, 255, 0, 0, 0] 107 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /utils/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /utils/torchsummary.py: -------------------------------------------------------------------------------- 1 | """ 2 | A modied version of the code by Tae Hwan Jung 3 | https://github.com/graykode/modelsummary 4 | """ 5 | 6 | import torch 7 | import numpy as np 8 | import torch.nn as nn 9 | from collections import OrderedDict 10 | 11 | def summary(model, input_shape, batch_size=-1, intputshow=True): 12 | 13 | def register_hook(module): 14 | def hook(module, input, output=None): 15 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 16 | module_idx = len(summary) 17 | 18 | m_key = "%s-%i" % (class_name, module_idx + 1) 19 | summary[m_key] = OrderedDict() 20 | summary[m_key]["input_shape"] = list(input[0].size()) 21 | summary[m_key]["input_shape"][0] = batch_size 22 | 23 | params = 0 24 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 25 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 26 | summary[m_key]["trainable"] = module.weight.requires_grad 27 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 28 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 29 | summary[m_key]["nb_params"] = params 30 | 31 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) 32 | and not (module == model)) and 'torch' in str(module.__class__): 33 | if intputshow is True: 34 | hooks.append(module.register_forward_pre_hook(hook)) 35 | else: 36 | hooks.append(module.register_forward_hook(hook)) 37 | 38 | # create properties 39 | summary = OrderedDict() 40 | hooks = [] 41 | 42 | # register hook 43 | model.apply(register_hook) 44 | model(torch.zeros(input_shape)) 45 | 46 | # remove these hooks 47 | for h in hooks: 48 | h.remove() 49 | 50 | model_info = '' 51 | 52 | model_info += "-----------------------------------------------------------------------\n" 53 | line_new = "{:>25} {:>25} {:>15}".format("Layer (type)", "Input Shape", "Param #") 54 | model_info += line_new + '\n' 55 | model_info += "=======================================================================\n" 56 | 57 | total_params = 0 58 | total_output = 0 59 | trainable_params = 0 60 | for layer in summary: 61 | line_new = "{:>25} {:>25} {:>15}".format( 62 | layer, 63 | str(summary[layer]["input_shape"]), 64 | "{0:,}".format(summary[layer]["nb_params"]), 65 | ) 66 | 67 | total_params += summary[layer]["nb_params"] 68 | if intputshow is True: 69 | total_output += np.prod(summary[layer]["input_shape"]) 70 | else: 71 | total_output += np.prod(summary[layer]["output_shape"]) 72 | if "trainable" in summary[layer]: 73 | if summary[layer]["trainable"] == True: 74 | trainable_params += summary[layer]["nb_params"] 75 | 76 | model_info += line_new + '\n' 77 | 78 | model_info += "=======================================================================\n" 79 | model_info += "Total params: {0:,}\n".format(total_params) 80 | model_info += "Trainable params: {0:,}\n".format(trainable_params) 81 | model_info += "Non-trainable params: {0:,}\n".format(total_params - trainable_params) 82 | model_info += "-----------------------------------------------------------------------\n" 83 | 84 | return model_info -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from PIL import Image, ImageOps, ImageFilter 4 | from skimage.filters import gaussian 5 | import torch 6 | import math 7 | import numbers 8 | import random 9 | 10 | class RandomVerticalFlip(object): 11 | def __call__(self, img): 12 | if random.random() < 0.5: 13 | return img.transpose(Image.FLIP_TOP_BOTTOM) 14 | return img 15 | 16 | class DeNormalize(object): 17 | def __init__(self, mean, std): 18 | self.mean = mean 19 | self.std = std 20 | 21 | def __call__(self, tensor): 22 | for t, m, s in zip(tensor, self.mean, self.std): 23 | t.mul_(s).add_(m) 24 | return tensor 25 | 26 | class MaskToTensor(object): 27 | def __call__(self, img): 28 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 29 | 30 | class FreeScale(object): 31 | def __init__(self, size, interpolation=Image.BILINEAR): 32 | self.size = tuple(reversed(size)) # size: (h, w) 33 | self.interpolation = interpolation 34 | 35 | def __call__(self, img): 36 | return img.resize(self.size, self.interpolation) 37 | 38 | class FlipChannels(object): 39 | def __call__(self, img): 40 | img = np.array(img)[:, :, ::-1] 41 | return Image.fromarray(img.astype(np.uint8)) 42 | 43 | class RandomGaussianBlur(object): 44 | def __call__(self, img): 45 | sigma = 0.15 + random.random() * 1.15 46 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 47 | blurred_img *= 255 48 | return Image.fromarray(blurred_img.astype(np.uint8)) 49 | 50 | class Compose(object): 51 | def __init__(self, transforms): 52 | self.transforms = transforms 53 | 54 | def __call__(self, img, mask): 55 | assert img.size == mask.size 56 | for t in self.transforms: 57 | img, mask = t(img, mask) 58 | return img, mask 59 | 60 | class RandomCrop(object): 61 | def __init__(self, size, padding=0): 62 | if isinstance(size, numbers.Number): 63 | self.size = (int(size), int(size)) 64 | else: 65 | self.size = size 66 | self.padding = padding 67 | 68 | def __call__(self, img, mask): 69 | if self.padding > 0: 70 | img = ImageOps.expand(img, border=self.padding, fill=0) 71 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 72 | 73 | assert img.size == mask.size 74 | w, h = img.size 75 | th, tw = self.size 76 | if w == tw and h == th: 77 | return img, mask 78 | if w < tw or h < th: 79 | return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) 80 | 81 | x1 = random.randint(0, w - tw) 82 | y1 = random.randint(0, h - th) 83 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 84 | 85 | 86 | class CenterCrop(object): 87 | def __init__(self, size): 88 | if isinstance(size, numbers.Number): 89 | self.size = (int(size), int(size)) 90 | else: 91 | self.size = size 92 | 93 | def __call__(self, img, mask): 94 | assert img.size == mask.size 95 | w, h = img.size 96 | th, tw = self.size 97 | x1 = int(round((w - tw) / 2.)) 98 | y1 = int(round((h - th) / 2.)) 99 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 100 | 101 | 102 | class RandomHorizontallyFlip(object): 103 | def __call__(self, img, mask): 104 | if random.random() < 0.5: 105 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 106 | return img, mask 107 | 108 | class Scale(object): 109 | def __init__(self, size): 110 | self.size = size 111 | 112 | def __call__(self, img, mask): 113 | assert img.size == mask.size 114 | w, h = img.size 115 | if (w >= h and w == self.size) or (h >= w and h == self.size): 116 | return img, mask 117 | if w > h: 118 | ow = self.size 119 | oh = int(self.size * h / w) 120 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 125 | 126 | class RandomSizedCrop(object): 127 | def __init__(self, size): 128 | self.size = size 129 | 130 | def __call__(self, img, mask): 131 | assert img.size == mask.size 132 | for attempt in range(10): 133 | area = img.size[0] * img.size[1] 134 | target_area = random.uniform(0.45, 1.0) * area 135 | aspect_ratio = random.uniform(0.5, 2) 136 | 137 | w = int(round(math.sqrt(target_area * aspect_ratio))) 138 | h = int(round(math.sqrt(target_area / aspect_ratio))) 139 | 140 | if random.random() < 0.5: 141 | w, h = h, w 142 | 143 | if w <= img.size[0] and h <= img.size[1]: 144 | x1 = random.randint(0, img.size[0] - w) 145 | y1 = random.randint(0, img.size[1] - h) 146 | 147 | img = img.crop((x1, y1, x1 + w, y1 + h)) 148 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 149 | assert (img.size == (w, h)) 150 | 151 | return img.resize((self.size, self.size), Image.BILINEAR), mask.resize((self.size, self.size), 152 | Image.NEAREST) 153 | 154 | # Fallback 155 | scale = Scale(self.size) 156 | crop = CenterCrop(self.size) 157 | return crop(*scale(img, mask)) 158 | 159 | class RandomRotate(object): 160 | def __init__(self, degree): 161 | self.degree = degree 162 | 163 | def __call__(self, img, mask): 164 | rotate_degree = random.random() * 2 * self.degree - self.degree 165 | return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) 166 | 167 | class RandomSized(object): 168 | def __init__(self, size): 169 | self.size = size 170 | self.scale = Scale(self.size) 171 | self.crop = RandomCrop(self.size) 172 | 173 | def __call__(self, img, mask): 174 | assert img.size == mask.size 175 | 176 | w = int(random.uniform(0.5, 2) * img.size[0]) 177 | h = int(random.uniform(0.5, 2) * img.size[1]) 178 | 179 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 180 | 181 | return self.crop(*self.scale(img, mask)) 182 | 183 | class SlidingCropOld(object): 184 | def __init__(self, crop_size, stride_rate, ignore_label): 185 | self.crop_size = crop_size 186 | self.stride_rate = stride_rate 187 | self.ignore_label = ignore_label 188 | 189 | def _pad(self, img, mask): 190 | h, w = img.shape[: 2] 191 | pad_h = max(self.crop_size - h, 0) 192 | pad_w = max(self.crop_size - w, 0) 193 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 194 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 195 | return img, mask 196 | 197 | def __call__(self, img, mask): 198 | assert img.size == mask.size 199 | 200 | w, h = img.size 201 | long_size = max(h, w) 202 | 203 | img = np.array(img) 204 | mask = np.array(mask) 205 | 206 | if long_size > self.crop_size: 207 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 208 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 209 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 210 | img_sublist, mask_sublist = [], [] 211 | for yy in range(h_step_num): 212 | for xx in range(w_step_num): 213 | sy, sx = yy * stride, xx * stride 214 | ey, ex = sy + self.crop_size, sx + self.crop_size 215 | img_sub = img[sy: ey, sx: ex, :] 216 | mask_sub = mask[sy: ey, sx: ex] 217 | img_sub, mask_sub = self._pad(img_sub, mask_sub) 218 | img_sublist.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 219 | mask_sublist.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 220 | return img_sublist, mask_sublist 221 | else: 222 | img, mask = self._pad(img, mask) 223 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 224 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 225 | return img, mask 226 | 227 | 228 | class SlidingCrop(object): 229 | def __init__(self, crop_size, stride_rate, ignore_label): 230 | self.crop_size = crop_size 231 | self.stride_rate = stride_rate 232 | self.ignore_label = ignore_label 233 | 234 | def _pad(self, img, mask): 235 | h, w = img.shape[: 2] 236 | pad_h = max(self.crop_size - h, 0) 237 | pad_w = max(self.crop_size - w, 0) 238 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 239 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 240 | return img, mask, h, w 241 | 242 | def __call__(self, img, mask): 243 | assert img.size == mask.size 244 | 245 | w, h = img.size 246 | long_size = max(h, w) 247 | 248 | img = np.array(img) 249 | mask = np.array(mask) 250 | 251 | if long_size > self.crop_size: 252 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 253 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 254 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 255 | img_slices, mask_slices, slices_info = [], [], [] 256 | for yy in range(h_step_num): 257 | for xx in range(w_step_num): 258 | sy, sx = yy * stride, xx * stride 259 | ey, ex = sy + self.crop_size, sx + self.crop_size 260 | img_sub = img[sy: ey, sx: ex, :] 261 | mask_sub = mask[sy: ey, sx: ex] 262 | img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub) 263 | img_slices.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 264 | mask_slices.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 265 | slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) 266 | return img_slices, mask_slices, slices_info 267 | else: 268 | img, mask, sub_h, sub_w = self._pad(img, mask) 269 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 270 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 271 | return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]] 272 | --------------------------------------------------------------------------------