├── .gitignore
├── LICENSE
├── README.md
├── README_zh-CN.md
├── core
├── __init__.py
├── data
│ ├── __init__.py
│ ├── dataloader
│ │ ├── __init__.py
│ │ ├── ade.py
│ │ ├── cityscapes.py
│ │ ├── lip_parsing.py
│ │ ├── mscoco.py
│ │ ├── pascal_aug.py
│ │ ├── pascal_voc.py
│ │ ├── sbu_shadow.py
│ │ ├── segbase.py
│ │ └── utils.py
│ └── downloader
│ │ ├── __init__.py
│ │ ├── ade20k.py
│ │ ├── cityscapes.py
│ │ ├── mscoco.py
│ │ ├── pascal_voc.py
│ │ └── sbu_shadow.py
├── models
│ ├── __init__.py
│ ├── base_models
│ │ ├── __init__.py
│ │ ├── densenet.py
│ │ ├── eespnet.py
│ │ ├── hrnet.py
│ │ ├── mobilenetv2.py
│ │ ├── resnet.py
│ │ ├── resnetv1b.py
│ │ ├── resnext.py
│ │ ├── vgg.py
│ │ └── xception.py
│ ├── bisenet.py
│ ├── cgnet.py
│ ├── danet.py
│ ├── deeplabv3.py
│ ├── deeplabv3_plus.py
│ ├── denseaspp.py
│ ├── dfanet.py
│ ├── dunet.py
│ ├── encnet.py
│ ├── enet.py
│ ├── espnet.py
│ ├── fcn.py
│ ├── fcnv2.py
│ ├── hrnet.py
│ ├── icnet.py
│ ├── lednet.py
│ ├── model_store.py
│ ├── model_zoo.py
│ ├── ocnet.py
│ ├── psanet.py
│ ├── pspnet.py
│ ├── segbase.py
│ └── swnet.py
├── nn
│ ├── __init__.py
│ ├── basic.py
│ └── jpu.py
└── utils
│ ├── __init__.py
│ ├── distributed.py
│ ├── download.py
│ ├── filesystem.py
│ ├── logger.py
│ ├── loss.py
│ ├── lr_scheduler.py
│ ├── parallel.py
│ ├── score.py
│ └── visualize.py
├── datasets
├── ade
├── citys
├── sbu
└── voc
├── docs
├── DETAILS.md
├── QQ.jpg
├── WeChat.jpeg
├── requirements.yml
└── weimar_000091_000019_gtFine_color.png
├── requirements.txt
├── scripts
├── demo.py
├── eval.py
├── fcn32s_vgg16_pascal_voc.sh
├── fcn32s_vgg16_pascal_voc_dist.sh
└── train.py
└── tests
├── README.md
├── runs
├── bisenet_epoch_100.png
├── danet_epoch_100.png
├── denseaspp_epoch_40.png
├── dunet_epoch_100.png
├── encnet_epoch_100.png
├── enet_epoch_100.png
├── fcn16s_epoch_200.png
├── fcn32s_epoch_300.png
├── fcn8s_epoch_100.png
├── icnet_epoch_100.png
├── ocnet_epoch_100.png
└── psp_epoch_100.png
├── test_img.jpg
├── test_mask.png
├── test_model.py
└── test_module.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib64/
18 | parts/
19 | sdist/
20 | var/
21 | wheels/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 | *.idea
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # pycharm
107 |
108 | # premodel
109 | weights/
110 | *.pkl
111 | *.pth
112 |
113 | # dataset
114 | datasets/
115 | VOCdevket/
116 | eval/
117 |
118 | # overfitting test
119 |
120 | # run result
121 | /tests/runs
122 | /runs
123 |
124 | # model
125 | /models/hrnet.py
126 | /models/psanet_old.py
127 | /scripts/debug.py
128 |
129 | # nn
130 | nn/sync_bn/
131 |
132 | # venv
133 | AwsmSemSegPytorch-env/
134 | .vscode/launch.json
135 | .vscode/settings.json
136 |
137 |
138 | # builded files
139 | core/nn/sync_bn/lib/gpu/build.ninja
140 | core/nn/sync_bn/lib/gpu/.ninja_log
141 | core/nn/sync_bn/lib/gpu/.ninja_deps
142 |
143 |
--------------------------------------------------------------------------------
/README_zh-CN.md:
--------------------------------------------------------------------------------
1 | ## 中文说明
2 |
3 | [English](/README.md) | 简体中文
4 |
5 | [![python-image]][python-url]
6 | [![pytorch-image]][pytorch-url]
7 | [![lic-image]][lic-url]
8 |
9 |

