├── .gitignore ├── LICENSE ├── README.md ├── dataloaders ├── __init__.py ├── custom_transforms.py ├── datasets │ ├── __init__.py │ ├── cityscapes.py │ ├── coco.py │ ├── combine_dbs.py │ ├── pascal.py │ └── sbd.py └── utils.py ├── doc ├── deeplab_resnet.py ├── deeplab_xception.py └── results.png ├── modeling ├── __init__.py ├── aspp.py ├── backbone │ ├── __init__.py │ ├── drn.py │ ├── mobilenet.py │ ├── resnet.py │ └── xception.py ├── decoder.py ├── deeplab.py └── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── mypath.py ├── train.py ├── train_coco.sh ├── train_voc.sh └── utils ├── calculate_weights.py ├── loss.py ├── lr_scheduler.py ├── metrics.py ├── saver.py └── summaries.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | run/ 10 | .idea/ 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # Environments 88 | .env 89 | .venv 90 | env/ 91 | venv/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Pyjcsx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-deeplab-xception 2 | 3 | **Update on 2018/12/06. Provide model trained on VOC and SBD datasets.** 4 | 5 | **Update on 2018/11/24. Release newest version code, which fix some previous issues and also add support for new backbones and multi-gpu training. For previous code, please see in `previous` branch** 6 | 7 | ### TODO 8 | - [x] Support different backbones 9 | - [x] Support VOC, SBD, Cityscapes and COCO datasets 10 | - [x] Multi-GPU training 11 | 12 | 13 | 14 | | Backbone | train/eval os |mIoU in val |Pretrained Model| 15 | | :-------- | :------------: |:---------: |:--------------:| 16 | | ResNet | 16/16 | 78.43% | [google drive](https://drive.google.com/open?id=1NwcwlWqA-0HqAPk3dSNNPipGMF0iS0Zu) | 17 | | MobileNet | 16/16 | 70.81% | [google drive](https://drive.google.com/open?id=1G9mWafUAj09P4KvGSRVzIsV_U5OqFLdt) | 18 | | DRN | 16/16 | 78.87% | [google drive](https://drive.google.com/open?id=131gZN_dKEXO79NknIQazPJ-4UmRrZAfI) | 19 | 20 | 21 | 22 | ### Introduction 23 | This is a PyTorch(0.4.1) implementation of [DeepLab-V3-Plus](https://arxiv.org/pdf/1802.02611). It 24 | can use Modified Aligned Xception and ResNet as backbone. Currently, we train DeepLab V3 Plus 25 | using Pascal VOC 2012, SBD and Cityscapes datasets. 26 | 27 | ![Results](doc/results.png) 28 | 29 | 30 | ### Installation 31 | The code was tested with Anaconda and Python 3.6. After installing the Anaconda environment: 32 | 33 | 0. Clone the repo: 34 | ```Shell 35 | git clone https://github.com/jfzhang95/pytorch-deeplab-xception.git 36 | cd pytorch-deeplab-xception 37 | ``` 38 | 39 | 1. Install dependencies: 40 | 41 | For PyTorch dependency, see [pytorch.org](https://pytorch.org/) for more details. 42 | 43 | For custom dependencies: 44 | ```Shell 45 | pip install matplotlib pillow tensorboardX tqdm 46 | ``` 47 | ### Training 48 | Follow steps below to train your model: 49 | 50 | 0. Configure your dataset path in [mypath.py](https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/mypath.py). 51 | 52 | 1. Input arguments: (see full input arguments via python train.py --help): 53 | ```Shell 54 | usage: train.py [-h] [--backbone {resnet,xception,drn,mobilenet}] 55 | [--out-stride OUT_STRIDE] [--dataset {pascal,coco,cityscapes}] 56 | [--use-sbd] [--workers N] [--base-size BASE_SIZE] 57 | [--crop-size CROP_SIZE] [--sync-bn SYNC_BN] 58 | [--freeze-bn FREEZE_BN] [--loss-type {ce,focal}] [--epochs N] 59 | [--start_epoch N] [--batch-size N] [--test-batch-size N] 60 | [--use-balanced-weights] [--lr LR] 61 | [--lr-scheduler {poly,step,cos}] [--momentum M] 62 | [--weight-decay M] [--nesterov] [--no-cuda] 63 | [--gpu-ids GPU_IDS] [--seed S] [--resume RESUME] 64 | [--checkname CHECKNAME] [--ft] [--eval-interval EVAL_INTERVAL] 65 | [--no-val] 66 | 67 | ``` 68 | 69 | 2. To train deeplabv3+ using Pascal VOC dataset and ResNet as backbone: 70 | ```Shell 71 | bash train_voc.sh 72 | ``` 73 | 3. To train deeplabv3+ using COCO dataset and ResNet as backbone: 74 | ```Shell 75 | bash train_coco.sh 76 | ``` 77 | 78 | ### Acknowledgement 79 | [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) 80 | 81 | [Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) 82 | 83 | [drn](https://github.com/fyu/drn) 84 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd 2 | from torch.utils.data import DataLoader 3 | 4 | def make_data_loader(args, **kwargs): 5 | 6 | if args.dataset == 'pascal': 7 | train_set = pascal.VOCSegmentation(args, split='train') 8 | val_set = pascal.VOCSegmentation(args, split='val') 9 | if args.use_sbd: 10 | sbd_train = sbd.SBDSegmentation(args, split=['train', 'val']) 11 | train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set]) 12 | 13 | num_class = train_set.NUM_CLASSES 14 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 15 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 16 | test_loader = None 17 | 18 | return train_loader, val_loader, test_loader, num_class 19 | 20 | elif args.dataset == 'cityscapes': 21 | train_set = cityscapes.CityscapesSegmentation(args, split='train') 22 | val_set = cityscapes.CityscapesSegmentation(args, split='val') 23 | test_set = cityscapes.CityscapesSegmentation(args, split='test') 24 | num_class = train_set.NUM_CLASSES 25 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 26 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 27 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 28 | 29 | return train_loader, val_loader, test_loader, num_class 30 | 31 | elif args.dataset == 'coco': 32 | train_set = coco.COCOSegmentation(args, split='train') 33 | val_set = coco.COCOSegmentation(args, split='val') 34 | num_class = train_set.NUM_CLASSES 35 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 36 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 37 | test_loader = None 38 | return train_loader, val_loader, test_loader, num_class 39 | 40 | else: 41 | raise NotImplementedError 42 | 43 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | class Normalize(object): 8 | """Normalize a tensor image with mean and standard deviation. 9 | Args: 10 | mean (tuple): means for each channel. 11 | std (tuple): standard deviations for each channel. 12 | """ 13 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 14 | self.mean = mean 15 | self.std = std 16 | 17 | def __call__(self, sample): 18 | img = sample['image'] 19 | mask = sample['label'] 20 | img = np.array(img).astype(np.float32) 21 | mask = np.array(mask).astype(np.float32) 22 | img /= 255.0 23 | img -= self.mean 24 | img /= self.std 25 | 26 | return {'image': img, 27 | 'label': mask} 28 | 29 | 30 | class ToTensor(object): 31 | """Convert ndarrays in sample to Tensors.""" 32 | 33 | def __call__(self, sample): 34 | # swap color axis because 35 | # numpy image: H x W x C 36 | # torch image: C X H X W 37 | img = sample['image'] 38 | mask = sample['label'] 39 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 40 | mask = np.array(mask).astype(np.float32) 41 | 42 | img = torch.from_numpy(img).float() 43 | mask = torch.from_numpy(mask).float() 44 | 45 | return {'image': img, 46 | 'label': mask} 47 | 48 | 49 | class RandomHorizontalFlip(object): 50 | def __call__(self, sample): 51 | img = sample['image'] 52 | mask = sample['label'] 53 | if random.random() < 0.5: 54 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 55 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 56 | 57 | return {'image': img, 58 | 'label': mask} 59 | 60 | 61 | class RandomRotate(object): 62 | def __init__(self, degree): 63 | self.degree = degree 64 | 65 | def __call__(self, sample): 66 | img = sample['image'] 67 | mask = sample['label'] 68 | rotate_degree = random.uniform(-1*self.degree, self.degree) 69 | img = img.rotate(rotate_degree, Image.BILINEAR) 70 | mask = mask.rotate(rotate_degree, Image.NEAREST) 71 | 72 | return {'image': img, 73 | 'label': mask} 74 | 75 | 76 | class RandomGaussianBlur(object): 77 | def __call__(self, sample): 78 | img = sample['image'] 79 | mask = sample['label'] 80 | if random.random() < 0.5: 81 | img = img.filter(ImageFilter.GaussianBlur( 82 | radius=random.random())) 83 | 84 | return {'image': img, 85 | 'label': mask} 86 | 87 | 88 | class RandomScaleCrop(object): 89 | def __init__(self, base_size, crop_size, fill=0): 90 | self.base_size = base_size 91 | self.crop_size = crop_size 92 | self.fill = fill 93 | 94 | def __call__(self, sample): 95 | img = sample['image'] 96 | mask = sample['label'] 97 | # random scale (short edge) 98 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 99 | w, h = img.size 100 | if h > w: 101 | ow = short_size 102 | oh = int(1.0 * h * ow / w) 103 | else: 104 | oh = short_size 105 | ow = int(1.0 * w * oh / h) 106 | img = img.resize((ow, oh), Image.BILINEAR) 107 | mask = mask.resize((ow, oh), Image.NEAREST) 108 | # pad crop 109 | if short_size < self.crop_size: 110 | padh = self.crop_size - oh if oh < self.crop_size else 0 111 | padw = self.crop_size - ow if ow < self.crop_size else 0 112 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 113 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 114 | # random crop crop_size 115 | w, h = img.size 116 | x1 = random.randint(0, w - self.crop_size) 117 | y1 = random.randint(0, h - self.crop_size) 118 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 119 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 120 | 121 | return {'image': img, 122 | 'label': mask} 123 | 124 | 125 | class FixScaleCrop(object): 126 | def __init__(self, crop_size): 127 | self.crop_size = crop_size 128 | 129 | def __call__(self, sample): 130 | img = sample['image'] 131 | mask = sample['label'] 132 | w, h = img.size 133 | if w > h: 134 | oh = self.crop_size 135 | ow = int(1.0 * w * oh / h) 136 | else: 137 | ow = self.crop_size 138 | oh = int(1.0 * h * ow / w) 139 | img = img.resize((ow, oh), Image.BILINEAR) 140 | mask = mask.resize((ow, oh), Image.NEAREST) 141 | # center crop 142 | w, h = img.size 143 | x1 = int(round((w - self.crop_size) / 2.)) 144 | y1 = int(round((h - self.crop_size) / 2.)) 145 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 146 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 147 | 148 | return {'image': img, 149 | 'label': mask} 150 | 151 | class FixedResize(object): 152 | def __init__(self, size): 153 | self.size = (size, size) # size: (h, w) 154 | 155 | def __call__(self, sample): 156 | img = sample['image'] 157 | mask = sample['label'] 158 | 159 | assert img.size == mask.size 160 | 161 | img = img.resize(self.size, Image.BILINEAR) 162 | mask = mask.resize(self.size, Image.NEAREST) 163 | 164 | return {'image': img, 165 | 'label': mask} -------------------------------------------------------------------------------- /dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfzhang95/pytorch-deeplab-xception/9135e104a7a51ea9effa9c6676a2fcffe6a6a2e6/dataloaders/datasets/__init__.py -------------------------------------------------------------------------------- /dataloaders/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.misc as m 4 | from PIL import Image 5 | from torch.utils import data 6 | from mypath import Path 7 | from torchvision import transforms 8 | from dataloaders import custom_transforms as tr 9 | 10 | class CityscapesSegmentation(data.Dataset): 11 | NUM_CLASSES = 19 12 | 13 | def __init__(self, args, root=Path.db_root_dir('cityscapes'), split="train"): 14 | 15 | self.root = root 16 | self.split = split 17 | self.args = args 18 | self.files = {} 19 | 20 | self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) 21 | self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split) 22 | 23 | self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') 24 | 25 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 26 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 27 | self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \ 28 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \ 29 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \ 30 | 'motorcycle', 'bicycle'] 31 | 32 | self.ignore_index = 255 33 | self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES))) 34 | 35 | if not self.files[split]: 36 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 37 | 38 | print("Found %d %s images" % (len(self.files[split]), split)) 39 | 40 | def __len__(self): 41 | return len(self.files[self.split]) 42 | 43 | def __getitem__(self, index): 44 | 45 | img_path = self.files[self.split][index].rstrip() 46 | lbl_path = os.path.join(self.annotations_base, 47 | img_path.split(os.sep)[-2], 48 | os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png') 49 | 50 | _img = Image.open(img_path).convert('RGB') 51 | _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) 52 | _tmp = self.encode_segmap(_tmp) 53 | _target = Image.fromarray(_tmp) 54 | 55 | sample = {'image': _img, 'label': _target} 56 | 57 | if self.split == 'train': 58 | return self.transform_tr(sample) 59 | elif self.split == 'val': 60 | return self.transform_val(sample) 61 | elif self.split == 'test': 62 | return self.transform_ts(sample) 63 | 64 | def encode_segmap(self, mask): 65 | # Put all void classes to zero 66 | for _voidc in self.void_classes: 67 | mask[mask == _voidc] = self.ignore_index 68 | for _validc in self.valid_classes: 69 | mask[mask == _validc] = self.class_map[_validc] 70 | return mask 71 | 72 | def recursive_glob(self, rootdir='.', suffix=''): 73 | """Performs recursive glob with given suffix and rootdir 74 | :param rootdir is the root directory 75 | :param suffix is the suffix to be searched 76 | """ 77 | return [os.path.join(looproot, filename) 78 | for looproot, _, filenames in os.walk(rootdir) 79 | for filename in filenames if filename.endswith(suffix)] 80 | 81 | def transform_tr(self, sample): 82 | composed_transforms = transforms.Compose([ 83 | tr.RandomHorizontalFlip(), 84 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), 85 | tr.RandomGaussianBlur(), 86 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 87 | tr.ToTensor()]) 88 | 89 | return composed_transforms(sample) 90 | 91 | def transform_val(self, sample): 92 | 93 | composed_transforms = transforms.Compose([ 94 | tr.FixScaleCrop(crop_size=self.args.crop_size), 95 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 96 | tr.ToTensor()]) 97 | 98 | return composed_transforms(sample) 99 | 100 | def transform_ts(self, sample): 101 | 102 | composed_transforms = transforms.Compose([ 103 | tr.FixedResize(size=self.args.crop_size), 104 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 105 | tr.ToTensor()]) 106 | 107 | return composed_transforms(sample) 108 | 109 | if __name__ == '__main__': 110 | from dataloaders.utils import decode_segmap 111 | from torch.utils.data import DataLoader 112 | import matplotlib.pyplot as plt 113 | import argparse 114 | 115 | parser = argparse.ArgumentParser() 116 | args = parser.parse_args() 117 | args.base_size = 513 118 | args.crop_size = 513 119 | 120 | cityscapes_train = CityscapesSegmentation(args, split='train') 121 | 122 | dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) 123 | 124 | for ii, sample in enumerate(dataloader): 125 | for jj in range(sample["image"].size()[0]): 126 | img = sample['image'].numpy() 127 | gt = sample['label'].numpy() 128 | tmp = np.array(gt[jj]).astype(np.uint8) 129 | segmap = decode_segmap(tmp, dataset='cityscapes') 130 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 131 | img_tmp *= (0.229, 0.224, 0.225) 132 | img_tmp += (0.485, 0.456, 0.406) 133 | img_tmp *= 255.0 134 | img_tmp = img_tmp.astype(np.uint8) 135 | plt.figure() 136 | plt.title('display') 137 | plt.subplot(211) 138 | plt.imshow(img_tmp) 139 | plt.subplot(212) 140 | plt.imshow(segmap) 141 | 142 | if ii == 1: 143 | break 144 | 145 | plt.show(block=True) 146 | 147 | -------------------------------------------------------------------------------- /dataloaders/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from mypath import Path 5 | from tqdm import trange 6 | import os 7 | from pycocotools.coco import COCO 8 | from pycocotools import mask 9 | from torchvision import transforms 10 | from dataloaders import custom_transforms as tr 11 | from PIL import Image, ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | class COCOSegmentation(Dataset): 16 | NUM_CLASSES = 21 17 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 18 | 1, 64, 20, 63, 7, 72] 19 | 20 | def __init__(self, 21 | args, 22 | base_dir=Path.db_root_dir('coco'), 23 | split='train', 24 | year='2017'): 25 | super().__init__() 26 | ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year)) 27 | ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year)) 28 | self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year)) 29 | self.split = split 30 | self.coco = COCO(ann_file) 31 | self.coco_mask = mask 32 | if os.path.exists(ids_file): 33 | self.ids = torch.load(ids_file) 34 | else: 35 | ids = list(self.coco.imgs.keys()) 36 | self.ids = self._preprocess(ids, ids_file) 37 | self.args = args 38 | 39 | def __getitem__(self, index): 40 | _img, _target = self._make_img_gt_point_pair(index) 41 | sample = {'image': _img, 'label': _target} 42 | 43 | if self.split == "train": 44 | return self.transform_tr(sample) 45 | elif self.split == 'val': 46 | return self.transform_val(sample) 47 | 48 | def _make_img_gt_point_pair(self, index): 49 | coco = self.coco 50 | img_id = self.ids[index] 51 | img_metadata = coco.loadImgs(img_id)[0] 52 | path = img_metadata['file_name'] 53 | _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB') 54 | cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) 55 | _target = Image.fromarray(self._gen_seg_mask( 56 | cocotarget, img_metadata['height'], img_metadata['width'])) 57 | 58 | return _img, _target 59 | 60 | def _preprocess(self, ids, ids_file): 61 | print("Preprocessing mask, this will take a while. " + \ 62 | "But don't worry, it only run once for each split.") 63 | tbar = trange(len(ids)) 64 | new_ids = [] 65 | for i in tbar: 66 | img_id = ids[i] 67 | cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 68 | img_metadata = self.coco.loadImgs(img_id)[0] 69 | mask = self._gen_seg_mask(cocotarget, img_metadata['height'], 70 | img_metadata['width']) 71 | # more than 1k pixels 72 | if (mask > 0).sum() > 1000: 73 | new_ids.append(img_id) 74 | tbar.set_description('Doing: {}/{}, got {} qualified images'. \ 75 | format(i, len(ids), len(new_ids))) 76 | print('Found number of qualified images: ', len(new_ids)) 77 | torch.save(new_ids, ids_file) 78 | return new_ids 79 | 80 | def _gen_seg_mask(self, target, h, w): 81 | mask = np.zeros((h, w), dtype=np.uint8) 82 | coco_mask = self.coco_mask 83 | for instance in target: 84 | rle = coco_mask.frPyObjects(instance['segmentation'], h, w) 85 | m = coco_mask.decode(rle) 86 | cat = instance['category_id'] 87 | if cat in self.CAT_LIST: 88 | c = self.CAT_LIST.index(cat) 89 | else: 90 | continue 91 | if len(m.shape) < 3: 92 | mask[:, :] += (mask == 0) * (m * c) 93 | else: 94 | mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) 95 | return mask 96 | 97 | def transform_tr(self, sample): 98 | composed_transforms = transforms.Compose([ 99 | tr.RandomHorizontalFlip(), 100 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 101 | tr.RandomGaussianBlur(), 102 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 103 | tr.ToTensor()]) 104 | 105 | return composed_transforms(sample) 106 | 107 | def transform_val(self, sample): 108 | 109 | composed_transforms = transforms.Compose([ 110 | tr.FixScaleCrop(crop_size=self.args.crop_size), 111 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 112 | tr.ToTensor()]) 113 | 114 | return composed_transforms(sample) 115 | 116 | 117 | def __len__(self): 118 | return len(self.ids) 119 | 120 | 121 | 122 | if __name__ == "__main__": 123 | from dataloaders import custom_transforms as tr 124 | from dataloaders.utils import decode_segmap 125 | from torch.utils.data import DataLoader 126 | from torchvision import transforms 127 | import matplotlib.pyplot as plt 128 | import argparse 129 | 130 | parser = argparse.ArgumentParser() 131 | args = parser.parse_args() 132 | args.base_size = 513 133 | args.crop_size = 513 134 | 135 | coco_val = COCOSegmentation(args, split='val', year='2017') 136 | 137 | dataloader = DataLoader(coco_val, batch_size=4, shuffle=True, num_workers=0) 138 | 139 | for ii, sample in enumerate(dataloader): 140 | for jj in range(sample["image"].size()[0]): 141 | img = sample['image'].numpy() 142 | gt = sample['label'].numpy() 143 | tmp = np.array(gt[jj]).astype(np.uint8) 144 | segmap = decode_segmap(tmp, dataset='coco') 145 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 146 | img_tmp *= (0.229, 0.224, 0.225) 147 | img_tmp += (0.485, 0.456, 0.406) 148 | img_tmp *= 255.0 149 | img_tmp = img_tmp.astype(np.uint8) 150 | plt.figure() 151 | plt.title('display') 152 | plt.subplot(211) 153 | plt.imshow(img_tmp) 154 | plt.subplot(212) 155 | plt.imshow(segmap) 156 | 157 | if ii == 1: 158 | break 159 | 160 | plt.show(block=True) -------------------------------------------------------------------------------- /dataloaders/datasets/combine_dbs.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class CombineDBs(data.Dataset): 5 | NUM_CLASSES = 21 6 | def __init__(self, dataloaders, excluded=None): 7 | self.dataloaders = dataloaders 8 | self.excluded = excluded 9 | self.im_ids = [] 10 | 11 | # Combine object lists 12 | for dl in dataloaders: 13 | for elem in dl.im_ids: 14 | if elem not in self.im_ids: 15 | self.im_ids.append(elem) 16 | 17 | # Exclude 18 | if excluded: 19 | for dl in excluded: 20 | for elem in dl.im_ids: 21 | if elem in self.im_ids: 22 | self.im_ids.remove(elem) 23 | 24 | # Get object pointers 25 | self.cat_list = [] 26 | self.im_list = [] 27 | new_im_ids = [] 28 | num_images = 0 29 | for ii, dl in enumerate(dataloaders): 30 | for jj, curr_im_id in enumerate(dl.im_ids): 31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids): 32 | num_images += 1 33 | new_im_ids.append(curr_im_id) 34 | self.cat_list.append({'db_ii': ii, 'cat_ii': jj}) 35 | 36 | self.im_ids = new_im_ids 37 | print('Combined number of images: {:d}'.format(num_images)) 38 | 39 | def __getitem__(self, index): 40 | 41 | _db_ii = self.cat_list[index]["db_ii"] 42 | _cat_ii = self.cat_list[index]['cat_ii'] 43 | sample = self.dataloaders[_db_ii].__getitem__(_cat_ii) 44 | 45 | if 'meta' in sample.keys(): 46 | sample['meta']['db'] = str(self.dataloaders[_db_ii]) 47 | 48 | return sample 49 | 50 | def __len__(self): 51 | return len(self.cat_list) 52 | 53 | def __str__(self): 54 | include_db = [str(db) for db in self.dataloaders] 55 | exclude_db = [str(db) for db in self.excluded] 56 | return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db) 57 | 58 | 59 | if __name__ == "__main__": 60 | import matplotlib.pyplot as plt 61 | from dataloaders.datasets import pascal, sbd 62 | from dataloaders import sbd 63 | import torch 64 | import numpy as np 65 | from dataloaders.utils import decode_segmap 66 | import argparse 67 | 68 | parser = argparse.ArgumentParser() 69 | args = parser.parse_args() 70 | args.base_size = 513 71 | args.crop_size = 513 72 | 73 | pascal_voc_val = pascal.VOCSegmentation(args, split='val') 74 | sbd = sbd.SBDSegmentation(args, split=['train', 'val']) 75 | pascal_voc_train = pascal.VOCSegmentation(args, split='train') 76 | 77 | dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val]) 78 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0) 79 | 80 | for ii, sample in enumerate(dataloader): 81 | for jj in range(sample["image"].size()[0]): 82 | img = sample['image'].numpy() 83 | gt = sample['label'].numpy() 84 | tmp = np.array(gt[jj]).astype(np.uint8) 85 | segmap = decode_segmap(tmp, dataset='pascal') 86 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 87 | img_tmp *= (0.229, 0.224, 0.225) 88 | img_tmp += (0.485, 0.456, 0.406) 89 | img_tmp *= 255.0 90 | img_tmp = img_tmp.astype(np.uint8) 91 | plt.figure() 92 | plt.title('display') 93 | plt.subplot(211) 94 | plt.imshow(img_tmp) 95 | plt.subplot(212) 96 | plt.imshow(segmap) 97 | 98 | if ii == 1: 99 | break 100 | plt.show(block=True) -------------------------------------------------------------------------------- /dataloaders/datasets/pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from mypath import Path 7 | from torchvision import transforms 8 | from dataloaders import custom_transforms as tr 9 | 10 | class VOCSegmentation(Dataset): 11 | """ 12 | PascalVoc dataset 13 | """ 14 | NUM_CLASSES = 21 15 | 16 | def __init__(self, 17 | args, 18 | base_dir=Path.db_root_dir('pascal'), 19 | split='train', 20 | ): 21 | """ 22 | :param base_dir: path to VOC dataset directory 23 | :param split: train/val 24 | :param transform: transform to apply 25 | """ 26 | super().__init__() 27 | self._base_dir = base_dir 28 | self._image_dir = os.path.join(self._base_dir, 'JPEGImages') 29 | self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass') 30 | 31 | if isinstance(split, str): 32 | self.split = [split] 33 | else: 34 | split.sort() 35 | self.split = split 36 | 37 | self.args = args 38 | 39 | _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation') 40 | 41 | self.im_ids = [] 42 | self.images = [] 43 | self.categories = [] 44 | 45 | for splt in self.split: 46 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: 47 | lines = f.read().splitlines() 48 | 49 | for ii, line in enumerate(lines): 50 | _image = os.path.join(self._image_dir, line + ".jpg") 51 | _cat = os.path.join(self._cat_dir, line + ".png") 52 | assert os.path.isfile(_image) 53 | assert os.path.isfile(_cat) 54 | self.im_ids.append(line) 55 | self.images.append(_image) 56 | self.categories.append(_cat) 57 | 58 | assert (len(self.images) == len(self.categories)) 59 | 60 | # Display stats 61 | print('Number of images in {}: {:d}'.format(split, len(self.images))) 62 | 63 | def __len__(self): 64 | return len(self.images) 65 | 66 | 67 | def __getitem__(self, index): 68 | _img, _target = self._make_img_gt_point_pair(index) 69 | sample = {'image': _img, 'label': _target} 70 | 71 | for split in self.split: 72 | if split == "train": 73 | return self.transform_tr(sample) 74 | elif split == 'val': 75 | return self.transform_val(sample) 76 | 77 | 78 | def _make_img_gt_point_pair(self, index): 79 | _img = Image.open(self.images[index]).convert('RGB') 80 | _target = Image.open(self.categories[index]) 81 | 82 | return _img, _target 83 | 84 | def transform_tr(self, sample): 85 | composed_transforms = transforms.Compose([ 86 | tr.RandomHorizontalFlip(), 87 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 88 | tr.RandomGaussianBlur(), 89 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 90 | tr.ToTensor()]) 91 | 92 | return composed_transforms(sample) 93 | 94 | def transform_val(self, sample): 95 | 96 | composed_transforms = transforms.Compose([ 97 | tr.FixScaleCrop(crop_size=self.args.crop_size), 98 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 99 | tr.ToTensor()]) 100 | 101 | return composed_transforms(sample) 102 | 103 | def __str__(self): 104 | return 'VOC2012(split=' + str(self.split) + ')' 105 | 106 | 107 | if __name__ == '__main__': 108 | from dataloaders.utils import decode_segmap 109 | from torch.utils.data import DataLoader 110 | import matplotlib.pyplot as plt 111 | import argparse 112 | 113 | parser = argparse.ArgumentParser() 114 | args = parser.parse_args() 115 | args.base_size = 513 116 | args.crop_size = 513 117 | 118 | voc_train = VOCSegmentation(args, split='train') 119 | 120 | dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0) 121 | 122 | for ii, sample in enumerate(dataloader): 123 | for jj in range(sample["image"].size()[0]): 124 | img = sample['image'].numpy() 125 | gt = sample['label'].numpy() 126 | tmp = np.array(gt[jj]).astype(np.uint8) 127 | segmap = decode_segmap(tmp, dataset='pascal') 128 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 129 | img_tmp *= (0.229, 0.224, 0.225) 130 | img_tmp += (0.485, 0.456, 0.406) 131 | img_tmp *= 255.0 132 | img_tmp = img_tmp.astype(np.uint8) 133 | plt.figure() 134 | plt.title('display') 135 | plt.subplot(211) 136 | plt.imshow(img_tmp) 137 | plt.subplot(212) 138 | plt.imshow(segmap) 139 | 140 | if ii == 1: 141 | break 142 | 143 | plt.show(block=True) 144 | 145 | 146 | -------------------------------------------------------------------------------- /dataloaders/datasets/sbd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | 4 | import numpy as np 5 | import scipy.io 6 | import torch.utils.data as data 7 | from PIL import Image 8 | from mypath import Path 9 | 10 | from torchvision import transforms 11 | from dataloaders import custom_transforms as tr 12 | 13 | class SBDSegmentation(data.Dataset): 14 | NUM_CLASSES = 21 15 | 16 | def __init__(self, 17 | args, 18 | base_dir=Path.db_root_dir('sbd'), 19 | split='train', 20 | ): 21 | """ 22 | :param base_dir: path to VOC dataset directory 23 | :param split: train/val 24 | :param transform: transform to apply 25 | """ 26 | super().__init__() 27 | self._base_dir = base_dir 28 | self._dataset_dir = os.path.join(self._base_dir, 'dataset') 29 | self._image_dir = os.path.join(self._dataset_dir, 'img') 30 | self._cat_dir = os.path.join(self._dataset_dir, 'cls') 31 | 32 | 33 | if isinstance(split, str): 34 | self.split = [split] 35 | else: 36 | split.sort() 37 | self.split = split 38 | 39 | self.args = args 40 | 41 | # Get list of all images from the split and check that the files exist 42 | self.im_ids = [] 43 | self.images = [] 44 | self.categories = [] 45 | for splt in self.split: 46 | with open(os.path.join(self._dataset_dir, splt + '.txt'), "r") as f: 47 | lines = f.read().splitlines() 48 | 49 | for line in lines: 50 | _image = os.path.join(self._image_dir, line + ".jpg") 51 | _categ= os.path.join(self._cat_dir, line + ".mat") 52 | assert os.path.isfile(_image) 53 | assert os.path.isfile(_categ) 54 | self.im_ids.append(line) 55 | self.images.append(_image) 56 | self.categories.append(_categ) 57 | 58 | assert (len(self.images) == len(self.categories)) 59 | 60 | # Display stats 61 | print('Number of images: {:d}'.format(len(self.images))) 62 | 63 | 64 | def __getitem__(self, index): 65 | _img, _target = self._make_img_gt_point_pair(index) 66 | sample = {'image': _img, 'label': _target} 67 | 68 | return self.transform(sample) 69 | 70 | def __len__(self): 71 | return len(self.images) 72 | 73 | def _make_img_gt_point_pair(self, index): 74 | _img = Image.open(self.images[index]).convert('RGB') 75 | _target = Image.fromarray(scipy.io.loadmat(self.categories[index])["GTcls"][0]['Segmentation'][0]) 76 | 77 | return _img, _target 78 | 79 | def transform(self, sample): 80 | composed_transforms = transforms.Compose([ 81 | tr.RandomHorizontalFlip(), 82 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 83 | tr.RandomGaussianBlur(), 84 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 85 | tr.ToTensor()]) 86 | 87 | return composed_transforms(sample) 88 | 89 | 90 | def __str__(self): 91 | return 'SBDSegmentation(split=' + str(self.split) + ')' 92 | 93 | 94 | if __name__ == '__main__': 95 | from dataloaders.utils import decode_segmap 96 | from torch.utils.data import DataLoader 97 | import matplotlib.pyplot as plt 98 | import argparse 99 | 100 | parser = argparse.ArgumentParser() 101 | args = parser.parse_args() 102 | args.base_size = 513 103 | args.crop_size = 513 104 | 105 | sbd_train = SBDSegmentation(args, split='train') 106 | dataloader = DataLoader(sbd_train, batch_size=2, shuffle=True, num_workers=2) 107 | 108 | for ii, sample in enumerate(dataloader): 109 | for jj in range(sample["image"].size()[0]): 110 | img = sample['image'].numpy() 111 | gt = sample['label'].numpy() 112 | tmp = np.array(gt[jj]).astype(np.uint8) 113 | segmap = decode_segmap(tmp, dataset='pascal') 114 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 115 | img_tmp *= (0.229, 0.224, 0.225) 116 | img_tmp += (0.485, 0.456, 0.406) 117 | img_tmp *= 255.0 118 | img_tmp = img_tmp.astype(np.uint8) 119 | plt.figure() 120 | plt.title('display') 121 | plt.subplot(211) 122 | plt.imshow(img_tmp) 123 | plt.subplot(212) 124 | plt.imshow(segmap) 125 | 126 | if ii == 1: 127 | break 128 | 129 | plt.show(block=True) -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 6 | rgb_masks = [] 7 | for label_mask in label_masks: 8 | rgb_mask = decode_segmap(label_mask, dataset) 9 | rgb_masks.append(rgb_mask) 10 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 11 | return rgb_masks 12 | 13 | 14 | def decode_segmap(label_mask, dataset, plot=False): 15 | """Decode segmentation class labels into a color image 16 | Args: 17 | label_mask (np.ndarray): an (M,N) array of integer values denoting 18 | the class label at each spatial location. 19 | plot (bool, optional): whether to show the resulting color image 20 | in a figure. 21 | Returns: 22 | (np.ndarray, optional): the resulting decoded color image. 23 | """ 24 | if dataset == 'pascal' or dataset == 'coco': 25 | n_classes = 21 26 | label_colours = get_pascal_labels() 27 | elif dataset == 'cityscapes': 28 | n_classes = 19 29 | label_colours = get_cityscapes_labels() 30 | else: 31 | raise NotImplementedError 32 | 33 | r = label_mask.copy() 34 | g = label_mask.copy() 35 | b = label_mask.copy() 36 | for ll in range(0, n_classes): 37 | r[label_mask == ll] = label_colours[ll, 0] 38 | g[label_mask == ll] = label_colours[ll, 1] 39 | b[label_mask == ll] = label_colours[ll, 2] 40 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 41 | rgb[:, :, 0] = r / 255.0 42 | rgb[:, :, 1] = g / 255.0 43 | rgb[:, :, 2] = b / 255.0 44 | if plot: 45 | plt.imshow(rgb) 46 | plt.show() 47 | else: 48 | return rgb 49 | 50 | 51 | def encode_segmap(mask): 52 | """Encode segmentation label images as pascal classes 53 | Args: 54 | mask (np.ndarray): raw segmentation label image of dimension 55 | (M, N, 3), in which the Pascal classes are encoded as colours. 56 | Returns: 57 | (np.ndarray): class map with dimensions (M,N), where the value at 58 | a given location is the integer denoting the class index. 59 | """ 60 | mask = mask.astype(int) 61 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 62 | for ii, label in enumerate(get_pascal_labels()): 63 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 64 | label_mask = label_mask.astype(int) 65 | return label_mask 66 | 67 | 68 | def get_cityscapes_labels(): 69 | return np.array([ 70 | [128, 64, 128], 71 | [244, 35, 232], 72 | [70, 70, 70], 73 | [102, 102, 156], 74 | [190, 153, 153], 75 | [153, 153, 153], 76 | [250, 170, 30], 77 | [220, 220, 0], 78 | [107, 142, 35], 79 | [152, 251, 152], 80 | [0, 130, 180], 81 | [220, 20, 60], 82 | [255, 0, 0], 83 | [0, 0, 142], 84 | [0, 0, 70], 85 | [0, 60, 100], 86 | [0, 80, 100], 87 | [0, 0, 230], 88 | [119, 11, 32]]) 89 | 90 | 91 | def get_pascal_labels(): 92 | """Load the mapping that associates pascal classes with label colors 93 | Returns: 94 | np.ndarray with dimensions (21, 3) 95 | """ 96 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 97 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 98 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 99 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 100 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 101 | [0, 64, 128]]) -------------------------------------------------------------------------------- /doc/deeplab_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | BatchNorm2d = SynchronizedBatchNorm2d 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 14 | super(Bottleneck, self).__init__() 15 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 16 | self.bn1 = BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 18 | dilation=dilation, padding=dilation, bias=False) 19 | self.bn2 = BatchNorm2d(planes) 20 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 21 | self.bn3 = BatchNorm2d(planes * 4) 22 | self.relu = nn.ReLU(inplace=True) 23 | self.downsample = downsample 24 | self.stride = stride 25 | self.dilation = dilation 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv3(out) 39 | out = self.bn3(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | class ResNet(nn.Module): 50 | 51 | def __init__(self, nInputChannels, block, layers, os=16, pretrained=False): 52 | self.inplanes = 64 53 | super(ResNet, self).__init__() 54 | if os == 16: 55 | strides = [1, 2, 2, 1] 56 | dilations = [1, 1, 1, 2] 57 | blocks = [1, 2, 4] 58 | elif os == 8: 59 | strides = [1, 2, 1, 1] 60 | dilations = [1, 1, 2, 2] 61 | blocks = [1, 2, 1] 62 | else: 63 | raise NotImplementedError 64 | 65 | # Modules 66 | self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3, 67 | bias=False) 68 | self.bn1 = BatchNorm2d(64) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 71 | 72 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0]) 73 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1]) 74 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2]) 75 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3]) 76 | 77 | self._init_weight() 78 | 79 | if pretrained: 80 | self._load_pretrained_model() 81 | 82 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 83 | downsample = None 84 | if stride != 1 or self.inplanes != planes * block.expansion: 85 | downsample = nn.Sequential( 86 | nn.Conv2d(self.inplanes, planes * block.expansion, 87 | kernel_size=1, stride=stride, bias=False), 88 | BatchNorm2d(planes * block.expansion), 89 | ) 90 | 91 | layers = [] 92 | layers.append(block(self.inplanes, planes, stride, dilation, downsample)) 93 | self.inplanes = planes * block.expansion 94 | for i in range(1, blocks): 95 | layers.append(block(self.inplanes, planes)) 96 | 97 | return nn.Sequential(*layers) 98 | 99 | def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, dilation=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | downsample = nn.Sequential( 103 | nn.Conv2d(self.inplanes, planes * block.expansion, 104 | kernel_size=1, stride=stride, bias=False), 105 | BatchNorm2d(planes * block.expansion), 106 | ) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, downsample=downsample)) 110 | self.inplanes = planes * block.expansion 111 | for i in range(1, len(blocks)): 112 | layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i]*dilation)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, input): 117 | x = self.conv1(input) 118 | x = self.bn1(x) 119 | x = self.relu(x) 120 | x = self.maxpool(x) 121 | 122 | x = self.layer1(x) 123 | low_level_feat = x 124 | x = self.layer2(x) 125 | x = self.layer3(x) 126 | x = self.layer4(x) 127 | return x, low_level_feat 128 | 129 | def _init_weight(self): 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | elif isinstance(m, BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _load_pretrained_model(self): 139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 140 | model_dict = {} 141 | state_dict = self.state_dict() 142 | for k, v in pretrain_dict.items(): 143 | if k in state_dict: 144 | model_dict[k] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def ResNet101(nInputChannels=3, os=16, pretrained=False): 149 | model = ResNet(nInputChannels, Bottleneck, [3, 4, 23, 3], os, pretrained=pretrained) 150 | return model 151 | 152 | 153 | class ASPP_module(nn.Module): 154 | def __init__(self, inplanes, planes, dilation): 155 | super(ASPP_module, self).__init__() 156 | if dilation == 1: 157 | kernel_size = 1 158 | padding = 0 159 | else: 160 | kernel_size = 3 161 | padding = dilation 162 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 163 | stride=1, padding=padding, dilation=dilation, bias=False) 164 | self.bn = BatchNorm2d(planes) 165 | self.relu = nn.ReLU() 166 | 167 | self._init_weight() 168 | 169 | def forward(self, x): 170 | x = self.atrous_convolution(x) 171 | x = self.bn(x) 172 | 173 | return self.relu(x) 174 | 175 | def _init_weight(self): 176 | for m in self.modules(): 177 | if isinstance(m, nn.Conv2d): 178 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 179 | m.weight.data.normal_(0, math.sqrt(2. / n)) 180 | elif isinstance(m, BatchNorm2d): 181 | m.weight.data.fill_(1) 182 | m.bias.data.zero_() 183 | 184 | 185 | class DeepLabv3_plus(nn.Module): 186 | def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, freeze_bn=False, _print=True): 187 | if _print: 188 | print("Constructing DeepLabv3+ model...") 189 | print("Backbone: Resnet-101") 190 | print("Number of classes: {}".format(n_classes)) 191 | print("Output stride: {}".format(os)) 192 | print("Number of Input Channels: {}".format(nInputChannels)) 193 | super(DeepLabv3_plus, self).__init__() 194 | 195 | # Atrous Conv 196 | self.resnet_features = ResNet101(nInputChannels, os, pretrained=pretrained) 197 | 198 | # ASPP 199 | if os == 16: 200 | dilations = [1, 6, 12, 18] 201 | elif os == 8: 202 | dilations = [1, 12, 24, 36] 203 | else: 204 | raise NotImplementedError 205 | 206 | self.aspp1 = ASPP_module(2048, 256, dilation=dilations[0]) 207 | self.aspp2 = ASPP_module(2048, 256, dilation=dilations[1]) 208 | self.aspp3 = ASPP_module(2048, 256, dilation=dilations[2]) 209 | self.aspp4 = ASPP_module(2048, 256, dilation=dilations[3]) 210 | 211 | self.relu = nn.ReLU() 212 | 213 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 214 | nn.Conv2d(2048, 256, 1, stride=1, bias=False), 215 | BatchNorm2d(256), 216 | nn.ReLU()) 217 | 218 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 219 | self.bn1 = BatchNorm2d(256) 220 | 221 | # adopt [1x1, 48] for channel reduction. 222 | self.conv2 = nn.Conv2d(256, 48, 1, bias=False) 223 | self.bn2 = BatchNorm2d(48) 224 | 225 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 226 | BatchNorm2d(256), 227 | nn.ReLU(), 228 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 229 | BatchNorm2d(256), 230 | nn.ReLU(), 231 | nn.Conv2d(256, n_classes, kernel_size=1, stride=1)) 232 | if freeze_bn: 233 | self._freeze_bn() 234 | 235 | def forward(self, input): 236 | x, low_level_features = self.resnet_features(input) 237 | x1 = self.aspp1(x) 238 | x2 = self.aspp2(x) 239 | x3 = self.aspp3(x) 240 | x4 = self.aspp4(x) 241 | x5 = self.global_avg_pool(x) 242 | x5 = F.upsample(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 243 | 244 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 245 | 246 | x = self.conv1(x) 247 | x = self.bn1(x) 248 | x = self.relu(x) 249 | x = F.upsample(x, size=(int(math.ceil(input.size()[-2]/4)), 250 | int(math.ceil(input.size()[-1]/4))), mode='bilinear', align_corners=True) 251 | 252 | low_level_features = self.conv2(low_level_features) 253 | low_level_features = self.bn2(low_level_features) 254 | low_level_features = self.relu(low_level_features) 255 | 256 | 257 | x = torch.cat((x, low_level_features), dim=1) 258 | x = self.last_conv(x) 259 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 260 | 261 | return x 262 | 263 | def _freeze_bn(self): 264 | for m in self.modules(): 265 | if isinstance(m, BatchNorm2d): 266 | m.eval() 267 | 268 | def _init_weight(self): 269 | for m in self.modules(): 270 | if isinstance(m, nn.Conv2d): 271 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 272 | m.weight.data.normal_(0, math.sqrt(2. / n)) 273 | elif isinstance(m, BatchNorm2d): 274 | m.weight.data.fill_(1) 275 | m.bias.data.zero_() 276 | 277 | def get_1x_lr_params(model): 278 | """ 279 | This generator returns all the parameters of the net except for 280 | the last classification layer. Note that for each batchnorm layer, 281 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 282 | any batchnorm parameter 283 | """ 284 | b = [model.resnet_features] 285 | for i in range(len(b)): 286 | for k in b[i].parameters(): 287 | if k.requires_grad: 288 | yield k 289 | 290 | 291 | def get_10x_lr_params(model): 292 | """ 293 | This generator returns all the parameters for the last layer of the net, 294 | which does the classification of pixel into classes 295 | """ 296 | b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv] 297 | for j in range(len(b)): 298 | for k in b[j].parameters(): 299 | if k.requires_grad: 300 | yield k 301 | 302 | 303 | if __name__ == "__main__": 304 | model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=True, _print=True) 305 | model.eval() 306 | image = torch.randn(1, 3, 512, 512) 307 | with torch.no_grad(): 308 | output = model.forward(image) 309 | print(output.size()) 310 | 311 | 312 | 313 | 314 | 315 | 316 | -------------------------------------------------------------------------------- /doc/deeplab_xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | BatchNorm2d = SynchronizedBatchNorm2d 9 | 10 | class SeparableConv2d(nn.Module): 11 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False): 12 | super(SeparableConv2d, self)._init_() 13 | 14 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, 15 | groups=inplanes, bias=bias) 16 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 17 | 18 | def forward(self, x): 19 | x = self.conv1(x) 20 | x = self.pointwise(x) 21 | return x 22 | 23 | 24 | def fixed_padding(inputs, kernel_size, dilation): 25 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 26 | pad_total = kernel_size_effective - 1 27 | pad_beg = pad_total // 2 28 | pad_end = pad_total - pad_beg 29 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 30 | return padded_inputs 31 | 32 | 33 | class SeparableConv2d_same(nn.Module): 34 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False): 35 | super(SeparableConv2d_same, self).__init__() 36 | 37 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 38 | groups=inplanes, bias=bias) 39 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 40 | 41 | def forward(self, x): 42 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 43 | x = self.conv1(x) 44 | x = self.pointwise(x) 45 | return x 46 | 47 | 48 | class Block(nn.Module): 49 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False): 50 | super(Block, self).__init__() 51 | 52 | if planes != inplanes or stride != 1: 53 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 54 | self.skipbn = BatchNorm2d(planes) 55 | else: 56 | self.skip = None 57 | 58 | self.relu = nn.ReLU(inplace=True) 59 | rep = [] 60 | 61 | filters = inplanes 62 | if grow_first: 63 | rep.append(self.relu) 64 | rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) 65 | rep.append(BatchNorm2d(planes)) 66 | filters = planes 67 | 68 | for i in range(reps - 1): 69 | rep.append(self.relu) 70 | rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) 71 | rep.append(BatchNorm2d(filters)) 72 | 73 | if not grow_first: 74 | rep.append(self.relu) 75 | rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) 76 | rep.append(BatchNorm2d(planes)) 77 | 78 | if not start_with_relu: 79 | rep = rep[1:] 80 | 81 | if stride != 1: 82 | rep.append(SeparableConv2d_same(planes, planes, 3, stride=2)) 83 | 84 | if stride == 1 and is_last: 85 | rep.append(SeparableConv2d_same(planes, planes, 3, stride=1)) 86 | 87 | 88 | self.rep = nn.Sequential(*rep) 89 | 90 | def forward(self, inp): 91 | x = self.rep(inp) 92 | 93 | if self.skip is not None: 94 | skip = self.skip(inp) 95 | skip = self.skipbn(skip) 96 | else: 97 | skip = inp 98 | 99 | x += skip 100 | 101 | return x 102 | 103 | 104 | class Xception(nn.Module): 105 | """ 106 | Modified Alighed Xception 107 | """ 108 | def __init__(self, inplanes=3, os=16, pretrained=False): 109 | super(Xception, self).__init__() 110 | 111 | if os == 16: 112 | entry_block3_stride = 2 113 | middle_block_dilation = 1 114 | exit_block_dilations = (1, 2) 115 | elif os == 8: 116 | entry_block3_stride = 1 117 | middle_block_dilation = 2 118 | exit_block_dilations = (2, 4) 119 | else: 120 | raise NotImplementedError 121 | 122 | 123 | # Entry flow 124 | self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False) 125 | self.bn1 = BatchNorm2d(32) 126 | self.relu = nn.ReLU(inplace=True) 127 | 128 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 129 | self.bn2 = BatchNorm2d(64) 130 | 131 | self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False) 132 | self.block2 = Block(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True) 133 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True, 134 | is_last=True) 135 | 136 | # Middle flow 137 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 138 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 139 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 140 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 141 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 142 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 143 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 144 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 146 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 147 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 148 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 149 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 150 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 151 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 152 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 153 | 154 | # Exit flow 155 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 156 | start_with_relu=True, grow_first=False, is_last=True) 157 | 158 | self.conv3 = SeparableConv2d_same(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1]) 159 | self.bn3 = BatchNorm2d(1536) 160 | 161 | self.conv4 = SeparableConv2d_same(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1]) 162 | self.bn4 = BatchNorm2d(1536) 163 | 164 | self.conv5 = SeparableConv2d_same(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1]) 165 | self.bn5 = BatchNorm2d(2048) 166 | 167 | # Init weights 168 | self._init_weight() 169 | 170 | # Load pretrained model 171 | if pretrained: 172 | self._load_xception_pretrained() 173 | 174 | def forward(self, x): 175 | # Entry flow 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | 180 | x = self.conv2(x) 181 | x = self.bn2(x) 182 | x = self.relu(x) 183 | 184 | x = self.block1(x) 185 | low_level_feat = x 186 | x = self.block2(x) 187 | x = self.block3(x) 188 | 189 | # Middle flow 190 | x = self.block4(x) 191 | x = self.block5(x) 192 | x = self.block6(x) 193 | x = self.block7(x) 194 | x = self.block8(x) 195 | x = self.block9(x) 196 | x = self.block10(x) 197 | x = self.block11(x) 198 | x = self.block12(x) 199 | x = self.block13(x) 200 | x = self.block14(x) 201 | x = self.block15(x) 202 | x = self.block16(x) 203 | x = self.block17(x) 204 | x = self.block18(x) 205 | x = self.block19(x) 206 | 207 | # Exit flow 208 | x = self.block20(x) 209 | x = self.conv3(x) 210 | x = self.bn3(x) 211 | x = self.relu(x) 212 | 213 | x = self.conv4(x) 214 | x = self.bn4(x) 215 | x = self.relu(x) 216 | 217 | x = self.conv5(x) 218 | x = self.bn5(x) 219 | x = self.relu(x) 220 | 221 | return x, low_level_feat 222 | 223 | def _init_weight(self): 224 | for m in self.modules(): 225 | if isinstance(m, nn.Conv2d): 226 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 227 | m.weight.data.normal_(0, math.sqrt(2. / n)) 228 | elif isinstance(m, BatchNorm2d): 229 | m.weight.data.fill_(1) 230 | m.bias.data.zero_() 231 | 232 | def _load_xception_pretrained(self): 233 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 234 | model_dict = {} 235 | state_dict = self.state_dict() 236 | 237 | for k, v in pretrain_dict.items(): 238 | if k in model_dict: 239 | if 'pointwise' in k: 240 | v = v.unsqueeze(-1).unsqueeze(-1) 241 | if k.startswith('block11'): 242 | model_dict[k] = v 243 | model_dict[k.replace('block11', 'block12')] = v 244 | model_dict[k.replace('block11', 'block13')] = v 245 | model_dict[k.replace('block11', 'block14')] = v 246 | model_dict[k.replace('block11', 'block15')] = v 247 | model_dict[k.replace('block11', 'block16')] = v 248 | model_dict[k.replace('block11', 'block17')] = v 249 | model_dict[k.replace('block11', 'block18')] = v 250 | model_dict[k.replace('block11', 'block19')] = v 251 | elif k.startswith('block12'): 252 | model_dict[k.replace('block12', 'block20')] = v 253 | elif k.startswith('bn3'): 254 | model_dict[k] = v 255 | model_dict[k.replace('bn3', 'bn4')] = v 256 | elif k.startswith('conv4'): 257 | model_dict[k.replace('conv4', 'conv5')] = v 258 | elif k.startswith('bn4'): 259 | model_dict[k.replace('bn4', 'bn5')] = v 260 | else: 261 | model_dict[k] = v 262 | state_dict.update(model_dict) 263 | self.load_state_dict(state_dict) 264 | 265 | class ASPP_module(nn.Module): 266 | def __init__(self, inplanes, planes, dilation): 267 | super(ASPP_module, self).__init__() 268 | if dilation == 1: 269 | kernel_size = 1 270 | padding = 0 271 | else: 272 | kernel_size = 3 273 | padding = dilation 274 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 275 | stride=1, padding=padding, dilation=dilation, bias=False) 276 | self.bn = BatchNorm2d(planes) 277 | self.relu = nn.ReLU() 278 | 279 | self._init_weight() 280 | 281 | def forward(self, x): 282 | x = self.atrous_convolution(x) 283 | x = self.bn(x) 284 | 285 | return self.relu(x) 286 | 287 | def _init_weight(self): 288 | for m in self.modules(): 289 | if isinstance(m, nn.Conv2d): 290 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 291 | m.weight.data.normal_(0, math.sqrt(2. / n)) 292 | elif isinstance(m, BatchNorm2d): 293 | m.weight.data.fill_(1) 294 | m.bias.data.zero_() 295 | 296 | 297 | class DeepLabv3_plus(nn.Module): 298 | def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, freeze_bn=False, _print=True): 299 | if _print: 300 | print("Constructing DeepLabv3+ model...") 301 | print("Backbone: Xception") 302 | print("Number of classes: {}".format(n_classes)) 303 | print("Output stride: {}".format(os)) 304 | print("Number of Input Channels: {}".format(nInputChannels)) 305 | super(DeepLabv3_plus, self).__init__() 306 | 307 | # Atrous Conv 308 | self.xception_features = Xception(nInputChannels, os, pretrained) 309 | 310 | # ASPP 311 | if os == 16: 312 | dilations = [1, 6, 12, 18] 313 | elif os == 8: 314 | dilations = [1, 12, 24, 36] 315 | else: 316 | raise NotImplementedError 317 | 318 | self.aspp1 = ASPP_module(2048, 256, dilation=dilations[0]) 319 | self.aspp2 = ASPP_module(2048, 256, dilation=dilations[1]) 320 | self.aspp3 = ASPP_module(2048, 256, dilation=dilations[2]) 321 | self.aspp4 = ASPP_module(2048, 256, dilation=dilations[3]) 322 | 323 | self.relu = nn.ReLU() 324 | 325 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 326 | nn.Conv2d(2048, 256, 1, stride=1, bias=False), 327 | BatchNorm2d(256), 328 | nn.ReLU()) 329 | 330 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 331 | self.bn1 = BatchNorm2d(256) 332 | 333 | # adopt [1x1, 48] for channel reduction. 334 | self.conv2 = nn.Conv2d(128, 48, 1, bias=False) 335 | self.bn2 = BatchNorm2d(48) 336 | 337 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 338 | BatchNorm2d(256), 339 | nn.ReLU(), 340 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 341 | BatchNorm2d(256), 342 | nn.ReLU(), 343 | nn.Conv2d(256, n_classes, kernel_size=1, stride=1)) 344 | if freeze_bn: 345 | self._freeze_bn() 346 | 347 | def forward(self, input): 348 | x, low_level_features = self.xception_features(input) 349 | x1 = self.aspp1(x) 350 | x2 = self.aspp2(x) 351 | x3 = self.aspp3(x) 352 | x4 = self.aspp4(x) 353 | x5 = self.global_avg_pool(x) 354 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 355 | 356 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 357 | 358 | x = self.conv1(x) 359 | x = self.bn1(x) 360 | x = self.relu(x) 361 | x = F.interpolate(x, size=(int(math.ceil(input.size()[-2]/4)), 362 | int(math.ceil(input.size()[-1]/4))), mode='bilinear', align_corners=True) 363 | 364 | low_level_features = self.conv2(low_level_features) 365 | low_level_features = self.bn2(low_level_features) 366 | low_level_features = self.relu(low_level_features) 367 | 368 | 369 | x = torch.cat((x, low_level_features), dim=1) 370 | x = self.last_conv(x) 371 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 372 | 373 | return x 374 | 375 | def _freeze_bn(self): 376 | for m in self.modules(): 377 | if isinstance(m, BatchNorm2d): 378 | m.eval() 379 | 380 | def _init_weight(self): 381 | for m in self.modules(): 382 | if isinstance(m, nn.Conv2d): 383 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 384 | m.weight.data.normal_(0, math.sqrt(2. / n)) 385 | elif isinstance(m, BatchNorm2d): 386 | m.weight.data.fill_(1) 387 | m.bias.data.zero_() 388 | 389 | def get_1x_lr_params(model): 390 | """ 391 | This generator returns all the parameters of the net except for 392 | the last classification layer. Note that for each batchnorm layer, 393 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 394 | any batchnorm parameter 395 | """ 396 | b = [model.xception_features] 397 | for i in range(len(b)): 398 | for k in b[i].parameters(): 399 | if k.requires_grad: 400 | yield k 401 | 402 | 403 | def get_10x_lr_params(model): 404 | """ 405 | This generator returns all the parameters for the last layer of the net, 406 | which does the classification of pixel into classes 407 | """ 408 | b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv] 409 | for j in range(len(b)): 410 | for k in b[j].parameters(): 411 | if k.requires_grad: 412 | yield k 413 | 414 | 415 | if __name__ == "__main__": 416 | model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=True, _print=True) 417 | model.eval() 418 | image = torch.randn(1, 3, 512, 512) 419 | with torch.no_grad(): 420 | output = model.forward(image) 421 | print(output.size()) 422 | 423 | 424 | 425 | -------------------------------------------------------------------------------- /doc/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfzhang95/pytorch-deeplab-xception/9135e104a7a51ea9effa9c6676a2fcffe6a6a2e6/doc/results.png -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jfzhang95/pytorch-deeplab-xception/9135e104a7a51ea9effa9c6676a2fcffe6a6a2e6/modeling/__init__.py -------------------------------------------------------------------------------- /modeling/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from modeling.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /modeling/backbone/drn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | webroot = 'http://dl.yf.io/drn/' 7 | 8 | model_urls = { 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth', 11 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth', 12 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth', 13 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth', 14 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth', 15 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth', 16 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth' 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=padding, bias=False, dilation=dilation) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, 29 | dilation=(1, 1), residual=True, BatchNorm=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride, 32 | padding=dilation[0], dilation=dilation[0]) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes, 36 | padding=dilation[1], dilation=dilation[1]) 37 | self.bn2 = BatchNorm(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.residual = residual 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | if self.residual: 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, 65 | dilation=(1, 1), residual=True, BatchNorm=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = BatchNorm(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=dilation[1], bias=False, 71 | dilation=dilation[1]) 72 | self.bn2 = BatchNorm(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = BatchNorm(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class DRN(nn.Module): 103 | 104 | def __init__(self, block, layers, arch='D', 105 | channels=(16, 32, 64, 128, 256, 512, 512, 512), 106 | BatchNorm=None): 107 | super(DRN, self).__init__() 108 | self.inplanes = channels[0] 109 | self.out_dim = channels[-1] 110 | self.arch = arch 111 | 112 | if arch == 'C': 113 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 114 | padding=3, bias=False) 115 | self.bn1 = BatchNorm(channels[0]) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | self.layer1 = self._make_layer( 119 | BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 120 | self.layer2 = self._make_layer( 121 | BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 122 | 123 | elif arch == 'D': 124 | self.layer0 = nn.Sequential( 125 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, 126 | bias=False), 127 | BatchNorm(channels[0]), 128 | nn.ReLU(inplace=True) 129 | ) 130 | 131 | self.layer1 = self._make_conv_layers( 132 | channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 133 | self.layer2 = self._make_conv_layers( 134 | channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 135 | 136 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm) 137 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm) 138 | self.layer5 = self._make_layer(block, channels[4], layers[4], 139 | dilation=2, new_level=False, BatchNorm=BatchNorm) 140 | self.layer6 = None if layers[5] == 0 else \ 141 | self._make_layer(block, channels[5], layers[5], dilation=4, 142 | new_level=False, BatchNorm=BatchNorm) 143 | 144 | if arch == 'C': 145 | self.layer7 = None if layers[6] == 0 else \ 146 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, 147 | new_level=False, residual=False, BatchNorm=BatchNorm) 148 | self.layer8 = None if layers[7] == 0 else \ 149 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, 150 | new_level=False, residual=False, BatchNorm=BatchNorm) 151 | elif arch == 'D': 152 | self.layer7 = None if layers[6] == 0 else \ 153 | self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm) 154 | self.layer8 = None if layers[7] == 0 else \ 155 | self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm) 156 | 157 | self._init_weight() 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, SynchronizedBatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 173 | new_level=True, residual=True, BatchNorm=None): 174 | assert dilation == 1 or dilation % 2 == 0 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | nn.Conv2d(self.inplanes, planes * block.expansion, 179 | kernel_size=1, stride=stride, bias=False), 180 | BatchNorm(planes * block.expansion), 181 | ) 182 | 183 | layers = list() 184 | layers.append(block( 185 | self.inplanes, planes, stride, downsample, 186 | dilation=(1, 1) if dilation == 1 else ( 187 | dilation // 2 if new_level else dilation, dilation), 188 | residual=residual, BatchNorm=BatchNorm)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, residual=residual, 192 | dilation=(dilation, dilation), BatchNorm=BatchNorm)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None): 197 | modules = [] 198 | for i in range(convs): 199 | modules.extend([ 200 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 201 | stride=stride if i == 0 else 1, 202 | padding=dilation, bias=False, dilation=dilation), 203 | BatchNorm(channels), 204 | nn.ReLU(inplace=True)]) 205 | self.inplanes = channels 206 | return nn.Sequential(*modules) 207 | 208 | def forward(self, x): 209 | if self.arch == 'C': 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | elif self.arch == 'D': 214 | x = self.layer0(x) 215 | 216 | x = self.layer1(x) 217 | x = self.layer2(x) 218 | 219 | x = self.layer3(x) 220 | low_level_feat = x 221 | 222 | x = self.layer4(x) 223 | x = self.layer5(x) 224 | 225 | if self.layer6 is not None: 226 | x = self.layer6(x) 227 | 228 | if self.layer7 is not None: 229 | x = self.layer7(x) 230 | 231 | if self.layer8 is not None: 232 | x = self.layer8(x) 233 | 234 | return x, low_level_feat 235 | 236 | 237 | class DRN_A(nn.Module): 238 | 239 | def __init__(self, block, layers, BatchNorm=None): 240 | self.inplanes = 64 241 | super(DRN_A, self).__init__() 242 | self.out_dim = 512 * block.expansion 243 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 244 | bias=False) 245 | self.bn1 = BatchNorm(64) 246 | self.relu = nn.ReLU(inplace=True) 247 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 248 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm) 249 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm) 250 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 251 | dilation=2, BatchNorm=BatchNorm) 252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 253 | dilation=4, BatchNorm=BatchNorm) 254 | 255 | self._init_weight() 256 | 257 | def _init_weight(self): 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 261 | m.weight.data.normal_(0, math.sqrt(2. / n)) 262 | elif isinstance(m, SynchronizedBatchNorm2d): 263 | m.weight.data.fill_(1) 264 | m.bias.data.zero_() 265 | elif isinstance(m, nn.BatchNorm2d): 266 | m.weight.data.fill_(1) 267 | m.bias.data.zero_() 268 | 269 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 270 | downsample = None 271 | if stride != 1 or self.inplanes != planes * block.expansion: 272 | downsample = nn.Sequential( 273 | nn.Conv2d(self.inplanes, planes * block.expansion, 274 | kernel_size=1, stride=stride, bias=False), 275 | BatchNorm(planes * block.expansion), 276 | ) 277 | 278 | layers = [] 279 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 280 | self.inplanes = planes * block.expansion 281 | for i in range(1, blocks): 282 | layers.append(block(self.inplanes, planes, 283 | dilation=(dilation, dilation, ), BatchNorm=BatchNorm)) 284 | 285 | return nn.Sequential(*layers) 286 | 287 | def forward(self, x): 288 | x = self.conv1(x) 289 | x = self.bn1(x) 290 | x = self.relu(x) 291 | x = self.maxpool(x) 292 | 293 | x = self.layer1(x) 294 | x = self.layer2(x) 295 | x = self.layer3(x) 296 | x = self.layer4(x) 297 | 298 | return x 299 | 300 | def drn_a_50(BatchNorm, pretrained=True): 301 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 304 | return model 305 | 306 | 307 | def drn_c_26(BatchNorm, pretrained=True): 308 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm) 309 | if pretrained: 310 | pretrained = model_zoo.load_url(model_urls['drn-c-26']) 311 | del pretrained['fc.weight'] 312 | del pretrained['fc.bias'] 313 | model.load_state_dict(pretrained) 314 | return model 315 | 316 | 317 | def drn_c_42(BatchNorm, pretrained=True): 318 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 319 | if pretrained: 320 | pretrained = model_zoo.load_url(model_urls['drn-c-42']) 321 | del pretrained['fc.weight'] 322 | del pretrained['fc.bias'] 323 | model.load_state_dict(pretrained) 324 | return model 325 | 326 | 327 | def drn_c_58(BatchNorm, pretrained=True): 328 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 329 | if pretrained: 330 | pretrained = model_zoo.load_url(model_urls['drn-c-58']) 331 | del pretrained['fc.weight'] 332 | del pretrained['fc.bias'] 333 | model.load_state_dict(pretrained) 334 | return model 335 | 336 | 337 | def drn_d_22(BatchNorm, pretrained=True): 338 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm) 339 | if pretrained: 340 | pretrained = model_zoo.load_url(model_urls['drn-d-22']) 341 | del pretrained['fc.weight'] 342 | del pretrained['fc.bias'] 343 | model.load_state_dict(pretrained) 344 | return model 345 | 346 | 347 | def drn_d_24(BatchNorm, pretrained=True): 348 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm) 349 | if pretrained: 350 | pretrained = model_zoo.load_url(model_urls['drn-d-24']) 351 | del pretrained['fc.weight'] 352 | del pretrained['fc.bias'] 353 | model.load_state_dict(pretrained) 354 | return model 355 | 356 | 357 | def drn_d_38(BatchNorm, pretrained=True): 358 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 359 | if pretrained: 360 | pretrained = model_zoo.load_url(model_urls['drn-d-38']) 361 | del pretrained['fc.weight'] 362 | del pretrained['fc.bias'] 363 | model.load_state_dict(pretrained) 364 | return model 365 | 366 | 367 | def drn_d_40(BatchNorm, pretrained=True): 368 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm) 369 | if pretrained: 370 | pretrained = model_zoo.load_url(model_urls['drn-d-40']) 371 | del pretrained['fc.weight'] 372 | del pretrained['fc.bias'] 373 | model.load_state_dict(pretrained) 374 | return model 375 | 376 | 377 | def drn_d_54(BatchNorm, pretrained=True): 378 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 379 | if pretrained: 380 | pretrained = model_zoo.load_url(model_urls['drn-d-54']) 381 | del pretrained['fc.weight'] 382 | del pretrained['fc.bias'] 383 | model.load_state_dict(pretrained) 384 | return model 385 | 386 | 387 | def drn_d_105(BatchNorm, pretrained=True): 388 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 389 | if pretrained: 390 | pretrained = model_zoo.load_url(model_urls['drn-d-105']) 391 | del pretrained['fc.weight'] 392 | del pretrained['fc.bias'] 393 | model.load_state_dict(pretrained) 394 | return model 395 | 396 | if __name__ == "__main__": 397 | import torch 398 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True) 399 | input = torch.rand(1, 3, 512, 512) 400 | output, low_level_feat = model(input) 401 | print(output.size()) 402 | print(low_level_feat.size()) 403 | -------------------------------------------------------------------------------- /modeling/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /modeling/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class ResNet(nn.Module): 46 | 47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 48 | self.inplanes = 64 49 | super(ResNet, self).__init__() 50 | blocks = [1, 2, 4] 51 | if output_stride == 16: 52 | strides = [1, 2, 2, 1] 53 | dilations = [1, 1, 1, 2] 54 | elif output_stride == 8: 55 | strides = [1, 2, 1, 1] 56 | dilations = [1, 1, 2, 4] 57 | else: 58 | raise NotImplementedError 59 | 60 | # Modules 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = BatchNorm(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 72 | self._init_weight() 73 | 74 | if pretrained: 75 | self._load_pretrained_model() 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = nn.Sequential( 81 | nn.Conv2d(self.inplanes, planes * block.expansion, 82 | kernel_size=1, stride=stride, bias=False), 83 | BatchNorm(planes * block.expansion), 84 | ) 85 | 86 | layers = [] 87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 88 | self.inplanes = planes * block.expansion 89 | for i in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 95 | downsample = None 96 | if stride != 1 or self.inplanes != planes * block.expansion: 97 | downsample = nn.Sequential( 98 | nn.Conv2d(self.inplanes, planes * block.expansion, 99 | kernel_size=1, stride=stride, bias=False), 100 | BatchNorm(planes * block.expansion), 101 | ) 102 | 103 | layers = [] 104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 105 | downsample=downsample, BatchNorm=BatchNorm)) 106 | self.inplanes = planes * block.expansion 107 | for i in range(1, len(blocks)): 108 | layers.append(block(self.inplanes, planes, stride=1, 109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | x = self.conv1(input) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | low_level_feat = x 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | return x, low_level_feat 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, SynchronizedBatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _load_pretrained_model(self): 139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 140 | model_dict = {} 141 | state_dict = self.state_dict() 142 | for k, v in pretrain_dict.items(): 143 | if k in state_dict: 144 | model_dict[k] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def ResNet101(output_stride, BatchNorm, pretrained=True): 149 | """Constructs a ResNet-101 model. 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 154 | return model 155 | 156 | if __name__ == "__main__": 157 | import torch 158 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 159 | input = torch.rand(1, 3, 512, 512) 160 | output, low_level_feat = model(input) 161 | print(output.size()) 162 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /modeling/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in state_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) 289 | -------------------------------------------------------------------------------- /modeling/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /modeling/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from modeling.aspp import build_aspp 6 | from modeling.decoder import build_decoder 7 | from modeling.backbone import build_backbone 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 11 | sync_bn=True, freeze_bn=False): 12 | super(DeepLab, self).__init__() 13 | if backbone == 'drn': 14 | output_stride = 8 15 | 16 | if sync_bn == True: 17 | BatchNorm = SynchronizedBatchNorm2d 18 | else: 19 | BatchNorm = nn.BatchNorm2d 20 | 21 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 22 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 23 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 24 | 25 | self.freeze_bn = freeze_bn 26 | 27 | def forward(self, input): 28 | x, low_level_feat = self.backbone(input) 29 | x = self.aspp(x) 30 | x = self.decoder(x, low_level_feat) 31 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 32 | 33 | return x 34 | 35 | def freeze_bn(self): 36 | for m in self.modules(): 37 | if isinstance(m, SynchronizedBatchNorm2d): 38 | m.eval() 39 | elif isinstance(m, nn.BatchNorm2d): 40 | m.eval() 41 | 42 | def get_1x_lr_params(self): 43 | modules = [self.backbone] 44 | for i in range(len(modules)): 45 | for m in modules[i].named_modules(): 46 | if self.freeze_bn: 47 | if isinstance(m[1], nn.Conv2d): 48 | for p in m[1].parameters(): 49 | if p.requires_grad: 50 | yield p 51 | else: 52 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 53 | or isinstance(m[1], nn.BatchNorm2d): 54 | for p in m[1].parameters(): 55 | if p.requires_grad: 56 | yield p 57 | 58 | def get_10x_lr_params(self): 59 | modules = [self.aspp, self.decoder] 60 | for i in range(len(modules)): 61 | for m in modules[i].named_modules(): 62 | if self.freeze_bn: 63 | if isinstance(m[1], nn.Conv2d): 64 | for p in m[1].parameters(): 65 | if p.requires_grad: 66 | yield p 67 | else: 68 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 69 | or isinstance(m[1], nn.BatchNorm2d): 70 | for p in m[1].parameters(): 71 | if p.requires_grad: 72 | yield p 73 | 74 | if __name__ == "__main__": 75 | model = DeepLab(backbone='mobilenet', output_stride=16) 76 | model.eval() 77 | input = torch.rand(1, 3, 513, 513) 78 | output = model(input) 79 | print(output.size()) 80 | 81 | 82 | -------------------------------------------------------------------------------- /modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /modeling/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /modeling/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | class Path(object): 2 | @staticmethod 3 | def db_root_dir(dataset): 4 | if dataset == 'pascal': 5 | return '/path/to/datasets/VOCdevkit/VOC2012/' # folder that contains VOCdevkit/. 6 | elif dataset == 'sbd': 7 | return '/path/to/datasets/benchmark_RELEASE/' # folder that contains dataset/. 8 | elif dataset == 'cityscapes': 9 | return '/path/to/datasets/cityscapes/' # foler that contains leftImg8bit/ 10 | elif dataset == 'coco': 11 | return '/path/to/datasets/coco/' 12 | else: 13 | print('Dataset {} not available.'.format(dataset)) 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from mypath import Path 7 | from dataloaders import make_data_loader 8 | from modeling.sync_batchnorm.replicate import patch_replication_callback 9 | from modeling.deeplab import * 10 | from utils.loss import SegmentationLosses 11 | from utils.calculate_weights import calculate_weigths_labels 12 | from utils.lr_scheduler import LR_Scheduler 13 | from utils.saver import Saver 14 | from utils.summaries import TensorboardSummary 15 | from utils.metrics import Evaluator 16 | 17 | class Trainer(object): 18 | def __init__(self, args): 19 | self.args = args 20 | 21 | # Define Saver 22 | self.saver = Saver(args) 23 | self.saver.save_experiment_config() 24 | # Define Tensorboard Summary 25 | self.summary = TensorboardSummary(self.saver.experiment_dir) 26 | self.writer = self.summary.create_summary() 27 | 28 | # Define Dataloader 29 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 30 | self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 31 | 32 | # Define network 33 | model = DeepLab(num_classes=self.nclass, 34 | backbone=args.backbone, 35 | output_stride=args.out_stride, 36 | sync_bn=args.sync_bn, 37 | freeze_bn=args.freeze_bn) 38 | 39 | train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}, 40 | {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}] 41 | 42 | # Define Optimizer 43 | optimizer = torch.optim.SGD(train_params, momentum=args.momentum, 44 | weight_decay=args.weight_decay, nesterov=args.nesterov) 45 | 46 | # Define Criterion 47 | # whether to use class balanced weights 48 | if args.use_balanced_weights: 49 | classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy') 50 | if os.path.isfile(classes_weights_path): 51 | weight = np.load(classes_weights_path) 52 | else: 53 | weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) 54 | weight = torch.from_numpy(weight.astype(np.float32)) 55 | else: 56 | weight = None 57 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) 58 | self.model, self.optimizer = model, optimizer 59 | 60 | # Define Evaluator 61 | self.evaluator = Evaluator(self.nclass) 62 | # Define lr scheduler 63 | self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, 64 | args.epochs, len(self.train_loader)) 65 | 66 | # Using cuda 67 | if args.cuda: 68 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 69 | patch_replication_callback(self.model) 70 | self.model = self.model.cuda() 71 | 72 | # Resuming checkpoint 73 | self.best_pred = 0.0 74 | if args.resume is not None: 75 | if not os.path.isfile(args.resume): 76 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 77 | checkpoint = torch.load(args.resume) 78 | args.start_epoch = checkpoint['epoch'] 79 | if args.cuda: 80 | self.model.module.load_state_dict(checkpoint['state_dict']) 81 | else: 82 | self.model.load_state_dict(checkpoint['state_dict']) 83 | if not args.ft: 84 | self.optimizer.load_state_dict(checkpoint['optimizer']) 85 | self.best_pred = checkpoint['best_pred'] 86 | print("=> loaded checkpoint '{}' (epoch {})" 87 | .format(args.resume, checkpoint['epoch'])) 88 | 89 | # Clear start epoch if fine-tuning 90 | if args.ft: 91 | args.start_epoch = 0 92 | 93 | def training(self, epoch): 94 | train_loss = 0.0 95 | self.model.train() 96 | tbar = tqdm(self.train_loader) 97 | num_img_tr = len(self.train_loader) 98 | for i, sample in enumerate(tbar): 99 | image, target = sample['image'], sample['label'] 100 | if self.args.cuda: 101 | image, target = image.cuda(), target.cuda() 102 | self.scheduler(self.optimizer, i, epoch, self.best_pred) 103 | self.optimizer.zero_grad() 104 | output = self.model(image) 105 | loss = self.criterion(output, target) 106 | loss.backward() 107 | self.optimizer.step() 108 | train_loss += loss.item() 109 | tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) 110 | self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) 111 | 112 | # Show 10 * 3 inference results each epoch 113 | if i % (num_img_tr // 10) == 0: 114 | global_step = i + num_img_tr * epoch 115 | self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) 116 | 117 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 118 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 119 | print('Loss: %.3f' % train_loss) 120 | 121 | if self.args.no_val: 122 | # save checkpoint every epoch 123 | is_best = False 124 | self.saver.save_checkpoint({ 125 | 'epoch': epoch + 1, 126 | 'state_dict': self.model.module.state_dict(), 127 | 'optimizer': self.optimizer.state_dict(), 128 | 'best_pred': self.best_pred, 129 | }, is_best) 130 | 131 | 132 | def validation(self, epoch): 133 | self.model.eval() 134 | self.evaluator.reset() 135 | tbar = tqdm(self.val_loader, desc='\r') 136 | test_loss = 0.0 137 | for i, sample in enumerate(tbar): 138 | image, target = sample['image'], sample['label'] 139 | if self.args.cuda: 140 | image, target = image.cuda(), target.cuda() 141 | with torch.no_grad(): 142 | output = self.model(image) 143 | loss = self.criterion(output, target) 144 | test_loss += loss.item() 145 | tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) 146 | pred = output.data.cpu().numpy() 147 | target = target.cpu().numpy() 148 | pred = np.argmax(pred, axis=1) 149 | # Add batch sample into evaluator 150 | self.evaluator.add_batch(target, pred) 151 | 152 | # Fast test during the training 153 | Acc = self.evaluator.Pixel_Accuracy() 154 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 155 | mIoU = self.evaluator.Mean_Intersection_over_Union() 156 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 157 | self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) 158 | self.writer.add_scalar('val/mIoU', mIoU, epoch) 159 | self.writer.add_scalar('val/Acc', Acc, epoch) 160 | self.writer.add_scalar('val/Acc_class', Acc_class, epoch) 161 | self.writer.add_scalar('val/fwIoU', FWIoU, epoch) 162 | print('Validation:') 163 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 164 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 165 | print('Loss: %.3f' % test_loss) 166 | 167 | new_pred = mIoU 168 | if new_pred > self.best_pred: 169 | is_best = True 170 | self.best_pred = new_pred 171 | self.saver.save_checkpoint({ 172 | 'epoch': epoch + 1, 173 | 'state_dict': self.model.module.state_dict(), 174 | 'optimizer': self.optimizer.state_dict(), 175 | 'best_pred': self.best_pred, 176 | }, is_best) 177 | 178 | def main(): 179 | parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training") 180 | parser.add_argument('--backbone', type=str, default='resnet', 181 | choices=['resnet', 'xception', 'drn', 'mobilenet'], 182 | help='backbone name (default: resnet)') 183 | parser.add_argument('--out-stride', type=int, default=16, 184 | help='network output stride (default: 8)') 185 | parser.add_argument('--dataset', type=str, default='pascal', 186 | choices=['pascal', 'coco', 'cityscapes'], 187 | help='dataset name (default: pascal)') 188 | parser.add_argument('--use-sbd', action='store_true', default=True, 189 | help='whether to use SBD dataset (default: True)') 190 | parser.add_argument('--workers', type=int, default=4, 191 | metavar='N', help='dataloader threads') 192 | parser.add_argument('--base-size', type=int, default=513, 193 | help='base image size') 194 | parser.add_argument('--crop-size', type=int, default=513, 195 | help='crop image size') 196 | parser.add_argument('--sync-bn', type=bool, default=None, 197 | help='whether to use sync bn (default: auto)') 198 | parser.add_argument('--freeze-bn', type=bool, default=False, 199 | help='whether to freeze bn parameters (default: False)') 200 | parser.add_argument('--loss-type', type=str, default='ce', 201 | choices=['ce', 'focal'], 202 | help='loss func type (default: ce)') 203 | # training hyper params 204 | parser.add_argument('--epochs', type=int, default=None, metavar='N', 205 | help='number of epochs to train (default: auto)') 206 | parser.add_argument('--start_epoch', type=int, default=0, 207 | metavar='N', help='start epochs (default:0)') 208 | parser.add_argument('--batch-size', type=int, default=None, 209 | metavar='N', help='input batch size for \ 210 | training (default: auto)') 211 | parser.add_argument('--test-batch-size', type=int, default=None, 212 | metavar='N', help='input batch size for \ 213 | testing (default: auto)') 214 | parser.add_argument('--use-balanced-weights', action='store_true', default=False, 215 | help='whether to use balanced weights (default: False)') 216 | # optimizer params 217 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 218 | help='learning rate (default: auto)') 219 | parser.add_argument('--lr-scheduler', type=str, default='poly', 220 | choices=['poly', 'step', 'cos'], 221 | help='lr scheduler mode: (default: poly)') 222 | parser.add_argument('--momentum', type=float, default=0.9, 223 | metavar='M', help='momentum (default: 0.9)') 224 | parser.add_argument('--weight-decay', type=float, default=5e-4, 225 | metavar='M', help='w-decay (default: 5e-4)') 226 | parser.add_argument('--nesterov', action='store_true', default=False, 227 | help='whether use nesterov (default: False)') 228 | # cuda, seed and logging 229 | parser.add_argument('--no-cuda', action='store_true', default= 230 | False, help='disables CUDA training') 231 | parser.add_argument('--gpu-ids', type=str, default='0', 232 | help='use which gpu to train, must be a \ 233 | comma-separated list of integers only (default=0)') 234 | parser.add_argument('--seed', type=int, default=1, metavar='S', 235 | help='random seed (default: 1)') 236 | # checking point 237 | parser.add_argument('--resume', type=str, default=None, 238 | help='put the path to resuming file if needed') 239 | parser.add_argument('--checkname', type=str, default=None, 240 | help='set the checkpoint name') 241 | # finetuning pre-trained models 242 | parser.add_argument('--ft', action='store_true', default=False, 243 | help='finetuning on a different dataset') 244 | # evaluation option 245 | parser.add_argument('--eval-interval', type=int, default=1, 246 | help='evaluuation interval (default: 1)') 247 | parser.add_argument('--no-val', action='store_true', default=False, 248 | help='skip validation during training') 249 | 250 | args = parser.parse_args() 251 | args.cuda = not args.no_cuda and torch.cuda.is_available() 252 | if args.cuda: 253 | try: 254 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 255 | except ValueError: 256 | raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') 257 | 258 | if args.sync_bn is None: 259 | if args.cuda and len(args.gpu_ids) > 1: 260 | args.sync_bn = True 261 | else: 262 | args.sync_bn = False 263 | 264 | # default settings for epochs, batch_size and lr 265 | if args.epochs is None: 266 | epoches = { 267 | 'coco': 30, 268 | 'cityscapes': 200, 269 | 'pascal': 50, 270 | } 271 | args.epochs = epoches[args.dataset.lower()] 272 | 273 | if args.batch_size is None: 274 | args.batch_size = 4 * len(args.gpu_ids) 275 | 276 | if args.test_batch_size is None: 277 | args.test_batch_size = args.batch_size 278 | 279 | if args.lr is None: 280 | lrs = { 281 | 'coco': 0.1, 282 | 'cityscapes': 0.01, 283 | 'pascal': 0.007, 284 | } 285 | args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size 286 | 287 | 288 | if args.checkname is None: 289 | args.checkname = 'deeplab-'+str(args.backbone) 290 | print(args) 291 | torch.manual_seed(args.seed) 292 | trainer = Trainer(args) 293 | print('Starting Epoch:', trainer.args.start_epoch) 294 | print('Total Epoches:', trainer.args.epochs) 295 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 296 | trainer.training(epoch) 297 | if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1): 298 | trainer.validation(epoch) 299 | 300 | trainer.writer.close() 301 | 302 | if __name__ == "__main__": 303 | main() 304 | -------------------------------------------------------------------------------- /train_coco.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --backbone resnet --lr 0.01 --workers 4 --epochs 40 --batch-size 16 --gpu-ids 0,1,2,3 --checkname deeplab-resnet --eval-interval 1 --dataset coco 2 | -------------------------------------------------------------------------------- /train_voc.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --backbone resnet --lr 0.007 --workers 4 --use-sbd True --epochs 50 --batch-size 16 --gpu-ids 0,1,2,3 --checkname deeplab-resnet --eval-interval 1 --dataset pascal 2 | -------------------------------------------------------------------------------- /utils/calculate_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | from mypath import Path 5 | 6 | def calculate_weigths_labels(dataset, dataloader, num_classes): 7 | # Create an instance from the data loader 8 | z = np.zeros((num_classes,)) 9 | # Initialize tqdm 10 | tqdm_batch = tqdm(dataloader) 11 | print('Calculating classes weights') 12 | for sample in tqdm_batch: 13 | y = sample['label'] 14 | y = y.detach().cpu().numpy() 15 | mask = (y >= 0) & (y < num_classes) 16 | labels = y[mask].astype(np.uint8) 17 | count_l = np.bincount(labels, minlength=num_classes) 18 | z += count_l 19 | tqdm_batch.close() 20 | total_frequency = np.sum(z) 21 | class_weights = [] 22 | for frequency in z: 23 | class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) 24 | class_weights.append(class_weight) 25 | ret = np.array(class_weights) 26 | classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy') 27 | np.save(classes_weights_path, ret) 28 | 29 | return ret -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SegmentationLosses(object): 5 | def __init__(self, weight=None, size_average=True, batch_average=True, ignore_index=255, cuda=False): 6 | self.ignore_index = ignore_index 7 | self.weight = weight 8 | self.size_average = size_average 9 | self.batch_average = batch_average 10 | self.cuda = cuda 11 | 12 | def build_loss(self, mode='ce'): 13 | """Choices: ['ce' or 'focal']""" 14 | if mode == 'ce': 15 | return self.CrossEntropyLoss 16 | elif mode == 'focal': 17 | return self.FocalLoss 18 | else: 19 | raise NotImplementedError 20 | 21 | def CrossEntropyLoss(self, logit, target): 22 | n, c, h, w = logit.size() 23 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 24 | size_average=self.size_average) 25 | if self.cuda: 26 | criterion = criterion.cuda() 27 | 28 | loss = criterion(logit, target.long()) 29 | 30 | if self.batch_average: 31 | loss /= n 32 | 33 | return loss 34 | 35 | def FocalLoss(self, logit, target, gamma=2, alpha=0.5): 36 | n, c, h, w = logit.size() 37 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 38 | size_average=self.size_average) 39 | if self.cuda: 40 | criterion = criterion.cuda() 41 | 42 | logpt = -criterion(logit, target.long()) 43 | pt = torch.exp(logpt) 44 | if alpha is not None: 45 | logpt *= alpha 46 | loss = -((1 - pt) ** gamma) * logpt 47 | 48 | if self.batch_average: 49 | loss /= n 50 | 51 | return loss 52 | 53 | if __name__ == "__main__": 54 | loss = SegmentationLosses(cuda=True) 55 | a = torch.rand(1, 3, 7, 7).cuda() 56 | b = torch.rand(1, 7, 7).cuda() 57 | print(loss.CrossEntropyLoss(a, b).item()) 58 | print(loss.FocalLoss(a, b, gamma=0, alpha=None).item()) 59 | print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item()) 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | class LR_Scheduler(object): 14 | """Learning Rate Scheduler 15 | 16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 17 | 18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 19 | 20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 21 | 22 | Args: 23 | args: 24 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 25 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 26 | :attr:`args.lr_step` 27 | 28 | iters_per_epoch: number of iterations per epoch 29 | """ 30 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 31 | lr_step=0, warmup_epochs=0): 32 | self.mode = mode 33 | print('Using {} LR Scheduler!'.format(self.mode)) 34 | self.lr = base_lr 35 | if mode == 'step': 36 | assert lr_step 37 | self.lr_step = lr_step 38 | self.iters_per_epoch = iters_per_epoch 39 | self.N = num_epochs * iters_per_epoch 40 | self.epoch = -1 41 | self.warmup_iters = warmup_epochs * iters_per_epoch 42 | 43 | def __call__(self, optimizer, i, epoch, best_pred): 44 | T = epoch * self.iters_per_epoch + i 45 | if self.mode == 'cos': 46 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 47 | elif self.mode == 'poly': 48 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 49 | elif self.mode == 'step': 50 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 51 | else: 52 | raise NotImplemented 53 | # warm up lr schedule 54 | if self.warmup_iters > 0 and T < self.warmup_iters: 55 | lr = lr * 1.0 * T / self.warmup_iters 56 | if epoch > self.epoch: 57 | print('\n=>Epoches %i, learning rate = %.4f, \ 58 | previous best = %.4f' % (epoch, lr, best_pred)) 59 | self.epoch = epoch 60 | assert lr >= 0 61 | self._adjust_learning_rate(optimizer, lr) 62 | 63 | def _adjust_learning_rate(self, optimizer, lr): 64 | if len(optimizer.param_groups) == 1: 65 | optimizer.param_groups[0]['lr'] = lr 66 | else: 67 | # enlarge the lr at the head 68 | optimizer.param_groups[0]['lr'] = lr 69 | for i in range(1, len(optimizer.param_groups)): 70 | optimizer.param_groups[i]['lr'] = lr * 10 71 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = np.zeros((self.num_class,)*2) 8 | 9 | def Pixel_Accuracy(self): 10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 11 | return Acc 12 | 13 | def Pixel_Accuracy_Class(self): 14 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 15 | Acc = np.nanmean(Acc) 16 | return Acc 17 | 18 | def Mean_Intersection_over_Union(self): 19 | MIoU = np.diag(self.confusion_matrix) / ( 20 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 21 | np.diag(self.confusion_matrix)) 22 | MIoU = np.nanmean(MIoU) 23 | return MIoU 24 | 25 | def Frequency_Weighted_Intersection_over_Union(self): 26 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 27 | iu = np.diag(self.confusion_matrix) / ( 28 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 29 | np.diag(self.confusion_matrix)) 30 | 31 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 32 | return FWIoU 33 | 34 | def _generate_matrix(self, gt_image, pre_image): 35 | mask = (gt_image >= 0) & (gt_image < self.num_class) 36 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 37 | count = np.bincount(label, minlength=self.num_class**2) 38 | confusion_matrix = count.reshape(self.num_class, self.num_class) 39 | return confusion_matrix 40 | 41 | def add_batch(self, gt_image, pre_image): 42 | assert gt_image.shape == pre_image.shape 43 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 44 | 45 | def reset(self): 46 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | from collections import OrderedDict 5 | import glob 6 | 7 | class Saver(object): 8 | 9 | def __init__(self, args): 10 | self.args = args 11 | self.directory = os.path.join('run', args.dataset, args.checkname) 12 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 13 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 14 | 15 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 16 | if not os.path.exists(self.experiment_dir): 17 | os.makedirs(self.experiment_dir) 18 | 19 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 20 | """Saves checkpoint to disk""" 21 | filename = os.path.join(self.experiment_dir, filename) 22 | torch.save(state, filename) 23 | if is_best: 24 | best_pred = state['best_pred'] 25 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 26 | f.write(str(best_pred)) 27 | if self.runs: 28 | previous_miou = [0.0] 29 | for run in self.runs: 30 | run_id = run.split('_')[-1] 31 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 32 | if os.path.exists(path): 33 | with open(path, 'r') as f: 34 | miou = float(f.readline()) 35 | previous_miou.append(miou) 36 | else: 37 | continue 38 | max_miou = max(previous_miou) 39 | if best_pred > max_miou: 40 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 41 | else: 42 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 43 | 44 | def save_experiment_config(self): 45 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 46 | log_file = open(logfile, 'w') 47 | p = OrderedDict() 48 | p['datset'] = self.args.dataset 49 | p['backbone'] = self.args.backbone 50 | p['out_stride'] = self.args.out_stride 51 | p['lr'] = self.args.lr 52 | p['lr_scheduler'] = self.args.lr_scheduler 53 | p['loss_type'] = self.args.loss_type 54 | p['epoch'] = self.args.epochs 55 | p['base_size'] = self.args.base_size 56 | p['crop_size'] = self.args.crop_size 57 | 58 | for key, val in p.items(): 59 | log_file.write(key + ':' + str(val) + '\n') 60 | log_file.close() -------------------------------------------------------------------------------- /utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import make_grid 4 | from tensorboardX import SummaryWriter 5 | from dataloaders.utils import decode_seg_map_sequence 6 | 7 | class TensorboardSummary(object): 8 | def __init__(self, directory): 9 | self.directory = directory 10 | 11 | def create_summary(self): 12 | writer = SummaryWriter(log_dir=os.path.join(self.directory)) 13 | return writer 14 | 15 | def visualize_image(self, writer, dataset, image, target, output, global_step): 16 | grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) 17 | writer.add_image('Image', grid_image, global_step) 18 | grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), 19 | dataset=dataset), 3, normalize=False, range=(0, 255)) 20 | writer.add_image('Predicted label', grid_image, global_step) 21 | grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), 22 | dataset=dataset), 3, normalize=False, range=(0, 255)) 23 | writer.add_image('Groundtruth label', grid_image, global_step) --------------------------------------------------------------------------------