├── .gitignore
├── LICENSE
├── README.md
├── data
├── __init__.py
├── base.py
├── base_seg.py
├── cityscapes.py
└── sampler.py
├── demo.py
├── eval.py
├── model
├── __init__.py
├── basic.py
├── lednet.py
├── loss.py
└── lr_scheduler.py
├── png
├── demo.png
├── gt.png
└── output.png
├── train.py
└── utils
├── __init__.py
├── logger.py
├── metric.py
├── metric_seg.py
├── parallel.py
├── util.py
└── visual.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Ace
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LEDNet
2 | This is an unofficial implement of [LEDNet](https://arxiv.org/abs/1905.02423).
3 |
4 | > the official version:[LEDNet-official](https://github.com/xiaoyufenfei/LEDNet)
5 |
6 |
9 |
10 | ## Environment
11 |
12 | - Python 3.6
13 | - PyTorch 1.1
14 |
15 | ## Performance
16 |
17 | - Base Size 1024, Crop Size 768, only fine. (new-version, with dropout)
18 |
19 | | Model | Paper | OHEM | Drop-rate | lr | Epoch | val (crop) | val |
20 | | :----: | :---: | :--: | :-------: | :----: | :---: | :---------: | :----------------------------------------------------------: |
21 | | LEDNet | / | ✗ | 0.1 | 0.0005 | 800 | 60.32/94.51 | 66.29/94.40 |
22 | | LEDNet | / | ✗ | 0.1 | 0.005 | 600 | 61.29/94.75 | 66.56/94.72 |
23 | | LEDNet | / | ✗ | 0.3 | 0.01 | 800 | 63.84/94.83 | [69.09/94.75](https://drive.google.com/open?id=1oelPUKAnZYD75RruyBQU9HZKneMEMIAp) |
24 |
25 | > Note:
26 | >
27 | > - The paper only provide the test results: 69.2/86.8 (class mIoU/category mIoU).
28 | > - And the training setting is a little different with original paper (original paper use 1024x512)
29 |
30 | Some things you can use to improve the performance:
31 |
32 | 1. use larger learning rate (like 0.01)
33 | 2. use more epochs (like 1000)
34 | 3. use larger training input size (like Base Size 1344, Crop Size 1024)
35 |
36 | ## Demo
37 |
38 | Please download [pretrained](https://drive.google.com/open?id=1oelPUKAnZYD75RruyBQU9HZKneMEMIAp) model first
39 |
40 | ```shell
41 | $ python demo.py [--input-pic png/demo.png] [--pretrained your-root-of-pretrained] [--cuda true]
42 | ```
43 |
44 | ## Evaluation
45 |
46 | The default data root is `~/.torch/datasets` (You can download dataset and build a soft-link to it)
47 |
48 | ```shell
49 | $ python eval.py [--mode testval] [--pretrained root-of-pretrained-model] [--cuda true]
50 | ```
51 |
52 | ## Training
53 |
54 | Recommend to using distributed training.
55 |
56 | ```shell
57 | $ export NGPUS=4
58 | $ python -m torch.distributed.launch --nproc_per_node=$NGPUS train.py [--dataset citys] [--batch-size 8] [--base-size 1024] [--crop-size 768] [--epochs 800] [--warmup-factor 0.1] [--warmup-iters 200] [--log-step 10] [--save-epoch 40] [--lr 0.005]
59 | ```
60 |
61 | ## Prepare data
62 |
63 | Your can reference [gluon-cv-cityspaces](https://gluon-cv.mxnet.io/build/examples_datasets/cityscapes.html#sphx-glr-build-examples-datasets-cityscapes-py) to prepare the dataset
64 |
65 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .cityscapes import CitySegmentation
2 |
3 | datasets = {
4 | 'citys': CitySegmentation,
5 | }
6 |
7 |
8 | def get_segmentation_dataset(name, **kwargs):
9 | """Segmentation Datasets"""
10 | return datasets[name.lower()](**kwargs)
11 |
--------------------------------------------------------------------------------
/data/base.py:
--------------------------------------------------------------------------------
1 | """Base dataset methods."""
2 | import os
3 | from torch.utils import data
4 |
5 |
6 | class ClassProperty(object):
7 | """Readonly @ClassProperty descriptor for internal usage."""
8 |
9 | def __init__(self, fget):
10 | self.fget = fget
11 |
12 | def __get__(self, owner_self, owner_cls):
13 | return self.fget(owner_cls)
14 |
15 |
16 | class SimpleDataset(data.Dataset):
17 | """Simple Dataset wrapper for lists and arrays.
18 |
19 | Parameters
20 | ----------
21 | data : dataset-like object
22 | Any object that implements `len()` and `[]`.
23 | """
24 |
25 | def __init__(self, data):
26 | self._data = data
27 |
28 | def __len__(self):
29 | return len(self._data)
30 |
31 | def __getitem__(self, idx):
32 | return self._data[idx]
33 |
34 |
35 | class _LazyTransformDataset(data.Dataset):
36 | """Lazily transformed dataset."""
37 |
38 | def __init__(self, data, fn):
39 | super(_LazyTransformDataset, self).__init__()
40 | self._data = data
41 | self._fn = fn
42 |
43 | def __len__(self):
44 | return len(self._data)
45 |
46 | def __getitem__(self, idx):
47 | item = self._data[idx]
48 | if isinstance(item, tuple):
49 | return self._fn(*item)
50 | return self._fn(item)
51 |
52 | def transform(self, fn):
53 | self._fn = fn
54 |
55 |
56 | class VisionDataset(data.Dataset):
57 | """Base Dataset with directory checker.
58 |
59 | Parameters
60 | ----------
61 | root : str
62 | The root path of xxx.names, by default is '~/.mxnet/datasets/foo', where
63 | `foo` is the name of the dataset.
64 | """
65 |
66 | def __init__(self, root):
67 | super(VisionDataset, self).__init__()
68 | if not os.path.isdir(os.path.expanduser(root)):
69 | helper_msg = "{} is not a valid dir. Did you forget to initialize \
70 | datasets described in: \
71 | `http://gluon-cv.mxnet.io/build/examples_datasets/index.html`? \
72 | You need to initialize each dataset only once.".format(root)
73 | raise OSError(helper_msg)
74 |
75 | @property
76 | def classes(self):
77 | raise NotImplementedError
78 |
79 | @property
80 | def num_class(self):
81 | """Number of categories."""
82 | return len(self.classes)
83 |
84 | def transform(self, fn, lazy=True):
85 | """Returns a new dataset with each sample transformed by the
86 | transformer function `fn`.
87 |
88 | Parameters
89 | ----------
90 | fn : callable
91 | A transformer function that takes a sample as input and
92 | returns the transformed sample.
93 | lazy : bool, default True
94 | If False, transforms all samples at once. Otherwise,
95 | transforms each sample on demand. Note that if `fn`
96 | is stochastic, you must set lazy to True or you will
97 | get the same result on all epochs.
98 |
99 | Returns
100 | -------
101 | Dataset
102 | The transformed dataset.
103 | """
104 | trans = _LazyTransformDataset(self, fn)
105 | if lazy:
106 | return trans
107 | return SimpleDataset([i for i in trans])
108 |
--------------------------------------------------------------------------------
/data/base_seg.py:
--------------------------------------------------------------------------------
1 | """Base segmentation dataset"""
2 | import torch
3 | import random
4 | import numpy as np
5 | from PIL import Image, ImageOps, ImageFilter
6 |
7 | from data.base import VisionDataset
8 |
9 |
10 | class SegmentationDataset(VisionDataset):
11 | """Segmentation Base Dataset"""
12 |
13 | # pylint: disable=abstract-method
14 | def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
15 | super(SegmentationDataset, self).__init__(root)
16 | self.root = root
17 | self.transform = transform
18 | self.split = split
19 | self.mode = mode if mode is not None else split
20 | self.base_size = base_size
21 | self.crop_size = crop_size
22 |
23 | def _val_sync_transform(self, img, mask):
24 | outsize = self.crop_size
25 | short_size = outsize
26 | w, h = img.size
27 | if w > h:
28 | oh = short_size
29 | ow = int(1.0 * w * oh / h)
30 | else:
31 | ow = short_size
32 | oh = int(1.0 * h * ow / w)
33 | img = img.resize((ow, oh), Image.BILINEAR)
34 | mask = mask.resize((ow, oh), Image.NEAREST)
35 | # center crop
36 | w, h = img.size
37 | x1 = int(round((w - outsize) / 2.))
38 | y1 = int(round((h - outsize) / 2.))
39 | img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
40 | mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
41 | # final transform
42 | img, mask = self._img_transform(img), self._mask_transform(mask)
43 | return img, mask
44 |
45 | def _sync_transform(self, img, mask):
46 | # random mirror
47 | if random.random() < 0.5:
48 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
49 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
50 | crop_size = self.crop_size
51 | # random scale (short edge)
52 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
53 | w, h = img.size
54 | if h > w:
55 | ow = short_size
56 | oh = int(1.0 * h * ow / w)
57 | else:
58 | oh = short_size
59 | ow = int(1.0 * w * oh / h)
60 | img = img.resize((ow, oh), Image.BILINEAR)
61 | mask = mask.resize((ow, oh), Image.NEAREST)
62 | # pad crop
63 | if short_size < crop_size:
64 | padh = crop_size - oh if oh < crop_size else 0
65 | padw = crop_size - ow if ow < crop_size else 0
66 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
67 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
68 | # random crop crop_size
69 | w, h = img.size
70 | x1 = random.randint(0, w - crop_size)
71 | y1 = random.randint(0, h - crop_size)
72 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
73 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
74 | # gaussian blur as in PSP
75 | if random.random() < 0.5:
76 | img = img.filter(ImageFilter.GaussianBlur(
77 | radius=random.random()))
78 | # final transform
79 | img, mask = self._img_transform(img), self._mask_transform(mask)
80 | return img, mask
81 |
82 | def _img_transform(self, img):
83 | # return torch.from_numpy(np.array(img))
84 | return np.array(img)
85 |
86 | def _mask_transform(self, mask):
87 | # return torch.from_numpy(np.array(mask).astype('int32'))
88 | return np.array(mask).astype('int64')
89 |
90 | @property
91 | def num_class(self):
92 | """Number of categories."""
93 | return self.NUM_CLASS
94 |
95 | @property
96 | def pred_offset(self):
97 | return 0
98 |
99 |
100 | def ms_batchify_fn(data):
101 | """Multi-size batchify function"""
102 | if isinstance(data[0], (str, torch.Tensor)):
103 | return list(data)
104 | elif isinstance(data[0], tuple):
105 | data = zip(*data)
106 | return [ms_batchify_fn(i) for i in data]
107 | raise RuntimeError('unknown datatype')
108 |
--------------------------------------------------------------------------------
/data/cityscapes.py:
--------------------------------------------------------------------------------
1 | """Cityscapes Dataloader"""
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import torch
6 |
7 | from data.base_seg import SegmentationDataset
8 |
9 |
10 | def _get_city_pairs(folder, split='train'):
11 | def get_path_pairs(img_folder, mask_folder):
12 | img_paths = []
13 | mask_paths = []
14 | for root, _, files in os.walk(img_folder):
15 | for filename in files:
16 | if filename.endswith(".png"):
17 | imgpath = os.path.join(root, filename)
18 | foldername = os.path.basename(os.path.dirname(imgpath))
19 | maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
20 | maskpath = os.path.join(mask_folder, foldername, maskname)
21 | if os.path.isfile(imgpath) and os.path.isfile(maskpath):
22 | img_paths.append(imgpath)
23 | mask_paths.append(maskpath)
24 | else:
25 | print('cannot find the mask or image:', imgpath, maskpath)
26 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder))
27 | return img_paths, mask_paths
28 |
29 | if split in ('train', 'val'):
30 | img_folder = os.path.join(folder, 'leftImg8bit/' + split)
31 | mask_folder = os.path.join(folder, 'gtFine/' + split)
32 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
33 | return img_paths, mask_paths
34 | else:
35 | assert split == 'trainval'
36 | print('trainval set')
37 | train_img_folder = os.path.join(folder, 'leftImg8bit/train')
38 | train_mask_folder = os.path.join(folder, 'gtFine/train')
39 | val_img_folder = os.path.join(folder, 'leftImg8bit/val')
40 | val_mask_folder = os.path.join(folder, 'gtFine/val')
41 | train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder)
42 | val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
43 | img_paths = train_img_paths + val_img_paths
44 | mask_paths = train_mask_paths + val_mask_paths
45 | return img_paths, mask_paths
46 |
47 |
48 | class CitySegmentation(SegmentationDataset):
49 | """Cityscapes Dataloader"""
50 | # pylint: disable=abstract-method
51 | BASE_DIR = 'cityscapes'
52 | NUM_CLASS = 19
53 |
54 | def __init__(self, root=os.path.expanduser('~/.torch/datasets/citys'), split='train',
55 | mode=None, transform=None, **kwargs):
56 | super(CitySegmentation, self).__init__(
57 | root, split, mode, transform, **kwargs)
58 | # self.root = os.path.join(root, self.BASE_DIR)
59 | self.images, self.mask_paths = _get_city_pairs(self.root, self.split)
60 | assert (len(self.images) == len(self.mask_paths))
61 | if len(self.images) == 0:
62 | raise RuntimeError("Found 0 images in subfolders of: \
63 | " + self.root + "\n")
64 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
65 | 23, 24, 25, 26, 27, 28, 31, 32, 33]
66 | self._key = np.array([-1, -1, -1, -1, -1, -1,
67 | -1, -1, 0, 1, -1, -1,
68 | 2, 3, 4, -1, -1, -1,
69 | 5, -1, 6, 7, 8, 9,
70 | 10, 11, 12, 13, 14, 15,
71 | -1, -1, 16, 17, 18])
72 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
73 |
74 | def _class_to_index(self, mask):
75 | # assert the values
76 | values = np.unique(mask)
77 | for value in values:
78 | assert (value in self._mapping)
79 | index = np.digitize(mask.ravel(), self._mapping, right=True)
80 | return self._key[index].reshape(mask.shape)
81 |
82 | def __getitem__(self, index):
83 | img = Image.open(self.images[index]).convert('RGB')
84 | if self.mode == 'test':
85 | if self.transform is not None:
86 | img = self.transform(img)
87 | return img, os.path.basename(self.images[index])
88 | # mask = self.masks[index]
89 | mask = Image.open(self.mask_paths[index])
90 | # synchrosized transform
91 | if self.mode == 'train':
92 | img, mask = self._sync_transform(img, mask)
93 | elif self.mode == 'val':
94 | img, mask = self._val_sync_transform(img, mask)
95 | else:
96 | assert self.mode == 'testval'
97 | img, mask = self._img_transform(img), self._mask_transform(mask)
98 | # general resize, normalize and toTensor
99 | if self.transform is not None:
100 | img = self.transform(img)
101 | return img, mask
102 |
103 | def _mask_transform(self, mask):
104 | target = self._class_to_index(np.array(mask).astype('int64'))
105 | return torch.from_numpy(target)
106 |
107 | def __len__(self):
108 | return len(self.images)
109 |
110 |
111 | if __name__ == '__main__':
112 | data = CitySegmentation(split='val', mode='val')
113 | print(data[0][0].shape, data[0][0].shape)
114 |
--------------------------------------------------------------------------------
/data/sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | # Code is copy-pasted exactly as in torch.utils.data.distributed.
3 | # FIXME remove this once c10d fixes the bug it has
4 | import math
5 | import torch
6 | import torch.distributed as dist
7 | from torch.utils.data.sampler import Sampler, BatchSampler
8 |
9 |
10 | class DistributedSampler(Sampler):
11 | """Sampler that restricts data loading to a subset of the dataset.
12 | It is especially useful in conjunction with
13 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
14 | process can pass a DistributedSampler instance as a DataLoader sampler,
15 | and load a subset of the original dataset that is exclusive to it.
16 | .. note::
17 | Dataset is assumed to be of constant size.
18 | Arguments:
19 | dataset: Dataset used for sampling.
20 | num_replicas (optional): Number of processes participating in
21 | distributed training.
22 | rank (optional): Rank of the current process within num_replicas.
23 | """
24 |
25 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
26 | if num_replicas is None:
27 | if not dist.is_available():
28 | raise RuntimeError("Requires distributed package to be available")
29 | num_replicas = dist.get_world_size()
30 | if rank is None:
31 | if not dist.is_available():
32 | raise RuntimeError("Requires distributed package to be available")
33 | rank = dist.get_rank()
34 | self.dataset = dataset
35 | self.num_replicas = num_replicas
36 | self.rank = rank
37 | self.epoch = 0
38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
39 | self.total_size = self.num_samples * self.num_replicas
40 | self.shuffle = shuffle
41 |
42 | def __iter__(self):
43 | if self.shuffle:
44 | # deterministically shuffle based on epoch
45 | g = torch.Generator()
46 | g.manual_seed(self.epoch)
47 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
48 | else:
49 | indices = torch.arange(len(self.dataset)).tolist()
50 |
51 | # add extra samples to make it evenly divisible
52 | indices += indices[: (self.total_size - len(indices))]
53 | assert len(indices) == self.total_size
54 |
55 | # subsample
56 | offset = self.num_samples * self.rank
57 | indices = indices[offset: offset + self.num_samples]
58 | assert len(indices) == self.num_samples
59 |
60 | return iter(indices)
61 |
62 | def __len__(self):
63 | return self.num_samples
64 |
65 | def set_epoch(self, epoch):
66 | self.epoch = epoch
67 |
68 |
69 | def make_data_sampler(dataset, shuffle, distributed):
70 | if distributed:
71 | return DistributedSampler(dataset, shuffle=shuffle)
72 | if shuffle:
73 | sampler = torch.utils.data.sampler.RandomSampler(dataset)
74 | else:
75 | sampler = torch.utils.data.sampler.SequentialSampler(dataset)
76 | return sampler
77 |
78 |
79 | class IterationBasedBatchSampler(BatchSampler):
80 | """
81 | Wraps a BatchSampler, resampling from it until
82 | a specified number of iterations have been sampled
83 | """
84 |
85 | def __init__(self, batch_sampler, num_iterations, start_iter=0):
86 | self.batch_sampler = batch_sampler
87 | self.num_iterations = num_iterations
88 | self.start_iter = start_iter
89 |
90 | def __iter__(self):
91 | iteration = self.start_iter
92 | while iteration <= self.num_iterations:
93 | # if the underlying sampler has a set_epoch method, like
94 | # DistributedSampler, used for making each process see
95 | # a different split of the dataset, then set it
96 | if hasattr(self.batch_sampler.sampler, "set_epoch"):
97 | self.batch_sampler.sampler.set_epoch(iteration)
98 | for batch in self.batch_sampler:
99 | iteration += 1
100 | if iteration > self.num_iterations:
101 | break
102 | yield batch
103 |
104 | def __len__(self):
105 | return self.num_iterations
106 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | from PIL import Image
5 | import matplotlib.pyplot as plt
6 | import matplotlib.image as mpimg
7 |
8 | import torch
9 | from torchvision import transforms
10 |
11 | cur_path = os.path.dirname(__file__)
12 | sys.path.insert(0, os.path.join(cur_path, '..'))
13 | from model.lednet import LEDNet
14 | import utils as ptutil
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(description='Demo for LEDNet from a given image')
19 |
20 | parser.add_argument('--input-pic', type=str, default=os.path.join(cur_path, 'png/demo.png'),
21 | help='path to the input picture')
22 | parser.add_argument('--pretrained', type=str,
23 | default=os.path.expanduser('~/cbb/own/pretrained/seg/lednet/LEDNet_final.pth'),
24 | help='Default Pre-trained model root.')
25 | parser.add_argument('--cuda', type=ptutil.str2bool, default='true',
26 | help='demo with GPU')
27 |
28 | opt = parser.parse_args()
29 | return opt
30 |
31 |
32 | if __name__ == '__main__':
33 | args = parse_args()
34 | device = torch.device('cpu')
35 | if args.cuda:
36 | device = torch.device('cuda')
37 | # Load Model
38 | model = LEDNet(19).to(device)
39 | model.load_state_dict(torch.load(args.pretrained))
40 | model.eval()
41 |
42 | # Load Images
43 | img = Image.open(args.input_pic)
44 |
45 | # Transform
46 | transform_fn = transforms.Compose([
47 | transforms.ToTensor(),
48 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
49 | ])
50 |
51 | img = transform_fn(img).unsqueeze(0).to(device)
52 | with torch.no_grad():
53 | output = model(img)
54 |
55 | predict = torch.argmax(output, 1).squeeze(0).cpu().numpy()
56 | mask = ptutil.get_color_pallete(predict, 'citys')
57 | mask.save(os.path.join(cur_path, 'png/output.png'))
58 | mmask = mpimg.imread(os.path.join(cur_path, 'png/output.png'))
59 | plt.imshow(mmask)
60 | plt.show()
61 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | import torch
7 | from torch.utils import data
8 | from torchvision import transforms
9 |
10 | cur_path = os.path.dirname(__file__)
11 | sys.path.insert(0, os.path.join(cur_path, '..'))
12 | import utils as ptutil
13 | from model.lednet import LEDNet
14 | from data import get_segmentation_dataset
15 | from data.sampler import make_data_sampler
16 | from utils.metric_seg import SegmentationMetric
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description='Eval Segmentation.')
21 | parser.add_argument('--batch-size', type=int, default=1,
22 | help='Training mini-batch size')
23 | parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
24 | default=4, help='Number of data workers')
25 | parser.add_argument('--dataset', type=str, default='citys',
26 | help='Select dataset.')
27 | parser.add_argument('--split', type=str, default='val',
28 | help='Select val|test, evaluate in val or test data')
29 | parser.add_argument('--mode', type=str, default='testval',
30 | help='Select testval|val, w/o corp and with crop')
31 | parser.add_argument('--base-size', type=int, default=1024,
32 | help='base image size')
33 | parser.add_argument('--crop-size', type=int, default=768,
34 | help='crop image size')
35 |
36 | parser.add_argument('--pretrained', type=str,
37 | default='./LEDNet_iter_073600.pth',
38 | help='Default Pre-trained model root.')
39 |
40 | # device
41 | parser.add_argument('--cuda', type=ptutil.str2bool, default='true',
42 | help='Training with GPUs.')
43 | parser.add_argument('--local_rank', type=int, default=0)
44 | parser.add_argument('--init-method', type=str, default="env://")
45 |
46 | args = parser.parse_args()
47 | return args
48 |
49 |
50 | def validate(net, val_data, metric, device):
51 | net.eval()
52 | tbar = tqdm(val_data)
53 | for i, (data, targets) in enumerate(tbar):
54 | data, targets = data.to(device), targets.to(device)
55 | with torch.no_grad():
56 | predicts = net(data)
57 | metric.update(targets, predicts)
58 | return metric
59 |
60 |
61 | if __name__ == '__main__':
62 | args = parse_args()
63 |
64 | device = torch.device('cpu')
65 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
66 | distributed = num_gpus > 1
67 | if args.cuda and torch.cuda.is_available():
68 | torch.backends.cudnn.benchmark = False if args.mode == 'testval' else True
69 | device = torch.device('cuda')
70 | else:
71 | distributed = False
72 |
73 | if distributed:
74 | torch.cuda.set_device(args.local_rank)
75 | torch.distributed.init_process_group(backend="nccl", init_method=args.init_method)
76 |
77 | # Load Model
78 | model = LEDNet(19)
79 | model.load_state_dict(torch.load(args.pretrained))
80 | model.keep_shape = True if args.mode == 'testval' else False
81 | model.to(device)
82 |
83 | # testing data
84 | input_transform = transforms.Compose([
85 | transforms.ToTensor(),
86 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
87 | ])
88 |
89 | data_kwargs = {'base_size': args.base_size, 'crop_size': args.crop_size, 'transform': input_transform}
90 |
91 | val_dataset = get_segmentation_dataset(args.dataset, split=args.split, mode=args.mode, **data_kwargs)
92 | sampler = make_data_sampler(val_dataset, False, distributed)
93 | batch_sampler = data.BatchSampler(sampler=sampler, batch_size=args.batch_size, drop_last=False)
94 | val_data = data.DataLoader(val_dataset, shuffle=False, batch_sampler=batch_sampler,
95 | num_workers=args.num_workers)
96 | metric = SegmentationMetric(val_dataset.num_class)
97 |
98 | metric = validate(model, val_data, metric, device)
99 | ptutil.synchronize()
100 | pixAcc, mIoU = ptutil.accumulate_metric(metric)
101 | if ptutil.is_main_process():
102 | print('pixAcc: %.4f, mIoU: %.4f' % (pixAcc, mIoU))
103 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/LEDNet/8887545ef0c0eba8b8e5d92f9452764d7bd55bb3/model/__init__.py
--------------------------------------------------------------------------------
/model/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # helper function
7 | def channel_shuffle(x, groups):
8 | b, n, h, w = x.shape
9 | channels_per_group = n // groups
10 |
11 | # reshape
12 | x = x.view(b, groups, channels_per_group, h, w)
13 | x = torch.transpose(x, 1, 2).contiguous()
14 |
15 | # flatten
16 | x = x.view(b, -1, h, w)
17 |
18 | return x
19 |
20 |
21 | def basic_conv(in_channel, channel, kernel=3, stride=1):
22 | return nn.Sequential(
23 | nn.Conv2d(in_channel, channel, kernel, stride, kernel // 2, bias=False),
24 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True)
25 | )
26 |
27 |
28 | # basic module
29 | # TODO: may add bn and relu
30 | class DownSampling(nn.Module):
31 | def __init__(self, in_channel, out_channel):
32 | super(DownSampling, self).__init__()
33 | self.conv = nn.Conv2d(in_channel, out_channel, 3, stride=2, padding=1)
34 | self.pool = nn.MaxPool2d(2, ceil_mode=True)
35 | self.bn = nn.BatchNorm2d(out_channel + in_channel)
36 |
37 | def forward(self, x):
38 | x1 = self.conv(x)
39 | x2 = self.pool(x)
40 | x = torch.cat([x1, x2], dim=1)
41 | x = F.relu_(self.bn(x))
42 | return x
43 |
44 |
45 | class SSnbt(nn.Module):
46 | def __init__(self, channel, dilate=1, drop_prob=0.01):
47 | super(SSnbt, self).__init__()
48 | channel = channel // 2
49 | self.left = nn.Sequential(
50 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (1, 0)), nn.ReLU(inplace=True),
51 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, 1), bias=False),
52 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
53 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (dilate, 0), dilation=(dilate, 1)),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, dilate), dilation=(1, dilate), bias=False),
56 | nn.BatchNorm2d(channel), nn.Dropout2d(drop_prob, inplace=True)
57 | )
58 | self.right = nn.Sequential(
59 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, 1)), nn.ReLU(inplace=True),
60 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (1, 0), bias=False),
61 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
62 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, dilate), dilation=(1, dilate)),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (dilate, 0), dilation=(dilate, 1), bias=False),
65 | nn.BatchNorm2d(channel), nn.Dropout2d(drop_prob, inplace=True)
66 | )
67 |
68 | def forward(self, x):
69 | x1, x2 = x.split(x.shape[1] // 2, 1)
70 | x1 = self.left(x1)
71 | x2 = self.right(x2)
72 | out = torch.cat([x1, x2], 1)
73 | x = F.relu(out + x)
74 | return channel_shuffle(x, 2)
75 |
76 |
77 | class SSnbtv2(nn.Module):
78 | def __init__(self, channel, dilate=1, drop_prob=0.01):
79 | super(SSnbtv2, self).__init__()
80 | channel = channel // 2
81 | self.left = nn.Sequential(
82 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (1, 0)), nn.ReLU(inplace=True),
83 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, 1), bias=False),
84 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
85 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (dilate, 0), dilation=(dilate, 1)),
86 | nn.ReLU(inplace=True),
87 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, dilate), dilation=(1, dilate), bias=False),
88 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True), nn.Dropout2d(drop_prob, inplace=True)
89 | )
90 | self.right = nn.Sequential(
91 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, 1)), nn.ReLU(inplace=True),
92 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (1, 0), bias=False),
93 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True),
94 | nn.Conv2d(channel, channel, (1, 3), (1, 1), (0, dilate), dilation=(1, dilate)),
95 | nn.ReLU(inplace=True),
96 | nn.Conv2d(channel, channel, (3, 1), (1, 1), (dilate, 0), dilation=(dilate, 1), bias=False),
97 | nn.BatchNorm2d(channel), nn.ReLU(inplace=True), nn.Dropout2d(drop_prob, inplace=True)
98 | )
99 |
100 | def forward(self, x):
101 | x1, x2 = x.split(x.shape[1] // 2, 1)
102 | x1 = self.left(x1)
103 | x2 = self.right(x2)
104 | out = torch.cat([x1, x2], 1)
105 | x = F.relu(out + x)
106 | return channel_shuffle(x, 2)
107 |
108 |
109 | class APN(nn.Module):
110 | def __init__(self, channel, classes):
111 | super(APN, self).__init__()
112 | self.conv1 = basic_conv(channel, channel, 3, 2)
113 | self.conv2 = basic_conv(channel, channel, 5, 2)
114 | self.conv3 = basic_conv(channel, channel, 7, 2)
115 | self.branch1 = basic_conv(channel, classes, 1, 1)
116 | self.branch2 = basic_conv(channel, classes, 1, 1)
117 | self.branch3 = basic_conv(channel, classes, 1, 1)
118 | self.branch4 = basic_conv(channel, classes, 1, 1)
119 | self.branch5 = nn.Sequential(
120 | nn.AdaptiveAvgPool2d(output_size=1),
121 | basic_conv(channel, classes, 1, 1)
122 | )
123 |
124 | def forward(self, x):
125 | _, _, h, w = x.shape
126 | out3 = self.conv1(x)
127 | out2 = self.conv2(out3)
128 | out = self.branch1(self.conv3(out2))
129 | out = F.interpolate(out, size=((h + 3) // 4, (w + 3) // 4), mode='bilinear', align_corners=True)
130 | out = out + self.branch2(out2)
131 | out = F.interpolate(out, size=((h + 1) // 2, (w + 1) // 2), mode='bilinear', align_corners=True)
132 | out = out + self.branch3(out3)
133 | out = F.interpolate(out, size=(h, w), mode='bilinear', align_corners=True)
134 | out = out * self.branch4(x)
135 | out = out + self.branch5(x)
136 | return out
137 |
138 |
139 | if __name__ == '__main__':
140 | # model = DownSampling(32)
141 | # a = torch.randn(1, 32, 512, 256)
142 | # out = model(a)
143 | # print(out.shape)
144 |
145 | # model = SSnbt(10, 2)
146 | # a = torch.randn(1, 20, 10, 10)
147 | # out = model(a)
148 | # print(out.shape)
149 | # model = basic_conv(10, 20, 3, 2)
150 | # a = torch.randn(1, 10, 128, 65)
151 | # out = model(a)
152 | # print(out.shape)
153 |
154 | model = APN(64, 10)
155 | x = torch.randn(2, 64, 127, 65)
156 | out = model(x)
157 | print(out.shape)
158 |
--------------------------------------------------------------------------------
/model/lednet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch.nn.functional as F
3 | from model.basic import DownSampling, SSnbt, APN
4 |
5 |
6 | class LEDNet(nn.Module):
7 | def __init__(self, nclass, drop=0.1):
8 | super(LEDNet, self).__init__()
9 | self.encoder = nn.Sequential(
10 | DownSampling(3, 29), SSnbt(32, 1, 0.1 * drop), SSnbt(32, 1, 0.1 * drop), SSnbt(32, 1, 0.1 * drop),
11 | DownSampling(32, 32), SSnbt(64, 1, 0.1 * drop), SSnbt(64, 1, 0.1 * drop),
12 | DownSampling(64, 64), SSnbt(128, 1, drop), SSnbt(128, 2, drop), SSnbt(128, 5, drop),
13 | SSnbt(128, 9, drop), SSnbt(128, 2, drop), SSnbt(128, 5, drop), SSnbt(128, 9, drop), SSnbt(128, 17, drop)
14 | )
15 | self.decoder = APN(128, nclass)
16 |
17 | def forward(self, x):
18 | _, _, h, w = x.shape
19 | x = self.encoder(x)
20 | x = self.decoder(x)
21 | return F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True)
22 |
23 |
24 | if __name__ == '__main__':
25 | net = LEDNet(21)
26 | import torch
27 |
28 | a = torch.randn(2, 3, 554, 253)
29 | out = net(a)
30 | print(out.shape)
31 |
--------------------------------------------------------------------------------
/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MixSoftmaxCrossEntropyLoss(nn.Module):
7 | def __init__(self, ignore_label=-1, **kwargs):
8 | super(MixSoftmaxCrossEntropyLoss, self).__init__(**kwargs)
9 | self.ignore_label = ignore_label
10 |
11 | def forward(self, preds, target):
12 | return dict(loss=F.cross_entropy(preds, target, ignore_index=self.ignore_label))
13 |
14 |
15 | # TODO: add aux support
16 | class OHEMSoftmaxCrossEntropyLoss(nn.Module):
17 | def __init__(self, ignore_label=-1, thresh=0.6, min_kept=256,
18 | down_ratio=1, reduction='mean', use_weight=False):
19 | super(OHEMSoftmaxCrossEntropyLoss, self).__init__()
20 | self.ignore_label = ignore_label
21 | self.thresh = float(thresh)
22 | self.min_kept = int(min_kept)
23 | self.down_ratio = down_ratio
24 | if use_weight:
25 | weight = torch.FloatTensor(
26 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
27 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
28 | 1.0865, 1.1529, 1.0507])
29 | self.criterion = torch.nn.CrossEntropyLoss(reduction=reduction,
30 | weight=weight,
31 | ignore_index=ignore_label)
32 | else:
33 | self.criterion = torch.nn.CrossEntropyLoss(reduction=reduction,
34 | ignore_index=ignore_label)
35 |
36 | def base_forward(self, pred, target):
37 | b, c, h, w = pred.size()
38 | target = target.view(-1)
39 | valid_mask = target.ne(self.ignore_label)
40 | target = target * valid_mask.long()
41 | num_valid = valid_mask.sum()
42 |
43 | prob = F.softmax(pred, dim=1)
44 | prob = (prob.transpose(0, 1)).reshape(c, -1)
45 |
46 | if self.min_kept < num_valid and num_valid > 0:
47 | prob = prob.masked_fill_(1 - valid_mask, 1)
48 | mask_prob = prob[target, torch.arange(len(target), dtype=torch.long)]
49 | threshold = self.thresh
50 | if self.min_kept > 0:
51 | index = mask_prob.argsort()
52 | threshold_index = index[min(len(index), self.min_kept) - 1]
53 | if mask_prob[threshold_index] > self.thresh:
54 | threshold = mask_prob[threshold_index]
55 | kept_mask = mask_prob.le(threshold)
56 | target = target * kept_mask.long()
57 | valid_mask = valid_mask * kept_mask
58 |
59 | target = target.masked_fill_(1 - valid_mask, self.ignore_label)
60 | target = target.view(b, h, w)
61 |
62 | return self.criterion(pred, target)
63 |
64 | def forward(self, preds, target):
65 | for i, pred in enumerate(preds):
66 | if i == 0:
67 | loss = self.base_forward(pred, target)
68 | else:
69 | loss = loss + self.base_forward(pred, target)
70 | return dict(loss=loss)
71 |
--------------------------------------------------------------------------------
/model/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
3 |
4 |
5 | class WarmupMultiStepLR(MultiStepLR):
6 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3,
7 | warmup_iters=500, last_epoch=-1):
8 | self.warmup_factor = warmup_factor
9 | self.warmup_iters = warmup_iters
10 | super().__init__(optimizer, milestones, gamma, last_epoch)
11 |
12 | def get_lr(self):
13 | if self.last_epoch <= self.warmup_iters:
14 | alpha = self.last_epoch / self.warmup_iters
15 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
16 | # print(self.base_lrs[0]*warmup_factor)
17 | return [lr * warmup_factor for lr in self.base_lrs]
18 | else:
19 | lr = super().get_lr()
20 | return lr
21 |
22 |
23 | class WarmupCosineLR(_LRScheduler):
24 | def __init__(self, optimizer, T_max, warmup_factor=1.0 / 3, warmup_iters=500,
25 | eta_min=0, last_epoch=-1):
26 | self.warmup_factor = warmup_factor
27 | self.warmup_iters = warmup_iters
28 | self.T_max, self.eta_min = T_max, eta_min
29 | super().__init__(optimizer, last_epoch)
30 |
31 | def get_lr(self):
32 | if self.last_epoch <= self.warmup_iters:
33 | alpha = self.last_epoch / self.warmup_iters
34 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
35 | # print(self.base_lrs[0]*warmup_factor)
36 | return [lr * warmup_factor for lr in self.base_lrs]
37 | else:
38 | return [self.eta_min + (base_lr - self.eta_min) *
39 | (1 + math.cos(
40 | math.pi * (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters))) / 2
41 | for base_lr in self.base_lrs]
42 |
43 |
44 | class WarmupPolyLR(_LRScheduler):
45 | def __init__(self, optimizer, T_max, warmup_factor=1.0 / 3, warmup_iters=500,
46 | eta_min=0, power=0.9, last_epoch=-1):
47 | self.warmup_factor = warmup_factor
48 | self.warmup_iters = warmup_iters
49 | self.power = power
50 | self.T_max, self.eta_min = T_max, eta_min
51 | super().__init__(optimizer, last_epoch)
52 |
53 | def get_lr(self):
54 | if self.last_epoch <= self.warmup_iters:
55 | alpha = self.last_epoch / self.warmup_iters
56 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
57 | # print(self.base_lrs[0]*warmup_factor)
58 | return [lr * warmup_factor for lr in self.base_lrs]
59 | else:
60 | return [self.eta_min + (base_lr - self.eta_min) *
61 | math.pow(1 - (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters),
62 | self.power) for base_lr in self.base_lrs]
63 |
64 |
65 | if __name__ == '__main__':
66 | optim = WarmupPolyLR()
67 |
--------------------------------------------------------------------------------
/png/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/LEDNet/8887545ef0c0eba8b8e5d92f9452764d7bd55bb3/png/demo.png
--------------------------------------------------------------------------------
/png/gt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/LEDNet/8887545ef0c0eba8b8e5d92f9452764d7bd55bb3/png/gt.png
--------------------------------------------------------------------------------
/png/output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AceCoooool/LEDNet/8887545ef0c0eba8b8e5d92f9452764d7bd55bb3/png/output.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import datetime
5 | import argparse
6 | from tqdm import tqdm
7 |
8 | import torch
9 | from torch import optim
10 | from torch.backends import cudnn
11 | from torch.utils import data
12 | from torchvision import transforms
13 |
14 | cur_path = os.path.dirname(__file__)
15 | sys.path.insert(0, os.path.join(cur_path, '..'))
16 | import utils as ptutil
17 | from utils.metric_seg import SegmentationMetric
18 | from data import get_segmentation_dataset
19 | from data.sampler import make_data_sampler, IterationBasedBatchSampler
20 | from model.loss import MixSoftmaxCrossEntropyLoss, OHEMSoftmaxCrossEntropyLoss
21 | from model.lr_scheduler import WarmupPolyLR
22 | from model.lednet import LEDNet
23 |
24 |
25 | def parse_args():
26 | """Training Options for Segmentation Experiments"""
27 | parser = argparse.ArgumentParser(description='LEDNet Segmentation')
28 | parser.add_argument('--dataset', type=str, default='citys',
29 | help='dataset name (default: citys)')
30 | parser.add_argument('--workers', '-j', type=int, default=4,
31 | metavar='N', help='dataloader threads')
32 | parser.add_argument('--base-size', type=int, default=512, # 1024
33 | help='base image size')
34 | parser.add_argument('--crop-size', type=int, default=360, # 512
35 | help='crop image size')
36 | parser.add_argument('--train-split', type=str, default='train',
37 | help='dataset train split (default: train)')
38 | parser.add_argument('--drop-rate', type=float, default=0.3,
39 | help='drop rate of SSnbt')
40 | # training hyper params
41 | parser.add_argument('--ohem', type=ptutil.str2bool, default='false',
42 | help='whether using ohem loss')
43 | parser.add_argument('--epochs', type=int, default=240, metavar='N',
44 | help='number of epochs to train (default: 50)')
45 | parser.add_argument('--start_epoch', type=int, default=0,
46 | metavar='N', help='start epochs (default:0)')
47 | parser.add_argument('--batch-size', type=int, default=2,
48 | metavar='N', help='input batch size for \
49 | training (default: 16)')
50 | parser.add_argument('--test-batch-size', type=int, default=1,
51 | metavar='N', help='input batch size for \
52 | testing (default: 16)')
53 | parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
54 | help='learning rate (default: 1e-3)')
55 | parser.add_argument('--momentum', type=float, default=0.9,
56 | metavar='M', help='momentum (default: 0.9)')
57 | parser.add_argument('--weight-decay', type=float, default=1e-4,
58 | metavar='M', help='w-decay (default: 1e-4)')
59 | parser.add_argument('--warmup-iters', type=int, default=200, # 500
60 | help='warmup iterations')
61 | parser.add_argument('--warmup-factor', type=float, default=1.0 / 3,
62 | help='warm up start lr=warmup_factor*lr')
63 | parser.add_argument('--eval-epochs', type=int, default=-1,
64 | help='validate interval')
65 | parser.add_argument('--skip-eval', type=ptutil.str2bool, default='False',
66 | help='whether to skip evaluation')
67 | # cuda and logging
68 | parser.add_argument('--no-cuda', type=ptutil.str2bool, default='False',
69 | help='disables CUDA training')
70 | parser.add_argument('--local_rank', type=int, default=0)
71 | parser.add_argument('--init-method', type=str, default="env://")
72 | parser.add_argument('--dtype', type=str, default='float32',
73 | help='data type for training. default is float32')
74 | # checking point
75 | parser.add_argument('--log-step', type=int, default=1,
76 | help='iteration to show results')
77 | parser.add_argument('--save-epoch', type=int, default=10,
78 | help='epoch interval to save model.')
79 | parser.add_argument('--save-dir', type=str, default=cur_path,
80 | help='Resume from previously saved parameters if not None.')
81 | parser.add_argument('--resume', type=str, default=None,
82 | help='put the path to resuming file if needed')
83 |
84 | # the parser
85 | args = parser.parse_args()
86 |
87 | args.lr = args.lr * args.batch_size
88 | return args
89 |
90 |
91 | class Trainer(object):
92 | def __init__(self, args):
93 | self.device = torch.device(args.device)
94 | # image transform
95 | input_transform = transforms.Compose([
96 | transforms.ToTensor(),
97 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
98 | ])
99 | # dataset and dataloader
100 | data_kwargs = {'transform': input_transform, 'base_size': args.base_size,
101 | 'crop_size': args.crop_size}
102 | trainset = get_segmentation_dataset(
103 | args.dataset, split=args.train_split, mode='train', **data_kwargs)
104 | args.per_iter = len(trainset) // (args.num_gpus * args.batch_size)
105 | args.max_iter = args.epochs * args.per_iter
106 | if args.distributed:
107 | sampler = data.DistributedSampler(trainset)
108 | else:
109 | sampler = data.RandomSampler(trainset)
110 | train_sampler = data.sampler.BatchSampler(sampler, args.batch_size, True)
111 | train_sampler = IterationBasedBatchSampler(train_sampler, num_iterations=args.max_iter)
112 | self.train_loader = data.DataLoader(trainset, batch_sampler=train_sampler, pin_memory=True,
113 | num_workers=args.workers)
114 | if not args.skip_eval or 0 < args.eval_epochs < args.epochs:
115 | valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)
116 | val_sampler = make_data_sampler(valset, False, args.distributed)
117 | val_batch_sampler = data.sampler.BatchSampler(val_sampler, args.test_batch_size, False)
118 | self.valid_loader = data.DataLoader(valset, batch_sampler=val_batch_sampler,
119 | num_workers=args.workers, pin_memory=True)
120 |
121 | # create network
122 | self.net = LEDNet(trainset.NUM_CLASS, args.drop_rate)
123 |
124 | if args.distributed:
125 | self.net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.net)
126 | self.net.to(self.device)
127 | # resume checkpoint if needed
128 | if args.resume is not None:
129 | if os.path.isfile(args.resume):
130 | self.net.load_state_dict(torch.load(args.resume))
131 | else:
132 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
133 |
134 | # create criterion
135 | if args.ohem:
136 | min_kept = args.batch_size * args.crop_size ** 2 // 16
137 | self.criterion = OHEMSoftmaxCrossEntropyLoss(thresh=0.7, min_kept=min_kept, use_weight=False)
138 | else:
139 | self.criterion = MixSoftmaxCrossEntropyLoss()
140 |
141 | # optimizer and lr scheduling
142 | self.optimizer = optim.SGD(self.net.parameters(), lr=args.lr, momentum=args.momentum,
143 | weight_decay=args.weight_decay)
144 | self.scheduler = WarmupPolyLR(self.optimizer, T_max=args.max_iter, warmup_factor=args.warmup_factor,
145 | warmup_iters=args.warmup_iters, power=0.9)
146 |
147 | if args.distributed:
148 | self.net = torch.nn.parallel.DistributedDataParallel(
149 | self.net, device_ids=[args.local_rank], output_device=args.local_rank)
150 |
151 | # evaluation metrics
152 | self.metric = SegmentationMetric(trainset.num_class)
153 | self.args = args
154 |
155 | def training(self):
156 | self.net.train()
157 | save_to_disk = ptutil.get_rank() == 0
158 | start_training_time = time.time()
159 | trained_time = 0
160 | tic = time.time()
161 | end = time.time()
162 | iteration, max_iter = 0, self.args.max_iter
163 | save_iter, eval_iter = self.args.per_iter * self.args.save_epoch, self.args.per_iter * self.args.eval_epochs
164 | # save_iter, eval_iter = 10, 10
165 |
166 | logger.info("Start training, total epochs {:3d} = total iteration: {:6d}".format(self.args.epochs, max_iter))
167 |
168 | for i, (image, target) in enumerate(self.train_loader):
169 | iteration += 1
170 | self.scheduler.step()
171 | self.optimizer.zero_grad()
172 | image, target = image.to(self.device), target.to(self.device)
173 | outputs = self.net(image)
174 | loss_dict = self.criterion(outputs, target)
175 | # reduce losses over all GPUs for logging purposes
176 | loss_dict_reduced = ptutil.reduce_loss_dict(loss_dict)
177 | losses_reduced = sum(loss for loss in loss_dict_reduced.values())
178 |
179 | loss = sum(loss for loss in loss_dict.values())
180 | loss.backward()
181 | self.optimizer.step()
182 | trained_time += time.time() - end
183 | end = time.time()
184 | if iteration % args.log_step == 0:
185 | eta_seconds = int((trained_time / iteration) * (max_iter - iteration))
186 | log_str = ["Iteration {:06d} , Lr: {:.5f}, Cost: {:.2f}s, Eta: {}"
187 | .format(iteration, self.optimizer.param_groups[0]['lr'], time.time() - tic,
188 | str(datetime.timedelta(seconds=eta_seconds))),
189 | "total_loss: {:.3f}".format(losses_reduced.item())]
190 | log_str = ', '.join(log_str)
191 | logger.info(log_str)
192 | tic = time.time()
193 | if save_to_disk and iteration % save_iter == 0:
194 | model_path = os.path.join(self.args.save_dir, "{}_iter_{:06d}.pth"
195 | .format('LEDNet', iteration))
196 | self.save_model(model_path)
197 | # Do eval when training, to trace the mAP changes and see performance improved whether or nor
198 | if args.eval_epochs > 0 and iteration % eval_iter == 0 and not iteration == max_iter:
199 | metrics = self.validate()
200 | ptutil.synchronize()
201 | pixAcc, mIoU = ptutil.accumulate_metric(metrics)
202 | if pixAcc is not None:
203 | logger.info('pixAcc: {:.4f}, mIoU: {:.4f}'.format(pixAcc, mIoU))
204 | self.net.train()
205 | if save_to_disk:
206 | model_path = os.path.join(self.args.save_dir, "{}_iter_{:06d}.pth"
207 | .format('LEDNet', max_iter))
208 | self.save_model(model_path)
209 | # compute training time
210 | total_training_time = int(time.time() - start_training_time)
211 | total_time_str = str(datetime.timedelta(seconds=total_training_time))
212 | logger.info("Total training time: {} ({:.4f} s / it)".format(total_time_str, total_training_time / max_iter))
213 | # eval after training
214 | if not self.args.skip_eval:
215 | metrics = self.validate()
216 | ptutil.synchronize()
217 | pixAcc, mIoU = ptutil.accumulate_metric(metrics)
218 | if pixAcc is not None:
219 | logger.info('After training, pixAcc: {:.4f}, mIoU: {:.4f}'.format(pixAcc, mIoU))
220 |
221 | def validate(self):
222 | # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
223 | self.metric.reset()
224 | torch.cuda.empty_cache()
225 | if isinstance(self.net, torch.nn.parallel.DistributedDataParallel):
226 | model = self.net.module
227 | else:
228 | model = self.net
229 | model.eval()
230 | tbar = tqdm(self.valid_loader)
231 | for i, (image, target) in enumerate(tbar):
232 | # if i == 10: break
233 | image, target = image.to(self.device), target.to(self.device)
234 | with torch.no_grad():
235 | outputs = model(image)
236 | self.metric.update(target, outputs)
237 | return self.metric
238 |
239 | def save_model(self, model_path):
240 | if isinstance(self.net, torch.nn.parallel.DistributedDataParallel):
241 | model = self.net.module
242 | else:
243 | model = self.net
244 | torch.save(model.state_dict(), model_path)
245 | logger.info("Saved checkpoint to {}".format(model_path))
246 |
247 |
248 | if __name__ == "__main__":
249 | args = parse_args()
250 |
251 | # device setting
252 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
253 | args.distributed = num_gpus > 1
254 | args.num_gpus = num_gpus
255 | if not args.no_cuda and torch.cuda.is_available():
256 | torch.backends.cudnn.benchmark = True
257 | args.device = "cuda"
258 | else:
259 | args.distributed = False
260 | args.device = "cpu"
261 | if args.distributed:
262 | torch.cuda.set_device(args.local_rank)
263 | torch.distributed.init_process_group(backend="nccl", init_method=args.init_method)
264 |
265 | args.lr = args.lr * args.num_gpus # scale by num gpus
266 |
267 | logger = ptutil.setup_logger('Segmentation', args.save_dir, ptutil.get_rank(), 'log_seg.txt', 'w')
268 | logger.info("Using {} GPUs".format(num_gpus))
269 | logger.info(args)
270 | trainer = Trainer(args)
271 |
272 | trainer.training()
273 | torch.cuda.empty_cache()
274 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .metric_seg import *
2 | from .util import *
3 | from .logger import *
4 | from .parallel import *
5 | from .visual import *
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 |
6 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", filemode='a'):
7 | logger = logging.getLogger(name)
8 | logger.setLevel(logging.DEBUG)
9 | # don't log results for the non-master process
10 | if distributed_rank > 0:
11 | return logger
12 | ch = logging.StreamHandler(stream=sys.stdout)
13 | ch.setLevel(logging.DEBUG)
14 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
15 | ch.setFormatter(formatter)
16 | logger.addHandler(ch)
17 |
18 | if save_dir:
19 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=filemode)
20 | fh.setLevel(logging.DEBUG)
21 | fh.setFormatter(formatter)
22 | logger.addHandler(fh)
23 |
24 | return logger
25 |
--------------------------------------------------------------------------------
/utils/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def check_label_shapes(labels, preds, wrap=False, shape=False):
5 | """Helper function for checking shape of label and prediction
6 |
7 | Parameters
8 | ----------
9 | labels : list of `tensor`
10 | The labels of the data.
11 |
12 | preds : list of `tensor`
13 | Predicted values.
14 |
15 | wrap : boolean
16 | If True, wrap labels/preds in a list if they are single NDArray
17 |
18 | shape : boolean
19 | If True, check the shape of labels and preds;
20 | Otherwise only check their length.
21 | """
22 | if not shape:
23 | label_shape, pred_shape = len(labels), len(preds)
24 | else:
25 | label_shape, pred_shape = labels.shape, preds.shape
26 |
27 | if label_shape != pred_shape:
28 | raise ValueError("Shape of labels {} does not match shape of "
29 | "predictions {}".format(label_shape, pred_shape))
30 |
31 | if wrap:
32 | if isinstance(labels, torch.Tensor):
33 | labels = [labels]
34 | if isinstance(preds, torch.Tensor):
35 | preds = [preds]
36 |
37 | return labels, preds
38 |
39 |
40 | class EvalMetric(object):
41 | """Base class for all evaluation metrics.
42 |
43 | .. note::
44 |
45 | This is a base class that provides common metric interfaces.
46 | One should not use this class directly, but instead create new metric
47 | classes that extend it.
48 |
49 | Parameters
50 | ----------
51 | name : str
52 | Name of this metric instance for display.
53 | output_names : list of str, or None
54 | Name of predictions that should be used when updating with update_dict.
55 | By default include all predictions.
56 | label_names : list of str, or None
57 | Name of labels that should be used when updating with update_dict.
58 | By default include all labels.
59 | """
60 |
61 | def __init__(self, name, output_names=None,
62 | label_names=None, **kwargs):
63 | self.name = str(name)
64 | self.output_names = output_names
65 | self.label_names = label_names
66 | self._has_global_stats = kwargs.pop("has_global_stats", False)
67 | self._kwargs = kwargs
68 | # self.reset()
69 |
70 | def __str__(self):
71 | return "EvalMetric: {}".format(dict(self.get_name_value()))
72 |
73 | def get_config(self):
74 | """Save configurations of metric. Can be recreated
75 | from configs with metric.create(``**config``)
76 | """
77 | config = self._kwargs.copy()
78 | config.update({
79 | 'metric': self.__class__.__name__,
80 | 'name': self.name,
81 | 'output_names': self.output_names,
82 | 'label_names': self.label_names})
83 | return config
84 |
85 | def update_dict(self, label, pred):
86 | """Update the internal evaluation with named label and pred
87 |
88 | Parameters
89 | ----------
90 | labels : OrderedDict of str -> NDArray
91 | name to array mapping for labels.
92 |
93 | preds : OrderedDict of str -> NDArray
94 | name to array mapping of predicted outputs.
95 | """
96 | if self.output_names is not None:
97 | pred = [pred[name] for name in self.output_names]
98 | else:
99 | pred = list(pred.values())
100 |
101 | if self.label_names is not None:
102 | label = [label[name] for name in self.label_names]
103 | else:
104 | label = list(label.values())
105 |
106 | self.update(label, pred)
107 |
108 | def update(self, labels, preds):
109 | """Updates the internal evaluation result.
110 |
111 | Parameters
112 | ----------
113 | labels : list of `NDArray`
114 | The labels of the data.
115 |
116 | preds : list of `NDArray`
117 | Predicted values.
118 | """
119 | raise NotImplementedError()
120 |
121 | def reset(self):
122 | """Resets the internal evaluation result to initial state."""
123 | self.num_inst = 0
124 | self.sum_metric = 0.0
125 | self.global_num_inst = 0
126 | self.global_sum_metric = 0.0
127 |
128 | def reset_local(self):
129 | """Resets the local portion of the internal evaluation results
130 | to initial state."""
131 | self.num_inst = 0
132 | self.sum_metric = 0.0
133 |
134 | def get(self):
135 | """Gets the current evaluation result.
136 |
137 | Returns
138 | -------
139 | names : list of str
140 | Name of the metrics.
141 | values : list of float
142 | Value of the evaluations.
143 | """
144 | if self.num_inst == 0:
145 | return (self.name, float('nan'))
146 | else:
147 | return (self.name, self.sum_metric / self.num_inst)
148 |
149 | def get_global(self):
150 | """Gets the current global evaluation result.
151 |
152 | Returns
153 | -------
154 | names : list of str
155 | Name of the metrics.
156 | values : list of float
157 | Value of the evaluations.
158 | """
159 | if self._has_global_stats:
160 | if self.global_num_inst == 0:
161 | return (self.name, float('nan'))
162 | else:
163 | return (self.name, self.global_sum_metric / self.global_num_inst)
164 | else:
165 | return self.get()
166 |
167 | def get_name_value(self):
168 | """Returns zipped name and value pairs.
169 |
170 | Returns
171 | -------
172 | list of tuples
173 | A (name, value) tuple list.
174 | """
175 | name, value = self.get()
176 | if not isinstance(name, list):
177 | name = [name]
178 | if not isinstance(value, list):
179 | value = [value]
180 | return list(zip(name, value))
181 |
182 | def get_global_name_value(self):
183 | """Returns zipped name and value pairs for global results.
184 |
185 | Returns
186 | -------
187 | list of tuples
188 | A (name, value) tuple list.
189 | """
190 | if self._has_global_stats:
191 | name, value = self.get_global()
192 | if not isinstance(name, list):
193 | name = [name]
194 | if not isinstance(value, list):
195 | value = [value]
196 | return list(zip(name, value))
197 | else:
198 | return self.get_name_value()
199 |
--------------------------------------------------------------------------------
/utils/metric_seg.py:
--------------------------------------------------------------------------------
1 | """Evaluation Metrics for Semantic Segmentation"""
2 | import torch
3 | from utils.metric import EvalMetric
4 |
5 | __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union']
6 |
7 |
8 | class SegmentationMetric(EvalMetric):
9 | """Computes pixAcc and mIoU metric scores
10 | """
11 |
12 | def __init__(self, nclass):
13 | super(SegmentationMetric, self).__init__('pixAcc & mIoU')
14 | self.nclass = nclass
15 | self.reset()
16 |
17 | def update(self, labels, preds):
18 | """Updates the internal evaluation result.
19 |
20 | Parameters
21 | ----------
22 | labels : 'NDArray' or list of `NDArray`
23 | The labels of the data.
24 |
25 | preds : 'NDArray' or list of `NDArray`
26 | Predicted values.
27 | """
28 |
29 | def evaluate_worker(self, label, pred):
30 | correct, labeled = batch_pix_accuracy(pred, label)
31 | inter, union = batch_intersection_union(pred, label, self.nclass)
32 | self.total_correct += correct
33 | self.total_label += labeled
34 | if self.total_inter.device != inter.device:
35 | self.total_inter = self.total_inter.to(inter.device)
36 | self.total_union = self.total_union.to(union.device)
37 | self.total_inter += inter
38 | self.total_union += union
39 |
40 | if isinstance(preds, torch.Tensor):
41 | evaluate_worker(self, labels, preds)
42 | elif isinstance(preds, (list, tuple)):
43 | for (label, pred) in zip(labels, preds):
44 | evaluate_worker(self, label, pred)
45 |
46 | def get(self):
47 | """Gets the current evaluation result.
48 |
49 | Returns
50 | -------
51 | metrics : tuple of float
52 | pixAcc and mIoU
53 | """
54 | pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label)
55 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
56 | mIoU = IoU.mean().item()
57 | return pixAcc, mIoU
58 |
59 | def reset(self):
60 | """Resets the internal evaluation result to initial state."""
61 | self.total_inter = torch.zeros(self.nclass)
62 | self.total_union = torch.zeros(self.nclass)
63 | self.total_correct = 0
64 | self.total_label = 0
65 |
66 | def get_value(self):
67 | return {'total_inter': self.total_inter, 'total_union': self.total_union,
68 | 'total_correct': self.total_correct, 'total_label': self.total_label}
69 |
70 | def combine_value(self, values):
71 | if self.total_inter.is_cuda:
72 | device = torch.device('cuda')
73 | self.total_inter += values['total_inter'].to(device)
74 | self.total_union += values['total_union'].to(device)
75 | else:
76 | self.total_inter += values['total_inter']
77 | self.total_union += values['total_union']
78 | self.total_correct += values['total_correct']
79 | self.total_label += values['total_label']
80 |
81 | # def combine_metric(self, metric):
82 | # if self.total_inter.is_cuda:
83 | # metric.total_inter = metric.total_inter.to(self.total_inter.device)
84 | # self.total_inter += metric.total_inter
85 | # metric.total_union = metric.total_union.to(self.total_union.device)
86 | # self.total_union += metric.total_union
87 | # else:
88 | # self.total_inter += metric.total_inter
89 | # self.total_union += metric.total_union
90 | # self.total_correct += metric.total_correct
91 | # self.total_label += metric.total_label
92 |
93 |
94 | def batch_pix_accuracy(output, target):
95 | """PixAcc"""
96 | # inputs are NDarray, output 4D, target 3D
97 | # the category -1 is ignored class, typically for background / boundary
98 | predict = torch.argmax(output.long(), 1) + 1
99 |
100 | target = target.long() + 1
101 |
102 | pixel_labeled = torch.sum(target > 0).item()
103 | pixel_correct = torch.sum((predict == target) * (target > 0)).item()
104 |
105 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
106 | return pixel_correct, pixel_labeled
107 |
108 |
109 | def batch_intersection_union(output, target, nclass):
110 | """mIoU"""
111 | # inputs are NDarray, output 4D, target 3D
112 | # the category -1 is ignored class, typically for background / boundary
113 | mini = 1
114 | maxi = nclass
115 | nbins = nclass
116 | predict = torch.argmax(output, 1) + 1
117 | target = target.float() + 1
118 |
119 | predict = predict.float() * (target > 0).float()
120 | intersection = predict * (predict == target).float()
121 | # areas of intersection and union
122 | area_inter = torch.histc(intersection, bins=nbins, min=mini, max=maxi)
123 | area_pred = torch.histc(predict, bins=nbins, min=mini, max=maxi)
124 | area_lab = torch.histc(target, bins=nbins, min=mini, max=maxi)
125 | area_union = area_pred + area_lab - area_inter
126 | assert torch.sum(area_inter > area_union).item() == 0, \
127 | "Intersection area should be smaller than Union area"
128 | return area_inter.float(), area_union.float()
129 |
130 |
131 | if __name__ == '__main__':
132 | a = torch.Tensor([[1.0, 2.0], [3.0, 4.0]]).cuda()
133 | b = torch.LongTensor([[1, 3], [3, 4]]).cuda()
134 | metric = SegmentationMetric(4)
135 | metric.update(a, b)
136 | print(metric.get())
137 |
--------------------------------------------------------------------------------
/utils/parallel.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 |
7 | def get_world_size():
8 | if not torch.distributed.is_initialized():
9 | return 1
10 | return torch.distributed.get_world_size()
11 |
12 |
13 | def get_rank():
14 | if not torch.distributed.is_initialized():
15 | return 0
16 | return torch.distributed.get_rank()
17 |
18 |
19 | def is_main_process():
20 | return get_rank() == 0
21 |
22 |
23 | def synchronize():
24 | """
25 | Helper function to synchronize (barrier) among all processes when
26 | using distributed training
27 | """
28 | if not dist.is_available():
29 | return
30 | if not dist.is_initialized():
31 | return
32 | world_size = dist.get_world_size()
33 | if world_size == 1:
34 | return
35 | dist.barrier()
36 |
37 |
38 | def all_gather(data):
39 | """
40 | Run all_gather on arbitrary picklable data (not necessarily tensors)
41 | Args:
42 | data: any picklable object
43 | Returns:
44 | list[data]: list of data gathered from each rank
45 | """
46 | world_size = get_world_size()
47 | if world_size == 1:
48 | return [data]
49 |
50 | # serialized to a Tensor
51 | buffer = pickle.dumps(data)
52 | storage = torch.ByteStorage.from_buffer(buffer)
53 | tensor = torch.ByteTensor(storage).to("cuda")
54 |
55 | # obtain Tensor size of each rank
56 | local_size = torch.IntTensor([tensor.numel()]).to("cuda")
57 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
58 | dist.all_gather(size_list, local_size)
59 | size_list = [int(size.item()) for size in size_list]
60 | max_size = max(size_list)
61 |
62 | # receiving Tensor from all ranks
63 | # we pad the tensor because torch all_gather does not support
64 | # gathering tensors of different shapes
65 | tensor_list = []
66 | for _ in size_list:
67 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
68 | if local_size != max_size:
69 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
70 | tensor = torch.cat((tensor, padding), dim=0)
71 | dist.all_gather(tensor_list, tensor)
72 |
73 | data_list = []
74 | for size, tensor in zip(size_list, tensor_list):
75 | buffer = tensor.cpu().numpy().tobytes()[:size]
76 | data_list.append(pickle.loads(buffer))
77 |
78 | return data_list
79 |
80 |
81 | def reduce_loss_dict(loss_dict):
82 | """
83 | Reduce the loss dictionary from all processes so that process with rank
84 | 0 has the averaged results. Returns a dict with the same fields as
85 | loss_dict, after reduction.
86 | """
87 | world_size = get_world_size()
88 | if world_size < 2:
89 | return loss_dict
90 | with torch.no_grad():
91 | loss_names = []
92 | all_losses = []
93 | for k in sorted(loss_dict.keys()):
94 | loss_names.append(k)
95 | all_losses.append(loss_dict[k])
96 | all_losses = torch.stack(all_losses, dim=0)
97 | dist.reduce(all_losses, dst=0)
98 | if dist.get_rank() == 0:
99 | # only main process gets accumulated, so only divide by
100 | # world_size in this case
101 | all_losses /= world_size
102 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
103 | return reduced_losses
104 |
105 |
106 | def accumulate_metric(metric):
107 | all_values = all_gather(metric.get_value())
108 | if not is_main_process():
109 | return None, None
110 | for value in all_values[1:]:
111 | metric.combine_value(value)
112 | return metric.get()
113 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def str2bool(v):
5 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
6 | return True
7 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
8 | return False
9 | else:
10 | raise argparse.ArgumentTypeError('Boolean value expected.')
11 |
--------------------------------------------------------------------------------
/utils/visual.py:
--------------------------------------------------------------------------------
1 | """Segmentation Utils"""
2 | from PIL import Image
3 |
4 | __all__ = ['get_color_pallete']
5 |
6 |
7 | def get_color_pallete(npimg, dataset='pascal_voc'):
8 | """Visualize image.
9 |
10 | Parameters
11 | ----------
12 | npimg : numpy.ndarray
13 | Single channel image with shape `H, W, 1`.
14 | dataset : str, default: 'pascal_voc'
15 | The dataset that model pretrained on. ('pascal_voc', 'ade20k')
16 |
17 | Returns
18 | -------
19 | out_img : PIL.Image
20 | Image with color pallete
21 |
22 | """
23 | # recovery boundary
24 | if dataset in ('pascal_voc', 'pascal_aug'):
25 | npimg[npimg == -1] = 255
26 | # put colormap
27 | if dataset == 'ade20k':
28 | npimg = npimg + 1
29 | out_img = Image.fromarray(npimg.astype('uint8'))
30 | out_img.putpalette(adepallete)
31 | return out_img
32 | elif dataset == 'citys':
33 | out_img = Image.fromarray(npimg.astype('uint8'))
34 | out_img.putpalette(cityspallete)
35 | return out_img
36 | out_img = Image.fromarray(npimg.astype('uint8'))
37 | out_img.putpalette(vocpallete)
38 | return out_img
39 |
40 |
41 | def _getvocpallete(num_cls):
42 | n = num_cls
43 | pallete = [0] * (n * 3)
44 | for j in range(0, n):
45 | lab = j
46 | pallete[j * 3 + 0] = 0
47 | pallete[j * 3 + 1] = 0
48 | pallete[j * 3 + 2] = 0
49 | i = 0
50 | while (lab > 0):
51 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
52 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
53 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
54 | i = i + 1
55 | lab >>= 3
56 | return pallete
57 |
58 |
59 | vocpallete = _getvocpallete(256)
60 |
61 | adepallete = [
62 | 0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204,
63 | 5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82,
64 | 143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255,
65 | 7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6,
66 | 10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255,
67 | 20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15,
68 | 20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255,
69 | 31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163,
70 | 0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255,
71 | 0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0,
72 | 31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0,
73 | 194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255,
74 | 0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255,
75 | 0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0,
76 | 163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0,
77 | 10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0,
78 | 255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0,
79 | 133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255]
80 |
81 | cityspallete = [
82 | 128, 64, 128,
83 | 244, 35, 232,
84 | 70, 70, 70,
85 | 102, 102, 156,
86 | 190, 153, 153,
87 | 153, 153, 153,
88 | 250, 170, 30,
89 | 220, 220, 0,
90 | 107, 142, 35,
91 | 152, 251, 152,
92 | 0, 130, 180,
93 | 220, 20, 60,
94 | 255, 0, 0,
95 | 0, 0, 142,
96 | 0, 0, 70,
97 | 0, 60, 100,
98 | 0, 80, 100,
99 | 0, 0, 230,
100 | 119, 11, 32,
101 | ]
102 |
--------------------------------------------------------------------------------