10 |
11 | 该项目旨在提供一个基于[PyTorch](https://pytorch.org/)的简洁、易用且可扩展的语义分割工具箱。
12 |
13 | 主分支代码目前支持 **PyTorch 1.1.0 以上**的版本
14 |
15 | ## 安装
16 |
17 | ```
18 | # semantic-segmentation-pytorch dependencies
19 | pip install ninja tqdm
20 |
21 | # follow PyTorch installation in https://pytorch.org/get-started/locally/
22 | conda install pytorch torchvision -c pytorch
23 |
24 | # install PyTorch Segmentation
25 | git clone https://github.com/Tramac/awesome-semantic-segmentation-pytorch.git
26 | ```
27 |
28 | ## 使用方式
29 |
30 | ### 训练
31 | -----------------
32 | - **单卡训练**
33 | ```
34 | # for example, train fcn32_vgg16_pascal_voc:
35 | python train.py --model fcn32s --backbone vgg16 --dataset pascal_voc --lr 0.0001 --epochs 50
36 | ```
37 | - **分布式训练**
38 | ```
39 | # for example, train fcn32_vgg16_pascal_voc with 4 GPUs:
40 | export NGPUS=4
41 | python -m torch.distributed.launch --nproc_per_node=$NGPUS train.py --model fcn32s --backbone vgg16 --dataset pascal_voc --lr 0.0001 --epochs 50
42 | ```
43 |
44 | ### 测试
45 | -----------------
46 | - **单卡测试**
47 | ```
48 | # for example, evaluate fcn32_vgg16_pascal_voc
49 | python eval.py --model fcn32s --backbone vgg16 --dataset pascal_voc
50 | ```
51 | - **多卡测试**
52 | ```
53 | # for example, evaluate fcn32_vgg16_pascal_voc with 4 GPUs:
54 | export NGPUS=4
55 | python -m torch.distributed.launch --nproc_per_node=$NGPUS eval.py --model fcn32s --backbone vgg16 --dataset pascal_voc
56 | ```
57 |
58 | ### Demo
59 | ```
60 | cd ./scripts
61 | #for new users:
62 | python demo.py --model fcn32s_vgg16_voc --input-pic ../tests/test_img.jpg
63 | #you should add 'test.jpg' by yourself
64 | python demo.py --model fcn32s_vgg16_voc --input-pic ../datasets/test.jpg
65 | ```
66 |
67 | ### 模型库
68 | -------------------------------------
69 | - [FCN](https://arxiv.org/abs/1411.4038)
70 | - [ENet](https://arxiv.org/pdf/1606.02147)
71 | - [PSPNet](https://arxiv.org/pdf/1612.01105)
72 | - [ICNet](https://arxiv.org/pdf/1704.08545)
73 | - [DeepLabv3](https://arxiv.org/abs/1706.05587)
74 | - [DeepLabv3+](https://arxiv.org/pdf/1802.02611)
75 | - [DenseASPP](http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_DenseASPP_for_Semantic_CVPR_2018_paper.pdf)
76 | - [EncNet](https://arxiv.org/abs/1803.08904v1)
77 | - [BiSeNet](https://arxiv.org/abs/1808.00897)
78 | - [PSANet](https://hszhao.github.io/papers/eccv18_psanet.pdf)
79 | - [DANet](https://arxiv.org/pdf/1809.02983)
80 | - [OCNet](https://arxiv.org/pdf/1809.00916)
81 | - [CGNet](https://arxiv.org/pdf/1811.08201)
82 | - [ESPNetv2](https://arxiv.org/abs/1811.11431)
83 | - [DUNet(DUpsampling)](https://arxiv.org/abs/1903.02120)
84 | - [FastFCN(JPU)](https://arxiv.org/abs/1903.11816)
85 | - [LEDNet](https://arxiv.org/abs/1905.02423)
86 | - [Fast-SCNN](https://github.com/Tramac/Fast-SCNN-pytorch)
87 | - [LightSeg](https://github.com/Tramac/Lightweight-Segmentation)
88 | - [DFANet](https://arxiv.org/abs/1904.02216)
89 |
90 | Model与Backbone的支持详情可见[这里](https://github.com/Tramac/awesome-semantic-segmentation-pytorch/blob/master/docs/DETAILS.md)。
91 |
92 | ```
93 | .{SEG_ROOT}
94 | ├── core
95 | │ ├── models
96 | │ │ ├── bisenet.py
97 | │ │ ├── danet.py
98 | │ │ ├── deeplabv3.py
99 | │ │ ├── deeplabv3+.py
100 | │ │ ├── denseaspp.py
101 | │ │ ├── dunet.py
102 | │ │ ├── encnet.py
103 | │ │ ├── fcn.py
104 | │ │ ├── pspnet.py
105 | │ │ ├── icnet.py
106 | │ │ ├── enet.py
107 | │ │ ├── ocnet.py
108 | │ │ ├── psanet.py
109 | │ │ ├── cgnet.py
110 | │ │ ├── espnet.py
111 | │ │ ├── lednet.py
112 | │ │ ├── dfanet.py
113 | │ │ ├── ......
114 | ```
115 |
116 | ### 数据集
117 | 可以选择以下方式下载指定数据集,比如:
118 | ```
119 | cd ./core/data/downloader
120 | python ade20k.py --download-dir ../datasets/ade
121 | ```
122 |
123 | | Dataset | training set | validation set | testing set |
124 | | :----------------------------------------------------------: | :----------: | :------------: | :---------: |
125 | | [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) | 1464 | 1449 | ✘ |
126 | | [VOCAug](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz) | 11355 | 2857 | ✘ |
127 | | [ADK20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/) | 20210 | 2000 | ✘ |
128 | | [Cityscapes](https://www.cityscapes-dataset.com/downloads/) | 2975 | 500 | ✘ |
129 | | [COCO](http://cocodataset.org/#download) | | | |
130 | | [SBU-shadow](http://www3.cs.stonybrook.edu/~cvl/content/datasets/shadow_db/SBU-shadow.zip) | 4085 | 638 | ✘ |
131 | | [LIP(Look into Person)](http://sysu-hcp.net/lip/) | 30462 | 10000 | 10000 |
132 |
133 | ```
134 | .{SEG_ROOT}
135 | ├── core
136 | │ ├── data
137 | │ │ ├── dataloader
138 | │ │ │ ├── ade.py
139 | │ │ │ ├── cityscapes.py
140 | │ │ │ ├── mscoco.py
141 | │ │ │ ├── pascal_aug.py
142 | │ │ │ ├── pascal_voc.py
143 | │ │ │ ├── sbu_shadow.py
144 | │ │ └── downloader
145 | │ │ ├── ade20k.py
146 | │ │ ├── cityscapes.py
147 | │ │ ├── mscoco.py
148 | │ │ ├── pascal_voc.py
149 | │ │ └── sbu_shadow.py
150 | ```
151 |
152 | ## 部分结果
153 | |Methods|Backbone|TrainSet|EvalSet|crops_size|epochs|JPU|Mean IoU|pixAcc|
154 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
155 | |FCN32s|vgg16|train|val|480|60|✘|47.50|85.39|
156 | |FCN16s|vgg16|train|val|480|60|✘|49.16|85.98|
157 | |FCN8s|vgg16|train|val|480|60|✘|48.87|85.02|
158 | |FCN32s|resnet50|train|val|480|50|✘|54.60|88.57|
159 | |PSPNet|resnet50|train|val|480|60|✘|63.44|89.78|
160 | |DeepLabv3|resnet50|train|val|480|60|✘|60.15|88.36|
161 |
162 | `lr=1e-4, batch_size=4, epochs=80`.
163 | 注意: 以上结果均基于`train.py`中的默认参数所得,更优的效果请参照paper中具体参数。
164 |
165 | ## 版本
166 |
167 | - v0.1.0:相较于master分支,该版本包含了`ccnet`与`psanet`,需要依赖编译产出自定义层,如有需要按照分支说明操作即可。
168 |
169 | ## To Do
170 |
171 | - [x] add train script
172 | - [ ] remove syncbn
173 | - [ ] train & evaluate
174 | - [x] test distributed training
175 | - [x] fix syncbn ([Why SyncBN?](https://tramac.github.io/2019/02/25/%E8%B7%A8%E5%8D%A1%E5%90%8C%E6%AD%A5%20Batch%20Normalization[%E8%BD%AC]/))
176 | - [x] add distributed ([How DIST?]("https://tramac.github.io/2019/03/06/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83-PyTorch/"))
177 |
178 | ## 参考
179 | - [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding)
180 | - [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark)
181 | - [gloun-cv](https://github.com/dmlc/gluon-cv)
182 | - [imagenet](https://github.com/pytorch/examples/tree/master/imagenet)
183 |
184 | ## WeChat&QQ
185 | 由于很多小伙伴通过知乎私信寻求一些帮助,同时也不忍心忽略到大家的问题,下面分别提供了针对该项目的微信&QQ群,希望可以帮助到有需要的同学~
186 | 

187 |
188 |
193 |
194 | [python-image]: https://img.shields.io/badge/Python-2.x|3.x-ff69b4.svg
195 | [python-url]: https://www.python.org/
196 | [pytorch-image]: https://img.shields.io/badge/PyTorch-1.1-2BAF2B.svg
197 | [pytorch-url]: https://pytorch.org/
198 | [lic-image]: https://img.shields.io/badge/Apache-2.0-blue.svg
199 | [lic-url]: https://github.com/Tramac/Awesome-semantic-segmentation-pytorch/blob/master/LICENSE
200 |
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
1 | from . import nn, models, utils, data
--------------------------------------------------------------------------------
/core/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/core/data/__init__.py
--------------------------------------------------------------------------------
/core/data/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | This module provides data loaders and transformers for popular vision datasets.
3 | """
4 | from .mscoco import COCOSegmentation
5 | from .cityscapes import CitySegmentation
6 | from .ade import ADE20KSegmentation
7 | from .pascal_voc import VOCSegmentation
8 | from .pascal_aug import VOCAugSegmentation
9 | from .sbu_shadow import SBUSegmentation
10 |
11 | datasets = {
12 | 'ade20k': ADE20KSegmentation,
13 | 'pascal_voc': VOCSegmentation,
14 | 'pascal_aug': VOCAugSegmentation,
15 | 'coco': COCOSegmentation,
16 | 'citys': CitySegmentation,
17 | 'sbu': SBUSegmentation,
18 | }
19 |
20 |
21 | def get_segmentation_dataset(name, **kwargs):
22 | """Segmentation Datasets"""
23 | return datasets[name.lower()](**kwargs)
24 |
--------------------------------------------------------------------------------
/core/data/dataloader/cityscapes.py:
--------------------------------------------------------------------------------
1 | """Prepare Cityscapes dataset"""
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 | from PIL import Image
7 | from .segbase import SegmentationDataset
8 |
9 |
10 | class CitySegmentation(SegmentationDataset):
11 | """Cityscapes Semantic Segmentation Dataset.
12 |
13 | Parameters
14 | ----------
15 | root : string
16 | Path to Cityscapes folder. Default is './datasets/citys'
17 | split: string
18 | 'train', 'val' or 'test'
19 | transform : callable, optional
20 | A function that transforms the image
21 | Examples
22 | --------
23 | >>> from torchvision import transforms
24 | >>> import torch.utils.data as data
25 | >>> # Transforms for Normalization
26 | >>> input_transform = transforms.Compose([
27 | >>> transforms.ToTensor(),
28 | >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
29 | >>> ])
30 | >>> # Create Dataset
31 | >>> trainset = CitySegmentation(split='train', transform=input_transform)
32 | >>> # Create Training Loader
33 | >>> train_data = data.DataLoader(
34 | >>> trainset, 4, shuffle=True,
35 | >>> num_workers=4)
36 | """
37 | BASE_DIR = 'cityscapes'
38 | NUM_CLASS = 19
39 |
40 | def __init__(self, root='../datasets/citys', split='train', mode=None, transform=None, **kwargs):
41 | super(CitySegmentation, self).__init__(root, split, mode, transform, **kwargs)
42 | # self.root = os.path.join(root, self.BASE_DIR)
43 | assert os.path.exists(self.root), "Please setup the dataset using ../datasets/cityscapes.py"
44 | self.images, self.mask_paths = _get_city_pairs(self.root, self.split)
45 | assert (len(self.images) == len(self.mask_paths))
46 | if len(self.images) == 0:
47 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n")
48 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22,
49 | 23, 24, 25, 26, 27, 28, 31, 32, 33]
50 | self._key = np.array([-1, -1, -1, -1, -1, -1,
51 | -1, -1, 0, 1, -1, -1,
52 | 2, 3, 4, -1, -1, -1,
53 | 5, -1, 6, 7, 8, 9,
54 | 10, 11, 12, 13, 14, 15,
55 | -1, -1, 16, 17, 18])
56 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32')
57 |
58 | def _class_to_index(self, mask):
59 | # assert the value
60 | values = np.unique(mask)
61 | for value in values:
62 | assert (value in self._mapping)
63 | index = np.digitize(mask.ravel(), self._mapping, right=True)
64 | return self._key[index].reshape(mask.shape)
65 |
66 | def __getitem__(self, index):
67 | img = Image.open(self.images[index]).convert('RGB')
68 | if self.mode == 'test':
69 | if self.transform is not None:
70 | img = self.transform(img)
71 | return img, os.path.basename(self.images[index])
72 | mask = Image.open(self.mask_paths[index])
73 | # synchrosized transform
74 | if self.mode == 'train':
75 | img, mask = self._sync_transform(img, mask)
76 | elif self.mode == 'val':
77 | img, mask = self._val_sync_transform(img, mask)
78 | else:
79 | assert self.mode == 'testval'
80 | img, mask = self._img_transform(img), self._mask_transform(mask)
81 | # general resize, normalize and toTensor
82 | if self.transform is not None:
83 | img = self.transform(img)
84 | return img, mask, os.path.basename(self.images[index])
85 |
86 | def _mask_transform(self, mask):
87 | target = self._class_to_index(np.array(mask).astype('int32'))
88 | return torch.LongTensor(np.array(target).astype('int32'))
89 |
90 | def __len__(self):
91 | return len(self.images)
92 |
93 | @property
94 | def pred_offset(self):
95 | return 0
96 |
97 |
98 | def _get_city_pairs(folder, split='train'):
99 | def get_path_pairs(img_folder, mask_folder):
100 | img_paths = []
101 | mask_paths = []
102 | for root, _, files in os.walk(img_folder):
103 | for filename in files:
104 | if filename.endswith('.png'):
105 | imgpath = os.path.join(root, filename)
106 | foldername = os.path.basename(os.path.dirname(imgpath))
107 | maskname = filename.replace('leftImg8bit', 'gtFine_labelIds')
108 | maskpath = os.path.join(mask_folder, foldername, maskname)
109 | if os.path.isfile(imgpath) and os.path.isfile(maskpath):
110 | img_paths.append(imgpath)
111 | mask_paths.append(maskpath)
112 | else:
113 | print('cannot find the mask or image:', imgpath, maskpath)
114 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder))
115 | return img_paths, mask_paths
116 |
117 | if split in ('train', 'val'):
118 | img_folder = os.path.join(folder, 'leftImg8bit/' + split)
119 | mask_folder = os.path.join(folder, 'gtFine/' + split)
120 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
121 | return img_paths, mask_paths
122 | else:
123 | assert split == 'trainval'
124 | print('trainval set')
125 | train_img_folder = os.path.join(folder, 'leftImg8bit/train')
126 | train_mask_folder = os.path.join(folder, 'gtFine/train')
127 | val_img_folder = os.path.join(folder, 'leftImg8bit/val')
128 | val_mask_folder = os.path.join(folder, 'gtFine/val')
129 | train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder)
130 | val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder)
131 | img_paths = train_img_paths + val_img_paths
132 | mask_paths = train_mask_paths + val_mask_paths
133 | return img_paths, mask_paths
134 |
135 |
136 | if __name__ == '__main__':
137 | dataset = CitySegmentation()
138 |
--------------------------------------------------------------------------------
/core/data/dataloader/lip_parsing.py:
--------------------------------------------------------------------------------
1 | """Look into Person Dataset"""
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 | from PIL import Image
7 | from core.data.dataloader.segbase import SegmentationDataset
8 |
9 |
10 | class LIPSegmentation(SegmentationDataset):
11 | """Look into person parsing dataset """
12 |
13 | BASE_DIR = 'LIP'
14 | NUM_CLASS = 20
15 |
16 | def __init__(self, root='../datasets/LIP', split='train', mode=None, transform=None, **kwargs):
17 | super(LIPSegmentation, self).__init__(root, split, mode, transform, **kwargs)
18 | _trainval_image_dir = os.path.join(root, 'TrainVal_images')
19 | _testing_image_dir = os.path.join(root, 'Testing_images')
20 | _trainval_mask_dir = os.path.join(root, 'TrainVal_parsing_annotations')
21 | if split == 'train':
22 | _image_dir = os.path.join(_trainval_image_dir, 'train_images')
23 | _mask_dir = os.path.join(_trainval_mask_dir, 'train_segmentations')
24 | _split_f = os.path.join(_trainval_image_dir, 'train_id.txt')
25 | elif split == 'val':
26 | _image_dir = os.path.join(_trainval_image_dir, 'val_images')
27 | _mask_dir = os.path.join(_trainval_mask_dir, 'val_segmentations')
28 | _split_f = os.path.join(_trainval_image_dir, 'val_id.txt')
29 | elif split == 'test':
30 | _image_dir = os.path.join(_testing_image_dir, 'testing_images')
31 | _split_f = os.path.join(_testing_image_dir, 'test_id.txt')
32 | else:
33 | raise RuntimeError('Unknown dataset split.')
34 |
35 | self.images = []
36 | self.masks = []
37 | with open(os.path.join(_split_f), 'r') as lines:
38 | for line in lines:
39 | _image = os.path.join(_image_dir, line.rstrip('\n') + '.jpg')
40 | assert os.path.isfile(_image)
41 | self.images.append(_image)
42 | if split != 'test':
43 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + '.png')
44 | assert os.path.isfile(_mask)
45 | self.masks.append(_mask)
46 |
47 | if split != 'test':
48 | assert (len(self.images) == len(self.masks))
49 | print('Found {} {} images in the folder {}'.format(len(self.images), split, root))
50 |
51 | def __getitem__(self, index):
52 | img = Image.open(self.images[index]).convert('RGB')
53 | if self.mode == 'test':
54 | img = self._img_transform(img)
55 | if self.transform is not None:
56 | img = self.transform(img)
57 | return img, os.path.basename(self.images[index])
58 | mask = Image.open(self.masks[index])
59 | # synchronized transform
60 | if self.mode == 'train':
61 | img, mask = self._sync_transform(img, mask)
62 | elif self.mode == 'val':
63 | img, mask = self._val_sync_transform(img, mask)
64 | else:
65 | assert self.mode == 'testval'
66 | img, mask = self._img_transform(img), self._mask_transform(mask)
67 | # general resize, normalize and toTensor
68 | if self.transform is not None:
69 | img = self.transform(img)
70 |
71 | return img, mask, os.path.basename(self.images[index])
72 |
73 | def __len__(self):
74 | return len(self.images)
75 |
76 | def _mask_transform(self, mask):
77 | target = np.array(mask).astype('int32')
78 | return torch.from_numpy(target).long()
79 |
80 | @property
81 | def classes(self):
82 | """Category name."""
83 | return ('background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes',
84 | 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt',
85 | 'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe',
86 | 'rightShoe')
87 |
88 |
89 | if __name__ == '__main__':
90 | dataset = LIPSegmentation(base_size=280, crop_size=256)
--------------------------------------------------------------------------------
/core/data/dataloader/mscoco.py:
--------------------------------------------------------------------------------
1 | """MSCOCO Semantic Segmentation pretraining for VOC."""
2 | import os
3 | import pickle
4 | import torch
5 | import numpy as np
6 |
7 | from tqdm import trange
8 | from PIL import Image
9 | from .segbase import SegmentationDataset
10 |
11 |
12 | class COCOSegmentation(SegmentationDataset):
13 | """COCO Semantic Segmentation Dataset for VOC Pre-training.
14 |
15 | Parameters
16 | ----------
17 | root : string
18 | Path to ADE20K folder. Default is './datasets/coco'
19 | split: string
20 | 'train', 'val' or 'test'
21 | transform : callable, optional
22 | A function that transforms the image
23 | Examples
24 | --------
25 | >>> from torchvision import transforms
26 | >>> import torch.utils.data as data
27 | >>> # Transforms for Normalization
28 | >>> input_transform = transforms.Compose([
29 | >>> transforms.ToTensor(),
30 | >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)),
31 | >>> ])
32 | >>> # Create Dataset
33 | >>> trainset = COCOSegmentation(split='train', transform=input_transform)
34 | >>> # Create Training Loader
35 | >>> train_data = data.DataLoader(
36 | >>> trainset, 4, shuffle=True,
37 | >>> num_workers=4)
38 | """
39 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
40 | 1, 64, 20, 63, 7, 72]
41 | NUM_CLASS = 21
42 |
43 | def __init__(self, root='../datasets/coco', split='train', mode=None, transform=None, **kwargs):
44 | super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs)
45 | # lazy import pycocotools
46 | from pycocotools.coco import COCO
47 | from pycocotools import mask
48 | if split == 'train':
49 | print('train set')
50 | ann_file = os.path.join(root, 'annotations/instances_train2017.json')
51 | ids_file = os.path.join(root, 'annotations/train_ids.mx')
52 | self.root = os.path.join(root, 'train2017')
53 | else:
54 | print('val set')
55 | ann_file = os.path.join(root, 'annotations/instances_val2017.json')
56 | ids_file = os.path.join(root, 'annotations/val_ids.mx')
57 | self.root = os.path.join(root, 'val2017')
58 | self.coco = COCO(ann_file)
59 | self.coco_mask = mask
60 | if os.path.exists(ids_file):
61 | with open(ids_file, 'rb') as f:
62 | self.ids = pickle.load(f)
63 | else:
64 | ids = list(self.coco.imgs.keys())
65 | self.ids = self._preprocess(ids, ids_file)
66 | self.transform = transform
67 |
68 | def __getitem__(self, index):
69 | coco = self.coco
70 | img_id = self.ids[index]
71 | img_metadata = coco.loadImgs(img_id)[0]
72 | path = img_metadata['file_name']
73 | img = Image.open(os.path.join(self.root, path)).convert('RGB')
74 | cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
75 | mask = Image.fromarray(self._gen_seg_mask(
76 | cocotarget, img_metadata['height'], img_metadata['width']))
77 | # synchrosized transform
78 | if self.mode == 'train':
79 | img, mask = self._sync_transform(img, mask)
80 | elif self.mode == 'val':
81 | img, mask = self._val_sync_transform(img, mask)
82 | else:
83 | assert self.mode == 'testval'
84 | img, mask = self._img_transform(img), self._mask_transform(mask)
85 | # general resize, normalize and toTensor
86 | if self.transform is not None:
87 | img = self.transform(img)
88 | return img, mask, os.path.basename(self.ids[index])
89 |
90 | def _mask_transform(self, mask):
91 | return torch.LongTensor(np.array(mask).astype('int32'))
92 |
93 | def _gen_seg_mask(self, target, h, w):
94 | mask = np.zeros((h, w), dtype=np.uint8)
95 | coco_mask = self.coco_mask
96 | for instance in target:
97 | rle = coco_mask.frPyObjects(instance['Segmentation'], h, w)
98 | m = coco_mask.decode(rle)
99 | cat = instance['category_id']
100 | if cat in self.CAT_LIST:
101 | c = self.CAT_LIST.index(cat)
102 | else:
103 | continue
104 | if len(m.shape) < 3:
105 | mask[:, :] += (mask == 0) * (m * c)
106 | else:
107 | mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8)
108 | return mask
109 |
110 | def _preprocess(self, ids, ids_file):
111 | print("Preprocessing mask, this will take a while." + \
112 | "But don't worry, it only run once for each split.")
113 | tbar = trange(len(ids))
114 | new_ids = []
115 | for i in tbar:
116 | img_id = ids[i]
117 | cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
118 | img_metadata = self.coco.loadImgs(img_id)[0]
119 | mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width'])
120 | # more than 1k pixels
121 | if (mask > 0).sum() > 1000:
122 | new_ids.append(img_id)
123 | tbar.set_description('Doing: {}/{}, got {} qualified images'. \
124 | format(i, len(ids), len(new_ids)))
125 | print('Found number of qualified images: ', len(new_ids))
126 | with open(ids_file, 'wb') as f:
127 | pickle.dump(new_ids, f)
128 | return new_ids
129 |
130 | @property
131 | def classes(self):
132 | """Category names."""
133 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle',
134 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
135 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
136 | 'tv')
137 |
--------------------------------------------------------------------------------
/core/data/dataloader/pascal_aug.py:
--------------------------------------------------------------------------------
1 | """Pascal Augmented VOC Semantic Segmentation Dataset."""
2 | import os
3 | import torch
4 | import scipy.io as sio
5 | import numpy as np
6 |
7 | from PIL import Image
8 | from .segbase import SegmentationDataset
9 |
10 |
11 | class VOCAugSegmentation(SegmentationDataset):
12 | """Pascal VOC Augmented Semantic Segmentation Dataset.
13 |
14 | Parameters
15 | ----------
16 | root : string
17 | Path to VOCdevkit folder. Default is './datasets/voc'
18 | split: string
19 | 'train', 'val' or 'test'
20 | transform : callable, optional
21 | A function that transforms the image
22 | Examples
23 | --------
24 | >>> from torchvision import transforms
25 | >>> import torch.utils.data as data
26 | >>> # Transforms for Normalization
27 | >>> input_transform = transforms.Compose([
28 | >>> transforms.ToTensor(),
29 | >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
30 | >>> ])
31 | >>> # Create Dataset
32 | >>> trainset = VOCAugSegmentation(split='train', transform=input_transform)
33 | >>> # Create Training Loader
34 | >>> train_data = data.DataLoader(
35 | >>> trainset, 4, shuffle=True,
36 | >>> num_workers=4)
37 | """
38 | BASE_DIR = 'VOCaug/dataset/'
39 | NUM_CLASS = 21
40 |
41 | def __init__(self, root='../datasets/voc', split='train', mode=None, transform=None, **kwargs):
42 | super(VOCAugSegmentation, self).__init__(root, split, mode, transform, **kwargs)
43 | # train/val/test splits are pre-cut
44 | _voc_root = os.path.join(root, self.BASE_DIR)
45 | _mask_dir = os.path.join(_voc_root, 'cls')
46 | _image_dir = os.path.join(_voc_root, 'img')
47 | if split == 'train':
48 | _split_f = os.path.join(_voc_root, 'trainval.txt')
49 | elif split == 'val':
50 | _split_f = os.path.join(_voc_root, 'val.txt')
51 | else:
52 | raise RuntimeError('Unknown dataset split: {}'.format(split))
53 |
54 | self.images = []
55 | self.masks = []
56 | with open(os.path.join(_split_f), "r") as lines:
57 | for line in lines:
58 | _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg")
59 | assert os.path.isfile(_image)
60 | self.images.append(_image)
61 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".mat")
62 | assert os.path.isfile(_mask)
63 | self.masks.append(_mask)
64 |
65 | assert (len(self.images) == len(self.masks))
66 | print('Found {} images in the folder {}'.format(len(self.images), _voc_root))
67 |
68 | def __getitem__(self, index):
69 | img = Image.open(self.images[index]).convert('RGB')
70 | target = self._load_mat(self.masks[index])
71 | # synchrosized transform
72 | if self.mode == 'train':
73 | img, target = self._sync_transform(img, target)
74 | elif self.mode == 'val':
75 | img, target = self._val_sync_transform(img, target)
76 | else:
77 | raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode))
78 | # general resize, normalize and toTensor
79 | if self.transform is not None:
80 | img = self.transform(img)
81 | return img, target, os.path.basename(self.images[index])
82 |
83 | def _mask_transform(self, mask):
84 | return torch.LongTensor(np.array(mask).astype('int32'))
85 |
86 | def _load_mat(self, filename):
87 | mat = sio.loadmat(filename, mat_dtype=True, squeeze_me=True, struct_as_record=False)
88 | mask = mat['GTcls'].Segmentation
89 | return Image.fromarray(mask)
90 |
91 | def __len__(self):
92 | return len(self.images)
93 |
94 | @property
95 | def classes(self):
96 | """Category names."""
97 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle',
98 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
99 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
100 | 'tv')
101 |
102 |
103 | if __name__ == '__main__':
104 | dataset = VOCAugSegmentation()
--------------------------------------------------------------------------------
/core/data/dataloader/pascal_voc.py:
--------------------------------------------------------------------------------
1 | """Pascal VOC Semantic Segmentation Dataset."""
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 | from PIL import Image
7 | from .segbase import SegmentationDataset
8 |
9 |
10 | class VOCSegmentation(SegmentationDataset):
11 | """Pascal VOC Semantic Segmentation Dataset.
12 |
13 | Parameters
14 | ----------
15 | root : string
16 | Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
17 | split: string
18 | 'train', 'val' or 'test'
19 | transform : callable, optional
20 | A function that transforms the image
21 | Examples
22 | --------
23 | >>> from torchvision import transforms
24 | >>> import torch.utils.data as data
25 | >>> # Transforms for Normalization
26 | >>> input_transform = transforms.Compose([
27 | >>> transforms.ToTensor(),
28 | >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
29 | >>> ])
30 | >>> # Create Dataset
31 | >>> trainset = VOCSegmentation(split='train', transform=input_transform)
32 | >>> # Create Training Loader
33 | >>> train_data = data.DataLoader(
34 | >>> trainset, 4, shuffle=True,
35 | >>> num_workers=4)
36 | """
37 | BASE_DIR = 'VOC2012'
38 | NUM_CLASS = 21
39 |
40 | def __init__(self, root='../datasets/voc', split='train', mode=None, transform=None, **kwargs):
41 | super(VOCSegmentation, self).__init__(root, split, mode, transform, **kwargs)
42 | _voc_root = os.path.join(root, self.BASE_DIR)
43 | _mask_dir = os.path.join(_voc_root, 'SegmentationClass')
44 | _image_dir = os.path.join(_voc_root, 'JPEGImages')
45 | # train/val/test splits are pre-cut
46 | _splits_dir = os.path.join(_voc_root, 'ImageSets/Segmentation')
47 | if split == 'train':
48 | _split_f = os.path.join(_splits_dir, 'train.txt')
49 | elif split == 'val':
50 | _split_f = os.path.join(_splits_dir, 'val.txt')
51 | elif split == 'test':
52 | _split_f = os.path.join(_splits_dir, 'test.txt')
53 | else:
54 | raise RuntimeError('Unknown dataset split.')
55 |
56 | self.images = []
57 | self.masks = []
58 | with open(os.path.join(_split_f), "r") as lines:
59 | for line in lines:
60 | _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg")
61 | assert os.path.isfile(_image)
62 | self.images.append(_image)
63 | if split != 'test':
64 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".png")
65 | assert os.path.isfile(_mask)
66 | self.masks.append(_mask)
67 |
68 | if split != 'test':
69 | assert (len(self.images) == len(self.masks))
70 | print('Found {} images in the folder {}'.format(len(self.images), _voc_root))
71 |
72 | def __getitem__(self, index):
73 | img = Image.open(self.images[index]).convert('RGB')
74 | if self.mode == 'test':
75 | img = self._img_transform(img)
76 | if self.transform is not None:
77 | img = self.transform(img)
78 | return img, os.path.basename(self.images[index])
79 | mask = Image.open(self.masks[index])
80 | # synchronized transform
81 | if self.mode == 'train':
82 | img, mask = self._sync_transform(img, mask)
83 | elif self.mode == 'val':
84 | img, mask = self._val_sync_transform(img, mask)
85 | else:
86 | assert self.mode == 'testval'
87 | img, mask = self._img_transform(img), self._mask_transform(mask)
88 | # general resize, normalize and toTensor
89 | if self.transform is not None:
90 | img = self.transform(img)
91 |
92 | return img, mask, os.path.basename(self.images[index])
93 |
94 | def __len__(self):
95 | return len(self.images)
96 |
97 | def _mask_transform(self, mask):
98 | target = np.array(mask).astype('int32')
99 | target[target == 255] = -1
100 | return torch.from_numpy(target).long()
101 |
102 | @property
103 | def classes(self):
104 | """Category names."""
105 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle',
106 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
107 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train',
108 | 'tv')
109 |
110 |
111 | if __name__ == '__main__':
112 | dataset = VOCSegmentation()
--------------------------------------------------------------------------------
/core/data/dataloader/sbu_shadow.py:
--------------------------------------------------------------------------------
1 | """SBU Shadow Segmentation Dataset."""
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 | from PIL import Image
7 | from .segbase import SegmentationDataset
8 |
9 |
10 | class SBUSegmentation(SegmentationDataset):
11 | """SBU Shadow Segmentation Dataset
12 | """
13 | NUM_CLASS = 2
14 |
15 | def __init__(self, root='../datasets/sbu', split='train', mode=None, transform=None, **kwargs):
16 | super(SBUSegmentation, self).__init__(root, split, mode, transform, **kwargs)
17 | assert os.path.exists(self.root)
18 | self.images, self.masks = _get_sbu_pairs(self.root, self.split)
19 | assert (len(self.images) == len(self.masks))
20 | if len(self.images) == 0:
21 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n")
22 |
23 | def __getitem__(self, index):
24 | img = Image.open(self.images[index]).convert('RGB')
25 | if self.mode == 'test':
26 | if self.transform is not None:
27 | img = self.transform(img)
28 | return img, os.path.basename(self.images[index])
29 | mask = Image.open(self.masks[index])
30 | # synchrosized transform
31 | if self.mode == 'train':
32 | img, mask = self._sync_transform(img, mask)
33 | elif self.mode == 'val':
34 | img, mask = self._val_sync_transform(img, mask)
35 | else:
36 | assert self.mode == 'testval'
37 | img, mask = self._img_transform(img), self._mask_transform(mask)
38 | # general resize, normalize and toTensor
39 | if self.transform is not None:
40 | img = self.transform(img)
41 | return img, mask, os.path.basename(self.images[index])
42 |
43 | def _mask_transform(self, mask):
44 | target = np.array(mask).astype('int32')
45 | target[target > 0] = 1
46 | return torch.from_numpy(target).long()
47 |
48 | def __len__(self):
49 | return len(self.images)
50 |
51 | @property
52 | def pred_offset(self):
53 | return 0
54 |
55 |
56 | def _get_sbu_pairs(folder, split='train'):
57 | def get_path_pairs(img_folder, mask_folder):
58 | img_paths = []
59 | mask_paths = []
60 | for root, _, files in os.walk(img_folder):
61 | print(root)
62 | for filename in files:
63 | if filename.endswith('.jpg'):
64 | imgpath = os.path.join(root, filename)
65 | maskname = filename.replace('.jpg', '.png')
66 | maskpath = os.path.join(mask_folder, maskname)
67 | if os.path.isfile(imgpath) and os.path.isfile(maskpath):
68 | img_paths.append(imgpath)
69 | mask_paths.append(maskpath)
70 | else:
71 | print('cannot find the mask or image:', imgpath, maskpath)
72 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder))
73 | return img_paths, mask_paths
74 |
75 | if split == 'train':
76 | img_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowImages')
77 | mask_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowMasks')
78 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
79 | else:
80 | assert split in ('val', 'test')
81 | img_folder = os.path.join(folder, 'SBU-Test/ShadowImages')
82 | mask_folder = os.path.join(folder, 'SBU-Test/ShadowMasks')
83 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder)
84 | return img_paths, mask_paths
85 |
86 |
87 | if __name__ == '__main__':
88 | dataset = SBUSegmentation(base_size=280, crop_size=256)
--------------------------------------------------------------------------------
/core/data/dataloader/segbase.py:
--------------------------------------------------------------------------------
1 | """Base segmentation dataset"""
2 | import random
3 | import numpy as np
4 |
5 | from PIL import Image, ImageOps, ImageFilter
6 |
7 | __all__ = ['SegmentationDataset']
8 |
9 |
10 | class SegmentationDataset(object):
11 | """Segmentation Base Dataset"""
12 |
13 | def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
14 | super(SegmentationDataset, self).__init__()
15 | self.root = root
16 | self.transform = transform
17 | self.split = split
18 | self.mode = mode if mode is not None else split
19 | self.base_size = base_size
20 | self.crop_size = crop_size
21 |
22 | def _val_sync_transform(self, img, mask):
23 | outsize = self.crop_size
24 | short_size = outsize
25 | w, h = img.size
26 | if w > h:
27 | oh = short_size
28 | ow = int(1.0 * w * oh / h)
29 | else:
30 | ow = short_size
31 | oh = int(1.0 * h * ow / w)
32 | img = img.resize((ow, oh), Image.BILINEAR)
33 | mask = mask.resize((ow, oh), Image.NEAREST)
34 | # center crop
35 | w, h = img.size
36 | x1 = int(round((w - outsize) / 2.))
37 | y1 = int(round((h - outsize) / 2.))
38 | img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
39 | mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
40 | # final transform
41 | img, mask = self._img_transform(img), self._mask_transform(mask)
42 | return img, mask
43 |
44 | def _sync_transform(self, img, mask):
45 | # random mirror
46 | if random.random() < 0.5:
47 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
48 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
49 | crop_size = self.crop_size
50 | # random scale (short edge)
51 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
52 | w, h = img.size
53 | if h > w:
54 | ow = short_size
55 | oh = int(1.0 * h * ow / w)
56 | else:
57 | oh = short_size
58 | ow = int(1.0 * w * oh / h)
59 | img = img.resize((ow, oh), Image.BILINEAR)
60 | mask = mask.resize((ow, oh), Image.NEAREST)
61 | # pad crop
62 | if short_size < crop_size:
63 | padh = crop_size - oh if oh < crop_size else 0
64 | padw = crop_size - ow if ow < crop_size else 0
65 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
66 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
67 | # random crop crop_size
68 | w, h = img.size
69 | x1 = random.randint(0, w - crop_size)
70 | y1 = random.randint(0, h - crop_size)
71 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
72 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
73 | # gaussian blur as in PSP
74 | if random.random() < 0.5:
75 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
76 | # final transform
77 | img, mask = self._img_transform(img), self._mask_transform(mask)
78 | return img, mask
79 |
80 | def _img_transform(self, img):
81 | return np.array(img)
82 |
83 | def _mask_transform(self, mask):
84 | return np.array(mask).astype('int32')
85 |
86 | @property
87 | def num_class(self):
88 | """Number of categories."""
89 | return self.NUM_CLASS
90 |
91 | @property
92 | def pred_offset(self):
93 | return 0
94 |
--------------------------------------------------------------------------------
/core/data/dataloader/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hashlib
3 | import errno
4 | import tarfile
5 | from six.moves import urllib
6 | from torch.utils.model_zoo import tqdm
7 |
8 | def gen_bar_updater():
9 | pbar = tqdm(total=None)
10 |
11 | def bar_update(count, block_size, total_size):
12 | if pbar.total is None and total_size:
13 | pbar.total = total_size
14 | progress_bytes = count * block_size
15 | pbar.update(progress_bytes - pbar.n)
16 |
17 | return bar_update
18 |
19 | def check_integrity(fpath, md5=None):
20 | if md5 is None:
21 | return True
22 | if not os.path.isfile(fpath):
23 | return False
24 | md5o = hashlib.md5()
25 | with open(fpath, 'rb') as f:
26 | # read in 1MB chunks
27 | for chunk in iter(lambda: f.read(1024 * 1024), b''):
28 | md5o.update(chunk)
29 | md5c = md5o.hexdigest()
30 | if md5c != md5:
31 | return False
32 | return True
33 |
34 | def makedir_exist_ok(dirpath):
35 | try:
36 | os.makedirs(dirpath)
37 | except OSError as e:
38 | if e.errno == errno.EEXIST:
39 | pass
40 | else:
41 | pass
42 |
43 | def download_url(url, root, filename=None, md5=None):
44 | """Download a file from a url and place it in root."""
45 | root = os.path.expanduser(root)
46 | if not filename:
47 | filename = os.path.basename(url)
48 | fpath = os.path.join(root, filename)
49 |
50 | makedir_exist_ok(root)
51 |
52 | # downloads file
53 | if os.path.isfile(fpath) and check_integrity(fpath, md5):
54 | print('Using downloaded and verified file: ' + fpath)
55 | else:
56 | try:
57 | print('Downloading ' + url + ' to ' + fpath)
58 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
59 | except OSError:
60 | if url[:5] == 'https':
61 | url = url.replace('https:', 'http:')
62 | print('Failed download. Trying https -> http instead.'
63 | ' Downloading ' + url + ' to ' + fpath)
64 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
65 |
66 | def download_extract(url, root, filename, md5):
67 | download_url(url, root, filename, md5)
68 | with tarfile.open(os.path.join(root, filename), "r") as tar:
69 | def is_within_directory(directory, target):
70 |
71 | abs_directory = os.path.abspath(directory)
72 | abs_target = os.path.abspath(target)
73 |
74 | prefix = os.path.commonprefix([abs_directory, abs_target])
75 |
76 | return prefix == abs_directory
77 |
78 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
79 |
80 | for member in tar.getmembers():
81 | member_path = os.path.join(path, member.name)
82 | if not is_within_directory(path, member_path):
83 | raise Exception("Attempted Path Traversal in Tar File")
84 |
85 | tar.extractall(path, members, numeric_owner=numeric_owner)
86 |
87 |
88 | safe_extract(tar, path=root)
--------------------------------------------------------------------------------
/core/data/downloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/core/data/downloader/__init__.py
--------------------------------------------------------------------------------
/core/data/downloader/ade20k.py:
--------------------------------------------------------------------------------
1 | """Prepare ADE20K dataset"""
2 | import os
3 | import sys
4 | import argparse
5 | import zipfile
6 |
7 | # TODO: optim code
8 | cur_path = os.path.abspath(os.path.dirname(__file__))
9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0]
10 | sys.path.append(root_path)
11 |
12 | from core.utils import download, makedirs
13 |
14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/ade')
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(
19 | description='Initialize ADE20K dataset.',
20 | epilog='Example: python setup_ade20k.py',
21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk')
23 | args = parser.parse_args()
24 | return args
25 |
26 |
27 | def download_ade(path, overwrite=False):
28 | _AUG_DOWNLOAD_URLS = [
29 | ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip',
30 | '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'),
31 | (
32 | 'http://data.csail.mit.edu/places/ADEchallenge/release_test.zip',
33 | 'e05747892219d10e9243933371a497e905a4860c'), ]
34 | download_dir = os.path.join(path, 'downloads')
35 | makedirs(download_dir)
36 | for url, checksum in _AUG_DOWNLOAD_URLS:
37 | filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum)
38 | # extract
39 | with zipfile.ZipFile(filename, "r") as zip_ref:
40 | zip_ref.extractall(path=path)
41 |
42 |
43 | if __name__ == '__main__':
44 | args = parse_args()
45 | makedirs(os.path.expanduser('~/.torch/datasets'))
46 | if args.download_dir is not None:
47 | if os.path.isdir(_TARGET_DIR):
48 | os.remove(_TARGET_DIR)
49 | # make symlink
50 | os.symlink(args.download_dir, _TARGET_DIR)
51 | download_ade(_TARGET_DIR, overwrite=False)
52 |
--------------------------------------------------------------------------------
/core/data/downloader/cityscapes.py:
--------------------------------------------------------------------------------
1 | """Prepare Cityscapes dataset"""
2 | import os
3 | import sys
4 | import argparse
5 | import zipfile
6 |
7 | # TODO: optim code
8 | cur_path = os.path.abspath(os.path.dirname(__file__))
9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0]
10 | sys.path.append(root_path)
11 |
12 | from core.utils import download, makedirs, check_sha1
13 |
14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/citys')
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(
19 | description='Initialize ADE20K dataset.',
20 | epilog='Example: python prepare_cityscapes.py',
21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk')
23 | args = parser.parse_args()
24 | return args
25 |
26 |
27 | def download_city(path, overwrite=False):
28 | _CITY_DOWNLOAD_URLS = [
29 | ('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'),
30 | ('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')]
31 | download_dir = os.path.join(path, 'downloads')
32 | makedirs(download_dir)
33 | for filename, checksum in _CITY_DOWNLOAD_URLS:
34 | if not check_sha1(filename, checksum):
35 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \
36 | 'The repo may be outdated or download may be incomplete. ' \
37 | 'If the "repo_url" is overridden, consider switching to ' \
38 | 'the default repo.'.format(filename))
39 | # extract
40 | with zipfile.ZipFile(filename, "r") as zip_ref:
41 | zip_ref.extractall(path=path)
42 | print("Extracted", filename)
43 |
44 |
45 | if __name__ == '__main__':
46 | args = parse_args()
47 | makedirs(os.path.expanduser('~/.torch/datasets'))
48 | if args.download_dir is not None:
49 | if os.path.isdir(_TARGET_DIR):
50 | os.remove(_TARGET_DIR)
51 | # make symlink
52 | os.symlink(args.download_dir, _TARGET_DIR)
53 | else:
54 | download_city(_TARGET_DIR, overwrite=False)
55 |
--------------------------------------------------------------------------------
/core/data/downloader/mscoco.py:
--------------------------------------------------------------------------------
1 | """Prepare MS COCO datasets"""
2 | import os
3 | import sys
4 | import argparse
5 | import zipfile
6 |
7 | # TODO: optim code
8 | cur_path = os.path.abspath(os.path.dirname(__file__))
9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0]
10 | sys.path.append(root_path)
11 |
12 | from core.utils import download, makedirs, try_import_pycocotools
13 |
14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/coco')
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(
19 | description='Initialize MS COCO dataset.',
20 | epilog='Example: python mscoco.py --download-dir ~/mscoco',
21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument('--download-dir', type=str, default='~/mscoco/', help='dataset directory on disk')
23 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
24 | parser.add_argument('--overwrite', action='store_true',
25 | help='overwrite downloaded files if set, in case they are corrupted')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def download_coco(path, overwrite=False):
31 | _DOWNLOAD_URLS = [
32 | ('http://images.cocodataset.org/zips/train2017.zip',
33 | '10ad623668ab00c62c096f0ed636d6aff41faca5'),
34 | ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
35 | '8551ee4bb5860311e79dace7e79cb91e432e78b3'),
36 | ('http://images.cocodataset.org/zips/val2017.zip',
37 | '4950dc9d00dbe1c933ee0170f5797584351d2a41'),
38 | # ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip',
39 | # '46cdcf715b6b4f67e980b529534e79c2edffe084'),
40 | # test2017.zip, for those who want to attend the competition.
41 | # ('http://images.cocodataset.org/zips/test2017.zip',
42 | # '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'),
43 | ]
44 | makedirs(path)
45 | for url, checksum in _DOWNLOAD_URLS:
46 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
47 | # extract
48 | with zipfile.ZipFile(filename) as zf:
49 | zf.extractall(path=path)
50 |
51 |
52 | if __name__ == '__main__':
53 | args = parse_args()
54 | path = os.path.expanduser(args.download_dir)
55 | if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'train2017')) \
56 | or not os.path.isdir(os.path.join(path, 'val2017')) \
57 | or not os.path.isdir(os.path.join(path, 'annotations')):
58 | if args.no_download:
59 | raise ValueError(('{} is not a valid directory, make sure it is present.'
60 | ' Or you should not disable "--no-download" to grab it'.format(path)))
61 | else:
62 | download_coco(path, overwrite=args.overwrite)
63 |
64 | # make symlink
65 | makedirs(os.path.expanduser('~/.torch/datasets'))
66 | if os.path.isdir(_TARGET_DIR):
67 | os.remove(_TARGET_DIR)
68 | os.symlink(path, _TARGET_DIR)
69 | try_import_pycocotools()
70 |
--------------------------------------------------------------------------------
/core/data/downloader/pascal_voc.py:
--------------------------------------------------------------------------------
1 | """Prepare PASCAL VOC datasets"""
2 | import os
3 | import sys
4 | import shutil
5 | import argparse
6 | import tarfile
7 |
8 | # TODO: optim code
9 | cur_path = os.path.abspath(os.path.dirname(__file__))
10 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0]
11 | sys.path.append(root_path)
12 |
13 | from core.utils import download, makedirs
14 |
15 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/voc')
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(
20 | description='Initialize PASCAL VOC dataset.',
21 | epilog='Example: python pascal_voc.py --download-dir ~/VOCdevkit',
22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
23 | parser.add_argument('--download-dir', type=str, default='~/VOCdevkit/', help='dataset directory on disk')
24 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
25 | parser.add_argument('--overwrite', action='store_true',
26 | help='overwrite downloaded files if set, in case they are corrupted')
27 | args = parser.parse_args()
28 | return args
29 |
30 |
31 | #####################################################################################
32 | # Download and extract VOC datasets into ``path``
33 |
34 | def download_voc(path, overwrite=False):
35 | _DOWNLOAD_URLS = [
36 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
37 | '34ed68851bce2a36e2a223fa52c661d592c66b3c'),
38 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
39 | '41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'),
40 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
41 | '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')]
42 | makedirs(path)
43 | for url, checksum in _DOWNLOAD_URLS:
44 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
45 | # extract
46 | with tarfile.open(filename) as tar:
47 | def is_within_directory(directory, target):
48 |
49 | abs_directory = os.path.abspath(directory)
50 | abs_target = os.path.abspath(target)
51 |
52 | prefix = os.path.commonprefix([abs_directory, abs_target])
53 |
54 | return prefix == abs_directory
55 |
56 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
57 |
58 | for member in tar.getmembers():
59 | member_path = os.path.join(path, member.name)
60 | if not is_within_directory(path, member_path):
61 | raise Exception("Attempted Path Traversal in Tar File")
62 |
63 | tar.extractall(path, members, numeric_owner=numeric_owner)
64 |
65 |
66 | safe_extract(tar, path=path)
67 |
68 |
69 | #####################################################################################
70 | # Download and extract the VOC augmented segmentation dataset into ``path``
71 |
72 | def download_aug(path, overwrite=False):
73 | _AUG_DOWNLOAD_URLS = [
74 | ('http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz',
75 | '7129e0a480c2d6afb02b517bb18ac54283bfaa35')]
76 | makedirs(path)
77 | for url, checksum in _AUG_DOWNLOAD_URLS:
78 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum)
79 | # extract
80 | with tarfile.open(filename) as tar:
81 | def is_within_directory(directory, target):
82 |
83 | abs_directory = os.path.abspath(directory)
84 | abs_target = os.path.abspath(target)
85 |
86 | prefix = os.path.commonprefix([abs_directory, abs_target])
87 |
88 | return prefix == abs_directory
89 |
90 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
91 |
92 | for member in tar.getmembers():
93 | member_path = os.path.join(path, member.name)
94 | if not is_within_directory(path, member_path):
95 | raise Exception("Attempted Path Traversal in Tar File")
96 |
97 | tar.extractall(path, members, numeric_owner=numeric_owner)
98 |
99 |
100 | safe_extract(tar, path=path)
101 | shutil.move(os.path.join(path, 'benchmark_RELEASE'),
102 | os.path.join(path, 'VOCaug'))
103 | filenames = ['VOCaug/dataset/train.txt', 'VOCaug/dataset/val.txt']
104 | # generate trainval.txt
105 | with open(os.path.join(path, 'VOCaug/dataset/trainval.txt'), 'w') as outfile:
106 | for fname in filenames:
107 | fname = os.path.join(path, fname)
108 | with open(fname) as infile:
109 | for line in infile:
110 | outfile.write(line)
111 |
112 |
113 | if __name__ == '__main__':
114 | args = parse_args()
115 | path = os.path.expanduser(args.download_dir)
116 | if not os.path.isfile(path) or not os.path.isdir(os.path.join(path, 'VOC2007')) \
117 | or not os.path.isdir(os.path.join(path, 'VOC2012')):
118 | if args.no_download:
119 | raise ValueError(('{} is not a valid directory, make sure it is present.'
120 | ' Or you should not disable "--no-download" to grab it'.format(path)))
121 | else:
122 | download_voc(path, overwrite=args.overwrite)
123 | shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2007'), os.path.join(path, 'VOC2007'))
124 | shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2012'), os.path.join(path, 'VOC2012'))
125 | shutil.rmtree(os.path.join(path, 'VOCdevkit'))
126 |
127 | if not os.path.isdir(os.path.join(path, 'VOCaug')):
128 | if args.no_download:
129 | raise ValueError(('{} is not a valid directory, make sure it is present.'
130 | ' Or you should not disable "--no-download" to grab it'.format(path)))
131 | else:
132 | download_aug(path, overwrite=args.overwrite)
133 |
134 | # make symlink
135 | makedirs(os.path.expanduser('~/.torch/datasets'))
136 | if os.path.isdir(_TARGET_DIR):
137 | os.remove(_TARGET_DIR)
138 | os.symlink(path, _TARGET_DIR)
139 |
--------------------------------------------------------------------------------
/core/data/downloader/sbu_shadow.py:
--------------------------------------------------------------------------------
1 | """Prepare SBU Shadow datasets"""
2 | import os
3 | import sys
4 | import argparse
5 | import zipfile
6 |
7 | # TODO: optim code
8 | cur_path = os.path.abspath(os.path.dirname(__file__))
9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0]
10 | sys.path.append(root_path)
11 |
12 | from core.utils import download, makedirs
13 |
14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/sbu')
15 |
16 |
17 | def parse_args():
18 | parser = argparse.ArgumentParser(
19 | description='Initialize SBU Shadow dataset.',
20 | epilog='Example: python sbu_shadow.py --download-dir ~/SBU-shadow',
21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument('--download-dir', type=str, default=None, help='dataset directory on disk')
23 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set')
24 | parser.add_argument('--overwrite', action='store_true',
25 | help='overwrite downloaded files if set, in case they are corrupted')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | #####################################################################################
31 | # Download and extract SBU shadow datasets into ``path``
32 |
33 | def download_sbu(path, overwrite=False):
34 | _DOWNLOAD_URLS = [
35 | ('http://www3.cs.stonybrook.edu/~cvl/content/datasets/shadow_db/SBU-shadow.zip'),
36 | ]
37 | download_dir = os.path.join(path, 'downloads')
38 | makedirs(download_dir)
39 | for url in _DOWNLOAD_URLS:
40 | filename = download(url, path=path, overwrite=overwrite)
41 | # extract
42 | with zipfile.ZipFile(filename, "r") as zf:
43 | zf.extractall(path=path)
44 | print("Extracted", filename)
45 |
46 |
47 | if __name__ == '__main__':
48 | args = parse_args()
49 | makedirs(os.path.expanduser('~/.torch/datasets'))
50 | if args.download_dir is not None:
51 | if os.path.isdir(_TARGET_DIR):
52 | os.remove(_TARGET_DIR)
53 | # make symlink
54 | os.symlink(args.download_dir, _TARGET_DIR)
55 | else:
56 | download_sbu(_TARGET_DIR, overwrite=False)
57 |
--------------------------------------------------------------------------------
/core/models/__init__.py:
--------------------------------------------------------------------------------
1 | """Model Zoo"""
2 | from .model_zoo import get_model, get_model_list
--------------------------------------------------------------------------------
/core/models/base_models/__init__.py:
--------------------------------------------------------------------------------
1 | from .densenet import *
2 | from .resnet import *
3 | from .resnetv1b import *
4 | from .vgg import *
5 | from .eespnet import *
6 | from .xception import *
7 |
--------------------------------------------------------------------------------
/core/models/base_models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | """MobileNet and MobileNetV2."""
2 | import torch
3 | import torch.nn as nn
4 |
5 | from core.nn import _ConvBNReLU, _DepthwiseConv, InvertedResidual
6 |
7 | __all__ = ['MobileNet', 'MobileNetV2', 'get_mobilenet', 'get_mobilenet_v2',
8 | 'mobilenet1_0', 'mobilenet_v2_1_0', 'mobilenet0_75', 'mobilenet_v2_0_75',
9 | 'mobilenet0_5', 'mobilenet_v2_0_5', 'mobilenet0_25', 'mobilenet_v2_0_25']
10 |
11 |
12 | class MobileNet(nn.Module):
13 | def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs):
14 | super(MobileNet, self).__init__()
15 | conv_dw_setting = [
16 | [64, 1, 1],
17 | [128, 2, 2],
18 | [256, 2, 2],
19 | [512, 6, 2],
20 | [1024, 2, 2]]
21 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32
22 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, norm_layer=norm_layer)]
23 |
24 | for c, n, s in conv_dw_setting:
25 | out_channels = int(c * multiplier)
26 | for i in range(n):
27 | stride = s if i == 0 else 1
28 | features.append(_DepthwiseConv(input_channels, out_channels, stride, norm_layer))
29 | input_channels = out_channels
30 | features.append(nn.AdaptiveAvgPool2d(1))
31 | self.features = nn.Sequential(*features)
32 |
33 | self.classifier = nn.Linear(int(1024 * multiplier), num_classes)
34 |
35 | # weight initialization
36 | for m in self.modules():
37 | if isinstance(m, nn.Conv2d):
38 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
39 | if m.bias is not None:
40 | nn.init.zeros_(m.bias)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | nn.init.ones_(m.weight)
43 | nn.init.zeros_(m.bias)
44 | elif isinstance(m, nn.Linear):
45 | nn.init.normal_(m.weight, 0, 0.01)
46 | nn.init.zeros_(m.bias)
47 |
48 | def forward(self, x):
49 | x = self.features(x)
50 | x = self.classifier(x.view(x.size(0), x.size(1)))
51 | return x
52 |
53 |
54 | class MobileNetV2(nn.Module):
55 | def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs):
56 | super(MobileNetV2, self).__init__()
57 | inverted_residual_setting = [
58 | # t, c, n, s
59 | [1, 16, 1, 1],
60 | [6, 24, 2, 2],
61 | [6, 32, 3, 2],
62 | [6, 64, 4, 2],
63 | [6, 96, 3, 1],
64 | [6, 160, 3, 2],
65 | [6, 320, 1, 1]]
66 | # building first layer
67 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32
68 | last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
69 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer)]
70 |
71 | # building inverted residual blocks
72 | for t, c, n, s in inverted_residual_setting:
73 | out_channels = int(c * multiplier)
74 | for i in range(n):
75 | stride = s if i == 0 else 1
76 | features.append(InvertedResidual(input_channels, out_channels, stride, t, norm_layer))
77 | input_channels = out_channels
78 |
79 | # building last several layers
80 | features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer))
81 | features.append(nn.AdaptiveAvgPool2d(1))
82 | self.features = nn.Sequential(*features)
83 |
84 | self.classifier = nn.Sequential(
85 | nn.Dropout2d(0.2),
86 | nn.Linear(last_channels, num_classes))
87 |
88 | # weight initialization
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
92 | if m.bias is not None:
93 | nn.init.zeros_(m.bias)
94 | elif isinstance(m, nn.BatchNorm2d):
95 | nn.init.ones_(m.weight)
96 | nn.init.zeros_(m.bias)
97 | elif isinstance(m, nn.Linear):
98 | nn.init.normal_(m.weight, 0, 0.01)
99 | if m.bias is not None:
100 | nn.init.zeros_(m.bias)
101 |
102 | def forward(self, x):
103 | x = self.features(x)
104 | x = self.classifier(x.view(x.size(0), x.size(1)))
105 | return x
106 |
107 |
108 | # Constructor
109 | def get_mobilenet(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs):
110 | model = MobileNet(multiplier=multiplier, **kwargs)
111 |
112 | if pretrained:
113 | raise ValueError("Not support pretrained")
114 | return model
115 |
116 |
117 | def get_mobilenet_v2(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs):
118 | model = MobileNetV2(multiplier=multiplier, **kwargs)
119 |
120 | if pretrained:
121 | raise ValueError("Not support pretrained")
122 | return model
123 |
124 |
125 | def mobilenet1_0(**kwargs):
126 | return get_mobilenet(1.0, **kwargs)
127 |
128 |
129 | def mobilenet_v2_1_0(**kwargs):
130 | return get_mobilenet_v2(1.0, **kwargs)
131 |
132 |
133 | def mobilenet0_75(**kwargs):
134 | return get_mobilenet(0.75, **kwargs)
135 |
136 |
137 | def mobilenet_v2_0_75(**kwargs):
138 | return get_mobilenet_v2(0.75, **kwargs)
139 |
140 |
141 | def mobilenet0_5(**kwargs):
142 | return get_mobilenet(0.5, **kwargs)
143 |
144 |
145 | def mobilenet_v2_0_5(**kwargs):
146 | return get_mobilenet_v2(0.5, **kwargs)
147 |
148 |
149 | def mobilenet0_25(**kwargs):
150 | return get_mobilenet(0.25, **kwargs)
151 |
152 |
153 | def mobilenet_v2_0_25(**kwargs):
154 | return get_mobilenet_v2(0.25, **kwargs)
155 |
156 |
157 | if __name__ == '__main__':
158 | model = mobilenet0_5()
159 |
--------------------------------------------------------------------------------
/core/models/base_models/resnext.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 | __all__ = ['ResNext', 'resnext50_32x4d', 'resnext101_32x8d']
5 |
6 | model_urls = {
7 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
8 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
9 | }
10 |
11 |
12 | class Bottleneck(nn.Module):
13 | expansion = 4
14 |
15 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
16 | base_width=64, dilation=1, norm_layer=None, **kwargs):
17 | super(Bottleneck, self).__init__()
18 | width = int(planes * (base_width / 64.)) * groups
19 |
20 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False)
21 | self.bn1 = norm_layer(width)
22 | self.conv2 = nn.Conv2d(width, width, 3, stride, dilation, dilation, groups, bias=False)
23 | self.bn2 = norm_layer(width)
24 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False)
25 | self.bn3 = norm_layer(planes * self.expansion)
26 | self.relu = nn.ReLU(True)
27 | self.downsample = downsample
28 | self.stride = stride
29 |
30 | def forward(self, x):
31 | identity = x
32 |
33 | out = self.conv1(x)
34 | out = self.bn1(out)
35 | out = self.relu(out)
36 |
37 | out = self.conv2(out)
38 | out = self.bn2(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv3(out)
42 | out = self.bn3(out)
43 |
44 | if self.downsample is not None:
45 | identity = self.downsample(x)
46 |
47 | out += identity
48 | out = self.relu(out)
49 |
50 | return out
51 |
52 |
53 | class ResNext(nn.Module):
54 |
55 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1,
56 | width_per_group=64, dilated=False, norm_layer=nn.BatchNorm2d, **kwargs):
57 | super(ResNext, self).__init__()
58 | self.inplanes = 64
59 | self.groups = groups
60 | self.base_width = width_per_group
61 |
62 | self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False)
63 | self.bn1 = norm_layer(self.inplanes)
64 | self.relu = nn.ReLU(True)
65 | self.maxpool = nn.MaxPool2d(3, 2, 1)
66 |
67 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
69 | if dilated:
70 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer)
71 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer)
72 | else:
73 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
74 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
75 |
76 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
77 | self.fc = nn.Linear(512 * block.expansion, num_classes)
78 |
79 | for m in self.modules():
80 | if isinstance(m, nn.Conv2d):
81 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
82 | elif isinstance(m, nn.BatchNorm2d):
83 | nn.init.constant_(m.weight, 1)
84 | nn.init.constant_(m.bias, 0)
85 |
86 | if zero_init_residual:
87 | for m in self.modules():
88 | if isinstance(m, Bottleneck):
89 | nn.init.constant_(m.bn3.weight, 0)
90 |
91 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d):
92 | downsample = None
93 | if stride != 1 or self.inplanes != planes * block.expansion:
94 | downsample = nn.Sequential(
95 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False),
96 | norm_layer(planes * block.expansion)
97 | )
98 |
99 | layers = list()
100 | if dilation in (1, 2):
101 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
102 | self.base_width, norm_layer=norm_layer))
103 | elif dilation == 4:
104 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
105 | self.base_width, dilation=2, norm_layer=norm_layer))
106 | else:
107 | raise RuntimeError("=> unknown dilation size: {}".format(dilation))
108 | self.inplanes = planes * block.expansion
109 | for _ in range(1, blocks):
110 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width,
111 | dilation=dilation, norm_layer=norm_layer))
112 |
113 | return nn.Sequential(*layers)
114 |
115 | def forward(self, x):
116 | x = self.conv1(x)
117 | x = self.bn1(x)
118 | x = self.relu(x)
119 | x = self.maxpool(x)
120 |
121 | x = self.layer1(x)
122 | x = self.layer2(x)
123 | x = self.layer3(x)
124 | x = self.layer4(x)
125 |
126 | x = self.avgpool(x)
127 | x = x.view(x.size(0), -1)
128 | x = self.fc(x)
129 |
130 | return x
131 |
132 |
133 | def resnext50_32x4d(pretrained=False, **kwargs):
134 | kwargs['groups'] = 32
135 | kwargs['width_per_group'] = 4
136 | model = ResNext(Bottleneck, [3, 4, 6, 3], **kwargs)
137 | if pretrained:
138 | state_dict = model_zoo.load_url(model_urls['resnext50_32x4d'])
139 | model.load_state_dict(state_dict)
140 | return model
141 |
142 |
143 | def resnext101_32x8d(pretrained=False, **kwargs):
144 | kwargs['groups'] = 32
145 | kwargs['width_per_group'] = 8
146 | model = ResNext(Bottleneck, [3, 4, 23, 3], **kwargs)
147 | if pretrained:
148 | state_dict = model_zoo.load_url(model_urls['resnext101_32x8d'])
149 | model.load_state_dict(state_dict)
150 | return model
151 |
152 |
153 | if __name__ == '__main__':
154 | model = resnext101_32x8d()
155 |
--------------------------------------------------------------------------------
/core/models/base_models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 |
5 | __all__ = [
6 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
7 | 'vgg19_bn', 'vgg19',
8 | ]
9 |
10 | model_urls = {
11 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
12 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
13 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
14 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
15 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
16 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
17 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
18 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
19 | }
20 |
21 |
22 | class VGG(nn.Module):
23 | def __init__(self, features, num_classes=1000, init_weights=True):
24 | super(VGG, self).__init__()
25 | self.features = features
26 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
27 | self.classifier = nn.Sequential(
28 | nn.Linear(512 * 7 * 7, 4096),
29 | nn.ReLU(True),
30 | nn.Dropout(),
31 | nn.Linear(4096, 4096),
32 | nn.ReLU(True),
33 | nn.Dropout(),
34 | nn.Linear(4096, num_classes)
35 | )
36 | if init_weights:
37 | self._initialize_weights()
38 |
39 | def forward(self, x):
40 | x = self.features(x)
41 | x = self.avgpool(x)
42 | x = x.view(x.size(0), -1)
43 | x = self.classifier(x)
44 | return x
45 |
46 | def _initialize_weights(self):
47 | for m in self.modules():
48 | if isinstance(m, nn.Conv2d):
49 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
50 | if m.bias is not None:
51 | nn.init.constant_(m.bias, 0)
52 | elif isinstance(m, nn.BatchNorm2d):
53 | nn.init.constant_(m.weight, 1)
54 | nn.init.constant_(m.bias, 0)
55 | elif isinstance(m, nn.Linear):
56 | nn.init.normal_(m.weight, 0, 0.01)
57 | nn.init.constant_(m.bias, 0)
58 |
59 |
60 | def make_layers(cfg, batch_norm=False):
61 | layers = []
62 | in_channels = 3
63 | for v in cfg:
64 | if v == 'M':
65 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
66 | else:
67 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
68 | if batch_norm:
69 | layers += (conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True))
70 | else:
71 | layers += [conv2d, nn.ReLU(inplace=True)]
72 | in_channels = v
73 | return nn.Sequential(*layers)
74 |
75 |
76 | cfg = {
77 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
78 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
79 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
80 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
81 | }
82 |
83 |
84 | def vgg11(pretrained=False, **kwargs):
85 | """VGG 11-layer model (configuration "A")
86 | Args:
87 | pretrained (bool): If True, returns a model pre-trained on ImageNet
88 | """
89 | if pretrained:
90 | kwargs['init_weights'] = False
91 | model = VGG(make_layers(cfg['A']), **kwargs)
92 | if pretrained:
93 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
94 | return model
95 |
96 |
97 | def vgg11_bn(pretrained=False, **kwargs):
98 | """VGG 11-layer model (configuration "A") with batch normalization
99 | Args:
100 | pretrained (bool): If True, returns a model pre-trained on ImageNet
101 | """
102 | if pretrained:
103 | kwargs['init_weights'] = False
104 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
105 | if pretrained:
106 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
107 | return model
108 |
109 |
110 | def vgg13(pretrained=False, **kwargs):
111 | """VGG 13-layer model (configuration "B")
112 | Args:
113 | pretrained (bool): If True, returns a model pre-trained on ImageNet
114 | """
115 | if pretrained:
116 | kwargs['init_weights'] = False
117 | model = VGG(make_layers(cfg['B']), **kwargs)
118 | if pretrained:
119 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
120 | return model
121 |
122 |
123 | def vgg13_bn(pretrained=False, **kwargs):
124 | """VGG 13-layer model (configuration "B") with batch normalization
125 | Args:
126 | pretrained (bool): If True, returns a model pre-trained on ImageNet
127 | """
128 | if pretrained:
129 | kwargs['init_weights'] = False
130 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
131 | if pretrained:
132 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
133 | return model
134 |
135 |
136 | def vgg16(pretrained=False, **kwargs):
137 | """VGG 16-layer model (configuration "D")
138 | Args:
139 | pretrained (bool): If True, returns a model pre-trained on ImageNet
140 | """
141 | if pretrained:
142 | kwargs['init_weights'] = False
143 | model = VGG(make_layers(cfg['D']), **kwargs)
144 | if pretrained:
145 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
146 | return model
147 |
148 |
149 | def vgg16_bn(pretrained=False, **kwargs):
150 | """VGG 16-layer model (configuration "D") with batch normalization
151 | Args:
152 | pretrained (bool): If True, returns a model pre-trained on ImageNet
153 | """
154 | if pretrained:
155 | kwargs['init_weights'] = False
156 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
157 | if pretrained:
158 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
159 | return model
160 |
161 |
162 | def vgg19(pretrained=False, **kwargs):
163 | """VGG 19-layer model (configuration "E")
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | if pretrained:
168 | kwargs['init_weights'] = False
169 | model = VGG(make_layers(cfg['E']), **kwargs)
170 | if pretrained:
171 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
172 | return model
173 |
174 |
175 | def vgg19_bn(pretrained=False, **kwargs):
176 | """VGG 19-layer model (configuration 'E') with batch normalization
177 | Args:
178 | pretrained (bool): If True, returns a model pre-trained on ImageNet
179 | """
180 | if pretrained:
181 | kwargs['init_weights'] = False
182 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
183 | if pretrained:
184 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
185 | return model
186 |
187 |
188 | if __name__ == '__main__':
189 | img = torch.randn((4, 3, 480, 480))
190 | model = vgg16(pretrained=False)
191 | out = model(img)
192 |
--------------------------------------------------------------------------------
/core/models/deeplabv3.py:
--------------------------------------------------------------------------------
1 | """Pyramid Scene Parsing Network"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .segbase import SegBaseModel
7 | from .fcn import _FCNHead
8 |
9 | __all__ = ['DeepLabV3', 'get_deeplabv3', 'get_deeplabv3_resnet50_voc', 'get_deeplabv3_resnet101_voc',
10 | 'get_deeplabv3_resnet152_voc', 'get_deeplabv3_resnet50_ade', 'get_deeplabv3_resnet101_ade',
11 | 'get_deeplabv3_resnet152_ade']
12 |
13 |
14 | class DeepLabV3(SegBaseModel):
15 | r"""DeepLabV3
16 |
17 | Parameters
18 | ----------
19 | nclass : int
20 | Number of categories for the training dataset.
21 | backbone : string
22 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
23 | 'resnet101' or 'resnet152').
24 | norm_layer : object
25 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
26 | for Synchronized Cross-GPU BachNormalization).
27 | aux : bool
28 | Auxiliary loss.
29 |
30 | Reference:
31 | Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation."
32 | arXiv preprint arXiv:1706.05587 (2017).
33 | """
34 |
35 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs):
36 | super(DeepLabV3, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
37 | self.head = _DeepLabHead(nclass, **kwargs)
38 | if self.aux:
39 | self.auxlayer = _FCNHead(1024, nclass, **kwargs)
40 |
41 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
42 |
43 | def forward(self, x):
44 | size = x.size()[2:]
45 | _, _, c3, c4 = self.base_forward(x)
46 | outputs = []
47 | x = self.head(c4)
48 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
49 | outputs.append(x)
50 |
51 | if self.aux:
52 | auxout = self.auxlayer(c3)
53 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
54 | outputs.append(auxout)
55 | return tuple(outputs)
56 |
57 |
58 | class _DeepLabHead(nn.Module):
59 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
60 | super(_DeepLabHead, self).__init__()
61 | self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs)
62 | self.block = nn.Sequential(
63 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
64 | norm_layer(256, **({} if norm_kwargs is None else norm_kwargs)),
65 | nn.ReLU(True),
66 | nn.Dropout(0.1),
67 | nn.Conv2d(256, nclass, 1)
68 | )
69 |
70 | def forward(self, x):
71 | x = self.aspp(x)
72 | return self.block(x)
73 |
74 |
75 | class _ASPPConv(nn.Module):
76 | def __init__(self, in_channels, out_channels, atrous_rate, norm_layer, norm_kwargs):
77 | super(_ASPPConv, self).__init__()
78 | self.block = nn.Sequential(
79 | nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
80 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
81 | nn.ReLU(True)
82 | )
83 |
84 | def forward(self, x):
85 | return self.block(x)
86 |
87 |
88 | class _AsppPooling(nn.Module):
89 | def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs, **kwargs):
90 | super(_AsppPooling, self).__init__()
91 | self.gap = nn.Sequential(
92 | nn.AdaptiveAvgPool2d(1),
93 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
94 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
95 | nn.ReLU(True)
96 | )
97 |
98 | def forward(self, x):
99 | size = x.size()[2:]
100 | pool = self.gap(x)
101 | out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
102 | return out
103 |
104 |
105 | class _ASPP(nn.Module):
106 | def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs, **kwargs):
107 | super(_ASPP, self).__init__()
108 | out_channels = 256
109 | self.b0 = nn.Sequential(
110 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
111 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
112 | nn.ReLU(True)
113 | )
114 |
115 | rate1, rate2, rate3 = tuple(atrous_rates)
116 | self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs)
117 | self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs)
118 | self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs)
119 | self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
120 |
121 | self.project = nn.Sequential(
122 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
123 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
124 | nn.ReLU(True),
125 | nn.Dropout(0.5)
126 | )
127 |
128 | def forward(self, x):
129 | feat1 = self.b0(x)
130 | feat2 = self.b1(x)
131 | feat3 = self.b2(x)
132 | feat4 = self.b3(x)
133 | feat5 = self.b4(x)
134 | x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
135 | x = self.project(x)
136 | return x
137 |
138 |
139 | def get_deeplabv3(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
140 | pretrained_base=True, **kwargs):
141 | acronyms = {
142 | 'pascal_voc': 'pascal_voc',
143 | 'pascal_aug': 'pascal_aug',
144 | 'ade20k': 'ade',
145 | 'coco': 'coco',
146 | 'citys': 'citys',
147 | }
148 | from ..data.dataloader import datasets
149 | model = DeepLabV3(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
150 | if pretrained:
151 | from .model_store import get_model_file
152 | device = torch.device(kwargs['local_rank'])
153 | model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root),
154 | map_location=device))
155 | return model
156 |
157 |
158 | def get_deeplabv3_resnet50_voc(**kwargs):
159 | return get_deeplabv3('pascal_voc', 'resnet50', **kwargs)
160 |
161 |
162 | def get_deeplabv3_resnet101_voc(**kwargs):
163 | return get_deeplabv3('pascal_voc', 'resnet101', **kwargs)
164 |
165 |
166 | def get_deeplabv3_resnet152_voc(**kwargs):
167 | return get_deeplabv3('pascal_voc', 'resnet152', **kwargs)
168 |
169 |
170 | def get_deeplabv3_resnet50_ade(**kwargs):
171 | return get_deeplabv3('ade20k', 'resnet50', **kwargs)
172 |
173 |
174 | def get_deeplabv3_resnet101_ade(**kwargs):
175 | return get_deeplabv3('ade20k', 'resnet101', **kwargs)
176 |
177 |
178 | def get_deeplabv3_resnet152_ade(**kwargs):
179 | return get_deeplabv3('ade20k', 'resnet152', **kwargs)
180 |
181 |
182 | if __name__ == '__main__':
183 | model = get_deeplabv3_resnet50_voc()
184 | img = torch.randn(2, 3, 480, 480)
185 | output = model(img)
186 |
--------------------------------------------------------------------------------
/core/models/deeplabv3_plus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_models.xception import get_xception
6 | from .deeplabv3 import _ASPP
7 | from .fcn import _FCNHead
8 | from ..nn import _ConvBNReLU
9 |
10 | __all__ = ['DeepLabV3Plus', 'get_deeplabv3_plus', 'get_deeplabv3_plus_xception_voc']
11 |
12 |
13 | class DeepLabV3Plus(nn.Module):
14 | r"""DeepLabV3Plus
15 | Parameters
16 | ----------
17 | nclass : int
18 | Number of categories for the training dataset.
19 | backbone : string
20 | Pre-trained dilated backbone network type (default:'xception').
21 | norm_layer : object
22 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
23 | for Synchronized Cross-GPU BachNormalization).
24 | aux : bool
25 | Auxiliary loss.
26 |
27 | Reference:
28 | Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic
29 | Image Segmentation."
30 | """
31 |
32 | def __init__(self, nclass, backbone='xception', aux=True, pretrained_base=True, dilated=True, **kwargs):
33 | super(DeepLabV3Plus, self).__init__()
34 | self.aux = aux
35 | self.nclass = nclass
36 | output_stride = 8 if dilated else 32
37 |
38 | self.pretrained = get_xception(pretrained=pretrained_base, output_stride=output_stride, **kwargs)
39 |
40 | # deeplabv3 plus
41 | self.head = _DeepLabHead(nclass, **kwargs)
42 | if aux:
43 | self.auxlayer = _FCNHead(728, nclass, **kwargs)
44 |
45 | def base_forward(self, x):
46 | # Entry flow
47 | x = self.pretrained.conv1(x)
48 | x = self.pretrained.bn1(x)
49 | x = self.pretrained.relu(x)
50 |
51 | x = self.pretrained.conv2(x)
52 | x = self.pretrained.bn2(x)
53 | x = self.pretrained.relu(x)
54 |
55 | x = self.pretrained.block1(x)
56 | # add relu here
57 | x = self.pretrained.relu(x)
58 | low_level_feat = x
59 |
60 | x = self.pretrained.block2(x)
61 | x = self.pretrained.block3(x)
62 |
63 | # Middle flow
64 | x = self.pretrained.midflow(x)
65 | mid_level_feat = x
66 |
67 | # Exit flow
68 | x = self.pretrained.block20(x)
69 | x = self.pretrained.relu(x)
70 | x = self.pretrained.conv3(x)
71 | x = self.pretrained.bn3(x)
72 | x = self.pretrained.relu(x)
73 |
74 | x = self.pretrained.conv4(x)
75 | x = self.pretrained.bn4(x)
76 | x = self.pretrained.relu(x)
77 |
78 | x = self.pretrained.conv5(x)
79 | x = self.pretrained.bn5(x)
80 | x = self.pretrained.relu(x)
81 | return low_level_feat, mid_level_feat, x
82 |
83 | def forward(self, x):
84 | size = x.size()[2:]
85 | c1, c3, c4 = self.base_forward(x)
86 | outputs = list()
87 | x = self.head(c4, c1)
88 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
89 | outputs.append(x)
90 | if self.aux:
91 | auxout = self.auxlayer(c3)
92 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
93 | outputs.append(auxout)
94 | return tuple(outputs)
95 |
96 |
97 | class _DeepLabHead(nn.Module):
98 | def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm2d, **kwargs):
99 | super(_DeepLabHead, self).__init__()
100 | self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, **kwargs)
101 | self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer)
102 | self.block = nn.Sequential(
103 | _ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer),
104 | nn.Dropout(0.5),
105 | _ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer),
106 | nn.Dropout(0.1),
107 | nn.Conv2d(256, nclass, 1))
108 |
109 | def forward(self, x, c1):
110 | size = c1.size()[2:]
111 | c1 = self.c1_block(c1)
112 | x = self.aspp(x)
113 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
114 | return self.block(torch.cat([x, c1], dim=1))
115 |
116 |
117 | def get_deeplabv3_plus(dataset='pascal_voc', backbone='xception', pretrained=False, root='~/.torch/models',
118 | pretrained_base=True, **kwargs):
119 | acronyms = {
120 | 'pascal_voc': 'pascal_voc',
121 | 'pascal_aug': 'pascal_aug',
122 | 'ade20k': 'ade',
123 | 'coco': 'coco',
124 | 'citys': 'citys',
125 | }
126 | from ..data.dataloader import datasets
127 | model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
128 | if pretrained:
129 | from .model_store import get_model_file
130 | device = torch.device(kwargs['local_rank'])
131 | model.load_state_dict(
132 | torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root),
133 | map_location=device))
134 | return model
135 |
136 |
137 | def get_deeplabv3_plus_xception_voc(**kwargs):
138 | return get_deeplabv3_plus('pascal_voc', 'xception', **kwargs)
139 |
140 |
141 | if __name__ == '__main__':
142 | model = get_deeplabv3_plus_xception_voc()
143 |
--------------------------------------------------------------------------------
/core/models/dfanet.py:
--------------------------------------------------------------------------------
1 | """ Deep Feature Aggregation"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from core.models.base_models import Enc, FCAttention, get_xception_a
7 | from core.nn import _ConvBNReLU
8 |
9 | __all__ = ['DFANet', 'get_dfanet', 'get_dfanet_citys']
10 |
11 |
12 | class DFANet(nn.Module):
13 | def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs):
14 | super(DFANet, self).__init__()
15 | self.pretrained = get_xception_a(pretrained_base, **kwargs)
16 |
17 | self.enc2_2 = Enc(240, 48, 4, **kwargs)
18 | self.enc3_2 = Enc(144, 96, 6, **kwargs)
19 | self.enc4_2 = Enc(288, 192, 4, **kwargs)
20 | self.fca_2 = FCAttention(192, **kwargs)
21 |
22 | self.enc2_3 = Enc(240, 48, 4, **kwargs)
23 | self.enc3_3 = Enc(144, 96, 6, **kwargs)
24 | self.enc3_4 = Enc(288, 192, 4, **kwargs)
25 | self.fca_3 = FCAttention(192, **kwargs)
26 |
27 | self.enc2_1_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
28 | self.enc2_2_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
29 | self.enc2_3_reduce = _ConvBNReLU(48, 32, 1, **kwargs)
30 | self.conv_fusion = _ConvBNReLU(32, 32, 1, **kwargs)
31 |
32 | self.fca_1_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
33 | self.fca_2_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
34 | self.fca_3_reduce = _ConvBNReLU(192, 32, 1, **kwargs)
35 | self.conv_out = nn.Conv2d(32, nclass, 1)
36 |
37 | self.__setattr__('exclusive', ['enc2_2', 'enc3_2', 'enc4_2', 'fca_2', 'enc2_3', 'enc3_3', 'enc3_4', 'fca_3',
38 | 'enc2_1_reduce', 'enc2_2_reduce', 'enc2_3_reduce', 'conv_fusion', 'fca_1_reduce',
39 | 'fca_2_reduce', 'fca_3_reduce', 'conv_out'])
40 |
41 | def forward(self, x):
42 | # backbone
43 | stage1_conv1 = self.pretrained.conv1(x)
44 | stage1_enc2 = self.pretrained.enc2(stage1_conv1)
45 | stage1_enc3 = self.pretrained.enc3(stage1_enc2)
46 | stage1_enc4 = self.pretrained.enc4(stage1_enc3)
47 | stage1_fca = self.pretrained.fca(stage1_enc4)
48 | stage1_out = F.interpolate(stage1_fca, scale_factor=4, mode='bilinear', align_corners=True)
49 |
50 | # stage2
51 | stage2_enc2 = self.enc2_2(torch.cat([stage1_enc2, stage1_out], dim=1))
52 | stage2_enc3 = self.enc3_2(torch.cat([stage1_enc3, stage2_enc2], dim=1))
53 | stage2_enc4 = self.enc4_2(torch.cat([stage1_enc4, stage2_enc3], dim=1))
54 | stage2_fca = self.fca_2(stage2_enc4)
55 | stage2_out = F.interpolate(stage2_fca, scale_factor=4, mode='bilinear', align_corners=True)
56 |
57 | # stage3
58 | stage3_enc2 = self.enc2_3(torch.cat([stage2_enc2, stage2_out], dim=1))
59 | stage3_enc3 = self.enc3_3(torch.cat([stage2_enc3, stage3_enc2], dim=1))
60 | stage3_enc4 = self.enc3_4(torch.cat([stage2_enc4, stage3_enc3], dim=1))
61 | stage3_fca = self.fca_3(stage3_enc4)
62 |
63 | stage1_enc2_decoder = self.enc2_1_reduce(stage1_enc2)
64 | stage2_enc2_docoder = F.interpolate(self.enc2_2_reduce(stage2_enc2), scale_factor=2,
65 | mode='bilinear', align_corners=True)
66 | stage3_enc2_decoder = F.interpolate(self.enc2_3_reduce(stage3_enc2), scale_factor=4,
67 | mode='bilinear', align_corners=True)
68 | fusion = stage1_enc2_decoder + stage2_enc2_docoder + stage3_enc2_decoder
69 | fusion = self.conv_fusion(fusion)
70 |
71 | stage1_fca_decoder = F.interpolate(self.fca_1_reduce(stage1_fca), scale_factor=4,
72 | mode='bilinear', align_corners=True)
73 | stage2_fca_decoder = F.interpolate(self.fca_2_reduce(stage2_fca), scale_factor=8,
74 | mode='bilinear', align_corners=True)
75 | stage3_fca_decoder = F.interpolate(self.fca_3_reduce(stage3_fca), scale_factor=16,
76 | mode='bilinear', align_corners=True)
77 | fusion = fusion + stage1_fca_decoder + stage2_fca_decoder + stage3_fca_decoder
78 |
79 | outputs = list()
80 | out = self.conv_out(fusion)
81 | out = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True)
82 | outputs.append(out)
83 |
84 | return tuple(outputs)
85 |
86 |
87 | def get_dfanet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models',
88 | pretrained_base=True, **kwargs):
89 | acronyms = {
90 | 'pascal_voc': 'pascal_voc',
91 | 'pascal_aug': 'pascal_aug',
92 | 'ade20k': 'ade',
93 | 'coco': 'coco',
94 | 'citys': 'citys',
95 | }
96 | from ..data.dataloader import datasets
97 | model = DFANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
98 | if pretrained:
99 | from .model_store import get_model_file
100 | device = torch.device(kwargs['local_rank'])
101 | model.load_state_dict(torch.load(get_model_file('dfanet_%s' % (acronyms[dataset]), root=root),
102 | map_location=device))
103 | return model
104 |
105 |
106 | def get_dfanet_citys(**kwargs):
107 | return get_dfanet('citys', **kwargs)
108 |
109 |
110 | if __name__ == '__main__':
111 | model = get_dfanet_citys()
112 |
--------------------------------------------------------------------------------
/core/models/dunet.py:
--------------------------------------------------------------------------------
1 | """Decoders Matter for Semantic Segmentation"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .segbase import SegBaseModel
7 | from .fcn import _FCNHead
8 |
9 | __all__ = ['DUNet', 'get_dunet', 'get_dunet_resnet50_pascal_voc',
10 | 'get_dunet_resnet101_pascal_voc', 'get_dunet_resnet152_pascal_voc']
11 |
12 |
13 | # The model may be wrong because lots of details missing in paper.
14 | class DUNet(SegBaseModel):
15 | """Decoders Matter for Semantic Segmentation
16 |
17 | Reference:
18 | Zhi Tian, Tong He, Chunhua Shen, and Youliang Yan.
19 | "Decoders Matter for Semantic Segmentation:
20 | Data-Dependent Decoding Enables Flexible Feature Aggregation." CVPR, 2019
21 | """
22 |
23 | def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs):
24 | super(DUNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
25 | self.head = _DUHead(2144, **kwargs)
26 | self.dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs)
27 | if aux:
28 | self.auxlayer = _FCNHead(1024, 256, **kwargs)
29 | self.aux_dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs)
30 |
31 | self.__setattr__('exclusive',
32 | ['dupsample', 'head', 'auxlayer', 'aux_dupsample'] if aux else ['dupsample', 'head'])
33 |
34 | def forward(self, x):
35 | c1, c2, c3, c4 = self.base_forward(x)
36 | outputs = []
37 | x = self.head(c2, c3, c4)
38 | x = self.dupsample(x)
39 | outputs.append(x)
40 |
41 | if self.aux:
42 | auxout = self.auxlayer(c3)
43 | auxout = self.aux_dupsample(auxout)
44 | outputs.append(auxout)
45 | return tuple(outputs)
46 |
47 |
48 | class FeatureFused(nn.Module):
49 | """Module for fused features"""
50 |
51 | def __init__(self, inter_channels=48, norm_layer=nn.BatchNorm2d, **kwargs):
52 | super(FeatureFused, self).__init__()
53 | self.conv2 = nn.Sequential(
54 | nn.Conv2d(512, inter_channels, 1, bias=False),
55 | norm_layer(inter_channels),
56 | nn.ReLU(True)
57 | )
58 | self.conv3 = nn.Sequential(
59 | nn.Conv2d(1024, inter_channels, 1, bias=False),
60 | norm_layer(inter_channels),
61 | nn.ReLU(True)
62 | )
63 |
64 | def forward(self, c2, c3, c4):
65 | size = c4.size()[2:]
66 | c2 = self.conv2(F.interpolate(c2, size, mode='bilinear', align_corners=True))
67 | c3 = self.conv3(F.interpolate(c3, size, mode='bilinear', align_corners=True))
68 | fused_feature = torch.cat([c4, c3, c2], dim=1)
69 | return fused_feature
70 |
71 |
72 | class _DUHead(nn.Module):
73 | def __init__(self, in_channels, norm_layer=nn.BatchNorm2d, **kwargs):
74 | super(_DUHead, self).__init__()
75 | self.fuse = FeatureFused(norm_layer=norm_layer, **kwargs)
76 | self.block = nn.Sequential(
77 | nn.Conv2d(in_channels, 256, 3, padding=1, bias=False),
78 | norm_layer(256),
79 | nn.ReLU(True),
80 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
81 | norm_layer(256),
82 | nn.ReLU(True)
83 | )
84 |
85 | def forward(self, c2, c3, c4):
86 | fused_feature = self.fuse(c2, c3, c4)
87 | out = self.block(fused_feature)
88 | return out
89 |
90 |
91 | class DUpsampling(nn.Module):
92 | """DUsampling module"""
93 |
94 | def __init__(self, in_channels, out_channels, scale_factor=2, **kwargs):
95 | super(DUpsampling, self).__init__()
96 | self.scale_factor = scale_factor
97 | self.conv_w = nn.Conv2d(in_channels, out_channels * scale_factor * scale_factor, 1, bias=False)
98 |
99 | def forward(self, x):
100 | x = self.conv_w(x)
101 | n, c, h, w = x.size()
102 |
103 | # N, C, H, W --> N, W, H, C
104 | x = x.permute(0, 3, 2, 1).contiguous()
105 |
106 | # N, W, H, C --> N, W, H * scale, C // scale
107 | x = x.view(n, w, h * self.scale_factor, c // self.scale_factor)
108 |
109 | # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
110 | x = x.permute(0, 2, 1, 3).contiguous()
111 |
112 | # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
113 | x = x.view(n, h * self.scale_factor, w * self.scale_factor, c // (self.scale_factor * self.scale_factor))
114 |
115 | # N, H * scale, W * scale, C // (scale ** 2) -- > N, C // (scale ** 2), H * scale, W * scale
116 | x = x.permute(0, 3, 1, 2)
117 |
118 | return x
119 |
120 |
121 | def get_dunet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
122 | root='~/.torch/models', pretrained_base=True, **kwargs):
123 | acronyms = {
124 | 'pascal_voc': 'pascal_voc',
125 | 'pascal_aug': 'pascal_aug',
126 | 'ade20k': 'ade',
127 | 'coco': 'coco',
128 | 'citys': 'citys',
129 | }
130 | from ..data.dataloader import datasets
131 | model = DUNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
132 | if pretrained:
133 | from .model_store import get_model_file
134 | device = torch.device(kwargs['local_rank'])
135 | model.load_state_dict(torch.load(get_model_file('dunet_%s_%s' % (backbone, acronyms[dataset]), root=root),
136 | map_location=device))
137 | return model
138 |
139 |
140 | def get_dunet_resnet50_pascal_voc(**kwargs):
141 | return get_dunet('pascal_voc', 'resnet50', **kwargs)
142 |
143 |
144 | def get_dunet_resnet101_pascal_voc(**kwargs):
145 | return get_dunet('pascal_voc', 'resnet101', **kwargs)
146 |
147 |
148 | def get_dunet_resnet152_pascal_voc(**kwargs):
149 | return get_dunet('pascal_voc', 'resnet152', **kwargs)
150 |
151 |
152 | if __name__ == '__main__':
153 | img = torch.randn(2, 3, 256, 256)
154 | model = get_dunet_resnet50_pascal_voc()
155 | outputs = model(img)
156 |
--------------------------------------------------------------------------------
/core/models/espnet.py:
--------------------------------------------------------------------------------
1 | "ESPNetv2: A Light-weight, Power Efficient, and General Purpose for Semantic Segmentation"
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from core.models.base_models import eespnet, EESP
7 | from core.nn import _ConvBNPReLU, _BNPReLU
8 |
9 |
10 | class ESPNetV2(nn.Module):
11 | r"""ESPNetV2
12 |
13 | Parameters
14 | ----------
15 | nclass : int
16 | Number of categories for the training dataset.
17 | backbone : string
18 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
19 | 'resnet101' or 'resnet152').
20 | norm_layer : object
21 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
22 | for Synchronized Cross-GPU BachNormalization).
23 | aux : bool
24 | Auxiliary loss.
25 |
26 | Reference:
27 | Sachin Mehta, et al. "ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network."
28 | arXiv preprint arXiv:1811.11431 (2018).
29 | """
30 |
31 | def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs):
32 | super(ESPNetV2, self).__init__()
33 | self.pretrained = eespnet(pretrained=pretrained_base, **kwargs)
34 | self.proj_L4_C = _ConvBNPReLU(256, 128, 1, **kwargs)
35 | self.pspMod = nn.Sequential(
36 | EESP(256, 128, stride=1, k=4, r_lim=7, **kwargs),
37 | _PSPModule(128, 128, **kwargs))
38 | self.project_l3 = nn.Sequential(
39 | nn.Dropout2d(0.1),
40 | nn.Conv2d(128, nclass, 1, bias=False))
41 | self.act_l3 = _BNPReLU(nclass, **kwargs)
42 | self.project_l2 = _ConvBNPReLU(64 + nclass, nclass, 1, **kwargs)
43 | self.project_l1 = nn.Sequential(
44 | nn.Dropout2d(0.1),
45 | nn.Conv2d(32 + nclass, nclass, 1, bias=False))
46 |
47 | self.aux = aux
48 |
49 | self.__setattr__('exclusive', ['proj_L4_C', 'pspMod', 'project_l3', 'act_l3', 'project_l2', 'project_l1'])
50 |
51 | def forward(self, x):
52 | size = x.size()[2:]
53 | out_l1, out_l2, out_l3, out_l4 = self.pretrained(x, seg=True)
54 | out_l4_proj = self.proj_L4_C(out_l4)
55 | up_l4_to_l3 = F.interpolate(out_l4_proj, scale_factor=2, mode='bilinear', align_corners=True)
56 | merged_l3_upl4 = self.pspMod(torch.cat([out_l3, up_l4_to_l3], 1))
57 | proj_merge_l3_bef_act = self.project_l3(merged_l3_upl4)
58 | proj_merge_l3 = self.act_l3(proj_merge_l3_bef_act)
59 | out_up_l3 = F.interpolate(proj_merge_l3, scale_factor=2, mode='bilinear', align_corners=True)
60 | merge_l2 = self.project_l2(torch.cat([out_l2, out_up_l3], 1))
61 | out_up_l2 = F.interpolate(merge_l2, scale_factor=2, mode='bilinear', align_corners=True)
62 | merge_l1 = self.project_l1(torch.cat([out_l1, out_up_l2], 1))
63 |
64 | outputs = list()
65 | merge1_l1 = F.interpolate(merge_l1, scale_factor=2, mode='bilinear', align_corners=True)
66 | outputs.append(merge1_l1)
67 | if self.aux:
68 | # different from paper
69 | auxout = F.interpolate(proj_merge_l3_bef_act, size, mode='bilinear', align_corners=True)
70 | outputs.append(auxout)
71 |
72 | return tuple(outputs)
73 |
74 |
75 | # different from PSPNet
76 | class _PSPModule(nn.Module):
77 | def __init__(self, in_channels, out_channels=1024, sizes=(1, 2, 4, 8), **kwargs):
78 | super(_PSPModule, self).__init__()
79 | self.stages = nn.ModuleList(
80 | [nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels, bias=False) for _ in sizes])
81 | self.project = _ConvBNPReLU(in_channels * (len(sizes) + 1), out_channels, 1, 1, **kwargs)
82 |
83 | def forward(self, x):
84 | size = x.size()[2:]
85 | feats = [x]
86 | for stage in self.stages:
87 | x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1)
88 | upsampled = F.interpolate(stage(x), size, mode='bilinear', align_corners=True)
89 | feats.append(upsampled)
90 | return self.project(torch.cat(feats, dim=1))
91 |
92 |
93 | def get_espnet(dataset='pascal_voc', backbone='', pretrained=False, root='~/.torch/models',
94 | pretrained_base=False, **kwargs):
95 | acronyms = {
96 | 'pascal_voc': 'pascal_voc',
97 | 'pascal_aug': 'pascal_aug',
98 | 'ade20k': 'ade',
99 | 'coco': 'coco',
100 | 'citys': 'citys',
101 | }
102 | from core.data.dataloader import datasets
103 | model = ESPNetV2(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
104 | if pretrained:
105 | from .model_store import get_model_file
106 | device = torch.device(kwargs['local_rank'])
107 | model.load_state_dict(torch.load(get_model_file('espnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
108 | map_location=device))
109 | return model
110 |
111 |
112 | def get_espnet_citys(**kwargs):
113 | return get_espnet('citys', **kwargs)
114 |
115 |
116 | if __name__ == '__main__':
117 | model = get_espnet_citys()
118 |
--------------------------------------------------------------------------------
/core/models/fcnv2.py:
--------------------------------------------------------------------------------
1 | """Fully Convolutional Network with Stride of 8"""
2 | from __future__ import division
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .segbase import SegBaseModel
9 |
10 | __all__ = ['FCN', 'get_fcn', 'get_fcn_resnet50_voc',
11 | 'get_fcn_resnet101_voc', 'get_fcn_resnet152_voc']
12 |
13 |
14 | class FCN(SegBaseModel):
15 | def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs):
16 | super(FCN, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
17 | self.head = _FCNHead(2048, nclass, **kwargs)
18 | if aux:
19 | self.auxlayer = _FCNHead(1024, nclass, **kwargs)
20 |
21 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
22 |
23 | def forward(self, x):
24 | size = x.size()[2:]
25 | _, _, c3, c4 = self.base_forward(x)
26 |
27 | outputs = []
28 | x = self.head(c4)
29 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
30 | outputs.append(x)
31 | if self.aux:
32 | auxout = self.auxlayer(c3)
33 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
34 | outputs.append(auxout)
35 | return tuple(outputs)
36 |
37 |
38 | class _FCNHead(nn.Module):
39 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
40 | super(_FCNHead, self).__init__()
41 | inter_channels = in_channels // 4
42 | self.block = nn.Sequential(
43 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
44 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)),
45 | nn.ReLU(True),
46 | nn.Dropout(0.1),
47 | nn.Conv2d(inter_channels, channels, 1)
48 | )
49 |
50 | def forward(self, x):
51 | return self.block(x)
52 |
53 |
54 | def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
55 | pretrained_base=True, **kwargs):
56 | acronyms = {
57 | 'pascal_voc': 'pascal_voc',
58 | 'pascal_aug': 'pascal_aug',
59 | 'ade20k': 'ade',
60 | 'coco': 'coco',
61 | 'citys': 'citys',
62 | }
63 | from ..data.dataloader import datasets
64 | model = FCN(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
65 | if pretrained:
66 | from .model_store import get_model_file
67 | device = torch.device(kwargs['local_rank'])
68 | model.load_state_dict(torch.load(get_model_file('fcn_%s_%s' % (backbone, acronyms[dataset]), root=root),
69 | map_location=device))
70 | return model
71 |
72 |
73 | def get_fcn_resnet50_voc(**kwargs):
74 | return get_fcn('pascal_voc', 'resnet50', **kwargs)
75 |
76 |
77 | def get_fcn_resnet101_voc(**kwargs):
78 | return get_fcn('pascal_voc', 'resnet101', **kwargs)
79 |
80 |
81 | def get_fcn_resnet152_voc(**kwargs):
82 | return get_fcn('pascal_voc', 'resnet152', **kwargs)
83 |
--------------------------------------------------------------------------------
/core/models/hrnet.py:
--------------------------------------------------------------------------------
1 | """High-Resolution Representations for Semantic Segmentation"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class HRNet(nn.Module):
7 | """HRNet
8 |
9 | Parameters
10 | ----------
11 | nclass : int
12 | Number of categories for the training dataset.
13 | backbone : string
14 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
15 | 'resnet101' or 'resnet152').
16 | norm_layer : object
17 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
18 | for Synchronized Cross-GPU BachNormalization).
19 | aux : bool
20 | Auxiliary loss.
21 | Reference:
22 | Ke Sun. "High-Resolution Representations for Labeling Pixels and Regions."
23 | arXiv preprint arXiv:1904.04514 (2019).
24 | """
25 | def __init__(self, nclass, backbone='', aux=False, pretrained_base=False, **kwargs):
26 | super(HRNet, self).__init__()
27 |
28 | def forward(self, x):
29 | pass
--------------------------------------------------------------------------------
/core/models/icnet.py:
--------------------------------------------------------------------------------
1 | """Image Cascade Network"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .segbase import SegBaseModel
7 |
8 | __all__ = ['ICNet', 'get_icnet', 'get_icnet_resnet50_citys',
9 | 'get_icnet_resnet101_citys', 'get_icnet_resnet152_citys']
10 |
11 |
12 | class ICNet(SegBaseModel):
13 | """Image Cascade Network"""
14 |
15 | def __init__(self, nclass, backbone='resnet50', aux=False, jpu=False, pretrained_base=True, **kwargs):
16 | super(ICNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
17 | self.conv_sub1 = nn.Sequential(
18 | _ConvBNReLU(3, 32, 3, 2, **kwargs),
19 | _ConvBNReLU(32, 32, 3, 2, **kwargs),
20 | _ConvBNReLU(32, 64, 3, 2, **kwargs)
21 | )
22 |
23 | self.ppm = PyramidPoolingModule()
24 |
25 | self.head = _ICHead(nclass, **kwargs)
26 |
27 | self.__setattr__('exclusive', ['conv_sub1', 'head'])
28 |
29 | def forward(self, x):
30 | # sub 1
31 | x_sub1 = self.conv_sub1(x)
32 |
33 | # sub 2
34 | x_sub2 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
35 | _, x_sub2, _, _ = self.base_forward(x_sub2)
36 |
37 | # sub 4
38 | x_sub4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True)
39 | _, _, _, x_sub4 = self.base_forward(x_sub4)
40 | # add PyramidPoolingModule
41 | x_sub4 = self.ppm(x_sub4)
42 | outputs = self.head(x_sub1, x_sub2, x_sub4)
43 |
44 | return tuple(outputs)
45 |
46 | class PyramidPoolingModule(nn.Module):
47 | def __init__(self, pyramids=[1,2,3,6]):
48 | super(PyramidPoolingModule, self).__init__()
49 | self.pyramids = pyramids
50 |
51 | def forward(self, input):
52 | feat = input
53 | height, width = input.shape[2:]
54 | for bin_size in self.pyramids:
55 | x = F.adaptive_avg_pool2d(input, output_size=bin_size)
56 | x = F.interpolate(x, size=(height, width), mode='bilinear', align_corners=True)
57 | feat = feat + x
58 | return feat
59 |
60 | class _ICHead(nn.Module):
61 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
62 | super(_ICHead, self).__init__()
63 | #self.cff_12 = CascadeFeatureFusion(512, 64, 128, nclass, norm_layer, **kwargs)
64 | self.cff_12 = CascadeFeatureFusion(128, 64, 128, nclass, norm_layer, **kwargs)
65 | self.cff_24 = CascadeFeatureFusion(2048, 512, 128, nclass, norm_layer, **kwargs)
66 |
67 | self.conv_cls = nn.Conv2d(128, nclass, 1, bias=False)
68 |
69 | def forward(self, x_sub1, x_sub2, x_sub4):
70 | outputs = list()
71 | x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2)
72 | outputs.append(x_24_cls)
73 | #x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1)
74 | x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1)
75 | outputs.append(x_12_cls)
76 |
77 | up_x2 = F.interpolate(x_cff_12, scale_factor=2, mode='bilinear', align_corners=True)
78 | up_x2 = self.conv_cls(up_x2)
79 | outputs.append(up_x2)
80 | up_x8 = F.interpolate(up_x2, scale_factor=4, mode='bilinear', align_corners=True)
81 | outputs.append(up_x8)
82 | # 1 -> 1/4 -> 1/8 -> 1/16
83 | outputs.reverse()
84 |
85 | return outputs
86 |
87 |
88 | class _ConvBNReLU(nn.Module):
89 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1,
90 | groups=1, norm_layer=nn.BatchNorm2d, bias=False, **kwargs):
91 | super(_ConvBNReLU, self).__init__()
92 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
93 | self.bn = norm_layer(out_channels)
94 | self.relu = nn.ReLU(True)
95 |
96 | def forward(self, x):
97 | x = self.conv(x)
98 | x = self.bn(x)
99 | x = self.relu(x)
100 | return x
101 |
102 |
103 | class CascadeFeatureFusion(nn.Module):
104 | """CFF Unit"""
105 |
106 | def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
107 | super(CascadeFeatureFusion, self).__init__()
108 | self.conv_low = nn.Sequential(
109 | nn.Conv2d(low_channels, out_channels, 3, padding=2, dilation=2, bias=False),
110 | norm_layer(out_channels)
111 | )
112 | self.conv_high = nn.Sequential(
113 | nn.Conv2d(high_channels, out_channels, 1, bias=False),
114 | norm_layer(out_channels)
115 | )
116 | self.conv_low_cls = nn.Conv2d(out_channels, nclass, 1, bias=False)
117 |
118 | def forward(self, x_low, x_high):
119 | x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True)
120 | x_low = self.conv_low(x_low)
121 | x_high = self.conv_high(x_high)
122 | x = x_low + x_high
123 | x = F.relu(x, inplace=True)
124 | x_low_cls = self.conv_low_cls(x_low)
125 |
126 | return x, x_low_cls
127 |
128 |
129 | def get_icnet(dataset='citys', backbone='resnet50', pretrained=False, root='~/.torch/models',
130 | pretrained_base=True, **kwargs):
131 | acronyms = {
132 | 'pascal_voc': 'pascal_voc',
133 | 'pascal_aug': 'pascal_aug',
134 | 'ade20k': 'ade',
135 | 'coco': 'coco',
136 | 'citys': 'citys',
137 | }
138 | from ..data.dataloader import datasets
139 | model = ICNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
140 | if pretrained:
141 | from .model_store import get_model_file
142 | device = torch.device(kwargs['local_rank'])
143 | model.load_state_dict(torch.load(get_model_file('icnet_%s_%s' % (backbone, acronyms[dataset]), root=root),
144 | map_location=device))
145 | return model
146 |
147 |
148 | def get_icnet_resnet50_citys(**kwargs):
149 | return get_icnet('citys', 'resnet50', **kwargs)
150 |
151 |
152 | def get_icnet_resnet101_citys(**kwargs):
153 | return get_icnet('citys', 'resnet101', **kwargs)
154 |
155 |
156 | def get_icnet_resnet152_citys(**kwargs):
157 | return get_icnet('citys', 'resnet152', **kwargs)
158 |
159 |
160 | if __name__ == '__main__':
161 | img = torch.randn(1, 3, 256, 256)
162 | model = get_icnet_resnet50_citys()
163 | outputs = model(img)
164 |
--------------------------------------------------------------------------------
/core/models/lednet.py:
--------------------------------------------------------------------------------
1 | """LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from core.nn import _ConvBNReLU
7 |
8 | __all__ = ['LEDNet', 'get_lednet', 'get_lednet_citys']
9 |
10 | class LEDNet(nn.Module):
11 | r"""LEDNet
12 |
13 | Parameters
14 | ----------
15 | nclass : int
16 | Number of categories for the training dataset.
17 | backbone : string
18 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
19 | 'resnet101' or 'resnet152').
20 | norm_layer : object
21 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
22 | for Synchronized Cross-GPU BachNormalization).
23 | aux : bool
24 | Auxiliary loss.
25 |
26 | Reference:
27 | Yu Wang, et al. "LEDNet: A Lightweight Encoder-Decoder Network for Real-Time Semantic Segmentation."
28 | arXiv preprint arXiv:1905.02423 (2019).
29 | """
30 |
31 | def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=True, **kwargs):
32 | super(LEDNet, self).__init__()
33 | self.encoder = nn.Sequential(
34 | Downsampling(3, 32),
35 | SSnbt(32, **kwargs), SSnbt(32, **kwargs), SSnbt(32, **kwargs),
36 | Downsampling(32, 64),
37 | SSnbt(64, **kwargs), SSnbt(64, **kwargs),
38 | Downsampling(64, 128),
39 | SSnbt(128, **kwargs),
40 | SSnbt(128, 2, **kwargs),
41 | SSnbt(128, 5, **kwargs),
42 | SSnbt(128, 9, **kwargs),
43 | SSnbt(128, 2, **kwargs),
44 | SSnbt(128, 5, **kwargs),
45 | SSnbt(128, 9, **kwargs),
46 | SSnbt(128, 17, **kwargs),
47 | )
48 | self.decoder = APNModule(128, nclass)
49 |
50 | self.__setattr__('exclusive', ['encoder', 'decoder'])
51 |
52 | def forward(self, x):
53 | size = x.size()[2:]
54 | x = self.encoder(x)
55 | x = self.decoder(x)
56 | outputs = list()
57 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
58 | outputs.append(x)
59 |
60 | return tuple(outputs)
61 |
62 |
63 | class Downsampling(nn.Module):
64 | def __init__(self, in_channels, out_channels, **kwargs):
65 | super(Downsampling, self).__init__()
66 | self.conv1 = nn.Conv2d(in_channels, out_channels // 2, 3, 2, 2, bias=False)
67 | self.conv2 = nn.Conv2d(in_channels, out_channels // 2, 3, 2, 2, bias=False)
68 | self.pool = nn.MaxPool2d(kernel_size=2, stride=1)
69 |
70 | def forward(self, x):
71 | x1 = self.conv1(x)
72 | x1 = self.pool(x1)
73 |
74 | x2 = self.conv2(x)
75 | x2 = self.pool(x2)
76 |
77 | return torch.cat([x1, x2], dim=1)
78 |
79 |
80 | class SSnbt(nn.Module):
81 | def __init__(self, in_channels, dilation=1, norm_layer=nn.BatchNorm2d, **kwargs):
82 | super(SSnbt, self).__init__()
83 | inter_channels = in_channels // 2
84 | self.branch1 = nn.Sequential(
85 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(1, 0), bias=False),
86 | nn.ReLU(True),
87 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, 1), bias=False),
88 | norm_layer(inter_channels),
89 | nn.ReLU(True),
90 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(dilation, 0), dilation=(dilation, 1),
91 | bias=False),
92 | nn.ReLU(True),
93 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, dilation), dilation=(1, dilation),
94 | bias=False),
95 | norm_layer(inter_channels),
96 | nn.ReLU(True))
97 |
98 | self.branch2 = nn.Sequential(
99 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, 1), bias=False),
100 | nn.ReLU(True),
101 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(1, 0), bias=False),
102 | norm_layer(inter_channels),
103 | nn.ReLU(True),
104 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, dilation), dilation=(1, dilation),
105 | bias=False),
106 | nn.ReLU(True),
107 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(dilation, 0), dilation=(dilation, 1),
108 | bias=False),
109 | norm_layer(inter_channels),
110 | nn.ReLU(True))
111 |
112 | self.relu = nn.ReLU(True)
113 |
114 | @staticmethod
115 | def channel_shuffle(x, groups):
116 | n, c, h, w = x.size()
117 |
118 | channels_per_group = c // groups
119 | x = x.view(n, groups, channels_per_group, h, w)
120 | x = torch.transpose(x, 1, 2).contiguous()
121 | x = x.view(n, -1, h, w)
122 |
123 | return x
124 |
125 | def forward(self, x):
126 | # channels split
127 | x1, x2 = x.split(x.size(1) // 2, 1)
128 |
129 | x1 = self.branch1(x1)
130 | x2 = self.branch2(x2)
131 |
132 | out = torch.cat([x1, x2], dim=1)
133 | out = self.relu(out + x)
134 | out = self.channel_shuffle(out, groups=2)
135 |
136 | return out
137 |
138 |
139 | class APNModule(nn.Module):
140 | def __init__(self, in_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
141 | super(APNModule, self).__init__()
142 | self.conv1 = _ConvBNReLU(in_channels, in_channels, 3, 2, 1, norm_layer=norm_layer)
143 | self.conv2 = _ConvBNReLU(in_channels, in_channels, 5, 2, 2, norm_layer=norm_layer)
144 | self.conv3 = _ConvBNReLU(in_channels, in_channels, 7, 2, 3, norm_layer=norm_layer)
145 | self.level1 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer)
146 | self.level2 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer)
147 | self.level3 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer)
148 | self.level4 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer)
149 | self.level5 = nn.Sequential(
150 | nn.AdaptiveAvgPool2d(1),
151 | _ConvBNReLU(in_channels, nclass, 1))
152 |
153 | def forward(self, x):
154 | w, h = x.size()[2:]
155 | branch3 = self.conv1(x)
156 | branch2 = self.conv2(branch3)
157 | branch1 = self.conv3(branch2)
158 |
159 | out = self.level1(branch1)
160 | out = F.interpolate(out, ((w + 3) // 4, (h + 3) // 4), mode='bilinear', align_corners=True)
161 | out = self.level2(branch2) + out
162 | out = F.interpolate(out, ((w + 1) // 2, (h + 1) // 2), mode='bilinear', align_corners=True)
163 | out = self.level3(branch3) + out
164 | out = F.interpolate(out, (w, h), mode='bilinear', align_corners=True)
165 | out = self.level4(x) * out
166 | out = self.level5(x) + out
167 | return out
168 |
169 |
170 | def get_lednet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models',
171 | pretrained_base=True, **kwargs):
172 | acronyms = {
173 | 'pascal_voc': 'pascal_voc',
174 | 'pascal_aug': 'pascal_aug',
175 | 'ade20k': 'ade',
176 | 'coco': 'coco',
177 | 'citys': 'citys',
178 | }
179 | from ..data.dataloader import datasets
180 | model = LEDNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
181 | if pretrained:
182 | from .model_store import get_model_file
183 | device = torch.device(kwargs['local_rank'])
184 | model.load_state_dict(torch.load(get_model_file('lednet_%s' % (acronyms[dataset]), root=root),
185 | map_location=device))
186 | return model
187 |
188 |
189 | def get_lednet_citys(**kwargs):
190 | return get_lednet('citys', **kwargs)
191 |
192 |
193 | if __name__ == '__main__':
194 | model = get_lednet_citys()
195 |
--------------------------------------------------------------------------------
/core/models/model_store.py:
--------------------------------------------------------------------------------
1 | """Model store which provides pretrained models."""
2 | from __future__ import print_function
3 |
4 | import os
5 | import zipfile
6 |
7 | from ..utils.download import download, check_sha1
8 |
9 | __all__ = ['get_model_file', 'get_resnet_file']
10 |
11 | _model_sha1 = {name: checksum for checksum, name in [
12 | ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'),
13 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'),
14 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'),
15 | ]}
16 |
17 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/'
18 | _url_format = '{repo_url}encoding/models/{file_name}.zip'
19 |
20 |
21 | def short_hash(name):
22 | if name not in _model_sha1:
23 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name))
24 | return _model_sha1[name][:8]
25 |
26 |
27 | def get_resnet_file(name, root='~/.torch/models'):
28 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name))
29 | root = os.path.expanduser(root)
30 |
31 | file_path = os.path.join(root, file_name + '.pth')
32 | sha1_hash = _model_sha1[name]
33 | if os.path.exists(file_path):
34 | if check_sha1(file_path, sha1_hash):
35 | return file_path
36 | else:
37 | print('Mismatch in the content of model file {} detected.' +
38 | ' Downloading again.'.format(file_path))
39 | else:
40 | print('Model file {} is not found. Downloading.'.format(file_path))
41 |
42 | if not os.path.exists(root):
43 | os.makedirs(root)
44 |
45 | zip_file_path = os.path.join(root, file_name + '.zip')
46 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url)
47 | if repo_url[-1] != '/':
48 | repo_url = repo_url + '/'
49 | download(_url_format.format(repo_url=repo_url, file_name=file_name),
50 | path=zip_file_path,
51 | overwrite=True)
52 | with zipfile.ZipFile(zip_file_path) as zf:
53 | zf.extractall(root)
54 | os.remove(zip_file_path)
55 |
56 | if check_sha1(file_path, sha1_hash):
57 | return file_path
58 | else:
59 | raise ValueError('Downloaded file has different hash. Please try again.')
60 |
61 |
62 | def get_model_file(name, root='~/.torch/models'):
63 | root = os.path.expanduser(root)
64 | file_path = os.path.join(root, name + '.pth')
65 | if os.path.exists(file_path):
66 | return file_path
67 | else:
68 | raise ValueError('Model file is not found. Downloading or trainning.')
69 |
--------------------------------------------------------------------------------
/core/models/model_zoo.py:
--------------------------------------------------------------------------------
1 | """Model store which handles pretrained models """
2 | from .fcn import *
3 | from .fcnv2 import *
4 | from .pspnet import *
5 | from .deeplabv3 import *
6 | from .deeplabv3_plus import *
7 | from .danet import *
8 | from .denseaspp import *
9 | from .bisenet import *
10 | from .encnet import *
11 | from .dunet import *
12 | from .icnet import *
13 | from .enet import *
14 | from .ocnet import *
15 | from .psanet import *
16 | from .cgnet import *
17 | from .espnet import *
18 | from .lednet import *
19 | from .dfanet import *
20 |
21 | __all__ = ['get_model', 'get_model_list', 'get_segmentation_model']
22 |
23 | _models = {
24 | 'fcn32s_vgg16_voc': get_fcn32s_vgg16_voc,
25 | 'fcn16s_vgg16_voc': get_fcn16s_vgg16_voc,
26 | 'fcn8s_vgg16_voc': get_fcn8s_vgg16_voc,
27 | 'fcn_resnet50_voc': get_fcn_resnet50_voc,
28 | 'fcn_resnet101_voc': get_fcn_resnet101_voc,
29 | 'fcn_resnet152_voc': get_fcn_resnet152_voc,
30 | 'psp_resnet50_voc': get_psp_resnet50_voc,
31 | 'psp_resnet50_ade': get_psp_resnet50_ade,
32 | 'psp_resnet101_voc': get_psp_resnet101_voc,
33 | 'psp_resnet101_ade': get_psp_resnet101_ade,
34 | 'psp_resnet101_citys': get_psp_resnet101_citys,
35 | 'psp_resnet101_coco': get_psp_resnet101_coco,
36 | 'deeplabv3_resnet50_voc': get_deeplabv3_resnet50_voc,
37 | 'deeplabv3_resnet101_voc': get_deeplabv3_resnet101_voc,
38 | 'deeplabv3_resnet152_voc': get_deeplabv3_resnet152_voc,
39 | 'deeplabv3_resnet50_ade': get_deeplabv3_resnet50_ade,
40 | 'deeplabv3_resnet101_ade': get_deeplabv3_resnet101_ade,
41 | 'deeplabv3_resnet152_ade': get_deeplabv3_resnet152_ade,
42 | 'deeplabv3_plus_xception_voc': get_deeplabv3_plus_xception_voc,
43 | 'danet_resnet50_ciyts': get_danet_resnet50_citys,
44 | 'danet_resnet101_citys': get_danet_resnet101_citys,
45 | 'danet_resnet152_citys': get_danet_resnet152_citys,
46 | 'denseaspp_densenet121_citys': get_denseaspp_densenet121_citys,
47 | 'denseaspp_densenet161_citys': get_denseaspp_densenet161_citys,
48 | 'denseaspp_densenet169_citys': get_denseaspp_densenet169_citys,
49 | 'denseaspp_densenet201_citys': get_denseaspp_densenet201_citys,
50 | 'bisenet_resnet18_citys': get_bisenet_resnet18_citys,
51 | 'encnet_resnet50_ade': get_encnet_resnet50_ade,
52 | 'encnet_resnet101_ade': get_encnet_resnet101_ade,
53 | 'encnet_resnet152_ade': get_encnet_resnet152_ade,
54 | 'dunet_resnet50_pascal_voc': get_dunet_resnet50_pascal_voc,
55 | 'dunet_resnet101_pascal_voc': get_dunet_resnet101_pascal_voc,
56 | 'dunet_resnet152_pascal_voc': get_dunet_resnet152_pascal_voc,
57 | 'icnet_resnet50_citys': get_icnet_resnet50_citys,
58 | 'icnet_resnet101_citys': get_icnet_resnet101_citys,
59 | 'icnet_resnet152_citys': get_icnet_resnet152_citys,
60 | 'enet_citys': get_enet_citys,
61 | 'base_ocnet_resnet101_citys': get_base_ocnet_resnet101_citys,
62 | 'pyramid_ocnet_resnet101_citys': get_pyramid_ocnet_resnet101_citys,
63 | 'asp_ocnet_resnet101_citys': get_asp_ocnet_resnet101_citys,
64 | 'psanet_resnet50_voc': get_psanet_resnet50_voc,
65 | 'psanet_resnet101_voc': get_psanet_resnet101_voc,
66 | 'psanet_resnet152_voc': get_psanet_resnet152_voc,
67 | 'psanet_resnet50_citys': get_psanet_resnet50_citys,
68 | 'psanet_resnet101_citys': get_psanet_resnet101_citys,
69 | 'psanet_resnet152_citys': get_psanet_resnet152_citys,
70 | 'cgnet_citys': get_cgnet_citys,
71 | 'espnet_citys': get_espnet_citys,
72 | 'lednet_citys': get_lednet_citys,
73 | 'dfanet_citys': get_dfanet_citys,
74 | }
75 |
76 |
77 | def get_model(name, **kwargs):
78 | name = name.lower()
79 | if name not in _models:
80 | err_str = '"%s" is not among the following model list:\n\t' % (name)
81 | err_str += '%s' % ('\n\t'.join(sorted(_models.keys())))
82 | raise ValueError(err_str)
83 | net = _models[name](**kwargs)
84 | return net
85 |
86 |
87 | def get_model_list():
88 | return _models.keys()
89 |
90 |
91 | def get_segmentation_model(model, **kwargs):
92 | models = {
93 | 'fcn32s': get_fcn32s,
94 | 'fcn16s': get_fcn16s,
95 | 'fcn8s': get_fcn8s,
96 | 'fcn': get_fcn,
97 | 'psp': get_psp,
98 | 'deeplabv3': get_deeplabv3,
99 | 'deeplabv3_plus': get_deeplabv3_plus,
100 | 'danet': get_danet,
101 | 'denseaspp': get_denseaspp,
102 | 'bisenet': get_bisenet,
103 | 'encnet': get_encnet,
104 | 'dunet': get_dunet,
105 | 'icnet': get_icnet,
106 | 'enet': get_enet,
107 | 'ocnet': get_ocnet,
108 | 'psanet': get_psanet,
109 | 'cgnet': get_cgnet,
110 | 'espnet': get_espnet,
111 | 'lednet': get_lednet,
112 | 'dfanet': get_dfanet,
113 | }
114 | return models[model](**kwargs)
115 |
--------------------------------------------------------------------------------
/core/models/psanet.py:
--------------------------------------------------------------------------------
1 | """Point-wise Spatial Attention Network"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from core.nn import _ConvBNReLU
7 | from core.models.segbase import SegBaseModel
8 | from core.models.fcn import _FCNHead
9 |
10 | __all__ = ['PSANet', 'get_psanet', 'get_psanet_resnet50_voc', 'get_psanet_resnet101_voc',
11 | 'get_psanet_resnet152_voc', 'get_psanet_resnet50_citys', 'get_psanet_resnet101_citys',
12 | 'get_psanet_resnet152_citys']
13 |
14 |
15 | class PSANet(SegBaseModel):
16 | r"""PSANet
17 |
18 | Parameters
19 | ----------
20 | nclass : int
21 | Number of categories for the training dataset.
22 | backbone : string
23 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
24 | 'resnet101' or 'resnet152').
25 | norm_layer : object
26 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
27 | for Synchronized Cross-GPU BachNormalization).
28 | aux : bool
29 | Auxiliary loss.
30 |
31 | Reference:
32 | Hengshuang Zhao, et al. "PSANet: Point-wise Spatial Attention Network for Scene Parsing."
33 | ECCV-2018.
34 | """
35 |
36 | def __init__(self, nclass, backbone='resnet', aux=False, pretrained_base=True, **kwargs):
37 | super(PSANet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
38 | self.head = _PSAHead(nclass, **kwargs)
39 | if aux:
40 | self.auxlayer = _FCNHead(1024, nclass, **kwargs)
41 |
42 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
43 |
44 | def forward(self, x):
45 | size = x.size()[2:]
46 | _, _, c3, c4 = self.base_forward(x)
47 | outputs = list()
48 | x = self.head(c4)
49 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
50 | outputs.append(x)
51 |
52 | if self.aux:
53 | auxout = self.auxlayer(c3)
54 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
55 | outputs.append(auxout)
56 | return tuple(outputs)
57 |
58 |
59 | class _PSAHead(nn.Module):
60 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs):
61 | super(_PSAHead, self).__init__()
62 | # psa_out_channels = crop_size // 8 ** 2
63 | self.psa = _PointwiseSpatialAttention(2048, 3600, norm_layer)
64 |
65 | self.conv_post = _ConvBNReLU(1024, 2048, 1, norm_layer=norm_layer)
66 | self.project = nn.Sequential(
67 | _ConvBNReLU(4096, 512, 3, padding=1, norm_layer=norm_layer),
68 | nn.Dropout2d(0.1, False),
69 | nn.Conv2d(512, nclass, 1))
70 |
71 | def forward(self, x):
72 | global_feature = self.psa(x)
73 | out = self.conv_post(global_feature)
74 | out = torch.cat([x, out], dim=1)
75 | out = self.project(out)
76 |
77 | return out
78 |
79 |
80 | class _PointwiseSpatialAttention(nn.Module):
81 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
82 | super(_PointwiseSpatialAttention, self).__init__()
83 | reduced_channels = 512
84 | self.collect_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer)
85 | self.distribute_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer)
86 |
87 | def forward(self, x):
88 | collect_fm = self.collect_attention(x)
89 | distribute_fm = self.distribute_attention(x)
90 | psa_fm = torch.cat([collect_fm, distribute_fm], dim=1)
91 | return psa_fm
92 |
93 |
94 | class _AttentionGeneration(nn.Module):
95 | def __init__(self, in_channels, reduced_channels, out_channels, norm_layer, **kwargs):
96 | super(_AttentionGeneration, self).__init__()
97 | self.conv_reduce = _ConvBNReLU(in_channels, reduced_channels, 1, norm_layer=norm_layer)
98 | self.attention = nn.Sequential(
99 | _ConvBNReLU(reduced_channels, reduced_channels, 1, norm_layer=norm_layer),
100 | nn.Conv2d(reduced_channels, out_channels, 1, bias=False))
101 |
102 | self.reduced_channels = reduced_channels
103 |
104 | def forward(self, x):
105 | reduce_x = self.conv_reduce(x)
106 | attention = self.attention(reduce_x)
107 | n, c, h, w = attention.size()
108 | attention = attention.view(n, c, -1)
109 | reduce_x = reduce_x.view(n, self.reduced_channels, -1)
110 | fm = torch.bmm(reduce_x, torch.softmax(attention, dim=1))
111 | fm = fm.view(n, self.reduced_channels, h, w)
112 |
113 | return fm
114 |
115 |
116 | def get_psanet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
117 | pretrained_base=True, **kwargs):
118 | acronyms = {
119 | 'pascal_voc': 'pascal_voc',
120 | 'pascal_aug': 'pascal_aug',
121 | 'ade20k': 'ade',
122 | 'coco': 'coco',
123 | 'citys': 'citys',
124 | }
125 | from core.data.dataloader import datasets
126 | model = PSANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
127 | if pretrained:
128 | from .model_store import get_model_file
129 | device = torch.device(kwargs['local_rank'])
130 | model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root),
131 | map_location=device))
132 | return model
133 |
134 |
135 | def get_psanet_resnet50_voc(**kwargs):
136 | return get_psanet('pascal_voc', 'resnet50', **kwargs)
137 |
138 |
139 | def get_psanet_resnet101_voc(**kwargs):
140 | return get_psanet('pascal_voc', 'resnet101', **kwargs)
141 |
142 |
143 | def get_psanet_resnet152_voc(**kwargs):
144 | return get_psanet('pascal_voc', 'resnet152', **kwargs)
145 |
146 |
147 | def get_psanet_resnet50_citys(**kwargs):
148 | return get_psanet('citys', 'resnet50', **kwargs)
149 |
150 |
151 | def get_psanet_resnet101_citys(**kwargs):
152 | return get_psanet('citys', 'resnet101', **kwargs)
153 |
154 |
155 | def get_psanet_resnet152_citys(**kwargs):
156 | return get_psanet('citys', 'resnet152', **kwargs)
157 |
158 |
159 | if __name__ == '__main__':
160 | model = get_psanet_resnet50_voc()
161 | img = torch.randn(1, 3, 480, 480)
162 | output = model(img)
163 |
--------------------------------------------------------------------------------
/core/models/pspnet.py:
--------------------------------------------------------------------------------
1 | """Pyramid Scene Parsing Network"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .segbase import SegBaseModel
7 | from .fcn import _FCNHead
8 |
9 | __all__ = ['PSPNet', 'get_psp', 'get_psp_resnet50_voc', 'get_psp_resnet50_ade', 'get_psp_resnet101_voc',
10 | 'get_psp_resnet101_ade', 'get_psp_resnet101_citys', 'get_psp_resnet101_coco']
11 |
12 |
13 | class PSPNet(SegBaseModel):
14 | r"""Pyramid Scene Parsing Network
15 |
16 | Parameters
17 | ----------
18 | nclass : int
19 | Number of categories for the training dataset.
20 | backbone : string
21 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
22 | 'resnet101' or 'resnet152').
23 | norm_layer : object
24 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
25 | for Synchronized Cross-GPU BachNormalization).
26 | aux : bool
27 | Auxiliary loss.
28 |
29 | Reference:
30 | Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia.
31 | "Pyramid scene parsing network." *CVPR*, 2017
32 | """
33 |
34 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs):
35 | super(PSPNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
36 | self.head = _PSPHead(nclass, **kwargs)
37 | if self.aux:
38 | self.auxlayer = _FCNHead(1024, nclass, **kwargs)
39 |
40 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
41 |
42 | def forward(self, x):
43 | size = x.size()[2:]
44 | _, _, c3, c4 = self.base_forward(x)
45 | outputs = []
46 | x = self.head(c4)
47 | x = F.interpolate(x, size, mode='bilinear', align_corners=True)
48 | outputs.append(x)
49 |
50 | if self.aux:
51 | auxout = self.auxlayer(c3)
52 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
53 | outputs.append(auxout)
54 | return tuple(outputs)
55 |
56 |
57 | def _PSP1x1Conv(in_channels, out_channels, norm_layer, norm_kwargs):
58 | return nn.Sequential(
59 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
60 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
61 | nn.ReLU(True)
62 | )
63 |
64 |
65 | class _PyramidPooling(nn.Module):
66 | def __init__(self, in_channels, **kwargs):
67 | super(_PyramidPooling, self).__init__()
68 | out_channels = int(in_channels / 4)
69 | self.avgpool1 = nn.AdaptiveAvgPool2d(1)
70 | self.avgpool2 = nn.AdaptiveAvgPool2d(2)
71 | self.avgpool3 = nn.AdaptiveAvgPool2d(3)
72 | self.avgpool4 = nn.AdaptiveAvgPool2d(6)
73 | self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
74 | self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
75 | self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
76 | self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs)
77 |
78 | def forward(self, x):
79 | size = x.size()[2:]
80 | feat1 = F.interpolate(self.conv1(self.avgpool1(x)), size, mode='bilinear', align_corners=True)
81 | feat2 = F.interpolate(self.conv2(self.avgpool2(x)), size, mode='bilinear', align_corners=True)
82 | feat3 = F.interpolate(self.conv3(self.avgpool3(x)), size, mode='bilinear', align_corners=True)
83 | feat4 = F.interpolate(self.conv4(self.avgpool4(x)), size, mode='bilinear', align_corners=True)
84 | return torch.cat([x, feat1, feat2, feat3, feat4], dim=1)
85 |
86 |
87 | class _PSPHead(nn.Module):
88 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
89 | super(_PSPHead, self).__init__()
90 | self.psp = _PyramidPooling(2048, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
91 | self.block = nn.Sequential(
92 | nn.Conv2d(4096, 512, 3, padding=1, bias=False),
93 | norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)),
94 | nn.ReLU(True),
95 | nn.Dropout(0.1),
96 | nn.Conv2d(512, nclass, 1)
97 | )
98 |
99 | def forward(self, x):
100 | x = self.psp(x)
101 | return self.block(x)
102 |
103 |
104 | def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
105 | pretrained_base=True, **kwargs):
106 | r"""Pyramid Scene Parsing Network
107 |
108 | Parameters
109 | ----------
110 | dataset : str, default pascal_voc
111 | The dataset that model pretrained on. (pascal_voc, ade20k)
112 | pretrained : bool or str
113 | Boolean value controls whether to load the default pretrained weights for model.
114 | String value represents the hashtag for a certain version of pretrained weights.
115 | root : str, default '~/.torch/models'
116 | Location for keeping the model parameters.
117 | pretrained_base : bool or str, default True
118 | This will load pretrained backbone network, that was trained on ImageNet.
119 | Examples
120 | --------
121 | >>> model = get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False)
122 | >>> print(model)
123 | """
124 | acronyms = {
125 | 'pascal_voc': 'pascal_voc',
126 | 'pascal_aug': 'pascal_aug',
127 | 'ade20k': 'ade',
128 | 'coco': 'coco',
129 | 'citys': 'citys',
130 | }
131 | from ..data.dataloader import datasets
132 | model = PSPNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
133 | if pretrained:
134 | from .model_store import get_model_file
135 | device = torch.device(kwargs['local_rank'])
136 | model.load_state_dict(torch.load(get_model_file('psp_%s_%s' % (backbone, acronyms[dataset]), root=root),
137 | map_location=device))
138 | return model
139 |
140 |
141 | def get_psp_resnet50_voc(**kwargs):
142 | return get_psp('pascal_voc', 'resnet50', **kwargs)
143 |
144 |
145 | def get_psp_resnet50_ade(**kwargs):
146 | return get_psp('ade20k', 'resnet50', **kwargs)
147 |
148 |
149 | def get_psp_resnet101_voc(**kwargs):
150 | return get_psp('pascal_voc', 'resnet101', **kwargs)
151 |
152 |
153 | def get_psp_resnet101_ade(**kwargs):
154 | return get_psp('ade20k', 'resnet101', **kwargs)
155 |
156 |
157 | def get_psp_resnet101_citys(**kwargs):
158 | return get_psp('citys', 'resnet101', **kwargs)
159 |
160 |
161 | def get_psp_resnet101_coco(**kwargs):
162 | return get_psp('coco', 'resnet101', **kwargs)
163 |
164 |
165 | if __name__ == '__main__':
166 | model = get_psp_resnet50_voc()
167 | img = torch.randn(4, 3, 480, 480)
168 | output = model(img)
169 |
--------------------------------------------------------------------------------
/core/models/segbase.py:
--------------------------------------------------------------------------------
1 | """Base Model for Semantic Segmentation"""
2 | import torch.nn as nn
3 |
4 | from ..nn import JPU
5 | from .base_models.resnetv1b import resnet50_v1s, resnet101_v1s, resnet152_v1s
6 |
7 | __all__ = ['SegBaseModel']
8 |
9 |
10 | class SegBaseModel(nn.Module):
11 | r"""Base Model for Semantic Segmentation
12 |
13 | Parameters
14 | ----------
15 | backbone : string
16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
17 | 'resnet101' or 'resnet152').
18 | """
19 |
20 | def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=True, **kwargs):
21 | super(SegBaseModel, self).__init__()
22 | dilated = False if jpu else True
23 | self.aux = aux
24 | self.nclass = nclass
25 | if backbone == 'resnet50':
26 | self.pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
27 | elif backbone == 'resnet101':
28 | self.pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
29 | elif backbone == 'resnet152':
30 | self.pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs)
31 | else:
32 | raise RuntimeError('unknown backbone: {}'.format(backbone))
33 |
34 | self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None
35 |
36 | def base_forward(self, x):
37 | """forwarding pre-trained network"""
38 | x = self.pretrained.conv1(x)
39 | x = self.pretrained.bn1(x)
40 | x = self.pretrained.relu(x)
41 | x = self.pretrained.maxpool(x)
42 | c1 = self.pretrained.layer1(x)
43 | c2 = self.pretrained.layer2(c1)
44 | c3 = self.pretrained.layer3(c2)
45 | c4 = self.pretrained.layer4(c3)
46 |
47 | if self.jpu:
48 | return self.jpu(c1, c2, c3, c4)
49 | else:
50 | return c1, c2, c3, c4
51 |
52 | def evaluate(self, x):
53 | """evaluating network with inputs and targets"""
54 | return self.forward(x)[0]
55 |
56 | def demo(self, x):
57 | pred = self.forward(x)
58 | if self.aux:
59 | pred = pred[0]
60 | return pred
61 |
--------------------------------------------------------------------------------
/core/nn/__init__.py:
--------------------------------------------------------------------------------
1 | """Seg NN Modules"""
2 | from .jpu import *
3 | from .basic import *
4 |
--------------------------------------------------------------------------------
/core/nn/basic.py:
--------------------------------------------------------------------------------
1 | """Basic Module for Semantic Segmentation"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | __all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual']
7 |
8 |
9 | class _ConvBNReLU(nn.Module):
10 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
11 | dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d, **kwargs):
12 | super(_ConvBNReLU, self).__init__()
13 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
14 | self.bn = norm_layer(out_channels)
15 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
16 |
17 | def forward(self, x):
18 | x = self.conv(x)
19 | x = self.bn(x)
20 | x = self.relu(x)
21 | return x
22 |
23 |
24 | class _ConvBNPReLU(nn.Module):
25 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
26 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs):
27 | super(_ConvBNPReLU, self).__init__()
28 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
29 | self.bn = norm_layer(out_channels)
30 | self.prelu = nn.PReLU(out_channels)
31 |
32 | def forward(self, x):
33 | x = self.conv(x)
34 | x = self.bn(x)
35 | x = self.prelu(x)
36 | return x
37 |
38 |
39 | class _ConvBN(nn.Module):
40 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
41 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs):
42 | super(_ConvBN, self).__init__()
43 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False)
44 | self.bn = norm_layer(out_channels)
45 |
46 | def forward(self, x):
47 | x = self.conv(x)
48 | x = self.bn(x)
49 | return x
50 |
51 |
52 | class _BNPReLU(nn.Module):
53 | def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs):
54 | super(_BNPReLU, self).__init__()
55 | self.bn = norm_layer(out_channels)
56 | self.prelu = nn.PReLU(out_channels)
57 |
58 | def forward(self, x):
59 | x = self.bn(x)
60 | x = self.prelu(x)
61 | return x
62 |
63 |
64 | # -----------------------------------------------------------------
65 | # For PSPNet
66 | # -----------------------------------------------------------------
67 | class _PSPModule(nn.Module):
68 | def __init__(self, in_channels, sizes=(1, 2, 3, 6), **kwargs):
69 | super(_PSPModule, self).__init__()
70 | out_channels = int(in_channels / 4)
71 | self.avgpools = nn.ModuleList()
72 | self.convs = nn.ModuleList()
73 | for size in sizes:
74 | self.avgpool.append(nn.AdaptiveAvgPool2d(size))
75 | self.convs.append(_ConvBNReLU(in_channels, out_channels, 1, **kwargs))
76 |
77 | def forward(self, x):
78 | size = x.size()[2:]
79 | feats = [x]
80 | for (avgpool, conv) in enumerate(zip(self.avgpools, self.convs)):
81 | feats.append(F.interpolate(conv(avgpool(x)), size, mode='bilinear', align_corners=True))
82 | return torch.cat(feats, dim=1)
83 |
84 |
85 | # -----------------------------------------------------------------
86 | # For MobileNet
87 | # -----------------------------------------------------------------
88 | class _DepthwiseConv(nn.Module):
89 | """conv_dw in MobileNet"""
90 |
91 | def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs):
92 | super(_DepthwiseConv, self).__init__()
93 | self.conv = nn.Sequential(
94 | _ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer),
95 | _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer))
96 |
97 | def forward(self, x):
98 | return self.conv(x)
99 |
100 |
101 | # -----------------------------------------------------------------
102 | # For MobileNetV2
103 | # -----------------------------------------------------------------
104 | class InvertedResidual(nn.Module):
105 | def __init__(self, in_channels, out_channels, stride, expand_ratio, norm_layer=nn.BatchNorm2d, **kwargs):
106 | super(InvertedResidual, self).__init__()
107 | assert stride in [1, 2]
108 | self.use_res_connect = stride == 1 and in_channels == out_channels
109 |
110 | layers = list()
111 | inter_channels = int(round(in_channels * expand_ratio))
112 | if expand_ratio != 1:
113 | # pw
114 | layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer))
115 | layers.extend([
116 | # dw
117 | _ConvBNReLU(inter_channels, inter_channels, 3, stride, 1,
118 | groups=inter_channels, relu6=True, norm_layer=norm_layer),
119 | # pw-linear
120 | nn.Conv2d(inter_channels, out_channels, 1, bias=False),
121 | norm_layer(out_channels)])
122 | self.conv = nn.Sequential(*layers)
123 |
124 | def forward(self, x):
125 | if self.use_res_connect:
126 | return x + self.conv(x)
127 | else:
128 | return self.conv(x)
129 |
130 |
131 | if __name__ == '__main__':
132 | x = torch.randn(1, 32, 64, 64)
133 | model = InvertedResidual(32, 64, 2, 1)
134 | out = model(x)
135 |
--------------------------------------------------------------------------------
/core/nn/jpu.py:
--------------------------------------------------------------------------------
1 | """Joint Pyramid Upsampling"""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | __all__ = ['JPU']
7 |
8 |
9 | class SeparableConv2d(nn.Module):
10 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1,
11 | dilation=1, bias=False, norm_layer=nn.BatchNorm2d):
12 | super(SeparableConv2d, self).__init__()
13 | self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias)
14 | self.bn = norm_layer(inplanes)
15 | self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
16 |
17 | def forward(self, x):
18 | x = self.conv(x)
19 | x = self.bn(x)
20 | x = self.pointwise(x)
21 | return x
22 |
23 |
24 | # copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py
25 | class JPU(nn.Module):
26 | def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs):
27 | super(JPU, self).__init__()
28 |
29 | self.conv5 = nn.Sequential(
30 | nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False),
31 | norm_layer(width),
32 | nn.ReLU(True))
33 | self.conv4 = nn.Sequential(
34 | nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False),
35 | norm_layer(width),
36 | nn.ReLU(True))
37 | self.conv3 = nn.Sequential(
38 | nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False),
39 | norm_layer(width),
40 | nn.ReLU(True))
41 |
42 | self.dilation1 = nn.Sequential(
43 | SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False),
44 | norm_layer(width),
45 | nn.ReLU(True))
46 | self.dilation2 = nn.Sequential(
47 | SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False),
48 | norm_layer(width),
49 | nn.ReLU(True))
50 | self.dilation3 = nn.Sequential(
51 | SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False),
52 | norm_layer(width),
53 | nn.ReLU(True))
54 | self.dilation4 = nn.Sequential(
55 | SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False),
56 | norm_layer(width),
57 | nn.ReLU(True))
58 |
59 | def forward(self, *inputs):
60 | feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])]
61 | size = feats[-1].size()[2:]
62 | feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True)
63 | feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True)
64 | feat = torch.cat(feats, dim=1)
65 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)],
66 | dim=1)
67 |
68 | return inputs[0], inputs[1], inputs[2], feat
69 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Utility functions."""
2 | from __future__ import absolute_import
3 |
4 | from .download import download, check_sha1
5 | from .filesystem import makedirs, try_import_pycocotools
6 |
--------------------------------------------------------------------------------
/core/utils/download.py:
--------------------------------------------------------------------------------
1 | """Download files with progress bar."""
2 | import os
3 | import hashlib
4 | import requests
5 | from tqdm import tqdm
6 |
7 | def check_sha1(filename, sha1_hash):
8 | """Check whether the sha1 hash of the file content matches the expected hash.
9 | Parameters
10 | ----------
11 | filename : str
12 | Path to the file.
13 | sha1_hash : str
14 | Expected sha1 hash in hexadecimal digits.
15 | Returns
16 | -------
17 | bool
18 | Whether the file content matches the expected hash.
19 | """
20 | sha1 = hashlib.sha1()
21 | with open(filename, 'rb') as f:
22 | while True:
23 | data = f.read(1048576)
24 | if not data:
25 | break
26 | sha1.update(data)
27 |
28 | sha1_file = sha1.hexdigest()
29 | l = min(len(sha1_file), len(sha1_hash))
30 | return sha1.hexdigest()[0:l] == sha1_hash[0:l]
31 |
32 | def download(url, path=None, overwrite=False, sha1_hash=None):
33 | """Download an given URL
34 | Parameters
35 | ----------
36 | url : str
37 | URL to download
38 | path : str, optional
39 | Destination path to store downloaded file. By default stores to the
40 | current directory with same name as in url.
41 | overwrite : bool, optional
42 | Whether to overwrite destination file if already exists.
43 | sha1_hash : str, optional
44 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
45 | but doesn't match.
46 | Returns
47 | -------
48 | str
49 | The file path of the downloaded file.
50 | """
51 | if path is None:
52 | fname = url.split('/')[-1]
53 | else:
54 | path = os.path.expanduser(path)
55 | if os.path.isdir(path):
56 | fname = os.path.join(path, url.split('/')[-1])
57 | else:
58 | fname = path
59 |
60 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
61 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
62 | if not os.path.exists(dirname):
63 | os.makedirs(dirname)
64 |
65 | print('Downloading %s from %s...'%(fname, url))
66 | r = requests.get(url, stream=True)
67 | if r.status_code != 200:
68 | raise RuntimeError("Failed downloading url %s"%url)
69 | total_length = r.headers.get('content-length')
70 | with open(fname, 'wb') as f:
71 | if total_length is None: # no content length header
72 | for chunk in r.iter_content(chunk_size=1024):
73 | if chunk: # filter out keep-alive new chunks
74 | f.write(chunk)
75 | else:
76 | total_length = int(total_length)
77 | for chunk in tqdm(r.iter_content(chunk_size=1024),
78 | total=int(total_length / 1024. + 0.5),
79 | unit='KB', unit_scale=False, dynamic_ncols=True):
80 | f.write(chunk)
81 |
82 | if sha1_hash and not check_sha1(fname, sha1_hash):
83 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \
84 | 'The repo may be outdated or download may be incomplete. ' \
85 | 'If the "repo_url" is overridden, consider switching to ' \
86 | 'the default repo.'.format(fname))
87 |
88 | return fname
--------------------------------------------------------------------------------
/core/utils/filesystem.py:
--------------------------------------------------------------------------------
1 | """Filesystem utility functions."""
2 | from __future__ import absolute_import
3 | import os
4 | import errno
5 |
6 |
7 | def makedirs(path):
8 | """Create directory recursively if not exists.
9 | Similar to `makedir -p`, you can skip checking existence before this function.
10 | Parameters
11 | ----------
12 | path : str
13 | Path of the desired dir
14 | """
15 | try:
16 | os.makedirs(path)
17 | except OSError as exc:
18 | if exc.errno != errno.EEXIST:
19 | raise
20 |
21 |
22 | def try_import(package, message=None):
23 | """Try import specified package, with custom message support.
24 | Parameters
25 | ----------
26 | package : str
27 | The name of the targeting package.
28 | message : str, default is None
29 | If not None, this function will raise customized error message when import error is found.
30 | Returns
31 | -------
32 | module if found, raise ImportError otherwise
33 | """
34 | try:
35 | return __import__(package)
36 | except ImportError as e:
37 | if not message:
38 | raise e
39 | raise ImportError(message)
40 |
41 |
42 | def try_import_cv2():
43 | """Try import cv2 at runtime.
44 | Returns
45 | -------
46 | cv2 module if found. Raise ImportError otherwise
47 | """
48 | msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \
49 | or `pip install opencv-python --user` (note that this is unofficial PYPI package)."
50 | return try_import('cv2', msg)
51 |
52 |
53 | def import_try_install(package, extern_url=None):
54 | """Try import the specified package.
55 | If the package not installed, try use pip to install and import if success.
56 | Parameters
57 | ----------
58 | package : str
59 | The name of the package trying to import.
60 | extern_url : str or None, optional
61 | The external url if package is not hosted on PyPI.
62 | For example, you can install a package using:
63 | "pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx".
64 | In this case, you can pass the url to the extern_url.
65 | Returns
66 | -------
67 |
68 | The imported python module.
69 | """
70 | try:
71 | return __import__(package)
72 | except ImportError:
73 | try:
74 | from pip import main as pipmain
75 | except ImportError:
76 | from pip._internal import main as pipmain
77 |
78 | # trying to install package
79 | url = package if extern_url is None else extern_url
80 | pipmain(['install', '--user', url]) # will raise SystemExit Error if fails
81 |
82 | # trying to load again
83 | try:
84 | return __import__(package)
85 | except ImportError:
86 | import sys
87 | import site
88 | user_site = site.getusersitepackages()
89 | if user_site not in sys.path:
90 | sys.path.append(user_site)
91 | return __import__(package)
92 | return __import__(package)
93 |
94 |
95 | """Import helper for pycocotools"""
96 |
97 |
98 | # NOTE: for developers
99 | # please do not import any pycocotools in __init__ because we are trying to lazy
100 | # import pycocotools to avoid install it for other users who may not use it.
101 | # only import when you actually use it
102 |
103 |
104 | def try_import_pycocotools():
105 | """Tricks to optionally install and import pycocotools"""
106 | # first we can try import pycocotools
107 | try:
108 | import pycocotools as _
109 | except ImportError:
110 | import os
111 | # we need to install pycootools, which is a bit tricky
112 | # pycocotools sdist requires Cython, numpy(already met)
113 | import_try_install('cython')
114 | # pypi pycocotools is not compatible with windows
115 | win_url = 'git+https://github.com/zhreshold/cocoapi.git#subdirectory=PythonAPI'
116 | try:
117 | if os.name == 'nt':
118 | import_try_install('pycocotools', win_url)
119 | else:
120 | import_try_install('pycocotools')
121 | except ImportError:
122 | faq = 'cocoapi FAQ'
123 | raise ImportError('Cannot import or install pycocotools, please refer to %s.' % faq)
124 |
--------------------------------------------------------------------------------
/core/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import logging
3 | import os
4 | import sys
5 |
6 | __all__ = ['setup_logger']
7 |
8 |
9 | # reference from: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/logger.py
10 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'):
11 | logger = logging.getLogger(name)
12 | logger.setLevel(logging.DEBUG)
13 | # don't log results for the non-master process
14 | if distributed_rank > 0:
15 | return logger
16 | ch = logging.StreamHandler(stream=sys.stdout)
17 | ch.setLevel(logging.DEBUG)
18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
19 | ch.setFormatter(formatter)
20 | logger.addHandler(ch)
21 |
22 | if save_dir:
23 | if not os.path.exists(save_dir):
24 | os.makedirs(save_dir)
25 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite
26 | fh.setLevel(logging.DEBUG)
27 | fh.setFormatter(formatter)
28 | logger.addHandler(fh)
29 |
30 | return logger
31 |
--------------------------------------------------------------------------------
/core/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | """Popular Learning Rate Schedulers"""
2 | from __future__ import division
3 | import math
4 | import torch
5 |
6 | from bisect import bisect_right
7 |
8 | __all__ = ['LRScheduler', 'WarmupMultiStepLR', 'WarmupPolyLR']
9 |
10 |
11 | class LRScheduler(object):
12 | r"""Learning Rate Scheduler
13 |
14 | Parameters
15 | ----------
16 | mode : str
17 | Modes for learning rate scheduler.
18 | Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'.
19 | base_lr : float
20 | Base learning rate, i.e. the starting learning rate.
21 | target_lr : float
22 | Target learning rate, i.e. the ending learning rate.
23 | With constant mode target_lr is ignored.
24 | niters : int
25 | Number of iterations to be scheduled.
26 | nepochs : int
27 | Number of epochs to be scheduled.
28 | iters_per_epoch : int
29 | Number of iterations in each epoch.
30 | offset : int
31 | Number of iterations before this scheduler.
32 | power : float
33 | Power parameter of poly scheduler.
34 | step_iter : list
35 | A list of iterations to decay the learning rate.
36 | step_epoch : list
37 | A list of epochs to decay the learning rate.
38 | step_factor : float
39 | Learning rate decay factor.
40 | """
41 |
42 | def __init__(self, mode, base_lr=0.01, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0,
43 | offset=0, power=0.9, step_iter=None, step_epoch=None, step_factor=0.1, warmup_epochs=0):
44 | super(LRScheduler, self).__init__()
45 | assert (mode in ['constant', 'step', 'linear', 'poly', 'cosine'])
46 |
47 | if mode == 'step':
48 | assert (step_iter is not None or step_epoch is not None)
49 | self.niters = niters
50 | self.step = step_iter
51 | epoch_iters = nepochs * iters_per_epoch
52 | if epoch_iters > 0:
53 | self.niters = epoch_iters
54 | if step_epoch is not None:
55 | self.step = [s * iters_per_epoch for s in step_epoch]
56 |
57 | self.step_factor = step_factor
58 | self.base_lr = base_lr
59 | self.target_lr = base_lr if mode == 'constant' else target_lr
60 | self.offset = offset
61 | self.power = power
62 | self.warmup_iters = warmup_epochs * iters_per_epoch
63 | self.mode = mode
64 |
65 | def __call__(self, optimizer, num_update):
66 | self.update(num_update)
67 | assert self.learning_rate >= 0
68 | self._adjust_learning_rate(optimizer, self.learning_rate)
69 |
70 | def update(self, num_update):
71 | N = self.niters - 1
72 | T = num_update - self.offset
73 | T = min(max(0, T), N)
74 |
75 | if self.mode == 'constant':
76 | factor = 0
77 | elif self.mode == 'linear':
78 | factor = 1 - T / N
79 | elif self.mode == 'poly':
80 | factor = pow(1 - T / N, self.power)
81 | elif self.mode == 'cosine':
82 | factor = (1 + math.cos(math.pi * T / N)) / 2
83 | elif self.mode == 'step':
84 | if self.step is not None:
85 | count = sum([1 for s in self.step if s <= T])
86 | factor = pow(self.step_factor, count)
87 | else:
88 | factor = 1
89 | else:
90 | raise NotImplementedError
91 |
92 | # warm up lr schedule
93 | if self.warmup_iters > 0 and T < self.warmup_iters:
94 | factor = factor * 1.0 * T / self.warmup_iters
95 |
96 | if self.mode == 'step':
97 | self.learning_rate = self.base_lr * factor
98 | else:
99 | self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor
100 |
101 | def _adjust_learning_rate(self, optimizer, lr):
102 | optimizer.param_groups[0]['lr'] = lr
103 | # enlarge the lr at the head
104 | for i in range(1, len(optimizer.param_groups)):
105 | optimizer.param_groups[i]['lr'] = lr * 10
106 |
107 |
108 | # separating MultiStepLR with WarmupLR
109 | # but the current LRScheduler design doesn't allow it
110 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py
111 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
112 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3,
113 | warmup_iters=500, warmup_method="linear", last_epoch=-1):
114 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
115 | if not list(milestones) == sorted(milestones):
116 | raise ValueError(
117 | "Milestones should be a list of" " increasing integers. Got {}", milestones)
118 | if warmup_method not in ("constant", "linear"):
119 | raise ValueError(
120 | "Only 'constant' or 'linear' warmup_method accepted got {}".format(warmup_method))
121 |
122 | self.milestones = milestones
123 | self.gamma = gamma
124 | self.warmup_factor = warmup_factor
125 | self.warmup_iters = warmup_iters
126 | self.warmup_method = warmup_method
127 |
128 | def get_lr(self):
129 | warmup_factor = 1
130 | if self.last_epoch < self.warmup_iters:
131 | if self.warmup_method == 'constant':
132 | warmup_factor = self.warmup_factor
133 | elif self.warmup_factor == 'linear':
134 | alpha = float(self.last_epoch) / self.warmup_iters
135 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
136 | return [base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
137 | for base_lr in self.base_lrs]
138 |
139 |
140 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
141 | def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3,
142 | warmup_iters=500, warmup_method='linear', last_epoch=-1):
143 | if warmup_method not in ("constant", "linear"):
144 | raise ValueError(
145 | "Only 'constant' or 'linear' warmup_method accepted "
146 | "got {}".format(warmup_method))
147 |
148 | self.target_lr = target_lr
149 | self.max_iters = max_iters
150 | self.power = power
151 | self.warmup_factor = warmup_factor
152 | self.warmup_iters = warmup_iters
153 | self.warmup_method = warmup_method
154 |
155 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch)
156 |
157 | def get_lr(self):
158 | N = self.max_iters - self.warmup_iters
159 | T = self.last_epoch - self.warmup_iters
160 | if self.last_epoch < self.warmup_iters:
161 | if self.warmup_method == 'constant':
162 | warmup_factor = self.warmup_factor
163 | elif self.warmup_method == 'linear':
164 | alpha = float(self.last_epoch) / self.warmup_iters
165 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
166 | else:
167 | raise ValueError("Unknown warmup type.")
168 | return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs]
169 | factor = pow(1 - T / N, self.power)
170 | return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs]
171 |
172 |
173 | if __name__ == '__main__':
174 | import torch
175 | import torch.nn as nn
176 |
177 | model = nn.Conv2d(16, 16, 3, 1, 1)
178 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
179 | lr_scheduler = WarmupPolyLR(optimizer, niters=1000)
180 |
--------------------------------------------------------------------------------
/core/utils/parallel.py:
--------------------------------------------------------------------------------
1 | """Utils for Semantic Segmentation"""
2 | import threading
3 | import torch
4 | import torch.cuda.comm as comm
5 | from torch.nn.parallel.data_parallel import DataParallel
6 | from torch.nn.parallel._functions import Broadcast
7 | from torch.autograd import Function
8 |
9 | __all__ = ['DataParallelModel', 'DataParallelCriterion']
10 |
11 |
12 | class Reduce(Function):
13 | @staticmethod
14 | def forward(ctx, *inputs):
15 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
16 | inputs = sorted(inputs, key=lambda i: i.get_device())
17 | return comm.reduce_add(inputs)
18 |
19 | @staticmethod
20 | def backward(ctx, gradOutputs):
21 | return Broadcast.apply(ctx.target_gpus, gradOutputs)
22 |
23 |
24 | class DataParallelModel(DataParallel):
25 | """Data parallelism
26 |
27 | Hide the difference of single/multiple GPUs to the user.
28 | In the forward pass, the module is replicated on each device,
29 | and each replica handles a portion of the input. During the backwards
30 | pass, gradients from each replica are summed into the original module.
31 |
32 | The batch size should be larger than the number of GPUs used.
33 |
34 | Parameters
35 | ----------
36 | module : object
37 | Network to be parallelized.
38 | sync : bool
39 | enable synchronization (default: False).
40 | Inputs:
41 | - **inputs**: list of input
42 | Outputs:
43 | - **outputs**: list of output
44 | Example::
45 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2])
46 | >>> output = net(input_var) # input_var can be on any device, including CPU
47 | """
48 |
49 | def gather(self, outputs, output_device):
50 | return outputs
51 |
52 | def replicate(self, module, device_ids):
53 | modules = super(DataParallelModel, self).replicate(module, device_ids)
54 | return modules
55 |
56 |
57 | # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py
58 | class DataParallelCriterion(DataParallel):
59 | """
60 | Calculate loss in multiple-GPUs, which balance the memory usage for
61 | Semantic Segmentation.
62 |
63 | The targets are splitted across the specified devices by chunking in
64 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
65 |
66 | Example::
67 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2])
68 | >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2])
69 | >>> y = net(x)
70 | >>> loss = criterion(y, target)
71 | """
72 |
73 | def forward(self, inputs, *targets, **kwargs):
74 | # the inputs should be the outputs of DataParallelModel
75 | if not self.device_ids:
76 | return self.module(inputs, *targets, **kwargs)
77 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
78 | if len(self.device_ids) == 1:
79 | return self.module(inputs, *targets[0], **kwargs[0])
80 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
81 | outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs)
82 | return Reduce.apply(*outputs) / len(outputs)
83 |
84 |
85 | def get_a_var(obj):
86 | if isinstance(obj, torch.Tensor):
87 | return obj
88 |
89 | if isinstance(obj, list) or isinstance(obj, tuple):
90 | for result in map(get_a_var, obj):
91 | if isinstance(result, torch.Tensor):
92 | return result
93 |
94 | if isinstance(obj, dict):
95 | for result in map(get_a_var, obj.items()):
96 | if isinstance(result, torch.Tensor):
97 | return result
98 | return None
99 |
100 |
101 | def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
102 | r"""Applies each `module` in :attr:`modules` in parallel on arguments
103 | contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword)
104 | on each of :attr:`devices`.
105 |
106 | Args:
107 | modules (Module): modules to be parallelized
108 | inputs (tensor): inputs to the modules
109 | targets (tensor): targets to the modules
110 | devices (list of int or torch.device): CUDA devices
111 | :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and
112 | :attr:`devices` (if given) should all have same length. Moreover, each
113 | element of :attr:`inputs` can either be a single object as the only argument
114 | to a module, or a collection of positional arguments.
115 | """
116 | assert len(modules) == len(inputs)
117 | assert len(targets) == len(inputs)
118 | if kwargs_tup is not None:
119 | assert len(modules) == len(kwargs_tup)
120 | else:
121 | kwargs_tup = ({},) * len(modules)
122 | if devices is not None:
123 | assert len(modules) == len(devices)
124 | else:
125 | devices = [None] * len(modules)
126 | lock = threading.Lock()
127 | results = {}
128 | grad_enabled = torch.is_grad_enabled()
129 |
130 | def _worker(i, module, input, target, kwargs, device=None):
131 | torch.set_grad_enabled(grad_enabled)
132 | if device is None:
133 | device = get_a_var(input).get_device()
134 | try:
135 | with torch.cuda.device(device):
136 | output = module(*(list(input) + target), **kwargs)
137 | with lock:
138 | results[i] = output
139 | except Exception as e:
140 | with lock:
141 | results[i] = e
142 |
143 | if len(modules) > 1:
144 | threads = [threading.Thread(target=_worker,
145 | args=(i, module, input, target, kwargs, device))
146 | for i, (module, input, target, kwargs, device) in
147 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]
148 |
149 | for thread in threads:
150 | thread.start()
151 | for thread in threads:
152 | thread.join()
153 | else:
154 | _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0])
155 |
156 | outputs = []
157 | for i in range(len(inputs)):
158 | output = results[i]
159 | if isinstance(output, Exception):
160 | raise output
161 | outputs.append(output)
162 | return outputs
163 |
--------------------------------------------------------------------------------
/core/utils/score.py:
--------------------------------------------------------------------------------
1 | """Evaluation Metrics for Semantic Segmentation"""
2 | import torch
3 | import numpy as np
4 |
5 | __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union',
6 | 'pixelAccuracy', 'intersectionAndUnion', 'hist_info', 'compute_score']
7 |
8 |
9 | class SegmentationMetric(object):
10 | """Computes pixAcc and mIoU metric scores
11 | """
12 |
13 | def __init__(self, nclass):
14 | super(SegmentationMetric, self).__init__()
15 | self.nclass = nclass
16 | self.reset()
17 |
18 | def update(self, preds, labels):
19 | """Updates the internal evaluation result.
20 |
21 | Parameters
22 | ----------
23 | labels : 'NumpyArray' or list of `NumpyArray`
24 | The labels of the data.
25 | preds : 'NumpyArray' or list of `NumpyArray`
26 | Predicted values.
27 | """
28 |
29 | def evaluate_worker(self, pred, label):
30 | correct, labeled = batch_pix_accuracy(pred, label)
31 | inter, union = batch_intersection_union(pred, label, self.nclass)
32 |
33 | self.total_correct += correct
34 | self.total_label += labeled
35 | if self.total_inter.device != inter.device:
36 | self.total_inter = self.total_inter.to(inter.device)
37 | self.total_union = self.total_union.to(union.device)
38 | self.total_inter += inter
39 | self.total_union += union
40 |
41 | if isinstance(preds, torch.Tensor):
42 | evaluate_worker(self, preds, labels)
43 | elif isinstance(preds, (list, tuple)):
44 | for (pred, label) in zip(preds, labels):
45 | evaluate_worker(self, pred, label)
46 |
47 | def get(self):
48 | """Gets the current evaluation result.
49 |
50 | Returns
51 | -------
52 | metrics : tuple of float
53 | pixAcc and mIoU
54 | """
55 | pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove np.spacing(1)
56 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
57 | mIoU = IoU.mean().item()
58 | return pixAcc, mIoU
59 |
60 | def reset(self):
61 | """Resets the internal evaluation result to initial state."""
62 | self.total_inter = torch.zeros(self.nclass)
63 | self.total_union = torch.zeros(self.nclass)
64 | self.total_correct = 0
65 | self.total_label = 0
66 |
67 |
68 | # pytorch version
69 | def batch_pix_accuracy(output, target):
70 | """PixAcc"""
71 | # inputs are numpy array, output 4D, target 3D
72 | predict = torch.argmax(output.long(), 1) + 1
73 | target = target.long() + 1
74 |
75 | pixel_labeled = torch.sum(target > 0).item()
76 | pixel_correct = torch.sum((predict == target) * (target > 0)).item()
77 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
78 | return pixel_correct, pixel_labeled
79 |
80 |
81 | def batch_intersection_union(output, target, nclass):
82 | """mIoU"""
83 | # inputs are numpy array, output 4D, target 3D
84 | mini = 1
85 | maxi = nclass
86 | nbins = nclass
87 | predict = torch.argmax(output, 1) + 1
88 | target = target.float() + 1
89 |
90 | predict = predict.float() * (target > 0).float()
91 | intersection = predict * (predict == target).float()
92 | # areas of intersection and union
93 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
94 | area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
95 | area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
96 | area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
97 | area_union = area_pred + area_lab - area_inter
98 | assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
99 | return area_inter.float(), area_union.float()
100 |
101 |
102 | def pixelAccuracy(imPred, imLab):
103 | """
104 | This function takes the prediction and label of a single image, returns pixel-wise accuracy
105 | To compute over many images do:
106 | for i = range(Nimages):
107 | (pixel_accuracy[i], pixel_correct[i], pixel_labeled[i]) = \
108 | pixelAccuracy(imPred[i], imLab[i])
109 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled))
110 | """
111 | # Remove classes from unlabeled pixels in gt image.
112 | # We should not penalize detections in unlabeled portions of the image.
113 | pixel_labeled = np.sum(imLab >= 0)
114 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0))
115 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
116 | return (pixel_accuracy, pixel_correct, pixel_labeled)
117 |
118 |
119 | def intersectionAndUnion(imPred, imLab, numClass):
120 | """
121 | This function takes the prediction and label of a single image,
122 | returns intersection and union areas for each class
123 | To compute over many images do:
124 | for i in range(Nimages):
125 | (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i])
126 | IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1)
127 | """
128 | # Remove classes from unlabeled pixels in gt image.
129 | # We should not penalize detections in unlabeled portions of the image.
130 | imPred = imPred * (imLab >= 0)
131 |
132 | # Compute area intersection:
133 | intersection = imPred * (imPred == imLab)
134 | (area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass))
135 |
136 | # Compute area union:
137 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
138 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
139 | area_union = area_pred + area_lab - area_intersection
140 | return (area_intersection, area_union)
141 |
142 |
143 | def hist_info(pred, label, num_cls):
144 | assert pred.shape == label.shape
145 | k = (label >= 0) & (label < num_cls)
146 | labeled = np.sum(k)
147 | correct = np.sum((pred[k] == label[k]))
148 |
149 | return np.bincount(num_cls * label[k].astype(int) + pred[k], minlength=num_cls ** 2).reshape(num_cls,
150 | num_cls), labeled, correct
151 |
152 |
153 | def compute_score(hist, correct, labeled):
154 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
155 | mean_IU = np.nanmean(iu)
156 | mean_IU_no_back = np.nanmean(iu[1:])
157 | freq = hist.sum(1) / hist.sum()
158 | freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
159 | mean_pixel_acc = correct / labeled
160 |
161 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc
162 |
--------------------------------------------------------------------------------
/core/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | __all__ = ['get_color_pallete', 'print_iou', 'set_img_color',
6 | 'show_prediction', 'show_colorful_images', 'save_colorful_images']
7 |
8 |
9 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False):
10 | n = iu.size
11 | lines = []
12 | for i in range(n):
13 | if class_names is None:
14 | cls = 'Class %d:' % (i + 1)
15 | else:
16 | cls = '%d %s' % (i + 1, class_names[i])
17 | # lines.append('%-8s: %.3f%%' % (cls, iu[i] * 100))
18 | mean_IU = np.nanmean(iu)
19 | mean_IU_no_back = np.nanmean(iu[1:])
20 | if show_no_back:
21 | lines.append('mean_IU: %.3f%% || mean_IU_no_back: %.3f%% || mean_pixel_acc: %.3f%%' % (
22 | mean_IU * 100, mean_IU_no_back * 100, mean_pixel_acc * 100))
23 | else:
24 | lines.append('mean_IU: %.3f%% || mean_pixel_acc: %.3f%%' % (mean_IU * 100, mean_pixel_acc * 100))
25 | lines.append('=================================================')
26 | line = "\n".join(lines)
27 |
28 | print(line)
29 |
30 |
31 | def set_img_color(img, label, colors, background=0, show255=False):
32 | for i in range(len(colors)):
33 | if i != background:
34 | img[np.where(label == i)] = colors[i]
35 | if show255:
36 | img[np.where(label == 255)] = 255
37 |
38 | return img
39 |
40 |
41 | def show_prediction(img, pred, colors, background=0):
42 | im = np.array(img, np.uint8)
43 | set_img_color(im, pred, colors, background)
44 | out = np.array(im)
45 |
46 | return out
47 |
48 |
49 | def show_colorful_images(prediction, palettes):
50 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()])
51 | im.show()
52 |
53 |
54 | def save_colorful_images(prediction, filename, output_dir, palettes):
55 | '''
56 | :param prediction: [B, H, W, C]
57 | '''
58 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()])
59 | fn = os.path.join(output_dir, filename)
60 | out_dir = os.path.split(fn)[0]
61 | if not os.path.exists(out_dir):
62 | os.mkdir(out_dir)
63 | im.save(fn)
64 |
65 |
66 | def get_color_pallete(npimg, dataset='pascal_voc'):
67 | """Visualize image.
68 |
69 | Parameters
70 | ----------
71 | npimg : numpy.ndarray
72 | Single channel image with shape `H, W, 1`.
73 | dataset : str, default: 'pascal_voc'
74 | The dataset that model pretrained on. ('pascal_voc', 'ade20k')
75 | Returns
76 | -------
77 | out_img : PIL.Image
78 | Image with color pallete
79 | """
80 | # recovery boundary
81 | if dataset in ('pascal_voc', 'pascal_aug'):
82 | npimg[npimg == -1] = 255
83 | # put colormap
84 | if dataset == 'ade20k':
85 | npimg = npimg + 1
86 | out_img = Image.fromarray(npimg.astype('uint8'))
87 | out_img.putpalette(adepallete)
88 | return out_img
89 | elif dataset == 'citys':
90 | out_img = Image.fromarray(npimg.astype('uint8'))
91 | out_img.putpalette(cityspallete)
92 | return out_img
93 | out_img = Image.fromarray(npimg.astype('uint8'))
94 | out_img.putpalette(vocpallete)
95 | return out_img
96 |
97 |
98 | def _getvocpallete(num_cls):
99 | n = num_cls
100 | pallete = [0] * (n * 3)
101 | for j in range(0, n):
102 | lab = j
103 | pallete[j * 3 + 0] = 0
104 | pallete[j * 3 + 1] = 0
105 | pallete[j * 3 + 2] = 0
106 | i = 0
107 | while (lab > 0):
108 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
109 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
110 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
111 | i = i + 1
112 | lab >>= 3
113 | return pallete
114 |
115 |
116 | vocpallete = _getvocpallete(256)
117 |
118 | adepallete = [
119 | 0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204,
120 | 5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82,
121 | 143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255,
122 | 7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6,
123 | 10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255,
124 | 20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15,
125 | 20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255,
126 | 31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163,
127 | 0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255,
128 | 0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0,
129 | 31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0,
130 | 194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255,
131 | 0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255,
132 | 0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0,
133 | 163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0,
134 | 10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0,
135 | 255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0,
136 | 133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255]
137 |
138 | cityspallete = [
139 | 128, 64, 128,
140 | 244, 35, 232,
141 | 70, 70, 70,
142 | 102, 102, 156,
143 | 190, 153, 153,
144 | 153, 153, 153,
145 | 250, 170, 30,
146 | 220, 220, 0,
147 | 107, 142, 35,
148 | 152, 251, 152,
149 | 0, 130, 180,
150 | 220, 20, 60,
151 | 255, 0, 0,
152 | 0, 0, 142,
153 | 0, 0, 70,
154 | 0, 60, 100,
155 | 0, 80, 100,
156 | 0, 0, 230,
157 | 119, 11, 32,
158 | ]
159 |
--------------------------------------------------------------------------------
/datasets/ade:
--------------------------------------------------------------------------------
1 | /home/tramac/PycharmProjects/Data_zoo/ade
--------------------------------------------------------------------------------
/datasets/citys:
--------------------------------------------------------------------------------
1 | /home/tramac/PycharmProjects/Data_zoo/citys
--------------------------------------------------------------------------------
/datasets/sbu:
--------------------------------------------------------------------------------
1 | /home/tramac/PycharmProjects/Data_zoo/SBU-shadow
--------------------------------------------------------------------------------
/datasets/voc:
--------------------------------------------------------------------------------
1 | /home/tramac/PycharmProjects/Data_zoo/VOCdevkit
--------------------------------------------------------------------------------
/docs/DETAILS.md:
--------------------------------------------------------------------------------
1 | ### Model & Backbone
2 |
3 | | Model | Scratch | VGG16 | ResNet18 | ResNet50 | ResNet101 | ResNet152 | DenseNet121 | DenseNet169 |
4 | | :-------: | :-----: | :---: | :------: | :------: | :-------: | :-------: | :---------: | :---------: |
5 | | FCN32s | ✘ | ✓ | | | | | | |
6 | | FCN16s | | ✓ | | | | | | |
7 | | FCN8s | | ✓ | | | | | | |
8 | | FCNv2 | | | | ✓ | ✓ | ✓ | | |
9 | | PSPNet | | | | ✓ | ✓ | ✓ | | |
10 | | DeepLabv3 | | | | ✓ | ✓ | ✓ | | |
11 | | DenseASPP | | | | | | | ✓ | ✓ |
12 | | DANet | | | | ✓ | ✓ | ✓ | | |
13 | | BiSeNet | | | ✓ | | | | | |
14 | | EncNet | | | | ✓ | ✓ | ✓ | | |
15 | | ICNet | | | | ✓ | ✓ | ✓ | | |
16 | | DUNet | | | | ✓ | ✓ | ✓ | | |
17 | | ENet | ✓ | | | | | | | |
18 | | OCNet | | | | ✓ | ✓ | ✓ | | |
19 | | CCNet | | | | ✓ | ✓ | ✓ | | |
20 | | PSANet | | | | ✓ | ✓ | ✓ | | |
21 | | CGNet | ✓ | | | | | | | |
22 | | ESPNet | ✓ | | | | | | | |
23 | | LEDNet | ✓ | | | | | | | |
24 | | DFANet | ✓ | | | | | | | |
25 |
--------------------------------------------------------------------------------
/docs/QQ.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/docs/QQ.jpg
--------------------------------------------------------------------------------
/docs/WeChat.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/docs/WeChat.jpeg
--------------------------------------------------------------------------------
/docs/requirements.yml:
--------------------------------------------------------------------------------
1 | name: seg_requirements
2 | dependencies:
3 | - python3
4 | - numpy
5 | - cuda
6 | - pip:
7 | - Image
8 | - tqdm
9 | - requests
10 | - pytorch 1.0
11 | - torchvision
12 |
--------------------------------------------------------------------------------
/docs/weimar_000091_000019_gtFine_color.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/docs/weimar_000091_000019_gtFine_color.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | torch>=1.1.0
3 | torchvision>=0.3.0
4 | math
5 | pickle
6 | logging
7 | Pillow
8 | shutil
9 |
--------------------------------------------------------------------------------
/scripts/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import torch
5 |
6 | cur_path = os.path.abspath(os.path.dirname(__file__))
7 | root_path = os.path.split(cur_path)[0]
8 | sys.path.append(root_path)
9 |
10 | from torchvision import transforms
11 | from PIL import Image
12 | from core.utils.visualize import get_color_pallete
13 | from core.models import get_model
14 |
15 | parser = argparse.ArgumentParser(
16 | description='Predict segmentation result from a given image')
17 | parser.add_argument('--model', type=str, default='fcn32s_vgg16_voc',
18 | help='model name (default: fcn32_vgg16)')
19 | parser.add_argument('--dataset', type=str, default='pascal_aug', choices=['pascal_voc, pascal_aug, ade20k, citys'],
20 | help='dataset name (default: pascal_voc)')
21 | parser.add_argument('--save-folder', default='~/.torch/models',
22 | help='Directory for saving checkpoint models')
23 | parser.add_argument('--input-pic', type=str, default='../datasets/voc/VOC2012/JPEGImages/2007_000032.jpg',
24 | help='path to the input picture')
25 | parser.add_argument('--outdir', default='./eval', type=str,
26 | help='path to save the predict result')
27 | parser.add_argument('--local_rank', type=int, default=0)
28 | args = parser.parse_args()
29 |
30 |
31 | def demo(config):
32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33 | # output folder
34 | if not os.path.exists(config.outdir):
35 | os.makedirs(config.outdir)
36 |
37 | # image transform
38 | transform = transforms.Compose([
39 | transforms.ToTensor(),
40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
41 | ])
42 | image = Image.open(config.input_pic).convert('RGB')
43 | images = transform(image).unsqueeze(0).to(device)
44 |
45 | model = get_model(args.model, local_rank=args.local_rank, pretrained=True, root=args.save_folder).to(device)
46 | print('Finished loading model!')
47 |
48 | model.eval()
49 | with torch.no_grad():
50 | output = model(images)
51 |
52 | pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy()
53 | mask = get_color_pallete(pred, args.dataset)
54 | outname = os.path.splitext(os.path.split(args.input_pic)[-1])[0] + '.png'
55 | mask.save(os.path.join(args.outdir, outname))
56 |
57 |
58 | if __name__ == '__main__':
59 | demo(args)
60 |
--------------------------------------------------------------------------------
/scripts/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import sys
5 |
6 | cur_path = os.path.abspath(os.path.dirname(__file__))
7 | root_path = os.path.split(cur_path)[0]
8 | sys.path.append(root_path)
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.utils.data as data
13 | import torch.backends.cudnn as cudnn
14 |
15 | from torchvision import transforms
16 | from core.data.dataloader import get_segmentation_dataset
17 | from core.models.model_zoo import get_segmentation_model
18 | from core.utils.score import SegmentationMetric
19 | from core.utils.visualize import get_color_pallete
20 | from core.utils.logger import setup_logger
21 | from core.utils.distributed import synchronize, get_rank, make_data_sampler, make_batch_data_sampler
22 |
23 | from train import parse_args
24 |
25 |
26 | class Evaluator(object):
27 | def __init__(self, args):
28 | self.args = args
29 | self.device = torch.device(args.device)
30 |
31 | # image transform
32 | input_transform = transforms.Compose([
33 | transforms.ToTensor(),
34 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
35 | ])
36 |
37 | # dataset and dataloader
38 | val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform)
39 | val_sampler = make_data_sampler(val_dataset, False, args.distributed)
40 | val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1)
41 | self.val_loader = data.DataLoader(dataset=val_dataset,
42 | batch_sampler=val_batch_sampler,
43 | num_workers=args.workers,
44 | pin_memory=True)
45 |
46 | # create network
47 | BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
48 | self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
49 | aux=args.aux, pretrained=True, pretrained_base=False,
50 | local_rank=args.local_rank,
51 | norm_layer=BatchNorm2d).to(self.device)
52 | if args.distributed:
53 | self.model = nn.parallel.DistributedDataParallel(self.model,
54 | device_ids=[args.local_rank], output_device=args.local_rank)
55 | self.model.to(self.device)
56 |
57 | self.metric = SegmentationMetric(val_dataset.num_class)
58 |
59 | def eval(self):
60 | self.metric.reset()
61 | self.model.eval()
62 | if self.args.distributed:
63 | model = self.model.module
64 | else:
65 | model = self.model
66 | logger.info("Start validation, Total sample: {:d}".format(len(self.val_loader)))
67 | for i, (image, target, filename) in enumerate(self.val_loader):
68 | image = image.to(self.device)
69 | target = target.to(self.device)
70 |
71 | with torch.no_grad():
72 | outputs = model(image)
73 | self.metric.update(outputs[0], target)
74 | pixAcc, mIoU = self.metric.get()
75 | logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
76 | i + 1, pixAcc * 100, mIoU * 100))
77 |
78 | if self.args.save_pred:
79 | pred = torch.argmax(outputs[0], 1)
80 | pred = pred.cpu().data.numpy()
81 |
82 | predict = pred.squeeze(0)
83 | mask = get_color_pallete(predict, self.args.dataset)
84 | mask.save(os.path.join(outdir, os.path.splitext(filename[0])[0] + '.png'))
85 | synchronize()
86 |
87 |
88 | if __name__ == '__main__':
89 | args = parse_args()
90 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
91 | args.distributed = num_gpus > 1
92 | if not args.no_cuda and torch.cuda.is_available():
93 | cudnn.benchmark = True
94 | args.device = "cuda"
95 | else:
96 | args.distributed = False
97 | args.device = "cpu"
98 | if args.distributed:
99 | torch.cuda.set_device(args.local_rank)
100 | torch.distributed.init_process_group(backend="nccl", init_method="env://")
101 | synchronize()
102 |
103 | # TODO: optim code
104 | args.save_pred = True
105 | if args.save_pred:
106 | outdir = '../runs/pred_pic/{}_{}_{}'.format(args.model, args.backbone, args.dataset)
107 | if not os.path.exists(outdir):
108 | os.makedirs(outdir)
109 |
110 | logger = setup_logger("semantic_segmentation", args.log_dir, get_rank(),
111 | filename='{}_{}_{}_log.txt'.format(args.model, args.backbone, args.dataset), mode='a+')
112 |
113 | evaluator = Evaluator(args)
114 | evaluator.eval()
115 | torch.cuda.empty_cache()
116 |
--------------------------------------------------------------------------------
/scripts/fcn32s_vgg16_pascal_voc.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # train
4 | CUDA_VISIBLE_DEVICES=0 python train.py --model fcn32s \
5 | --backbone vgg16 --dataset pascal_voc \
6 | --lr 0.0001 --epochs 80
--------------------------------------------------------------------------------
/scripts/fcn32s_vgg16_pascal_voc_dist.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # train
4 | export NGPUS=4
5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --model fcn32s \
6 | --backbone vgg16 --dataset pascal_voc \
7 | --lr 0.01 --epochs 80 --batch_size 16
--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | # Overfitting Test
2 |
3 | In order to ensure the correctness of models, the project provides a overfitting test (a trick which makes the train set and the val set includes the same images) script.
4 | Observing the convergence process of different models is so interesting:joy:
5 |
6 | ### Usage
7 |
8 |
9 |
10 | (a) img: 2007_000033.jpg (b) mask: 2007_000033.png
11 |
12 | ### Test Result
13 | | Model | backbone | epoch | mIoU | pixAcc |
14 | | :-----: | :----: | :-----: | :-----: | :------: |
15 | | FCN32s | vgg16 | 200 | 94.0% | 98.2% |
16 | | FCN16s | vgg16 | 200 | 99.2% | 99.8% |
17 | | FCN8s | vgg16 | 100 | 99.8% | 99.9% |
18 | | DANet | resnet50 | 100 | 99.5% | 99.9% |
19 | | EncNet | resnet50 | 100 | 99.7% | 99.9% |
20 | | DUNet | resnet50 | 100 | 98.8% | 99.6% |
21 | | PSPNet | resnet50 | 100 | 99.8% | 99.9% |
22 | | BiSeNet | resnet18 | 100 | 99.6% | 99.9% |
23 | | DenseASPP | densenet121 | 40 | 100% | 100% |
24 | | ICNet | resnet50 | 100 | 98.8% | 99.6% |
25 | | ENet | scratch | 100 | 99.9% | 100% |
26 | | OCNet | resnet50 | 100 | 99.8% | 100% |
27 |
28 | ### Visualization
29 |
30 |
31 |
32 |
33 |
34 |
35 | FCN32s FCN16s FCN8s DANet EncNet DUNet PSPNet BiSeNet DenseASPP
36 |
37 |
38 |
39 |
40 | ICNet ENet OCNet
41 |
42 | ### Conclusion
43 | - The result of FCN32s is the worst.
44 | - There are gridding artifacts in DUNet results.
45 | - The result of BiSeNet is bad when the `lr=1e-3`, the lr needs to be set to `1e-2`.
46 | - DenseASPP has the fastest convergence process, and reached 100%.
47 | - The lr of ENet need to be set to `1e-2`, the edge of result is not smooth.
--------------------------------------------------------------------------------
/tests/runs/bisenet_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/bisenet_epoch_100.png
--------------------------------------------------------------------------------
/tests/runs/danet_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/danet_epoch_100.png
--------------------------------------------------------------------------------
/tests/runs/denseaspp_epoch_40.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/denseaspp_epoch_40.png
--------------------------------------------------------------------------------
/tests/runs/dunet_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/dunet_epoch_100.png
--------------------------------------------------------------------------------
/tests/runs/encnet_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/encnet_epoch_100.png
--------------------------------------------------------------------------------
/tests/runs/enet_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/enet_epoch_100.png
--------------------------------------------------------------------------------
/tests/runs/fcn16s_epoch_200.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/fcn16s_epoch_200.png
--------------------------------------------------------------------------------
/tests/runs/fcn32s_epoch_300.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/fcn32s_epoch_300.png
--------------------------------------------------------------------------------
/tests/runs/fcn8s_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/fcn8s_epoch_100.png
--------------------------------------------------------------------------------
/tests/runs/icnet_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/icnet_epoch_100.png
--------------------------------------------------------------------------------
/tests/runs/ocnet_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/ocnet_epoch_100.png
--------------------------------------------------------------------------------
/tests/runs/psp_epoch_100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/psp_epoch_100.png
--------------------------------------------------------------------------------
/tests/test_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/test_img.jpg
--------------------------------------------------------------------------------
/tests/test_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/test_mask.png
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | """Model overfitting test"""
2 | import argparse
3 | import time
4 | import os
5 | import sys
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.backends.cudnn as cudnn
10 | import numpy as np
11 |
12 | cur_path = os.path.abspath(os.path.dirname(__file__))
13 | root_path = os.path.split(cur_path)[0]
14 | sys.path.append(root_path)
15 |
16 | from torchvision import transforms
17 | from core.models.model_zoo import get_segmentation_model
18 | from core.utils.loss import MixSoftmaxCrossEntropyLoss, EncNetLoss, ICNetLoss
19 | from core.utils.lr_scheduler import LRScheduler
20 | from core.utils.score import hist_info, compute_score
21 | from core.utils.visualize import get_color_pallete
22 | from PIL import Image
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser(description='Semantic Segmentation Overfitting Test')
27 | # model
28 | parser.add_argument('--model', type=str, default='fcn32s',
29 | choices=['fcn32s', 'fcn16s', 'fcn8s', 'fcn', 'psp',
30 | 'deeplabv3', 'danet', 'denseaspp', 'bisenet', 'encnet',
31 | 'dunet', 'icnet', 'enet', 'ocnet'],
32 | help='model name (default: fcn32s)')
33 | parser.add_argument('--backbone', type=str, default='vgg16',
34 | choices=['vgg16', 'resnet18', 'resnet50', 'resnet101',
35 | 'resnet152', 'densenet121', '161', '169', '201'],
36 | help='backbone name (default: vgg16)')
37 | parser.add_argument('--dataset', type=str, default='pascal_voc',
38 | choices=['pascal_voc', 'pascal_aug', 'ade20k', 'citys',
39 | 'sbu'],
40 | help='dataset name (default: pascal_voc)')
41 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
42 | help='number of epochs to train (default: 100)')
43 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
44 | help='learning rate (default: 1e-3)')
45 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
46 | help='momentum (default: 0.9)')
47 | parser.add_argument('--weight-decay', type=float, default=1e-4, metavar='M',
48 | help='w-decay (default: 5e-4)')
49 | args = parser.parse_args()
50 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
51 | cudnn.benchmark = True
52 | args.device = device
53 | print(args)
54 | return args
55 |
56 |
57 | class VOCSegmentation(object):
58 | def __init__(self):
59 | super(VOCSegmentation, self).__init__()
60 | self.img = Image.open('test_img.jpg').convert('RGB')
61 | self.mask = Image.open('test_mask.png')
62 |
63 | self.img = self.img.resize((504, 368), Image.BILINEAR)
64 | self.mask = self.mask.resize((504, 368), Image.NEAREST)
65 |
66 | def get(self):
67 | img, mask = self._img_transform(self.img), self._mask_transform(self.mask)
68 | return img, mask
69 |
70 | def _img_transform(self, img):
71 | input_transform = transforms.Compose([
72 | transforms.ToTensor(),
73 | transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
74 | img = input_transform(img)
75 | img = img.unsqueeze(0)
76 |
77 | # For adaptive pooling
78 | # img = torch.cat([img, img], dim=0)
79 | return img
80 |
81 | def _mask_transform(self, mask):
82 | target = np.array(mask).astype('int32')
83 | target[target == 255] = -1
84 | target = torch.from_numpy(target).long()
85 | target = target.unsqueeze(0)
86 |
87 | # For adaptive pooling
88 | # target = torch.cat([target, target], dim=0)
89 | return target
90 |
91 |
92 | class Trainer(object):
93 | def __init__(self, args):
94 | self.args = args
95 |
96 | self.img, self.target = VOCSegmentation().get()
97 |
98 | self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
99 | aux=False, norm_layer=nn.BatchNorm2d).to(args.device)
100 |
101 | self.criterion = MixSoftmaxCrossEntropyLoss(False, 0., ignore_label=-1).to(args.device)
102 |
103 | # for EncNet
104 | # self.criterion = EncNetLoss(nclass=21, ignore_label=-1).to(args.device)
105 | # for ICNet
106 | # self.criterion = ICNetLoss(nclass=21, ignore_index=-1).to(args.device)
107 |
108 | self.optimizer = torch.optim.Adam(self.model.parameters(),
109 | lr=args.lr,
110 | weight_decay=args.weight_decay)
111 | self.lr_scheduler = LRScheduler(mode='poly', base_lr=args.lr, nepochs=args.epochs,
112 | iters_per_epoch=1, power=0.9)
113 |
114 | def train(self):
115 | self.model.train()
116 | start_time = time.time()
117 | for epoch in range(self.args.epochs):
118 | self.lr_scheduler(self.optimizer, epoch)
119 | cur_lr = self.lr_scheduler.learning_rate
120 | # self.lr_scheduler(self.optimizer, epoch)
121 | for param_group in self.optimizer.param_groups:
122 | param_group['lr'] = cur_lr
123 |
124 | images = self.img.to(self.args.device)
125 | targets = self.target.to(self.args.device)
126 |
127 | outputs = self.model(images)
128 | loss = self.criterion(outputs, targets)
129 |
130 | self.optimizer.zero_grad()
131 | loss['loss'].backward()
132 | self.optimizer.step()
133 |
134 | pred = torch.argmax(outputs[0], 1).cpu().data.numpy()
135 | mask = get_color_pallete(pred.squeeze(0), self.args.dataset)
136 | save_pred(self.args, epoch, mask)
137 | hist, labeled, correct = hist_info(pred, targets.cpu().numpy(), 21)
138 | _, mIoU, _, pixAcc = compute_score(hist, correct, labeled)
139 |
140 | print('Epoch: [%2d/%2d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f || pixAcc: %.3f || mIoU: %.3f' % (
141 | epoch, self.args.epochs, time.time() - start_time, cur_lr, loss['loss'].item(), pixAcc, mIoU))
142 |
143 |
144 | def save_pred(args, epoch, mask):
145 | directory = "runs/%s/" % (args.model)
146 | if not os.path.exists(directory):
147 | os.makedirs(directory)
148 | filename = directory + '{}_epoch_{}.png'.format(args.model, epoch + 1)
149 | mask.save(filename)
150 |
151 |
152 | if __name__ == '__main__':
153 | args = parse_args()
154 | trainer = Trainer(args)
155 | print('Test model: ', args.model)
156 | trainer.train()
157 |
--------------------------------------------------------------------------------
/tests/test_module.py:
--------------------------------------------------------------------------------
1 | import core
2 | import torch
3 | import numpy as np
4 |
5 | from torch.autograd import Variable
6 |
7 | EPS = 1e-3
8 | ATOL = 1e-3
9 |
10 |
11 | def _assert_tensor_close(a, b, atol=ATOL, rtol=EPS):
12 | npa, npb = a.cpu().numpy(), b.cpu().numpy()
13 | assert np.allclose(npa, npb, rtol=rtol, atol=atol), \
14 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(
15 | a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
16 |
17 |
18 | def testSyncBN():
19 | def _check_batchnorm_result(bn1, bn2, input, is_train, cuda=False):
20 | def _find_bn(module):
21 | for m in module.modules():
22 | if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
23 | core.nn.SyncBatchNorm)):
24 | return m
25 |
26 | def _syncParameters(bn1, bn2):
27 | bn1.reset_parameters()
28 | bn2.reset_parameters()
29 | if bn1.affine and bn2.affine:
30 | bn2.weight.data.copy_(bn1.weight.data)
31 | bn2.bias.data.copy_(bn1.bias.data)
32 | bn2.running_mean.copy_(bn1.running_mean)
33 | bn2.running_var.copy_(bn1.running_var)
34 |
35 | bn1.train(mode=is_train)
36 | bn2.train(mode=is_train)
37 |
38 | if cuda:
39 | input = input.cuda()
40 | # using the same values for gamma and beta
41 | _syncParameters(_find_bn(bn1), _find_bn(bn2))
42 |
43 | input1 = Variable(input.clone().detach(), requires_grad=True)
44 | input2 = Variable(input.clone().detach(), requires_grad=True)
45 | if is_train:
46 | bn1.train()
47 | bn2.train()
48 | output1 = bn1(input1)
49 | output2 = bn2(input2)
50 | else:
51 | bn1.eval()
52 | bn2.eval()
53 | with torch.no_grad():
54 | output1 = bn1(input1)
55 | output2 = bn2(input2)
56 | # assert forwarding
57 | # _assert_tensor_close(input1.data, input2.data)
58 | _assert_tensor_close(output1.data, output2.data)
59 | if not is_train:
60 | return
61 | (output1 ** 2).sum().backward()
62 | (output2 ** 2).sum().backward()
63 | _assert_tensor_close(_find_bn(bn1).bias.grad.data, _find_bn(bn2).bias.grad.data)
64 | _assert_tensor_close(_find_bn(bn1).weight.grad.data, _find_bn(bn2).weight.grad.data)
65 | _assert_tensor_close(input1.grad.data, input2.grad.data)
66 | _assert_tensor_close(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
67 | # _assert_tensor_close(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
68 |
69 | bn = torch.nn.BatchNorm2d(10).cuda().double()
70 | sync_bn = core.nn.SyncBatchNorm(10, inplace=True, sync=True).cuda().double()
71 | sync_bn = torch.nn.DataParallel(sync_bn).cuda()
72 | # check with unsync version
73 | # _check_batchnorm_result(bn, sync_bn, torch.rand(2, 1, 2, 2).double(), True, cuda=True)
74 | for i in range(10):
75 | print(i)
76 | _check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), True, cuda=True)
77 | # _check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), False, cuda=True)
78 |
79 |
80 | if __name__ == '__main__':
81 | import nose
82 |
83 | nose.runmodule()
84 |
--------------------------------------------------------------------------------