├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── base.py ├── base_seg.py ├── cityscapes.py └── sampler.py ├── demo.py ├── eval.py ├── model ├── __init__.py ├── basic.py ├── lednet.py ├── loss.py └── lr_scheduler.py ├── png ├── demo.png ├── gt.png └── output.png ├── train.py └── utils ├── __init__.py ├── logger.py ├── metric.py ├── metric_seg.py ├── parallel.py ├── util.py └── visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ace 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LEDNet 2 | This is an unofficial implement of [LEDNet](https://arxiv.org/abs/1905.02423). 3 | 4 | > the official version:[LEDNet-official](https://github.com/xiaoyufenfei/LEDNet) 5 | 6 |
7 | 8 |
9 | 10 | ## Environment 11 | 12 | - Python 3.6 13 | - PyTorch 1.1 14 | 15 | ## Performance 16 | 17 | - Base Size 1024, Crop Size 768, only fine. (new-version, with dropout) 18 | 19 | | Model | Paper | OHEM | Drop-rate | lr | Epoch | val (crop) | val | 20 | | :----: | :---: | :--: | :-------: | :----: | :---: | :---------: | :----------------------------------------------------------: | 21 | | LEDNet | / | ✗ | 0.1 | 0.0005 | 800 | 60.32/94.51 | 66.29/94.40 | 22 | | LEDNet | / | ✗ | 0.1 | 0.005 | 600 | 61.29/94.75 | 66.56/94.72 | 23 | | LEDNet | / | ✗ | 0.3 | 0.01 | 800 | 63.84/94.83 | [69.09/94.75](https://drive.google.com/open?id=1oelPUKAnZYD75RruyBQU9HZKneMEMIAp) | 24 | 25 | > Note: 26 | > 27 | > - The paper only provide the test results: 69.2/86.8 (class mIoU/category mIoU). 28 | > - And the training setting is a little different with original paper (original paper use 1024x512) 29 | 30 | Some things you can use to improve the performance: 31 | 32 | 1. use larger learning rate (like 0.01) 33 | 2. use more epochs (like 1000) 34 | 3. use larger training input size (like Base Size 1344, Crop Size 1024) 35 | 36 | ## Demo 37 | 38 | Please download [pretrained](https://drive.google.com/open?id=1oelPUKAnZYD75RruyBQU9HZKneMEMIAp) model first 39 | 40 | ```shell 41 | $ python demo.py [--input-pic png/demo.png] [--pretrained your-root-of-pretrained] [--cuda true] 42 | ``` 43 | 44 | ## Evaluation 45 | 46 | The default data root is `~/.torch/datasets` (You can download dataset and build a soft-link to it) 47 | 48 | ```shell 49 | $ python eval.py [--mode testval] [--pretrained root-of-pretrained-model] [--cuda true] 50 | ``` 51 | 52 | ## Training 53 | 54 | Recommend to using distributed training. 55 | 56 | ```shell 57 | $ export NGPUS=4 58 | $ python -m torch.distributed.launch --nproc_per_node=$NGPUS train.py [--dataset citys] [--batch-size 8] [--base-size 1024] [--crop-size 768] [--epochs 800] [--warmup-factor 0.1] [--warmup-iters 200] [--log-step 10] [--save-epoch 40] [--lr 0.005] 59 | ``` 60 | 61 | ## Prepare data 62 | 63 | Your can reference [gluon-cv-cityspaces](https://gluon-cv.mxnet.io/build/examples_datasets/cityscapes.html#sphx-glr-build-examples-datasets-cityscapes-py) to prepare the dataset 64 | 65 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .cityscapes import CitySegmentation 2 | 3 | datasets = { 4 | 'citys': CitySegmentation, 5 | } 6 | 7 | 8 | def get_segmentation_dataset(name, **kwargs): 9 | """Segmentation Datasets""" 10 | return datasets[name.lower()](**kwargs) 11 | -------------------------------------------------------------------------------- /data/base.py: -------------------------------------------------------------------------------- 1 | """Base dataset methods.""" 2 | import os 3 | from torch.utils import data 4 | 5 | 6 | class ClassProperty(object): 7 | """Readonly @ClassProperty descriptor for internal usage.""" 8 | 9 | def __init__(self, fget): 10 | self.fget = fget 11 | 12 | def __get__(self, owner_self, owner_cls): 13 | return self.fget(owner_cls) 14 | 15 | 16 | class SimpleDataset(data.Dataset): 17 | """Simple Dataset wrapper for lists and arrays. 18 | 19 | Parameters 20 | ---------- 21 | data : dataset-like object 22 | Any object that implements `len()` and `[]`. 23 | """ 24 | 25 | def __init__(self, data): 26 | self._data = data 27 | 28 | def __len__(self): 29 | return len(self._data) 30 | 31 | def __getitem__(self, idx): 32 | return self._data[idx] 33 | 34 | 35 | class _LazyTransformDataset(data.Dataset): 36 | """Lazily transformed dataset.""" 37 | 38 | def __init__(self, data, fn): 39 | super(_LazyTransformDataset, self).__init__() 40 | self._data = data 41 | self._fn = fn 42 | 43 | def __len__(self): 44 | return len(self._data) 45 | 46 | def __getitem__(self, idx): 47 | item = self._data[idx] 48 | if isinstance(item, tuple): 49 | return self._fn(*item) 50 | return self._fn(item) 51 | 52 | def transform(self, fn): 53 | self._fn = fn 54 | 55 | 56 | class VisionDataset(data.Dataset): 57 | """Base Dataset with directory checker. 58 | 59 | Parameters 60 | ---------- 61 | root : str 62 | The root path of xxx.names, by default is '~/.mxnet/datasets/foo', where 63 | `foo` is the name of the dataset. 64 | """ 65 | 66 | def __init__(self, root): 67 | super(VisionDataset, self).__init__() 68 | if not os.path.isdir(os.path.expanduser(root)): 69 | helper_msg = "{} is not a valid dir. Did you forget to initialize \ 70 | datasets described in: \ 71 | `http://gluon-cv.mxnet.io/build/examples_datasets/index.html`? \ 72 | You need to initialize each dataset only once.".format(root) 73 | raise OSError(helper_msg) 74 | 75 | @property 76 | def classes(self): 77 | raise NotImplementedError 78 | 79 | @property 80 | def num_class(self): 81 | """Number of categories.""" 82 | return len(self.classes) 83 | 84 | def transform(self, fn, lazy=True): 85 | """Returns a new dataset with each sample transformed by the 86 | transformer function `fn`. 87 | 88 | Parameters 89 | ---------- 90 | fn : callable 91 | A transformer function that takes a sample as input and 92 | returns the transformed sample. 93 | lazy : bool, default True 94 | If False, transforms all samples at once. Otherwise, 95 | transforms each sample on demand. Note that if `fn` 96 | is stochastic, you must set lazy to True or you will 97 | get the same result on all epochs. 98 | 99 | Returns 100 | ------- 101 | Dataset 102 | The transformed dataset. 103 | """ 104 | trans = _LazyTransformDataset(self, fn) 105 | if lazy: 106 | return trans 107 | return SimpleDataset([i for i in trans]) 108 | -------------------------------------------------------------------------------- /data/base_seg.py: -------------------------------------------------------------------------------- 1 | """Base segmentation dataset""" 2 | import torch 3 | import random 4 | import numpy as np 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | from data.base import VisionDataset 8 | 9 | 10 | class SegmentationDataset(VisionDataset): 11 | """Segmentation Base Dataset""" 12 | 13 | # pylint: disable=abstract-method 14 | def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): 15 | super(SegmentationDataset, self).__init__(root) 16 | self.root = root 17 | self.transform = transform 18 | self.split = split 19 | self.mode = mode if mode is not None else split 20 | self.base_size = base_size 21 | self.crop_size = crop_size 22 | 23 | def _val_sync_transform(self, img, mask): 24 | outsize = self.crop_size 25 | short_size = outsize 26 | w, h = img.size 27 | if w > h: 28 | oh = short_size 29 | ow = int(1.0 * w * oh / h) 30 | else: 31 | ow = short_size 32 | oh = int(1.0 * h * ow / w) 33 | img = img.resize((ow, oh), Image.BILINEAR) 34 | mask = mask.resize((ow, oh), Image.NEAREST) 35 | # center crop 36 | w, h = img.size 37 | x1 = int(round((w - outsize) / 2.)) 38 | y1 = int(round((h - outsize) / 2.)) 39 | img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) 40 | mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) 41 | # final transform 42 | img, mask = self._img_transform(img), self._mask_transform(mask) 43 | return img, mask 44 | 45 | def _sync_transform(self, img, mask): 46 | # random mirror 47 | if random.random() < 0.5: 48 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 49 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 50 | crop_size = self.crop_size 51 | # random scale (short edge) 52 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 53 | w, h = img.size 54 | if h > w: 55 | ow = short_size 56 | oh = int(1.0 * h * ow / w) 57 | else: 58 | oh = short_size 59 | ow = int(1.0 * w * oh / h) 60 | img = img.resize((ow, oh), Image.BILINEAR) 61 | mask = mask.resize((ow, oh), Image.NEAREST) 62 | # pad crop 63 | if short_size < crop_size: 64 | padh = crop_size - oh if oh < crop_size else 0 65 | padw = crop_size - ow if ow < crop_size else 0 66 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 67 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 68 | # random crop crop_size 69 | w, h = img.size 70 | x1 = random.randint(0, w - crop_size) 71 | y1 = random.randint(0, h - crop_size) 72 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 73 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 74 | # gaussian blur as in PSP 75 | if random.random() < 0.5: 76 | img = img.filter(ImageFilter.GaussianBlur( 77 | radius=random.random())) 78 | # final transform 79 | img, mask = self._img_transform(img), self._mask_transform(mask) 80 | return img, mask 81 | 82 | def _img_transform(self, img): 83 | # return torch.from_numpy(np.array(img)) 84 | return np.array(img) 85 | 86 | def _mask_transform(self, mask): 87 | # return torch.from_numpy(np.array(mask).astype('int32')) 88 | return np.array(mask).astype('int64') 89 | 90 | @property 91 | def num_class(self): 92 | """Number of categories.""" 93 | return self.NUM_CLASS 94 | 95 | @property 96 | def pred_offset(self): 97 | return 0 98 | 99 | 100 | def ms_batchify_fn(data): 101 | """Multi-size batchify function""" 102 | if isinstance(data[0], (str, torch.Tensor)): 103 | return list(data) 104 | elif isinstance(data[0], tuple): 105 | data = zip(*data) 106 | return [ms_batchify_fn(i) for i in data] 107 | raise RuntimeError('unknown datatype') 108 | -------------------------------------------------------------------------------- /data/cityscapes.py: -------------------------------------------------------------------------------- 1 | """Cityscapes Dataloader""" 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | 7 | from data.base_seg import SegmentationDataset 8 | 9 | 10 | def _get_city_pairs(folder, split='train'): 11 | def get_path_pairs(img_folder, mask_folder): 12 | img_paths = [] 13 | mask_paths = [] 14 | for root, _, files in os.walk(img_folder): 15 | for filename in files: 16 | if filename.endswith(".png"): 17 | imgpath = os.path.join(root, filename) 18 | foldername = os.path.basename(os.path.dirname(imgpath)) 19 | maskname = filename.replace('leftImg8bit', 'gtFine_labelIds') 20 | maskpath = os.path.join(mask_folder, foldername, maskname) 21 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 22 | img_paths.append(imgpath) 23 | mask_paths.append(maskpath) 24 | else: 25 | print('cannot find the mask or image:', imgpath, maskpath) 26 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 27 | return img_paths, mask_paths 28 | 29 | if split in ('train', 'val'): 30 | img_folder = os.path.join(folder, 'leftImg8bit/' + split) 31 | mask_folder = os.path.join(folder, 'gtFine/' + split) 32 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 33 | return img_paths, mask_paths 34 | else: 35 | assert split == 'trainval' 36 | print('trainval set') 37 | train_img_folder = os.path.join(folder, 'leftImg8bit/train') 38 | train_mask_folder = os.path.join(folder, 'gtFine/train') 39 | val_img_folder = os.path.join(folder, 'leftImg8bit/val') 40 | val_mask_folder = os.path.join(folder, 'gtFine/val') 41 | train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder) 42 | val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder) 43 | img_paths = train_img_paths + val_img_paths 44 | mask_paths = train_mask_paths + val_mask_paths 45 | return img_paths, mask_paths 46 | 47 | 48 | class CitySegmentation(SegmentationDataset): 49 | """Cityscapes Dataloader""" 50 | # pylint: disable=abstract-method 51 | BASE_DIR = 'cityscapes' 52 | NUM_CLASS = 19 53 | 54 | def __init__(self, root=os.path.expanduser('~/.torch/datasets/citys'), split='train', 55 | mode=None, transform=None, **kwargs): 56 | super(CitySegmentation, self).__init__( 57 | root, split, mode, transform, **kwargs) 58 | # self.root = os.path.join(root, self.BASE_DIR) 59 | self.images, self.mask_paths = _get_city_pairs(self.root, self.split) 60 | assert (len(self.images) == len(self.mask_paths)) 61 | if len(self.images) == 0: 62 | raise RuntimeError("Found 0 images in subfolders of: \ 63 | " + self.root + "\n") 64 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 65 | 23, 24, 25, 26, 27, 28, 31, 32, 33] 66 | self._key = np.array([-1, -1, -1, -1, -1, -1, 67 | -1, -1, 0, 1, -1, -1, 68 | 2, 3, 4, -1, -1, -1, 69 | 5, -1, 6, 7, 8, 9, 70 | 10, 11, 12, 13, 14, 15, 71 | -1, -1, 16, 17, 18]) 72 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 73 | 74 | def _class_to_index(self, mask): 75 | # assert the values 76 | values = np.unique(mask) 77 | for value in values: 78 | assert (value in self._mapping) 79 | index = np.digitize(mask.ravel(), self._mapping, right=True) 80 | return self._key[index].reshape(mask.shape) 81 | 82 | def __getitem__(self, index): 83 | img = Image.open(self.images[index]).convert('RGB') 84 | if self.mode == 'test': 85 | if self.transform is not None: 86 | img = self.transform(img) 87 | return img, os.path.basename(self.images[index]) 88 | # mask = self.masks[index] 89 | mask = Image.open(self.mask_paths[index]) 90 | # synchrosized transform 91 | if self.mode == 'train': 92 | img, mask = self._sync_transform(img, mask) 93 | elif self.mode == 'val': 94 | img, mask = self._val_sync_transform(img, mask) 95 | else: 96 | assert self.mode == 'testval' 97 | img, mask = self._img_transform(img), self._mask_transform(mask) 98 | # general resize, normalize and toTensor 99 | if self.transform is not None: 100 | img = self.transform(img) 101 | return img, mask 102 | 103 | def _mask_transform(self, mask): 104 | target = self._class_to_index(np.array(mask).astype('int64')) 105 | return torch.from_numpy(target) 106 | 107 | def __len__(self): 108 | return len(self.images) 109 | 110 | 111 | if __name__ == '__main__': 112 | data = CitySegmentation(split='val', mode='val') 113 | print(data[0][0].shape, data[0][0].shape) 114 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Code is copy-pasted exactly as in torch.utils.data.distributed. 3 | # FIXME remove this once c10d fixes the bug it has 4 | import math 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.data.sampler import Sampler, BatchSampler 8 | 9 | 10 | class DistributedSampler(Sampler): 11 | """Sampler that restricts data loading to a subset of the dataset. 12 | It is especially useful in conjunction with 13 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 14 | process can pass a DistributedSampler instance as a DataLoader sampler, 15 | and load a subset of the original dataset that is exclusive to it. 16 | .. note:: 17 | Dataset is assumed to be of constant size. 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 26 | if num_replicas is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | num_replicas = dist.get_world_size() 30 | if rank is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available") 33 | rank = dist.get_rank() 34 | self.dataset = dataset 35 | self.num_replicas = num_replicas 36 | self.rank = rank 37 | self.epoch = 0 38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 39 | self.total_size = self.num_samples * self.num_replicas 40 | self.shuffle = shuffle 41 | 42 | def __iter__(self): 43 | if self.shuffle: 44 | # deterministically shuffle based on epoch 45 | g = torch.Generator() 46 | g.manual_seed(self.epoch) 47 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 48 | else: 49 | indices = torch.arange(len(self.dataset)).tolist() 50 | 51 | # add extra samples to make it evenly divisible 52 | indices += indices[: (self.total_size - len(indices))] 53 | assert len(indices) == self.total_size 54 | 55 | # subsample 56 | offset = self.num_samples * self.rank 57 | indices = indices[offset: offset + self.num_samples] 58 | assert len(indices) == self.num_samples 59 | 60 | return iter(indices) 61 | 62 | def __len__(self): 63 | return self.num_samples 64 | 65 | def set_epoch(self, epoch): 66 | self.epoch = epoch 67 | 68 | 69 | def make_data_sampler(dataset, shuffle, distributed): 70 | if distributed: 71 | return DistributedSampler(dataset, shuffle=shuffle) 72 | if shuffle: 73 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 74 | else: 75 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 76 | return sampler 77 | 78 | 79 | class IterationBasedBatchSampler(BatchSampler): 80 | """ 81 | Wraps a BatchSampler, resampling from it until 82 | a specified number of iterations have been sampled 83 | """ 84 | 85 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 86 | self.batch_sampler = batch_sampler 87 | self.num_iterations = num_iterations 88 | self.start_iter = start_iter 89 | 90 | def __iter__(self): 91 | iteration = self.start_iter 92 | while iteration <= self.num_iterations: 93 | # if the underlying sampler has a set_epoch method, like 94 | # DistributedSampler, used for making each process see 95 | # a different split of the dataset, then set it 96 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 97 | self.batch_sampler.sampler.set_epoch(iteration) 98 | for batch in self.batch_sampler: 99 | iteration += 1 100 | if iteration > self.num_iterations: 101 | break 102 | yield batch 103 | 104 | def __len__(self): 105 | return self.num_iterations 106 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | import matplotlib.image as mpimg 7 | 8 | import torch 9 | from torchvision import transforms 10 | 11 | cur_path = os.path.dirname(__file__) 12 | sys.path.insert(0, os.path.join(cur_path, '..')) 13 | from model.lednet import LEDNet 14 | import utils as ptutil 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description='Demo for LEDNet from a given image') 19 | 20 | parser.add_argument('--input-pic', type=str, default=os.path.join(cur_path, 'png/demo.png'), 21 | help='path to the input picture') 22 | parser.add_argument('--pretrained', type=str, 23 | default=os.path.expanduser('~/cbb/own/pretrained/seg/lednet/LEDNet_final.pth'), 24 | help='Default Pre-trained model root.') 25 | parser.add_argument('--cuda', type=ptutil.str2bool, default='true', 26 | help='demo with GPU') 27 | 28 | opt = parser.parse_args() 29 | return opt 30 | 31 | 32 | if __name__ == '__main__': 33 | args = parse_args() 34 | device = torch.device('cpu') 35 | if args.cuda: 36 | device = torch.device('cuda') 37 | # Load Model 38 | model = LEDNet(19).to(device) 39 | model.load_state_dict(torch.load(args.pretrained)) 40 | model.eval() 41 | 42 | # Load Images 43 | img = Image.open(args.input_pic) 44 | 45 | # Transform 46 | transform_fn = transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 49 | ]) 50 | 51 | img = transform_fn(img).unsqueeze(0).to(device) 52 | with torch.no_grad(): 53 | output = model(img) 54 | 55 | predict = torch.argmax(output, 1).squeeze(0).cpu().numpy() 56 | mask = ptutil.get_color_pallete(predict, 'citys') 57 | mask.save(os.path.join(cur_path, 'png/output.png')) 58 | mmask = mpimg.imread(os.path.join(cur_path, 'png/output.png')) 59 | plt.imshow(mmask) 60 | plt.show() 61 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch 7 | from torch.utils import data 8 | from torchvision import transforms 9 | 10 | cur_path = os.path.dirname(__file__) 11 | sys.path.insert(0, os.path.join(cur_path, '..')) 12 | import utils as ptutil 13 | from model.lednet import LEDNet 14 | from data import get_segmentation_dataset 15 | from data.sampler import make_data_sampler 16 | from utils.metric_seg import SegmentationMetric 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='Eval Segmentation.') 21 | parser.add_argument('--batch-size', type=int, default=1, 22 | help='Training mini-batch size') 23 | parser.add_argument('--num-workers', '-j', dest='num_workers', type=int, 24 | default=4, help='Number of data workers') 25 | parser.add_argument('--dataset', type=str, default='citys', 26 | help='Select dataset.') 27 | parser.add_argument('--split', type=str, default='val', 28 | help='Select val|test, evaluate in val or test data') 29 | parser.add_argument('--mode', type=str, default='testval', 30 | help='Select testval|val, w/o corp and with crop') 31 | parser.add_argument('--base-size', type=int, default=1024, 32 | help='base image size') 33 | parser.add_argument('--crop-size', type=int, default=768, 34 | help='crop image size') 35 | 36 | parser.add_argument('--pretrained', type=str, 37 | default='./LEDNet_iter_073600.pth', 38 | help='Default Pre-trained model root.') 39 | 40 | # device 41 | parser.add_argument('--cuda', type=ptutil.str2bool, default='true', 42 | help='Training with GPUs.') 43 | parser.add_argument('--local_rank', type=int, default=0) 44 | parser.add_argument('--init-method', type=str, default="env://") 45 | 46 | args = parser.parse_args() 47 | return args 48 | 49 | 50 | def validate(net, val_data, metric, device): 51 | net.eval() 52 | tbar = tqdm(val_data) 53 | for i, (data, targets) in enumerate(tbar): 54 | data, targets = data.to(device), targets.to(device) 55 | with torch.no_grad(): 56 | predicts = net(data) 57 | metric.update(targets, predicts) 58 | return metric 59 | 60 | 61 | if __name__ == '__main__': 62 | args = parse_args() 63 | 64 | device = torch.device('cpu') 65 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 66 | distributed = num_gpus > 1 67 | if args.cuda and torch.cuda.is_available(): 68 | torch.backends.cudnn.benchmark = False if args.mode == 'testval' else True 69 | device = torch.device('cuda') 70 | else: 71 | distributed = False 72 | 73 | if distributed: 74 | torch.cuda.set_device(args.local_rank) 75 | torch.distributed.init_process_group(backend="nccl", init_method=args.init_method) 76 | 77 | # Load Model 78 | model = LEDNet(19) 79 | model.load_state_dict(torch.load(args.pretrained)) 80 | model.keep_shape = True if args.mode == 'testval' else False 81 | model.to(device) 82 | 83 | # testing data 84 | input_transform = transforms.Compose([ 85 | transforms.ToTensor(), 86 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 87 | ]) 88 | 89 | data_kwargs = {'base_size': args.base_size, 'crop_size': args.crop_size, 'transform': input_transform} 90 | 91 | val_dataset = get_segmentation_dataset(args.dataset, split=args.split, mode=args.mode, **data_kwargs) 92 | sampler = make_data_sampler(val_dataset, False, distributed) 93 | batch_sampler = data.BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=False) 94 | val_data = data.DataLoader(val_dataset, shuffle=False, batch_sampler=batch_sampler, 95 | num_workers=args.num_workers) 96 | metric = SegmentationMetric(val_dataset.num_class) 97 | 98 | metric = validate(model, val_data, metric, device) 99 | ptutil.synchronize() 100 | pixAcc, mIoU = ptutil.accumulate_metric(metric) 101 | if ptutil.is_main_process(): 102 | print('pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU)) 103 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/LEDNet/8887545ef0c0eba8b8e5d92f9452764d7bd55bb3/model/__init__.py -------------------------------------------------------------------------------- /model/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # helper function 7 | def channel_shuffle(x, groups): 8 | b, n, h, w = x.shape 9 | channels_per_group = n // groups 10 | 11 | # reshape 12 | x = x.view(b, groups, channels_per_group, h, w) 13 | x = torch.transpose(x, 1, 2).contiguous() 14 | 15 | # flatten 16 | x = x.view(b, -1, h, w) 17 | 18 | return x 19 | 20 | 21 | def basic_conv(in_channel, channel, kernel=3, stride=1): 22 | return nn.Sequential( 23 | nn.Conv2d(in_channel, channel, kernel, stride, kernel // 2, bias=False), 24 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True) 25 | ) 26 | 27 | 28 | # basic module 29 | # TODO: may add bn and relu 30 | class DownSampling(nn.Module): 31 | def __init__(self, in_channel, out_channel): 32 | super(DownSampling, self).__init__() 33 | self.conv = nn.Conv2d(in_channel, out_channel, 3, stride=2, padding=1) 34 | self.pool = nn.MaxPool2d(2, ceil_mode=True) 35 | self.bn = nn.BatchNorm2d(out_channel + in_channel) 36 | 37 | def forward(self, x): 38 | x1 = self.conv(x) 39 | x2 = self.pool(x) 40 | x = torch.cat([x1, x2], dim=1) 41 | x = F.relu_(self.bn(x)) 42 | return x 43 | 44 | 45 | class SSnbt(nn.Module): 46 | def __init__(self, channel, dilate=1, drop_prob=0.01): 47 | super(SSnbt, self).__init__() 48 | channel = channel // 2 49 | self.left = nn.Sequential( 50 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (1, 0)), nn.ReLU(inplace=True), 51 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, 1), bias=False), 52 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True), 53 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (dilate, 0), dilation=(dilate, 1)), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, dilate), dilation=(1, dilate), bias=False), 56 | nn.BatchNorm2d(channel), nn.Dropout2d(drop_prob, inplace=True) 57 | ) 58 | self.right = nn.Sequential( 59 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, 1)), nn.ReLU(inplace=True), 60 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (1, 0), bias=False), 61 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True), 62 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, dilate), dilation=(1, dilate)), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (dilate, 0), dilation=(dilate, 1), bias=False), 65 | nn.BatchNorm2d(channel), nn.Dropout2d(drop_prob, inplace=True) 66 | ) 67 | 68 | def forward(self, x): 69 | x1, x2 = x.split(x.shape[1] // 2, 1) 70 | x1 = self.left(x1) 71 | x2 = self.right(x2) 72 | out = torch.cat([x1, x2], 1) 73 | x = F.relu(out + x) 74 | return channel_shuffle(x, 2) 75 | 76 | 77 | class SSnbtv2(nn.Module): 78 | def __init__(self, channel, dilate=1, drop_prob=0.01): 79 | super(SSnbtv2, self).__init__() 80 | channel = channel // 2 81 | self.left = nn.Sequential( 82 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (1, 0)), nn.ReLU(inplace=True), 83 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, 1), bias=False), 84 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True), 85 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (dilate, 0), dilation=(dilate, 1)), 86 | nn.ReLU(inplace=True), 87 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, dilate), dilation=(1, dilate), bias=False), 88 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True), nn.Dropout2d(drop_prob, inplace=True) 89 | ) 90 | self.right = nn.Sequential( 91 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, 1)), nn.ReLU(inplace=True), 92 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (1, 0), bias=False), 93 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True), 94 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, dilate), dilation=(1, dilate)), 95 | nn.ReLU(inplace=True), 96 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (dilate, 0), dilation=(dilate, 1), bias=False), 97 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True), nn.Dropout2d(drop_prob, inplace=True) 98 | ) 99 | 100 | def forward(self, x): 101 | x1, x2 = x.split(x.shape[1] // 2, 1) 102 | x1 = self.left(x1) 103 | x2 = self.right(x2) 104 | out = torch.cat([x1, x2], 1) 105 | x = F.relu(out + x) 106 | return channel_shuffle(x, 2) 107 | 108 | 109 | class APN(nn.Module): 110 | def __init__(self, channel, classes): 111 | super(APN, self).__init__() 112 | self.conv1 = basic_conv(channel, channel, 3, 2) 113 | self.conv2 = basic_conv(channel, channel, 5, 2) 114 | self.conv3 = basic_conv(channel, channel, 7, 2) 115 | self.branch1 = basic_conv(channel, classes, 1, 1) 116 | self.branch2 = basic_conv(channel, classes, 1, 1) 117 | self.branch3 = basic_conv(channel, classes, 1, 1) 118 | self.branch4 = basic_conv(channel, classes, 1, 1) 119 | self.branch5 = nn.Sequential( 120 | nn.AdaptiveAvgPool2d(output_size=1), 121 | basic_conv(channel, classes, 1, 1) 122 | ) 123 | 124 | def forward(self, x): 125 | _, _, h, w = x.shape 126 | out3 = self.conv1(x) 127 | out2 = self.conv2(out3) 128 | out = self.branch1(self.conv3(out2)) 129 | out = F.interpolate(out, size=((h + 3) // 4, (w + 3) // 4), mode='bilinear', align_corners=True) 130 | out = out + self.branch2(out2) 131 | out = F.interpolate(out, size=((h + 1) // 2, (w + 1) // 2), mode='bilinear', align_corners=True) 132 | out = out + self.branch3(out3) 133 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True) 134 | out = out * self.branch4(x) 135 | out = out + self.branch5(x) 136 | return out 137 | 138 | 139 | if __name__ == '__main__': 140 | # model = DownSampling(32) 141 | # a = torch.randn(1, 32, 512, 256) 142 | # out = model(a) 143 | # print(out.shape) 144 | 145 | # model = SSnbt(10, 2) 146 | # a = torch.randn(1, 20, 10, 10) 147 | # out = model(a) 148 | # print(out.shape) 149 | # model = basic_conv(10, 20, 3, 2) 150 | # a = torch.randn(1, 10, 128, 65) 151 | # out = model(a) 152 | # print(out.shape) 153 | 154 | model = APN(64, 10) 155 | x = torch.randn(2, 64, 127, 65) 156 | out = model(x) 157 | print(out.shape) 158 | -------------------------------------------------------------------------------- /model/lednet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | from model.basic import DownSampling, SSnbt, APN 4 | 5 | 6 | class LEDNet(nn.Module): 7 | def __init__(self, nclass, drop=0.1): 8 | super(LEDNet, self).__init__() 9 | self.encoder = nn.Sequential( 10 | DownSampling(3, 29), SSnbt(32, 1, 0.1 * drop), SSnbt(32, 1, 0.1 * drop), SSnbt(32, 1, 0.1 * drop), 11 | DownSampling(32, 32), SSnbt(64, 1, 0.1 * drop), SSnbt(64, 1, 0.1 * drop), 12 | DownSampling(64, 64), SSnbt(128, 1, drop), SSnbt(128, 2, drop), SSnbt(128, 5, drop), 13 | SSnbt(128, 9, drop), SSnbt(128, 2, drop), SSnbt(128, 5, drop), SSnbt(128, 9, drop), SSnbt(128, 17, drop) 14 | ) 15 | self.decoder = APN(128, nclass) 16 | 17 | def forward(self, x): 18 | _, _, h, w = x.shape 19 | x = self.encoder(x) 20 | x = self.decoder(x) 21 | return F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) 22 | 23 | 24 | if __name__ == '__main__': 25 | net = LEDNet(21) 26 | import torch 27 | 28 | a = torch.randn(2, 3, 554, 253) 29 | out = net(a) 30 | print(out.shape) 31 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MixSoftmaxCrossEntropyLoss(nn.Module): 7 | def __init__(self, ignore_label=-1, **kwargs): 8 | super(MixSoftmaxCrossEntropyLoss, self).__init__(**kwargs) 9 | self.ignore_label = ignore_label 10 | 11 | def forward(self, preds, target): 12 | return dict(loss=F.cross_entropy(preds, target, ignore_index=self.ignore_label)) 13 | 14 | 15 | # TODO: add aux support 16 | class OHEMSoftmaxCrossEntropyLoss(nn.Module): 17 | def __init__(self, ignore_label=-1, thresh=0.6, min_kept=256, 18 | down_ratio=1, reduction='mean', use_weight=False): 19 | super(OHEMSoftmaxCrossEntropyLoss, self).__init__() 20 | self.ignore_label = ignore_label 21 | self.thresh = float(thresh) 22 | self.min_kept = int(min_kept) 23 | self.down_ratio = down_ratio 24 | if use_weight: 25 | weight = torch.FloatTensor( 26 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 27 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 28 | 1.0865, 1.1529, 1.0507]) 29 | self.criterion = torch.nn.CrossEntropyLoss(reduction=reduction, 30 | weight=weight, 31 | ignore_index=ignore_label) 32 | else: 33 | self.criterion = torch.nn.CrossEntropyLoss(reduction=reduction, 34 | ignore_index=ignore_label) 35 | 36 | def base_forward(self, pred, target): 37 | b, c, h, w = pred.size() 38 | target = target.view(-1) 39 | valid_mask = target.ne(self.ignore_label) 40 | target = target * valid_mask.long() 41 | num_valid = valid_mask.sum() 42 | 43 | prob = F.softmax(pred, dim=1) 44 | prob = (prob.transpose(0, 1)).reshape(c, -1) 45 | 46 | if self.min_kept < num_valid and num_valid > 0: 47 | prob = prob.masked_fill_(1 - valid_mask, 1) 48 | mask_prob = prob[target, torch.arange(len(target), dtype=torch.long)] 49 | threshold = self.thresh 50 | if self.min_kept > 0: 51 | index = mask_prob.argsort() 52 | threshold_index = index[min(len(index), self.min_kept) - 1] 53 | if mask_prob[threshold_index] > self.thresh: 54 | threshold = mask_prob[threshold_index] 55 | kept_mask = mask_prob.le(threshold) 56 | target = target * kept_mask.long() 57 | valid_mask = valid_mask * kept_mask 58 | 59 | target = target.masked_fill_(1 - valid_mask, self.ignore_label) 60 | target = target.view(b, h, w) 61 | 62 | return self.criterion(pred, target) 63 | 64 | def forward(self, preds, target): 65 | for i, pred in enumerate(preds): 66 | if i == 0: 67 | loss = self.base_forward(pred, target) 68 | else: 69 | loss = loss + self.base_forward(pred, target) 70 | return dict(loss=loss) 71 | -------------------------------------------------------------------------------- /model/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler 3 | 4 | 5 | class WarmupMultiStepLR(MultiStepLR): 6 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, 7 | warmup_iters=500, last_epoch=-1): 8 | self.warmup_factor = warmup_factor 9 | self.warmup_iters = warmup_iters 10 | super().__init__(optimizer, milestones, gamma, last_epoch) 11 | 12 | def get_lr(self): 13 | if self.last_epoch <= self.warmup_iters: 14 | alpha = self.last_epoch / self.warmup_iters 15 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 16 | # print(self.base_lrs[0]*warmup_factor) 17 | return [lr * warmup_factor for lr in self.base_lrs] 18 | else: 19 | lr = super().get_lr() 20 | return lr 21 | 22 | 23 | class WarmupCosineLR(_LRScheduler): 24 | def __init__(self, optimizer, T_max, warmup_factor=1.0 / 3, warmup_iters=500, 25 | eta_min=0, last_epoch=-1): 26 | self.warmup_factor = warmup_factor 27 | self.warmup_iters = warmup_iters 28 | self.T_max, self.eta_min = T_max, eta_min 29 | super().__init__(optimizer, last_epoch) 30 | 31 | def get_lr(self): 32 | if self.last_epoch <= self.warmup_iters: 33 | alpha = self.last_epoch / self.warmup_iters 34 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 35 | # print(self.base_lrs[0]*warmup_factor) 36 | return [lr * warmup_factor for lr in self.base_lrs] 37 | else: 38 | return [self.eta_min + (base_lr - self.eta_min) * 39 | (1 + math.cos( 40 | math.pi * (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters))) / 2 41 | for base_lr in self.base_lrs] 42 | 43 | 44 | class WarmupPolyLR(_LRScheduler): 45 | def __init__(self, optimizer, T_max, warmup_factor=1.0 / 3, warmup_iters=500, 46 | eta_min=0, power=0.9, last_epoch=-1): 47 | self.warmup_factor = warmup_factor 48 | self.warmup_iters = warmup_iters 49 | self.power = power 50 | self.T_max, self.eta_min = T_max, eta_min 51 | super().__init__(optimizer, last_epoch) 52 | 53 | def get_lr(self): 54 | if self.last_epoch <= self.warmup_iters: 55 | alpha = self.last_epoch / self.warmup_iters 56 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 57 | # print(self.base_lrs[0]*warmup_factor) 58 | return [lr * warmup_factor for lr in self.base_lrs] 59 | else: 60 | return [self.eta_min + (base_lr - self.eta_min) * 61 | math.pow(1 - (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters), 62 | self.power) for base_lr in self.base_lrs] 63 | 64 | 65 | if __name__ == '__main__': 66 | optim = WarmupPolyLR() 67 | -------------------------------------------------------------------------------- /png/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/LEDNet/8887545ef0c0eba8b8e5d92f9452764d7bd55bb3/png/demo.png -------------------------------------------------------------------------------- /png/gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/LEDNet/8887545ef0c0eba8b8e5d92f9452764d7bd55bb3/png/gt.png -------------------------------------------------------------------------------- /png/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AceCoooool/LEDNet/8887545ef0c0eba8b8e5d92f9452764d7bd55bb3/png/output.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import datetime 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | import torch 9 | from torch import optim 10 | from torch.backends import cudnn 11 | from torch.utils import data 12 | from torchvision import transforms 13 | 14 | cur_path = os.path.dirname(__file__) 15 | sys.path.insert(0, os.path.join(cur_path, '..')) 16 | import utils as ptutil 17 | from utils.metric_seg import SegmentationMetric 18 | from data import get_segmentation_dataset 19 | from data.sampler import make_data_sampler, IterationBasedBatchSampler 20 | from model.loss import MixSoftmaxCrossEntropyLoss, OHEMSoftmaxCrossEntropyLoss 21 | from model.lr_scheduler import WarmupPolyLR 22 | from model.lednet import LEDNet 23 | 24 | 25 | def parse_args(): 26 | """Training Options for Segmentation Experiments""" 27 | parser = argparse.ArgumentParser(description='LEDNet Segmentation') 28 | parser.add_argument('--dataset', type=str, default='citys', 29 | help='dataset name (default: citys)') 30 | parser.add_argument('--workers', '-j', type=int, default=4, 31 | metavar='N', help='dataloader threads') 32 | parser.add_argument('--base-size', type=int, default=512, # 1024 33 | help='base image size') 34 | parser.add_argument('--crop-size', type=int, default=360, # 512 35 | help='crop image size') 36 | parser.add_argument('--train-split', type=str, default='train', 37 | help='dataset train split (default: train)') 38 | parser.add_argument('--drop-rate', type=float, default=0.3, 39 | help='drop rate of SSnbt') 40 | # training hyper params 41 | parser.add_argument('--ohem', type=ptutil.str2bool, default='false', 42 | help='whether using ohem loss') 43 | parser.add_argument('--epochs', type=int, default=240, metavar='N', 44 | help='number of epochs to train (default: 50)') 45 | parser.add_argument('--start_epoch', type=int, default=0, 46 | metavar='N', help='start epochs (default:0)') 47 | parser.add_argument('--batch-size', type=int, default=2, 48 | metavar='N', help='input batch size for \ 49 | training (default: 16)') 50 | parser.add_argument('--test-batch-size', type=int, default=1, 51 | metavar='N', help='input batch size for \ 52 | testing (default: 16)') 53 | parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', 54 | help='learning rate (default: 1e-3)') 55 | parser.add_argument('--momentum', type=float, default=0.9, 56 | metavar='M', help='momentum (default: 0.9)') 57 | parser.add_argument('--weight-decay', type=float, default=1e-4, 58 | metavar='M', help='w-decay (default: 1e-4)') 59 | parser.add_argument('--warmup-iters', type=int, default=200, # 500 60 | help='warmup iterations') 61 | parser.add_argument('--warmup-factor', type=float, default=1.0 / 3, 62 | help='warm up start lr=warmup_factor*lr') 63 | parser.add_argument('--eval-epochs', type=int, default=-1, 64 | help='validate interval') 65 | parser.add_argument('--skip-eval', type=ptutil.str2bool, default='False', 66 | help='whether to skip evaluation') 67 | # cuda and logging 68 | parser.add_argument('--no-cuda', type=ptutil.str2bool, default='False', 69 | help='disables CUDA training') 70 | parser.add_argument('--local_rank', type=int, default=0) 71 | parser.add_argument('--init-method', type=str, default="env://") 72 | parser.add_argument('--dtype', type=str, default='float32', 73 | help='data type for training. default is float32') 74 | # checking point 75 | parser.add_argument('--log-step', type=int, default=1, 76 | help='iteration to show results') 77 | parser.add_argument('--save-epoch', type=int, default=10, 78 | help='epoch interval to save model.') 79 | parser.add_argument('--save-dir', type=str, default=cur_path, 80 | help='Resume from previously saved parameters if not None.') 81 | parser.add_argument('--resume', type=str, default=None, 82 | help='put the path to resuming file if needed') 83 | 84 | # the parser 85 | args = parser.parse_args() 86 | 87 | args.lr = args.lr * args.batch_size 88 | return args 89 | 90 | 91 | class Trainer(object): 92 | def __init__(self, args): 93 | self.device = torch.device(args.device) 94 | # image transform 95 | input_transform = transforms.Compose([ 96 | transforms.ToTensor(), 97 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 98 | ]) 99 | # dataset and dataloader 100 | data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 101 | 'crop_size': args.crop_size} 102 | trainset = get_segmentation_dataset( 103 | args.dataset, split=args.train_split, mode='train', **data_kwargs) 104 | args.per_iter = len(trainset) // (args.num_gpus * args.batch_size) 105 | args.max_iter = args.epochs * args.per_iter 106 | if args.distributed: 107 | sampler = data.DistributedSampler(trainset) 108 | else: 109 | sampler = data.RandomSampler(trainset) 110 | train_sampler = data.sampler.BatchSampler(sampler, args.batch_size, True) 111 | train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=args.max_iter) 112 | self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=True, 113 | num_workers=args.workers) 114 | if not args.skip_eval or 0 < args.eval_epochs < args.epochs: 115 | valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs) 116 | val_sampler = make_data_sampler(valset, False, args.distributed) 117 | val_batch_sampler = data.sampler.BatchSampler(val_sampler, args.test_batch_size, False) 118 | self.valid_loader = data.DataLoader(valset, batch_sampler=val_batch_sampler, 119 | num_workers=args.workers, pin_memory=True) 120 | 121 | # create network 122 | self.net = LEDNet(trainset.NUM_CLASS, args.drop_rate) 123 | 124 | if args.distributed: 125 | self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net) 126 | self.net.to(self.device) 127 | # resume checkpoint if needed 128 | if args.resume is not None: 129 | if os.path.isfile(args.resume): 130 | self.net.load_state_dict(torch.load(args.resume)) 131 | else: 132 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) 133 | 134 | # create criterion 135 | if args.ohem: 136 | min_kept = args.batch_size * args.crop_size ** 2 // 16 137 | self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7, min_kept=min_kept, use_weight=False) 138 | else: 139 | self.criterion = MixSoftmaxCrossEntropyLoss() 140 | 141 | # optimizer and lr scheduling 142 | self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum, 143 | weight_decay=args.weight_decay) 144 | self.scheduler = WarmupPolyLR(self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor, 145 | warmup_iters=args.warmup_iters, power=0.9) 146 | 147 | if args.distributed: 148 | self.net = torch.nn.parallel.DistributedDataParallel( 149 | self.net, device_ids=[args.local_rank], output_device=args.local_rank) 150 | 151 | # evaluation metrics 152 | self.metric = SegmentationMetric(trainset.num_class) 153 | self.args = args 154 | 155 | def training(self): 156 | self.net.train() 157 | save_to_disk = ptutil.get_rank() == 0 158 | start_training_time = time.time() 159 | trained_time = 0 160 | tic = time.time() 161 | end = time.time() 162 | iteration, max_iter = 0, self.args.max_iter 163 | save_iter, eval_iter = self.args.per_iter * self.args.save_epoch, self.args.per_iter * self.args.eval_epochs 164 | # save_iter, eval_iter = 10, 10 165 | 166 | logger.info("Start training, total epochs {:3d} = total iteration: {:6d}".format(self.args.epochs, max_iter)) 167 | 168 | for i, (image, target) in enumerate(self.train_loader): 169 | iteration += 1 170 | self.scheduler.step() 171 | self.optimizer.zero_grad() 172 | image, target = image.to(self.device), target.to(self.device) 173 | outputs = self.net(image) 174 | loss_dict = self.criterion(outputs, target) 175 | # reduce losses over all GPUs for logging purposes 176 | loss_dict_reduced = ptutil.reduce_loss_dict(loss_dict) 177 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 178 | 179 | loss = sum(loss for loss in loss_dict.values()) 180 | loss.backward() 181 | self.optimizer.step() 182 | trained_time += time.time() - end 183 | end = time.time() 184 | if iteration % args.log_step == 0: 185 | eta_seconds = int((trained_time / iteration) * (max_iter - iteration)) 186 | log_str = ["Iteration {:06d} , Lr: {:.5f}, Cost: {:.2f}s, Eta: {}" 187 | .format(iteration, self.optimizer.param_groups[0]['lr'], time.time() - tic, 188 | str(datetime.timedelta(seconds=eta_seconds))), 189 | "total_loss: {:.3f}".format(losses_reduced.item())] 190 | log_str = ', '.join(log_str) 191 | logger.info(log_str) 192 | tic = time.time() 193 | if save_to_disk and iteration % save_iter == 0: 194 | model_path = os.path.join(self.args.save_dir, "{}_iter_{:06d}.pth" 195 | .format('LEDNet', iteration)) 196 | self.save_model(model_path) 197 | # Do eval when training, to trace the mAP changes and see performance improved whether or nor 198 | if args.eval_epochs > 0 and iteration % eval_iter == 0 and not iteration == max_iter: 199 | metrics = self.validate() 200 | ptutil.synchronize() 201 | pixAcc, mIoU = ptutil.accumulate_metric(metrics) 202 | if pixAcc is not None: 203 | logger.info('pixAcc: {:.4f}, mIoU: {:.4f}'.format(pixAcc, mIoU)) 204 | self.net.train() 205 | if save_to_disk: 206 | model_path = os.path.join(self.args.save_dir, "{}_iter_{:06d}.pth" 207 | .format('LEDNet', max_iter)) 208 | self.save_model(model_path) 209 | # compute training time 210 | total_training_time = int(time.time() - start_training_time) 211 | total_time_str = str(datetime.timedelta(seconds=total_training_time)) 212 | logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / max_iter)) 213 | # eval after training 214 | if not self.args.skip_eval: 215 | metrics = self.validate() 216 | ptutil.synchronize() 217 | pixAcc, mIoU = ptutil.accumulate_metric(metrics) 218 | if pixAcc is not None: 219 | logger.info('After training, pixAcc: {:.4f}, mIoU: {:.4f}'.format(pixAcc, mIoU)) 220 | 221 | def validate(self): 222 | # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0 223 | self.metric.reset() 224 | torch.cuda.empty_cache() 225 | if isinstance(self.net, torch.nn.parallel.DistributedDataParallel): 226 | model = self.net.module 227 | else: 228 | model = self.net 229 | model.eval() 230 | tbar = tqdm(self.valid_loader) 231 | for i, (image, target) in enumerate(tbar): 232 | # if i == 10: break 233 | image, target = image.to(self.device), target.to(self.device) 234 | with torch.no_grad(): 235 | outputs = model(image) 236 | self.metric.update(target, outputs) 237 | return self.metric 238 | 239 | def save_model(self, model_path): 240 | if isinstance(self.net, torch.nn.parallel.DistributedDataParallel): 241 | model = self.net.module 242 | else: 243 | model = self.net 244 | torch.save(model.state_dict(), model_path) 245 | logger.info("Saved checkpoint to {}".format(model_path)) 246 | 247 | 248 | if __name__ == "__main__": 249 | args = parse_args() 250 | 251 | # device setting 252 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 253 | args.distributed = num_gpus > 1 254 | args.num_gpus = num_gpus 255 | if not args.no_cuda and torch.cuda.is_available(): 256 | torch.backends.cudnn.benchmark = True 257 | args.device = "cuda" 258 | else: 259 | args.distributed = False 260 | args.device = "cpu" 261 | if args.distributed: 262 | torch.cuda.set_device(args.local_rank) 263 | torch.distributed.init_process_group(backend="nccl", init_method=args.init_method) 264 | 265 | args.lr = args.lr * args.num_gpus # scale by num gpus 266 | 267 | logger = ptutil.setup_logger('Segmentation', args.save_dir, ptutil.get_rank(), 'log_seg.txt', 'w') 268 | logger.info("Using {} GPUs".format(num_gpus)) 269 | logger.info(args) 270 | trainer = Trainer(args) 271 | 272 | trainer.training() 273 | torch.cuda.empty_cache() 274 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metric_seg import * 2 | from .util import * 3 | from .logger import * 4 | from .parallel import * 5 | from .visual import * -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", filemode='a'): 7 | logger = logging.getLogger(name) 8 | logger.setLevel(logging.DEBUG) 9 | # don't log results for the non-master process 10 | if distributed_rank > 0: 11 | return logger 12 | ch = logging.StreamHandler(stream=sys.stdout) 13 | ch.setLevel(logging.DEBUG) 14 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 15 | ch.setFormatter(formatter) 16 | logger.addHandler(ch) 17 | 18 | if save_dir: 19 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=filemode) 20 | fh.setLevel(logging.DEBUG) 21 | fh.setFormatter(formatter) 22 | logger.addHandler(fh) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def check_label_shapes(labels, preds, wrap=False, shape=False): 5 | """Helper function for checking shape of label and prediction 6 | 7 | Parameters 8 | ---------- 9 | labels : list of `tensor` 10 | The labels of the data. 11 | 12 | preds : list of `tensor` 13 | Predicted values. 14 | 15 | wrap : boolean 16 | If True, wrap labels/preds in a list if they are single NDArray 17 | 18 | shape : boolean 19 | If True, check the shape of labels and preds; 20 | Otherwise only check their length. 21 | """ 22 | if not shape: 23 | label_shape, pred_shape = len(labels), len(preds) 24 | else: 25 | label_shape, pred_shape = labels.shape, preds.shape 26 | 27 | if label_shape != pred_shape: 28 | raise ValueError("Shape of labels {} does not match shape of " 29 | "predictions {}".format(label_shape, pred_shape)) 30 | 31 | if wrap: 32 | if isinstance(labels, torch.Tensor): 33 | labels = [labels] 34 | if isinstance(preds, torch.Tensor): 35 | preds = [preds] 36 | 37 | return labels, preds 38 | 39 | 40 | class EvalMetric(object): 41 | """Base class for all evaluation metrics. 42 | 43 | .. note:: 44 | 45 | This is a base class that provides common metric interfaces. 46 | One should not use this class directly, but instead create new metric 47 | classes that extend it. 48 | 49 | Parameters 50 | ---------- 51 | name : str 52 | Name of this metric instance for display. 53 | output_names : list of str, or None 54 | Name of predictions that should be used when updating with update_dict. 55 | By default include all predictions. 56 | label_names : list of str, or None 57 | Name of labels that should be used when updating with update_dict. 58 | By default include all labels. 59 | """ 60 | 61 | def __init__(self, name, output_names=None, 62 | label_names=None, **kwargs): 63 | self.name = str(name) 64 | self.output_names = output_names 65 | self.label_names = label_names 66 | self._has_global_stats = kwargs.pop("has_global_stats", False) 67 | self._kwargs = kwargs 68 | # self.reset() 69 | 70 | def __str__(self): 71 | return "EvalMetric: {}".format(dict(self.get_name_value())) 72 | 73 | def get_config(self): 74 | """Save configurations of metric. Can be recreated 75 | from configs with metric.create(``**config``) 76 | """ 77 | config = self._kwargs.copy() 78 | config.update({ 79 | 'metric': self.__class__.__name__, 80 | 'name': self.name, 81 | 'output_names': self.output_names, 82 | 'label_names': self.label_names}) 83 | return config 84 | 85 | def update_dict(self, label, pred): 86 | """Update the internal evaluation with named label and pred 87 | 88 | Parameters 89 | ---------- 90 | labels : OrderedDict of str -> NDArray 91 | name to array mapping for labels. 92 | 93 | preds : OrderedDict of str -> NDArray 94 | name to array mapping of predicted outputs. 95 | """ 96 | if self.output_names is not None: 97 | pred = [pred[name] for name in self.output_names] 98 | else: 99 | pred = list(pred.values()) 100 | 101 | if self.label_names is not None: 102 | label = [label[name] for name in self.label_names] 103 | else: 104 | label = list(label.values()) 105 | 106 | self.update(label, pred) 107 | 108 | def update(self, labels, preds): 109 | """Updates the internal evaluation result. 110 | 111 | Parameters 112 | ---------- 113 | labels : list of `NDArray` 114 | The labels of the data. 115 | 116 | preds : list of `NDArray` 117 | Predicted values. 118 | """ 119 | raise NotImplementedError() 120 | 121 | def reset(self): 122 | """Resets the internal evaluation result to initial state.""" 123 | self.num_inst = 0 124 | self.sum_metric = 0.0 125 | self.global_num_inst = 0 126 | self.global_sum_metric = 0.0 127 | 128 | def reset_local(self): 129 | """Resets the local portion of the internal evaluation results 130 | to initial state.""" 131 | self.num_inst = 0 132 | self.sum_metric = 0.0 133 | 134 | def get(self): 135 | """Gets the current evaluation result. 136 | 137 | Returns 138 | ------- 139 | names : list of str 140 | Name of the metrics. 141 | values : list of float 142 | Value of the evaluations. 143 | """ 144 | if self.num_inst == 0: 145 | return (self.name, float('nan')) 146 | else: 147 | return (self.name, self.sum_metric / self.num_inst) 148 | 149 | def get_global(self): 150 | """Gets the current global evaluation result. 151 | 152 | Returns 153 | ------- 154 | names : list of str 155 | Name of the metrics. 156 | values : list of float 157 | Value of the evaluations. 158 | """ 159 | if self._has_global_stats: 160 | if self.global_num_inst == 0: 161 | return (self.name, float('nan')) 162 | else: 163 | return (self.name, self.global_sum_metric / self.global_num_inst) 164 | else: 165 | return self.get() 166 | 167 | def get_name_value(self): 168 | """Returns zipped name and value pairs. 169 | 170 | Returns 171 | ------- 172 | list of tuples 173 | A (name, value) tuple list. 174 | """ 175 | name, value = self.get() 176 | if not isinstance(name, list): 177 | name = [name] 178 | if not isinstance(value, list): 179 | value = [value] 180 | return list(zip(name, value)) 181 | 182 | def get_global_name_value(self): 183 | """Returns zipped name and value pairs for global results. 184 | 185 | Returns 186 | ------- 187 | list of tuples 188 | A (name, value) tuple list. 189 | """ 190 | if self._has_global_stats: 191 | name, value = self.get_global() 192 | if not isinstance(name, list): 193 | name = [name] 194 | if not isinstance(value, list): 195 | value = [value] 196 | return list(zip(name, value)) 197 | else: 198 | return self.get_name_value() 199 | -------------------------------------------------------------------------------- /utils/metric_seg.py: -------------------------------------------------------------------------------- 1 | """Evaluation Metrics for Semantic Segmentation""" 2 | import torch 3 | from utils.metric import EvalMetric 4 | 5 | __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union'] 6 | 7 | 8 | class SegmentationMetric(EvalMetric): 9 | """Computes pixAcc and mIoU metric scores 10 | """ 11 | 12 | def __init__(self, nclass): 13 | super(SegmentationMetric, self).__init__('pixAcc & mIoU') 14 | self.nclass = nclass 15 | self.reset() 16 | 17 | def update(self, labels, preds): 18 | """Updates the internal evaluation result. 19 | 20 | Parameters 21 | ---------- 22 | labels : 'NDArray' or list of `NDArray` 23 | The labels of the data. 24 | 25 | preds : 'NDArray' or list of `NDArray` 26 | Predicted values. 27 | """ 28 | 29 | def evaluate_worker(self, label, pred): 30 | correct, labeled = batch_pix_accuracy(pred, label) 31 | inter, union = batch_intersection_union(pred, label, self.nclass) 32 | self.total_correct += correct 33 | self.total_label += labeled 34 | if self.total_inter.device != inter.device: 35 | self.total_inter = self.total_inter.to(inter.device) 36 | self.total_union = self.total_union.to(union.device) 37 | self.total_inter += inter 38 | self.total_union += union 39 | 40 | if isinstance(preds, torch.Tensor): 41 | evaluate_worker(self, labels, preds) 42 | elif isinstance(preds, (list, tuple)): 43 | for (label, pred) in zip(labels, preds): 44 | evaluate_worker(self, label, pred) 45 | 46 | def get(self): 47 | """Gets the current evaluation result. 48 | 49 | Returns 50 | ------- 51 | metrics : tuple of float 52 | pixAcc and mIoU 53 | """ 54 | pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) 55 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union) 56 | mIoU = IoU.mean().item() 57 | return pixAcc, mIoU 58 | 59 | def reset(self): 60 | """Resets the internal evaluation result to initial state.""" 61 | self.total_inter = torch.zeros(self.nclass) 62 | self.total_union = torch.zeros(self.nclass) 63 | self.total_correct = 0 64 | self.total_label = 0 65 | 66 | def get_value(self): 67 | return {'total_inter': self.total_inter, 'total_union': self.total_union, 68 | 'total_correct': self.total_correct, 'total_label': self.total_label} 69 | 70 | def combine_value(self, values): 71 | if self.total_inter.is_cuda: 72 | device = torch.device('cuda') 73 | self.total_inter += values['total_inter'].to(device) 74 | self.total_union += values['total_union'].to(device) 75 | else: 76 | self.total_inter += values['total_inter'] 77 | self.total_union += values['total_union'] 78 | self.total_correct += values['total_correct'] 79 | self.total_label += values['total_label'] 80 | 81 | # def combine_metric(self, metric): 82 | # if self.total_inter.is_cuda: 83 | # metric.total_inter = metric.total_inter.to(self.total_inter.device) 84 | # self.total_inter += metric.total_inter 85 | # metric.total_union = metric.total_union.to(self.total_union.device) 86 | # self.total_union += metric.total_union 87 | # else: 88 | # self.total_inter += metric.total_inter 89 | # self.total_union += metric.total_union 90 | # self.total_correct += metric.total_correct 91 | # self.total_label += metric.total_label 92 | 93 | 94 | def batch_pix_accuracy(output, target): 95 | """PixAcc""" 96 | # inputs are NDarray, output 4D, target 3D 97 | # the category -1 is ignored class, typically for background / boundary 98 | predict = torch.argmax(output.long(), 1) + 1 99 | 100 | target = target.long() + 1 101 | 102 | pixel_labeled = torch.sum(target > 0).item() 103 | pixel_correct = torch.sum((predict == target) * (target > 0)).item() 104 | 105 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 106 | return pixel_correct, pixel_labeled 107 | 108 | 109 | def batch_intersection_union(output, target, nclass): 110 | """mIoU""" 111 | # inputs are NDarray, output 4D, target 3D 112 | # the category -1 is ignored class, typically for background / boundary 113 | mini = 1 114 | maxi = nclass 115 | nbins = nclass 116 | predict = torch.argmax(output, 1) + 1 117 | target = target.float() + 1 118 | 119 | predict = predict.float() * (target > 0).float() 120 | intersection = predict * (predict == target).float() 121 | # areas of intersection and union 122 | area_inter = torch.histc(intersection, bins=nbins, min=mini, max=maxi) 123 | area_pred = torch.histc(predict, bins=nbins, min=mini, max=maxi) 124 | area_lab = torch.histc(target, bins=nbins, min=mini, max=maxi) 125 | area_union = area_pred + area_lab - area_inter 126 | assert torch.sum(area_inter > area_union).item() == 0, \ 127 | "Intersection area should be smaller than Union area" 128 | return area_inter.float(), area_union.float() 129 | 130 | 131 | if __name__ == '__main__': 132 | a = torch.Tensor([[1.0, 2.0], [3.0, 4.0]]).cuda() 133 | b = torch.LongTensor([[1, 3], [3, 4]]).cuda() 134 | metric = SegmentationMetric(4) 135 | metric.update(a, b) 136 | print(metric.get()) 137 | -------------------------------------------------------------------------------- /utils/parallel.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | def get_world_size(): 8 | if not torch.distributed.is_initialized(): 9 | return 1 10 | return torch.distributed.get_world_size() 11 | 12 | 13 | def get_rank(): 14 | if not torch.distributed.is_initialized(): 15 | return 0 16 | return torch.distributed.get_rank() 17 | 18 | 19 | def is_main_process(): 20 | return get_rank() == 0 21 | 22 | 23 | def synchronize(): 24 | """ 25 | Helper function to synchronize (barrier) among all processes when 26 | using distributed training 27 | """ 28 | if not dist.is_available(): 29 | return 30 | if not dist.is_initialized(): 31 | return 32 | world_size = dist.get_world_size() 33 | if world_size == 1: 34 | return 35 | dist.barrier() 36 | 37 | 38 | def all_gather(data): 39 | """ 40 | Run all_gather on arbitrary picklable data (not necessarily tensors) 41 | Args: 42 | data: any picklable object 43 | Returns: 44 | list[data]: list of data gathered from each rank 45 | """ 46 | world_size = get_world_size() 47 | if world_size == 1: 48 | return [data] 49 | 50 | # serialized to a Tensor 51 | buffer = pickle.dumps(data) 52 | storage = torch.ByteStorage.from_buffer(buffer) 53 | tensor = torch.ByteTensor(storage).to("cuda") 54 | 55 | # obtain Tensor size of each rank 56 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 57 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 58 | dist.all_gather(size_list, local_size) 59 | size_list = [int(size.item()) for size in size_list] 60 | max_size = max(size_list) 61 | 62 | # receiving Tensor from all ranks 63 | # we pad the tensor because torch all_gather does not support 64 | # gathering tensors of different shapes 65 | tensor_list = [] 66 | for _ in size_list: 67 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 68 | if local_size != max_size: 69 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 70 | tensor = torch.cat((tensor, padding), dim=0) 71 | dist.all_gather(tensor_list, tensor) 72 | 73 | data_list = [] 74 | for size, tensor in zip(size_list, tensor_list): 75 | buffer = tensor.cpu().numpy().tobytes()[:size] 76 | data_list.append(pickle.loads(buffer)) 77 | 78 | return data_list 79 | 80 | 81 | def reduce_loss_dict(loss_dict): 82 | """ 83 | Reduce the loss dictionary from all processes so that process with rank 84 | 0 has the averaged results. Returns a dict with the same fields as 85 | loss_dict, after reduction. 86 | """ 87 | world_size = get_world_size() 88 | if world_size < 2: 89 | return loss_dict 90 | with torch.no_grad(): 91 | loss_names = [] 92 | all_losses = [] 93 | for k in sorted(loss_dict.keys()): 94 | loss_names.append(k) 95 | all_losses.append(loss_dict[k]) 96 | all_losses = torch.stack(all_losses, dim=0) 97 | dist.reduce(all_losses, dst=0) 98 | if dist.get_rank() == 0: 99 | # only main process gets accumulated, so only divide by 100 | # world_size in this case 101 | all_losses /= world_size 102 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} 103 | return reduced_losses 104 | 105 | 106 | def accumulate_metric(metric): 107 | all_values = all_gather(metric.get_value()) 108 | if not is_main_process(): 109 | return None, None 110 | for value in all_values[1:]: 111 | metric.combine_value(value) 112 | return metric.get() 113 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 6 | return True 7 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 8 | return False 9 | else: 10 | raise argparse.ArgumentTypeError('Boolean value expected.') 11 | -------------------------------------------------------------------------------- /utils/visual.py: -------------------------------------------------------------------------------- 1 | """Segmentation Utils""" 2 | from PIL import Image 3 | 4 | __all__ = ['get_color_pallete'] 5 | 6 | 7 | def get_color_pallete(npimg, dataset='pascal_voc'): 8 | """Visualize image. 9 | 10 | Parameters 11 | ---------- 12 | npimg : numpy.ndarray 13 | Single channel image with shape `H, W, 1`. 14 | dataset : str, default: 'pascal_voc' 15 | The dataset that model pretrained on. ('pascal_voc', 'ade20k') 16 | 17 | Returns 18 | ------- 19 | out_img : PIL.Image 20 | Image with color pallete 21 | 22 | """ 23 | # recovery boundary 24 | if dataset in ('pascal_voc', 'pascal_aug'): 25 | npimg[npimg == -1] = 255 26 | # put colormap 27 | if dataset == 'ade20k': 28 | npimg = npimg + 1 29 | out_img = Image.fromarray(npimg.astype('uint8')) 30 | out_img.putpalette(adepallete) 31 | return out_img 32 | elif dataset == 'citys': 33 | out_img = Image.fromarray(npimg.astype('uint8')) 34 | out_img.putpalette(cityspallete) 35 | return out_img 36 | out_img = Image.fromarray(npimg.astype('uint8')) 37 | out_img.putpalette(vocpallete) 38 | return out_img 39 | 40 | 41 | def _getvocpallete(num_cls): 42 | n = num_cls 43 | pallete = [0] * (n * 3) 44 | for j in range(0, n): 45 | lab = j 46 | pallete[j * 3 + 0] = 0 47 | pallete[j * 3 + 1] = 0 48 | pallete[j * 3 + 2] = 0 49 | i = 0 50 | while (lab > 0): 51 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 52 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 53 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 54 | i = i + 1 55 | lab >>= 3 56 | return pallete 57 | 58 | 59 | vocpallete = _getvocpallete(256) 60 | 61 | adepallete = [ 62 | 0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204, 63 | 5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82, 64 | 143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255, 65 | 7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6, 66 | 10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255, 67 | 20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15, 68 | 20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255, 69 | 31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163, 70 | 0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255, 71 | 0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0, 72 | 31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0, 73 | 194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255, 74 | 0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255, 75 | 0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0, 76 | 163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0, 77 | 10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0, 78 | 255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0, 79 | 133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255] 80 | 81 | cityspallete = [ 82 | 128, 64, 128, 83 | 244, 35, 232, 84 | 70, 70, 70, 85 | 102, 102, 156, 86 | 190, 153, 153, 87 | 153, 153, 153, 88 | 250, 170, 30, 89 | 220, 220, 0, 90 | 107, 142, 35, 91 | 152, 251, 152, 92 | 0, 130, 180, 93 | 220, 20, 60, 94 | 255, 0, 0, 95 | 0, 0, 142, 96 | 0, 0, 70, 97 | 0, 60, 100, 98 | 0, 80, 100, 99 | 0, 0, 230, 100 | 119, 11, 32, 101 | ] 102 | --------------------------------------------------------------------------------