├── .gitignore ├── LICENSE ├── README.md ├── README_zh-CN.md ├── core ├── __init__.py ├── data │ ├── __init__.py │ ├── dataloader │ │ ├── __init__.py │ │ ├── ade.py │ │ ├── cityscapes.py │ │ ├── lip_parsing.py │ │ ├── mscoco.py │ │ ├── pascal_aug.py │ │ ├── pascal_voc.py │ │ ├── sbu_shadow.py │ │ ├── segbase.py │ │ └── utils.py │ └── downloader │ │ ├── __init__.py │ │ ├── ade20k.py │ │ ├── cityscapes.py │ │ ├── mscoco.py │ │ ├── pascal_voc.py │ │ └── sbu_shadow.py ├── models │ ├── __init__.py │ ├── base_models │ │ ├── __init__.py │ │ ├── densenet.py │ │ ├── eespnet.py │ │ ├── hrnet.py │ │ ├── mobilenetv2.py │ │ ├── resnet.py │ │ ├── resnetv1b.py │ │ ├── resnext.py │ │ ├── vgg.py │ │ └── xception.py │ ├── bisenet.py │ ├── cgnet.py │ ├── danet.py │ ├── deeplabv3.py │ ├── deeplabv3_plus.py │ ├── denseaspp.py │ ├── dfanet.py │ ├── dunet.py │ ├── encnet.py │ ├── enet.py │ ├── espnet.py │ ├── fcn.py │ ├── fcnv2.py │ ├── hrnet.py │ ├── icnet.py │ ├── lednet.py │ ├── model_store.py │ ├── model_zoo.py │ ├── ocnet.py │ ├── psanet.py │ ├── pspnet.py │ ├── segbase.py │ └── swnet.py ├── nn │ ├── __init__.py │ ├── basic.py │ └── jpu.py └── utils │ ├── __init__.py │ ├── distributed.py │ ├── download.py │ ├── filesystem.py │ ├── logger.py │ ├── loss.py │ ├── lr_scheduler.py │ ├── parallel.py │ ├── score.py │ └── visualize.py ├── datasets ├── ade ├── citys ├── sbu └── voc ├── docs ├── DETAILS.md ├── QQ.jpg ├── WeChat.jpeg ├── requirements.yml └── weimar_000091_000019_gtFine_color.png ├── requirements.txt ├── scripts ├── demo.py ├── eval.py ├── fcn32s_vgg16_pascal_voc.sh ├── fcn32s_vgg16_pascal_voc_dist.sh └── train.py └── tests ├── README.md ├── runs ├── bisenet_epoch_100.png ├── danet_epoch_100.png ├── denseaspp_epoch_40.png ├── dunet_epoch_100.png ├── encnet_epoch_100.png ├── enet_epoch_100.png ├── fcn16s_epoch_200.png ├── fcn32s_epoch_300.png ├── fcn8s_epoch_100.png ├── icnet_epoch_100.png ├── ocnet_epoch_100.png └── psp_epoch_100.png ├── test_img.jpg ├── test_mask.png ├── test_model.py └── test_module.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | *.idea 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # pycharm 107 | 108 | # premodel 109 | weights/ 110 | *.pkl 111 | *.pth 112 | 113 | # dataset 114 | datasets/ 115 | VOCdevket/ 116 | eval/ 117 | 118 | # overfitting test 119 | 120 | # run result 121 | /tests/runs 122 | /runs 123 | 124 | # model 125 | /models/hrnet.py 126 | /models/psanet_old.py 127 | /scripts/debug.py 128 | 129 | # nn 130 | nn/sync_bn/ 131 | 132 | # venv 133 | AwsmSemSegPytorch-env/ 134 | .vscode/launch.json 135 | .vscode/settings.json 136 | 137 | 138 | # builded files 139 | core/nn/sync_bn/lib/gpu/build.ninja 140 | core/nn/sync_bn/lib/gpu/.ninja_log 141 | core/nn/sync_bn/lib/gpu/.ninja_deps 142 | 143 | -------------------------------------------------------------------------------- /README_zh-CN.md: -------------------------------------------------------------------------------- 1 | ## 中文说明 2 | 3 | [English](/README.md) | 简体中文 4 | 5 | [![python-image]][python-url] 6 | [![pytorch-image]][pytorch-url] 7 | [![lic-image]][lic-url] 8 | 9 |

