├── .gitignore ├── README.md ├── __init__.py ├── cityscapes_demo.ipynb ├── datasets ├── __init__.py └── cityscapes.py ├── eval ├── cityscapes.txt └── pretraining.txt ├── eval_cityscapes.py ├── metric ├── __init__.py ├── confusionmatrix.py ├── iou.py └── metric.py ├── model ├── __init__.py ├── backbone.py ├── decoder.py └── dfanet.py ├── plugin.py ├── pretrain_backbone.py ├── train_cityscapes.py └── utils ├── __init__.py ├── joint_transforms.py ├── misc.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data/ 107 | *.idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DFANet_PyTorch 2 | Unofficial implementation of Deep Feature Aggregation Networks for real-time semantic segmentation. 3 | 4 | Hanchao Li, Pengfei Xiong, Haoqiang Fan and Jian Sun. DFANet Deep Feature Aggregation for Real-Time Semantic Segmentation. In CoRR (2019). 5 | https://arxiv.org/abs/1904.02216 6 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jandylin/DFANet_PyTorch/724ea05ba8b44d2dedb0a3389b4dfaaf57f5a2b6/__init__.py -------------------------------------------------------------------------------- /cityscapes_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "from datasets.cityscapes import Cityscapes\n", 11 | "from plugin import DFANetPlugin\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "\n", 14 | "net_plugin = DFANetPlugin(2048, 1024, False) # or 1024x1024?\n", 15 | "\n", 16 | "path_dataset = \"\" # fill in..\n", 17 | "test_dataset = Cityscapes(path_dataset, split='test', mode='fine', target_type='semantic')\n", 18 | "\n", 19 | "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=4)\n", 20 | "\n", 21 | "data_iterator = iter(test_loader)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "image, label = data_iterator.next()\n", 31 | "cv2.imshow('ground truth', label)\n", 32 | "\n", 33 | "output = plugin.process(image)\n", 34 | "cv2.imshow('DFANet output', output)" 35 | ] 36 | } 37 | ], 38 | "metadata": { 39 | "kernelspec": { 40 | "display_name": "Python 3", 41 | "language": "python", 42 | "name": "python3" 43 | }, 44 | "language_info": { 45 | "codemirror_mode": { 46 | "name": "ipython", 47 | "version": 3 48 | }, 49 | "file_extension": ".py", 50 | "mimetype": "text/x-python", 51 | "name": "python", 52 | "nbconvert_exporter": "python", 53 | "pygments_lexer": "ipython3", 54 | "version": "3.6.4" 55 | } 56 | }, 57 | "nbformat": 4, 58 | "nbformat_minor": 2 59 | } 60 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jandylin/DFANet_PyTorch/724ea05ba8b44d2dedb0a3389b4dfaaf57f5a2b6/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import namedtuple 4 | 5 | import torch.utils.data as data 6 | from PIL import Image 7 | 8 | class Cityscapes(data.Dataset): 9 | """`Cityscapes `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory of dataset where directory ``leftImg8bit`` 13 | and ``gtFine`` or ``gtCoarse`` are located. 14 | split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine" 15 | otherwise ``train``, ``train_extra`` or ``val`` 16 | mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse`` 17 | target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon`` 18 | or ``color``. Can also be a list to output a tuple with all specified target types. 19 | transform (callable, optional): A function/transform that takes in a PIL image 20 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 21 | target_transform (callable, optional): A function/transform that takes in the 22 | target and transforms it. 23 | 24 | Examples: 25 | 26 | Get semantic segmentation target 27 | 28 | .. code-block:: python 29 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', 30 | target_type='semantic') 31 | 32 | img, smnt = dataset[0] 33 | 34 | Get multiple targets 35 | 36 | .. code-block:: python 37 | dataset = Cityscapes('./data/cityscapes', split='train', mode='fine', 38 | target_type=['instance', 'color', 'polygon']) 39 | 40 | img, (inst, col, poly) = dataset[0] 41 | 42 | Validate on the "coarse" set 43 | 44 | .. code-block:: python 45 | dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse', 46 | target_type='semantic') 47 | 48 | img, smnt = dataset[0] 49 | """ 50 | 51 | # Based on https://github.com/mcordts/cityscapesScripts 52 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', 53 | 'has_instances', 'ignore_in_eval', 'color']) 54 | 55 | classes = [ 56 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 57 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 58 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 59 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 60 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 61 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 62 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 63 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 64 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), 65 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 66 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 67 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), 68 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), 69 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), 70 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 71 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 72 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 73 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), 74 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 75 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), 76 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), 77 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), 78 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), 79 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), 80 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), 81 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), 82 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), 83 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), 84 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), 85 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 86 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 87 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), 88 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), 89 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), 90 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), 91 | ] 92 | 93 | def __init__(self, root, split='train', mode='fine', target_type='instance', 94 | transform=None, target_transform=None): 95 | self.root = os.path.expanduser(root) 96 | self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' 97 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split) 98 | self.targets_dir = os.path.join(self.root, self.mode, split) 99 | self.transform = transform 100 | self.target_transform = target_transform 101 | self.target_type = target_type 102 | self.split = split 103 | self.images = [] 104 | self.targets = [] 105 | 106 | if mode not in ['fine', 'coarse']: 107 | raise ValueError('Invalid mode! Please use mode="fine" or mode="coarse"') 108 | 109 | if mode == 'fine' and split not in ['train', 'test', 'val']: 110 | raise ValueError('Invalid split for mode "fine"! Please use split="train", split="test"' 111 | ' or split="val"') 112 | elif mode == 'coarse' and split not in ['train', 'train_extra', 'val']: 113 | raise ValueError('Invalid split for mode "coarse"! Please use split="train", split="train_extra"' 114 | ' or split="val"') 115 | 116 | if not isinstance(target_type, list): 117 | self.target_type = [target_type] 118 | 119 | if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type): 120 | raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"' 121 | ' or "color"') 122 | 123 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): 124 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' 125 | ' specified "split" and "mode" are inside the "root" directory') 126 | 127 | for city in os.listdir(self.images_dir): 128 | img_dir = os.path.join(self.images_dir, city) 129 | target_dir = os.path.join(self.targets_dir, city) 130 | for file_name in os.listdir(img_dir): 131 | target_types = [] 132 | for t in self.target_type: 133 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 134 | self._get_target_suffix(self.mode, t)) 135 | target_types.append(os.path.join(target_dir, target_name)) 136 | 137 | self.images.append(os.path.join(img_dir, file_name)) 138 | self.targets.append(target_types) 139 | 140 | def __getitem__(self, index): 141 | """ 142 | Args: 143 | index (int): Index 144 | Returns: 145 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more 146 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. 147 | """ 148 | 149 | image = Image.open(self.images[index]).convert('RGB') 150 | 151 | targets = [] 152 | for i, t in enumerate(self.target_type): 153 | if t == 'polygon': 154 | target = self._load_json(self.targets[index][i]) 155 | else: 156 | target = Image.open(self.targets[index][i]) 157 | 158 | targets.append(target) 159 | 160 | target = tuple(targets) if len(targets) > 1 else targets[0] 161 | 162 | if self.transform: 163 | image, target = self.transform(image, target) 164 | 165 | if self.target_transform: 166 | target = self.target_transform(target) 167 | 168 | return image, target 169 | 170 | def __len__(self): 171 | return len(self.images) 172 | 173 | def __repr__(self): 174 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 175 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 176 | fmt_str += ' Split: {}\n'.format(self.split) 177 | fmt_str += ' Mode: {}\n'.format(self.mode) 178 | fmt_str += ' Type: {}\n'.format(self.target_type) 179 | fmt_str += ' Root Location: {}\n'.format(self.root) 180 | tmp = ' Transforms (if any): ' 181 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 182 | tmp = ' Target Transforms (if any): ' 183 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 184 | return fmt_str 185 | 186 | def _load_json(self, path): 187 | with open(path, 'r') as file: 188 | data = json.load(file) 189 | return data 190 | 191 | def _get_target_suffix(self, mode, target_type): 192 | if target_type == 'instance': 193 | return '{}_instanceIds.png'.format(mode) 194 | elif target_type == 'semantic': 195 | return '{}_labelIds.png'.format(mode) 196 | elif target_type == 'color': 197 | return '{}_color.png'.format(mode) 198 | else: 199 | return '{}_polygons.json'.format(mode) 200 | -------------------------------------------------------------------------------- /eval/cityscapes.txt: -------------------------------------------------------------------------------- 1 | After 440 epochs: 2 | Training Loss: 1.214 -------------------------------------------------------------------------------- /eval/pretraining.txt: -------------------------------------------------------------------------------- 1 | After 90 epochs: 2 | Top1 accuracy: 57.320% 3 | Top5 accuracy: 80.642% 4 | Training loss = 2.153 -------------------------------------------------------------------------------- /eval_cityscapes.py: -------------------------------------------------------------------------------- 1 | # Adapted from official PyTorch Tutorial 2 | import argparse 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | import sys 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | from datasets.cityscapes import Cityscapes 20 | from utils import joint_transforms 21 | from metric.iou import IoU 22 | from model.dfanet import DFANet 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch Cityscapes Training') 25 | parser.add_argument('data', metavar='DIR', 26 | help='path to dataset') 27 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 28 | help='number of data loading workers (default: 4)') 29 | parser.add_argument('--epochs', default=1, type=int, metavar='N', 30 | help='number of total epochs to run') 31 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 32 | help='manual epoch number (useful on restarts)') 33 | parser.add_argument('-b', '--batch-size', default=4, type=int, 34 | metavar='N', 35 | help='mini-batch size (default: 256), this is the total ' 36 | 'batch size of all GPUs on the current node when ' 37 | 'using Data Parallel or Distributed Data Parallel') 38 | parser.add_argument('--lr', '--learning-rate', default=2e-1, type=float, 39 | metavar='LR', help='initial learning rate', dest='lr') 40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 41 | help='momentum') 42 | parser.add_argument('--power', default=0.9, type=float, metavar='M', 43 | help='power for poly learning rate policy') 44 | parser.add_argument('--wd', '--weight-decay', default=1e-5, type=float, 45 | metavar='W', help='weight decay (default: 1e-4)', 46 | dest='weight_decay') 47 | parser.add_argument('-p', '--print-freq', default=10, type=int, 48 | metavar='N', help='print frequency (default: 10)') 49 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 50 | help='path to latest checkpoint (default: none)') 51 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 52 | help='evaluate model on validation set') 53 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 54 | help='use pre-trained model') 55 | parser.add_argument('--world-size', default=-1, type=int, 56 | help='number of nodes for distributed training') 57 | parser.add_argument('--rank', default=-1, type=int, 58 | help='node rank for distributed training') 59 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 60 | help='url used to set up distributed training') 61 | parser.add_argument('--dist-backend', default='nccl', type=str, 62 | help='distributed backend') 63 | parser.add_argument('--seed', default=None, type=int, 64 | help='seed for initializing training. ') 65 | parser.add_argument('--gpu', default=None, type=int, 66 | help='GPU id to use.') 67 | parser.add_argument('--multiprocessing-distributed', action='store_true', 68 | help='Use multi-processing distributed training to launch ' 69 | 'N processes per node, which has N GPUs. This is the ' 70 | 'fastest way to use PyTorch for either single node or ' 71 | 'multi node data parallel training') 72 | 73 | best_mIoU = 0 74 | 75 | 76 | def main(): 77 | args = parser.parse_args() 78 | 79 | if args.seed is not None: 80 | random.seed(args.seed) 81 | torch.manual_seed(args.seed) 82 | cudnn.deterministic = True 83 | warnings.warn('You have chosen to seed training. ' 84 | 'This will turn on the CUDNN deterministic setting, ' 85 | 'which can slow down your training considerably! ' 86 | 'You may see unexpected behavior when restarting ' 87 | 'from checkpoints.') 88 | 89 | if args.gpu is not None: 90 | warnings.warn('You have chosen a specific GPU. This will completely ' 91 | 'disable data parallelism.') 92 | 93 | if args.dist_url == "env://" and args.world_size == -1: 94 | args.world_size = int(os.environ["WORLD_SIZE"]) 95 | 96 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 97 | 98 | ngpus_per_node = torch.cuda.device_count() 99 | if args.multiprocessing_distributed: 100 | # Since we have ngpus_per_node processes per node, the total world_size 101 | # needs to be adjusted accordingly 102 | args.world_size = ngpus_per_node * args.world_size 103 | # Use torch.multiprocessing.spawn to launch distributed processes: the 104 | # main_worker process function 105 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 106 | else: 107 | # Simply call main_worker function 108 | main_worker(args.gpu, ngpus_per_node, args) 109 | 110 | 111 | def main_worker(gpu, ngpus_per_node, args): 112 | global best_mIoU 113 | args.gpu = gpu 114 | 115 | if args.gpu is not None: 116 | print("Use GPU: {} for training".format(args.gpu)) 117 | 118 | if args.distributed: 119 | if args.dist_url == "env://" and args.rank == -1: 120 | args.rank = int(os.environ["RANK"]) 121 | if args.multiprocessing_distributed: 122 | # For multiprocessing distributed training, rank needs to be the 123 | # global rank among all the processes 124 | args.rank = args.rank * ngpus_per_node + gpu 125 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 126 | world_size=args.world_size, rank=args.rank) 127 | # create model 128 | if args.pretrained: 129 | print("=> using pre-trained model 'DFANet'") 130 | model = DFANet(pretrained=True, pretrained_backbone=False) 131 | else: 132 | print("=> creating model 'DFANet'") 133 | model = DFANet(pretrained=False, pretrained_backbone=True) 134 | 135 | if args.distributed: 136 | # For multiprocessing distributed, DistributedDataParallel constructor 137 | # should always set the single device scope, otherwise, 138 | # DistributedDataParallel will use all available devices. 139 | if args.gpu is not None: 140 | torch.cuda.set_device(args.gpu) 141 | model.cuda(args.gpu) 142 | # When using a single GPU per process and per 143 | # DistributedDataParallel, we need to divide the batch size 144 | # ourselves based on the total number of GPUs we have 145 | args.batch_size = int(args.batch_size / ngpus_per_node) 146 | args.workers = int(args.workers / ngpus_per_node) 147 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 148 | else: 149 | model.cuda() 150 | # DistributedDataParallel will divide and allocate batch_size to all 151 | # available GPUs if device_ids are not set 152 | model = torch.nn.parallel.DistributedDataParallel(model) 153 | elif args.gpu is not None: 154 | torch.cuda.set_device(args.gpu) 155 | model = model.cuda(args.gpu) 156 | else: 157 | # DataParallel will divide and allocate batch_size to all available GPUs 158 | model = torch.nn.DataParallel(model).cuda() 159 | 160 | # define loss function (criterion) and optimizer 161 | criterion = nn.CrossEntropyLoss(ignore_index=19).cuda(args.gpu) 162 | 163 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 164 | momentum=args.momentum, 165 | weight_decay=args.weight_decay) 166 | 167 | metric = IoU(20, ignore_index=19) 168 | 169 | # optionally resume from a checkpoint 170 | if args.resume: 171 | if os.path.isfile(args.resume): 172 | print("=> loading checkpoint '{}'".format(args.resume)) 173 | checkpoint = torch.load(args.resume) 174 | args.start_epoch = checkpoint['epoch'] 175 | best_mIoU = checkpoint['best_mIoU'] 176 | if args.gpu is not None: 177 | # best_mIoU may be from a checkpoint from a different GPU 178 | best_mIoU = best_mIoU.to(args.gpu) 179 | model.load_state_dict(checkpoint['state_dict']) 180 | optimizer.load_state_dict(checkpoint['optimizer']) 181 | print("=> loaded checkpoint '{}' (epoch {})" 182 | .format(args.resume, checkpoint['epoch'])) 183 | else: 184 | print("=> no checkpoint found at '{}'".format(args.resume)) 185 | 186 | cudnn.benchmark = True 187 | 188 | # Data loading code 189 | train_dataset = Cityscapes(args.data, split='train', mode='fine', target_type='semantic', 190 | transform=joint_transforms.Compose([ 191 | joint_transforms.RandomHorizontalFlip(), 192 | joint_transforms.RandomSized(1024), 193 | joint_transforms.ToTensor(), 194 | joint_transforms.Normalize( 195 | mean=[0.485, 0.456, 0.406], 196 | std=[0.229, 0.224, 0.225]) 197 | ])) 198 | 199 | if args.distributed: 200 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 201 | else: 202 | train_sampler = None 203 | 204 | train_loader = torch.utils.data.DataLoader( 205 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 206 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 207 | 208 | val_loader = torch.utils.data.DataLoader( 209 | Cityscapes(args.data, split='val', mode='fine', target_type='semantic', 210 | transform=joint_transforms.Compose([ 211 | joint_transforms.RandomHorizontalFlip(), 212 | joint_transforms.RandomSized(1024), 213 | joint_transforms.ToTensor(), 214 | joint_transforms.Normalize( 215 | mean=[0.485, 0.456, 0.406], 216 | std=[0.229, 0.224, 0.225]) 217 | ])), 218 | batch_size=args.batch_size, shuffle=False, 219 | num_workers=args.workers, pin_memory=True) 220 | 221 | if args.evaluate: 222 | validate(val_loader, model, criterion, args) 223 | return 224 | 225 | for epoch in range(args.start_epoch, args.epochs): 226 | if args.distributed: 227 | train_sampler.set_epoch(epoch) 228 | 229 | # evaluate on training data 230 | train_mIoU, train_loss = validate(train_loader, model, criterion, metric, args) 231 | 232 | # evaluate on validation set 233 | val_mIoU, val_loss = validate(val_loader, model, criterion, metric, args) 234 | 235 | print("Train mIoU: {}".format(train_mIoU)) 236 | print("Train Loss: {}".format(train_loss)) 237 | print("Val mIoU: {}".format(val_mIoU)) 238 | print("Val mIoU: {}".format(val_loss)) 239 | 240 | 241 | def validate(val_loader, model, criterion, metric, args): 242 | mIoU = AverageMeter('mIoU', ':6.2f') 243 | progress = ProgressMeter(len(val_loader), mIoU) 244 | # switch to evaluate mode 245 | model.eval() 246 | metric.reset() 247 | avg_loss = 0 248 | iter = 0 249 | with torch.no_grad(): 250 | for i, (input, target) in enumerate(val_loader): 251 | if args.gpu is not None: 252 | input = input.cuda(args.gpu, non_blocking=True) 253 | target = target.cuda(args.gpu, non_blocking=True) 254 | 255 | # compute output 256 | output = model(input) 257 | avg_loss += criterion(output.view(output.shape[0], 19, -1), target.view(target.shape[0], -1)) 258 | 259 | # measure mIoU and record loss 260 | metric.reset() 261 | metric.add(output.max(1)[1].view(output.shape[0], 1024, 1024), target.view(target.shape[0], 1024, 1024)) 262 | mIoU.update(metric.value()[1]) 263 | iter += 1 264 | 265 | if i % args.print_freq == 0: 266 | progress.print(i) 267 | 268 | return metric.value()[1], avg_loss / iter 269 | 270 | 271 | def save_checkpoint(state, is_best, filename='./checkpoints/Cityscapes.pth.tar'): 272 | torch.save(state, filename) 273 | if is_best: 274 | shutil.copyfile(filename, 'Cityscapes_best.pth.tar') 275 | 276 | 277 | class AverageMeter(object): 278 | """Computes and stores the average and current value""" 279 | def __init__(self, name, fmt=':f'): 280 | self.name = name 281 | self.fmt = fmt 282 | self.reset() 283 | 284 | def reset(self): 285 | self.val = 0 286 | self.avg = 0 287 | self.sum = 0 288 | self.count = 0 289 | 290 | def update(self, val, n=1): 291 | self.val = val 292 | self.sum += val * n 293 | self.count += n 294 | self.avg = self.sum / self.count 295 | 296 | def __str__(self): 297 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 298 | return fmtstr.format(**self.__dict__) 299 | 300 | 301 | class ProgressMeter(object): 302 | def __init__(self, num_batches, *meters, prefix=""): 303 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 304 | self.meters = meters 305 | self.prefix = prefix 306 | 307 | def print(self, batch): 308 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 309 | entries += [str(meter) for meter in self.meters] 310 | print('\t'.join(entries)) 311 | 312 | def _get_batch_fmtstr(self, num_batches): 313 | num_digits = len(str(num_batches // 1)) 314 | fmt = '{:' + str(num_digits) + 'd}' 315 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 316 | 317 | 318 | def adjust_learning_rate(optimizer, epoch, args): 319 | """Polynomial decay learning rate policy.""" 320 | lr = args.lr * (1 - epoch/args.epochs)**args.power 321 | for param_group in optimizer.param_groups: 322 | param_group['lr'] = lr 323 | 324 | 325 | if __name__ == '__main__': 326 | main() 327 | -------------------------------------------------------------------------------- /metric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jandylin/DFANet_PyTorch/724ea05ba8b44d2dedb0a3389b4dfaaf57f5a2b6/metric/__init__.py -------------------------------------------------------------------------------- /metric/confusionmatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from metric import metric 4 | 5 | 6 | class ConfusionMatrix(metric.Metric): 7 | """Constructs a confusion matrix for a multi-class classification problems. 8 | 9 | Does not support multi-label, multi-class problems. 10 | 11 | Keyword arguments: 12 | - num_classes (int): number of classes in the classification problem. 13 | - normalized (boolean, optional): Determines whether or not the confusion 14 | matrix is normalized or not. Default: False. 15 | 16 | Modified from: https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py 17 | """ 18 | 19 | def __init__(self, num_classes, normalized=False): 20 | super().__init__() 21 | 22 | self.conf = np.ndarray((num_classes, num_classes), dtype=np.int32) 23 | self.normalized = normalized 24 | self.num_classes = num_classes 25 | self.reset() 26 | 27 | def reset(self): 28 | self.conf.fill(0) 29 | 30 | def add(self, predicted, target): 31 | """Computes the confusion matrix 32 | 33 | The shape of the confusion matrix is K x K, where K is the number 34 | of classes. 35 | 36 | Keyword arguments: 37 | - predicted (Tensor or numpy.ndarray): Can be an N x K tensor/array of 38 | predicted scores obtained from the model for N examples and K classes, 39 | or an N-tensor/array of integer values between 0 and K-1. 40 | - target (Tensor or numpy.ndarray): Can be an N x K tensor/array of 41 | ground-truth classes for N examples and K classes, or an N-tensor/array 42 | of integer values between 0 and K-1. 43 | 44 | """ 45 | # If target and/or predicted are tensors, convert them to numpy arrays 46 | if torch.is_tensor(predicted): 47 | predicted = predicted.cpu().numpy() 48 | if torch.is_tensor(target): 49 | target = target.cpu().numpy() 50 | 51 | assert predicted.shape[0] == target.shape[0], \ 52 | 'number of targets and predicted outputs do not match' 53 | 54 | if np.ndim(predicted) != 1: 55 | assert predicted.shape[1] == self.num_classes, \ 56 | 'number of predictions does not match size of confusion matrix' 57 | predicted = np.argmax(predicted, 1) 58 | else: 59 | assert (predicted.max() < self.num_classes) and (predicted.min() >= 0), \ 60 | 'predicted values are not between 0 and k-1' 61 | 62 | if np.ndim(target) != 1: 63 | assert target.shape[1] == self.num_classes, \ 64 | 'Onehot target does not match size of confusion matrix' 65 | assert (target >= 0).all() and (target <= 1).all(), \ 66 | 'in one-hot encoding, target values should be 0 or 1' 67 | assert (target.sum(1) == 1).all(), \ 68 | 'multi-label setting is not supported' 69 | target = np.argmax(target, 1) 70 | else: 71 | assert (target.max() < self.num_classes) and (target.min() >= 0), \ 72 | 'target values are not between 0 and k-1' 73 | 74 | # hack for bincounting 2 arrays together 75 | x = predicted + self.num_classes * target 76 | bincount_2d = np.bincount( 77 | x.astype(np.int32), minlength=self.num_classes**2) 78 | assert bincount_2d.size == self.num_classes**2 79 | conf = bincount_2d.reshape((self.num_classes, self.num_classes)) 80 | 81 | self.conf += conf 82 | 83 | def value(self): 84 | """ 85 | Returns: 86 | Confustion matrix of K rows and K columns, where rows corresponds 87 | to ground-truth targets and columns corresponds to predicted 88 | targets. 89 | """ 90 | if self.normalized: 91 | conf = self.conf.astype(np.float32) 92 | return conf / conf.sum(1).clip(min=1e-12)[:, None] 93 | else: 94 | return self.conf 95 | -------------------------------------------------------------------------------- /metric/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from metric import metric 4 | from metric.confusionmatrix import ConfusionMatrix 5 | 6 | 7 | class IoU(metric.Metric): 8 | """Computes the intersection over union (IoU) per class and corresponding 9 | mean (mIoU). 10 | 11 | Intersection over union (IoU) is a common evaluation metric for semantic 12 | segmentation. The predictions are first accumulated in a confusion matrix 13 | and the IoU is computed from it as follows: 14 | 15 | IoU = true_positive / (true_positive + false_positive + false_negative). 16 | 17 | Keyword arguments: 18 | - num_classes (int): number of classes in the classification problem 19 | - normalized (boolean, optional): Determines whether or not the confusion 20 | matrix is normalized or not. Default: False. 21 | - ignore_index (int or iterable, optional): Index of the classes to ignore 22 | when computing the IoU. Can be an int, or any iterable of ints. 23 | """ 24 | 25 | def __init__(self, num_classes, normalized=False, ignore_index=None): 26 | super().__init__() 27 | self.conf_metric = ConfusionMatrix(num_classes, normalized) 28 | 29 | if ignore_index is None: 30 | self.ignore_index = None 31 | elif isinstance(ignore_index, int): 32 | self.ignore_index = (ignore_index,) 33 | else: 34 | try: 35 | self.ignore_index = tuple(ignore_index) 36 | except TypeError: 37 | raise ValueError("'ignore_index' must be an int or iterable") 38 | 39 | def reset(self): 40 | self.conf_metric.reset() 41 | 42 | def add(self, predicted, target): 43 | """Adds the predicted and target pair to the IoU metric. 44 | 45 | Keyword arguments: 46 | - predicted (Tensor): Can be a (N, K, H, W) tensor of 47 | predicted scores obtained from the model for N examples and K classes, 48 | or (N, H, W) tensor of integer values between 0 and K-1. 49 | - target (Tensor): Can be a (N, K, H, W) tensor of 50 | target scores for N examples and K classes, or (N, H, W) tensor of 51 | integer values between 0 and K-1. 52 | 53 | """ 54 | # Dimensions check 55 | assert predicted.size(0) == target.size(0), \ 56 | 'number of targets and predicted outputs do not match' 57 | assert predicted.dim() == 3 or predicted.dim() == 4, \ 58 | "predictions must be of dimension (N, H, W) or (N, K, H, W)" 59 | assert target.dim() == 3 or target.dim() == 4, \ 60 | "targets must be of dimension (N, H, W) or (N, K, H, W)" 61 | 62 | # If the tensor is in categorical format convert it to integer format 63 | if predicted.dim() == 4: 64 | _, predicted = predicted.max(1) 65 | if target.dim() == 4: 66 | _, target = target.max(1) 67 | 68 | self.conf_metric.add(predicted.view(-1), target.view(-1)) 69 | 70 | def value(self): 71 | """Computes the IoU and mean IoU. 72 | 73 | The mean computation ignores NaN elements of the IoU array. 74 | 75 | Returns: 76 | Tuple: (IoU, mIoU). The first output is the per class IoU, 77 | for K classes it's numpy.ndarray with K elements. The second output, 78 | is the mean IoU. 79 | """ 80 | conf_matrix = self.conf_metric.value() 81 | if self.ignore_index is not None: 82 | for index in self.ignore_index: 83 | conf_matrix[:, self.ignore_index] = 0 84 | conf_matrix[self.ignore_index, :] = 0 85 | true_positive = np.diag(conf_matrix) 86 | false_positive = np.sum(conf_matrix, 0) - true_positive 87 | false_negative = np.sum(conf_matrix, 1) - true_positive 88 | 89 | # Just in case we get a division by 0, ignore/hide the error 90 | with np.errstate(divide='ignore', invalid='ignore'): 91 | iou = true_positive / (true_positive + false_positive + false_negative) 92 | 93 | return iou, np.nanmean(iou) 94 | 95 | 96 | if __name__ == '__main__': 97 | mIoU = IoU(3, ignore_index=2) 98 | target = torch.from_numpy(np.array([ 99 | [ 100 | [0, 0, 0, 1, 1], 101 | [0, 0, 2, 1, 1], 102 | [0, 2, 2, 1, 1], 103 | [0, 2, 2, 1, 1], 104 | [1, 0, 0, 1, 1] 105 | ], 106 | [ 107 | [0, 0, 0, 1, 2], 108 | [0, 2, 1, 1, 2], 109 | [0, 2, 1, 1, 2], 110 | [0, 1, 1, 1, 2], 111 | [1, 2, 2, 2, 2] 112 | ], 113 | ])).view(2, 5, 5) 114 | output = torch.from_numpy(np.array([ 115 | [ 116 | [0, 0, 0, 1, 1], 117 | [0, 0, 2, 1, 1], 118 | [0, 1, 2, 1, 1], 119 | [0, 1, 2, 1, 1], 120 | [1, 0, 0, 1, 1] 121 | ], 122 | [ 123 | [0, 0, 0, 1, 2], 124 | [0, 2, 1, 1, 2], 125 | [0, 2, 1, 1, 2], 126 | [0, 1, 1, 1, 2], 127 | [1, 1, 1, 1, 1] 128 | ], 129 | ])).view(2, 5, 5) 130 | mIoU.add(output, target) 131 | print(mIoU.value()) 132 | -------------------------------------------------------------------------------- /metric/metric.py: -------------------------------------------------------------------------------- 1 | class Metric(object): 2 | """Base class for all metrics. 3 | 4 | From: https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py 5 | """ 6 | def reset(self): 7 | pass 8 | 9 | def add(self): 10 | pass 11 | 12 | def value(self): 13 | pass 14 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jandylin/DFANet_PyTorch/724ea05ba8b44d2dedb0a3389b4dfaaf57f5a2b6/model/__init__.py -------------------------------------------------------------------------------- /model/backbone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates an Xception Model as defined in: 3 | 4 | Francois Chollet 5 | Xception: Deep Learning with Depthwise Separable Convolutions 6 | https://arxiv.org/pdf/1610.02357.pdf 7 | 8 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 9 | 10 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 11 | 12 | REMEMBER to set your image size to 3x299x299 for both test and validation 13 | 14 | normalize = utils.Normalize(mean=[0.5, 0.5, 0.5], 15 | std=[0.5, 0.5, 0.5]) 16 | 17 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 18 | """ 19 | import math 20 | import torch 21 | import torch.nn as nn 22 | 23 | model_url = './checkpoints/XceptionA_best.pth.tar' 24 | 25 | 26 | class SeparableConv2d(nn.Module): 27 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): 28 | super(SeparableConv2d, self).__init__() 29 | 30 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, 31 | bias=bias) 32 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 33 | 34 | def forward(self, x): 35 | x = self.conv1(x) 36 | x = self.pointwise(x) 37 | return x 38 | 39 | 40 | class Block(nn.Module): 41 | def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): 42 | super(Block, self).__init__() 43 | 44 | if out_filters != in_filters or strides != 1: 45 | self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) 46 | self.skipbn = nn.BatchNorm2d(out_filters) 47 | else: 48 | self.skip = None 49 | 50 | self.relu = nn.ReLU(inplace=True) 51 | rep = [] 52 | 53 | filters = in_filters 54 | if grow_first: 55 | rep.append(self.relu) 56 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) 57 | rep.append(nn.BatchNorm2d(out_filters)) 58 | filters = out_filters 59 | 60 | for i in range(reps - 1): 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) 63 | rep.append(nn.BatchNorm2d(filters)) 64 | 65 | if not grow_first: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) 68 | rep.append(nn.BatchNorm2d(out_filters)) 69 | 70 | if not start_with_relu: 71 | rep = rep[1:] 72 | else: 73 | rep[0] = nn.ReLU(inplace=False) 74 | 75 | if strides != 1: 76 | rep.append(nn.MaxPool2d(3, strides, 1)) 77 | self.rep = nn.Sequential(*rep) 78 | 79 | def forward(self, inp): 80 | x = self.rep(inp) 81 | 82 | if self.skip is not None: 83 | skip = self.skip(inp) 84 | skip = self.skipbn(skip) 85 | else: 86 | skip = inp 87 | 88 | x += skip 89 | return x 90 | 91 | 92 | class XceptionA(nn.Module): 93 | """ 94 | Xception optimized for the ImageNet dataset, as specified in 95 | https://arxiv.org/pdf/1610.02357.pdf 96 | 97 | Modified Xception A architecture, as specified in 98 | https://arxiv.org/pdf/1904.02216.pdf 99 | """ 100 | 101 | def __init__(self, num_classes=1000): 102 | """ Constructor 103 | Args: 104 | num_classes: number of classes 105 | """ 106 | super(XceptionA, self).__init__() 107 | 108 | self.num_classes = num_classes 109 | 110 | self.conv1 = nn.Conv2d(3, 8, 3, 2, 1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(8) 112 | 113 | # conv for reducing channel size in input for non-first backbone stages 114 | self.enc2_conv = nn.Conv2d(240, 8, 1, 1, bias=False) # bias=False? 115 | 116 | self.enc2_1 = Block(8, 12, 4, 1, start_with_relu=True, grow_first=True) 117 | self.enc2_2 = Block(12, 12, 4, 1, start_with_relu=True, grow_first=True) 118 | self.enc2_3 = Block(12, 48, 4, 2, start_with_relu=True, grow_first=True) 119 | self.enc2 = nn.Sequential(self.enc2_1, self.enc2_2, self.enc2_3) 120 | 121 | self.enc3_conv = nn.Conv2d(144, 48, 1, 1, bias=False) 122 | 123 | self.enc3_1 = Block(48, 24, 6, 1, start_with_relu=True, grow_first=True) 124 | self.enc3_2 = Block(24, 24, 6, 1, start_with_relu=True, grow_first=True) 125 | self.enc3_3 = Block(24, 96, 6, 2, start_with_relu=True, grow_first=True) 126 | self.enc3 = nn.Sequential(self.enc3_1, self.enc3_2, self.enc3_3) 127 | 128 | self.enc4_conv = nn.Conv2d(288, 96, 1, 1, bias=False) 129 | 130 | self.enc4_1 = Block(96, 48, 4, 1, start_with_relu=True, grow_first=True) 131 | self.enc4_2 = Block(48, 48, 4, 1, start_with_relu=True, grow_first=True) 132 | self.enc4_3 = Block(48, 192, 4, 2, start_with_relu=True, grow_first=True) 133 | self.enc4 = nn.Sequential(self.enc4_1, self.enc4_2, self.enc4_3) 134 | 135 | self.pooling = nn.AdaptiveAvgPool2d(1) 136 | self.fc = nn.Linear(192, num_classes) 137 | self.fca = nn.Conv2d(num_classes, 192, 1) 138 | 139 | # ------- init weights -------- 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 143 | m.weight.data.normal_(0, math.sqrt(2. / n)) 144 | elif isinstance(m, nn.BatchNorm2d): 145 | m.weight.data.fill_(1) 146 | m.bias.data.zero_() 147 | # ----------------------------- 148 | 149 | def forward(self, x): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | 153 | enc2 = self.enc2(x) 154 | enc3 = self.enc3(enc2) 155 | enc4 = self.enc4(enc3) 156 | pool = self.pooling(enc4) 157 | fc = self.fc(pool.view(pool.size(0), -1)) 158 | fca = self.fca(fc.view(fc.size(0), -1, 1, 1)) 159 | fca = enc4 * fca 160 | 161 | return enc2, enc3, enc4, fc, fca 162 | 163 | def forward_concat(self, fca_concat, enc2_concat, enc3_concat, enc4_concat): 164 | """For second and third stage.""" 165 | enc2 = self.enc2(self.enc2_conv(torch.cat((fca_concat, enc2_concat), dim=1))) 166 | enc3 = self.enc3(self.enc3_conv(torch.cat((enc2, enc3_concat), dim=1))) 167 | enc4 = self.enc4(self.enc4_conv(torch.cat((enc3, enc4_concat), dim=1))) 168 | pool = self.pooling(enc4) 169 | fc = self.fc(pool.view(pool.size(0), -1)) 170 | fca = self.fca(fc.view(fc.size(0), -1, 1, 1)) 171 | fca = enc4 * fca 172 | 173 | return enc2, enc3, enc4, fc, fca 174 | 175 | 176 | def backbone(pretrained=False, **kwargs): 177 | """ 178 | Construct Xception. 179 | """ 180 | 181 | model = XceptionA(**kwargs) 182 | if pretrained: 183 | # from collections import OrderedDict 184 | # state_dict = torch.load(model_url) 185 | # new_state_dict = OrderedDict() 186 | # 187 | # for k, v in state_dict.items(): 188 | # name = k[7:] # remove 'module.' of data parallel 189 | # new_state_dict[name] = v 190 | # 191 | # model.load_state_dict(new_state_dict, strict=False) 192 | model.load_state_dict(torch.load(model_url), strict=False) 193 | return model 194 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class ConvBlock(nn.Module): 7 | 8 | def __init__(self, in_channels, out_channels, kernel_size=1): 9 | super(ConvBlock, self).__init__() 10 | self.relu = nn.ReLU() 11 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size) 12 | self.bn = nn.BatchNorm2d(out_channels) 13 | 14 | def forward(self, x): 15 | x_relu = self.relu(x) 16 | x_conv = self.conv(x_relu) 17 | x_bn = self.bn(x_conv) 18 | return x_bn 19 | 20 | 21 | class Decoder(nn.Module): 22 | 23 | def __init__(self, n_classes=19): 24 | super(Decoder, self).__init__() 25 | self.n_classes = n_classes 26 | self.enc1_conv = ConvBlock(48, 32, 1) # not sure about the out channels 27 | 28 | self.enc2_conv = ConvBlock(48, 32, 1) 29 | self.enc2_up = nn.UpsamplingBilinear2d(scale_factor=2) 30 | 31 | self.enc3_conv = ConvBlock(48, 32, 1) 32 | self.enc3_up = nn.UpsamplingBilinear2d(scale_factor=4) 33 | 34 | self.enc_conv = ConvBlock(32, n_classes, 1) 35 | 36 | self.fca1_conv = ConvBlock(192, n_classes, 1) 37 | self.fca1_up = nn.UpsamplingBilinear2d(scale_factor=4) 38 | 39 | self.fca2_conv = ConvBlock(192, n_classes, 1) 40 | self.fca2_up = nn.UpsamplingBilinear2d(scale_factor=8) 41 | 42 | self.fca3_conv = ConvBlock(192, n_classes, 1) 43 | self.fca3_up = nn.UpsamplingBilinear2d(scale_factor=16) 44 | 45 | self.final_up = nn.UpsamplingBilinear2d(scale_factor=4) 46 | 47 | # ------- init weights -------- 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 51 | m.weight.data.normal_(0, math.sqrt(2. / n)) 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | # ----------------------------- 56 | 57 | def forward(self, enc1, enc2, enc3, fca1, fca2, fca3): 58 | """Note that enc1 denotes the output of the enc4 module of backbone instance 1.""" 59 | e1 = self.enc1_conv(enc1) 60 | e2 = self.enc2_up(self.enc2_conv(enc2)) 61 | e3 = self.enc3_up(self.enc3_conv(enc3)) 62 | 63 | e = self.enc_conv(e1 + e2 + e3) 64 | 65 | f1 = self.fca1_up(self.fca1_conv(fca1)) 66 | f2 = self.fca2_up(self.fca1_conv(fca2)) 67 | f3 = self.fca3_up(self.fca1_conv(fca3)) 68 | 69 | o = self.final_up(e + f1 + f2 + f3) 70 | 71 | return o 72 | 73 | -------------------------------------------------------------------------------- /model/dfanet.py: -------------------------------------------------------------------------------- 1 | from model.backbone import backbone 2 | from model.decoder import Decoder 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | model_url = './Cityscapes_best.pth.tar' 8 | 9 | 10 | class DFANet(nn.Module): 11 | 12 | def __init__(self, n_classes=19, pretrained=False, pretrained_backbone=True): 13 | super(DFANet, self).__init__() 14 | self.backbone1 = backbone(pretrained=pretrained_backbone) 15 | self.backbone1_up = nn.UpsamplingBilinear2d(scale_factor=4) 16 | 17 | self.backbone2 = backbone(pretrained=pretrained_backbone) 18 | self.backbone2_up = nn.UpsamplingBilinear2d(scale_factor=4) 19 | 20 | self.backbone3 = backbone(pretrained=pretrained_backbone) 21 | 22 | self.decoder = Decoder(n_classes=n_classes) 23 | 24 | if pretrained: 25 | self.load_state_dict(torch.load(model_url)["state_dict"]) 26 | 27 | def forward(self, x): 28 | enc1_2, enc1_3, enc1_4, fc1, fca1 = self.backbone1(x) 29 | fca1_up = self.backbone1_up(fca1) 30 | 31 | enc2_2, enc2_3, enc2_4, fc2, fca2 = self.backbone2.forward_concat(fca1_up, enc1_2, enc1_3, enc1_4) 32 | fca2_up = self.backbone2_up(fca2) 33 | 34 | enc3_2, enc3_3, enc3_4, fc3, fca3 = self.backbone3.forward_concat(fca2_up, enc2_2, enc2_3, enc2_4) 35 | 36 | out = self.decoder(enc1_2, enc2_2, enc3_2, fca1, fca2, fca3) 37 | 38 | return out 39 | 40 | -------------------------------------------------------------------------------- /plugin.py: -------------------------------------------------------------------------------- 1 | from model.dfanet import DFANet 2 | import numpy as np 3 | from collections import OrderedDict 4 | import cv2 5 | import torch 6 | 7 | 8 | cityscapes_color_dict = { 9 | 0: (128, 64, 128), 10 | 1: (244, 35, 232), 11 | 2: (70, 70, 70), 12 | 3: (102, 102, 156), 13 | 4: (190, 153, 153), 14 | 5: (153, 153, 153), 15 | 6: (250, 170, 30), 16 | 7: (220, 220, 0), 17 | 8: (107, 142, 35), 18 | 9: (152, 251, 152), 19 | 10: (70, 130, 180), 20 | 11: (220, 20, 60), 21 | 12: (255, 0, 0), 22 | 13: (0, 0, 142), 23 | 14: (0, 0, 70), 24 | 15: (0, 60, 100), 25 | 16: (0, 80, 100), 26 | 17: (0, 0, 230), 27 | 18: (119, 11, 32), 28 | 19: (0, 0, 0) 29 | } 30 | 31 | mask_to_colormap = np.vectorize(lambda x: cityscapes_color_dict[x]) 32 | 33 | 34 | class DFANetPlugin(object): 35 | def __init__(self, im_height, im_width, use_cuda, model_url='./Cityscapes_best.pth.tar', opacity=0.4): 36 | super().__init__() 37 | self.name = "DFANet" 38 | self.im_height = im_height 39 | self.im_width = im_width 40 | self.use_cuda = use_cuda 41 | self.opacity = opacity 42 | 43 | self.model = DFANet(pretrained=False, pretrained_backbone=False) 44 | 45 | state_dict = torch.load(model_url) 46 | new_state_dict = OrderedDict() 47 | for k, v in state_dict.items(): 48 | name = k[7:] # remove 'module.' from nn.DataParallel 49 | new_state_dict[name] = v 50 | self.model.load_state_dict(new_state_dict) 51 | 52 | self.model.eval() 53 | if use_cuda: 54 | self.model.cuda() 55 | 56 | def process(self, image): 57 | x = cv2.resize(image, dsize=(1024, 1024, 3)).transpose((2, 0, 1)) 58 | x = torch.from_numpy(x).view(1, 3, 1024, 1024) 59 | if self.use_cuda: 60 | x = x.cuda() 61 | _, mask = self.model(x).max(1) 62 | if self.use_cuda: 63 | mask = mask.cpu() 64 | mask = mask.numpy() 65 | colormap = mask_to_colormap(mask) 66 | colormap = np.array(colormap).transpose((1, 2, 0)) 67 | colormap = cv2.resize(colormap, dsize=image.shape, interpolation=cv2.INTER_NEAREST) 68 | output = self.opacity * colormap + (1 - self.opacity) * image 69 | return output 70 | 71 | def release(self): 72 | del self.model 73 | -------------------------------------------------------------------------------- /pretrain_backbone.py: -------------------------------------------------------------------------------- 1 | # Adapted from official PyTorch Tutorial 2 | import argparse 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | import sys 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torchvision.transforms as transforms 20 | import torchvision.datasets as datasets 21 | from model.backbone import backbone 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 24 | parser.add_argument('data', metavar='DIR', 25 | help='path to dataset') 26 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 27 | help='number of data loading workers (default: 4)') 28 | parser.add_argument('--epochs', default=60, type=int, metavar='N', 29 | help='number of total epochs to run') 30 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 31 | help='manual epoch number (useful on restarts)') 32 | parser.add_argument('-b', '--batch-size', default=256, type=int, 33 | metavar='N', 34 | help='mini-batch size (default: 256), this is the total ' 35 | 'batch size of all GPUs on the current node when ' 36 | 'using Data Parallel or Distributed Data Parallel') 37 | parser.add_argument('--lr', '--learning-rate', default=0.3, type=float, 38 | metavar='LR', help='initial learning rate', dest='lr') 39 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 40 | help='momentum') 41 | parser.add_argument('--wd', '--weight-decay', default=4e-5, type=float, 42 | metavar='W', help='weight decay (default: 1e-4)', 43 | dest='weight_decay') 44 | parser.add_argument('-p', '--print-freq', default=10, type=int, 45 | metavar='N', help='print frequency (default: 10)') 46 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 47 | help='path to latest checkpoint (default: none)') 48 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 49 | help='evaluate model on validation set') 50 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 51 | help='use pre-trained model') 52 | parser.add_argument('--world-size', default=-1, type=int, 53 | help='number of nodes for distributed training') 54 | parser.add_argument('--rank', default=-1, type=int, 55 | help='node rank for distributed training') 56 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 57 | help='url used to set up distributed training') 58 | parser.add_argument('--dist-backend', default='nccl', type=str, 59 | help='distributed backend') 60 | parser.add_argument('--seed', default=None, type=int, 61 | help='seed for initializing training. ') 62 | parser.add_argument('--gpu', default=None, type=int, 63 | help='GPU id to use.') 64 | parser.add_argument('--multiprocessing-distributed', action='store_true', 65 | help='Use multi-processing distributed training to launch ' 66 | 'N processes per node, which has N GPUs. This is the ' 67 | 'fastest way to use PyTorch for either single node or ' 68 | 'multi node data parallel training') 69 | 70 | best_acc1 = 0 71 | 72 | 73 | def main(): 74 | args = parser.parse_args() 75 | 76 | if args.seed is not None: 77 | random.seed(args.seed) 78 | torch.manual_seed(args.seed) 79 | cudnn.deterministic = True 80 | warnings.warn('You have chosen to seed training. ' 81 | 'This will turn on the CUDNN deterministic setting, ' 82 | 'which can slow down your training considerably! ' 83 | 'You may see unexpected behavior when restarting ' 84 | 'from checkpoints.') 85 | 86 | if args.gpu is not None: 87 | warnings.warn('You have chosen a specific GPU. This will completely ' 88 | 'disable data parallelism.') 89 | 90 | if args.dist_url == "env://" and args.world_size == -1: 91 | args.world_size = int(os.environ["WORLD_SIZE"]) 92 | 93 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 94 | 95 | ngpus_per_node = torch.cuda.device_count() 96 | if args.multiprocessing_distributed: 97 | # Since we have ngpus_per_node processes per node, the total world_size 98 | # needs to be adjusted accordingly 99 | args.world_size = ngpus_per_node * args.world_size 100 | # Use torch.multiprocessing.spawn to launch distributed processes: the 101 | # main_worker process function 102 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 103 | else: 104 | # Simply call main_worker function 105 | main_worker(args.gpu, ngpus_per_node, args) 106 | 107 | 108 | def main_worker(gpu, ngpus_per_node, args): 109 | global best_acc1 110 | args.gpu = gpu 111 | 112 | if args.gpu is not None: 113 | print("Use GPU: {} for training".format(args.gpu)) 114 | 115 | if args.distributed: 116 | if args.dist_url == "env://" and args.rank == -1: 117 | args.rank = int(os.environ["RANK"]) 118 | if args.multiprocessing_distributed: 119 | # For multiprocessing distributed training, rank needs to be the 120 | # global rank among all the processes 121 | args.rank = args.rank * ngpus_per_node + gpu 122 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 123 | world_size=args.world_size, rank=args.rank) 124 | # create model 125 | if args.pretrained: 126 | print("=> using pre-trained model 'XceptionA'") 127 | model = backbone(pretrained=True) 128 | else: 129 | print("=> creating model 'XceptionA'") 130 | model = backbone() 131 | 132 | if args.distributed: 133 | # For multiprocessing distributed, DistributedDataParallel constructor 134 | # should always set the single device scope, otherwise, 135 | # DistributedDataParallel will use all available devices. 136 | if args.gpu is not None: 137 | torch.cuda.set_device(args.gpu) 138 | model.cuda(args.gpu) 139 | # When using a single GPU per process and per 140 | # DistributedDataParallel, we need to divide the batch size 141 | # ourselves based on the total number of GPUs we have 142 | args.batch_size = int(args.batch_size / ngpus_per_node) 143 | args.workers = int(args.workers / ngpus_per_node) 144 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 145 | else: 146 | model.cuda() 147 | # DistributedDataParallel will divide and allocate batch_size to all 148 | # available GPUs if device_ids are not set 149 | model = torch.nn.parallel.DistributedDataParallel(model) 150 | elif args.gpu is not None: 151 | torch.cuda.set_device(args.gpu) 152 | model = model.cuda(args.gpu) 153 | else: 154 | # DataParallel will divide and allocate batch_size to all available GPUs 155 | model = torch.nn.DataParallel(model).cuda() 156 | 157 | # define loss function (criterion) and optimizer 158 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 159 | 160 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 161 | momentum=args.momentum, 162 | weight_decay=args.weight_decay) 163 | 164 | # optionally resume from a checkpoint 165 | if args.resume: 166 | if os.path.isfile(args.resume): 167 | print("=> loading checkpoint '{}'".format(args.resume)) 168 | checkpoint = torch.load(args.resume) 169 | args.start_epoch = checkpoint['epoch'] 170 | best_acc1 = checkpoint['best_acc1'] 171 | if args.gpu is not None: 172 | # best_acc1 may be from a checkpoint from a different GPU 173 | best_acc1 = best_acc1.to(args.gpu) 174 | model.load_state_dict(checkpoint['state_dict']) 175 | optimizer.load_state_dict(checkpoint['optimizer']) 176 | print("=> loaded checkpoint '{}' (epoch {})" 177 | .format(args.resume, checkpoint['epoch'])) 178 | else: 179 | print("=> no checkpoint found at '{}'".format(args.resume)) 180 | 181 | cudnn.benchmark = True 182 | 183 | # Data loading code 184 | traindir = os.path.join(args.data, 'train') 185 | valdir = os.path.join(args.data, 'val') 186 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 187 | std=[0.229, 0.224, 0.225]) 188 | 189 | train_dataset = datasets.ImageFolder( 190 | traindir, 191 | transforms.Compose([ 192 | transforms.RandomResizedCrop(224), 193 | transforms.RandomHorizontalFlip(), 194 | transforms.ToTensor(), 195 | normalize, 196 | ])) 197 | 198 | if args.distributed: 199 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 200 | else: 201 | train_sampler = None 202 | 203 | train_loader = torch.utils.data.DataLoader( 204 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 205 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 206 | 207 | val_loader = torch.utils.data.DataLoader( 208 | datasets.ImageFolder(valdir, transforms.Compose([ 209 | transforms.Resize(256), 210 | transforms.CenterCrop(224), 211 | transforms.ToTensor(), 212 | normalize, 213 | ])), 214 | batch_size=args.batch_size, shuffle=False, 215 | num_workers=args.workers, pin_memory=True) 216 | 217 | if args.evaluate: 218 | validate(val_loader, model, criterion, args) 219 | return 220 | 221 | for epoch in range(args.start_epoch, args.epochs): 222 | if args.distributed: 223 | train_sampler.set_epoch(epoch) 224 | adjust_learning_rate(optimizer, epoch, args) 225 | 226 | # train for one epoch 227 | train(train_loader, model, criterion, optimizer, epoch, args) 228 | 229 | # evaluate on validation set 230 | acc1 = validate(val_loader, model, criterion, args) 231 | 232 | # remember best acc@1 and save checkpoint 233 | is_best = acc1 > best_acc1 234 | best_acc1 = max(acc1, best_acc1) 235 | 236 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 237 | and args.rank % ngpus_per_node == 0): 238 | save_checkpoint({ 239 | 'epoch': epoch + 1, 240 | 'state_dict': model.state_dict(), 241 | 'best_acc1': best_acc1, 242 | 'optimizer': optimizer.state_dict(), 243 | }, is_best) 244 | 245 | 246 | def train(train_loader, model, criterion, optimizer, epoch, args): 247 | batch_time = AverageMeter('Time', ':6.3f') 248 | data_time = AverageMeter('Data', ':6.3f') 249 | losses = AverageMeter('Loss', ':.4e') 250 | top1 = AverageMeter('Acc@1', ':6.2f') 251 | top5 = AverageMeter('Acc@5', ':6.2f') 252 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 253 | top5, prefix="Epoch: [{}]".format(epoch)) 254 | 255 | # switch to train mode 256 | model.train() 257 | 258 | end = time.time() 259 | for i, (input, target) in enumerate(train_loader): 260 | # measure data loading time 261 | data_time.update(time.time() - end) 262 | 263 | if args.gpu is not None: 264 | input = input.cuda(args.gpu, non_blocking=True) 265 | target = target.cuda(args.gpu, non_blocking=True) 266 | 267 | # compute output 268 | _, _, _, output, _ = model(input) 269 | loss = criterion(output, target) 270 | 271 | # measure accuracy and record loss 272 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 273 | losses.update(loss.item(), input.size(0)) 274 | top1.update(acc1[0], input.size(0)) 275 | top5.update(acc5[0], input.size(0)) 276 | 277 | # compute gradient and do SGD step 278 | optimizer.zero_grad() 279 | loss.backward() 280 | optimizer.step() 281 | 282 | # measure elapsed time 283 | batch_time.update(time.time() - end) 284 | end = time.time() 285 | 286 | if i % args.print_freq == 0: 287 | progress.print(i) 288 | 289 | 290 | def validate(val_loader, model, criterion, args): 291 | batch_time = AverageMeter('Time', ':6.3f') 292 | losses = AverageMeter('Loss', ':.4e') 293 | top1 = AverageMeter('Acc@1', ':6.2f') 294 | top5 = AverageMeter('Acc@5', ':6.2f') 295 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 296 | prefix='Test: ') 297 | 298 | # switch to evaluate mode 299 | model.eval() 300 | 301 | with torch.no_grad(): 302 | end = time.time() 303 | for i, (input, target) in enumerate(val_loader): 304 | if args.gpu is not None: 305 | input = input.cuda(args.gpu, non_blocking=True) 306 | target = target.cuda(args.gpu, non_blocking=True) 307 | 308 | # compute output 309 | _, _, _, output, _ = model(input) 310 | loss = criterion(output, target) 311 | 312 | # measure accuracy and record loss 313 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 314 | losses.update(loss.item(), input.size(0)) 315 | top1.update(acc1[0], input.size(0)) 316 | top5.update(acc5[0], input.size(0)) 317 | 318 | # measure elapsed time 319 | batch_time.update(time.time() - end) 320 | end = time.time() 321 | 322 | if i % args.print_freq == 0: 323 | progress.print(i) 324 | 325 | # TODO: this should also be done with the ProgressMeter 326 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 327 | .format(top1=top1, top5=top5)) 328 | 329 | return top1.avg 330 | 331 | 332 | def save_checkpoint(state, is_best, filename='./checkpoints/XceptionA.pth.tar'): 333 | torch.save(state, filename) 334 | if is_best: 335 | shutil.copyfile(filename, 'XceptionA_best.pth.tar') 336 | 337 | 338 | class AverageMeter(object): 339 | """Computes and stores the average and current value""" 340 | def __init__(self, name, fmt=':f'): 341 | self.name = name 342 | self.fmt = fmt 343 | self.reset() 344 | 345 | def reset(self): 346 | self.val = 0 347 | self.avg = 0 348 | self.sum = 0 349 | self.count = 0 350 | 351 | def update(self, val, n=1): 352 | self.val = val 353 | self.sum += val * n 354 | self.count += n 355 | self.avg = self.sum / self.count 356 | 357 | def __str__(self): 358 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 359 | return fmtstr.format(**self.__dict__) 360 | 361 | 362 | class ProgressMeter(object): 363 | def __init__(self, num_batches, *meters, prefix=""): 364 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 365 | self.meters = meters 366 | self.prefix = prefix 367 | 368 | def print(self, batch): 369 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 370 | entries += [str(meter) for meter in self.meters] 371 | print('\t'.join(entries)) 372 | 373 | def _get_batch_fmtstr(self, num_batches): 374 | num_digits = len(str(num_batches // 1)) 375 | fmt = '{:' + str(num_digits) + 'd}' 376 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 377 | 378 | 379 | def adjust_learning_rate(optimizer, epoch, args): 380 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 381 | lr = args.lr * (0.1 ** (epoch // 30)) 382 | for param_group in optimizer.param_groups: 383 | param_group['lr'] = lr 384 | 385 | 386 | def accuracy(output, target, topk=(1,)): 387 | """Computes the accuracy over the k top predictions for the specified values of k""" 388 | with torch.no_grad(): 389 | maxk = max(topk) 390 | batch_size = target.size(0) 391 | 392 | _, pred = output.topk(maxk, 1, True, True) 393 | pred = pred.t() 394 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 395 | 396 | res = [] 397 | for k in topk: 398 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 399 | res.append(correct_k.mul_(100.0 / batch_size)) 400 | return res 401 | 402 | 403 | if __name__ == '__main__': 404 | main() -------------------------------------------------------------------------------- /train_cityscapes.py: -------------------------------------------------------------------------------- 1 | # Adapted from official PyTorch Tutorial 2 | import argparse 3 | import os 4 | import random 5 | import shutil 6 | import time 7 | import warnings 8 | import sys 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | from datasets.cityscapes import Cityscapes 20 | from utils import joint_transforms 21 | from metric.iou import IoU 22 | from model.dfanet import DFANet 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch Cityscapes Training') 25 | parser.add_argument('data', metavar='DIR', 26 | help='path to dataset') 27 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 28 | help='number of data loading workers (default: 4)') 29 | parser.add_argument('--epochs', default=500, type=int, metavar='N', 30 | help='number of total epochs to run') 31 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 32 | help='manual epoch number (useful on restarts)') 33 | parser.add_argument('-b', '--batch-size', default=4, type=int, 34 | metavar='N', 35 | help='mini-batch size (default: 256), this is the total ' 36 | 'batch size of all GPUs on the current node when ' 37 | 'using Data Parallel or Distributed Data Parallel') 38 | parser.add_argument('--lr', '--learning-rate', default=2e-1, type=float, 39 | metavar='LR', help='initial learning rate', dest='lr') 40 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 41 | help='momentum') 42 | parser.add_argument('--power', default=0.9, type=float, metavar='M', 43 | help='power for poly learning rate policy') 44 | parser.add_argument('--wd', '--weight-decay', default=1e-5, type=float, 45 | metavar='W', help='weight decay (default: 1e-4)', 46 | dest='weight_decay') 47 | parser.add_argument('-p', '--print-freq', default=10, type=int, 48 | metavar='N', help='print frequency (default: 10)') 49 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 50 | help='path to latest checkpoint (default: none)') 51 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 52 | help='evaluate model on validation set') 53 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 54 | help='use pre-trained model') 55 | parser.add_argument('--world-size', default=-1, type=int, 56 | help='number of nodes for distributed training') 57 | parser.add_argument('--rank', default=-1, type=int, 58 | help='node rank for distributed training') 59 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 60 | help='url used to set up distributed training') 61 | parser.add_argument('--dist-backend', default='nccl', type=str, 62 | help='distributed backend') 63 | parser.add_argument('--seed', default=None, type=int, 64 | help='seed for initializing training. ') 65 | parser.add_argument('--gpu', default=None, type=int, 66 | help='GPU id to use.') 67 | parser.add_argument('--multiprocessing-distributed', action='store_true', 68 | help='Use multi-processing distributed training to launch ' 69 | 'N processes per node, which has N GPUs. This is the ' 70 | 'fastest way to use PyTorch for either single node or ' 71 | 'multi node data parallel training') 72 | 73 | best_mIoU = 0 74 | 75 | 76 | def main(): 77 | args = parser.parse_args() 78 | 79 | if args.seed is not None: 80 | random.seed(args.seed) 81 | torch.manual_seed(args.seed) 82 | cudnn.deterministic = True 83 | warnings.warn('You have chosen to seed training. ' 84 | 'This will turn on the CUDNN deterministic setting, ' 85 | 'which can slow down your training considerably! ' 86 | 'You may see unexpected behavior when restarting ' 87 | 'from checkpoints.') 88 | 89 | if args.gpu is not None: 90 | warnings.warn('You have chosen a specific GPU. This will completely ' 91 | 'disable data parallelism.') 92 | 93 | if args.dist_url == "env://" and args.world_size == -1: 94 | args.world_size = int(os.environ["WORLD_SIZE"]) 95 | 96 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 97 | 98 | ngpus_per_node = torch.cuda.device_count() 99 | if args.multiprocessing_distributed: 100 | # Since we have ngpus_per_node processes per node, the total world_size 101 | # needs to be adjusted accordingly 102 | args.world_size = ngpus_per_node * args.world_size 103 | # Use torch.multiprocessing.spawn to launch distributed processes: the 104 | # main_worker process function 105 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 106 | else: 107 | # Simply call main_worker function 108 | main_worker(args.gpu, ngpus_per_node, args) 109 | 110 | 111 | def main_worker(gpu, ngpus_per_node, args): 112 | global best_mIoU 113 | args.gpu = gpu 114 | 115 | if args.gpu is not None: 116 | print("Use GPU: {} for training".format(args.gpu)) 117 | 118 | if args.distributed: 119 | if args.dist_url == "env://" and args.rank == -1: 120 | args.rank = int(os.environ["RANK"]) 121 | if args.multiprocessing_distributed: 122 | # For multiprocessing distributed training, rank needs to be the 123 | # global rank among all the processes 124 | args.rank = args.rank * ngpus_per_node + gpu 125 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 126 | world_size=args.world_size, rank=args.rank) 127 | # create model 128 | if args.pretrained: 129 | print("=> using pre-trained model 'DFANet'") 130 | model = DFANet(pretrained=True, pretrained_backbone=True) 131 | else: 132 | print("=> creating model 'DFANet'") 133 | model = DFANet(pretrained=False, pretrained_backbone=True) 134 | 135 | if args.distributed: 136 | # For multiprocessing distributed, DistributedDataParallel constructor 137 | # should always set the single device scope, otherwise, 138 | # DistributedDataParallel will use all available devices. 139 | if args.gpu is not None: 140 | torch.cuda.set_device(args.gpu) 141 | model.cuda(args.gpu) 142 | # When using a single GPU per process and per 143 | # DistributedDataParallel, we need to divide the batch size 144 | # ourselves based on the total number of GPUs we have 145 | args.batch_size = int(args.batch_size / ngpus_per_node) 146 | args.workers = int(args.workers / ngpus_per_node) 147 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 148 | else: 149 | model.cuda() 150 | # DistributedDataParallel will divide and allocate batch_size to all 151 | # available GPUs if device_ids are not set 152 | model = torch.nn.parallel.DistributedDataParallel(model) 153 | elif args.gpu is not None: 154 | torch.cuda.set_device(args.gpu) 155 | model = model.cuda(args.gpu) 156 | else: 157 | # DataParallel will divide and allocate batch_size to all available GPUs 158 | model = torch.nn.DataParallel(model).cuda() 159 | 160 | # define loss function (criterion) and optimizer 161 | criterion = nn.CrossEntropyLoss(ignore_index=19).cuda(args.gpu) 162 | 163 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 164 | momentum=args.momentum, 165 | weight_decay=args.weight_decay) 166 | 167 | metric = IoU(20, ignore_index=19) 168 | 169 | # optionally resume from a checkpoint 170 | if args.resume: 171 | if os.path.isfile(args.resume): 172 | print("=> loading checkpoint '{}'".format(args.resume)) 173 | checkpoint = torch.load(args.resume) 174 | args.start_epoch = checkpoint['epoch'] 175 | best_mIoU = checkpoint['best_mIoU'] 176 | if args.gpu is not None: 177 | # best_mIoU may be from a checkpoint from a different GPU 178 | best_mIoU = best_mIoU.to(args.gpu) 179 | model.load_state_dict(checkpoint['state_dict']) 180 | optimizer.load_state_dict(checkpoint['optimizer']) 181 | print("=> loaded checkpoint '{}' (epoch {})" 182 | .format(args.resume, checkpoint['epoch'])) 183 | else: 184 | print("=> no checkpoint found at '{}'".format(args.resume)) 185 | 186 | cudnn.benchmark = True 187 | 188 | # Data loading code 189 | train_dataset = Cityscapes(args.data, split='train', mode='fine', target_type='semantic', 190 | transform=joint_transforms.Compose([ 191 | joint_transforms.RandomHorizontalFlip(), 192 | joint_transforms.RandomSized(1024), 193 | joint_transforms.ToTensor(), 194 | joint_transforms.Normalize( 195 | mean=[0.485, 0.456, 0.406], 196 | std=[0.229, 0.224, 0.225]) 197 | ])) 198 | 199 | if args.distributed: 200 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 201 | else: 202 | train_sampler = None 203 | 204 | train_loader = torch.utils.data.DataLoader( 205 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 206 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 207 | 208 | val_loader = torch.utils.data.DataLoader( 209 | Cityscapes(args.data, split='val', mode='fine', target_type='semantic', 210 | transform=joint_transforms.Compose([ 211 | joint_transforms.RandomHorizontalFlip(), 212 | joint_transforms.RandomSized(1024), 213 | joint_transforms.ToTensor(), 214 | joint_transforms.Normalize( 215 | mean=[0.485, 0.456, 0.406], 216 | std=[0.229, 0.224, 0.225]) 217 | ])), 218 | batch_size=args.batch_size, shuffle=False, 219 | num_workers=args.workers, pin_memory=True) 220 | 221 | if args.evaluate: 222 | validate(val_loader, model, criterion, args) 223 | return 224 | 225 | for epoch in range(args.start_epoch, args.epochs): 226 | if args.distributed: 227 | train_sampler.set_epoch(epoch) 228 | 229 | adjust_learning_rate(optimizer, epoch, args) 230 | 231 | # train for one epoch 232 | train(train_loader, model, criterion, optimizer, metric, epoch, args) 233 | 234 | # evaluate on validation set 235 | mIoU = validate(val_loader, model, criterion, metric, args) 236 | 237 | # remember best mIoU and save checkpoint 238 | is_best = mIoU > best_mIoU 239 | best_mIoU = max(mIoU, best_mIoU) 240 | 241 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 242 | and args.rank % ngpus_per_node == 0): 243 | save_checkpoint({ 244 | 'epoch': epoch + 1, 245 | 'state_dict': model.state_dict(), 246 | 'best_mIoU': best_mIoU, 247 | 'optimizer': optimizer.state_dict(), 248 | }, is_best) 249 | 250 | 251 | def train(train_loader, model, criterion, optimizer, metric, epoch, args): 252 | batch_time = AverageMeter('Time', ':6.3f') 253 | data_time = AverageMeter('Data', ':6.3f') 254 | losses = AverageMeter('Loss', ':.4e') 255 | mIoU = AverageMeter('mIoU', ':6.2f') 256 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, mIoU, prefix="Epoch: [{}]".format(epoch)) 257 | 258 | # switch to train mode 259 | model.train() 260 | 261 | end = time.time() 262 | for i, (input, target) in enumerate(train_loader): 263 | # measure data loading time 264 | data_time.update(time.time() - end) 265 | 266 | if args.gpu is not None: 267 | input = input.cuda(args.gpu, non_blocking=True) 268 | target = target.cuda(args.gpu, non_blocking=True) 269 | 270 | # compute output 271 | output = model(input) 272 | loss = criterion(output.view(output.shape[0], 19, -1), target.view(target.shape[0], -1)) 273 | 274 | # measure accuracy and record loss 275 | losses.update(loss.item(), input.size(0)) 276 | metric.reset() 277 | metric.add(output.max(1)[1].view(output.shape[0], 1024, 1024), target.view(target.shape[0], 1024, 1024)) 278 | mIoU.update(metric.value()[1]) 279 | 280 | # compute gradient and do SGD step 281 | optimizer.zero_grad() 282 | loss.backward() 283 | optimizer.step() 284 | 285 | # measure elapsed time 286 | batch_time.update(time.time() - end) 287 | end = time.time() 288 | 289 | if i % args.print_freq == 0: 290 | progress.print(i) 291 | 292 | 293 | def validate(val_loader, model, criterion, metric, args): 294 | batch_time = AverageMeter('Time', ':6.3f') 295 | losses = AverageMeter('Loss', ':.4e') 296 | mIoU = AverageMeter('mIoU', ':6.2f') 297 | progress = ProgressMeter(len(val_loader), batch_time, losses, mIoU, prefix='Test: ') 298 | 299 | # switch to evaluate mode 300 | model.eval() 301 | 302 | with torch.no_grad(): 303 | end = time.time() 304 | for i, (input, target) in enumerate(val_loader): 305 | if args.gpu is not None: 306 | input = input.cuda(args.gpu, non_blocking=True) 307 | target = target.cuda(args.gpu, non_blocking=True) 308 | 309 | # compute output 310 | output = model(input) 311 | loss = criterion(output.view(output.shape[0], 19, -1), target.view(target.shape[0], -1)) 312 | 313 | # measure mIoU and record loss 314 | losses.update(loss.item(), input.size(0)) 315 | metric.reset() 316 | metric.add(output.max(1)[1].view(output.shape[0], 1024, 1024), target.view(target.shape[0], 1024, 1024)) 317 | mIoU.update(metric.value()[1]) 318 | 319 | # measure elapsed time 320 | batch_time.update(time.time() - end) 321 | end = time.time() 322 | 323 | if i % args.print_freq == 0: 324 | progress.print(i) 325 | 326 | print(' * mIoU {mIoU.avg:.3f}'.format(mIoU=mIoU)) 327 | 328 | return mIoU.avg # ? 329 | 330 | 331 | def save_checkpoint(state, is_best, filename='./checkpoints/Cityscapes.pth.tar'): 332 | torch.save(state, filename) 333 | if is_best: 334 | shutil.copyfile(filename, 'Cityscapes_best.pth.tar') 335 | 336 | 337 | class AverageMeter(object): 338 | """Computes and stores the average and current value""" 339 | def __init__(self, name, fmt=':f'): 340 | self.name = name 341 | self.fmt = fmt 342 | self.reset() 343 | 344 | def reset(self): 345 | self.val = 0 346 | self.avg = 0 347 | self.sum = 0 348 | self.count = 0 349 | 350 | def update(self, val, n=1): 351 | self.val = val 352 | self.sum += val * n 353 | self.count += n 354 | self.avg = self.sum / self.count 355 | 356 | def __str__(self): 357 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 358 | return fmtstr.format(**self.__dict__) 359 | 360 | 361 | class ProgressMeter(object): 362 | def __init__(self, num_batches, *meters, prefix=""): 363 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 364 | self.meters = meters 365 | self.prefix = prefix 366 | 367 | def print(self, batch): 368 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 369 | entries += [str(meter) for meter in self.meters] 370 | print('\t'.join(entries)) 371 | 372 | def _get_batch_fmtstr(self, num_batches): 373 | num_digits = len(str(num_batches // 1)) 374 | fmt = '{:' + str(num_digits) + 'd}' 375 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 376 | 377 | 378 | def adjust_learning_rate(optimizer, epoch, args): 379 | """Polynomial decay learning rate policy.""" 380 | lr = args.lr * (1 - epoch/args.epochs)**args.power 381 | for param_group in optimizer.param_groups: 382 | param_group['lr'] = lr 383 | 384 | 385 | if __name__ == '__main__': 386 | main() 387 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jandylin/DFANet_PyTorch/724ea05ba8b44d2dedb0a3389b4dfaaf57f5a2b6/utils/__init__.py -------------------------------------------------------------------------------- /utils/joint_transforms.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py 2 | import math 3 | import numbers 4 | import random 5 | 6 | from PIL import Image, ImageOps 7 | import numpy as np 8 | import torchvision.transforms.functional as F 9 | 10 | 11 | class Compose(object): 12 | def __init__(self, transforms): 13 | self.transforms = transforms 14 | 15 | def __call__(self, img, mask): 16 | assert img.size == mask.size 17 | for t in self.transforms: 18 | img, mask = t(img, mask) 19 | return img, mask 20 | 21 | 22 | _trainID_map = {0: 19, 23 | 1: 19, 24 | 2: 19, 25 | 3: 19, 26 | 4: 19, 27 | 5: 19, 28 | 6: 19, 29 | 7: 0, 30 | 8: 1, 31 | 9: 19, 32 | 10: 19, 33 | 11: 2, 34 | 12: 3, 35 | 13: 4, 36 | 14: 19, 37 | 15: 19, 38 | 16: 19, 39 | 17: 5, 40 | 18: 19, 41 | 19: 6, 42 | 20: 7, 43 | 21: 8, 44 | 22: 9, 45 | 23: 10, 46 | 24: 11, 47 | 25: 12, 48 | 26: 13, 49 | 27: 14, 50 | 28: 15, 51 | 29: 19, 52 | 30: 19, 53 | 31: 16, 54 | 32: 17, 55 | 33: 18, 56 | -1: 19} 57 | 58 | 59 | def _cityscapes_trainID_map(id): 60 | return _trainID_map[id] 61 | 62 | 63 | class ToTensor(object): 64 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 65 | 66 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 67 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 68 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 69 | or if the numpy.ndarray has dtype = np.uint8 70 | 71 | In the other cases, tensors are returned without scaling. 72 | """ 73 | 74 | def __call__(self, pic, mask): 75 | """ 76 | Args: 77 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 78 | 79 | Returns: 80 | Tensor: Converted image. 81 | """ 82 | mask = np.array(mask, dtype=np.int) 83 | trainID_map = np.vectorize(_cityscapes_trainID_map) 84 | mask = trainID_map(mask) 85 | return F.to_tensor(pic), F.to_tensor(mask) 86 | 87 | 88 | class Normalize(object): 89 | """Normalize a tensor image with mean and standard deviation. 90 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 91 | will normalize each channel of the input ``torch.*Tensor`` i.e. 92 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 93 | 94 | .. note:: 95 | This transform acts out of place, i.e., it does not mutates the input tensor. 96 | 97 | Args: 98 | mean (sequence): Sequence of means for each channel. 99 | std (sequence): Sequence of standard deviations for each channel. 100 | """ 101 | 102 | def __init__(self, mean, std, inplace=False): 103 | self.mean = mean 104 | self.std = std 105 | self.inplace = inplace 106 | 107 | def __call__(self, tensor, mask): 108 | """ 109 | Args: 110 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 111 | 112 | Returns: 113 | Tensor: Normalized Tensor image. 114 | """ 115 | return F.normalize(tensor, self.mean, self.std, self.inplace), mask 116 | 117 | 118 | class RandomCrop(object): 119 | def __init__(self, size, padding=0): 120 | if isinstance(size, numbers.Number): 121 | self.size = (int(size), int(size)) 122 | else: 123 | self.size = size 124 | self.padding = padding 125 | 126 | def __call__(self, img, mask): 127 | if self.padding > 0: 128 | img = ImageOps.expand(img, border=self.padding, fill=0) 129 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 130 | 131 | assert img.size == mask.size 132 | w, h = img.size 133 | th, tw = self.size 134 | if w == tw and h == th: 135 | return img, mask 136 | if w < tw or h < th: 137 | return img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST) 138 | 139 | x1 = random.randint(0, w - tw) 140 | y1 = random.randint(0, h - th) 141 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 142 | 143 | 144 | class CenterCrop(object): 145 | def __init__(self, size): 146 | if isinstance(size, numbers.Number): 147 | self.size = (int(size), int(size)) 148 | else: 149 | self.size = size 150 | 151 | def __call__(self, img, mask): 152 | assert img.size == mask.size 153 | w, h = img.size 154 | th, tw = self.size 155 | x1 = int(round((w - tw) / 2.)) 156 | y1 = int(round((h - th) / 2.)) 157 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 158 | 159 | 160 | class RandomHorizontalFlip(object): 161 | def __call__(self, img, mask): 162 | if random.random() < 0.5: 163 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT) 164 | return img, mask 165 | 166 | 167 | class FreeScale(object): 168 | def __init__(self, size): 169 | self.size = tuple(reversed(size)) # size: (h, w) 170 | 171 | def __call__(self, img, mask): 172 | assert img.size == mask.size 173 | return img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST) 174 | 175 | 176 | class Scale(object): 177 | def __init__(self, size): 178 | self.size = size 179 | 180 | def __call__(self, img, mask): 181 | assert img.size == mask.size 182 | w, h = img.size 183 | if (w >= h and w == self.size) or (h >= w and h == self.size): 184 | return img, mask 185 | if w > h: 186 | ow = self.size 187 | oh = int(self.size * h / w) 188 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 189 | else: 190 | oh = self.size 191 | ow = int(self.size * w / h) 192 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 193 | 194 | 195 | class RandomSizedCrop(object): 196 | def __init__(self, size): 197 | self.size = size 198 | 199 | def __call__(self, img, mask): 200 | assert img.size == mask.size 201 | for attempt in range(10): 202 | area = img.size[0] * img.size[1] 203 | target_area = random.uniform(0.45, 1.0) * area 204 | aspect_ratio = random.uniform(0.5, 2) 205 | 206 | w = int(round(math.sqrt(target_area * aspect_ratio))) 207 | h = int(round(math.sqrt(target_area / aspect_ratio))) 208 | 209 | if random.random() < 0.5: 210 | w, h = h, w 211 | 212 | if w <= img.size[0] and h <= img.size[1]: 213 | x1 = random.randint(0, img.size[0] - w) 214 | y1 = random.randint(0, img.size[1] - h) 215 | 216 | img = img.crop((x1, y1, x1 + w, y1 + h)) 217 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 218 | assert (img.size == (w, h)) 219 | 220 | return img.resize((self.size, self.size), Image.BILINEAR), mask.resize((self.size, self.size), 221 | Image.NEAREST) 222 | 223 | # Fallback 224 | scale = Scale(self.size) 225 | crop = CenterCrop(self.size) 226 | return crop(*scale(img, mask)) 227 | 228 | 229 | class RandomRotate(object): 230 | def __init__(self, degree): 231 | self.degree = degree 232 | 233 | def __call__(self, img, mask): 234 | rotate_degree = random.random() * 2 * self.degree - self.degree 235 | return img.rotate(rotate_degree, Image.BILINEAR), mask.rotate(rotate_degree, Image.NEAREST) 236 | 237 | 238 | class RandomSized(object): 239 | def __init__(self, size): 240 | self.size = size 241 | self.scale = Scale(self.size) 242 | self.crop = RandomCrop(self.size) 243 | 244 | def __call__(self, img, mask): 245 | assert img.size == mask.size 246 | 247 | w = int(random.uniform(0.75, 1.75) * img.size[0]) 248 | h = int(random.uniform(0.75, 1.75) * img.size[1]) 249 | 250 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 251 | 252 | return self.crop(*self.scale(img, mask)) 253 | 254 | 255 | class SlidingCropOld(object): 256 | def __init__(self, crop_size, stride_rate, ignore_label): 257 | self.crop_size = crop_size 258 | self.stride_rate = stride_rate 259 | self.ignore_label = ignore_label 260 | 261 | def _pad(self, img, mask): 262 | h, w = img.shape[: 2] 263 | pad_h = max(self.crop_size - h, 0) 264 | pad_w = max(self.crop_size - w, 0) 265 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 266 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 267 | return img, mask 268 | 269 | def __call__(self, img, mask): 270 | assert img.size == mask.size 271 | 272 | w, h = img.size 273 | long_size = max(h, w) 274 | 275 | img = np.array(img) 276 | mask = np.array(mask) 277 | 278 | if long_size > self.crop_size: 279 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 280 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 281 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 282 | img_sublist, mask_sublist = [], [] 283 | for yy in range(h_step_num): 284 | for xx in range(w_step_num): 285 | sy, sx = yy * stride, xx * stride 286 | ey, ex = sy + self.crop_size, sx + self.crop_size 287 | img_sub = img[sy: ey, sx: ex, :] 288 | mask_sub = mask[sy: ey, sx: ex] 289 | img_sub, mask_sub = self._pad(img_sub, mask_sub) 290 | img_sublist.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 291 | mask_sublist.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 292 | return img_sublist, mask_sublist 293 | else: 294 | img, mask = self._pad(img, mask) 295 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 296 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 297 | return img, mask 298 | 299 | 300 | class SlidingCrop(object): 301 | def __init__(self, crop_size, stride_rate, ignore_label): 302 | self.crop_size = crop_size 303 | self.stride_rate = stride_rate 304 | self.ignore_label = ignore_label 305 | 306 | def _pad(self, img, mask): 307 | h, w = img.shape[: 2] 308 | pad_h = max(self.crop_size - h, 0) 309 | pad_w = max(self.crop_size - w, 0) 310 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 311 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', constant_values=self.ignore_label) 312 | return img, mask, h, w 313 | 314 | def __call__(self, img, mask): 315 | assert img.size == mask.size 316 | 317 | w, h = img.size 318 | long_size = max(h, w) 319 | 320 | img = np.array(img) 321 | mask = np.array(mask) 322 | 323 | if long_size > self.crop_size: 324 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 325 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 326 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 327 | img_slices, mask_slices, slices_info = [], [], [] 328 | for yy in range(h_step_num): 329 | for xx in range(w_step_num): 330 | sy, sx = yy * stride, xx * stride 331 | ey, ex = sy + self.crop_size, sx + self.crop_size 332 | img_sub = img[sy: ey, sx: ex, :] 333 | mask_sub = mask[sy: ey, sx: ex] 334 | img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub) 335 | img_slices.append(Image.fromarray(img_sub.astype(np.uint8)).convert('RGB')) 336 | mask_slices.append(Image.fromarray(mask_sub.astype(np.uint8)).convert('P')) 337 | slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) 338 | return img_slices, mask_slices, slices_info 339 | else: 340 | img, mask, sub_h, sub_w = self._pad(img, mask) 341 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 342 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 343 | return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]] 344 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # from https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/master/utils/misc.py 2 | import os 3 | from math import ceil 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch.autograd import Variable 10 | 11 | 12 | def check_mkdir(dir_name): 13 | if not os.path.exists(dir_name): 14 | os.mkdir(dir_name) 15 | 16 | 17 | def initialize_weights(*models): 18 | for model in models: 19 | for module in model.modules(): 20 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 21 | nn.init.kaiming_normal(module.weight) 22 | if module.bias is not None: 23 | module.bias.data.zero_() 24 | elif isinstance(module, nn.BatchNorm2d): 25 | module.weight.data.fill_(1) 26 | module.bias.data.zero_() 27 | 28 | 29 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 30 | factor = (kernel_size + 1) // 2 31 | if kernel_size % 2 == 1: 32 | center = factor - 1 33 | else: 34 | center = factor - 0.5 35 | og = np.ogrid[:kernel_size, :kernel_size] 36 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 37 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) 38 | weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt 39 | return torch.from_numpy(weight).float() 40 | 41 | 42 | class CrossEntropyLoss2d(nn.Module): 43 | def __init__(self, weight=None, size_average=True, ignore_index=255): 44 | super(CrossEntropyLoss2d, self).__init__() 45 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 46 | 47 | def forward(self, inputs, targets): 48 | return self.nll_loss(F.log_softmax(inputs), targets) 49 | 50 | 51 | class FocalLoss2d(nn.Module): 52 | def __init__(self, gamma=2, weight=None, size_average=True, ignore_index=255): 53 | super(FocalLoss2d, self).__init__() 54 | self.gamma = gamma 55 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 56 | 57 | def forward(self, inputs, targets): 58 | return self.nll_loss((1 - F.softmax(inputs)) ** self.gamma * F.log_softmax(inputs), targets) 59 | 60 | 61 | def _fast_hist(label_pred, label_true, num_classes): 62 | mask = (label_true >= 0) & (label_true < num_classes) 63 | hist = np.bincount( 64 | num_classes * label_true[mask].astype(int) + 65 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 66 | return hist 67 | 68 | 69 | def evaluate(predictions, gts, num_classes): 70 | hist = np.zeros((num_classes, num_classes)) 71 | for lp, lt in zip(predictions, gts): 72 | hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes) 73 | # axis 0: gt, axis 1: prediction 74 | acc = np.diag(hist).sum() / hist.sum() 75 | acc_cls = np.diag(hist) / hist.sum(axis=1) 76 | acc_cls = np.nanmean(acc_cls) 77 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 78 | mean_iu = np.nanmean(iu) 79 | freq = hist.sum(axis=1) / hist.sum() 80 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 81 | return acc, acc_cls, mean_iu, fwavacc 82 | 83 | 84 | class AverageMeter(object): 85 | def __init__(self): 86 | self.reset() 87 | 88 | def reset(self): 89 | self.val = 0 90 | self.avg = 0 91 | self.sum = 0 92 | self.count = 0 93 | 94 | def update(self, val, n=1): 95 | self.val = val 96 | self.sum += val * n 97 | self.count += n 98 | self.avg = self.sum / self.count 99 | 100 | 101 | class PolyLR(object): 102 | def __init__(self, optimizer, curr_iter, max_iter, lr_decay): 103 | self.max_iter = float(max_iter) 104 | self.init_lr_groups = [] 105 | for p in optimizer.param_groups: 106 | self.init_lr_groups.append(p['lr']) 107 | self.param_groups = optimizer.param_groups 108 | self.curr_iter = curr_iter 109 | self.lr_decay = lr_decay 110 | 111 | def step(self): 112 | for idx, p in enumerate(self.param_groups): 113 | p['lr'] = self.init_lr_groups[idx] * (1 - self.curr_iter / self.max_iter) ** self.lr_decay 114 | 115 | 116 | # just a try, not recommend to use 117 | class Conv2dDeformable(nn.Module): 118 | def __init__(self, regular_filter, cuda=True): 119 | super(Conv2dDeformable, self).__init__() 120 | assert isinstance(regular_filter, nn.Conv2d) 121 | self.regular_filter = regular_filter 122 | self.offset_filter = nn.Conv2d(regular_filter.in_channels, 2 * regular_filter.in_channels, kernel_size=3, 123 | padding=1, bias=False) 124 | self.offset_filter.weight.data.normal_(0, 0.0005) 125 | self.input_shape = None 126 | self.grid_w = None 127 | self.grid_h = None 128 | self.cuda = cuda 129 | 130 | def forward(self, x): 131 | x_shape = x.size() # (b, c, h, w) 132 | offset = self.offset_filter(x) # (b, 2*c, h, w) 133 | offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1) # (b, c, h, w) 134 | offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) 135 | offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w) 136 | if not self.input_shape or self.input_shape != x_shape: 137 | self.input_shape = x_shape 138 | grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2])) # (h, w) 139 | grid_w = torch.Tensor(grid_w) 140 | grid_h = torch.Tensor(grid_h) 141 | if self.cuda: 142 | grid_w = grid_w.cuda() 143 | grid_h = grid_h.cuda() 144 | self.grid_w = nn.Parameter(grid_w) 145 | self.grid_h = nn.Parameter(grid_h) 146 | offset_w = offset_w + self.grid_w # (b*c, h, w) 147 | offset_h = offset_h + self.grid_h # (b*c, h, w) 148 | x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1) # (b*c, 1, h, w) 149 | x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3)) # (b*c, h, w) 150 | x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) # (b, c, h, w) 151 | x = self.regular_filter(x) 152 | return x 153 | 154 | 155 | def sliced_forward(single_forward): 156 | def _pad(x, crop_size): 157 | h, w = x.size()[2:] 158 | pad_h = max(crop_size - h, 0) 159 | pad_w = max(crop_size - w, 0) 160 | x = F.pad(x, (0, pad_w, 0, pad_h)) 161 | return x, pad_h, pad_w 162 | 163 | def wrapper(self, x): 164 | batch_size, _, ori_h, ori_w = x.size() 165 | if self.training and self.use_aux: 166 | outputs_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 167 | aux_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 168 | for s in self.scales: 169 | new_size = (int(ori_h * s), int(ori_w * s)) 170 | scaled_x = F.upsample(x, size=new_size, mode='bilinear') 171 | scaled_x = Variable(scaled_x).cuda() 172 | scaled_h, scaled_w = scaled_x.size()[2:] 173 | long_size = max(scaled_h, scaled_w) 174 | print(scaled_x.size()) 175 | 176 | if long_size > self.crop_size: 177 | count = torch.zeros((scaled_h, scaled_w)) 178 | outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 179 | aux_outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 180 | stride = int(ceil(self.crop_size * self.stride_rate)) 181 | h_step_num = int(ceil((scaled_h - self.crop_size) / stride)) + 1 182 | w_step_num = int(ceil((scaled_w - self.crop_size) / stride)) + 1 183 | for yy in range(h_step_num): 184 | for xx in range(w_step_num): 185 | sy, sx = yy * stride, xx * stride 186 | ey, ex = sy + self.crop_size, sx + self.crop_size 187 | x_sub = scaled_x[:, :, sy: ey, sx: ex] 188 | x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size) 189 | print(x_sub.size()) 190 | outputs_sub, aux_sub = single_forward(self, x_sub) 191 | 192 | if sy + self.crop_size > scaled_h: 193 | outputs_sub = outputs_sub[:, :, : -pad_h, :] 194 | aux_sub = aux_sub[:, :, : -pad_h, :] 195 | 196 | if sx + self.crop_size > scaled_w: 197 | outputs_sub = outputs_sub[:, :, :, : -pad_w] 198 | aux_sub = aux_sub[:, :, :, : -pad_w] 199 | 200 | outputs[:, :, sy: ey, sx: ex] = outputs_sub 201 | aux_outputs[:, :, sy: ey, sx: ex] = aux_sub 202 | 203 | count[sy: ey, sx: ex] += 1 204 | count = Variable(count).cuda() 205 | outputs = (outputs / count) 206 | aux_outputs = (outputs / count) 207 | else: 208 | scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size) 209 | outputs, aux_outputs = single_forward(self, scaled_x) 210 | outputs = outputs[:, :, : -pad_h, : -pad_w] 211 | aux_outputs = aux_outputs[:, :, : -pad_h, : -pad_w] 212 | outputs_all_scales += outputs 213 | aux_all_scales += aux_outputs 214 | return outputs_all_scales / len(self.scales), aux_all_scales 215 | else: 216 | outputs_all_scales = Variable(torch.zeros((batch_size, self.num_classes, ori_h, ori_w))).cuda() 217 | for s in self.scales: 218 | new_size = (int(ori_h * s), int(ori_w * s)) 219 | scaled_x = F.upsample(x, size=new_size, mode='bilinear') 220 | scaled_h, scaled_w = scaled_x.size()[2:] 221 | long_size = max(scaled_h, scaled_w) 222 | 223 | if long_size > self.crop_size: 224 | count = torch.zeros((scaled_h, scaled_w)) 225 | outputs = Variable(torch.zeros((batch_size, self.num_classes, scaled_h, scaled_w))).cuda() 226 | stride = int(ceil(self.crop_size * self.stride_rate)) 227 | h_step_num = int(ceil((scaled_h - self.crop_size) / stride)) + 1 228 | w_step_num = int(ceil((scaled_w - self.crop_size) / stride)) + 1 229 | for yy in range(h_step_num): 230 | for xx in range(w_step_num): 231 | sy, sx = yy * stride, xx * stride 232 | ey, ex = sy + self.crop_size, sx + self.crop_size 233 | x_sub = scaled_x[:, :, sy: ey, sx: ex] 234 | x_sub, pad_h, pad_w = _pad(x_sub, self.crop_size) 235 | 236 | outputs_sub = single_forward(self, x_sub) 237 | 238 | if sy + self.crop_size > scaled_h: 239 | outputs_sub = outputs_sub[:, :, : -pad_h, :] 240 | 241 | if sx + self.crop_size > scaled_w: 242 | outputs_sub = outputs_sub[:, :, :, : -pad_w] 243 | 244 | outputs[:, :, sy: ey, sx: ex] = outputs_sub 245 | 246 | count[sy: ey, sx: ex] += 1 247 | count = Variable(count).cuda() 248 | outputs = (outputs / count) 249 | else: 250 | scaled_x, pad_h, pad_w = _pad(scaled_x, self.crop_size) 251 | outputs = single_forward(self, scaled_x) 252 | outputs = outputs[:, :, : -pad_h, : -pad_w] 253 | outputs_all_scales += outputs 254 | return outputs_all_scales 255 | 256 | return wrapper -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | # from https://raw.githubusercontent.com/zijundeng/pytorch-semantic-segmentation/master/utils/transforms.py 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image, ImageFilter 7 | 8 | 9 | class RandomVerticalFlip(object): 10 | def __call__(self, img): 11 | if random.random() < 0.5: 12 | return img.transpose(Image.FLIP_TOP_BOTTOM) 13 | return img 14 | 15 | 16 | class DeNormalize(object): 17 | def __init__(self, mean, std): 18 | self.mean = mean 19 | self.std = std 20 | 21 | def __call__(self, tensor): 22 | for t, m, s in zip(tensor, self.mean, self.std): 23 | t.mul_(s).add_(m) 24 | return tensor 25 | 26 | 27 | class MaskToTensor(object): 28 | def __call__(self, img): 29 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 30 | 31 | 32 | class FreeScale(object): 33 | def __init__(self, size, interpolation=Image.BILINEAR): 34 | self.size = tuple(reversed(size)) # size: (h, w) 35 | self.interpolation = interpolation 36 | 37 | def __call__(self, img): 38 | return img.resize(self.size, self.interpolation) 39 | 40 | 41 | class FlipChannels(object): 42 | def __call__(self, img): 43 | img = np.array(img)[:, :, ::-1] 44 | return Image.fromarray(img.astype(np.uint8)) 45 | 46 | --------------------------------------------------------------------------------