├── .gitignore ├── README.md ├── configs ├── hrnet_config.py └── segformer_config.py ├── create_video.ipynb ├── models ├── __init__.py ├── hrnet.py ├── segformer.py ├── segformer_simple.py └── segformer_utils │ ├── __init__.py │ ├── encoder_decoder.py │ ├── logger.py │ ├── mix_transformer.py │ ├── segformer_build.py │ └── segformer_head.py ├── requirements.txt ├── segmentation_hrnet.ipynb ├── segmentation_segformer.ipynb ├── sliding_window_test.ipynb ├── src ├── hrnet_w48_graph.png ├── segformer_b0_graph.png ├── segformer_simple_b0_graph.png ├── stuttgart_hrnet_w48_sample.gif └── stuttgart_segformer_sample.gif └── utils ├── __init__.py ├── data_utils.py ├── label_utils.py ├── lr_schedule.py ├── modelsummary.py ├── runners.py ├── train_utils.py ├── transformation_pipelines.py ├── transformations.py └── visual.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # other 132 | outputs/ 133 | data/ 134 | logs/ 135 | weights/ 136 | .DS_Store 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SegFormer and HRNet Comparason for Semantic Segmentation 2 | 3 | This repo consists of an image segmentation pipeline on the Cityscapes dataset, using [HRNet](https://github.com/HRNet/HRNet-Semantic-Segmentation), and a powerful new transformer-based architecture called [SegFormer](https://github.com/NVlabs/SegFormer) . The scripts for data preprocessing, training, and inference are done mainly from scratch. The model construction code for HRNet (`models/hrnet.py`) and SegFormer (`models/segformer.py`) have been adapted from the official mmseg implementation, whereas `models/segformer_simple.py` contains a very clean SegFormer implementation that may not be correct. 4 | 5 | HRNet and SegFormer are useful architectures to compare, because they represent fundamentally different approaches to image understanding. HRNet - like most other vision architectures - is at its core a series of convolution operations that are stacked, fused, and connected in a very efficient manner. SegFormer, on the other hand, has no convolutional operations, and instead uses transformer layers. It treats each image as a sequence of tokens, where each token represents a 4x4 pixel patch of the image. 6 | 7 | For training, the implementation details of the original papers are followed as closely as possible. 8 | 9 | Due to memory limitations (single RTX 3090 GPU 24 GB), gradient accumilation was used for training the SegFormer model. 10 | 11 | 12 | # HRNet 13 | ---------------------------------------------------------------------------------------------------- 14 | 15 | ![](src/stuttgart_hrnet_w48_sample.gif) 16 | 17 | 18 | 19 | # SegFormer 20 | ---------------------------------------------------------------------------------------------------- 21 | 22 | 23 | ![](src/stuttgart_segformer_sample.gif) 24 | 25 | -------------------------------------------------------------------------------- /configs/hrnet_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from yacs.config import CfgNode as CN 6 | 7 | 8 | config = CN() 9 | 10 | config.NAME = 'hrnet_w48train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484' 11 | config.OUTPUT_DIR = 'outputs' 12 | config.LOG_DIR = 'logs' 13 | 14 | 15 | config.DATASET = CN() 16 | config.DATASET.NAME = 'cityscapes' 17 | config.DATASET.DATA_DIR = 'data/cityscapes' 18 | config.DATASET.INPUT_PATTERN = '*_leftImg8bit.png' 19 | config.DATASET.ANNOT_PATTERN = '*_gtFine_labelIds.png' 20 | config.DATASET.IMAGE_DIR = 'leftImg8bit' 21 | config.DATASET.LABEL_DIR = 'gtFine' 22 | config.DATASET.NUM_CLASSES = 19 23 | config.DATASET.IGNORE_LABEL = 255 24 | config.DATASET.MEAN = [0.485, 0.456, 0.406] 25 | config.DATASET.STD = [0.229, 0.224, 0.225] 26 | config.DATASET.BASE_SIZE = (1024, 2048) 27 | config.DATASET.CROP_SIZE = (512, 1024) 28 | 29 | config.TRAIN = CN() 30 | config.TRAIN.EPOCHS = 484 31 | config.TRAIN.DECAY_STEPS = 120000 32 | config.TRAIN.BATCH_SIZE = 12 33 | 34 | config.TRAIN.BASE_LR = 1e-2 35 | config.TRAIN.END_LR = 1e-5 36 | config.TRAIN.OPTIMIZER = 'sgd' 37 | config.TRAIN.WD = 0.0005 38 | config.TRAIN.MOMENTUM = 0.9 39 | 40 | config.MODEL = CN() 41 | config.MODEL.NAME = 'hrnet_w48' 42 | config.MODEL.PRETRAINED = 'weights/HRNet_W48_C_pretrained.pth' 43 | config.MODEL.W = 48 44 | 45 | config.MODEL.STAGE_1 = CN() 46 | config.MODEL.STAGE_1.NUM_MODULES = 1 47 | config.MODEL.STAGE_1.NUM_BRANCHES = 1 48 | config.MODEL.STAGE_1.BLOCK = 'BOTTLENECK' 49 | config.MODEL.STAGE_1.NUM_BLOCKS = [4] 50 | config.MODEL.STAGE_1.NUM_CHANNELS = [64] 51 | 52 | config.MODEL.STAGE_2 = CN() 53 | config.MODEL.STAGE_2.NUM_MODULES = 1 54 | config.MODEL.STAGE_2.NUM_BRANCHES = 2 55 | config.MODEL.STAGE_2.BLOCK = 'BASIC' 56 | config.MODEL.STAGE_2.NUM_BLOCKS = [4, 4] 57 | config.MODEL.STAGE_2.NUM_CHANNELS = [48, 96] 58 | 59 | config.MODEL.STAGE_3 = CN() 60 | config.MODEL.STAGE_3.NUM_MODULES = 4 61 | config.MODEL.STAGE_3.NUM_BRANCHES = 3 62 | config.MODEL.STAGE_3.BLOCK = 'BASIC' 63 | config.MODEL.STAGE_3.NUM_BLOCKS = [4, 4, 4] 64 | config.MODEL.STAGE_3.NUM_CHANNELS = [48, 96, 192] 65 | 66 | config.MODEL.STAGE_4 = CN() 67 | config.MODEL.STAGE_4.NUM_MODULES = 3 68 | config.MODEL.STAGE_4.NUM_BRANCHES = 4 69 | config.MODEL.STAGE_4.BLOCK = 'BASIC' 70 | config.MODEL.STAGE_4.NUM_BLOCKS = [4, 4, 4, 4] 71 | config.MODEL.STAGE_4.NUM_CHANNELS = [48, 96, 192, 384] 72 | 73 | -------------------------------------------------------------------------------- /configs/segformer_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from yacs.config import CfgNode as CN 6 | 7 | config = CN() 8 | 9 | config.NAME = 'segformer_train_1024x1024_adamw_lr6e-6_wd1e-2_bs_8_epoch400' 10 | config.OUTPUT_DIR = 'outputs' 11 | config.LOG_DIR = 'logs' 12 | 13 | config.DATASET = CN() 14 | config.DATASET.NAME = 'cityscapes' 15 | config.DATASET.DATA_DIR = 'data/cityscapes' 16 | config.DATASET.INPUT_PATTERN = '*_leftImg8bit.png' 17 | config.DATASET.ANNOT_PATTERN = '*_gtFine_labelIds.png' 18 | config.DATASET.IMAGE_DIR = 'leftImg8bit' 19 | config.DATASET.LABEL_DIR = 'gtFine' 20 | config.DATASET.NUM_CLASSES = 19 21 | config.DATASET.IGNORE_LABEL = 255 22 | config.DATASET.MEAN = [0.485, 0.456, 0.406] 23 | config.DATASET.STD = [0.229, 0.224, 0.225] 24 | config.DATASET.BASE_SIZE = (1024, 2048) 25 | config.DATASET.CROP_SIZE = (1024, 1024) # (768, 768) 26 | 27 | 28 | config.TRAIN = CN() 29 | config.TRAIN.EPOCHS = 400 30 | config.TRAIN.DECAY_STEPS = 160000 31 | config.TRAIN.BATCH_SIZE = 8 32 | config.TRAIN.ACCUM_STEPS = 4 33 | config.TRAIN.ADJ_BATCH_SIZE = config.TRAIN.BATCH_SIZE // config.TRAIN.ACCUM_STEPS 34 | config.TRAIN.POWER = 1.0 35 | config.TRAIN.WARMUP_ITERS = 1500 36 | config.TRAIN.WARMUP_RATIO = 1e-6 37 | config.TRAIN.BY_EPOCH = False 38 | config.TRAIN.BASE_LR = 0.00006 39 | config.TRAIN.MIN_LR = 0.0 40 | config.TRAIN.WARMUP = "linear" 41 | config.TRAIN.OPTIMIZER = 'AdamW' 42 | config.TRAIN.WD = 0.01 43 | 44 | 45 | config.MODEL = CN() 46 | config.MODEL.NAME = 'segformer' 47 | config.MODEL.PATCH_SIZE = 4 48 | 49 | 50 | ##### B5 ##### 51 | 52 | config.MODEL.B5 = CN() 53 | config.MODEL.B5.PRETRAINED = 'weights/mit_b5.pth' 54 | config.MODEL.B5.DECODER_DIM = 768 55 | config.MODEL.B5.CHANNEL_DIMS = (64, 128, 320, 512) 56 | config.MODEL.B5.SR_RATIOS = (8, 4, 2, 1) 57 | config.MODEL.B5.NUM_HEADS = (1, 2, 5, 8) 58 | config.MODEL.B5.MLP_RATIOS = (4, 4, 4, 4) 59 | config.MODEL.B5.DEPTHS = (3, 6, 40, 3) 60 | config.MODEL.B5.QKV_BIAS = True 61 | config.MODEL.B5.DROP_RATE = 0.0 62 | config.MODEL.B5.DROP_PATH_RATE = 0.1 63 | 64 | ##### B3 ##### 65 | config.MODEL.B3 = CN() 66 | config.MODEL.B3.PRETRAINED = 'weights/mit_b3.pth' 67 | config.MODEL.B3.DECODER_DIM = 768 68 | config.MODEL.B3.CHANNEL_DIMS = (64, 128, 320, 512) 69 | config.MODEL.B3.SR_RATIOS = (8, 4, 2, 1) 70 | config.MODEL.B3.NUM_HEADS = (1, 2, 5, 8) 71 | config.MODEL.B3.MLP_RATIOS = (4, 4, 4, 4) 72 | config.MODEL.B3.DEPTHS = (3, 4, 18, 3) 73 | config.MODEL.B3.QKV_BIAS = True 74 | config.MODEL.B3.DROP_RATE = 0.0 75 | config.MODEL.B3.DROP_PATH_RATE = 0.1 76 | 77 | ##### B2 ##### 78 | config.MODEL.B2 = CN() 79 | config.MODEL.B2.PRETRAINED = 'weights/mit_b2.pth' 80 | config.MODEL.B2.DECODER_DIM = 768 81 | config.MODEL.B2.CHANNEL_DIMS = (64, 128, 320, 512) 82 | config.MODEL.B2.SR_RATIOS = (8, 4, 2, 1) 83 | config.MODEL.B2.NUM_HEADS = (1, 2, 5, 8) 84 | config.MODEL.B2.MLP_RATIOS = (4, 4, 4, 4) 85 | config.MODEL.B2.DEPTHS = (3, 4, 6, 3) 86 | config.MODEL.B2.QKV_BIAS = True 87 | config.MODEL.B2.DROP_RATE = 0.0 88 | config.MODEL.B2.DROP_PATH_RATE = 0.1 89 | 90 | 91 | ##### B1 ##### 92 | config.MODEL.B1 = CN() 93 | config.MODEL.B1.PRETRAINED = 'weights/mit_b1.pth' 94 | config.MODEL.B1.DECODER_DIM = 768 95 | config.MODEL.B1.CHANNEL_DIMS = (64, 128, 320, 512) 96 | config.MODEL.B1.SR_RATIOS = (8, 4, 2, 1) 97 | config.MODEL.B1.NUM_HEADS = (1, 2, 5, 8) 98 | config.MODEL.B1.MLP_RATIOS = (4, 4, 4, 4) 99 | config.MODEL.B1.DEPTHS = (2, 2, 2, 2) 100 | config.MODEL.B1.QKV_BIAS = True 101 | config.MODEL.B1.DROP_RATE = 0.0 102 | config.MODEL.B1.DROP_PATH_RATE = 0.1 103 | 104 | 105 | ##### B0 ##### 106 | config.MODEL.B0 = CN() 107 | config.MODEL.B0.PRETRAINED = 'weights/mit_b0.pth' 108 | config.MODEL.B0.DECODER_DIM = 256 109 | config.MODEL.B0.CHANNEL_DIMS = (32, 64, 160, 256) 110 | config.MODEL.B0.SR_RATIOS = (8, 4, 2, 1) 111 | config.MODEL.B0.NUM_HEADS = (1, 2, 5, 8) 112 | config.MODEL.B0.MLP_RATIOS = (4, 4, 4, 4) 113 | config.MODEL.B0.DEPTHS = (2, 2, 2, 2) 114 | config.MODEL.B0.QKV_BIAS = True 115 | config.MODEL.B0.DROP_RATE = 0.0 116 | config.MODEL.B0.DROP_PATH_RATE = 0.1 117 | 118 | -------------------------------------------------------------------------------- /create_video.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "2d52492a", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os\n", 11 | "import ffmpeg" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "7d731948", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "video_dir = \"outputs/video_frames\"\n", 22 | "img_pattern = '*_leftImg8bit.png'\n", 23 | "PATTERN = os.path.join(video_dir, img_pattern)\n", 24 | "VIDEO_PATH = \"outputs/stuttgart_output.mp4\"" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "id": "93f93239", 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "data": { 35 | "text/plain": [ 36 | "'outputs/video_frames/*_leftImg8bit.png'" 37 | ] 38 | }, 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "output_type": "execute_result" 42 | } 43 | ], 44 | "source": [ 45 | "PATTERN" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "id": "324b7106", 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "(None, None)" 58 | ] 59 | }, 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "output_type": "execute_result" 63 | } 64 | ], 65 | "source": [ 66 | "ffmpeg.input(\n", 67 | " PATTERN, \n", 68 | " pattern_type='glob', \n", 69 | " framerate=15\n", 70 | ").output(VIDEO_PATH, pix_fmt='yuv420p', vcodec='libx264').run()" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "id": "c05c49ff", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [] 80 | } 81 | ], 82 | "metadata": { 83 | "kernelspec": { 84 | "display_name": "Python 3", 85 | "language": "python", 86 | "name": "python3" 87 | }, 88 | "language_info": { 89 | "codemirror_mode": { 90 | "name": "ipython", 91 | "version": 3 92 | }, 93 | "file_extension": ".py", 94 | "mimetype": "text/x-python", 95 | "name": "python", 96 | "nbconvert_exporter": "python", 97 | "pygments_lexer": "ipython3", 98 | "version": "3.7.11" 99 | } 100 | }, 101 | "nbformat": 4, 102 | "nbformat_minor": 5 103 | } 104 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlaedtke/segmentation_pytorch/3818a88d4f40b51cf149b1e04a7b94f889c358a9/models/__init__.py -------------------------------------------------------------------------------- /models/hrnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | import functools 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch._utils 12 | import torch.nn.functional as F 13 | 14 | BatchNorm2d_class = BatchNorm2d = torch.nn.BatchNorm2d 15 | relu_inplace = True 16 | BN_MOMENTUM = 0.1 17 | ALIGN_CORNERS = False 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 34 | self.relu = nn.ReLU(inplace=relu_inplace) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out = out + residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 67 | self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) 68 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 69 | self.bn3 = BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) 70 | self.relu = nn.ReLU(inplace=relu_inplace) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out = out + residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class HighResolutionModule(nn.Module): 98 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels, 99 | num_channels, fuse_method, multi_scale_output=True): 100 | super(HighResolutionModule, self).__init__() 101 | self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels) 102 | 103 | self.num_inchannels = num_inchannels 104 | self.fuse_method = fuse_method 105 | self.num_branches = num_branches 106 | 107 | self.multi_scale_output = multi_scale_output 108 | 109 | self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) 110 | self.fuse_layers = self._make_fuse_layers() 111 | self.relu = nn.ReLU(inplace=relu_inplace) 112 | 113 | def _check_branches(self, num_branches, blocks, num_blocks, 114 | num_inchannels, num_channels): 115 | if num_branches != len(num_blocks): 116 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks)) 117 | logger.error(error_msg) 118 | raise ValueError(error_msg) 119 | 120 | if num_branches != len(num_channels): 121 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels)) 122 | logger.error(error_msg) 123 | raise ValueError(error_msg) 124 | 125 | if num_branches != len(num_inchannels): 126 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(num_branches, len(num_inchannels)) 127 | logger.error(error_msg) 128 | raise ValueError(error_msg) 129 | 130 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels, 131 | stride=1): 132 | downsample = None 133 | if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: 134 | downsample = nn.Sequential( 135 | nn.Conv2d(self.num_inchannels[branch_index], 136 | num_channels[branch_index] * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) 143 | self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion 144 | for i in range(1, num_blocks[branch_index]): 145 | layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def _make_branches(self, num_branches, block, num_blocks, num_channels): 150 | branches = [] 151 | 152 | for i in range(num_branches): 153 | branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) 154 | 155 | return nn.ModuleList(branches) 156 | 157 | def _make_fuse_layers(self): 158 | if self.num_branches == 1: 159 | return None 160 | 161 | num_branches = self.num_branches 162 | num_inchannels = self.num_inchannels 163 | fuse_layers = [] 164 | for i in range(num_branches if self.multi_scale_output else 1): 165 | fuse_layer = [] 166 | for j in range(num_branches): 167 | if j > i: 168 | fuse_layer.append(nn.Sequential( 169 | nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), 170 | BatchNorm2d(num_inchannels[i], momentum=BN_MOMENTUM))) 171 | elif j == i: 172 | fuse_layer.append(None) 173 | else: 174 | conv3x3s = [] 175 | for k in range(i-j): 176 | if k == i - j - 1: 177 | num_outchannels_conv3x3 = num_inchannels[i] 178 | conv3x3s.append(nn.Sequential( 179 | nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), 180 | BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM))) 181 | else: 182 | num_outchannels_conv3x3 = num_inchannels[j] 183 | conv3x3s.append(nn.Sequential( 184 | nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), 185 | BatchNorm2d(num_outchannels_conv3x3, momentum=BN_MOMENTUM), 186 | nn.ReLU(inplace=relu_inplace))) 187 | fuse_layer.append(nn.Sequential(*conv3x3s)) 188 | fuse_layers.append(nn.ModuleList(fuse_layer)) 189 | 190 | return nn.ModuleList(fuse_layers) 191 | 192 | def get_num_inchannels(self): 193 | return self.num_inchannels 194 | 195 | def forward(self, x): 196 | if self.num_branches == 1: 197 | return [self.branches[0](x[0])] 198 | 199 | for i in range(self.num_branches): 200 | x[i] = self.branches[i](x[i]) 201 | 202 | x_fuse = [] 203 | for i in range(len(self.fuse_layers)): 204 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 205 | for j in range(1, self.num_branches): 206 | if i == j: 207 | y = y + x[j] 208 | elif j > i: 209 | width_output = x[i].shape[-1] 210 | height_output = x[i].shape[-2] 211 | y = y + F.interpolate( 212 | self.fuse_layers[i][j](x[j]), 213 | size=[height_output, width_output], 214 | mode='bilinear', align_corners=ALIGN_CORNERS) 215 | else: 216 | y = y + self.fuse_layers[i][j](x[j]) 217 | x_fuse.append(self.relu(y)) 218 | 219 | return x_fuse 220 | 221 | 222 | blocks_dict = { 223 | 'BASIC': BasicBlock, 224 | 'BOTTLENECK': Bottleneck 225 | } 226 | 227 | 228 | class HRNet(nn.Module): 229 | def __init__(self, cfg): 230 | super(HRNet, self).__init__() 231 | 232 | self.cfg = cfg 233 | self.inplanes = 64 234 | 235 | 236 | # stem net 237 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) 238 | self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) 239 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False) 240 | self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) 241 | self.relu = nn.ReLU(inplace=relu_inplace) 242 | 243 | # STAGE 1 244 | num_channels = cfg.MODEL.STAGE_1.NUM_CHANNELS[0] 245 | block = blocks_dict[cfg.MODEL.STAGE_1.BLOCK] 246 | num_blocks = cfg.MODEL.STAGE_1.NUM_BLOCKS[0] 247 | self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) 248 | stage1_out_channel = block.expansion * num_channels 249 | 250 | # STAGE 2 251 | num_channels = cfg.MODEL.STAGE_2.NUM_CHANNELS 252 | block = blocks_dict[cfg.MODEL.STAGE_2.BLOCK] 253 | num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] 254 | self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels) 255 | self.stage2, pre_stage_channels = self._make_stage(cfg.MODEL.STAGE_2, num_channels) 256 | 257 | # STAGE 3 258 | num_channels = cfg.MODEL.STAGE_3.NUM_CHANNELS 259 | block = blocks_dict[cfg.MODEL.STAGE_3.BLOCK] 260 | num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] 261 | self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) 262 | self.stage3, pre_stage_channels = self._make_stage(cfg.MODEL.STAGE_3, num_channels) 263 | 264 | # STAGE 4 265 | num_channels = cfg.MODEL.STAGE_4.NUM_CHANNELS 266 | block = blocks_dict[cfg.MODEL.STAGE_4.BLOCK] 267 | num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] 268 | self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) 269 | self.stage4, pre_stage_channels = self._make_stage(cfg.MODEL.STAGE_4, num_channels, multi_scale_output=True) 270 | 271 | last_inp_channels = np.int(np.sum(pre_stage_channels)) 272 | 273 | self.last_layer = nn.Sequential( 274 | nn.Conv2d( 275 | in_channels=last_inp_channels, 276 | out_channels=last_inp_channels, 277 | kernel_size=1, 278 | stride=1, 279 | padding=0), 280 | BatchNorm2d(last_inp_channels, momentum=BN_MOMENTUM), 281 | nn.ReLU(inplace=relu_inplace), 282 | nn.Conv2d( 283 | in_channels=last_inp_channels, 284 | out_channels=cfg.DATASET.NUM_CLASSES, 285 | kernel_size=1, 286 | stride=1, 287 | padding=0), 288 | ) 289 | 290 | 291 | def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): 292 | num_branches_cur = len(num_channels_cur_layer) 293 | num_branches_pre = len(num_channels_pre_layer) 294 | 295 | transition_layers = [] 296 | for i in range(num_branches_cur): 297 | if i < num_branches_pre: 298 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 299 | transition_layers.append(nn.Sequential( 300 | nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i],3,1,1,bias=False), 301 | BatchNorm2d(num_channels_cur_layer[i], momentum=BN_MOMENTUM), 302 | nn.ReLU(inplace=relu_inplace))) 303 | else: 304 | transition_layers.append(None) 305 | else: 306 | conv3x3s = [] 307 | for j in range(i+1-num_branches_pre): 308 | inchannels = num_channels_pre_layer[-1] 309 | outchannels = num_channels_cur_layer[i] if j == i-num_branches_pre else inchannels 310 | conv3x3s.append(nn.Sequential( 311 | nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), 312 | BatchNorm2d(outchannels, momentum=BN_MOMENTUM), 313 | nn.ReLU(inplace=relu_inplace))) 314 | transition_layers.append(nn.Sequential(*conv3x3s)) 315 | 316 | return nn.ModuleList(transition_layers) 317 | 318 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 319 | downsample = None 320 | if stride != 1 or inplanes != planes * block.expansion: 321 | downsample = nn.Sequential( 322 | nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 323 | BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 324 | ) 325 | 326 | layers = [] 327 | layers.append(block(inplanes, planes, stride, downsample)) 328 | inplanes = planes * block.expansion 329 | for i in range(1, blocks): 330 | layers.append(block(inplanes, planes)) 331 | 332 | return nn.Sequential(*layers) 333 | 334 | def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): 335 | num_modules = layer_config.NUM_MODULES 336 | num_branches = layer_config.NUM_BRANCHES 337 | num_blocks = layer_config.NUM_BLOCKS 338 | num_channels = layer_config.NUM_CHANNELS 339 | block = blocks_dict[layer_config.BLOCK] 340 | fuse_method = "SUM" 341 | 342 | modules = [] 343 | for i in range(num_modules): 344 | # multi_scale_output is only used last module 345 | if not multi_scale_output and i == num_modules - 1: 346 | reset_multi_scale_output = False 347 | else: 348 | reset_multi_scale_output = True 349 | modules.append( 350 | HighResolutionModule(num_branches, block, num_blocks, num_inchannels, num_channels, 351 | fuse_method, reset_multi_scale_output) 352 | ) 353 | num_inchannels = modules[-1].get_num_inchannels() 354 | 355 | return nn.Sequential(*modules), num_inchannels 356 | 357 | def forward(self, x): 358 | x = self.conv1(x) 359 | x = self.bn1(x) 360 | x = self.relu(x) 361 | x = self.conv2(x) 362 | x = self.bn2(x) 363 | x = self.relu(x) 364 | x = self.layer1(x) 365 | 366 | x_list = [] 367 | for i in range(self.cfg.MODEL.STAGE_2.NUM_BRANCHES): 368 | if self.transition1[i] is not None: 369 | x_list.append(self.transition1[i](x)) 370 | else: 371 | x_list.append(x) 372 | y_list = self.stage2(x_list) 373 | 374 | x_list = [] 375 | for i in range(self.cfg.MODEL.STAGE_3.NUM_BRANCHES): 376 | if self.transition2[i] is not None: 377 | if i < self.cfg.MODEL.STAGE_2.NUM_BRANCHES: 378 | x_list.append(self.transition2[i](y_list[i])) 379 | else: 380 | x_list.append(self.transition2[i](y_list[-1])) 381 | else: 382 | x_list.append(y_list[i]) 383 | y_list = self.stage3(x_list) 384 | 385 | x_list = [] 386 | for i in range(self.cfg.MODEL.STAGE_4.NUM_BRANCHES): 387 | if self.transition3[i] is not None: 388 | if i < self.cfg.MODEL.STAGE_3.NUM_BRANCHES: 389 | x_list.append(self.transition3[i](y_list[i])) 390 | else: 391 | x_list.append(self.transition3[i](y_list[-1])) 392 | else: 393 | x_list.append(y_list[i]) 394 | x = self.stage4(x_list) 395 | 396 | # Upsampling 397 | x0_h, x0_w = x[0].size(2), x[0].size(3) 398 | x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) 399 | x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) 400 | x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) 401 | 402 | x = torch.cat([x[0], x1, x2, x3], 1) 403 | 404 | x = self.last_layer(x) 405 | 406 | x = F.interpolate(input=x, size=self.cfg.DATASET.CROP_SIZE, mode='bilinear', 407 | align_corners=ALIGN_CORNERS) 408 | 409 | x = x.type(torch.float32) 410 | 411 | return x 412 | 413 | def init_weights(self, pretrained='',): 414 | logger.info('=> init weights from normal distribution') 415 | for m in self.modules(): 416 | if isinstance(m, nn.Conv2d): 417 | nn.init.normal_(m.weight, std=0.001) 418 | elif isinstance(m, BatchNorm2d_class): 419 | nn.init.constant_(m.weight, 1) 420 | nn.init.constant_(m.bias, 0) 421 | if os.path.isfile(pretrained): 422 | pretrained_dict = torch.load(pretrained) 423 | logger.info('=> loading pretrained model {}'.format(pretrained)) 424 | model_dict = self.state_dict() 425 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 426 | for k, _ in pretrained_dict.items(): 427 | logger.info('=> loading {} pretrained model {}'.format(k, pretrained)) 428 | model_dict.update(pretrained_dict) 429 | self.load_state_dict(model_dict) 430 | 431 | def __repr__(self): 432 | attributes = {attr_key: self.__dict__[attr_key] for attr_key in self.__dict__.keys() if '_' not in attr_key[0] and 'training' not in attr_key} 433 | d = {self.__class__.__name__: attributes} 434 | return f'{d}' 435 | 436 | -------------------------------------------------------------------------------- /models/segformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from models.segformer_utils.logger import get_root_logger 8 | from mmcv.runner import load_checkpoint 9 | 10 | 11 | 12 | class Mlp(nn.Module): 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | self.fc1 = nn.Linear(in_features, hidden_features) 18 | self.dwconv = DWConv(hidden_features) 19 | self.act = act_layer() 20 | self.fc2 = nn.Linear(hidden_features, out_features) 21 | self.drop = nn.Dropout(drop) 22 | 23 | self.apply(self._init_weights) 24 | 25 | def _init_weights(self, m): 26 | if isinstance(m, nn.Linear): 27 | trunc_normal_(m.weight, std=.02) 28 | if isinstance(m, nn.Linear) and m.bias is not None: 29 | nn.init.constant_(m.bias, 0) 30 | elif isinstance(m, nn.LayerNorm): 31 | nn.init.constant_(m.bias, 0) 32 | nn.init.constant_(m.weight, 1.0) 33 | elif isinstance(m, nn.Conv2d): 34 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 35 | fan_out //= m.groups 36 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 37 | if m.bias is not None: 38 | m.bias.data.zero_() 39 | 40 | def forward(self, x, H, W): 41 | x = self.fc1(x) 42 | x = self.dwconv(x, H, W) 43 | x = self.act(x) 44 | x = self.drop(x) 45 | x = self.fc2(x) 46 | x = self.drop(x) 47 | return x 48 | 49 | 50 | class Attention(nn.Module): 51 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 52 | super().__init__() 53 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 54 | 55 | self.dim = dim 56 | self.num_heads = num_heads 57 | head_dim = dim // num_heads 58 | self.scale = qk_scale or head_dim ** -0.5 59 | 60 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 61 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | self.sr_ratio = sr_ratio 67 | if sr_ratio > 1: 68 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 69 | self.norm = nn.LayerNorm(dim) 70 | 71 | self.apply(self._init_weights) 72 | 73 | def _init_weights(self, m): 74 | if isinstance(m, nn.Linear): 75 | trunc_normal_(m.weight, std=.02) 76 | if isinstance(m, nn.Linear) and m.bias is not None: 77 | nn.init.constant_(m.bias, 0) 78 | elif isinstance(m, nn.LayerNorm): 79 | nn.init.constant_(m.bias, 0) 80 | nn.init.constant_(m.weight, 1.0) 81 | elif isinstance(m, nn.Conv2d): 82 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 83 | fan_out //= m.groups 84 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 85 | if m.bias is not None: 86 | m.bias.data.zero_() 87 | 88 | def forward(self, x, H, W): 89 | B, N, C = x.shape 90 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 91 | 92 | if self.sr_ratio > 1: 93 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 94 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 95 | x_ = self.norm(x_) 96 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 97 | else: 98 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | k, v = kv[0], kv[1] 100 | 101 | attn = (q @ k.transpose(-2, -1)) * self.scale 102 | attn = attn.softmax(dim=-1) 103 | attn = self.attn_drop(attn) 104 | 105 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 106 | x = self.proj(x) 107 | x = self.proj_drop(x) 108 | 109 | return x 110 | 111 | 112 | class Block(nn.Module): 113 | 114 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 115 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 116 | super().__init__() 117 | self.norm1 = norm_layer(dim) 118 | self.attn = Attention( 119 | dim, 120 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 121 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 122 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 123 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 124 | self.norm2 = norm_layer(dim) 125 | mlp_hidden_dim = int(dim * mlp_ratio) 126 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 127 | 128 | self.apply(self._init_weights) 129 | 130 | def _init_weights(self, m): 131 | if isinstance(m, nn.Linear): 132 | trunc_normal_(m.weight, std=.02) 133 | if isinstance(m, nn.Linear) and m.bias is not None: 134 | nn.init.constant_(m.bias, 0) 135 | elif isinstance(m, nn.LayerNorm): 136 | nn.init.constant_(m.bias, 0) 137 | nn.init.constant_(m.weight, 1.0) 138 | elif isinstance(m, nn.Conv2d): 139 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | fan_out //= m.groups 141 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 142 | if m.bias is not None: 143 | m.bias.data.zero_() 144 | 145 | def forward(self, x, H, W): 146 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 147 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 148 | 149 | return x 150 | 151 | 152 | class OverlapPatchEmbed(nn.Module): 153 | """ Image to Patch Embedding 154 | """ 155 | 156 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 157 | super().__init__() 158 | img_size = to_2tuple(img_size) 159 | patch_size = to_2tuple(patch_size) 160 | 161 | self.img_size = img_size 162 | self.patch_size = patch_size 163 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 164 | self.num_patches = self.H * self.W 165 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 166 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 167 | self.norm = nn.LayerNorm(embed_dim) 168 | 169 | self.apply(self._init_weights) 170 | 171 | def _init_weights(self, m): 172 | if isinstance(m, nn.Linear): 173 | trunc_normal_(m.weight, std=.02) 174 | if isinstance(m, nn.Linear) and m.bias is not None: 175 | nn.init.constant_(m.bias, 0) 176 | elif isinstance(m, nn.LayerNorm): 177 | nn.init.constant_(m.bias, 0) 178 | nn.init.constant_(m.weight, 1.0) 179 | elif isinstance(m, nn.Conv2d): 180 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 181 | fan_out //= m.groups 182 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 183 | if m.bias is not None: 184 | m.bias.data.zero_() 185 | 186 | def forward(self, x): 187 | x = self.proj(x) 188 | _, _, H, W = x.shape 189 | x = x.flatten(2).transpose(1, 2) 190 | x = self.norm(x) 191 | 192 | return x, H, W 193 | 194 | 195 | class LinearMLP(nn.Module): 196 | """ 197 | Linear Embedding 198 | """ 199 | def __init__(self, input_dim=2048, embed_dim=768): 200 | super().__init__() 201 | self.proj = nn.Linear(input_dim, embed_dim) 202 | 203 | def forward(self, x): 204 | x = x.flatten(2).transpose(1, 2) 205 | x = self.proj(x) 206 | return x 207 | 208 | 209 | 210 | class DWConv(nn.Module): 211 | def __init__(self, dim=768): 212 | super(DWConv, self).__init__() 213 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 214 | 215 | def forward(self, x, H, W): 216 | B, N, C = x.shape 217 | x = x.transpose(1, 2).view(B, C, H, W) 218 | x = self.dwconv(x) 219 | x = x.flatten(2).transpose(1, 2) 220 | 221 | return x 222 | 223 | 224 | class Segformer(nn.Module): 225 | def __init__( 226 | self, 227 | pretrained=None, 228 | img_size=1024, 229 | patch_size=4, 230 | in_chans=3, 231 | num_classes=19, 232 | embed_dims=[64, 128, 320, 512], 233 | num_heads=[1, 2, 5, 8], 234 | mlp_ratios=[4, 4, 4, 4], 235 | qkv_bias=True, 236 | qk_scale=None, 237 | drop_rate=0., 238 | attn_drop_rate=0., 239 | drop_path_rate=0., 240 | norm_layer=nn.LayerNorm, 241 | depths=[3, 6, 40, 3], 242 | sr_ratios=[8, 4, 2, 1], 243 | decoder_dim = 768 244 | ): 245 | super().__init__() 246 | self.num_classes = num_classes 247 | self.depths = depths 248 | 249 | # patch_embed 250 | self.patch_embed1 = OverlapPatchEmbed( 251 | img_size=img_size, 252 | patch_size=7, 253 | stride=4, 254 | in_chans=in_chans, 255 | embed_dim=embed_dims[0] 256 | ) 257 | self.patch_embed2 = OverlapPatchEmbed( 258 | img_size=img_size // 4, 259 | patch_size=3, 260 | stride=2, 261 | in_chans=embed_dims[0], 262 | embed_dim=embed_dims[1] 263 | ) 264 | self.patch_embed3 = OverlapPatchEmbed( 265 | img_size=img_size // 8, 266 | patch_size=3, 267 | stride=2, 268 | in_chans=embed_dims[1], 269 | embed_dim=embed_dims[2] 270 | ) 271 | self.patch_embed4 = OverlapPatchEmbed( 272 | img_size=img_size // 16, 273 | patch_size=3, 274 | stride=2, 275 | in_chans=embed_dims[2], 276 | embed_dim=embed_dims[3] 277 | ) 278 | 279 | # transformer encoder 280 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 281 | cur = 0 282 | self.block1 = nn.ModuleList([Block( 283 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 284 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[0]) 285 | for i in range(depths[0])]) 286 | self.norm1 = norm_layer(embed_dims[0]) 287 | 288 | cur += depths[0] 289 | self.block2 = nn.ModuleList([Block( 290 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 291 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[1]) 292 | for i in range(depths[1])]) 293 | self.norm2 = norm_layer(embed_dims[1]) 294 | 295 | cur += depths[1] 296 | self.block3 = nn.ModuleList([Block( 297 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 298 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[2]) 299 | for i in range(depths[2])]) 300 | self.norm3 = norm_layer(embed_dims[2]) 301 | 302 | cur += depths[2] 303 | self.block4 = nn.ModuleList([Block( 304 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 305 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[3]) 306 | for i in range(depths[3])]) 307 | self.norm4 = norm_layer(embed_dims[3]) 308 | 309 | # classification head 310 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 311 | 312 | 313 | # segmentation head 314 | self.linear_c4 = LinearMLP(input_dim=embed_dims[3], embed_dim=decoder_dim) 315 | self.linear_c3 = LinearMLP(input_dim=embed_dims[2], embed_dim=decoder_dim) 316 | self.linear_c2 = LinearMLP(input_dim=embed_dims[1], embed_dim=decoder_dim) 317 | self.linear_c1 = LinearMLP(input_dim=embed_dims[0], embed_dim=decoder_dim) 318 | self.linear_fuse = nn.Conv2d(4 * decoder_dim, decoder_dim, 1) 319 | self.linear_fuse_bn = nn.BatchNorm2d(decoder_dim) 320 | self.dropout = nn.Dropout2d(drop_rate) 321 | self.linear_pred = nn.Conv2d(decoder_dim, num_classes, kernel_size=1) 322 | 323 | self.apply(self._init_weights) 324 | self.init_weights(pretrained=pretrained) 325 | 326 | 327 | def _init_weights(self, m): 328 | if isinstance(m, nn.Linear): 329 | trunc_normal_(m.weight, std=.02) 330 | if isinstance(m, nn.Linear) and m.bias is not None: 331 | nn.init.constant_(m.bias, 0) 332 | elif isinstance(m, nn.LayerNorm): 333 | nn.init.constant_(m.bias, 0) 334 | nn.init.constant_(m.weight, 1.0) 335 | elif isinstance(m, nn.Conv2d): 336 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 337 | fan_out //= m.groups 338 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 339 | if m.bias is not None: 340 | m.bias.data.zero_() 341 | 342 | def init_weights(self, pretrained=None): 343 | if isinstance(pretrained, str): 344 | logger = get_root_logger() 345 | load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 346 | 347 | def reset_drop_path(self, drop_path_rate): 348 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 349 | cur = 0 350 | for i in range(self.depths[0]): 351 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 352 | 353 | cur += self.depths[0] 354 | for i in range(self.depths[1]): 355 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 356 | 357 | cur += self.depths[1] 358 | for i in range(self.depths[2]): 359 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 360 | 361 | cur += self.depths[2] 362 | for i in range(self.depths[3]): 363 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 364 | 365 | def freeze_patch_emb(self): 366 | self.patch_embed1.requires_grad = False 367 | 368 | @torch.jit.ignore 369 | def no_weight_decay(self): 370 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 371 | 372 | def get_classifier(self): 373 | return self.head 374 | 375 | def reset_classifier(self, num_classes, global_pool=''): 376 | self.num_classes = num_classes 377 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 378 | 379 | def forward_features(self, x): 380 | B = x.shape[0] 381 | outs = [] 382 | 383 | # stage 1 384 | x, H, W = self.patch_embed1(x) 385 | for i, blk in enumerate(self.block1): 386 | x = blk(x, H, W) 387 | x = self.norm1(x) 388 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 389 | outs.append(x) 390 | 391 | # stage 2 392 | x, H, W = self.patch_embed2(x) 393 | for i, blk in enumerate(self.block2): 394 | x = blk(x, H, W) 395 | x = self.norm2(x) 396 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 397 | outs.append(x) 398 | 399 | # stage 3 400 | x, H, W = self.patch_embed3(x) 401 | for i, blk in enumerate(self.block3): 402 | x = blk(x, H, W) 403 | x = self.norm3(x) 404 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 405 | outs.append(x) 406 | 407 | # stage 4 408 | x, H, W = self.patch_embed4(x) 409 | for i, blk in enumerate(self.block4): 410 | x = blk(x, H, W) 411 | x = self.norm4(x) 412 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 413 | outs.append(x) 414 | 415 | return outs 416 | 417 | def forward(self, x): 418 | x = self.forward_features(x) 419 | 420 | c1, c2, c3, c4 = x 421 | 422 | ############## MLP decoder on C1-C4 ########### 423 | n, _, h, w = c4.shape 424 | h_out, w_out = c1.size()[2], c1.size()[3] 425 | 426 | _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) 427 | _c4 = F.interpolate(_c4, size = c1.size()[2:], mode = 'bilinear', align_corners = False) 428 | 429 | _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) 430 | _c3 = F.interpolate(_c3, size = c1.size()[2:], mode = 'bilinear', align_corners = False) 431 | 432 | _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) 433 | _c2 = F.interpolate(_c2, size = c1.size()[2:], mode = 'bilinear', align_corners = False) 434 | 435 | _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) 436 | 437 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim = 1)) 438 | _c = self.linear_fuse_bn(_c) 439 | 440 | x = self.dropout(_c) 441 | x = self.linear_pred(x) 442 | 443 | x = F.interpolate(input = x, size = (h_out, w_out), mode = 'bilinear', align_corners = False) 444 | x = x.type(torch.float32) 445 | 446 | return x 447 | -------------------------------------------------------------------------------- /models/segformer_simple.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from functools import partial 3 | import torch 4 | from torch import nn, einsum 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange, reduce 8 | from einops.layers.torch import Rearrange 9 | 10 | # helpers 11 | 12 | def exists(val): 13 | return val is not None 14 | 15 | def cast_tuple(val, depth): 16 | return val if isinstance(val, tuple) else (val,) * depth 17 | 18 | # classes 19 | 20 | class DsConv2d(nn.Module): 21 | def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True): 22 | super().__init__() 23 | self.net = nn.Sequential( 24 | nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias=bias), 25 | nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) 26 | ) 27 | def forward(self, x): 28 | return self.net(x) 29 | 30 | 31 | class LayerNorm(nn.Module): 32 | def __init__(self, dim, eps = 1e-5): 33 | super().__init__() 34 | self.eps = eps 35 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) 36 | self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) 37 | 38 | def forward(self, x): 39 | std = torch.var(x, dim = 1, unbiased = False, keepdim = True).sqrt() 40 | mean = torch.mean(x, dim = 1, keepdim = True) 41 | return (x - mean) / (std + self.eps) * self.g + self.b 42 | 43 | 44 | class PreNorm(nn.Module): 45 | def __init__(self, dim, fn): 46 | super().__init__() 47 | self.fn = fn 48 | self.norm = LayerNorm(dim) 49 | 50 | def forward(self, x): 51 | return self.fn(self.norm(x)) 52 | 53 | 54 | class EfficientSelfAttention(nn.Module): 55 | def __init__( 56 | self, 57 | *, 58 | dim, 59 | heads, 60 | reduction_ratio 61 | ): 62 | super().__init__() 63 | self.scale = (dim // heads) ** -0.5 64 | self.heads = heads 65 | 66 | self.to_q = nn.Conv2d(dim, dim, 1, bias = False) 67 | self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride = reduction_ratio, bias = False) 68 | self.to_out = nn.Conv2d(dim, dim, 1, bias = False) 69 | 70 | def forward(self, x): 71 | h, w = x.shape[-2:] 72 | heads = self.heads 73 | 74 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1)) 75 | q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = heads), (q, k, v)) 76 | 77 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 78 | attn = sim.softmax(dim = -1) 79 | 80 | out = einsum('b i j, b j d -> b i d', attn, v) 81 | out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h = heads, x = h, y = w) 82 | return self.to_out(out) 83 | 84 | 85 | class MixFeedForward(nn.Module): 86 | def __init__( 87 | self, 88 | *, 89 | dim, 90 | expansion_factor 91 | ): 92 | super().__init__() 93 | hidden_dim = dim * expansion_factor 94 | self.net = nn.Sequential( 95 | nn.Conv2d(dim, hidden_dim, 1), 96 | DsConv2d(hidden_dim, hidden_dim, 3, padding = 1), 97 | nn.GELU(), 98 | nn.Conv2d(hidden_dim, dim, 1) 99 | ) 100 | 101 | def forward(self, x): 102 | return self.net(x) 103 | 104 | 105 | class MiT(nn.Module): 106 | def __init__( 107 | self, 108 | *, 109 | channels, 110 | dims, 111 | heads, 112 | ff_expansion, 113 | reduction_ratio, 114 | num_layers 115 | ): 116 | super().__init__() 117 | stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1)) 118 | 119 | dims = (channels, *dims) 120 | dim_pairs = list(zip(dims[:-1], dims[1:])) 121 | 122 | self.stages = nn.ModuleList([]) 123 | 124 | for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip( 125 | dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio): 126 | 127 | get_overlap_patches = nn.Unfold(kernel, stride = stride, padding = padding) 128 | overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1) 129 | 130 | layers = nn.ModuleList([]) 131 | 132 | for _ in range(num_layers): 133 | layers.append(nn.ModuleList([ 134 | PreNorm(dim_out, EfficientSelfAttention(dim = dim_out, heads = heads, reduction_ratio = reduction_ratio)), 135 | PreNorm(dim_out, MixFeedForward(dim = dim_out, expansion_factor = ff_expansion)), 136 | ])) 137 | 138 | self.stages.append(nn.ModuleList([ 139 | get_overlap_patches, 140 | overlap_patch_embed, 141 | layers 142 | ])) 143 | 144 | def forward( 145 | self, 146 | x, 147 | return_layer_outputs = False 148 | ): 149 | h, w = x.shape[-2:] 150 | 151 | layer_outputs = [] 152 | for (get_overlap_patches, overlap_embed, layers) in self.stages: 153 | x = get_overlap_patches(x) 154 | 155 | num_patches = x.shape[-1] 156 | ratio = int(sqrt((h * w) / num_patches)) 157 | x = rearrange(x, 'b c (h w) -> b c h w', h = h // ratio) 158 | 159 | x = overlap_embed(x) 160 | for (attn, ff) in layers: 161 | x = attn(x) + x 162 | x = ff(x) + x 163 | 164 | layer_outputs.append(x) 165 | 166 | ret = x if not return_layer_outputs else layer_outputs 167 | return ret 168 | 169 | 170 | class Segformer(nn.Module): 171 | def __init__( 172 | self, 173 | *, 174 | dims = (32, 64, 160, 256), 175 | heads = (1, 2, 5, 8), 176 | ff_expansion = (8, 8, 4, 4), 177 | reduction_ratio = (8, 4, 2, 1), 178 | num_layers = 2, 179 | channels = 3, 180 | decoder_dim = 256, 181 | num_classes = 19 182 | ): 183 | super().__init__() 184 | dims, heads, ff_expansion, reduction_ratio, num_layers = map( 185 | partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers)) 186 | assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), \ 187 | 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values' 188 | 189 | self.mit = MiT( 190 | channels = channels, 191 | dims = dims, 192 | heads = heads, 193 | ff_expansion = ff_expansion, 194 | reduction_ratio = reduction_ratio, 195 | num_layers = num_layers 196 | ) 197 | 198 | self.to_fused = nn.ModuleList([nn.Sequential( 199 | nn.Conv2d(dim, decoder_dim, 1), 200 | nn.Upsample(scale_factor = 2 ** i) 201 | ) for i, dim in enumerate(dims)]) 202 | 203 | self.to_segmentation = nn.Sequential( 204 | nn.Conv2d(4 * decoder_dim, decoder_dim, 1), 205 | nn.Conv2d(decoder_dim, num_classes, 1), 206 | ) 207 | 208 | def forward(self, x): 209 | layer_outputs = self.mit(x, return_layer_outputs = True) 210 | 211 | fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)] 212 | fused = torch.cat(fused, dim = 1) 213 | return self.to_segmentation(fused) 214 | -------------------------------------------------------------------------------- /models/segformer_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlaedtke/segmentation_pytorch/3818a88d4f40b51cf149b1e04a7b94f889c358a9/models/segformer_utils/__init__.py -------------------------------------------------------------------------------- /models/segformer_utils/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from mmseg.core import add_prefix 6 | from mmseg.ops import resize 7 | from .. import builder 8 | from ..builder import SEGMENTORS 9 | from .base import BaseSegmentor 10 | 11 | 12 | @SEGMENTORS.register_module() 13 | class EncoderDecoder(BaseSegmentor): 14 | """Encoder Decoder segmentors. 15 | EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. 16 | Note that auxiliary_head is only used for deep supervision during training, 17 | which could be dumped during inference. 18 | """ 19 | 20 | def __init__(self, 21 | backbone, 22 | decode_head, 23 | neck=None, 24 | auxiliary_head=None, 25 | train_cfg=None, 26 | test_cfg=None, 27 | pretrained=None): 28 | super(EncoderDecoder, self).__init__() 29 | self.backbone = builder.build_backbone(backbone) 30 | if neck is not None: 31 | self.neck = builder.build_neck(neck) 32 | self._init_decode_head(decode_head) 33 | self._init_auxiliary_head(auxiliary_head) 34 | 35 | self.train_cfg = train_cfg 36 | self.test_cfg = test_cfg 37 | 38 | self.init_weights(pretrained=pretrained) 39 | 40 | assert self.with_decode_head 41 | 42 | def _init_decode_head(self, decode_head): 43 | """Initialize ``decode_head``""" 44 | self.decode_head = builder.build_head(decode_head) 45 | self.align_corners = self.decode_head.align_corners 46 | self.num_classes = self.decode_head.num_classes 47 | 48 | def _init_auxiliary_head(self, auxiliary_head): 49 | """Initialize ``auxiliary_head``""" 50 | if auxiliary_head is not None: 51 | if isinstance(auxiliary_head, list): 52 | self.auxiliary_head = nn.ModuleList() 53 | for head_cfg in auxiliary_head: 54 | self.auxiliary_head.append(builder.build_head(head_cfg)) 55 | else: 56 | self.auxiliary_head = builder.build_head(auxiliary_head) 57 | 58 | def init_weights(self, pretrained=None): 59 | """Initialize the weights in backbone and heads. 60 | Args: 61 | pretrained (str, optional): Path to pre-trained weights. 62 | Defaults to None. 63 | """ 64 | 65 | super(EncoderDecoder, self).init_weights(pretrained) 66 | self.backbone.init_weights(pretrained=pretrained) 67 | self.decode_head.init_weights() 68 | if self.with_auxiliary_head: 69 | if isinstance(self.auxiliary_head, nn.ModuleList): 70 | for aux_head in self.auxiliary_head: 71 | aux_head.init_weights() 72 | else: 73 | self.auxiliary_head.init_weights() 74 | 75 | def extract_feat(self, img): 76 | """Extract features from images.""" 77 | x = self.backbone(img) 78 | if self.with_neck: 79 | x = self.neck(x) 80 | return x 81 | 82 | def encode_decode(self, img, img_metas): 83 | """Encode images with backbone and decode into a semantic segmentation 84 | map of the same size as input.""" 85 | x = self.extract_feat(img) 86 | out = self._decode_head_forward_test(x, img_metas) 87 | out = resize( 88 | input=out, 89 | size=img.shape[2:], 90 | mode='bilinear', 91 | align_corners=self.align_corners) 92 | return out 93 | 94 | def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg): 95 | """Run forward function and calculate loss for decode head in 96 | training.""" 97 | losses = dict() 98 | loss_decode = self.decode_head.forward_train(x, img_metas, 99 | gt_semantic_seg, 100 | self.train_cfg) 101 | 102 | losses.update(add_prefix(loss_decode, 'decode')) 103 | return losses 104 | 105 | def _decode_head_forward_test(self, x, img_metas): 106 | """Run forward function and calculate loss for decode head in 107 | inference.""" 108 | seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) 109 | return seg_logits 110 | 111 | def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): 112 | """Run forward function and calculate loss for auxiliary head in 113 | training.""" 114 | losses = dict() 115 | if isinstance(self.auxiliary_head, nn.ModuleList): 116 | for idx, aux_head in enumerate(self.auxiliary_head): 117 | loss_aux = aux_head.forward_train(x, img_metas, 118 | gt_semantic_seg, 119 | self.train_cfg) 120 | losses.update(add_prefix(loss_aux, f'aux_{idx}')) 121 | else: 122 | loss_aux = self.auxiliary_head.forward_train( 123 | x, img_metas, gt_semantic_seg, self.train_cfg) 124 | losses.update(add_prefix(loss_aux, 'aux')) 125 | 126 | return losses 127 | 128 | def forward_dummy(self, img): 129 | """Dummy forward function.""" 130 | seg_logit = self.encode_decode(img, None) 131 | 132 | return seg_logit 133 | 134 | def forward_train(self, img, img_metas, gt_semantic_seg): 135 | """Forward function for training. 136 | Args: 137 | img (Tensor): Input images. 138 | img_metas (list[dict]): List of image info dict where each dict 139 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 140 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 141 | For details on the values of these keys see 142 | `mmseg/datasets/pipelines/formatting.py:Collect`. 143 | gt_semantic_seg (Tensor): Semantic segmentation masks 144 | used if the architecture supports semantic segmentation task. 145 | Returns: 146 | dict[str, Tensor]: a dictionary of loss components 147 | """ 148 | 149 | x = self.extract_feat(img) 150 | 151 | losses = dict() 152 | 153 | loss_decode = self._decode_head_forward_train(x, img_metas, 154 | gt_semantic_seg) 155 | losses.update(loss_decode) 156 | 157 | if self.with_auxiliary_head: 158 | loss_aux = self._auxiliary_head_forward_train( 159 | x, img_metas, gt_semantic_seg) 160 | losses.update(loss_aux) 161 | 162 | return losses 163 | 164 | # TODO refactor 165 | def slide_inference(self, img, img_meta, rescale): 166 | """Inference by sliding-window with overlap. 167 | If h_crop > h_img or w_crop > w_img, the small patch will be used to 168 | decode without padding. 169 | """ 170 | 171 | h_stride, w_stride = self.test_cfg.stride 172 | h_crop, w_crop = self.test_cfg.crop_size 173 | batch_size, _, h_img, w_img = img.size() 174 | num_classes = self.num_classes 175 | h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 176 | w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 177 | preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) 178 | count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) 179 | for h_idx in range(h_grids): 180 | for w_idx in range(w_grids): 181 | y1 = h_idx * h_stride 182 | x1 = w_idx * w_stride 183 | y2 = min(y1 + h_crop, h_img) 184 | x2 = min(x1 + w_crop, w_img) 185 | y1 = max(y2 - h_crop, 0) 186 | x1 = max(x2 - w_crop, 0) 187 | crop_img = img[:, :, y1:y2, x1:x2] 188 | crop_seg_logit = self.encode_decode(crop_img, img_meta) 189 | preds += F.pad(crop_seg_logit, 190 | (int(x1), int(preds.shape[3] - x2), int(y1), 191 | int(preds.shape[2] - y2))) 192 | 193 | count_mat[:, :, y1:y2, x1:x2] += 1 194 | assert (count_mat == 0).sum() == 0 195 | if torch.onnx.is_in_onnx_export(): 196 | # cast count_mat to constant while exporting to ONNX 197 | count_mat = torch.from_numpy( 198 | count_mat.cpu().detach().numpy()).to(device=img.device) 199 | preds = preds / count_mat 200 | if rescale: 201 | preds = resize( 202 | preds, 203 | size=img_meta[0]['ori_shape'][:2], 204 | mode='bilinear', 205 | align_corners=self.align_corners, 206 | warning=False) 207 | return preds 208 | 209 | def whole_inference(self, img, img_meta, rescale): 210 | """Inference with full image.""" 211 | 212 | seg_logit = self.encode_decode(img, img_meta) 213 | if rescale: 214 | seg_logit = resize( 215 | seg_logit, 216 | size=img_meta[0]['ori_shape'][:2], 217 | mode='bilinear', 218 | align_corners=self.align_corners, 219 | warning=False) 220 | 221 | return seg_logit 222 | 223 | def inference(self, img, img_meta, rescale): 224 | """Inference with slide/whole style. 225 | Args: 226 | img (Tensor): The input image of shape (N, 3, H, W). 227 | img_meta (dict): Image info dict where each dict has: 'img_shape', 228 | 'scale_factor', 'flip', and may also contain 229 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 230 | For details on the values of these keys see 231 | `mmseg/datasets/pipelines/formatting.py:Collect`. 232 | rescale (bool): Whether rescale back to original shape. 233 | Returns: 234 | Tensor: The output segmentation map. 235 | """ 236 | 237 | assert self.test_cfg.mode in ['slide', 'whole'] 238 | ori_shape = img_meta[0]['ori_shape'] 239 | assert all(_['ori_shape'] == ori_shape for _ in img_meta) 240 | if self.test_cfg.mode == 'slide': 241 | seg_logit = self.slide_inference(img, img_meta, rescale) 242 | else: 243 | seg_logit = self.whole_inference(img, img_meta, rescale) 244 | output = F.softmax(seg_logit, dim=1) 245 | flip = img_meta[0]['flip'] 246 | if flip: 247 | flip_direction = img_meta[0]['flip_direction'] 248 | assert flip_direction in ['horizontal', 'vertical'] 249 | if flip_direction == 'horizontal': 250 | output = output.flip(dims=(3, )) 251 | elif flip_direction == 'vertical': 252 | output = output.flip(dims=(2, )) 253 | 254 | return output 255 | 256 | def simple_test(self, img, img_meta, rescale=True): 257 | """Simple test with single image.""" 258 | seg_logit = self.inference(img, img_meta, rescale) 259 | seg_pred = seg_logit.argmax(dim=1) 260 | if torch.onnx.is_in_onnx_export(): 261 | # our inference backend only support 4D output 262 | seg_pred = seg_pred.unsqueeze(0) 263 | return seg_pred 264 | seg_pred = seg_pred.cpu().numpy() 265 | # unravel batch dim 266 | seg_pred = list(seg_pred) 267 | return seg_pred 268 | 269 | def aug_test(self, imgs, img_metas, rescale=True): 270 | """Test with augmentations. 271 | Only rescale=True is supported. 272 | """ 273 | # aug_test rescale all imgs back to ori_shape for now 274 | assert rescale 275 | # to save memory, we get augmented seg logit inplace 276 | seg_logit = self.inference(imgs[0], img_metas[0], rescale) 277 | for i in range(1, len(imgs)): 278 | cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) 279 | seg_logit += cur_seg_logit 280 | seg_logit /= len(imgs) 281 | seg_pred = seg_logit.argmax(dim=1) 282 | seg_pred = seg_pred.cpu().numpy() 283 | # unravel batch dim 284 | seg_pred = list(seg_pred) 285 | return seg_pred -------------------------------------------------------------------------------- /models/segformer_utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mmcv.utils import get_logger 4 | 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO): 7 | """Get the root logger. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmseg". 12 | Args: 13 | log_file (str | None): The log filename. If specified, a FileHandler 14 | will be added to the root logger. 15 | log_level (int): The root logger level. Note that only the process of 16 | rank 0 is affected, while other processes will set the level to 17 | "Error" and be silent most of the time. 18 | Returns: 19 | logging.Logger: The root logger. 20 | """ 21 | 22 | logger = get_logger(name='mmseg', log_file=log_file, log_level=log_level) 23 | 24 | return logger 25 | 26 | def print_log(msg, logger=None, level=logging.INFO): 27 | """Print a log message. 28 | Args: 29 | msg (str): The message to be logged. 30 | logger (logging.Logger | str | None): The logger to be used. Some 31 | special loggers are: 32 | - "root": the root logger obtained with `get_root_logger()`. 33 | - "silent": no message will be printed. 34 | - None: The `print()` method will be used to print log messages. 35 | level (int): Logging level. Only available when `logger` is a Logger 36 | object or "root". 37 | """ 38 | if logger is None: 39 | print(msg) 40 | elif logger == 'root': 41 | _logger = get_root_logger() 42 | _logger.log(level, msg) 43 | elif isinstance(logger, logging.Logger): 44 | logger.log(level, msg) 45 | elif logger != 'silent': 46 | raise TypeError( 47 | 'logger should be either a logging.Logger object, "root", ' 48 | '"silent" or None, but got {}'.format(logger)) -------------------------------------------------------------------------------- /models/segformer_utils/mix_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | # from mmseg.models.builder import BACKBONES 7 | # from mmseg.utils import get_root_logger 8 | from mmcv.runner import load_checkpoint 9 | import math 10 | 11 | 12 | class Mlp(nn.Module): 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | self.fc1 = nn.Linear(in_features, hidden_features) 18 | self.dwconv = DWConv(hidden_features) 19 | self.act = act_layer() 20 | self.fc2 = nn.Linear(hidden_features, out_features) 21 | self.drop = nn.Dropout(drop) 22 | 23 | self.apply(self._init_weights) 24 | 25 | def _init_weights(self, m): 26 | if isinstance(m, nn.Linear): 27 | trunc_normal_(m.weight, std=.02) 28 | if isinstance(m, nn.Linear) and m.bias is not None: 29 | nn.init.constant_(m.bias, 0) 30 | elif isinstance(m, nn.LayerNorm): 31 | nn.init.constant_(m.bias, 0) 32 | nn.init.constant_(m.weight, 1.0) 33 | elif isinstance(m, nn.Conv2d): 34 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 35 | fan_out //= m.groups 36 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 37 | if m.bias is not None: 38 | m.bias.data.zero_() 39 | 40 | def forward(self, x, H, W): 41 | x = self.fc1(x) 42 | x = self.dwconv(x, H, W) 43 | x = self.act(x) 44 | x = self.drop(x) 45 | x = self.fc2(x) 46 | x = self.drop(x) 47 | return x 48 | 49 | 50 | class Attention(nn.Module): 51 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 52 | super().__init__() 53 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 54 | 55 | self.dim = dim 56 | self.num_heads = num_heads 57 | head_dim = dim // num_heads 58 | self.scale = qk_scale or head_dim ** -0.5 59 | 60 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 61 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | self.sr_ratio = sr_ratio 67 | if sr_ratio > 1: 68 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 69 | self.norm = nn.LayerNorm(dim) 70 | 71 | self.apply(self._init_weights) 72 | 73 | def _init_weights(self, m): 74 | if isinstance(m, nn.Linear): 75 | trunc_normal_(m.weight, std=.02) 76 | if isinstance(m, nn.Linear) and m.bias is not None: 77 | nn.init.constant_(m.bias, 0) 78 | elif isinstance(m, nn.LayerNorm): 79 | nn.init.constant_(m.bias, 0) 80 | nn.init.constant_(m.weight, 1.0) 81 | elif isinstance(m, nn.Conv2d): 82 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 83 | fan_out //= m.groups 84 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 85 | if m.bias is not None: 86 | m.bias.data.zero_() 87 | 88 | def forward(self, x, H, W): 89 | B, N, C = x.shape 90 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 91 | 92 | if self.sr_ratio > 1: 93 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 94 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 95 | x_ = self.norm(x_) 96 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 97 | else: 98 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 99 | k, v = kv[0], kv[1] 100 | 101 | attn = (q @ k.transpose(-2, -1)) * self.scale 102 | attn = attn.softmax(dim=-1) 103 | attn = self.attn_drop(attn) 104 | 105 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 106 | x = self.proj(x) 107 | x = self.proj_drop(x) 108 | 109 | return x 110 | 111 | 112 | class Block(nn.Module): 113 | 114 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 115 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 116 | super().__init__() 117 | self.norm1 = norm_layer(dim) 118 | self.attn = Attention( 119 | dim, 120 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 121 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 122 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 123 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 124 | self.norm2 = norm_layer(dim) 125 | mlp_hidden_dim = int(dim * mlp_ratio) 126 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 127 | 128 | self.apply(self._init_weights) 129 | 130 | def _init_weights(self, m): 131 | if isinstance(m, nn.Linear): 132 | trunc_normal_(m.weight, std=.02) 133 | if isinstance(m, nn.Linear) and m.bias is not None: 134 | nn.init.constant_(m.bias, 0) 135 | elif isinstance(m, nn.LayerNorm): 136 | nn.init.constant_(m.bias, 0) 137 | nn.init.constant_(m.weight, 1.0) 138 | elif isinstance(m, nn.Conv2d): 139 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | fan_out //= m.groups 141 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 142 | if m.bias is not None: 143 | m.bias.data.zero_() 144 | 145 | def forward(self, x, H, W): 146 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 147 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 148 | 149 | return x 150 | 151 | 152 | class OverlapPatchEmbed(nn.Module): 153 | """ Image to Patch Embedding 154 | """ 155 | 156 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 157 | super().__init__() 158 | img_size = to_2tuple(img_size) 159 | patch_size = to_2tuple(patch_size) 160 | 161 | self.img_size = img_size 162 | self.patch_size = patch_size 163 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 164 | self.num_patches = self.H * self.W 165 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 166 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 167 | self.norm = nn.LayerNorm(embed_dim) 168 | 169 | self.apply(self._init_weights) 170 | 171 | def _init_weights(self, m): 172 | if isinstance(m, nn.Linear): 173 | trunc_normal_(m.weight, std=.02) 174 | if isinstance(m, nn.Linear) and m.bias is not None: 175 | nn.init.constant_(m.bias, 0) 176 | elif isinstance(m, nn.LayerNorm): 177 | nn.init.constant_(m.bias, 0) 178 | nn.init.constant_(m.weight, 1.0) 179 | elif isinstance(m, nn.Conv2d): 180 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 181 | fan_out //= m.groups 182 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 183 | if m.bias is not None: 184 | m.bias.data.zero_() 185 | 186 | def forward(self, x): 187 | x = self.proj(x) 188 | _, _, H, W = x.shape 189 | x = x.flatten(2).transpose(1, 2) 190 | x = self.norm(x) 191 | 192 | return x, H, W 193 | 194 | 195 | class MixVisionTransformer(nn.Module): 196 | def __init__(self, 197 | img_size=102, 198 | patch_size=4, 199 | in_chans=3, 200 | num_classes=1000, 201 | embed_dims=[64, 128, 320, 512], 202 | num_heads=[1, 2, 5, 8], 203 | mlp_ratios=[4, 4, 4, 4], 204 | qkv_bias=False, 205 | qk_scale=None, 206 | drop_rate=0., 207 | attn_drop_rate=0., 208 | drop_path_rate=0., 209 | norm_layer=nn.LayerNorm, 210 | depths=[3, 6, 40, 3], 211 | sr_ratios=[8, 4, 2, 1]): 212 | super().__init__() 213 | self.num_classes = num_classes 214 | self.depths = depths 215 | 216 | # patch_embed 217 | self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans, 218 | embed_dim=embed_dims[0]) 219 | self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], 220 | embed_dim=embed_dims[1]) 221 | self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], 222 | embed_dim=embed_dims[2]) 223 | self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2], 224 | embed_dim=embed_dims[3]) 225 | 226 | # transformer encoder 227 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 228 | cur = 0 229 | self.block1 = nn.ModuleList([Block( 230 | dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 231 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 232 | sr_ratio=sr_ratios[0]) 233 | for i in range(depths[0])]) 234 | self.norm1 = norm_layer(embed_dims[0]) 235 | 236 | cur += depths[0] 237 | self.block2 = nn.ModuleList([Block( 238 | dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 239 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 240 | sr_ratio=sr_ratios[1]) 241 | for i in range(depths[1])]) 242 | self.norm2 = norm_layer(embed_dims[1]) 243 | 244 | cur += depths[1] 245 | self.block3 = nn.ModuleList([Block( 246 | dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 247 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 248 | sr_ratio=sr_ratios[2]) 249 | for i in range(depths[2])]) 250 | self.norm3 = norm_layer(embed_dims[2]) 251 | 252 | cur += depths[2] 253 | self.block4 = nn.ModuleList([Block( 254 | dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 255 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 256 | sr_ratio=sr_ratios[3]) 257 | for i in range(depths[3])]) 258 | self.norm4 = norm_layer(embed_dims[3]) 259 | 260 | # classification head 261 | # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 262 | 263 | self.apply(self._init_weights) 264 | 265 | def _init_weights(self, m): 266 | if isinstance(m, nn.Linear): 267 | trunc_normal_(m.weight, std=.02) 268 | if isinstance(m, nn.Linear) and m.bias is not None: 269 | nn.init.constant_(m.bias, 0) 270 | elif isinstance(m, nn.LayerNorm): 271 | nn.init.constant_(m.bias, 0) 272 | nn.init.constant_(m.weight, 1.0) 273 | elif isinstance(m, nn.Conv2d): 274 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 275 | fan_out //= m.groups 276 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 277 | if m.bias is not None: 278 | m.bias.data.zero_() 279 | 280 | def init_weights(self, pretrained=None): 281 | if isinstance(pretrained, str): 282 | logger = get_root_logger() 283 | load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) 284 | 285 | def reset_drop_path(self, drop_path_rate): 286 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 287 | cur = 0 288 | for i in range(self.depths[0]): 289 | self.block1[i].drop_path.drop_prob = dpr[cur + i] 290 | 291 | cur += self.depths[0] 292 | for i in range(self.depths[1]): 293 | self.block2[i].drop_path.drop_prob = dpr[cur + i] 294 | 295 | cur += self.depths[1] 296 | for i in range(self.depths[2]): 297 | self.block3[i].drop_path.drop_prob = dpr[cur + i] 298 | 299 | cur += self.depths[2] 300 | for i in range(self.depths[3]): 301 | self.block4[i].drop_path.drop_prob = dpr[cur + i] 302 | 303 | def freeze_patch_emb(self): 304 | self.patch_embed1.requires_grad = False 305 | 306 | @torch.jit.ignore 307 | def no_weight_decay(self): 308 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 309 | 310 | def get_classifier(self): 311 | return self.head 312 | 313 | def reset_classifier(self, num_classes, global_pool=''): 314 | self.num_classes = num_classes 315 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 316 | 317 | def forward_features(self, x): 318 | B = x.shape[0] 319 | outs = [] 320 | 321 | # stage 1 322 | x, H, W = self.patch_embed1(x) 323 | for i, blk in enumerate(self.block1): 324 | x = blk(x, H, W) 325 | x = self.norm1(x) 326 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 327 | outs.append(x) 328 | 329 | # stage 2 330 | x, H, W = self.patch_embed2(x) 331 | for i, blk in enumerate(self.block2): 332 | x = blk(x, H, W) 333 | x = self.norm2(x) 334 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 335 | outs.append(x) 336 | 337 | # stage 3 338 | x, H, W = self.patch_embed3(x) 339 | for i, blk in enumerate(self.block3): 340 | x = blk(x, H, W) 341 | x = self.norm3(x) 342 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 343 | outs.append(x) 344 | 345 | # stage 4 346 | x, H, W = self.patch_embed4(x) 347 | for i, blk in enumerate(self.block4): 348 | x = blk(x, H, W) 349 | x = self.norm4(x) 350 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 351 | outs.append(x) 352 | 353 | return outs 354 | 355 | def forward(self, x): 356 | x = self.forward_features(x) 357 | # x = self.head(x) 358 | 359 | return x 360 | 361 | 362 | class DWConv(nn.Module): 363 | def __init__(self, dim=768): 364 | super(DWConv, self).__init__() 365 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 366 | 367 | def forward(self, x, H, W): 368 | B, N, C = x.shape 369 | x = x.transpose(1, 2).view(B, C, H, W) 370 | x = self.dwconv(x) 371 | x = x.flatten(2).transpose(1, 2) 372 | 373 | return x 374 | 375 | 376 | 377 | class mit_b0(MixVisionTransformer): 378 | def __init__(self, **kwargs): 379 | super(mit_b0, self).__init__( 380 | patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 381 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 382 | drop_rate=0.0, drop_path_rate=0.1) 383 | 384 | 385 | class mit_b1(MixVisionTransformer): 386 | def __init__(self, **kwargs): 387 | super(mit_b1, self).__init__( 388 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 389 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 390 | drop_rate=0.0, drop_path_rate=0.1) 391 | 392 | 393 | class mit_b2(MixVisionTransformer): 394 | def __init__(self, **kwargs): 395 | super(mit_b2, self).__init__( 396 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 397 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], 398 | drop_rate=0.0, drop_path_rate=0.1) 399 | 400 | 401 | class mit_b3(MixVisionTransformer): 402 | def __init__(self, **kwargs): 403 | super(mit_b3, self).__init__( 404 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 405 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 406 | drop_rate=0.0, drop_path_rate=0.1) 407 | 408 | 409 | class mit_b4(MixVisionTransformer): 410 | def __init__(self, **kwargs): 411 | super(mit_b4, self).__init__( 412 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 413 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 414 | drop_rate=0.0, drop_path_rate=0.1) 415 | 416 | 417 | class mit_b5(MixVisionTransformer): 418 | def __init__(self, **kwargs): 419 | super(mit_b5, self).__init__( 420 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4], 421 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1], 422 | drop_rate=0.0, drop_path_rate=0.1) 423 | -------------------------------------------------------------------------------- /models/segformer_utils/segformer_build.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/segformer.py', 3 | '../../_base_/datasets/cityscapes_1024x1024_repeat.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k_adamw.py' 6 | ] 7 | 8 | # model settings 9 | norm_cfg = dict(type='SyncBN', requires_grad=True) 10 | find_unused_parameters = True 11 | model = dict( 12 | type='EncoderDecoder', 13 | pretrained='pretrained/mit_b5.pth', 14 | backbone=dict( 15 | type='mit_b5', 16 | style='pytorch'), 17 | decode_head=dict( 18 | type='SegFormerHead', 19 | in_channels=[64, 128, 320, 512], 20 | in_index=[0, 1, 2, 3], 21 | feature_strides=[4, 8, 16, 32], 22 | channels=128, 23 | dropout_ratio=0.1, 24 | num_classes=19, 25 | norm_cfg=norm_cfg, 26 | align_corners=False, 27 | decoder_params=dict(embed_dim=768), 28 | loss_decode=dict(type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 29 | # model training and testing settings 30 | train_cfg=dict(), 31 | # test_cfg=dict(mode='whole')) 32 | test_cfg=dict(mode='slide', crop_size=(1024,1024), stride=(768,768))) 33 | 34 | # data 35 | data = dict(samples_per_gpu=1) 36 | evaluation = dict(interval=4000, metric='mIoU') 37 | 38 | # optimizer 39 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 40 | paramwise_cfg=dict(custom_keys={'pos_block': dict(decay_mult=0.), 41 | 'norm': dict(decay_mult=0.), 42 | 'head': dict(lr_mult=10.) 43 | })) 44 | 45 | lr_config = dict(_delete_=True, policy='poly', 46 | warmup='linear', 47 | warmup_iters=1500, 48 | warmup_ratio=1e-6, 49 | power=1.0, min_lr=0.0, by_epoch=False) -------------------------------------------------------------------------------- /models/segformer_utils/segformer_head.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule 5 | from collections import OrderedDict 6 | 7 | from mmseg.ops import resize 8 | from ..builder import HEADS 9 | from .decode_head import BaseDecodeHead 10 | from mmseg.models.utils import * 11 | import attr 12 | 13 | from IPython import embed 14 | 15 | class MLP(nn.Module): 16 | """ 17 | Linear Embedding 18 | """ 19 | def __init__(self, input_dim=2048, embed_dim=768): 20 | super().__init__() 21 | self.proj = nn.Linear(input_dim, embed_dim) 22 | 23 | def forward(self, x): 24 | x = x.flatten(2).transpose(1, 2) 25 | x = self.proj(x) 26 | return x 27 | 28 | 29 | @HEADS.register_module() 30 | class SegFormerHead(BaseDecodeHead): 31 | """ 32 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 33 | """ 34 | def __init__(self, feature_strides, **kwargs): 35 | super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs) 36 | assert len(feature_strides) == len(self.in_channels) 37 | assert min(feature_strides) == feature_strides[0] 38 | self.feature_strides = feature_strides 39 | 40 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 41 | 42 | decoder_params = kwargs['decoder_params'] 43 | embedding_dim = decoder_params['embed_dim'] 44 | 45 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 46 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 47 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 48 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 49 | 50 | self.linear_fuse = ConvModule( 51 | in_channels=embedding_dim*4, 52 | out_channels=embedding_dim, 53 | kernel_size=1, 54 | norm_cfg=dict(type='SyncBN', requires_grad=True) 55 | ) 56 | 57 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 58 | 59 | def forward(self, inputs): 60 | x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32 61 | c1, c2, c3, c4 = x 62 | 63 | ############## MLP decoder on C1-C4 ########### 64 | n, _, h, w = c4.shape 65 | 66 | _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) 67 | _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) 68 | 69 | _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) 70 | _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) 71 | 72 | _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) 73 | _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) 74 | 75 | _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) 76 | 77 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 78 | 79 | x = self.dropout(_c) 80 | x = self.linear_pred(x) 81 | 82 | return x -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | addict==2.4.0 3 | alabaster==0.7.12 4 | albumentations==1.0.3 5 | apex==0.1 6 | appdirs==1.4.4 7 | argon2-cffi==20.1.0 8 | asgiref==3.4.1 9 | attrs==21.2.0 10 | audioread==2.1.9 11 | Babel==2.9.1 12 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 13 | backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work 14 | beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1601745390275/work 15 | bleach==4.0.0 16 | blis @ file:///home/conda/feedstock_root/build_artifacts/cython-blis_1607338147605/work 17 | brotlipy==0.7.0 18 | cachetools==4.2.2 19 | catalogue @ file:///home/conda/feedstock_root/build_artifacts/catalogue_1619778554042/work 20 | certifi==2021.5.30 21 | cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1625835293160/work 22 | chardet @ file:///home/conda/feedstock_root/build_artifacts/chardet_1610093490430/work 23 | charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1626371162869/work 24 | click==7.1.2 25 | codecov==2.1.12 26 | colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1602866480661/work 27 | conda==4.10.3 28 | conda-build==3.21.4 29 | conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1618231394280/work 30 | coverage==5.5 31 | cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1616851476134/work 32 | cycler==0.10.0 33 | cymem @ file:///home/conda/feedstock_root/build_artifacts/cymem_1625012228281/work 34 | Cython==0.29.24 35 | dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work 36 | debugpy==1.4.1 37 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1621187651333/work 38 | defusedxml==0.7.1 39 | Django==3.2.5 40 | docutils==0.16 41 | entrypoints==0.3 42 | expecttest==0.1.3 43 | filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1589994591731/work 44 | flake8==3.7.9 45 | Flask==2.0.1 46 | future==0.18.2 47 | glob2==0.7 48 | google-auth==1.34.0 49 | google-auth-oauthlib==0.4.5 50 | graphsurgeon @ file:///workspace/TensorRT-8.0.1.6/graphsurgeon/graphsurgeon-0.4.5-py2.py3-none-any.whl 51 | graphviz==0.17 52 | grpcio==1.39.0 53 | gunicorn==20.1.0 54 | h11==0.12.0 55 | httptools==0.2.0 56 | hypothesis==4.50.8 57 | idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1609836280497/work 58 | imageio==2.9.0 59 | imagesize==1.2.0 60 | iniconfig==1.1.1 61 | iopath==0.1.9 62 | ipykernel==6.2.0 63 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1627911641747/work 64 | ipython-genutils==0.2.0 65 | itsdangerous==2.0.1 66 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1610146791023/work 67 | Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1621419064915/work 68 | joblib==1.0.1 69 | json5==0.9.6 70 | jsonschema==3.2.0 71 | jupyter-client==6.1.12 72 | jupyter-core==4.7.1 73 | jupyter-tensorboard @ git+https://github.com/cliffwoolley/jupyter_tensorboard.git@ffa7e26138b82549453306e06b535a9ac36db17a 74 | jupyterlab==2.3.1 75 | jupyterlab-pygments==0.1.2 76 | jupyterlab-server==1.2.0 77 | jupytext==1.11.4 78 | kiwisolver==1.3.1 79 | libarchive-c @ file:///home/conda/feedstock_root/build_artifacts/python-libarchive-c_1622603311319/work/dist 80 | librosa==0.8.1 81 | llvmlite==0.35.0 82 | lmdb==1.2.1 83 | Markdown==3.3.4 84 | markdown-it-py==1.1.0 85 | MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1621455677251/work 86 | matplotlib==3.4.3 87 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1618935594181/work 88 | mccabe==0.6.1 89 | mdit-py-plugins==0.2.8 90 | mistune==0.8.4 91 | mmcv==1.3.12 92 | mock @ file:///home/conda/feedstock_root/build_artifacts/mock_1610094566888/work 93 | murmurhash @ file:///home/conda/feedstock_root/build_artifacts/murmurhash_1607334246442/work 94 | nbclient==0.5.4 95 | nbconvert==6.1.0 96 | nbformat==5.1.3 97 | nest-asyncio==1.5.1 98 | networkx==2.0 99 | nltk==3.6.2 100 | notebook==6.2.0 101 | numba @ file:///home/conda/feedstock_root/build_artifacts/numba_1607010260266/work 102 | numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1629092056723/work 103 | nvidia-dali-cuda110==1.4.0 104 | nvidia-dlprof-pytorch-nvtx @ file:///nvidia/opt/dlprof/bin/nvidia_dlprof_pytorch_nvtx-1.4.0-py3-none-any.whl 105 | nvidia-dlprofviewer @ file:///opt/dlprof_viewer_install/nvidia_dlprofviewer-1.4.0-py3-none-any.whl 106 | oauthlib==3.1.1 107 | onnx @ file:///opt/pytorch/pytorch/third_party/onnx 108 | opencv-python-headless==4.5.3.56 109 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1625323647219/work 110 | pandocfilters==1.4.3 111 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1617148930513/work 112 | pathy @ file:///home/conda/feedstock_root/build_artifacts/pathy_1624897245984/work 113 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1602535608087/work 114 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 115 | Pillow @ file:///tmp/pillow-simd 116 | pkginfo @ file:///home/conda/feedstock_root/build_artifacts/pkginfo_1625854086923/work 117 | pluggy==0.13.1 118 | polygraphy==0.32.0 119 | pooch==1.4.0 120 | portalocker==2.3.0 121 | preshed @ file:///home/conda/feedstock_root/build_artifacts/preshed_1625048871810/work 122 | prettytable==2.1.0 123 | prometheus-client==0.11.0 124 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1623977816122/work 125 | protobuf==3.17.3 126 | psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1610127095720/work 127 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 128 | py==1.10.0 129 | pyasn1==0.4.8 130 | pyasn1-modules==0.2.8 131 | pybind11==2.7.1 132 | pycocotools @ git+https://github.com/nvidia/cocoapi.git@9a47a76980d02f70a371e12d4fad61f644a209f1#subdirectory=PythonAPI 133 | pycodestyle==2.5.0 134 | pycosat @ file:///home/conda/feedstock_root/build_artifacts/pycosat_1610094800877/work 135 | pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1593275161868/work 136 | pydantic @ file:///home/conda/feedstock_root/build_artifacts/pydantic_1620819961022/work 137 | pydot==1.4.2 138 | pyflakes==2.1.1 139 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1629119114968/work 140 | pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1608055815057/work 141 | pyparsing==2.4.7 142 | pyrsistent==0.18.0 143 | PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1610291447907/work 144 | pytest==6.2.4 145 | pytest-cov==2.12.1 146 | pytest-pythonpath==0.7.3 147 | python-dateutil==2.8.2 148 | python-dotenv==0.19.0 149 | python-hostlist==1.21 150 | python-nvd3==0.15.0 151 | python-slugify==5.0.2 152 | pytorch-quantization==2.1.0 153 | pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1612179539967/work 154 | PyWavelets==1.1.1 155 | PyYAML==5.4.1 156 | pyzmq==22.2.1 157 | regex==2021.8.3 158 | requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1626393743643/work 159 | requests-oauthlib==1.3.0 160 | resampy==0.2.2 161 | revtok @ git+git://github.com/jekbradbury/revtok.git@f1998b72a941d1e5f9578a66dc1c20b01913caab 162 | rsa==4.7.2 163 | ruamel-yaml-conda @ file:///home/conda/feedstock_root/build_artifacts/ruamel_yaml_1611943339799/work 164 | sacremoses==0.0.45 165 | scikit-image==0.18.3 166 | scikit-learn==0.24.2 167 | scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy_1619561901336/work 168 | Send2Trash==1.8.0 169 | shellingham @ file:///home/conda/feedstock_root/build_artifacts/shellingham_1612179560728/work 170 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 171 | smart-open @ file:///home/conda/feedstock_root/build_artifacts/smart_open_1623731621217/work 172 | snowballstemmer==2.1.0 173 | SoundFile==0.10.3.post1 174 | soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1597680516047/work 175 | spacy @ file:///home/conda/feedstock_root/build_artifacts/spacy_1626856035972/work 176 | spacy-legacy @ file:///home/conda/feedstock_root/build_artifacts/spacy-legacy_1625687473390/work 177 | Sphinx==4.1.2 178 | sphinx-glpi-theme==0.3 179 | sphinx-rtd-theme==0.5.2 180 | sphinxcontrib-applehelp==1.0.2 181 | sphinxcontrib-devhelp==1.0.2 182 | sphinxcontrib-htmlhelp==2.0.0 183 | sphinxcontrib-jsmath==1.0.1 184 | sphinxcontrib-qthelp==1.0.3 185 | sphinxcontrib-serializinghtml==1.1.5 186 | sqlparse==0.4.1 187 | srsly @ file:///home/conda/feedstock_root/build_artifacts/srsly_1618231649431/work 188 | tabulate==0.8.9 189 | tensorboard==2.6.0 190 | tensorboard-data-server==0.6.1 191 | tensorboard-plugin-wit==1.8.0 192 | tensorboardX==2.4 193 | tensorrt @ file:///workspace/TensorRT-8.0.1.6/python/tensorrt-8.0.1.6-cp38-none-linux_x86_64.whl 194 | terminado==0.11.0 195 | testpath==0.5.0 196 | text-unidecode==1.3 197 | thinc @ file:///home/conda/feedstock_root/build_artifacts/thinc_1626699893620/work 198 | threadpoolctl==2.2.0 199 | tifffile==2021.8.30 200 | timm==0.4.12 201 | toml==0.10.2 202 | torch==1.10.0a0+3fd9dcf 203 | torchtext @ file:///opt/pytorch/text 204 | torchvision @ file:///opt/pytorch/vision 205 | torchviz==0.0.2 206 | tornado==6.1 207 | tqdm==4.62.1 208 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1602771532708/work 209 | typer @ file:///home/conda/feedstock_root/build_artifacts/typer_1609874382867/work 210 | typing-extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1622748266870/work 211 | uff @ file:///workspace/TensorRT-8.0.1.6/uff/uff-0.6.9-py2.py3-none-any.whl 212 | urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1624634538755/work 213 | uvicorn==0.15.0 214 | uvloop==0.16.0 215 | wasabi @ file:///home/conda/feedstock_root/build_artifacts/wasabi_1612156086016/work 216 | watchgod==0.7 217 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1600965781394/work 218 | webencodings==0.5.1 219 | websockets==9.1 220 | Werkzeug==2.0.1 221 | whitenoise==5.3.0 222 | yacs==0.1.8 223 | yapf==0.31.0 224 | -------------------------------------------------------------------------------- /src/hrnet_w48_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlaedtke/segmentation_pytorch/3818a88d4f40b51cf149b1e04a7b94f889c358a9/src/hrnet_w48_graph.png -------------------------------------------------------------------------------- /src/segformer_b0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlaedtke/segmentation_pytorch/3818a88d4f40b51cf149b1e04a7b94f889c358a9/src/segformer_b0_graph.png -------------------------------------------------------------------------------- /src/segformer_simple_b0_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlaedtke/segmentation_pytorch/3818a88d4f40b51cf149b1e04a7b94f889c358a9/src/segformer_simple_b0_graph.png -------------------------------------------------------------------------------- /src/stuttgart_hrnet_w48_sample.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlaedtke/segmentation_pytorch/3818a88d4f40b51cf149b1e04a7b94f889c358a9/src/stuttgart_hrnet_w48_sample.gif -------------------------------------------------------------------------------- /src/stuttgart_segformer_sample.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlaedtke/segmentation_pytorch/3818a88d4f40b51cf149b1e04a7b94f889c358a9/src/stuttgart_segformer_sample.gif -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camlaedtke/segmentation_pytorch/3818a88d4f40b51cf149b1e04a7b94f889c358a9/utils/__init__.py -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | 3 | import os 4 | import sys 5 | import cv2 6 | import glob 7 | import torch 8 | import pathlib 9 | import numpy as np 10 | from torch import nn 11 | from PIL import Image 12 | import torch.nn.functional as F 13 | from skimage.io import imread 14 | import matplotlib.pyplot as plt 15 | from utils.label_utils import get_labels 16 | from sklearn.externals._pilutil import bytescale 17 | 18 | def re_normalize(inp: np.ndarray, low: int = 0, high: int = 255): 19 | """Normalize the data to a certain range. Default: [0-255]""" 20 | inp_out = bytescale(inp, low=low, high=high) 21 | return inp_out 22 | 23 | 24 | labels = get_labels() 25 | id2label = { label.id : label for label in labels } 26 | trainid2label = { label.trainId : label for label in labels } 27 | 28 | class SegmentationDataset(torch.utils.data.Dataset): 29 | def __init__(self, cfg: dict, split="train", transform=None, labels=True): 30 | self.cfg = cfg 31 | self.split = split 32 | self.labels = labels 33 | self.crop_size = cfg.CROP_SIZE 34 | self.base_size = cfg.BASE_SIZE 35 | 36 | search_image_files = os.path.join( 37 | cfg.DATA_DIR, 38 | cfg.IMAGE_DIR, 39 | split, '*', 40 | cfg.INPUT_PATTERN) 41 | 42 | if labels: 43 | search_annot_files = os.path.join( 44 | cfg.DATA_DIR, 45 | cfg.LABEL_DIR, 46 | split, '*', 47 | cfg.ANNOT_PATTERN) 48 | 49 | 50 | # root directory 51 | root = pathlib.Path.cwd() 52 | 53 | input_path = str(root / search_image_files) 54 | if labels: 55 | target_path = str(root / search_annot_files) 56 | 57 | self.inputs = [pathlib.PurePath(file) for file in sorted(glob.glob(search_image_files))] 58 | if labels: 59 | self.targets = [pathlib.PurePath(file) for file in sorted(glob.glob(search_annot_files))] 60 | 61 | print("Images: {} , Labels: {}".format(len(self.inputs), len(self.targets))) 62 | 63 | self.transform = transform 64 | self.inputs_dtype = torch.float32 65 | if labels: 66 | self.targets_dtype = torch.int64 67 | 68 | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 69 | 70 | self.class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 71 | 1.0166, 0.9969, 0.9754, 1.0489, 72 | 0.8786, 1.0023, 0.9539, 0.9843, 73 | 1.1116, 0.9037, 1.0865, 1.0955, 74 | 1.0865, 1.1529, 1.0507]).to(self.device) 75 | 76 | 77 | def __len__(self): 78 | return len(self.inputs) 79 | 80 | def __getitem__(self, index: int): 81 | 82 | # Select the sample 83 | input_ID = self.inputs[index] 84 | if self.labels: 85 | target_ID = self.targets[index] 86 | name = os.path.splitext(os.path.basename(input_ID))[0] 87 | 88 | # Load input and target 89 | if self.labels: 90 | x, y = imread(str(input_ID)), imread(str(target_ID)) 91 | else: 92 | x = imread(str(input_ID)) 93 | size = x.shape 94 | 95 | # Preprocessing 96 | if (self.transform is not None) and self.labels: 97 | x, y = self.transform(x, y) 98 | elif self.transform is not None: 99 | x = self.transform(x) 100 | 101 | # Typecasting 102 | if self.labels: 103 | x, y = torch.from_numpy(x).type(self.inputs_dtype), torch.from_numpy(y).type(self.targets_dtype) 104 | y = y.squeeze() 105 | return x, y, np.array(size), name 106 | else: 107 | x = torch.from_numpy(x).type(self.inputs_dtype) 108 | return x, np.array(size), name 109 | 110 | 111 | def inference(self, model, image): 112 | # assume input image is channels first 113 | batch, _, ori_height, ori_width = image.size() 114 | assert batch == 1, "only supporting batchsize 1." 115 | # convert to channels last for resizing 116 | image = image.numpy()[0].transpose((1,2,0)).copy() 117 | h, w = self.crop_size 118 | new_img = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR) 119 | # convert to channels first for inference 120 | new_img = new_img.transpose((2, 0, 1)) 121 | new_img = np.expand_dims(new_img, axis=0) 122 | pred = model(torch.from_numpy(new_img)) 123 | # resize to base size 124 | pred = F.interpolate(input=pred, size=(ori_height, ori_width), mode='bilinear', align_corners=False) 125 | # pred = pred.numpy() 126 | return pred.exp() 127 | 128 | 129 | def sliding_window(self, im, crop_size, stride): 130 | B, C, H, W = im.shape 131 | cs = crop_size 132 | 133 | windows = {"crop": [], "anchors": []} 134 | h_anchors = np.arange(0, H, stride[0]) 135 | w_anchors = np.arange(0, W, stride[1]) 136 | 137 | h_anchors = [h.item() for h in h_anchors if h < H - cs[0]] + [H - cs[0]] 138 | w_anchors = [w.item() for w in w_anchors if w < W - cs[1]] + [W - cs[1]] 139 | for ha in h_anchors: 140 | for wa in w_anchors: 141 | window = im[:, :, ha : ha + cs[0], wa : wa + cs[1]] 142 | windows["crop"].append(window) 143 | windows["anchors"].append((ha, wa)) 144 | windows["shape"] = (H, W) 145 | return windows 146 | 147 | 148 | def merge_windows(self, windows, crop_size, ori_shape): 149 | cs = crop_size 150 | im_windows = windows["crop_seg"] 151 | anchors = windows["anchors"] 152 | C = im_windows[0].shape[1] 153 | H, W = windows["shape"] 154 | 155 | logit = np.zeros((C, H, W)) 156 | count = np.zeros((1, H, W)) 157 | for window, (ha, wa) in zip(im_windows, anchors): 158 | # print("window.shape: {}, (ha, wa): ({}, {})".format(window.shape, ha, wa)) 159 | logit[:, ha : ha + cs[0], wa : wa + cs[1]] += window.squeeze() 160 | count[:, ha : ha + cs[0], wa : wa + cs[1]] += 1 161 | 162 | logit = logit / count 163 | logit = F.interpolate(torch.from_numpy(logit).unsqueeze(0), ori_shape, mode="bilinear")[0] 164 | result = F.softmax(logit, 0) 165 | return result.numpy() 166 | 167 | 168 | def sliding_inference(self, model, image): 169 | # assume input image is channels first 170 | batch, _, ori_height, ori_width = image.size() 171 | assert batch == 1, "only supporting batchsize 1." 172 | 173 | # gather sliding windows 174 | windows = self.sliding_window(im=image, crop_size=self.cfg.CROP_SIZE, stride=(768, 768)) 175 | crop_list = windows['crop'] 176 | 177 | # make predictions on windows 178 | pred_list = [] 179 | for x_crop in crop_list: 180 | pred = model(x_crop.to(self.device)) 181 | pred = F.interpolate(pred, self.cfg.CROP_SIZE, mode="bilinear", align_corners=False) 182 | pred_list.append(pred.detach().numpy()) 183 | 184 | windows['crop_seg'] = pred_list 185 | pred = self.merge_windows(windows, crop_size=self.cfg.CROP_SIZE, ori_shape=self.cfg.BASE_SIZE) 186 | # print(pred.shape) 187 | pred = np.expand_dims(pred, axis=0) 188 | return torch.from_numpy(np.exp(pred)) 189 | 190 | 191 | def label_to_rgb(self, seg): 192 | h = seg.shape[0] 193 | w = seg.shape[1] 194 | seg_rgb = np.zeros((h, w, 3), dtype=np.uint8) 195 | for key, val in trainid2label.items(): 196 | indices = seg == key 197 | seg_rgb[indices.squeeze()] = val.color 198 | return seg_rgb 199 | 200 | 201 | def save_pred(self, image, pred, sv_path, name): 202 | # pred = np.asarray(np.argmax(pred.cpu(), axis=1), dtype=np.uint8) 203 | # pred = np.asarray(np.argmax(pred, axis=0), dtype=np.uint8) 204 | pred = np.asarray(pred, dtype=np.uint8) 205 | # convert to channels last 206 | pred = pred.transpose((1,2,0)).copy() 207 | # print(pred.shape) 208 | # PROBLEM IS HERE 209 | pred = np.argmax(pred, axis=-1) 210 | # print(pred.shape) 211 | pred = self.label_to_rgb(pred) 212 | # print(pred.shape) 213 | image = image.cpu() 214 | # print(image.shape) 215 | image = image[0].permute(1,2,0).numpy() 216 | # print(image.shape) 217 | image = re_normalize(image) 218 | 219 | blend = cv2.addWeighted(image, 0.8, pred, 0.6, 0) 220 | pil_blend = Image.fromarray(blend).convert("RGB") 221 | pil_blend.save(os.path.join(sv_path, name[0]+'.png')) 222 | 223 | 224 | 225 | 226 | def label_mapping(seg: np.ndarray, label_map: dict): 227 | seg = seg.astype(np.int32) 228 | temp = np.copy(seg) 229 | for key, val in label_map.items(): 230 | seg[temp == key] = val.trainId 231 | return seg 232 | 233 | 234 | def cityscapes_label_to_rgb(mask): 235 | h = mask.shape[0] 236 | w = mask.shape[1] 237 | mask_rgb = np.zeros((h, w, 3), dtype=np.uint8) 238 | for key, val in trainid2label.items(): 239 | indices = mask == key 240 | mask_rgb[indices.squeeze()] = val.color 241 | return mask_rgb 242 | 243 | 244 | def display(display_list): 245 | plt.figure(figsize=(15, 5), dpi=150) 246 | title = ['Input Image', 'True Mask', 'Predicted Mask'] 247 | for i in range(len(display_list)): 248 | plt.subplot(1, len(display_list), i+1) 249 | plt.title(title[i]) 250 | plt.imshow(display_list[i]) 251 | plt.axis('off') 252 | plt.tight_layout() 253 | plt.show() 254 | 255 | 256 | def display_blend(display_list): 257 | plt.figure(figsize=(10, 10), dpi=150) 258 | for i in range(len(display_list)): 259 | blend = cv2.addWeighted(display_list[i][0], 0.8, display_list[i][1], 0.6, 0) 260 | plt.subplot(1, len(display_list), i+1) 261 | plt.imshow(blend) 262 | plt.axis('off') 263 | plt.tight_layout() 264 | plt.show() 265 | -------------------------------------------------------------------------------- /utils/label_utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from typing import List, Callable, Tuple 3 | 4 | 5 | 6 | def get_labels(): 7 | 8 | # a label and all meta information 9 | Label = namedtuple( 'Label' , [ 10 | 11 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 12 | # We use them to uniquely name a class 13 | 14 | 'id' , # An integer ID that is associated with this label. 15 | # The IDs are used to represent the label in ground truth images 16 | # An ID of -1 means that this label does not have an ID and thus 17 | # is ignored when creating ground truth images (e.g. license plate). 18 | # Do not modify these IDs, since exactly these IDs are expected by the 19 | # evaluation server. 20 | 21 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 22 | # ground truth images with train IDs, using the tools provided in the 23 | # 'preparation' folder. However, make sure to validate or submit results 24 | # to our evaluation server using the regular IDs above! 25 | # For trainIds, multiple labels might have the same ID. Then, these labels 26 | # are mapped to the same class in the ground truth images. For the inverse 27 | # mapping, we use the label that is defined first in the list below. 28 | # For example, mapping all void-type classes to the same ID in training, 29 | # might make sense for some approaches. 30 | # Max value is 255! 31 | 32 | 'category' , # The name of the category that this label belongs to 33 | 34 | 'categoryId' , # The ID of this category. Used to create ground truth images 35 | # on category level. 36 | 37 | 'hasInstances', # Whether this label distinguishes between single instances or not 38 | 39 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 40 | # during evaluations or not 41 | 42 | 'color' , # The color of this label 43 | ] ) 44 | 45 | 46 | #-------------------------------------------------------------------------------- 47 | # A list of all labels 48 | #-------------------------------------------------------------------------------- 49 | 50 | # Please adapt the train IDs as appropriate for your approach. 51 | # Note that you might want to ignore labels with ID 255 during training. 52 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 53 | # Make sure to provide your results using the original IDs and not the training IDs. 54 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 55 | 56 | labels = [ 57 | # name id trainId category catId hasInstances ignoreInEval color 58 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 59 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 60 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 61 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 62 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 63 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 64 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 65 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 66 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 67 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 68 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 69 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 70 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 71 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 72 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 73 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 74 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 75 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 76 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 77 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 78 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 79 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 80 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 81 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 82 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 83 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 84 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 85 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 86 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 87 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 88 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 89 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 90 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 91 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 92 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 93 | ] 94 | 95 | return labels -------------------------------------------------------------------------------- /utils/lr_schedule.py: -------------------------------------------------------------------------------- 1 | class LrUpdater(): 2 | """Modified version of LR Scheduler in MMCV. 3 | 4 | Args: 5 | by_epoch (bool): LR changes epoch by epoch 6 | warmup (string): Type of warmup used. It can be None(use no warmup), 7 | 'constant', 'linear' or 'exp' 8 | warmup_iters (int): The number of iterations or epochs that warmup 9 | lasts 10 | warmup_ratio (float): LR used at the beginning of warmup equals to 11 | warmup_ratio * initial_lr 12 | warmup_by_epoch (bool): When warmup_by_epoch == True, warmup_iters 13 | means the number of epochs that warmup lasts, otherwise means the 14 | number of iteration that warmup lasts 15 | """ 16 | 17 | def __init__(self, 18 | optimizer, 19 | epoch_len, 20 | by_epoch=True, 21 | warmup=None, 22 | warmup_iters=0, 23 | warmup_ratio=0.1, 24 | warmup_by_epoch=False): 25 | # validate the "warmup" argument 26 | if warmup is not None: 27 | if warmup not in ['constant', 'linear', 'exp']: 28 | raise ValueError( 29 | f'"{warmup}" is not a supported type for warming up, valid' 30 | ' types are "constant" and "linear"') 31 | if warmup is not None: 32 | assert warmup_iters > 0, \ 33 | '"warmup_iters" must be a positive integer' 34 | assert 0 < warmup_ratio <= 1.0, \ 35 | '"warmup_ratio" must be in range (0,1]' 36 | 37 | self.optimizer = optimizer 38 | self.epoch_len = epoch_len 39 | self.by_epoch = by_epoch 40 | self.warmup = warmup 41 | self.warmup_iters = warmup_iters 42 | self.warmup_ratio = warmup_ratio 43 | self.warmup_by_epoch = warmup_by_epoch 44 | 45 | if self.warmup_by_epoch: 46 | self.warmup_epochs = self.warmup_iters 47 | self.warmup_iters = None 48 | else: 49 | self.warmup_epochs = None 50 | 51 | self.base_lr = [] # initial lr for all param groups 52 | self.regular_lr = [] # expected lr if no warming up is performed 53 | 54 | def _set_lr(self, lr_groups): 55 | if isinstance(self.optimizer, dict): 56 | for k, optim in self.optimizer.items(): 57 | for param_group, lr in zip(optim.param_groups, lr_groups[k]): 58 | param_group['lr'] = lr 59 | else: 60 | for param_group, lr in zip(self.optimizer.param_groups, lr_groups): 61 | param_group['lr'] = lr 62 | 63 | def get_lr(self, base_lr): 64 | raise NotImplementedError 65 | 66 | def get_regular_lr(self): 67 | if isinstance(self.optimizer, dict): 68 | lr_groups = {} 69 | for k in self.optimizer.keys(): 70 | _lr_group = [self.get_lr(self.cur_iter, _base_lr) for _base_lr in self.base_lr[k]] 71 | lr_groups.update({k: _lr_group}) 72 | 73 | return lr_groups 74 | else: 75 | return [self.get_lr(self.cur_iter, _base_lr) for _base_lr in self.base_lr] 76 | 77 | def get_warmup_lr(self): 78 | 79 | def _get_warmup_lr(regular_lr): 80 | k = (1 - self.cur_iter / self.warmup_iters) * (1 - self.warmup_ratio) 81 | warmup_lr = [_lr * (1 - k) for _lr in regular_lr] 82 | return warmup_lr 83 | 84 | if isinstance(self.regular_lr, dict): 85 | lr_groups = {} 86 | for key, regular_lr in self.regular_lr.items(): 87 | lr_groups[key] = _get_warmup_lr(regular_lr) 88 | return lr_groups 89 | else: 90 | return _get_warmup_lr(self.regular_lr) 91 | 92 | def before_run(self): 93 | # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, 94 | # it will be set according to the optimizer params 95 | if isinstance(self.optimizer, dict): 96 | self.base_lr = {} 97 | for k, optim in self.optimizer.items(): 98 | for group in optim.param_groups: 99 | group.setdefault('initial_lr', group['lr']) 100 | _base_lr = [group['initial_lr'] for group in optim.param_groups] 101 | self.base_lr.update({k: _base_lr}) 102 | else: 103 | for group in self.optimizer.param_groups: 104 | group.setdefault('initial_lr', group['lr']) 105 | self.base_lr = [group['initial_lr'] for group in self.optimizer.param_groups] 106 | 107 | def before_train_epoch(self): 108 | if self.warmup_iters is None: 109 | self.warmup_iters = self.warmup_epochs * self.epoch_len 110 | 111 | if not self.by_epoch: 112 | return 113 | 114 | self.regular_lr = self.get_regular_lr() 115 | self._set_lr(self.regular_lr) 116 | 117 | def before_train_iter(self): 118 | self.regular_lr = self.get_regular_lr() 119 | if self.warmup is None or self.cur_iter >= self.warmup_iters: 120 | self._set_lr(self.regular_lr) 121 | else: 122 | warmup_lr = self.get_warmup_lr() 123 | self._set_lr(warmup_lr) 124 | 125 | 126 | 127 | 128 | class PolyLrUpdater(LrUpdater): 129 | 130 | def __init__(self, max_iters, power=1., min_lr=0., **kwargs): 131 | self.power = power 132 | self.min_lr = min_lr 133 | self.max_iters = max_iters 134 | super(PolyLrUpdater, self).__init__(**kwargs) 135 | 136 | self.cur_iter = 1 137 | 138 | def get_lr(self, step, base_lr): 139 | self.cur_iter = step 140 | progress = self.cur_iter 141 | max_progress = self.max_iters 142 | coeff = (1 - progress / max_progress)**self.power 143 | return (base_lr - self.min_lr) * coeff + self.min_lr -------------------------------------------------------------------------------- /utils/modelsummary.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import logging 7 | from collections import namedtuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | def get_model_summary(model, *input_tensors, item_length=20, verbose=False): 13 | """ 14 | :param model: 15 | :param input_tensors: 16 | :param item_length: 17 | :return: 18 | """ 19 | 20 | summary = [] 21 | 22 | ModuleDetails = namedtuple("Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds"]) 23 | hooks = [] 24 | layer_instances = {} 25 | 26 | def add_hooks(module): 27 | 28 | def hook(module, input, output): 29 | class_name = str(module.__class__.__name__) 30 | 31 | instance_index = 1 32 | if class_name not in layer_instances: 33 | layer_instances[class_name] = instance_index 34 | else: 35 | instance_index = layer_instances[class_name] + 1 36 | layer_instances[class_name] = instance_index 37 | 38 | layer_name = class_name + "_" + str(instance_index) 39 | 40 | params = 0 41 | 42 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or class_name.find("Linear") != -1: 43 | for param_ in module.parameters(): 44 | params += param_.view(-1).size(0) 45 | 46 | flops = "Not Available" 47 | if class_name.find("Conv") != -1 and hasattr(module, "weight"): 48 | flops = ( 49 | torch.prod(torch.LongTensor(list(module.weight.data.size()))) * 50 | torch.prod(torch.LongTensor(list(output.size())[2:]))).item() 51 | elif isinstance(module, nn.Linear): 52 | flops = (torch.prod(torch.LongTensor(list(output.size()))) * input[0].size(1)).item() 53 | 54 | if isinstance(input[0], list): 55 | input = input[0] 56 | if isinstance(output, list): 57 | output = output[0] 58 | if isinstance(output, tuple): 59 | output = output[0] 60 | 61 | 62 | summary.append( 63 | ModuleDetails( 64 | name=layer_name, 65 | input_size=list(input[0].size()), 66 | output_size=list(output.size()), 67 | num_parameters=params, 68 | multiply_adds=flops) 69 | ) 70 | 71 | if not isinstance(module, nn.ModuleList) \ 72 | and not isinstance(module, nn.Sequential) \ 73 | and module != model: 74 | hooks.append(module.register_forward_hook(hook)) 75 | 76 | model.eval() 77 | model.apply(add_hooks) 78 | 79 | space_len = item_length 80 | 81 | model(*input_tensors) 82 | for hook in hooks: 83 | hook.remove() 84 | 85 | details = '' 86 | if verbose: 87 | details = "Model Summary" + \ 88 | os.linesep + \ 89 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format( 90 | ' ' * (space_len - len("Name")), 91 | ' ' * (space_len - len("Input Size")), 92 | ' ' * (space_len - len("Output Size")), 93 | ' ' * (space_len - len("Parameters")), 94 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \ 95 | + os.linesep + '-' * space_len * 5 + os.linesep 96 | 97 | params_sum = 0 98 | flops_sum = 0 99 | for layer in summary: 100 | params_sum += layer.num_parameters 101 | if layer.multiply_adds != "Not Available": 102 | flops_sum += layer.multiply_adds 103 | if verbose: 104 | details += "{}{}{}{}{}{}{}{}{}{}".format( 105 | layer.name, 106 | ' ' * (space_len - len(layer.name)), 107 | layer.input_size, 108 | ' ' * (space_len - len(str(layer.input_size))), 109 | layer.output_size, 110 | ' ' * (space_len - len(str(layer.output_size))), 111 | layer.num_parameters, 112 | ' ' * (space_len - len(str(layer.num_parameters))), 113 | layer.multiply_adds, 114 | ' ' * (space_len - len(str(layer.multiply_adds)))) \ 115 | + os.linesep + '-' * space_len * 5 + os.linesep 116 | 117 | details += os.linesep \ 118 | + "Total Parameters: {:,}".format(params_sum) \ 119 | + os.linesep + '-' * space_len * 5 + os.linesep 120 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \ 121 | + os.linesep + '-' * space_len * 5 + os.linesep 122 | details += "Number of Layers" + os.linesep 123 | for layer in layer_instances: 124 | details += "{} : {} layers ".format(layer, layer_instances[layer]) 125 | 126 | return details -------------------------------------------------------------------------------- /utils/runners.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import pathlib 6 | import logging 7 | import numpy as np 8 | from torch import nn 9 | from tqdm import tqdm 10 | from tensorboardX import SummaryWriter 11 | from utils.modelsummary import get_model_summary 12 | from utils.train_utils import AverageMeter, get_confusion_matrix, adjust_learning_rate, create_logger 13 | 14 | 15 | def train( 16 | cfg, 17 | dataloader, 18 | model, 19 | loss_fn, 20 | optimizer, 21 | lr_scheduler, 22 | scaler, 23 | writer_dict, 24 | epoch, 25 | ): 26 | model.train() 27 | 28 | ave_loss = AverageMeter() 29 | steps_tot = epoch*len(dataloader) 30 | writer = writer_dict['writer'] 31 | global_steps = writer_dict['train_global_steps'] 32 | 33 | for step, batch in enumerate(dataloader): 34 | X, y, _, _ = batch 35 | X, y = X.cuda(), y.long().cuda() 36 | 37 | # Compute prediction and loss 38 | with torch.cuda.amp.autocast(): 39 | pred = model(X) 40 | losses = loss_fn(pred, y) 41 | loss = losses.mean() 42 | 43 | # Backpropagation 44 | scaler.scale(loss).backward() 45 | scaler.step(optimizer) 46 | scaler.update() 47 | optimizer.zero_grad() 48 | 49 | # update average loss 50 | ave_loss.update(loss.item()) 51 | 52 | # update learning schedule 53 | lr_scheduler.before_train_iter() 54 | lr = lr_scheduler.get_lr(int(steps_tot+step), cfg.TRAIN.BASE_LR) 55 | #lr = adjust_learning_rate(optimizer, cfg['BASE_LR'], cfg['END_LR'], step+steps_tot, cfg['DECAY_STEPS']) 56 | 57 | writer.add_scalar('train_loss', ave_loss.average(), global_steps) 58 | writer_dict['train_global_steps'] = global_steps + 1 59 | 60 | 61 | 62 | def validate(cfg, dataloader, model, loss_fn, writer_dict): 63 | model.eval() 64 | 65 | ave_loss = AverageMeter() 66 | iter_steps = len(dataloader.dataset) // cfg.BATCH_SIZE 67 | confusion_matrix = np.zeros((cfg.NUM_CLASSES, cfg.NUM_CLASSES, 1)) 68 | 69 | with torch.no_grad(): 70 | for idx, batch in enumerate(dataloader): 71 | x, y, _, _ = batch 72 | size = y.size() 73 | X, y = X.cuda(), y.long().cuda() 74 | 75 | pred = model(X) 76 | losses = loss_fn(pred, y) 77 | loss = losses.mean() 78 | 79 | if not isinstance(pred, (list, tuple)): 80 | pred = [pred] 81 | for i, x in enumerate(pred): 82 | confusion_matrix[..., i] += get_confusion_matrix( 83 | y, x, size, cfg.DATASET.NUM_CLASSES, cfg.DATASET.NUM_CLASSES) 84 | ave_loss.update(loss.item()) 85 | 86 | pos = confusion_matrix[..., 0].sum(1) 87 | res = confusion_matrix[..., 0].sum(0) 88 | tp = np.diag(confusion_matrix[..., 0]) 89 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 90 | mean_IoU = IoU_array.mean() 91 | 92 | writer = writer_dict['writer'] 93 | global_steps = writer_dict['valid_global_steps'] 94 | writer.add_scalar('valid_loss', ave_loss.average(), global_steps) 95 | writer.add_scalar('valid_mIoU', mean_IoU, global_steps) 96 | for key, val in trainid2label.items(): 97 | if key != cfg.DATASET.IGNORE_LABEL and key != -1: 98 | writer.add_scalar('valid_mIoU_{}'.format(val.name), IoU_array[key], global_steps) 99 | writer_dict['valid_global_steps'] = global_steps + 1 100 | 101 | return ave_loss.average(), mean_IoU, IoU_array 102 | 103 | 104 | 105 | def testval(cfg, testloader, model, sv_dir='', sv_pred=False, sliding_inf=False): 106 | model.eval() 107 | confusion_matrix = np.zeros((cfg.DATASET.NUM_CLASSES, cfg.DATASET.NUM_CLASSES)) 108 | 109 | with torch.no_grad(): 110 | for index, batch in enumerate(tqdm(testloader)): 111 | image, label, _, name, *border_padding = batch 112 | size = label.size() 113 | if sliding_inf: 114 | pred = testloader.dataset.sliding_inference(model, image) 115 | else: 116 | pred = testloader.dataset.inference(model, image) 117 | 118 | confusion_matrix += get_confusion_matrix( 119 | label, pred, size, cfg.DATASET.NUM_CLASSES, cfg.DATASET.NUM_CLASSES) 120 | 121 | if index % 100 == 0: 122 | logging.info('processing: %d images' % index) 123 | pos = confusion_matrix.sum(1) 124 | res = confusion_matrix.sum(0) 125 | tp = np.diag(confusion_matrix) 126 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 127 | mean_IoU = IoU_array.mean() 128 | logging.info('mIoU: %.4f' % (mean_IoU)) 129 | 130 | if sv_pred: 131 | sv_path = os.path.join(sv_dir, 'test_results') 132 | if not os.path.exists(sv_path): 133 | os.mkdir(sv_path) 134 | if sliding_inf: 135 | # print(pred.shape) 136 | pred = np.squeeze(pred, 0) 137 | testloader.dataset.save_pred(image, pred, sv_path, name) 138 | 139 | pos = confusion_matrix.sum(1) 140 | res = confusion_matrix.sum(0) 141 | tp = np.diag(confusion_matrix) 142 | pixel_acc = tp.sum()/pos.sum() 143 | mean_acc = (tp/np.maximum(1.0, pos)).mean() 144 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 145 | mean_IoU = IoU_array.mean() 146 | 147 | return mean_IoU, IoU_array, pixel_acc, mean_acc 148 | 149 | 150 | def testvideo(cfg, testloader, model, sv_dir='', sv_pred=False): 151 | model.eval() 152 | 153 | with torch.no_grad(): 154 | for index, batch in enumerate(tqdm(testloader)): 155 | image, _, name, *border_padding = batch 156 | size = image.size() 157 | pred = testloader.dataset.inference(model, image) 158 | 159 | if sv_pred: 160 | sv_path = os.path.join(sv_dir, 'video_frames') 161 | if not os.path.exists(sv_path): 162 | os.mkdir(sv_path) 163 | testloader.dataset.save_pred(image, pred, sv_path, name) 164 | 165 | print("done!") 166 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # + 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import glob 9 | import time 10 | import torch 11 | import pathlib 12 | import logging 13 | from pathlib import Path 14 | 15 | import numpy as np 16 | from torch import nn 17 | from torch.utils import data 18 | from skimage.util import crop 19 | from skimage.io import imread 20 | from skimage.transform import resize 21 | from sklearn.externals._pilutil import bytescale 22 | from typing import List, Callable, Tuple 23 | 24 | 25 | class AverageMeter(object): 26 | """Computes and stores the average and current value""" 27 | 28 | def __init__(self): 29 | self.initialized = False 30 | self.val = None 31 | self.avg = None 32 | self.sum = None 33 | self.count = None 34 | 35 | def initialize(self, val, weight): 36 | self.val = val 37 | self.avg = val 38 | self.sum = val * weight 39 | self.count = weight 40 | self.initialized = True 41 | 42 | def update(self, val, weight=1): 43 | if not self.initialized: 44 | self.initialize(val, weight) 45 | else: 46 | self.add(val, weight) 47 | 48 | def add(self, val, weight): 49 | self.val = val 50 | self.sum += val * weight 51 | self.count += weight 52 | self.avg = self.sum / self.count 53 | 54 | def value(self): 55 | return self.val 56 | 57 | def average(self): 58 | return self.avg 59 | 60 | 61 | 62 | class CrossEntropy(nn.Module): 63 | def __init__(self, ignore_label=-1, weight=None): 64 | super(CrossEntropy, self).__init__() 65 | self.ignore_label = ignore_label 66 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_label) 67 | self.class_weights = weight 68 | 69 | def _forward(self, score, target): 70 | loss = self.criterion(score, target) 71 | return loss 72 | 73 | def forward(self, score, target): 74 | score = [score] 75 | weights = [1] 76 | assert len(weights) == len(score) 77 | return sum([w * self._forward(x, target) for (w, x) in zip(weights, score)]) 78 | 79 | 80 | def create_logger(cfg, cfg_name, phase='train'): 81 | root_output_dir = Path(os.path.join(os.getcwd(), cfg.OUTPUT_DIR)) 82 | root_log_dir = Path(os.path.join(os.getcwd(), cfg.LOG_DIR)) 83 | # set up logger 84 | if not root_output_dir.exists(): 85 | print('=> creating {}'.format(root_output_dir)) 86 | root_output_dir.mkdir() 87 | 88 | if not root_log_dir.exists(): 89 | print('=> creating {}'.format(root_log_dir)) 90 | root_log_dir.mkdir() 91 | 92 | dataset = cfg.DATASET.NAME 93 | model = cfg.MODEL.NAME 94 | # cfg_name = os.path.basename(cfg_name).split('.')[0] 95 | 96 | final_output_dir = root_output_dir / dataset / cfg_name 97 | 98 | print('=> creating {}'.format(final_output_dir)) 99 | final_output_dir.mkdir(parents=True, exist_ok=True) 100 | 101 | time_str = time.strftime('%Y-%m-%d-%H-%M') 102 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase) 103 | final_log_file = final_output_dir / log_file 104 | head = '%(asctime)-15s %(message)s' 105 | logging.basicConfig(filename=str(final_log_file), format=head) 106 | logger = logging.getLogger() 107 | logger.setLevel(logging.INFO) 108 | console = logging.StreamHandler() 109 | logging.getLogger('').addHandler(console) 110 | 111 | tensorboard_log_dir = root_log_dir / dataset / model / (cfg_name + '_' + time_str) 112 | print('=> creating {}'.format(tensorboard_log_dir)) 113 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True) 114 | 115 | return logger, str(final_output_dir), str(tensorboard_log_dir) 116 | 117 | 118 | 119 | def get_confusion_matrix(label, pred, size, num_class, ignore=-1): 120 | """ 121 | Calcute the confusion matrix by given label and pred 122 | """ 123 | output = pred.cpu().numpy().transpose(0, 2, 3, 1) 124 | seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8) 125 | seg_gt = np.asarray(label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int32) 126 | 127 | ignore_index = seg_gt != ignore 128 | seg_gt = seg_gt[ignore_index] 129 | seg_pred = seg_pred[ignore_index] 130 | 131 | index = (seg_gt * num_class + seg_pred).astype('int32') 132 | label_count = np.bincount(index) 133 | confusion_matrix = np.zeros((num_class, num_class)) 134 | 135 | for i_label in range(num_class): 136 | for i_pred in range(num_class): 137 | cur_index = i_label * num_class + i_pred 138 | if cur_index < len(label_count): 139 | confusion_matrix[i_label, i_pred] = label_count[cur_index] 140 | return confusion_matrix 141 | 142 | 143 | def adjust_learning_rate(optimizer, base_lr, end_lr, step, decay_steps, power=0.9): 144 | lr = ((base_lr - end_lr) * (1 - step / decay_steps)**(power)) + end_lr 145 | return lr 146 | -------------------------------------------------------------------------------- /utils/transformation_pipelines.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import albumentations 3 | from utils.label_utils import get_labels 4 | from utils.data_utils import label_mapping 5 | from utils.transformations import (normalize, ComposeSingle, ComposeDouble, re_normalize, 6 | FunctionWrapperSingle, FunctionWrapperDouble, 7 | AlbuSeg2d, random_crop, random_resize, random_brightness, scale_aug) 8 | 9 | labels = get_labels() 10 | id2label = { label.id : label for label in labels } 11 | 12 | 13 | def get_transforms_training(cfg): 14 | 15 | transforms_training = ComposeDouble([ 16 | FunctionWrapperDouble(random_resize, scale_factor=16, base_size=cfg.DATASET.BASE_SIZE[1], both=True), 17 | FunctionWrapperDouble(random_crop, crop_size=cfg.DATASET.CROP_SIZE, ignore_label=cfg.DATASET.IGNORE_LABEL, both=True), 18 | AlbuSeg2d(albumentations.HorizontalFlip(p=0.5)), 19 | FunctionWrapperDouble(label_mapping, label_map=id2label, input=False, target=True), 20 | FunctionWrapperDouble(random_brightness, input=True, target=False), 21 | FunctionWrapperDouble(normalize, mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD, input=True, target=False), 22 | FunctionWrapperDouble(np.moveaxis, input=True, target=False, source=-1, destination=0), 23 | ]) 24 | return transforms_training 25 | 26 | 27 | def get_transforms_validation(cfg): 28 | 29 | transforms_validation = ComposeDouble([ 30 | FunctionWrapperDouble(random_crop, crop_size=cfg.DATASET.CROP_SIZE, ignore_label=cfg.DATASET.IGNORE_LABEL, both=True), 31 | FunctionWrapperDouble(label_mapping, label_map=id2label, input=False, target=True), 32 | FunctionWrapperDouble(normalize, mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD, input=True, target=False), 33 | FunctionWrapperDouble(np.moveaxis, input=True, target=False, source=-1, destination=0), 34 | ]) 35 | return transforms_validation 36 | 37 | 38 | def get_transforms_evaluation(cfg): 39 | 40 | transforms_evaluation = ComposeDouble([ 41 | FunctionWrapperDouble(label_mapping, label_map=id2label, input=False, target=True), 42 | FunctionWrapperDouble(normalize, mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD, input=True, target=False), 43 | FunctionWrapperDouble(np.moveaxis, input=True, target=False, source=-1, destination=0), 44 | ]) 45 | return transforms_evaluation 46 | 47 | 48 | def get_transforms_video(cfg): 49 | transforms_video = ComposeSingle([ 50 | FunctionWrapperSingle(normalize, mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD), 51 | FunctionWrapperSingle(np.moveaxis, source=-1, destination=0), 52 | ]) 53 | return transforms_video -------------------------------------------------------------------------------- /utils/transformations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import mmcv 3 | import torch 4 | import random 5 | import numpy as np 6 | import albumentations as A 7 | from typing import List, Callable, Tuple 8 | from skimage.util import crop 9 | from skimage.io import imread 10 | from sklearn.externals._pilutil import bytescale 11 | 12 | from utils.data_utils import get_labels 13 | labels = get_labels() 14 | trainid2label = { label.trainId : label for label in labels } 15 | 16 | 17 | def normalize_01(inp: np.ndarray): 18 | """Squash image input to the value range [0, 1] (no clipping)""" 19 | inp_out = (inp - np.min(inp)) / np.ptp(inp) 20 | return inp_out 21 | 22 | 23 | def normalize(img: np.ndarray, mean: float, std: float): 24 | """Normalize based on mean and standard deviation.""" 25 | img = img.astype(np.float32) / 255 26 | img = img - mean 27 | img = img / std 28 | return img 29 | 30 | 31 | def create_dense_target(tar: np.ndarray): 32 | classes = np.unique(tar) 33 | dummy = np.zeros_like(tar) 34 | for idx, value in enumerate(classes): 35 | mask = np.where(tar == value) 36 | dummy[mask] = idx 37 | 38 | return dummy 39 | 40 | 41 | def label_mapping(seg: np.ndarray, label_map: dict, inverse=False): 42 | temp = np.copy(seg) 43 | if inverse: 44 | for v, k in label_map.items(): 45 | seg[temp == k] = v 46 | else: 47 | for k, v in label_map.items(): 48 | seg[temp == k] = v 49 | return seg 50 | 51 | 52 | def cityscapes_label_to_rgb(mask): 53 | h = mask.shape[0] 54 | w = mask.shape[1] 55 | mask_rgb = np.zeros((h, w, 3), dtype=np.uint8) 56 | for val, key in trainid2label.items(): 57 | indices = mask == val 58 | mask_rgb[indices.squeeze()] = key.color 59 | return mask_rgb 60 | 61 | 62 | def center_crop_to_size(x: np.ndarray, size: Tuple, copy: bool = False) -> np.ndarray: 63 | """ 64 | Center crops a given array x to the size passed in the function. 65 | Expects even spatial dimensions! 66 | """ 67 | x_shape = np.array(x.shape) 68 | size = np.array(size) 69 | params_list = ((x_shape - size) / 2).astype(np.int).tolist() 70 | params_tuple = tuple([(i, i) for i in params_list]) 71 | cropped_image = crop(x, crop_width=params_tuple, copy=copy) 72 | return cropped_image 73 | 74 | 75 | def re_normalize(inp: np.ndarray, low: int = 0, high: int = 255): 76 | """Normalize the data to a certain range. Default: [0-255]""" 77 | inp_out = bytescale(inp, low=low, high=high) 78 | return inp_out 79 | 80 | 81 | def random_flip(inp: np.ndarray, tar: np.ndarray, ndim_spatial: int): 82 | flip_dims = [np.random.randint(low=0, high=2) for dim in range(ndim_spatial)] 83 | 84 | flip_dims_inp = tuple([i + 1 for i, element in enumerate(flip_dims) if element == 1]) 85 | flip_dims_tar = tuple([i for i, element in enumerate(flip_dims) if element == 1]) 86 | 87 | inp_flipped = np.flip(inp, axis=flip_dims_inp) 88 | tar_flipped = np.flip(tar, axis=flip_dims_tar) 89 | 90 | return inp_flipped, tar_flipped 91 | 92 | 93 | def pad_image(img, h, w, size, padvalue): 94 | pad_image = img.copy() 95 | pad_h = max(size[0] - h, 0) 96 | pad_w = max(size[1] - w, 0) 97 | if pad_h > 0 or pad_w > 0: 98 | top = pad_h // 2 99 | right = pad_w // 2 100 | 101 | if pad_h % 2 == 0: 102 | bottom = pad_h // 2 103 | else: 104 | bottom = pad_h // 2 + 1 105 | 106 | if pad_w % 2 == 0: 107 | left = pad_w // 2 108 | else: 109 | left = pad_w // 2 + 1 110 | pad_image = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padvalue) 111 | return pad_image 112 | 113 | 114 | def pad_seg(img, h, w, size, padvalue): 115 | pad_image = img.copy() 116 | pad_h = max(size[0] - h, 0) 117 | pad_w = max(size[1] - w, 0) 118 | if pad_h > 0 or pad_w > 0: 119 | 120 | top = pad_h // 2 121 | right = pad_w // 2 122 | 123 | if pad_h % 2 == 0: 124 | bottom = pad_h // 2 125 | else: 126 | bottom = pad_h // 2 + 1 127 | 128 | if pad_w % 2 == 0: 129 | left = pad_w // 2 130 | else: 131 | left = pad_w // 2 + 1 132 | # print("--> pad ({},{},{},{})".format(top, bottom, left, right)) 133 | pad_image = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=padvalue) 134 | pad_image = np.expand_dims(pad_image, 2) 135 | return pad_image 136 | 137 | 138 | def random_crop(img, seg, crop_size, ignore_label): 139 | 140 | h, w = img.shape[:-1] 141 | img = pad_image(img, h, w, crop_size, (0.0, 0.0, 0.0)) 142 | seg = pad_seg(seg, h, w, crop_size, (ignore_label,)) 143 | 144 | if seg.shape[-1] != 1: 145 | seg = np.expand_dims(seg, -1) 146 | new_h, new_w = seg.shape[:-1] 147 | 148 | x = random.randint(0, new_w - crop_size[1]) 149 | y = random.randint(0, new_h - crop_size[0]) 150 | 151 | img = img[y:y+crop_size[0], x:x+crop_size[1]] 152 | seg = seg[y:y+crop_size[0], x:x+crop_size[1]] 153 | 154 | return img, seg 155 | 156 | 157 | 158 | def random_resize(img, seg, scale_factor, base_size, min_scale = 0.5, max_scale=2.0): 159 | 160 | rand_scale = 0.5 + random.randint(0, scale_factor) / 10.0 161 | rand_scale = np.clip(rand_scale, min_scale, max_scale) 162 | long_size = np.int(base_size * rand_scale + 0.5) 163 | h, w = img.shape[:2] 164 | if h > w: 165 | new_h = long_size 166 | new_w = np.int(w * long_size / h + 0.5) 167 | else: 168 | new_w = long_size 169 | new_h = np.int(h * long_size / w + 0.5) 170 | # print(rand_scale, new_h, new_w) 171 | img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 172 | seg = cv2.resize(seg, (new_w, new_h), interpolation=cv2.INTER_NEAREST) 173 | return img, seg 174 | 175 | 176 | def scale_aug(img, seg=None, scale_factor=1, crop_size=(512, 1024), base_size=(1024, 2048), ignore_label=-1): 177 | 178 | img, seg = random_resize(img, seg, scale_factor, base_size) 179 | 180 | img, seg = rand_crop(img, seg, crop_size=crop_size, ignore_label=ignore_label) 181 | 182 | return img, seg 183 | 184 | 185 | 186 | def random_brightness(img, brightness_shift_value=10): 187 | if random.random() < 0.5: 188 | return img 189 | img = img.astype(np.float32) 190 | shift = random.randint(-brightness_shift_value, brightness_shift_value) 191 | img[:, :, :] += shift 192 | img = np.around(img) 193 | img = np.clip(img, 0, 255).astype(np.uint8) 194 | return img 195 | 196 | 197 | class Repr: 198 | """Evaluable string representation of an object""" 199 | 200 | def __repr__(self): return f'{self.__class__.__name__}: {self.__dict__}' 201 | 202 | 203 | class FunctionWrapperSingle(Repr): 204 | """A function wrapper that returns a partial for input only.""" 205 | 206 | def __init__(self, function: Callable, *args, **kwargs): 207 | from functools import partial 208 | self.function = partial(function, *args, **kwargs) 209 | 210 | def __call__(self, inp: np.ndarray): return self.function(inp) 211 | 212 | 213 | class FunctionWrapperDouble(Repr): 214 | """A function wrapper that returns a partial for an input-target pair.""" 215 | 216 | def __init__(self, function: Callable, input: bool = True, target: bool = False, both: bool = False, *args, **kwargs): 217 | from functools import partial 218 | self.function = partial(function, *args, **kwargs) 219 | self.input = input 220 | self.target = target 221 | self.both = both 222 | 223 | def __call__(self, inp: np.ndarray, tar: dict): 224 | if self.both: 225 | inp, tar = self.function(inp, tar) 226 | else: 227 | if self.input: inp = self.function(inp) 228 | if self.target: tar = self.function(tar) 229 | return inp, tar 230 | 231 | 232 | class Compose: 233 | """Baseclass - composes several transforms together.""" 234 | 235 | def __init__(self, transforms: List[Callable]): 236 | self.transforms = transforms 237 | 238 | def __repr__(self): return str([transform for transform in self.transforms]) 239 | 240 | 241 | class ComposeDouble(Compose): 242 | """Composes transforms for input-target pairs.""" 243 | 244 | def __call__(self, inp: np.ndarray, target: dict): 245 | for t in self.transforms: 246 | inp, target = t(inp, target) 247 | return inp, target 248 | 249 | 250 | class ComposeSingle(Compose): 251 | """Composes transforms for input only.""" 252 | 253 | def __call__(self, inp: np.ndarray): 254 | for t in self.transforms: 255 | inp = t(inp) 256 | return inp 257 | 258 | 259 | class AlbuSeg2d(Repr): 260 | """ 261 | Wrapper for albumentations' segmentation-compatible 2D augmentations. 262 | Wraps an augmentation so it can be used within the provided transform pipeline. 263 | See https://github.com/albu/albumentations for more information. 264 | Expected input: (C, spatial_dims) 265 | Expected target: (spatial_dims) -> No (C)hannel dimension 266 | """ 267 | def __init__(self, albumentation: Callable): 268 | self.albumentation = albumentation 269 | 270 | def __call__(self, inp: np.ndarray, tar: np.ndarray): 271 | # input, target 272 | out_dict = self.albumentation(image=inp, mask=tar) 273 | input_out = out_dict['image'] 274 | target_out = out_dict['mask'] 275 | 276 | return input_out, target_out 277 | -------------------------------------------------------------------------------- /utils/visual.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import napari 3 | from utils.transformations import re_normalize 4 | 5 | 6 | def enable_gui_qt(): 7 | """Performs the magic command %gui qt""" 8 | from IPython import get_ipython 9 | 10 | ipython = get_ipython() 11 | ipython.magic('gui qt') 12 | 13 | 14 | class DatasetViewer: 15 | def __init__(self, 16 | dataset): 17 | 18 | self.dataset = dataset 19 | self.index = 0 20 | 21 | # napari viewer instance 22 | self.viewer = None 23 | 24 | # current image & shape layer 25 | self.image_layer = None 26 | self.label_layer = None 27 | 28 | def napari(self): 29 | # IPython magic 30 | enable_gui_qt() 31 | 32 | # napari 33 | if self.viewer: 34 | try: 35 | del self.viewer 36 | except AttributeError: 37 | pass 38 | self.index = 0 39 | 40 | # Init napari instance 41 | self.viewer = napari.Viewer() 42 | 43 | # Show current sample 44 | self.show_sample() 45 | 46 | # Key-bindings 47 | # Press 'n' to get the next sample 48 | @self.viewer.bind_key('n') 49 | def next(viewer): 50 | self.increase_index() # Increase the index 51 | self.show_sample() # Show next sample 52 | 53 | # Press 'b' to get the previous sample 54 | @self.viewer.bind_key('b') 55 | def prev(viewer): 56 | self.decrease_index() # Decrease the index 57 | self.show_sample() # Show next sample 58 | 59 | def increase_index(self): 60 | self.index += 1 61 | if self.index >= len(self.dataset): 62 | self.index = 0 63 | 64 | def decrease_index(self): 65 | self.index -= 1 66 | if self.index < 0: 67 | self.index = len(self.dataset) - 1 68 | 69 | def show_sample(self): 70 | 71 | # Get a sample from the dataset 72 | sample = self.get_sample_dataset(self.index) 73 | x, y = sample 74 | 75 | # Get the names from the dataset 76 | names = self.get_names_dataset(self.index) 77 | x_name, y_name = names 78 | x_name, y_name = x_name.name, y_name.name # only possible if pathlib.Path 79 | 80 | # Transform the sample to numpy, cpu and correct format to visualize 81 | x = self.transform_x(x) 82 | y = self.transform_y(y) 83 | 84 | # Create or update image layer 85 | if self.image_layer not in self.viewer.layers: 86 | self.image_layer = self.create_image_layer(x, x_name) 87 | else: 88 | self.update_image_layer(self.image_layer, x, x_name) 89 | 90 | # Create or update label layer 91 | if self.label_layer not in self.viewer.layers: 92 | self.label_layer = self.create_label_layer(y, y_name) 93 | else: 94 | self.update_label_layer(self.label_layer, y, y_name) 95 | 96 | # Reset view 97 | self.viewer.reset_view() 98 | 99 | def create_image_layer(self, x, x_name): 100 | return self.viewer.add_image(x, name=str(x_name)) 101 | 102 | def update_image_layer(self, image_layer, x, x_name): 103 | """Replace the data and the name of a given image_layer""" 104 | image_layer.data = x 105 | image_layer.name = str(x_name) 106 | 107 | def create_label_layer(self, y, y_name): 108 | return self.viewer.add_labels(y, name=str(y_name)) 109 | 110 | def update_label_layer(self, target_layer, y, y_name): 111 | """Replace the data and the name of a given image_layer""" 112 | target_layer.data = y 113 | target_layer.name = str(y_name) 114 | 115 | def get_sample_dataset(self, index): 116 | return self.dataset[index] 117 | 118 | def get_names_dataset(self, index): 119 | return self.dataset.inputs[index], self.dataset.targets[index] 120 | 121 | def transform_x(self, x): 122 | # make sure it's a numpy.ndarray on the cpu 123 | x = x.cpu().numpy() 124 | 125 | # from [C, H, W] to [H, W, C] - only for RGB images. 126 | if self.check_if_rgb(x): 127 | x = np.moveaxis(x, source=0, destination=-1) 128 | 129 | # Re-normalize 130 | x = re_normalize(x) 131 | 132 | return x 133 | 134 | def transform_y(self, y): 135 | # make sure it's a numpy.ndarray on the cpu 136 | y = y.cpu().numpy() 137 | 138 | return y 139 | 140 | def check_if_rgb(self, x): 141 | # checks if the shape of the first dim (channel dim) is 3 142 | # TODO: Try other methods as a 3D grayscale input image can have 3 modalities -> 3 channels 143 | # TODO: Also think about RGBA images with 4 channels or a combination of a RGB and a grayscale image -> 4 channels 144 | return True if x.shape[0] == 3 else False 145 | --------------------------------------------------------------------------------