├── README.md ├── checkpoints └── .gitkeep ├── configs ├── base.yaml ├── r101.yaml └── r50.yaml ├── danet ├── __init__.py ├── backbone.py ├── config.py └── head.py ├── datasets └── prepare_cityscapes.py └── train_net.py /README.md: -------------------------------------------------------------------------------- 1 | # DANet 2 | A PyTorch implementation of DANet based on CVPR 2019 paper [Dual Attention Network for Scene Segmentation](https://arxiv.org/abs/1809.02983). 3 | 4 | ## Requirements 5 | - [Anaconda](https://www.anaconda.com/download/) 6 | - PyTorch 7 | ``` 8 | conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 9 | ``` 10 | - opencv 11 | ``` 12 | pip install opencv-python 13 | ``` 14 | - tensorboard 15 | ``` 16 | pip install tensorboard 17 | ``` 18 | - pycocotools 19 | ``` 20 | pip install git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI 21 | ``` 22 | - fvcore 23 | ``` 24 | pip install git+https://github.com/facebookresearch/fvcore 25 | ``` 26 | - cityscapesScripts 27 | ``` 28 | pip install git+https://github.com/mcordts/cityscapesScripts.git 29 | ``` 30 | - detectron2 31 | ``` 32 | pip install git+https://github.com/facebookresearch/detectron2.git@master 33 | ``` 34 | 35 | ## Datasets 36 | For a few datasets that detectron2 natively supports, the datasets are assumed to exist in a directory called 37 | `datasets/`, under the directory where you launch the program. They need to have the following directory structure: 38 | 39 | ### Expected dataset structure for Cityscapes: 40 | ``` 41 | cityscapes/ 42 | gtFine/ 43 | train/ 44 | aachen/ 45 | color.png, instanceIds.png, labelIds.png, polygons.json, 46 | labelTrainIds.png 47 | ... 48 | val/ 49 | test/ 50 | leftImg8bit/ 51 | train/ 52 | val/ 53 | test/ 54 | ``` 55 | run `./datasets/prepare_cityscapes.py` to creat `labelTrainIds.png`. 56 | 57 | ## Training 58 | To train a model, run 59 | ```bash 60 | python train_net.py --config-file 61 | ``` 62 | 63 | For example, to launch end-to-end DANet training with ResNet-50 backbone on 8 GPUs, one should execute: 64 | ```bash 65 | python train_net.py --config-file configs/r50.yaml --num-gpus 8 66 | ``` 67 | 68 | ## Evaluation 69 | Model evaluation can be done similarly: 70 | ```bash 71 | python train_net.py --config-file configs/r50.yaml --num-gpus 8 --eval-only MODEL.WEIGHTS checkpoints/model.pth 72 | ``` 73 | 74 | ## Results 75 | There are some difference between this implementation and official implementation: 76 | 1. No `Multi-Grid` and `Multi-Scale Testing`; 77 | 2. The image sizes of `Multi-Scale Training` are (800, 832, 864, 896, 928, 960); 78 | 3. Training step is set to `24000`; 79 | 4. Learning rate policy is `WarmupMultiStepLR`; 80 | 5. `Position Attention Module (PAM)` uses the similar mechanism as `Channel Attention Module (CAM)`, just uses the tensor 81 | and its transpose to compute attention. 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 |
Nametrain time (s/iter)inference time (s/im)train mem (GB)PA
%
mean PA %mean IoU %FW IoU %download link
R500.490.1227.1294.1975.3166.6489.54model | ga7k
R1010.650.1628.8194.2976.0867.5789.69model | xnvs
-------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/DANet/31e027a065603e9025ef5546743b6b6e5a8a0475/checkpoints/.gitkeep -------------------------------------------------------------------------------- /configs/base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "SemanticSegmentor" 3 | BACKBONE: 4 | NAME: "build_dilated_resnet_backbone" 5 | DILATED_RESNET: 6 | NORM: "SyncBN" 7 | SEM_SEG_HEAD: 8 | NAME: "DANetHead" 9 | COMMON_STRIDE: 8 10 | CONVS_DIM: 64 11 | IN_FEATURES: ["res5"] 12 | NUM_CLASSES: 19 13 | SOLVER: 14 | BASE_LR: 0.01 15 | STEPS: (18000,) 16 | MAX_ITER: 24000 17 | IMS_PER_BATCH: 8 18 | INPUT: 19 | MIN_SIZE_TRAIN: (800, 832, 864, 896, 928, 960) 20 | MAX_SIZE_TRAIN: 960 21 | MIN_SIZE_TEST: 800 22 | MAX_SIZE_TEST: 960 23 | DATASETS: 24 | TRAIN: ("cityscapes_fine_sem_seg_train",) 25 | TEST: ("cityscapes_fine_sem_seg_val",) 26 | TEST: 27 | EVAL_PERIOD: 2000 28 | 29 | 30 | -------------------------------------------------------------------------------- /configs/r101.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | DILATED_RESNET: 5 | DEPTH: 101 6 | OUTPUT_DIR: "./output/r101" 7 | -------------------------------------------------------------------------------- /configs/r50.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | DILATED_RESNET: 5 | DEPTH: 50 6 | OUTPUT_DIR: "./output/r50" 7 | -------------------------------------------------------------------------------- /danet/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import build_dilated_resnet_backbone 2 | from .config import add_danet_config 3 | from .head import DANetHead 4 | -------------------------------------------------------------------------------- /danet/backbone.py: -------------------------------------------------------------------------------- 1 | import fvcore.nn.weight_init as weight_init 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from detectron2.layers import ( 5 | Conv2d, 6 | FrozenBatchNorm2d, 7 | ShapeSpec, 8 | get_norm, 9 | ) 10 | from detectron2.modeling.backbone import Backbone 11 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY 12 | from torch import nn 13 | 14 | 15 | class BasicStem(nn.Module): 16 | def __init__(self, in_channels=3, out_channels=64, norm="BN"): 17 | """ 18 | Args: 19 | norm (str or callable): a callable that takes the number of 20 | channels and return a `nn.Module`, or a pre-defined string 21 | (one of {"FrozenBN", "BN", "GN"}). 22 | """ 23 | super().__init__() 24 | self.conv1 = Conv2d( 25 | in_channels, 26 | out_channels, 27 | kernel_size=7, 28 | stride=2, 29 | padding=3, 30 | bias=False, 31 | norm=get_norm(norm, out_channels), 32 | ) 33 | weight_init.c2_msra_fill(self.conv1) 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | x = F.relu_(x) 38 | x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) 39 | return x 40 | 41 | @property 42 | def out_channels(self): 43 | return self.conv1.out_channels 44 | 45 | @property 46 | def stride(self): 47 | return 4 # = stride 2 conv -> stride 2 max pool 48 | 49 | 50 | class ResNetBlockBase(nn.Module): 51 | def __init__(self, in_channels, out_channels, stride): 52 | """ 53 | The `__init__` method of any subclass should also contain these arguments. 54 | Args: 55 | in_channels (int): 56 | out_channels (int): 57 | stride (int): 58 | """ 59 | super().__init__() 60 | self.in_channels = in_channels 61 | self.out_channels = out_channels 62 | self.stride = stride 63 | 64 | def freeze(self): 65 | for p in self.parameters(): 66 | p.requires_grad = False 67 | FrozenBatchNorm2d.convert_frozen_batchnorm(self) 68 | return self 69 | 70 | 71 | class BottleneckBlock(ResNetBlockBase): 72 | def __init__( 73 | self, 74 | in_channels, 75 | out_channels, 76 | *, 77 | bottleneck_channels, 78 | stride=1, 79 | num_groups=1, 80 | norm="BN", 81 | stride_in_1x1=False, 82 | dilation=1, 83 | ): 84 | """ 85 | Args: 86 | norm (str or callable): a callable that takes the number of 87 | channels and return a `nn.Module`, or a pre-defined string 88 | (one of {"FrozenBN", "BN", "GN"}). 89 | stride_in_1x1 (bool): when stride==2, whether to put stride in the 90 | first 1x1 convolution or the bottleneck 3x3 convolution. 91 | """ 92 | super().__init__(in_channels, out_channels, stride) 93 | 94 | if in_channels != out_channels: 95 | self.shortcut = Conv2d( 96 | in_channels, 97 | out_channels, 98 | kernel_size=1, 99 | stride=stride, 100 | bias=False, 101 | norm=get_norm(norm, out_channels), 102 | ) 103 | else: 104 | self.shortcut = None 105 | 106 | # The original MSRA ResNet models have stride in the first 1x1 conv 107 | # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have 108 | # stride in the 3x3 conv 109 | stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) 110 | 111 | self.conv1 = Conv2d( 112 | in_channels, 113 | bottleneck_channels, 114 | kernel_size=1, 115 | stride=stride_1x1, 116 | bias=False, 117 | norm=get_norm(norm, bottleneck_channels), 118 | ) 119 | 120 | self.conv2 = Conv2d( 121 | bottleneck_channels, 122 | bottleneck_channels, 123 | kernel_size=3, 124 | stride=stride_3x3, 125 | padding=1 * dilation, 126 | bias=False, 127 | groups=num_groups, 128 | dilation=dilation, 129 | norm=get_norm(norm, bottleneck_channels), 130 | ) 131 | 132 | self.conv3 = Conv2d( 133 | bottleneck_channels, 134 | out_channels, 135 | kernel_size=1, 136 | bias=False, 137 | norm=get_norm(norm, out_channels), 138 | ) 139 | 140 | for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: 141 | if layer is not None: # shortcut can be None 142 | weight_init.c2_msra_fill(layer) 143 | 144 | def forward(self, x): 145 | out = self.conv1(x) 146 | out = F.relu_(out) 147 | 148 | out = self.conv2(out) 149 | out = F.relu_(out) 150 | 151 | out = self.conv3(out) 152 | 153 | if self.shortcut is not None: 154 | shortcut = self.shortcut(x) 155 | else: 156 | shortcut = x 157 | 158 | out += shortcut 159 | out = F.relu_(out) 160 | return out 161 | 162 | 163 | def make_stage(block_class, num_blocks, first_stride, **kwargs): 164 | """ 165 | Create a resnet stage by creating many blocks. 166 | Args: 167 | block_class (class): a subclass of ResNetBlockBase 168 | num_blocks (int): 169 | first_stride (int): the stride of the first block. The other blocks will have stride=1. 170 | A `stride` argument will be passed to the block constructor. 171 | kwargs: other arguments passed to the block constructor. 172 | Returns: 173 | list[nn.Module]: a list of block module. 174 | """ 175 | blocks = [] 176 | for i in range(num_blocks): 177 | blocks.append(block_class(stride=first_stride if i == 0 else 1, **kwargs)) 178 | kwargs["in_channels"] = kwargs["out_channels"] 179 | return blocks 180 | 181 | 182 | class DilatedResNet(Backbone): 183 | def __init__(self, stem, stages, num_classes=None, out_features=None): 184 | """ 185 | Args: 186 | stem (nn.Module): a stem module 187 | stages (list[list[ResNetBlock]]): several (typically 4) stages, 188 | each contains multiple :class:`ResNetBlockBase`. 189 | num_classes (None or int): if None, will not perform classification. 190 | out_features (list[str]): name of the layers whose outputs should 191 | be returned in forward. Can be anything in "stem", "linear", or "res2" ... 192 | If None, will return the output of the last layer. 193 | """ 194 | super(DilatedResNet, self).__init__() 195 | self.stem = stem 196 | self.num_classes = num_classes 197 | 198 | current_stride = self.stem.stride 199 | self._out_feature_strides = {"stem": current_stride} 200 | self._out_feature_channels = {"stem": self.stem.out_channels} 201 | 202 | self.stages_and_names = [] 203 | for i, blocks in enumerate(stages): 204 | for block in blocks: 205 | assert isinstance(block, ResNetBlockBase), block 206 | curr_channels = block.out_channels 207 | stage = nn.Sequential(*blocks) 208 | name = "res" + str(i + 2) 209 | self.add_module(name, stage) 210 | self.stages_and_names.append((stage, name)) 211 | self._out_feature_strides[name] = current_stride = int( 212 | current_stride * np.prod([k.stride for k in blocks]) 213 | ) 214 | self._out_feature_channels[name] = blocks[-1].out_channels 215 | 216 | if num_classes is not None: 217 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 218 | self.linear = nn.Linear(curr_channels, num_classes) 219 | nn.init.normal_(self.linear.weight, stddev=0.01) 220 | name = "linear" 221 | 222 | if out_features is None: 223 | out_features = [name] 224 | self._out_features = out_features 225 | assert len(self._out_features) 226 | children = [x[0] for x in self.named_children()] 227 | for out_feature in self._out_features: 228 | assert out_feature in children, "Available children: {}".format(", ".join(children)) 229 | 230 | def forward(self, x): 231 | outputs = {} 232 | x = self.stem(x) 233 | if "stem" in self._out_features: 234 | outputs["stem"] = x 235 | for stage, name in self.stages_and_names: 236 | x = stage(x) 237 | if name in self._out_features: 238 | outputs[name] = x 239 | if self.num_classes is not None: 240 | x = self.avgpool(x) 241 | x = self.linear(x) 242 | if "linear" in self._out_features: 243 | outputs["linear"] = x 244 | return outputs 245 | 246 | def output_shape(self): 247 | return { 248 | name: ShapeSpec( 249 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 250 | ) 251 | for name in self._out_features 252 | } 253 | 254 | 255 | @BACKBONE_REGISTRY.register() 256 | def build_dilated_resnet_backbone(cfg, input_shape): 257 | """ 258 | Create a Dilated ResNet instance from config. 259 | Returns: 260 | DilatedResNet: a :class:`DilatedResNet` instance. 261 | """ 262 | # need registration of new blocks/stems? 263 | norm = cfg.MODEL.DILATED_RESNET.NORM 264 | stem = BasicStem( 265 | in_channels=input_shape.channels, 266 | out_channels=cfg.MODEL.DILATED_RESNET.STEM_OUT_CHANNELS, 267 | norm=norm, 268 | ) 269 | 270 | # fmt: off 271 | out_features = cfg.MODEL.DILATED_RESNET.OUT_FEATURES 272 | depth = cfg.MODEL.DILATED_RESNET.DEPTH 273 | num_groups = cfg.MODEL.DILATED_RESNET.NUM_GROUPS 274 | width_per_group = cfg.MODEL.DILATED_RESNET.WIDTH_PER_GROUP 275 | bottleneck_channels = num_groups * width_per_group 276 | in_channels = cfg.MODEL.DILATED_RESNET.STEM_OUT_CHANNELS 277 | out_channels = cfg.MODEL.DILATED_RESNET.RES2_OUT_CHANNELS 278 | stride_in_1x1 = cfg.MODEL.DILATED_RESNET.STRIDE_IN_1X1 279 | res4_dilation = cfg.MODEL.DILATED_RESNET.RES4_DILATION 280 | res5_dilation = cfg.MODEL.DILATED_RESNET.RES5_DILATION 281 | # fmt: on 282 | 283 | num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] 284 | 285 | stages = [] 286 | 287 | # Avoid creating variables without gradients 288 | # It consumes extra memory and may cause allreduce to fail 289 | out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] 290 | max_stage_idx = max(out_stage_idx) 291 | for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): 292 | if stage_idx == 4: 293 | dilation = res4_dilation 294 | elif stage_idx == 5: 295 | dilation = res5_dilation 296 | else: 297 | dilation = 1 298 | first_stride = 1 if idx == 0 or (stage_idx == 4 and dilation == 2) or (stage_idx == 5 and dilation == 4) else 2 299 | stage_kargs = {"num_blocks": num_blocks_per_stage[idx], "first_stride": first_stride, 300 | "in_channels": in_channels, "bottleneck_channels": bottleneck_channels, 301 | "out_channels": out_channels, "num_groups": num_groups, "norm": norm, 302 | "stride_in_1x1": stride_in_1x1, "dilation": dilation, "block_class": BottleneckBlock} 303 | blocks = make_stage(**stage_kargs) 304 | in_channels = out_channels 305 | out_channels *= 2 306 | bottleneck_channels *= 2 307 | 308 | stages.append(blocks) 309 | return DilatedResNet(stem, stages, out_features=out_features) 310 | -------------------------------------------------------------------------------- /danet/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | 4 | def add_danet_config(cfg): 5 | """ 6 | Add config for DANet. 7 | """ 8 | _C = cfg 9 | 10 | _C.MODEL.DILATED_RESNET = CN() 11 | 12 | _C.MODEL.DILATED_RESNET.DEPTH = 50 13 | _C.MODEL.DILATED_RESNET.OUT_FEATURES = ["res5"] 14 | _C.MODEL.DILATED_RESNET.NUM_GROUPS = 1 15 | _C.MODEL.DILATED_RESNET.NORM = "FrozenBN" 16 | _C.MODEL.DILATED_RESNET.WIDTH_PER_GROUP = 64 17 | _C.MODEL.DILATED_RESNET.STRIDE_IN_1X1 = True 18 | _C.MODEL.DILATED_RESNET.RES2_OUT_CHANNELS = 256 19 | _C.MODEL.DILATED_RESNET.STEM_OUT_CHANNELS = 64 20 | _C.MODEL.DILATED_RESNET.RES4_DILATION = 2 21 | _C.MODEL.DILATED_RESNET.RES5_DILATION = 4 22 | -------------------------------------------------------------------------------- /danet/head.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import fvcore.nn.weight_init as weight_init 4 | import numpy as np 5 | import torch 6 | from detectron2.layers import Conv2d, ShapeSpec 7 | from detectron2.modeling.meta_arch.semantic_seg import SEM_SEG_HEADS_REGISTRY 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | 12 | @SEM_SEG_HEADS_REGISTRY.register() 13 | class DANetHead(nn.Module): 14 | def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): 15 | super().__init__() 16 | 17 | # fmt: off 18 | self.in_features = cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES 19 | feature_strides = {k: v.stride for k, v in input_shape.items()} 20 | feature_channels = {k: v.channels for k, v in input_shape.items()} 21 | self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE 22 | num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES 23 | conv_dims = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM 24 | self.common_stride = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE 25 | norm = cfg.MODEL.SEM_SEG_HEAD.NORM 26 | # fmt: on 27 | 28 | self.scale_pam_heads = [] 29 | for in_feature in self.in_features: 30 | head_ops = [] 31 | head_length = max( 32 | 1, int(np.log2(feature_strides[in_feature]) - np.log2(self.common_stride)) 33 | ) 34 | for k in range(head_length): 35 | norm_module = nn.GroupNorm(32, conv_dims) if norm == "GN" else None 36 | conv = Conv2d( 37 | feature_channels[in_feature] if k == 0 else conv_dims, 38 | conv_dims, 39 | kernel_size=3, 40 | stride=1, 41 | padding=1, 42 | bias=not norm, 43 | norm=norm_module, 44 | activation=F.relu, 45 | ) 46 | weight_init.c2_msra_fill(conv) 47 | head_ops.append(conv) 48 | if feature_strides[in_feature] != self.common_stride: 49 | head_ops.append( 50 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) 51 | ) 52 | self.scale_pam_heads.append(nn.Sequential(*head_ops)) 53 | self.add_module(in_feature + '_pam', self.scale_pam_heads[-1]) 54 | 55 | self.scale_cam_heads = [] 56 | for in_feature in self.in_features: 57 | head_ops = [] 58 | head_length = max( 59 | 1, int(np.log2(feature_strides[in_feature]) - np.log2(self.common_stride)) 60 | ) 61 | for k in range(head_length): 62 | norm_module = nn.GroupNorm(32, conv_dims) if norm == "GN" else None 63 | conv = Conv2d( 64 | feature_channels[in_feature] if k == 0 else conv_dims, 65 | conv_dims, 66 | kernel_size=3, 67 | stride=1, 68 | padding=1, 69 | bias=not norm, 70 | norm=norm_module, 71 | activation=F.relu, 72 | ) 73 | weight_init.c2_msra_fill(conv) 74 | head_ops.append(conv) 75 | if feature_strides[in_feature] != self.common_stride: 76 | head_ops.append( 77 | nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) 78 | ) 79 | self.scale_cam_heads.append(nn.Sequential(*head_ops)) 80 | self.add_module(in_feature + '_cam', self.scale_cam_heads[-1]) 81 | 82 | self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0) 83 | weight_init.c2_msra_fill(self.predictor) 84 | 85 | def forward(self, features, targets=None): 86 | for i, f in enumerate(self.in_features): 87 | if i == 0: 88 | pam = self.scale_pam_heads[i](features[f]) 89 | cam = self.scale_cam_heads[i](features[f]) 90 | else: 91 | pam = pam + self.scale_pam_heads[i](features[f]) 92 | cam = cam + self.scale_cam_heads[i](features[f]) 93 | 94 | b, c, h, w = pam.size() 95 | B_T = pam.view(b, c, -1) 96 | B = B_T.transpose(-1, -2).contiguous() 97 | pam_weight = F.softmax(torch.matmul(B, B_T), dim=-1).view(b, 1, h * w, h * w) 98 | weighted_pam = torch.matmul(pam_weight, pam.view(b, c, h * w, 1)).view(b, c, h, w) 99 | sum_pam = pam + weighted_pam 100 | 101 | b, c, h, w = cam.size() 102 | A = cam.view(b, c, -1) 103 | A_T = A.transpose(-1, -2).contiguous() 104 | cam_weight = F.softmax(torch.matmul(A, A_T), dim=-1).view(b, 1, c, c) 105 | weighted_cam = torch.matmul(cam_weight, cam.view(b, c, h * w).transpose(-1, -2).contiguous() 106 | .view(b, h * w, c, 1)).view(b, h * w, c).transpose(-1, -2).contiguous() \ 107 | .view(b, c, h, w) 108 | sum_cam = cam + weighted_cam 109 | 110 | x = self.predictor(sum_pam + sum_cam) 111 | x = F.interpolate(x, scale_factor=self.common_stride, mode="bilinear", align_corners=False) 112 | 113 | if self.training: 114 | losses = {} 115 | losses["loss_sem_seg"] = (F.cross_entropy(x, targets, reduction="mean", ignore_index=self.ignore_value)) 116 | return [], losses 117 | else: 118 | return x, {} 119 | -------------------------------------------------------------------------------- /datasets/prepare_cityscapes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Converts the polygonal annotations of the Cityscapes dataset 4 | # to images, where pixel values encode ground truth classes. 5 | # 6 | # The Cityscapes downloads already include such images 7 | # a) *color.png : the class is encoded by its color 8 | # b) *labelIds.png : the class is encoded by its ID 9 | # c) *instanceIds.png : the class and the instance are encoded by an instance ID 10 | # 11 | # With this tool, you can generate option 12 | # d) *labelTrainIds.png : the class is encoded by its training ID 13 | # This encoding might come handy for training purposes. You can use 14 | # the file labels.py to define the training IDs that suit your needs. 15 | # Note however, that once you submit or evaluate results, the regular 16 | # IDs are needed. 17 | # 18 | # Uses the converter tool in 'json2labelImg.py' 19 | # Uses the mapping defined in 'labels.py' 20 | # 21 | 22 | # python imports 23 | from __future__ import print_function, absolute_import, division 24 | 25 | import glob 26 | import os 27 | import sys 28 | 29 | # cityscapes imports 30 | from cityscapesscripts.helpers.csHelpers import printError 31 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 32 | 33 | 34 | # The main method 35 | def main(): 36 | # Where to look for Cityscapes 37 | cityscapesPath = os.path.join(os.path.dirname(__file__), "cityscapes") 38 | # how to search for all ground truth 39 | searchFine = os.path.join(cityscapesPath, "gtFine", "*", "*", "*_gt*_polygons.json") 40 | searchCoarse = os.path.join(cityscapesPath, "gtCoarse", "*", "*", "*_gt*_polygons.json") 41 | 42 | # search files 43 | filesFine = glob.glob(searchFine) 44 | filesFine.sort() 45 | filesCoarse = glob.glob(searchCoarse) 46 | filesCoarse.sort() 47 | 48 | # concatenate fine and coarse 49 | files = filesFine + filesCoarse 50 | # files = filesFine # use this line if fine is enough for now. 51 | 52 | # quit if we did not find anything 53 | if not files: 54 | printError("Did not find any files. Please consult the README.") 55 | 56 | # a bit verbose 57 | print("Processing {} annotation files".format(len(files))) 58 | 59 | # iterate through files 60 | progress = 0 61 | print("Progress: {:>3} %".format(progress * 100 / len(files)), end=' ') 62 | for f in files: 63 | # create the output filename 64 | dst = f.replace("_polygons.json", "_labelTrainIds.png") 65 | 66 | # do the conversion 67 | try: 68 | json2labelImg(f, dst, "trainIds") 69 | except: 70 | print("Failed to convert: {}".format(f)) 71 | raise 72 | 73 | # status 74 | progress += 1 75 | print("\rProgress: {:>3} %".format(progress * 100 / len(files)), end=' ') 76 | sys.stdout.flush() 77 | 78 | 79 | # call the main 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Detection Training Script. 4 | This scripts reads a given config file and runs the training or evaluation. 5 | It is an entry point that is made to train standard models in detectron2. 6 | In order to let one script support training of many models, 7 | this script contains logic that are specific to these built-in models and therefore 8 | may not be suitable for your own project. 9 | For example, your research project perhaps only needs a single "evaluator". 10 | Therefore, we recommend you to use detectron2 as an library and take 11 | this file as an example of how to use the library. 12 | You may want to write your own script with your datasets and other customizations. 13 | """ 14 | 15 | import logging 16 | import os 17 | from collections import OrderedDict 18 | 19 | import detectron2.utils.comm as comm 20 | import torch 21 | from detectron2.checkpoint import DetectionCheckpointer 22 | from detectron2.config import get_cfg 23 | from detectron2.data import MetadataCatalog 24 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch 25 | from detectron2.evaluation import ( 26 | CityscapesEvaluator, 27 | COCOEvaluator, 28 | COCOPanopticEvaluator, 29 | DatasetEvaluators, 30 | LVISEvaluator, 31 | PascalVOCDetectionEvaluator, 32 | SemSegEvaluator, 33 | verify_results, 34 | ) 35 | from detectron2.modeling import GeneralizedRCNNWithTTA 36 | 37 | from danet import add_danet_config 38 | 39 | 40 | class Trainer(DefaultTrainer): 41 | """ 42 | We use the "DefaultTrainer" which contains a number pre-defined logic for 43 | standard training workflow. They may not work for you, especially if you 44 | are working on a new research project. In that case you can use the cleaner 45 | "SimpleTrainer", or write your own training loop. 46 | """ 47 | 48 | @classmethod 49 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 50 | """ 51 | Create evaluator(s) for a given dataset. 52 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 53 | For your own dataset, you can simply create an evaluator manually in your 54 | script and do not have to worry about the hacky if-else logic here. 55 | """ 56 | if output_folder is None: 57 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 58 | evaluator_list = [] 59 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 60 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 61 | evaluator_list.append( 62 | SemSegEvaluator( 63 | dataset_name, 64 | distributed=True, 65 | num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, 66 | ignore_label=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, 67 | output_dir=output_folder, 68 | ) 69 | ) 70 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 71 | evaluator_list.append(COCOEvaluator(dataset_name, cfg, True, output_folder)) 72 | if evaluator_type == "coco_panoptic_seg": 73 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 74 | if evaluator_type == "cityscapes": 75 | assert ( 76 | torch.cuda.device_count() >= comm.get_rank() 77 | ), "CityscapesEvaluator currently do not work with multiple machines." 78 | return CityscapesEvaluator(dataset_name) 79 | if evaluator_type == "pascal_voc": 80 | return PascalVOCDetectionEvaluator(dataset_name) 81 | if evaluator_type == "lvis": 82 | return LVISEvaluator(dataset_name, cfg, True, output_folder) 83 | if len(evaluator_list) == 0: 84 | raise NotImplementedError( 85 | "no Evaluator for the dataset {} with the type {}".format( 86 | dataset_name, evaluator_type 87 | ) 88 | ) 89 | if len(evaluator_list) == 1: 90 | return evaluator_list[0] 91 | return DatasetEvaluators(evaluator_list) 92 | 93 | @classmethod 94 | def test_with_TTA(cls, cfg, model): 95 | logger = logging.getLogger("detectron2.trainer") 96 | # In the end of training, run an evaluation with TTA 97 | # Only support some R-CNN models. 98 | logger.info("Running inference with test-time augmentation ...") 99 | model = GeneralizedRCNNWithTTA(cfg, model) 100 | evaluators = [ 101 | cls.build_evaluator( 102 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 103 | ) 104 | for name in cfg.DATASETS.TEST 105 | ] 106 | res = cls.test(cfg, model, evaluators) 107 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 108 | return res 109 | 110 | 111 | def setup(args): 112 | """ 113 | Create configs and perform basic setups. 114 | """ 115 | cfg = get_cfg() 116 | add_danet_config(cfg) 117 | cfg.merge_from_file(args.config_file) 118 | cfg.merge_from_list(args.opts) 119 | cfg.freeze() 120 | default_setup(cfg, args) 121 | return cfg 122 | 123 | 124 | def main(args): 125 | cfg = setup(args) 126 | 127 | if args.eval_only: 128 | model = Trainer.build_model(cfg) 129 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 130 | cfg.MODEL.WEIGHTS, resume=args.resume 131 | ) 132 | res = Trainer.test(cfg, model) 133 | if comm.is_main_process(): 134 | verify_results(cfg, res) 135 | if cfg.TEST.AUG.ENABLED: 136 | res.update(Trainer.test_with_TTA(cfg, model)) 137 | return res 138 | 139 | """ 140 | If you'd like to do anything fancier than the standard training logic, 141 | consider writing your own training loop or subclassing the trainer. 142 | """ 143 | trainer = Trainer(cfg) 144 | trainer.resume_or_load(resume=args.resume) 145 | if cfg.TEST.AUG.ENABLED: 146 | trainer.register_hooks( 147 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] 148 | ) 149 | return trainer.train() 150 | 151 | 152 | if __name__ == "__main__": 153 | args = default_argument_parser().parse_args() 154 | print("Command Line Args:", args) 155 | launch( 156 | main, 157 | args.num_gpus, 158 | num_machines=args.num_machines, 159 | machine_rank=args.machine_rank, 160 | dist_url=args.dist_url, 161 | args=(args,), 162 | ) 163 | --------------------------------------------------------------------------------