10 | 11 | 该项目旨在提供一个基于[PyTorch](https://pytorch.org/)的简洁、易用且可扩展的语义分割工具箱。 12 | 13 | 主分支代码目前支持 **PyTorch 1.1.0 以上**的版本 14 | 15 | ## 安装 16 | 17 | ``` 18 | # semantic-segmentation-pytorch dependencies 19 | pip install ninja tqdm 20 | 21 | # follow PyTorch installation in https://pytorch.org/get-started/locally/ 22 | conda install pytorch torchvision -c pytorch 23 | 24 | # install PyTorch Segmentation 25 | git clone https://github.com/Tramac/awesome-semantic-segmentation-pytorch.git 26 | ``` 27 | 28 | ## 使用方式 29 | 30 | ### 训练 31 | ----------------- 32 | - **单卡训练** 33 | ``` 34 | # for example, train fcn32_vgg16_pascal_voc: 35 | python train.py --model fcn32s --backbone vgg16 --dataset pascal_voc --lr 0.0001 --epochs 50 36 | ``` 37 | - **分布式训练** 38 | ``` 39 | # for example, train fcn32_vgg16_pascal_voc with 4 GPUs: 40 | export NGPUS=4 41 | python -m torch.distributed.launch --nproc_per_node=$NGPUS train.py --model fcn32s --backbone vgg16 --dataset pascal_voc --lr 0.0001 --epochs 50 42 | ``` 43 | 44 | ### 测试 45 | ----------------- 46 | - **单卡测试** 47 | ``` 48 | # for example, evaluate fcn32_vgg16_pascal_voc 49 | python eval.py --model fcn32s --backbone vgg16 --dataset pascal_voc 50 | ``` 51 | - **多卡测试** 52 | ``` 53 | # for example, evaluate fcn32_vgg16_pascal_voc with 4 GPUs: 54 | export NGPUS=4 55 | python -m torch.distributed.launch --nproc_per_node=$NGPUS eval.py --model fcn32s --backbone vgg16 --dataset pascal_voc 56 | ``` 57 | 58 | ### Demo 59 | ``` 60 | cd ./scripts 61 | #for new users: 62 | python demo.py --model fcn32s_vgg16_voc --input-pic ../tests/test_img.jpg 63 | #you should add 'test.jpg' by yourself 64 | python demo.py --model fcn32s_vgg16_voc --input-pic ../datasets/test.jpg 65 | ``` 66 | 67 | ### 模型库 68 | ------------------------------------- 69 | - [FCN](https://arxiv.org/abs/1411.4038) 70 | - [ENet](https://arxiv.org/pdf/1606.02147) 71 | - [PSPNet](https://arxiv.org/pdf/1612.01105) 72 | - [ICNet](https://arxiv.org/pdf/1704.08545) 73 | - [DeepLabv3](https://arxiv.org/abs/1706.05587) 74 | - [DeepLabv3+](https://arxiv.org/pdf/1802.02611) 75 | - [DenseASPP](http://openaccess.thecvf.com/content_cvpr_2018/papers/Yang_DenseASPP_for_Semantic_CVPR_2018_paper.pdf) 76 | - [EncNet](https://arxiv.org/abs/1803.08904v1) 77 | - [BiSeNet](https://arxiv.org/abs/1808.00897) 78 | - [PSANet](https://hszhao.github.io/papers/eccv18_psanet.pdf) 79 | - [DANet](https://arxiv.org/pdf/1809.02983) 80 | - [OCNet](https://arxiv.org/pdf/1809.00916) 81 | - [CGNet](https://arxiv.org/pdf/1811.08201) 82 | - [ESPNetv2](https://arxiv.org/abs/1811.11431) 83 | - [DUNet(DUpsampling)](https://arxiv.org/abs/1903.02120) 84 | - [FastFCN(JPU)](https://arxiv.org/abs/1903.11816) 85 | - [LEDNet](https://arxiv.org/abs/1905.02423) 86 | - [Fast-SCNN](https://github.com/Tramac/Fast-SCNN-pytorch) 87 | - [LightSeg](https://github.com/Tramac/Lightweight-Segmentation) 88 | - [DFANet](https://arxiv.org/abs/1904.02216) 89 | 90 | Model与Backbone的支持详情可见[这里](https://github.com/Tramac/awesome-semantic-segmentation-pytorch/blob/master/docs/DETAILS.md)。 91 | 92 | ``` 93 | .{SEG_ROOT} 94 | ├── core 95 | │   ├── models 96 | │   │   ├── bisenet.py 97 | │   │   ├── danet.py 98 | │   │   ├── deeplabv3.py 99 | │ │ ├── deeplabv3+.py 100 | │   │   ├── denseaspp.py 101 | │   │   ├── dunet.py 102 | │   │   ├── encnet.py 103 | │   │   ├── fcn.py 104 | │   │   ├── pspnet.py 105 | │   │   ├── icnet.py 106 | │   │   ├── enet.py 107 | │   │   ├── ocnet.py 108 | │   │   ├── psanet.py 109 | │   │   ├── cgnet.py 110 | │   │   ├── espnet.py 111 | │   │   ├── lednet.py 112 | │   │   ├── dfanet.py 113 | │   │   ├── ...... 114 | ``` 115 | 116 | ### 数据集 117 | 可以选择以下方式下载指定数据集,比如: 118 | ``` 119 | cd ./core/data/downloader 120 | python ade20k.py --download-dir ../datasets/ade 121 | ``` 122 | 123 | | Dataset | training set | validation set | testing set | 124 | | :----------------------------------------------------------: | :----------: | :------------: | :---------: | 125 | | [VOC2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) | 1464 | 1449 | ✘ | 126 | | [VOCAug](http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz) | 11355 | 2857 | ✘ | 127 | | [ADK20K](http://groups.csail.mit.edu/vision/datasets/ADE20K/) | 20210 | 2000 | ✘ | 128 | | [Cityscapes](https://www.cityscapes-dataset.com/downloads/) | 2975 | 500 | ✘ | 129 | | [COCO](http://cocodataset.org/#download) | | | | 130 | | [SBU-shadow](http://www3.cs.stonybrook.edu/~cvl/content/datasets/shadow_db/SBU-shadow.zip) | 4085 | 638 | ✘ | 131 | | [LIP(Look into Person)](http://sysu-hcp.net/lip/) | 30462 | 10000 | 10000 | 132 | 133 | ``` 134 | .{SEG_ROOT} 135 | ├── core 136 | │   ├── data 137 | │   │   ├── dataloader 138 | │   │   │   ├── ade.py 139 | │   │   │   ├── cityscapes.py 140 | │   │   │   ├── mscoco.py 141 | │   │   │   ├── pascal_aug.py 142 | │   │   │   ├── pascal_voc.py 143 | │   │   │   ├── sbu_shadow.py 144 | │   │   └── downloader 145 | │   │   ├── ade20k.py 146 | │   │   ├── cityscapes.py 147 | │   │   ├── mscoco.py 148 | │   │   ├── pascal_voc.py 149 | │   │   └── sbu_shadow.py 150 | ``` 151 | 152 | ## 部分结果 153 | |Methods|Backbone|TrainSet|EvalSet|crops_size|epochs|JPU|Mean IoU|pixAcc| 154 | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| 155 | |FCN32s|vgg16|train|val|480|60|✘|47.50|85.39| 156 | |FCN16s|vgg16|train|val|480|60|✘|49.16|85.98| 157 | |FCN8s|vgg16|train|val|480|60|✘|48.87|85.02| 158 | |FCN32s|resnet50|train|val|480|50|✘|54.60|88.57| 159 | |PSPNet|resnet50|train|val|480|60|✘|63.44|89.78| 160 | |DeepLabv3|resnet50|train|val|480|60|✘|60.15|88.36| 161 | 162 | `lr=1e-4, batch_size=4, epochs=80`. 163 | 注意: 以上结果均基于`train.py`中的默认参数所得,更优的效果请参照paper中具体参数。 164 | 165 | ## 版本 166 | 167 | - v0.1.0:相较于master分支,该版本包含了`ccnet`与`psanet`,需要依赖编译产出自定义层,如有需要按照分支说明操作即可。 168 | 169 | ## To Do 170 | 171 | - [x] add train script 172 | - [ ] remove syncbn 173 | - [ ] train & evaluate 174 | - [x] test distributed training 175 | - [x] fix syncbn ([Why SyncBN?](https://tramac.github.io/2019/02/25/%E8%B7%A8%E5%8D%A1%E5%90%8C%E6%AD%A5%20Batch%20Normalization[%E8%BD%AC]/)) 176 | - [x] add distributed ([How DIST?]("https://tramac.github.io/2019/03/06/%E5%88%86%E5%B8%83%E5%BC%8F%E8%AE%AD%E7%BB%83-PyTorch/")) 177 | 178 | ## 参考 179 | - [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) 180 | - [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) 181 | - [gloun-cv](https://github.com/dmlc/gluon-cv) 182 | - [imagenet](https://github.com/pytorch/examples/tree/master/imagenet) 183 | 184 | ## WeChat&QQ 185 | 由于很多小伙伴通过知乎私信寻求一些帮助,同时也不忍心忽略到大家的问题,下面分别提供了针对该项目的微信&QQ群,希望可以帮助到有需要的同学~ 186 |

187 | 188 | 193 | 194 | [python-image]: https://img.shields.io/badge/Python-2.x|3.x-ff69b4.svg 195 | [python-url]: https://www.python.org/ 196 | [pytorch-image]: https://img.shields.io/badge/PyTorch-1.1-2BAF2B.svg 197 | [pytorch-url]: https://pytorch.org/ 198 | [lic-image]: https://img.shields.io/badge/Apache-2.0-blue.svg 199 | [lic-url]: https://github.com/Tramac/Awesome-semantic-segmentation-pytorch/blob/master/LICENSE 200 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nn, models, utils, data -------------------------------------------------------------------------------- /core/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/core/data/__init__.py -------------------------------------------------------------------------------- /core/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides data loaders and transformers for popular vision datasets. 3 | """ 4 | from .mscoco import COCOSegmentation 5 | from .cityscapes import CitySegmentation 6 | from .ade import ADE20KSegmentation 7 | from .pascal_voc import VOCSegmentation 8 | from .pascal_aug import VOCAugSegmentation 9 | from .sbu_shadow import SBUSegmentation 10 | 11 | datasets = { 12 | 'ade20k': ADE20KSegmentation, 13 | 'pascal_voc': VOCSegmentation, 14 | 'pascal_aug': VOCAugSegmentation, 15 | 'coco': COCOSegmentation, 16 | 'citys': CitySegmentation, 17 | 'sbu': SBUSegmentation, 18 | } 19 | 20 | 21 | def get_segmentation_dataset(name, **kwargs): 22 | """Segmentation Datasets""" 23 | return datasets[name.lower()](**kwargs) 24 | -------------------------------------------------------------------------------- /core/data/dataloader/cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from .segbase import SegmentationDataset 8 | 9 | 10 | class CitySegmentation(SegmentationDataset): 11 | """Cityscapes Semantic Segmentation Dataset. 12 | 13 | Parameters 14 | ---------- 15 | root : string 16 | Path to Cityscapes folder. Default is './datasets/citys' 17 | split: string 18 | 'train', 'val' or 'test' 19 | transform : callable, optional 20 | A function that transforms the image 21 | Examples 22 | -------- 23 | >>> from torchvision import transforms 24 | >>> import torch.utils.data as data 25 | >>> # Transforms for Normalization 26 | >>> input_transform = transforms.Compose([ 27 | >>> transforms.ToTensor(), 28 | >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)), 29 | >>> ]) 30 | >>> # Create Dataset 31 | >>> trainset = CitySegmentation(split='train', transform=input_transform) 32 | >>> # Create Training Loader 33 | >>> train_data = data.DataLoader( 34 | >>> trainset, 4, shuffle=True, 35 | >>> num_workers=4) 36 | """ 37 | BASE_DIR = 'cityscapes' 38 | NUM_CLASS = 19 39 | 40 | def __init__(self, root='../datasets/citys', split='train', mode=None, transform=None, **kwargs): 41 | super(CitySegmentation, self).__init__(root, split, mode, transform, **kwargs) 42 | # self.root = os.path.join(root, self.BASE_DIR) 43 | assert os.path.exists(self.root), "Please setup the dataset using ../datasets/cityscapes.py" 44 | self.images, self.mask_paths = _get_city_pairs(self.root, self.split) 45 | assert (len(self.images) == len(self.mask_paths)) 46 | if len(self.images) == 0: 47 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 48 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 49 | 23, 24, 25, 26, 27, 28, 31, 32, 33] 50 | self._key = np.array([-1, -1, -1, -1, -1, -1, 51 | -1, -1, 0, 1, -1, -1, 52 | 2, 3, 4, -1, -1, -1, 53 | 5, -1, 6, 7, 8, 9, 54 | 10, 11, 12, 13, 14, 15, 55 | -1, -1, 16, 17, 18]) 56 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 57 | 58 | def _class_to_index(self, mask): 59 | # assert the value 60 | values = np.unique(mask) 61 | for value in values: 62 | assert (value in self._mapping) 63 | index = np.digitize(mask.ravel(), self._mapping, right=True) 64 | return self._key[index].reshape(mask.shape) 65 | 66 | def __getitem__(self, index): 67 | img = Image.open(self.images[index]).convert('RGB') 68 | if self.mode == 'test': 69 | if self.transform is not None: 70 | img = self.transform(img) 71 | return img, os.path.basename(self.images[index]) 72 | mask = Image.open(self.mask_paths[index]) 73 | # synchrosized transform 74 | if self.mode == 'train': 75 | img, mask = self._sync_transform(img, mask) 76 | elif self.mode == 'val': 77 | img, mask = self._val_sync_transform(img, mask) 78 | else: 79 | assert self.mode == 'testval' 80 | img, mask = self._img_transform(img), self._mask_transform(mask) 81 | # general resize, normalize and toTensor 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | return img, mask, os.path.basename(self.images[index]) 85 | 86 | def _mask_transform(self, mask): 87 | target = self._class_to_index(np.array(mask).astype('int32')) 88 | return torch.LongTensor(np.array(target).astype('int32')) 89 | 90 | def __len__(self): 91 | return len(self.images) 92 | 93 | @property 94 | def pred_offset(self): 95 | return 0 96 | 97 | 98 | def _get_city_pairs(folder, split='train'): 99 | def get_path_pairs(img_folder, mask_folder): 100 | img_paths = [] 101 | mask_paths = [] 102 | for root, _, files in os.walk(img_folder): 103 | for filename in files: 104 | if filename.endswith('.png'): 105 | imgpath = os.path.join(root, filename) 106 | foldername = os.path.basename(os.path.dirname(imgpath)) 107 | maskname = filename.replace('leftImg8bit', 'gtFine_labelIds') 108 | maskpath = os.path.join(mask_folder, foldername, maskname) 109 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 110 | img_paths.append(imgpath) 111 | mask_paths.append(maskpath) 112 | else: 113 | print('cannot find the mask or image:', imgpath, maskpath) 114 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 115 | return img_paths, mask_paths 116 | 117 | if split in ('train', 'val'): 118 | img_folder = os.path.join(folder, 'leftImg8bit/' + split) 119 | mask_folder = os.path.join(folder, 'gtFine/' + split) 120 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 121 | return img_paths, mask_paths 122 | else: 123 | assert split == 'trainval' 124 | print('trainval set') 125 | train_img_folder = os.path.join(folder, 'leftImg8bit/train') 126 | train_mask_folder = os.path.join(folder, 'gtFine/train') 127 | val_img_folder = os.path.join(folder, 'leftImg8bit/val') 128 | val_mask_folder = os.path.join(folder, 'gtFine/val') 129 | train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder) 130 | val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder) 131 | img_paths = train_img_paths + val_img_paths 132 | mask_paths = train_mask_paths + val_mask_paths 133 | return img_paths, mask_paths 134 | 135 | 136 | if __name__ == '__main__': 137 | dataset = CitySegmentation() 138 | -------------------------------------------------------------------------------- /core/data/dataloader/lip_parsing.py: -------------------------------------------------------------------------------- 1 | """Look into Person Dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from core.data.dataloader.segbase import SegmentationDataset 8 | 9 | 10 | class LIPSegmentation(SegmentationDataset): 11 | """Look into person parsing dataset """ 12 | 13 | BASE_DIR = 'LIP' 14 | NUM_CLASS = 20 15 | 16 | def __init__(self, root='../datasets/LIP', split='train', mode=None, transform=None, **kwargs): 17 | super(LIPSegmentation, self).__init__(root, split, mode, transform, **kwargs) 18 | _trainval_image_dir = os.path.join(root, 'TrainVal_images') 19 | _testing_image_dir = os.path.join(root, 'Testing_images') 20 | _trainval_mask_dir = os.path.join(root, 'TrainVal_parsing_annotations') 21 | if split == 'train': 22 | _image_dir = os.path.join(_trainval_image_dir, 'train_images') 23 | _mask_dir = os.path.join(_trainval_mask_dir, 'train_segmentations') 24 | _split_f = os.path.join(_trainval_image_dir, 'train_id.txt') 25 | elif split == 'val': 26 | _image_dir = os.path.join(_trainval_image_dir, 'val_images') 27 | _mask_dir = os.path.join(_trainval_mask_dir, 'val_segmentations') 28 | _split_f = os.path.join(_trainval_image_dir, 'val_id.txt') 29 | elif split == 'test': 30 | _image_dir = os.path.join(_testing_image_dir, 'testing_images') 31 | _split_f = os.path.join(_testing_image_dir, 'test_id.txt') 32 | else: 33 | raise RuntimeError('Unknown dataset split.') 34 | 35 | self.images = [] 36 | self.masks = [] 37 | with open(os.path.join(_split_f), 'r') as lines: 38 | for line in lines: 39 | _image = os.path.join(_image_dir, line.rstrip('\n') + '.jpg') 40 | assert os.path.isfile(_image) 41 | self.images.append(_image) 42 | if split != 'test': 43 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + '.png') 44 | assert os.path.isfile(_mask) 45 | self.masks.append(_mask) 46 | 47 | if split != 'test': 48 | assert (len(self.images) == len(self.masks)) 49 | print('Found {} {} images in the folder {}'.format(len(self.images), split, root)) 50 | 51 | def __getitem__(self, index): 52 | img = Image.open(self.images[index]).convert('RGB') 53 | if self.mode == 'test': 54 | img = self._img_transform(img) 55 | if self.transform is not None: 56 | img = self.transform(img) 57 | return img, os.path.basename(self.images[index]) 58 | mask = Image.open(self.masks[index]) 59 | # synchronized transform 60 | if self.mode == 'train': 61 | img, mask = self._sync_transform(img, mask) 62 | elif self.mode == 'val': 63 | img, mask = self._val_sync_transform(img, mask) 64 | else: 65 | assert self.mode == 'testval' 66 | img, mask = self._img_transform(img), self._mask_transform(mask) 67 | # general resize, normalize and toTensor 68 | if self.transform is not None: 69 | img = self.transform(img) 70 | 71 | return img, mask, os.path.basename(self.images[index]) 72 | 73 | def __len__(self): 74 | return len(self.images) 75 | 76 | def _mask_transform(self, mask): 77 | target = np.array(mask).astype('int32') 78 | return torch.from_numpy(target).long() 79 | 80 | @property 81 | def classes(self): 82 | """Category name.""" 83 | return ('background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 84 | 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', 85 | 'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe', 86 | 'rightShoe') 87 | 88 | 89 | if __name__ == '__main__': 90 | dataset = LIPSegmentation(base_size=280, crop_size=256) -------------------------------------------------------------------------------- /core/data/dataloader/mscoco.py: -------------------------------------------------------------------------------- 1 | """MSCOCO Semantic Segmentation pretraining for VOC.""" 2 | import os 3 | import pickle 4 | import torch 5 | import numpy as np 6 | 7 | from tqdm import trange 8 | from PIL import Image 9 | from .segbase import SegmentationDataset 10 | 11 | 12 | class COCOSegmentation(SegmentationDataset): 13 | """COCO Semantic Segmentation Dataset for VOC Pre-training. 14 | 15 | Parameters 16 | ---------- 17 | root : string 18 | Path to ADE20K folder. Default is './datasets/coco' 19 | split: string 20 | 'train', 'val' or 'test' 21 | transform : callable, optional 22 | A function that transforms the image 23 | Examples 24 | -------- 25 | >>> from torchvision import transforms 26 | >>> import torch.utils.data as data 27 | >>> # Transforms for Normalization 28 | >>> input_transform = transforms.Compose([ 29 | >>> transforms.ToTensor(), 30 | >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)), 31 | >>> ]) 32 | >>> # Create Dataset 33 | >>> trainset = COCOSegmentation(split='train', transform=input_transform) 34 | >>> # Create Training Loader 35 | >>> train_data = data.DataLoader( 36 | >>> trainset, 4, shuffle=True, 37 | >>> num_workers=4) 38 | """ 39 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 40 | 1, 64, 20, 63, 7, 72] 41 | NUM_CLASS = 21 42 | 43 | def __init__(self, root='../datasets/coco', split='train', mode=None, transform=None, **kwargs): 44 | super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs) 45 | # lazy import pycocotools 46 | from pycocotools.coco import COCO 47 | from pycocotools import mask 48 | if split == 'train': 49 | print('train set') 50 | ann_file = os.path.join(root, 'annotations/instances_train2017.json') 51 | ids_file = os.path.join(root, 'annotations/train_ids.mx') 52 | self.root = os.path.join(root, 'train2017') 53 | else: 54 | print('val set') 55 | ann_file = os.path.join(root, 'annotations/instances_val2017.json') 56 | ids_file = os.path.join(root, 'annotations/val_ids.mx') 57 | self.root = os.path.join(root, 'val2017') 58 | self.coco = COCO(ann_file) 59 | self.coco_mask = mask 60 | if os.path.exists(ids_file): 61 | with open(ids_file, 'rb') as f: 62 | self.ids = pickle.load(f) 63 | else: 64 | ids = list(self.coco.imgs.keys()) 65 | self.ids = self._preprocess(ids, ids_file) 66 | self.transform = transform 67 | 68 | def __getitem__(self, index): 69 | coco = self.coco 70 | img_id = self.ids[index] 71 | img_metadata = coco.loadImgs(img_id)[0] 72 | path = img_metadata['file_name'] 73 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 74 | cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) 75 | mask = Image.fromarray(self._gen_seg_mask( 76 | cocotarget, img_metadata['height'], img_metadata['width'])) 77 | # synchrosized transform 78 | if self.mode == 'train': 79 | img, mask = self._sync_transform(img, mask) 80 | elif self.mode == 'val': 81 | img, mask = self._val_sync_transform(img, mask) 82 | else: 83 | assert self.mode == 'testval' 84 | img, mask = self._img_transform(img), self._mask_transform(mask) 85 | # general resize, normalize and toTensor 86 | if self.transform is not None: 87 | img = self.transform(img) 88 | return img, mask, os.path.basename(self.ids[index]) 89 | 90 | def _mask_transform(self, mask): 91 | return torch.LongTensor(np.array(mask).astype('int32')) 92 | 93 | def _gen_seg_mask(self, target, h, w): 94 | mask = np.zeros((h, w), dtype=np.uint8) 95 | coco_mask = self.coco_mask 96 | for instance in target: 97 | rle = coco_mask.frPyObjects(instance['Segmentation'], h, w) 98 | m = coco_mask.decode(rle) 99 | cat = instance['category_id'] 100 | if cat in self.CAT_LIST: 101 | c = self.CAT_LIST.index(cat) 102 | else: 103 | continue 104 | if len(m.shape) < 3: 105 | mask[:, :] += (mask == 0) * (m * c) 106 | else: 107 | mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) 108 | return mask 109 | 110 | def _preprocess(self, ids, ids_file): 111 | print("Preprocessing mask, this will take a while." + \ 112 | "But don't worry, it only run once for each split.") 113 | tbar = trange(len(ids)) 114 | new_ids = [] 115 | for i in tbar: 116 | img_id = ids[i] 117 | cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 118 | img_metadata = self.coco.loadImgs(img_id)[0] 119 | mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width']) 120 | # more than 1k pixels 121 | if (mask > 0).sum() > 1000: 122 | new_ids.append(img_id) 123 | tbar.set_description('Doing: {}/{}, got {} qualified images'. \ 124 | format(i, len(ids), len(new_ids))) 125 | print('Found number of qualified images: ', len(new_ids)) 126 | with open(ids_file, 'wb') as f: 127 | pickle.dump(new_ids, f) 128 | return new_ids 129 | 130 | @property 131 | def classes(self): 132 | """Category names.""" 133 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 134 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 135 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 136 | 'tv') 137 | -------------------------------------------------------------------------------- /core/data/dataloader/pascal_aug.py: -------------------------------------------------------------------------------- 1 | """Pascal Augmented VOC Semantic Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import scipy.io as sio 5 | import numpy as np 6 | 7 | from PIL import Image 8 | from .segbase import SegmentationDataset 9 | 10 | 11 | class VOCAugSegmentation(SegmentationDataset): 12 | """Pascal VOC Augmented Semantic Segmentation Dataset. 13 | 14 | Parameters 15 | ---------- 16 | root : string 17 | Path to VOCdevkit folder. Default is './datasets/voc' 18 | split: string 19 | 'train', 'val' or 'test' 20 | transform : callable, optional 21 | A function that transforms the image 22 | Examples 23 | -------- 24 | >>> from torchvision import transforms 25 | >>> import torch.utils.data as data 26 | >>> # Transforms for Normalization 27 | >>> input_transform = transforms.Compose([ 28 | >>> transforms.ToTensor(), 29 | >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 30 | >>> ]) 31 | >>> # Create Dataset 32 | >>> trainset = VOCAugSegmentation(split='train', transform=input_transform) 33 | >>> # Create Training Loader 34 | >>> train_data = data.DataLoader( 35 | >>> trainset, 4, shuffle=True, 36 | >>> num_workers=4) 37 | """ 38 | BASE_DIR = 'VOCaug/dataset/' 39 | NUM_CLASS = 21 40 | 41 | def __init__(self, root='../datasets/voc', split='train', mode=None, transform=None, **kwargs): 42 | super(VOCAugSegmentation, self).__init__(root, split, mode, transform, **kwargs) 43 | # train/val/test splits are pre-cut 44 | _voc_root = os.path.join(root, self.BASE_DIR) 45 | _mask_dir = os.path.join(_voc_root, 'cls') 46 | _image_dir = os.path.join(_voc_root, 'img') 47 | if split == 'train': 48 | _split_f = os.path.join(_voc_root, 'trainval.txt') 49 | elif split == 'val': 50 | _split_f = os.path.join(_voc_root, 'val.txt') 51 | else: 52 | raise RuntimeError('Unknown dataset split: {}'.format(split)) 53 | 54 | self.images = [] 55 | self.masks = [] 56 | with open(os.path.join(_split_f), "r") as lines: 57 | for line in lines: 58 | _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") 59 | assert os.path.isfile(_image) 60 | self.images.append(_image) 61 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".mat") 62 | assert os.path.isfile(_mask) 63 | self.masks.append(_mask) 64 | 65 | assert (len(self.images) == len(self.masks)) 66 | print('Found {} images in the folder {}'.format(len(self.images), _voc_root)) 67 | 68 | def __getitem__(self, index): 69 | img = Image.open(self.images[index]).convert('RGB') 70 | target = self._load_mat(self.masks[index]) 71 | # synchrosized transform 72 | if self.mode == 'train': 73 | img, target = self._sync_transform(img, target) 74 | elif self.mode == 'val': 75 | img, target = self._val_sync_transform(img, target) 76 | else: 77 | raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode)) 78 | # general resize, normalize and toTensor 79 | if self.transform is not None: 80 | img = self.transform(img) 81 | return img, target, os.path.basename(self.images[index]) 82 | 83 | def _mask_transform(self, mask): 84 | return torch.LongTensor(np.array(mask).astype('int32')) 85 | 86 | def _load_mat(self, filename): 87 | mat = sio.loadmat(filename, mat_dtype=True, squeeze_me=True, struct_as_record=False) 88 | mask = mat['GTcls'].Segmentation 89 | return Image.fromarray(mask) 90 | 91 | def __len__(self): 92 | return len(self.images) 93 | 94 | @property 95 | def classes(self): 96 | """Category names.""" 97 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 98 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 99 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 100 | 'tv') 101 | 102 | 103 | if __name__ == '__main__': 104 | dataset = VOCAugSegmentation() -------------------------------------------------------------------------------- /core/data/dataloader/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """Pascal VOC Semantic Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from .segbase import SegmentationDataset 8 | 9 | 10 | class VOCSegmentation(SegmentationDataset): 11 | """Pascal VOC Semantic Segmentation Dataset. 12 | 13 | Parameters 14 | ---------- 15 | root : string 16 | Path to VOCdevkit folder. Default is './datasets/VOCdevkit' 17 | split: string 18 | 'train', 'val' or 'test' 19 | transform : callable, optional 20 | A function that transforms the image 21 | Examples 22 | -------- 23 | >>> from torchvision import transforms 24 | >>> import torch.utils.data as data 25 | >>> # Transforms for Normalization 26 | >>> input_transform = transforms.Compose([ 27 | >>> transforms.ToTensor(), 28 | >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 29 | >>> ]) 30 | >>> # Create Dataset 31 | >>> trainset = VOCSegmentation(split='train', transform=input_transform) 32 | >>> # Create Training Loader 33 | >>> train_data = data.DataLoader( 34 | >>> trainset, 4, shuffle=True, 35 | >>> num_workers=4) 36 | """ 37 | BASE_DIR = 'VOC2012' 38 | NUM_CLASS = 21 39 | 40 | def __init__(self, root='../datasets/voc', split='train', mode=None, transform=None, **kwargs): 41 | super(VOCSegmentation, self).__init__(root, split, mode, transform, **kwargs) 42 | _voc_root = os.path.join(root, self.BASE_DIR) 43 | _mask_dir = os.path.join(_voc_root, 'SegmentationClass') 44 | _image_dir = os.path.join(_voc_root, 'JPEGImages') 45 | # train/val/test splits are pre-cut 46 | _splits_dir = os.path.join(_voc_root, 'ImageSets/Segmentation') 47 | if split == 'train': 48 | _split_f = os.path.join(_splits_dir, 'train.txt') 49 | elif split == 'val': 50 | _split_f = os.path.join(_splits_dir, 'val.txt') 51 | elif split == 'test': 52 | _split_f = os.path.join(_splits_dir, 'test.txt') 53 | else: 54 | raise RuntimeError('Unknown dataset split.') 55 | 56 | self.images = [] 57 | self.masks = [] 58 | with open(os.path.join(_split_f), "r") as lines: 59 | for line in lines: 60 | _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") 61 | assert os.path.isfile(_image) 62 | self.images.append(_image) 63 | if split != 'test': 64 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".png") 65 | assert os.path.isfile(_mask) 66 | self.masks.append(_mask) 67 | 68 | if split != 'test': 69 | assert (len(self.images) == len(self.masks)) 70 | print('Found {} images in the folder {}'.format(len(self.images), _voc_root)) 71 | 72 | def __getitem__(self, index): 73 | img = Image.open(self.images[index]).convert('RGB') 74 | if self.mode == 'test': 75 | img = self._img_transform(img) 76 | if self.transform is not None: 77 | img = self.transform(img) 78 | return img, os.path.basename(self.images[index]) 79 | mask = Image.open(self.masks[index]) 80 | # synchronized transform 81 | if self.mode == 'train': 82 | img, mask = self._sync_transform(img, mask) 83 | elif self.mode == 'val': 84 | img, mask = self._val_sync_transform(img, mask) 85 | else: 86 | assert self.mode == 'testval' 87 | img, mask = self._img_transform(img), self._mask_transform(mask) 88 | # general resize, normalize and toTensor 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | return img, mask, os.path.basename(self.images[index]) 93 | 94 | def __len__(self): 95 | return len(self.images) 96 | 97 | def _mask_transform(self, mask): 98 | target = np.array(mask).astype('int32') 99 | target[target == 255] = -1 100 | return torch.from_numpy(target).long() 101 | 102 | @property 103 | def classes(self): 104 | """Category names.""" 105 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 106 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 107 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 108 | 'tv') 109 | 110 | 111 | if __name__ == '__main__': 112 | dataset = VOCSegmentation() -------------------------------------------------------------------------------- /core/data/dataloader/sbu_shadow.py: -------------------------------------------------------------------------------- 1 | """SBU Shadow Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from .segbase import SegmentationDataset 8 | 9 | 10 | class SBUSegmentation(SegmentationDataset): 11 | """SBU Shadow Segmentation Dataset 12 | """ 13 | NUM_CLASS = 2 14 | 15 | def __init__(self, root='../datasets/sbu', split='train', mode=None, transform=None, **kwargs): 16 | super(SBUSegmentation, self).__init__(root, split, mode, transform, **kwargs) 17 | assert os.path.exists(self.root) 18 | self.images, self.masks = _get_sbu_pairs(self.root, self.split) 19 | assert (len(self.images) == len(self.masks)) 20 | if len(self.images) == 0: 21 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 22 | 23 | def __getitem__(self, index): 24 | img = Image.open(self.images[index]).convert('RGB') 25 | if self.mode == 'test': 26 | if self.transform is not None: 27 | img = self.transform(img) 28 | return img, os.path.basename(self.images[index]) 29 | mask = Image.open(self.masks[index]) 30 | # synchrosized transform 31 | if self.mode == 'train': 32 | img, mask = self._sync_transform(img, mask) 33 | elif self.mode == 'val': 34 | img, mask = self._val_sync_transform(img, mask) 35 | else: 36 | assert self.mode == 'testval' 37 | img, mask = self._img_transform(img), self._mask_transform(mask) 38 | # general resize, normalize and toTensor 39 | if self.transform is not None: 40 | img = self.transform(img) 41 | return img, mask, os.path.basename(self.images[index]) 42 | 43 | def _mask_transform(self, mask): 44 | target = np.array(mask).astype('int32') 45 | target[target > 0] = 1 46 | return torch.from_numpy(target).long() 47 | 48 | def __len__(self): 49 | return len(self.images) 50 | 51 | @property 52 | def pred_offset(self): 53 | return 0 54 | 55 | 56 | def _get_sbu_pairs(folder, split='train'): 57 | def get_path_pairs(img_folder, mask_folder): 58 | img_paths = [] 59 | mask_paths = [] 60 | for root, _, files in os.walk(img_folder): 61 | print(root) 62 | for filename in files: 63 | if filename.endswith('.jpg'): 64 | imgpath = os.path.join(root, filename) 65 | maskname = filename.replace('.jpg', '.png') 66 | maskpath = os.path.join(mask_folder, maskname) 67 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 68 | img_paths.append(imgpath) 69 | mask_paths.append(maskpath) 70 | else: 71 | print('cannot find the mask or image:', imgpath, maskpath) 72 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 73 | return img_paths, mask_paths 74 | 75 | if split == 'train': 76 | img_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowImages') 77 | mask_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowMasks') 78 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 79 | else: 80 | assert split in ('val', 'test') 81 | img_folder = os.path.join(folder, 'SBU-Test/ShadowImages') 82 | mask_folder = os.path.join(folder, 'SBU-Test/ShadowMasks') 83 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 84 | return img_paths, mask_paths 85 | 86 | 87 | if __name__ == '__main__': 88 | dataset = SBUSegmentation(base_size=280, crop_size=256) -------------------------------------------------------------------------------- /core/data/dataloader/segbase.py: -------------------------------------------------------------------------------- 1 | """Base segmentation dataset""" 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | __all__ = ['SegmentationDataset'] 8 | 9 | 10 | class SegmentationDataset(object): 11 | """Segmentation Base Dataset""" 12 | 13 | def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): 14 | super(SegmentationDataset, self).__init__() 15 | self.root = root 16 | self.transform = transform 17 | self.split = split 18 | self.mode = mode if mode is not None else split 19 | self.base_size = base_size 20 | self.crop_size = crop_size 21 | 22 | def _val_sync_transform(self, img, mask): 23 | outsize = self.crop_size 24 | short_size = outsize 25 | w, h = img.size 26 | if w > h: 27 | oh = short_size 28 | ow = int(1.0 * w * oh / h) 29 | else: 30 | ow = short_size 31 | oh = int(1.0 * h * ow / w) 32 | img = img.resize((ow, oh), Image.BILINEAR) 33 | mask = mask.resize((ow, oh), Image.NEAREST) 34 | # center crop 35 | w, h = img.size 36 | x1 = int(round((w - outsize) / 2.)) 37 | y1 = int(round((h - outsize) / 2.)) 38 | img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) 39 | mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) 40 | # final transform 41 | img, mask = self._img_transform(img), self._mask_transform(mask) 42 | return img, mask 43 | 44 | def _sync_transform(self, img, mask): 45 | # random mirror 46 | if random.random() < 0.5: 47 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 48 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 49 | crop_size = self.crop_size 50 | # random scale (short edge) 51 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 52 | w, h = img.size 53 | if h > w: 54 | ow = short_size 55 | oh = int(1.0 * h * ow / w) 56 | else: 57 | oh = short_size 58 | ow = int(1.0 * w * oh / h) 59 | img = img.resize((ow, oh), Image.BILINEAR) 60 | mask = mask.resize((ow, oh), Image.NEAREST) 61 | # pad crop 62 | if short_size < crop_size: 63 | padh = crop_size - oh if oh < crop_size else 0 64 | padw = crop_size - ow if ow < crop_size else 0 65 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 66 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 67 | # random crop crop_size 68 | w, h = img.size 69 | x1 = random.randint(0, w - crop_size) 70 | y1 = random.randint(0, h - crop_size) 71 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 72 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 73 | # gaussian blur as in PSP 74 | if random.random() < 0.5: 75 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) 76 | # final transform 77 | img, mask = self._img_transform(img), self._mask_transform(mask) 78 | return img, mask 79 | 80 | def _img_transform(self, img): 81 | return np.array(img) 82 | 83 | def _mask_transform(self, mask): 84 | return np.array(mask).astype('int32') 85 | 86 | @property 87 | def num_class(self): 88 | """Number of categories.""" 89 | return self.NUM_CLASS 90 | 91 | @property 92 | def pred_offset(self): 93 | return 0 94 | -------------------------------------------------------------------------------- /core/data/dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import errno 4 | import tarfile 5 | from six.moves import urllib 6 | from torch.utils.model_zoo import tqdm 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | def check_integrity(fpath, md5=None): 20 | if md5 is None: 21 | return True 22 | if not os.path.isfile(fpath): 23 | return False 24 | md5o = hashlib.md5() 25 | with open(fpath, 'rb') as f: 26 | # read in 1MB chunks 27 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 28 | md5o.update(chunk) 29 | md5c = md5o.hexdigest() 30 | if md5c != md5: 31 | return False 32 | return True 33 | 34 | def makedir_exist_ok(dirpath): 35 | try: 36 | os.makedirs(dirpath) 37 | except OSError as e: 38 | if e.errno == errno.EEXIST: 39 | pass 40 | else: 41 | pass 42 | 43 | def download_url(url, root, filename=None, md5=None): 44 | """Download a file from a url and place it in root.""" 45 | root = os.path.expanduser(root) 46 | if not filename: 47 | filename = os.path.basename(url) 48 | fpath = os.path.join(root, filename) 49 | 50 | makedir_exist_ok(root) 51 | 52 | # downloads file 53 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 54 | print('Using downloaded and verified file: ' + fpath) 55 | else: 56 | try: 57 | print('Downloading ' + url + ' to ' + fpath) 58 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 59 | except OSError: 60 | if url[:5] == 'https': 61 | url = url.replace('https:', 'http:') 62 | print('Failed download. Trying https -> http instead.' 63 | ' Downloading ' + url + ' to ' + fpath) 64 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 65 | 66 | def download_extract(url, root, filename, md5): 67 | download_url(url, root, filename, md5) 68 | with tarfile.open(os.path.join(root, filename), "r") as tar: 69 | def is_within_directory(directory, target): 70 | 71 | abs_directory = os.path.abspath(directory) 72 | abs_target = os.path.abspath(target) 73 | 74 | prefix = os.path.commonprefix([abs_directory, abs_target]) 75 | 76 | return prefix == abs_directory 77 | 78 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 79 | 80 | for member in tar.getmembers(): 81 | member_path = os.path.join(path, member.name) 82 | if not is_within_directory(path, member_path): 83 | raise Exception("Attempted Path Traversal in Tar File") 84 | 85 | tar.extractall(path, members, numeric_owner=numeric_owner) 86 | 87 | 88 | safe_extract(tar, path=root) -------------------------------------------------------------------------------- /core/data/downloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/core/data/downloader/__init__.py -------------------------------------------------------------------------------- /core/data/downloader/ade20k.py: -------------------------------------------------------------------------------- 1 | """Prepare ADE20K dataset""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from core.utils import download, makedirs 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/ade') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize ADE20K dataset.', 20 | epilog='Example: python setup_ade20k.py', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def download_ade(path, overwrite=False): 28 | _AUG_DOWNLOAD_URLS = [ 29 | ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', 30 | '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'), 31 | ( 32 | 'http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 33 | 'e05747892219d10e9243933371a497e905a4860c'), ] 34 | download_dir = os.path.join(path, 'downloads') 35 | makedirs(download_dir) 36 | for url, checksum in _AUG_DOWNLOAD_URLS: 37 | filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum) 38 | # extract 39 | with zipfile.ZipFile(filename, "r") as zip_ref: 40 | zip_ref.extractall(path=path) 41 | 42 | 43 | if __name__ == '__main__': 44 | args = parse_args() 45 | makedirs(os.path.expanduser('~/.torch/datasets')) 46 | if args.download_dir is not None: 47 | if os.path.isdir(_TARGET_DIR): 48 | os.remove(_TARGET_DIR) 49 | # make symlink 50 | os.symlink(args.download_dir, _TARGET_DIR) 51 | download_ade(_TARGET_DIR, overwrite=False) 52 | -------------------------------------------------------------------------------- /core/data/downloader/cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from core.utils import download, makedirs, check_sha1 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/citys') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize ADE20K dataset.', 20 | epilog='Example: python prepare_cityscapes.py', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def download_city(path, overwrite=False): 28 | _CITY_DOWNLOAD_URLS = [ 29 | ('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'), 30 | ('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')] 31 | download_dir = os.path.join(path, 'downloads') 32 | makedirs(download_dir) 33 | for filename, checksum in _CITY_DOWNLOAD_URLS: 34 | if not check_sha1(filename, checksum): 35 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 36 | 'The repo may be outdated or download may be incomplete. ' \ 37 | 'If the "repo_url" is overridden, consider switching to ' \ 38 | 'the default repo.'.format(filename)) 39 | # extract 40 | with zipfile.ZipFile(filename, "r") as zip_ref: 41 | zip_ref.extractall(path=path) 42 | print("Extracted", filename) 43 | 44 | 45 | if __name__ == '__main__': 46 | args = parse_args() 47 | makedirs(os.path.expanduser('~/.torch/datasets')) 48 | if args.download_dir is not None: 49 | if os.path.isdir(_TARGET_DIR): 50 | os.remove(_TARGET_DIR) 51 | # make symlink 52 | os.symlink(args.download_dir, _TARGET_DIR) 53 | else: 54 | download_city(_TARGET_DIR, overwrite=False) 55 | -------------------------------------------------------------------------------- /core/data/downloader/mscoco.py: -------------------------------------------------------------------------------- 1 | """Prepare MS COCO datasets""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from core.utils import download, makedirs, try_import_pycocotools 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/coco') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize MS COCO dataset.', 20 | epilog='Example: python mscoco.py --download-dir ~/mscoco', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', type=str, default='~/mscoco/', help='dataset directory on disk') 23 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 24 | parser.add_argument('--overwrite', action='store_true', 25 | help='overwrite downloaded files if set, in case they are corrupted') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def download_coco(path, overwrite=False): 31 | _DOWNLOAD_URLS = [ 32 | ('http://images.cocodataset.org/zips/train2017.zip', 33 | '10ad623668ab00c62c096f0ed636d6aff41faca5'), 34 | ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', 35 | '8551ee4bb5860311e79dace7e79cb91e432e78b3'), 36 | ('http://images.cocodataset.org/zips/val2017.zip', 37 | '4950dc9d00dbe1c933ee0170f5797584351d2a41'), 38 | # ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip', 39 | # '46cdcf715b6b4f67e980b529534e79c2edffe084'), 40 | # test2017.zip, for those who want to attend the competition. 41 | # ('http://images.cocodataset.org/zips/test2017.zip', 42 | # '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'), 43 | ] 44 | makedirs(path) 45 | for url, checksum in _DOWNLOAD_URLS: 46 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 47 | # extract 48 | with zipfile.ZipFile(filename) as zf: 49 | zf.extractall(path=path) 50 | 51 | 52 | if __name__ == '__main__': 53 | args = parse_args() 54 | path = os.path.expanduser(args.download_dir) 55 | if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'train2017')) \ 56 | or not os.path.isdir(os.path.join(path, 'val2017')) \ 57 | or not os.path.isdir(os.path.join(path, 'annotations')): 58 | if args.no_download: 59 | raise ValueError(('{} is not a valid directory, make sure it is present.' 60 | ' Or you should not disable "--no-download" to grab it'.format(path))) 61 | else: 62 | download_coco(path, overwrite=args.overwrite) 63 | 64 | # make symlink 65 | makedirs(os.path.expanduser('~/.torch/datasets')) 66 | if os.path.isdir(_TARGET_DIR): 67 | os.remove(_TARGET_DIR) 68 | os.symlink(path, _TARGET_DIR) 69 | try_import_pycocotools() 70 | -------------------------------------------------------------------------------- /core/data/downloader/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """Prepare PASCAL VOC datasets""" 2 | import os 3 | import sys 4 | import shutil 5 | import argparse 6 | import tarfile 7 | 8 | # TODO: optim code 9 | cur_path = os.path.abspath(os.path.dirname(__file__)) 10 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 11 | sys.path.append(root_path) 12 | 13 | from core.utils import download, makedirs 14 | 15 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/voc') 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser( 20 | description='Initialize PASCAL VOC dataset.', 21 | epilog='Example: python pascal_voc.py --download-dir ~/VOCdevkit', 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | parser.add_argument('--download-dir', type=str, default='~/VOCdevkit/', help='dataset directory on disk') 24 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 25 | parser.add_argument('--overwrite', action='store_true', 26 | help='overwrite downloaded files if set, in case they are corrupted') 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | ##################################################################################### 32 | # Download and extract VOC datasets into ``path`` 33 | 34 | def download_voc(path, overwrite=False): 35 | _DOWNLOAD_URLS = [ 36 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 37 | '34ed68851bce2a36e2a223fa52c661d592c66b3c'), 38 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 39 | '41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'), 40 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 41 | '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')] 42 | makedirs(path) 43 | for url, checksum in _DOWNLOAD_URLS: 44 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 45 | # extract 46 | with tarfile.open(filename) as tar: 47 | def is_within_directory(directory, target): 48 | 49 | abs_directory = os.path.abspath(directory) 50 | abs_target = os.path.abspath(target) 51 | 52 | prefix = os.path.commonprefix([abs_directory, abs_target]) 53 | 54 | return prefix == abs_directory 55 | 56 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 57 | 58 | for member in tar.getmembers(): 59 | member_path = os.path.join(path, member.name) 60 | if not is_within_directory(path, member_path): 61 | raise Exception("Attempted Path Traversal in Tar File") 62 | 63 | tar.extractall(path, members, numeric_owner=numeric_owner) 64 | 65 | 66 | safe_extract(tar, path=path) 67 | 68 | 69 | ##################################################################################### 70 | # Download and extract the VOC augmented segmentation dataset into ``path`` 71 | 72 | def download_aug(path, overwrite=False): 73 | _AUG_DOWNLOAD_URLS = [ 74 | ('http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz', 75 | '7129e0a480c2d6afb02b517bb18ac54283bfaa35')] 76 | makedirs(path) 77 | for url, checksum in _AUG_DOWNLOAD_URLS: 78 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 79 | # extract 80 | with tarfile.open(filename) as tar: 81 | def is_within_directory(directory, target): 82 | 83 | abs_directory = os.path.abspath(directory) 84 | abs_target = os.path.abspath(target) 85 | 86 | prefix = os.path.commonprefix([abs_directory, abs_target]) 87 | 88 | return prefix == abs_directory 89 | 90 | def safe_extract(tar, path=".", members=None, *, numeric_owner=False): 91 | 92 | for member in tar.getmembers(): 93 | member_path = os.path.join(path, member.name) 94 | if not is_within_directory(path, member_path): 95 | raise Exception("Attempted Path Traversal in Tar File") 96 | 97 | tar.extractall(path, members, numeric_owner=numeric_owner) 98 | 99 | 100 | safe_extract(tar, path=path) 101 | shutil.move(os.path.join(path, 'benchmark_RELEASE'), 102 | os.path.join(path, 'VOCaug')) 103 | filenames = ['VOCaug/dataset/train.txt', 'VOCaug/dataset/val.txt'] 104 | # generate trainval.txt 105 | with open(os.path.join(path, 'VOCaug/dataset/trainval.txt'), 'w') as outfile: 106 | for fname in filenames: 107 | fname = os.path.join(path, fname) 108 | with open(fname) as infile: 109 | for line in infile: 110 | outfile.write(line) 111 | 112 | 113 | if __name__ == '__main__': 114 | args = parse_args() 115 | path = os.path.expanduser(args.download_dir) 116 | if not os.path.isfile(path) or not os.path.isdir(os.path.join(path, 'VOC2007')) \ 117 | or not os.path.isdir(os.path.join(path, 'VOC2012')): 118 | if args.no_download: 119 | raise ValueError(('{} is not a valid directory, make sure it is present.' 120 | ' Or you should not disable "--no-download" to grab it'.format(path))) 121 | else: 122 | download_voc(path, overwrite=args.overwrite) 123 | shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2007'), os.path.join(path, 'VOC2007')) 124 | shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2012'), os.path.join(path, 'VOC2012')) 125 | shutil.rmtree(os.path.join(path, 'VOCdevkit')) 126 | 127 | if not os.path.isdir(os.path.join(path, 'VOCaug')): 128 | if args.no_download: 129 | raise ValueError(('{} is not a valid directory, make sure it is present.' 130 | ' Or you should not disable "--no-download" to grab it'.format(path))) 131 | else: 132 | download_aug(path, overwrite=args.overwrite) 133 | 134 | # make symlink 135 | makedirs(os.path.expanduser('~/.torch/datasets')) 136 | if os.path.isdir(_TARGET_DIR): 137 | os.remove(_TARGET_DIR) 138 | os.symlink(path, _TARGET_DIR) 139 | -------------------------------------------------------------------------------- /core/data/downloader/sbu_shadow.py: -------------------------------------------------------------------------------- 1 | """Prepare SBU Shadow datasets""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from core.utils import download, makedirs 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/sbu') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize SBU Shadow dataset.', 20 | epilog='Example: python sbu_shadow.py --download-dir ~/SBU-shadow', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', type=str, default=None, help='dataset directory on disk') 23 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 24 | parser.add_argument('--overwrite', action='store_true', 25 | help='overwrite downloaded files if set, in case they are corrupted') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | ##################################################################################### 31 | # Download and extract SBU shadow datasets into ``path`` 32 | 33 | def download_sbu(path, overwrite=False): 34 | _DOWNLOAD_URLS = [ 35 | ('http://www3.cs.stonybrook.edu/~cvl/content/datasets/shadow_db/SBU-shadow.zip'), 36 | ] 37 | download_dir = os.path.join(path, 'downloads') 38 | makedirs(download_dir) 39 | for url in _DOWNLOAD_URLS: 40 | filename = download(url, path=path, overwrite=overwrite) 41 | # extract 42 | with zipfile.ZipFile(filename, "r") as zf: 43 | zf.extractall(path=path) 44 | print("Extracted", filename) 45 | 46 | 47 | if __name__ == '__main__': 48 | args = parse_args() 49 | makedirs(os.path.expanduser('~/.torch/datasets')) 50 | if args.download_dir is not None: 51 | if os.path.isdir(_TARGET_DIR): 52 | os.remove(_TARGET_DIR) 53 | # make symlink 54 | os.symlink(args.download_dir, _TARGET_DIR) 55 | else: 56 | download_sbu(_TARGET_DIR, overwrite=False) 57 | -------------------------------------------------------------------------------- /core/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Model Zoo""" 2 | from .model_zoo import get_model, get_model_list -------------------------------------------------------------------------------- /core/models/base_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet import * 2 | from .resnet import * 3 | from .resnetv1b import * 4 | from .vgg import * 5 | from .eespnet import * 6 | from .xception import * 7 | -------------------------------------------------------------------------------- /core/models/base_models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """MobileNet and MobileNetV2.""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | from core.nn import _ConvBNReLU, _DepthwiseConv, InvertedResidual 6 | 7 | __all__ = ['MobileNet', 'MobileNetV2', 'get_mobilenet', 'get_mobilenet_v2', 8 | 'mobilenet1_0', 'mobilenet_v2_1_0', 'mobilenet0_75', 'mobilenet_v2_0_75', 9 | 'mobilenet0_5', 'mobilenet_v2_0_5', 'mobilenet0_25', 'mobilenet_v2_0_25'] 10 | 11 | 12 | class MobileNet(nn.Module): 13 | def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs): 14 | super(MobileNet, self).__init__() 15 | conv_dw_setting = [ 16 | [64, 1, 1], 17 | [128, 2, 2], 18 | [256, 2, 2], 19 | [512, 6, 2], 20 | [1024, 2, 2]] 21 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32 22 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, norm_layer=norm_layer)] 23 | 24 | for c, n, s in conv_dw_setting: 25 | out_channels = int(c * multiplier) 26 | for i in range(n): 27 | stride = s if i == 0 else 1 28 | features.append(_DepthwiseConv(input_channels, out_channels, stride, norm_layer)) 29 | input_channels = out_channels 30 | features.append(nn.AdaptiveAvgPool2d(1)) 31 | self.features = nn.Sequential(*features) 32 | 33 | self.classifier = nn.Linear(int(1024 * multiplier), num_classes) 34 | 35 | # weight initialization 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 39 | if m.bias is not None: 40 | nn.init.zeros_(m.bias) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | nn.init.ones_(m.weight) 43 | nn.init.zeros_(m.bias) 44 | elif isinstance(m, nn.Linear): 45 | nn.init.normal_(m.weight, 0, 0.01) 46 | nn.init.zeros_(m.bias) 47 | 48 | def forward(self, x): 49 | x = self.features(x) 50 | x = self.classifier(x.view(x.size(0), x.size(1))) 51 | return x 52 | 53 | 54 | class MobileNetV2(nn.Module): 55 | def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs): 56 | super(MobileNetV2, self).__init__() 57 | inverted_residual_setting = [ 58 | # t, c, n, s 59 | [1, 16, 1, 1], 60 | [6, 24, 2, 2], 61 | [6, 32, 3, 2], 62 | [6, 64, 4, 2], 63 | [6, 96, 3, 1], 64 | [6, 160, 3, 2], 65 | [6, 320, 1, 1]] 66 | # building first layer 67 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32 68 | last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280 69 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer)] 70 | 71 | # building inverted residual blocks 72 | for t, c, n, s in inverted_residual_setting: 73 | out_channels = int(c * multiplier) 74 | for i in range(n): 75 | stride = s if i == 0 else 1 76 | features.append(InvertedResidual(input_channels, out_channels, stride, t, norm_layer)) 77 | input_channels = out_channels 78 | 79 | # building last several layers 80 | features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer)) 81 | features.append(nn.AdaptiveAvgPool2d(1)) 82 | self.features = nn.Sequential(*features) 83 | 84 | self.classifier = nn.Sequential( 85 | nn.Dropout2d(0.2), 86 | nn.Linear(last_channels, num_classes)) 87 | 88 | # weight initialization 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 92 | if m.bias is not None: 93 | nn.init.zeros_(m.bias) 94 | elif isinstance(m, nn.BatchNorm2d): 95 | nn.init.ones_(m.weight) 96 | nn.init.zeros_(m.bias) 97 | elif isinstance(m, nn.Linear): 98 | nn.init.normal_(m.weight, 0, 0.01) 99 | if m.bias is not None: 100 | nn.init.zeros_(m.bias) 101 | 102 | def forward(self, x): 103 | x = self.features(x) 104 | x = self.classifier(x.view(x.size(0), x.size(1))) 105 | return x 106 | 107 | 108 | # Constructor 109 | def get_mobilenet(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs): 110 | model = MobileNet(multiplier=multiplier, **kwargs) 111 | 112 | if pretrained: 113 | raise ValueError("Not support pretrained") 114 | return model 115 | 116 | 117 | def get_mobilenet_v2(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs): 118 | model = MobileNetV2(multiplier=multiplier, **kwargs) 119 | 120 | if pretrained: 121 | raise ValueError("Not support pretrained") 122 | return model 123 | 124 | 125 | def mobilenet1_0(**kwargs): 126 | return get_mobilenet(1.0, **kwargs) 127 | 128 | 129 | def mobilenet_v2_1_0(**kwargs): 130 | return get_mobilenet_v2(1.0, **kwargs) 131 | 132 | 133 | def mobilenet0_75(**kwargs): 134 | return get_mobilenet(0.75, **kwargs) 135 | 136 | 137 | def mobilenet_v2_0_75(**kwargs): 138 | return get_mobilenet_v2(0.75, **kwargs) 139 | 140 | 141 | def mobilenet0_5(**kwargs): 142 | return get_mobilenet(0.5, **kwargs) 143 | 144 | 145 | def mobilenet_v2_0_5(**kwargs): 146 | return get_mobilenet_v2(0.5, **kwargs) 147 | 148 | 149 | def mobilenet0_25(**kwargs): 150 | return get_mobilenet(0.25, **kwargs) 151 | 152 | 153 | def mobilenet_v2_0_25(**kwargs): 154 | return get_mobilenet_v2(0.25, **kwargs) 155 | 156 | 157 | if __name__ == '__main__': 158 | model = mobilenet0_5() 159 | -------------------------------------------------------------------------------- /core/models/base_models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | __all__ = ['ResNext', 'resnext50_32x4d', 'resnext101_32x8d'] 5 | 6 | model_urls = { 7 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 8 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 9 | } 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 16 | base_width=64, dilation=1, norm_layer=None, **kwargs): 17 | super(Bottleneck, self).__init__() 18 | width = int(planes * (base_width / 64.)) * groups 19 | 20 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False) 21 | self.bn1 = norm_layer(width) 22 | self.conv2 = nn.Conv2d(width, width, 3, stride, dilation, dilation, groups, bias=False) 23 | self.bn2 = norm_layer(width) 24 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) 25 | self.bn3 = norm_layer(planes * self.expansion) 26 | self.relu = nn.ReLU(True) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | identity = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv3(out) 42 | out = self.bn3(out) 43 | 44 | if self.downsample is not None: 45 | identity = self.downsample(x) 46 | 47 | out += identity 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class ResNext(nn.Module): 54 | 55 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, 56 | width_per_group=64, dilated=False, norm_layer=nn.BatchNorm2d, **kwargs): 57 | super(ResNext, self).__init__() 58 | self.inplanes = 64 59 | self.groups = groups 60 | self.base_width = width_per_group 61 | 62 | self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False) 63 | self.bn1 = norm_layer(self.inplanes) 64 | self.relu = nn.ReLU(True) 65 | self.maxpool = nn.MaxPool2d(3, 2, 1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 69 | if dilated: 70 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer) 71 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer) 72 | else: 73 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 74 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 75 | 76 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 77 | self.fc = nn.Linear(512 * block.expansion, num_classes) 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 82 | elif isinstance(m, nn.BatchNorm2d): 83 | nn.init.constant_(m.weight, 1) 84 | nn.init.constant_(m.bias, 0) 85 | 86 | if zero_init_residual: 87 | for m in self.modules(): 88 | if isinstance(m, Bottleneck): 89 | nn.init.constant_(m.bn3.weight, 0) 90 | 91 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): 92 | downsample = None 93 | if stride != 1 or self.inplanes != planes * block.expansion: 94 | downsample = nn.Sequential( 95 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False), 96 | norm_layer(planes * block.expansion) 97 | ) 98 | 99 | layers = list() 100 | if dilation in (1, 2): 101 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 102 | self.base_width, norm_layer=norm_layer)) 103 | elif dilation == 4: 104 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 105 | self.base_width, dilation=2, norm_layer=norm_layer)) 106 | else: 107 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 108 | self.inplanes = planes * block.expansion 109 | for _ in range(1, blocks): 110 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, 111 | dilation=dilation, norm_layer=norm_layer)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | x = self.relu(x) 119 | x = self.maxpool(x) 120 | 121 | x = self.layer1(x) 122 | x = self.layer2(x) 123 | x = self.layer3(x) 124 | x = self.layer4(x) 125 | 126 | x = self.avgpool(x) 127 | x = x.view(x.size(0), -1) 128 | x = self.fc(x) 129 | 130 | return x 131 | 132 | 133 | def resnext50_32x4d(pretrained=False, **kwargs): 134 | kwargs['groups'] = 32 135 | kwargs['width_per_group'] = 4 136 | model = ResNext(Bottleneck, [3, 4, 6, 3], **kwargs) 137 | if pretrained: 138 | state_dict = model_zoo.load_url(model_urls['resnext50_32x4d']) 139 | model.load_state_dict(state_dict) 140 | return model 141 | 142 | 143 | def resnext101_32x8d(pretrained=False, **kwargs): 144 | kwargs['groups'] = 32 145 | kwargs['width_per_group'] = 8 146 | model = ResNext(Bottleneck, [3, 4, 23, 3], **kwargs) 147 | if pretrained: 148 | state_dict = model_zoo.load_url(model_urls['resnext101_32x8d']) 149 | model.load_state_dict(state_dict) 150 | return model 151 | 152 | 153 | if __name__ == '__main__': 154 | model = resnext101_32x8d() 155 | -------------------------------------------------------------------------------- /core/models/base_models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = [ 6 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 7 | 'vgg19_bn', 'vgg19', 8 | ] 9 | 10 | model_urls = { 11 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 12 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 13 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 14 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 15 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 16 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 17 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 18 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 19 | } 20 | 21 | 22 | class VGG(nn.Module): 23 | def __init__(self, features, num_classes=1000, init_weights=True): 24 | super(VGG, self).__init__() 25 | self.features = features 26 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 27 | self.classifier = nn.Sequential( 28 | nn.Linear(512 * 7 * 7, 4096), 29 | nn.ReLU(True), 30 | nn.Dropout(), 31 | nn.Linear(4096, 4096), 32 | nn.ReLU(True), 33 | nn.Dropout(), 34 | nn.Linear(4096, num_classes) 35 | ) 36 | if init_weights: 37 | self._initialize_weights() 38 | 39 | def forward(self, x): 40 | x = self.features(x) 41 | x = self.avgpool(x) 42 | x = x.view(x.size(0), -1) 43 | x = self.classifier(x) 44 | return x 45 | 46 | def _initialize_weights(self): 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 50 | if m.bias is not None: 51 | nn.init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.BatchNorm2d): 53 | nn.init.constant_(m.weight, 1) 54 | nn.init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.Linear): 56 | nn.init.normal_(m.weight, 0, 0.01) 57 | nn.init.constant_(m.bias, 0) 58 | 59 | 60 | def make_layers(cfg, batch_norm=False): 61 | layers = [] 62 | in_channels = 3 63 | for v in cfg: 64 | if v == 'M': 65 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 66 | else: 67 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 68 | if batch_norm: 69 | layers += (conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)) 70 | else: 71 | layers += [conv2d, nn.ReLU(inplace=True)] 72 | in_channels = v 73 | return nn.Sequential(*layers) 74 | 75 | 76 | cfg = { 77 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 78 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 79 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 80 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 81 | } 82 | 83 | 84 | def vgg11(pretrained=False, **kwargs): 85 | """VGG 11-layer model (configuration "A") 86 | Args: 87 | pretrained (bool): If True, returns a model pre-trained on ImageNet 88 | """ 89 | if pretrained: 90 | kwargs['init_weights'] = False 91 | model = VGG(make_layers(cfg['A']), **kwargs) 92 | if pretrained: 93 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 94 | return model 95 | 96 | 97 | def vgg11_bn(pretrained=False, **kwargs): 98 | """VGG 11-layer model (configuration "A") with batch normalization 99 | Args: 100 | pretrained (bool): If True, returns a model pre-trained on ImageNet 101 | """ 102 | if pretrained: 103 | kwargs['init_weights'] = False 104 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 105 | if pretrained: 106 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 107 | return model 108 | 109 | 110 | def vgg13(pretrained=False, **kwargs): 111 | """VGG 13-layer model (configuration "B") 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | if pretrained: 116 | kwargs['init_weights'] = False 117 | model = VGG(make_layers(cfg['B']), **kwargs) 118 | if pretrained: 119 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 120 | return model 121 | 122 | 123 | def vgg13_bn(pretrained=False, **kwargs): 124 | """VGG 13-layer model (configuration "B") with batch normalization 125 | Args: 126 | pretrained (bool): If True, returns a model pre-trained on ImageNet 127 | """ 128 | if pretrained: 129 | kwargs['init_weights'] = False 130 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 131 | if pretrained: 132 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 133 | return model 134 | 135 | 136 | def vgg16(pretrained=False, **kwargs): 137 | """VGG 16-layer model (configuration "D") 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | """ 141 | if pretrained: 142 | kwargs['init_weights'] = False 143 | model = VGG(make_layers(cfg['D']), **kwargs) 144 | if pretrained: 145 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 146 | return model 147 | 148 | 149 | def vgg16_bn(pretrained=False, **kwargs): 150 | """VGG 16-layer model (configuration "D") with batch normalization 151 | Args: 152 | pretrained (bool): If True, returns a model pre-trained on ImageNet 153 | """ 154 | if pretrained: 155 | kwargs['init_weights'] = False 156 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 157 | if pretrained: 158 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 159 | return model 160 | 161 | 162 | def vgg19(pretrained=False, **kwargs): 163 | """VGG 19-layer model (configuration "E") 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | if pretrained: 168 | kwargs['init_weights'] = False 169 | model = VGG(make_layers(cfg['E']), **kwargs) 170 | if pretrained: 171 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 172 | return model 173 | 174 | 175 | def vgg19_bn(pretrained=False, **kwargs): 176 | """VGG 19-layer model (configuration 'E') with batch normalization 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | if pretrained: 181 | kwargs['init_weights'] = False 182 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 183 | if pretrained: 184 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 185 | return model 186 | 187 | 188 | if __name__ == '__main__': 189 | img = torch.randn((4, 3, 480, 480)) 190 | model = vgg16(pretrained=False) 191 | out = model(img) 192 | -------------------------------------------------------------------------------- /core/models/deeplabv3.py: -------------------------------------------------------------------------------- 1 | """Pyramid Scene Parsing Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .fcn import _FCNHead 8 | 9 | __all__ = ['DeepLabV3', 'get_deeplabv3', 'get_deeplabv3_resnet50_voc', 'get_deeplabv3_resnet101_voc', 10 | 'get_deeplabv3_resnet152_voc', 'get_deeplabv3_resnet50_ade', 'get_deeplabv3_resnet101_ade', 11 | 'get_deeplabv3_resnet152_ade'] 12 | 13 | 14 | class DeepLabV3(SegBaseModel): 15 | r"""DeepLabV3 16 | 17 | Parameters 18 | ---------- 19 | nclass : int 20 | Number of categories for the training dataset. 21 | backbone : string 22 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 23 | 'resnet101' or 'resnet152'). 24 | norm_layer : object 25 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 26 | for Synchronized Cross-GPU BachNormalization). 27 | aux : bool 28 | Auxiliary loss. 29 | 30 | Reference: 31 | Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation." 32 | arXiv preprint arXiv:1706.05587 (2017). 33 | """ 34 | 35 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs): 36 | super(DeepLabV3, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 37 | self.head = _DeepLabHead(nclass, **kwargs) 38 | if self.aux: 39 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 40 | 41 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 42 | 43 | def forward(self, x): 44 | size = x.size()[2:] 45 | _, _, c3, c4 = self.base_forward(x) 46 | outputs = [] 47 | x = self.head(c4) 48 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 49 | outputs.append(x) 50 | 51 | if self.aux: 52 | auxout = self.auxlayer(c3) 53 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 54 | outputs.append(auxout) 55 | return tuple(outputs) 56 | 57 | 58 | class _DeepLabHead(nn.Module): 59 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 60 | super(_DeepLabHead, self).__init__() 61 | self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs) 62 | self.block = nn.Sequential( 63 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 64 | norm_layer(256, **({} if norm_kwargs is None else norm_kwargs)), 65 | nn.ReLU(True), 66 | nn.Dropout(0.1), 67 | nn.Conv2d(256, nclass, 1) 68 | ) 69 | 70 | def forward(self, x): 71 | x = self.aspp(x) 72 | return self.block(x) 73 | 74 | 75 | class _ASPPConv(nn.Module): 76 | def __init__(self, in_channels, out_channels, atrous_rate, norm_layer, norm_kwargs): 77 | super(_ASPPConv, self).__init__() 78 | self.block = nn.Sequential( 79 | nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False), 80 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 81 | nn.ReLU(True) 82 | ) 83 | 84 | def forward(self, x): 85 | return self.block(x) 86 | 87 | 88 | class _AsppPooling(nn.Module): 89 | def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs, **kwargs): 90 | super(_AsppPooling, self).__init__() 91 | self.gap = nn.Sequential( 92 | nn.AdaptiveAvgPool2d(1), 93 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 94 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 95 | nn.ReLU(True) 96 | ) 97 | 98 | def forward(self, x): 99 | size = x.size()[2:] 100 | pool = self.gap(x) 101 | out = F.interpolate(pool, size, mode='bilinear', align_corners=True) 102 | return out 103 | 104 | 105 | class _ASPP(nn.Module): 106 | def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs, **kwargs): 107 | super(_ASPP, self).__init__() 108 | out_channels = 256 109 | self.b0 = nn.Sequential( 110 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 111 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 112 | nn.ReLU(True) 113 | ) 114 | 115 | rate1, rate2, rate3 = tuple(atrous_rates) 116 | self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs) 117 | self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs) 118 | self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs) 119 | self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer, norm_kwargs=norm_kwargs) 120 | 121 | self.project = nn.Sequential( 122 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 123 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 124 | nn.ReLU(True), 125 | nn.Dropout(0.5) 126 | ) 127 | 128 | def forward(self, x): 129 | feat1 = self.b0(x) 130 | feat2 = self.b1(x) 131 | feat3 = self.b2(x) 132 | feat4 = self.b3(x) 133 | feat5 = self.b4(x) 134 | x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) 135 | x = self.project(x) 136 | return x 137 | 138 | 139 | def get_deeplabv3(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models', 140 | pretrained_base=True, **kwargs): 141 | acronyms = { 142 | 'pascal_voc': 'pascal_voc', 143 | 'pascal_aug': 'pascal_aug', 144 | 'ade20k': 'ade', 145 | 'coco': 'coco', 146 | 'citys': 'citys', 147 | } 148 | from ..data.dataloader import datasets 149 | model = DeepLabV3(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 150 | if pretrained: 151 | from .model_store import get_model_file 152 | device = torch.device(kwargs['local_rank']) 153 | model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root), 154 | map_location=device)) 155 | return model 156 | 157 | 158 | def get_deeplabv3_resnet50_voc(**kwargs): 159 | return get_deeplabv3('pascal_voc', 'resnet50', **kwargs) 160 | 161 | 162 | def get_deeplabv3_resnet101_voc(**kwargs): 163 | return get_deeplabv3('pascal_voc', 'resnet101', **kwargs) 164 | 165 | 166 | def get_deeplabv3_resnet152_voc(**kwargs): 167 | return get_deeplabv3('pascal_voc', 'resnet152', **kwargs) 168 | 169 | 170 | def get_deeplabv3_resnet50_ade(**kwargs): 171 | return get_deeplabv3('ade20k', 'resnet50', **kwargs) 172 | 173 | 174 | def get_deeplabv3_resnet101_ade(**kwargs): 175 | return get_deeplabv3('ade20k', 'resnet101', **kwargs) 176 | 177 | 178 | def get_deeplabv3_resnet152_ade(**kwargs): 179 | return get_deeplabv3('ade20k', 'resnet152', **kwargs) 180 | 181 | 182 | if __name__ == '__main__': 183 | model = get_deeplabv3_resnet50_voc() 184 | img = torch.randn(2, 3, 480, 480) 185 | output = model(img) 186 | -------------------------------------------------------------------------------- /core/models/deeplabv3_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_models.xception import get_xception 6 | from .deeplabv3 import _ASPP 7 | from .fcn import _FCNHead 8 | from ..nn import _ConvBNReLU 9 | 10 | __all__ = ['DeepLabV3Plus', 'get_deeplabv3_plus', 'get_deeplabv3_plus_xception_voc'] 11 | 12 | 13 | class DeepLabV3Plus(nn.Module): 14 | r"""DeepLabV3Plus 15 | Parameters 16 | ---------- 17 | nclass : int 18 | Number of categories for the training dataset. 19 | backbone : string 20 | Pre-trained dilated backbone network type (default:'xception'). 21 | norm_layer : object 22 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 23 | for Synchronized Cross-GPU BachNormalization). 24 | aux : bool 25 | Auxiliary loss. 26 | 27 | Reference: 28 | Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic 29 | Image Segmentation." 30 | """ 31 | 32 | def __init__(self, nclass, backbone='xception', aux=True, pretrained_base=True, dilated=True, **kwargs): 33 | super(DeepLabV3Plus, self).__init__() 34 | self.aux = aux 35 | self.nclass = nclass 36 | output_stride = 8 if dilated else 32 37 | 38 | self.pretrained = get_xception(pretrained=pretrained_base, output_stride=output_stride, **kwargs) 39 | 40 | # deeplabv3 plus 41 | self.head = _DeepLabHead(nclass, **kwargs) 42 | if aux: 43 | self.auxlayer = _FCNHead(728, nclass, **kwargs) 44 | 45 | def base_forward(self, x): 46 | # Entry flow 47 | x = self.pretrained.conv1(x) 48 | x = self.pretrained.bn1(x) 49 | x = self.pretrained.relu(x) 50 | 51 | x = self.pretrained.conv2(x) 52 | x = self.pretrained.bn2(x) 53 | x = self.pretrained.relu(x) 54 | 55 | x = self.pretrained.block1(x) 56 | # add relu here 57 | x = self.pretrained.relu(x) 58 | low_level_feat = x 59 | 60 | x = self.pretrained.block2(x) 61 | x = self.pretrained.block3(x) 62 | 63 | # Middle flow 64 | x = self.pretrained.midflow(x) 65 | mid_level_feat = x 66 | 67 | # Exit flow 68 | x = self.pretrained.block20(x) 69 | x = self.pretrained.relu(x) 70 | x = self.pretrained.conv3(x) 71 | x = self.pretrained.bn3(x) 72 | x = self.pretrained.relu(x) 73 | 74 | x = self.pretrained.conv4(x) 75 | x = self.pretrained.bn4(x) 76 | x = self.pretrained.relu(x) 77 | 78 | x = self.pretrained.conv5(x) 79 | x = self.pretrained.bn5(x) 80 | x = self.pretrained.relu(x) 81 | return low_level_feat, mid_level_feat, x 82 | 83 | def forward(self, x): 84 | size = x.size()[2:] 85 | c1, c3, c4 = self.base_forward(x) 86 | outputs = list() 87 | x = self.head(c4, c1) 88 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 89 | outputs.append(x) 90 | if self.aux: 91 | auxout = self.auxlayer(c3) 92 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 93 | outputs.append(auxout) 94 | return tuple(outputs) 95 | 96 | 97 | class _DeepLabHead(nn.Module): 98 | def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm2d, **kwargs): 99 | super(_DeepLabHead, self).__init__() 100 | self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, **kwargs) 101 | self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer) 102 | self.block = nn.Sequential( 103 | _ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer), 104 | nn.Dropout(0.5), 105 | _ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer), 106 | nn.Dropout(0.1), 107 | nn.Conv2d(256, nclass, 1)) 108 | 109 | def forward(self, x, c1): 110 | size = c1.size()[2:] 111 | c1 = self.c1_block(c1) 112 | x = self.aspp(x) 113 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 114 | return self.block(torch.cat([x, c1], dim=1)) 115 | 116 | 117 | def get_deeplabv3_plus(dataset='pascal_voc', backbone='xception', pretrained=False, root='~/.torch/models', 118 | pretrained_base=True, **kwargs): 119 | acronyms = { 120 | 'pascal_voc': 'pascal_voc', 121 | 'pascal_aug': 'pascal_aug', 122 | 'ade20k': 'ade', 123 | 'coco': 'coco', 124 | 'citys': 'citys', 125 | } 126 | from ..data.dataloader import datasets 127 | model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 128 | if pretrained: 129 | from .model_store import get_model_file 130 | device = torch.device(kwargs['local_rank']) 131 | model.load_state_dict( 132 | torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root), 133 | map_location=device)) 134 | return model 135 | 136 | 137 | def get_deeplabv3_plus_xception_voc(**kwargs): 138 | return get_deeplabv3_plus('pascal_voc', 'xception', **kwargs) 139 | 140 | 141 | if __name__ == '__main__': 142 | model = get_deeplabv3_plus_xception_voc() 143 | -------------------------------------------------------------------------------- /core/models/dfanet.py: -------------------------------------------------------------------------------- 1 | """ Deep Feature Aggregation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from core.models.base_models import Enc, FCAttention, get_xception_a 7 | from core.nn import _ConvBNReLU 8 | 9 | __all__ = ['DFANet', 'get_dfanet', 'get_dfanet_citys'] 10 | 11 | 12 | class DFANet(nn.Module): 13 | def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs): 14 | super(DFANet, self).__init__() 15 | self.pretrained = get_xception_a(pretrained_base, **kwargs) 16 | 17 | self.enc2_2 = Enc(240, 48, 4, **kwargs) 18 | self.enc3_2 = Enc(144, 96, 6, **kwargs) 19 | self.enc4_2 = Enc(288, 192, 4, **kwargs) 20 | self.fca_2 = FCAttention(192, **kwargs) 21 | 22 | self.enc2_3 = Enc(240, 48, 4, **kwargs) 23 | self.enc3_3 = Enc(144, 96, 6, **kwargs) 24 | self.enc3_4 = Enc(288, 192, 4, **kwargs) 25 | self.fca_3 = FCAttention(192, **kwargs) 26 | 27 | self.enc2_1_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 28 | self.enc2_2_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 29 | self.enc2_3_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 30 | self.conv_fusion = _ConvBNReLU(32, 32, 1, **kwargs) 31 | 32 | self.fca_1_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 33 | self.fca_2_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 34 | self.fca_3_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 35 | self.conv_out = nn.Conv2d(32, nclass, 1) 36 | 37 | self.__setattr__('exclusive', ['enc2_2', 'enc3_2', 'enc4_2', 'fca_2', 'enc2_3', 'enc3_3', 'enc3_4', 'fca_3', 38 | 'enc2_1_reduce', 'enc2_2_reduce', 'enc2_3_reduce', 'conv_fusion', 'fca_1_reduce', 39 | 'fca_2_reduce', 'fca_3_reduce', 'conv_out']) 40 | 41 | def forward(self, x): 42 | # backbone 43 | stage1_conv1 = self.pretrained.conv1(x) 44 | stage1_enc2 = self.pretrained.enc2(stage1_conv1) 45 | stage1_enc3 = self.pretrained.enc3(stage1_enc2) 46 | stage1_enc4 = self.pretrained.enc4(stage1_enc3) 47 | stage1_fca = self.pretrained.fca(stage1_enc4) 48 | stage1_out = F.interpolate(stage1_fca, scale_factor=4, mode='bilinear', align_corners=True) 49 | 50 | # stage2 51 | stage2_enc2 = self.enc2_2(torch.cat([stage1_enc2, stage1_out], dim=1)) 52 | stage2_enc3 = self.enc3_2(torch.cat([stage1_enc3, stage2_enc2], dim=1)) 53 | stage2_enc4 = self.enc4_2(torch.cat([stage1_enc4, stage2_enc3], dim=1)) 54 | stage2_fca = self.fca_2(stage2_enc4) 55 | stage2_out = F.interpolate(stage2_fca, scale_factor=4, mode='bilinear', align_corners=True) 56 | 57 | # stage3 58 | stage3_enc2 = self.enc2_3(torch.cat([stage2_enc2, stage2_out], dim=1)) 59 | stage3_enc3 = self.enc3_3(torch.cat([stage2_enc3, stage3_enc2], dim=1)) 60 | stage3_enc4 = self.enc3_4(torch.cat([stage2_enc4, stage3_enc3], dim=1)) 61 | stage3_fca = self.fca_3(stage3_enc4) 62 | 63 | stage1_enc2_decoder = self.enc2_1_reduce(stage1_enc2) 64 | stage2_enc2_docoder = F.interpolate(self.enc2_2_reduce(stage2_enc2), scale_factor=2, 65 | mode='bilinear', align_corners=True) 66 | stage3_enc2_decoder = F.interpolate(self.enc2_3_reduce(stage3_enc2), scale_factor=4, 67 | mode='bilinear', align_corners=True) 68 | fusion = stage1_enc2_decoder + stage2_enc2_docoder + stage3_enc2_decoder 69 | fusion = self.conv_fusion(fusion) 70 | 71 | stage1_fca_decoder = F.interpolate(self.fca_1_reduce(stage1_fca), scale_factor=4, 72 | mode='bilinear', align_corners=True) 73 | stage2_fca_decoder = F.interpolate(self.fca_2_reduce(stage2_fca), scale_factor=8, 74 | mode='bilinear', align_corners=True) 75 | stage3_fca_decoder = F.interpolate(self.fca_3_reduce(stage3_fca), scale_factor=16, 76 | mode='bilinear', align_corners=True) 77 | fusion = fusion + stage1_fca_decoder + stage2_fca_decoder + stage3_fca_decoder 78 | 79 | outputs = list() 80 | out = self.conv_out(fusion) 81 | out = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True) 82 | outputs.append(out) 83 | 84 | return tuple(outputs) 85 | 86 | 87 | def get_dfanet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models', 88 | pretrained_base=True, **kwargs): 89 | acronyms = { 90 | 'pascal_voc': 'pascal_voc', 91 | 'pascal_aug': 'pascal_aug', 92 | 'ade20k': 'ade', 93 | 'coco': 'coco', 94 | 'citys': 'citys', 95 | } 96 | from ..data.dataloader import datasets 97 | model = DFANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 98 | if pretrained: 99 | from .model_store import get_model_file 100 | device = torch.device(kwargs['local_rank']) 101 | model.load_state_dict(torch.load(get_model_file('dfanet_%s' % (acronyms[dataset]), root=root), 102 | map_location=device)) 103 | return model 104 | 105 | 106 | def get_dfanet_citys(**kwargs): 107 | return get_dfanet('citys', **kwargs) 108 | 109 | 110 | if __name__ == '__main__': 111 | model = get_dfanet_citys() 112 | -------------------------------------------------------------------------------- /core/models/dunet.py: -------------------------------------------------------------------------------- 1 | """Decoders Matter for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .fcn import _FCNHead 8 | 9 | __all__ = ['DUNet', 'get_dunet', 'get_dunet_resnet50_pascal_voc', 10 | 'get_dunet_resnet101_pascal_voc', 'get_dunet_resnet152_pascal_voc'] 11 | 12 | 13 | # The model may be wrong because lots of details missing in paper. 14 | class DUNet(SegBaseModel): 15 | """Decoders Matter for Semantic Segmentation 16 | 17 | Reference: 18 | Zhi Tian, Tong He, Chunhua Shen, and Youliang Yan. 19 | "Decoders Matter for Semantic Segmentation: 20 | Data-Dependent Decoding Enables Flexible Feature Aggregation." CVPR, 2019 21 | """ 22 | 23 | def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs): 24 | super(DUNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 25 | self.head = _DUHead(2144, **kwargs) 26 | self.dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs) 27 | if aux: 28 | self.auxlayer = _FCNHead(1024, 256, **kwargs) 29 | self.aux_dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs) 30 | 31 | self.__setattr__('exclusive', 32 | ['dupsample', 'head', 'auxlayer', 'aux_dupsample'] if aux else ['dupsample', 'head']) 33 | 34 | def forward(self, x): 35 | c1, c2, c3, c4 = self.base_forward(x) 36 | outputs = [] 37 | x = self.head(c2, c3, c4) 38 | x = self.dupsample(x) 39 | outputs.append(x) 40 | 41 | if self.aux: 42 | auxout = self.auxlayer(c3) 43 | auxout = self.aux_dupsample(auxout) 44 | outputs.append(auxout) 45 | return tuple(outputs) 46 | 47 | 48 | class FeatureFused(nn.Module): 49 | """Module for fused features""" 50 | 51 | def __init__(self, inter_channels=48, norm_layer=nn.BatchNorm2d, **kwargs): 52 | super(FeatureFused, self).__init__() 53 | self.conv2 = nn.Sequential( 54 | nn.Conv2d(512, inter_channels, 1, bias=False), 55 | norm_layer(inter_channels), 56 | nn.ReLU(True) 57 | ) 58 | self.conv3 = nn.Sequential( 59 | nn.Conv2d(1024, inter_channels, 1, bias=False), 60 | norm_layer(inter_channels), 61 | nn.ReLU(True) 62 | ) 63 | 64 | def forward(self, c2, c3, c4): 65 | size = c4.size()[2:] 66 | c2 = self.conv2(F.interpolate(c2, size, mode='bilinear', align_corners=True)) 67 | c3 = self.conv3(F.interpolate(c3, size, mode='bilinear', align_corners=True)) 68 | fused_feature = torch.cat([c4, c3, c2], dim=1) 69 | return fused_feature 70 | 71 | 72 | class _DUHead(nn.Module): 73 | def __init__(self, in_channels, norm_layer=nn.BatchNorm2d, **kwargs): 74 | super(_DUHead, self).__init__() 75 | self.fuse = FeatureFused(norm_layer=norm_layer, **kwargs) 76 | self.block = nn.Sequential( 77 | nn.Conv2d(in_channels, 256, 3, padding=1, bias=False), 78 | norm_layer(256), 79 | nn.ReLU(True), 80 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 81 | norm_layer(256), 82 | nn.ReLU(True) 83 | ) 84 | 85 | def forward(self, c2, c3, c4): 86 | fused_feature = self.fuse(c2, c3, c4) 87 | out = self.block(fused_feature) 88 | return out 89 | 90 | 91 | class DUpsampling(nn.Module): 92 | """DUsampling module""" 93 | 94 | def __init__(self, in_channels, out_channels, scale_factor=2, **kwargs): 95 | super(DUpsampling, self).__init__() 96 | self.scale_factor = scale_factor 97 | self.conv_w = nn.Conv2d(in_channels, out_channels * scale_factor * scale_factor, 1, bias=False) 98 | 99 | def forward(self, x): 100 | x = self.conv_w(x) 101 | n, c, h, w = x.size() 102 | 103 | # N, C, H, W --> N, W, H, C 104 | x = x.permute(0, 3, 2, 1).contiguous() 105 | 106 | # N, W, H, C --> N, W, H * scale, C // scale 107 | x = x.view(n, w, h * self.scale_factor, c // self.scale_factor) 108 | 109 | # N, W, H * scale, C // scale --> N, H * scale, W, C // scale 110 | x = x.permute(0, 2, 1, 3).contiguous() 111 | 112 | # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) 113 | x = x.view(n, h * self.scale_factor, w * self.scale_factor, c // (self.scale_factor * self.scale_factor)) 114 | 115 | # N, H * scale, W * scale, C // (scale ** 2) -- > N, C // (scale ** 2), H * scale, W * scale 116 | x = x.permute(0, 3, 1, 2) 117 | 118 | return x 119 | 120 | 121 | def get_dunet(dataset='pascal_voc', backbone='resnet50', pretrained=False, 122 | root='~/.torch/models', pretrained_base=True, **kwargs): 123 | acronyms = { 124 | 'pascal_voc': 'pascal_voc', 125 | 'pascal_aug': 'pascal_aug', 126 | 'ade20k': 'ade', 127 | 'coco': 'coco', 128 | 'citys': 'citys', 129 | } 130 | from ..data.dataloader import datasets 131 | model = DUNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 132 | if pretrained: 133 | from .model_store import get_model_file 134 | device = torch.device(kwargs['local_rank']) 135 | model.load_state_dict(torch.load(get_model_file('dunet_%s_%s' % (backbone, acronyms[dataset]), root=root), 136 | map_location=device)) 137 | return model 138 | 139 | 140 | def get_dunet_resnet50_pascal_voc(**kwargs): 141 | return get_dunet('pascal_voc', 'resnet50', **kwargs) 142 | 143 | 144 | def get_dunet_resnet101_pascal_voc(**kwargs): 145 | return get_dunet('pascal_voc', 'resnet101', **kwargs) 146 | 147 | 148 | def get_dunet_resnet152_pascal_voc(**kwargs): 149 | return get_dunet('pascal_voc', 'resnet152', **kwargs) 150 | 151 | 152 | if __name__ == '__main__': 153 | img = torch.randn(2, 3, 256, 256) 154 | model = get_dunet_resnet50_pascal_voc() 155 | outputs = model(img) 156 | -------------------------------------------------------------------------------- /core/models/espnet.py: -------------------------------------------------------------------------------- 1 | "ESPNetv2: A Light-weight, Power Efficient, and General Purpose for Semantic Segmentation" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from core.models.base_models import eespnet, EESP 7 | from core.nn import _ConvBNPReLU, _BNPReLU 8 | 9 | 10 | class ESPNetV2(nn.Module): 11 | r"""ESPNetV2 12 | 13 | Parameters 14 | ---------- 15 | nclass : int 16 | Number of categories for the training dataset. 17 | backbone : string 18 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 19 | 'resnet101' or 'resnet152'). 20 | norm_layer : object 21 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 22 | for Synchronized Cross-GPU BachNormalization). 23 | aux : bool 24 | Auxiliary loss. 25 | 26 | Reference: 27 | Sachin Mehta, et al. "ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network." 28 | arXiv preprint arXiv:1811.11431 (2018). 29 | """ 30 | 31 | def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs): 32 | super(ESPNetV2, self).__init__() 33 | self.pretrained = eespnet(pretrained=pretrained_base, **kwargs) 34 | self.proj_L4_C = _ConvBNPReLU(256, 128, 1, **kwargs) 35 | self.pspMod = nn.Sequential( 36 | EESP(256, 128, stride=1, k=4, r_lim=7, **kwargs), 37 | _PSPModule(128, 128, **kwargs)) 38 | self.project_l3 = nn.Sequential( 39 | nn.Dropout2d(0.1), 40 | nn.Conv2d(128, nclass, 1, bias=False)) 41 | self.act_l3 = _BNPReLU(nclass, **kwargs) 42 | self.project_l2 = _ConvBNPReLU(64 + nclass, nclass, 1, **kwargs) 43 | self.project_l1 = nn.Sequential( 44 | nn.Dropout2d(0.1), 45 | nn.Conv2d(32 + nclass, nclass, 1, bias=False)) 46 | 47 | self.aux = aux 48 | 49 | self.__setattr__('exclusive', ['proj_L4_C', 'pspMod', 'project_l3', 'act_l3', 'project_l2', 'project_l1']) 50 | 51 | def forward(self, x): 52 | size = x.size()[2:] 53 | out_l1, out_l2, out_l3, out_l4 = self.pretrained(x, seg=True) 54 | out_l4_proj = self.proj_L4_C(out_l4) 55 | up_l4_to_l3 = F.interpolate(out_l4_proj, scale_factor=2, mode='bilinear', align_corners=True) 56 | merged_l3_upl4 = self.pspMod(torch.cat([out_l3, up_l4_to_l3], 1)) 57 | proj_merge_l3_bef_act = self.project_l3(merged_l3_upl4) 58 | proj_merge_l3 = self.act_l3(proj_merge_l3_bef_act) 59 | out_up_l3 = F.interpolate(proj_merge_l3, scale_factor=2, mode='bilinear', align_corners=True) 60 | merge_l2 = self.project_l2(torch.cat([out_l2, out_up_l3], 1)) 61 | out_up_l2 = F.interpolate(merge_l2, scale_factor=2, mode='bilinear', align_corners=True) 62 | merge_l1 = self.project_l1(torch.cat([out_l1, out_up_l2], 1)) 63 | 64 | outputs = list() 65 | merge1_l1 = F.interpolate(merge_l1, scale_factor=2, mode='bilinear', align_corners=True) 66 | outputs.append(merge1_l1) 67 | if self.aux: 68 | # different from paper 69 | auxout = F.interpolate(proj_merge_l3_bef_act, size, mode='bilinear', align_corners=True) 70 | outputs.append(auxout) 71 | 72 | return tuple(outputs) 73 | 74 | 75 | # different from PSPNet 76 | class _PSPModule(nn.Module): 77 | def __init__(self, in_channels, out_channels=1024, sizes=(1, 2, 4, 8), **kwargs): 78 | super(_PSPModule, self).__init__() 79 | self.stages = nn.ModuleList( 80 | [nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels, bias=False) for _ in sizes]) 81 | self.project = _ConvBNPReLU(in_channels * (len(sizes) + 1), out_channels, 1, 1, **kwargs) 82 | 83 | def forward(self, x): 84 | size = x.size()[2:] 85 | feats = [x] 86 | for stage in self.stages: 87 | x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1) 88 | upsampled = F.interpolate(stage(x), size, mode='bilinear', align_corners=True) 89 | feats.append(upsampled) 90 | return self.project(torch.cat(feats, dim=1)) 91 | 92 | 93 | def get_espnet(dataset='pascal_voc', backbone='', pretrained=False, root='~/.torch/models', 94 | pretrained_base=False, **kwargs): 95 | acronyms = { 96 | 'pascal_voc': 'pascal_voc', 97 | 'pascal_aug': 'pascal_aug', 98 | 'ade20k': 'ade', 99 | 'coco': 'coco', 100 | 'citys': 'citys', 101 | } 102 | from core.data.dataloader import datasets 103 | model = ESPNetV2(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 104 | if pretrained: 105 | from .model_store import get_model_file 106 | device = torch.device(kwargs['local_rank']) 107 | model.load_state_dict(torch.load(get_model_file('espnet_%s_%s' % (backbone, acronyms[dataset]), root=root), 108 | map_location=device)) 109 | return model 110 | 111 | 112 | def get_espnet_citys(**kwargs): 113 | return get_espnet('citys', **kwargs) 114 | 115 | 116 | if __name__ == '__main__': 117 | model = get_espnet_citys() 118 | -------------------------------------------------------------------------------- /core/models/fcnv2.py: -------------------------------------------------------------------------------- 1 | """Fully Convolutional Network with Stride of 8""" 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .segbase import SegBaseModel 9 | 10 | __all__ = ['FCN', 'get_fcn', 'get_fcn_resnet50_voc', 11 | 'get_fcn_resnet101_voc', 'get_fcn_resnet152_voc'] 12 | 13 | 14 | class FCN(SegBaseModel): 15 | def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs): 16 | super(FCN, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 17 | self.head = _FCNHead(2048, nclass, **kwargs) 18 | if aux: 19 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 20 | 21 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 22 | 23 | def forward(self, x): 24 | size = x.size()[2:] 25 | _, _, c3, c4 = self.base_forward(x) 26 | 27 | outputs = [] 28 | x = self.head(c4) 29 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 30 | outputs.append(x) 31 | if self.aux: 32 | auxout = self.auxlayer(c3) 33 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 34 | outputs.append(auxout) 35 | return tuple(outputs) 36 | 37 | 38 | class _FCNHead(nn.Module): 39 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 40 | super(_FCNHead, self).__init__() 41 | inter_channels = in_channels // 4 42 | self.block = nn.Sequential( 43 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 44 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 45 | nn.ReLU(True), 46 | nn.Dropout(0.1), 47 | nn.Conv2d(inter_channels, channels, 1) 48 | ) 49 | 50 | def forward(self, x): 51 | return self.block(x) 52 | 53 | 54 | def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models', 55 | pretrained_base=True, **kwargs): 56 | acronyms = { 57 | 'pascal_voc': 'pascal_voc', 58 | 'pascal_aug': 'pascal_aug', 59 | 'ade20k': 'ade', 60 | 'coco': 'coco', 61 | 'citys': 'citys', 62 | } 63 | from ..data.dataloader import datasets 64 | model = FCN(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 65 | if pretrained: 66 | from .model_store import get_model_file 67 | device = torch.device(kwargs['local_rank']) 68 | model.load_state_dict(torch.load(get_model_file('fcn_%s_%s' % (backbone, acronyms[dataset]), root=root), 69 | map_location=device)) 70 | return model 71 | 72 | 73 | def get_fcn_resnet50_voc(**kwargs): 74 | return get_fcn('pascal_voc', 'resnet50', **kwargs) 75 | 76 | 77 | def get_fcn_resnet101_voc(**kwargs): 78 | return get_fcn('pascal_voc', 'resnet101', **kwargs) 79 | 80 | 81 | def get_fcn_resnet152_voc(**kwargs): 82 | return get_fcn('pascal_voc', 'resnet152', **kwargs) 83 | -------------------------------------------------------------------------------- /core/models/hrnet.py: -------------------------------------------------------------------------------- 1 | """High-Resolution Representations for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class HRNet(nn.Module): 7 | """HRNet 8 | 9 | Parameters 10 | ---------- 11 | nclass : int 12 | Number of categories for the training dataset. 13 | backbone : string 14 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 15 | 'resnet101' or 'resnet152'). 16 | norm_layer : object 17 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 18 | for Synchronized Cross-GPU BachNormalization). 19 | aux : bool 20 | Auxiliary loss. 21 | Reference: 22 | Ke Sun. "High-Resolution Representations for Labeling Pixels and Regions." 23 | arXiv preprint arXiv:1904.04514 (2019). 24 | """ 25 | def __init__(self, nclass, backbone='', aux=False, pretrained_base=False, **kwargs): 26 | super(HRNet, self).__init__() 27 | 28 | def forward(self, x): 29 | pass -------------------------------------------------------------------------------- /core/models/icnet.py: -------------------------------------------------------------------------------- 1 | """Image Cascade Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | 8 | __all__ = ['ICNet', 'get_icnet', 'get_icnet_resnet50_citys', 9 | 'get_icnet_resnet101_citys', 'get_icnet_resnet152_citys'] 10 | 11 | 12 | class ICNet(SegBaseModel): 13 | """Image Cascade Network""" 14 | 15 | def __init__(self, nclass, backbone='resnet50', aux=False, jpu=False, pretrained_base=True, **kwargs): 16 | super(ICNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 17 | self.conv_sub1 = nn.Sequential( 18 | _ConvBNReLU(3, 32, 3, 2, **kwargs), 19 | _ConvBNReLU(32, 32, 3, 2, **kwargs), 20 | _ConvBNReLU(32, 64, 3, 2, **kwargs) 21 | ) 22 | 23 | self.ppm = PyramidPoolingModule() 24 | 25 | self.head = _ICHead(nclass, **kwargs) 26 | 27 | self.__setattr__('exclusive', ['conv_sub1', 'head']) 28 | 29 | def forward(self, x): 30 | # sub 1 31 | x_sub1 = self.conv_sub1(x) 32 | 33 | # sub 2 34 | x_sub2 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) 35 | _, x_sub2, _, _ = self.base_forward(x_sub2) 36 | 37 | # sub 4 38 | x_sub4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True) 39 | _, _, _, x_sub4 = self.base_forward(x_sub4) 40 | # add PyramidPoolingModule 41 | x_sub4 = self.ppm(x_sub4) 42 | outputs = self.head(x_sub1, x_sub2, x_sub4) 43 | 44 | return tuple(outputs) 45 | 46 | class PyramidPoolingModule(nn.Module): 47 | def __init__(self, pyramids=[1,2,3,6]): 48 | super(PyramidPoolingModule, self).__init__() 49 | self.pyramids = pyramids 50 | 51 | def forward(self, input): 52 | feat = input 53 | height, width = input.shape[2:] 54 | for bin_size in self.pyramids: 55 | x = F.adaptive_avg_pool2d(input, output_size=bin_size) 56 | x = F.interpolate(x, size=(height, width), mode='bilinear', align_corners=True) 57 | feat = feat + x 58 | return feat 59 | 60 | class _ICHead(nn.Module): 61 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs): 62 | super(_ICHead, self).__init__() 63 | #self.cff_12 = CascadeFeatureFusion(512, 64, 128, nclass, norm_layer, **kwargs) 64 | self.cff_12 = CascadeFeatureFusion(128, 64, 128, nclass, norm_layer, **kwargs) 65 | self.cff_24 = CascadeFeatureFusion(2048, 512, 128, nclass, norm_layer, **kwargs) 66 | 67 | self.conv_cls = nn.Conv2d(128, nclass, 1, bias=False) 68 | 69 | def forward(self, x_sub1, x_sub2, x_sub4): 70 | outputs = list() 71 | x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2) 72 | outputs.append(x_24_cls) 73 | #x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1) 74 | x_cff_12, x_12_cls = self.cff_12(x_cff_24, x_sub1) 75 | outputs.append(x_12_cls) 76 | 77 | up_x2 = F.interpolate(x_cff_12, scale_factor=2, mode='bilinear', align_corners=True) 78 | up_x2 = self.conv_cls(up_x2) 79 | outputs.append(up_x2) 80 | up_x8 = F.interpolate(up_x2, scale_factor=4, mode='bilinear', align_corners=True) 81 | outputs.append(up_x8) 82 | # 1 -> 1/4 -> 1/8 -> 1/16 83 | outputs.reverse() 84 | 85 | return outputs 86 | 87 | 88 | class _ConvBNReLU(nn.Module): 89 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, 90 | groups=1, norm_layer=nn.BatchNorm2d, bias=False, **kwargs): 91 | super(_ConvBNReLU, self).__init__() 92 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 93 | self.bn = norm_layer(out_channels) 94 | self.relu = nn.ReLU(True) 95 | 96 | def forward(self, x): 97 | x = self.conv(x) 98 | x = self.bn(x) 99 | x = self.relu(x) 100 | return x 101 | 102 | 103 | class CascadeFeatureFusion(nn.Module): 104 | """CFF Unit""" 105 | 106 | def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs): 107 | super(CascadeFeatureFusion, self).__init__() 108 | self.conv_low = nn.Sequential( 109 | nn.Conv2d(low_channels, out_channels, 3, padding=2, dilation=2, bias=False), 110 | norm_layer(out_channels) 111 | ) 112 | self.conv_high = nn.Sequential( 113 | nn.Conv2d(high_channels, out_channels, 1, bias=False), 114 | norm_layer(out_channels) 115 | ) 116 | self.conv_low_cls = nn.Conv2d(out_channels, nclass, 1, bias=False) 117 | 118 | def forward(self, x_low, x_high): 119 | x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True) 120 | x_low = self.conv_low(x_low) 121 | x_high = self.conv_high(x_high) 122 | x = x_low + x_high 123 | x = F.relu(x, inplace=True) 124 | x_low_cls = self.conv_low_cls(x_low) 125 | 126 | return x, x_low_cls 127 | 128 | 129 | def get_icnet(dataset='citys', backbone='resnet50', pretrained=False, root='~/.torch/models', 130 | pretrained_base=True, **kwargs): 131 | acronyms = { 132 | 'pascal_voc': 'pascal_voc', 133 | 'pascal_aug': 'pascal_aug', 134 | 'ade20k': 'ade', 135 | 'coco': 'coco', 136 | 'citys': 'citys', 137 | } 138 | from ..data.dataloader import datasets 139 | model = ICNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 140 | if pretrained: 141 | from .model_store import get_model_file 142 | device = torch.device(kwargs['local_rank']) 143 | model.load_state_dict(torch.load(get_model_file('icnet_%s_%s' % (backbone, acronyms[dataset]), root=root), 144 | map_location=device)) 145 | return model 146 | 147 | 148 | def get_icnet_resnet50_citys(**kwargs): 149 | return get_icnet('citys', 'resnet50', **kwargs) 150 | 151 | 152 | def get_icnet_resnet101_citys(**kwargs): 153 | return get_icnet('citys', 'resnet101', **kwargs) 154 | 155 | 156 | def get_icnet_resnet152_citys(**kwargs): 157 | return get_icnet('citys', 'resnet152', **kwargs) 158 | 159 | 160 | if __name__ == '__main__': 161 | img = torch.randn(1, 3, 256, 256) 162 | model = get_icnet_resnet50_citys() 163 | outputs = model(img) 164 | -------------------------------------------------------------------------------- /core/models/lednet.py: -------------------------------------------------------------------------------- 1 | """LEDNet: A Lightweight Encoder-Decoder Network for Real-time Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from core.nn import _ConvBNReLU 7 | 8 | __all__ = ['LEDNet', 'get_lednet', 'get_lednet_citys'] 9 | 10 | class LEDNet(nn.Module): 11 | r"""LEDNet 12 | 13 | Parameters 14 | ---------- 15 | nclass : int 16 | Number of categories for the training dataset. 17 | backbone : string 18 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 19 | 'resnet101' or 'resnet152'). 20 | norm_layer : object 21 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 22 | for Synchronized Cross-GPU BachNormalization). 23 | aux : bool 24 | Auxiliary loss. 25 | 26 | Reference: 27 | Yu Wang, et al. "LEDNet: A Lightweight Encoder-Decoder Network for Real-Time Semantic Segmentation." 28 | arXiv preprint arXiv:1905.02423 (2019). 29 | """ 30 | 31 | def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=True, **kwargs): 32 | super(LEDNet, self).__init__() 33 | self.encoder = nn.Sequential( 34 | Downsampling(3, 32), 35 | SSnbt(32, **kwargs), SSnbt(32, **kwargs), SSnbt(32, **kwargs), 36 | Downsampling(32, 64), 37 | SSnbt(64, **kwargs), SSnbt(64, **kwargs), 38 | Downsampling(64, 128), 39 | SSnbt(128, **kwargs), 40 | SSnbt(128, 2, **kwargs), 41 | SSnbt(128, 5, **kwargs), 42 | SSnbt(128, 9, **kwargs), 43 | SSnbt(128, 2, **kwargs), 44 | SSnbt(128, 5, **kwargs), 45 | SSnbt(128, 9, **kwargs), 46 | SSnbt(128, 17, **kwargs), 47 | ) 48 | self.decoder = APNModule(128, nclass) 49 | 50 | self.__setattr__('exclusive', ['encoder', 'decoder']) 51 | 52 | def forward(self, x): 53 | size = x.size()[2:] 54 | x = self.encoder(x) 55 | x = self.decoder(x) 56 | outputs = list() 57 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 58 | outputs.append(x) 59 | 60 | return tuple(outputs) 61 | 62 | 63 | class Downsampling(nn.Module): 64 | def __init__(self, in_channels, out_channels, **kwargs): 65 | super(Downsampling, self).__init__() 66 | self.conv1 = nn.Conv2d(in_channels, out_channels // 2, 3, 2, 2, bias=False) 67 | self.conv2 = nn.Conv2d(in_channels, out_channels // 2, 3, 2, 2, bias=False) 68 | self.pool = nn.MaxPool2d(kernel_size=2, stride=1) 69 | 70 | def forward(self, x): 71 | x1 = self.conv1(x) 72 | x1 = self.pool(x1) 73 | 74 | x2 = self.conv2(x) 75 | x2 = self.pool(x2) 76 | 77 | return torch.cat([x1, x2], dim=1) 78 | 79 | 80 | class SSnbt(nn.Module): 81 | def __init__(self, in_channels, dilation=1, norm_layer=nn.BatchNorm2d, **kwargs): 82 | super(SSnbt, self).__init__() 83 | inter_channels = in_channels // 2 84 | self.branch1 = nn.Sequential( 85 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(1, 0), bias=False), 86 | nn.ReLU(True), 87 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, 1), bias=False), 88 | norm_layer(inter_channels), 89 | nn.ReLU(True), 90 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(dilation, 0), dilation=(dilation, 1), 91 | bias=False), 92 | nn.ReLU(True), 93 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, dilation), dilation=(1, dilation), 94 | bias=False), 95 | norm_layer(inter_channels), 96 | nn.ReLU(True)) 97 | 98 | self.branch2 = nn.Sequential( 99 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, 1), bias=False), 100 | nn.ReLU(True), 101 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(1, 0), bias=False), 102 | norm_layer(inter_channels), 103 | nn.ReLU(True), 104 | nn.Conv2d(inter_channels, inter_channels, (1, 3), padding=(0, dilation), dilation=(1, dilation), 105 | bias=False), 106 | nn.ReLU(True), 107 | nn.Conv2d(inter_channels, inter_channels, (3, 1), padding=(dilation, 0), dilation=(dilation, 1), 108 | bias=False), 109 | norm_layer(inter_channels), 110 | nn.ReLU(True)) 111 | 112 | self.relu = nn.ReLU(True) 113 | 114 | @staticmethod 115 | def channel_shuffle(x, groups): 116 | n, c, h, w = x.size() 117 | 118 | channels_per_group = c // groups 119 | x = x.view(n, groups, channels_per_group, h, w) 120 | x = torch.transpose(x, 1, 2).contiguous() 121 | x = x.view(n, -1, h, w) 122 | 123 | return x 124 | 125 | def forward(self, x): 126 | # channels split 127 | x1, x2 = x.split(x.size(1) // 2, 1) 128 | 129 | x1 = self.branch1(x1) 130 | x2 = self.branch2(x2) 131 | 132 | out = torch.cat([x1, x2], dim=1) 133 | out = self.relu(out + x) 134 | out = self.channel_shuffle(out, groups=2) 135 | 136 | return out 137 | 138 | 139 | class APNModule(nn.Module): 140 | def __init__(self, in_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs): 141 | super(APNModule, self).__init__() 142 | self.conv1 = _ConvBNReLU(in_channels, in_channels, 3, 2, 1, norm_layer=norm_layer) 143 | self.conv2 = _ConvBNReLU(in_channels, in_channels, 5, 2, 2, norm_layer=norm_layer) 144 | self.conv3 = _ConvBNReLU(in_channels, in_channels, 7, 2, 3, norm_layer=norm_layer) 145 | self.level1 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer) 146 | self.level2 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer) 147 | self.level3 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer) 148 | self.level4 = _ConvBNReLU(in_channels, nclass, 1, norm_layer=norm_layer) 149 | self.level5 = nn.Sequential( 150 | nn.AdaptiveAvgPool2d(1), 151 | _ConvBNReLU(in_channels, nclass, 1)) 152 | 153 | def forward(self, x): 154 | w, h = x.size()[2:] 155 | branch3 = self.conv1(x) 156 | branch2 = self.conv2(branch3) 157 | branch1 = self.conv3(branch2) 158 | 159 | out = self.level1(branch1) 160 | out = F.interpolate(out, ((w + 3) // 4, (h + 3) // 4), mode='bilinear', align_corners=True) 161 | out = self.level2(branch2) + out 162 | out = F.interpolate(out, ((w + 1) // 2, (h + 1) // 2), mode='bilinear', align_corners=True) 163 | out = self.level3(branch3) + out 164 | out = F.interpolate(out, (w, h), mode='bilinear', align_corners=True) 165 | out = self.level4(x) * out 166 | out = self.level5(x) + out 167 | return out 168 | 169 | 170 | def get_lednet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models', 171 | pretrained_base=True, **kwargs): 172 | acronyms = { 173 | 'pascal_voc': 'pascal_voc', 174 | 'pascal_aug': 'pascal_aug', 175 | 'ade20k': 'ade', 176 | 'coco': 'coco', 177 | 'citys': 'citys', 178 | } 179 | from ..data.dataloader import datasets 180 | model = LEDNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 181 | if pretrained: 182 | from .model_store import get_model_file 183 | device = torch.device(kwargs['local_rank']) 184 | model.load_state_dict(torch.load(get_model_file('lednet_%s' % (acronyms[dataset]), root=root), 185 | map_location=device)) 186 | return model 187 | 188 | 189 | def get_lednet_citys(**kwargs): 190 | return get_lednet('citys', **kwargs) 191 | 192 | 193 | if __name__ == '__main__': 194 | model = get_lednet_citys() 195 | -------------------------------------------------------------------------------- /core/models/model_store.py: -------------------------------------------------------------------------------- 1 | """Model store which provides pretrained models.""" 2 | from __future__ import print_function 3 | 4 | import os 5 | import zipfile 6 | 7 | from ..utils.download import download, check_sha1 8 | 9 | __all__ = ['get_model_file', 'get_resnet_file'] 10 | 11 | _model_sha1 = {name: checksum for checksum, name in [ 12 | ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'), 13 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), 14 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), 15 | ]} 16 | 17 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' 18 | _url_format = '{repo_url}encoding/models/{file_name}.zip' 19 | 20 | 21 | def short_hash(name): 22 | if name not in _model_sha1: 23 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 24 | return _model_sha1[name][:8] 25 | 26 | 27 | def get_resnet_file(name, root='~/.torch/models'): 28 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) 29 | root = os.path.expanduser(root) 30 | 31 | file_path = os.path.join(root, file_name + '.pth') 32 | sha1_hash = _model_sha1[name] 33 | if os.path.exists(file_path): 34 | if check_sha1(file_path, sha1_hash): 35 | return file_path 36 | else: 37 | print('Mismatch in the content of model file {} detected.' + 38 | ' Downloading again.'.format(file_path)) 39 | else: 40 | print('Model file {} is not found. Downloading.'.format(file_path)) 41 | 42 | if not os.path.exists(root): 43 | os.makedirs(root) 44 | 45 | zip_file_path = os.path.join(root, file_name + '.zip') 46 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) 47 | if repo_url[-1] != '/': 48 | repo_url = repo_url + '/' 49 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 50 | path=zip_file_path, 51 | overwrite=True) 52 | with zipfile.ZipFile(zip_file_path) as zf: 53 | zf.extractall(root) 54 | os.remove(zip_file_path) 55 | 56 | if check_sha1(file_path, sha1_hash): 57 | return file_path 58 | else: 59 | raise ValueError('Downloaded file has different hash. Please try again.') 60 | 61 | 62 | def get_model_file(name, root='~/.torch/models'): 63 | root = os.path.expanduser(root) 64 | file_path = os.path.join(root, name + '.pth') 65 | if os.path.exists(file_path): 66 | return file_path 67 | else: 68 | raise ValueError('Model file is not found. Downloading or trainning.') 69 | -------------------------------------------------------------------------------- /core/models/model_zoo.py: -------------------------------------------------------------------------------- 1 | """Model store which handles pretrained models """ 2 | from .fcn import * 3 | from .fcnv2 import * 4 | from .pspnet import * 5 | from .deeplabv3 import * 6 | from .deeplabv3_plus import * 7 | from .danet import * 8 | from .denseaspp import * 9 | from .bisenet import * 10 | from .encnet import * 11 | from .dunet import * 12 | from .icnet import * 13 | from .enet import * 14 | from .ocnet import * 15 | from .psanet import * 16 | from .cgnet import * 17 | from .espnet import * 18 | from .lednet import * 19 | from .dfanet import * 20 | 21 | __all__ = ['get_model', 'get_model_list', 'get_segmentation_model'] 22 | 23 | _models = { 24 | 'fcn32s_vgg16_voc': get_fcn32s_vgg16_voc, 25 | 'fcn16s_vgg16_voc': get_fcn16s_vgg16_voc, 26 | 'fcn8s_vgg16_voc': get_fcn8s_vgg16_voc, 27 | 'fcn_resnet50_voc': get_fcn_resnet50_voc, 28 | 'fcn_resnet101_voc': get_fcn_resnet101_voc, 29 | 'fcn_resnet152_voc': get_fcn_resnet152_voc, 30 | 'psp_resnet50_voc': get_psp_resnet50_voc, 31 | 'psp_resnet50_ade': get_psp_resnet50_ade, 32 | 'psp_resnet101_voc': get_psp_resnet101_voc, 33 | 'psp_resnet101_ade': get_psp_resnet101_ade, 34 | 'psp_resnet101_citys': get_psp_resnet101_citys, 35 | 'psp_resnet101_coco': get_psp_resnet101_coco, 36 | 'deeplabv3_resnet50_voc': get_deeplabv3_resnet50_voc, 37 | 'deeplabv3_resnet101_voc': get_deeplabv3_resnet101_voc, 38 | 'deeplabv3_resnet152_voc': get_deeplabv3_resnet152_voc, 39 | 'deeplabv3_resnet50_ade': get_deeplabv3_resnet50_ade, 40 | 'deeplabv3_resnet101_ade': get_deeplabv3_resnet101_ade, 41 | 'deeplabv3_resnet152_ade': get_deeplabv3_resnet152_ade, 42 | 'deeplabv3_plus_xception_voc': get_deeplabv3_plus_xception_voc, 43 | 'danet_resnet50_ciyts': get_danet_resnet50_citys, 44 | 'danet_resnet101_citys': get_danet_resnet101_citys, 45 | 'danet_resnet152_citys': get_danet_resnet152_citys, 46 | 'denseaspp_densenet121_citys': get_denseaspp_densenet121_citys, 47 | 'denseaspp_densenet161_citys': get_denseaspp_densenet161_citys, 48 | 'denseaspp_densenet169_citys': get_denseaspp_densenet169_citys, 49 | 'denseaspp_densenet201_citys': get_denseaspp_densenet201_citys, 50 | 'bisenet_resnet18_citys': get_bisenet_resnet18_citys, 51 | 'encnet_resnet50_ade': get_encnet_resnet50_ade, 52 | 'encnet_resnet101_ade': get_encnet_resnet101_ade, 53 | 'encnet_resnet152_ade': get_encnet_resnet152_ade, 54 | 'dunet_resnet50_pascal_voc': get_dunet_resnet50_pascal_voc, 55 | 'dunet_resnet101_pascal_voc': get_dunet_resnet101_pascal_voc, 56 | 'dunet_resnet152_pascal_voc': get_dunet_resnet152_pascal_voc, 57 | 'icnet_resnet50_citys': get_icnet_resnet50_citys, 58 | 'icnet_resnet101_citys': get_icnet_resnet101_citys, 59 | 'icnet_resnet152_citys': get_icnet_resnet152_citys, 60 | 'enet_citys': get_enet_citys, 61 | 'base_ocnet_resnet101_citys': get_base_ocnet_resnet101_citys, 62 | 'pyramid_ocnet_resnet101_citys': get_pyramid_ocnet_resnet101_citys, 63 | 'asp_ocnet_resnet101_citys': get_asp_ocnet_resnet101_citys, 64 | 'psanet_resnet50_voc': get_psanet_resnet50_voc, 65 | 'psanet_resnet101_voc': get_psanet_resnet101_voc, 66 | 'psanet_resnet152_voc': get_psanet_resnet152_voc, 67 | 'psanet_resnet50_citys': get_psanet_resnet50_citys, 68 | 'psanet_resnet101_citys': get_psanet_resnet101_citys, 69 | 'psanet_resnet152_citys': get_psanet_resnet152_citys, 70 | 'cgnet_citys': get_cgnet_citys, 71 | 'espnet_citys': get_espnet_citys, 72 | 'lednet_citys': get_lednet_citys, 73 | 'dfanet_citys': get_dfanet_citys, 74 | } 75 | 76 | 77 | def get_model(name, **kwargs): 78 | name = name.lower() 79 | if name not in _models: 80 | err_str = '"%s" is not among the following model list:\n\t' % (name) 81 | err_str += '%s' % ('\n\t'.join(sorted(_models.keys()))) 82 | raise ValueError(err_str) 83 | net = _models[name](**kwargs) 84 | return net 85 | 86 | 87 | def get_model_list(): 88 | return _models.keys() 89 | 90 | 91 | def get_segmentation_model(model, **kwargs): 92 | models = { 93 | 'fcn32s': get_fcn32s, 94 | 'fcn16s': get_fcn16s, 95 | 'fcn8s': get_fcn8s, 96 | 'fcn': get_fcn, 97 | 'psp': get_psp, 98 | 'deeplabv3': get_deeplabv3, 99 | 'deeplabv3_plus': get_deeplabv3_plus, 100 | 'danet': get_danet, 101 | 'denseaspp': get_denseaspp, 102 | 'bisenet': get_bisenet, 103 | 'encnet': get_encnet, 104 | 'dunet': get_dunet, 105 | 'icnet': get_icnet, 106 | 'enet': get_enet, 107 | 'ocnet': get_ocnet, 108 | 'psanet': get_psanet, 109 | 'cgnet': get_cgnet, 110 | 'espnet': get_espnet, 111 | 'lednet': get_lednet, 112 | 'dfanet': get_dfanet, 113 | } 114 | return models[model](**kwargs) 115 | -------------------------------------------------------------------------------- /core/models/psanet.py: -------------------------------------------------------------------------------- 1 | """Point-wise Spatial Attention Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from core.nn import _ConvBNReLU 7 | from core.models.segbase import SegBaseModel 8 | from core.models.fcn import _FCNHead 9 | 10 | __all__ = ['PSANet', 'get_psanet', 'get_psanet_resnet50_voc', 'get_psanet_resnet101_voc', 11 | 'get_psanet_resnet152_voc', 'get_psanet_resnet50_citys', 'get_psanet_resnet101_citys', 12 | 'get_psanet_resnet152_citys'] 13 | 14 | 15 | class PSANet(SegBaseModel): 16 | r"""PSANet 17 | 18 | Parameters 19 | ---------- 20 | nclass : int 21 | Number of categories for the training dataset. 22 | backbone : string 23 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 24 | 'resnet101' or 'resnet152'). 25 | norm_layer : object 26 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 27 | for Synchronized Cross-GPU BachNormalization). 28 | aux : bool 29 | Auxiliary loss. 30 | 31 | Reference: 32 | Hengshuang Zhao, et al. "PSANet: Point-wise Spatial Attention Network for Scene Parsing." 33 | ECCV-2018. 34 | """ 35 | 36 | def __init__(self, nclass, backbone='resnet', aux=False, pretrained_base=True, **kwargs): 37 | super(PSANet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 38 | self.head = _PSAHead(nclass, **kwargs) 39 | if aux: 40 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 41 | 42 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 43 | 44 | def forward(self, x): 45 | size = x.size()[2:] 46 | _, _, c3, c4 = self.base_forward(x) 47 | outputs = list() 48 | x = self.head(c4) 49 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 50 | outputs.append(x) 51 | 52 | if self.aux: 53 | auxout = self.auxlayer(c3) 54 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 55 | outputs.append(auxout) 56 | return tuple(outputs) 57 | 58 | 59 | class _PSAHead(nn.Module): 60 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs): 61 | super(_PSAHead, self).__init__() 62 | # psa_out_channels = crop_size // 8 ** 2 63 | self.psa = _PointwiseSpatialAttention(2048, 3600, norm_layer) 64 | 65 | self.conv_post = _ConvBNReLU(1024, 2048, 1, norm_layer=norm_layer) 66 | self.project = nn.Sequential( 67 | _ConvBNReLU(4096, 512, 3, padding=1, norm_layer=norm_layer), 68 | nn.Dropout2d(0.1, False), 69 | nn.Conv2d(512, nclass, 1)) 70 | 71 | def forward(self, x): 72 | global_feature = self.psa(x) 73 | out = self.conv_post(global_feature) 74 | out = torch.cat([x, out], dim=1) 75 | out = self.project(out) 76 | 77 | return out 78 | 79 | 80 | class _PointwiseSpatialAttention(nn.Module): 81 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs): 82 | super(_PointwiseSpatialAttention, self).__init__() 83 | reduced_channels = 512 84 | self.collect_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer) 85 | self.distribute_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer) 86 | 87 | def forward(self, x): 88 | collect_fm = self.collect_attention(x) 89 | distribute_fm = self.distribute_attention(x) 90 | psa_fm = torch.cat([collect_fm, distribute_fm], dim=1) 91 | return psa_fm 92 | 93 | 94 | class _AttentionGeneration(nn.Module): 95 | def __init__(self, in_channels, reduced_channels, out_channels, norm_layer, **kwargs): 96 | super(_AttentionGeneration, self).__init__() 97 | self.conv_reduce = _ConvBNReLU(in_channels, reduced_channels, 1, norm_layer=norm_layer) 98 | self.attention = nn.Sequential( 99 | _ConvBNReLU(reduced_channels, reduced_channels, 1, norm_layer=norm_layer), 100 | nn.Conv2d(reduced_channels, out_channels, 1, bias=False)) 101 | 102 | self.reduced_channels = reduced_channels 103 | 104 | def forward(self, x): 105 | reduce_x = self.conv_reduce(x) 106 | attention = self.attention(reduce_x) 107 | n, c, h, w = attention.size() 108 | attention = attention.view(n, c, -1) 109 | reduce_x = reduce_x.view(n, self.reduced_channels, -1) 110 | fm = torch.bmm(reduce_x, torch.softmax(attention, dim=1)) 111 | fm = fm.view(n, self.reduced_channels, h, w) 112 | 113 | return fm 114 | 115 | 116 | def get_psanet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models', 117 | pretrained_base=True, **kwargs): 118 | acronyms = { 119 | 'pascal_voc': 'pascal_voc', 120 | 'pascal_aug': 'pascal_aug', 121 | 'ade20k': 'ade', 122 | 'coco': 'coco', 123 | 'citys': 'citys', 124 | } 125 | from core.data.dataloader import datasets 126 | model = PSANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 127 | if pretrained: 128 | from .model_store import get_model_file 129 | device = torch.device(kwargs['local_rank']) 130 | model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root), 131 | map_location=device)) 132 | return model 133 | 134 | 135 | def get_psanet_resnet50_voc(**kwargs): 136 | return get_psanet('pascal_voc', 'resnet50', **kwargs) 137 | 138 | 139 | def get_psanet_resnet101_voc(**kwargs): 140 | return get_psanet('pascal_voc', 'resnet101', **kwargs) 141 | 142 | 143 | def get_psanet_resnet152_voc(**kwargs): 144 | return get_psanet('pascal_voc', 'resnet152', **kwargs) 145 | 146 | 147 | def get_psanet_resnet50_citys(**kwargs): 148 | return get_psanet('citys', 'resnet50', **kwargs) 149 | 150 | 151 | def get_psanet_resnet101_citys(**kwargs): 152 | return get_psanet('citys', 'resnet101', **kwargs) 153 | 154 | 155 | def get_psanet_resnet152_citys(**kwargs): 156 | return get_psanet('citys', 'resnet152', **kwargs) 157 | 158 | 159 | if __name__ == '__main__': 160 | model = get_psanet_resnet50_voc() 161 | img = torch.randn(1, 3, 480, 480) 162 | output = model(img) 163 | -------------------------------------------------------------------------------- /core/models/pspnet.py: -------------------------------------------------------------------------------- 1 | """Pyramid Scene Parsing Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .fcn import _FCNHead 8 | 9 | __all__ = ['PSPNet', 'get_psp', 'get_psp_resnet50_voc', 'get_psp_resnet50_ade', 'get_psp_resnet101_voc', 10 | 'get_psp_resnet101_ade', 'get_psp_resnet101_citys', 'get_psp_resnet101_coco'] 11 | 12 | 13 | class PSPNet(SegBaseModel): 14 | r"""Pyramid Scene Parsing Network 15 | 16 | Parameters 17 | ---------- 18 | nclass : int 19 | Number of categories for the training dataset. 20 | backbone : string 21 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 22 | 'resnet101' or 'resnet152'). 23 | norm_layer : object 24 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 25 | for Synchronized Cross-GPU BachNormalization). 26 | aux : bool 27 | Auxiliary loss. 28 | 29 | Reference: 30 | Zhao, Hengshuang, Jianping Shi, Xiaojuan Qi, Xiaogang Wang, and Jiaya Jia. 31 | "Pyramid scene parsing network." *CVPR*, 2017 32 | """ 33 | 34 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs): 35 | super(PSPNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 36 | self.head = _PSPHead(nclass, **kwargs) 37 | if self.aux: 38 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 39 | 40 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 41 | 42 | def forward(self, x): 43 | size = x.size()[2:] 44 | _, _, c3, c4 = self.base_forward(x) 45 | outputs = [] 46 | x = self.head(c4) 47 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 48 | outputs.append(x) 49 | 50 | if self.aux: 51 | auxout = self.auxlayer(c3) 52 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 53 | outputs.append(auxout) 54 | return tuple(outputs) 55 | 56 | 57 | def _PSP1x1Conv(in_channels, out_channels, norm_layer, norm_kwargs): 58 | return nn.Sequential( 59 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 60 | norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)), 61 | nn.ReLU(True) 62 | ) 63 | 64 | 65 | class _PyramidPooling(nn.Module): 66 | def __init__(self, in_channels, **kwargs): 67 | super(_PyramidPooling, self).__init__() 68 | out_channels = int(in_channels / 4) 69 | self.avgpool1 = nn.AdaptiveAvgPool2d(1) 70 | self.avgpool2 = nn.AdaptiveAvgPool2d(2) 71 | self.avgpool3 = nn.AdaptiveAvgPool2d(3) 72 | self.avgpool4 = nn.AdaptiveAvgPool2d(6) 73 | self.conv1 = _PSP1x1Conv(in_channels, out_channels, **kwargs) 74 | self.conv2 = _PSP1x1Conv(in_channels, out_channels, **kwargs) 75 | self.conv3 = _PSP1x1Conv(in_channels, out_channels, **kwargs) 76 | self.conv4 = _PSP1x1Conv(in_channels, out_channels, **kwargs) 77 | 78 | def forward(self, x): 79 | size = x.size()[2:] 80 | feat1 = F.interpolate(self.conv1(self.avgpool1(x)), size, mode='bilinear', align_corners=True) 81 | feat2 = F.interpolate(self.conv2(self.avgpool2(x)), size, mode='bilinear', align_corners=True) 82 | feat3 = F.interpolate(self.conv3(self.avgpool3(x)), size, mode='bilinear', align_corners=True) 83 | feat4 = F.interpolate(self.conv4(self.avgpool4(x)), size, mode='bilinear', align_corners=True) 84 | return torch.cat([x, feat1, feat2, feat3, feat4], dim=1) 85 | 86 | 87 | class _PSPHead(nn.Module): 88 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 89 | super(_PSPHead, self).__init__() 90 | self.psp = _PyramidPooling(2048, norm_layer=norm_layer, norm_kwargs=norm_kwargs) 91 | self.block = nn.Sequential( 92 | nn.Conv2d(4096, 512, 3, padding=1, bias=False), 93 | norm_layer(512, **({} if norm_kwargs is None else norm_kwargs)), 94 | nn.ReLU(True), 95 | nn.Dropout(0.1), 96 | nn.Conv2d(512, nclass, 1) 97 | ) 98 | 99 | def forward(self, x): 100 | x = self.psp(x) 101 | return self.block(x) 102 | 103 | 104 | def get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models', 105 | pretrained_base=True, **kwargs): 106 | r"""Pyramid Scene Parsing Network 107 | 108 | Parameters 109 | ---------- 110 | dataset : str, default pascal_voc 111 | The dataset that model pretrained on. (pascal_voc, ade20k) 112 | pretrained : bool or str 113 | Boolean value controls whether to load the default pretrained weights for model. 114 | String value represents the hashtag for a certain version of pretrained weights. 115 | root : str, default '~/.torch/models' 116 | Location for keeping the model parameters. 117 | pretrained_base : bool or str, default True 118 | This will load pretrained backbone network, that was trained on ImageNet. 119 | Examples 120 | -------- 121 | >>> model = get_psp(dataset='pascal_voc', backbone='resnet50', pretrained=False) 122 | >>> print(model) 123 | """ 124 | acronyms = { 125 | 'pascal_voc': 'pascal_voc', 126 | 'pascal_aug': 'pascal_aug', 127 | 'ade20k': 'ade', 128 | 'coco': 'coco', 129 | 'citys': 'citys', 130 | } 131 | from ..data.dataloader import datasets 132 | model = PSPNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 133 | if pretrained: 134 | from .model_store import get_model_file 135 | device = torch.device(kwargs['local_rank']) 136 | model.load_state_dict(torch.load(get_model_file('psp_%s_%s' % (backbone, acronyms[dataset]), root=root), 137 | map_location=device)) 138 | return model 139 | 140 | 141 | def get_psp_resnet50_voc(**kwargs): 142 | return get_psp('pascal_voc', 'resnet50', **kwargs) 143 | 144 | 145 | def get_psp_resnet50_ade(**kwargs): 146 | return get_psp('ade20k', 'resnet50', **kwargs) 147 | 148 | 149 | def get_psp_resnet101_voc(**kwargs): 150 | return get_psp('pascal_voc', 'resnet101', **kwargs) 151 | 152 | 153 | def get_psp_resnet101_ade(**kwargs): 154 | return get_psp('ade20k', 'resnet101', **kwargs) 155 | 156 | 157 | def get_psp_resnet101_citys(**kwargs): 158 | return get_psp('citys', 'resnet101', **kwargs) 159 | 160 | 161 | def get_psp_resnet101_coco(**kwargs): 162 | return get_psp('coco', 'resnet101', **kwargs) 163 | 164 | 165 | if __name__ == '__main__': 166 | model = get_psp_resnet50_voc() 167 | img = torch.randn(4, 3, 480, 480) 168 | output = model(img) 169 | -------------------------------------------------------------------------------- /core/models/segbase.py: -------------------------------------------------------------------------------- 1 | """Base Model for Semantic Segmentation""" 2 | import torch.nn as nn 3 | 4 | from ..nn import JPU 5 | from .base_models.resnetv1b import resnet50_v1s, resnet101_v1s, resnet152_v1s 6 | 7 | __all__ = ['SegBaseModel'] 8 | 9 | 10 | class SegBaseModel(nn.Module): 11 | r"""Base Model for Semantic Segmentation 12 | 13 | Parameters 14 | ---------- 15 | backbone : string 16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 17 | 'resnet101' or 'resnet152'). 18 | """ 19 | 20 | def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=True, **kwargs): 21 | super(SegBaseModel, self).__init__() 22 | dilated = False if jpu else True 23 | self.aux = aux 24 | self.nclass = nclass 25 | if backbone == 'resnet50': 26 | self.pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 27 | elif backbone == 'resnet101': 28 | self.pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 29 | elif backbone == 'resnet152': 30 | self.pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 31 | else: 32 | raise RuntimeError('unknown backbone: {}'.format(backbone)) 33 | 34 | self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None 35 | 36 | def base_forward(self, x): 37 | """forwarding pre-trained network""" 38 | x = self.pretrained.conv1(x) 39 | x = self.pretrained.bn1(x) 40 | x = self.pretrained.relu(x) 41 | x = self.pretrained.maxpool(x) 42 | c1 = self.pretrained.layer1(x) 43 | c2 = self.pretrained.layer2(c1) 44 | c3 = self.pretrained.layer3(c2) 45 | c4 = self.pretrained.layer4(c3) 46 | 47 | if self.jpu: 48 | return self.jpu(c1, c2, c3, c4) 49 | else: 50 | return c1, c2, c3, c4 51 | 52 | def evaluate(self, x): 53 | """evaluating network with inputs and targets""" 54 | return self.forward(x)[0] 55 | 56 | def demo(self, x): 57 | pred = self.forward(x) 58 | if self.aux: 59 | pred = pred[0] 60 | return pred 61 | -------------------------------------------------------------------------------- /core/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Seg NN Modules""" 2 | from .jpu import * 3 | from .basic import * 4 | -------------------------------------------------------------------------------- /core/nn/basic.py: -------------------------------------------------------------------------------- 1 | """Basic Module for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual'] 7 | 8 | 9 | class _ConvBNReLU(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 11 | dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d, **kwargs): 12 | super(_ConvBNReLU, self).__init__() 13 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 14 | self.bn = norm_layer(out_channels) 15 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | x = self.bn(x) 20 | x = self.relu(x) 21 | return x 22 | 23 | 24 | class _ConvBNPReLU(nn.Module): 25 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 26 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs): 27 | super(_ConvBNPReLU, self).__init__() 28 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 29 | self.bn = norm_layer(out_channels) 30 | self.prelu = nn.PReLU(out_channels) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | x = self.bn(x) 35 | x = self.prelu(x) 36 | return x 37 | 38 | 39 | class _ConvBN(nn.Module): 40 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 41 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs): 42 | super(_ConvBN, self).__init__() 43 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 44 | self.bn = norm_layer(out_channels) 45 | 46 | def forward(self, x): 47 | x = self.conv(x) 48 | x = self.bn(x) 49 | return x 50 | 51 | 52 | class _BNPReLU(nn.Module): 53 | def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs): 54 | super(_BNPReLU, self).__init__() 55 | self.bn = norm_layer(out_channels) 56 | self.prelu = nn.PReLU(out_channels) 57 | 58 | def forward(self, x): 59 | x = self.bn(x) 60 | x = self.prelu(x) 61 | return x 62 | 63 | 64 | # ----------------------------------------------------------------- 65 | # For PSPNet 66 | # ----------------------------------------------------------------- 67 | class _PSPModule(nn.Module): 68 | def __init__(self, in_channels, sizes=(1, 2, 3, 6), **kwargs): 69 | super(_PSPModule, self).__init__() 70 | out_channels = int(in_channels / 4) 71 | self.avgpools = nn.ModuleList() 72 | self.convs = nn.ModuleList() 73 | for size in sizes: 74 | self.avgpool.append(nn.AdaptiveAvgPool2d(size)) 75 | self.convs.append(_ConvBNReLU(in_channels, out_channels, 1, **kwargs)) 76 | 77 | def forward(self, x): 78 | size = x.size()[2:] 79 | feats = [x] 80 | for (avgpool, conv) in enumerate(zip(self.avgpools, self.convs)): 81 | feats.append(F.interpolate(conv(avgpool(x)), size, mode='bilinear', align_corners=True)) 82 | return torch.cat(feats, dim=1) 83 | 84 | 85 | # ----------------------------------------------------------------- 86 | # For MobileNet 87 | # ----------------------------------------------------------------- 88 | class _DepthwiseConv(nn.Module): 89 | """conv_dw in MobileNet""" 90 | 91 | def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs): 92 | super(_DepthwiseConv, self).__init__() 93 | self.conv = nn.Sequential( 94 | _ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer), 95 | _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer)) 96 | 97 | def forward(self, x): 98 | return self.conv(x) 99 | 100 | 101 | # ----------------------------------------------------------------- 102 | # For MobileNetV2 103 | # ----------------------------------------------------------------- 104 | class InvertedResidual(nn.Module): 105 | def __init__(self, in_channels, out_channels, stride, expand_ratio, norm_layer=nn.BatchNorm2d, **kwargs): 106 | super(InvertedResidual, self).__init__() 107 | assert stride in [1, 2] 108 | self.use_res_connect = stride == 1 and in_channels == out_channels 109 | 110 | layers = list() 111 | inter_channels = int(round(in_channels * expand_ratio)) 112 | if expand_ratio != 1: 113 | # pw 114 | layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer)) 115 | layers.extend([ 116 | # dw 117 | _ConvBNReLU(inter_channels, inter_channels, 3, stride, 1, 118 | groups=inter_channels, relu6=True, norm_layer=norm_layer), 119 | # pw-linear 120 | nn.Conv2d(inter_channels, out_channels, 1, bias=False), 121 | norm_layer(out_channels)]) 122 | self.conv = nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | if self.use_res_connect: 126 | return x + self.conv(x) 127 | else: 128 | return self.conv(x) 129 | 130 | 131 | if __name__ == '__main__': 132 | x = torch.randn(1, 32, 64, 64) 133 | model = InvertedResidual(32, 64, 2, 1) 134 | out = model(x) 135 | -------------------------------------------------------------------------------- /core/nn/jpu.py: -------------------------------------------------------------------------------- 1 | """Joint Pyramid Upsampling""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['JPU'] 7 | 8 | 9 | class SeparableConv2d(nn.Module): 10 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, 11 | dilation=1, bias=False, norm_layer=nn.BatchNorm2d): 12 | super(SeparableConv2d, self).__init__() 13 | self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias) 14 | self.bn = norm_layer(inplanes) 15 | self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | x = self.bn(x) 20 | x = self.pointwise(x) 21 | return x 22 | 23 | 24 | # copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py 25 | class JPU(nn.Module): 26 | def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs): 27 | super(JPU, self).__init__() 28 | 29 | self.conv5 = nn.Sequential( 30 | nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False), 31 | norm_layer(width), 32 | nn.ReLU(True)) 33 | self.conv4 = nn.Sequential( 34 | nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False), 35 | norm_layer(width), 36 | nn.ReLU(True)) 37 | self.conv3 = nn.Sequential( 38 | nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False), 39 | norm_layer(width), 40 | nn.ReLU(True)) 41 | 42 | self.dilation1 = nn.Sequential( 43 | SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False), 44 | norm_layer(width), 45 | nn.ReLU(True)) 46 | self.dilation2 = nn.Sequential( 47 | SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False), 48 | norm_layer(width), 49 | nn.ReLU(True)) 50 | self.dilation3 = nn.Sequential( 51 | SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False), 52 | norm_layer(width), 53 | nn.ReLU(True)) 54 | self.dilation4 = nn.Sequential( 55 | SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False), 56 | norm_layer(width), 57 | nn.ReLU(True)) 58 | 59 | def forward(self, *inputs): 60 | feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])] 61 | size = feats[-1].size()[2:] 62 | feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True) 63 | feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True) 64 | feat = torch.cat(feats, dim=1) 65 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], 66 | dim=1) 67 | 68 | return inputs[0], inputs[1], inputs[2], feat 69 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | from __future__ import absolute_import 3 | 4 | from .download import download, check_sha1 5 | from .filesystem import makedirs, try_import_pycocotools 6 | -------------------------------------------------------------------------------- /core/utils/download.py: -------------------------------------------------------------------------------- 1 | """Download files with progress bar.""" 2 | import os 3 | import hashlib 4 | import requests 5 | from tqdm import tqdm 6 | 7 | def check_sha1(filename, sha1_hash): 8 | """Check whether the sha1 hash of the file content matches the expected hash. 9 | Parameters 10 | ---------- 11 | filename : str 12 | Path to the file. 13 | sha1_hash : str 14 | Expected sha1 hash in hexadecimal digits. 15 | Returns 16 | ------- 17 | bool 18 | Whether the file content matches the expected hash. 19 | """ 20 | sha1 = hashlib.sha1() 21 | with open(filename, 'rb') as f: 22 | while True: 23 | data = f.read(1048576) 24 | if not data: 25 | break 26 | sha1.update(data) 27 | 28 | sha1_file = sha1.hexdigest() 29 | l = min(len(sha1_file), len(sha1_hash)) 30 | return sha1.hexdigest()[0:l] == sha1_hash[0:l] 31 | 32 | def download(url, path=None, overwrite=False, sha1_hash=None): 33 | """Download an given URL 34 | Parameters 35 | ---------- 36 | url : str 37 | URL to download 38 | path : str, optional 39 | Destination path to store downloaded file. By default stores to the 40 | current directory with same name as in url. 41 | overwrite : bool, optional 42 | Whether to overwrite destination file if already exists. 43 | sha1_hash : str, optional 44 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 45 | but doesn't match. 46 | Returns 47 | ------- 48 | str 49 | The file path of the downloaded file. 50 | """ 51 | if path is None: 52 | fname = url.split('/')[-1] 53 | else: 54 | path = os.path.expanduser(path) 55 | if os.path.isdir(path): 56 | fname = os.path.join(path, url.split('/')[-1]) 57 | else: 58 | fname = path 59 | 60 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 61 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 62 | if not os.path.exists(dirname): 63 | os.makedirs(dirname) 64 | 65 | print('Downloading %s from %s...'%(fname, url)) 66 | r = requests.get(url, stream=True) 67 | if r.status_code != 200: 68 | raise RuntimeError("Failed downloading url %s"%url) 69 | total_length = r.headers.get('content-length') 70 | with open(fname, 'wb') as f: 71 | if total_length is None: # no content length header 72 | for chunk in r.iter_content(chunk_size=1024): 73 | if chunk: # filter out keep-alive new chunks 74 | f.write(chunk) 75 | else: 76 | total_length = int(total_length) 77 | for chunk in tqdm(r.iter_content(chunk_size=1024), 78 | total=int(total_length / 1024. + 0.5), 79 | unit='KB', unit_scale=False, dynamic_ncols=True): 80 | f.write(chunk) 81 | 82 | if sha1_hash and not check_sha1(fname, sha1_hash): 83 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 84 | 'The repo may be outdated or download may be incomplete. ' \ 85 | 'If the "repo_url" is overridden, consider switching to ' \ 86 | 'the default repo.'.format(fname)) 87 | 88 | return fname -------------------------------------------------------------------------------- /core/utils/filesystem.py: -------------------------------------------------------------------------------- 1 | """Filesystem utility functions.""" 2 | from __future__ import absolute_import 3 | import os 4 | import errno 5 | 6 | 7 | def makedirs(path): 8 | """Create directory recursively if not exists. 9 | Similar to `makedir -p`, you can skip checking existence before this function. 10 | Parameters 11 | ---------- 12 | path : str 13 | Path of the desired dir 14 | """ 15 | try: 16 | os.makedirs(path) 17 | except OSError as exc: 18 | if exc.errno != errno.EEXIST: 19 | raise 20 | 21 | 22 | def try_import(package, message=None): 23 | """Try import specified package, with custom message support. 24 | Parameters 25 | ---------- 26 | package : str 27 | The name of the targeting package. 28 | message : str, default is None 29 | If not None, this function will raise customized error message when import error is found. 30 | Returns 31 | ------- 32 | module if found, raise ImportError otherwise 33 | """ 34 | try: 35 | return __import__(package) 36 | except ImportError as e: 37 | if not message: 38 | raise e 39 | raise ImportError(message) 40 | 41 | 42 | def try_import_cv2(): 43 | """Try import cv2 at runtime. 44 | Returns 45 | ------- 46 | cv2 module if found. Raise ImportError otherwise 47 | """ 48 | msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \ 49 | or `pip install opencv-python --user` (note that this is unofficial PYPI package)." 50 | return try_import('cv2', msg) 51 | 52 | 53 | def import_try_install(package, extern_url=None): 54 | """Try import the specified package. 55 | If the package not installed, try use pip to install and import if success. 56 | Parameters 57 | ---------- 58 | package : str 59 | The name of the package trying to import. 60 | extern_url : str or None, optional 61 | The external url if package is not hosted on PyPI. 62 | For example, you can install a package using: 63 | "pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx". 64 | In this case, you can pass the url to the extern_url. 65 | Returns 66 | ------- 67 | 68 | The imported python module. 69 | """ 70 | try: 71 | return __import__(package) 72 | except ImportError: 73 | try: 74 | from pip import main as pipmain 75 | except ImportError: 76 | from pip._internal import main as pipmain 77 | 78 | # trying to install package 79 | url = package if extern_url is None else extern_url 80 | pipmain(['install', '--user', url]) # will raise SystemExit Error if fails 81 | 82 | # trying to load again 83 | try: 84 | return __import__(package) 85 | except ImportError: 86 | import sys 87 | import site 88 | user_site = site.getusersitepackages() 89 | if user_site not in sys.path: 90 | sys.path.append(user_site) 91 | return __import__(package) 92 | return __import__(package) 93 | 94 | 95 | """Import helper for pycocotools""" 96 | 97 | 98 | # NOTE: for developers 99 | # please do not import any pycocotools in __init__ because we are trying to lazy 100 | # import pycocotools to avoid install it for other users who may not use it. 101 | # only import when you actually use it 102 | 103 | 104 | def try_import_pycocotools(): 105 | """Tricks to optionally install and import pycocotools""" 106 | # first we can try import pycocotools 107 | try: 108 | import pycocotools as _ 109 | except ImportError: 110 | import os 111 | # we need to install pycootools, which is a bit tricky 112 | # pycocotools sdist requires Cython, numpy(already met) 113 | import_try_install('cython') 114 | # pypi pycocotools is not compatible with windows 115 | win_url = 'git+https://github.com/zhreshold/cocoapi.git#subdirectory=PythonAPI' 116 | try: 117 | if os.name == 'nt': 118 | import_try_install('pycocotools', win_url) 119 | else: 120 | import_try_install('pycocotools') 121 | except ImportError: 122 | faq = 'cocoapi FAQ' 123 | raise ImportError('Cannot import or install pycocotools, please refer to %s.' % faq) 124 | -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | 6 | __all__ = ['setup_logger'] 7 | 8 | 9 | # reference from: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/logger.py 10 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'): 11 | logger = logging.getLogger(name) 12 | logger.setLevel(logging.DEBUG) 13 | # don't log results for the non-master process 14 | if distributed_rank > 0: 15 | return logger 16 | ch = logging.StreamHandler(stream=sys.stdout) 17 | ch.setLevel(logging.DEBUG) 18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 19 | ch.setFormatter(formatter) 20 | logger.addHandler(ch) 21 | 22 | if save_dir: 23 | if not os.path.exists(save_dir): 24 | os.makedirs(save_dir) 25 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /core/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """Popular Learning Rate Schedulers""" 2 | from __future__ import division 3 | import math 4 | import torch 5 | 6 | from bisect import bisect_right 7 | 8 | __all__ = ['LRScheduler', 'WarmupMultiStepLR', 'WarmupPolyLR'] 9 | 10 | 11 | class LRScheduler(object): 12 | r"""Learning Rate Scheduler 13 | 14 | Parameters 15 | ---------- 16 | mode : str 17 | Modes for learning rate scheduler. 18 | Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'. 19 | base_lr : float 20 | Base learning rate, i.e. the starting learning rate. 21 | target_lr : float 22 | Target learning rate, i.e. the ending learning rate. 23 | With constant mode target_lr is ignored. 24 | niters : int 25 | Number of iterations to be scheduled. 26 | nepochs : int 27 | Number of epochs to be scheduled. 28 | iters_per_epoch : int 29 | Number of iterations in each epoch. 30 | offset : int 31 | Number of iterations before this scheduler. 32 | power : float 33 | Power parameter of poly scheduler. 34 | step_iter : list 35 | A list of iterations to decay the learning rate. 36 | step_epoch : list 37 | A list of epochs to decay the learning rate. 38 | step_factor : float 39 | Learning rate decay factor. 40 | """ 41 | 42 | def __init__(self, mode, base_lr=0.01, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0, 43 | offset=0, power=0.9, step_iter=None, step_epoch=None, step_factor=0.1, warmup_epochs=0): 44 | super(LRScheduler, self).__init__() 45 | assert (mode in ['constant', 'step', 'linear', 'poly', 'cosine']) 46 | 47 | if mode == 'step': 48 | assert (step_iter is not None or step_epoch is not None) 49 | self.niters = niters 50 | self.step = step_iter 51 | epoch_iters = nepochs * iters_per_epoch 52 | if epoch_iters > 0: 53 | self.niters = epoch_iters 54 | if step_epoch is not None: 55 | self.step = [s * iters_per_epoch for s in step_epoch] 56 | 57 | self.step_factor = step_factor 58 | self.base_lr = base_lr 59 | self.target_lr = base_lr if mode == 'constant' else target_lr 60 | self.offset = offset 61 | self.power = power 62 | self.warmup_iters = warmup_epochs * iters_per_epoch 63 | self.mode = mode 64 | 65 | def __call__(self, optimizer, num_update): 66 | self.update(num_update) 67 | assert self.learning_rate >= 0 68 | self._adjust_learning_rate(optimizer, self.learning_rate) 69 | 70 | def update(self, num_update): 71 | N = self.niters - 1 72 | T = num_update - self.offset 73 | T = min(max(0, T), N) 74 | 75 | if self.mode == 'constant': 76 | factor = 0 77 | elif self.mode == 'linear': 78 | factor = 1 - T / N 79 | elif self.mode == 'poly': 80 | factor = pow(1 - T / N, self.power) 81 | elif self.mode == 'cosine': 82 | factor = (1 + math.cos(math.pi * T / N)) / 2 83 | elif self.mode == 'step': 84 | if self.step is not None: 85 | count = sum([1 for s in self.step if s <= T]) 86 | factor = pow(self.step_factor, count) 87 | else: 88 | factor = 1 89 | else: 90 | raise NotImplementedError 91 | 92 | # warm up lr schedule 93 | if self.warmup_iters > 0 and T < self.warmup_iters: 94 | factor = factor * 1.0 * T / self.warmup_iters 95 | 96 | if self.mode == 'step': 97 | self.learning_rate = self.base_lr * factor 98 | else: 99 | self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor 100 | 101 | def _adjust_learning_rate(self, optimizer, lr): 102 | optimizer.param_groups[0]['lr'] = lr 103 | # enlarge the lr at the head 104 | for i in range(1, len(optimizer.param_groups)): 105 | optimizer.param_groups[i]['lr'] = lr * 10 106 | 107 | 108 | # separating MultiStepLR with WarmupLR 109 | # but the current LRScheduler design doesn't allow it 110 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py 111 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 112 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, 113 | warmup_iters=500, warmup_method="linear", last_epoch=-1): 114 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 115 | if not list(milestones) == sorted(milestones): 116 | raise ValueError( 117 | "Milestones should be a list of" " increasing integers. Got {}", milestones) 118 | if warmup_method not in ("constant", "linear"): 119 | raise ValueError( 120 | "Only 'constant' or 'linear' warmup_method accepted got {}".format(warmup_method)) 121 | 122 | self.milestones = milestones 123 | self.gamma = gamma 124 | self.warmup_factor = warmup_factor 125 | self.warmup_iters = warmup_iters 126 | self.warmup_method = warmup_method 127 | 128 | def get_lr(self): 129 | warmup_factor = 1 130 | if self.last_epoch < self.warmup_iters: 131 | if self.warmup_method == 'constant': 132 | warmup_factor = self.warmup_factor 133 | elif self.warmup_factor == 'linear': 134 | alpha = float(self.last_epoch) / self.warmup_iters 135 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 136 | return [base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 137 | for base_lr in self.base_lrs] 138 | 139 | 140 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler): 141 | def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3, 142 | warmup_iters=500, warmup_method='linear', last_epoch=-1): 143 | if warmup_method not in ("constant", "linear"): 144 | raise ValueError( 145 | "Only 'constant' or 'linear' warmup_method accepted " 146 | "got {}".format(warmup_method)) 147 | 148 | self.target_lr = target_lr 149 | self.max_iters = max_iters 150 | self.power = power 151 | self.warmup_factor = warmup_factor 152 | self.warmup_iters = warmup_iters 153 | self.warmup_method = warmup_method 154 | 155 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch) 156 | 157 | def get_lr(self): 158 | N = self.max_iters - self.warmup_iters 159 | T = self.last_epoch - self.warmup_iters 160 | if self.last_epoch < self.warmup_iters: 161 | if self.warmup_method == 'constant': 162 | warmup_factor = self.warmup_factor 163 | elif self.warmup_method == 'linear': 164 | alpha = float(self.last_epoch) / self.warmup_iters 165 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 166 | else: 167 | raise ValueError("Unknown warmup type.") 168 | return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs] 169 | factor = pow(1 - T / N, self.power) 170 | return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs] 171 | 172 | 173 | if __name__ == '__main__': 174 | import torch 175 | import torch.nn as nn 176 | 177 | model = nn.Conv2d(16, 16, 3, 1, 1) 178 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 179 | lr_scheduler = WarmupPolyLR(optimizer, niters=1000) 180 | -------------------------------------------------------------------------------- /core/utils/parallel.py: -------------------------------------------------------------------------------- 1 | """Utils for Semantic Segmentation""" 2 | import threading 3 | import torch 4 | import torch.cuda.comm as comm 5 | from torch.nn.parallel.data_parallel import DataParallel 6 | from torch.nn.parallel._functions import Broadcast 7 | from torch.autograd import Function 8 | 9 | __all__ = ['DataParallelModel', 'DataParallelCriterion'] 10 | 11 | 12 | class Reduce(Function): 13 | @staticmethod 14 | def forward(ctx, *inputs): 15 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 16 | inputs = sorted(inputs, key=lambda i: i.get_device()) 17 | return comm.reduce_add(inputs) 18 | 19 | @staticmethod 20 | def backward(ctx, gradOutputs): 21 | return Broadcast.apply(ctx.target_gpus, gradOutputs) 22 | 23 | 24 | class DataParallelModel(DataParallel): 25 | """Data parallelism 26 | 27 | Hide the difference of single/multiple GPUs to the user. 28 | In the forward pass, the module is replicated on each device, 29 | and each replica handles a portion of the input. During the backwards 30 | pass, gradients from each replica are summed into the original module. 31 | 32 | The batch size should be larger than the number of GPUs used. 33 | 34 | Parameters 35 | ---------- 36 | module : object 37 | Network to be parallelized. 38 | sync : bool 39 | enable synchronization (default: False). 40 | Inputs: 41 | - **inputs**: list of input 42 | Outputs: 43 | - **outputs**: list of output 44 | Example:: 45 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 46 | >>> output = net(input_var) # input_var can be on any device, including CPU 47 | """ 48 | 49 | def gather(self, outputs, output_device): 50 | return outputs 51 | 52 | def replicate(self, module, device_ids): 53 | modules = super(DataParallelModel, self).replicate(module, device_ids) 54 | return modules 55 | 56 | 57 | # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py 58 | class DataParallelCriterion(DataParallel): 59 | """ 60 | Calculate loss in multiple-GPUs, which balance the memory usage for 61 | Semantic Segmentation. 62 | 63 | The targets are splitted across the specified devices by chunking in 64 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 65 | 66 | Example:: 67 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 68 | >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 69 | >>> y = net(x) 70 | >>> loss = criterion(y, target) 71 | """ 72 | 73 | def forward(self, inputs, *targets, **kwargs): 74 | # the inputs should be the outputs of DataParallelModel 75 | if not self.device_ids: 76 | return self.module(inputs, *targets, **kwargs) 77 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 78 | if len(self.device_ids) == 1: 79 | return self.module(inputs, *targets[0], **kwargs[0]) 80 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 81 | outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs) 82 | return Reduce.apply(*outputs) / len(outputs) 83 | 84 | 85 | def get_a_var(obj): 86 | if isinstance(obj, torch.Tensor): 87 | return obj 88 | 89 | if isinstance(obj, list) or isinstance(obj, tuple): 90 | for result in map(get_a_var, obj): 91 | if isinstance(result, torch.Tensor): 92 | return result 93 | 94 | if isinstance(obj, dict): 95 | for result in map(get_a_var, obj.items()): 96 | if isinstance(result, torch.Tensor): 97 | return result 98 | return None 99 | 100 | 101 | def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 102 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 103 | contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword) 104 | on each of :attr:`devices`. 105 | 106 | Args: 107 | modules (Module): modules to be parallelized 108 | inputs (tensor): inputs to the modules 109 | targets (tensor): targets to the modules 110 | devices (list of int or torch.device): CUDA devices 111 | :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and 112 | :attr:`devices` (if given) should all have same length. Moreover, each 113 | element of :attr:`inputs` can either be a single object as the only argument 114 | to a module, or a collection of positional arguments. 115 | """ 116 | assert len(modules) == len(inputs) 117 | assert len(targets) == len(inputs) 118 | if kwargs_tup is not None: 119 | assert len(modules) == len(kwargs_tup) 120 | else: 121 | kwargs_tup = ({},) * len(modules) 122 | if devices is not None: 123 | assert len(modules) == len(devices) 124 | else: 125 | devices = [None] * len(modules) 126 | lock = threading.Lock() 127 | results = {} 128 | grad_enabled = torch.is_grad_enabled() 129 | 130 | def _worker(i, module, input, target, kwargs, device=None): 131 | torch.set_grad_enabled(grad_enabled) 132 | if device is None: 133 | device = get_a_var(input).get_device() 134 | try: 135 | with torch.cuda.device(device): 136 | output = module(*(list(input) + target), **kwargs) 137 | with lock: 138 | results[i] = output 139 | except Exception as e: 140 | with lock: 141 | results[i] = e 142 | 143 | if len(modules) > 1: 144 | threads = [threading.Thread(target=_worker, 145 | args=(i, module, input, target, kwargs, device)) 146 | for i, (module, input, target, kwargs, device) in 147 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 148 | 149 | for thread in threads: 150 | thread.start() 151 | for thread in threads: 152 | thread.join() 153 | else: 154 | _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0]) 155 | 156 | outputs = [] 157 | for i in range(len(inputs)): 158 | output = results[i] 159 | if isinstance(output, Exception): 160 | raise output 161 | outputs.append(output) 162 | return outputs 163 | -------------------------------------------------------------------------------- /core/utils/score.py: -------------------------------------------------------------------------------- 1 | """Evaluation Metrics for Semantic Segmentation""" 2 | import torch 3 | import numpy as np 4 | 5 | __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union', 6 | 'pixelAccuracy', 'intersectionAndUnion', 'hist_info', 'compute_score'] 7 | 8 | 9 | class SegmentationMetric(object): 10 | """Computes pixAcc and mIoU metric scores 11 | """ 12 | 13 | def __init__(self, nclass): 14 | super(SegmentationMetric, self).__init__() 15 | self.nclass = nclass 16 | self.reset() 17 | 18 | def update(self, preds, labels): 19 | """Updates the internal evaluation result. 20 | 21 | Parameters 22 | ---------- 23 | labels : 'NumpyArray' or list of `NumpyArray` 24 | The labels of the data. 25 | preds : 'NumpyArray' or list of `NumpyArray` 26 | Predicted values. 27 | """ 28 | 29 | def evaluate_worker(self, pred, label): 30 | correct, labeled = batch_pix_accuracy(pred, label) 31 | inter, union = batch_intersection_union(pred, label, self.nclass) 32 | 33 | self.total_correct += correct 34 | self.total_label += labeled 35 | if self.total_inter.device != inter.device: 36 | self.total_inter = self.total_inter.to(inter.device) 37 | self.total_union = self.total_union.to(union.device) 38 | self.total_inter += inter 39 | self.total_union += union 40 | 41 | if isinstance(preds, torch.Tensor): 42 | evaluate_worker(self, preds, labels) 43 | elif isinstance(preds, (list, tuple)): 44 | for (pred, label) in zip(preds, labels): 45 | evaluate_worker(self, pred, label) 46 | 47 | def get(self): 48 | """Gets the current evaluation result. 49 | 50 | Returns 51 | ------- 52 | metrics : tuple of float 53 | pixAcc and mIoU 54 | """ 55 | pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove np.spacing(1) 56 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union) 57 | mIoU = IoU.mean().item() 58 | return pixAcc, mIoU 59 | 60 | def reset(self): 61 | """Resets the internal evaluation result to initial state.""" 62 | self.total_inter = torch.zeros(self.nclass) 63 | self.total_union = torch.zeros(self.nclass) 64 | self.total_correct = 0 65 | self.total_label = 0 66 | 67 | 68 | # pytorch version 69 | def batch_pix_accuracy(output, target): 70 | """PixAcc""" 71 | # inputs are numpy array, output 4D, target 3D 72 | predict = torch.argmax(output.long(), 1) + 1 73 | target = target.long() + 1 74 | 75 | pixel_labeled = torch.sum(target > 0).item() 76 | pixel_correct = torch.sum((predict == target) * (target > 0)).item() 77 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 78 | return pixel_correct, pixel_labeled 79 | 80 | 81 | def batch_intersection_union(output, target, nclass): 82 | """mIoU""" 83 | # inputs are numpy array, output 4D, target 3D 84 | mini = 1 85 | maxi = nclass 86 | nbins = nclass 87 | predict = torch.argmax(output, 1) + 1 88 | target = target.float() + 1 89 | 90 | predict = predict.float() * (target > 0).float() 91 | intersection = predict * (predict == target).float() 92 | # areas of intersection and union 93 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 94 | area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) 95 | area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) 96 | area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) 97 | area_union = area_pred + area_lab - area_inter 98 | assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" 99 | return area_inter.float(), area_union.float() 100 | 101 | 102 | def pixelAccuracy(imPred, imLab): 103 | """ 104 | This function takes the prediction and label of a single image, returns pixel-wise accuracy 105 | To compute over many images do: 106 | for i = range(Nimages): 107 | (pixel_accuracy[i], pixel_correct[i], pixel_labeled[i]) = \ 108 | pixelAccuracy(imPred[i], imLab[i]) 109 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled)) 110 | """ 111 | # Remove classes from unlabeled pixels in gt image. 112 | # We should not penalize detections in unlabeled portions of the image. 113 | pixel_labeled = np.sum(imLab >= 0) 114 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0)) 115 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 116 | return (pixel_accuracy, pixel_correct, pixel_labeled) 117 | 118 | 119 | def intersectionAndUnion(imPred, imLab, numClass): 120 | """ 121 | This function takes the prediction and label of a single image, 122 | returns intersection and union areas for each class 123 | To compute over many images do: 124 | for i in range(Nimages): 125 | (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i]) 126 | IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1) 127 | """ 128 | # Remove classes from unlabeled pixels in gt image. 129 | # We should not penalize detections in unlabeled portions of the image. 130 | imPred = imPred * (imLab >= 0) 131 | 132 | # Compute area intersection: 133 | intersection = imPred * (imPred == imLab) 134 | (area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass)) 135 | 136 | # Compute area union: 137 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 138 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 139 | area_union = area_pred + area_lab - area_intersection 140 | return (area_intersection, area_union) 141 | 142 | 143 | def hist_info(pred, label, num_cls): 144 | assert pred.shape == label.shape 145 | k = (label >= 0) & (label < num_cls) 146 | labeled = np.sum(k) 147 | correct = np.sum((pred[k] == label[k])) 148 | 149 | return np.bincount(num_cls * label[k].astype(int) + pred[k], minlength=num_cls ** 2).reshape(num_cls, 150 | num_cls), labeled, correct 151 | 152 | 153 | def compute_score(hist, correct, labeled): 154 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 155 | mean_IU = np.nanmean(iu) 156 | mean_IU_no_back = np.nanmean(iu[1:]) 157 | freq = hist.sum(1) / hist.sum() 158 | freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() 159 | mean_pixel_acc = correct / labeled 160 | 161 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 162 | -------------------------------------------------------------------------------- /core/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | __all__ = ['get_color_pallete', 'print_iou', 'set_img_color', 6 | 'show_prediction', 'show_colorful_images', 'save_colorful_images'] 7 | 8 | 9 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False): 10 | n = iu.size 11 | lines = [] 12 | for i in range(n): 13 | if class_names is None: 14 | cls = 'Class %d:' % (i + 1) 15 | else: 16 | cls = '%d %s' % (i + 1, class_names[i]) 17 | # lines.append('%-8s: %.3f%%' % (cls, iu[i] * 100)) 18 | mean_IU = np.nanmean(iu) 19 | mean_IU_no_back = np.nanmean(iu[1:]) 20 | if show_no_back: 21 | lines.append('mean_IU: %.3f%% || mean_IU_no_back: %.3f%% || mean_pixel_acc: %.3f%%' % ( 22 | mean_IU * 100, mean_IU_no_back * 100, mean_pixel_acc * 100)) 23 | else: 24 | lines.append('mean_IU: %.3f%% || mean_pixel_acc: %.3f%%' % (mean_IU * 100, mean_pixel_acc * 100)) 25 | lines.append('=================================================') 26 | line = "\n".join(lines) 27 | 28 | print(line) 29 | 30 | 31 | def set_img_color(img, label, colors, background=0, show255=False): 32 | for i in range(len(colors)): 33 | if i != background: 34 | img[np.where(label == i)] = colors[i] 35 | if show255: 36 | img[np.where(label == 255)] = 255 37 | 38 | return img 39 | 40 | 41 | def show_prediction(img, pred, colors, background=0): 42 | im = np.array(img, np.uint8) 43 | set_img_color(im, pred, colors, background) 44 | out = np.array(im) 45 | 46 | return out 47 | 48 | 49 | def show_colorful_images(prediction, palettes): 50 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()]) 51 | im.show() 52 | 53 | 54 | def save_colorful_images(prediction, filename, output_dir, palettes): 55 | ''' 56 | :param prediction: [B, H, W, C] 57 | ''' 58 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()]) 59 | fn = os.path.join(output_dir, filename) 60 | out_dir = os.path.split(fn)[0] 61 | if not os.path.exists(out_dir): 62 | os.mkdir(out_dir) 63 | im.save(fn) 64 | 65 | 66 | def get_color_pallete(npimg, dataset='pascal_voc'): 67 | """Visualize image. 68 | 69 | Parameters 70 | ---------- 71 | npimg : numpy.ndarray 72 | Single channel image with shape `H, W, 1`. 73 | dataset : str, default: 'pascal_voc' 74 | The dataset that model pretrained on. ('pascal_voc', 'ade20k') 75 | Returns 76 | ------- 77 | out_img : PIL.Image 78 | Image with color pallete 79 | """ 80 | # recovery boundary 81 | if dataset in ('pascal_voc', 'pascal_aug'): 82 | npimg[npimg == -1] = 255 83 | # put colormap 84 | if dataset == 'ade20k': 85 | npimg = npimg + 1 86 | out_img = Image.fromarray(npimg.astype('uint8')) 87 | out_img.putpalette(adepallete) 88 | return out_img 89 | elif dataset == 'citys': 90 | out_img = Image.fromarray(npimg.astype('uint8')) 91 | out_img.putpalette(cityspallete) 92 | return out_img 93 | out_img = Image.fromarray(npimg.astype('uint8')) 94 | out_img.putpalette(vocpallete) 95 | return out_img 96 | 97 | 98 | def _getvocpallete(num_cls): 99 | n = num_cls 100 | pallete = [0] * (n * 3) 101 | for j in range(0, n): 102 | lab = j 103 | pallete[j * 3 + 0] = 0 104 | pallete[j * 3 + 1] = 0 105 | pallete[j * 3 + 2] = 0 106 | i = 0 107 | while (lab > 0): 108 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 109 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 110 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 111 | i = i + 1 112 | lab >>= 3 113 | return pallete 114 | 115 | 116 | vocpallete = _getvocpallete(256) 117 | 118 | adepallete = [ 119 | 0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204, 120 | 5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82, 121 | 143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255, 122 | 7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6, 123 | 10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255, 124 | 20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15, 125 | 20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255, 126 | 31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163, 127 | 0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255, 128 | 0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0, 129 | 31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0, 130 | 194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255, 131 | 0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255, 132 | 0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0, 133 | 163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0, 134 | 10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0, 135 | 255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0, 136 | 133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255] 137 | 138 | cityspallete = [ 139 | 128, 64, 128, 140 | 244, 35, 232, 141 | 70, 70, 70, 142 | 102, 102, 156, 143 | 190, 153, 153, 144 | 153, 153, 153, 145 | 250, 170, 30, 146 | 220, 220, 0, 147 | 107, 142, 35, 148 | 152, 251, 152, 149 | 0, 130, 180, 150 | 220, 20, 60, 151 | 255, 0, 0, 152 | 0, 0, 142, 153 | 0, 0, 70, 154 | 0, 60, 100, 155 | 0, 80, 100, 156 | 0, 0, 230, 157 | 119, 11, 32, 158 | ] 159 | -------------------------------------------------------------------------------- /datasets/ade: -------------------------------------------------------------------------------- 1 | /home/tramac/PycharmProjects/Data_zoo/ade -------------------------------------------------------------------------------- /datasets/citys: -------------------------------------------------------------------------------- 1 | /home/tramac/PycharmProjects/Data_zoo/citys -------------------------------------------------------------------------------- /datasets/sbu: -------------------------------------------------------------------------------- 1 | /home/tramac/PycharmProjects/Data_zoo/SBU-shadow -------------------------------------------------------------------------------- /datasets/voc: -------------------------------------------------------------------------------- 1 | /home/tramac/PycharmProjects/Data_zoo/VOCdevkit -------------------------------------------------------------------------------- /docs/DETAILS.md: -------------------------------------------------------------------------------- 1 | ### Model & Backbone 2 | 3 | | Model | Scratch | VGG16 | ResNet18 | ResNet50 | ResNet101 | ResNet152 | DenseNet121 | DenseNet169 | 4 | | :-------: | :-----: | :---: | :------: | :------: | :-------: | :-------: | :---------: | :---------: | 5 | | FCN32s | ✘ | ✓ | | | | | | | 6 | | FCN16s | | ✓ | | | | | | | 7 | | FCN8s | | ✓ | | | | | | | 8 | | FCNv2 | | | | ✓ | ✓ | ✓ | | | 9 | | PSPNet | | | | ✓ | ✓ | ✓ | | | 10 | | DeepLabv3 | | | | ✓ | ✓ | ✓ | | | 11 | | DenseASPP | | | | | | | ✓ | ✓ | 12 | | DANet | | | | ✓ | ✓ | ✓ | | | 13 | | BiSeNet | | | ✓ | | | | | | 14 | | EncNet | | | | ✓ | ✓ | ✓ | | | 15 | | ICNet | | | | ✓ | ✓ | ✓ | | | 16 | | DUNet | | | | ✓ | ✓ | ✓ | | | 17 | | ENet | ✓ | | | | | | | | 18 | | OCNet | | | | ✓ | ✓ | ✓ | | | 19 | | CCNet | | | | ✓ | ✓ | ✓ | | | 20 | | PSANet | | | | ✓ | ✓ | ✓ | | | 21 | | CGNet | ✓ | | | | | | | | 22 | | ESPNet | ✓ | | | | | | | | 23 | | LEDNet | ✓ | | | | | | | | 24 | | DFANet | ✓ | | | | | | | | 25 | -------------------------------------------------------------------------------- /docs/QQ.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/docs/QQ.jpg -------------------------------------------------------------------------------- /docs/WeChat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/docs/WeChat.jpeg -------------------------------------------------------------------------------- /docs/requirements.yml: -------------------------------------------------------------------------------- 1 | name: seg_requirements 2 | dependencies: 3 | - python3 4 | - numpy 5 | - cuda 6 | - pip: 7 | - Image 8 | - tqdm 9 | - requests 10 | - pytorch 1.0 11 | - torchvision 12 | -------------------------------------------------------------------------------- /docs/weimar_000091_000019_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/docs/weimar_000091_000019_gtFine_color.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch>=1.1.0 3 | torchvision>=0.3.0 4 | math 5 | pickle 6 | logging 7 | Pillow 8 | shutil 9 | -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | from torchvision import transforms 11 | from PIL import Image 12 | from core.utils.visualize import get_color_pallete 13 | from core.models import get_model 14 | 15 | parser = argparse.ArgumentParser( 16 | description='Predict segmentation result from a given image') 17 | parser.add_argument('--model', type=str, default='fcn32s_vgg16_voc', 18 | help='model name (default: fcn32_vgg16)') 19 | parser.add_argument('--dataset', type=str, default='pascal_aug', choices=['pascal_voc, pascal_aug, ade20k, citys'], 20 | help='dataset name (default: pascal_voc)') 21 | parser.add_argument('--save-folder', default='~/.torch/models', 22 | help='Directory for saving checkpoint models') 23 | parser.add_argument('--input-pic', type=str, default='../datasets/voc/VOC2012/JPEGImages/2007_000032.jpg', 24 | help='path to the input picture') 25 | parser.add_argument('--outdir', default='./eval', type=str, 26 | help='path to save the predict result') 27 | parser.add_argument('--local_rank', type=int, default=0) 28 | args = parser.parse_args() 29 | 30 | 31 | def demo(config): 32 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 33 | # output folder 34 | if not os.path.exists(config.outdir): 35 | os.makedirs(config.outdir) 36 | 37 | # image transform 38 | transform = transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 41 | ]) 42 | image = Image.open(config.input_pic).convert('RGB') 43 | images = transform(image).unsqueeze(0).to(device) 44 | 45 | model = get_model(args.model, local_rank=args.local_rank, pretrained=True, root=args.save_folder).to(device) 46 | print('Finished loading model!') 47 | 48 | model.eval() 49 | with torch.no_grad(): 50 | output = model(images) 51 | 52 | pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy() 53 | mask = get_color_pallete(pred, args.dataset) 54 | outname = os.path.splitext(os.path.split(args.input_pic)[-1])[0] + '.png' 55 | mask.save(os.path.join(args.outdir, outname)) 56 | 57 | 58 | if __name__ == '__main__': 59 | demo(args) 60 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils.data as data 13 | import torch.backends.cudnn as cudnn 14 | 15 | from torchvision import transforms 16 | from core.data.dataloader import get_segmentation_dataset 17 | from core.models.model_zoo import get_segmentation_model 18 | from core.utils.score import SegmentationMetric 19 | from core.utils.visualize import get_color_pallete 20 | from core.utils.logger import setup_logger 21 | from core.utils.distributed import synchronize, get_rank, make_data_sampler, make_batch_data_sampler 22 | 23 | from train import parse_args 24 | 25 | 26 | class Evaluator(object): 27 | def __init__(self, args): 28 | self.args = args 29 | self.device = torch.device(args.device) 30 | 31 | # image transform 32 | input_transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 35 | ]) 36 | 37 | # dataset and dataloader 38 | val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform) 39 | val_sampler = make_data_sampler(val_dataset, False, args.distributed) 40 | val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1) 41 | self.val_loader = data.DataLoader(dataset=val_dataset, 42 | batch_sampler=val_batch_sampler, 43 | num_workers=args.workers, 44 | pin_memory=True) 45 | 46 | # create network 47 | BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d 48 | self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone, 49 | aux=args.aux, pretrained=True, pretrained_base=False, 50 | local_rank=args.local_rank, 51 | norm_layer=BatchNorm2d).to(self.device) 52 | if args.distributed: 53 | self.model = nn.parallel.DistributedDataParallel(self.model, 54 | device_ids=[args.local_rank], output_device=args.local_rank) 55 | self.model.to(self.device) 56 | 57 | self.metric = SegmentationMetric(val_dataset.num_class) 58 | 59 | def eval(self): 60 | self.metric.reset() 61 | self.model.eval() 62 | if self.args.distributed: 63 | model = self.model.module 64 | else: 65 | model = self.model 66 | logger.info("Start validation, Total sample: {:d}".format(len(self.val_loader))) 67 | for i, (image, target, filename) in enumerate(self.val_loader): 68 | image = image.to(self.device) 69 | target = target.to(self.device) 70 | 71 | with torch.no_grad(): 72 | outputs = model(image) 73 | self.metric.update(outputs[0], target) 74 | pixAcc, mIoU = self.metric.get() 75 | logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format( 76 | i + 1, pixAcc * 100, mIoU * 100)) 77 | 78 | if self.args.save_pred: 79 | pred = torch.argmax(outputs[0], 1) 80 | pred = pred.cpu().data.numpy() 81 | 82 | predict = pred.squeeze(0) 83 | mask = get_color_pallete(predict, self.args.dataset) 84 | mask.save(os.path.join(outdir, os.path.splitext(filename[0])[0] + '.png')) 85 | synchronize() 86 | 87 | 88 | if __name__ == '__main__': 89 | args = parse_args() 90 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 91 | args.distributed = num_gpus > 1 92 | if not args.no_cuda and torch.cuda.is_available(): 93 | cudnn.benchmark = True 94 | args.device = "cuda" 95 | else: 96 | args.distributed = False 97 | args.device = "cpu" 98 | if args.distributed: 99 | torch.cuda.set_device(args.local_rank) 100 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 101 | synchronize() 102 | 103 | # TODO: optim code 104 | args.save_pred = True 105 | if args.save_pred: 106 | outdir = '../runs/pred_pic/{}_{}_{}'.format(args.model, args.backbone, args.dataset) 107 | if not os.path.exists(outdir): 108 | os.makedirs(outdir) 109 | 110 | logger = setup_logger("semantic_segmentation", args.log_dir, get_rank(), 111 | filename='{}_{}_{}_log.txt'.format(args.model, args.backbone, args.dataset), mode='a+') 112 | 113 | evaluator = Evaluator(args) 114 | evaluator.eval() 115 | torch.cuda.empty_cache() 116 | -------------------------------------------------------------------------------- /scripts/fcn32s_vgg16_pascal_voc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # train 4 | CUDA_VISIBLE_DEVICES=0 python train.py --model fcn32s \ 5 | --backbone vgg16 --dataset pascal_voc \ 6 | --lr 0.0001 --epochs 80 -------------------------------------------------------------------------------- /scripts/fcn32s_vgg16_pascal_voc_dist.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # train 4 | export NGPUS=4 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --model fcn32s \ 6 | --backbone vgg16 --dataset pascal_voc \ 7 | --lr 0.01 --epochs 80 --batch_size 16 -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Overfitting Test 2 | 3 | In order to ensure the correctness of models, the project provides a overfitting test (a trick which makes the train set and the val set includes the same images) script. 4 | Observing the convergence process of different models is so interesting:joy: 5 | 6 | ### Usage 7 | 8 | 9 | 10 |    (a) img: 2007_000033.jpg        (b) mask: 2007_000033.png 11 | 12 | ### Test Result 13 | | Model | backbone | epoch | mIoU | pixAcc | 14 | | :-----: | :----: | :-----: | :-----: | :------: | 15 | | FCN32s | vgg16 | 200 | 94.0% | 98.2% | 16 | | FCN16s | vgg16 | 200 | 99.2% | 99.8% | 17 | | FCN8s | vgg16 | 100 | 99.8% | 99.9% | 18 | | DANet | resnet50 | 100 | 99.5% | 99.9% | 19 | | EncNet | resnet50 | 100 | 99.7% | 99.9% | 20 | | DUNet | resnet50 | 100 | 98.8% | 99.6% | 21 | | PSPNet | resnet50 | 100 | 99.8% | 99.9% | 22 | | BiSeNet | resnet18 | 100 | 99.6% | 99.9% | 23 | | DenseASPP | densenet121 | 40 | 100% | 100% | 24 | | ICNet | resnet50 | 100 | 98.8% | 99.6% | 25 | | ENet | scratch | 100 | 99.9% | 100% | 26 | | OCNet | resnet50 | 100 | 99.8% | 100% | 27 | 28 | ### Visualization 29 | 30 | 31 | 32 | 33 | 34 | 35 |   FCN32s  FCN16s   FCN8s   DANet   EncNet    DUNet   PSPNet   BiSeNet   DenseASPP 36 | 37 | 38 | 39 | 40 |   ICNet   ENet   OCNet 41 | 42 | ### Conclusion 43 | - The result of FCN32s is the worst. 44 | - There are gridding artifacts in DUNet results. 45 | - The result of BiSeNet is bad when the `lr=1e-3`, the lr needs to be set to `1e-2`. 46 | - DenseASPP has the fastest convergence process, and reached 100%. 47 | - The lr of ENet need to be set to `1e-2`, the edge of result is not smooth. -------------------------------------------------------------------------------- /tests/runs/bisenet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/bisenet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/danet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/danet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/denseaspp_epoch_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/denseaspp_epoch_40.png -------------------------------------------------------------------------------- /tests/runs/dunet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/dunet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/encnet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/encnet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/enet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/enet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/fcn16s_epoch_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/fcn16s_epoch_200.png -------------------------------------------------------------------------------- /tests/runs/fcn32s_epoch_300.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/fcn32s_epoch_300.png -------------------------------------------------------------------------------- /tests/runs/fcn8s_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/fcn8s_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/icnet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/icnet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/ocnet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/ocnet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/psp_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/runs/psp_epoch_100.png -------------------------------------------------------------------------------- /tests/test_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/test_img.jpg -------------------------------------------------------------------------------- /tests/test_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tramac/awesome-semantic-segmentation-pytorch/d37d2a17221d2681ad454958cf06a1065e9b1f7f/tests/test_mask.png -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | """Model overfitting test""" 2 | import argparse 3 | import time 4 | import os 5 | import sys 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.backends.cudnn as cudnn 10 | import numpy as np 11 | 12 | cur_path = os.path.abspath(os.path.dirname(__file__)) 13 | root_path = os.path.split(cur_path)[0] 14 | sys.path.append(root_path) 15 | 16 | from torchvision import transforms 17 | from core.models.model_zoo import get_segmentation_model 18 | from core.utils.loss import MixSoftmaxCrossEntropyLoss, EncNetLoss, ICNetLoss 19 | from core.utils.lr_scheduler import LRScheduler 20 | from core.utils.score import hist_info, compute_score 21 | from core.utils.visualize import get_color_pallete 22 | from PIL import Image 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Semantic Segmentation Overfitting Test') 27 | # model 28 | parser.add_argument('--model', type=str, default='fcn32s', 29 | choices=['fcn32s', 'fcn16s', 'fcn8s', 'fcn', 'psp', 30 | 'deeplabv3', 'danet', 'denseaspp', 'bisenet', 'encnet', 31 | 'dunet', 'icnet', 'enet', 'ocnet'], 32 | help='model name (default: fcn32s)') 33 | parser.add_argument('--backbone', type=str, default='vgg16', 34 | choices=['vgg16', 'resnet18', 'resnet50', 'resnet101', 35 | 'resnet152', 'densenet121', '161', '169', '201'], 36 | help='backbone name (default: vgg16)') 37 | parser.add_argument('--dataset', type=str, default='pascal_voc', 38 | choices=['pascal_voc', 'pascal_aug', 'ade20k', 'citys', 39 | 'sbu'], 40 | help='dataset name (default: pascal_voc)') 41 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 42 | help='number of epochs to train (default: 100)') 43 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 44 | help='learning rate (default: 1e-3)') 45 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 46 | help='momentum (default: 0.9)') 47 | parser.add_argument('--weight-decay', type=float, default=1e-4, metavar='M', 48 | help='w-decay (default: 5e-4)') 49 | args = parser.parse_args() 50 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 51 | cudnn.benchmark = True 52 | args.device = device 53 | print(args) 54 | return args 55 | 56 | 57 | class VOCSegmentation(object): 58 | def __init__(self): 59 | super(VOCSegmentation, self).__init__() 60 | self.img = Image.open('test_img.jpg').convert('RGB') 61 | self.mask = Image.open('test_mask.png') 62 | 63 | self.img = self.img.resize((504, 368), Image.BILINEAR) 64 | self.mask = self.mask.resize((504, 368), Image.NEAREST) 65 | 66 | def get(self): 67 | img, mask = self._img_transform(self.img), self._mask_transform(self.mask) 68 | return img, mask 69 | 70 | def _img_transform(self, img): 71 | input_transform = transforms.Compose([ 72 | transforms.ToTensor(), 73 | transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) 74 | img = input_transform(img) 75 | img = img.unsqueeze(0) 76 | 77 | # For adaptive pooling 78 | # img = torch.cat([img, img], dim=0) 79 | return img 80 | 81 | def _mask_transform(self, mask): 82 | target = np.array(mask).astype('int32') 83 | target[target == 255] = -1 84 | target = torch.from_numpy(target).long() 85 | target = target.unsqueeze(0) 86 | 87 | # For adaptive pooling 88 | # target = torch.cat([target, target], dim=0) 89 | return target 90 | 91 | 92 | class Trainer(object): 93 | def __init__(self, args): 94 | self.args = args 95 | 96 | self.img, self.target = VOCSegmentation().get() 97 | 98 | self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone, 99 | aux=False, norm_layer=nn.BatchNorm2d).to(args.device) 100 | 101 | self.criterion = MixSoftmaxCrossEntropyLoss(False, 0., ignore_label=-1).to(args.device) 102 | 103 | # for EncNet 104 | # self.criterion = EncNetLoss(nclass=21, ignore_label=-1).to(args.device) 105 | # for ICNet 106 | # self.criterion = ICNetLoss(nclass=21, ignore_index=-1).to(args.device) 107 | 108 | self.optimizer = torch.optim.Adam(self.model.parameters(), 109 | lr=args.lr, 110 | weight_decay=args.weight_decay) 111 | self.lr_scheduler = LRScheduler(mode='poly', base_lr=args.lr, nepochs=args.epochs, 112 | iters_per_epoch=1, power=0.9) 113 | 114 | def train(self): 115 | self.model.train() 116 | start_time = time.time() 117 | for epoch in range(self.args.epochs): 118 | self.lr_scheduler(self.optimizer, epoch) 119 | cur_lr = self.lr_scheduler.learning_rate 120 | # self.lr_scheduler(self.optimizer, epoch) 121 | for param_group in self.optimizer.param_groups: 122 | param_group['lr'] = cur_lr 123 | 124 | images = self.img.to(self.args.device) 125 | targets = self.target.to(self.args.device) 126 | 127 | outputs = self.model(images) 128 | loss = self.criterion(outputs, targets) 129 | 130 | self.optimizer.zero_grad() 131 | loss['loss'].backward() 132 | self.optimizer.step() 133 | 134 | pred = torch.argmax(outputs[0], 1).cpu().data.numpy() 135 | mask = get_color_pallete(pred.squeeze(0), self.args.dataset) 136 | save_pred(self.args, epoch, mask) 137 | hist, labeled, correct = hist_info(pred, targets.cpu().numpy(), 21) 138 | _, mIoU, _, pixAcc = compute_score(hist, correct, labeled) 139 | 140 | print('Epoch: [%2d/%2d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f || pixAcc: %.3f || mIoU: %.3f' % ( 141 | epoch, self.args.epochs, time.time() - start_time, cur_lr, loss['loss'].item(), pixAcc, mIoU)) 142 | 143 | 144 | def save_pred(args, epoch, mask): 145 | directory = "runs/%s/" % (args.model) 146 | if not os.path.exists(directory): 147 | os.makedirs(directory) 148 | filename = directory + '{}_epoch_{}.png'.format(args.model, epoch + 1) 149 | mask.save(filename) 150 | 151 | 152 | if __name__ == '__main__': 153 | args = parse_args() 154 | trainer = Trainer(args) 155 | print('Test model: ', args.model) 156 | trainer.train() 157 | -------------------------------------------------------------------------------- /tests/test_module.py: -------------------------------------------------------------------------------- 1 | import core 2 | import torch 3 | import numpy as np 4 | 5 | from torch.autograd import Variable 6 | 7 | EPS = 1e-3 8 | ATOL = 1e-3 9 | 10 | 11 | def _assert_tensor_close(a, b, atol=ATOL, rtol=EPS): 12 | npa, npb = a.cpu().numpy(), b.cpu().numpy() 13 | assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ 14 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( 15 | a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 16 | 17 | 18 | def testSyncBN(): 19 | def _check_batchnorm_result(bn1, bn2, input, is_train, cuda=False): 20 | def _find_bn(module): 21 | for m in module.modules(): 22 | if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, 23 | core.nn.SyncBatchNorm)): 24 | return m 25 | 26 | def _syncParameters(bn1, bn2): 27 | bn1.reset_parameters() 28 | bn2.reset_parameters() 29 | if bn1.affine and bn2.affine: 30 | bn2.weight.data.copy_(bn1.weight.data) 31 | bn2.bias.data.copy_(bn1.bias.data) 32 | bn2.running_mean.copy_(bn1.running_mean) 33 | bn2.running_var.copy_(bn1.running_var) 34 | 35 | bn1.train(mode=is_train) 36 | bn2.train(mode=is_train) 37 | 38 | if cuda: 39 | input = input.cuda() 40 | # using the same values for gamma and beta 41 | _syncParameters(_find_bn(bn1), _find_bn(bn2)) 42 | 43 | input1 = Variable(input.clone().detach(), requires_grad=True) 44 | input2 = Variable(input.clone().detach(), requires_grad=True) 45 | if is_train: 46 | bn1.train() 47 | bn2.train() 48 | output1 = bn1(input1) 49 | output2 = bn2(input2) 50 | else: 51 | bn1.eval() 52 | bn2.eval() 53 | with torch.no_grad(): 54 | output1 = bn1(input1) 55 | output2 = bn2(input2) 56 | # assert forwarding 57 | # _assert_tensor_close(input1.data, input2.data) 58 | _assert_tensor_close(output1.data, output2.data) 59 | if not is_train: 60 | return 61 | (output1 ** 2).sum().backward() 62 | (output2 ** 2).sum().backward() 63 | _assert_tensor_close(_find_bn(bn1).bias.grad.data, _find_bn(bn2).bias.grad.data) 64 | _assert_tensor_close(_find_bn(bn1).weight.grad.data, _find_bn(bn2).weight.grad.data) 65 | _assert_tensor_close(input1.grad.data, input2.grad.data) 66 | _assert_tensor_close(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 67 | # _assert_tensor_close(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 68 | 69 | bn = torch.nn.BatchNorm2d(10).cuda().double() 70 | sync_bn = core.nn.SyncBatchNorm(10, inplace=True, sync=True).cuda().double() 71 | sync_bn = torch.nn.DataParallel(sync_bn).cuda() 72 | # check with unsync version 73 | # _check_batchnorm_result(bn, sync_bn, torch.rand(2, 1, 2, 2).double(), True, cuda=True) 74 | for i in range(10): 75 | print(i) 76 | _check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), True, cuda=True) 77 | # _check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), False, cuda=True) 78 | 79 | 80 | if __name__ == '__main__': 81 | import nose 82 | 83 | nose.runmodule() 84 | --------------------------------------------------------------------------------