├── dataloaders ├── datasets │ ├── __init__.py │ ├── camvid.py │ ├── pascal.py │ └── cityscapes.py ├── __init__.py ├── utils.py ├── factory.py └── custom_transforms.py ├── model ├── sync_bn │ ├── src │ │ ├── cpu │ │ │ ├── __init__.py │ │ │ ├── setup.py │ │ │ ├── operator.cpp │ │ │ ├── operator.h │ │ │ └── syncbn_cpu.cpp │ │ ├── gpu │ │ │ ├── __init__.py │ │ │ ├── setup.py │ │ │ ├── operator.cpp │ │ │ ├── operator.h │ │ │ ├── device_tensor.h │ │ │ └── common.h │ │ └── __init__.py │ ├── __init__.py │ ├── functions.py │ ├── parallel_apply.py │ ├── comm.py │ ├── parallel.py │ └── syncbn.py ├── __init__.py ├── backbone │ ├── __init__.py │ └── resnet_v1.py ├── sync_batchnorm │ ├── __init__.py │ ├── unittest.py │ ├── batchnorm_reimpl.py │ ├── replicate.py │ └── comm.py ├── net_factory.py └── psp.py ├── crf ├── crf.png ├── gt.png ├── img.png ├── 01_bg.png ├── pred.png ├── 02_dog.png ├── 03_sofa.png ├── crf_eval.sh ├── crf.py ├── crf_refine_test.py └── crf_refine.py ├── img ├── intro.png ├── res_01.png ├── res_02.png ├── res_03.png ├── zju_cad.jpg └── tensorboard.png ├── .editorconfig ├── LICENSE ├── script ├── docker.sh ├── inference.sh ├── eval.sh └── train.sh ├── utils ├── loss.py ├── metrics.py ├── model_init.py ├── model_store.py ├── files.py └── train_utils.py ├── requirements.txt ├── losses ├── normal_loss.py ├── pyramid_loss.py ├── loss_factory.py ├── affinity │ ├── utils.py │ └── aaf.py └── rmi │ ├── rmi_utils.py │ └── rmi.py ├── full_model.py ├── inference.py ├── eval.py └── README.md /dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/sync_bn/src/cpu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/sync_bn/src/gpu/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | -------------------------------------------------------------------------------- /model/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | -------------------------------------------------------------------------------- /crf/crf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/crf/crf.png -------------------------------------------------------------------------------- /crf/gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/crf/gt.png -------------------------------------------------------------------------------- /crf/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/crf/img.png -------------------------------------------------------------------------------- /crf/01_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/crf/01_bg.png -------------------------------------------------------------------------------- /crf/pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/crf/pred.png -------------------------------------------------------------------------------- /img/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/img/intro.png -------------------------------------------------------------------------------- /crf/02_dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/crf/02_dog.png -------------------------------------------------------------------------------- /crf/03_sofa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/crf/03_sofa.png -------------------------------------------------------------------------------- /img/res_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/img/res_01.png -------------------------------------------------------------------------------- /img/res_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/img/res_02.png -------------------------------------------------------------------------------- /img/res_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/img/res_03.png -------------------------------------------------------------------------------- /img/zju_cad.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/img/zju_cad.jpg -------------------------------------------------------------------------------- /img/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/RMI/HEAD/img/tensorboard.png -------------------------------------------------------------------------------- /model/sync_bn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/10/3 下午2:10 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com 6 | # @File : __init__.py 7 | 8 | from .syncbn import * 9 | from .parallel import * -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | 7 | # Matches multiple files with brace expansion notation 8 | [*.{js,jsx,html,sass,py,md}] 9 | charset = utf-8 10 | indent_style = tab 11 | indent_size = 4 12 | trim_trailing_whitespace = true 13 | 14 | [*.md] 15 | trim_trailing_whitespace = false 16 | -------------------------------------------------------------------------------- /model/sync_bn/src/cpu/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CppExtension 3 | 4 | setup( 5 | name='syncbn_cpu', 6 | ext_modules=[ 7 | CppExtension('syncbn_cpu', [ 8 | 'operator.cpp', 9 | 'syncbn_cpu.cpp', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /model/sync_bn/src/gpu/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='syncbn_gpu', 6 | ext_modules=[ 7 | CUDAExtension('syncbn_gpu', [ 8 | 'operator.cpp', 9 | 'syncbn_kernel.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /model/sync_bn/src/cpu/operator.cpp: -------------------------------------------------------------------------------- 1 | #include "operator.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("batchnorm_forward", &BatchNorm_Forward_CPU, "BatchNorm forward (CPU)"); 5 | m.def("batchnorm_backward", &BatchNorm_Backward_CPU, "BatchNorm backward (CPU)"); 6 | m.def("sumsquare_forward", &Sum_Square_Forward_CPU, "SumSqu forward (CPU)"); 7 | m.def("sumsquare_backward", &Sum_Square_Backward_CPU, "SumSqu backward (CPU)"); 8 | } 9 | -------------------------------------------------------------------------------- /model/sync_bn/src/gpu/operator.cpp: -------------------------------------------------------------------------------- 1 | #include "operator.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("batchnorm_forward", &BatchNorm_Forward_CUDA, "BatchNorm forward (CUDA)"); 5 | m.def("batchnorm_backward", &BatchNorm_Backward_CUDA, "BatchNorm backward (CUDA)"); 6 | m.def("sumsquare_forward", &Sum_Square_Forward_CUDA, "SumSqu forward (CUDA)"); 7 | m.def("sumsquare_backward", &Sum_Square_Backward_CUDA, "SumSqu backward (CUDA)"); 8 | 9 | } 10 | -------------------------------------------------------------------------------- /model/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /model/sync_bn/src/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.cpp_extension import load 4 | 5 | cwd = os.path.dirname(os.path.realpath(__file__)) 6 | cpu_path = os.path.join(cwd, 'cpu') 7 | gpu_path = os.path.join(cwd, 'gpu') 8 | 9 | cpu = load('syncbn_cpu', [ 10 | os.path.join(cpu_path, 'operator.cpp'), 11 | os.path.join(cpu_path, 'syncbn_cpu.cpp'), 12 | ], build_directory=cpu_path, verbose=False) 13 | 14 | if torch.cuda.is_available(): 15 | gpu = load('syncbn_gpu', [ 16 | os.path.join(gpu_path, 'operator.cpp'), 17 | os.path.join(gpu_path, 'syncbn_kernel.cu'), 18 | ], build_directory=gpu_path, verbose=False) 19 | -------------------------------------------------------------------------------- /model/sync_bn/src/cpu/operator.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | at::Tensor BatchNorm_Forward_CPU( 5 | const at::Tensor input_, 6 | const at::Tensor mean_, 7 | const at::Tensor std_, 8 | const at::Tensor gamma_, 9 | const at::Tensor beta_); 10 | 11 | std::vector BatchNorm_Backward_CPU( 12 | const at::Tensor gradoutput_, 13 | const at::Tensor input_, 14 | const at::Tensor mean_, 15 | const at::Tensor std_, 16 | const at::Tensor gamma_, 17 | const at::Tensor beta_, 18 | bool train); 19 | 20 | std::vector Sum_Square_Forward_CPU( 21 | const at::Tensor input_); 22 | 23 | at::Tensor Sum_Square_Backward_CPU( 24 | const at::Tensor input_, 25 | const at::Tensor gradSum_, 26 | const at::Tensor gradSquare_); -------------------------------------------------------------------------------- /model/sync_bn/src/gpu/operator.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | 5 | at::Tensor BatchNorm_Forward_CUDA( 6 | const at::Tensor input_, 7 | const at::Tensor mean_, 8 | const at::Tensor std_, 9 | const at::Tensor gamma_, 10 | const at::Tensor beta_); 11 | 12 | std::vector BatchNorm_Backward_CUDA( 13 | const at::Tensor gradoutput_, 14 | const at::Tensor input_, 15 | const at::Tensor mean_, 16 | const at::Tensor std_, 17 | const at::Tensor gamma_, 18 | const at::Tensor beta_, 19 | bool train); 20 | 21 | std::vector Sum_Square_Forward_CUDA( 22 | const at::Tensor input_); 23 | 24 | at::Tensor Sum_Square_Backward_CUDA( 25 | const at::Tensor input_, 26 | const at::Tensor gradSum_, 27 | const at::Tensor gradSquare_); 28 | -------------------------------------------------------------------------------- /model/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shuai Zhao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /script/docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run the docker image 4 | DIR_NOW=$(pwd) 5 | 6 | cd ~ 7 | echo "current user : ${USER}" 8 | 9 | echo "" 10 | echo -n "input the docker image tag:" 11 | read docker_image_tag 12 | 13 | echo "" 14 | echo -n "input the mapping port:" 15 | read docker_image_port 16 | 17 | docker_image="zhaosssss/torch_lab:" 18 | docker_final_image="${docker_image}${docker_image_tag}" 19 | 20 | echo "The docker image is ${docker_final_image}" 21 | echo "run docker image..." 22 | 23 | 24 | #/usr/bin/docker run --runtime=nvidia --rm -it --memory-reservation 32G \ 25 | # --shm-size 8G \ 26 | # -v /home/${USER}:/home/${USER} -w ${DIR_NOW} \ 27 | # -p $docker_image_port:$docker_image_port $docker_final_image bash 28 | 29 | /usr/bin/docker run --runtime=nvidia --rm -it --memory-reservation 32G \ 30 | --shm-size 8G \ 31 | -v /home/${USER}:/home/${USER} --user=${UID}:${GID} -w ${DIR_NOW} \ 32 | -v /etc/group:/etc/group:ro -v /etc/passwd:/etc/passwd:ro \ 33 | -v /mnt/disk2/wy:/mnt/disk2/wy \ 34 | -v /mnt/disk1/wy:/mnt/disk1/wy \ 35 | -p $docker_image_port:$docker_image_port $docker_final_image bash 36 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class CrossEntropyLoss(object): 8 | """the normal cross entropy loss""" 9 | def __init__(self, ignore_index=255, accumulation_steps=1): 10 | self.ignore_index = ignore_index 11 | self.criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=self.ignore_index, 12 | reduction='elementwise_mean') 13 | self.accumulation_steps = accumulation_steps 14 | 15 | def __call__(self, logit, target): 16 | """call method""" 17 | #n, c, h, w = logit.size() 18 | loss = self.criterion(logit, target.long()) 19 | #loss = torch.div(loss, accumulation_steps * 1.0) 20 | return loss 21 | 22 | def cuda(self, main_gpu=0): 23 | self.criterion = self.criterion.cuda(main_gpu) 24 | 25 | 26 | class SegmentationLosses(nn.CrossEntropyLoss): 27 | """2D Cross Entropy Loss with Auxilary Loss""" 28 | def __init__(self, weight=None, ignore_index=255, reduction='mean'): 29 | super(SegmentationLosses, self).__init__(weight, None, ignore_index, reduction=reduction) 30 | 31 | def forward(self, *inputs): 32 | return super(SegmentationLosses, self).forward(*inputs) 33 | 34 | 35 | if __name__ == "__main__": 36 | pass 37 | -------------------------------------------------------------------------------- /model/net_factory.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch.nn as nn 8 | from RMI.model.backbone import resnet_v1 9 | #from RMI.model.backbone import resnet_v1_beta 10 | 11 | 12 | __all__ = ['get_backbone_net'] 13 | 14 | 15 | def get_backbone_net(backbone='resnet101', 16 | output_stride=16, 17 | pretrained=True, 18 | norm_layer=nn.BatchNorm2d, 19 | bn_mom=0.01, 20 | root_beta=True): 21 | """get the backnbone net of the segmentation model""" 22 | # A map from network name to network object. 23 | networks_obj_dict = { 24 | #'mobilenet_v2': _mobilenet_v2, 25 | 'resnet50': resnet_v1.resnet50, 26 | 'resnet101': resnet_v1.resnet101, 27 | 'resnet152': resnet_v1.resnet152, 28 | #'resnet50_beta': resnet_v1_beta.resnet50_beta, 29 | #'resnet101_beta': resnet_v1_beta.resnet101_beta, 30 | #'resnet152_beta': resnet_v1_beta.resnet152_beta, 31 | #'xception_41': xception.xception_41, 32 | #'xception_65': xception.xception_65, 33 | } 34 | assert backbone in networks_obj_dict.keys() 35 | if 'resnet' in backbone: 36 | backbone_net = networks_obj_dict[backbone](output_stride=output_stride, 37 | pretrained=pretrained, 38 | norm_layer=norm_layer, 39 | bn_mom=bn_mom, 40 | root_beta=root_beta) 41 | return backbone_net 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | asn1crypto==0.24.0 3 | astor==0.7.1 4 | backcall==0.1.0 5 | beautifulsoup4==4.7.1 6 | certifi==2019.3.9 7 | cffi==1.12.3 8 | chardet==3.0.4 9 | conda==4.6.14 10 | conda-build==3.17.8 11 | cryptography==2.6.1 12 | cycler==0.10.0 13 | Cython==0.29.7 14 | decorator==4.4.0 15 | filelock==3.0.10 16 | gast==0.2.2 17 | glob2==0.6 18 | grpcio==1.20.1 19 | h5py==2.9.0 20 | idna==2.8 21 | ipython==7.5.0 22 | ipython-genutils==0.2.0 23 | jedi==0.13.3 24 | Jinja2==2.10.1 25 | Keras-Applications==1.0.7 26 | Keras-Preprocessing==1.0.9 27 | kiwisolver==1.1.0 28 | libarchive-c==2.8 29 | lief==0.9.0 30 | Markdown==3.1 31 | MarkupSafe==1.1.1 32 | matplotlib==3.0.3 33 | mkl-fft==1.0.12 34 | mkl-random==1.0.2 35 | mock==3.0.5 36 | nose==1.3.7 37 | numpy==1.16.3 38 | olefile==0.46 39 | opencv-python==3.3.0.9 40 | parso==0.4.0 41 | pexpect==4.7.0 42 | pickleshare==0.7.5 43 | Pillow==6.0.0 44 | pkginfo==1.5.0.1 45 | prompt-toolkit==2.0.9 46 | protobuf==3.7.1 47 | psutil==5.6.2 48 | ptyprocess==0.6.0 49 | pycocotools==2.0.0 50 | pycosat==0.6.3 51 | pycparser==2.19 52 | pydensecrf==1.0rc3 53 | Pygments==2.3.1 54 | pyOpenSSL==19.0.0 55 | pyparsing==2.4.0 56 | PySocks==1.6.8 57 | python-dateutil==2.8.0 58 | pytz==2019.1 59 | PyYAML==5.1 60 | requests==2.21.0 61 | ruamel-yaml==0.15.46 62 | scipy==1.2.1 63 | setproctitle==1.1.10 64 | six==1.12.0 65 | soupsieve==1.8 66 | tensorboard==1.13.1 67 | tensorboardX==1.6 68 | tensorflow==1.13.1 69 | tensorflow-estimator==1.13.0 70 | termcolor==1.1.0 71 | torch==1.1.0 72 | torch-encoding==1.0.1 73 | torchvision==0.2.2 74 | tqdm==4.31.1 75 | traitlets==4.3.2 76 | urllib3==1.24.2 77 | wcwidth==0.1.7 78 | Werkzeug==0.15.2 79 | -------------------------------------------------------------------------------- /losses/normal_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | """ 4 | Implementation of some commonly used losses. 5 | """ 6 | 7 | # python 2.X, 3.X compatibility 8 | from __future__ import print_function 9 | from __future__ import division 10 | from __future__ import absolute_import 11 | 12 | #import os 13 | #import numpy as np 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | 20 | class BCECrossEntropyLoss(nn.Module): 21 | """ 22 | sigmoid with binary cross entropy loss. 23 | consider the multiclass task as multi binary classification problem. 24 | one-vs-rest way. 25 | SUM over the channel. 26 | """ 27 | def __init__(self, 28 | num_classes=21, 29 | ignore_index=255): 30 | super(BCECrossEntropyLoss, self).__init__() 31 | self.num_classes = num_classes 32 | self.ignore_index = ignore_index 33 | 34 | def forward(self, logits_4D, labels_4D): 35 | """ 36 | Args: 37 | logits_4D : [N, C, H, W], dtype=float32 38 | labels_4D : [N, H, W], dtype=long 39 | """ 40 | label_flat = labels_4D.view(-1).requires_grad_(False) 41 | label_mask_flat = label_flat < self.num_classes 42 | onehot_label_flat = F.one_hot(label_flat * label_mask_flat.long(), num_classes=self.num_classes).float() 43 | onehot_label_flat = onehot_label_flat.requires_grad_(False) 44 | logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes]) 45 | 46 | # binary loss, multiplied by the not_ignore_mask 47 | label_mask_flat = label_mask_flat.float() 48 | valid_pixels = torch.sum(label_mask_flat) 49 | binary_loss = F.binary_cross_entropy_with_logits(logits_flat, 50 | target=onehot_label_flat, 51 | weight=label_mask_flat.unsqueeze(dim=1), 52 | reduction='sum') 53 | bce_loss = torch.div(binary_loss, valid_pixels + 1.0) 54 | return bce_loss 55 | -------------------------------------------------------------------------------- /losses/pyramid_loss.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # python 2.X, 3.X compatibility 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | #import os 9 | #import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | class PyramidLoss(nn.Module): 16 | """ 17 | Pyramid Loss. 18 | """ 19 | def __init__(self, 20 | num_classes=21, 21 | ignore_index=255, 22 | scales=(0.25, 0.5, 0.75, 1.0)): 23 | super(PyramidLoss, self).__init__() 24 | self.num_classes = num_classes 25 | # ignore class 26 | self.ignore_index = ignore_index 27 | self.scales = scales 28 | 29 | def forward(self, logits_4D, labels_3D): 30 | """ 31 | Using both softmax and sigmoid operations. 32 | Args: 33 | logits_4D : [N, C, H, W], dtype=float32 34 | labels_4D : [N, H, W], dtype=long 35 | """ 36 | h, w = labels_3D.shape[-2], labels_3D.shape[-1] 37 | total_loss = F.cross_entropy(input=logits_4D, 38 | target=labels_3D.long(), 39 | ignore_index=self.ignore_index, 40 | reduction='mean') 41 | labels_4D = labels_3D.unsqueeze(dim=1) 42 | for scale in self.scales: 43 | if scale == 1.0: 44 | continue 45 | assert scale <= 1.0 46 | now_h, now_w = int(scale * h), int(scale * w) 47 | now_logits = F.interpolate(logits_4D, size=(now_h, now_w), mode='bilinear') 48 | now_labels = F.interpolate(labels_4D, size=(now_h, now_w), mode='nearest') 49 | now_loss = F.cross_entropy(input=now_logits, 50 | target=now_labels.squeeze(dim=1).long(), 51 | ignore_index=self.ignore_index, 52 | reduction='mean') 53 | total_loss += now_loss 54 | final_loss = total_loss / len(self.scales) 55 | return final_loss 56 | 57 | 58 | if __name__ == '__main__': 59 | pass 60 | -------------------------------------------------------------------------------- /losses/loss_factory.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | # python 2.X, 3.X compatibility 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | #import torch 9 | import torch.nn as nn 10 | 11 | from RMI.losses import normal_loss 12 | from RMI.losses import pyramid_loss 13 | from RMI.losses.rmi import rmi 14 | from RMI.losses.affinity import aaf 15 | 16 | def criterion_choose(num_classes=21, 17 | loss_type=0, 18 | weight=None, 19 | ignore_index=255, 20 | reduction='mean', 21 | max_iter=30000, 22 | args=None): 23 | """choose the criterion to use""" 24 | info_dict = { 25 | 0: "Normal Softmax Cross Entropy Loss", 26 | 1: "Normal Sigmoid Cross Entropy Loss", 27 | 2: "Region Mutual Information Loss", 28 | 3: "Affinity field Loss", 29 | 5: "Pyramid Loss" 30 | } 31 | print("INFO:PyTorch: Using {}.".format(info_dict[loss_type])) 32 | if loss_type == 0: 33 | return nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction=reduction) 34 | elif loss_type == 1: 35 | return normal_loss.BCECrossEntropyLoss(num_classes=num_classes, ignore_index=ignore_index) 36 | elif loss_type == 2: 37 | return rmi.RMILoss(num_classes=num_classes, 38 | rmi_radius=args.rmi_radius, 39 | rmi_pool_way=args.rmi_pool_way, 40 | rmi_pool_size=args.rmi_pool_size, 41 | rmi_pool_stride=args.rmi_pool_stride, 42 | loss_weight_lambda=args.loss_weight_lambda) 43 | elif loss_type == 3: 44 | return aaf.AffinityLoss(num_classes=num_classes, 45 | init_step=args.init_global_step, 46 | max_iter=max_iter) 47 | elif loss_type == 5: 48 | return pyramid_loss.PyramidLoss(num_classes=num_classes, ignore_index=ignore_index) 49 | 50 | else: 51 | raise NotImplementedError("The loss type {} is not implemented.".format(loss_type)) 52 | -------------------------------------------------------------------------------- /model/sync_bn/src/cpu/syncbn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 6 | if (x.ndimension() == 2) { 7 | return v; 8 | } else { 9 | std::vector broadcast_size = {1, -1}; 10 | for (int64_t i = 2; i < x.ndimension(); ++i) 11 | broadcast_size.push_back(1); 12 | 13 | return v.view(broadcast_size); 14 | } 15 | } 16 | 17 | at::Tensor BatchNorm_Forward_CPU( 18 | const at::Tensor input, 19 | const at::Tensor mean, 20 | const at::Tensor std, 21 | const at::Tensor gamma, 22 | const at::Tensor beta) { 23 | auto output = (input - broadcast_to(mean, input)) / broadcast_to(std, input); 24 | output = output * broadcast_to(gamma, input) + broadcast_to(beta, input); 25 | return output; 26 | } 27 | 28 | // Not implementing CPU backward for now 29 | std::vector BatchNorm_Backward_CPU( 30 | const at::Tensor gradoutput, 31 | const at::Tensor input, 32 | const at::Tensor mean, 33 | const at::Tensor std, 34 | const at::Tensor gamma, 35 | const at::Tensor beta, 36 | bool train) { 37 | /* outputs*/ 38 | at::Tensor gradinput = at::zeros_like(input); 39 | at::Tensor gradgamma = at::zeros_like(gamma); 40 | at::Tensor gradbeta = at::zeros_like(beta); 41 | at::Tensor gradMean = at::zeros_like(mean); 42 | at::Tensor gradStd = at::zeros_like(std); 43 | return {gradinput, gradMean, gradStd, gradgamma, gradbeta}; 44 | } 45 | 46 | std::vector Sum_Square_Forward_CPU( 47 | const at::Tensor input) { 48 | /* outputs */ 49 | at::Tensor sum = torch::zeros({input.size(1)}, input.options()); 50 | at::Tensor square = torch::zeros({input.size(1)}, input.options()); 51 | return {sum, square}; 52 | } 53 | 54 | at::Tensor Sum_Square_Backward_CPU( 55 | const at::Tensor input, 56 | const at::Tensor gradSum, 57 | const at::Tensor gradSquare) { 58 | /* outputs */ 59 | at::Tensor gradInput = at::zeros_like(input); 60 | return gradInput; 61 | } 62 | -------------------------------------------------------------------------------- /full_model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | #import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | _PSP_AUX_WEIGHT = 0.4 # the weight of the auxiliary loss in PSPNet 11 | 12 | 13 | class FullModel(nn.Module): 14 | """The full model wrapper.""" 15 | def __init__(self, seg_model='deeplabv3', 16 | model=None, 17 | loss_type=None, 18 | criterion=None): 19 | super(FullModel, self).__init__() 20 | assert seg_model in ['pspnet', 'deeplabv3', 'deeplabv3+'] 21 | self.seg_model = seg_model 22 | self.model = model 23 | self.loss_type = loss_type 24 | self.criterion = criterion 25 | 26 | def forward(self, inputs=None, target=None, global_step=0, mode='train'): 27 | """forward step""" 28 | # output of the model 29 | output = self.model(inputs) 30 | 31 | # do not calclate the loss during validation or testing 32 | if 'val' in mode or 'test' in mode: 33 | if self.seg_model == 'pspnet': 34 | output = output[0] 35 | return output 36 | 37 | # PSPNet have auxilary branch 38 | if self.loss_type == 2: 39 | if self.seg_model == 'pspnet': 40 | #loss = self.criterion(output[0], target) + _PSP_AUX_WEIGHT * self.criterion(output[1], target) 41 | #loss = loss / (1.0 + _PSP_AUX_WEIGHT) 42 | loss = self.criterion(output[0], target) + _PSP_AUX_WEIGHT * F.cross_entropy(input=output[1], 43 | target=target.long(), 44 | ignore_index=255, 45 | reduction='mean') 46 | output = output[0] 47 | else: 48 | loss = self.criterion(output, target) 49 | elif self.loss_type == 3: 50 | if self.seg_model == 'pspnet': 51 | loss = (self.criterion(output[0], target, global_step=global_step) + 52 | _PSP_AUX_WEIGHT * self.criterion(output[1], target, global_step=global_step)) 53 | output = output[0] 54 | else: 55 | loss = self.criterion(output, target, global_step=global_step) 56 | elif self.loss_type == 5: 57 | if self.seg_model == 'pspnet': 58 | loss = self.criterion(output[0], target) + _PSP_AUX_WEIGHT * self.criterion(output[1], target) 59 | output = output[0] 60 | else: 61 | loss = self.criterion(output, target) 62 | else: 63 | if self.seg_model == 'pspnet': 64 | loss = (self.criterion(output[0], target.long()) + _PSP_AUX_WEIGHT * self.criterion(output[1], target.long())) 65 | output = output[0] 66 | else: 67 | loss = self.criterion(output, target.long()) 68 | #loss = loss.unsqueeze(dim=0) 69 | return output, loss 70 | -------------------------------------------------------------------------------- /model/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /model/sync_bn/functions.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## Email: zhanghang0704@gmail.com 4 | ## Copyright (c) 2018 5 | ## 6 | ## This source code is licensed under the MIT-style license found in the 7 | ## LICENSE file in the root directory of this source tree 8 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 9 | 10 | """"the same as the functions.syncbn in torch-encoding""" 11 | """Synchronized Cross-GPU Batch Normalization functions""" 12 | import torch 13 | from torch.autograd import Variable, Function 14 | from .src import * 15 | 16 | __all__ = ['sum_square', 'batchnormtrain'] 17 | 18 | def sum_square(input): 19 | r"""Calculate sum of elements and sum of squares for Batch Normalization""" 20 | return _sum_square.apply(input) 21 | 22 | 23 | class _sum_square(Function): 24 | @staticmethod 25 | def forward(ctx, input): 26 | ctx.save_for_backward(input) 27 | if input.is_cuda: 28 | xsum, xsqusum = gpu.sumsquare_forward(input) 29 | else: 30 | xsum, xsqusum = cpu.sumsquare_forward(input) 31 | return xsum, xsqusum 32 | 33 | @staticmethod 34 | def backward(ctx, gradSum, gradSquare): 35 | input, = ctx.saved_variables 36 | if input.is_cuda: 37 | gradInput = gpu.sumsquare_backward(input, gradSum, gradSquare) 38 | else: 39 | raise NotImplemented 40 | return gradInput 41 | 42 | 43 | class _batchnormtrain(Function): 44 | @staticmethod 45 | def forward(ctx, input, mean, std, gamma, beta): 46 | ctx.save_for_backward(input, mean, std, gamma, beta) 47 | if input.is_cuda: 48 | output = gpu.batchnorm_forward(input, mean, std, gamma, beta) 49 | else: 50 | output = cpu.batchnorm_forward(input, mean, std, gamma, beta) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, gradOutput): 55 | input, mean, std, gamma, beta = ctx.saved_variables 56 | if gradOutput.is_cuda: 57 | gradInput, gradMean, gradStd, gradGamma, gradBeta = \ 58 | gpu.batchnorm_backward(gradOutput, input, mean, 59 | std, gamma, beta, True) 60 | else: 61 | raise NotImplemented 62 | return gradInput, gradMean, gradStd, gradGamma, gradBeta 63 | 64 | 65 | def batchnormtrain(input, mean, std, gamma, beta): 66 | r"""Applies Batch Normalization over a 3d input that is seen as a 67 | mini-batch. 68 | 69 | .. _encoding.batchnormtrain: 70 | 71 | .. math:: 72 | 73 | y = \frac{x - \mu[x]}{ \sqrt{var[x] + \epsilon}} * \gamma + \beta 74 | 75 | Shape: 76 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 77 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 78 | 79 | """ 80 | return _batchnormtrain.apply(input, mean, std, gamma, beta) 81 | -------------------------------------------------------------------------------- /losses/affinity/utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | """ 4 | The pytorch implementation of the paper: 5 | @inproceedings{aaf2018, 6 | author = {Ke, Tsung-Wei and Hwang, Jyh-Jing and Liu, Ziwei and Yu, Stella X.}, 7 | title = {Adaptive Affinity Fields for Semantic Segmentation}, 8 | booktitle = {European Conference on Computer Vision (ECCV)}, 9 | month = {September}, 10 | year = {2018} 11 | } 12 | """ 13 | 14 | # python 2.X, 3.X compatibility 15 | from __future__ import print_function 16 | from __future__ import division 17 | from __future__ import absolute_import 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | 23 | __all__ = ['edges_from_label', 'eightcorner_activation'] 24 | 25 | 26 | def edges_from_label(labels, size=1, ignore_class=255): 27 | """Retrieves edge positions from the ground-truth labels. 28 | This function computes the edge map by considering if the pixel values 29 | are equal between the center and the neighboring pixels on the eight 30 | corners from a (2 * size + 1) * (2 * size + 1) patch. 31 | Ignore edges where the any of the paired pixels with label value >= num_classes. 32 | 33 | Args: 34 | labels: A tensor of size [N, C, H, W], 35 | indicating semantic segmentation ground-truth labels. 36 | size: A number indicating the half size of a patch. 37 | ignore_class: A number indicating the label value to ignore. 38 | Return: 39 | A tensor of size [N, C, 8, H, W] 40 | """ 41 | # Get the number of channels in the input. 42 | shape_lab = labels.size() 43 | assert len(shape_lab) == 4 44 | n, c, h, w = shape_lab 45 | 46 | # Pad at the margin. 47 | labels_pad = F.pad(labels, (size, size, size, size), mode='constant', value=0) 48 | 49 | # Get the edge by comparing label value of the center and it paired pixels. 50 | edge_groups = [] 51 | for st_y in range(0, 2 * size + 1, size): 52 | for st_x in range(0, 2 * size + 1, size): 53 | if st_y == size and st_x == size: 54 | continue 55 | edge_groups.append(labels_pad[:, :, st_y:st_y+h, st_x:st_x+w] != labels) 56 | # shape [N, C, 8, H, W] 57 | return torch.stack(edge_groups, dim=2) 58 | 59 | 60 | def eightcorner_activation(x, size=1): 61 | """Retrieves neighboring pixels one the eight corners from a 62 | (2 * size + 1) x (2 * size + 1) patch. 63 | Args: 64 | x: A tensor with shape [N, C, H, W] 65 | size: A number indicating the half size of a patch. 66 | Returns: 67 | A tensor with shape [N, C, 8, H, W] 68 | """ 69 | # Get the number of channels in the input. 70 | shape_lab = x.size() 71 | assert len(shape_lab) == 4 72 | n, c, h, w = shape_lab 73 | 74 | # Pad at the margin. 75 | x_pad = F.pad(x, (size, size, size, size), mode='constant', value=0) 76 | 77 | # Get eight corner pixels/features in the patch. 78 | x_groups = [] 79 | for st_y in range(0, 2 * size + 1, size): 80 | for st_x in range(0, 2 * size + 1, size): 81 | if st_y == size and st_x == size: 82 | # Ignore the center pixel/feature. 83 | continue 84 | x_groups.append(x_pad[:, :, st_y:st_y+h, st_x:st_x+w]) 85 | 86 | # shape [N, C, 8, H, W] 87 | return torch.stack(x_groups, dim=2) 88 | -------------------------------------------------------------------------------- /model/sync_bn/src/gpu/device_tensor.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | template 4 | struct DeviceTensor { 5 | public: 6 | inline __device__ __host__ DeviceTensor(DType *p, const int *size) 7 | : dptr_(p) { 8 | for (int i = 0; i < Dim; ++i) { 9 | size_[i] = size ? size[i] : 0; 10 | } 11 | } 12 | 13 | inline __device__ __host__ unsigned getSize(const int i) const { 14 | assert(i < Dim); 15 | return size_[i]; 16 | } 17 | 18 | inline __device__ __host__ int numElements() const { 19 | int n = 1; 20 | for (int i = 0; i < Dim; ++i) { 21 | n *= size_[i]; 22 | } 23 | return n; 24 | } 25 | 26 | inline __device__ __host__ DeviceTensor select(const size_t x) const { 27 | assert(Dim > 1); 28 | int offset = x; 29 | for (int i = 1; i < Dim; ++i) { 30 | offset *= size_[i]; 31 | } 32 | DeviceTensor tensor(dptr_ + offset, nullptr); 33 | for (int i = 0; i < Dim - 1; ++i) { 34 | tensor.size_[i] = this->size_[i+1]; 35 | } 36 | return tensor; 37 | } 38 | 39 | inline __device__ __host__ DeviceTensor operator[](const size_t x) const { 40 | assert(Dim > 1); 41 | int offset = x; 42 | for (int i = 1; i < Dim; ++i) { 43 | offset *= size_[i]; 44 | } 45 | DeviceTensor tensor(dptr_ + offset, nullptr); 46 | for (int i = 0; i < Dim - 1; ++i) { 47 | tensor.size_[i] = this->size_[i+1]; 48 | } 49 | return tensor; 50 | } 51 | 52 | inline __device__ __host__ size_t InnerSize() const { 53 | assert(Dim >= 3); 54 | size_t sz = 1; 55 | for (size_t i = 2; i < Dim; ++i) { 56 | sz *= size_[i]; 57 | } 58 | return sz; 59 | } 60 | 61 | inline __device__ __host__ size_t ChannelCount() const { 62 | assert(Dim >= 3); 63 | return size_[1]; 64 | } 65 | 66 | inline __device__ __host__ DType* data_ptr() const { 67 | return dptr_; 68 | } 69 | 70 | DType *dptr_; 71 | int size_[Dim]; 72 | }; 73 | 74 | template 75 | struct DeviceTensor { 76 | inline __device__ __host__ DeviceTensor(DType *p, const int *size) 77 | : dptr_(p) { 78 | size_[0] = size ? size[0] : 0; 79 | } 80 | 81 | inline __device__ __host__ unsigned getSize(const int i) const { 82 | assert(i == 0); 83 | return size_[0]; 84 | } 85 | 86 | inline __device__ __host__ int numElements() const { 87 | return size_[0]; 88 | } 89 | 90 | inline __device__ __host__ DType &operator[](const size_t x) const { 91 | return *(dptr_ + x); 92 | } 93 | 94 | inline __device__ __host__ DType* data_ptr() const { 95 | return dptr_; 96 | } 97 | 98 | DType *dptr_; 99 | int size_[1]; 100 | }; 101 | 102 | template 103 | static DeviceTensor devicetensor(const at::Tensor &blob) { 104 | DType *data = blob.data(); 105 | DeviceTensor tensor(data, nullptr); 106 | for (int i = 0; i < Dim; ++i) { 107 | tensor.size_[i] = blob.size(i); 108 | } 109 | return tensor; 110 | } 111 | -------------------------------------------------------------------------------- /model/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /crf/crf_eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # python PATH 4 | export PYTHONPATH="${PYTHONPATH}:${HOME}/github" 5 | 6 | # hyperparameter 7 | echo -n "input the gpu (seperate by comma (,) ): " 8 | read gpus 9 | export CUDA_VISIBLE_DEVICES=${gpus} 10 | echo "using gpus ${gpus}" 11 | # replace comma(,) with empty 12 | #gpus=${gpus//,/} 13 | # the number of characters 14 | #num_gpus=${#gpus} 15 | #echo "the number of gpus is ${num_gpus}" 16 | 17 | # choose the base model 18 | echo "" 19 | echo "0 -- deeplabv3" 20 | echo "1 -- deeplabv3+" 21 | echo "2 -- pspnet" 22 | echo -n "choose the base model: " 23 | read model_choose 24 | case ${model_choose} in 25 | 0 ) 26 | base_model="deeplabv3" 27 | ;; 28 | 1 ) 29 | base_model="deeplabv3+" 30 | ;; 31 | 2 ) 32 | base_model="pspnet" 33 | ;; 34 | * ) 35 | echo "The choice of the segmentation model is illegal!" 36 | exit 1 37 | ;; 38 | esac 39 | 40 | # choose the backbone 41 | echo "" 42 | echo "0 -- resnet_v1_50" 43 | echo "1 -- resnet_v1_101" 44 | echo "2 -- resnet_v1_152" 45 | echo -n "choose the base network: " 46 | read base_network 47 | #base_network=1 48 | 49 | case ${base_network} in 50 | 0 ) 51 | backbone="resnet50";; 52 | 1 ) 53 | backbone="resnet101";; 54 | 2 ) 55 | backbone="resnet152";; 56 | * ) 57 | echo "The choice of the base network is illegal!" 58 | exit 1 59 | ;; 60 | esac 61 | echo "The backbone is ${backbone}" 62 | echo "The base model is ${base_model}" 63 | 64 | # choose the batch size 65 | batch_size=1 66 | 67 | # choose the dataset 68 | echo "" 69 | echo "0 -- PASCAL VOC2012 dataset" 70 | echo "1 -- Cityscapes" 71 | echo "2 -- CamVid" 72 | echo -n "input the dataset: " 73 | read dataset 74 | 75 | if [ ${dataset} = 0 ] 76 | then 77 | # data dir 78 | data_dir="${HOME}/dataset/VOCdevkit/VOC2012" 79 | checkpoint_name="deeplab-resnet_ckpt_30406.pth" 80 | train_split='val' 81 | dataset=pascal 82 | elif [ ${dataset} = 1 ] 83 | then 84 | data_dir="${HOME}/dataset/Cityscapes/" 85 | dataset=cityscapes 86 | elif [ ${dataset} = 2 ] 87 | then 88 | data_dir="${HOME}/dataset/CamVid/" 89 | checkpoint_name="deeplab-resnet_ckpt_5800.pth" 90 | dataset=camvid 91 | train_split='test' 92 | else 93 | echo "The choice of the dataset is illegal!" 94 | exit 1 95 | fi 96 | echo "The data dir is ${data_dir}, the batch size is ${batch_size}." 97 | 98 | 99 | 100 | 101 | # set the work dir 102 | work_dir="${HOME}/github/RMI/crf" 103 | 104 | # ckpt directory 105 | ##################################################### 106 | # STE YOUR CHECKPOINT FILE HERE 107 | ##################################################### 108 | model_name=TBD 109 | 110 | resume=TBD 111 | 112 | # model dir and output dir 113 | model_dir=TBD 114 | output_dir=TBD 115 | if [ -d ${output_dir} ] 116 | then 117 | echo "save outputs into ${output_dir}" 118 | else 119 | mkdir -p ${output_dir} 120 | echo "make the directory ${output_dir}" 121 | fi 122 | 123 | # crf steps 124 | # choose the dataset 125 | echo "" 126 | echo -n "input the iteration step of CRF (1 ~ 10):" 127 | read crf_iter_steps 128 | # train the model 129 | #for crf_iter_steps in 5 130 | #do 131 | python ${work_dir}/crf_refine.py --resume ${resume} \ 132 | --seg_model ${base_model} \ 133 | --backbone ${backbone} \ 134 | --model_dir ${model_dir} \ 135 | --train_split ${train_split} \ 136 | --gpu_ids ${gpus} \ 137 | --checkname deeplab-resnet \ 138 | --dataset ${dataset} \ 139 | --data_dir ${data_dir} \ 140 | --crf_iter_steps ${crf_iter_steps} \ 141 | --output_dir ${output_dir} 142 | #done 143 | echo "Test Finished!!!" 144 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | evaluation during training 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | class Evaluator(object): 11 | def __init__(self, num_class): 12 | """initialize the evalutor""" 13 | self.num_class = num_class 14 | self.confusion_matrix = np.zeros((self.num_class, self.num_class)) 15 | 16 | def pixel_accuracy_np(self): 17 | """calculate the pixel accuracy with numpy""" 18 | denominator = self.confusion_matrix.sum().astype(float) 19 | cm_diag_sum = np.diagonal(self.confusion_matrix).sum().astype(float) 20 | 21 | # If the number of valid entries is 0 (no classes) we return 0. 22 | accuracy = np.where(denominator > 0, cm_diag_sum / denominator, 0) 23 | accuracy = float(accuracy) 24 | #print('Pixel Accuracy: {:.4f}'.format(float(accuracy))) 25 | return accuracy 26 | 27 | def Pixel_Accuracy_Class(self): 28 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 29 | Acc = np.nanmean(Acc) 30 | return Acc 31 | 32 | def mean_iou_np(self, is_show_per_class=False): 33 | """compute mean iou with numpy""" 34 | sum_over_row = np.sum(self.confusion_matrix, axis=0).astype(float) 35 | sum_over_col = np.sum(self.confusion_matrix, axis=1).astype(float) 36 | cm_diag = np.diagonal(self.confusion_matrix).astype(float) 37 | denominator = sum_over_row + sum_over_col - cm_diag 38 | 39 | # The mean is only computed over classes that appear in the 40 | # label or prediction tensor. If the denominator is 0, we need to 41 | # ignore the class. 42 | num_valid_entries = np.sum((denominator != 0).astype(float)) 43 | 44 | # If the value of the denominator is 0, set it to 1 to avoid 45 | # zero division. 46 | denominator = np.where(denominator > 0, 47 | denominator, 48 | np.ones_like(denominator)) 49 | ious = cm_diag / denominator 50 | 51 | if is_show_per_class: 52 | print('\nIntersection over Union for each class:') 53 | for i, iou in enumerate(ious): 54 | print(' class {}: {:.4f}'.format(i, iou)) 55 | 56 | # If the number of valid entries is 0 (no classes) we return 0. 57 | m_iou = np.where(num_valid_entries > 0, 58 | np.sum(ious) / num_valid_entries, 59 | 0) 60 | m_iou = float(m_iou) 61 | if is_show_per_class: 62 | print('mean Intersection over Union: {:.4f}'.format(float(m_iou))) 63 | return m_iou 64 | 65 | def Frequency_Weighted_Intersection_over_Union(self): 66 | """frequencey weighted miou""" 67 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 68 | iu = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) 69 | + np.sum(self.confusion_matrix, axis=0) - np.diag(self.confusion_matrix)) 70 | 71 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 72 | return FWIoU 73 | 74 | def _generate_matrix(self, gt_image, pre_image): 75 | """calculate confusion matrix""" 76 | mask = (gt_image >= 0) & (gt_image < self.num_class) 77 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 78 | count = np.bincount(label, minlength=self.num_class**2) 79 | confusion_matrix = count.reshape(self.num_class, self.num_class) 80 | return confusion_matrix 81 | 82 | def add_batch(self, gt_image, pre_image): 83 | """add the evluation result to the confusion maxtrix""" 84 | assert gt_image.shape == pre_image.shape 85 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 86 | 87 | def reset(self): 88 | """set the confusion matrix to 0""" 89 | self.confusion_matrix = np.zeros((self.num_class, self.num_class)) 90 | 91 | 92 | if __name__ == '__main__': 93 | eval = Evaluator(num_class=5) 94 | gt = np.array([0, 1, 2, 3, 4, 6]) 95 | pre = np.array([0, 1, 2, 3, 4, 1]) 96 | eval.add_batch(gt, pre) 97 | print(eval.confusion_matrix) 98 | -------------------------------------------------------------------------------- /utils/model_init.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | """ 3 | some training utils. 4 | """ 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | __all__ = ['init_weights', 'group_weight', 'seg_model_get_optim_params'] 16 | 17 | 18 | def init_weights(modules, norm_layer=nn.BatchNorm2d, bn_momentum=0.1): 19 | """ 20 | as for he_init normal with std = sqrt(2 / (Cin * k * k)) 21 | """ 22 | if not isinstance(modules, (list, tuple)): 23 | modules = (modules,) 24 | for module in modules: 25 | __init_weights(module, norm_layer, bn_momentum) 26 | 27 | 28 | def __init_weights(module, norm_layer=nn.BatchNorm2d, bn_momentum=0.1): 29 | """ 30 | The defaut init for conv weight and bias is uniform with stdv = 1 / sqrt(Cin * k * k). 31 | As for he_init normal with std = sqrt(2 / (Cin * k * k)). 32 | """ 33 | for m in module.modules(): 34 | if isinstance(m, (nn.Linear, nn.Conv2d)): 35 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 36 | if m.bias is not None: 37 | m.bias.data.zero_() 38 | elif isinstance(m, norm_layer) or isinstance(m, nn.BatchNorm2d): 39 | m.momentum = bn_momentum 40 | m.weight.data.fill_(1) 41 | m.bias.data.zero_() 42 | 43 | 44 | def __group_weight(group_decay, group_no_decay, module, norm_layer): 45 | for m in module.modules(): 46 | if isinstance(m, (nn.Linear, nn.Conv2d)): 47 | group_decay.append(m.weight) 48 | if m.bias is not None: 49 | group_no_decay.append(m.bias) 50 | elif isinstance(m, (norm_layer, nn.GroupNorm, nn.BatchNorm2d)): 51 | group_no_decay.append(m.weight) 52 | group_no_decay.append(m.bias) 53 | return group_decay, group_no_decay 54 | 55 | 56 | def group_weight(params_list, modules, norm_layer, lr, weight_decay): 57 | """group the weights. 58 | no weight decay for the biases, and alpha and gamma of the bn layers. 59 | 60 | ref: 61 | Bag of Tricks for Image Classification with Convolutional Neural Networks, 2018. 62 | """ 63 | group_decay = [] 64 | group_no_decay = [] 65 | params_length = 0 66 | 67 | if not isinstance(modules, (list, tuple)): 68 | modules = (modules, ) 69 | for module in modules: 70 | params_length += len(list(module.parameters())) 71 | group_decay, group_no_decay = __group_weight(group_decay, group_no_decay, module, norm_layer) 72 | 73 | assert params_length == len(group_decay) + len(group_no_decay) 74 | params_list.append(dict(params=group_decay, weight_decay=weight_decay, lr=lr)) 75 | params_list.append(dict(params=group_no_decay, lr=lr)) 76 | return params_list 77 | 78 | 79 | def seg_model_get_optim_params(params_list, model, 80 | norm_layer=nn.BatchNorm2d, 81 | seg_model='pspnet', 82 | base_lr=0.007, 83 | lr_multiplier=1.0, 84 | weight_decay=4e-5): 85 | """ 86 | get the params of the segmentation models. 87 | """ 88 | # group weight and config optimizer 89 | modules_list1 = (model.backbone, ) 90 | if seg_model == 'deeplabv3': 91 | modules_list2 = (model.aspp, model.last_conv) 92 | elif seg_model == 'deeplabv3+': 93 | modules_list2 = (model.aspp, model.decoder) 94 | elif seg_model == 'pspnet': 95 | modules_list2 = (model.psp_module, model.main_branch, model.aux_branch) 96 | else: 97 | raise 98 | # get the param list 99 | params_list = group_weight(params_list, modules_list1, 100 | norm_layer, base_lr, weight_decay=weight_decay) 101 | params_list = group_weight(params_list, modules_list2, 102 | norm_layer, base_lr * lr_multiplier, weight_decay=weight_decay) 103 | return params_list 104 | -------------------------------------------------------------------------------- /model/sync_bn/parallel_apply.py: -------------------------------------------------------------------------------- 1 | # import threading 2 | import queue 3 | import torch 4 | import torch.multiprocessing as mp 5 | # from pathos.multiprocessing import ProcessPool as Pool 6 | from torch.cuda._utils import _get_device_index 7 | 8 | #######貌似没什么用 9 | 10 | def get_a_var(obj): 11 | if isinstance(obj, torch.Tensor): 12 | return obj 13 | 14 | if isinstance(obj, list) or isinstance(obj, tuple): 15 | for result in map(get_a_var, obj): 16 | if isinstance(result, torch.Tensor): 17 | return result 18 | if isinstance(obj, dict): 19 | for result in map(get_a_var, obj.items()): 20 | if isinstance(result, torch.Tensor): 21 | return result 22 | return None 23 | 24 | 25 | def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): 26 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 27 | contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) 28 | on each of :attr:`devices`. 29 | Args: 30 | modules (Module): modules to be parallelized 31 | inputs (tensor): inputs to the modules 32 | devices (list of int or torch.device): CUDA devices 33 | :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and 34 | :attr:`devices` (if given) should all have same length. Moreover, each 35 | element of :attr:`inputs` can either be a single object as the only argument 36 | to a module, or a collection of positional arguments. 37 | """ 38 | assert len(modules) == len(inputs) 39 | if kwargs_tup is not None: 40 | assert len(modules) == len(kwargs_tup) 41 | else: 42 | kwargs_tup = ({},) * len(modules) 43 | if devices is not None: 44 | assert len(modules) == len(devices) 45 | else: 46 | devices = [None] * len(modules) 47 | devices = list(map(lambda x: _get_device_index(x, True), devices)) 48 | context = mp.get_context('spawn') 49 | # lock = threading.Lock() 50 | # results = {} 51 | # results = [] 52 | # pool = context.Pool(len(devices)) 53 | results_queue = queue.Queue(len(devices)) 54 | grad_enabled = torch.is_grad_enabled() 55 | 56 | def _worker(module, input, kwargs, device=None): 57 | torch.set_grad_enabled(grad_enabled) 58 | if device is None: 59 | device = get_a_var(input).get_device() 60 | try: 61 | with torch.cuda.device(device): 62 | # this also avoids accidental slicing of `input` if it is a Tensor 63 | if not isinstance(input, (list, tuple)): 64 | input = (input,) 65 | output = module(*input, **kwargs) 66 | results_queue.put(output) 67 | # with lock: 68 | # results[i] = output 69 | except Exception as e: 70 | results_queue.put(e) 71 | # with lock: 72 | # results[i] = e 73 | 74 | if len(modules) > 1: 75 | # pool.map(_worker, [modules, inputs, kwargs_tup, devices]) 76 | processes = [context.Process(target=_worker, 77 | args=(i, module, input, kwargs, device)) 78 | for i, (module, input, kwargs, device) in 79 | enumerate(zip(modules, inputs, kwargs_tup, devices))] 80 | 81 | for process in processes: 82 | process.start() 83 | for process in processes: 84 | process.join() 85 | else: 86 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 87 | 88 | outputs = [] 89 | for i in range(len(inputs)): 90 | output = results_queue.get() 91 | if isinstance(output, Exception): 92 | raise output 93 | outputs.append(output) 94 | return outputs 95 | -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 9 | rgb_masks = [] 10 | for label_mask in label_masks: 11 | rgb_mask = decode_segmap(label_mask, dataset) 12 | rgb_masks.append(rgb_mask) 13 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 14 | return rgb_masks 15 | 16 | 17 | def decode_segmap(label_mask, dataset, plot=False): 18 | """Decode segmentation class labels into a color image 19 | Args: 20 | label_mask (np.ndarray): an (M,N) array of integer values denoting 21 | the class label at each spatial location. 22 | plot (bool, optional): whether to show the resulting color image 23 | in a figure. 24 | Returns: 25 | (np.ndarray, optional): the resulting decoded color image. 26 | """ 27 | if dataset == 'pascal' or dataset == 'coco': 28 | n_classes = 21 29 | label_colours = get_pascal_labels() 30 | elif dataset == 'cityscapes': 31 | n_classes = 19 32 | label_colours = get_cityscapes_labels() 33 | else: 34 | raise NotImplementedError 35 | 36 | r = label_mask.copy() 37 | g = label_mask.copy() 38 | b = label_mask.copy() 39 | for ll in range(0, n_classes): 40 | r[label_mask == ll] = label_colours[ll, 0] 41 | g[label_mask == ll] = label_colours[ll, 1] 42 | b[label_mask == ll] = label_colours[ll, 2] 43 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 44 | rgb[:, :, 0] = r #/ 255.0 45 | rgb[:, :, 1] = g #/ 255.0 46 | rgb[:, :, 2] = b #/ 255.0 47 | if plot: 48 | plt.imshow(rgb) 49 | plt.show() 50 | else: 51 | return rgb 52 | 53 | 54 | def encode_segmap(mask): 55 | """Encode segmentation label images as pascal classes 56 | Args: 57 | mask (np.ndarray): raw segmentation label image of dimension 58 | (M, N, 3), in which the Pascal classes are encoded as colours. 59 | Returns: 60 | (np.ndarray): class map with dimensions (M,N), where the value at 61 | a given location is the integer denoting the class index. 62 | """ 63 | mask = mask.astype(int) 64 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 65 | for ii, label in enumerate(get_pascal_labels()): 66 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 67 | label_mask = label_mask.astype(int) 68 | return label_mask 69 | 70 | 71 | def get_cityscapes_labels(): 72 | return np.array([ 73 | [128, 64, 128], 74 | [244, 35, 232], 75 | [70, 70, 70], 76 | [102, 102, 156], 77 | [190, 153, 153], 78 | [153, 153, 153], 79 | [250, 170, 30], 80 | [220, 220, 0], 81 | [107, 142, 35], 82 | [152, 251, 152], 83 | [0, 130, 180], 84 | [220, 20, 60], 85 | [255, 0, 0], 86 | [0, 0, 142], 87 | [0, 0, 70], 88 | [0, 60, 100], 89 | [0, 80, 100], 90 | [0, 0, 230], 91 | [119, 11, 32]]) 92 | 93 | 94 | def get_pascal_labels(): 95 | """Load the mapping that associates pascal classes with label colors 96 | Returns: 97 | np.ndarray with dimensions (21, 3) 98 | """ 99 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 100 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 101 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 102 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 103 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 104 | [0, 64, 128]]) -------------------------------------------------------------------------------- /script/inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # python PATH 4 | export PYTHONPATH="${PYTHONPATH}:${HOME}/github" 5 | 6 | # hyperparameter 7 | echo -n "input the gpu (seperate by comma (,) ): " 8 | read gpus 9 | export CUDA_VISIBLE_DEVICES=${gpus} 10 | echo "using gpus ${gpus}" 11 | # replace comma(,) with empty 12 | #gpus=${gpus//,/} 13 | # the number of characters 14 | #num_gpus=${#gpus} 15 | #echo "the number of gpus is ${num_gpus}" 16 | 17 | # choose the base model 18 | echo "" 19 | echo "0 -- deeplabv3" 20 | echo "1 -- deeplabv3+" 21 | echo "2 -- pspnet" 22 | echo -n "choose the base model: " 23 | read model_choose 24 | case ${model_choose} in 25 | 0 ) 26 | base_model="deeplabv3" 27 | ;; 28 | 1 ) 29 | base_model="deeplabv3+" 30 | ;; 31 | 2 ) 32 | base_model="pspnet" 33 | ;; 34 | * ) 35 | echo "The choice of the segmentation model is illegal!" 36 | exit 1 37 | ;; 38 | esac 39 | 40 | # choose the backbone 41 | #echo "" 42 | #echo "0 -- resnet_v1_50" 43 | #echo "1 -- resnet_v1_101" 44 | #echo "2 -- resnet_v1_152" 45 | #echo -n "choose the base network: " 46 | #read base_network 47 | base_network=1 48 | case ${base_network} in 49 | 0 ) 50 | backbone="resnet50";; 51 | 1 ) 52 | backbone="resnet101";; 53 | 2 ) 54 | backbone="resnet152";; 55 | * ) 56 | echo "The choice of the base network is illegal!" 57 | exit 1 58 | ;; 59 | esac 60 | echo "The backbone is ${backbone}" 61 | echo "The base model is ${base_model}" 62 | 63 | # choose the batch size 64 | batch_size=1 65 | 66 | # choose the dataset 67 | #echo "" 68 | #echo "0 -- PASCAL VOC2012 dataset" 69 | #echo -n "input the dataset: " 70 | #read dataset 71 | dataset=0 72 | if [ ${dataset} = 0 ] 73 | then 74 | ##################################################### 75 | # SET YOUR DATA DIR HERE 76 | ##################################################### 77 | data_dir="${HOME}/dataset/VOCtest" 78 | dataset=pascal 79 | elif [ ${dataset} = 1 ] 80 | then 81 | ##################################################### 82 | # SET YOUR DATA DIR HERE 83 | ##################################################### 84 | data_dir="${HOME}/dataset/Cityscapes/" 85 | dataset=cityscapes 86 | else 87 | echo "The choice of the dataset is illegal!" 88 | exit 1 89 | fi 90 | echo "The data dir is ${data_dir}, the batch size is ${batch_size}." 91 | 92 | 93 | train_split='test' 94 | 95 | # set the work dir 96 | work_dir="${HOME}/github/RMI" 97 | 98 | # ckpt directory 99 | ##################################################### 100 | # STE YOUR CHECKPOINT FILE HERE 101 | ##################################################### 102 | #model_name=TBD 103 | #checkpoint_name="deeplab-resnet_ckpt_30406.pth" 104 | 105 | ##################################################### 106 | # STE YOUR RESUME CKPT HERE 107 | ##################################################### 108 | resume=TBD 109 | 110 | # model dir and output dir 111 | ##################################################### 112 | # STE YOUR MODEL DIR AND OUTPUT DIR HERE 113 | ##################################################### 114 | model_dir=TDB 115 | output_dir=TBD 116 | 117 | 118 | if [ -d ${output_dir} ] 119 | then 120 | rm -r ${output_dir} 121 | mkdir -p ${output_dir} 122 | echo "delete and make the directory ${output_dir}" 123 | else 124 | mkdir -p ${output_dir} 125 | echo "make the directory ${output_dir}" 126 | fi 127 | 128 | # train the model 129 | 130 | python ${work_dir}/inference.py --resume ${resume} \ 131 | --seg_model ${base_model} \ 132 | --backbone ${backbone} \ 133 | --model_dir ${model_dir} \ 134 | --train_split ${train_split} \ 135 | --gpu_ids ${gpus} \ 136 | --checkname deeplab-resnet \ 137 | --dataset ${dataset} \ 138 | --data_dir ${data_dir} \ 139 | --output_dir ${output_dir} 140 | echo "Test Finished!!!" 141 | -------------------------------------------------------------------------------- /crf/crf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | Function which returns the labelled image after applying CRF. 4 | """ 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | #import cv2 10 | import numpy as np 11 | from PIL import Image 12 | import pydensecrf.densecrf as dcrf 13 | from pydensecrf.utils import unary_from_softmax 14 | 15 | 16 | def dense_crf(real_image, probs, iter_steps=10): 17 | """ 18 | Args: 19 | real_image : the real world RGB image, numpy array, shape [H, W, 3] 20 | probs : the predicted probability in (0, 1), shape [H, W, C] 21 | iter_steps : the iterative steps 22 | Returns: 23 | return the refined segmentation map in [0,1,2,...,N_label] 24 | ref: 25 | https://github.com/milesial/Pytorch-UNet/blob/master/utils/crf.py 26 | https://github.com/lucasb-eyer/pydensecrf/blob/master/examples/Non%20RGB%20Example.ipynb 27 | """ 28 | # converting real -world image to RGB if it is gray 29 | if(len(real_image.shape) < 3): 30 | #real_image = cv2.cvtColor(real_image, cv2.COLOR_GRAY2RGB) 31 | raise ValueError("The input image should be RGB image.") 32 | # shape, and transpose to [C, H, W] 33 | H, W, N_classes = probs.shape[0], probs.shape[1], probs.shape[2] 34 | probs = probs.transpose((2, 0, 1)) 35 | # get unary potentials from the probability distribution 36 | unary = unary_from_softmax(probs) 37 | #unary = np.ascontiguousarray(unary) 38 | # CRF 39 | d = dcrf.DenseCRF2D(W, H, N_classes) 40 | d.setUnaryEnergy(unary) 41 | # add pairwise potentials 42 | #real_image = np.ascontiguousarray(real_image) 43 | d.addPairwiseGaussian(sxy=3, compat=3) 44 | d.addPairwiseBilateral(sxy=30, srgb=13, rgbim=real_image, compat=10) 45 | #d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 46 | #d.addPairwiseBilateral(sxy=(80, 80), srgb=(13, 13, 13), rgbim=real_image, 47 | # compat=10, kernel=dcrf.DIAG_KERNEL, normalization=dcrf.NORMALIZE_SYMMETRIC) 48 | # inference 49 | Q = d.inference(iter_steps) 50 | Q = np.argmax(np.array(Q), axis=0).reshape((H, W)) 51 | return Q 52 | 53 | 54 | if __name__ == '__main__': 55 | import cv2 56 | 57 | def decode_labels(mask, num_images=1, num_classes=21, color_list=None): 58 | """Decode batch of segmentation masks. 59 | Args: 60 | mask: result of inference after taking argmax. 61 | num_images: number of images to decode from the batch. 62 | num_classes: number of classes to predict (including background). 63 | Returns: 64 | A batch with num_images RGB images of the same size as the input. 65 | """ 66 | n, h, w, c = mask.shape 67 | assert (n >= num_images) 68 | outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) 69 | for i in range(num_images): 70 | img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) 71 | pixels = img.load() 72 | for j_, j in enumerate(mask[i, :, :, 0]): 73 | for k_, k in enumerate(j): 74 | if k < num_classes: 75 | pixels[k_, j_] = color_dict[k] 76 | outputs[i] = np.array(img) 77 | return outputs 78 | 79 | img = cv2.cvtColor(cv2.imread('img.png'), cv2.COLOR_BGR2RGB) 80 | prob_01 = cv2.cvtColor(cv2.imread('01_bg.png'), cv2.COLOR_BGR2GRAY) 81 | prob_02 = cv2.cvtColor(cv2.imread('02_dog.png'), cv2.COLOR_BGR2GRAY) 82 | prob_03 = cv2.cvtColor(cv2.imread('03_sofa.png'), cv2.COLOR_BGR2GRAY) 83 | prob = np.stack([prob_01, prob_02, prob_03], axis=-1) / 255.0 84 | H, W = prob.shape[0], prob.shape[1] 85 | img = cv2.resize(img, (W, H)) 86 | pred = dense_crf(img.astype(np.uint8), prob.astype(np.float32), iter_steps=1) 87 | print(prob.shape, np.min(prob), np.max(prob)) 88 | print(img.shape) 89 | print(pred.shape) 90 | # background, dog, sofa 91 | color_dict = [(0, 0, 0), (64, 0, 128), (0, 192, 0)] 92 | pred = np.expand_dims(pred, axis=0) 93 | pred = np.expand_dims(pred, axis=-1) 94 | out = decode_labels(pred, num_images=1, num_classes=3, color_list=color_dict) 95 | out = np.squeeze(out, axis=0) 96 | cv2.imwrite('./crf.png', out) 97 | -------------------------------------------------------------------------------- /script/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # python PATH 4 | export PYTHONPATH="${PYTHONPATH}:${HOME}/github" 5 | #export LD_PRELOAD="/usr/lib/libtcmalloc_minimal.so.4" 6 | 7 | # hyperparameter 8 | echo -n "input the gpu (seperate by comma (,) ): " 9 | read gpus 10 | export CUDA_VISIBLE_DEVICES=${gpus} 11 | echo "using gpus ${gpus}" 12 | # replace comma(,) with empty 13 | #gpus=${gpus//,/} 14 | # the number of characters 15 | #num_gpus=${#gpus} 16 | #echo "the number of gpus is ${num_gpus}" 17 | 18 | # choose the base model 19 | echo "" 20 | echo "0 -- deeplabv3" 21 | echo "1 -- deeplabv3+" 22 | echo "2 -- pspnet" 23 | echo -n "choose the base model: " 24 | read model_choose 25 | case ${model_choose} in 26 | 0 ) 27 | base_model="deeplabv3" 28 | ;; 29 | 1 ) 30 | base_model="deeplabv3+" 31 | ;; 32 | 2 ) 33 | base_model="pspnet" 34 | ;; 35 | * ) 36 | echo "The choice of the segmentation model is illegal!" 37 | exit 1 38 | ;; 39 | esac 40 | 41 | # choose the backbone 42 | echo "" 43 | echo "0 -- resnet_v1_50" 44 | echo "1 -- resnet_v1_101" 45 | echo "2 -- resnet_v1_152" 46 | echo -n "choose the base network: " 47 | read base_network 48 | #base_network=1 49 | 50 | case ${base_network} in 51 | 0 ) 52 | backbone="resnet50";; 53 | 1 ) 54 | backbone="resnet101";; 55 | 2 ) 56 | backbone="resnet152";; 57 | * ) 58 | echo "The choice of the base network is illegal!" 59 | exit 1 60 | ;; 61 | esac 62 | echo "The backbone is ${backbone}" 63 | echo "The base model is ${base_model}" 64 | 65 | # choose the batch size 66 | batch_size=1 67 | 68 | # choose the dataset 69 | echo "" 70 | echo "0 -- PASCAL VOC2012 dataset" 71 | echo "1 -- Cityscapes" 72 | echo "2 -- CamVid" 73 | echo -n "input the dataset: " 74 | read dataset 75 | 76 | if [ ${dataset} = 0 ] 77 | then 78 | ##################################################### 79 | # SET YOUR DATA DIR HERE 80 | ##################################################### 81 | data_dir="${HOME}/dataset/VOCdevkit/VOC2012" 82 | checkpoint_name="deeplab-resnet_ckpt_30406.pth" 83 | train_split='val' 84 | dataset=pascal 85 | elif [ ${dataset} = 1 ] 86 | then 87 | data_dir="${HOME}/dataset/Cityscapes/" 88 | dataset=cityscapes 89 | elif [ ${dataset} = 2 ] 90 | then 91 | data_dir="${HOME}/dataset/CamVid/" 92 | checkpoint_name="deeplab-resnet_ckpt_5800.pth" 93 | dataset=camvid 94 | train_split='test' 95 | else 96 | echo "The choice of the dataset is illegal!" 97 | exit 1 98 | fi 99 | echo "The data dir is ${data_dir}, the batch size is ${batch_size}." 100 | 101 | 102 | # set the work dir 103 | work_dir="${HOME}/github/RMI/" 104 | train_split='val' 105 | 106 | # ckpt directory 107 | ##################################################### 108 | # STE YOUR CHECKPOINT FILE HERE 109 | ##################################################### 110 | # model_name=TBD 111 | 112 | ##################################################### 113 | # STE YOUR RESUME CHECKPOINT HERE 114 | ##################################################### 115 | resume=TBD 116 | 117 | # model dir and output dir 118 | ##################################################### 119 | # STE YOUR MODEL DIR AND OUTPUT DIR HERE 120 | ##################################################### 121 | model_dir=TBD 122 | output_dir=TBD 123 | 124 | 125 | if [ -d ${output_dir} ] 126 | then 127 | rm -r ${output_dir} 128 | mkdir -p ${output_dir} 129 | echo "delete and make the directory ${output_dir}" 130 | else 131 | mkdir -p ${output_dir} 132 | echo "make the directory ${output_dir}" 133 | fi 134 | 135 | #do 136 | python ${work_dir}/eval.py --resume ${resume} \ 137 | --seg_model ${base_model} \ 138 | --backbone ${backbone} \ 139 | --model_dir ${model_dir} \ 140 | --train_split ${train_split} \ 141 | --gpu_ids ${gpus} \ 142 | --checkname deeplab-resnet \ 143 | --dataset ${dataset} \ 144 | --data_dir ${data_dir} \ 145 | --output_dir ${output_dir} 146 | 147 | echo "Test Finished!!!" 148 | -------------------------------------------------------------------------------- /utils/model_store.py: -------------------------------------------------------------------------------- 1 | # coding = utf-8 2 | """ 3 | Model store which provides pretrained models. 4 | reference: 5 | https://github.com/zhanghang1989/PyTorch-Encoding 6 | """ 7 | 8 | 9 | from __future__ import print_function 10 | __all__ = ['get_model_file', 'purge'] 11 | 12 | 13 | import os 14 | import zipfile 15 | 16 | from RMI.utils.files import download, check_sha1 17 | 18 | _model_sha1 = {name: checksum for checksum, name in [ 19 | ('ebb6acbbd1d1c90b7f446ae59d30bf70c74febc1', 'resnet50'), 20 | ('2a57e44de9c853fa015b172309a1ee7e2d0e4e2a', 'resnet101'), 21 | ('0d43d698c66aceaa2bc0309f55efdd7ff4b143af', 'resnet152'), 22 | ('2e22611a7f3992ebdee6726af169991bc26d7363', 'deepten_minc'), 23 | ('662e979de25a389f11c65e9f1df7e06c2c356381', 'fcn_resnet50_ade'), 24 | ('eeed8e582f0fdccdba8579e7490570adc6d85c7c', 'fcn_resnet50_pcontext'), 25 | ('54f70c772505064e30efd1ddd3a14e1759faa363', 'psp_resnet50_ade'), 26 | ('075195c5237b778c718fd73ceddfa1376c18dfd0', 'deeplab_resnet50_ade'), 27 | ('5ee47ee28b480cc781a195d13b5806d5bbc616bf', 'encnet_resnet101_coco'), 28 | ('4de91d5922d4d3264f678b663f874da72e82db00', 'encnet_resnet50_pcontext'), 29 | ('9f27ea13d514d7010e59988341bcbd4140fcc33d', 'encnet_resnet101_pcontext'), 30 | ('07ac287cd77e53ea583f37454e17d30ce1509a4a', 'encnet_resnet50_ade'), 31 | ('3f54fa3b67bac7619cd9b3673f5c8227cf8f4718', 'encnet_resnet101_ade'), 32 | ]} 33 | 34 | 35 | encoding_repo_url = 'https://hangzh.s3.amazonaws.com/' 36 | _url_format = '{repo_url}encoding/models/{file_name}.zip' 37 | 38 | def short_hash(name): 39 | if name not in _model_sha1: 40 | raise ValueError('Pretrained model for {name} is not available.'.format(name=name)) 41 | return _model_sha1[name][:8] 42 | 43 | 44 | def get_model_file(name, root=os.path.join('~', '.encoding', 'models')): 45 | r"""Return location for the pretrained on local file system. 46 | 47 | This function will download from online model zoo when model cannot be found or has mismatch. 48 | The root directory will be created if it doesn't exist. 49 | 50 | Parameters 51 | ---------- 52 | name : str 53 | Name of the model. 54 | root : str, default '~/.encoding/models' 55 | Location for keeping the model parameters. 56 | 57 | Returns 58 | ------- 59 | file_path 60 | Path to the requested pretrained model file. 61 | """ 62 | file_name = '{name}-{short_hash}'.format(name=name, short_hash=short_hash(name)) 63 | root = os.path.expanduser(root) 64 | file_path = os.path.join(root, file_name + '.pth') 65 | sha1_hash = _model_sha1[name] 66 | if os.path.exists(file_path): 67 | if check_sha1(file_path, sha1_hash): 68 | print("Restore parameters from the {}".format(file_path)) 69 | return file_path 70 | else: 71 | print('Mismatch in the content of model file {} detected.'.format(file_path)) 72 | print('Downloading again...') 73 | else: 74 | print('Model file {} is not found. Downloading.'.format(file_path)) 75 | 76 | if not os.path.exists(root): 77 | os.makedirs(root) 78 | 79 | zip_file_path = os.path.join(root, file_name+'.zip') 80 | repo_url = os.environ.get('ENCODING_REPO', encoding_repo_url) 81 | if repo_url[-1] != '/': 82 | repo_url = repo_url + '/' 83 | download(_url_format.format(repo_url=repo_url, file_name=file_name), 84 | path=zip_file_path, 85 | overwrite=True) 86 | with zipfile.ZipFile(zip_file_path) as zf: 87 | zf.extractall(root) 88 | os.remove(zip_file_path) 89 | 90 | if check_sha1(file_path, sha1_hash): 91 | return file_path 92 | else: 93 | raise ValueError('Downloaded file has different hash. Please try again.') 94 | 95 | 96 | def purge(root=os.path.join('~', '.encoding', 'models')): 97 | r"""Purge all pretrained model files in local file store. 98 | 99 | Parameters 100 | ---------- 101 | root : str, default '~/.encoding/models' 102 | Location for keeping the model parameters. 103 | """ 104 | root = os.path.expanduser(root) 105 | files = os.listdir(root) 106 | for f in files: 107 | if f.endswith(".pth"): 108 | os.remove(os.path.join(root, f)) 109 | 110 | 111 | def pretrained_model_list(): 112 | return list(_model_sha1.keys()) 113 | 114 | 115 | if __name__ == '__main__': 116 | #get_model_file('resnet101', root='~/.encoding/models') 117 | #get_model_file('resnet50', root='~/.encoding/models') 118 | get_model_file('resnet152', root='~/.encoding/models') 119 | -------------------------------------------------------------------------------- /utils/files.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import os 4 | import requests 5 | import errno 6 | import shutil 7 | import hashlib 8 | from tqdm import tqdm 9 | import torch 10 | 11 | __all__ = ['save_checkpoint', 'download', 'mkdir', 'check_sha1'] 12 | 13 | 14 | def save_checkpoint(state, args, is_best, filename='checkpoint.pth.tar'): 15 | """Saves checkpoint to disk""" 16 | directory = "runs/%s/%s/%s/"%(args.dataset, args.model, args.checkname) 17 | if not os.path.exists(directory): 18 | os.makedirs(directory) 19 | filename = directory + filename 20 | torch.save(state, filename) 21 | if is_best: 22 | shutil.copyfile(filename, directory + 'model_best.pth.tar') 23 | 24 | 25 | def download(url, path=None, overwrite=False, sha1_hash=None): 26 | """Download an given URL 27 | Parameters 28 | ---------- 29 | url : str 30 | URL to download 31 | path : str, optional 32 | Destination path to store downloaded file. By default stores to the 33 | current directory with same name as in url. 34 | overwrite : bool, optional 35 | Whether to overwrite destination file if already exists. 36 | sha1_hash : str, optional 37 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 38 | but doesn't match. 39 | Returns 40 | ------- 41 | str 42 | The file path of the downloaded file. 43 | """ 44 | if path is None: 45 | fname = url.split('/')[-1] 46 | else: 47 | path = os.path.expanduser(path) 48 | if os.path.isdir(path): 49 | fname = os.path.join(path, url.split('/')[-1]) 50 | else: 51 | fname = path 52 | 53 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 54 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 55 | if not os.path.exists(dirname): 56 | os.makedirs(dirname) 57 | 58 | print('Downloading %s from %s...'%(fname, url)) 59 | r = requests.get(url, stream=True) 60 | if r.status_code != 200: 61 | raise RuntimeError("Failed downloading url %s"%url) 62 | total_length = r.headers.get('content-length') 63 | with open(fname, 'wb') as f: 64 | if total_length is None: # no content length header 65 | for chunk in r.iter_content(chunk_size=1024): 66 | if chunk: # filter out keep-alive new chunks 67 | f.write(chunk) 68 | else: 69 | total_length = int(total_length) 70 | for chunk in tqdm(r.iter_content(chunk_size=1024), 71 | total=int(total_length / 1024. + 0.5), 72 | unit='KB', unit_scale=False, dynamic_ncols=True): 73 | f.write(chunk) 74 | 75 | if sha1_hash and not check_sha1(fname, sha1_hash): 76 | raise UserWarning('File {} is downloaded but the content hash does not match. ' 77 | 'The repo may be outdated or download may be incomplete. ' 78 | 'If the "repo_url" is overridden, consider switching to ' 79 | 'the default repo.'.format(fname)) 80 | 81 | return fname 82 | 83 | 84 | def check_sha1(filename, sha1_hash): 85 | """Check whether the sha1 hash of the file content matches the expected hash. 86 | Parameters 87 | ---------- 88 | filename : str 89 | Path to the file. 90 | sha1_hash : str 91 | Expected sha1 hash in hexadecimal digits. 92 | Returns 93 | ------- 94 | bool 95 | Whether the file content matches the expected hash. 96 | """ 97 | sha1 = hashlib.sha1() 98 | with open(filename, 'rb') as f: 99 | while True: 100 | data = f.read(1048576) 101 | if not data: 102 | break 103 | sha1.update(data) 104 | 105 | return sha1.hexdigest() == sha1_hash 106 | 107 | 108 | def mkdir(path): 109 | """make dir exists okay""" 110 | try: 111 | os.makedirs(path) 112 | except OSError as exc: # Python >2.5 113 | if exc.errno == errno.EEXIST and os.path.isdir(path): 114 | pass 115 | else: 116 | raise 117 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import os 4 | import time 5 | import timeit 6 | import argparse 7 | import numpy as np 8 | 9 | #import cv2 10 | from PIL import Image 11 | 12 | import torch 13 | #import torch.nn.functional as F 14 | from RMI import parser_params, full_model 15 | 16 | from RMI.model import psp, deeplab 17 | from RMI.dataloaders import factory 18 | from RMI.utils.metrics import Evaluator 19 | 20 | 21 | # A map from segmentation name to model object. 22 | seg_model_obj_dict = { 23 | 'pspnet': psp.PSPNet, 24 | 'deeplabv3': deeplab.DeepLabv3, 25 | 'deeplabv3+': deeplab.DeepLabv3Plus, 26 | } 27 | 28 | 29 | class Trainer(object): 30 | def __init__(self, args): 31 | """initialize the Trainer""" 32 | # about gpus 33 | self.cuda = args.cuda 34 | self.gpu_ids = args.gpu_ids 35 | self.num_gpus = len(self.gpu_ids) 36 | self.crf_iter_steps = args.crf_iter_steps 37 | self.output_dir = args.output_dir 38 | self.model = 'test' 39 | 40 | # define dataloader 41 | self.val_loader = factory.get_dataset(args.data_dir, 42 | batch_size=1, 43 | dataset=args.dataset, 44 | split=args.train_split) 45 | self.nclass = self.val_loader.NUM_CLASSES 46 | # define network 47 | assert args.seg_model in seg_model_obj_dict.keys() 48 | self.seg_model = args.seg_model 49 | self.seg_model = seg_model_obj_dict[self.seg_model](num_classes=self.nclass, 50 | backbone=args.backbone, 51 | output_stride=args.out_stride, 52 | norm_layer=torch.nn.BatchNorm2d, 53 | bn_mom=args.bn_mom, 54 | freeze_bn=True) 55 | 56 | # define criterion 57 | #self.criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduction='mean') 58 | self.model = full_model.FullModel(seg_model=self.seg_model, 59 | model=self.model) 60 | # define evaluator 61 | self.evaluator = Evaluator(self.nclass) 62 | 63 | # using cuda 64 | if args.cuda: 65 | self.model = torch.nn.DataParallel(self.model, device_ids=self.gpu_ids) 66 | #patch_replication_callback(self.model) 67 | self.model = self.model.cuda() 68 | #self.criterion = self.criterion.cuda() 69 | 70 | # resuming checkpoint 71 | if args.resume is not None: 72 | if not os.path.isfile(args.resume): 73 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 74 | print('Restore parameters from the {}'.format(args.resume)) 75 | checkpoint = torch.load(args.resume) 76 | self.global_step = checkpoint['global_step'] 77 | 78 | if args.cuda: 79 | self.model.module.load_state_dict(checkpoint['state_dict']) 80 | else: 81 | self.model.load_state_dict(checkpoint['state_dict']) 82 | 83 | def validation(self): 84 | """validation procedure 85 | """ 86 | # set validation mode 87 | self.model.eval() 88 | self.evaluator.reset() 89 | start = timeit.default_timer() 90 | for i in range(len(self.val_loader)): 91 | sample = self.val_loader[i] 92 | image = sample['image'] 93 | if self.cuda: 94 | image = image.cuda() 95 | image = image.unsqueeze(dim=0) 96 | # forward 97 | with torch.no_grad(): 98 | output = self.model(image) 99 | # the output of the pspnet is a tuple 100 | if self.seg_model == 'pspnet': 101 | output = output[0] 102 | 103 | output = output.squeeze_() 104 | pred = output.data.cpu().numpy() 105 | # save output 106 | pred = np.argmax(pred, axis=0) 107 | path_to_output = os.path.join(self.output_dir, self.val_loader.image_ids[i] + '.png') 108 | result = Image.fromarray(pred.astype(np.uint8)) 109 | result.save(path_to_output) 110 | #cv2.imwrite(path_to_output, pred) 111 | # report time of CRF 112 | if not i % 100: 113 | stop = timeit.default_timer() 114 | print("current step = {} ({:.3f} sec)". 115 | format(i, stop - start)) 116 | start = timeit.default_timer() 117 | 118 | 119 | def main(): 120 | # get the parameters 121 | parser = argparse.ArgumentParser(description="PyTorch Segmentation Model Testing") 122 | args = parser_params.add_parser_params(parser) 123 | print(args) 124 | 125 | torch.manual_seed(args.seed) 126 | trainer = Trainer(args) 127 | start_time = time.time() 128 | trainer.validation() 129 | total_time = time.time() - start_time 130 | print("The validation time is {:.5f} sec".format(total_time)) 131 | 132 | 133 | if __name__ == "__main__": 134 | main() 135 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | """ 3 | some training utils. 4 | reference: 5 | https://github.com/zhanghang1989/PyTorch-Encoding 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | #import os 13 | import math 14 | import torch 15 | from torchvision.utils import make_grid 16 | #from tensorboardX import SummaryWriter 17 | from RMI.dataloaders.utils import decode_seg_map_sequence 18 | 19 | 20 | class lr_scheduler(object): 21 | """learning rate scheduler 22 | step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 23 | cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 24 | poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 25 | 26 | Args: 27 | init_lr: initial learnig rate; 28 | mode: ['cos', 'poly', 'step']; 29 | num_epochs: traing steps; 30 | max_iter: max iterations of training; 31 | lr_step: hope you do not use this argument; 32 | slow_start_steps: slow start steps of training; 33 | slow_start_lr: slow start learning rate for slow_start_steps; 34 | end_lr: minimum learning rate. 35 | """ 36 | def __init__(self, init_lr, 37 | mode='poly', 38 | num_epochs=30, 39 | max_iter=30000, 40 | lr_step=1, 41 | slow_start_steps=0, 42 | slow_start_lr=1e-4, 43 | end_lr=1e-6, 44 | multiplier=1.0): 45 | self.init_lr = init_lr 46 | self.now_lr = self.init_lr 47 | self.mode = mode 48 | self.num_epochs = num_epochs 49 | self.max_iter = max_iter 50 | self.slow_start_steps = slow_start_steps 51 | self.slow_start_lr = slow_start_lr 52 | self.slow_max_iter = self.max_iter - self.slow_start_steps 53 | self.end_lr = end_lr 54 | self.multiplier = multiplier 55 | # step mode 56 | if self.mode == 'step': 57 | assert lr_step 58 | self.lr_step = lr_step 59 | # log info 60 | print('INFO:PyTorch: Using {} learning rate scheduler!'.format(self.mode)) 61 | 62 | def __call__(self, optimizer, global_step, epoch=1.0): 63 | """call method""" 64 | step_now = 1.0 * global_step 65 | 66 | if global_step <= self.slow_start_steps: 67 | # slow start strategy -- warm up 68 | # see https://arxiv.org/pdf/1812.01187.pdf 69 | # Bag of Tricks for Image Classification with Convolutional Neural Networks 70 | # for details. 71 | lr = (step_now / self.slow_start_steps) * (self.init_lr - self.slow_start_lr) 72 | lr = lr + self.slow_start_lr 73 | lr = min(lr, self.init_lr) 74 | else: 75 | step_now = step_now - self.slow_start_steps 76 | # calculate the learning rate 77 | if self.mode == 'cos': 78 | lr = 0.5 * self.init_lr * (1.0 + math.cos(step_now / self.slow_max_iter * math.pi)) 79 | elif self.mode == 'poly': 80 | lr = self.init_lr * pow(1.0 - step_now / self.slow_max_iter, 0.9) 81 | #elif self.mode == 'step': 82 | # lr = self.init_lr * (0.1 ** (epoch // self.lr_step)) 83 | else: 84 | raise NotImplementedError 85 | lr = max(lr, self.end_lr) 86 | 87 | self.now_lr = lr 88 | # adjust learning rate 89 | self._adjust_learning_rate(optimizer, lr) 90 | 91 | def _adjust_learning_rate(self, optimizer, lr): 92 | """adjust the leaning rate""" 93 | if len(optimizer.param_groups) == 1: 94 | optimizer.param_groups[0]['lr'] = lr 95 | else: 96 | # BE CAREFUL HERE!!! 97 | # 0 -- the backbone conv weights with weight decay 98 | # 1 -- the bn params and bias of backbone without weight decay 99 | # 2 -- the weights of other layers with weight decay 100 | # 3 -- the bn params and bias of other layers without weigth decay 101 | optimizer.param_groups[0]['lr'] = lr 102 | optimizer.param_groups[1]['lr'] = lr 103 | for i in range(2, len(optimizer.param_groups)): 104 | optimizer.param_groups[i]['lr'] = lr * self.multiplier 105 | 106 | 107 | def visualize_image(writer, dataset, image, target, output, global_step): 108 | """summary image during training. 109 | """ 110 | grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) 111 | writer.add_image('Image', grid_image, global_step) 112 | grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), 113 | dataset=dataset), 3, normalize=False, range=(0, 255)) 114 | writer.add_image('Predicted label', grid_image, global_step) 115 | grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), 116 | dataset=dataset), 3, normalize=False, range=(0, 255)) 117 | writer.add_image('Groundtruth label', grid_image, global_step) 118 | -------------------------------------------------------------------------------- /losses/affinity/aaf.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | """ 4 | The pytorch implementation of the paper: 5 | @inproceedings{aaf2018, 6 | author = {Ke, Tsung-Wei and Hwang, Jyh-Jing and Liu, Ziwei and Yu, Stella X.}, 7 | title = {Adaptive Affinity Fields for Semantic Segmentation}, 8 | booktitle = {European Conference on Computer Vision (ECCV)}, 9 | month = {September}, 10 | year = {2018} 11 | } 12 | """ 13 | 14 | # python 2.X, 3.X compatibility 15 | from __future__ import print_function 16 | from __future__ import division 17 | from __future__ import absolute_import 18 | 19 | import math 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from RMI.losses.affinity import utils as aaf_utils 25 | 26 | 27 | __all__ = ['AffinityLoss'] 28 | 29 | 30 | _BOT_EPSILON = 1e-4 31 | _TOP_EPSILON = 1.0 32 | 33 | 34 | class AffinityLoss(nn.Module): 35 | """ 36 | The affinity field loss. 37 | """ 38 | def __init__(self, 39 | num_classes=21, 40 | ignore_index=255, 41 | kld_lambda_1=1.0, 42 | kld_lambda_2=1.0, 43 | kld_margin=3.0, 44 | init_step=0, 45 | max_iter=30000): 46 | super(AffinityLoss, self).__init__() 47 | self.num_classes = num_classes 48 | # factor of aaf 49 | self.kld_lambda_1 = kld_lambda_1 50 | self.kld_lambda_2 = kld_lambda_2 51 | self.kld_margin = kld_margin 52 | self.ignore_index = ignore_index 53 | self.reduction = 'mean' 54 | self.down_stride = 8 55 | self.init_step = init_step 56 | self.max_iter = max_iter 57 | 58 | def forward(self, logits_4D, labels_3D, global_step=0): 59 | """ 60 | Args: 61 | logits_4D : [N, C, H, W], dtype=float32 62 | labels_4D : [N, H, W], dtype=long 63 | """ 64 | # PART I -- get the normal cross entropy loss 65 | normal_loss = F.cross_entropy(input=logits_4D, 66 | target=labels_3D.long(), 67 | ignore_index=self.ignore_index, 68 | reduction=self.reduction) 69 | 70 | # PART II -- get the affinity field loss 71 | # downsample the logits and labels to save memory 72 | shape = logits_4D.size() 73 | new_h, new_w = shape[2] // (self.down_stride), shape[3] // (self.down_stride) 74 | labels_3D = F.interpolate(labels_3D.unsqueeze(dim=1), size=(new_h, new_w), mode='nearest') 75 | labels_3D = labels_3D.squeeze(dim=1) 76 | logits_4D = F.interpolate(logits_4D, size=(new_h, new_w), mode='bilinear', align_corners=True) 77 | 78 | # get the valid label and logits 79 | # valid label, [N, C, H, W] 80 | label_mask_3D = labels_3D < self.num_classes 81 | valid_onehot_labels_4D = F.one_hot(labels_3D.long() * label_mask_3D.long(), num_classes=self.num_classes).float() 82 | label_mask_3D = label_mask_3D.float() 83 | valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3) 84 | valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) 85 | # valid probs 86 | probs_4D = F.softmax(logits_4D, dim=1) * label_mask_3D.unsqueeze(dim=1) 87 | probs_4D = probs_4D.clamp(min=_BOT_EPSILON, max=_TOP_EPSILON) 88 | 89 | # decay as https://github.com/twke18/Adaptive_Affinity_Fields 90 | aff_decay = math.pow(20.0, 0.0 - 1.0 * (global_step - self.init_step) / float(self.max_iter)) 91 | aaf_loss = aff_decay * self.affinity_loss(probs_4D, labels_4D=valid_onehot_labels_4D) 92 | # the final loss 93 | final_loss = normal_loss + aaf_loss 94 | return final_loss 95 | 96 | def affinity_loss(self, probs_4D, labels_4D=None, size=1): 97 | """ 98 | Args: 99 | logits_4D : [N, C, H, W], dtype=float32 100 | labels_4D : [N, C, H, W], dtype=float32 101 | size : default 1. 102 | """ 103 | # edge, shape [N, C, 8, H, W] 104 | edge = aaf_utils.edges_from_label(labels_4D, size=size) 105 | edge = edge.view(-1) 106 | 107 | # neighbour points, [N, C, 8, H, W] 108 | probs_paired = aaf_utils.eightcorner_activation(probs_4D, size=size) 109 | probs_paired = torch.clamp(probs_paired, min=_BOT_EPSILON, max=_TOP_EPSILON) 110 | probs_4D = probs_4D.unsqueeze(dim=2) 111 | neg_probs_4D = 1.0 - probs_4D + _BOT_EPSILON 112 | neg_probs_paired = 1.0 - probs_paired + _BOT_EPSILON 113 | 114 | # compute KL-Divergence 115 | KL_div = (probs_paired * (probs_paired.log() - probs_4D.log()) + 116 | neg_probs_paired * (neg_probs_paired.log() - neg_probs_4D.log())) 117 | KL_div = KL_div.view(-1) 118 | edge_loss = torch.max(torch.zeros(1).type_as(KL_div), self.kld_margin - KL_div) 119 | 120 | # average 121 | edge_loss = torch.mean(edge_loss[edge]) 122 | not_edge_loss = torch.mean(KL_div[~edge]) 123 | aaf_loss = edge_loss * self.kld_lambda_1 + not_edge_loss * self.kld_lambda_2 124 | return aaf_loss 125 | -------------------------------------------------------------------------------- /model/sync_bn/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | 16 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 17 | 18 | 19 | class FutureResult(object): 20 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 21 | 22 | def __init__(self): 23 | self._result = None 24 | self._lock = threading.Lock() 25 | self._cond = threading.Condition(self._lock) 26 | 27 | def put(self, result): 28 | with self._lock: 29 | assert self._result is None, 'Previous result has\'t been fetched.' 30 | self._result = result 31 | self._cond.notify() 32 | 33 | def get(self): 34 | with self._lock: 35 | if self._result is None: 36 | self._cond.wait() 37 | 38 | res = self._result 39 | self._result = None 40 | return res 41 | 42 | 43 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 44 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 45 | 46 | 47 | class SlavePipe(_SlavePipeBase): 48 | """Pipe for master-slave communication.""" 49 | 50 | def run_slave(self, msg): 51 | self.queue.put((self.identifier, msg)) 52 | ret = self.result.get() 53 | self.queue.put(True) 54 | return ret 55 | 56 | 57 | class SyncMaster(object): 58 | """An abstract `SyncMaster` object. 59 | 60 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 61 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 62 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 63 | and passed to a registered callback. 64 | - After receiving the messages, the master device should gather the information and determine to message passed 65 | back to each slave devices. 66 | """ 67 | 68 | def __init__(self, master_callback): 69 | """ 70 | 71 | Args: 72 | master_callback: a callback to be invoked after having collected messages from slave devices. 73 | """ 74 | self._master_callback = master_callback 75 | self._queue = queue.Queue() 76 | self._registry = collections.OrderedDict() 77 | self._activated = False 78 | 79 | def register_slave(self, identifier): 80 | """ 81 | Register an slave device. 82 | 83 | Args: 84 | identifier: an identifier, usually is the device id. 85 | 86 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 87 | 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | 104 | Args: 105 | master_msg: the message that the master want to send to itself. This will be placed as the first 106 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 107 | 108 | Returns: the message to be sent back to the master device. 109 | 110 | """ 111 | self._activated = True 112 | 113 | intermediates = [(0, master_msg)] 114 | for i in range(self.nr_slaves): 115 | intermediates.append(self._queue.get()) 116 | 117 | results = self._master_callback(intermediates) 118 | assert results[0][0] == 0, 'The first result should belongs to the master.' 119 | 120 | for i, res in results: 121 | if i == 0: 122 | continue 123 | self._registry[i].result.put(res) 124 | 125 | for i in range(self.nr_slaves): 126 | assert self._queue.get() is True 127 | 128 | return results[0][1] 129 | 130 | @property 131 | def nr_slaves(self): 132 | return len(self._registry) -------------------------------------------------------------------------------- /model/psp.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | """ 3 | some training utils. 4 | reference: 5 | https://github.com/zhanghang1989/PyTorch-Encoding 6 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from RMI.model import net_factory 17 | from RMI.utils import model_init 18 | 19 | 20 | __all__ = ['PSPModule', 'PSPNet'] 21 | 22 | # the feature map used to calculated the auxiliary loss 23 | pspnet_aux_end_point_dict = { 24 | 'resnet50': 'layer3', 25 | 'resnet101': 'layer3', 26 | } 27 | 28 | # https://discuss.pytorch.org/t/whats-the-difference-between-nn-relu-and-nn-relu-inplace-true/948 29 | # inplace ReLU save more memory. 30 | _IS_ReLU_INPLACE = True 31 | 32 | 33 | class PSPModule(nn.Module): 34 | """The pyramid pooling module of the PSPNet.""" 35 | def __init__(self, 36 | in_channels, 37 | depth=512, 38 | pool_sizes=[1, 2, 3, 6], 39 | norm_layer=nn.BatchNorm2d, 40 | bn_mom=0.05): 41 | super(PSPModule, self).__init__() 42 | self.in_channels = in_channels 43 | self.depth = depth 44 | self.norm_layer = norm_layer 45 | self.pool_sizes = pool_sizes 46 | self.bn_mom = bn_mom 47 | self.pools = nn.ModuleList([self._pooling(size) for size in self.pool_sizes]) 48 | 49 | # fused conv layers, # 2048 + 4 * depth 50 | self.fuse_conv = nn.Sequential( 51 | nn.Conv2d(self.in_channels + 4 * depth, out_channels=depth, kernel_size=3, padding=1, bias=False), 52 | norm_layer(depth, momentum=self.bn_mom), 53 | nn.ReLU(inplace=_IS_ReLU_INPLACE), 54 | nn.Dropout2d(0.1, inplace=False) 55 | ) 56 | 57 | def _pooling(self, size): 58 | return nn.Sequential(nn.AdaptiveAvgPool2d(output_size=(size, size)), 59 | nn.Conv2d(self.in_channels, self.depth, kernel_size=1, bias=False), 60 | self.norm_layer(self.depth, momentum=self.bn_mom), 61 | nn.ReLU(inplace=_IS_ReLU_INPLACE) 62 | ) 63 | 64 | def forward(self, inputs): 65 | h, w = inputs.shape[2:] 66 | output_slices = [inputs] 67 | # pyramid pooling 68 | for i, size in enumerate(self.pool_sizes): 69 | pool = self.pools[i](inputs) 70 | out = F.interpolate(pool, size=(h, w), mode='bilinear', align_corners=True) 71 | output_slices.append(out) 72 | # concat and fuse 73 | outputs = torch.cat(output_slices, dim=1) 74 | outputs = self.fuse_conv(outputs) 75 | 76 | return outputs 77 | 78 | 79 | class PSPNet(nn.Module): 80 | def __init__(self, num_classes=21, 81 | output_stride=16, 82 | backbone='resnet50', 83 | norm_layer=nn.BatchNorm2d, 84 | bn_mom=0.01, 85 | depth_aux_branch=256, 86 | pretrained=True, 87 | freeze_bn=False): 88 | super(PSPNet, self).__init__() 89 | self.num_classes = num_classes 90 | self.bn_mom = bn_mom 91 | self.aux_key = pspnet_aux_end_point_dict[backbone] 92 | # backbone 93 | self.backbone = net_factory.get_backbone_net(output_stride=output_stride, 94 | pretrained=pretrained, 95 | norm_layer=norm_layer, 96 | bn_mom=bn_mom, 97 | root_beta=True) 98 | # pyramid pooling module 99 | self.psp_module = PSPModule(in_channels=2048, 100 | depth=512, 101 | pool_sizes=[1, 2, 3, 6], 102 | norm_layer=norm_layer, 103 | bn_mom=bn_mom) 104 | # main branch 105 | self.main_branch = nn.Conv2d(in_channels=512, out_channels=self.num_classes, kernel_size=1) 106 | 107 | # auxiliary branch 108 | self.aux_branch = nn.Sequential( 109 | nn.Conv2d(in_channels=1024, out_channels=depth_aux_branch, kernel_size=3, padding=1, bias=False), 110 | norm_layer(depth_aux_branch, momentum=self.bn_mom), 111 | nn.ReLU(inplace=_IS_ReLU_INPLACE), 112 | nn.Dropout2d(0.1, inplace=False), 113 | nn.Conv2d(in_channels=depth_aux_branch, out_channels=self.num_classes, kernel_size=1) 114 | ) 115 | 116 | # initialize weights 117 | model_init.init_weights([self.psp_module, self.main_branch, self.aux_branch], 118 | norm_layer=norm_layer, 119 | bn_momentum=bn_mom) 120 | 121 | if freeze_bn: 122 | self.freeze_bn() 123 | 124 | def forward(self, inputs): 125 | h, w = inputs.shape[2:] 126 | x, end_points = self.backbone(inputs) 127 | 128 | # main branch 129 | x = self.psp_module(x) 130 | x_small = self.main_branch(x) 131 | x = F.interpolate(x_small, size=(h, w), mode='bilinear', align_corners=True) 132 | 133 | # auxiliary out for training 134 | x_aux_small = self.aux_branch(end_points[self.aux_key]) 135 | x_aux = F.interpolate(x_aux_small, size=(h, w), mode='bilinear', align_corners=True) 136 | return x, x_aux, x_small, x_aux_small 137 | 138 | def freeze_bn(self): 139 | """freeze bn""" 140 | for m in self.modules(): 141 | if isinstance(m, (self.norm_layer, nn.BatchNorm2d)): 142 | m.eval() 143 | -------------------------------------------------------------------------------- /crf/crf_refine_test.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import os 4 | import time 5 | import timeit 6 | import argparse 7 | import numpy as np 8 | 9 | #import cv2 10 | from PIL import Image 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from RMI import parser_params 15 | 16 | from RMI.crf import crf 17 | from RMI.model import psp, deeplab 18 | from RMI.dataloaders import factory 19 | from RMI.utils.metrics import Evaluator 20 | 21 | 22 | # A map from segmentation name to model object. 23 | seg_model_obj_dict = { 24 | 'pspnet': psp.PSPNet, 25 | 'deeplabv3': deeplab.DeepLabv3, 26 | 'deeplabv3+': deeplab.DeepLabv3Plus, 27 | } 28 | 29 | 30 | class Trainer(object): 31 | def __init__(self, args): 32 | """initialize the Trainer""" 33 | # about gpus 34 | self.cuda = args.cuda 35 | self.gpu_ids = args.gpu_ids 36 | self.num_gpus = len(self.gpu_ids) 37 | self.crf_iter_steps = args.crf_iter_steps 38 | self.output_dir = args.output_dir 39 | # define dataloader 40 | self.val_loader = factory.get_dataset(args.data_dir, 41 | batch_size=1, 42 | dataset=args.dataset, 43 | split=args.train_split) 44 | self.nclass = self.val_loader.NUM_CLASSES 45 | # define network 46 | assert args.seg_model in seg_model_obj_dict.keys() 47 | self.seg_model = args.seg_model 48 | self.model = seg_model_obj_dict[self.seg_model](num_classes=self.nclass, 49 | backbone=args.backbone, 50 | output_stride=args.out_stride, 51 | norm_layer=torch.nn.BatchNorm2d, 52 | bn_mom=args.bn_mom, 53 | freeze_bn=True) 54 | 55 | # define evaluator 56 | self.evaluator = Evaluator(self.nclass) 57 | 58 | # using cuda 59 | if args.cuda: 60 | self.model = torch.nn.DataParallel(self.model, device_ids=self.gpu_ids) 61 | #patch_replication_callback(self.model) 62 | self.model = self.model.cuda() 63 | 64 | # resuming checkpoint 65 | if args.resume is not None: 66 | if not os.path.isfile(args.resume): 67 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 68 | print('Restore parameters from the {}'.format(args.resume)) 69 | checkpoint = torch.load(args.resume) 70 | self.global_step = checkpoint['global_step'] 71 | 72 | if args.cuda: 73 | self.model.module.load_state_dict(checkpoint['state_dict']) 74 | else: 75 | self.model.load_state_dict(checkpoint['state_dict']) 76 | 77 | def validation(self): 78 | """validation procedure 79 | """ 80 | # set validation mode 81 | self.model.eval() 82 | self.evaluator.reset() 83 | crf_100_steps = 0.0 84 | start = timeit.default_timer() 85 | for i in range(len(self.val_loader)): 86 | #for i, sample in enumerate(self.val_loader): 87 | sample = self.val_loader[i] 88 | image = sample['image'] 89 | #image = image.repeat(self.num_gpus, 1, 1, 1) 90 | #print("{}-th sample, Image shape {}, label shape {}".format(i + 1, image.size(), target.size())) 91 | if self.cuda: 92 | image = image.cuda() 93 | image = image.unsqueeze(dim=0) 94 | # forward 95 | with torch.no_grad(): 96 | output = self.model(image) 97 | # the output of the pspnet is a tuple 98 | if self.seg_model == 'pspnet': 99 | output = output[0] 100 | 101 | # get probs, shape [N, C, H, W] --> [N, H, W, C] 102 | probs = F.softmax(output, dim=1).permute(0, 2, 3, 1).squeeze_() 103 | probs_np = probs.data.cpu().numpy() 104 | #pred = output.data.cpu().numpy() 105 | 106 | # CRF post-processing 107 | image_name = self.val_loader.image_lists[i] 108 | #real_image = cv2.cvtColor(cv2.imread(image_name), cv2.COLOR_BGR2RGB) 109 | 110 | real_image = Image.open(image_name).convert('RGB') 111 | real_image = np.array(real_image).astype(np.uint8) 112 | 113 | crf_start = timeit.default_timer() 114 | pred = crf.dense_crf(real_image=real_image, probs=probs_np, iter_steps=self.crf_iter_steps) 115 | crf_end = timeit.default_timer() 116 | crf_100_steps += (crf_end - crf_start) 117 | # save output 118 | path_to_output = os.path.join(self.output_dir, self.val_loader.image_ids[i] + '.png') 119 | result = Image.fromarray(pred.astype(np.uint8)) 120 | result.save(path_to_output) 121 | #cv2.imwrite(path_to_output, pred) 122 | # report time of CRF 123 | if not i % 100: 124 | stop = timeit.default_timer() 125 | print("current step = {} ({:.3f} sec), crf time {:.3f} sec". 126 | format(i, stop - start, crf_100_steps)) 127 | crf_100_steps = 0.0 128 | start = timeit.default_timer() 129 | 130 | 131 | def main(): 132 | # get the parameters 133 | parser = argparse.ArgumentParser(description="PyTorch Segmentation Model Training") 134 | args = parser_params.add_parser_params(parser) 135 | print(args) 136 | 137 | torch.manual_seed(args.seed) 138 | trainer = Trainer(args) 139 | start_time = time.time() 140 | trainer.validation() 141 | total_time = time.time() - start_time 142 | print("The inference time is {:.5f} sec".format(total_time)) 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /model/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /dataloaders/datasets/camvid.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | """ 4 | dataloader for CamVid dataset 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import numpy as np 13 | from PIL import Image 14 | from torch.utils import data 15 | from torchvision import transforms 16 | from RMI.dataloaders import custom_transforms as tr 17 | 18 | 19 | __all__ = ['CamVidSegmentation'] 20 | 21 | 22 | # CamVid dataset statistics 23 | _CamVid_R_MEAN = 100 24 | _CamVid_G_MEAN = 103 25 | _CamVid_B_MEAN = 106 26 | 27 | _CamVid_R_STD = 75.61 28 | _CamVid_G_STD = 77.81 29 | _CamVid_B_STD = 76.70 30 | 31 | 32 | # CamVid 33 | camvid_label_colours = [(128, 128, 128), #0=Sky 34 | # 1=Building, 2=Pole, 3=Road, 4=Pavement, 5=Tree 35 | (128, 0, 0), (192, 192, 128), (128, 64, 128), (60, 40, 222), (128, 128, 0), 36 | # 6=SignSymbol, 7=Fence, 8=Car, 9=Pedestrian, 10=Bicyclist 37 | (192, 128, 128), (64, 64, 128), (64, 0, 128), (64, 64, 0), (0, 128, 192), 38 | # 11=Unlabelled 39 | (0, 0, 0)] 40 | 41 | 42 | class CamVidSegmentation(data.Dataset): 43 | NUM_CLASSES = 12 44 | 45 | def __init__(self, 46 | data_dir, 47 | crop_size=479, 48 | split="train", 49 | min_scale=0.75, 50 | max_scale=1.25, 51 | step_size=0.0): 52 | """ 53 | Only support the gtFine part. 54 | Args: 55 | data_dir: path to CamVidscapes dataset directory. 56 | crop_size: the crop size. 57 | split: ["train", val", "test"]. 58 | """ 59 | super().__init__() 60 | # dataset dir 61 | self.data_dir = data_dir 62 | self.split = split 63 | 64 | assert self.split in ['train', 'val', 'test', 'trainval'] 65 | self.data_list_file = os.path.join(self.data_dir, '{}.txt'.format(self.split)) 66 | self.iamge_dir = os.path.join(self.data_dir, self.split) 67 | self.label_dir = os.path.join(self.data_dir, '{}annot'.format(self.split)) 68 | 69 | # crop size and scales 70 | self.crop_size = crop_size 71 | self.min_scale = min_scale 72 | self.max_scale = max_scale 73 | self.step_size = step_size 74 | 75 | # dataset info 76 | self.mean = (_CamVid_R_MEAN, _CamVid_G_MEAN, _CamVid_B_MEAN) 77 | self.std = (_CamVid_R_STD, _CamVid_G_STD, _CamVid_B_STD) 78 | self.ignore_label = 255 79 | 80 | # read file list 81 | with open(self.data_list_file, "r") as f: 82 | lines = f.read().splitlines() 83 | lines = [line.strip().split(' ')[0] for line in lines] 84 | 85 | # extract the id_now 86 | self.image_ids = [id_now.split('/')[-1] for id_now in lines] 87 | # the file list, all are *.png files 88 | self.image_lists = [os.path.join(self.iamge_dir, filename) for filename in self.image_ids] 89 | self.label_lists = [os.path.join(self.label_dir, filename) for filename in self.image_ids] 90 | 91 | assert (len(self.image_lists) == len(self.label_lists)) 92 | 93 | # print the dataset info 94 | print('Number of image_lists in {}: {:d}'.format(split, len(self.image_lists))) 95 | 96 | def __len__(self): 97 | """len() method""" 98 | return len(self.image_lists) 99 | 100 | def __getitem__(self, index): 101 | """how to get the data""" 102 | _image, _label = self._make_img_gt_point_pair(index) 103 | sample = {'image': _image, 'label': _label} 104 | 105 | if 'train' in self.split: 106 | return self.transform_train(sample) 107 | elif 'val' in self.split or 'test' in self.split: 108 | return self.transform_val(sample) 109 | else: 110 | raise NotImplementedError 111 | 112 | def _make_img_gt_point_pair(self, index): 113 | """open the image and the gorund truth""" 114 | _image = Image.open(self.image_lists[index]).convert('RGB') 115 | _label = Image.open(self.label_lists[index]) 116 | return _image, _label 117 | 118 | def transform_train(self, sample): 119 | composed_transforms = transforms.Compose([ 120 | tr.RandomRescale(self.min_scale, self.max_scale, self.step_size), 121 | tr.RandomPadOrCrop(crop_height=self.crop_size, crop_width=self.crop_size, 122 | ignore_label=self.ignore_label, mean=self.mean), 123 | tr.RandomHorizontalFlip(), 124 | tr.Normalize(mean=self.mean, std=self.std), 125 | tr.ToTensor()]) 126 | 127 | return composed_transforms(sample) 128 | 129 | def transform_val(self, sample): 130 | """transform for validation""" 131 | composed_transforms = transforms.Compose([ 132 | tr.Normalize(mean=self.mean, std=self.std), 133 | tr.ToTensor()]) 134 | 135 | return composed_transforms(sample) 136 | 137 | def __str__(self): 138 | return 'CamVid(split=' + str(self.split) + ')' 139 | 140 | 141 | if __name__ == '__main__': 142 | # data dir 143 | data_dir = os.path.join("/home/zhaoshuai/dataset/CamVid") 144 | print(data_dir) 145 | dataset = CamVidSegmentation(data_dir=data_dir, split='trainval') 146 | #print(dataset.image_lists) 147 | image_mean = np.array([0.0, 0.0, 0.0]) 148 | cov_sum = np.array([0.0, 0.0, 0.0]) 149 | pixel_nums = 0.0 150 | # mean 151 | for filename in dataset.image_lists: 152 | image = Image.open(filename).convert('RGB') 153 | image = np.array(image).astype(np.float32) 154 | pixel_nums += image.shape[0] * image.shape[1] 155 | image_mean += np.sum(image, axis=(0, 1)) 156 | image_mean = image_mean / pixel_nums 157 | print(image_mean) 158 | # covariance 159 | for filename in dataset.image_lists: 160 | image = Image.open(filename).convert('RGB') 161 | image = np.array(image).astype(np.float32) 162 | cov_sum += np.sum(np.square(image - image_mean), axis=(0, 1)) 163 | image_cov = np.sqrt(cov_sum / (pixel_nums - 1)) 164 | print(image_cov) 165 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import os 4 | import time 5 | import timeit 6 | import argparse 7 | import numpy as np 8 | 9 | #import cv2 10 | from PIL import Image 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from RMI import parser_params, full_model 15 | 16 | from RMI.model import psp, deeplab 17 | from RMI.dataloaders import factory 18 | from RMI.utils.metrics import Evaluator 19 | from RMI.dataloaders import utils 20 | 21 | # A map from segmentation name to model object. 22 | seg_model_obj_dict = { 23 | 'pspnet': psp.PSPNet, 24 | 'deeplabv3': deeplab.DeepLabv3, 25 | 'deeplabv3+': deeplab.DeepLabv3Plus, 26 | } 27 | 28 | 29 | class Trainer(object): 30 | def __init__(self, args): 31 | """initialize the Trainer""" 32 | # about gpus 33 | self.cuda = args.cuda 34 | self.gpu_ids = args.gpu_ids 35 | self.num_gpus = len(self.gpu_ids) 36 | self.crf_iter_steps = args.crf_iter_steps 37 | self.output_dir = args.output_dir 38 | self.model = 'val' 39 | # define dataloader 40 | self.val_loader = factory.get_dataset(args.data_dir, 41 | batch_size=1, 42 | dataset=args.dataset, 43 | split=args.train_split) 44 | self.nclass = self.val_loader.NUM_CLASSES 45 | # define network 46 | assert args.seg_model in seg_model_obj_dict.keys() 47 | self.seg_model = args.seg_model 48 | self.seg_model = seg_model_obj_dict[self.seg_model](num_classes=self.nclass, 49 | backbone=args.backbone, 50 | output_stride=args.out_stride, 51 | norm_layer=torch.nn.BatchNorm2d, 52 | bn_mom=args.bn_mom, 53 | freeze_bn=True) 54 | 55 | # define criterion 56 | self.criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduction='mean') 57 | self.model = full_model.FullModel(seg_model=self.seg_model, 58 | model=self.model, 59 | criterion=self.criterion) 60 | 61 | # define evaluator 62 | self.evaluator = Evaluator(self.nclass) 63 | 64 | # using cuda 65 | if args.cuda: 66 | self.model = torch.nn.DataParallel(self.model, device_ids=self.gpu_ids) 67 | #patch_replication_callback(self.model) 68 | self.model = self.model.cuda() 69 | self.criterion = self.criterion.cuda() 70 | 71 | # resuming checkpoint 72 | if args.resume is not None: 73 | if not os.path.isfile(args.resume): 74 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 75 | print('Restore parameters from the {}'.format(args.resume)) 76 | checkpoint = torch.load(args.resume) 77 | self.global_step = checkpoint['global_step'] 78 | 79 | if args.cuda: 80 | self.model.module.load_state_dict(checkpoint['state_dict']) 81 | else: 82 | self.model.load_state_dict(checkpoint['state_dict']) 83 | 84 | def validation(self): 85 | """validation procedure 86 | """ 87 | # set validation mode 88 | self.model.eval() 89 | self.evaluator.reset() 90 | test_loss = 0.0 91 | start = timeit.default_timer() 92 | for i in range(len(self.val_loader)): 93 | #for i, sample in enumerate(self.val_loader): 94 | sample = self.val_loader[i] 95 | image, target = sample['image'], sample['label'] 96 | image, target = image.repeat(self.num_gpus, 1, 1, 1), target.repeat(self.num_gpus, 1, 1) 97 | #print("{}-th sample, Image shape {}, label shape {}".format(i + 1, image.size(), target.size())) 98 | if self.cuda: 99 | image, target = image.cuda(), target.cuda() 100 | # forward 101 | with torch.no_grad(): 102 | output = self.model(image) 103 | # the output of the pspnet is a tuple 104 | if self.seg_model == 'pspnet': 105 | output = output[0] 106 | loss = self.criterion(output, target.long()) 107 | test_loss += loss.item() 108 | 109 | # get probs, shape [N, C, H, W] --> [N, H, W, C] 110 | output = output.squeeze_() 111 | pred = output.data.cpu().numpy() 112 | pred = np.argmax(pred, axis=0) 113 | target = target.squeeze_().cpu().numpy() 114 | 115 | # save output 116 | color_img = True 117 | path_to_output = os.path.join(self.output_dir, self.val_loader.image_ids[i] + '.png') 118 | pred = pred.astype(np.uint8) 119 | if color_img: 120 | pass 121 | pred_color = utils.decode_segmap(pred, dataset='pascal') 122 | result = Image.fromarray(pred_color.astype(np.uint8)) 123 | result.save(path_to_output) 124 | else: 125 | result = Image.fromarray() 126 | result.save(path_to_output) 127 | # report time 128 | if not i % 100: 129 | stop = timeit.default_timer() 130 | print("current step = {} ({:.3f} sec)".format(i, stop - start)) 131 | start = timeit.default_timer() 132 | 133 | # Add batch sample into evaluator 134 | self.evaluator.add_batch(target, pred) 135 | 136 | # log and summary the validation results 137 | # log and summary the validation results 138 | px_acc = self.evaluator.pixel_accuracy_np() 139 | val_miou = self.evaluator.mean_iou_np(is_show_per_class=True) 140 | print("\nINFO:PyTorch: validation results: miou={:5f}, px_acc={:5f}, loss={:5f} \n". 141 | format(val_miou, px_acc, test_loss)) 142 | 143 | 144 | def main(): 145 | # get the parameters 146 | parser = argparse.ArgumentParser(description="PyTorch Segmentation Model Training") 147 | args = parser_params.add_parser_params(parser) 148 | print(args) 149 | 150 | torch.manual_seed(args.seed) 151 | trainer = Trainer(args) 152 | start_time = time.time() 153 | trainer.validation() 154 | total_time = time.time() - start_time 155 | print("The validation time is {:.5f} sec".format(total_time)) 156 | 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /dataloaders/datasets/pascal.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | dataloader for PASCAL VOC 2012 dataset 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import numpy as np 13 | from PIL import Image 14 | from torchvision import transforms 15 | from torch.utils.data import Dataset 16 | 17 | from RMI.dataloaders import custom_transforms as tr 18 | 19 | 20 | # PASCAL VOC 2012 dataset statistics 21 | _PASCAL_R_MEAN = 116 22 | _PASCAL_G_MEAN = 113 23 | _PASCAL_B_MEAN = 104 24 | 25 | _PASCAL_R_STD = 69.58 26 | _PASCAL_G_STD = 68.68 27 | _PASCAL_B_STD = 72.67 28 | 29 | 30 | class VOCSegmentation(Dataset): 31 | """PASCAL VOC 2012 dataset 32 | """ 33 | NUM_CLASSES = 21 34 | 35 | def __init__(self, 36 | data_dir, 37 | crop_size=513, 38 | split='train', 39 | min_scale=0.5, 40 | max_scale=2.0, 41 | step_size=0.25): 42 | """ 43 | Args: 44 | data_dir: path to VOC dataset directory. 45 | crop_size: the crop size. 46 | split: ["trainaug", "train", "trainval", "val", "test"]. 47 | """ 48 | super().__init__() 49 | # dataset dir 50 | self.data_dir = data_dir 51 | self.iamge_dir = os.path.join(self.data_dir, 'JPEGImages') 52 | self.label_dir = os.path.join(self.data_dir, 'SegmentationClassAug') 53 | 54 | assert split in ["trainaug", "train", "trainval", "val", "test"] 55 | self.split = split 56 | # txt lists of images 57 | list_file_dir = os.path.join(self.data_dir, 'ImageSets/Segmentation') 58 | 59 | # crop size and scales 60 | self.crop_size = crop_size 61 | self.min_scale = min_scale 62 | self.max_scale = max_scale 63 | self.step_size = step_size 64 | 65 | # dataset info 66 | self.mean = (_PASCAL_R_MEAN, _PASCAL_G_MEAN, _PASCAL_B_MEAN) 67 | self.std = (_PASCAL_R_STD, _PASCAL_G_STD, _PASCAL_B_STD) 68 | self.ignore_label = 255 69 | self.image_ids = [] 70 | self.image_lists = [] 71 | self.label_lists = [] 72 | 73 | # read the dataset file 74 | with open(os.path.join(os.path.join(list_file_dir, self.split + '.txt')), "r") as f: 75 | lines = f.read().splitlines() 76 | 77 | for line in lines: 78 | image_filename = os.path.join(self.iamge_dir, line + ".jpg") 79 | label_filename = os.path.join(self.label_dir, line + ".png") 80 | assert os.path.isfile(image_filename) 81 | if 'test' not in self.split: 82 | assert os.path.isfile(label_filename) 83 | self.image_ids.append(line) 84 | self.image_lists.append(image_filename) 85 | self.label_lists.append(label_filename) 86 | 87 | assert (len(self.image_lists) == len(self.label_lists)) 88 | 89 | # print the dataset info 90 | print('Number of image_lists in {}: {:d}'.format(split, len(self.image_lists))) 91 | 92 | def __len__(self): 93 | """len() method""" 94 | return len(self.image_lists) 95 | 96 | def __getitem__(self, index): 97 | """index method""" 98 | _image, _label = self._make_img_gt_point_pair(index) 99 | 100 | # different transforms for different splits 101 | if 'train' in self.split: 102 | sample = {'image': _image, 'label': _label} 103 | return self.transform_train(sample) 104 | elif 'val' in self.split: 105 | sample = {'image': _image, 'label': _label} 106 | return self.transform_val(sample) 107 | elif 'test' in self.split: 108 | sample = {'image': _image} 109 | return self.transform_test(sample) 110 | else: 111 | raise NotImplementedError 112 | 113 | def _make_img_gt_point_pair(self, index): 114 | """open the image and the gorund truth""" 115 | _image = Image.open(self.image_lists[index]).convert('RGB') 116 | if 'test' not in self.split: 117 | _label = Image.open(self.label_lists[index]) 118 | else: 119 | _label = None 120 | return _image, _label 121 | 122 | def transform_train(self, sample): 123 | composed_transforms = transforms.Compose([ 124 | tr.RandomRescale(self.min_scale, self.max_scale, self.step_size), 125 | tr.RandomPadOrCrop(crop_height=self.crop_size, crop_width=self.crop_size, 126 | ignore_label=self.ignore_label, mean=self.mean), 127 | tr.RandomHorizontalFlip(), 128 | tr.Normalize(mean=self.mean, std=self.std), 129 | tr.ToTensor()]) 130 | 131 | return composed_transforms(sample) 132 | 133 | def transform_val(self, sample): 134 | """transform for validation""" 135 | composed_transforms = transforms.Compose([ 136 | tr.Normalize(mean=self.mean, std=self.std), 137 | tr.ToTensor()]) 138 | 139 | return composed_transforms(sample) 140 | 141 | def transform_test(self, sample): 142 | """transform for validation""" 143 | composed_transforms = transforms.Compose([ 144 | tr.Normalize_Image(mean=self.mean, std=self.std), 145 | tr.ToTensor_Image()]) 146 | 147 | return composed_transforms(sample) 148 | 149 | def __str__(self): 150 | return 'VOC2012(split=' + str(self.split) + ')' 151 | 152 | 153 | if __name__ == '__main__': 154 | # data dir 155 | data_dir = os.path.join("/home/zhaoshuai/dataset/VOCdevkit/VOC2012") 156 | print(data_dir) 157 | dataset = VOCSegmentation(data_dir) 158 | #print(dataset.image_lists) 159 | image_mean = np.array([0.0, 0.0, 0.0]) 160 | cov_sum = np.array([0.0, 0.0, 0.0]) 161 | pixel_nums = 0.0 162 | # mean 163 | for filename in dataset.image_lists: 164 | image = Image.open(filename).convert('RGB') 165 | image = np.array(image).astype(np.float32) 166 | pixel_nums += image.shape[0] * image.shape[1] 167 | image_mean += np.sum(image, axis=(0, 1)) 168 | image_mean = image_mean / pixel_nums 169 | print(image_mean) 170 | # covariance 171 | for filename in dataset.image_lists: 172 | image = Image.open(filename).convert('RGB') 173 | image = np.array(image).astype(np.float32) 174 | cov_sum += np.sum(np.square(image - image_mean), axis=(0, 1)) 175 | image_cov = np.sqrt(cov_sum / (pixel_nums - 1)) 176 | print(image_cov) 177 | -------------------------------------------------------------------------------- /crf/crf_refine.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import os 4 | import time 5 | import timeit 6 | import argparse 7 | import numpy as np 8 | 9 | #import cv2 10 | from PIL import Image 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from RMI import parser_params 15 | 16 | from RMI.crf import crf 17 | from RMI.model import psp, deeplab 18 | from RMI.dataloaders import factory 19 | from RMI.utils.metrics import Evaluator 20 | 21 | 22 | # A map from segmentation name to model object. 23 | seg_model_obj_dict = { 24 | 'pspnet': psp.PSPNet, 25 | 'deeplabv3': deeplab.DeepLabv3, 26 | 'deeplabv3+': deeplab.DeepLabv3Plus, 27 | } 28 | 29 | 30 | class Trainer(object): 31 | def __init__(self, args): 32 | """initialize the Trainer""" 33 | # about gpus 34 | self.cuda = args.cuda 35 | self.gpu_ids = args.gpu_ids 36 | self.num_gpus = len(self.gpu_ids) 37 | self.crf_iter_steps = args.crf_iter_steps 38 | self.output_dir = args.output_dir 39 | # define dataloader 40 | self.val_loader = factory.get_dataset(args.data_dir, 41 | batch_size=1, 42 | dataset=args.dataset, 43 | split=args.train_split) 44 | self.nclass = self.val_loader.NUM_CLASSES 45 | # define network 46 | assert args.seg_model in seg_model_obj_dict.keys() 47 | self.seg_model = args.seg_model 48 | self.model = seg_model_obj_dict[self.seg_model](num_classes=self.nclass, 49 | backbone=args.backbone, 50 | output_stride=args.out_stride, 51 | norm_layer=torch.nn.BatchNorm2d, 52 | bn_mom=args.bn_mom, 53 | freeze_bn=True) 54 | 55 | # define criterion 56 | self.criterion = torch.nn.CrossEntropyLoss(weight=None, ignore_index=255, reduction='mean') 57 | 58 | # define evaluator 59 | self.evaluator = Evaluator(self.nclass) 60 | 61 | # using cuda 62 | if args.cuda: 63 | self.model = torch.nn.DataParallel(self.model, device_ids=self.gpu_ids) 64 | #patch_replication_callback(self.model) 65 | self.model = self.model.cuda() 66 | self.criterion = self.criterion.cuda() 67 | 68 | # resuming checkpoint 69 | if args.resume is not None: 70 | if not os.path.isfile(args.resume): 71 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 72 | print('Restore parameters from the {}'.format(args.resume)) 73 | checkpoint = torch.load(args.resume) 74 | self.global_step = checkpoint['global_step'] 75 | 76 | if args.cuda: 77 | self.model.module.load_state_dict(checkpoint['state_dict']) 78 | else: 79 | self.model.load_state_dict(checkpoint['state_dict']) 80 | 81 | def validation(self): 82 | """validation procedure 83 | """ 84 | # set validation mode 85 | self.model.eval() 86 | self.evaluator.reset() 87 | test_loss = 0.0 88 | crf_100_steps = 0.0 89 | start = timeit.default_timer() 90 | for i in range(len(self.val_loader)): 91 | #for i, sample in enumerate(self.val_loader): 92 | sample = self.val_loader[i] 93 | image, target = sample['image'], sample['label'] 94 | image, target = image.repeat(self.num_gpus, 1, 1, 1), target.repeat(self.num_gpus, 1, 1) 95 | #print("{}-th sample, Image shape {}, label shape {}".format(i + 1, image.size(), target.size())) 96 | if self.cuda: 97 | image, target = image.cuda(), target.cuda() 98 | # forward 99 | with torch.no_grad(): 100 | output = self.model(image) 101 | # the output of the pspnet is a tuple 102 | if self.seg_model == 'pspnet': 103 | output = output[0] 104 | loss = self.criterion(output, target.long()) 105 | test_loss += loss.item() 106 | 107 | # get probs, shape [N, C, H, W] --> [N, H, W, C] 108 | probs = F.softmax(output, dim=1).permute(0, 2, 3, 1).squeeze_() 109 | probs_np = probs.data.cpu().numpy() 110 | #pred = output.data.cpu().numpy() 111 | target = target.squeeze_().cpu().numpy() 112 | 113 | # CRF post-processing 114 | image_name = self.val_loader.image_lists[i] 115 | #real_image = cv2.cvtColor(cv2.imread(image_name), cv2.COLOR_BGR2RGB) 116 | 117 | real_image = Image.open(image_name).convert('RGB') 118 | real_image = np.array(real_image).astype(np.uint8) 119 | 120 | crf_start = timeit.default_timer() 121 | pred = crf.dense_crf(real_image=real_image, probs=probs_np, iter_steps=self.crf_iter_steps) 122 | crf_end = timeit.default_timer() 123 | crf_100_steps += (crf_end - crf_start) 124 | # save output 125 | path_to_output = os.path.join(self.output_dir, self.val_loader.image_ids[i] + '.png') 126 | result = Image.fromarray(pred.astype(np.uint8)) 127 | result.save(path_to_output) 128 | #cv2.imwrite(path_to_output, pred) 129 | # report time of CRF 130 | if not i % 100: 131 | stop = timeit.default_timer() 132 | print("current step = {} ({:.3f} sec), crf time {:.3f} sec". 133 | format(i, stop - start, crf_100_steps)) 134 | crf_100_steps = 0.0 135 | start = timeit.default_timer() 136 | #pred = np.argmax(pred, axis=1) 137 | #pred = np.argmax(pred, axis=2) 138 | # Add batch sample into evaluator 139 | self.evaluator.add_batch(target, pred) 140 | 141 | # log and summary the validation results 142 | # log and summary the validation results 143 | px_acc = self.evaluator.pixel_accuracy_np() 144 | val_miou = self.evaluator.mean_iou_np(is_show_per_class=True) 145 | print("\nINFO:PyTorch: validation results: miou={:5f}, px_acc={:5f}, loss={:5f} \n". 146 | format(val_miou, px_acc, test_loss)) 147 | print("The iteration steps of CRF is {}.".format(self.crf_iter_steps)) 148 | 149 | 150 | def main(): 151 | # get the parameters 152 | parser = argparse.ArgumentParser(description="PyTorch Segmentation Model Training") 153 | args = parser_params.add_parser_params(parser) 154 | print(args) 155 | 156 | torch.manual_seed(args.seed) 157 | trainer = Trainer(args) 158 | start_time = time.time() 159 | trainer.validation() 160 | total_time = time.time() - start_time 161 | print("The validation time is {:.5f} sec".format(total_time)) 162 | 163 | 164 | if __name__ == "__main__": 165 | main() 166 | -------------------------------------------------------------------------------- /dataloaders/factory.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from RMI.dataloaders.datasets import cityscapes, pascal, camvid 6 | 7 | __all__ = ['get_data_loader', 'get_dataset'] 8 | 9 | def get_data_loader(data_dir, 10 | batch_size=16, 11 | crop_size=513, 12 | dataset='pascal', 13 | split="train", 14 | num_workers=4, 15 | pin_memory=True, 16 | distributed=False): 17 | """get the dataset loader""" 18 | kwargs = {'num_workers': num_workers, 'pin_memory': pin_memory} 19 | if dataset == 'pascal': 20 | """PASCAL VOC dataset""" 21 | assert split in ['trainaug', 'trainval', 'train', 'val', 'test'] 22 | if 'train' in split: 23 | print("INFO:PyTorch: Using PASCAL VOC dataset, the training batch size {} and crop size is {}.". 24 | format(batch_size, crop_size)) 25 | train_set = pascal.VOCSegmentation(data_dir, crop_size, split, 26 | min_scale=0.5, 27 | max_scale=2.0, 28 | step_size=0.25) 29 | num_class = train_set.NUM_CLASSES 30 | # distributed training 31 | if distributed: 32 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 33 | else: 34 | train_sampler = None 35 | train_loader = DataLoader(train_set, 36 | batch_size=batch_size, 37 | shuffle=(train_sampler is None), 38 | sampler=train_sampler, 39 | drop_last=True, 40 | **kwargs) 41 | return train_loader, num_class 42 | else: 43 | val_set = pascal.VOCSegmentation(data_dir, crop_size, split) 44 | num_class = val_set.NUM_CLASSES 45 | val_loader = DataLoader(val_set, batch_size=1, shuffle=False, **kwargs) 46 | return val_loader, num_class 47 | elif dataset == 'cityscapes': 48 | """Cityscapes dataset""" 49 | assert split in ['train', 'val', 'test'] 50 | if 'train' in split: 51 | print("INFO:PyTorch: Using cityscapes dataset, the training batch size {} and crop size is {}.". 52 | format(batch_size, crop_size)) 53 | train_set = cityscapes.CityscapesSegmentation(data_dir, 54 | crop_size=crop_size, 55 | split=split, 56 | min_scale=0.75, 57 | max_scale=1.25, 58 | step_size=0.0) 59 | num_class = train_set.NUM_CLASSES 60 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 61 | return train_loader, num_class 62 | else: 63 | val_set = cityscapes.CityscapesSegmentation(data_dir, crop_size, split) 64 | num_class = val_set.NUM_CLASSES 65 | val_loader = DataLoader(val_set, batch_size=1, shuffle=False, **kwargs) 66 | return val_loader, num_class 67 | elif dataset == 'camvid': 68 | """CamVid dataset""" 69 | assert split in ['train', 'trainval', 'val', 'test'] 70 | if 'train' in split: 71 | print("INFO:PyTorch: Using camvid dataset, the training batch size {} and crop size is {}.". 72 | format(batch_size, crop_size)) 73 | train_set = camvid.CamVidSegmentation(data_dir, 74 | crop_size=crop_size, 75 | split=split, 76 | min_scale=0.75, 77 | max_scale=1.25, 78 | step_size=0.0) 79 | num_class = train_set.NUM_CLASSES 80 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) 81 | return train_loader, num_class 82 | else: 83 | val_set = camvid.CamVidSegmentation(data_dir, crop_size, split) 84 | num_class = val_set.NUM_CLASSES 85 | val_loader = DataLoader(val_set, batch_size=1, shuffle=False, **kwargs) 86 | return val_loader, num_class 87 | else: 88 | """raise error""" 89 | raise NotImplementedError("The DataLoader for {} is not not implemented.".format(dataset)) 90 | # 91 | return 92 | 93 | 94 | def get_dataset(data_dir, 95 | batch_size=16, 96 | crop_size=513, 97 | dataset='pascal', 98 | split="train"): 99 | """get the dataset""" 100 | if dataset == 'pascal': 101 | """PASCAL VOC dataset""" 102 | assert split in ['trainaug', 'trainval', 'train', 'val', 'test'] 103 | if 'train' in split: 104 | print("INFO:PyTorch: Using PASCAL VOC dataset, the training batch size {} and crop size is {}.". 105 | format(batch_size, crop_size)) 106 | train_set = pascal.VOCSegmentation(data_dir, crop_size, split, 107 | min_scale=0.5, 108 | max_scale=2.0, 109 | step_size=0.25) 110 | return train_set 111 | else: 112 | val_set = pascal.VOCSegmentation(data_dir, crop_size, split=split) 113 | return val_set 114 | elif dataset == 'cityscapes': 115 | """Cityscapes dataset""" 116 | assert split in ['train', 'val', 'test'] 117 | if 'train' in split: 118 | print("INFO:PyTorch: Using cityscapes dataset, the training batch size {} and crop size is {}.". 119 | format(batch_size, crop_size)) 120 | train_set = cityscapes.CityscapesSegmentation(data_dir, 121 | crop_size=crop_size, 122 | split=split, 123 | min_scale=0.75, 124 | max_scale=1.25, 125 | step_size=0.0) 126 | return train_set 127 | else: 128 | val_set = cityscapes.CityscapesSegmentation(data_dir, crop_size, split) 129 | return val_set 130 | elif dataset == 'camvid': 131 | """CamVid dataset""" 132 | assert split in ['train', 'trainval', 'val', 'test'] 133 | if 'train' in split: 134 | print("INFO:PyTorch: Using camvid dataset, the training batch size {} and crop size is {}.". 135 | format(batch_size, crop_size)) 136 | train_set = camvid.CamVidSegmentation(data_dir, 137 | crop_size=crop_size, 138 | split=split, 139 | min_scale=0.75, 140 | max_scale=1.25, 141 | step_size=0.0) 142 | return train_set 143 | else: 144 | val_set = camvid.CamVidSegmentation(data_dir, crop_size, split) 145 | return val_set 146 | else: 147 | """raise error""" 148 | raise NotImplementedError("The DataLoader for {} is not not implemented.".format(dataset)) 149 | return None 150 | -------------------------------------------------------------------------------- /losses/rmi/rmi_utils.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | # python 2.X, 3.X compatibility 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | #import os 9 | #import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | 14 | __all__ = ['map_get_pairs', 'log_det_by_cholesky'] 15 | 16 | 17 | def map_get_pairs(labels_4D, probs_4D, radius=3, is_combine=True): 18 | """get map pairs 19 | Args: 20 | labels_4D : labels, shape [N, C, H, W] 21 | probs_4D : probabilities, shape [N, C, H, W] 22 | radius : the square radius 23 | Return: 24 | tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)] 25 | """ 26 | # pad to ensure the following slice operation is valid 27 | #pad_beg = int(radius // 2) 28 | #pad_end = radius - pad_beg 29 | 30 | # the original height and width 31 | label_shape = labels_4D.size() 32 | h, w = label_shape[2], label_shape[3] 33 | new_h, new_w = h - (radius - 1), w - (radius - 1) 34 | # https://pytorch.org/docs/stable/nn.html?highlight=f%20pad#torch.nn.functional.pad 35 | #padding = (pad_beg, pad_end, pad_beg, pad_end) 36 | #labels_4D, probs_4D = F.pad(labels_4D, padding), F.pad(probs_4D, padding) 37 | 38 | # get the neighbors 39 | la_ns = [] 40 | pr_ns = [] 41 | #for x in range(0, radius, 1): 42 | for y in range(0, radius, 1): 43 | for x in range(0, radius, 1): 44 | la_now = labels_4D[:, :, y:y + new_h, x:x + new_w] 45 | pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w] 46 | la_ns.append(la_now) 47 | pr_ns.append(pr_now) 48 | 49 | if is_combine: 50 | # for calculating RMI 51 | pair_ns = la_ns + pr_ns 52 | p_vectors = torch.stack(pair_ns, dim=2) 53 | return p_vectors 54 | else: 55 | # for other purpose 56 | la_vectors = torch.stack(la_ns, dim=2) 57 | pr_vectors = torch.stack(pr_ns, dim=2) 58 | return la_vectors, pr_vectors 59 | 60 | 61 | def map_get_pairs_region(labels_4D, probs_4D, radius=3, is_combine=0, num_classeses=21): 62 | """get map pairs 63 | Args: 64 | labels_4D : labels, shape [N, C, H, W]. 65 | probs_4D : probabilities, shape [N, C, H, W]. 66 | radius : The side length of the square region. 67 | Return: 68 | A tensor with shape [N, C, radiu * radius, H // radius, W // raidius] 69 | """ 70 | kernel = torch.zeros([num_classeses, 1, radius, radius]).type_as(probs_4D) 71 | padding = radius // 2 72 | # get the neighbours 73 | la_ns = [] 74 | pr_ns = [] 75 | for y in range(0, radius, 1): 76 | for x in range(0, radius, 1): 77 | kernel_now = kernel.clone() 78 | kernel_now[:, :, y, x] = 1.0 79 | la_now = F.conv2d(labels_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) 80 | pr_now = F.conv2d(probs_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) 81 | la_ns.append(la_now) 82 | pr_ns.append(pr_now) 83 | 84 | if is_combine: 85 | # for calculating RMI 86 | pair_ns = la_ns + pr_ns 87 | p_vectors = torch.stack(pair_ns, dim=2) 88 | return p_vectors 89 | else: 90 | # for other purpose 91 | la_vectors = torch.stack(la_ns, dim=2) 92 | pr_vectors = torch.stack(pr_ns, dim=2) 93 | return la_vectors, pr_vectors 94 | return 95 | 96 | 97 | def log_det_by_cholesky(matrix): 98 | """ 99 | Args: 100 | matrix: matrix must be a positive define matrix. 101 | shape [N, C, D, D]. 102 | Ref: 103 | https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/linalg/linalg_impl.py 104 | """ 105 | # This uses the property that the log det(A) = 2 * sum(log(real(diag(C)))) 106 | # where C is the cholesky decomposition of A. 107 | chol = torch.cholesky(matrix) 108 | #return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-6), dim=-1) 109 | return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-8), dim=-1) 110 | 111 | 112 | def batch_cholesky_inverse(matrix): 113 | """ 114 | Args: matrix, 4-D tensor, [N, C, M, M]. 115 | matrix must be a symmetric positive define matrix. 116 | """ 117 | chol_low = torch.cholesky(matrix, upper=False) 118 | chol_low_inv = batch_low_tri_inv(chol_low) 119 | return torch.matmul(chol_low_inv.transpose(-2, -1), chol_low_inv) 120 | 121 | 122 | def batch_low_tri_inv(L): 123 | """ 124 | Batched inverse of lower triangular matrices 125 | Args: 126 | L : a lower triangular matrix 127 | Ref: 128 | https://www.pugetsystems.com/labs/hpc/PyTorch-for-Scientific-Computing 129 | """ 130 | n = L.shape[-1] 131 | invL = torch.zeros_like(L) 132 | for j in range(0, n): 133 | invL[..., j, j] = 1.0 / L[..., j, j] 134 | for i in range(j + 1, n): 135 | S = 0.0 136 | for k in range(0, i + 1): 137 | S = S - L[..., i, k] * invL[..., k, j].clone() 138 | invL[..., i, j] = S / L[..., i, i] 139 | return invL 140 | 141 | 142 | def log_det_by_cholesky_test(): 143 | """ 144 | test for function log_det_by_cholesky() 145 | """ 146 | a = torch.randn(1, 4, 4) 147 | a = torch.matmul(a, a.transpose(2, 1)) 148 | print(a) 149 | res_1 = torch.logdet(torch.squeeze(a)) 150 | res_2 = log_det_by_cholesky(a) 151 | print(res_1, res_2) 152 | 153 | 154 | def batch_inv_test(): 155 | """ 156 | test for function batch_cholesky_inverse() 157 | """ 158 | a = torch.randn(1, 1, 4, 4) 159 | a = torch.matmul(a, a.transpose(-2, -1)) 160 | print(a) 161 | res_1 = torch.inverse(a) 162 | res_2 = batch_cholesky_inverse(a) 163 | print(res_1, '\n', res_2) 164 | 165 | 166 | def mean_var_test(): 167 | x = torch.randn(3, 4) 168 | y = torch.randn(3, 4) 169 | 170 | x_mean = x.mean(dim=1, keepdim=True) 171 | x_sum = x.sum(dim=1, keepdim=True) / 2.0 172 | y_mean = y.mean(dim=1, keepdim=True) 173 | y_sum = y.sum(dim=1, keepdim=True) / 2.0 174 | 175 | x_var_1 = torch.matmul(x - x_mean, (x - x_mean).t()) 176 | x_var_2 = torch.matmul(x, x.t()) - torch.matmul(x_sum, x_sum.t()) 177 | xy_cov = torch.matmul(x - x_mean, (y - y_mean).t()) 178 | xy_cov_1 = torch.matmul(x, y.t()) - x_sum.matmul(y_sum.t()) 179 | 180 | print(x_var_1) 181 | print(x_var_2) 182 | 183 | print(xy_cov, '\n', xy_cov_1) 184 | 185 | 186 | if __name__ == '__main__': 187 | batch_inv_test() 188 | -------------------------------------------------------------------------------- /dataloaders/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | """ 4 | dataloader for Cityscapes dataset 5 | """ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import numpy as np 13 | from PIL import Image 14 | from torch.utils import data 15 | from torchvision import transforms 16 | from RMI.dataloaders import custom_transforms as tr 17 | 18 | __all__ = ['CityscapesSegmentation'] 19 | 20 | 21 | # Cityscapes dataset statistics 22 | _City_R_MEAN = 73 23 | _City_G_MEAN = 83 24 | _City_B_MEAN = 72 25 | 26 | _City_R_STD = 47.67 27 | _City_G_STD = 48.49 28 | _City_B_STD = 47.74 29 | 30 | 31 | class CityscapesSegmentation(data.Dataset): 32 | NUM_CLASSES = 19 33 | 34 | def __init__(self, 35 | data_dir, 36 | crop_size=769, 37 | split="train", 38 | min_scale=0.75, 39 | max_scale=1.25, 40 | step_size=0.0): 41 | """ 42 | Only support the gtFine part. 43 | Args: 44 | data_dir: path to Cityscapes dataset directory. 45 | crop_size: the crop size. 46 | split: ["train", val", "test"]. 47 | """ 48 | super().__init__() 49 | # dataset dir 50 | self.data_dir = data_dir 51 | self.iamge_dir = os.path.join(self.data_dir, 'leftImg8bit') 52 | self.label_dir = os.path.join(self.data_dir, 'gtFine') 53 | self.split = split 54 | 55 | assert self.split in ['train', 'val', 'test'] 56 | if self.split == 'train': 57 | self.data_list_file = os.path.join(self.iamge_dir, 'train_images.txt') 58 | elif self.split == 'val': 59 | self.data_list_file = os.path.join(self.iamge_dir, 'val_images.txt') 60 | elif self.split == 'test': 61 | self.data_list_file = os.path.join(self.iamge_dir, 'test_images.txt') 62 | 63 | # crop size and scales 64 | self.crop_size = crop_size 65 | self.min_scale = min_scale 66 | self.max_scale = max_scale 67 | self.step_size = step_size 68 | 69 | # dataset info 70 | self.mean = (_City_R_MEAN, _City_G_MEAN, _City_B_MEAN) 71 | self.std = (_City_R_STD, _City_G_STD, _City_B_STD) 72 | self.ignore_label = 255 73 | 74 | # We assume that the label is already converted. 75 | #self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 76 | #self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 77 | #self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', 78 | # 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', 79 | # 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 80 | # 'motorcycle', 'bicycle'] 81 | 82 | with open(self.data_list_file, "r") as f: 83 | lines = f.read().splitlines() 84 | lines = [line.strip().split(' ')[0] for line in lines] 85 | 86 | # extract the id_now 87 | #image_ids = [id_now.strip() for id_now in lines] 88 | self.image_ids = [id_now.split('/')[-1] for id_now in lines] 89 | self.image_ids = [id_now.replace('_leftImg8bit.png', '') for id_now in self.image_ids] 90 | 91 | # the file list 92 | image_base_dir = os.path.join(self.iamge_dir, self.split) 93 | label_base_dir = os.path.join(self.label_dir, self.split) 94 | self.image_lists = [os.path.join(image_base_dir, filename.split('_')[0], filename + '_leftImg8bit.png') 95 | for filename in self.image_ids] 96 | self.label_lists = [os.path.join(label_base_dir, filename.split('_')[0], filename + '_gtFine_trainIds.png') 97 | for filename in self.image_ids] 98 | 99 | assert (len(self.image_lists) == len(self.label_lists)) 100 | 101 | # print the dataset info 102 | print('Number of image_lists in {}: {:d}'.format(split, len(self.image_lists))) 103 | 104 | def __len__(self): 105 | """len() method""" 106 | return len(self.image_lists) 107 | 108 | def __getitem__(self, index): 109 | """how to get the data""" 110 | _image, _label = self._make_img_gt_point_pair(index) 111 | sample = {'image': _image, 'label': _label} 112 | 113 | if 'train' in self.split: 114 | return self.transform_train(sample) 115 | elif 'val' in self.split or 'test' in self.split: 116 | return self.transform_val(sample) 117 | else: 118 | raise NotImplementedError 119 | 120 | def _make_img_gt_point_pair(self, index): 121 | """open the image and the gorund truth""" 122 | _image = Image.open(self.image_lists[index]).convert('RGB') 123 | _label = Image.open(self.label_lists[index]) 124 | return _image, _label 125 | 126 | def transform_train(self, sample): 127 | composed_transforms = transforms.Compose([ 128 | tr.RandomRescale(self.min_scale, self.max_scale, self.step_size), 129 | tr.RandomPadOrCrop(crop_height=self.crop_size, crop_width=self.crop_size, 130 | ignore_label=self.ignore_label, mean=self.mean), 131 | tr.RandomHorizontalFlip(), 132 | tr.Normalize(mean=self.mean, std=self.std), 133 | tr.ToTensor()]) 134 | 135 | return composed_transforms(sample) 136 | 137 | def transform_val(self, sample): 138 | """transform for validation""" 139 | composed_transforms = transforms.Compose([ 140 | tr.Normalize(mean=self.mean, std=self.std), 141 | tr.ToTensor()]) 142 | 143 | return composed_transforms(sample) 144 | 145 | def __str__(self): 146 | return 'Cityscapes(split=' + str(self.split) + ')' 147 | 148 | 149 | if __name__ == '__main__': 150 | # data dir 151 | data_dir = os.path.join("/home/zhaoshuai/dataset/Cityscapes") 152 | print(data_dir) 153 | dataset = CityscapesSegmentation(data_dir=data_dir) 154 | #print(dataset.image_lists) 155 | image_mean = np.array([0.0, 0.0, 0.0]) 156 | cov_sum = np.array([0.0, 0.0, 0.0]) 157 | pixel_nums = 0.0 158 | # mean 159 | for filename in dataset.image_lists: 160 | image = Image.open(filename).convert('RGB') 161 | image = np.array(image).astype(np.float32) 162 | pixel_nums += image.shape[0] * image.shape[1] 163 | image_mean += np.sum(image, axis=(0, 1)) 164 | image_mean = image_mean / pixel_nums 165 | print(image_mean) 166 | # covariance 167 | for filename in dataset.image_lists: 168 | image = Image.open(filename).convert('RGB') 169 | image = np.array(image).astype(np.float32) 170 | cov_sum += np.sum(np.square(image - image_mean), axis=(0, 1)) 171 | image_cov = np.sqrt(cov_sum / (pixel_nums - 1)) 172 | print(image_cov) 173 | -------------------------------------------------------------------------------- /model/sync_bn/parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | import torch 14 | from torch.nn.parallel.data_parallel import DataParallel 15 | from torch.autograd import Variable, Function 16 | import torch.cuda.comm as comm 17 | from torch.nn.parallel._functions import Broadcast 18 | 19 | # from .parallel_apply import parallel_apply 20 | 21 | torch_ver = torch.__version__[:3] 22 | 23 | __all__ = ['allreduce', 'Reduce', 'DataParallelModel', 'patch_replication_callback'] 24 | 25 | 26 | def allreduce(*inputs): 27 | """Cross GPU all reduce autograd operation for calculate mean and 28 | variance in SyncBN. 29 | """ 30 | return AllReduce.apply(*inputs) 31 | 32 | class AllReduce(Function): 33 | @staticmethod 34 | def forward(ctx, num_inputs, *inputs): 35 | ctx.num_inputs = num_inputs 36 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 37 | inputs = [inputs[i:i + num_inputs] 38 | for i in range(0, len(inputs), num_inputs)] 39 | # sort before reduce sum 40 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 41 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 42 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 43 | return tuple([t for tensors in outputs for t in tensors]) 44 | 45 | @staticmethod 46 | def backward(ctx, *inputs): 47 | inputs = [i.data for i in inputs] 48 | inputs = [inputs[i:i + ctx.num_inputs] 49 | for i in range(0, len(inputs), ctx.num_inputs)] 50 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 51 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 52 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 53 | 54 | 55 | class Reduce(Function): 56 | @staticmethod 57 | def forward(ctx, *inputs): 58 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 59 | inputs = sorted(inputs, key=lambda i: i.get_device()) 60 | return comm.reduce_add(inputs) 61 | 62 | @staticmethod 63 | def backward(ctx, gradOutput): 64 | return Broadcast.apply(ctx.target_gpus, gradOutput) 65 | 66 | 67 | class CallbackContext(object): 68 | pass 69 | 70 | 71 | def execute_replication_callbacks(modules): 72 | """ 73 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 74 | 75 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 76 | 77 | Note that, as all modules are isomorphism, we assign each sub-module with a context 78 | (shared among multiple copies of this module on different devices). 79 | Through this context, different copies can share some information. 80 | 81 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 82 | of any slave copies. 83 | """ 84 | master_copy = modules[0] 85 | nr_modules = len(list(master_copy.modules())) 86 | ctxs = [CallbackContext() for _ in range(nr_modules)] 87 | 88 | for i, module in enumerate(modules): 89 | for j, m in enumerate(module.modules()): 90 | if hasattr(m, '__data_parallel_replicate__'): 91 | m.__data_parallel_replicate__(ctxs[j], i) 92 | 93 | 94 | class DataParallelModel(DataParallel): 95 | """Implements data parallelism at the module level. 96 | 97 | This container parallelizes the application of the given module by 98 | splitting the input across the specified devices by chunking in the 99 | batch dimension. 100 | In the forward pass, the module is replicated on each device, 101 | and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module. 102 | Note that the outputs are not gathered, please use compatible 103 | :class:`encoding.parallel.DataParallelCriterion`. 104 | 105 | The batch size should be larger than the number of GPUs used. It should 106 | also be an integer multiple of the number of GPUs so that each chunk is 107 | the same size (so that each GPU processes the same number of samples). 108 | 109 | Args: 110 | module: module to be parallelized 111 | device_ids: CUDA devices (default: all devices) 112 | 113 | Reference: 114 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 115 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 116 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 117 | 118 | Example:: 119 | 120 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 121 | >>> y = net(x) 122 | """ 123 | def gather(self, outputs, output_device): 124 | return outputs 125 | 126 | def replicate(self, module, device_ids): 127 | modules = super(DataParallelModel, self).replicate(module, device_ids) 128 | execute_replication_callbacks(modules) 129 | return modules 130 | 131 | # def parallel_apply(self, replicas, inputs, kwargs): 132 | # return parallel_apply(replicas, inputs, kwargs, 133 | # self.device_ids[:len(replicas)]) 134 | 135 | 136 | def patch_replication_callback(data_parallel): 137 | """ 138 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 139 | Useful when you have customized `DataParallel` implementation. 140 | 141 | Examples: 142 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 143 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 144 | > patch_replication_callback(sync_bn) 145 | # this is equivalent to 146 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 147 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 148 | """ 149 | 150 | assert isinstance(data_parallel, DataParallel) 151 | 152 | old_replicate = data_parallel.replicate 153 | 154 | @functools.wraps(old_replicate) 155 | def new_replicate(module, device_ids): 156 | modules = old_replicate(module, device_ids) 157 | execute_replication_callbacks(modules) 158 | return modules 159 | 160 | data_parallel.replicate = new_replicate 161 | -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # python PATH 3 | export PYTHONPATH="${PYTHONPATH}:${HOME}/github" 4 | 5 | # hyperparameter 6 | echo -n "input the gpu (seperate by comma (,) ): " 7 | read gpus 8 | export CUDA_VISIBLE_DEVICES=${gpus} 9 | echo "using gpus ${gpus}" 10 | # replace comma(,) with empty 11 | #gpus=${gpus//,/} 12 | # the number of characters 13 | #num_gpus=${#gpus} 14 | #echo "the number of gpus is ${num_gpus}" 15 | 16 | # choose the base model 17 | echo "" 18 | echo "0 -- deeplabv3" 19 | echo "1 -- deeplabv3+" 20 | echo "2 -- pspnet" 21 | echo -n "choose the base network: " 22 | read model_choose 23 | case ${model_choose} in 24 | 0 ) 25 | base_model="deeplabv3" 26 | ;; 27 | 1 ) 28 | base_model="deeplabv3+" 29 | ;; 30 | 2 ) 31 | base_model="pspnet" 32 | ;; 33 | * ) 34 | echo "The choice of the segmentation model is illegal!" 35 | exit 1 36 | ;; 37 | esac 38 | 39 | # choose the backbone 40 | echo "" 41 | echo "0 -- resnet_v1_50" 42 | echo "1 -- resnet_v1_101" 43 | echo "2 -- resnet_v1_152" 44 | echo -n "choose the base network: " 45 | read base_network 46 | 47 | case ${base_network} in 48 | 0 ) 49 | backbone="resnet50";; 50 | 1 ) 51 | backbone="resnet101";; 52 | 2 ) 53 | backbone="resnet152";; 54 | * ) 55 | echo "The choice of the base network is illegal!" 56 | exit 1 57 | ;; 58 | esac 59 | echo "The backbone is ${backbone}" 60 | echo "The base model is ${base_model}" 61 | 62 | 63 | # choose the loss 64 | echo "" 65 | echo "0 -- softmax cross entropy loss." 66 | echo "1 -- sigmoid binary cross entropy loss." 67 | echo "2 -- bce and RMI loss." 68 | echo "3 -- Affinity field loss." 69 | echo "5 -- Pyramid loss." 70 | echo -n "input the loss type of the first stage: " 71 | read loss_type 72 | 73 | # choose the dataset 74 | echo "" 75 | echo "0 -- PASCAL VOC2012 dataset" 76 | echo "1 -- Cityscapes" 77 | echo "2 -- CamVid" 78 | echo -n "input the dataset: " 79 | read dataset 80 | 81 | # choose the batch size 82 | echo "" 83 | echo -n "input the batch_size (4, 8, 12 or 16): " 84 | read batch_size 85 | 86 | if [ ${dataset} = 0 ] 87 | then 88 | ##################################################### 89 | # SET YOUR DATA DIR HERE 90 | ##################################################### 91 | data_dir="${HOME}/dataset/VOCdevkit/VOC2012" 92 | dataset=pascal 93 | # !!! train epochs change with batch size !!! 94 | crop_size=513 95 | if [ ${batch_size} = 16 ] 96 | then 97 | # first 30K on PASCAL VOC 98 | train_epochs_1=46 99 | eval_interval=2 100 | train_split=trainaug 101 | elif [ ${batch_size} = 12 ] 102 | then 103 | # first 30K on PASCAL VOC 104 | train_epochs_1=34 105 | eval_interval=2 106 | train_split=trainaug 107 | else 108 | train_epochs_1=23 109 | eval_interval=2 110 | train_split=trainaug 111 | fi 112 | elif [ ${dataset} = 1 ] 113 | then 114 | data_dir="${HOME}/dataset/Cityscapes/" 115 | dataset=cityscapes 116 | # 90K on Cityscapes 117 | if [ ${batch_size} = 8 ] 118 | then 119 | crop_size=769 120 | train_split=train 121 | train_epochs_1=160 122 | eval_interval=10 123 | elif [ ${batch_size} = 4 ] 124 | then 125 | crop_size=769 126 | train_split=train 127 | train_epochs_1=160 128 | eval_interval=10 129 | fi 130 | 131 | elif [ ${dataset} = 2 ] 132 | then 133 | data_dir="${HOME}/dataset/CamVid/" 134 | dataset=camvid 135 | # 90K on Cityscapes 136 | if [ ${batch_size} = 16 ] 137 | then 138 | crop_size=481 139 | train_split=trainval 140 | train_epochs_1=200 141 | eval_interval=10 142 | elif [ ${batch_size} = 4 ] 143 | then 144 | crop_size=479 145 | train_split=trainval 146 | train_epochs_1=200 147 | eval_interval=10 148 | fi 149 | else 150 | echo "The choice of the dataset is illegal!" 151 | exit 1 152 | fi 153 | echo "The data dir is ${data_dir}, the batch size is ${batch_size}." 154 | 155 | # learning rate 156 | lr_1=0.007 157 | lr_multiplier=10.0 158 | 159 | # slow start 160 | slow_start_steps=1500 161 | slow_start_lr=0.0001 162 | 163 | workers=8 164 | accumulation_steps=1 165 | 166 | # parameter of rmi 167 | rmi_pool_way=1 168 | rmi_pool_size=4 169 | rmi_pool_stride=4 170 | #rmi_pool_size=2 171 | #rmi_pool_stride=2 172 | rmi_radius=3 173 | loss_weight_lambda=0.5 174 | 175 | ##################################################### 176 | # STE YOUR MODEL DIR HERE 177 | ##################################################### 178 | pre_dir="rmi_model" 179 | # set the work dir 180 | work_dir="${HOME}/github/RMI" 181 | 182 | ##################################################### 183 | # STE YOUR RESUME CHECKPOINT HERE 184 | ##################################################### 185 | resume=None 186 | 187 | # create PID 188 | case ${loss_type} in 189 | 0 ) 190 | SPID="${pre_dir}/CE_${dataset}_pb${crop_size}-${batch_size}_net${model_choose}-${base_network}_n${num}" 191 | ;; 192 | 1 ) 193 | SPID="${pre_dir}/bce_${dataset}_pb${crop_size}-${batch_size}" 194 | SPID="${SPID}_net${model_choose}-${base_network}_n${num}" 195 | ;; 196 | 2 ) 197 | SPID="${pre_dir}/rmi_re_${dataset}_r${rmi_radius}_pw${rmi_pool_way}_st${rmi_pool_stride}_si${rmi_pool_size}" 198 | SPID="${SPID}_bp${crop_size}-${batch_size}" 199 | SPID="${SPID}_net${model_choose}-${base_network}-${loss_weight_lambda}_n${num}" 200 | ;; 201 | 3 ) 202 | SPID="${pre_dir}/affinity_${dataset}_bp${crop_size}-${batch_size}_net${model_choose}-${base_network}_n${num}" 203 | ;; 204 | 5 ) 205 | SPID="${pre_dir}/pyramid_${dataset}_pb${crop_size}-${batch_size}_net${model_choose}-${base_network}_n${num}" 206 | ;; 207 | esac 208 | 209 | 210 | model_dir=${HOME}/${SPID} 211 | proc_name=${SPID} 212 | 213 | 214 | # detect the directory 215 | if [ -d ${model_dir} ] 216 | then 217 | echo "save model into ${model_dir}" 218 | else 219 | mkdir ${model_dir} 220 | echo "make the directory ${model_dir}" 221 | fi 222 | 223 | 224 | # train the model 225 | python ${work_dir}/train.py --backbone ${backbone} \ 226 | --seg_model ${base_model} \ 227 | --slow_start_steps ${slow_start_steps} \ 228 | --slow_start_lr ${slow_start_lr} \ 229 | --init_lr ${lr_1} \ 230 | --lr_multiplier ${lr_multiplier} \ 231 | --model_dir ${model_dir} \ 232 | --workers ${workers} \ 233 | --epochs ${train_epochs_1} \ 234 | --batch_size ${batch_size} \ 235 | --crop_size ${crop_size} \ 236 | --gpu_ids ${gpus} \ 237 | --checkname deeplab-resnet \ 238 | --dataset ${dataset} \ 239 | --data_dir ${data_dir} \ 240 | --train_split ${train_split} \ 241 | --proc_name ${proc_name} \ 242 | --accumulation_steps ${accumulation_steps} \ 243 | --eval_interval ${eval_interval} \ 244 | --loss_type ${loss_type} \ 245 | --rmi_pool_way ${rmi_pool_way} \ 246 | --rmi_pool_size ${rmi_pool_size} \ 247 | --rmi_radius ${rmi_radius} \ 248 | --rmi_pool_stride ${rmi_pool_stride} \ 249 | --resume ${resume} \ 250 | --loss_weight_lambda ${loss_weight_lambda} 251 | 252 | echo "Training Finished!!!" 253 | -------------------------------------------------------------------------------- /model/sync_bn/src/gpu/common.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | static const unsigned WARP_SIZE = 32; 5 | 6 | // The maximum number of threads in a block 7 | static const unsigned MAX_BLOCK_SIZE = 512U; 8 | 9 | template 10 | struct ScalarConvert { 11 | static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; } 12 | }; 13 | 14 | // Number of threads in a block given an input size up to MAX_BLOCK_SIZE 15 | static int getNumThreads(int nElem) { 16 | int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; 17 | for (int i = 0; i != 5; ++i) { 18 | if (nElem <= threadSizes[i]) { 19 | return threadSizes[i]; 20 | } 21 | } 22 | return MAX_BLOCK_SIZE; 23 | } 24 | 25 | // Returns the index of the most significant 1 bit in `val`. 26 | __device__ __forceinline__ int getMSB(int val) { 27 | return 31 - __clz(val); 28 | } 29 | 30 | template 31 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) 32 | { 33 | #if CUDA_VERSION >= 9000 34 | return __shfl_xor_sync(mask, value, laneMask, width); 35 | #else 36 | return __shfl_xor(value, laneMask, width); 37 | #endif 38 | } 39 | 40 | // Sum across all threads within a warp 41 | template 42 | static __device__ __forceinline__ T warpSum(T val) { 43 | #if __CUDA_ARCH__ >= 300 44 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 45 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 46 | } 47 | #else 48 | __shared__ T values[MAX_BLOCK_SIZE]; 49 | values[threadIdx.x] = val; 50 | __threadfence_block(); 51 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 52 | for (int i = 1; i < WARP_SIZE; i++) { 53 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 54 | } 55 | #endif 56 | return val; 57 | } 58 | 59 | template 60 | struct Float2 { 61 | Acctype v1, v2; 62 | __device__ Float2() {} 63 | __device__ Float2(DType v1, DType v2) : v1(ScalarConvert::to(v1)), v2(ScalarConvert::to(v2)) {} 64 | __device__ Float2(DType v) : v1(ScalarConvert::to(v)), v2(ScalarConvert::to(v)) {} 65 | __device__ Float2(int v) : v1(ScalarConvert::to(v)), v2(ScalarConvert::to(v)) {} 66 | __device__ Float2& operator+=(const Float2& a) { 67 | v1 += a.v1; 68 | v2 += a.v2; 69 | return *this; 70 | } 71 | }; 72 | 73 | template 74 | static __device__ __forceinline__ Float2 warpSum(Float2 value) { 75 | value.v1 = warpSum(value.v1); 76 | value.v2 = warpSum(value.v2); 77 | return value; 78 | } 79 | 80 | template 81 | __device__ T reduceD( 82 | Op op, int b, int i, int k, int D) { 83 | T sum = 0; 84 | for (int x = threadIdx.x; x < D; x += blockDim.x) { 85 | sum += op(b,i,k,x); 86 | } 87 | // sum over NumThreads within a warp 88 | sum = warpSum(sum); 89 | 90 | // 'transpose', and reduce within warp again 91 | __shared__ T shared[32]; 92 | 93 | __syncthreads(); 94 | if (threadIdx.x % WARP_SIZE == 0) { 95 | if (threadIdx.x / WARP_SIZE < 32) { 96 | shared[threadIdx.x / WARP_SIZE] = sum; 97 | } 98 | } 99 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 100 | // zero out the other entries in shared 101 | shared[threadIdx.x] = (T) 0; 102 | } 103 | __syncthreads(); 104 | if (threadIdx.x / WARP_SIZE == 0) { 105 | sum = warpSum(shared[threadIdx.x]); 106 | if (threadIdx.x == 0) { 107 | shared[0] = sum; 108 | } 109 | } 110 | __syncthreads(); 111 | 112 | // Everyone picks it up, should be broadcast into the whole gradInput 113 | return shared[0]; 114 | } 115 | 116 | template 117 | __device__ T reduceN( 118 | Op op, int b, int k, int d, int N) { 119 | T sum = 0; 120 | for (int x = threadIdx.x; x < N; x += blockDim.x) { 121 | sum += op(b,x,k,d); 122 | } 123 | // sum over NumThreads within a warp 124 | sum = warpSum(sum); 125 | 126 | // 'transpose', and reduce within warp again 127 | __shared__ T shared[32]; 128 | 129 | __syncthreads(); 130 | if (threadIdx.x % WARP_SIZE == 0) { 131 | if (threadIdx.x / WARP_SIZE < 32) { 132 | shared[threadIdx.x / WARP_SIZE] = sum; 133 | } 134 | } 135 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 136 | // zero out the other entries in shared 137 | shared[threadIdx.x] = (T) 0; 138 | } 139 | __syncthreads(); 140 | if (threadIdx.x / WARP_SIZE == 0) { 141 | sum = warpSum(shared[threadIdx.x]); 142 | if (threadIdx.x == 0) { 143 | shared[0] = sum; 144 | } 145 | } 146 | __syncthreads(); 147 | 148 | // Everyone picks it up, should be broadcast into the whole gradInput 149 | return shared[0]; 150 | } 151 | 152 | template 153 | __device__ T reduceK( 154 | Op op, int b, int i, int d, int K) { 155 | T sum = 0; 156 | for (int x = threadIdx.x; x < K; x += blockDim.x) { 157 | sum += op(b,i,x,d); 158 | } 159 | // sum over NumThreads within a warp 160 | sum = warpSum(sum); 161 | 162 | // 'transpose', and reduce within warp again 163 | __shared__ T shared[32]; 164 | 165 | __syncthreads(); 166 | if (threadIdx.x % WARP_SIZE == 0) { 167 | if (threadIdx.x / WARP_SIZE < 32) { 168 | shared[threadIdx.x / WARP_SIZE] = sum; 169 | } 170 | } 171 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 172 | // zero out the other entries in shared 173 | shared[threadIdx.x] = (T) 0; 174 | } 175 | __syncthreads(); 176 | if (threadIdx.x / WARP_SIZE == 0) { 177 | sum = warpSum(shared[threadIdx.x]); 178 | if (threadIdx.x == 0) { 179 | shared[0] = sum; 180 | } 181 | } 182 | __syncthreads(); 183 | 184 | // Everyone picks it up, should be broadcast into the whole gradInput 185 | return shared[0]; 186 | } 187 | 188 | template 189 | __device__ T reduceBN( 190 | Op op, 191 | int k, int d, int B, int N) { 192 | T sum = 0; 193 | for (int batch = 0; batch < B; ++batch) { 194 | for (int x = threadIdx.x; x < N; x += blockDim.x) { 195 | sum += op(batch,x,k,d); 196 | } 197 | } 198 | // sum over NumThreads within a warp 199 | sum = warpSum(sum); 200 | // 'transpose', and reduce within warp again 201 | __shared__ T shared[32]; 202 | 203 | __syncthreads(); 204 | if (threadIdx.x % WARP_SIZE == 0) { 205 | if (threadIdx.x / WARP_SIZE < 32) { 206 | shared[threadIdx.x / WARP_SIZE] = sum; 207 | } 208 | } 209 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 210 | // zero out the other entries in shared 211 | shared[threadIdx.x] = (T) 0; 212 | } 213 | __syncthreads(); 214 | if (threadIdx.x / WARP_SIZE == 0) { 215 | sum = warpSum(shared[threadIdx.x]); 216 | if (threadIdx.x == 0) { 217 | shared[0] = sum; 218 | } 219 | } 220 | __syncthreads(); 221 | 222 | // Everyone picks it up, should be broadcast into the whole gradInput 223 | return shared[0]; 224 | } 225 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Region Mutual Information Loss for Semantic Segmentation 2 | 3 | ## Table of Contents 4 | 5 | 6 | * [Introduction](#Introduction) 7 | * [Features and TODO](#Features-and-TODO) 8 | * [Installation](#Installation) 9 | * [Training](#Training) 10 | * [Evaluation and Inference](#Evaluation-and-Inference) 11 | * [Experiments](#Experiments) 12 | * [Citations](#Citations) 13 | * [Acknowledgements](#Acknowledgements) 14 | 15 | 16 | ## Introduction 17 | 18 | This is the code for the NeurIPS 2019 paper [Region Mutual Information Loss for Semantic Segmentation](https://arxiv.org/abs/1910.12037). 19 | 20 | This paper proposes a region mutual information (RMI) loss to model the dependencies among pixels. RMI uses one pixel and its neighbor pixels to represent this pixel. Then for each pixel in an image, we get a multi-dimensional point that encodes the relationship between pixels, and the image is cast into a multi-dimensional distribution of these high-dimensional points. The prediction and ground truth thus can achieve high order consistency through maximizing the mutual information (MI) between their multi-dimensional distributions. 21 | 22 | ![img_intro](img/intro.png) 23 | 24 | ## Features and TODO 25 | 26 | - [x] Support different segmentation models, i.e., DeepLabv3, DeepLabv3+, PSPNet 27 | - [x] Multi-GPU training 28 | - [x] Multi-GPU Synchronized BatchNorm 29 | - [ ] Support different backbones, e.g., Mobilenet, Xception 30 | - [ ] Model pretrained on MS-COCO 31 | - [ ] Distributed training 32 | 33 | We are open to pull requests. 34 | 35 | ## Installation 36 | 37 | ### Install dependencies 38 | 39 | Please install PyTorch-1.1.0 and Python3.6.5. 40 | We highly recommend you to use our established PyTorch docker image - [zhaosssss/torch_lab](https://hub.docker.com/r/zhaosssss/torch_lab). 41 | ``` 42 | docker pull zhaosssss/torch_lab:1.1.0 43 | ``` 44 | If you have not installed docker, see https://docs.docker.com/. 45 | 46 | After you install docker and pull our image, you can `cd` to `script` directory and run 47 | ``` 48 | ./docker.sh 49 | ``` 50 | to create a running docker container. 51 | 52 | If you do not want to use docker, try 53 | ``` 54 | pip install -r requirements.txt 55 | ``` 56 | However, this is not suggested. 57 | 58 | 59 | ### Prepare data 60 | 61 | Generally, directories are organized as follow: 62 | ``` 63 | | 64 | |--dataset (save the dataset) 65 | |--models (save the output checkpoints) 66 | |--github (save the code) 67 | |--| 68 | |--|--RMI (the RMI code repository) 69 | |--|--|--crf 70 | |--|--|--dataloaders 71 | |--|--|--losses 72 | ... 73 | ``` 74 | 75 | 76 | - Download [PASCAL VOC training/validation data](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) 77 | (2GB tar file) and [augmented segmentation data](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0), extract and put them in the `dataset` directory. 78 | 79 | - `cd` to `github` directory and clone the RMI repo. 80 | 81 | As for the CamVid dataset, you can download at [SegNet-Tutorial](https://github.com/alexgkendall/SegNet-Tutorial). This is a processed version of [original CamVid dataset](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/). 82 | 83 | ## Training 84 | 85 | See `script/train.sh` for detailed information. 86 | Before start training, you should specify some variables in the `script/train.sh`. 87 | 88 | - `pre_dir`, where you save your output checkpoints. If you organize the dir as we suggest, it should be `pre_dir=models`. 89 | 90 | - `data_dir`, where you save your dataset. Besides, you should put the lists of the images in the dataset in a certain directory, check `dataloaders/datasets/pascal.py` to find how we organize the input pipeline. 91 | 92 | You can find more information about the arguments of the code in `parser_params.py`. 93 | ``` 94 | python parser_params.py --help 95 | 96 | usage: parser_params.py [-h] [--resume RESUME] [--checkname CHECKNAME] 97 | [--save_ckpt_steps SAVE_CKPT_STEPS] 98 | [--max_ckpt_nums MAX_CKPT_NUMS] 99 | [--model_dir MODEL_DIR] [--output_dir OUTPUT_DIR] 100 | [--seg_model {deeplabv3,deeplabv3+,pspnet}] 101 | [--backbone {resnet50,resnet101,resnet152,resnet50_beta,resnet101_beta,resnet152_beta}] 102 | [--out_stride OUT_STRIDE] [--batch_size N] 103 | [--accumulation_steps N] [--test_batch_size N] 104 | [--dataset {pascal,coco,cityscapes,camvid}] 105 | [--train_split {train,trainaug,trainval,val,test}] 106 | [--data_dir DATA_DIR] [--use_sbd] [--workers N] 107 | ... 108 | [--rmi_pool_size RMI_POOL_SIZE] 109 | [--rmi_pool_stride RMI_POOL_STRIDE] 110 | [--rmi_radius RMI_RADIUS] 111 | [--crf_iter_steps CRF_ITER_STEPS] 112 | [--local_rank LOCAL_RANK] [--world_size WORLD_SIZE] 113 | [--dist_backend DIST_BACKEND] 114 | [--multiprocessing_distributed] 115 | ``` 116 | 117 | 118 | After you set all the arguments properly, you can simply `cd` to `RMI/script` and run 119 | ``` 120 | ./train.sh 121 | ``` 122 | to start training. 123 | 124 | * Monitoring the training process through tensorboard 125 | 126 | ``` 127 | tensorboard --logdir=your_logdir --port=your_port 128 | ``` 129 | 130 | ![img_ten](img/tensorboard.png) 131 | 132 | * GPU memory usage 133 | 134 | Training a DeepLabv3 model with `output_stride=16`, `crop_size=513`, and `batch_size=16` needs 4 GTX 1080 GPUs (8GB) 135 | or 2 GTX TITAN X GPUs (12 GB) or 1 TITAN RTX GPUs (24 GB). 136 | 137 | 138 | ## Evaluation and Inference 139 | 140 | See `script/eval.sh` and `script/inference.sh` for detailed information. 141 | 142 | You should also specify some variables in the scripts. 143 | 144 | - `data_dir`, where you save your dataset. 145 | 146 | - `resume`, where your checkpoints locate. 147 | 148 | - `output_dir`, where the output data will be saved. 149 | 150 | Then run 151 | ``` 152 | ./eval.sh 153 | ``` 154 | or 155 | ``` 156 | ./inference.sh 157 | ``` 158 | 159 | 160 | ## Experiments 161 | 162 | ![img_res01](img/res_01.png) 163 | ![img_res02](img/res_02.png) 164 | 165 | ![img_res03](img/res_03.png) 166 | 167 | Some selected qualitative results on PASCAL VOC 2012 val set. 168 | Segmentation results of DeepLabv3+&RMI have richer details than DeepLabv3+&CE, e.g., small bumps of the airplane wing, branches of plants, limbs of cows and sheep, and so on. 169 | 170 | 171 | ## Citations 172 | 173 | If our paper and code are beneficial to your work, please cite: 174 | ``` 175 | @inproceedings{2019_zhao_rmi, 176 | author = {Shuai Zhao and 177 | Yang Wang and 178 | Zheng Yang and 179 | Deng Cai}, 180 | title = {Region Mutual Information Loss for Semantic Segmentation}, 181 | booktitle = {NeurIPS}, 182 | year = {2019}, 183 | } 184 | ``` 185 | 186 | If other related work in our code or paper also helps you, please cite the corresponding papers. 187 | 188 | ## Acknowledgements 189 | 190 | 191 | * [PyTorch-Encoding](https://github.com/zhanghang1989/PyTorch-Encoding) 192 | * [official-tensorflow-deeplab](https://github.com/tensorflow/models/tree/master/research/deeplab) 193 | * [Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) 194 | * [pytorch-deeplab-xception](https://github.com/jfzhang95/pytorch-deeplab-xception) 195 | * [pydensecrf](https://github.com/lucasb-eyer/pydensecrf) 196 | * [Adaptive_Affinity_Fields](https://github.com/twke18/Adaptive\_Affinity\_Fields) 197 | * [rishizek-tensorflow-deeplab-v3-plus](https://github.com/rishizek/tensorflow-deeplab-v3-plus) 198 | * [SegNet-Tutorial](https://github.com/alexgkendall/SegNet-Tutorial) 199 | 200 | 201 | ![img_cad](img/zju_cad.jpg) 202 | -------------------------------------------------------------------------------- /model/backbone/resnet_v1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | Reference: 5 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | """ 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.utils.model_zoo as model_zoo 15 | from RMI.utils import model_store 16 | 17 | __all__ = ['ResNet', 'resnet50', 'resnet101', 'resnet152'] 18 | 19 | 20 | zhanghang_dir = '~/.encoding/models' 21 | 22 | model_urls = { 23 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 24 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 25 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 26 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 27 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 28 | } 29 | 30 | 31 | model_dirs = { 32 | 'resnet50': '~/.torch/models/resnet50-19c8e357.pth', 33 | 'resnet101': '~/.torch/models/resnet101-5d3b4d8f.pth', 34 | #'resnet101': '/home/zhaoshuai/pretrained/resnet_v1_101_20160828/resnet.pth', 35 | } 36 | 37 | # https://discuss.pytorch.org/t/whats-the-difference-between-nn-relu-and-nn-relu-inplace-true/948 38 | # inplace ReLU save more memory. 39 | _IS_ReLU_INPLACE = True 40 | 41 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 42 | """3x3 convolution with padding""" 43 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 44 | padding=dilation, groups=groups, bias=False, dilation=dilation) 45 | 46 | 47 | def conv1x1(in_planes, out_planes, stride=1): 48 | """1x1 convolution""" 49 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 50 | 51 | 52 | class Bottleneck(nn.Module): 53 | """resnet v1 bottleneck block""" 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, 57 | planes, 58 | stride=1, 59 | downsample=None, 60 | groups=1, 61 | base_width=64, 62 | dilation=1, 63 | norm_layer=None, 64 | bn_mom=0.05): 65 | super(Bottleneck, self).__init__() 66 | if norm_layer is None: 67 | norm_layer = nn.BatchNorm2d 68 | width = int(planes * (base_width / 64.)) * groups 69 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 70 | self.conv1 = conv1x1(inplanes, width) 71 | self.bn1 = norm_layer(width, momentum=bn_mom) 72 | 73 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 74 | self.bn2 = norm_layer(width, momentum=bn_mom) 75 | 76 | self.conv3 = conv1x1(width, planes * self.expansion) 77 | self.bn3 = norm_layer(planes * self.expansion, momentum=bn_mom) 78 | 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | identity = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | identity = self.downsample(x) 99 | 100 | out += identity 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, block, 109 | layers, 110 | output_stride=16, 111 | zero_init_residual=True, 112 | groups=1, 113 | width_per_group=64, 114 | norm_layer=nn.BatchNorm2d, 115 | bn_mom=0.05, 116 | root_beta=True): 117 | super(ResNet, self).__init__() 118 | self._norm_layer = norm_layer 119 | self.inplanes = 128 if root_beta else 64 120 | self.dilation = 1 121 | self.bn_mom = bn_mom 122 | 123 | # stride and dilations 124 | assert output_stride in [8, 16] 125 | self.strides = [1, 2, 2 if output_stride == 16 else 1, 1] 126 | # slightly different with the official implementation 127 | self.dilations = [1, 1, 1, 1] 128 | 129 | self.groups = groups 130 | self.base_width = width_per_group 131 | 132 | # the network modules, use 3 conv3x3 layers to replace the one conv7x7 133 | if root_beta: 134 | self.conv1 = nn.Sequential( 135 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 136 | norm_layer(64, momentum=bn_mom), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 139 | norm_layer(64, momentum=bn_mom), 140 | nn.ReLU(inplace=True), 141 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), 142 | ) 143 | else: 144 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 145 | 146 | self.bn1 = norm_layer(self.inplanes, momentum=self.bn_mom) 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | # four stacked blocks 150 | self.layer1 = self._make_layer(block, 64, layers[0], stride=self.strides[0], dilation=self.dilations[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=self.strides[1], dilation=self.dilations[1]) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=self.strides[2], dilation=self.dilations[2]) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=self.strides[3], dilation=self.dilations[3]) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | if m.bias is not None: 159 | m.bias.data.zero_() 160 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, norm_layer)): 161 | nn.init.constant_(m.weight, 1) 162 | nn.init.constant_(m.bias, 0) 163 | 164 | # Zero-initialize the last BN in each residual branch, 165 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 166 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 167 | if zero_init_residual: 168 | for m in self.modules(): 169 | if isinstance(m, Bottleneck): 170 | nn.init.constant_(m.bn3.weight, 0) 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 173 | """construct layers""" 174 | norm_layer = self._norm_layer 175 | downsample = None 176 | 177 | if stride != 1 or self.inplanes != planes * block.expansion: 178 | downsample = nn.Sequential( 179 | conv1x1(self.inplanes, planes * block.expansion, stride), 180 | norm_layer(planes * block.expansion, momentum=self.bn_mom), 181 | ) 182 | 183 | layers = [] 184 | # the dialtion of the first layer 185 | dilation_first = 1 if dilation in [1, 2] else 2 186 | layers.append(block(self.inplanes, planes, stride, 187 | downsample, 188 | self.groups, 189 | self.base_width, 190 | dilation_first, 191 | norm_layer, 192 | bn_mom=self.bn_mom)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, 196 | groups=self.groups, 197 | base_width=self.base_width, 198 | dilation=dilation, 199 | norm_layer=norm_layer, 200 | bn_mom=self.bn_mom)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | end_points = {} 206 | x = self.conv1(x) 207 | x = self.bn1(x) 208 | x = self.relu(x) 209 | x = self.maxpool(x) 210 | 211 | x = self.layer1(x) 212 | end_points['layer1'] = x 213 | 214 | x = self.layer2(x) 215 | end_points['layer2'] = x 216 | 217 | x = self.layer3(x) 218 | end_points['layer3'] = x 219 | 220 | x = self.layer4(x) 221 | 222 | return x, end_points 223 | 224 | 225 | def _resnet(arch, block, layers, output_stride=16, pretrained=True, norm_layer=None, bn_mom=0.05, root_beta=True): 226 | model = ResNet(block, layers, output_stride=output_stride, norm_layer=norm_layer, bn_mom=bn_mom, root_beta=root_beta) 227 | if pretrained: 228 | if root_beta: 229 | old_dict = torch.load(model_store.get_model_file(arch, root=zhanghang_dir)) 230 | else: 231 | old_dict = model_zoo.load_url(model_urls[arch]) 232 | #old_dict = torch.load(model_dirs[arch])['state_dict'] 233 | #print(old_dict) 234 | model_dict = model.state_dict() 235 | old_dict = {k: v for k, v in old_dict.items() if (k in model_dict)} 236 | model_dict.update(old_dict) 237 | model.load_state_dict(model_dict) 238 | return model 239 | 240 | 241 | def resnet50(output_stride=16, pretrained=True, norm_layer=None, bn_mom=0.05, root_beta=True): 242 | """Constructs a ResNet-50 model. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], output_stride, pretrained, norm_layer, bn_mom) 247 | 248 | 249 | def resnet101(output_stride=16, pretrained=True, norm_layer=None, bn_mom=0.05, root_beta=True): 250 | """Constructs a ResNet-101 model. 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | """ 254 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], output_stride, pretrained, norm_layer, bn_mom) 255 | 256 | 257 | def resnet152(output_stride=16, pretrained=True, norm_layer=None, bn_mom=0.05, root_beta=True): 258 | """Constructs a ResNet-152 model. 259 | Args: 260 | pretrained (bool): If True, returns a model pre-trained on ImageNet 261 | """ 262 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], output_stride, pretrained, norm_layer, bn_mom) 263 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | """ 4 | some custom transforms 5 | """ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import torch 11 | #from torchvision import transforms 12 | 13 | import random 14 | import numpy as np 15 | from PIL import Image, ImageOps, ImageFilter 16 | 17 | 18 | class RandomRescale(object): 19 | """rescale an image and label with in target scale 20 | PIL image version""" 21 | def __init__(self, min_scale=0.5, max_scale=2.0, step_size=0.25): 22 | """initialize 23 | Args: 24 | min_scale: Min target scale. 25 | max_scale: Max target scale. 26 | """ 27 | self.min_scale = min_scale 28 | self.max_scale = max_scale 29 | self.step_size = step_size 30 | # discrete scales 31 | if (max_scale - min_scale) > step_size and step_size > 0.05: 32 | self.num_steps = int((max_scale - min_scale) / step_size + 1) 33 | self.scale_steps = np.linspace(self.min_scale, self.max_scale, self.num_steps) 34 | elif (max_scale - min_scale) > step_size and step_size < 0.05: 35 | self.num_steps = 0 36 | self.scale_steps = np.array([min_scale]) 37 | else: 38 | self.num_steps = 1 39 | self.scale_steps = np.array([min_scale]) 40 | 41 | def __call__(self, sample): 42 | """call method""" 43 | image, label = sample['image'], sample['label'] 44 | width, height = image.size 45 | # random scale 46 | if self.num_steps > 0: 47 | index = random.randint(0, self.num_steps - 1) 48 | scale_now = self.scale_steps[index] 49 | else: 50 | scale_now = random.uniform(self.min_scale, self.max_scale) 51 | new_width, new_height = int(scale_now * width), int(scale_now * height) 52 | # resize 53 | #image = image.resize(self.size, Image.BILINEAR) 54 | image = image.resize((new_width, new_height), Image.BICUBIC) 55 | label = label.resize((new_width, new_height), Image.NEAREST) 56 | 57 | return {'image': image, 58 | 'label': label} 59 | 60 | 61 | class RandomPadOrCrop(object): 62 | """Crops and/or pads an image to a target width and height 63 | PIL image version 64 | """ 65 | def __init__(self, crop_height, crop_width, ignore_label=255, mean=(125, 125, 125)): 66 | """ 67 | Args: 68 | crop_height: The new height. 69 | crop_width: The new width. 70 | ignore_label: Label class to be ignored. 71 | """ 72 | self.crop_height = crop_height 73 | self.crop_width = crop_width 74 | self.ignore_label = ignore_label 75 | self.mean = mean 76 | 77 | def __call__(self, sample): 78 | """call method""" 79 | image, label = sample['image'], sample['label'] 80 | width, height = image.size 81 | pad_width, pad_height = max(width, self.crop_width), max(height, self.crop_height) 82 | pad_width = self.crop_width - width if width < self.crop_width else 0 83 | pad_height = self.crop_height - height if height < self.crop_height else 0 84 | # pad the image with constant 85 | image = ImageOps.expand(image, border=(0, 0, pad_width, pad_height), fill=self.mean) 86 | label = ImageOps.expand(label, border=(0, 0, pad_width, pad_height), fill=self.ignore_label) 87 | # random crop image to crop_size 88 | new_w, new_h = image.size 89 | x1 = random.randint(0, new_w - self.crop_width) 90 | y1 = random.randint(0, new_h - self.crop_height) 91 | image = image.crop((x1, y1, x1 + self.crop_width, y1 + self.crop_height)) 92 | label = label.crop((x1, y1, x1 + self.crop_width, y1 + self.crop_height)) 93 | 94 | return {'image': image, 95 | 'label': label} 96 | 97 | 98 | class RandomHorizontalFlip(object): 99 | """Randomly flip an image and label horizontally (left to right). 100 | PIL image version""" 101 | def __call__(self, sample): 102 | """call method""" 103 | image, label = sample['image'], sample['label'] 104 | if random.random() < 0.5: 105 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 106 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 107 | 108 | return {'image': image, 109 | 'label': label} 110 | 111 | 112 | class Normalize(object): 113 | """Normalize a tensor image with mean and standard deviation. 114 | PIL image version. 115 | Args: 116 | mean (tuple): means for each channel. 117 | std (tuple): standard deviations for each channel. 118 | """ 119 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 120 | self.mean = mean 121 | self.std = std 122 | 123 | def __call__(self, sample): 124 | """call method""" 125 | image, label = sample['image'], sample['label'] 126 | image = np.array(image).astype(np.float32) 127 | label = np.array(label).astype(np.float32) 128 | #image /= 255.0 129 | image -= self.mean 130 | image /= self.std 131 | 132 | return {'image': image, 133 | 'label': label} 134 | 135 | 136 | class Normalize_Image(object): 137 | """Normalize a tensor image with mean and standard deviation. 138 | PIL image version. 139 | Args: 140 | mean (tuple): means for each channel. 141 | std (tuple): standard deviations for each channel. 142 | """ 143 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 144 | self.mean = mean 145 | self.std = std 146 | 147 | def __call__(self, sample): 148 | """call method""" 149 | image = sample['image'] 150 | image = np.array(image).astype(np.float32) 151 | #image /= 255.0 152 | image -= self.mean 153 | image /= self.std 154 | 155 | return {'image': image} 156 | 157 | 158 | class ToTensor(object): 159 | """Convert ndarrays in sample to Tensors.""" 160 | def __call__(self, sample): 161 | # swap color axis because 162 | # PIL image : W x H x C 163 | # numpy image: H x W x C 164 | # torch image: C X H X W 165 | image, label = sample['image'], sample['label'] 166 | # W x H x C -> H x W x C 167 | image = np.array(image).astype(np.float32).transpose((2, 0, 1)) 168 | label = np.array(label).astype(np.float32) 169 | # convet to torch tensor 170 | image = torch.from_numpy(image).float() 171 | label = torch.from_numpy(label).float() 172 | 173 | return {'image': image, 174 | 'label': label} 175 | 176 | 177 | class ToTensor_Image(object): 178 | """Convert ndarrays in sample to Tensors.""" 179 | def __call__(self, sample): 180 | # swap color axis because 181 | # PIL image : W x H x C 182 | # numpy image: H x W x C 183 | # torch image: C X H X W 184 | image = sample['image'] 185 | # W x H x C -> H x W x C 186 | image = np.array(image).astype(np.float32).transpose((2, 0, 1)) 187 | # convet to torch tensor 188 | image = torch.from_numpy(image).float() 189 | 190 | return {'image': image} 191 | 192 | 193 | class RandomRotate(object): 194 | def __init__(self, degree): 195 | self.degree = degree 196 | 197 | def __call__(self, sample): 198 | image, label = sample['image'], sample['label'] 199 | rotate_degree = random.uniform(-1*self.degree, self.degree) 200 | image = image.rotate(rotate_degree, Image.BILINEAR) 201 | label = label.rotate(rotate_degree, Image.NEAREST) 202 | 203 | return {'image': image, 204 | 'label': label} 205 | 206 | 207 | class RandomGaussianBlur(object): 208 | def __call__(self, sample): 209 | image = sample['image'] 210 | label = sample['label'] 211 | if random.random() < 0.5: 212 | image = image.filter(ImageFilter.GaussianBlur( 213 | radius=random.random())) 214 | 215 | return {'image': image, 216 | 'label': label} 217 | 218 | 219 | class FixScaleCrop(object): 220 | def __init__(self, crop_size): 221 | self.crop_size = crop_size 222 | 223 | def __call__(self, sample): 224 | image = sample['image'] 225 | label = sample['label'] 226 | w, h = image.size 227 | if w > h: 228 | oh = self.crop_size 229 | ow = int(1.0 * w * oh / h) 230 | else: 231 | ow = self.crop_size 232 | oh = int(1.0 * h * ow / w) 233 | image = image.resize((ow, oh), Image.BILINEAR) 234 | label = label.resize((ow, oh), Image.NEAREST) 235 | # center crop 236 | w, h = image.size 237 | x1 = int(round((w - self.crop_size) / 2.)) 238 | y1 = int(round((h - self.crop_size) / 2.)) 239 | image = image.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 240 | label = label.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 241 | 242 | return {'image': image, 243 | 'label': label} 244 | 245 | 246 | class FixedResize(object): 247 | """resize the image and label to fixed size""" 248 | def __init__(self, size): 249 | self.size = (size, size) # size: (h, w) 250 | 251 | def __call__(self, sample): 252 | image = sample['image'] 253 | label = sample['label'] 254 | 255 | assert image.size == label.size 256 | 257 | image = image.resize(self.size, Image.BILINEAR) 258 | label = label.resize(self.size, Image.NEAREST) 259 | 260 | return {'image': image, 261 | 'label': label} 262 | -------------------------------------------------------------------------------- /losses/rmi/rmi.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | """ 4 | The implementation of the paper: 5 | Region Mutual Information Loss for Semantic Segmentation. 6 | """ 7 | 8 | # python 2.X, 3.X compatibility 9 | from __future__ import print_function 10 | from __future__ import division 11 | from __future__ import absolute_import 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from RMI.losses.rmi import rmi_utils 18 | 19 | 20 | _euler_num = 2.718281828 # euler number 21 | _pi = 3.14159265 # pi 22 | _ln_2_pi = 1.837877 # ln(2 * pi) 23 | _CLIP_MIN = 1e-6 # min clip value after softmax or sigmoid operations 24 | _CLIP_MAX = 1.0 # max clip value after softmax or sigmoid operations 25 | _POS_ALPHA = 5e-4 # add this factor to ensure the AA^T is positive definite 26 | _IS_SUM = 1 # sum the loss per channel 27 | 28 | 29 | __all__ = ['RMILoss'] 30 | 31 | 32 | class RMILoss(nn.Module): 33 | """ 34 | region mutual information 35 | I(A, B) = H(A) + H(B) - H(A, B) 36 | This version need a lot of memory if do not dwonsample. 37 | """ 38 | def __init__(self, 39 | num_classes=21, 40 | rmi_radius=3, 41 | rmi_pool_way=0, 42 | rmi_pool_size=3, 43 | rmi_pool_stride=3, 44 | loss_weight_lambda=0.5, 45 | lambda_way=1): 46 | super(RMILoss, self).__init__() 47 | self.num_classes = num_classes 48 | # radius choices 49 | assert rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 50 | self.rmi_radius = rmi_radius 51 | assert rmi_pool_way in [0, 1, 2, 3] 52 | self.rmi_pool_way = rmi_pool_way 53 | 54 | # set the pool_size = rmi_pool_stride 55 | assert rmi_pool_size == rmi_pool_stride 56 | self.rmi_pool_size = rmi_pool_size 57 | self.rmi_pool_stride = rmi_pool_stride 58 | self.weight_lambda = loss_weight_lambda 59 | self.lambda_way = lambda_way 60 | 61 | # dimension of the distribution 62 | self.half_d = self.rmi_radius * self.rmi_radius 63 | self.d = 2 * self.half_d 64 | self.kernel_padding = self.rmi_pool_size // 2 65 | # ignore class 66 | self.ignore_index = 255 67 | 68 | def forward(self, logits_4D, labels_4D): 69 | loss = self.forward_sigmoid(logits_4D, labels_4D) 70 | #loss = self.forward_softmax_sigmoid(logits_4D, labels_4D) 71 | return loss 72 | 73 | def forward_softmax_sigmoid(self, logits_4D, labels_4D): 74 | """ 75 | Using both softmax and sigmoid operations. 76 | Args: 77 | logits_4D : [N, C, H, W], dtype=float32 78 | labels_4D : [N, H, W], dtype=long 79 | """ 80 | # PART I -- get the normal cross entropy loss 81 | normal_loss = F.cross_entropy(input=logits_4D, 82 | target=labels_4D.long(), 83 | ignore_index=self.ignore_index, 84 | reduction='mean') 85 | 86 | # PART II -- get the lower bound of the region mutual information 87 | # get the valid label and logits 88 | # valid label, [N, C, H, W] 89 | label_mask_3D = labels_4D < self.num_classes 90 | valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(), num_classes=self.num_classes).float() 91 | label_mask_3D = label_mask_3D.float() 92 | valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3) 93 | valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) 94 | # valid probs 95 | probs_4D = F.sigmoid(logits_4D) * label_mask_3D.unsqueeze(dim=1) 96 | probs_4D = probs_4D.clamp(min=_CLIP_MIN, max=_CLIP_MAX) 97 | 98 | # get region mutual information 99 | rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) 100 | 101 | # add together 102 | final_loss = (self.weight_lambda * normal_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way 103 | else normal_loss + rmi_loss * self.weight_lambda) 104 | 105 | return final_loss 106 | 107 | def forward_sigmoid(self, logits_4D, labels_4D): 108 | """ 109 | Using the sigmiod operation both. 110 | Args: 111 | logits_4D : [N, C, H, W], dtype=float32 112 | labels_4D : [N, H, W], dtype=long 113 | """ 114 | # label mask -- [N, H, W, 1] 115 | label_mask_3D = labels_4D < self.num_classes 116 | 117 | # valid label 118 | valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(), num_classes=self.num_classes).float() 119 | label_mask_3D = label_mask_3D.float() 120 | label_mask_flat = label_mask_3D.view([-1, ]) 121 | valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3) 122 | valid_onehot_labels_4D.requires_grad_(False) 123 | 124 | # PART I -- calculate the sigmoid binary cross entropy loss 125 | valid_onehot_label_flat = valid_onehot_labels_4D.view([-1, self.num_classes]).requires_grad_(False) 126 | logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes]) 127 | 128 | # binary loss, multiplied by the not_ignore_mask 129 | valid_pixels = torch.sum(label_mask_flat) 130 | binary_loss = F.binary_cross_entropy_with_logits(logits_flat, 131 | target=valid_onehot_label_flat, 132 | weight=label_mask_flat.unsqueeze(dim=1), 133 | reduction='sum') 134 | bce_loss = torch.div(binary_loss, valid_pixels + 1.0) 135 | 136 | # PART II -- get rmi loss 137 | # onehot_labels_4D -- [N, C, H, W] 138 | probs_4D = logits_4D.sigmoid() * label_mask_3D.unsqueeze(dim=1) + _CLIP_MIN 139 | valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) 140 | 141 | # get region mutual information 142 | rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) 143 | 144 | # add together 145 | final_loss = (self.weight_lambda * bce_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way 146 | else bce_loss + rmi_loss * self.weight_lambda) 147 | 148 | return final_loss 149 | 150 | def rmi_lower_bound(self, labels_4D, probs_4D): 151 | """ 152 | calculate the lower bound of the region mutual information. 153 | Args: 154 | labels_4D : [N, C, H, W], dtype=float32 155 | probs_4D : [N, C, H, W], dtype=float32 156 | """ 157 | assert labels_4D.size() == probs_4D.size() 158 | 159 | p, s = self.rmi_pool_size, self.rmi_pool_stride 160 | if self.rmi_pool_stride > 1: 161 | if self.rmi_pool_way == 0: 162 | labels_4D = F.max_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 163 | probs_4D = F.max_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 164 | elif self.rmi_pool_way == 1: 165 | labels_4D = F.avg_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 166 | probs_4D = F.avg_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 167 | elif self.rmi_pool_way == 2: 168 | # interpolation 169 | shape = labels_4D.size() 170 | new_h, new_w = shape[2] // s, shape[3] // s 171 | labels_4D = F.interpolate(labels_4D, size=(new_h, new_w), mode='nearest') 172 | probs_4D = F.interpolate(probs_4D, size=(new_h, new_w), mode='bilinear', align_corners=True) 173 | else: 174 | raise NotImplementedError("Pool way of RMI is not defined!") 175 | # we do not need the gradient of label. 176 | label_shape = labels_4D.size() 177 | n, c = label_shape[0], label_shape[1] 178 | 179 | # combine the high dimension points from label and probability map. new shape [N, C, radius * radius, H, W] 180 | la_vectors, pr_vectors = rmi_utils.map_get_pairs(labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0) 181 | 182 | la_vectors = la_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor).requires_grad_(False) 183 | pr_vectors = pr_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor) 184 | 185 | # small diagonal matrix, shape = [1, 1, radius * radius, radius * radius] 186 | diag_matrix = torch.eye(self.half_d).unsqueeze(dim=0).unsqueeze(dim=0) 187 | 188 | # the mean and covariance of these high dimension points 189 | # Var(X) = E(X^2) - E(X) E(X), N * Var(X) = X^2 - X E(X) 190 | la_vectors = la_vectors - la_vectors.mean(dim=3, keepdim=True) 191 | la_cov = torch.matmul(la_vectors, la_vectors.transpose(2, 3)) 192 | 193 | pr_vectors = pr_vectors - pr_vectors.mean(dim=3, keepdim=True) 194 | pr_cov = torch.matmul(pr_vectors, pr_vectors.transpose(2, 3)) 195 | # https://github.com/pytorch/pytorch/issues/7500 196 | # waiting for batched torch.cholesky_inverse() 197 | pr_cov_inv = torch.inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) 198 | # if the dimension of the point is less than 9, you can use the below function 199 | # to acceleration computational speed. 200 | #pr_cov_inv = utils.batch_cholesky_inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) 201 | 202 | la_pr_cov = torch.matmul(la_vectors, pr_vectors.transpose(2, 3)) 203 | # the approxiamation of the variance, det(c A) = c^n det(A), A is in n x n shape; 204 | # then log det(c A) = n log(c) + log det(A). 205 | # appro_var = appro_var / n_points, we do not divide the appro_var by number of points here, 206 | # and the purpose is to avoid underflow issue. 207 | # If A = A^T, A^-1 = (A^-1)^T. 208 | appro_var = la_cov - torch.matmul(la_pr_cov.matmul(pr_cov_inv), la_pr_cov.transpose(-2, -1)) 209 | #appro_var = la_cov - torch.chain_matmul(la_pr_cov, pr_cov_inv, la_pr_cov.transpose(-2, -1)) 210 | #appro_var = torch.div(appro_var, n_points.type_as(appro_var)) + diag_matrix.type_as(appro_var) * 1e-6 211 | 212 | # The lower bound. If A is nonsingular, ln( det(A) ) = Tr( ln(A) ). 213 | rmi_now = 0.5 * rmi_utils.log_det_by_cholesky(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) 214 | #rmi_now = 0.5 * torch.logdet(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) 215 | 216 | # mean over N samples. sum over classes. 217 | rmi_per_class = rmi_now.view([-1, self.num_classes]).mean(dim=0).float() 218 | #is_half = False 219 | #if is_half: 220 | # rmi_per_class = torch.div(rmi_per_class, float(self.half_d / 2.0)) 221 | #else: 222 | rmi_per_class = torch.div(rmi_per_class, float(self.half_d)) 223 | 224 | rmi_loss = torch.sum(rmi_per_class) if _IS_SUM else torch.mean(rmi_per_class) 225 | return rmi_loss 226 | -------------------------------------------------------------------------------- /model/sync_bn/syncbn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # encoding: utf-8 3 | # @Time : 2018/10/3 下午1:45 4 | # @Author : yuchangqian 5 | # @Contact : changqian_yu@163.com from torch-encoding.nn.syncbn 6 | # @File : syncbn.py.py 7 | 8 | """Synchronized Cross-GPU Batch Normalization Module""" 9 | import collections 10 | import threading 11 | 12 | import torch 13 | from torch.nn.modules.batchnorm import _BatchNorm 14 | from torch.nn.functional import batch_norm 15 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 16 | 17 | from .comm import SyncMaster 18 | from .parallel import allreduce 19 | from .functions import sum_square, batchnormtrain 20 | 21 | 22 | __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'] 23 | 24 | 25 | class _SyncBatchNorm(_BatchNorm): 26 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 27 | super(_SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 28 | 29 | self._sync_master = SyncMaster(self._data_parallel_master) 30 | self._parallel_id = None 31 | self._slave_pipe = None 32 | 33 | def forward(self, input): 34 | if not self.training: 35 | return batch_norm( 36 | input, self.running_mean, self.running_var, self.weight, self.bias, 37 | self.training, self.momentum, self.eps) 38 | 39 | # Resize the input to (B, C, -1). 40 | input_shape = input.size() 41 | input = input.view(input_shape[0], self.num_features, -1) 42 | 43 | # sum(x) and sum(x^2) 44 | N = input.size(0) * input.size(2) 45 | xsum, xsqsum = sum_square(input) 46 | 47 | # all-reduce for global sum(x) and sum(x^2) 48 | if self._parallel_id == 0: 49 | mean, inv_std = self._sync_master.run_master(_ChildMessage(xsum, xsqsum, N)) 50 | else: 51 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(xsum, xsqsum, N)) 52 | # forward 53 | return batchnormtrain(input, mean, 1.0/inv_std, self.weight, self.bias).view(input_shape) 54 | 55 | def extra_repr(self): 56 | return '{}, eps={}, momentum={}, sync={}'.format( 57 | self.num_features, self.eps, self.momentum, True) 58 | 59 | def __data_parallel_replicate__(self, ctx, copy_id): 60 | self._parallel_id = copy_id 61 | 62 | # parallel_id == 0 means master device. 63 | if self._parallel_id == 0: 64 | ctx.sync_master = self._sync_master 65 | else: 66 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 67 | 68 | def _data_parallel_master(self, intermediates): 69 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 70 | 71 | # Always using same "device order" makes the ReduceAdd operation faster. 72 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 73 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 74 | 75 | to_reduce = [i[1][:2] for i in intermediates] 76 | to_reduce = [j for i in to_reduce for j in i] # flatten 77 | target_gpus = [i[1].sum.get_device() for i in intermediates] 78 | 79 | sum_size = sum([i[1].sum_size for i in intermediates]) 80 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 81 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 82 | 83 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 84 | 85 | outputs = [] 86 | for i, rec in enumerate(intermediates): 87 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 88 | 89 | return outputs 90 | 91 | def _compute_mean_std(self, sum_, ssum, size): 92 | """Compute the mean and standard-deviation with sum and square-sum. This method 93 | also maintains the moving average on the master device.""" 94 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 95 | mean = sum_ / size 96 | sumvar = ssum - sum_ * mean 97 | unbias_var = sumvar / (size - 1) 98 | bias_var = sumvar / size 99 | 100 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 101 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 102 | 103 | return mean, (bias_var + self.eps) ** -0.5 104 | 105 | 106 | # API adapted from https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 107 | _ChildMessage = collections.namedtuple('Message', ['sum', 'ssum', 'sum_size']) 108 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 109 | 110 | 111 | class BatchNorm1d(_SyncBatchNorm): 112 | r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`""" 113 | def _check_input_dim(self, input): 114 | if input.dim() != 2 and input.dim() != 3: 115 | raise ValueError('expected 2D or 3D input (got {}D input)' 116 | .format(input.dim())) 117 | super(BatchNorm2d, self)._check_input_dim(input) 118 | 119 | 120 | class BatchNorm2d(_SyncBatchNorm): 121 | r"""Cross-GPU Synchronized Batch normalization (SyncBN) 122 | 123 | Standard BN [1]_ implementation only normalize the data within each device (GPU). 124 | SyncBN normalizes the input within the whole mini-batch. 125 | We follow the sync-onece implmentation described in the paper [2]_ . 126 | Please see the design idea in the `notes <./notes/syncbn.html>`_. 127 | 128 | .. note:: 129 | We adapt the awesome python API from another `PyTorch SyncBN Implementation 130 | `_ and provide 131 | efficient CUDA backend. 132 | 133 | .. math:: 134 | 135 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 136 | 137 | The mean and standard-deviation are calculated per-channel over 138 | the mini-batches and gamma and beta are learnable parameter vectors 139 | of size C (where C is the input size). 140 | 141 | During training, this layer keeps a running estimate of its computed mean 142 | and variance. The running sum is kept with a default momentum of 0.1. 143 | 144 | During evaluation, this running mean/variance is used for normalization. 145 | 146 | Because the BatchNorm is done over the `C` dimension, computing statistics 147 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 148 | 149 | Args: 150 | num_features: num_features from an expected input of 151 | size batch_size x num_features x height x width 152 | eps: a value added to the denominator for numerical stability. 153 | Default: 1e-5 154 | momentum: the value used for the running_mean and running_var 155 | computation. Default: 0.1 156 | affine: a boolean value that when set to ``True``, gives the layer learnable 157 | affine parameters. Default: ``True`` 158 | 159 | Shape: 160 | - Input: :math:`(N, C, H, W)` 161 | - Output: :math:`(N, C, H, W)` (same shape as input) 162 | 163 | Reference: 164 | .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015* 165 | .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* 166 | 167 | Examples: 168 | >>> m = BatchNorm2d(100) 169 | >>> net = torch.nn.DataParallel(m) 170 | >>> encoding.parallel.patch_replication_callback(net) 171 | >>> output = net(input) 172 | """ 173 | def _check_input_dim(self, input): 174 | if input.dim() != 4: 175 | raise ValueError('expected 4D input (got {}D input)' 176 | .format(input.dim())) 177 | super(BatchNorm2d, self)._check_input_dim(input) 178 | 179 | 180 | class BatchNorm3d(_SyncBatchNorm): 181 | r"""Please see the docs in :class:`encoding.nn.BatchNorm2d`""" 182 | def _check_input_dim(self, input): 183 | if input.dim() != 5: 184 | raise ValueError('expected 5D input (got {}D input)' 185 | .format(input.dim())) 186 | super(BatchNorm3d, self)._check_input_dim(input) 187 | 188 | 189 | class SharedTensor(object): 190 | """Shared Tensor for cross GPU all reduce operation""" 191 | def __init__(self, nGPUs): 192 | self.mutex = threading.Lock() 193 | self.all_tasks_done = threading.Condition(self.mutex) 194 | self.nGPUs = nGPUs 195 | self._clear() 196 | 197 | def _clear(self): 198 | self.N = 0 199 | self.dict = {} 200 | self.push_tasks = self.nGPUs 201 | self.reduce_tasks = self.nGPUs 202 | 203 | def push(self, *inputs): 204 | # push from device 205 | with self.mutex: 206 | if self.push_tasks == 0: 207 | self._clear() 208 | self.N += inputs[0] 209 | igpu = inputs[1] 210 | self.dict[igpu] = inputs[2:] 211 | #idx = self.nGPUs - self.push_tasks 212 | self.push_tasks -= 1 213 | with self.all_tasks_done: 214 | if self.push_tasks == 0: 215 | self.all_tasks_done.notify_all() 216 | while self.push_tasks: 217 | self.all_tasks_done.wait() 218 | 219 | def pull(self, igpu): 220 | # pull from device 221 | with self.mutex: 222 | if igpu == 0: 223 | assert(len(self.dict) == self.nGPUs) 224 | # flatten the tensors 225 | self.list = [t for i in range(len(self.dict)) for t in self.dict[i]] 226 | self.outlist = allreduce(2, *self.list) 227 | self.reduce_tasks -= 1 228 | else: 229 | self.reduce_tasks -= 1 230 | with self.all_tasks_done: 231 | if self.reduce_tasks == 0: 232 | self.all_tasks_done.notify_all() 233 | while self.reduce_tasks: 234 | self.all_tasks_done.wait() 235 | # all reduce done 236 | return self.N, self.outlist[2*igpu], self.outlist[2*igpu+1] 237 | 238 | def __len__(self): 239 | return self.nGPUs 240 | 241 | def __repr__(self): 242 | return ('SharedTensor') 243 | --------------------------------------------------------------------------------