├── .gitignore ├── README.md ├── docs └── imgs │ ├── MP_SEL_SUR_001453_out.jpg │ ├── MP_SEL_SUR_001456_out.jpg │ ├── MP_SEL_SUR_001457.jpg │ ├── MP_SEL_SUR_001457.png │ ├── MP_SEL_SUR_001503_out.jpg │ ├── MP_SEL_SUR_001563_out.jpg │ ├── deeplabv3p.jpg │ └── logs.png ├── evaluate.py ├── modules ├── __init__.py ├── dataloaders │ ├── __init__.py │ ├── custom_transforms.py │ ├── datasets │ │ ├── __init__.py │ │ ├── cityscapes.py │ │ ├── coco.py │ │ ├── combine_dbs.py │ │ ├── pascal.py │ │ ├── sbd.py │ │ └── surface.py │ └── utils.py ├── models │ ├── deeplab_xception.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py └── utils │ ├── __init__.py │ ├── calculate_weights.py │ ├── loss.py │ ├── lr_scheduler.py │ ├── metrics.py │ ├── saver.py │ ├── summaries.py │ ├── surface_dataset_tools │ ├── split_dataset.py │ └── surface_polygon.py │ └── torch_logger.py ├── predict.py ├── settings.py ├── test └── jpgs │ ├── MP_SEL_SUR_001453.jpg │ ├── MP_SEL_SUR_001456.jpg │ ├── MP_SEL_SUR_001503.jpg │ └── MP_SEL_SUR_001563.jpg └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .idea/ 107 | 108 | output/ 109 | run/ 110 | test/test.mp4 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # segmentation-selectstar 2 | 3 | ![](https://media.giphy.com/media/S7KnEAj0ZYpXeDLLuJ/giphy.gif) 4 | 5 | # Introduction 6 | 7 | This is a prototype model of **Pedestrian zone detection for blind people**. 8 | 9 | Separates **sidewalk** and **driveway** areas using **Semantic Segmentation**. 10 | 11 | Sample Results | - 12 | --- | --- 13 | ![](docs/imgs/MP_SEL_SUR_001453_out.jpg) | ![](docs/imgs/MP_SEL_SUR_001456_out.jpg) 14 | ![](docs/imgs/MP_SEL_SUR_001503_out.jpg) | ![](docs/imgs/MP_SEL_SUR_001563_out.jpg) 15 | 16 | Original repository: https://github.com/jfzhang95/pytorch-deeplab-xception 17 | 18 | Modified to run NIA SurfaceMasking dataset by yoongi@selectstar.ai 19 | 20 | 21 | # Model 22 | 23 | ### DeepLab v3+ 24 | ![](docs/imgs/deeplabv3p.jpg) 25 | 26 | [Paper] https://arxiv.org/abs/1802.02611 27 | 28 | 29 | # Training Surface Masking Dataset 30 | 31 | 1. **Download NIA Surface Masking dataset from AIhub.** (Not yet published) 32 | 33 | The Original NIA Surface Masking dataset consists of the following classes: 34 | 35 | class@attribute | Meaning 36 | --- | --- 37 | alley@crosswalk|이면도로 - 횡단보도 38 | alley@damaged|이면도로 -파손 39 | alley@normal|이면도로 - 속성값 없음 40 | alley@speed_bump|이면도로 - 과속방지턱 41 | bike_lane|자전거도로 42 | braille_guide_blocks@damaged|점자블록 -파손 43 | braille_guide_blocks@normal|점자블록 -속성값 없음 44 | caution_zone@grating|주의구역 - 그레이팅 45 | caution_zone@manhole|주의구역 - 맨홀 46 | caution_zone@repair_zone|주의구역 - 보수구역 47 | caution_zone@stairs|주의구역 - 계단 48 | caution_zone@tree_zone|주의구역 - 가로수영역 49 | roadway@crosswalk|차도 - 횡단보도 50 | roadway@normal|차도 - 속성값없음 51 | sidewalk@asphalt|인도 - 아스팔트 52 | sidewalk@blocks|인도 - 보도블럭 53 | sidewalk@cement|인도 - 시멘트 54 | sidewalk@damaged|인도 - 파손 55 | sidewalk@other|인도 - 기타 56 | sidewalk@soil_stone|인도 - 흙,돌,비포장 57 | sidewalk@urethane|인도 - 우레탄 58 | 59 | **But there are too many classes to do segmentation, so I reduced into 6 classes:** 60 | 61 | New Class | Label | RGB Color 62 | --- | --- | --- 63 | background|0|[0, 0, 0] 64 | bike_lane|1|[255, 128, 0] 65 | caution_zone|2|[255, 0, 0] 66 | crosswalk|3|[255, 0, 255] 67 | guide_block|4|[255, 255, 0] 68 | roadway|5|[0, 0, 255] 69 | sidewalk|6|[0, 255, 0] 70 | 71 | **Check ```settings.py``` for detailed classes info.** 72 | 73 | 74 | 2. **Generate mask images by running:** 75 | 1. ```modules/utils/surface_dataset_tools/surface_polygon.py``` 76 | 2. ```modules/utils/surface_dataset_tools/split_dataset.py``` 77 | 78 | image | mask 79 | --- | --- 80 | ![](docs/imgs/MP_SEL_SUR_001457.jpg)|![](docs/imgs/MP_SEL_SUR_001457.png) 81 | 82 | 3. **Dataset structure should be like this.** 83 | ``` 84 | surface6 85 | ├── annotations 86 | │ ├── *.xml 87 | ├── images 88 | │ ├── *.jpg 89 | ├── masks 90 | │ ├── *.png 91 | ├── train.txt 92 | └── valid.txt 93 | ``` 94 | 4. **Install python packages** 95 | ``` 96 | Install Anaconda3 [https://www.anaconda.com/distribution/] 97 | conda create ml 98 | conda activate ml 99 | conda install conda 100 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 101 | pip install tensorboardx, matplotlib 102 | ``` 103 | 5. **Edit training options ```settings.py```** 104 | ``` 105 | Designate dataset directory 106 | ... 107 | elif dataset == 'surface': 108 | root_dir = '/home/super/Projects/dataset/surface6' 109 | ... 110 | ``` 111 | 6. **Run ```train.py```** 112 | 1. On Windows: ```python train.py``` 113 | 2. On Linux: ```python3 train.py``` 114 | 3. On multi-gpu: ```CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train.py``` 115 | 116 | 117 | # Download Trained Weights 118 | [Model Weights](https://drive.google.com/file/d/1Y8RhV3hWEoE4mqbriGbAyMDQMIaQdrnb/view?usp=sharing) 119 | 120 | Download and put it into ```./run/surface/deeplab/model_iou_77.pth.tar``` 121 | 122 | (Just create directory or Edit settings.py -> checkpoint, predict.py -> MODEL_PATH) 123 | 124 | [Settings for Reproduction](https://drive.google.com/drive/folders/16Pu_N7TOJN6NA9d92ohREWsVy9cWRH1i?usp=sharing) 125 | 126 | ![](docs/imgs/logs.png) 127 | Trained on TitanXP x 4 128 | 129 | 130 | # Predict 131 | 1. Prepare 'mp4 video' or 'jpg images' to predict. And put it into 'test' directory. 132 | 2. Prepare trained model like ```model_iou_77.pth.tar``` 133 | 2. Edit ```RUN OPTIONS``` on predict.py 134 | ``` 135 | MODEL_PATH, MODE, DATA_PATH, OUTPUT_PATH 136 | ``` 137 | 3. Run ```predict.py``` 138 | 4. Output result will be saved to OUTPUT_PATH 139 | 140 | 141 | # Evaluate 142 | 1. Prepare dataset and trained model file. 143 | 2. Check settings.py options. 144 | 3. Run evaluate.py 145 | 146 | ### Performance 147 | 148 | Result of 2000 random selected validation set. 149 | 150 | (fwIoU: Frequency Weighted Intersection over Union) 151 | 152 | Acc | Acc_class | mIoU | fwIoU 153 | --- | --- | --- | --- 154 | 91.46% | 84.74% | 77.29% | 84.34% 155 | 156 | IoU of each class 157 | 158 | Class | IoU 159 | --- | --- 160 | background|85.40% 161 | bike_lane|64.78% 162 | caution_zone|57.19% 163 | crosswalk|80.21% 164 | guide_block|81.34% 165 | roadway|85.69% 166 | sidewalk|86.45% 167 | -------------------------------------------------------------------------------- /docs/imgs/MP_SEL_SUR_001453_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/docs/imgs/MP_SEL_SUR_001453_out.jpg -------------------------------------------------------------------------------- /docs/imgs/MP_SEL_SUR_001456_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/docs/imgs/MP_SEL_SUR_001456_out.jpg -------------------------------------------------------------------------------- /docs/imgs/MP_SEL_SUR_001457.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/docs/imgs/MP_SEL_SUR_001457.jpg -------------------------------------------------------------------------------- /docs/imgs/MP_SEL_SUR_001457.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/docs/imgs/MP_SEL_SUR_001457.png -------------------------------------------------------------------------------- /docs/imgs/MP_SEL_SUR_001503_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/docs/imgs/MP_SEL_SUR_001503_out.jpg -------------------------------------------------------------------------------- /docs/imgs/MP_SEL_SUR_001563_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/docs/imgs/MP_SEL_SUR_001563_out.jpg -------------------------------------------------------------------------------- /docs/imgs/deeplabv3p.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/docs/imgs/deeplabv3p.jpg -------------------------------------------------------------------------------- /docs/imgs/logs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/docs/imgs/logs.png -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | import settings 7 | from modules.dataloaders import make_data_loader 8 | from modules.models.sync_batchnorm.replicate import patch_replication_callback 9 | from modules.models.deeplab_xception import DeepLabv3_plus 10 | from modules.utils.loss import SegmentationLosses 11 | from modules.utils.calculate_weights import calculate_weigths_labels 12 | from modules.utils.metrics import Evaluator 13 | 14 | """ 15 | Running this program requires settings.py options below 16 | 17 | settings.resume = True 18 | settings.checkpoint = 'PathToCheckpointModel.pth.tar' 19 | settings.dataset = 'surface' 20 | settings.root_dir = '/path/to/surface6' 21 | settings.num_classes 22 | settings.resize_height 23 | settings.resize_width 24 | settings.batch_size 25 | settings.workers 26 | (settings.use_sbd=False) 27 | settings.use_balanced_weights 28 | settings.cuda 29 | settings.loss_type 30 | settings.gpu_ids 31 | settings.labels 32 | """ 33 | 34 | 35 | class Trainer(object): 36 | def __init__(self, ): 37 | # Define Dataloader 38 | kwargs = {'num_workers': settings.workers, 'pin_memory': True} 39 | self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(**kwargs) 40 | 41 | # Define network 42 | self.model = DeepLabv3_plus(nInputChannels=3, n_classes=self.nclass, os=16, pretrained=settings.pretrained, 43 | _print=True) 44 | 45 | # Define Criterion 46 | # whether to use class balanced weights 47 | if settings.use_balanced_weights: 48 | classes_weights_path = os.path.join(settings.root_dir, settings.dataset + '_classes_weights.npy') 49 | if os.path.isfile(classes_weights_path): 50 | weight = np.load(classes_weights_path) 51 | else: 52 | weight = calculate_weigths_labels(settings.dataset, self.train_loader, self.nclass) 53 | weight = torch.from_numpy(weight.astype(np.float32)) 54 | else: 55 | weight = None 56 | self.criterion = SegmentationLosses(weight=weight, cuda=settings.cuda).build_loss(mode=settings.loss_type) 57 | 58 | # Define Evaluator 59 | self.evaluator = Evaluator(self.nclass) 60 | 61 | # Using cuda 62 | if settings.cuda: 63 | self.model = torch.nn.DataParallel(self.model, device_ids=settings.gpu_ids) 64 | patch_replication_callback(self.model) 65 | self.model = self.model.cuda() 66 | 67 | # Resuming checkpoint 68 | self.best_pred = 0.0 69 | if settings.resume is False: 70 | print("settings.resume is False but ignoring...") 71 | if not os.path.isfile(settings.checkpoint): 72 | raise RuntimeError("=> no checkpoint found at '{}'.\ 73 | Please designate pretrained weights file to settings.checkpoint='~.pth.tar'.".format(settings.checkpoint)) 74 | checkpoint = torch.load(settings.checkpoint) 75 | settings.start_epoch = checkpoint['epoch'] 76 | if settings.cuda: 77 | self.model.module.load_state_dict(checkpoint['state_dict']) 78 | else: 79 | self.model.load_state_dict(checkpoint['state_dict']) 80 | # if not settings.ft: 81 | # self.optimizer.load_state_dict(checkpoint['optimizer']) 82 | self.best_pred = checkpoint['best_pred'] 83 | print("=> loaded checkpoint '{}' (epoch {})" 84 | .format(settings.checkpoint, checkpoint['epoch'])) 85 | 86 | def validation(self): 87 | self.model.eval() 88 | self.evaluator.reset() 89 | tbar = tqdm(self.val_loader, desc='\r') 90 | test_loss = 0.0 91 | for i, sample in enumerate(tbar): 92 | image, target = sample['image'], sample['label'] 93 | if settings.cuda: 94 | image, target = image.cuda(), target.cuda() 95 | with torch.no_grad(): 96 | output = self.model(image) 97 | loss = self.criterion(output, target) 98 | test_loss += loss.item() 99 | tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) 100 | pred = output.data.cpu().numpy() 101 | target = target.cpu().numpy() 102 | pred = np.argmax(pred, axis=1) 103 | # Add batch sample into evaluator 104 | self.evaluator.add_batch(target, pred) 105 | 106 | # Fast test during the training 107 | Acc = self.evaluator.Pixel_Accuracy() 108 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 109 | mIoU = self.evaluator.Mean_Intersection_over_Union() 110 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 111 | print('Validation:') 112 | print('numImages: %5d' % (i * settings.batch_size + image.data.shape[0])) 113 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 114 | print('Loss: %.3f' % test_loss) 115 | ClassIoU = self.evaluator.Intersection_over_Union() 116 | 117 | print('IoU of each class') 118 | for index, label in enumerate(settings.labels): 119 | print('{}: {}'.format(label, ClassIoU[index])) 120 | 121 | if __name__ == "__main__": 122 | trainer = Trainer() 123 | trainer.validation() 124 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/modules/__init__.py -------------------------------------------------------------------------------- /modules/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import settings 3 | 4 | def make_data_loader(**kwargs): 5 | 6 | if settings.dataset == 'pascal': 7 | from modules.dataloaders.datasets import pascal 8 | 9 | train_set = pascal.VOCSegmentation(settings, split='train') 10 | val_set = pascal.VOCSegmentation(settings, split='val') 11 | if settings.use_sbd: 12 | from modules.dataloaders.datasets import sbd, combine_dbs 13 | sbd_train = sbd.SBDSegmentation(settings, split=['train', 'val']) 14 | train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set]) 15 | 16 | num_class = train_set.NUM_CLASSES 17 | train_loader = DataLoader(train_set, batch_size=settings.batch_size, shuffle=True, **kwargs) 18 | val_loader = DataLoader(val_set, batch_size=settings.batch_size, shuffle=False, **kwargs) 19 | test_loader = None 20 | 21 | return train_loader, val_loader, test_loader, num_class 22 | 23 | elif settings.dataset == 'cityscapes': 24 | from modules.dataloaders.datasets import cityscapes 25 | 26 | train_set = cityscapes.CityscapesSegmentation(settings, split='train') 27 | val_set = cityscapes.CityscapesSegmentation(settings, split='val') 28 | test_set = cityscapes.CityscapesSegmentation(settings, split='test') 29 | num_class = train_set.NUM_CLASSES 30 | train_loader = DataLoader(train_set, batch_size=settings.batch_size, shuffle=True, **kwargs) 31 | val_loader = DataLoader(val_set, batch_size=settings.batch_size, shuffle=False, **kwargs) 32 | test_loader = DataLoader(test_set, batch_size=settings.batch_size, shuffle=False, **kwargs) 33 | 34 | return train_loader, val_loader, test_loader, num_class 35 | 36 | elif settings.dataset == 'coco': 37 | from modules.dataloaders.datasets import coco 38 | 39 | train_set = coco.COCOSegmentation(split='train') 40 | val_set = coco.COCOSegmentation(settings.root_dir, split='valid') 41 | num_class = train_set.NUM_CLASSES 42 | train_loader = DataLoader(train_set, batch_size=settings.batch_size, shuffle=True, **kwargs) 43 | val_loader = DataLoader(val_set, batch_size=settings.batch_size, shuffle=False, **kwargs) 44 | test_loader = None 45 | return train_loader, val_loader, test_loader, num_class 46 | 47 | elif settings.dataset == 'surface': 48 | from modules.dataloaders.datasets import surface 49 | 50 | train_set = surface.SurfaceSegmentation(split='train') 51 | val_set = surface.SurfaceSegmentation(split='valid') 52 | num_class = train_set.NUM_CLASSES 53 | train_loader = DataLoader(train_set, batch_size=settings.batch_size, shuffle=True, **kwargs) 54 | val_loader = DataLoader(val_set, batch_size=settings.batch_size, shuffle=False, **kwargs) 55 | test_loader = None 56 | return train_loader, val_loader, test_loader, num_class 57 | 58 | else: 59 | raise NotImplementedError 60 | 61 | -------------------------------------------------------------------------------- /modules/dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | class Normalize(object): 8 | """Normalize a tensor image with mean and standard deviation. 9 | Args: 10 | mean (tuple): means for each channel. 11 | std (tuple): standard deviations for each channel. 12 | """ 13 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 14 | self.mean = mean 15 | self.std = std 16 | 17 | def __call__(self, sample): 18 | img = sample['image'] 19 | mask = sample['label'] 20 | img = np.array(img).astype(np.float32) 21 | mask = np.array(mask).astype(np.float32) 22 | img /= 255.0 23 | img -= self.mean 24 | img /= self.std 25 | 26 | return {'image': img, 27 | 'label': mask} 28 | 29 | 30 | class ToTensor(object): 31 | """Convert ndarrays in sample to Tensors.""" 32 | 33 | def __call__(self, sample): 34 | # swap color axis because 35 | # numpy image: H x W x C 36 | # torch image: C X H X W 37 | img = sample['image'] 38 | mask = sample['label'] 39 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 40 | mask = np.array(mask).astype(np.float32) 41 | 42 | img = torch.from_numpy(img).float() 43 | mask = torch.from_numpy(mask).float() 44 | 45 | return {'image': img, 46 | 'label': mask} 47 | 48 | 49 | class RandomHorizontalFlip(object): 50 | def __call__(self, sample): 51 | img = sample['image'] 52 | mask = sample['label'] 53 | if random.random() < 0.5: 54 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 55 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 56 | 57 | return {'image': img, 58 | 'label': mask} 59 | 60 | 61 | class RandomRotate(object): 62 | def __init__(self, degree): 63 | self.degree = degree 64 | 65 | def __call__(self, sample): 66 | img = sample['image'] 67 | mask = sample['label'] 68 | rotate_degree = random.uniform(-1*self.degree, self.degree) 69 | img = img.rotate(rotate_degree, Image.BILINEAR) 70 | mask = mask.rotate(rotate_degree, Image.NEAREST) 71 | 72 | return {'image': img, 73 | 'label': mask} 74 | 75 | 76 | class RandomGaussianBlur(object): 77 | def __call__(self, sample): 78 | img = sample['image'] 79 | mask = sample['label'] 80 | if random.random() < 0.5: 81 | img = img.filter(ImageFilter.GaussianBlur( 82 | radius=random.random())) 83 | 84 | return {'image': img, 85 | 'label': mask} 86 | 87 | 88 | class RandomScaleCrop(object): 89 | def __init__(self, base_size, crop_size, fill=0): 90 | self.base_size = base_size 91 | self.crop_size = crop_size 92 | self.fill = fill 93 | 94 | def __call__(self, sample): 95 | img = sample['image'] 96 | mask = sample['label'] 97 | # random scale (short edge) 98 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 99 | w, h = img.size 100 | if h > w: 101 | ow = short_size 102 | oh = int(1.0 * h * ow / w) 103 | else: 104 | oh = short_size 105 | ow = int(1.0 * w * oh / h) 106 | img = img.resize((ow, oh), Image.BILINEAR) 107 | mask = mask.resize((ow, oh), Image.NEAREST) 108 | # pad crop 109 | if short_size < self.crop_size: 110 | padh = self.crop_size - oh if oh < self.crop_size else 0 111 | padw = self.crop_size - ow if ow < self.crop_size else 0 112 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 113 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 114 | # random crop crop_size 115 | w, h = img.size 116 | x1 = random.randint(0, w - self.crop_size) 117 | y1 = random.randint(0, h - self.crop_size) 118 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 119 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 120 | 121 | return {'image': img, 122 | 'label': mask} 123 | 124 | 125 | class FixScaleCrop(object): 126 | def __init__(self, crop_size): 127 | self.crop_size = crop_size 128 | 129 | def __call__(self, sample): 130 | img = sample['image'] 131 | mask = sample['label'] 132 | w, h = img.size 133 | if w > h: 134 | oh = self.crop_size 135 | ow = int(1.0 * w * oh / h) 136 | else: 137 | ow = self.crop_size 138 | oh = int(1.0 * h * ow / w) 139 | img = img.resize((ow, oh), Image.BILINEAR) 140 | mask = mask.resize((ow, oh), Image.NEAREST) 141 | # center crop 142 | w, h = img.size 143 | x1 = int(round((w - self.crop_size) / 2.)) 144 | y1 = int(round((h - self.crop_size) / 2.)) 145 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 146 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 147 | 148 | return {'image': img, 149 | 'label': mask} 150 | 151 | class FixedResize(object): 152 | def __init__(self, height, width): 153 | self.size = (width, height) # size: (h, w) 154 | 155 | def __call__(self, sample): 156 | img = sample['image'] 157 | mask = sample['label'] 158 | 159 | assert img.size == mask.size 160 | 161 | img = img.resize(self.size, Image.BILINEAR) 162 | mask = mask.resize(self.size, Image.NEAREST) 163 | 164 | return {'image': img, 165 | 'label': mask} -------------------------------------------------------------------------------- /modules/dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/modules/dataloaders/datasets/__init__.py -------------------------------------------------------------------------------- /modules/dataloaders/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.misc as m 4 | from PIL import Image 5 | from torch.utils import data 6 | from torchvision import transforms 7 | from modules.dataloaders import custom_transforms as tr 8 | import settings 9 | 10 | class CityscapesSegmentation(data.Dataset): 11 | NUM_CLASSES = 19 12 | 13 | def __init__(self, args, root=settings.root_dir, split="train"): 14 | 15 | self.root = root 16 | self.split = split 17 | self.args = args 18 | self.files = {} 19 | 20 | self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) 21 | self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split) 22 | 23 | self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') 24 | 25 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 26 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 27 | self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \ 28 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \ 29 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \ 30 | 'motorcycle', 'bicycle'] 31 | 32 | self.ignore_index = 255 33 | self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES))) 34 | 35 | if not self.files[split]: 36 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 37 | 38 | print("Found %d %s images" % (len(self.files[split]), split)) 39 | 40 | def __len__(self): 41 | return len(self.files[self.split]) 42 | 43 | def __getitem__(self, index): 44 | 45 | img_path = self.files[self.split][index].rstrip() 46 | lbl_path = os.path.join(self.annotations_base, 47 | img_path.split(os.sep)[-2], 48 | os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png') 49 | 50 | _img = Image.open(img_path).convert('RGB') 51 | _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) 52 | _tmp = self.encode_segmap(_tmp) 53 | _target = Image.fromarray(_tmp) 54 | 55 | sample = {'image': _img, 'label': _target} 56 | 57 | if self.split == 'train': 58 | return self.transform_tr(sample) 59 | elif self.split == 'val': 60 | return self.transform_val(sample) 61 | elif self.split == 'test': 62 | return self.transform_ts(sample) 63 | 64 | def encode_segmap(self, mask): 65 | # Put all void classes to zero 66 | for _voidc in self.void_classes: 67 | mask[mask == _voidc] = self.ignore_index 68 | for _validc in self.valid_classes: 69 | mask[mask == _validc] = self.class_map[_validc] 70 | return mask 71 | 72 | def recursive_glob(self, rootdir='.', suffix=''): 73 | """Performs recursive glob with given suffix and rootdir 74 | :param rootdir is the root directory 75 | :param suffix is the suffix to be searched 76 | """ 77 | return [os.path.join(looproot, filename) 78 | for looproot, _, filenames in os.walk(rootdir) 79 | for filename in filenames if filename.endswith(suffix)] 80 | 81 | def transform_tr(self, sample): 82 | composed_transforms = transforms.Compose([ 83 | tr.RandomHorizontalFlip(), 84 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), 85 | tr.RandomGaussianBlur(), 86 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 87 | tr.ToTensor()]) 88 | 89 | return composed_transforms(sample) 90 | 91 | def transform_val(self, sample): 92 | 93 | composed_transforms = transforms.Compose([ 94 | tr.FixScaleCrop(crop_size=self.args.crop_size), 95 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 96 | tr.ToTensor()]) 97 | 98 | return composed_transforms(sample) 99 | 100 | def transform_ts(self, sample): 101 | 102 | composed_transforms = transforms.Compose([ 103 | tr.FixedResize(size=self.args.crop_size), 104 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 105 | tr.ToTensor()]) 106 | 107 | return composed_transforms(sample) 108 | 109 | if __name__ == '__main__': 110 | from modules.dataloaders.utils import decode_segmap 111 | from torch.utils.data import DataLoader 112 | import matplotlib.pyplot as plt 113 | import argparse 114 | 115 | parser = argparse.ArgumentParser() 116 | args = parser.parse_args() 117 | args.base_size = 513 118 | args.crop_size = 513 119 | 120 | cityscapes_train = CityscapesSegmentation(args, split='train') 121 | 122 | dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) 123 | 124 | for ii, sample in enumerate(dataloader): 125 | for jj in range(sample["image"].size()[0]): 126 | img = sample['image'].numpy() 127 | gt = sample['label'].numpy() 128 | tmp = np.array(gt[jj]).astype(np.uint8) 129 | segmap = decode_segmap(tmp, dataset='cityscapes') 130 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 131 | img_tmp *= (0.229, 0.224, 0.225) 132 | img_tmp += (0.485, 0.456, 0.406) 133 | img_tmp *= 255.0 134 | img_tmp = img_tmp.astype(np.uint8) 135 | plt.figure() 136 | plt.title('display') 137 | plt.subplot(211) 138 | plt.imshow(img_tmp) 139 | plt.subplot(212) 140 | plt.imshow(segmap) 141 | 142 | if ii == 1: 143 | break 144 | 145 | plt.show(block=True) 146 | 147 | -------------------------------------------------------------------------------- /modules/dataloaders/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from tqdm import trange 5 | import os 6 | from pycocotools.coco import COCO 7 | from pycocotools import mask 8 | from torchvision import transforms 9 | from modules.dataloaders import custom_transforms as tr 10 | from PIL import Image, ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | import settings 13 | 14 | 15 | class COCOSegmentation(Dataset): 16 | NUM_CLASSES = 21 17 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 18 | 1, 64, 20, 63, 7, 72] 19 | 20 | def __init__(self, base_dir=settings.root_dir, split='train', year='2017'): 21 | super().__init__() 22 | ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year)) 23 | ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year)) 24 | self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year)) 25 | self.split = split 26 | self.coco = COCO(ann_file) 27 | self.coco_mask = mask 28 | if os.path.exists(ids_file): 29 | self.ids = torch.load(ids_file) 30 | else: 31 | ids = list(self.coco.imgs.keys()) 32 | self.ids = self._preprocess(ids, ids_file) 33 | 34 | def __getitem__(self, index): 35 | _img, _target = self._make_img_gt_point_pair(index) 36 | sample = {'image': _img, 'label': _target} 37 | 38 | if self.split == "train": 39 | return self.transform_tr(sample) 40 | elif self.split == 'val': 41 | return self.transform_val(sample) 42 | 43 | def _make_img_gt_point_pair(self, index): 44 | coco = self.coco 45 | img_id = self.ids[index] 46 | img_metadata = coco.loadImgs(img_id)[0] 47 | path = img_metadata['file_name'] 48 | _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB') 49 | cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) 50 | _target = Image.fromarray(self._gen_seg_mask( 51 | cocotarget, img_metadata['height'], img_metadata['width'])) 52 | 53 | return _img, _target 54 | 55 | def _preprocess(self, ids, ids_file): 56 | print("Preprocessing mask, this will take a while. " + \ 57 | "But don't worry, it only run once for each split.") 58 | tbar = trange(len(ids)) 59 | new_ids = [] 60 | for i in tbar: 61 | img_id = ids[i] 62 | cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 63 | img_metadata = self.coco.loadImgs(img_id)[0] 64 | mask = self._gen_seg_mask(cocotarget, img_metadata['height'], 65 | img_metadata['width']) 66 | # more than 1k pixels 67 | if (mask > 0).sum() > 1000: 68 | new_ids.append(img_id) 69 | tbar.set_description('Doing: {}/{}, got {} qualified images'. \ 70 | format(i, len(ids), len(new_ids))) 71 | print('Found number of qualified images: ', len(new_ids)) 72 | torch.save(new_ids, ids_file) 73 | return new_ids 74 | 75 | def _gen_seg_mask(self, target, h, w): 76 | mask = np.zeros((h, w), dtype=np.uint8) 77 | coco_mask = self.coco_mask 78 | for instance in target: 79 | rle = coco_mask.frPyObjects(instance['segmentation'], h, w) 80 | m = coco_mask.decode(rle) 81 | cat = instance['category_id'] 82 | if cat in self.CAT_LIST: 83 | c = self.CAT_LIST.index(cat) 84 | else: 85 | continue 86 | if len(m.shape) < 3: 87 | mask[:, :] += (mask == 0) * (m * c) 88 | else: 89 | mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) 90 | return mask 91 | 92 | def transform_tr(self, sample): 93 | composed_transforms = transforms.Compose([ 94 | tr.RandomHorizontalFlip(), 95 | tr.RandomScaleCrop(base_size=settings.base_size, crop_size=settings.crop_size), 96 | tr.RandomGaussianBlur(), 97 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 98 | tr.ToTensor()]) 99 | 100 | return composed_transforms(sample) 101 | 102 | def transform_val(self, sample): 103 | 104 | composed_transforms = transforms.Compose([ 105 | tr.FixScaleCrop(crop_size=settings.crop_size), 106 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 107 | tr.ToTensor()]) 108 | 109 | return composed_transforms(sample) 110 | 111 | 112 | def __len__(self): 113 | return len(self.ids) 114 | 115 | 116 | 117 | if __name__ == "__main__": 118 | from modules.dataloaders import custom_transforms as tr 119 | from modules.dataloaders.utils import decode_segmap 120 | from torch.utils.data import DataLoader 121 | from torchvision import transforms 122 | import matplotlib.pyplot as plt 123 | import argparse 124 | 125 | coco_val = COCOSegmentation(split='val', year='2017') 126 | 127 | dataloader = DataLoader(coco_val, batch_size=4, shuffle=True, num_workers=0) 128 | 129 | for ii, sample in enumerate(dataloader): 130 | for jj in range(sample["image"].size()[0]): 131 | img = sample['image'].numpy() 132 | gt = sample['label'].numpy() 133 | tmp = np.array(gt[jj]).astype(np.uint8) 134 | segmap = decode_segmap(tmp, dataset='coco') 135 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 136 | img_tmp *= (0.229, 0.224, 0.225) 137 | img_tmp += (0.485, 0.456, 0.406) 138 | img_tmp *= 255.0 139 | img_tmp = img_tmp.astype(np.uint8) 140 | plt.figure() 141 | plt.title('display') 142 | plt.subplot(211) 143 | plt.imshow(img_tmp) 144 | plt.subplot(212) 145 | plt.imshow(segmap) 146 | 147 | if ii == 1: 148 | break 149 | 150 | plt.show(block=True) -------------------------------------------------------------------------------- /modules/dataloaders/datasets/combine_dbs.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class CombineDBs(data.Dataset): 5 | NUM_CLASSES = 21 6 | def __init__(self, dataloaders, excluded=None): 7 | self.dataloaders = dataloaders 8 | self.excluded = excluded 9 | self.im_ids = [] 10 | 11 | # Combine object lists 12 | for dl in dataloaders: 13 | for elem in dl.im_ids: 14 | if elem not in self.im_ids: 15 | self.im_ids.append(elem) 16 | 17 | # Exclude 18 | if excluded: 19 | for dl in excluded: 20 | for elem in dl.im_ids: 21 | if elem in self.im_ids: 22 | self.im_ids.remove(elem) 23 | 24 | # Get object pointers 25 | self.cat_list = [] 26 | self.im_list = [] 27 | new_im_ids = [] 28 | num_images = 0 29 | for ii, dl in enumerate(dataloaders): 30 | for jj, curr_im_id in enumerate(dl.im_ids): 31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids): 32 | num_images += 1 33 | new_im_ids.append(curr_im_id) 34 | self.cat_list.append({'db_ii': ii, 'cat_ii': jj}) 35 | 36 | self.im_ids = new_im_ids 37 | print('Combined number of images: {:d}'.format(num_images)) 38 | 39 | def __getitem__(self, index): 40 | 41 | _db_ii = self.cat_list[index]["db_ii"] 42 | _cat_ii = self.cat_list[index]['cat_ii'] 43 | sample = self.dataloaders[_db_ii].__getitem__(_cat_ii) 44 | 45 | if 'meta' in sample.keys(): 46 | sample['meta']['db'] = str(self.dataloaders[_db_ii]) 47 | 48 | return sample 49 | 50 | def __len__(self): 51 | return len(self.cat_list) 52 | 53 | def __str__(self): 54 | include_db = [str(db) for db in self.dataloaders] 55 | exclude_db = [str(db) for db in self.excluded] 56 | return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db) 57 | 58 | 59 | if __name__ == "__main__": 60 | import matplotlib.pyplot as plt 61 | from dataloaders.datasets import pascal, sbd 62 | from dataloaders import sbd 63 | import torch 64 | import numpy as np 65 | from dataloaders.utils import decode_segmap 66 | import argparse 67 | 68 | parser = argparse.ArgumentParser() 69 | args = parser.parse_args() 70 | args.base_size = 513 71 | args.crop_size = 513 72 | 73 | pascal_voc_val = pascal.VOCSegmentation(args, split='val') 74 | sbd = sbd.SBDSegmentation(args, split=['train', 'val']) 75 | pascal_voc_train = pascal.VOCSegmentation(args, split='train') 76 | 77 | dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val]) 78 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0) 79 | 80 | for ii, sample in enumerate(dataloader): 81 | for jj in range(sample["image"].size()[0]): 82 | img = sample['image'].numpy() 83 | gt = sample['label'].numpy() 84 | tmp = np.array(gt[jj]).astype(np.uint8) 85 | segmap = decode_segmap(tmp, dataset='pascal') 86 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 87 | img_tmp *= (0.229, 0.224, 0.225) 88 | img_tmp += (0.485, 0.456, 0.406) 89 | img_tmp *= 255.0 90 | img_tmp = img_tmp.astype(np.uint8) 91 | plt.figure() 92 | plt.title('display') 93 | plt.subplot(211) 94 | plt.imshow(img_tmp) 95 | plt.subplot(212) 96 | plt.imshow(segmap) 97 | 98 | if ii == 1: 99 | break 100 | plt.show(block=True) -------------------------------------------------------------------------------- /modules/dataloaders/datasets/pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | from modules.dataloaders import custom_transforms as tr 8 | import settings 9 | 10 | class VOCSegmentation(Dataset): 11 | """ 12 | PascalVoc dataset 13 | """ 14 | NUM_CLASSES = 21 15 | 16 | def __init__(self, 17 | args, 18 | base_dir=settings.root_dir, 19 | split='train', 20 | ): 21 | """ 22 | :param base_dir: path to VOC dataset directory 23 | :param split: train/val 24 | :param transform: transform to apply 25 | """ 26 | super().__init__() 27 | self._base_dir = base_dir 28 | self._image_dir = os.path.join(self._base_dir, 'JPEGImages') 29 | self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass') 30 | 31 | if isinstance(split, str): 32 | self.split = [split] 33 | else: 34 | split.sort() 35 | self.split = split 36 | 37 | self.args = args 38 | 39 | _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation') 40 | 41 | self.im_ids = [] 42 | self.images = [] 43 | self.categories = [] 44 | 45 | for splt in self.split: 46 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: 47 | lines = f.read().splitlines() 48 | 49 | for ii, line in enumerate(lines): 50 | _image = os.path.join(self._image_dir, line + ".jpg") 51 | _cat = os.path.join(self._cat_dir, line + ".png") 52 | assert os.path.isfile(_image) 53 | assert os.path.isfile(_cat) 54 | self.im_ids.append(line) 55 | self.images.append(_image) 56 | self.categories.append(_cat) 57 | 58 | assert (len(self.images) == len(self.categories)) 59 | 60 | # Display stats 61 | print('Number of images in {}: {:d}'.format(split, len(self.images))) 62 | 63 | def __len__(self): 64 | return len(self.images) 65 | 66 | 67 | def __getitem__(self, index): 68 | _img, _target = self._make_img_gt_point_pair(index) 69 | sample = {'image': _img, 'label': _target} 70 | 71 | for split in self.split: 72 | if split == "train": 73 | return self.transform_tr(sample) 74 | elif split == 'val': 75 | return self.transform_val(sample) 76 | 77 | 78 | def _make_img_gt_point_pair(self, index): 79 | _img = Image.open(self.images[index]).convert('RGB') 80 | _target = Image.open(self.categories[index]) 81 | 82 | return _img, _target 83 | 84 | def transform_tr(self, sample): 85 | composed_transforms = transforms.Compose([ 86 | tr.RandomHorizontalFlip(), 87 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 88 | tr.RandomGaussianBlur(), 89 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 90 | tr.ToTensor()]) 91 | 92 | return composed_transforms(sample) 93 | 94 | def transform_val(self, sample): 95 | 96 | composed_transforms = transforms.Compose([ 97 | tr.FixScaleCrop(crop_size=self.args.crop_size), 98 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 99 | tr.ToTensor()]) 100 | 101 | return composed_transforms(sample) 102 | 103 | def __str__(self): 104 | return 'VOC2012(split=' + str(self.split) + ')' 105 | 106 | 107 | if __name__ == '__main__': 108 | from dataloaders.utils import decode_segmap 109 | from torch.utils.data import DataLoader 110 | import matplotlib.pyplot as plt 111 | import argparse 112 | 113 | parser = argparse.ArgumentParser() 114 | args = parser.parse_args() 115 | args.base_size = 513 116 | args.crop_size = 513 117 | 118 | voc_train = VOCSegmentation(args, split='train') 119 | 120 | dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0) 121 | 122 | for ii, sample in enumerate(dataloader): 123 | for jj in range(sample["image"].size()[0]): 124 | img = sample['image'].numpy() 125 | gt = sample['label'].numpy() 126 | tmp = np.array(gt[jj]).astype(np.uint8) 127 | segmap = decode_segmap(tmp, dataset='pascal') 128 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 129 | img_tmp *= (0.229, 0.224, 0.225) 130 | img_tmp += (0.485, 0.456, 0.406) 131 | img_tmp *= 255.0 132 | img_tmp = img_tmp.astype(np.uint8) 133 | plt.figure() 134 | plt.title('display') 135 | plt.subplot(211) 136 | plt.imshow(img_tmp) 137 | plt.subplot(212) 138 | plt.imshow(segmap) 139 | 140 | if ii == 1: 141 | break 142 | 143 | plt.show(block=True) 144 | 145 | 146 | -------------------------------------------------------------------------------- /modules/dataloaders/datasets/sbd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | 4 | import numpy as np 5 | import scipy.io 6 | import torch.utils.data as data 7 | from PIL import Image 8 | 9 | from torchvision import transforms 10 | from modules.dataloaders import custom_transforms as tr 11 | import settings 12 | 13 | class SBDSegmentation(data.Dataset): 14 | NUM_CLASSES = 21 15 | 16 | def __init__(self, 17 | args, 18 | base_dir=settings.root_dir, 19 | split='train', 20 | ): 21 | """ 22 | :param base_dir: path to VOC dataset directory 23 | :param split: train/val 24 | :param transform: transform to apply 25 | """ 26 | super().__init__() 27 | self._base_dir = base_dir 28 | self._dataset_dir = os.path.join(self._base_dir, 'dataset') 29 | self._image_dir = os.path.join(self._dataset_dir, 'img') 30 | self._cat_dir = os.path.join(self._dataset_dir, 'cls') 31 | 32 | 33 | if isinstance(split, str): 34 | self.split = [split] 35 | else: 36 | split.sort() 37 | self.split = split 38 | 39 | self.args = args 40 | 41 | # Get list of all images from the split and check that the files exist 42 | self.im_ids = [] 43 | self.images = [] 44 | self.categories = [] 45 | for splt in self.split: 46 | with open(os.path.join(self._dataset_dir, splt + '.txt'), "r") as f: 47 | lines = f.read().splitlines() 48 | 49 | for line in lines: 50 | _image = os.path.join(self._image_dir, line + ".jpg") 51 | _categ= os.path.join(self._cat_dir, line + ".mat") 52 | assert os.path.isfile(_image) 53 | assert os.path.isfile(_categ) 54 | self.im_ids.append(line) 55 | self.images.append(_image) 56 | self.categories.append(_categ) 57 | 58 | assert (len(self.images) == len(self.categories)) 59 | 60 | # Display stats 61 | print('Number of images: {:d}'.format(len(self.images))) 62 | 63 | 64 | def __getitem__(self, index): 65 | _img, _target = self._make_img_gt_point_pair(index) 66 | sample = {'image': _img, 'label': _target} 67 | 68 | return self.transform(sample) 69 | 70 | def __len__(self): 71 | return len(self.images) 72 | 73 | def _make_img_gt_point_pair(self, index): 74 | _img = Image.open(self.images[index]).convert('RGB') 75 | _target = Image.fromarray(scipy.io.loadmat(self.categories[index])["GTcls"][0]['Segmentation'][0]) 76 | 77 | return _img, _target 78 | 79 | def transform(self, sample): 80 | composed_transforms = transforms.Compose([ 81 | tr.RandomHorizontalFlip(), 82 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 83 | tr.RandomGaussianBlur(), 84 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 85 | tr.ToTensor()]) 86 | 87 | return composed_transforms(sample) 88 | 89 | 90 | def __str__(self): 91 | return 'SBDSegmentation(split=' + str(self.split) + ')' 92 | 93 | 94 | if __name__ == '__main__': 95 | from modules.dataloaders.utils import decode_segmap 96 | from torch.utils.data import DataLoader 97 | import matplotlib.pyplot as plt 98 | import argparse 99 | 100 | parser = argparse.ArgumentParser() 101 | args = parser.parse_args() 102 | args.base_size = 513 103 | args.crop_size = 513 104 | 105 | sbd_train = SBDSegmentation(args, split='train') 106 | dataloader = DataLoader(sbd_train, batch_size=2, shuffle=True, num_workers=2) 107 | 108 | for ii, sample in enumerate(dataloader): 109 | for jj in range(sample["image"].size()[0]): 110 | img = sample['image'].numpy() 111 | gt = sample['label'].numpy() 112 | tmp = np.array(gt[jj]).astype(np.uint8) 113 | segmap = decode_segmap(tmp, dataset='pascal') 114 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 115 | img_tmp *= (0.229, 0.224, 0.225) 116 | img_tmp += (0.485, 0.456, 0.406) 117 | img_tmp *= 255.0 118 | img_tmp = img_tmp.astype(np.uint8) 119 | plt.figure() 120 | plt.title('display') 121 | plt.subplot(211) 122 | plt.imshow(img_tmp) 123 | plt.subplot(212) 124 | plt.imshow(segmap) 125 | 126 | if ii == 1: 127 | break 128 | 129 | plt.show(block=True) -------------------------------------------------------------------------------- /modules/dataloaders/datasets/surface.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from glob import glob 4 | import settings 5 | from PIL import Image, ImageFile 6 | ImageFile.LOAD_TRUNCATED_IMAGES = True 7 | import numpy as np 8 | from torchvision import transforms 9 | from modules.dataloaders import custom_transforms as tr 10 | from modules.dataloaders.utils import decode_segmap, encode_segmap 11 | 12 | """ 13 | base_dir 14 | ├── annotations 15 | │ ├── *.xml 16 | ├── images 17 | │ ├── *.jpg 18 | ├── masks 19 | │ ├── *.png 20 | ├── train.txt 21 | ├── valid.txt 22 | └── labels.xlsx 23 | """ 24 | 25 | 26 | class SurfaceSegmentation(Dataset): 27 | NUM_CLASSES = settings.num_classes 28 | 29 | def __init__(self, base_dir=settings.root_dir, split='train'): 30 | super().__init__() 31 | self.split = split 32 | self.images = [] 33 | self.masks = [] 34 | 35 | file_list = Path(base_dir) / f'{split}.txt' 36 | with open(file_list, 'r') as f: 37 | lines = f.read().splitlines() 38 | for line in lines: 39 | image, mask = line.split(',') 40 | image = Path(base_dir) / image 41 | mask = Path(base_dir) / mask 42 | assert image.exists(), f'File not found: {image}' 43 | assert mask.exists(), f'File not found: {mask}' 44 | self.images.append(image) 45 | self.masks.append(mask) 46 | print(f'{split}| images: {len(self.images)}, masks: {len(self.masks)}') 47 | 48 | if split == 'train': 49 | self.composed_transform = transforms.Compose([ 50 | tr.FixedResize(settings.resize_height, settings.resize_width), 51 | tr.RandomHorizontalFlip(), 52 | # tr.RandomRotate(90), 53 | # tr.RandomScaleCrop(base_size=settings.base_size, crop_size=settings.crop_size), 54 | tr.RandomGaussianBlur(), 55 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 56 | tr.ToTensor()]) 57 | elif split == 'valid': 58 | self.composed_transform = transforms.Compose([ 59 | # tr.FixScaleCrop(crop_size=settings.crop_size), 60 | tr.FixedResize(settings.resize_height, settings.resize_width), 61 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 62 | tr.ToTensor()]) 63 | else: 64 | raise KeyError("split must be one of 'train' or 'valid'.") 65 | 66 | def __getitem__(self, index): 67 | image_file, mask_file = self.images[index], self.masks[index] 68 | _img = self.preprocess_image(image_file) 69 | _mask = self.preprocess_mask(mask_file) 70 | sample = {'image': _img, 'label': _mask} 71 | 72 | return self.composed_transform(sample) 73 | 74 | def __len__(self): 75 | return len(self.masks) 76 | 77 | @staticmethod 78 | def preprocess_image(jpg_file): 79 | image = Image.open(jpg_file).convert('RGB') 80 | return image 81 | 82 | @staticmethod 83 | def preprocess_mask(png_file): 84 | image = np.array(Image.open(png_file), dtype=np.uint8) 85 | 86 | h, w, c = image.shape 87 | assert c == 3, f"Invalid channel number: {c}. {png_file}" 88 | 89 | new_mask = encode_segmap(image, dataset='surface') 90 | 91 | return Image.fromarray(new_mask) 92 | 93 | 94 | if __name__ == '__main__': 95 | from torch.utils.data import DataLoader 96 | import matplotlib.pyplot as plt 97 | 98 | dataset = SurfaceSegmentation(base_dir=settings.root_dir, split='train') 99 | 100 | dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0) 101 | 102 | for ii, sample in enumerate(dataloader): 103 | for jj in range(sample["image"].size()[0]): 104 | img = sample['image'].numpy() 105 | gt = sample['label'].numpy() 106 | tmp = np.array(gt[jj]).astype(np.uint8) 107 | segmap = decode_segmap(tmp, dataset='surface') 108 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 109 | img_tmp *= (0.229, 0.224, 0.225) 110 | img_tmp += (0.485, 0.456, 0.406) 111 | img_tmp *= 255.0 112 | img_tmp = img_tmp.astype(np.uint8) 113 | plt.figure() 114 | plt.title('display') 115 | plt.subplot(211) 116 | plt.imshow(img_tmp) 117 | plt.subplot(212) 118 | plt.imshow(segmap) 119 | 120 | if ii == 1: 121 | break 122 | 123 | plt.show(block=True) 124 | -------------------------------------------------------------------------------- /modules/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import settings 5 | 6 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 7 | rgb_masks = [] 8 | for label_mask in label_masks: 9 | rgb_mask = decode_segmap(label_mask, dataset) 10 | rgb_masks.append(rgb_mask) 11 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 12 | return rgb_masks 13 | 14 | 15 | def decode_segmap(label_mask, dataset, plot=False, **kwargs): 16 | """Decode segmentation class labels into a color image 17 | Args: 18 | label_mask (np.ndarray): an (M,N) array of integer values denoting 19 | the class label at each spatial location. 20 | plot (bool, optional): whether to show the resulting color image 21 | in a figure. 22 | kwargs (dict, optional): if dataset == 'custom', then uses kwargs['n_classes'], kwargs['label_colors']. 23 | Returns: 24 | (np.ndarray, optional): the resulting decoded color image. 25 | """ 26 | if dataset == 'pascal' or dataset == 'coco': 27 | n_classes = 21 28 | label_colors = get_pascal_labels() 29 | elif dataset == 'cityscapes': 30 | n_classes = 19 31 | label_colors = get_cityscapes_labels() 32 | elif dataset == 'surface': 33 | n_classes = settings.num_classes 34 | label_colors = get_surface_labels() 35 | elif dataset == 'custom': 36 | assert 'n_classes' in kwargs and 'label_colors' in kwargs, "Please specify custom color map and n_classes." 37 | n_classes = int(kwargs['n_classes']) 38 | label_colors = np.asarray(kwargs['label_colors']) 39 | else: 40 | raise NotImplementedError 41 | 42 | r = label_mask.copy() 43 | g = label_mask.copy() 44 | b = label_mask.copy() 45 | for ll in range(0, n_classes): 46 | r[label_mask == ll] = label_colors[ll, 0] 47 | g[label_mask == ll] = label_colors[ll, 1] 48 | b[label_mask == ll] = label_colors[ll, 2] 49 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3), dtype=np.float32) 50 | rgb[:, :, 0] = r / 255.0 51 | rgb[:, :, 1] = g / 255.0 52 | rgb[:, :, 2] = b / 255.0 53 | if plot: 54 | plt.imshow(rgb) 55 | plt.show() 56 | else: 57 | return rgb 58 | 59 | 60 | def encode_segmap(mask, dataset): 61 | """Encode segmentation label images as pascal classes 62 | Args: 63 | mask (np.ndarray): raw segmentation label image of dimension 64 | (M, N, 3), in which the Pascal classes are encoded as colours. 65 | Returns: 66 | (np.ndarray): class map with dimensions (M,N), where the value at 67 | a given location is the integer denoting the class index. 68 | """ 69 | mask = mask.astype(np.uint8) 70 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8) 71 | 72 | if dataset == 'pascal' or dataset == 'coco': 73 | n_classes = 21 74 | label_colors = get_pascal_labels() 75 | elif dataset == 'cityscapes': 76 | n_classes = 19 77 | label_colors = get_cityscapes_labels() 78 | elif dataset == 'surface': 79 | n_classes = settings.num_classes 80 | label_colors = get_surface_labels() 81 | else: 82 | raise NotImplementedError 83 | assert n_classes <= np.iinfo(np.uint8).max, "assert n_classes <= uint8 max" 84 | 85 | for ii, label in enumerate(label_colors): 86 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 87 | 88 | return label_mask 89 | 90 | 91 | def get_cityscapes_labels(): 92 | return np.array([ 93 | [128, 64, 128], 94 | [244, 35, 232], 95 | [70, 70, 70], 96 | [102, 102, 156], 97 | [190, 153, 153], 98 | [153, 153, 153], 99 | [250, 170, 30], 100 | [220, 220, 0], 101 | [107, 142, 35], 102 | [152, 251, 152], 103 | [0, 130, 180], 104 | [220, 20, 60], 105 | [255, 0, 0], 106 | [0, 0, 142], 107 | [0, 0, 70], 108 | [0, 60, 100], 109 | [0, 80, 100], 110 | [0, 0, 230], 111 | [119, 11, 32]]) 112 | 113 | 114 | def get_pascal_labels(): 115 | """Load the mapping that associates pascal classes with label colors 116 | Returns: 117 | np.ndarray with dimensions (21, 3) 118 | """ 119 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 120 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 121 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 122 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 123 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 124 | [0, 64, 128]]) 125 | 126 | def get_surface_labels(): 127 | return np.array(settings.colors) 128 | -------------------------------------------------------------------------------- /modules/models/deeplab_xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | BatchNorm2d = SynchronizedBatchNorm2d 9 | 10 | class SeparableConv2d(nn.Module): 11 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=0, dilation=1, bias=False): 12 | super(SeparableConv2d, self).__init__() 13 | 14 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, 15 | groups=inplanes, bias=bias) 16 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 17 | 18 | def forward(self, x): 19 | x = self.conv1(x) 20 | x = self.pointwise(x) 21 | return x 22 | 23 | 24 | def fixed_padding(inputs, kernel_size, dilation): 25 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 26 | pad_total = kernel_size_effective - 1 27 | pad_beg = pad_total // 2 28 | pad_end = pad_total - pad_beg 29 | padded_inputs = F.pad(inputs, [pad_beg, pad_end, pad_beg, pad_end]) 30 | return padded_inputs 31 | 32 | 33 | class SeparableConv2d_same(nn.Module): 34 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False): 35 | super(SeparableConv2d_same, self).__init__() 36 | 37 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 38 | groups=inplanes, bias=bias) 39 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 40 | 41 | def forward(self, x): 42 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 43 | x = self.conv1(x) 44 | x = self.pointwise(x) 45 | return x 46 | 47 | 48 | class Block(nn.Module): 49 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, start_with_relu=True, grow_first=True, is_last=False): 50 | super(Block, self).__init__() 51 | 52 | if planes != inplanes or stride != 1: 53 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 54 | self.skipbn = BatchNorm2d(planes) 55 | else: 56 | self.skip = None 57 | 58 | self.relu = nn.ReLU(inplace=True) 59 | rep = [] 60 | 61 | filters = inplanes 62 | if grow_first: 63 | rep.append(self.relu) 64 | rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) 65 | rep.append(BatchNorm2d(planes)) 66 | filters = planes 67 | 68 | for i in range(reps - 1): 69 | rep.append(self.relu) 70 | rep.append(SeparableConv2d_same(filters, filters, 3, stride=1, dilation=dilation)) 71 | rep.append(BatchNorm2d(filters)) 72 | 73 | if not grow_first: 74 | rep.append(self.relu) 75 | rep.append(SeparableConv2d_same(inplanes, planes, 3, stride=1, dilation=dilation)) 76 | rep.append(BatchNorm2d(planes)) 77 | 78 | if not start_with_relu: 79 | rep = rep[1:] 80 | 81 | if stride != 1: 82 | rep.append(SeparableConv2d_same(planes, planes, 3, stride=2)) 83 | 84 | if stride == 1 and is_last: 85 | rep.append(SeparableConv2d_same(planes, planes, 3, stride=1)) 86 | 87 | 88 | self.rep = nn.Sequential(*rep) 89 | 90 | def forward(self, inp): 91 | x = self.rep(inp) 92 | 93 | if self.skip is not None: 94 | skip = self.skip(inp) 95 | skip = self.skipbn(skip) 96 | else: 97 | skip = inp 98 | 99 | x += skip 100 | 101 | return x 102 | 103 | 104 | class Xception(nn.Module): 105 | """ 106 | Modified Alighed Xception 107 | """ 108 | def __init__(self, inplanes=3, os=16, pretrained=False): 109 | super(Xception, self).__init__() 110 | 111 | if os == 16: 112 | entry_block3_stride = 2 113 | middle_block_dilation = 1 114 | exit_block_dilations = (1, 2) 115 | elif os == 8: 116 | entry_block3_stride = 1 117 | middle_block_dilation = 2 118 | exit_block_dilations = (2, 4) 119 | else: 120 | raise NotImplementedError 121 | 122 | 123 | # Entry flow 124 | self.conv1 = nn.Conv2d(inplanes, 32, 3, stride=2, padding=1, bias=False) 125 | self.bn1 = BatchNorm2d(32) 126 | self.relu = nn.ReLU(inplace=True) 127 | 128 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 129 | self.bn2 = BatchNorm2d(64) 130 | 131 | self.block1 = Block(64, 128, reps=2, stride=2, start_with_relu=False) 132 | self.block2 = Block(128, 256, reps=2, stride=2, start_with_relu=True, grow_first=True) 133 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, start_with_relu=True, grow_first=True, 134 | is_last=True) 135 | 136 | # Middle flow 137 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 138 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 139 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 140 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 141 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 142 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 143 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 144 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 146 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 147 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 148 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 149 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 150 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 151 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 152 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, start_with_relu=True, grow_first=True) 153 | 154 | # Exit flow 155 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 156 | start_with_relu=True, grow_first=False, is_last=True) 157 | 158 | self.conv3 = SeparableConv2d_same(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1]) 159 | self.bn3 = BatchNorm2d(1536) 160 | 161 | self.conv4 = SeparableConv2d_same(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1]) 162 | self.bn4 = BatchNorm2d(1536) 163 | 164 | self.conv5 = SeparableConv2d_same(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1]) 165 | self.bn5 = BatchNorm2d(2048) 166 | 167 | # Init weights 168 | self._init_weight() 169 | 170 | # Load pretrained model 171 | if pretrained: 172 | self._load_xception_pretrained() 173 | 174 | def forward(self, x): 175 | # Entry flow 176 | x = self.conv1(x) 177 | x = self.bn1(x) 178 | x = self.relu(x) 179 | 180 | x = self.conv2(x) 181 | x = self.bn2(x) 182 | x = self.relu(x) 183 | 184 | x = self.block1(x) 185 | low_level_feat = x 186 | x = self.block2(x) 187 | x = self.block3(x) 188 | 189 | # Middle flow 190 | x = self.block4(x) 191 | x = self.block5(x) 192 | x = self.block6(x) 193 | x = self.block7(x) 194 | x = self.block8(x) 195 | x = self.block9(x) 196 | x = self.block10(x) 197 | x = self.block11(x) 198 | x = self.block12(x) 199 | x = self.block13(x) 200 | x = self.block14(x) 201 | x = self.block15(x) 202 | x = self.block16(x) 203 | x = self.block17(x) 204 | x = self.block18(x) 205 | x = self.block19(x) 206 | 207 | # Exit flow 208 | x = self.block20(x) 209 | x = self.conv3(x) 210 | x = self.bn3(x) 211 | x = self.relu(x) 212 | 213 | x = self.conv4(x) 214 | x = self.bn4(x) 215 | x = self.relu(x) 216 | 217 | x = self.conv5(x) 218 | x = self.bn5(x) 219 | x = self.relu(x) 220 | 221 | return x, low_level_feat 222 | 223 | def _init_weight(self): 224 | for m in self.modules(): 225 | if isinstance(m, nn.Conv2d): 226 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 227 | m.weight.data.normal_(0, math.sqrt(2. / n)) 228 | elif isinstance(m, BatchNorm2d): 229 | m.weight.data.fill_(1) 230 | m.bias.data.zero_() 231 | 232 | def _load_xception_pretrained(self): 233 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 234 | model_dict = {} 235 | state_dict = self.state_dict() 236 | 237 | for k, v in pretrain_dict.items(): 238 | if k in model_dict: 239 | if 'pointwise' in k: 240 | v = v.unsqueeze(-1).unsqueeze(-1) 241 | if k.startswith('block11'): 242 | model_dict[k] = v 243 | model_dict[k.replace('block11', 'block12')] = v 244 | model_dict[k.replace('block11', 'block13')] = v 245 | model_dict[k.replace('block11', 'block14')] = v 246 | model_dict[k.replace('block11', 'block15')] = v 247 | model_dict[k.replace('block11', 'block16')] = v 248 | model_dict[k.replace('block11', 'block17')] = v 249 | model_dict[k.replace('block11', 'block18')] = v 250 | model_dict[k.replace('block11', 'block19')] = v 251 | elif k.startswith('block12'): 252 | model_dict[k.replace('block12', 'block20')] = v 253 | elif k.startswith('bn3'): 254 | model_dict[k] = v 255 | model_dict[k.replace('bn3', 'bn4')] = v 256 | elif k.startswith('conv4'): 257 | model_dict[k.replace('conv4', 'conv5')] = v 258 | elif k.startswith('bn4'): 259 | model_dict[k.replace('bn4', 'bn5')] = v 260 | else: 261 | model_dict[k] = v 262 | state_dict.update(model_dict) 263 | self.load_state_dict(state_dict) 264 | 265 | class ASPP_module(nn.Module): 266 | def __init__(self, inplanes, planes, dilation): 267 | super(ASPP_module, self).__init__() 268 | if dilation == 1: 269 | kernel_size = 1 270 | padding = 0 271 | else: 272 | kernel_size = 3 273 | padding = dilation 274 | self.atrous_convolution = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 275 | stride=1, padding=padding, dilation=dilation, bias=False) 276 | self.bn = BatchNorm2d(planes) 277 | self.relu = nn.ReLU() 278 | 279 | self._init_weight() 280 | 281 | def forward(self, x): 282 | x = self.atrous_convolution(x) 283 | x = self.bn(x) 284 | 285 | return self.relu(x) 286 | 287 | def _init_weight(self): 288 | for m in self.modules(): 289 | if isinstance(m, nn.Conv2d): 290 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 291 | m.weight.data.normal_(0, math.sqrt(2. / n)) 292 | elif isinstance(m, BatchNorm2d): 293 | m.weight.data.fill_(1) 294 | m.bias.data.zero_() 295 | 296 | 297 | class DeepLabv3_plus(nn.Module): 298 | def __init__(self, nInputChannels=3, n_classes=21, os=16, pretrained=False, freeze_bn=False, _print=True): 299 | if _print: 300 | print("Constructing DeepLabv3+ model...") 301 | print("Backbone: Xception") 302 | print("Number of classes: {}".format(n_classes)) 303 | print("Output stride: {}".format(os)) 304 | print("Number of Input Channels: {}".format(nInputChannels)) 305 | super(DeepLabv3_plus, self).__init__() 306 | 307 | # Atrous Conv 308 | self.xception_features = Xception(nInputChannels, os, pretrained) 309 | 310 | # ASPP 311 | if os == 16: 312 | dilations = [1, 6, 12, 18] 313 | elif os == 8: 314 | dilations = [1, 12, 24, 36] 315 | else: 316 | raise NotImplementedError 317 | 318 | self.aspp1 = ASPP_module(2048, 256, dilation=dilations[0]) 319 | self.aspp2 = ASPP_module(2048, 256, dilation=dilations[1]) 320 | self.aspp3 = ASPP_module(2048, 256, dilation=dilations[2]) 321 | self.aspp4 = ASPP_module(2048, 256, dilation=dilations[3]) 322 | 323 | self.relu = nn.ReLU() 324 | 325 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 326 | nn.Conv2d(2048, 256, 1, stride=1, bias=False), 327 | BatchNorm2d(256), 328 | nn.ReLU()) 329 | 330 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 331 | self.bn1 = BatchNorm2d(256) 332 | 333 | # adopt [1x1, 48] for channel reduction. 334 | self.conv2 = nn.Conv2d(128, 48, 1, bias=False) 335 | self.bn2 = BatchNorm2d(48) 336 | 337 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 338 | BatchNorm2d(256), 339 | nn.ReLU(), 340 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 341 | BatchNorm2d(256), 342 | nn.ReLU(), 343 | nn.Conv2d(256, n_classes, kernel_size=1, stride=1)) 344 | if freeze_bn: 345 | self._freeze_bn() 346 | 347 | def forward(self, input): 348 | x, low_level_features = self.xception_features(input) 349 | x1 = self.aspp1(x) 350 | x2 = self.aspp2(x) 351 | x3 = self.aspp3(x) 352 | x4 = self.aspp4(x) 353 | x5 = self.global_avg_pool(x) 354 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 355 | 356 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 357 | 358 | x = self.conv1(x) 359 | x = self.bn1(x) 360 | x = self.relu(x) 361 | x = F.interpolate(x, size=(int(math.ceil(input.size()[-2]/4)), 362 | int(math.ceil(input.size()[-1]/4))), mode='bilinear', align_corners=True) 363 | 364 | low_level_features = self.conv2(low_level_features) 365 | low_level_features = self.bn2(low_level_features) 366 | low_level_features = self.relu(low_level_features) 367 | 368 | 369 | x = torch.cat((x, low_level_features), dim=1) 370 | x = self.last_conv(x) 371 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 372 | 373 | return x 374 | 375 | def _freeze_bn(self): 376 | for m in self.modules(): 377 | if isinstance(m, BatchNorm2d): 378 | m.eval() 379 | 380 | def _init_weight(self): 381 | for m in self.modules(): 382 | if isinstance(m, nn.Conv2d): 383 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 384 | m.weight.data.normal_(0, math.sqrt(2. / n)) 385 | elif isinstance(m, BatchNorm2d): 386 | m.weight.data.fill_(1) 387 | m.bias.data.zero_() 388 | 389 | def get_1x_lr_params(model): 390 | """ 391 | This generator returns all the parameters of the net except for 392 | the last classification layer. Note that for each batchnorm layer, 393 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 394 | any batchnorm parameter 395 | """ 396 | b = [model.xception_features] 397 | for i in range(len(b)): 398 | for k in b[i].parameters(): 399 | if k.requires_grad: 400 | yield k 401 | 402 | 403 | def get_10x_lr_params(model): 404 | """ 405 | This generator returns all the parameters for the last layer of the net, 406 | which does the classification of pixel into classes 407 | """ 408 | b = [model.aspp1, model.aspp2, model.aspp3, model.aspp4, model.conv1, model.conv2, model.last_conv] 409 | for j in range(len(b)): 410 | for k in b[j].parameters(): 411 | if k.requires_grad: 412 | yield k 413 | 414 | 415 | if __name__ == "__main__": 416 | model = DeepLabv3_plus(nInputChannels=3, n_classes=21, os=16, pretrained=True, _print=True) 417 | model.eval() 418 | image = torch.randn(1, 3, 512, 512) 419 | with torch.no_grad(): 420 | output = model.forward(image) 421 | print(output.size()) -------------------------------------------------------------------------------- /modules/models/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /modules/models/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /modules/models/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /modules/models/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /modules/models/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /modules/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/modules/utils/__init__.py -------------------------------------------------------------------------------- /modules/utils/calculate_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import settings 5 | 6 | def calculate_weigths_labels(dataset, dataloader, num_classes): 7 | # Create an instance from the data loader 8 | z = np.zeros((num_classes,)) 9 | # Initialize tqdm 10 | tqdm_batch = tqdm(dataloader) 11 | print('Calculating classes weights') 12 | for sample in tqdm_batch: 13 | y = sample['label'] 14 | y = y.detach().cpu().numpy() 15 | mask = (y >= 0) & (y < num_classes) 16 | labels = y[mask].astype(np.uint8) 17 | count_l = np.bincount(labels, minlength=num_classes) 18 | z += count_l 19 | tqdm_batch.close() 20 | total_frequency = np.sum(z) 21 | class_weights = [] 22 | for frequency in z: 23 | class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) 24 | class_weights.append(class_weight) 25 | ret = np.array(class_weights) 26 | classes_weights_path = os.path.join(settings.root_dir, dataset+'_classes_weights.npy') 27 | np.save(classes_weights_path, ret) 28 | 29 | return ret -------------------------------------------------------------------------------- /modules/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SegmentationLosses(object): 5 | def __init__(self, weight=None, size_average=True, batch_average=True, ignore_index=255, cuda=False): 6 | self.ignore_index = ignore_index 7 | self.weight = weight 8 | self.size_average = size_average 9 | self.batch_average = batch_average 10 | self.cuda = cuda 11 | 12 | def build_loss(self, mode='ce'): 13 | """Choices: ['ce' or 'focal']""" 14 | if mode == 'ce': 15 | return self.CrossEntropyLoss 16 | elif mode == 'focal': 17 | return self.FocalLoss 18 | else: 19 | raise NotImplementedError 20 | 21 | def CrossEntropyLoss(self, logit, target): 22 | n, c, h, w = logit.size() 23 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 24 | size_average=self.size_average) 25 | if self.cuda: 26 | criterion = criterion.cuda() 27 | 28 | loss = criterion(logit, target.long()) 29 | 30 | if self.batch_average: 31 | loss /= n 32 | 33 | return loss 34 | 35 | def FocalLoss(self, logit, target, gamma=2, alpha=0.5): 36 | n, c, h, w = logit.size() 37 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 38 | size_average=self.size_average) 39 | if self.cuda: 40 | criterion = criterion.cuda() 41 | 42 | logpt = -criterion(logit, target.long()) 43 | pt = torch.exp(logpt) 44 | if alpha is not None: 45 | logpt *= alpha 46 | loss = -((1 - pt) ** gamma) * logpt 47 | 48 | if self.batch_average: 49 | loss /= n 50 | 51 | return loss 52 | 53 | if __name__ == "__main__": 54 | loss = SegmentationLosses(cuda=True) 55 | a = torch.rand(1, 3, 7, 7).cuda() 56 | b = torch.rand(1, 7, 7).cuda() 57 | print(loss.CrossEntropyLoss(a, b).item()) 58 | print(loss.FocalLoss(a, b, gamma=0, alpha=None).item()) 59 | print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item()) 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /modules/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | class LR_Scheduler(object): 14 | """Learning Rate Scheduler 15 | 16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 17 | 18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 19 | 20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 21 | 22 | Args: 23 | args: 24 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 25 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 26 | :attr:`args.lr_step` 27 | 28 | iters_per_epoch: number of iterations per epoch 29 | """ 30 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 31 | lr_step=0, warmup_epochs=0): 32 | self.mode = mode 33 | print('Using {} LR Scheduler!'.format(self.mode)) 34 | self.lr = base_lr 35 | if mode == 'step': 36 | assert lr_step 37 | self.lr_step = lr_step 38 | self.iters_per_epoch = iters_per_epoch 39 | self.N = num_epochs * iters_per_epoch 40 | self.epoch = -1 41 | self.warmup_iters = warmup_epochs * iters_per_epoch 42 | 43 | def __call__(self, optimizer, i, epoch, best_pred): 44 | T = epoch * self.iters_per_epoch + i 45 | if self.mode == 'cos': 46 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 47 | elif self.mode == 'poly': 48 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 49 | elif self.mode == 'step': 50 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 51 | else: 52 | raise NotImplemented 53 | # warm up lr schedule 54 | if self.warmup_iters > 0 and T < self.warmup_iters: 55 | lr = lr * 1.0 * T / self.warmup_iters 56 | if epoch > self.epoch: 57 | print('\n=>Epoches %i, learning rate = %.4f, \ 58 | previous best = %.4f' % (epoch, lr, best_pred)) 59 | self.epoch = epoch 60 | assert lr >= 0 61 | self._adjust_learning_rate(optimizer, lr) 62 | 63 | def _adjust_learning_rate(self, optimizer, lr): 64 | if len(optimizer.param_groups) == 1: 65 | optimizer.param_groups[0]['lr'] = lr 66 | else: 67 | # enlarge the lr at the head 68 | optimizer.param_groups[0]['lr'] = lr 69 | for i in range(1, len(optimizer.param_groups)): 70 | optimizer.param_groups[i]['lr'] = lr * 10 71 | -------------------------------------------------------------------------------- /modules/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = np.zeros((self.num_class,)*2) 8 | 9 | def Pixel_Accuracy(self): 10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 11 | return Acc 12 | 13 | def Pixel_Accuracy_Class(self): 14 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 15 | Acc = np.nanmean(Acc) 16 | return Acc 17 | 18 | def Intersection_over_Union(self): 19 | IoU = np.diag(self.confusion_matrix) / ( 20 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 21 | np.diag(self.confusion_matrix)) 22 | return IoU 23 | 24 | def Mean_Intersection_over_Union(self): 25 | IoU = self.Intersection_over_Union() 26 | MIoU = np.nanmean(IoU) 27 | return MIoU 28 | 29 | def Frequency_Weighted_Intersection_over_Union(self): 30 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 31 | iu = np.diag(self.confusion_matrix) / ( 32 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 33 | np.diag(self.confusion_matrix)) 34 | 35 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 36 | return FWIoU 37 | 38 | def _generate_matrix(self, gt_image, pre_image): 39 | mask = (gt_image >= 0) & (gt_image < self.num_class) 40 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 41 | count = np.bincount(label, minlength=self.num_class**2) 42 | confusion_matrix = count.reshape(self.num_class, self.num_class) 43 | return confusion_matrix 44 | 45 | def add_batch(self, gt_image, pre_image): 46 | assert gt_image.shape == pre_image.shape 47 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 48 | 49 | def reset(self): 50 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /modules/utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | from collections import OrderedDict 5 | import glob 6 | from shutil import copy2 7 | import settings 8 | 9 | class Saver(object): 10 | 11 | def __init__(self): 12 | self.directory = os.path.join('run', settings.dataset, settings.checkname) 13 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 14 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 15 | 16 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 17 | if not os.path.exists(self.experiment_dir): 18 | os.makedirs(self.experiment_dir) 19 | 20 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 21 | """Saves checkpoint to disk""" 22 | filename = os.path.join(self.experiment_dir, filename) 23 | torch.save(state, filename) 24 | if is_best: 25 | best_pred = state['best_pred'] 26 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 27 | f.write(str(best_pred)) 28 | if self.runs: 29 | previous_miou = [0.0] 30 | for run in self.runs: 31 | run_id = run.split('_')[-1] 32 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 33 | if os.path.exists(path): 34 | with open(path, 'r') as f: 35 | miou = float(f.readline()) 36 | previous_miou.append(miou) 37 | else: 38 | continue 39 | max_miou = max(previous_miou) 40 | if best_pred > max_miou: 41 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 42 | else: 43 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 44 | 45 | def save_experiment_config(self): 46 | if not os.path.exists('settings.py'): 47 | print("settings.py couldn't be found. save_experiment_config failed.") 48 | copy2('settings.py', self.experiment_dir) 49 | -------------------------------------------------------------------------------- /modules/utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import make_grid 4 | from tensorboardX import SummaryWriter 5 | from modules.dataloaders.utils import decode_seg_map_sequence 6 | 7 | class TensorboardSummary(object): 8 | def __init__(self, directory): 9 | self.directory = directory 10 | 11 | def create_summary(self): 12 | writer = SummaryWriter(log_dir=os.path.join(self.directory)) 13 | return writer 14 | 15 | def visualize_image(self, writer, dataset, image, target, output, global_step): 16 | grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) 17 | writer.add_image('Image', grid_image, global_step) 18 | grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), 19 | dataset=dataset), 3, normalize=False, range=(0, 255)) 20 | writer.add_image('Predicted label', grid_image, global_step) 21 | grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), 22 | dataset=dataset), 3, normalize=False, range=(0, 255)) 23 | writer.add_image('Groundtruth label', grid_image, global_step) -------------------------------------------------------------------------------- /modules/utils/surface_dataset_tools/split_dataset.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from pathlib import Path 3 | import random 4 | from os import path as osp 5 | 6 | """ 7 | By running this code train.txt, valid.txt will be generated. 8 | 9 | base_dir 10 | ├── annotations 11 | │ ├── *.xml 12 | ├── images 13 | │ ├── *.jpg 14 | ├── masks 15 | │ ├── *.png 16 | ├── train.txt <-- 17 | ├── valid.txt <-- 18 | └── color_index.xlsx 19 | """ 20 | 21 | ### RUN OPTIONS ### 22 | BASE_DIR = Path('/home/super/Projects/dataset/surface6') 23 | VALID_COUNT = 2000 24 | ################### 25 | 26 | images = sorted(glob(str(BASE_DIR/'**/*.jpg'), recursive=True)) 27 | masks = sorted(glob(str(BASE_DIR/'**/*.png'), recursive=True)) 28 | images = [s for s in images if not '.ipynb_checkpoints' in s] 29 | masks = [s for s in masks if not '.ipynb_checkpoints' in s] 30 | print(len(images), len(masks)) 31 | 32 | masks = sorted(masks) 33 | random.shuffle(masks) 34 | 35 | train_lines = [] 36 | for mask in masks[VALID_COUNT:]: 37 | mask = Path(mask) 38 | name = mask.name.replace('.png', '') 39 | mask_file = f'masks/{name}.png' 40 | image_file = f'images/{name}.jpg' 41 | assert osp.exists(BASE_DIR / image_file), f"{image_file} not exists." 42 | 43 | line = f'{image_file},{mask_file}' 44 | train_lines.append(line) 45 | 46 | with open(osp.join(BASE_DIR, 'train.txt'), 'w+') as f: 47 | f.write('\n'.join(train_lines)) 48 | 49 | valid_lines = [] 50 | for mask in masks[:VALID_COUNT]: 51 | mask = Path(mask) 52 | name = mask.name.replace('.png', '') 53 | mask_file = f'masks/{name}.png' 54 | image_file = f'images/{name}.jpg' 55 | assert osp.exists(BASE_DIR / image_file), f"{image_file} not exists." 56 | 57 | line = f'{image_file},{mask_file}' 58 | valid_lines.append(line) 59 | 60 | with open(osp.join(BASE_DIR / 'valid.txt'), 'w+') as f: 61 | f.write('\n'.join(valid_lines)) 62 | 63 | print(f'train: {len(train_lines)}, valid: {len(valid_lines)}') 64 | -------------------------------------------------------------------------------- /modules/utils/surface_dataset_tools/surface_polygon.py: -------------------------------------------------------------------------------- 1 | # pip install xmltodict 2 | import json 3 | import xmltodict 4 | from glob import glob 5 | from os import path as osp 6 | from pathlib import Path 7 | import cv2 8 | import numpy as np 9 | from shutil import copy2 10 | from PIL import Image 11 | 12 | """ 13 | Input Dataset Structure 14 | 15 | BASE_DIR 16 | ├── SM0915_13 17 | │ ├── 27_SM0915_13.xml 18 | │ ├── MP_SEL_SUR_001441.jpg 19 | │ ├── MP_SEL_SUR_001442.jpg 20 | │ ├── MP_SEL_SUR_001443.jpg 21 | │ ├── ... 22 | ├── SM0915_14 23 | │ ├── 28_SM0915_14.xml 24 | │ ├── MP_SEL_SUR_001562.jpg 25 | │ ├── MP_SEL_SUR_001563.jpg 26 | │ ├── MP_SEL_SUR_001564.jpg 27 | │ ├── ... 28 | ├── ... 29 | └── color_index.xlsx 30 | 31 | Output Dataset Structure 32 | 33 | OUTPUT_DIR 34 | ├── annotations 35 | │ ├── *.xml 36 | ├── images 37 | │ ├── *.jpg 38 | ├── masks 39 | │ ├── *.png 40 | └── color_index.xlsx 41 | """ 42 | 43 | ### RUN OPTIONS ### 44 | BASE_DIR = Path('/home/super/Projects/dataset/surface_org') 45 | XML_GLOB = Path(BASE_DIR) / '**/*.xml' 46 | OUTPUT_DIR = Path('/home/super/Projects/dataset/surface6') 47 | JSON_OUTPUT = OUTPUT_DIR / 'annotations' 48 | IMAGE_OUTPUT = OUTPUT_DIR / 'images' 49 | MASK_OUTPUT = OUTPUT_DIR / 'masks' 50 | 51 | # [0, 0, 0] for ignore mask 52 | color_map = { 53 | "sidewalk@blocks": [0, 255, 0], # sidewalk 54 | "sidewalk@cement": [0, 255, 0], # sidewalk 55 | "sidewalk@urethane": [255, 128, 0], # bike_lane 56 | "sidewalk@asphalt": [255, 128, 0], # bike_lane 57 | "sidewalk@soil_stone": [0, 255, 0], # sidewalk 58 | "sidewalk@damaged": [0, 255, 0], # sidewalk 59 | "sidewalk@other": [0, 255, 0], # sidewalk 60 | "braille_guide_blocks@normal": [255, 255, 0], # guide_block 61 | "braille_guide_blocks@damaged": [255, 255, 0], # guide_block 62 | "roadway@normal": [0, 0, 255], # roadway 63 | "roadway@crosswalk": [255, 0, 255], # crosswalk 64 | "alley@normal": [0, 0, 255], # roadway 65 | "alley@crosswalk": [255, 0, 255], # crosswalk 66 | "alley@speed_bump": [0, 0, 255], # roadway 67 | "alley@damaged": [0, 0, 255], # roadway 68 | "bike_lane@normal": [255, 128, 0], # bike_lane 69 | "caution_zone@stairs": [255, 0, 0], # caution_zone 70 | "caution_zone@manhole": [0, 0, 0], # background 71 | "caution_zone@tree_zone": [255, 0, 0], # caution_zone 72 | "caution_zone@grating": [255, 0, 0], # caution_zone 73 | "caution_zone@repair_zone": [255, 0, 0], # caution_zone 74 | } 75 | ################### 76 | 77 | error_logs = [] 78 | image2path = {} 79 | statistics = {} 80 | 81 | def convert_to_json(): 82 | xml_files = sorted(glob(str(XML_GLOB), recursive=True)) 83 | JSON_OUTPUT.mkdir(exist_ok=False, parents=True) 84 | 85 | for i, xml in enumerate(xml_files): 86 | xml = Path(xml) 87 | out_json = JSON_OUTPUT / xml.name.replace('.xml', '.json') 88 | print(f'[{i+1}/{len(xml_files)}] Converting {xml} -> {out_json}') 89 | 90 | with open(xml, 'r') as f: 91 | xml_string = f.read() 92 | 93 | xml_dict = xmltodict.parse(xml_string) 94 | json_string = json.dumps(xml_dict, indent=4) 95 | 96 | with open(out_json, 'w+') as f: 97 | f.write(json_string) 98 | 99 | def index_images(): 100 | global image2path 101 | for image in BASE_DIR.rglob('*.jpg'): 102 | image2path[image.name] = image 103 | print(f"Found {len(list(image2path.keys()))} images from dataset.") 104 | 105 | def add_count(key: str): 106 | if key not in statistics: 107 | statistics[key] = 0 108 | statistics[key] += 1 109 | 110 | def draw_polygon(points: list, mask: np.array, rgb): 111 | pts = np.asarray(points).reshape((-1, 1, 2)) 112 | r, g, b = rgb 113 | mask = cv2.fillPoly(mask, [pts], color=(r,g,b)) 114 | return mask 115 | 116 | def parse_polygon(polygon: dict, mask: np.array): 117 | label = polygon['@label'] # str 'sidewalk' 118 | # occluded = polygon['@occluded'] # str '0' I think this is always zero 119 | points = str(polygon['@points']) # str '778.28,0.00;0.56,1080.00;1920.00,1080.00;...' 120 | # z_order = polygon['@z_order'] # str '2' 121 | 122 | if 'attribute' in polygon: 123 | attribute = polygon['attribute'] # dict 124 | # attribute_name = attribute['@name'] # str 'attribute' useless... 125 | attribute_text = attribute['#text'] # str 'blocks', 'normal', ...' 126 | else: 127 | attribute_text = 'normal' 128 | 129 | cls = str(label) + '@' + str(attribute_text) 130 | assert cls in color_map, f"Invalid label@attribute: {cls}" 131 | add_count(cls) 132 | 133 | rgb = color_map[cls] 134 | if rgb == [0, 0, 0]: # Ignore polygon 135 | return mask 136 | 137 | xy_points = [[round(float(x)), round(float(y))] for x, y in [str(xy_str).split(',') for xy_str in points.split(';')]] 138 | mask = draw_polygon(xy_points, mask, rgb) 139 | return mask 140 | 141 | def parse_image(image: dict, json_file: Path): 142 | global image2path 143 | 144 | id = image['@id'] # str '0' 145 | name = str(image['@name']) # str 'MP_SEL_SUR_007985.jpg' 146 | width = int(image['@width']) # str '1920' 147 | height = int(image['@height']) # str '1080' 148 | 149 | if 'polygon' in image: 150 | polygon_list = image['polygon'] # list or dict(if only one instance) 151 | elif 'polyline' in image: 152 | polygon_list = image['polyline'] 153 | else: 154 | raise KeyError(f"Couldn't find polygon key from dict.") 155 | 156 | ### Masking polygon ### 157 | mask = np.zeros((height, width, 3), dtype=np.uint8) 158 | if isinstance(polygon_list, dict): 159 | polygon_list = [polygon_list] 160 | polygon_list = sorted(polygon_list, key=lambda poly: int(poly['@z_order'])) 161 | for polygon in polygon_list: 162 | mask = parse_polygon(polygon, mask) 163 | 164 | ### Saving mask ### 165 | mask_file = MASK_OUTPUT / str(name).replace('.jpg', '.png') 166 | Image.fromarray(mask).save(mask_file) 167 | add_count('mask') 168 | 169 | ### Saving image ### 170 | if not name in image2path: 171 | raise FileNotFoundError(f"Cannot find {name} from indexed image paths.") 172 | copy2(image2path[name], IMAGE_OUTPUT) 173 | add_count('image') 174 | 175 | def parse_json(json_file, **kwargs): 176 | # root > annotations > image[] > {'@id', '@name', '@width', '@height', 'polygon[]'} 177 | global error_logs 178 | 179 | json_file = Path(json_file) 180 | with open(json_file, 'r') as f: 181 | ann = json.load(f) 182 | annotations = ann['annotations'] 183 | # version = annotations['version'] # str 184 | # meta = annotations['meta'] # dict 185 | image_list = annotations['image'] # list 186 | 187 | for i, image in enumerate(image_list): 188 | print(f"[{kwargs['json_count']}/{kwargs['json_total']}] [{i + 1}/{len(image_list)}] {image['@name']}") 189 | try: 190 | parse_image(image, json_file) 191 | except KeyError as e: 192 | print(e) 193 | error_logs.append(f"{json_file.name}-{image['@name']}-{e}") 194 | 195 | def generate_masks(): 196 | global error_logs 197 | 198 | MASK_OUTPUT.mkdir(exist_ok=False, parents=True) 199 | IMAGE_OUTPUT.mkdir(exist_ok=False, parents=True) 200 | 201 | index_images() 202 | 203 | json_files = sorted(glob(str(JSON_OUTPUT / '**/*.json'), recursive=True)) 204 | for i, json_file in enumerate(json_files): 205 | print(f'[{i+1}/{len(json_files)}] {json_file}') 206 | parse_json(json_file, json_count=i+1, json_total=len(json_files)) 207 | 208 | print(f'Error count: {len(error_logs)}') 209 | print(error_logs) 210 | print(statistics) 211 | 212 | if __name__ == '__main__': 213 | if not JSON_OUTPUT.exists(): 214 | convert_to_json() 215 | 216 | generate_masks() 217 | -------------------------------------------------------------------------------- /modules/utils/torch_logger.py: -------------------------------------------------------------------------------- 1 | try: 2 | from tqdm import tqdm 3 | TQDM_ENABLED = True 4 | except ImportError: 5 | TQDM_ENABLED = False 6 | 7 | 8 | class TorchLogger: 9 | def __init__(self, total_epoch, total_step): 10 | self.total_epoch = total_epoch 11 | self.total_step = total_step 12 | 13 | self.sum_dict = {} 14 | self.count_dict = {} 15 | self.avg_dict = {} 16 | 17 | def reset(self): 18 | self.sum_dict = {} 19 | self.count_dict = {} 20 | self.avg_dict = {} 21 | 22 | def log(self, epoch, step, etc_str='', **kwargs): 23 | """ 24 | logger.log(loss=loss.item(), acc=acc, ...) 25 | :param kwargs: 26 | :return: 27 | """ 28 | self.update(**kwargs) 29 | print(self.get_log_string(epoch, step, etc_str)) 30 | 31 | def update(self, **kwargs): 32 | """ 33 | avg_loss, avg_acc, ... = logger.update(loss=loss.item(), acc=acc, ...) 34 | :param kwargs: 35 | :return: 36 | """ 37 | value_dict = kwargs 38 | 39 | for key, value in value_dict.items(): 40 | if key in self.sum_dict: 41 | self.sum_dict[key] += value 42 | else: 43 | self.sum_dict[key] = value 44 | if key in self.count_dict: 45 | self.count_dict[key] += 1 46 | else: 47 | self.count_dict[key] = 1 48 | 49 | for key, value in self.sum_dict.items(): 50 | if self.count_dict[key] == 0: 51 | self.count_dict[key] = 1 52 | self.avg_dict[key] = self.sum_dict[key] / self.count_dict[key] 53 | 54 | return self.avg_dict 55 | 56 | def get_log_string(self, epoch, step, etc_str=''): 57 | msg = 'Epoch: [{}/{}]\tStep: [{}/{}]\t'.format(epoch, self.total_epoch, step, self.total_step) 58 | 59 | for i, (key, value) in enumerate(self.avg_dict.items()): 60 | msg += '{}: {:.3f}\t'.format(key, value) 61 | 62 | msg += etc_str 63 | 64 | return msg 65 | 66 | def print_log(self, epoch, step): 67 | print(self.get_log_string(epoch, step)) 68 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # predict on jpg files or mp4 video 2 | 3 | import cv2 4 | import torch 5 | from glob import glob 6 | import os 7 | import os.path as osp 8 | from pathlib import Path 9 | from torchvision import transforms 10 | from modules.dataloaders.utils import decode_segmap 11 | from modules.models.deeplab_xception import DeepLabv3_plus 12 | from modules.models.sync_batchnorm.replicate import patch_replication_callback 13 | import numpy as np 14 | from PIL import Image 15 | from tqdm import tqdm 16 | 17 | ### RUN OPTIONS ### 18 | MODEL_PATH = "./run/surface/deeplab/model_iou_77.pth.tar" 19 | ORIGINAL_HEIGHT = 720 20 | ORIGINAL_WIDTH = 1280 21 | MODEL_HEIGHT = 512 22 | MODEL_WIDTH = 1024 23 | NUM_CLASSES = 7 # including background 24 | CUDA = True if torch.cuda.is_available() else False 25 | 26 | MODE = 'jpg' # 'mp4' or 'jpg' 27 | DATA_PATH = './test/jpgs' # .mp4 path or folder containing jpg images 28 | OUTPUT_PATH = './output/jpgs' # where video file or jpg frames folder should be saved. 29 | 30 | # MODE = 'mp4' 31 | # DATA_PATH = './test/test.mp4' 32 | # OUTPUT_PATH = './output/test.avi' 33 | 34 | SHOW_OUTPUT = True if 'DISPLAY' in os.environ else False # whether to cv2.show() 35 | 36 | OVERLAPPING = True # whether to mix segmentation map and original image 37 | FPS_OVERRIDE = 60 # None to use original video fps 38 | 39 | CUSTOM_COLOR_MAP = [ 40 | [0, 0, 0], # background 41 | [255, 128, 0], # bike_lane 42 | [255, 0, 0], # caution_zone 43 | [255, 0, 255], # crosswalk 44 | [255, 255, 0], # guide_block 45 | [0, 0, 255], # roadway 46 | [0, 255, 0], # sidewalk 47 | ] # To ignore unused classes while predicting 48 | 49 | CUSTOM_N_CLASSES = len(CUSTOM_COLOR_MAP) 50 | ###### 51 | 52 | 53 | class FrameGeneratorMP4: 54 | def __init__(self, mp4_file: str, output_path=None, show=True): 55 | assert osp.isfile(mp4_file), "DATA_PATH should be existing mp4 file path." 56 | self.vidcap = cv2.VideoCapture(mp4_file) 57 | self.fps = int(self.vidcap.get(cv2.CAP_PROP_FPS)) 58 | self.total = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) 59 | self.show = show 60 | self.output_path = output_path 61 | 62 | if self.output_path is not None: 63 | os.makedirs(osp.dirname(output_path), exist_ok=True) 64 | self.fourcc = cv2.VideoWriter_fourcc(*'DIVX') 65 | 66 | if FPS_OVERRIDE is not None: 67 | self.fps = int(FPS_OVERRIDE) 68 | self.out = cv2.VideoWriter(OUTPUT_PATH, self.fourcc, self.fps, (ORIGINAL_WIDTH, ORIGINAL_HEIGHT)) 69 | 70 | def __iter__(self): 71 | success, image = self.vidcap.read() 72 | for i in range(0, self.total): 73 | if success: 74 | img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 75 | yield np.array(img) 76 | 77 | success, image = self.vidcap.read() 78 | 79 | def __len__(self): 80 | return self.total 81 | 82 | def write(self, rgb_img): 83 | bgr = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR) 84 | 85 | if self.show: 86 | cv2.imshow('output', bgr) 87 | if cv2.waitKey(1) & 0xFF == ord('q'): 88 | print('User Interrupted') 89 | self.close() 90 | exit(1) 91 | 92 | if self.output_path is not None: 93 | self.out.write(bgr) 94 | 95 | def close(self): 96 | cv2.destroyAllWindows() 97 | self.vidcap.release() 98 | if self.output_path is not None: 99 | self.out.release() 100 | 101 | 102 | class FrameGeneratorJpg: 103 | def __init__(self, jpg_folder: str, output_folder=None, show=True): 104 | assert osp.isdir(jpg_folder), "DATA_PATH should be directory including jpg files." 105 | self.files = sorted(glob(osp.join(jpg_folder, '*.jpg'), recursive=False)) 106 | self.show = show 107 | self.output_folder = output_folder 108 | self.last_file_name = "" 109 | 110 | if self.output_folder is not None: 111 | os.makedirs(output_folder, exist_ok=True) 112 | 113 | def __iter__(self): 114 | for file in self.files: 115 | img = cv2.imread(file, cv2.IMREAD_COLOR) 116 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 117 | self.last_file_name = str(Path(file).name) 118 | yield np.array(img) 119 | 120 | def __len__(self): 121 | return len(self.files) 122 | 123 | def write(self, rgb_img): 124 | bgr = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR) 125 | 126 | if self.show: 127 | cv2.imshow('output', bgr) 128 | if cv2.waitKey(1) & 0xFF == ord('q'): 129 | print('User Interrupted') 130 | self.close() 131 | exit(1) 132 | 133 | if self.output_folder is not None: 134 | path = osp.join(self.output_folder, f'{self.last_file_name}') 135 | cv2.imwrite(path, bgr) 136 | 137 | def close(self): 138 | cv2.destroyAllWindows() 139 | 140 | 141 | class ModelWrapper: 142 | def __init__(self): 143 | self.composed_transform = transforms.Compose([ 144 | transforms.Resize((MODEL_HEIGHT, MODEL_WIDTH), interpolation=Image.BILINEAR), 145 | transforms.ToTensor(), 146 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) 147 | 148 | self.model = self.load_model(MODEL_PATH) 149 | 150 | @staticmethod 151 | def load_model(model_path): 152 | model = DeepLabv3_plus(nInputChannels=3, n_classes=NUM_CLASSES, os=16) 153 | if CUDA: 154 | model = torch.nn.DataParallel(model, device_ids=[0]) 155 | patch_replication_callback(model) 156 | model = model.cuda() 157 | if not osp.isfile(MODEL_PATH): 158 | raise RuntimeError("=> no checkpoint found at '{}'".format(model_path)) 159 | checkpoint = torch.load(model_path) 160 | if CUDA: 161 | model.module.load_state_dict(checkpoint['state_dict']) 162 | else: 163 | model.load_state_dict(checkpoint['state_dict']) 164 | print("=> loaded checkpoint '{}' (epoch: {}, best_pred: {})" 165 | .format(model_path, checkpoint['epoch'], checkpoint['best_pred'])) 166 | model.eval() 167 | return model 168 | 169 | def predict(self, rgb_img: np.array): 170 | x = self.composed_transform(Image.fromarray(rgb_img)) 171 | x = x.unsqueeze(0) 172 | 173 | if CUDA: 174 | x = x.cuda() 175 | with torch.no_grad(): 176 | output = self.model(x) 177 | pred = output.data.detach().cpu().numpy() 178 | pred = np.argmax(pred, axis=1).squeeze(0) 179 | segmap = decode_segmap(pred, dataset='custom', label_colors=CUSTOM_COLOR_MAP, n_classes=CUSTOM_N_CLASSES) 180 | segmap = np.array(segmap * 255).astype(np.uint8) 181 | 182 | resized = cv2.resize(segmap, (ORIGINAL_WIDTH, ORIGINAL_HEIGHT), 183 | interpolation=cv2.INTER_NEAREST) 184 | return resized 185 | 186 | 187 | def main(): 188 | print('Loading model...') 189 | model_wrapper = ModelWrapper() 190 | 191 | if MODE == 'mp4': 192 | generator = FrameGeneratorMP4(DATA_PATH, OUTPUT_PATH, show=SHOW_OUTPUT) 193 | elif MODE == 'jpg': 194 | generator = FrameGeneratorJpg(DATA_PATH, OUTPUT_PATH, show=SHOW_OUTPUT) 195 | else: 196 | raise NotImplementedError('MODE should be "mp4" or "jpg".') 197 | 198 | for index, img in enumerate(tqdm(generator)): 199 | segmap = model_wrapper.predict(img) 200 | if OVERLAPPING: 201 | h, w, _ = np.array(segmap).shape 202 | img_resized = cv2.resize(img, (w, h)) 203 | result = (img_resized * 0.5 + segmap * 0.5).astype(np.uint8) 204 | else: 205 | result = segmap 206 | generator.write(result) 207 | 208 | generator.close() 209 | print('Done.') 210 | 211 | 212 | if __name__ == '__main__': 213 | main() 214 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | """CUDA_VISIBLE_DEVICES=0,1,2,3 python3 train.py""" 2 | 3 | backbone = 'xception' 4 | out_stride = 8 # network output stride (default: 8) 5 | workers = 16 6 | pretrained = True # whether to use pretrained Xception backbone 7 | 8 | resize_height = 512 # model input shape 9 | resize_width = 1024 10 | 11 | cuda = True 12 | 13 | # If you want to use gpu:1,2,3, run CUDA_VISIBLE_DEVICES=1,2,3 python3 ... 14 | # with gpu_ids option [0,1,2] starting with zero 15 | gpu_ids = [0, 1, 2, 3] # use which gpu to train 16 | 17 | sync_bn = True if len(gpu_ids) > 1 else False # whether to use sync bn 18 | freeze_bn = False # whether to freeze bn parameters (default: False) 19 | 20 | epochs = 200 21 | start_epoch = 0 22 | batch_size = 2 * len(gpu_ids) 23 | test_batch_size = 2 * len(gpu_ids) 24 | 25 | loss_type = 'ce' # 'ce': CrossEntropy, 'focal': Focal Loss 26 | use_balanced_weights = False # whether to use balanced weights (default: False) 27 | lr = 1e-3 28 | 29 | # Adam optimizer performed far better. 30 | # lr_scheduler = 'poly' # lr scheduler mode: ['poly', 'step', 'cos'] 31 | # momentum = 0.9 32 | # weight_decay = 5e-4 33 | # nesterov = False 34 | 35 | resume = False # True: load checkpoint model. False: train from scratch 36 | checkpoint = './run/surface/deeplab/model_iou_77.pth.tar' 37 | 38 | checkname = "deeplab" # set the checkpoint name 39 | 40 | ft = False # finetuning on a different dataset 41 | eval_interval = 1 # evaluuation interval (default: 1) 42 | no_val = False # skip validation during training 43 | 44 | dataset = 'surface' 45 | root_dir = '' 46 | if dataset == 'pascal': 47 | use_sbd = False # whether to use SBD dataset 48 | root_dir = '/path/to/datasets/VOCdevkit/VOC2012/' # folder that contains VOCdevkit/. 49 | elif dataset == 'sbd': 50 | root_dir = '/path/to/datasets/benchmark_RELEASE/' # folder that contains dataset/. 51 | elif dataset == 'cityscapes': 52 | root_dir = '/path/to/datasets/cityscapes/' # foler that contains leftImg8bit/ 53 | elif dataset == 'coco': 54 | root_dir = '/home/super/Projects/dataset/coco/' 55 | elif dataset == 'surface': 56 | root_dir = '/home/super/Projects/dataset/surface6' 57 | else: 58 | print('Dataset {} not available.'.format(dataset)) 59 | raise NotImplementedError 60 | 61 | """ 62 | background 0 [0, 0, 0] 63 | bike_lane 1 [255, 128, 0] 64 | caution_zone 2 [255, 0, 0] 65 | crosswalk 3 [255, 0, 255] 66 | guide_block 4 [255, 255, 0] 67 | roadway 5 [0, 0, 255] 68 | sidewalk 6 [0, 255, 0] 69 | """ 70 | """ 71 | Class Attr Unique Label R G B 72 | background background 0 0 0 0 73 | sidewalk blocks sidewalk 6 0 0 255 74 | sidewalk cement sidewalk 6 217 217 217 75 | sidewalk urethane bike_lane 1 198 89 17 76 | sidewalk asphalt background 1 128 128 128 77 | sidewalk soil_stone sidewalk 6 255 230 153 78 | sidewalk damaged sidewalk 6 55 86 35 79 | sidewalk other sidewalk 6 110 168 70 80 | braille_guide_blocks normal guide_block 4 255 255 0 81 | braille_guide_blocks damaged guide_block 4 128 96 0 82 | roadway normal roadway 5 255 128 255 83 | roadway crosswalk crosswalk 3 255 0 255 84 | alley normal roadway 5 230 170 255 85 | alley crosswalk crosswalk 3 208 88 255 86 | alley speed_bump roadway 5 138 60 200 87 | alley damaged roadway 5 88 38 128 88 | bike_lane normal bike_lane 1 255 155 155 89 | caution_zone stairs caution_zone 2 255 192 0 90 | caution_zone manhole caution_zone 2 255 0 0 91 | caution_zone tree_zone caution_zone 2 0 255 0 92 | caution_zone grating caution_zone 2 255 128 0 93 | caution_zone repair_zone caution_zone 2 105 105 255 94 | 95 | See more info at 96 | modules/utils/surface_dataset_tools/surface_polygon.py 97 | modules/utils/surface_dataset_tools/split_dataset.py 98 | modules/dataloaders/datasets/surface.py 99 | """ 100 | labels = [ 101 | 'background', 102 | 'bike_lane', 103 | 'caution_zone', 104 | 'crosswalk', 105 | 'guide_block', 106 | 'roadway', 107 | 'sidewalk', 108 | ] 109 | # RGB 110 | colors = [ 111 | [0, 0, 0], 112 | [255, 128, 0], 113 | [255, 0, 0], 114 | [255, 0, 255], 115 | [255, 255, 0], 116 | [0, 0, 255], 117 | [0, 255, 0], 118 | ] 119 | 120 | num_classes = len(colors) # 7 121 | -------------------------------------------------------------------------------- /test/jpgs/MP_SEL_SUR_001453.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/test/jpgs/MP_SEL_SUR_001453.jpg -------------------------------------------------------------------------------- /test/jpgs/MP_SEL_SUR_001456.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/test/jpgs/MP_SEL_SUR_001456.jpg -------------------------------------------------------------------------------- /test/jpgs/MP_SEL_SUR_001503.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/test/jpgs/MP_SEL_SUR_001503.jpg -------------------------------------------------------------------------------- /test/jpgs/MP_SEL_SUR_001563.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/selectstarofficial/segmentation-selectstar/c3b1cf608adbf58f2412b1874ce99305898d6638/test/jpgs/MP_SEL_SUR_001563.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | import settings 7 | from modules.dataloaders import make_data_loader 8 | from modules.models.sync_batchnorm.replicate import patch_replication_callback 9 | from modules.models.deeplab_xception import DeepLabv3_plus, get_1x_lr_params, get_10x_lr_params 10 | from modules.utils.loss import SegmentationLosses 11 | from modules.utils.calculate_weights import calculate_weigths_labels 12 | # from modules.utils.lr_scheduler import LR_Scheduler 13 | from modules.utils.saver import Saver 14 | from modules.utils.summaries import TensorboardSummary 15 | from modules.utils.metrics import Evaluator 16 | 17 | class Trainer(object): 18 | def __init__(self,): 19 | # Define Saver 20 | self.saver = Saver() 21 | self.saver.save_experiment_config() 22 | # Define Tensorboard Summary 23 | self.summary = TensorboardSummary(self.saver.experiment_dir) 24 | self.writer = self.summary.create_summary() 25 | 26 | # Define Dataloader 27 | kwargs = {'num_workers': settings.workers, 'pin_memory': True} 28 | self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(**kwargs) 29 | 30 | # Define network 31 | model = DeepLabv3_plus(nInputChannels=3, n_classes=self.nclass, os=16, pretrained=settings.pretrained, _print=True) 32 | 33 | train_params = [{'params': get_1x_lr_params(model), 'lr': settings.lr}, 34 | {'params': get_10x_lr_params(model), 'lr': settings.lr}] 35 | 36 | # Define Optimizer 37 | # optimizer = torch.optim.SGD(train_params, momentum=settings.momentum, 38 | # weight_decay=settings.weight_decay, nesterov=settings.nesterov) 39 | optimizer = torch.optim.Adam(train_params) 40 | 41 | # Define Criterion 42 | # whether to use class balanced weights 43 | if settings.use_balanced_weights: 44 | classes_weights_path = os.path.join(settings.root_dir, settings.dataset+'_classes_weights.npy') 45 | if os.path.isfile(classes_weights_path): 46 | weight = np.load(classes_weights_path) 47 | else: 48 | weight = calculate_weigths_labels(settings.dataset, self.train_loader, self.nclass) 49 | weight = torch.from_numpy(weight.astype(np.float32)) 50 | else: 51 | weight = None 52 | self.criterion = SegmentationLosses(weight=weight, cuda=settings.cuda).build_loss(mode=settings.loss_type) 53 | self.model, self.optimizer = model, optimizer 54 | 55 | # Define Evaluator 56 | self.evaluator = Evaluator(self.nclass) 57 | # Define lr scheduler 58 | # self.scheduler = LR_Scheduler(settings.lr_scheduler, settings.lr, 59 | # settings.epochs, len(self.train_loader)) 60 | 61 | # Using cuda 62 | if settings.cuda: 63 | self.model = torch.nn.DataParallel(self.model, device_ids=settings.gpu_ids) 64 | patch_replication_callback(self.model) 65 | self.model = self.model.cuda() 66 | 67 | # Resuming checkpoint 68 | self.best_pred = 0.0 69 | if settings.resume: 70 | if not os.path.isfile(settings.checkpoint): 71 | raise RuntimeError("=> no checkpoint found at '{}'" .format(settings.checkpoint)) 72 | checkpoint = torch.load(settings.checkpoint) 73 | settings.start_epoch = checkpoint['epoch'] 74 | if settings.cuda: 75 | self.model.module.load_state_dict(checkpoint['state_dict']) 76 | else: 77 | self.model.load_state_dict(checkpoint['state_dict']) 78 | if not settings.ft: 79 | self.optimizer.load_state_dict(checkpoint['optimizer']) 80 | self.best_pred = checkpoint['best_pred'] 81 | print("=> loaded checkpoint '{}' (epoch {})" 82 | .format(settings.checkpoint, checkpoint['epoch'])) 83 | 84 | # Clear start epoch if fine-tuning 85 | if settings.ft: 86 | settings.start_epoch = 0 87 | 88 | def training(self, epoch): 89 | train_loss = 0.0 90 | self.model.train() 91 | tbar = tqdm(self.train_loader) 92 | num_img_tr = len(self.train_loader) 93 | for i, sample in enumerate(tbar): 94 | image, target = sample['image'], sample['label'] 95 | if settings.cuda: 96 | image, target = image.cuda(), target.cuda() 97 | # self.scheduler(self.optimizer, i, epoch, self.best_pred) 98 | self.optimizer.zero_grad() 99 | output = self.model(image) 100 | loss = self.criterion(output, target) 101 | loss.backward() 102 | self.optimizer.step() 103 | train_loss += loss.item() 104 | tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) 105 | self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) 106 | 107 | # Show 10 * 3 inference results each epoch 108 | if i % (num_img_tr // 10) == 0: 109 | global_step = i + num_img_tr * epoch 110 | self.summary.visualize_image(self.writer, settings.dataset, image, target, output, global_step) 111 | 112 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 113 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * settings.batch_size + image.data.shape[0])) 114 | print('Loss: %.3f' % train_loss) 115 | 116 | if settings.no_val: 117 | # save checkpoint every epoch 118 | is_best = False 119 | self.saver.save_checkpoint({ 120 | 'epoch': epoch + 1, 121 | 'state_dict': self.model.module.state_dict(), 122 | 'optimizer': self.optimizer.state_dict(), 123 | 'best_pred': self.best_pred, 124 | }, is_best) 125 | 126 | 127 | def validation(self, epoch): 128 | self.model.eval() 129 | self.evaluator.reset() 130 | tbar = tqdm(self.val_loader, desc='\r') 131 | test_loss = 0.0 132 | for i, sample in enumerate(tbar): 133 | image, target = sample['image'], sample['label'] 134 | if settings.cuda: 135 | image, target = image.cuda(), target.cuda() 136 | with torch.no_grad(): 137 | output = self.model(image) 138 | loss = self.criterion(output, target) 139 | test_loss += loss.item() 140 | tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) 141 | pred = output.data.cpu().numpy() 142 | target = target.cpu().numpy() 143 | pred = np.argmax(pred, axis=1) 144 | # Add batch sample into evaluator 145 | self.evaluator.add_batch(target, pred) 146 | 147 | # Fast test during the training 148 | Acc = self.evaluator.Pixel_Accuracy() 149 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 150 | mIoU = self.evaluator.Mean_Intersection_over_Union() 151 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 152 | self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) 153 | self.writer.add_scalar('val/mIoU', mIoU, epoch) 154 | self.writer.add_scalar('val/Acc', Acc, epoch) 155 | self.writer.add_scalar('val/Acc_class', Acc_class, epoch) 156 | self.writer.add_scalar('val/fwIoU', FWIoU, epoch) 157 | print('Validation:') 158 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * settings.batch_size + image.data.shape[0])) 159 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 160 | print('Loss: %.3f' % test_loss) 161 | 162 | new_pred = mIoU 163 | if new_pred > self.best_pred: 164 | is_best = True 165 | self.best_pred = new_pred 166 | self.saver.save_checkpoint({ 167 | 'epoch': epoch + 1, 168 | 'state_dict': self.model.module.state_dict(), 169 | 'optimizer': self.optimizer.state_dict(), 170 | 'best_pred': self.best_pred, 171 | }, is_best) 172 | 173 | if __name__ == "__main__": 174 | trainer = Trainer() 175 | print('Starting Epoch:', settings.start_epoch) 176 | print('Total Epoches:', settings.epochs) 177 | for epoch in range(settings.start_epoch, settings.epochs): 178 | trainer.training(epoch) 179 | if not settings.no_val and epoch % settings.eval_interval == (settings.eval_interval - 1): 180 | trainer.validation(epoch) 181 | 182 | trainer.writer.close() 183 | --------------------------------------------------------------------------------