├── .gitignore ├── LICENSE ├── README.md ├── core ├── __init__.py ├── data │ ├── __init__.py │ ├── dataloader │ │ ├── __init__.py │ │ ├── ade.py │ │ ├── cityscapes.py │ │ ├── lip_parsing.py │ │ ├── mscoco.py │ │ ├── pascal_aug.py │ │ ├── pascal_voc.py │ │ ├── sbu_shadow.py │ │ ├── segbase.py │ │ └── utils.py │ └── downloader │ │ ├── __init__.py │ │ ├── ade20k.py │ │ ├── cityscapes.py │ │ ├── mscoco.py │ │ ├── pascal_voc.py │ │ └── sbu_shadow.py ├── models │ ├── __init__.py │ ├── base_models │ │ ├── __init__.py │ │ ├── densenet.py │ │ ├── eespnet.py │ │ ├── hrnet.py │ │ ├── mobilenetv2.py │ │ ├── resnet.py │ │ ├── resnetv1b.py │ │ ├── resnext.py │ │ ├── vgg.py │ │ └── xception.py │ ├── bisenet.py │ ├── ccnet.py │ ├── cgnet.py │ ├── danet.py │ ├── deeplabv3.py │ ├── deeplabv3_plus.py │ ├── denseaspp.py │ ├── dfanet.py │ ├── dunet.py │ ├── encnet.py │ ├── enet.py │ ├── espnet.py │ ├── fcn.py │ ├── fcnv2.py │ ├── hrnet.py │ ├── icnet.py │ ├── lednet.py │ ├── model_store.py │ ├── model_zoo.py │ ├── ocnet.py │ ├── psanet.py │ ├── psanet_old.py │ ├── pspnet.py │ └── segbase.py ├── nn │ ├── __init__.py │ ├── basic.py │ ├── ca_block.py │ ├── csrc │ │ ├── ca.h │ │ ├── cpu │ │ │ ├── ca_cpu.cpp │ │ │ ├── psa_cpu.cpp │ │ │ ├── syncbn_cpu.cpp │ │ │ └── vision.h │ │ ├── cuda │ │ │ ├── ca_cuda.cu │ │ │ ├── helper.h │ │ │ ├── psa_cuda.cu │ │ │ ├── syncbn_cuda.cu │ │ │ └── vision.h │ │ ├── psa.h │ │ ├── syncbn.h │ │ └── vision.cpp │ ├── jpu.py │ ├── psa_block.py │ ├── setup.py │ ├── sync_bn │ │ ├── __init__.py │ │ ├── functions.py │ │ ├── lib │ │ │ ├── __init__.py │ │ │ ├── cpu │ │ │ │ ├── .ninja_deps │ │ │ │ ├── .ninja_log │ │ │ │ ├── build.ninja │ │ │ │ ├── operator.cpp │ │ │ │ ├── operator.h │ │ │ │ ├── operator.o │ │ │ │ ├── setup.py │ │ │ │ ├── syncbn_cpu.cpp │ │ │ │ └── syncbn_cpu.o │ │ │ └── gpu │ │ │ │ ├── __init__.py │ │ │ │ ├── activation_kernel.cu │ │ │ │ ├── common.h │ │ │ │ ├── device_tensor.h │ │ │ │ ├── operator.cpp │ │ │ │ ├── operator.h │ │ │ │ ├── setup.py │ │ │ │ └── syncbn_kernel.cu │ │ └── syncbn.py │ └── syncbn.py └── utils │ ├── __init__.py │ ├── distributed.py │ ├── download.py │ ├── filesystem.py │ ├── logger.py │ ├── loss.py │ ├── lr_scheduler.py │ ├── parallel.py │ ├── score.py │ └── visualize.py ├── datasets ├── ade ├── citys ├── sbu └── voc ├── docs ├── DETAILS.md ├── requirements.yml └── weimar_000091_000019_gtFine_color.png ├── scripts ├── demo.py ├── eval.py ├── fcn32s_vgg16_pascal_voc.sh ├── fcn32s_vgg16_pascal_voc_dist.sh └── train.py └── tests ├── README.md ├── runs ├── bisenet_epoch_100.png ├── danet_epoch_100.png ├── denseaspp_epoch_40.png ├── dunet_epoch_100.png ├── encnet_epoch_100.png ├── enet_epoch_100.png ├── fcn16s_epoch_200.png ├── fcn32s_epoch_300.png ├── fcn8s_epoch_100.png ├── icnet_epoch_100.png ├── ocnet_epoch_100.png └── psp_epoch_100.png ├── test_img.jpg ├── test_mask.png ├── test_model.py └── test_module.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | *.idea 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # pycharm 107 | 108 | # premodel 109 | weights/ 110 | *.pkl 111 | *.pth 112 | 113 | # dataset 114 | datasets/ 115 | VOCdevket/ 116 | eval/ 117 | 118 | # overfitting test 119 | 120 | # run result 121 | 122 | # model 123 | /models/hrnet.py 124 | /models/psanet_old.py 125 | /scripts/debug.py 126 | 127 | # nn 128 | nn/sync_bn/ -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from . import nn, models, utils, data -------------------------------------------------------------------------------- /core/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/core/data/__init__.py -------------------------------------------------------------------------------- /core/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides data loaders and transformers for popular vision datasets. 3 | """ 4 | from .mscoco import COCOSegmentation 5 | from .cityscapes import CitySegmentation 6 | from .ade import ADE20KSegmentation 7 | from .pascal_voc import VOCSegmentation 8 | from .pascal_aug import VOCAugSegmentation 9 | from .sbu_shadow import SBUSegmentation 10 | 11 | datasets = { 12 | 'ade20k': ADE20KSegmentation, 13 | 'pascal_voc': VOCSegmentation, 14 | 'pascal_aug': VOCAugSegmentation, 15 | 'coco': COCOSegmentation, 16 | 'citys': CitySegmentation, 17 | 'sbu': SBUSegmentation, 18 | } 19 | 20 | 21 | def get_segmentation_dataset(name, **kwargs): 22 | """Segmentation Datasets""" 23 | return datasets[name.lower()](**kwargs) 24 | -------------------------------------------------------------------------------- /core/data/dataloader/cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from .segbase import SegmentationDataset 8 | 9 | 10 | class CitySegmentation(SegmentationDataset): 11 | """Cityscapes Semantic Segmentation Dataset. 12 | 13 | Parameters 14 | ---------- 15 | root : string 16 | Path to Cityscapes folder. Default is './datasets/citys' 17 | split: string 18 | 'train', 'val' or 'test' 19 | transform : callable, optional 20 | A function that transforms the image 21 | Examples 22 | -------- 23 | >>> from torchvision import transforms 24 | >>> import torch.utils.data as data 25 | >>> # Transforms for Normalization 26 | >>> input_transform = transforms.Compose([ 27 | >>> transforms.ToTensor(), 28 | >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)), 29 | >>> ]) 30 | >>> # Create Dataset 31 | >>> trainset = CitySegmentation(split='train', transform=input_transform) 32 | >>> # Create Training Loader 33 | >>> train_data = data.DataLoader( 34 | >>> trainset, 4, shuffle=True, 35 | >>> num_workers=4) 36 | """ 37 | BASE_DIR = 'cityscapes' 38 | NUM_CLASS = 19 39 | 40 | def __init__(self, root='../datasets/citys', split='train', mode=None, transform=None, **kwargs): 41 | super(CitySegmentation, self).__init__(root, split, mode, transform, **kwargs) 42 | # self.root = os.path.join(root, self.BASE_DIR) 43 | assert os.path.exists(self.root), "Please setup the dataset using ../datasets/cityscapes.py" 44 | self.images, self.mask_paths = _get_city_pairs(self.root, self.split) 45 | assert (len(self.images) == len(self.mask_paths)) 46 | if len(self.images) == 0: 47 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 48 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 49 | 23, 24, 25, 26, 27, 28, 31, 32, 33] 50 | self._key = np.array([-1, -1, -1, -1, -1, -1, 51 | -1, -1, 0, 1, -1, -1, 52 | 2, 3, 4, -1, -1, -1, 53 | 5, -1, 6, 7, 8, 9, 54 | 10, 11, 12, 13, 14, 15, 55 | -1, -1, 16, 17, 18]) 56 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 57 | 58 | def _class_to_index(self, mask): 59 | # assert the value 60 | values = np.unique(mask) 61 | for value in values: 62 | assert (value in self._mapping) 63 | index = np.digitize(mask.ravel(), self._mapping, right=True) 64 | return self._key[index].reshape(mask.shape) 65 | 66 | def __getitem__(self, index): 67 | img = Image.open(self.images[index]).convert('RGB') 68 | if self.mode == 'test': 69 | if self.transform is not None: 70 | img = self.transform(img) 71 | return img, os.path.basename(self.images[index]) 72 | mask = Image.open(self.mask_paths[index]) 73 | # synchrosized transform 74 | if self.mode == 'train': 75 | img, mask = self._sync_transform(img, mask) 76 | elif self.mode == 'val': 77 | img, mask = self._val_sync_transform(img, mask) 78 | else: 79 | assert self.mode == 'testval' 80 | img, mask = self._img_transform(img), self._mask_transform(mask) 81 | # general resize, normalize and toTensor 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | return img, mask, os.path.basename(self.images[index]) 85 | 86 | def _mask_transform(self, mask): 87 | target = self._class_to_index(np.array(mask).astype('int32')) 88 | return torch.LongTensor(np.array(target).astype('int32')) 89 | 90 | def __len__(self): 91 | return len(self.images) 92 | 93 | @property 94 | def pred_offset(self): 95 | return 0 96 | 97 | 98 | def _get_city_pairs(folder, split='train'): 99 | def get_path_pairs(img_folder, mask_folder): 100 | img_paths = [] 101 | mask_paths = [] 102 | for root, _, files in os.walk(img_folder): 103 | for filename in files: 104 | if filename.endswith('.png'): 105 | imgpath = os.path.join(root, filename) 106 | foldername = os.path.basename(os.path.dirname(imgpath)) 107 | maskname = filename.replace('leftImg8bit', 'gtFine_labelIds') 108 | maskpath = os.path.join(mask_folder, foldername, maskname) 109 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 110 | img_paths.append(imgpath) 111 | mask_paths.append(maskpath) 112 | else: 113 | print('cannot find the mask or image:', imgpath, maskpath) 114 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 115 | return img_paths, mask_paths 116 | 117 | if split in ('train', 'val'): 118 | img_folder = os.path.join(folder, 'leftImg8bit/' + split) 119 | mask_folder = os.path.join(folder, 'gtFine/' + split) 120 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 121 | return img_paths, mask_paths 122 | else: 123 | assert split == 'trainval' 124 | print('trainval set') 125 | train_img_folder = os.path.join(folder, 'leftImg8bit/train') 126 | train_mask_folder = os.path.join(folder, 'gtFine/train') 127 | val_img_folder = os.path.join(folder, 'leftImg8bit/val') 128 | val_mask_folder = os.path.join(folder, 'gtFine/val') 129 | train_img_paths, train_mask_paths = get_path_pairs(train_img_folder, train_mask_folder) 130 | val_img_paths, val_mask_paths = get_path_pairs(val_img_folder, val_mask_folder) 131 | img_paths = train_img_paths + val_img_paths 132 | mask_paths = train_mask_paths + val_mask_paths 133 | return img_paths, mask_paths 134 | 135 | 136 | if __name__ == '__main__': 137 | dataset = CitySegmentation() 138 | -------------------------------------------------------------------------------- /core/data/dataloader/lip_parsing.py: -------------------------------------------------------------------------------- 1 | """Look into Person Dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from core.data.dataloader.segbase import SegmentationDataset 8 | 9 | 10 | class LIPSegmentation(SegmentationDataset): 11 | """Look into person parsing dataset """ 12 | 13 | BASE_DIR = 'LIP' 14 | NUM_CLASS = 20 15 | 16 | def __init__(self, root='../datasets/LIP', split='train', mode=None, transform=None, **kwargs): 17 | super(LIPSegmentation, self).__init__(root, split, mode, transform, **kwargs) 18 | _trainval_image_dir = os.path.join(root, 'TrainVal_images') 19 | _testing_image_dir = os.path.join(root, 'Testing_images') 20 | _trainval_mask_dir = os.path.join(root, 'TrainVal_parsing_annotations') 21 | if split == 'train': 22 | _image_dir = os.path.join(_trainval_image_dir, 'train_images') 23 | _mask_dir = os.path.join(_trainval_mask_dir, 'train_segmentations') 24 | _split_f = os.path.join(_trainval_image_dir, 'train_id.txt') 25 | elif split == 'val': 26 | _image_dir = os.path.join(_trainval_image_dir, 'val_images') 27 | _mask_dir = os.path.join(_trainval_mask_dir, 'val_segmentations') 28 | _split_f = os.path.join(_trainval_image_dir, 'val_id.txt') 29 | elif split == 'test': 30 | _image_dir = os.path.join(_testing_image_dir, 'testing_images') 31 | _split_f = os.path.join(_testing_image_dir, 'test_id.txt') 32 | else: 33 | raise RuntimeError('Unknown dataset split.') 34 | 35 | self.images = [] 36 | self.masks = [] 37 | with open(os.path.join(_split_f), 'r') as lines: 38 | for line in lines: 39 | _image = os.path.join(_image_dir, line.rstrip('\n') + '.jpg') 40 | assert os.path.isfile(_image) 41 | self.images.append(_image) 42 | if split != 'test': 43 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + '.png') 44 | assert os.path.isfile(_mask) 45 | self.masks.append(_mask) 46 | 47 | if split != 'test': 48 | assert (len(self.images) == len(self.masks)) 49 | print('Found {} {} images in the folder {}'.format(len(self.images), split, root)) 50 | 51 | def __getitem__(self, index): 52 | img = Image.open(self.images[index]).convert('RGB') 53 | if self.mode == 'test': 54 | img = self._img_transform(img) 55 | if self.transform is not None: 56 | img = self.transform(img) 57 | return img, os.path.basename(self.images[index]) 58 | mask = Image.open(self.masks[index]) 59 | # synchronized transform 60 | if self.mode == 'train': 61 | img, mask = self._sync_transform(img, mask) 62 | elif self.mode == 'val': 63 | img, mask = self._val_sync_transform(img, mask) 64 | else: 65 | assert self.mode == 'testval' 66 | img, mask = self._img_transform(img), self._mask_transform(mask) 67 | # general resize, normalize and toTensor 68 | if self.transform is not None: 69 | img = self.transform(img) 70 | 71 | return img, mask, os.path.basename(self.images[index]) 72 | 73 | def __len__(self): 74 | return len(self.images) 75 | 76 | def _mask_transform(self, mask): 77 | target = np.array(mask).astype('int32') 78 | return torch.from_numpy(target).long() 79 | 80 | @property 81 | def classes(self): 82 | """Category name.""" 83 | return ('background', 'hat', 'hair', 'glove', 'sunglasses', 'upperclothes', 84 | 'dress', 'coat', 'socks', 'pants', 'jumpsuits', 'scarf', 'skirt', 85 | 'face', 'leftArm', 'rightArm', 'leftLeg', 'rightLeg', 'leftShoe', 86 | 'rightShoe') 87 | 88 | 89 | if __name__ == '__main__': 90 | dataset = LIPSegmentation(base_size=280, crop_size=256) -------------------------------------------------------------------------------- /core/data/dataloader/mscoco.py: -------------------------------------------------------------------------------- 1 | """MSCOCO Semantic Segmentation pretraining for VOC.""" 2 | import os 3 | import pickle 4 | import torch 5 | import numpy as np 6 | 7 | from tqdm import trange 8 | from PIL import Image 9 | from .segbase import SegmentationDataset 10 | 11 | 12 | class COCOSegmentation(SegmentationDataset): 13 | """COCO Semantic Segmentation Dataset for VOC Pre-training. 14 | 15 | Parameters 16 | ---------- 17 | root : string 18 | Path to ADE20K folder. Default is './datasets/coco' 19 | split: string 20 | 'train', 'val' or 'test' 21 | transform : callable, optional 22 | A function that transforms the image 23 | Examples 24 | -------- 25 | >>> from torchvision import transforms 26 | >>> import torch.utils.data as data 27 | >>> # Transforms for Normalization 28 | >>> input_transform = transforms.Compose([ 29 | >>> transforms.ToTensor(), 30 | >>> transforms.Normalize((.485, .456, .406), (.229, .224, .225)), 31 | >>> ]) 32 | >>> # Create Dataset 33 | >>> trainset = COCOSegmentation(split='train', transform=input_transform) 34 | >>> # Create Training Loader 35 | >>> train_data = data.DataLoader( 36 | >>> trainset, 4, shuffle=True, 37 | >>> num_workers=4) 38 | """ 39 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 40 | 1, 64, 20, 63, 7, 72] 41 | NUM_CLASS = 21 42 | 43 | def __init__(self, root='../datasets/coco', split='train', mode=None, transform=None, **kwargs): 44 | super(COCOSegmentation, self).__init__(root, split, mode, transform, **kwargs) 45 | # lazy import pycocotools 46 | from pycocotools.coco import COCO 47 | from pycocotools import mask 48 | if split == 'train': 49 | print('train set') 50 | ann_file = os.path.join(root, 'annotations/instances_train2017.json') 51 | ids_file = os.path.join(root, 'annotations/train_ids.mx') 52 | self.root = os.path.join(root, 'train2017') 53 | else: 54 | print('val set') 55 | ann_file = os.path.join(root, 'annotations/instances_val2017.json') 56 | ids_file = os.path.join(root, 'annotations/val_ids.mx') 57 | self.root = os.path.join(root, 'val2017') 58 | self.coco = COCO(ann_file) 59 | self.coco_mask = mask 60 | if os.path.exists(ids_file): 61 | with open(ids_file, 'rb') as f: 62 | self.ids = pickle.load(f) 63 | else: 64 | ids = list(self.coco.imgs.keys()) 65 | self.ids = self._preprocess(ids, ids_file) 66 | self.transform = transform 67 | 68 | def __getitem__(self, index): 69 | coco = self.coco 70 | img_id = self.ids[index] 71 | img_metadata = coco.loadImgs(img_id)[0] 72 | path = img_metadata['file_name'] 73 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 74 | cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) 75 | mask = Image.fromarray(self._gen_seg_mask( 76 | cocotarget, img_metadata['height'], img_metadata['width'])) 77 | # synchrosized transform 78 | if self.mode == 'train': 79 | img, mask = self._sync_transform(img, mask) 80 | elif self.mode == 'val': 81 | img, mask = self._val_sync_transform(img, mask) 82 | else: 83 | assert self.mode == 'testval' 84 | img, mask = self._img_transform(img), self._mask_transform(mask) 85 | # general resize, normalize and toTensor 86 | if self.transform is not None: 87 | img = self.transform(img) 88 | return img, mask, os.path.basename(self.ids[index]) 89 | 90 | def _mask_transform(self, mask): 91 | return torch.LongTensor(np.array(mask).astype('int32')) 92 | 93 | def _gen_seg_mask(self, target, h, w): 94 | mask = np.zeros((h, w), dtype=np.uint8) 95 | coco_mask = self.coco_mask 96 | for instance in target: 97 | rle = coco_mask.frPyObjects(instance['Segmentation'], h, w) 98 | m = coco_mask.decode(rle) 99 | cat = instance['category_id'] 100 | if cat in self.CAT_LIST: 101 | c = self.CAT_LIST.index(cat) 102 | else: 103 | continue 104 | if len(m.shape) < 3: 105 | mask[:, :] += (mask == 0) * (m * c) 106 | else: 107 | mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) 108 | return mask 109 | 110 | def _preprocess(self, ids, ids_file): 111 | print("Preprocessing mask, this will take a while." + \ 112 | "But don't worry, it only run once for each split.") 113 | tbar = trange(len(ids)) 114 | new_ids = [] 115 | for i in tbar: 116 | img_id = ids[i] 117 | cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 118 | img_metadata = self.coco.loadImgs(img_id)[0] 119 | mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width']) 120 | # more than 1k pixels 121 | if (mask > 0).sum() > 1000: 122 | new_ids.append(img_id) 123 | tbar.set_description('Doing: {}/{}, got {} qualified images'. \ 124 | format(i, len(ids), len(new_ids))) 125 | print('Found number of qualified images: ', len(new_ids)) 126 | with open(ids_file, 'wb') as f: 127 | pickle.dump(new_ids, f) 128 | return new_ids 129 | 130 | @property 131 | def classes(self): 132 | """Category names.""" 133 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 134 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 135 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 136 | 'tv') 137 | -------------------------------------------------------------------------------- /core/data/dataloader/pascal_aug.py: -------------------------------------------------------------------------------- 1 | """Pascal Augmented VOC Semantic Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import scipy.io as sio 5 | import numpy as np 6 | 7 | from PIL import Image 8 | from .segbase import SegmentationDataset 9 | 10 | 11 | class VOCAugSegmentation(SegmentationDataset): 12 | """Pascal VOC Augmented Semantic Segmentation Dataset. 13 | 14 | Parameters 15 | ---------- 16 | root : string 17 | Path to VOCdevkit folder. Default is './datasets/voc' 18 | split: string 19 | 'train', 'val' or 'test' 20 | transform : callable, optional 21 | A function that transforms the image 22 | Examples 23 | -------- 24 | >>> from torchvision import transforms 25 | >>> import torch.utils.data as data 26 | >>> # Transforms for Normalization 27 | >>> input_transform = transforms.Compose([ 28 | >>> transforms.ToTensor(), 29 | >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 30 | >>> ]) 31 | >>> # Create Dataset 32 | >>> trainset = VOCAugSegmentation(split='train', transform=input_transform) 33 | >>> # Create Training Loader 34 | >>> train_data = data.DataLoader( 35 | >>> trainset, 4, shuffle=True, 36 | >>> num_workers=4) 37 | """ 38 | BASE_DIR = 'VOCaug/dataset/' 39 | NUM_CLASS = 21 40 | 41 | def __init__(self, root='../datasets/voc', split='train', mode=None, transform=None, **kwargs): 42 | super(VOCAugSegmentation, self).__init__(root, split, mode, transform, **kwargs) 43 | # train/val/test splits are pre-cut 44 | _voc_root = os.path.join(root, self.BASE_DIR) 45 | _mask_dir = os.path.join(_voc_root, 'cls') 46 | _image_dir = os.path.join(_voc_root, 'img') 47 | if split == 'train': 48 | _split_f = os.path.join(_voc_root, 'trainval.txt') 49 | elif split == 'val': 50 | _split_f = os.path.join(_voc_root, 'val.txt') 51 | else: 52 | raise RuntimeError('Unknown dataset split: {}'.format(split)) 53 | 54 | self.images = [] 55 | self.masks = [] 56 | with open(os.path.join(_split_f), "r") as lines: 57 | for line in lines: 58 | _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") 59 | assert os.path.isfile(_image) 60 | self.images.append(_image) 61 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".mat") 62 | assert os.path.isfile(_mask) 63 | self.masks.append(_mask) 64 | 65 | assert (len(self.images) == len(self.masks)) 66 | print('Found {} images in the folder {}'.format(len(self.images), _voc_root)) 67 | 68 | def __getitem__(self, index): 69 | img = Image.open(self.images[index]).convert('RGB') 70 | target = self._load_mat(self.masks[index]) 71 | # synchrosized transform 72 | if self.mode == 'train': 73 | img, target = self._sync_transform(img, target) 74 | elif self.mode == 'val': 75 | img, target = self._val_sync_transform(img, target) 76 | else: 77 | raise RuntimeError('unknown mode for dataloader: {}'.format(self.mode)) 78 | # general resize, normalize and toTensor 79 | if self.transform is not None: 80 | img = self.transform(img) 81 | return img, target, os.path.basename(self.images[index]) 82 | 83 | def _mask_transform(self, mask): 84 | return torch.LongTensor(np.array(mask).astype('int32')) 85 | 86 | def _load_mat(self, filename): 87 | mat = sio.loadmat(filename, mat_dtype=True, squeeze_me=True, struct_as_record=False) 88 | mask = mat['GTcls'].Segmentation 89 | return Image.fromarray(mask) 90 | 91 | def __len__(self): 92 | return len(self.images) 93 | 94 | @property 95 | def classes(self): 96 | """Category names.""" 97 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 98 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 99 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 100 | 'tv') 101 | 102 | 103 | if __name__ == '__main__': 104 | dataset = VOCAugSegmentation() -------------------------------------------------------------------------------- /core/data/dataloader/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """Pascal VOC Semantic Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from .segbase import SegmentationDataset 8 | 9 | 10 | class VOCSegmentation(SegmentationDataset): 11 | """Pascal VOC Semantic Segmentation Dataset. 12 | 13 | Parameters 14 | ---------- 15 | root : string 16 | Path to VOCdevkit folder. Default is './datasets/VOCdevkit' 17 | split: string 18 | 'train', 'val' or 'test' 19 | transform : callable, optional 20 | A function that transforms the image 21 | Examples 22 | -------- 23 | >>> from torchvision import transforms 24 | >>> import torch.utils.data as data 25 | >>> # Transforms for Normalization 26 | >>> input_transform = transforms.Compose([ 27 | >>> transforms.ToTensor(), 28 | >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 29 | >>> ]) 30 | >>> # Create Dataset 31 | >>> trainset = VOCSegmentation(split='train', transform=input_transform) 32 | >>> # Create Training Loader 33 | >>> train_data = data.DataLoader( 34 | >>> trainset, 4, shuffle=True, 35 | >>> num_workers=4) 36 | """ 37 | BASE_DIR = 'VOC2012' 38 | NUM_CLASS = 21 39 | 40 | def __init__(self, root='../datasets/voc', split='train', mode=None, transform=None, **kwargs): 41 | super(VOCSegmentation, self).__init__(root, split, mode, transform, **kwargs) 42 | _voc_root = os.path.join(root, self.BASE_DIR) 43 | _mask_dir = os.path.join(_voc_root, 'SegmentationClass') 44 | _image_dir = os.path.join(_voc_root, 'JPEGImages') 45 | # train/val/test splits are pre-cut 46 | _splits_dir = os.path.join(_voc_root, 'ImageSets/Segmentation') 47 | if split == 'train': 48 | _split_f = os.path.join(_splits_dir, 'train.txt') 49 | elif split == 'val': 50 | _split_f = os.path.join(_splits_dir, 'val.txt') 51 | elif split == 'test': 52 | _split_f = os.path.join(_splits_dir, 'test.txt') 53 | else: 54 | raise RuntimeError('Unknown dataset split.') 55 | 56 | self.images = [] 57 | self.masks = [] 58 | with open(os.path.join(_split_f), "r") as lines: 59 | for line in lines: 60 | _image = os.path.join(_image_dir, line.rstrip('\n') + ".jpg") 61 | assert os.path.isfile(_image) 62 | self.images.append(_image) 63 | if split != 'test': 64 | _mask = os.path.join(_mask_dir, line.rstrip('\n') + ".png") 65 | assert os.path.isfile(_mask) 66 | self.masks.append(_mask) 67 | 68 | if split != 'test': 69 | assert (len(self.images) == len(self.masks)) 70 | print('Found {} images in the folder {}'.format(len(self.images), _voc_root)) 71 | 72 | def __getitem__(self, index): 73 | img = Image.open(self.images[index]).convert('RGB') 74 | if self.mode == 'test': 75 | img = self._img_transform(img) 76 | if self.transform is not None: 77 | img = self.transform(img) 78 | return img, os.path.basename(self.images[index]) 79 | mask = Image.open(self.masks[index]) 80 | # synchronized transform 81 | if self.mode == 'train': 82 | img, mask = self._sync_transform(img, mask) 83 | elif self.mode == 'val': 84 | img, mask = self._val_sync_transform(img, mask) 85 | else: 86 | assert self.mode == 'testval' 87 | img, mask = self._img_transform(img), self._mask_transform(mask) 88 | # general resize, normalize and toTensor 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | return img, mask, os.path.basename(self.images[index]) 93 | 94 | def __len__(self): 95 | return len(self.images) 96 | 97 | def _mask_transform(self, mask): 98 | target = np.array(mask).astype('int32') 99 | target[target == 255] = -1 100 | return torch.from_numpy(target).long() 101 | 102 | @property 103 | def classes(self): 104 | """Category names.""" 105 | return ('background', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 106 | 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 107 | 'motorcycle', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 108 | 'tv') 109 | 110 | 111 | if __name__ == '__main__': 112 | dataset = VOCSegmentation() -------------------------------------------------------------------------------- /core/data/dataloader/sbu_shadow.py: -------------------------------------------------------------------------------- 1 | """SBU Shadow Segmentation Dataset.""" 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from PIL import Image 7 | from .segbase import SegmentationDataset 8 | 9 | 10 | class SBUSegmentation(SegmentationDataset): 11 | """SBU Shadow Segmentation Dataset 12 | """ 13 | NUM_CLASS = 2 14 | 15 | def __init__(self, root='../datasets/sbu', split='train', mode=None, transform=None, **kwargs): 16 | super(SBUSegmentation, self).__init__(root, split, mode, transform, **kwargs) 17 | assert os.path.exists(self.root) 18 | self.images, self.masks = _get_sbu_pairs(self.root, self.split) 19 | assert (len(self.images) == len(self.masks)) 20 | if len(self.images) == 0: 21 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 22 | 23 | def __getitem__(self, index): 24 | img = Image.open(self.images[index]).convert('RGB') 25 | if self.mode == 'test': 26 | if self.transform is not None: 27 | img = self.transform(img) 28 | return img, os.path.basename(self.images[index]) 29 | mask = Image.open(self.masks[index]) 30 | # synchrosized transform 31 | if self.mode == 'train': 32 | img, mask = self._sync_transform(img, mask) 33 | elif self.mode == 'val': 34 | img, mask = self._val_sync_transform(img, mask) 35 | else: 36 | assert self.mode == 'testval' 37 | img, mask = self._img_transform(img), self._mask_transform(mask) 38 | # general resize, normalize and toTensor 39 | if self.transform is not None: 40 | img = self.transform(img) 41 | return img, mask, os.path.basename(self.images[index]) 42 | 43 | def _mask_transform(self, mask): 44 | target = np.array(mask).astype('int32') 45 | target[target > 0] = 1 46 | return torch.from_numpy(target).long() 47 | 48 | def __len__(self): 49 | return len(self.images) 50 | 51 | @property 52 | def pred_offset(self): 53 | return 0 54 | 55 | 56 | def _get_sbu_pairs(folder, split='train'): 57 | def get_path_pairs(img_folder, mask_folder): 58 | img_paths = [] 59 | mask_paths = [] 60 | for root, _, files in os.walk(img_folder): 61 | print(root) 62 | for filename in files: 63 | if filename.endswith('.jpg'): 64 | imgpath = os.path.join(root, filename) 65 | maskname = filename.replace('.jpg', '.png') 66 | maskpath = os.path.join(mask_folder, maskname) 67 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 68 | img_paths.append(imgpath) 69 | mask_paths.append(maskpath) 70 | else: 71 | print('cannot find the mask or image:', imgpath, maskpath) 72 | print('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 73 | return img_paths, mask_paths 74 | 75 | if split == 'train': 76 | img_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowImages') 77 | mask_folder = os.path.join(folder, 'SBUTrain4KRecoveredSmall/ShadowMasks') 78 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 79 | else: 80 | assert split in ('val', 'test') 81 | img_folder = os.path.join(folder, 'SBU-Test/ShadowImages') 82 | mask_folder = os.path.join(folder, 'SBU-Test/ShadowMasks') 83 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 84 | return img_paths, mask_paths 85 | 86 | 87 | if __name__ == '__main__': 88 | dataset = SBUSegmentation(base_size=280, crop_size=256) -------------------------------------------------------------------------------- /core/data/dataloader/segbase.py: -------------------------------------------------------------------------------- 1 | """Base segmentation dataset""" 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | __all__ = ['SegmentationDataset'] 8 | 9 | 10 | class SegmentationDataset(object): 11 | """Segmentation Base Dataset""" 12 | 13 | def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): 14 | super(SegmentationDataset, self).__init__() 15 | self.root = root 16 | self.transform = transform 17 | self.split = split 18 | self.mode = mode if mode is not None else split 19 | self.base_size = base_size 20 | self.crop_size = crop_size 21 | 22 | def _val_sync_transform(self, img, mask): 23 | outsize = self.crop_size 24 | short_size = outsize 25 | w, h = img.size 26 | if w > h: 27 | oh = short_size 28 | ow = int(1.0 * w * oh / h) 29 | else: 30 | ow = short_size 31 | oh = int(1.0 * h * ow / w) 32 | img = img.resize((ow, oh), Image.BILINEAR) 33 | mask = mask.resize((ow, oh), Image.NEAREST) 34 | # center crop 35 | w, h = img.size 36 | x1 = int(round((w - outsize) / 2.)) 37 | y1 = int(round((h - outsize) / 2.)) 38 | img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) 39 | mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) 40 | # final transform 41 | img, mask = self._img_transform(img), self._mask_transform(mask) 42 | return img, mask 43 | 44 | def _sync_transform(self, img, mask): 45 | # random mirror 46 | if random.random() < 0.5: 47 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 48 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 49 | crop_size = self.crop_size 50 | # random scale (short edge) 51 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 52 | w, h = img.size 53 | if h > w: 54 | ow = short_size 55 | oh = int(1.0 * h * ow / w) 56 | else: 57 | oh = short_size 58 | ow = int(1.0 * w * oh / h) 59 | img = img.resize((ow, oh), Image.BILINEAR) 60 | mask = mask.resize((ow, oh), Image.NEAREST) 61 | # pad crop 62 | if short_size < crop_size: 63 | padh = crop_size - oh if oh < crop_size else 0 64 | padw = crop_size - ow if ow < crop_size else 0 65 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 66 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 67 | # random crop crop_size 68 | w, h = img.size 69 | x1 = random.randint(0, w - crop_size) 70 | y1 = random.randint(0, h - crop_size) 71 | img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 72 | mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) 73 | # gaussian blur as in PSP 74 | if random.random() < 0.5: 75 | img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) 76 | # final transform 77 | img, mask = self._img_transform(img), self._mask_transform(mask) 78 | return img, mask 79 | 80 | def _img_transform(self, img): 81 | return np.array(img) 82 | 83 | def _mask_transform(self, mask): 84 | return np.array(mask).astype('int32') 85 | 86 | @property 87 | def num_class(self): 88 | """Number of categories.""" 89 | return self.NUM_CLASS 90 | 91 | @property 92 | def pred_offset(self): 93 | return 0 94 | -------------------------------------------------------------------------------- /core/data/dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import errno 4 | import tarfile 5 | from six.moves import urllib 6 | from torch.utils.model_zoo import tqdm 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | def check_integrity(fpath, md5=None): 20 | if md5 is None: 21 | return True 22 | if not os.path.isfile(fpath): 23 | return False 24 | md5o = hashlib.md5() 25 | with open(fpath, 'rb') as f: 26 | # read in 1MB chunks 27 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 28 | md5o.update(chunk) 29 | md5c = md5o.hexdigest() 30 | if md5c != md5: 31 | return False 32 | return True 33 | 34 | def makedir_exist_ok(dirpath): 35 | try: 36 | os.makedirs(dirpath) 37 | except OSError as e: 38 | if e.errno == errno.EEXIST: 39 | pass 40 | else: 41 | pass 42 | 43 | def download_url(url, root, filename=None, md5=None): 44 | """Download a file from a url and place it in root.""" 45 | root = os.path.expanduser(root) 46 | if not filename: 47 | filename = os.path.basename(url) 48 | fpath = os.path.join(root, filename) 49 | 50 | makedir_exist_ok(root) 51 | 52 | # downloads file 53 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 54 | print('Using downloaded and verified file: ' + fpath) 55 | else: 56 | try: 57 | print('Downloading ' + url + ' to ' + fpath) 58 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 59 | except OSError: 60 | if url[:5] == 'https': 61 | url = url.replace('https:', 'http:') 62 | print('Failed download. Trying https -> http instead.' 63 | ' Downloading ' + url + ' to ' + fpath) 64 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 65 | 66 | def download_extract(url, root, filename, md5): 67 | download_url(url, root, filename, md5) 68 | with tarfile.open(os.path.join(root, filename), "r") as tar: 69 | tar.extractall(path=root) -------------------------------------------------------------------------------- /core/data/downloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/core/data/downloader/__init__.py -------------------------------------------------------------------------------- /core/data/downloader/ade20k.py: -------------------------------------------------------------------------------- 1 | """Prepare ADE20K dataset""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from core.utils import download, makedirs 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/ade') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize ADE20K dataset.', 20 | epilog='Example: python setup_ade20k.py', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def download_ade(path, overwrite=False): 28 | _AUG_DOWNLOAD_URLS = [ 29 | ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', 30 | '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'), 31 | ( 32 | 'http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 33 | 'e05747892219d10e9243933371a497e905a4860c'), ] 34 | download_dir = os.path.join(path, 'downloads') 35 | makedirs(download_dir) 36 | for url, checksum in _AUG_DOWNLOAD_URLS: 37 | filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum) 38 | # extract 39 | with zipfile.ZipFile(filename, "r") as zip_ref: 40 | zip_ref.extractall(path=path) 41 | 42 | 43 | if __name__ == '__main__': 44 | args = parse_args() 45 | makedirs(os.path.expanduser('~/.torch/datasets')) 46 | if args.download_dir is not None: 47 | if os.path.isdir(_TARGET_DIR): 48 | os.remove(_TARGET_DIR) 49 | # make symlink 50 | os.symlink(args.download_dir, _TARGET_DIR) 51 | download_ade(_TARGET_DIR, overwrite=False) 52 | -------------------------------------------------------------------------------- /core/data/downloader/cityscapes.py: -------------------------------------------------------------------------------- 1 | """Prepare Cityscapes dataset""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from core.utils import download, makedirs, check_sha1 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/citys') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize ADE20K dataset.', 20 | epilog='Example: python prepare_cityscapes.py', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', default=None, help='dataset directory on disk') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def download_city(path, overwrite=False): 28 | _CITY_DOWNLOAD_URLS = [ 29 | ('gtFine_trainvaltest.zip', '99f532cb1af174f5fcc4c5bc8feea8c66246ddbc'), 30 | ('leftImg8bit_trainvaltest.zip', '2c0b77ce9933cc635adda307fbba5566f5d9d404')] 31 | download_dir = os.path.join(path, 'downloads') 32 | makedirs(download_dir) 33 | for filename, checksum in _CITY_DOWNLOAD_URLS: 34 | if not check_sha1(filename, checksum): 35 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 36 | 'The repo may be outdated or download may be incomplete. ' \ 37 | 'If the "repo_url" is overridden, consider switching to ' \ 38 | 'the default repo.'.format(filename)) 39 | # extract 40 | with zipfile.ZipFile(filename, "r") as zip_ref: 41 | zip_ref.extractall(path=path) 42 | print("Extracted", filename) 43 | 44 | 45 | if __name__ == '__main__': 46 | args = parse_args() 47 | makedirs(os.path.expanduser('~/.torch/datasets')) 48 | if args.download_dir is not None: 49 | if os.path.isdir(_TARGET_DIR): 50 | os.remove(_TARGET_DIR) 51 | # make symlink 52 | os.symlink(args.download_dir, _TARGET_DIR) 53 | else: 54 | download_city(_TARGET_DIR, overwrite=False) 55 | -------------------------------------------------------------------------------- /core/data/downloader/mscoco.py: -------------------------------------------------------------------------------- 1 | """Prepare MS COCO datasets""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from core.utils import download, makedirs, try_import_pycocotools 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/coco') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize MS COCO dataset.', 20 | epilog='Example: python mscoco.py --download-dir ~/mscoco', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', type=str, default='~/mscoco/', help='dataset directory on disk') 23 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 24 | parser.add_argument('--overwrite', action='store_true', 25 | help='overwrite downloaded files if set, in case they are corrupted') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def download_coco(path, overwrite=False): 31 | _DOWNLOAD_URLS = [ 32 | ('http://images.cocodataset.org/zips/train2017.zip', 33 | '10ad623668ab00c62c096f0ed636d6aff41faca5'), 34 | ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip', 35 | '8551ee4bb5860311e79dace7e79cb91e432e78b3'), 36 | ('http://images.cocodataset.org/zips/val2017.zip', 37 | '4950dc9d00dbe1c933ee0170f5797584351d2a41'), 38 | # ('http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip', 39 | # '46cdcf715b6b4f67e980b529534e79c2edffe084'), 40 | # test2017.zip, for those who want to attend the competition. 41 | # ('http://images.cocodataset.org/zips/test2017.zip', 42 | # '4e443f8a2eca6b1dac8a6c57641b67dd40621a49'), 43 | ] 44 | makedirs(path) 45 | for url, checksum in _DOWNLOAD_URLS: 46 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 47 | # extract 48 | with zipfile.ZipFile(filename) as zf: 49 | zf.extractall(path=path) 50 | 51 | 52 | if __name__ == '__main__': 53 | args = parse_args() 54 | path = os.path.expanduser(args.download_dir) 55 | if not os.path.isdir(path) or not os.path.isdir(os.path.join(path, 'train2017')) \ 56 | or not os.path.isdir(os.path.join(path, 'val2017')) \ 57 | or not os.path.isdir(os.path.join(path, 'annotations')): 58 | if args.no_download: 59 | raise ValueError(('{} is not a valid directory, make sure it is present.' 60 | ' Or you should not disable "--no-download" to grab it'.format(path))) 61 | else: 62 | download_coco(path, overwrite=args.overwrite) 63 | 64 | # make symlink 65 | makedirs(os.path.expanduser('~/.torch/datasets')) 66 | if os.path.isdir(_TARGET_DIR): 67 | os.remove(_TARGET_DIR) 68 | os.symlink(path, _TARGET_DIR) 69 | try_import_pycocotools() 70 | -------------------------------------------------------------------------------- /core/data/downloader/pascal_voc.py: -------------------------------------------------------------------------------- 1 | """Prepare PASCAL VOC datasets""" 2 | import os 3 | import sys 4 | import shutil 5 | import argparse 6 | import tarfile 7 | 8 | # TODO: optim code 9 | cur_path = os.path.abspath(os.path.dirname(__file__)) 10 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 11 | sys.path.append(root_path) 12 | 13 | from core.utils import download, makedirs 14 | 15 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/voc') 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser( 20 | description='Initialize PASCAL VOC dataset.', 21 | epilog='Example: python pascal_voc.py --download-dir ~/VOCdevkit', 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | parser.add_argument('--download-dir', type=str, default='~/VOCdevkit/', help='dataset directory on disk') 24 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 25 | parser.add_argument('--overwrite', action='store_true', 26 | help='overwrite downloaded files if set, in case they are corrupted') 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | ##################################################################################### 32 | # Download and extract VOC datasets into ``path`` 33 | 34 | def download_voc(path, overwrite=False): 35 | _DOWNLOAD_URLS = [ 36 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 37 | '34ed68851bce2a36e2a223fa52c661d592c66b3c'), 38 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', 39 | '41a8d6e12baa5ab18ee7f8f8029b9e11805b4ef1'), 40 | ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 41 | '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')] 42 | makedirs(path) 43 | for url, checksum in _DOWNLOAD_URLS: 44 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 45 | # extract 46 | with tarfile.open(filename) as tar: 47 | tar.extractall(path=path) 48 | 49 | 50 | ##################################################################################### 51 | # Download and extract the VOC augmented segmentation dataset into ``path`` 52 | 53 | def download_aug(path, overwrite=False): 54 | _AUG_DOWNLOAD_URLS = [ 55 | ('http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz', 56 | '7129e0a480c2d6afb02b517bb18ac54283bfaa35')] 57 | makedirs(path) 58 | for url, checksum in _AUG_DOWNLOAD_URLS: 59 | filename = download(url, path=path, overwrite=overwrite, sha1_hash=checksum) 60 | # extract 61 | with tarfile.open(filename) as tar: 62 | tar.extractall(path=path) 63 | shutil.move(os.path.join(path, 'benchmark_RELEASE'), 64 | os.path.join(path, 'VOCaug')) 65 | filenames = ['VOCaug/dataset/train.txt', 'VOCaug/dataset/val.txt'] 66 | # generate trainval.txt 67 | with open(os.path.join(path, 'VOCaug/dataset/trainval.txt'), 'w') as outfile: 68 | for fname in filenames: 69 | fname = os.path.join(path, fname) 70 | with open(fname) as infile: 71 | for line in infile: 72 | outfile.write(line) 73 | 74 | 75 | if __name__ == '__main__': 76 | args = parse_args() 77 | path = os.path.expanduser(args.download_dir) 78 | if not os.path.isfile(path) or not os.path.isdir(os.path.join(path, 'VOC2007')) \ 79 | or not os.path.isdir(os.path.join(path, 'VOC2012')): 80 | if args.no_download: 81 | raise ValueError(('{} is not a valid directory, make sure it is present.' 82 | ' Or you should not disable "--no-download" to grab it'.format(path))) 83 | else: 84 | download_voc(path, overwrite=args.overwrite) 85 | shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2007'), os.path.join(path, 'VOC2007')) 86 | shutil.move(os.path.join(path, 'VOCdevkit', 'VOC2012'), os.path.join(path, 'VOC2012')) 87 | shutil.rmtree(os.path.join(path, 'VOCdevkit')) 88 | 89 | if not os.path.isdir(os.path.join(path, 'VOCaug')): 90 | if args.no_download: 91 | raise ValueError(('{} is not a valid directory, make sure it is present.' 92 | ' Or you should not disable "--no-download" to grab it'.format(path))) 93 | else: 94 | download_aug(path, overwrite=args.overwrite) 95 | 96 | # make symlink 97 | makedirs(os.path.expanduser('~/.torch/datasets')) 98 | if os.path.isdir(_TARGET_DIR): 99 | os.remove(_TARGET_DIR) 100 | os.symlink(path, _TARGET_DIR) 101 | -------------------------------------------------------------------------------- /core/data/downloader/sbu_shadow.py: -------------------------------------------------------------------------------- 1 | """Prepare SBU Shadow datasets""" 2 | import os 3 | import sys 4 | import argparse 5 | import zipfile 6 | 7 | # TODO: optim code 8 | cur_path = os.path.abspath(os.path.dirname(__file__)) 9 | root_path = os.path.split(os.path.split(os.path.split(cur_path)[0])[0])[0] 10 | sys.path.append(root_path) 11 | 12 | from core.utils import download, makedirs 13 | 14 | _TARGET_DIR = os.path.expanduser('~/.torch/datasets/sbu') 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description='Initialize SBU Shadow dataset.', 20 | epilog='Example: python sbu_shadow.py --download-dir ~/SBU-shadow', 21 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--download-dir', type=str, default=None, help='dataset directory on disk') 23 | parser.add_argument('--no-download', action='store_true', help='disable automatic download if set') 24 | parser.add_argument('--overwrite', action='store_true', 25 | help='overwrite downloaded files if set, in case they are corrupted') 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | ##################################################################################### 31 | # Download and extract SBU shadow datasets into ``path`` 32 | 33 | def download_sbu(path, overwrite=False): 34 | _DOWNLOAD_URLS = [ 35 | ('http://www3.cs.stonybrook.edu/~cvl/content/datasets/shadow_db/SBU-shadow.zip'), 36 | ] 37 | download_dir = os.path.join(path, 'downloads') 38 | makedirs(download_dir) 39 | for url in _DOWNLOAD_URLS: 40 | filename = download(url, path=path, overwrite=overwrite) 41 | # extract 42 | with zipfile.ZipFile(filename, "r") as zf: 43 | zf.extractall(path=path) 44 | print("Extracted", filename) 45 | 46 | 47 | if __name__ == '__main__': 48 | args = parse_args() 49 | makedirs(os.path.expanduser('~/.torch/datasets')) 50 | if args.download_dir is not None: 51 | if os.path.isdir(_TARGET_DIR): 52 | os.remove(_TARGET_DIR) 53 | # make symlink 54 | os.symlink(args.download_dir, _TARGET_DIR) 55 | else: 56 | download_sbu(_TARGET_DIR, overwrite=False) 57 | -------------------------------------------------------------------------------- /core/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Model Zoo""" 2 | from .model_zoo import get_model, get_model_list -------------------------------------------------------------------------------- /core/models/base_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet import * 2 | from .resnet import * 3 | from .resnetv1b import * 4 | from .vgg import * 5 | from .eespnet import * 6 | from .xception import * 7 | -------------------------------------------------------------------------------- /core/models/base_models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """MobileNet and MobileNetV2.""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | from core.nn import _ConvBNReLU, _DepthwiseConv, InvertedResidual 6 | 7 | __all__ = ['MobileNet', 'MobileNetV2', 'get_mobilenet', 'get_mobilenet_v2', 8 | 'mobilenet1_0', 'mobilenet_v2_1_0', 'mobilenet0_75', 'mobilenet_v2_0_75', 9 | 'mobilenet0_5', 'mobilenet_v2_0_5', 'mobilenet0_25', 'mobilenet_v2_0_25'] 10 | 11 | 12 | class MobileNet(nn.Module): 13 | def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs): 14 | super(MobileNet, self).__init__() 15 | conv_dw_setting = [ 16 | [64, 1, 1], 17 | [128, 2, 2], 18 | [256, 2, 2], 19 | [512, 6, 2], 20 | [1024, 2, 2]] 21 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32 22 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, norm_layer=norm_layer)] 23 | 24 | for c, n, s in conv_dw_setting: 25 | out_channels = int(c * multiplier) 26 | for i in range(n): 27 | stride = s if i == 0 else 1 28 | features.append(_DepthwiseConv(input_channels, out_channels, stride, norm_layer)) 29 | input_channels = out_channels 30 | features.append(nn.AdaptiveAvgPool2d(1)) 31 | self.features = nn.Sequential(*features) 32 | 33 | self.classifier = nn.Linear(int(1024 * multiplier), num_classes) 34 | 35 | # weight initialization 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 39 | if m.bias is not None: 40 | nn.init.zeros_(m.bias) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | nn.init.ones_(m.weight) 43 | nn.init.zeros_(m.bias) 44 | elif isinstance(m, nn.Linear): 45 | nn.init.normal_(m.weight, 0, 0.01) 46 | nn.init.zeros_(m.bias) 47 | 48 | def forward(self, x): 49 | x = self.features(x) 50 | x = self.classifier(x.view(x.size(0), x.size(1))) 51 | return x 52 | 53 | 54 | class MobileNetV2(nn.Module): 55 | def __init__(self, num_classes=1000, multiplier=1.0, norm_layer=nn.BatchNorm2d, **kwargs): 56 | super(MobileNetV2, self).__init__() 57 | inverted_residual_setting = [ 58 | # t, c, n, s 59 | [1, 16, 1, 1], 60 | [6, 24, 2, 2], 61 | [6, 32, 3, 2], 62 | [6, 64, 4, 2], 63 | [6, 96, 3, 1], 64 | [6, 160, 3, 2], 65 | [6, 320, 1, 1]] 66 | # building first layer 67 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32 68 | last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280 69 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer)] 70 | 71 | # building inverted residual blocks 72 | for t, c, n, s in inverted_residual_setting: 73 | out_channels = int(c * multiplier) 74 | for i in range(n): 75 | stride = s if i == 0 else 1 76 | features.append(InvertedResidual(input_channels, out_channels, stride, t, norm_layer)) 77 | input_channels = out_channels 78 | 79 | # building last several layers 80 | features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer)) 81 | features.append(nn.AdaptiveAvgPool2d(1)) 82 | self.features = nn.Sequential(*features) 83 | 84 | self.classifier = nn.Sequential( 85 | nn.Dropout2d(0.2), 86 | nn.Linear(last_channels, num_classes)) 87 | 88 | # weight initialization 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 92 | if m.bias is not None: 93 | nn.init.zeros_(m.bias) 94 | elif isinstance(m, nn.BatchNorm2d): 95 | nn.init.ones_(m.weight) 96 | nn.init.zeros_(m.bias) 97 | elif isinstance(m, nn.Linear): 98 | nn.init.normal_(m.weight, 0, 0.01) 99 | if m.bias is not None: 100 | nn.init.zeros_(m.bias) 101 | 102 | def forward(self, x): 103 | x = self.features(x) 104 | x = self.classifier(x.view(x.size(0), x.size(1))) 105 | return x 106 | 107 | 108 | # Constructor 109 | def get_mobilenet(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs): 110 | model = MobileNet(multiplier=multiplier, **kwargs) 111 | 112 | if pretrained: 113 | raise ValueError("Not support pretrained") 114 | return model 115 | 116 | 117 | def get_mobilenet_v2(multiplier=1.0, pretrained=False, root='~/.torch/models', **kwargs): 118 | model = MobileNetV2(multiplier=multiplier, **kwargs) 119 | 120 | if pretrained: 121 | raise ValueError("Not support pretrained") 122 | return model 123 | 124 | 125 | def mobilenet1_0(**kwargs): 126 | return get_mobilenet(1.0, **kwargs) 127 | 128 | 129 | def mobilenet_v2_1_0(**kwargs): 130 | return get_mobilenet_v2(1.0, **kwargs) 131 | 132 | 133 | def mobilenet0_75(**kwargs): 134 | return get_mobilenet(0.75, **kwargs) 135 | 136 | 137 | def mobilenet_v2_0_75(**kwargs): 138 | return get_mobilenet_v2(0.75, **kwargs) 139 | 140 | 141 | def mobilenet0_5(**kwargs): 142 | return get_mobilenet(0.5, **kwargs) 143 | 144 | 145 | def mobilenet_v2_0_5(**kwargs): 146 | return get_mobilenet_v2(0.5, **kwargs) 147 | 148 | 149 | def mobilenet0_25(**kwargs): 150 | return get_mobilenet(0.25, **kwargs) 151 | 152 | 153 | def mobilenet_v2_0_25(**kwargs): 154 | return get_mobilenet_v2(0.25, **kwargs) 155 | 156 | 157 | if __name__ == '__main__': 158 | model = mobilenet0_5() 159 | -------------------------------------------------------------------------------- /core/models/base_models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | __all__ = ['ResNext', 'resnext50_32x4d', 'resnext101_32x8d'] 5 | 6 | model_urls = { 7 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 8 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 9 | } 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 16 | base_width=64, dilation=1, norm_layer=None, **kwargs): 17 | super(Bottleneck, self).__init__() 18 | width = int(planes * (base_width / 64.)) * groups 19 | 20 | self.conv1 = nn.Conv2d(inplanes, width, 1, bias=False) 21 | self.bn1 = norm_layer(width) 22 | self.conv2 = nn.Conv2d(width, width, 3, stride, dilation, dilation, groups, bias=False) 23 | self.bn2 = norm_layer(width) 24 | self.conv3 = nn.Conv2d(width, planes * self.expansion, 1, bias=False) 25 | self.bn3 = norm_layer(planes * self.expansion) 26 | self.relu = nn.ReLU(True) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | identity = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv3(out) 42 | out = self.bn3(out) 43 | 44 | if self.downsample is not None: 45 | identity = self.downsample(x) 46 | 47 | out += identity 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | 53 | class ResNext(nn.Module): 54 | 55 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, 56 | width_per_group=64, dilated=False, norm_layer=nn.BatchNorm2d, **kwargs): 57 | super(ResNext, self).__init__() 58 | self.inplanes = 64 59 | self.groups = groups 60 | self.base_width = width_per_group 61 | 62 | self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False) 63 | self.bn1 = norm_layer(self.inplanes) 64 | self.relu = nn.ReLU(True) 65 | self.maxpool = nn.MaxPool2d(3, 2, 1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 69 | if dilated: 70 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2, norm_layer=norm_layer) 71 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, norm_layer=norm_layer) 72 | else: 73 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 74 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 75 | 76 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 77 | self.fc = nn.Linear(512 * block.expansion, num_classes) 78 | 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 82 | elif isinstance(m, nn.BatchNorm2d): 83 | nn.init.constant_(m.weight, 1) 84 | nn.init.constant_(m.bias, 0) 85 | 86 | if zero_init_residual: 87 | for m in self.modules(): 88 | if isinstance(m, Bottleneck): 89 | nn.init.constant_(m.bn3.weight, 0) 90 | 91 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d): 92 | downsample = None 93 | if stride != 1 or self.inplanes != planes * block.expansion: 94 | downsample = nn.Sequential( 95 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False), 96 | norm_layer(planes * block.expansion) 97 | ) 98 | 99 | layers = list() 100 | if dilation in (1, 2): 101 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 102 | self.base_width, norm_layer=norm_layer)) 103 | elif dilation == 4: 104 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 105 | self.base_width, dilation=2, norm_layer=norm_layer)) 106 | else: 107 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 108 | self.inplanes = planes * block.expansion 109 | for _ in range(1, blocks): 110 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, 111 | dilation=dilation, norm_layer=norm_layer)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | x = self.relu(x) 119 | x = self.maxpool(x) 120 | 121 | x = self.layer1(x) 122 | x = self.layer2(x) 123 | x = self.layer3(x) 124 | x = self.layer4(x) 125 | 126 | x = self.avgpool(x) 127 | x = x.view(x.size(0), -1) 128 | x = self.fc(x) 129 | 130 | return x 131 | 132 | 133 | def resnext50_32x4d(pretrained=False, **kwargs): 134 | kwargs['groups'] = 32 135 | kwargs['width_per_group'] = 4 136 | model = ResNext(Bottleneck, [3, 4, 6, 3], **kwargs) 137 | if pretrained: 138 | state_dict = model_zoo.load_url(model_urls['resnext50_32x4d']) 139 | model.load_state_dict(state_dict) 140 | return model 141 | 142 | 143 | def resnext101_32x8d(pretrained=False, **kwargs): 144 | kwargs['groups'] = 32 145 | kwargs['width_per_group'] = 8 146 | model = ResNext(Bottleneck, [3, 4, 23, 3], **kwargs) 147 | if pretrained: 148 | state_dict = model_zoo.load_url(model_urls['resnext101_32x8d']) 149 | model.load_state_dict(state_dict) 150 | return model 151 | 152 | 153 | if __name__ == '__main__': 154 | model = resnext101_32x8d() 155 | -------------------------------------------------------------------------------- /core/models/ccnet.py: -------------------------------------------------------------------------------- 1 | """Criss-Cross Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from core.nn import CrissCrossAttention 7 | from .segbase import SegBaseModel 8 | from .fcn import _FCNHead 9 | 10 | __all__ = ['CCNet', 'get_ccnet', 'get_ccnet_resnet50_citys', 'get_ccnet_resnet101_citys', 11 | 'get_ccnet_resnet152_citys', 'get_ccnet_resnet50_ade', 'get_ccnet_resnet101_ade', 12 | 'get_ccnet_resnet152_ade'] 13 | 14 | 15 | class CCNet(SegBaseModel): 16 | r"""CCNet 17 | 18 | Parameters 19 | ---------- 20 | nclass : int 21 | Number of categories for the training dataset. 22 | backbone : string 23 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 24 | 'resnet101' or 'resnet152'). 25 | norm_layer : object 26 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 27 | for Synchronized Cross-GPU BachNormalization). 28 | aux : bool 29 | Auxiliary loss. 30 | 31 | Reference: 32 | Zilong Huang, et al. "CCNet: Criss-Cross Attention for Semantic Segmentation." 33 | arXiv preprint arXiv:1811.11721 (2018). 34 | """ 35 | 36 | def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs): 37 | super(CCNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 38 | self.head = _CCHead(nclass, **kwargs) 39 | if aux: 40 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 41 | 42 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 43 | 44 | def forward(self, x): 45 | size = x.size()[2:] 46 | _, _, c3, c4 = self.base_forward(x) 47 | outputs = list() 48 | x = self.head(c4) 49 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 50 | outputs.append(x) 51 | 52 | if self.aux: 53 | auxout = self.auxlayer(c3) 54 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 55 | outputs.append(auxout) 56 | return tuple(outputs) 57 | 58 | 59 | class _CCHead(nn.Module): 60 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs): 61 | super(_CCHead, self).__init__() 62 | self.rcca = _RCCAModule(2048, 512, norm_layer, **kwargs) 63 | self.out = nn.Conv2d(512, nclass, 1) 64 | 65 | def forward(self, x): 66 | x = self.rcca(x) 67 | x = self.out(x) 68 | return x 69 | 70 | 71 | class _RCCAModule(nn.Module): 72 | def __init__(self, in_channels, out_channels, norm_layer, **kwargs): 73 | super(_RCCAModule, self).__init__() 74 | inter_channels = in_channels // 4 75 | self.conva = nn.Sequential( 76 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 77 | norm_layer(inter_channels), 78 | nn.ReLU(True)) 79 | self.cca = CrissCrossAttention(inter_channels) 80 | self.convb = nn.Sequential( 81 | nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False), 82 | norm_layer(inter_channels), 83 | nn.ReLU(True)) 84 | 85 | self.bottleneck = nn.Sequential( 86 | nn.Conv2d(in_channels + inter_channels, out_channels, 3, padding=1, bias=False), 87 | norm_layer(out_channels), 88 | nn.Dropout2d(0.1)) 89 | 90 | def forward(self, x, recurrence=1): 91 | out = self.conva(x) 92 | for i in range(recurrence): 93 | out = self.cca(out) 94 | out = self.convb(out) 95 | out = torch.cat([x, out], dim=1) 96 | out = self.bottleneck(out) 97 | 98 | return out 99 | 100 | 101 | def get_ccnet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models', 102 | pretrained_base=True, **kwargs): 103 | acronyms = { 104 | 'pascal_voc': 'pascal_voc', 105 | 'pascal_aug': 'pascal_aug', 106 | 'ade20k': 'ade', 107 | 'coco': 'coco', 108 | 'citys': 'citys', 109 | } 110 | from ..data.dataloader import datasets 111 | model = CCNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 112 | if pretrained: 113 | from .model_store import get_model_file 114 | device = torch.device(kwargs['local_rank']) 115 | model.load_state_dict(torch.load(get_model_file('ccnet_%s_%s' % (backbone, acronyms[dataset]), root=root), 116 | map_location=device)) 117 | return model 118 | 119 | 120 | def get_ccnet_resnet50_citys(**kwargs): 121 | return get_ccnet('citys', 'resnet50', **kwargs) 122 | 123 | 124 | def get_ccnet_resnet101_citys(**kwargs): 125 | return get_ccnet('citys', 'resnet101', **kwargs) 126 | 127 | 128 | def get_ccnet_resnet152_citys(**kwargs): 129 | return get_ccnet('citys', 'resnet152', **kwargs) 130 | 131 | 132 | def get_ccnet_resnet50_ade(**kwargs): 133 | return get_ccnet('ade20k', 'resnet50', **kwargs) 134 | 135 | 136 | def get_ccnet_resnet101_ade(**kwargs): 137 | return get_ccnet('ade20k', 'resnet101', **kwargs) 138 | 139 | 140 | def get_ccnet_resnet152_ade(**kwargs): 141 | return get_ccnet('ade20k', 'resnet152', **kwargs) 142 | 143 | 144 | if __name__ == '__main__': 145 | model = get_ccnet_resnet50_citys() 146 | img = torch.randn(1, 3, 480, 480) 147 | outputs = model(img) 148 | -------------------------------------------------------------------------------- /core/models/deeplabv3_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_models.xception import get_xception 6 | from .deeplabv3 import _ASPP 7 | from .fcn import _FCNHead 8 | from ..nn import _ConvBNReLU 9 | 10 | __all__ = ['DeepLabV3Plus', 'get_deeplabv3_plus', 'get_deeplabv3_plus_xception_voc'] 11 | 12 | 13 | class DeepLabV3Plus(nn.Module): 14 | r"""DeepLabV3Plus 15 | Parameters 16 | ---------- 17 | nclass : int 18 | Number of categories for the training dataset. 19 | backbone : string 20 | Pre-trained dilated backbone network type (default:'xception'). 21 | norm_layer : object 22 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 23 | for Synchronized Cross-GPU BachNormalization). 24 | aux : bool 25 | Auxiliary loss. 26 | 27 | Reference: 28 | Chen, Liang-Chieh, et al. "Encoder-Decoder with Atrous Separable Convolution for Semantic 29 | Image Segmentation." 30 | """ 31 | 32 | def __init__(self, nclass, backbone='xception', aux=True, pretrained_base=True, dilated=True, **kwargs): 33 | super(DeepLabV3Plus, self).__init__() 34 | self.aux = aux 35 | self.nclass = nclass 36 | output_stride = 8 if dilated else 32 37 | 38 | self.pretrained = get_xception(pretrained=pretrained_base, output_stride=output_stride, **kwargs) 39 | 40 | # deeplabv3 plus 41 | self.head = _DeepLabHead(nclass, **kwargs) 42 | if aux: 43 | self.auxlayer = _FCNHead(728, nclass, **kwargs) 44 | 45 | def base_forward(self, x): 46 | # Entry flow 47 | x = self.pretrained.conv1(x) 48 | x = self.pretrained.bn1(x) 49 | x = self.pretrained.relu(x) 50 | 51 | x = self.pretrained.conv2(x) 52 | x = self.pretrained.bn2(x) 53 | x = self.pretrained.relu(x) 54 | 55 | x = self.pretrained.block1(x) 56 | # add relu here 57 | x = self.pretrained.relu(x) 58 | low_level_feat = x 59 | 60 | x = self.pretrained.block2(x) 61 | x = self.pretrained.block3(x) 62 | 63 | # Middle flow 64 | x = self.pretrained.midflow(x) 65 | mid_level_feat = x 66 | 67 | # Exit flow 68 | x = self.pretrained.block20(x) 69 | x = self.pretrained.relu(x) 70 | x = self.pretrained.conv3(x) 71 | x = self.pretrained.bn3(x) 72 | x = self.pretrained.relu(x) 73 | 74 | x = self.pretrained.conv4(x) 75 | x = self.pretrained.bn4(x) 76 | x = self.pretrained.relu(x) 77 | 78 | x = self.pretrained.conv5(x) 79 | x = self.pretrained.bn5(x) 80 | x = self.pretrained.relu(x) 81 | return low_level_feat, mid_level_feat, x 82 | 83 | def forward(self, x): 84 | size = x.size()[2:] 85 | c1, c3, c4 = self.base_forward(x) 86 | outputs = list() 87 | x = self.head(c4, c1) 88 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 89 | outputs.append(x) 90 | if self.aux: 91 | auxout = self.auxlayer(c3) 92 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 93 | outputs.append(auxout) 94 | return tuple(outputs) 95 | 96 | 97 | class _DeepLabHead(nn.Module): 98 | def __init__(self, nclass, c1_channels=128, norm_layer=nn.BatchNorm2d, **kwargs): 99 | super(_DeepLabHead, self).__init__() 100 | self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, **kwargs) 101 | self.c1_block = _ConvBNReLU(c1_channels, 48, 3, padding=1, norm_layer=norm_layer) 102 | self.block = nn.Sequential( 103 | _ConvBNReLU(304, 256, 3, padding=1, norm_layer=norm_layer), 104 | nn.Dropout(0.5), 105 | _ConvBNReLU(256, 256, 3, padding=1, norm_layer=norm_layer), 106 | nn.Dropout(0.1), 107 | nn.Conv2d(256, nclass, 1)) 108 | 109 | def forward(self, x, c1): 110 | size = c1.size()[2:] 111 | c1 = self.c1_block(c1) 112 | x = self.aspp(x) 113 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 114 | return self.block(torch.cat([x, c1], dim=1)) 115 | 116 | 117 | def get_deeplabv3_plus(dataset='pascal_voc', backbone='xception', pretrained=False, root='~/.torch/models', 118 | pretrained_base=True, **kwargs): 119 | acronyms = { 120 | 'pascal_voc': 'pascal_voc', 121 | 'pascal_aug': 'pascal_aug', 122 | 'ade20k': 'ade', 123 | 'coco': 'coco', 124 | 'citys': 'citys', 125 | } 126 | from ..data.dataloader import datasets 127 | model = DeepLabV3Plus(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 128 | if pretrained: 129 | from .model_store import get_model_file 130 | device = torch.device(kwargs['local_rank']) 131 | model.load_state_dict( 132 | torch.load(get_model_file('deeplabv3_plus_%s_%s' % (backbone, acronyms[dataset]), root=root), 133 | map_location=device)) 134 | return model 135 | 136 | 137 | def get_deeplabv3_plus_xception_voc(**kwargs): 138 | return get_deeplabv3_plus('pascal_voc', 'xception', **kwargs) 139 | 140 | 141 | if __name__ == '__main__': 142 | model = get_deeplabv3_plus_xception_voc() 143 | -------------------------------------------------------------------------------- /core/models/dfanet.py: -------------------------------------------------------------------------------- 1 | """ Deep Feature Aggregation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from core.models.base_models import Enc, FCAttention, get_xception_a 7 | from core.nn import _ConvBNReLU 8 | 9 | __all__ = ['DFANet', 'get_dfanet', 'get_dfanet_citys'] 10 | 11 | 12 | class DFANet(nn.Module): 13 | def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs): 14 | super(DFANet, self).__init__() 15 | self.pretrained = get_xception_a(pretrained_base, **kwargs) 16 | 17 | self.enc2_2 = Enc(240, 48, 4, **kwargs) 18 | self.enc3_2 = Enc(144, 96, 6, **kwargs) 19 | self.enc4_2 = Enc(288, 192, 4, **kwargs) 20 | self.fca_2 = FCAttention(192, **kwargs) 21 | 22 | self.enc2_3 = Enc(240, 48, 4, **kwargs) 23 | self.enc3_3 = Enc(144, 96, 6, **kwargs) 24 | self.enc3_4 = Enc(288, 192, 4, **kwargs) 25 | self.fca_3 = FCAttention(192, **kwargs) 26 | 27 | self.enc2_1_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 28 | self.enc2_2_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 29 | self.enc2_3_reduce = _ConvBNReLU(48, 32, 1, **kwargs) 30 | self.conv_fusion = _ConvBNReLU(32, 32, 1, **kwargs) 31 | 32 | self.fca_1_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 33 | self.fca_2_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 34 | self.fca_3_reduce = _ConvBNReLU(192, 32, 1, **kwargs) 35 | self.conv_out = nn.Conv2d(32, nclass, 1) 36 | 37 | self.__setattr__('exclusive', ['enc2_2', 'enc3_2', 'enc4_2', 'fca_2', 'enc2_3', 'enc3_3', 'enc3_4', 'fca_3', 38 | 'enc2_1_reduce', 'enc2_2_reduce', 'enc2_3_reduce', 'conv_fusion', 'fca_1_reduce', 39 | 'fca_2_reduce', 'fca_3_reduce', 'conv_out']) 40 | 41 | def forward(self, x): 42 | # backbone 43 | stage1_conv1 = self.pretrained.conv1(x) 44 | stage1_enc2 = self.pretrained.enc2(stage1_conv1) 45 | stage1_enc3 = self.pretrained.enc3(stage1_enc2) 46 | stage1_enc4 = self.pretrained.enc4(stage1_enc3) 47 | stage1_fca = self.pretrained.fca(stage1_enc4) 48 | stage1_out = F.interpolate(stage1_fca, scale_factor=4, mode='bilinear', align_corners=True) 49 | 50 | # stage2 51 | stage2_enc2 = self.enc2_2(torch.cat([stage1_enc2, stage1_out], dim=1)) 52 | stage2_enc3 = self.enc3_2(torch.cat([stage1_enc3, stage2_enc2], dim=1)) 53 | stage2_enc4 = self.enc4_2(torch.cat([stage1_enc4, stage2_enc3], dim=1)) 54 | stage2_fca = self.fca_2(stage2_enc4) 55 | stage2_out = F.interpolate(stage2_fca, scale_factor=4, mode='bilinear', align_corners=True) 56 | 57 | # stage3 58 | stage3_enc2 = self.enc2_3(torch.cat([stage2_enc2, stage2_out], dim=1)) 59 | stage3_enc3 = self.enc3_3(torch.cat([stage2_enc3, stage3_enc2], dim=1)) 60 | stage3_enc4 = self.enc3_4(torch.cat([stage2_enc4, stage3_enc3], dim=1)) 61 | stage3_fca = self.fca_3(stage3_enc4) 62 | 63 | stage1_enc2_decoder = self.enc2_1_reduce(stage1_enc2) 64 | stage2_enc2_docoder = F.interpolate(self.enc2_2_reduce(stage2_enc2), scale_factor=2, 65 | mode='bilinear', align_corners=True) 66 | stage3_enc2_decoder = F.interpolate(self.enc2_3_reduce(stage3_enc2), scale_factor=4, 67 | mode='bilinear', align_corners=True) 68 | fusion = stage1_enc2_decoder + stage2_enc2_docoder + stage3_enc2_decoder 69 | fusion = self.conv_fusion(fusion) 70 | 71 | stage1_fca_decoder = F.interpolate(self.fca_1_reduce(stage1_fca), scale_factor=4, 72 | mode='bilinear', align_corners=True) 73 | stage2_fca_decoder = F.interpolate(self.fca_2_reduce(stage2_fca), scale_factor=8, 74 | mode='bilinear', align_corners=True) 75 | stage3_fca_decoder = F.interpolate(self.fca_3_reduce(stage3_fca), scale_factor=16, 76 | mode='bilinear', align_corners=True) 77 | fusion = fusion + stage1_fca_decoder + stage2_fca_decoder + stage3_fca_decoder 78 | 79 | outputs = list() 80 | out = self.conv_out(fusion) 81 | out = F.interpolate(out, scale_factor=4, mode='bilinear', align_corners=True) 82 | outputs.append(out) 83 | 84 | return tuple(outputs) 85 | 86 | 87 | def get_dfanet(dataset='citys', backbone='', pretrained=False, root='~/.torch/models', 88 | pretrained_base=True, **kwargs): 89 | acronyms = { 90 | 'pascal_voc': 'pascal_voc', 91 | 'pascal_aug': 'pascal_aug', 92 | 'ade20k': 'ade', 93 | 'coco': 'coco', 94 | 'citys': 'citys', 95 | } 96 | from ..data.dataloader import datasets 97 | model = DFANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 98 | if pretrained: 99 | from .model_store import get_model_file 100 | device = torch.device(kwargs['local_rank']) 101 | model.load_state_dict(torch.load(get_model_file('dfanet_%s' % (acronyms[dataset]), root=root), 102 | map_location=device)) 103 | return model 104 | 105 | 106 | def get_dfanet_citys(**kwargs): 107 | return get_dfanet('citys', **kwargs) 108 | 109 | 110 | if __name__ == '__main__': 111 | model = get_dfanet_citys() 112 | -------------------------------------------------------------------------------- /core/models/dunet.py: -------------------------------------------------------------------------------- 1 | """Decoders Matter for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | from .fcn import _FCNHead 8 | 9 | __all__ = ['DUNet', 'get_dunet', 'get_dunet_resnet50_pascal_voc', 10 | 'get_dunet_resnet101_pascal_voc', 'get_dunet_resnet152_pascal_voc'] 11 | 12 | 13 | # The model may be wrong because lots of details missing in paper. 14 | class DUNet(SegBaseModel): 15 | """Decoders Matter for Semantic Segmentation 16 | 17 | Reference: 18 | Zhi Tian, Tong He, Chunhua Shen, and Youliang Yan. 19 | "Decoders Matter for Semantic Segmentation: 20 | Data-Dependent Decoding Enables Flexible Feature Aggregation." CVPR, 2019 21 | """ 22 | 23 | def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs): 24 | super(DUNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 25 | self.head = _DUHead(2144, **kwargs) 26 | self.dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs) 27 | if aux: 28 | self.auxlayer = _FCNHead(1024, 256, **kwargs) 29 | self.aux_dupsample = DUpsampling(256, nclass, scale_factor=8, **kwargs) 30 | 31 | self.__setattr__('exclusive', 32 | ['dupsample', 'head', 'auxlayer', 'aux_dupsample'] if aux else ['dupsample', 'head']) 33 | 34 | def forward(self, x): 35 | c1, c2, c3, c4 = self.base_forward(x) 36 | outputs = [] 37 | x = self.head(c2, c3, c4) 38 | x = self.dupsample(x) 39 | outputs.append(x) 40 | 41 | if self.aux: 42 | auxout = self.auxlayer(c3) 43 | auxout = self.aux_dupsample(auxout) 44 | outputs.append(auxout) 45 | return tuple(outputs) 46 | 47 | 48 | class FeatureFused(nn.Module): 49 | """Module for fused features""" 50 | 51 | def __init__(self, inter_channels=48, norm_layer=nn.BatchNorm2d, **kwargs): 52 | super(FeatureFused, self).__init__() 53 | self.conv2 = nn.Sequential( 54 | nn.Conv2d(512, inter_channels, 1, bias=False), 55 | norm_layer(inter_channels), 56 | nn.ReLU(True) 57 | ) 58 | self.conv3 = nn.Sequential( 59 | nn.Conv2d(1024, inter_channels, 1, bias=False), 60 | norm_layer(inter_channels), 61 | nn.ReLU(True) 62 | ) 63 | 64 | def forward(self, c2, c3, c4): 65 | size = c4.size()[2:] 66 | c2 = self.conv2(F.interpolate(c2, size, mode='bilinear', align_corners=True)) 67 | c3 = self.conv3(F.interpolate(c3, size, mode='bilinear', align_corners=True)) 68 | fused_feature = torch.cat([c4, c3, c2], dim=1) 69 | return fused_feature 70 | 71 | 72 | class _DUHead(nn.Module): 73 | def __init__(self, in_channels, norm_layer=nn.BatchNorm2d, **kwargs): 74 | super(_DUHead, self).__init__() 75 | self.fuse = FeatureFused(norm_layer=norm_layer, **kwargs) 76 | self.block = nn.Sequential( 77 | nn.Conv2d(in_channels, 256, 3, padding=1, bias=False), 78 | norm_layer(256), 79 | nn.ReLU(True), 80 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 81 | norm_layer(256), 82 | nn.ReLU(True) 83 | ) 84 | 85 | def forward(self, c2, c3, c4): 86 | fused_feature = self.fuse(c2, c3, c4) 87 | out = self.block(fused_feature) 88 | return out 89 | 90 | 91 | class DUpsampling(nn.Module): 92 | """DUsampling module""" 93 | 94 | def __init__(self, in_channels, out_channels, scale_factor=2, **kwargs): 95 | super(DUpsampling, self).__init__() 96 | self.scale_factor = scale_factor 97 | self.conv_w = nn.Conv2d(in_channels, out_channels * scale_factor * scale_factor, 1, bias=False) 98 | 99 | def forward(self, x): 100 | x = self.conv_w(x) 101 | n, c, h, w = x.size() 102 | 103 | # N, C, H, W --> N, W, H, C 104 | x = x.permute(0, 3, 2, 1).contiguous() 105 | 106 | # N, W, H, C --> N, W, H * scale, C // scale 107 | x = x.view(n, w, h * self.scale_factor, c // self.scale_factor) 108 | 109 | # N, W, H * scale, C // scale --> N, H * scale, W, C // scale 110 | x = x.permute(0, 2, 1, 3).contiguous() 111 | 112 | # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) 113 | x = x.view(n, h * self.scale_factor, w * self.scale_factor, c // (self.scale_factor * self.scale_factor)) 114 | 115 | # N, H * scale, W * scale, C // (scale ** 2) -- > N, C // (scale ** 2), H * scale, W * scale 116 | x = x.permute(0, 3, 1, 2) 117 | 118 | return x 119 | 120 | 121 | def get_dunet(dataset='pascal_voc', backbone='resnet50', pretrained=False, 122 | root='~/.torch/models', pretrained_base=True, **kwargs): 123 | acronyms = { 124 | 'pascal_voc': 'pascal_voc', 125 | 'pascal_aug': 'pascal_aug', 126 | 'ade20k': 'ade', 127 | 'coco': 'coco', 128 | 'citys': 'citys', 129 | } 130 | from ..data.dataloader import datasets 131 | model = DUNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 132 | if pretrained: 133 | from .model_store import get_model_file 134 | device = torch.device(kwargs['local_rank']) 135 | model.load_state_dict(torch.load(get_model_file('dunet_%s_%s' % (backbone, acronyms[dataset]), root=root), 136 | map_location=device)) 137 | return model 138 | 139 | 140 | def get_dunet_resnet50_pascal_voc(**kwargs): 141 | return get_dunet('pascal_voc', 'resnet50', **kwargs) 142 | 143 | 144 | def get_dunet_resnet101_pascal_voc(**kwargs): 145 | return get_dunet('pascal_voc', 'resnet101', **kwargs) 146 | 147 | 148 | def get_dunet_resnet152_pascal_voc(**kwargs): 149 | return get_dunet('pascal_voc', 'resnet152', **kwargs) 150 | 151 | 152 | if __name__ == '__main__': 153 | img = torch.randn(2, 3, 256, 256) 154 | model = get_dunet_resnet50_pascal_voc() 155 | outputs = model(img) 156 | -------------------------------------------------------------------------------- /core/models/espnet.py: -------------------------------------------------------------------------------- 1 | "ESPNetv2: A Light-weight, Power Efficient, and General Purpose for Semantic Segmentation" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from core.models.base_models import eespnet, EESP 7 | from core.nn import _ConvBNPReLU, _BNPReLU 8 | 9 | 10 | class ESPNetV2(nn.Module): 11 | r"""ESPNetV2 12 | 13 | Parameters 14 | ---------- 15 | nclass : int 16 | Number of categories for the training dataset. 17 | backbone : string 18 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 19 | 'resnet101' or 'resnet152'). 20 | norm_layer : object 21 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 22 | for Synchronized Cross-GPU BachNormalization). 23 | aux : bool 24 | Auxiliary loss. 25 | 26 | Reference: 27 | Sachin Mehta, et al. "ESPNetv2: A Light-weight, Power Efficient, and General Purpose Convolutional Neural Network." 28 | arXiv preprint arXiv:1811.11431 (2018). 29 | """ 30 | 31 | def __init__(self, nclass, backbone='', aux=False, jpu=False, pretrained_base=False, **kwargs): 32 | super(ESPNetV2, self).__init__() 33 | self.pretrained = eespnet(pretrained=pretrained_base, **kwargs) 34 | self.proj_L4_C = _ConvBNPReLU(256, 128, 1, **kwargs) 35 | self.pspMod = nn.Sequential( 36 | EESP(256, 128, stride=1, k=4, r_lim=7, **kwargs), 37 | _PSPModule(128, 128, **kwargs)) 38 | self.project_l3 = nn.Sequential( 39 | nn.Dropout2d(0.1), 40 | nn.Conv2d(128, nclass, 1, bias=False)) 41 | self.act_l3 = _BNPReLU(nclass, **kwargs) 42 | self.project_l2 = _ConvBNPReLU(64 + nclass, nclass, 1, **kwargs) 43 | self.project_l1 = nn.Sequential( 44 | nn.Dropout2d(0.1), 45 | nn.Conv2d(32 + nclass, nclass, 1, bias=False)) 46 | 47 | self.aux = aux 48 | 49 | self.__setattr__('exclusive', ['proj_L4_C', 'pspMod', 'project_l3', 'act_l3', 'project_l2', 'project_l1']) 50 | 51 | def forward(self, x): 52 | size = x.size()[2:] 53 | out_l1, out_l2, out_l3, out_l4 = self.pretrained(x, seg=True) 54 | out_l4_proj = self.proj_L4_C(out_l4) 55 | up_l4_to_l3 = F.interpolate(out_l4_proj, scale_factor=2, mode='bilinear', align_corners=True) 56 | merged_l3_upl4 = self.pspMod(torch.cat([out_l3, up_l4_to_l3], 1)) 57 | proj_merge_l3_bef_act = self.project_l3(merged_l3_upl4) 58 | proj_merge_l3 = self.act_l3(proj_merge_l3_bef_act) 59 | out_up_l3 = F.interpolate(proj_merge_l3, scale_factor=2, mode='bilinear', align_corners=True) 60 | merge_l2 = self.project_l2(torch.cat([out_l2, out_up_l3], 1)) 61 | out_up_l2 = F.interpolate(merge_l2, scale_factor=2, mode='bilinear', align_corners=True) 62 | merge_l1 = self.project_l1(torch.cat([out_l1, out_up_l2], 1)) 63 | 64 | outputs = list() 65 | merge1_l1 = F.interpolate(merge_l1, scale_factor=2, mode='bilinear', align_corners=True) 66 | outputs.append(merge1_l1) 67 | if self.aux: 68 | # different from paper 69 | auxout = F.interpolate(proj_merge_l3_bef_act, size, mode='bilinear', align_corners=True) 70 | outputs.append(auxout) 71 | 72 | return tuple(outputs) 73 | 74 | 75 | # different from PSPNet 76 | class _PSPModule(nn.Module): 77 | def __init__(self, in_channels, out_channels=1024, sizes=(1, 2, 4, 8), **kwargs): 78 | super(_PSPModule, self).__init__() 79 | self.stages = nn.ModuleList( 80 | [nn.Conv2d(in_channels, in_channels, 3, 1, 1, groups=in_channels, bias=False) for _ in sizes]) 81 | self.project = _ConvBNPReLU(in_channels * (len(sizes) + 1), out_channels, 1, 1, **kwargs) 82 | 83 | def forward(self, x): 84 | size = x.size()[2:] 85 | feats = [x] 86 | for stage in self.stages: 87 | x = F.avg_pool2d(x, kernel_size=3, stride=2, padding=1) 88 | upsampled = F.interpolate(stage(x), size, mode='bilinear', align_corners=True) 89 | feats.append(upsampled) 90 | return self.project(torch.cat(feats, dim=1)) 91 | 92 | 93 | def get_espnet(dataset='pascal_voc', backbone='', pretrained=False, root='~/.torch/models', 94 | pretrained_base=False, **kwargs): 95 | acronyms = { 96 | 'pascal_voc': 'pascal_voc', 97 | 'pascal_aug': 'pascal_aug', 98 | 'ade20k': 'ade', 99 | 'coco': 'coco', 100 | 'citys': 'citys', 101 | } 102 | from core.data.dataloader import datasets 103 | model = ESPNetV2(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 104 | if pretrained: 105 | from .model_store import get_model_file 106 | device = torch.device(kwargs['local_rank']) 107 | model.load_state_dict(torch.load(get_model_file('espnet_%s_%s' % (backbone, acronyms[dataset]), root=root), 108 | map_location=device)) 109 | return model 110 | 111 | 112 | def get_espnet_citys(**kwargs): 113 | return get_espnet('citys', **kwargs) 114 | 115 | 116 | if __name__ == '__main__': 117 | model = get_espnet_citys() 118 | -------------------------------------------------------------------------------- /core/models/fcnv2.py: -------------------------------------------------------------------------------- 1 | """Fully Convolutional Network with Stride of 8""" 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .segbase import SegBaseModel 9 | 10 | __all__ = ['FCN', 'get_fcn', 'get_fcn_resnet50_voc', 11 | 'get_fcn_resnet101_voc', 'get_fcn_resnet152_voc'] 12 | 13 | 14 | class FCN(SegBaseModel): 15 | def __init__(self, nclass, backbone='resnet50', aux=True, pretrained_base=True, **kwargs): 16 | super(FCN, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 17 | self.head = _FCNHead(2048, nclass, **kwargs) 18 | if aux: 19 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 20 | 21 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 22 | 23 | def forward(self, x): 24 | size = x.size()[2:] 25 | _, _, c3, c4 = self.base_forward(x) 26 | 27 | outputs = [] 28 | x = self.head(c4) 29 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 30 | outputs.append(x) 31 | if self.aux: 32 | auxout = self.auxlayer(c3) 33 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 34 | outputs.append(auxout) 35 | return tuple(outputs) 36 | 37 | 38 | class _FCNHead(nn.Module): 39 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs): 40 | super(_FCNHead, self).__init__() 41 | inter_channels = in_channels // 4 42 | self.block = nn.Sequential( 43 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 44 | norm_layer(inter_channels, **({} if norm_kwargs is None else norm_kwargs)), 45 | nn.ReLU(True), 46 | nn.Dropout(0.1), 47 | nn.Conv2d(inter_channels, channels, 1) 48 | ) 49 | 50 | def forward(self, x): 51 | return self.block(x) 52 | 53 | 54 | def get_fcn(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models', 55 | pretrained_base=True, **kwargs): 56 | acronyms = { 57 | 'pascal_voc': 'pascal_voc', 58 | 'pascal_aug': 'pascal_aug', 59 | 'ade20k': 'ade', 60 | 'coco': 'coco', 61 | 'citys': 'citys', 62 | } 63 | from ..data.dataloader import datasets 64 | model = FCN(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 65 | if pretrained: 66 | from .model_store import get_model_file 67 | device = torch.device(kwargs['local_rank']) 68 | model.load_state_dict(torch.load(get_model_file('fcn_%s_%s' % (backbone, acronyms[dataset]), root=root), 69 | map_location=device)) 70 | return model 71 | 72 | 73 | def get_fcn_resnet50_voc(**kwargs): 74 | return get_fcn('pascal_voc', 'resnet50', **kwargs) 75 | 76 | 77 | def get_fcn_resnet101_voc(**kwargs): 78 | return get_fcn('pascal_voc', 'resnet101', **kwargs) 79 | 80 | 81 | def get_fcn_resnet152_voc(**kwargs): 82 | return get_fcn('pascal_voc', 'resnet152', **kwargs) 83 | -------------------------------------------------------------------------------- /core/models/hrnet.py: -------------------------------------------------------------------------------- 1 | """High-Resolution Representations for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class HRNet(nn.Module): 7 | """HRNet 8 | 9 | Parameters 10 | ---------- 11 | nclass : int 12 | Number of categories for the training dataset. 13 | backbone : string 14 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 15 | 'resnet101' or 'resnet152'). 16 | norm_layer : object 17 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 18 | for Synchronized Cross-GPU BachNormalization). 19 | aux : bool 20 | Auxiliary loss. 21 | Reference: 22 | Ke Sun. "High-Resolution Representations for Labeling Pixels and Regions." 23 | arXiv preprint arXiv:1904.04514 (2019). 24 | """ 25 | def __init__(self, nclass, backbone='', aux=False, pretrained_base=False, **kwargs): 26 | super(HRNet, self).__init__() 27 | 28 | def forward(self, x): 29 | pass -------------------------------------------------------------------------------- /core/models/icnet.py: -------------------------------------------------------------------------------- 1 | """Image Cascade Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .segbase import SegBaseModel 7 | 8 | __all__ = ['ICNet', 'get_icnet', 'get_icnet_resnet50_citys', 9 | 'get_icnet_resnet101_citys', 'get_icnet_resnet152_citys'] 10 | 11 | 12 | class ICNet(SegBaseModel): 13 | """Image Cascade Network""" 14 | 15 | def __init__(self, nclass, backbone='resnet50', aux=False, jpu=False, pretrained_base=True, **kwargs): 16 | super(ICNet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 17 | self.conv_sub1 = nn.Sequential( 18 | _ConvBNReLU(3, 32, 3, 2, **kwargs), 19 | _ConvBNReLU(32, 32, 3, 2, **kwargs), 20 | _ConvBNReLU(32, 64, 3, 2, **kwargs) 21 | ) 22 | 23 | self.head = _ICHead(nclass, **kwargs) 24 | 25 | self.__setattr__('exclusive', ['conv_sub1', 'head']) 26 | 27 | def forward(self, x): 28 | # sub 1 29 | x_sub1 = self.conv_sub1(x) 30 | 31 | # sub 2 32 | x_sub2 = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) 33 | _, x_sub2, _, _ = self.base_forward(x_sub2) 34 | 35 | # sub 4 36 | x_sub4 = F.interpolate(x, scale_factor=0.25, mode='bilinear', align_corners=True) 37 | _, _, _, x_sub4 = self.base_forward(x_sub4) 38 | 39 | outputs = self.head(x_sub1, x_sub2, x_sub4) 40 | 41 | return tuple(outputs) 42 | 43 | 44 | class _ICHead(nn.Module): 45 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs): 46 | super(_ICHead, self).__init__() 47 | self.cff_12 = CascadeFeatureFusion(512, 64, 128, nclass, norm_layer, **kwargs) 48 | self.cff_24 = CascadeFeatureFusion(2048, 512, 128, nclass, norm_layer, **kwargs) 49 | 50 | self.conv_cls = nn.Conv2d(128, nclass, 1, bias=False) 51 | 52 | def forward(self, x_sub1, x_sub2, x_sub4): 53 | outputs = list() 54 | x_cff_24, x_24_cls = self.cff_24(x_sub4, x_sub2) 55 | outputs.append(x_24_cls) 56 | x_cff_12, x_12_cls = self.cff_12(x_sub2, x_sub1) 57 | outputs.append(x_12_cls) 58 | 59 | up_x2 = F.interpolate(x_cff_12, scale_factor=2, mode='bilinear', align_corners=True) 60 | up_x2 = self.conv_cls(up_x2) 61 | outputs.append(up_x2) 62 | up_x8 = F.interpolate(up_x2, scale_factor=4, mode='bilinear', align_corners=True) 63 | outputs.append(up_x8) 64 | # 1 -> 1/4 -> 1/8 -> 1/16 65 | outputs.reverse() 66 | 67 | return outputs 68 | 69 | 70 | class _ConvBNReLU(nn.Module): 71 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, 72 | groups=1, norm_layer=nn.BatchNorm2d, bias=False, **kwargs): 73 | super(_ConvBNReLU, self).__init__() 74 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 75 | self.bn = norm_layer(out_channels) 76 | self.relu = nn.ReLU(True) 77 | 78 | def forward(self, x): 79 | x = self.conv(x) 80 | x = self.bn(x) 81 | x = self.relu(x) 82 | return x 83 | 84 | 85 | class CascadeFeatureFusion(nn.Module): 86 | """CFF Unit""" 87 | 88 | def __init__(self, low_channels, high_channels, out_channels, nclass, norm_layer=nn.BatchNorm2d, **kwargs): 89 | super(CascadeFeatureFusion, self).__init__() 90 | self.conv_low = nn.Sequential( 91 | nn.Conv2d(low_channels, out_channels, 3, padding=2, dilation=2, bias=False), 92 | norm_layer(out_channels) 93 | ) 94 | self.conv_high = nn.Sequential( 95 | nn.Conv2d(high_channels, out_channels, 1, bias=False), 96 | norm_layer(out_channels) 97 | ) 98 | self.conv_low_cls = nn.Conv2d(out_channels, nclass, 1, bias=False) 99 | 100 | def forward(self, x_low, x_high): 101 | x_low = F.interpolate(x_low, size=x_high.size()[2:], mode='bilinear', align_corners=True) 102 | x_low = self.conv_low(x_low) 103 | x_high = self.conv_high(x_high) 104 | x = x_low + x_high 105 | x = F.relu(x, inplace=True) 106 | x_low_cls = self.conv_low_cls(x_low) 107 | 108 | return x, x_low_cls 109 | 110 | 111 | def get_icnet(dataset='citys', backbone='resnet50', pretrained=False, root='~/.torch/models', 112 | pretrained_base=True, **kwargs): 113 | acronyms = { 114 | 'pascal_voc': 'pascal_voc', 115 | 'pascal_aug': 'pascal_aug', 116 | 'ade20k': 'ade', 117 | 'coco': 'coco', 118 | 'citys': 'citys', 119 | } 120 | from ..data.dataloader import datasets 121 | model = ICNet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 122 | if pretrained: 123 | from .model_store import get_model_file 124 | device = torch.device(kwargs['local_rank']) 125 | model.load_state_dict(torch.load(get_model_file('icnet_%s_%s' % (backbone, acronyms[dataset]), root=root), 126 | map_location=device)) 127 | return model 128 | 129 | 130 | def get_icnet_resnet50_citys(**kwargs): 131 | return get_icnet('citys', 'resnet50', **kwargs) 132 | 133 | 134 | def get_icnet_resnet101_citys(**kwargs): 135 | return get_icnet('citys', 'resnet101', **kwargs) 136 | 137 | 138 | def get_icnet_resnet152_citys(**kwargs): 139 | return get_icnet('citys', 'resnet152', **kwargs) 140 | 141 | 142 | if __name__ == '__main__': 143 | img = torch.randn(1, 3, 256, 256) 144 | model = get_icnet_resnet50_citys() 145 | outputs = model(img) 146 | -------------------------------------------------------------------------------- /core/models/model_store.py: -------------------------------------------------------------------------------- 1 | """Model store which provides pretrained models.""" 2 | from __future__ import print_function 3 | 4 | import os 5 | import zipfile 6 | 7 | from ..utils.download import download, check_sha1 8 | 9 | __all__ = ['get_model_file', 'get_resnet_file'] 10 | 11 | _model_sha1 = {name: checksum for checksum, name in [ 12 | ('25c4b50959ef024fcc050213a06b614899f94b3d', 'resnet50'), 13 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), 14 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), 15 | ]} 16 | 17 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' 18 | _url_format = '{repo_url}encoding/models/{file_name}.zip' 19 | 20 | 21 | def short_hash(name): 22 | if name not in _model_sha1: 23 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 24 | return _model_sha1[name][:8] 25 | 26 | 27 | def get_resnet_file(name, root='~/.torch/models'): 28 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) 29 | root = os.path.expanduser(root) 30 | 31 | file_path = os.path.join(root, file_name + '.pth') 32 | sha1_hash = _model_sha1[name] 33 | if os.path.exists(file_path): 34 | if check_sha1(file_path, sha1_hash): 35 | return file_path 36 | else: 37 | print('Mismatch in the content of model file {} detected.' + 38 | ' Downloading again.'.format(file_path)) 39 | else: 40 | print('Model file {} is not found. Downloading.'.format(file_path)) 41 | 42 | if not os.path.exists(root): 43 | os.makedirs(root) 44 | 45 | zip_file_path = os.path.join(root, file_name + '.zip') 46 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) 47 | if repo_url[-1] != '/': 48 | repo_url = repo_url + '/' 49 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 50 | path=zip_file_path, 51 | overwrite=True) 52 | with zipfile.ZipFile(zip_file_path) as zf: 53 | zf.extractall(root) 54 | os.remove(zip_file_path) 55 | 56 | if check_sha1(file_path, sha1_hash): 57 | return file_path 58 | else: 59 | raise ValueError('Downloaded file has different hash. Please try again.') 60 | 61 | 62 | def get_model_file(name, root='~/.torch/models'): 63 | root = os.path.expanduser(root) 64 | file_path = os.path.join(root, name + '.pth') 65 | if os.path.exists(file_path): 66 | return file_path 67 | else: 68 | raise ValueError('Model file is not found. Downloading or trainning.') 69 | -------------------------------------------------------------------------------- /core/models/model_zoo.py: -------------------------------------------------------------------------------- 1 | """Model store which handles pretrained models """ 2 | from .fcn import * 3 | from .fcnv2 import * 4 | from .pspnet import * 5 | from .deeplabv3 import * 6 | from .deeplabv3_plus import * 7 | from .danet import * 8 | from .denseaspp import * 9 | from .bisenet import * 10 | from .encnet import * 11 | from .dunet import * 12 | from .icnet import * 13 | from .enet import * 14 | from .ocnet import * 15 | from .ccnet import * 16 | from .psanet import * 17 | from .cgnet import * 18 | from .espnet import * 19 | from .lednet import * 20 | from .dfanet import * 21 | 22 | __all__ = ['get_model', 'get_model_list', 'get_segmentation_model'] 23 | 24 | _models = { 25 | 'fcn32s_vgg16_voc': get_fcn32s_vgg16_voc, 26 | 'fcn16s_vgg16_voc': get_fcn16s_vgg16_voc, 27 | 'fcn8s_vgg16_voc': get_fcn8s_vgg16_voc, 28 | 'fcn_resnet50_voc': get_fcn_resnet50_voc, 29 | 'fcn_resnet101_voc': get_fcn_resnet101_voc, 30 | 'fcn_resnet152_voc': get_fcn_resnet152_voc, 31 | 'psp_resnet50_voc': get_psp_resnet50_voc, 32 | 'psp_resnet50_ade': get_psp_resnet50_ade, 33 | 'psp_resnet101_voc': get_psp_resnet101_voc, 34 | 'psp_resnet101_ade': get_psp_resnet101_ade, 35 | 'psp_resnet101_citys': get_psp_resnet101_citys, 36 | 'psp_resnet101_coco': get_psp_resnet101_coco, 37 | 'deeplabv3_resnet50_voc': get_deeplabv3_resnet50_voc, 38 | 'deeplabv3_resnet101_voc': get_deeplabv3_resnet101_voc, 39 | 'deeplabv3_resnet152_voc': get_deeplabv3_resnet152_voc, 40 | 'deeplabv3_resnet50_ade': get_deeplabv3_resnet50_ade, 41 | 'deeplabv3_resnet101_ade': get_deeplabv3_resnet101_ade, 42 | 'deeplabv3_resnet152_ade': get_deeplabv3_resnet152_ade, 43 | 'deeplabv3_plus_xception_voc': get_deeplabv3_plus_xception_voc, 44 | 'danet_resnet50_ciyts': get_danet_resnet50_citys, 45 | 'danet_resnet101_citys': get_danet_resnet101_citys, 46 | 'danet_resnet152_citys': get_danet_resnet152_citys, 47 | 'denseaspp_densenet121_citys': get_denseaspp_densenet121_citys, 48 | 'denseaspp_densenet161_citys': get_denseaspp_densenet161_citys, 49 | 'denseaspp_densenet169_citys': get_denseaspp_densenet169_citys, 50 | 'denseaspp_densenet201_citys': get_denseaspp_densenet201_citys, 51 | 'bisenet_resnet18_citys': get_bisenet_resnet18_citys, 52 | 'encnet_resnet50_ade': get_encnet_resnet50_ade, 53 | 'encnet_resnet101_ade': get_encnet_resnet101_ade, 54 | 'encnet_resnet152_ade': get_encnet_resnet152_ade, 55 | 'dunet_resnet50_pascal_voc': get_dunet_resnet50_pascal_voc, 56 | 'dunet_resnet101_pascal_voc': get_dunet_resnet101_pascal_voc, 57 | 'dunet_resnet152_pascal_voc': get_dunet_resnet152_pascal_voc, 58 | 'icnet_resnet50_citys': get_icnet_resnet50_citys, 59 | 'icnet_resnet101_citys': get_icnet_resnet101_citys, 60 | 'icnet_resnet152_citys': get_icnet_resnet152_citys, 61 | 'enet_citys': get_enet_citys, 62 | 'base_ocnet_resnet101_citys': get_base_ocnet_resnet101_citys, 63 | 'pyramid_ocnet_resnet101_citys': get_pyramid_ocnet_resnet101_citys, 64 | 'asp_ocnet_resnet101_citys': get_asp_ocnet_resnet101_citys, 65 | 'ccnet_resnet50_citys': get_ccnet_resnet50_citys, 66 | 'ccnet_resnet101_citys': get_ccnet_resnet101_citys, 67 | 'ccnet_resnet152_citys': get_ccnet_resnet152_citys, 68 | 'ccnet_resnet50_ade': get_ccnet_resnet50_ade, 69 | 'ccnet_resnet101_ade': get_ccnet_resnet101_ade, 70 | 'ccnet_resnet152_ade': get_ccnet_resnet152_ade, 71 | 'psanet_resnet50_voc': get_psanet_resnet50_voc, 72 | 'psanet_resnet101_voc': get_psanet_resnet101_voc, 73 | 'psanet_resnet152_voc': get_psanet_resnet152_voc, 74 | 'psanet_resnet50_citys': get_psanet_resnet50_citys, 75 | 'psanet_resnet101_citys': get_psanet_resnet101_citys, 76 | 'psanet_resnet152_citys': get_psanet_resnet152_citys, 77 | 'cgnet_citys': get_cgnet_citys, 78 | 'espnet_citys': get_espnet_citys, 79 | 'lednet_citys': get_lednet_citys, 80 | 'dfanet_citys': get_dfanet_citys, 81 | } 82 | 83 | 84 | def get_model(name, **kwargs): 85 | name = name.lower() 86 | if name not in _models: 87 | err_str = '"%s" is not among the following model list:\n\t' % (name) 88 | err_str += '%s' % ('\n\t'.join(sorted(_models.keys()))) 89 | raise ValueError(err_str) 90 | net = _models[name](**kwargs) 91 | return net 92 | 93 | 94 | def get_model_list(): 95 | return _models.keys() 96 | 97 | 98 | def get_segmentation_model(model, **kwargs): 99 | models = { 100 | 'fcn32s': get_fcn32s, 101 | 'fcn16s': get_fcn16s, 102 | 'fcn8s': get_fcn8s, 103 | 'fcn': get_fcn, 104 | 'psp': get_psp, 105 | 'deeplabv3': get_deeplabv3, 106 | 'deeplabv3_plus': get_deeplabv3_plus, 107 | 'danet': get_danet, 108 | 'denseaspp': get_denseaspp, 109 | 'bisenet': get_bisenet, 110 | 'encnet': get_encnet, 111 | 'dunet': get_dunet, 112 | 'icnet': get_icnet, 113 | 'enet': get_enet, 114 | 'ocnet': get_ocnet, 115 | 'ccnet': get_ccnet, 116 | 'psanet': get_psanet, 117 | 'cgnet': get_cgnet, 118 | 'espnet': get_espnet, 119 | 'lednet': get_lednet, 120 | 'dfanet': get_dfanet, 121 | } 122 | return models[model](**kwargs) 123 | -------------------------------------------------------------------------------- /core/models/psanet.py: -------------------------------------------------------------------------------- 1 | """Point-wise Spatial Attention Network""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from core.nn import _ConvBNReLU 7 | from core.models.segbase import SegBaseModel 8 | from core.models.fcn import _FCNHead 9 | 10 | __all__ = ['PSANet', 'get_psanet', 'get_psanet_resnet50_voc', 'get_psanet_resnet101_voc', 11 | 'get_psanet_resnet152_voc', 'get_psanet_resnet50_citys', 'get_psanet_resnet101_citys', 12 | 'get_psanet_resnet152_citys'] 13 | 14 | 15 | class PSANet(SegBaseModel): 16 | r"""PSANet 17 | 18 | Parameters 19 | ---------- 20 | nclass : int 21 | Number of categories for the training dataset. 22 | backbone : string 23 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 24 | 'resnet101' or 'resnet152'). 25 | norm_layer : object 26 | Normalization layer used in backbone network (default: :class:`nn.BatchNorm`; 27 | for Synchronized Cross-GPU BachNormalization). 28 | aux : bool 29 | Auxiliary loss. 30 | 31 | Reference: 32 | Hengshuang Zhao, et al. "PSANet: Point-wise Spatial Attention Network for Scene Parsing." 33 | ECCV-2018. 34 | """ 35 | 36 | def __init__(self, nclass, backbone='resnet', aux=False, pretrained_base=True, **kwargs): 37 | super(PSANet, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs) 38 | self.head = _PSAHead(nclass, **kwargs) 39 | if aux: 40 | self.auxlayer = _FCNHead(1024, nclass, **kwargs) 41 | 42 | self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head']) 43 | 44 | def forward(self, x): 45 | size = x.size()[2:] 46 | _, _, c3, c4 = self.base_forward(x) 47 | outputs = list() 48 | x = self.head(c4) 49 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 50 | outputs.append(x) 51 | 52 | if self.aux: 53 | auxout = self.auxlayer(c3) 54 | auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True) 55 | outputs.append(auxout) 56 | return tuple(outputs) 57 | 58 | 59 | class _PSAHead(nn.Module): 60 | def __init__(self, nclass, norm_layer=nn.BatchNorm2d, **kwargs): 61 | super(_PSAHead, self).__init__() 62 | # psa_out_channels = crop_size // 8 ** 2 63 | self.psa = _PointwiseSpatialAttention(2048, 3600, norm_layer) 64 | 65 | self.conv_post = _ConvBNReLU(1024, 2048, 1, norm_layer=norm_layer) 66 | self.project = nn.Sequential( 67 | _ConvBNReLU(4096, 512, 3, padding=1, norm_layer=norm_layer), 68 | nn.Dropout2d(0.1, False), 69 | nn.Conv2d(512, nclass, 1)) 70 | 71 | def forward(self, x): 72 | global_feature = self.psa(x) 73 | out = self.conv_post(global_feature) 74 | out = torch.cat([x, out], dim=1) 75 | out = self.project(out) 76 | 77 | return out 78 | 79 | 80 | class _PointwiseSpatialAttention(nn.Module): 81 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, **kwargs): 82 | super(_PointwiseSpatialAttention, self).__init__() 83 | reduced_channels = 512 84 | self.collect_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer) 85 | self.distribute_attention = _AttentionGeneration(in_channels, reduced_channels, out_channels, norm_layer) 86 | 87 | def forward(self, x): 88 | collect_fm = self.collect_attention(x) 89 | distribute_fm = self.distribute_attention(x) 90 | psa_fm = torch.cat([collect_fm, distribute_fm], dim=1) 91 | return psa_fm 92 | 93 | 94 | class _AttentionGeneration(nn.Module): 95 | def __init__(self, in_channels, reduced_channels, out_channels, norm_layer, **kwargs): 96 | super(_AttentionGeneration, self).__init__() 97 | self.conv_reduce = _ConvBNReLU(in_channels, reduced_channels, 1, norm_layer=norm_layer) 98 | self.attention = nn.Sequential( 99 | _ConvBNReLU(reduced_channels, reduced_channels, 1, norm_layer=norm_layer), 100 | nn.Conv2d(reduced_channels, out_channels, 1, bias=False)) 101 | 102 | self.reduced_channels = reduced_channels 103 | 104 | def forward(self, x): 105 | reduce_x = self.conv_reduce(x) 106 | attention = self.attention(reduce_x) 107 | n, c, h, w = attention.size() 108 | attention = attention.view(n, c, -1) 109 | reduce_x = reduce_x.view(n, self.reduced_channels, -1) 110 | fm = torch.bmm(reduce_x, torch.softmax(attention, dim=1)) 111 | fm = fm.view(n, self.reduced_channels, h, w) 112 | 113 | return fm 114 | 115 | 116 | def get_psanet(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models', 117 | pretrained_base=True, **kwargs): 118 | acronyms = { 119 | 'pascal_voc': 'pascal_voc', 120 | 'pascal_aug': 'pascal_aug', 121 | 'ade20k': 'ade', 122 | 'coco': 'coco', 123 | 'citys': 'citys', 124 | } 125 | from core.data.dataloader import datasets 126 | model = PSANet(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs) 127 | if pretrained: 128 | from .model_store import get_model_file 129 | device = torch.device(kwargs['local_rank']) 130 | model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root), 131 | map_location=device)) 132 | return model 133 | 134 | 135 | def get_psanet_resnet50_voc(**kwargs): 136 | return get_psanet('pascal_voc', 'resnet50', **kwargs) 137 | 138 | 139 | def get_psanet_resnet101_voc(**kwargs): 140 | return get_psanet('pascal_voc', 'resnet101', **kwargs) 141 | 142 | 143 | def get_psanet_resnet152_voc(**kwargs): 144 | return get_psanet('pascal_voc', 'resnet152', **kwargs) 145 | 146 | 147 | def get_psanet_resnet50_citys(**kwargs): 148 | return get_psanet('citys', 'resnet50', **kwargs) 149 | 150 | 151 | def get_psanet_resnet101_citys(**kwargs): 152 | return get_psanet('citys', 'resnet101', **kwargs) 153 | 154 | 155 | def get_psanet_resnet152_citys(**kwargs): 156 | return get_psanet('citys', 'resnet152', **kwargs) 157 | 158 | 159 | if __name__ == '__main__': 160 | model = get_psanet_resnet50_voc() 161 | img = torch.randn(1, 3, 480, 480) 162 | output = model(img) 163 | -------------------------------------------------------------------------------- /core/models/segbase.py: -------------------------------------------------------------------------------- 1 | """Base Model for Semantic Segmentation""" 2 | import torch.nn as nn 3 | 4 | from ..nn import JPU 5 | from .base_models.resnetv1b import resnet50_v1s, resnet101_v1s, resnet152_v1s 6 | 7 | __all__ = ['SegBaseModel'] 8 | 9 | 10 | class SegBaseModel(nn.Module): 11 | r"""Base Model for Semantic Segmentation 12 | 13 | Parameters 14 | ---------- 15 | backbone : string 16 | Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 17 | 'resnet101' or 'resnet152'). 18 | """ 19 | 20 | def __init__(self, nclass, aux, backbone='resnet50', jpu=False, pretrained_base=True, **kwargs): 21 | super(SegBaseModel, self).__init__() 22 | dilated = False if jpu else True 23 | self.aux = aux 24 | self.nclass = nclass 25 | if backbone == 'resnet50': 26 | self.pretrained = resnet50_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 27 | elif backbone == 'resnet101': 28 | self.pretrained = resnet101_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 29 | elif backbone == 'resnet152': 30 | self.pretrained = resnet152_v1s(pretrained=pretrained_base, dilated=dilated, **kwargs) 31 | else: 32 | raise RuntimeError('unknown backbone: {}'.format(backbone)) 33 | 34 | self.jpu = JPU([512, 1024, 2048], width=512, **kwargs) if jpu else None 35 | 36 | def base_forward(self, x): 37 | """forwarding pre-trained network""" 38 | x = self.pretrained.conv1(x) 39 | x = self.pretrained.bn1(x) 40 | x = self.pretrained.relu(x) 41 | x = self.pretrained.maxpool(x) 42 | c1 = self.pretrained.layer1(x) 43 | c2 = self.pretrained.layer2(c1) 44 | c3 = self.pretrained.layer3(c2) 45 | c4 = self.pretrained.layer4(c3) 46 | 47 | if self.jpu: 48 | return self.jpu(c1, c2, c3, c4) 49 | else: 50 | return c1, c2, c3, c4 51 | 52 | def evaluate(self, x): 53 | """evaluating network with inputs and targets""" 54 | return self.forward(x)[0] 55 | 56 | def demo(self, x): 57 | pred = self.forward(x) 58 | if self.aux: 59 | pred = pred[0] 60 | return pred 61 | -------------------------------------------------------------------------------- /core/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Seg NN Modules""" 2 | # from .sync_bn.syncbn import * 3 | # from .syncbn import * 4 | from .ca_block import * 5 | from .psa_block import * 6 | from .jpu import * 7 | from .basic import * -------------------------------------------------------------------------------- /core/nn/basic.py: -------------------------------------------------------------------------------- 1 | """Basic Module for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual'] 7 | 8 | 9 | class _ConvBNReLU(nn.Module): 10 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 11 | dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d, **kwargs): 12 | super(_ConvBNReLU, self).__init__() 13 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 14 | self.bn = norm_layer(out_channels) 15 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | x = self.bn(x) 20 | x = self.relu(x) 21 | return x 22 | 23 | 24 | class _ConvBNPReLU(nn.Module): 25 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 26 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs): 27 | super(_ConvBNPReLU, self).__init__() 28 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 29 | self.bn = norm_layer(out_channels) 30 | self.prelu = nn.PReLU(out_channels) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | x = self.bn(x) 35 | x = self.prelu(x) 36 | return x 37 | 38 | 39 | class _ConvBN(nn.Module): 40 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 41 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs): 42 | super(_ConvBN, self).__init__() 43 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 44 | self.bn = norm_layer(out_channels) 45 | 46 | def forward(self, x): 47 | x = self.conv(x) 48 | x = self.bn(x) 49 | return x 50 | 51 | 52 | class _BNPReLU(nn.Module): 53 | def __init__(self, out_channels, norm_layer=nn.BatchNorm2d, **kwargs): 54 | super(_BNPReLU, self).__init__() 55 | self.bn = norm_layer(out_channels) 56 | self.prelu = nn.PReLU(out_channels) 57 | 58 | def forward(self, x): 59 | x = self.bn(x) 60 | x = self.prelu(x) 61 | return x 62 | 63 | 64 | # ----------------------------------------------------------------- 65 | # For PSPNet 66 | # ----------------------------------------------------------------- 67 | class _PSPModule(nn.Module): 68 | def __init__(self, in_channels, sizes=(1, 2, 3, 6), **kwargs): 69 | super(_PSPModule, self).__init__() 70 | out_channels = int(in_channels / 4) 71 | self.avgpools = nn.ModuleList() 72 | self.convs = nn.ModuleList() 73 | for size in sizes: 74 | self.avgpool.append(nn.AdaptiveAvgPool2d(size)) 75 | self.convs.append(_ConvBNReLU(in_channels, out_channels, 1, **kwargs)) 76 | 77 | def forward(self, x): 78 | size = x.size()[2:] 79 | feats = [x] 80 | for (avgpool, conv) in enumerate(zip(self.avgpools, self.convs)): 81 | feats.append(F.interpolate(conv(avgpool(x)), size, mode='bilinear', align_corners=True)) 82 | return torch.cat(feats, dim=1) 83 | 84 | 85 | # ----------------------------------------------------------------- 86 | # For MobileNet 87 | # ----------------------------------------------------------------- 88 | class _DepthwiseConv(nn.Module): 89 | """conv_dw in MobileNet""" 90 | 91 | def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs): 92 | super(_DepthwiseConv, self).__init__() 93 | self.conv = nn.Sequential( 94 | _ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer), 95 | _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer)) 96 | 97 | def forward(self, x): 98 | return self.conv(x) 99 | 100 | 101 | # ----------------------------------------------------------------- 102 | # For MobileNetV2 103 | # ----------------------------------------------------------------- 104 | class InvertedResidual(nn.Module): 105 | def __init__(self, in_channels, out_channels, stride, expand_ratio, norm_layer=nn.BatchNorm2d, **kwargs): 106 | super(InvertedResidual, self).__init__() 107 | assert stride in [1, 2] 108 | self.use_res_connect = stride == 1 and in_channels == out_channels 109 | 110 | layers = list() 111 | inter_channels = int(round(in_channels * expand_ratio)) 112 | if expand_ratio != 1: 113 | # pw 114 | layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer)) 115 | layers.extend([ 116 | # dw 117 | _ConvBNReLU(inter_channels, inter_channels, 3, stride, 1, 118 | groups=inter_channels, relu6=True, norm_layer=norm_layer), 119 | # pw-linear 120 | nn.Conv2d(inter_channels, out_channels, 1, bias=False), 121 | norm_layer(out_channels)]) 122 | self.conv = nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | if self.use_res_connect: 126 | return x + self.conv(x) 127 | else: 128 | return self.conv(x) 129 | 130 | 131 | if __name__ == '__main__': 132 | x = torch.randn(1, 32, 64, 64) 133 | model = InvertedResidual(32, 64, 2, 1) 134 | out = model(x) 135 | -------------------------------------------------------------------------------- /core/nn/ca_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd.function import once_differentiable 6 | from core.nn import _C 7 | 8 | __all__ = ['CrissCrossAttention', 'ca_weight', 'ca_map'] 9 | 10 | 11 | class _CAWeight(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, t, f): 14 | weight = _C.ca_forward(t, f) 15 | 16 | ctx.save_for_backward(t, f) 17 | 18 | return weight 19 | 20 | @staticmethod 21 | @once_differentiable 22 | def backward(ctx, dw): 23 | t, f = ctx.saved_tensors 24 | 25 | dt, df = _C.ca_backward(dw, t, f) 26 | return dt, df 27 | 28 | 29 | class _CAMap(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, weight, g): 32 | out = _C.ca_map_forward(weight, g) 33 | 34 | ctx.save_for_backward(weight, g) 35 | 36 | return out 37 | 38 | @staticmethod 39 | @once_differentiable 40 | def backward(ctx, dout): 41 | weight, g = ctx.saved_tensors 42 | 43 | dw, dg = _C.ca_map_backward(dout, weight, g) 44 | 45 | return dw, dg 46 | 47 | 48 | ca_weight = _CAWeight.apply 49 | ca_map = _CAMap.apply 50 | 51 | 52 | class CrissCrossAttention(nn.Module): 53 | """Criss-Cross Attention Module""" 54 | 55 | def __init__(self, in_channels): 56 | super(CrissCrossAttention, self).__init__() 57 | self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 58 | self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 59 | self.value_conv = nn.Conv2d(in_channels, in_channels, 1) 60 | self.gamma = nn.Parameter(torch.zeros(1)) 61 | 62 | def forward(self, x): 63 | proj_query = self.query_conv(x) 64 | proj_key = self.key_conv(x) 65 | proj_value = self.value_conv(x) 66 | 67 | energy = ca_weight(proj_query, proj_key) 68 | attention = F.softmax(energy, 1) 69 | out = ca_map(attention, proj_value) 70 | out = self.gamma * out + x 71 | 72 | return out 73 | -------------------------------------------------------------------------------- /core/nn/csrc/ca.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/vision.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/vision.h" 7 | #endif 8 | 9 | // Interface for Python 10 | at::Tensor ca_forward(const at::Tensor& t, 11 | const at::Tensor& f) { 12 | if (t.type().is_cuda()) { 13 | #ifdef WITH_CUDA 14 | return ca_forward_cuda(t, f); 15 | #else 16 | AT_ERROR("Not compiled with GPU support"); 17 | #endif 18 | } 19 | return ca_forward_cpu(t, f); 20 | } 21 | 22 | std::tuple ca_backward(const at::Tensor& dw, 23 | const at::Tensor& t, 24 | const at::Tensor& f) { 25 | if (dw.type().is_cuda()) { 26 | #ifdef WITH_CUDA 27 | return ca_backward_cuda(dw, t, f); 28 | #else 29 | AT_ERROR("Not compiled with GPU support"); 30 | #endif 31 | } 32 | return ca_backward_cpu(dw, t, f); 33 | } 34 | 35 | at::Tensor ca_map_forward(const at::Tensor& weight, 36 | const at::Tensor& g) { 37 | if (weight.type().is_cuda()) { 38 | #ifdef WITH_CUDA 39 | return ca_map_forward_cuda(weight, g); 40 | #else 41 | AT_ERROR("Not compiled with GPU support"); 42 | #endif 43 | } 44 | return ca_map_forward_cpu(weight, g); 45 | } 46 | 47 | std::tuple ca_map_backward(const at::Tensor& dout, 48 | const at::Tensor& weight, 49 | const at::Tensor& g) { 50 | if (dout.type().is_cuda()) { 51 | #ifdef WITH_CUDA 52 | return ca_map_backward_cuda(dout, weight, g); 53 | #else 54 | AT_ERROR("Not compiled with GPU support"); 55 | #endif 56 | } 57 | return ca_map_backward_cpu(dout, weight, g); 58 | } -------------------------------------------------------------------------------- /core/nn/csrc/cpu/ca_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "cpu/vision.h" 2 | 3 | 4 | at::Tensor ca_forward_cpu( 5 | const torch::Tensor& t, 6 | const torch::Tensor& f) { 7 | AT_ERROR("Not implemented on the CPU");} 8 | 9 | std::tuple ca_backward_cpu( 10 | const at::Tensor& dw, 11 | const at::Tensor& t, 12 | const at::Tensor& f) { 13 | AT_ERROR("Not implemented on the CPU");} 14 | 15 | at::Tensor ca_map_forward_cpu( 16 | const at::Tensor& weight, 17 | const at::Tensor& g) { 18 | AT_ERROR("Not implemented on the CPU");} 19 | 20 | std::tuple ca_map_backward_cpu( 21 | const at::Tensor& dout, 22 | const at::Tensor& weight, 23 | const at::Tensor& g) { 24 | AT_ERROR("Not implemented on the CPU");} -------------------------------------------------------------------------------- /core/nn/csrc/cpu/psa_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "cpu/vision.h" 2 | 3 | 4 | at::Tensor psa_forward_cpu( 5 | const torch::Tensor& hc, 6 | const int forward_type) { 7 | AT_ERROR("Not implemented on the CPU");} 8 | 9 | at::Tensor psa_backward_cpu( 10 | const at::Tensor& dout, 11 | const at::Tensor& hc, 12 | const int forward_type) { 13 | AT_ERROR("Not implemented on the CPU");} -------------------------------------------------------------------------------- /core/nn/csrc/cpu/syncbn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 6 | if (x.ndimension() == 2) { 7 | return v; 8 | } else { 9 | std::vector broadcast_size = {1, -1}; 10 | for (int64_t i = 2; i < x.ndimension(); ++i) 11 | broadcast_size.push_back(1); 12 | 13 | return v.view(broadcast_size); 14 | } 15 | } 16 | 17 | at::Tensor batchnorm_forward_cpu( 18 | const at::Tensor input_, 19 | const at::Tensor ex_, 20 | const at::Tensor exs_, 21 | const at::Tensor gamma_, 22 | const at::Tensor beta_, 23 | float eps) { 24 | auto output = (input_ - broadcast_to(ex_, input_)) / broadcast_to(exs_, input_); 25 | output = output * broadcast_to(gamma_, input_) + broadcast_to(beta_, input_); 26 | return output; 27 | } 28 | 29 | // Not implementing CPU backward for now 30 | std::vector batchnorm_backward_cpu( 31 | const at::Tensor gradoutput_, 32 | const at::Tensor input_, 33 | const at::Tensor ex_, 34 | const at::Tensor exs_, 35 | const at::Tensor gamma_, 36 | const at::Tensor beta_, 37 | float eps) { 38 | /* outputs*/ 39 | at::Tensor gradinput = at::zeros_like(input_); 40 | at::Tensor gradgamma = at::zeros_like(gamma_); 41 | at::Tensor gradbeta = at::zeros_like(beta_); 42 | at::Tensor gradMean = at::zeros_like(ex_); 43 | at::Tensor gradStd = at::zeros_like(exs_); 44 | return {gradinput, gradMean, gradStd, gradgamma, gradbeta}; 45 | } -------------------------------------------------------------------------------- /core/nn/csrc/cpu/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | 5 | at::Tensor ca_forward_cpu( 6 | const at::Tensor& t, 7 | const at::Tensor& f); 8 | 9 | std::tuple ca_backward_cpu( 10 | const at::Tensor& dw, 11 | const at::Tensor& t, 12 | const at::Tensor& f); 13 | 14 | at::Tensor ca_map_forward_cpu( 15 | const at::Tensor& weight, 16 | const at::Tensor& g); 17 | 18 | std::tuple ca_map_backward_cpu( 19 | const at::Tensor& dout, 20 | const at::Tensor& weight, 21 | const at::Tensor& g); 22 | 23 | at::Tensor psa_forward_cpu( 24 | const at::Tensor& hc, 25 | const int forward_type); 26 | 27 | at::Tensor psa_backward_cpu( 28 | const at::Tensor& dout, 29 | const at::Tensor& hc, 30 | const int forward_type); 31 | 32 | at::Tensor batchnorm_forward_cpu( 33 | const at::Tensor input_, 34 | const at::Tensor mean_, 35 | const at::Tensor std_, 36 | const at::Tensor gamma_, 37 | const at::Tensor beta_, 38 | float eps); 39 | 40 | std::vector batchnorm_backward_cpu( 41 | const at::Tensor gradoutput_, 42 | const at::Tensor input_, 43 | const at::Tensor ex_, 44 | const at::Tensor exs_, 45 | const at::Tensor gamma_, 46 | const at::Tensor beta_, 47 | float eps); -------------------------------------------------------------------------------- /core/nn/csrc/cuda/vision.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | 6 | at::Tensor ca_forward_cuda( 7 | const at::Tensor& t, 8 | const at::Tensor& f); 9 | 10 | std::tuple ca_backward_cuda( 11 | const at::Tensor& dw, 12 | const at::Tensor& t, 13 | const at::Tensor& f); 14 | 15 | at::Tensor ca_map_forward_cuda( 16 | const at::Tensor& weight, 17 | const at::Tensor& g); 18 | 19 | std::tuple ca_map_backward_cuda( 20 | const at::Tensor& dout, 21 | const at::Tensor& weight, 22 | const at::Tensor& g); 23 | 24 | at::Tensor psa_forward_cuda( 25 | const at::Tensor& hc, 26 | const int forward_type); 27 | 28 | at::Tensor psa_backward_cuda( 29 | const at::Tensor& dout, 30 | const at::Tensor& hc, 31 | const int forward_type); 32 | 33 | at::Tensor batchnorm_forward_cuda( 34 | const at::Tensor input_, 35 | const at::Tensor ex_, 36 | const at::Tensor exs_, 37 | const at::Tensor gamma_, 38 | const at::Tensor beta_, 39 | float eps); 40 | 41 | at::Tensor inp_batchnorm_forward_cuda( 42 | const at::Tensor input_, 43 | const at::Tensor ex_, 44 | const at::Tensor exs_, 45 | const at::Tensor gamma_, 46 | const at::Tensor beta_, 47 | float eps); 48 | 49 | std::vector batchnorm_backward_cuda( 50 | const at::Tensor gradoutput_, 51 | const at::Tensor input_, 52 | const at::Tensor ex_, 53 | const at::Tensor exs_, 54 | const at::Tensor gamma_, 55 | const at::Tensor beta_, 56 | float eps); 57 | 58 | std::vector inp_batchnorm_backward_cuda( 59 | const at::Tensor gradoutput_, 60 | const at::Tensor output_, 61 | const at::Tensor ex_, 62 | const at::Tensor exs_, 63 | const at::Tensor gamma_, 64 | const at::Tensor beta_, 65 | float eps); 66 | 67 | std::vector expectation_forward_cuda( 68 | const at::Tensor input_); 69 | 70 | at::Tensor expectation_backward_cuda( 71 | const at::Tensor input_, 72 | const at::Tensor gradEx_, 73 | const at::Tensor gradExs_); 74 | 75 | at::Tensor inp_expectation_backward_cuda( 76 | const at::Tensor gradInput_, 77 | const at::Tensor output_, 78 | const at::Tensor gradEx_, 79 | const at::Tensor gradExs_, 80 | const at::Tensor ex_, 81 | const at::Tensor exs_, 82 | const at::Tensor gamma_, 83 | const at::Tensor beta_, 84 | float eps); -------------------------------------------------------------------------------- /core/nn/csrc/psa.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu/vision.h" 4 | 5 | #ifdef WITH_CUDA 6 | #include "cuda/vision.h" 7 | #endif 8 | 9 | // Interface for Python 10 | at::Tensor psa_forward(const at::Tensor& hc, 11 | const int forward_type) { 12 | if (hc.type().is_cuda()) { 13 | #ifdef WITH_CUDA 14 | return psa_forward_cuda(hc, forward_type); 15 | #else 16 | AT_ERROR("Not compiled with GPU support"); 17 | #endif 18 | } 19 | return psa_forward_cpu(hc, forward_type); 20 | } 21 | 22 | at::Tensor psa_backward(const at::Tensor& dout, 23 | const at::Tensor& hc, 24 | const int forward_type) { 25 | if (hc.type().is_cuda()) { 26 | #ifdef WITH_CUDA 27 | return psa_backward_cuda(dout, hc, forward_type); 28 | #else 29 | AT_ERROR("Not compiled with GPU support"); 30 | #endif 31 | } 32 | return psa_backward_cpu(dout, hc, forward_type); 33 | } -------------------------------------------------------------------------------- /core/nn/csrc/syncbn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "cpu/vision.h" 5 | 6 | #ifdef WITH_CUDA 7 | #include "cuda/vision.h" 8 | #endif 9 | 10 | // Interface for Python 11 | at::Tensor batchnorm_forward(const at::Tensor input_, 12 | const at::Tensor ex_, 13 | const at::Tensor exs_, 14 | const at::Tensor gamma_, 15 | const at::Tensor beta_, 16 | float eps) { 17 | if (input_.type().is_cuda()) { 18 | #ifdef WITH_CUDA 19 | return batchnorm_forward_cuda(input_, ex_, exs_, gamma_, beta_, eps); 20 | #else 21 | AT_ERROR("Not compiled with GPU support"); 22 | #endif 23 | } 24 | return batchnorm_forward_cpu(input_, ex_, exs_, gamma_, beta_, eps); 25 | } 26 | 27 | at::Tensor inp_batchnorm_forward(const at::Tensor input_, 28 | const at::Tensor ex_, 29 | const at::Tensor exs_, 30 | const at::Tensor gamma_, 31 | const at::Tensor beta_, 32 | float eps) { 33 | if (input_.type().is_cuda()) { 34 | #ifdef WITH_CUDA 35 | return inp_batchnorm_forward_cuda(input_, ex_, exs_, gamma_, beta_, eps); 36 | #else 37 | AT_ERROR("Not compiled with GPU support"); 38 | #endif 39 | } 40 | AT_ERROR("Not implemented on the CPU"); 41 | } 42 | 43 | std::vector batchnorm_backward(const at::Tensor gradoutput_, 44 | const at::Tensor input_, 45 | const at::Tensor ex_, 46 | const at::Tensor exs_, 47 | const at::Tensor gamma_, 48 | const at::Tensor beta_, 49 | float eps) { 50 | if (gradoutput_.type().is_cuda()) { 51 | #ifdef WITH_CUDA 52 | return batchnorm_backward_cuda(gradoutput_, input_, ex_, exs_, gamma_, beta_, eps); 53 | #else 54 | AT_ERROR("Not compiled with GPU support"); 55 | #endif 56 | } 57 | return batchnorm_backward_cpu(gradoutput_, input_, ex_, exs_, gamma_, beta_, eps); 58 | } 59 | 60 | std::vector inp_batchnorm_backward(const at::Tensor gradoutput_, 61 | const at::Tensor input_, 62 | const at::Tensor ex_, 63 | const at::Tensor exs_, 64 | const at::Tensor gamma_, 65 | const at::Tensor beta_, 66 | float eps) { 67 | if (gradoutput_.type().is_cuda()) { 68 | #ifdef WITH_CUDA 69 | return inp_batchnorm_backward_cuda(gradoutput_, input_, ex_, exs_, gamma_, beta_, eps); 70 | #else 71 | AT_ERROR("Not compiled with GPU support"); 72 | #endif 73 | } 74 | AT_ERROR("Not implemented on the CPU"); 75 | } 76 | 77 | std::vector expectation_forward(const at::Tensor input_) { 78 | if (input_.type().is_cuda()) { 79 | #ifdef WITH_CUDA 80 | return expectation_forward_cuda(input_); 81 | #else 82 | AT_ERROR("Not compiled with GPU support"); 83 | #endif 84 | } 85 | AT_ERROR("Not implemented on the CPU"); 86 | } 87 | 88 | at::Tensor expectation_backward(const at::Tensor input_, 89 | const at::Tensor gradEx_, 90 | const at::Tensor gradExs_) { 91 | if (input_.type().is_cuda()) { 92 | #ifdef WITH_CUDA 93 | return expectation_backward_cuda(input_, gradEx_, gradExs_); 94 | #else 95 | AT_ERROR("Not compiled with GPU support"); 96 | #endif 97 | } 98 | AT_ERROR("Not implemented on the CPU"); 99 | } 100 | 101 | at::Tensor inp_expectation_backward(const at::Tensor gradInput_, 102 | const at::Tensor output_, 103 | const at::Tensor gradEx_, 104 | const at::Tensor gradExs_, 105 | const at::Tensor ex_, 106 | const at::Tensor exs_, 107 | const at::Tensor gamma_, 108 | const at::Tensor beta_, 109 | float eps) { 110 | if (output_.type().is_cuda()) { 111 | #ifdef WITH_CUDA 112 | return inp_expectation_backward_cuda(gradInput_, output_, gradEx_, gradExs_, ex_, exs_, gamma_, beta_, eps); 113 | #else 114 | AT_ERROR("Not compiled with GPU support"); 115 | #endif 116 | } 117 | AT_ERROR("Not implemented on the CPU"); 118 | } -------------------------------------------------------------------------------- /core/nn/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | #include "ca.h" 2 | #include "psa.h" 3 | #include "syncbn.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("ca_forward", &ca_forward, "ca_forward"); 7 | m.def("ca_backward", &ca_backward, "ca_backward"); 8 | m.def("ca_map_forward", &ca_map_forward, "ca_map_forward"); 9 | m.def("ca_map_backward", &ca_map_backward, "ca_map_backward"); 10 | m.def("psa_forward", &psa_forward, "psa_forward"); 11 | m.def("psa_backward", &psa_backward, "psa_backward"); 12 | m.def("batchnorm_forward", &batchnorm_forward, "batchnorm_forward"); 13 | m.def("inp_batchnorm_forward", &inp_batchnorm_forward, "inp_batchnorm_forward"); 14 | m.def("batchnorm_backward", &batchnorm_backward, "batchnorm_backward"); 15 | m.def("inp_batchnorm_backward", &inp_batchnorm_backward, "inp_batchnorm_backward"); 16 | m.def("expectation_forward", &expectation_forward, "expectation_forward"); 17 | m.def("expectation_backward", &expectation_backward, "expectation_backward"); 18 | m.def("inp_expectation_backward", &inp_expectation_backward, "inp_expectation_backward"); 19 | } -------------------------------------------------------------------------------- /core/nn/jpu.py: -------------------------------------------------------------------------------- 1 | """Joint Pyramid Upsampling""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['JPU'] 7 | 8 | 9 | class SeparableConv2d(nn.Module): 10 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, 11 | dilation=1, bias=False, norm_layer=nn.BatchNorm2d): 12 | super(SeparableConv2d, self).__init__() 13 | self.conv = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias) 14 | self.bn = norm_layer(inplanes) 15 | self.pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias) 16 | 17 | def forward(self, x): 18 | x = self.conv(x) 19 | x = self.bn(x) 20 | x = self.pointwise(x) 21 | return x 22 | 23 | 24 | # copy from: https://github.com/wuhuikai/FastFCN/blob/master/encoding/nn/customize.py 25 | class JPU(nn.Module): 26 | def __init__(self, in_channels, width=512, norm_layer=nn.BatchNorm2d, **kwargs): 27 | super(JPU, self).__init__() 28 | 29 | self.conv5 = nn.Sequential( 30 | nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False), 31 | norm_layer(width), 32 | nn.ReLU(True)) 33 | self.conv4 = nn.Sequential( 34 | nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False), 35 | norm_layer(width), 36 | nn.ReLU(True)) 37 | self.conv3 = nn.Sequential( 38 | nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False), 39 | norm_layer(width), 40 | nn.ReLU(True)) 41 | 42 | self.dilation1 = nn.Sequential( 43 | SeparableConv2d(3 * width, width, 3, padding=1, dilation=1, bias=False), 44 | norm_layer(width), 45 | nn.ReLU(True)) 46 | self.dilation2 = nn.Sequential( 47 | SeparableConv2d(3 * width, width, 3, padding=2, dilation=2, bias=False), 48 | norm_layer(width), 49 | nn.ReLU(True)) 50 | self.dilation3 = nn.Sequential( 51 | SeparableConv2d(3 * width, width, 3, padding=4, dilation=4, bias=False), 52 | norm_layer(width), 53 | nn.ReLU(True)) 54 | self.dilation4 = nn.Sequential( 55 | SeparableConv2d(3 * width, width, 3, padding=8, dilation=8, bias=False), 56 | norm_layer(width), 57 | nn.ReLU(True)) 58 | 59 | def forward(self, *inputs): 60 | feats = [self.conv5(inputs[-1]), self.conv4(inputs[-2]), self.conv3(inputs[-3])] 61 | size = feats[-1].size()[2:] 62 | feats[-2] = F.interpolate(feats[-2], size, mode='bilinear', align_corners=True) 63 | feats[-3] = F.interpolate(feats[-3], size, mode='bilinear', align_corners=True) 64 | feat = torch.cat(feats, dim=1) 65 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], 66 | dim=1) 67 | 68 | return inputs[0], inputs[1], inputs[2], feat 69 | -------------------------------------------------------------------------------- /core/nn/psa_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torch.autograd.function import once_differentiable 5 | from core.nn import _C 6 | 7 | __all__ = ['CollectAttention', 'DistributeAttention', 'psa_collect', 'psa_distribute'] 8 | 9 | 10 | class _PSACollect(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, hc): 13 | out = _C.psa_forward(hc, 1) 14 | 15 | ctx.save_for_backward(hc) 16 | 17 | return out 18 | 19 | @staticmethod 20 | @once_differentiable 21 | def backward(ctx, dout): 22 | hc = ctx.saved_tensors 23 | 24 | dhc = _C.psa_backward(dout, hc[0], 1) 25 | 26 | return dhc 27 | 28 | 29 | class _PSADistribute(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, hc): 32 | out = _C.psa_forward(hc, 2) 33 | 34 | ctx.save_for_backward(hc) 35 | 36 | return out 37 | 38 | @staticmethod 39 | @once_differentiable 40 | def backward(ctx, dout): 41 | hc = ctx.saved_tensors 42 | 43 | dhc = _C.psa_backward(dout, hc[0], 2) 44 | 45 | return dhc 46 | 47 | 48 | psa_collect = _PSACollect.apply 49 | psa_distribute = _PSADistribute.apply 50 | 51 | 52 | class CollectAttention(nn.Module): 53 | """Collect Attention Generation Module""" 54 | 55 | def __init__(self): 56 | super(CollectAttention, self).__init__() 57 | 58 | def forward(self, x): 59 | out = psa_collect(x) 60 | return out 61 | 62 | 63 | class DistributeAttention(nn.Module): 64 | """Distribute Attention Generation Module""" 65 | 66 | def __init__(self): 67 | super(DistributeAttention, self).__init__() 68 | 69 | def forward(self, x): 70 | out = psa_distribute(x) 71 | return out 72 | -------------------------------------------------------------------------------- /core/nn/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # !/usr/bin/env python 3 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/90c226cf10e098263d1df28bda054a5f22513b4f/setup.py 4 | 5 | import os 6 | import glob 7 | import torch 8 | 9 | from setuptools import setup 10 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME 11 | 12 | requirements = ["torch"] 13 | 14 | 15 | def get_extension(): 16 | this_dir = os.path.dirname(os.path.abspath(__file__)) 17 | extensions_dir = os.path.join(this_dir, "csrc") 18 | 19 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 20 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 21 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 22 | 23 | sources = main_file + source_cpu 24 | extension = CppExtension 25 | 26 | define_macros = [] 27 | 28 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 29 | extension = CUDAExtension 30 | sources += source_cuda 31 | define_macros += [("WITH_CUDA", None)] 32 | 33 | sources = [os.path.join(extensions_dir, s) for s in sources] 34 | 35 | include_dirs = [extensions_dir] 36 | 37 | ext_modules = [ 38 | extension( 39 | "._C", 40 | sources, 41 | include_dirs=include_dirs, 42 | define_macros=define_macros, 43 | ) 44 | ] 45 | 46 | return ext_modules 47 | 48 | 49 | setup( 50 | name="semantic_segmentation", 51 | version="0.1", 52 | author="tramac", 53 | description="semantic segmentation in pytorch", 54 | ext_modules=get_extension(), 55 | cmdclass={"build_ext": BuildExtension} 56 | ) -------------------------------------------------------------------------------- /core/nn/sync_bn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/core/nn/sync_bn/__init__.py -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.cpp_extension import load 4 | 5 | cwd = os.path.dirname(os.path.realpath(__file__)) 6 | cpu_path = os.path.join(cwd, 'cpu') 7 | gpu_path = os.path.join(cwd, 'gpu') 8 | 9 | cpu = load('sync_cpu', [ 10 | os.path.join(cpu_path, 'operator.cpp'), 11 | os.path.join(cpu_path, 'syncbn_cpu.cpp'), 12 | ], build_directory=cpu_path, verbose=False) 13 | 14 | if torch.cuda.is_available(): 15 | gpu = load('sync_gpu', [ 16 | os.path.join(gpu_path, 'operator.cpp'), 17 | os.path.join(gpu_path, 'activation_kernel.cu'), 18 | os.path.join(gpu_path, 'syncbn_kernel.cu'), 19 | ], extra_cuda_cflags=["--expt-extended-lambda"], 20 | build_directory=gpu_path, verbose=False) 21 | -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/.ninja_deps: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/core/nn/sync_bn/lib/cpu/.ninja_deps -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/.ninja_log: -------------------------------------------------------------------------------- 1 | # ninja log v5 2 | 0 6679 1555417150 syncbn_cpu.o b884354b4810778d 3 | 0 7702 1555417151 operator.o df6e270344a1d164 4 | 7703 8115 1555417151 sync_cpu.so d148b4e40b0af67e 5 | 0 5172 1557113015 syncbn_cpu.o 9052547bb175072 6 | 0 6447 1557113016 operator.o 209836e0b0c1e97e 7 | 6447 6613 1557113016 sync_cpu.so d148b4e40b0af67e 8 | -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | 4 | cflags = -DTORCH_EXTENSION_NAME=sync_cpu -DTORCH_API_INCLUDE_EXTENSION_H -isystem /home/tramac/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/torch/include -isystem /home/tramac/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -isystem /home/tramac/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/torch/include/TH -isystem /home/tramac/.pyenv/versions/anaconda3-4.4.0/lib/python3.6/site-packages/torch/include/THC -isystem /home/tramac/.pyenv/versions/anaconda3-4.4.0/include/python3.6m -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++11 5 | ldflags = -shared 6 | 7 | rule compile 8 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out 9 | depfile = $out.d 10 | deps = gcc 11 | 12 | rule link 13 | command = $cxx $in $ldflags -o $out 14 | 15 | build operator.o: compile /home/tramac/PycharmProjects/awesome-semantic-segmentation-pytorch/core/nn/sync_bn/lib/cpu/operator.cpp 16 | build syncbn_cpu.o: compile /home/tramac/PycharmProjects/awesome-semantic-segmentation-pytorch/core/nn/sync_bn/lib/cpu/syncbn_cpu.cpp 17 | 18 | build sync_cpu.so: link operator.o syncbn_cpu.o 19 | 20 | default sync_cpu.so 21 | 22 | -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/operator.cpp: -------------------------------------------------------------------------------- 1 | #include "operator.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("batchnorm_forward", &BatchNorm_Forward_CPU, "BatchNorm forward (CPU)"); 5 | m.def("batchnorm_backward", &BatchNorm_Backward_CPU, "BatchNorm backward (CPU)"); 6 | m.def("sumsquare_forward", &Sum_Square_Forward_CPU, "SumSqu forward (CPU)"); 7 | m.def("sumsquare_backward", &Sum_Square_Backward_CPU, "SumSqu backward (CPU)"); 8 | } -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/operator.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | at::Tensor BatchNorm_Forward_CPU( 5 | const at::Tensor input_, 6 | const at::Tensor mean_, 7 | const at::Tensor std_, 8 | const at::Tensor gamma_, 9 | const at::Tensor beta_); 10 | 11 | std::vector BatchNorm_Backward_CPU( 12 | const at::Tensor gradoutput_, 13 | const at::Tensor input_, 14 | const at::Tensor mean_, 15 | const at::Tensor std_, 16 | const at::Tensor gamma_, 17 | const at::Tensor beta_, 18 | bool train); 19 | 20 | std::vector Sum_Square_Forward_CPU( 21 | const at::Tensor input_); 22 | 23 | at::Tensor Sum_Square_Backward_CPU( 24 | const at::Tensor input_, 25 | const at::Tensor gradSum_, 26 | const at::Tensor gradSquare_); -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/operator.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/core/nn/sync_bn/lib/cpu/operator.o -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | setup( 5 | name='syncbn_cpu', 6 | ext_modules=[ 7 | CppExtension('syncbn_cpu', [ 8 | 'operator.cpp', 9 | 'syncbn_cpu.cpp', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/syncbn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 6 | if (x.ndimension() == 2) { 7 | return v; 8 | } else { 9 | std::vector broadcast_size = {1, -1}; 10 | for (int64_t i = 2; i < x.ndimension(); ++i) 11 | broadcast_size.push_back(1); 12 | 13 | return v.view(broadcast_size); 14 | } 15 | } 16 | 17 | at::Tensor BatchNorm_Forward_CPU( 18 | const at::Tensor input, 19 | const at::Tensor mean, 20 | const at::Tensor std, 21 | const at::Tensor gamma, 22 | const at::Tensor beta) { 23 | auto output = (input - broadcast_to(mean, input)) / broadcast_to(std, input); 24 | output = output * broadcast_to(gamma, input) + broadcast_to(beta, input); 25 | return output; 26 | } 27 | 28 | // Not implementing CPU backward for now 29 | std::vector BatchNorm_Backward_CPU( 30 | const at::Tensor gradoutput, 31 | const at::Tensor input, 32 | const at::Tensor mean, 33 | const at::Tensor std, 34 | const at::Tensor gamma, 35 | const at::Tensor beta, 36 | bool train) { 37 | /* outputs*/ 38 | at::Tensor gradinput = at::zeros_like(input); 39 | at::Tensor gradgamma = at::zeros_like(gamma); 40 | at::Tensor gradbeta = at::zeros_like(beta); 41 | at::Tensor gradMean = at::zeros_like(mean); 42 | at::Tensor gradStd = at::zeros_like(std); 43 | return {gradinput, gradMean, gradStd, gradgamma, gradbeta}; 44 | } 45 | 46 | std::vector Sum_Square_Forward_CPU( 47 | const at::Tensor input) { 48 | /* outputs */ 49 | at::Tensor sum = torch::zeros({input.size(1)}, input.options()); 50 | at::Tensor square = torch::zeros({input.size(1)}, input.options()); 51 | return {sum, square}; 52 | } 53 | 54 | at::Tensor Sum_Square_Backward_CPU( 55 | const at::Tensor input, 56 | const at::Tensor gradSum, 57 | const at::Tensor gradSquare) { 58 | /* outputs */ 59 | at::Tensor gradInput = at::zeros_like(input); 60 | return gradInput; 61 | } -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/cpu/syncbn_cpu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/core/nn/sync_bn/lib/cpu/syncbn_cpu.o -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/gpu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/core/nn/sync_bn/lib/gpu/__init__.py -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/gpu/activation_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | // #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | 13 | namespace { 14 | 15 | template 16 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { 17 | // Create thrust pointers 18 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 19 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 20 | 21 | thrust::transform_if(th_dz, th_dz + count, th_z, th_dz, 22 | [slope] __device__ (const T& dz) { return dz * slope; }, 23 | [] __device__ (const T& z) { return z < 0; }); 24 | thrust::transform_if(th_z, th_z + count, th_z, 25 | [slope] __device__ (const T& z) { return z / slope; }, 26 | [] __device__ (const T& z) { return z < 0; }); 27 | } 28 | 29 | } 30 | 31 | void LeakyRelu_Forward_CUDA(at::Tensor z, float slope) { 32 | at::leaky_relu_(z, slope); 33 | } 34 | 35 | void LeakyRelu_Backward_CUDA(at::Tensor z, at::Tensor dz, float slope) { 36 | int64_t count = z.numel(); 37 | 38 | AT_DISPATCH_FLOATING_TYPES(z.type(), "LeakyRelu_Backward_CUDA", ([&] { 39 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count); 40 | })); 41 | /* 42 | // unstable after scaling 43 | at::leaky_relu_(z, 1.0 / slope); 44 | at::leaky_relu_backward(dz, z, slope); 45 | */ 46 | } -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/gpu/device_tensor.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template 4 | struct DeviceTensor { 5 | public: 6 | inline __device__ __host__ DeviceTensor(DType *p, const int *size) 7 | : dptr_(p) { 8 | for (int i = 0; i < Dim; ++i) { 9 | size_[i] = size ? size[i] : 0; 10 | } 11 | } 12 | 13 | inline __device__ __host__ unsigned getSize(const int i) const { 14 | assert(i < Dim); 15 | return size_[i]; 16 | } 17 | 18 | inline __device__ __host__ int numElements() const { 19 | int n = 1; 20 | for (int i = 0; i < Dim; ++i) { 21 | n *= size_[i]; 22 | } 23 | return n; 24 | } 25 | 26 | inline __device__ __host__ DeviceTensor select(const size_t x) const { 27 | assert(Dim > 1); 28 | int offset = x; 29 | for (int i = 1; i < Dim; ++i) { 30 | offset *= size_[i]; 31 | } 32 | DeviceTensor tensor(dptr_ + offset, nullptr); 33 | for (int i = 0; i < Dim - 1; ++i) { 34 | tensor.size_[i] = this->size_[i+1]; 35 | } 36 | return tensor; 37 | } 38 | 39 | inline __device__ __host__ DeviceTensor operator[](const size_t x) const { 40 | assert(Dim > 1); 41 | int offset = x; 42 | for (int i = 1; i < Dim; ++i) { 43 | offset *= size_[i]; 44 | } 45 | DeviceTensor tensor(dptr_ + offset, nullptr); 46 | for (int i = 0; i < Dim - 1; ++i) { 47 | tensor.size_[i] = this->size_[i+1]; 48 | } 49 | return tensor; 50 | } 51 | 52 | inline __device__ __host__ size_t InnerSize() const { 53 | assert(Dim >= 3); 54 | size_t sz = 1; 55 | for (size_t i = 2; i < Dim; ++i) { 56 | sz *= size_[i]; 57 | } 58 | return sz; 59 | } 60 | 61 | inline __device__ __host__ size_t ChannelCount() const { 62 | assert(Dim >= 3); 63 | return size_[1]; 64 | } 65 | 66 | inline __device__ __host__ DType* data_ptr() const { 67 | return dptr_; 68 | } 69 | 70 | DType *dptr_; 71 | int size_[Dim]; 72 | }; 73 | 74 | template 75 | struct DeviceTensor { 76 | inline __device__ __host__ DeviceTensor(DType *p, const int *size) 77 | : dptr_(p) { 78 | size_[0] = size ? size[0] : 0; 79 | } 80 | 81 | inline __device__ __host__ unsigned getSize(const int i) const { 82 | assert(i == 0); 83 | return size_[0]; 84 | } 85 | 86 | inline __device__ __host__ int numElements() const { 87 | return size_[0]; 88 | } 89 | 90 | inline __device__ __host__ DType &operator[](const size_t x) const { 91 | return *(dptr_ + x); 92 | } 93 | 94 | inline __device__ __host__ DType* data_ptr() const { 95 | return dptr_; 96 | } 97 | 98 | DType *dptr_; 99 | int size_[1]; 100 | }; 101 | 102 | template 103 | static DeviceTensor devicetensor(const at::Tensor &blob) { 104 | DType *data = blob.data(); 105 | DeviceTensor tensor(data, nullptr); 106 | for (int i = 0; i < Dim; ++i) { 107 | tensor.size_[i] = blob.size(i); 108 | } 109 | return tensor; 110 | } -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/gpu/operator.cpp: -------------------------------------------------------------------------------- 1 | #include "operator.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("batchnorm_forward", &BatchNorm_Forward_CUDA, "BatchNorm forward (CUDA)"); 5 | m.def("batchnorm_inp_forward", &BatchNorm_Forward_Inp_CUDA, "BatchNorm forward (CUDA)"); 6 | m.def("batchnorm_backward", &BatchNorm_Backward_CUDA, "BatchNorm backward (CUDA)"); 7 | m.def("batchnorm_inp_backward", &BatchNorm_Inp_Backward_CUDA, "BatchNorm backward (CUDA)"); 8 | m.def("expectation_forward", &Expectation_Forward_CUDA, "Expectation forward (CUDA)"); 9 | m.def("expectation_backward", &Expectation_Backward_CUDA, "Expectation backward (CUDA)"); 10 | m.def("expectation_inp_backward", &Expectation_Inp_Backward_CUDA, "Inplace Expectation backward (CUDA)"); 11 | m.def("leaky_relu_forward", &LeakyRelu_Forward_CUDA, "Learky ReLU forward (CUDA)"); 12 | m.def("leaky_relu_backward", &LeakyRelu_Backward_CUDA, "Learky ReLU backward (CUDA)"); 13 | } -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/gpu/operator.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | at::Tensor BatchNorm_Forward_CUDA( 5 | const at::Tensor input_, 6 | const at::Tensor mean_, 7 | const at::Tensor std_, 8 | const at::Tensor gamma_, 9 | const at::Tensor beta_, 10 | float eps); 11 | 12 | at::Tensor BatchNorm_Forward_Inp_CUDA( 13 | const at::Tensor input_, 14 | const at::Tensor ex_, 15 | const at::Tensor exs_, 16 | const at::Tensor gamma_, 17 | const at::Tensor beta_, 18 | float eps); 19 | 20 | std::vector BatchNorm_Backward_CUDA( 21 | const at::Tensor gradoutput_, 22 | const at::Tensor input_, 23 | const at::Tensor ex_, 24 | const at::Tensor exs_, 25 | const at::Tensor gamma_, 26 | const at::Tensor beta_, 27 | float eps); 28 | 29 | std::vector BatchNorm_Inp_Backward_CUDA( 30 | const at::Tensor gradoutput_, 31 | const at::Tensor output_, 32 | const at::Tensor ex_, 33 | const at::Tensor exs_, 34 | const at::Tensor gamma_, 35 | const at::Tensor beta_, 36 | float eps); 37 | 38 | std::vector Expectation_Forward_CUDA( 39 | const at::Tensor input_); 40 | 41 | at::Tensor Expectation_Backward_CUDA( 42 | const at::Tensor input_, 43 | const at::Tensor gradEx_, 44 | const at::Tensor gradExs_); 45 | 46 | at::Tensor Expectation_Inp_Backward_CUDA( 47 | const at::Tensor gradInput_, 48 | const at::Tensor output_, 49 | const at::Tensor gradEx_, 50 | const at::Tensor gradExs_, 51 | const at::Tensor ex_, 52 | const at::Tensor exs_, 53 | const at::Tensor gamma_, 54 | const at::Tensor beta_, 55 | float eps); 56 | 57 | void LeakyRelu_Forward_CUDA(at::Tensor z, float slope); 58 | 59 | void LeakyRelu_Backward_CUDA(at::Tensor z, at::Tensor dz, float slope); -------------------------------------------------------------------------------- /core/nn/sync_bn/lib/gpu/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='syncbn_gpu', 6 | ext_modules=[ 7 | CUDAExtension('sync_gpu', [ 8 | 'operator.cpp', 9 | 'activation_kernel.cu', 10 | 'syncbn_kernel.cu', 11 | ]), 12 | ], 13 | cmdclass={ 14 | 'build_ext': BuildExtension 15 | }) -------------------------------------------------------------------------------- /core/nn/sync_bn/syncbn.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """Synchronized Cross-GPU Batch Normalization Module""" 12 | import warnings 13 | import torch 14 | 15 | from torch.nn.modules.batchnorm import _BatchNorm 16 | from queue import Queue 17 | from .functions import * 18 | 19 | __all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'] 20 | 21 | 22 | # Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py 23 | class SyncBatchNorm(_BatchNorm): 24 | """Cross-GPU Synchronized Batch normalization (SyncBN) 25 | 26 | Parameters: 27 | num_features: num_features from an expected input of 28 | size batch_size x num_features x height x width 29 | eps: a value added to the denominator for numerical stability. 30 | Default: 1e-5 31 | momentum: the value used for the running_mean and running_var 32 | computation. Default: 0.1 33 | sync: a boolean value that when set to ``True``, synchronize across 34 | different gpus. Default: ``True`` 35 | activation : str 36 | Name of the activation functions, one of: `leaky_relu` or `none`. 37 | slope : float 38 | Negative slope for the `leaky_relu` activation. 39 | 40 | Shape: 41 | - Input: :math:`(N, C, H, W)` 42 | - Output: :math:`(N, C, H, W)` (same shape as input) 43 | Reference: 44 | .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015* 45 | .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* 46 | Examples: 47 | >>> m = SyncBatchNorm(100) 48 | >>> net = torch.nn.DataParallel(m) 49 | >>> output = net(input) 50 | """ 51 | 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01, inplace=True): 53 | super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True) 54 | self.activation = activation 55 | self.inplace = False if activation == 'none' else inplace 56 | self.slope = slope 57 | self.devices = list(range(torch.cuda.device_count())) 58 | self.sync = sync if len(self.devices) > 1 else False 59 | # Initialize queues 60 | self.worker_ids = self.devices[1:] 61 | self.master_queue = Queue(len(self.worker_ids)) 62 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 63 | 64 | def forward(self, x): 65 | # resize the input to (B, C, -1) 66 | input_shape = x.size() 67 | x = x.view(input_shape[0], self.num_features, -1) 68 | if x.get_device() == self.devices[0]: 69 | # Master mode 70 | extra = { 71 | "is_master": True, 72 | "master_queue": self.master_queue, 73 | "worker_queues": self.worker_queues, 74 | "worker_ids": self.worker_ids 75 | } 76 | else: 77 | # Worker mode 78 | extra = { 79 | "is_master": False, 80 | "master_queue": self.master_queue, 81 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 82 | } 83 | if self.inplace: 84 | return inp_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, 85 | extra, self.sync, self.training, self.momentum, self.eps, 86 | self.activation, self.slope).view(input_shape) 87 | else: 88 | return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, 89 | extra, self.sync, self.training, self.momentum, self.eps, 90 | self.activation, self.slope).view(input_shape) 91 | 92 | def extra_repr(self): 93 | if self.activation == 'none': 94 | return 'sync={}'.format(self.sync) 95 | else: 96 | return 'sync={}, act={}, slope={}, inplace={}'.format( 97 | self.sync, self.activation, self.slope, self.inplace) 98 | 99 | 100 | class BatchNorm1d(SyncBatchNorm): 101 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 102 | 103 | def __init__(self, *args, **kwargs): 104 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 105 | .format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning) 106 | super(BatchNorm1d, self).__init__(*args, **kwargs) 107 | 108 | 109 | class BatchNorm2d(SyncBatchNorm): 110 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 111 | 112 | def __init__(self, *args, **kwargs): 113 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 114 | .format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning) 115 | super(BatchNorm2d, self).__init__(*args, **kwargs) 116 | 117 | 118 | class BatchNorm3d(SyncBatchNorm): 119 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 120 | 121 | def __init__(self, *args, **kwargs): 122 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 123 | .format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning) 124 | super(BatchNorm3d, self).__init__(*args, **kwargs) 125 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | from __future__ import absolute_import 3 | 4 | from .download import download, check_sha1 5 | from .filesystem import makedirs, try_import_pycocotools 6 | -------------------------------------------------------------------------------- /core/utils/download.py: -------------------------------------------------------------------------------- 1 | """Download files with progress bar.""" 2 | import os 3 | import hashlib 4 | import requests 5 | from tqdm import tqdm 6 | 7 | def check_sha1(filename, sha1_hash): 8 | """Check whether the sha1 hash of the file content matches the expected hash. 9 | Parameters 10 | ---------- 11 | filename : str 12 | Path to the file. 13 | sha1_hash : str 14 | Expected sha1 hash in hexadecimal digits. 15 | Returns 16 | ------- 17 | bool 18 | Whether the file content matches the expected hash. 19 | """ 20 | sha1 = hashlib.sha1() 21 | with open(filename, 'rb') as f: 22 | while True: 23 | data = f.read(1048576) 24 | if not data: 25 | break 26 | sha1.update(data) 27 | 28 | sha1_file = sha1.hexdigest() 29 | l = min(len(sha1_file), len(sha1_hash)) 30 | return sha1.hexdigest()[0:l] == sha1_hash[0:l] 31 | 32 | def download(url, path=None, overwrite=False, sha1_hash=None): 33 | """Download an given URL 34 | Parameters 35 | ---------- 36 | url : str 37 | URL to download 38 | path : str, optional 39 | Destination path to store downloaded file. By default stores to the 40 | current directory with same name as in url. 41 | overwrite : bool, optional 42 | Whether to overwrite destination file if already exists. 43 | sha1_hash : str, optional 44 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 45 | but doesn't match. 46 | Returns 47 | ------- 48 | str 49 | The file path of the downloaded file. 50 | """ 51 | if path is None: 52 | fname = url.split('/')[-1] 53 | else: 54 | path = os.path.expanduser(path) 55 | if os.path.isdir(path): 56 | fname = os.path.join(path, url.split('/')[-1]) 57 | else: 58 | fname = path 59 | 60 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 61 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 62 | if not os.path.exists(dirname): 63 | os.makedirs(dirname) 64 | 65 | print('Downloading %s from %s...'%(fname, url)) 66 | r = requests.get(url, stream=True) 67 | if r.status_code != 200: 68 | raise RuntimeError("Failed downloading url %s"%url) 69 | total_length = r.headers.get('content-length') 70 | with open(fname, 'wb') as f: 71 | if total_length is None: # no content length header 72 | for chunk in r.iter_content(chunk_size=1024): 73 | if chunk: # filter out keep-alive new chunks 74 | f.write(chunk) 75 | else: 76 | total_length = int(total_length) 77 | for chunk in tqdm(r.iter_content(chunk_size=1024), 78 | total=int(total_length / 1024. + 0.5), 79 | unit='KB', unit_scale=False, dynamic_ncols=True): 80 | f.write(chunk) 81 | 82 | if sha1_hash and not check_sha1(fname, sha1_hash): 83 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 84 | 'The repo may be outdated or download may be incomplete. ' \ 85 | 'If the "repo_url" is overridden, consider switching to ' \ 86 | 'the default repo.'.format(fname)) 87 | 88 | return fname -------------------------------------------------------------------------------- /core/utils/filesystem.py: -------------------------------------------------------------------------------- 1 | """Filesystem utility functions.""" 2 | from __future__ import absolute_import 3 | import os 4 | import errno 5 | 6 | 7 | def makedirs(path): 8 | """Create directory recursively if not exists. 9 | Similar to `makedir -p`, you can skip checking existence before this function. 10 | Parameters 11 | ---------- 12 | path : str 13 | Path of the desired dir 14 | """ 15 | try: 16 | os.makedirs(path) 17 | except OSError as exc: 18 | if exc.errno != errno.EEXIST: 19 | raise 20 | 21 | 22 | def try_import(package, message=None): 23 | """Try import specified package, with custom message support. 24 | Parameters 25 | ---------- 26 | package : str 27 | The name of the targeting package. 28 | message : str, default is None 29 | If not None, this function will raise customized error message when import error is found. 30 | Returns 31 | ------- 32 | module if found, raise ImportError otherwise 33 | """ 34 | try: 35 | return __import__(package) 36 | except ImportError as e: 37 | if not message: 38 | raise e 39 | raise ImportError(message) 40 | 41 | 42 | def try_import_cv2(): 43 | """Try import cv2 at runtime. 44 | Returns 45 | ------- 46 | cv2 module if found. Raise ImportError otherwise 47 | """ 48 | msg = "cv2 is required, you can install by package manager, e.g. 'apt-get', \ 49 | or `pip install opencv-python --user` (note that this is unofficial PYPI package)." 50 | return try_import('cv2', msg) 51 | 52 | 53 | def import_try_install(package, extern_url=None): 54 | """Try import the specified package. 55 | If the package not installed, try use pip to install and import if success. 56 | Parameters 57 | ---------- 58 | package : str 59 | The name of the package trying to import. 60 | extern_url : str or None, optional 61 | The external url if package is not hosted on PyPI. 62 | For example, you can install a package using: 63 | "pip install git+http://github.com/user/repo/tarball/master/egginfo=xxx". 64 | In this case, you can pass the url to the extern_url. 65 | Returns 66 | ------- 67 | 68 | The imported python module. 69 | """ 70 | try: 71 | return __import__(package) 72 | except ImportError: 73 | try: 74 | from pip import main as pipmain 75 | except ImportError: 76 | from pip._internal import main as pipmain 77 | 78 | # trying to install package 79 | url = package if extern_url is None else extern_url 80 | pipmain(['install', '--user', url]) # will raise SystemExit Error if fails 81 | 82 | # trying to load again 83 | try: 84 | return __import__(package) 85 | except ImportError: 86 | import sys 87 | import site 88 | user_site = site.getusersitepackages() 89 | if user_site not in sys.path: 90 | sys.path.append(user_site) 91 | return __import__(package) 92 | return __import__(package) 93 | 94 | 95 | """Import helper for pycocotools""" 96 | 97 | 98 | # NOTE: for developers 99 | # please do not import any pycocotools in __init__ because we are trying to lazy 100 | # import pycocotools to avoid install it for other users who may not use it. 101 | # only import when you actually use it 102 | 103 | 104 | def try_import_pycocotools(): 105 | """Tricks to optionally install and import pycocotools""" 106 | # first we can try import pycocotools 107 | try: 108 | import pycocotools as _ 109 | except ImportError: 110 | import os 111 | # we need to install pycootools, which is a bit tricky 112 | # pycocotools sdist requires Cython, numpy(already met) 113 | import_try_install('cython') 114 | # pypi pycocotools is not compatible with windows 115 | win_url = 'git+https://github.com/zhreshold/cocoapi.git#subdirectory=PythonAPI' 116 | try: 117 | if os.name == 'nt': 118 | import_try_install('pycocotools', win_url) 119 | else: 120 | import_try_install('pycocotools') 121 | except ImportError: 122 | faq = 'cocoapi FAQ' 123 | raise ImportError('Cannot import or install pycocotools, please refer to %s.' % faq) 124 | -------------------------------------------------------------------------------- /core/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | 6 | __all__ = ['setup_logger'] 7 | 8 | 9 | # reference from: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/logger.py 10 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'): 11 | logger = logging.getLogger(name) 12 | logger.setLevel(logging.DEBUG) 13 | # don't log results for the non-master process 14 | if distributed_rank > 0: 15 | return logger 16 | ch = logging.StreamHandler(stream=sys.stdout) 17 | ch.setLevel(logging.DEBUG) 18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 19 | ch.setFormatter(formatter) 20 | logger.addHandler(ch) 21 | 22 | if save_dir: 23 | if not os.path.exists(save_dir): 24 | os.makedirs(save_dir) 25 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /core/utils/parallel.py: -------------------------------------------------------------------------------- 1 | """Utils for Semantic Segmentation""" 2 | import threading 3 | import torch 4 | import torch.cuda.comm as comm 5 | from torch.nn.parallel.data_parallel import DataParallel 6 | from torch.nn.parallel._functions import Broadcast 7 | from torch.autograd import Function 8 | 9 | __all__ = ['DataParallelModel', 'DataParallelCriterion'] 10 | 11 | 12 | class Reduce(Function): 13 | @staticmethod 14 | def forward(ctx, *inputs): 15 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 16 | inputs = sorted(inputs, key=lambda i: i.get_device()) 17 | return comm.reduce_add(inputs) 18 | 19 | @staticmethod 20 | def backward(ctx, gradOutputs): 21 | return Broadcast.apply(ctx.target_gpus, gradOutputs) 22 | 23 | 24 | class DataParallelModel(DataParallel): 25 | """Data parallelism 26 | 27 | Hide the difference of single/multiple GPUs to the user. 28 | In the forward pass, the module is replicated on each device, 29 | and each replica handles a portion of the input. During the backwards 30 | pass, gradients from each replica are summed into the original module. 31 | 32 | The batch size should be larger than the number of GPUs used. 33 | 34 | Parameters 35 | ---------- 36 | module : object 37 | Network to be parallelized. 38 | sync : bool 39 | enable synchronization (default: False). 40 | Inputs: 41 | - **inputs**: list of input 42 | Outputs: 43 | - **outputs**: list of output 44 | Example:: 45 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 46 | >>> output = net(input_var) # input_var can be on any device, including CPU 47 | """ 48 | 49 | def gather(self, outputs, output_device): 50 | return outputs 51 | 52 | def replicate(self, module, device_ids): 53 | modules = super(DataParallelModel, self).replicate(module, device_ids) 54 | return modules 55 | 56 | 57 | # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py 58 | class DataParallelCriterion(DataParallel): 59 | """ 60 | Calculate loss in multiple-GPUs, which balance the memory usage for 61 | Semantic Segmentation. 62 | 63 | The targets are splitted across the specified devices by chunking in 64 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 65 | 66 | Example:: 67 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 68 | >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 69 | >>> y = net(x) 70 | >>> loss = criterion(y, target) 71 | """ 72 | 73 | def forward(self, inputs, *targets, **kwargs): 74 | # the inputs should be the outputs of DataParallelModel 75 | if not self.device_ids: 76 | return self.module(inputs, *targets, **kwargs) 77 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 78 | if len(self.device_ids) == 1: 79 | return self.module(inputs, *targets[0], **kwargs[0]) 80 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 81 | outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs) 82 | return Reduce.apply(*outputs) / len(outputs) 83 | 84 | 85 | def get_a_var(obj): 86 | if isinstance(obj, torch.Tensor): 87 | return obj 88 | 89 | if isinstance(obj, list) or isinstance(obj, tuple): 90 | for result in map(get_a_var, obj): 91 | if isinstance(result, torch.Tensor): 92 | return result 93 | 94 | if isinstance(obj, dict): 95 | for result in map(get_a_var, obj.items()): 96 | if isinstance(result, torch.Tensor): 97 | return result 98 | return None 99 | 100 | 101 | def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 102 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 103 | contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword) 104 | on each of :attr:`devices`. 105 | 106 | Args: 107 | modules (Module): modules to be parallelized 108 | inputs (tensor): inputs to the modules 109 | targets (tensor): targets to the modules 110 | devices (list of int or torch.device): CUDA devices 111 | :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and 112 | :attr:`devices` (if given) should all have same length. Moreover, each 113 | element of :attr:`inputs` can either be a single object as the only argument 114 | to a module, or a collection of positional arguments. 115 | """ 116 | assert len(modules) == len(inputs) 117 | assert len(targets) == len(inputs) 118 | if kwargs_tup is not None: 119 | assert len(modules) == len(kwargs_tup) 120 | else: 121 | kwargs_tup = ({},) * len(modules) 122 | if devices is not None: 123 | assert len(modules) == len(devices) 124 | else: 125 | devices = [None] * len(modules) 126 | lock = threading.Lock() 127 | results = {} 128 | grad_enabled = torch.is_grad_enabled() 129 | 130 | def _worker(i, module, input, target, kwargs, device=None): 131 | torch.set_grad_enabled(grad_enabled) 132 | if device is None: 133 | device = get_a_var(input).get_device() 134 | try: 135 | with torch.cuda.device(device): 136 | output = module(*(list(input) + target), **kwargs) 137 | with lock: 138 | results[i] = output 139 | except Exception as e: 140 | with lock: 141 | results[i] = e 142 | 143 | if len(modules) > 1: 144 | threads = [threading.Thread(target=_worker, 145 | args=(i, module, input, target, kwargs, device)) 146 | for i, (module, input, target, kwargs, device) in 147 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 148 | 149 | for thread in threads: 150 | thread.start() 151 | for thread in threads: 152 | thread.join() 153 | else: 154 | _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0]) 155 | 156 | outputs = [] 157 | for i in range(len(inputs)): 158 | output = results[i] 159 | if isinstance(output, Exception): 160 | raise output 161 | outputs.append(output) 162 | return outputs 163 | -------------------------------------------------------------------------------- /core/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | __all__ = ['get_color_pallete', 'print_iou', 'set_img_color', 6 | 'show_prediction', 'show_colorful_images', 'save_colorful_images'] 7 | 8 | 9 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False): 10 | n = iu.size 11 | lines = [] 12 | for i in range(n): 13 | if class_names is None: 14 | cls = 'Class %d:' % (i + 1) 15 | else: 16 | cls = '%d %s' % (i + 1, class_names[i]) 17 | # lines.append('%-8s: %.3f%%' % (cls, iu[i] * 100)) 18 | mean_IU = np.nanmean(iu) 19 | mean_IU_no_back = np.nanmean(iu[1:]) 20 | if show_no_back: 21 | lines.append('mean_IU: %.3f%% || mean_IU_no_back: %.3f%% || mean_pixel_acc: %.3f%%' % ( 22 | mean_IU * 100, mean_IU_no_back * 100, mean_pixel_acc * 100)) 23 | else: 24 | lines.append('mean_IU: %.3f%% || mean_pixel_acc: %.3f%%' % (mean_IU * 100, mean_pixel_acc * 100)) 25 | lines.append('=================================================') 26 | line = "\n".join(lines) 27 | 28 | print(line) 29 | 30 | 31 | def set_img_color(img, label, colors, background=0, show255=False): 32 | for i in range(len(colors)): 33 | if i != background: 34 | img[np.where(label == i)] = colors[i] 35 | if show255: 36 | img[np.where(label == 255)] = 255 37 | 38 | return img 39 | 40 | 41 | def show_prediction(img, pred, colors, background=0): 42 | im = np.array(img, np.uint8) 43 | set_img_color(im, pred, colors, background) 44 | out = np.array(im) 45 | 46 | return out 47 | 48 | 49 | def show_colorful_images(prediction, palettes): 50 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()]) 51 | im.show() 52 | 53 | 54 | def save_colorful_images(prediction, filename, output_dir, palettes): 55 | ''' 56 | :param prediction: [B, H, W, C] 57 | ''' 58 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()]) 59 | fn = os.path.join(output_dir, filename) 60 | out_dir = os.path.split(fn)[0] 61 | if not os.path.exists(out_dir): 62 | os.mkdir(out_dir) 63 | im.save(fn) 64 | 65 | 66 | def get_color_pallete(npimg, dataset='pascal_voc'): 67 | """Visualize image. 68 | 69 | Parameters 70 | ---------- 71 | npimg : numpy.ndarray 72 | Single channel image with shape `H, W, 1`. 73 | dataset : str, default: 'pascal_voc' 74 | The dataset that model pretrained on. ('pascal_voc', 'ade20k') 75 | Returns 76 | ------- 77 | out_img : PIL.Image 78 | Image with color pallete 79 | """ 80 | # recovery boundary 81 | if dataset in ('pascal_voc', 'pascal_aug'): 82 | npimg[npimg == -1] = 255 83 | # put colormap 84 | if dataset == 'ade20k': 85 | npimg = npimg + 1 86 | out_img = Image.fromarray(npimg.astype('uint8')) 87 | out_img.putpalette(adepallete) 88 | return out_img 89 | elif dataset == 'citys': 90 | out_img = Image.fromarray(npimg.astype('uint8')) 91 | out_img.putpalette(cityspallete) 92 | return out_img 93 | out_img = Image.fromarray(npimg.astype('uint8')) 94 | out_img.putpalette(vocpallete) 95 | return out_img 96 | 97 | 98 | def _getvocpallete(num_cls): 99 | n = num_cls 100 | pallete = [0] * (n * 3) 101 | for j in range(0, n): 102 | lab = j 103 | pallete[j * 3 + 0] = 0 104 | pallete[j * 3 + 1] = 0 105 | pallete[j * 3 + 2] = 0 106 | i = 0 107 | while (lab > 0): 108 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 109 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 110 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 111 | i = i + 1 112 | lab >>= 3 113 | return pallete 114 | 115 | 116 | vocpallete = _getvocpallete(256) 117 | 118 | adepallete = [ 119 | 0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204, 120 | 5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82, 121 | 143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255, 122 | 7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6, 123 | 10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255, 124 | 20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15, 125 | 20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255, 126 | 31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163, 127 | 0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255, 128 | 0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0, 129 | 31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0, 130 | 194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255, 131 | 0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255, 132 | 0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0, 133 | 163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0, 134 | 10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0, 135 | 255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0, 136 | 133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255] 137 | 138 | cityspallete = [ 139 | 128, 64, 128, 140 | 244, 35, 232, 141 | 70, 70, 70, 142 | 102, 102, 156, 143 | 190, 153, 153, 144 | 153, 153, 153, 145 | 250, 170, 30, 146 | 220, 220, 0, 147 | 107, 142, 35, 148 | 152, 251, 152, 149 | 0, 130, 180, 150 | 220, 20, 60, 151 | 255, 0, 0, 152 | 0, 0, 142, 153 | 0, 0, 70, 154 | 0, 60, 100, 155 | 0, 80, 100, 156 | 0, 0, 230, 157 | 119, 11, 32, 158 | ] 159 | -------------------------------------------------------------------------------- /datasets/ade: -------------------------------------------------------------------------------- 1 | /home/tramac/PycharmProjects/Data_zoo/ade -------------------------------------------------------------------------------- /datasets/citys: -------------------------------------------------------------------------------- 1 | /home/tramac/PycharmProjects/Data_zoo/citys -------------------------------------------------------------------------------- /datasets/sbu: -------------------------------------------------------------------------------- 1 | /home/tramac/PycharmProjects/Data_zoo/SBU-shadow -------------------------------------------------------------------------------- /datasets/voc: -------------------------------------------------------------------------------- 1 | /home/tramac/PycharmProjects/Data_zoo/VOCdevkit -------------------------------------------------------------------------------- /docs/DETAILS.md: -------------------------------------------------------------------------------- 1 | ### Model & Backbone 2 | 3 | | Model | Scratch | VGG16 | ResNet18 | ResNet50 | ResNet101 | ResNet152 | DenseNet121 | DenseNet169 | 4 | | :-------: | :-----: | :---: | :------: | :------: | :-------: | :-------: | :---------: | :---------: | 5 | | FCN32s | ✘ | ✓ | | | | | | | 6 | | FCN16s | | ✓ | | | | | | | 7 | | FCN8s | | ✓ | | | | | | | 8 | | FCNv2 | | | | ✓ | ✓ | ✓ | | | 9 | | PSPNet | | | | ✓ | ✓ | ✓ | | | 10 | | DeepLabv3 | | | | ✓ | ✓ | ✓ | | | 11 | | DenseASPP | | | | | | | ✓ | ✓ | 12 | | DANet | | | | ✓ | ✓ | ✓ | | | 13 | | BiSeNet | | | ✓ | | | | | | 14 | | EncNet | | | | ✓ | ✓ | ✓ | | | 15 | | ICNet | | | | ✓ | ✓ | ✓ | | | 16 | | DUNet | | | | ✓ | ✓ | ✓ | | | 17 | | ENet | ✓ | | | | | | | | 18 | | OCNet | | | | ✓ | ✓ | ✓ | | | 19 | | CCNet | | | | ✓ | ✓ | ✓ | | | 20 | | PSANet | | | | ✓ | ✓ | ✓ | | | 21 | | CGNet | ✓ | | | | | | | | 22 | | ESPNet | ✓ | | | | | | | | 23 | | LEDNet | ✓ | | | | | | | | 24 | | DFANet | ✓ | | | | | | | | 25 | -------------------------------------------------------------------------------- /docs/requirements.yml: -------------------------------------------------------------------------------- 1 | name: seg_requirements 2 | dependencies: 3 | - python3 4 | - numpy 5 | - cuda 6 | - pip: 7 | - Image 8 | - tqdm 9 | - requests 10 | - pytorch 1.0 11 | - torchvision 12 | -------------------------------------------------------------------------------- /docs/weimar_000091_000019_gtFine_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/docs/weimar_000091_000019_gtFine_color.png -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | from torchvision import transforms 11 | from PIL import Image 12 | from core.utils.visualize import get_color_pallete 13 | from core.models import get_model 14 | 15 | parser = argparse.ArgumentParser( 16 | description='Predict segmentation result from a given image') 17 | parser.add_argument('--model', type=str, default='fcn32s_vgg16_voc', 18 | help='model name (default: fcn32_vgg16)') 19 | parser.add_argument('--dataset', type=str, default='pascal_aug', choices=['pascal_voc/pascal_aug/ade20k/citys'], 20 | help='dataset name (default: pascal_voc)') 21 | parser.add_argument('--save-folder', default='~/.torch/models', 22 | help='Directory for saving checkpoint models') 23 | parser.add_argument('--input-pic', type=str, default='../datasets/voc/VOC2012/JPEGImages/2007_000032.jpg', 24 | help='path to the input picture') 25 | parser.add_argument('--outdir', default='./eval', type=str, 26 | help='path to save the predict result') 27 | args = parser.parse_args() 28 | 29 | 30 | def demo(config): 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | # output folder 33 | if not os.path.exists(config.outdir): 34 | os.makedirs(config.outdir) 35 | 36 | # image transform 37 | transform = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 40 | ]) 41 | image = Image.open(config.input_pic).convert('RGB') 42 | images = transform(image).unsqueeze(0).to(device) 43 | 44 | model = get_model(args.model, pretrained=True, root=args.save_folder).to(device) 45 | print('Finished loading model!') 46 | 47 | model.eval() 48 | with torch.no_grad(): 49 | output = model(images) 50 | 51 | pred = torch.argmax(output[0], 1).squeeze(0).cpu().data.numpy() 52 | mask = get_color_pallete(pred, args.dataset) 53 | outname = os.path.splitext(os.path.split(args.input_pic)[-1])[0] + '.png' 54 | mask.save(os.path.join(args.outdir, outname)) 55 | 56 | 57 | if __name__ == '__main__': 58 | demo(args) 59 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils.data as data 13 | import torch.backends.cudnn as cudnn 14 | 15 | from torchvision import transforms 16 | from core.data.dataloader import get_segmentation_dataset 17 | from core.models.model_zoo import get_segmentation_model 18 | from core.utils.score import SegmentationMetric 19 | from core.utils.visualize import get_color_pallete 20 | from core.utils.logger import setup_logger 21 | from core.utils.distributed import synchronize, get_rank, make_data_sampler, make_batch_data_sampler 22 | 23 | from train import parse_args 24 | 25 | 26 | class Evaluator(object): 27 | def __init__(self, args): 28 | self.args = args 29 | self.device = torch.device(args.device) 30 | 31 | # image transform 32 | input_transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 35 | ]) 36 | 37 | # dataset and dataloader 38 | val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform) 39 | val_sampler = make_data_sampler(val_dataset, False, args.distributed) 40 | val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1) 41 | self.val_loader = data.DataLoader(dataset=val_dataset, 42 | batch_sampler=val_batch_sampler, 43 | num_workers=args.workers, 44 | pin_memory=True) 45 | 46 | # create network 47 | BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d 48 | self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone, 49 | aux=args.aux, pretrained=True, pretrained_base=False, 50 | local_rank=args.local_rank, 51 | norm_layer=BatchNorm2d).to(self.device) 52 | if args.distributed: 53 | self.model = nn.parallel.DistributedDataParallel(self.model, 54 | device_ids=[args.local_rank], output_device=args.local_rank) 55 | self.model.to(self.device) 56 | 57 | self.metric = SegmentationMetric(val_dataset.num_class) 58 | 59 | def eval(self): 60 | self.metric.reset() 61 | self.model.eval() 62 | if self.args.distributed: 63 | model = self.model.module 64 | else: 65 | model = self.model 66 | logger.info("Start validation, Total sample: {:d}".format(len(self.val_loader))) 67 | for i, (image, target, filename) in enumerate(self.val_loader): 68 | image = image.to(self.device) 69 | target = target.to(self.device) 70 | 71 | with torch.no_grad(): 72 | outputs = model(image) 73 | self.metric.update(outputs[0], target) 74 | pixAcc, mIoU = self.metric.get() 75 | logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format( 76 | i + 1, pixAcc * 100, mIoU * 100)) 77 | 78 | if self.args.save_pred: 79 | pred = torch.argmax(outputs[0], 1) 80 | pred = pred.cpu().data.numpy() 81 | 82 | predict = pred.squeeze(0) 83 | mask = get_color_pallete(predict, self.args.dataset) 84 | mask.save(os.path.join(outdir, os.path.splitext(filename[0])[0] + '.png')) 85 | synchronize() 86 | 87 | 88 | if __name__ == '__main__': 89 | args = parse_args() 90 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 91 | args.distributed = num_gpus > 1 92 | if not args.no_cuda and torch.cuda.is_available(): 93 | cudnn.benchmark = True 94 | args.device = "cuda" 95 | else: 96 | args.distributed = False 97 | args.device = "cpu" 98 | if args.distributed: 99 | torch.cuda.set_device(args.local_rank) 100 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 101 | synchronize() 102 | 103 | # TODO: optim code 104 | args.save_pred = True 105 | if args.save_pred: 106 | outdir = '../runs/pred_pic/{}_{}_{}'.format(args.model, args.backbone, args.dataset) 107 | if not os.path.exists(outdir): 108 | os.makedirs(outdir) 109 | 110 | logger = setup_logger("semantic_segmentation", args.log_dir, get_rank(), 111 | filename='{}_{}_{}_log.txt'.format(args.model, args.backbone, args.dataset), mode='a+') 112 | 113 | evaluator = Evaluator(args) 114 | evaluator.eval() 115 | torch.cuda.empty_cache() 116 | -------------------------------------------------------------------------------- /scripts/fcn32s_vgg16_pascal_voc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # train 4 | CUDA_VISIBLE_DEVICES=0 python train.py --model fcn32s \ 5 | --backbone vgg16 --dataset pascal_voc \ 6 | --lr 0.0001 --epochs 80 -------------------------------------------------------------------------------- /scripts/fcn32s_vgg16_pascal_voc_dist.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # train 4 | export NGPUS=4 5 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --model fcn32s \ 6 | --backbone vgg16 --dataset pascal_voc \ 7 | --lr 0.01 --epochs 80 --batch_size 16 -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Overfitting Test 2 | 3 | In order to ensure the correctness of models, the project provides a overfitting test (a trick which makes the train set and the val set includes the same images) script. 4 | Observing the convergence process of different models is so interesting:joy: 5 | 6 | ### Usage 7 | 8 | 9 | 10 |    (a) img: 2007_000033.jpg        (b) mask: 2007_000033.png 11 | 12 | ### Test Result 13 | | Model | backbone | epoch | mIoU | pixAcc | 14 | | :-----: | :----: | :-----: | :-----: | :------: | 15 | | FCN32s | vgg16 | 200 | 94.0% | 98.2% | 16 | | FCN16s | vgg16 | 200 | 99.2% | 99.8% | 17 | | FCN8s | vgg16 | 100 | 99.8% | 99.9% | 18 | | DANet | resnet50 | 100 | 99.5% | 99.9% | 19 | | EncNet | resnet50 | 100 | 99.7% | 99.9% | 20 | | DUNet | resnet50 | 100 | 98.8% | 99.6% | 21 | | PSPNet | resnet50 | 100 | 99.8% | 99.9% | 22 | | BiSeNet | resnet18 | 100 | 99.6% | 99.9% | 23 | | DenseASPP | densenet121 | 40 | 100% | 100% | 24 | | ICNet | resnet50 | 100 | 98.8% | 99.6% | 25 | | ENet | scratch | 100 | 99.9% | 100% | 26 | | OCNet | resnet50 | 100 | 99.8% | 100% | 27 | 28 | ### Visualization 29 | 30 | 31 | 32 | 33 | 34 | 35 |   FCN32s  FCN16s   FCN8s   DANet   EncNet    DUNet   PSPNet   BiSeNet   DenseASPP 36 | 37 | 38 | 39 | 40 |   ICNet   ENet   OCNet 41 | 42 | ### Conclusion 43 | - The result of FCN32s is the worst. 44 | - There are gridding artifacts in DUNet results. 45 | - The result of BiSeNet is bad when the `lr=1e-3`, the lr needs to be set to `1e-2`. 46 | - DenseASPP has the fastest convergence process, and reached 100%. 47 | - The lr of ENet need to be set to `1e-2`, the edge of result is not smooth. -------------------------------------------------------------------------------- /tests/runs/bisenet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/bisenet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/danet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/danet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/denseaspp_epoch_40.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/denseaspp_epoch_40.png -------------------------------------------------------------------------------- /tests/runs/dunet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/dunet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/encnet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/encnet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/enet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/enet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/fcn16s_epoch_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/fcn16s_epoch_200.png -------------------------------------------------------------------------------- /tests/runs/fcn32s_epoch_300.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/fcn32s_epoch_300.png -------------------------------------------------------------------------------- /tests/runs/fcn8s_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/fcn8s_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/icnet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/icnet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/ocnet_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/ocnet_epoch_100.png -------------------------------------------------------------------------------- /tests/runs/psp_epoch_100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/runs/psp_epoch_100.png -------------------------------------------------------------------------------- /tests/test_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/test_img.jpg -------------------------------------------------------------------------------- /tests/test_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xautdestiny/awesome-semantic-segmentation-pytorch/7fbe397a9add570fe1ebee8654898f2b3ba1942f/tests/test_mask.png -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | """Model overfitting test""" 2 | import argparse 3 | import time 4 | import os 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import numpy as np 10 | 11 | from torchvision import transforms 12 | from core.models.model_zoo import get_segmentation_model 13 | from core.utils.loss import MixSoftmaxCrossEntropyLoss, EncNetLoss, ICNetLoss 14 | from core.utils.lr_scheduler import LRScheduler 15 | from core.utils.score import hist_info, compute_score 16 | from core.utils.visualize import get_color_pallete 17 | from PIL import Image 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Semantic Segmentation Overfitting Test') 22 | # model 23 | parser.add_argument('--model', type=str, default='ocnet', 24 | choices=['fcn32s/fcn16s/fcn8s/fcn/psp/deeplabv3/danet/denseaspp/bisenet/encnet/dunet/icnet/enet/ocnet'], 25 | help='model name (default: fcn32s)') 26 | parser.add_argument('--backbone', type=str, default='resnet50', 27 | choices=['vgg16/resnet18/resnet50/resnet101/resnet152/densenet121/161/169/201'], 28 | help='backbone name (default: vgg16)') 29 | parser.add_argument('--dataset', type=str, default='pascal_voc', 30 | choices=['pascal_voc/pascal_aug/ade20k/citys/sbu'], 31 | help='dataset name (default: pascal_voc)') 32 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 33 | help='number of epochs to train (default: 60)') 34 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 35 | help='learning rate (default: 1e-3)') 36 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 37 | help='momentum (default: 0.9)') 38 | parser.add_argument('--weight-decay', type=float, default=1e-4, metavar='M', 39 | help='w-decay (default: 5e-4)') 40 | args = parser.parse_args() 41 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 42 | cudnn.benchmark = True 43 | args.device = device 44 | print(args) 45 | return args 46 | 47 | 48 | class VOCSegmentation(object): 49 | def __init__(self): 50 | super(VOCSegmentation, self).__init__() 51 | self.img = Image.open('test_img.jpg').convert('RGB') 52 | self.mask = Image.open('test_mask.png') 53 | 54 | self.img = self.img.resize((504, 368), Image.BILINEAR) 55 | self.mask = self.mask.resize((504, 368), Image.NEAREST) 56 | 57 | def get(self): 58 | img, mask = self._img_transform(self.img), self._mask_transform(self.mask) 59 | return img, mask 60 | 61 | def _img_transform(self, img): 62 | input_transform = transforms.Compose([ 63 | transforms.ToTensor(), 64 | transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) 65 | img = input_transform(img) 66 | img = img.unsqueeze(0) 67 | 68 | # For adaptive pooling 69 | # img = torch.cat([img, img], dim=0) 70 | return img 71 | 72 | def _mask_transform(self, mask): 73 | target = np.array(mask).astype('int32') 74 | target[target == 255] = -1 75 | target = torch.from_numpy(target).long() 76 | target = target.unsqueeze(0) 77 | 78 | # For adaptive pooling 79 | # target = torch.cat([target, target], dim=0) 80 | return target 81 | 82 | 83 | class Trainer(object): 84 | def __init__(self, args): 85 | self.args = args 86 | 87 | self.img, self.target = VOCSegmentation().get() 88 | 89 | self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone, 90 | aux=False, norm_layer=nn.BatchNorm2d).to(args.device) 91 | 92 | self.criterion = MixSoftmaxCrossEntropyLoss(False, 0., ignore_label=-1).to(args.device) 93 | 94 | # for EncNet 95 | # self.criterion = EncNetLoss(nclass=21, ignore_label=-1).to(args.device) 96 | # for ICNet 97 | # self.criterion = ICNetLoss(nclass=21, ignore_index=-1).to(args.device) 98 | 99 | self.optimizer = torch.optim.Adam(self.model.parameters(), 100 | lr=args.lr, 101 | weight_decay=args.weight_decay) 102 | self.lr_scheduler = LRScheduler(mode='poly', base_lr=args.lr, nepochs=args.epochs, 103 | iters_per_epoch=1, power=0.9) 104 | 105 | def train(self): 106 | self.model.train() 107 | start_time = time.time() 108 | for epoch in range(self.args.epochs): 109 | cur_lr = self.lr_scheduler(epoch) 110 | for param_group in self.optimizer.param_groups: 111 | param_group['lr'] = cur_lr 112 | 113 | images = self.img.to(self.args.device) 114 | targets = self.target.to(self.args.device) 115 | 116 | outputs = self.model(images) 117 | loss = self.criterion(outputs, targets) 118 | 119 | self.optimizer.zero_grad() 120 | loss.backward() 121 | self.optimizer.step() 122 | 123 | pred = torch.argmax(outputs[0], 1).cpu().data.numpy() 124 | mask = get_color_pallete(pred.squeeze(0), self.args.dataset) 125 | save_pred(self.args, epoch, mask) 126 | hist, labeled, correct = hist_info(pred, targets.numpy(), 21) 127 | _, mIoU, _, pixAcc = compute_score(hist, correct, labeled) 128 | 129 | print('Epoch: [%2d/%2d] || Time: %4.4f sec || lr: %.8f || Loss: %.4f || pixAcc: %.3f || mIoU: %.3f' % ( 130 | epoch, self.args.epochs, time.time() - start_time, cur_lr, loss.item(), pixAcc, mIoU)) 131 | 132 | 133 | def save_pred(args, epoch, mask): 134 | directory = "runs/%s/" % (args.model) 135 | if not os.path.exists(directory): 136 | os.makedirs(directory) 137 | filename = directory + '{}_epoch_{}.png'.format(args.model, epoch + 1) 138 | mask.save(filename) 139 | 140 | 141 | if __name__ == '__main__': 142 | args = parse_args() 143 | trainer = Trainer(args) 144 | print('Test model: ', args.model) 145 | trainer.train() 146 | -------------------------------------------------------------------------------- /tests/test_module.py: -------------------------------------------------------------------------------- 1 | import core 2 | import torch 3 | import numpy as np 4 | 5 | from torch.autograd import Variable 6 | 7 | EPS = 1e-3 8 | ATOL = 1e-3 9 | 10 | 11 | def _assert_tensor_close(a, b, atol=ATOL, rtol=EPS): 12 | npa, npb = a.cpu().numpy(), b.cpu().numpy() 13 | assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ 14 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( 15 | a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 16 | 17 | 18 | def testSyncBN(): 19 | def _check_batchnorm_result(bn1, bn2, input, is_train, cuda=False): 20 | def _find_bn(module): 21 | for m in module.modules(): 22 | if isinstance(m, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, 23 | core.nn.SyncBatchNorm)): 24 | return m 25 | 26 | def _syncParameters(bn1, bn2): 27 | bn1.reset_parameters() 28 | bn2.reset_parameters() 29 | if bn1.affine and bn2.affine: 30 | bn2.weight.data.copy_(bn1.weight.data) 31 | bn2.bias.data.copy_(bn1.bias.data) 32 | bn2.running_mean.copy_(bn1.running_mean) 33 | bn2.running_var.copy_(bn1.running_var) 34 | 35 | bn1.train(mode=is_train) 36 | bn2.train(mode=is_train) 37 | 38 | if cuda: 39 | input = input.cuda() 40 | # using the same values for gamma and beta 41 | _syncParameters(_find_bn(bn1), _find_bn(bn2)) 42 | 43 | input1 = Variable(input.clone().detach(), requires_grad=True) 44 | input2 = Variable(input.clone().detach(), requires_grad=True) 45 | if is_train: 46 | bn1.train() 47 | bn2.train() 48 | output1 = bn1(input1) 49 | output2 = bn2(input2) 50 | else: 51 | bn1.eval() 52 | bn2.eval() 53 | with torch.no_grad(): 54 | output1 = bn1(input1) 55 | output2 = bn2(input2) 56 | # assert forwarding 57 | # _assert_tensor_close(input1.data, input2.data) 58 | _assert_tensor_close(output1.data, output2.data) 59 | if not is_train: 60 | return 61 | (output1 ** 2).sum().backward() 62 | (output2 ** 2).sum().backward() 63 | _assert_tensor_close(_find_bn(bn1).bias.grad.data, _find_bn(bn2).bias.grad.data) 64 | _assert_tensor_close(_find_bn(bn1).weight.grad.data, _find_bn(bn2).weight.grad.data) 65 | _assert_tensor_close(input1.grad.data, input2.grad.data) 66 | _assert_tensor_close(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 67 | # _assert_tensor_close(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 68 | 69 | bn = torch.nn.BatchNorm2d(10).cuda().double() 70 | sync_bn = core.nn.SyncBatchNorm(10, inplace=True, sync=True).cuda().double() 71 | sync_bn = torch.nn.DataParallel(sync_bn).cuda() 72 | # check with unsync version 73 | # _check_batchnorm_result(bn, sync_bn, torch.rand(2, 1, 2, 2).double(), True, cuda=True) 74 | for i in range(10): 75 | print(i) 76 | _check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), True, cuda=True) 77 | # _check_batchnorm_result(bn, sync_bn, torch.rand(16, 10, 16, 16).double(), False, cuda=True) 78 | 79 | 80 | if __name__ == '__main__': 81 | import nose 82 | 83 | nose.runmodule() 84 | --------------------------------------------------------------------------------