├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── semantic_segmentation.iml └── vcs.xml ├── Deeplabv3.py ├── README.md ├── accuracy_upperbound.py ├── augment.py ├── benchmark.py ├── blocks.py ├── cityscapes.py ├── coco_download.sh ├── coco_utils.py ├── configs ├── cityscapes_regnety40_160epochs_mixed_precision.yaml ├── configs_sanity_check.py ├── voc_mobilenetv2_30epochs.yaml ├── voc_regnetx40_30epochs.yaml ├── voc_regnety40_30epochs.yaml ├── voc_regnety40_30epochs_mixed_precision.yaml ├── voc_resnet50d_30epochs.yaml └── yoho.yaml ├── custom_dataset.py ├── data.py ├── data_utils.py ├── experimental_models.py ├── fov.py ├── google0ccf7212e9c814a7.html ├── model.py ├── requirements.txt ├── show.py ├── train.py ├── transforms.py └── voc12.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project exclude paths 2 | /venv/ 3 | cityscapes_dataset.zip 4 | hello.py 5 | pascal_voc_dataset.zip 6 | /cifar/ 7 | /cityscapes_dataset/ 8 | /pascal_voc_dataset/ 9 | checkpoints 10 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 22 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/semantic_segmentation.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /Deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from copy import deepcopy 5 | 6 | class AtrousSeparableConvolution(nn.Sequential): 7 | def __init__(self, in_channels, out_channels, kernel_size, 8 | stride=1, padding=0, dilation=1, bias=True,add_norm=True): 9 | modules=[] 10 | modules.append(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, 11 | stride=stride, padding=padding, dilation=dilation, 12 | bias=(not add_norm), groups=in_channels)) 13 | if add_norm: 14 | modules.append(nn.BatchNorm2d(in_channels)) 15 | modules.append(nn.ReLU(inplace=True)) 16 | modules.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, 17 | padding=0, bias=bias)) 18 | super().__init__(*modules) 19 | 20 | 21 | class DeepLabHead(nn.Sequential): 22 | def __init__(self, in_channels, num_classes,output_stride): 23 | base_rates=[3,6,9] 24 | mul=32//output_stride 25 | rates=[x*mul for x in base_rates] 26 | super().__init__( 27 | ASPP(in_channels, rates), 28 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 29 | nn.BatchNorm2d(256), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(256, num_classes, 1) 32 | ) 33 | class DeepLabHeadNoASSP(nn.Sequential): 34 | def __init__(self, in_channels, num_classes): 35 | super().__init__( 36 | ASPP(in_channels, []), 37 | nn.Conv2d(256, num_classes, 1) 38 | ) 39 | 40 | def get_ASSP(in_channels,output_stride,output_channels=256): 41 | base_rates = [3, 6, 9] 42 | mul = 32 // output_stride 43 | rates = [x * mul for x in base_rates] 44 | return ASPP(in_channels, rates,output_channels) 45 | 46 | class ASPPConv(nn.Sequential): 47 | def __init__(self, in_channels, out_channels, dilation): 48 | modules = [ 49 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 50 | nn.BatchNorm2d(out_channels), 51 | nn.ReLU(inplace=True) 52 | ] 53 | super(ASPPConv, self).__init__(*modules) 54 | 55 | class ASPPPooling(nn.Sequential): 56 | def __init__(self, in_channels, out_channels): 57 | super(ASPPPooling, self).__init__( 58 | nn.AdaptiveAvgPool2d(1), 59 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 60 | nn.BatchNorm2d(out_channels), 61 | nn.ReLU(inplace=True)) 62 | 63 | def forward(self, x): 64 | size = x.shape[-2:] 65 | for mod in self: 66 | x = mod(x) 67 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 68 | 69 | class ASPP(nn.Module): 70 | def __init__(self, in_channels, atrous_rates, out_channels=256,intermediate_channels=256,dropout=0.5): 71 | super(ASPP, self).__init__() 72 | modules = [] 73 | modules.append(nn.Sequential( 74 | nn.Conv2d(in_channels, intermediate_channels, 1, bias=False), 75 | nn.BatchNorm2d(intermediate_channels), 76 | nn.ReLU(inplace=True))) 77 | 78 | rates = tuple(atrous_rates) 79 | for rate in rates: 80 | modules.append(ASPPConv(in_channels, intermediate_channels, rate)) 81 | 82 | modules.append(ASPPPooling(in_channels, intermediate_channels)) 83 | 84 | self.convs = nn.ModuleList(modules) 85 | num_branches=len(self.convs) 86 | self.project = nn.Sequential( 87 | nn.Conv2d(num_branches * intermediate_channels, out_channels, 1, bias=False), 88 | nn.BatchNorm2d(out_channels), 89 | nn.ReLU(inplace=True), 90 | nn.Dropout(dropout)) 91 | 92 | def forward(self, x): 93 | res = [] 94 | for conv in self.convs: 95 | res.append(conv(x)) 96 | res = torch.cat(res, dim=1) 97 | return self.project(res) 98 | 99 | 100 | def convert_to_separable_conv(module,deep_copy=True): 101 | new_module=module 102 | if deep_copy: 103 | new_module = deepcopy(module) 104 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1 and module.groups == 1: 105 | new_module = AtrousSeparableConvolution( 106 | module.in_channels, 107 | module.out_channels, 108 | module.kernel_size, 109 | module.stride, 110 | module.padding, 111 | module.dilation, 112 | module.bias is not None) 113 | for name, child in new_module.named_children(): 114 | new_module.add_module(name, convert_to_separable_conv(child,deep_copy=False)) 115 | return new_module 116 | 117 | if __name__=='__main__': 118 | module=nn.Conv2d(5,5,1,bias=False) 119 | print(module.in_channels, 120 | module.out_channels, 121 | module.kernel_size, 122 | module.stride, 123 | module.padding, 124 | module.dilation, 125 | module.bias is not None) 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch_DeepLab 2 | 3 | This repo is old. Go check out my new model [RegSeg](https://github.com/RolandGao/RegSeg) that achieved SOTA on real-time semantic segmentation on Cityscapes. 4 | 5 | Currently, the code supports DeepLabv3+ with many common backbones, such as Mobilenetv2, Mobilenetv3, Resnet, Resnetv2, XceptionAligned, Regnet, EfficientNet, and many more, thanks to the package [timm](https://github.com/rwightman/pytorch-image-models). The code supports 3 datasets, namely PascalVoc, Coco, and Cityscapes. 6 | 7 | I trained a few models on Cityscapes and PascalVoc, and will release the weights soon. 8 | 9 | ## Results 10 | 11 | Using separable convolution in the decoder 12 | reduces model size and the number of flops, 13 | but increases the memory requirement by 1 GB during training. 14 | 15 | #### PascalVoc 16 | To use the weights, click the link, and instantiate an object like the line below, 17 | changing the name, sc("separable convolution"), and the path to the pretrained weights that you just downloaded. 18 | 19 | ``` 20 | model=Deeplab3P(name='regnetx_040',num_classes=21, 21 | sc=False,pretrained=pretrained_path).to(device) 22 | ``` 23 | name | separable convolution | mIOU | weights 24 | --- | --- | --- | --- 25 | resnet50d | yes | 77.1 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_resnet50d) 26 | regnetx_040 | yes | 77.0 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_regnetx40) 27 | regnety_040 | yes | 78.6 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_regnety40) 28 | regnetx_080 | no | 77.3 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_regnetx80) 29 | mobilenetv2 | no | 72.8 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_mobilenetv2) 30 | 31 | ## Installation 32 | After cloning the repository, run the following command to install all dependencies. 33 | pip install -r requirements.txt 34 | 35 | ## Datasets 36 | #### COCO 37 | run the command 38 | ```shell 39 | sh coco_download.sh 40 | ``` 41 | We use the 21 classes that intersect PascalVoc's. 42 | 43 | #### Cityscapes 44 | go to https://www.cityscapes-dataset.com, create an account, and download 45 | gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. 46 | You can delete the test images to save some space if you don't want to submit to the competition. 47 | Name the directory cityscapes_dataset. 48 | Make sure that you have downloaded the required python packages and run 49 | ``` 50 | CITYSCAPES_DATASET=cityscapes_dataset csCreateTrainIdLabelImgs 51 | ``` 52 | There are 19 classes. 53 | 54 | #### PascalVoc 55 | Download the original dataset [here](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar). 56 | 57 | Then download the augmented dataset [here](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0), 58 | and create a text file named train_aug.txt with [this content](https://gist.githubusercontent.com/sun11/2dbda6b31acc7c6292d14a872d0c90b7/raw/5f5a5270089239ef2f6b65b1cc55208355b5acca/trainaug.txt). 59 | 60 | Place train_aug.txt in VOCdevkit/VOC2012/ImageSets/Segmentation/train_aug.txt 61 | 62 | Place SegmentationClassAug directory in VOCdevkit/VOC2012/SegmentationClassAug 63 | 64 | There are 21 claases. 65 | 66 | Credits to https://www.sun11.me/blog/2018/how-to-use-10582-trainaug-images-on-DeeplabV3-code/ 67 | 68 | 69 | #### Once you have downloaded the dataset 70 | do one of the following three lines in train.py 71 | ``` 72 | data_loader, data_loader_test=get_coco(root,batch_size=16) 73 | data_loader, data_loader_test=get_pascal_voc(root,batch_size=16) 74 | data_loader, data_loader_test=get_cityscapes(root,batch_size=16) 75 | ``` 76 | where the root is usually "." or the top level directory name of the dataset. 77 | 78 | ## To train a model yourself 79 | Download one of the three datasets, change save_path, and num_classes in train.py if necessary, and run the command 80 | ``` 81 | python train.py 82 | ``` 83 | 84 | ## To resume training 85 | In train.py, set resume=True, and change the resume_path to the save_path of your last train session. 86 | -------------------------------------------------------------------------------- /accuracy_upperbound.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | import torch 4 | from torch.nn import functional as F 5 | from train import ConfusionMatrix 6 | from data import get_cityscapes,get_pascal_voc 7 | 8 | class DoNothingNet(nn.Module): 9 | def __init__(self,output_stride=16,mode="bilinear"): 10 | super(DoNothingNet,self).__init__() 11 | self.os=output_stride 12 | self.mode=mode 13 | def forward(self,y): 14 | shape = y.shape[-2:] 15 | downsample_shape=((shape[0]-1)//self.os+1,(shape[1]-1)//self.os+1) 16 | if self.mode=="bilinear": 17 | y=F.interpolate(y,size=downsample_shape,mode='bilinear',align_corners=False) 18 | elif self.mode=="adaptive_avg": 19 | y=F.adaptive_avg_pool2d(y,downsample_shape) 20 | elif self.mode=="adaptive_max": 21 | y=F.adaptive_max_pool2d(y,downsample_shape) 22 | elif self.mode=="max3x3": 23 | y=F.max_pool2d(y,kernel_size=3,stride=2,padding=1) 24 | elif self.mode=="avg3x3": 25 | y=F.avg_pool2d(y,kernel_size=3,stride=2,padding=1) 26 | else: 27 | raise NotImplementedError() 28 | y=F.interpolate(y,size=shape,mode='bilinear',align_corners=False) 29 | return y 30 | 31 | def evaluate(model, data_loader, device, num_classes,eval_steps,print_every=100): 32 | model.eval() 33 | confmat = ConfusionMatrix(num_classes) 34 | with torch.no_grad(): 35 | for i,(image, target) in enumerate(data_loader): 36 | if (i+1)%print_every==0: 37 | print(i+1) 38 | if i==eval_steps: 39 | break 40 | target = target.to(device) 41 | logits=torch.zeros(target.shape[0],num_classes,target.shape[1],target.shape[2],device=device) 42 | target2=torch.unsqueeze(target,1) 43 | target2[target2==255]=0 44 | logits.scatter_(1,target2,1) 45 | output = model(logits) 46 | confmat.update(target.flatten(), output.argmax(1).flatten()) 47 | return confmat 48 | 49 | def f(os,mode,device,num_classes): 50 | net=DoNothingNet(output_stride=os,mode=mode).to(device) 51 | #data_loader, data_loader_test=get_pascal_voc("pascal_voc_dataset",16,train_size=481,val_size=513) 52 | data_loader, data_loader_test=get_cityscapes("cityscapes_dataset",16,train_size=480,val_size=1024,num_workers=0) 53 | confmat=evaluate(net,data_loader_test,device,num_classes,eval_steps=100,print_every=20) 54 | return confmat 55 | 56 | def experiment1(): 57 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 58 | num_classes=19 59 | print() 60 | for mode in ["adaptive_avg"]:#"bilinear","adaptive_max","max3x3", 61 | for os in [2,4,8,16,32]: 62 | confmat=f(os,mode,device,num_classes) 63 | # net=DoNothingNet(output_stride=os,mode=mode).to(device) 64 | # #data_loader, data_loader_test=get_pascal_voc("pascal_voc_dataset",16,train_size=481,val_size=513) 65 | # data_loader, data_loader_test=get_cityscapes("cityscapes_dataset",16,train_size=480,val_size=1024) 66 | # confmat=evaluate(net,data_loader_test,device,num_classes,eval_steps=300,print_every=100) 67 | print(mode,os) 68 | print(confmat) 69 | 70 | if __name__=="__main__": 71 | experiment1() 72 | -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Lightweight and simple implementation of AutoAugment and RandAugment. 10 | 11 | AutoAugment - https://arxiv.org/abs/1805.09501 12 | RandAugment - https://arxiv.org/abs/1909.13719 13 | 14 | http://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 15 | Note that the official implementation varies substantially from the papers :-( 16 | 17 | Our AutoAugment policy should be fairly identical to the official AutoAugment policy. 18 | The main difference is we set POSTERIZE_MIN = 1, which avoids degenerate (all 0) images. 19 | Our RandAugment policy differs, and uses transforms that increase in intensity with 20 | increasing magnitude. This allows for a more natural control of the magnitude. That is, 21 | setting magnitude = 0 results in ops that leaves the image unchanged, if possible. 22 | We also set the range of the magnitude to be 0 to 1 to avoid setting a "max level". 23 | 24 | Our implementation is inspired by and uses policies that are the similar to those in: 25 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 26 | Specifically our implementation can be *numerically identical* as the implementation in 27 | timm if using timm's "v0" policy for AutoAugment and "inc" transforms for RandAugment 28 | and if we set POSTERIZE_MIN = 0 (although as noted our default is POSTERIZE_MIN = 1). 29 | Note that magnitude in our code ranges from 0 to 1 (compared to 0 to 10 in timm). 30 | 31 | Specifically, given the same seeds, the functions from timm: 32 | out_auto = auto_augment_transform("v0", {"interpolation": 2})(im) 33 | out_rand = rand_augment_transform("rand-inc1-n2-m05", {"interpolation": 2})(im) 34 | Are numerically equivalent to: 35 | POSTERIZE_MIN = 0 36 | out_auto = auto_augment(im) 37 | out_rand = rand_augment(im, prob=0.5, n_ops=2, magnitude=0.5) 38 | Tested as of 10/07/2020. Can alter corresponding params for both and should match. 39 | 40 | Finally, the ops and augmentations can be visualized as follows: 41 | from PIL import Image 42 | import pycls.datasets.augment as augment 43 | im = Image.open("scratch.jpg") 44 | im_ops = augment.visualize_ops(im) 45 | im_rand = augment.visualize_aug(im, augment=augment.rand_augment, magnitude=0.5) 46 | im_auto = augment.visualize_aug(im, augment=augment.auto_augment) 47 | im_ops.show() 48 | im_auto.show() 49 | im_rand.show() 50 | """ 51 | 52 | import random 53 | 54 | import numpy as np 55 | from PIL import Image, ImageEnhance, ImageOps 56 | 57 | 58 | # Minimum value for posterize (0 in EfficientNet implementation) 59 | POSTERIZE_MIN = 1 60 | 61 | # Parameters for affine warping and rotation 62 | WARP_PARAMS = {"fillcolor": (128, 128, 128), "resample": Image.BILINEAR} 63 | 64 | 65 | def affine_warp(im, data): 66 | """Applies affine transform to image.""" 67 | return im.transform(im.size, Image.AFFINE, data, **WARP_PARAMS) 68 | 69 | 70 | OP_FUNCTIONS = { 71 | # Each op takes an image x and a level v and returns an augmented image. 72 | "auto_contrast": lambda x, _: ImageOps.autocontrast(x), 73 | "equalize": lambda x, _: ImageOps.equalize(x), 74 | "invert": lambda x, _: ImageOps.invert(x), 75 | "rotate": lambda x, v: x.rotate(v, **WARP_PARAMS), 76 | "posterize": lambda x, v: ImageOps.posterize(x, max(POSTERIZE_MIN, int(v))), 77 | "posterize_inc": lambda x, v: ImageOps.posterize(x, max(POSTERIZE_MIN, 4 - int(v))), 78 | "solarize": lambda x, v: x.point(lambda i: i if i < int(v) else 255 - i), 79 | "solarize_inc": lambda x, v: x.point(lambda i: i if i < 256 - v else 255 - i), 80 | "solarize_add": lambda x, v: x.point(lambda i: min(255, v + i) if i < 128 else i), 81 | "color": lambda x, v: ImageEnhance.Color(x).enhance(v), 82 | "contrast": lambda x, v: ImageEnhance.Contrast(x).enhance(v), 83 | "brightness": lambda x, v: ImageEnhance.Brightness(x).enhance(v), 84 | "sharpness": lambda x, v: ImageEnhance.Sharpness(x).enhance(v), 85 | "color_inc": lambda x, v: ImageEnhance.Color(x).enhance(1 + v), 86 | "contrast_inc": lambda x, v: ImageEnhance.Contrast(x).enhance(1 + v), 87 | "brightness_inc": lambda x, v: ImageEnhance.Brightness(x).enhance(1 + v), 88 | "sharpness_inc": lambda x, v: ImageEnhance.Sharpness(x).enhance(1 + v), 89 | "shear_x": lambda x, v: affine_warp(x, (1, v, 0, 0, 1, 0)), 90 | "shear_y": lambda x, v: affine_warp(x, (1, 0, 0, v, 1, 0)), 91 | "trans_x": lambda x, v: affine_warp(x, (1, 0, v * x.size[0], 0, 1, 0)), 92 | "trans_y": lambda x, v: affine_warp(x, (1, 0, 0, 0, 1, v * x.size[1])), 93 | } 94 | affine_ops=[ 95 | "rotate","shear_x","shear_y","trans_x","trans_y" 96 | ] 97 | 98 | 99 | OP_RANGES = { 100 | # Ranges for each op in the form of a (min, max, negate). 101 | "auto_contrast": (0, 1, False), 102 | "equalize": (0, 1, False), 103 | "invert": (0, 1, False), 104 | "rotate": (0.0, 30.0, True), 105 | "posterize": (0, 4, False), 106 | "posterize_inc": (0, 4, False), 107 | "solarize": (0, 256, False), 108 | "solarize_inc": (0, 256, False), 109 | "solarize_add": (0, 110, False), 110 | "color": (0.1, 1.9, False), 111 | "contrast": (0.1, 1.9, False), 112 | "brightness": (0.1, 1.9, False), 113 | "sharpness": (0.1, 1.9, False), 114 | "color_inc": (0, 0.9, True), 115 | "contrast_inc": (0, 0.9, True), 116 | "brightness_inc": (0, 0.9, True), 117 | "sharpness_inc": (0, 0.9, True), 118 | "shear_x": (0.0, 0.3, True), 119 | "shear_y": (0.0, 0.3, True), 120 | "trans_x": (0.0, 0.45, True), 121 | "trans_y": (0.0, 0.45, True), 122 | } 123 | 124 | 125 | AUTOAUG_POLICY = [ 126 | # AutoAugment "policy_v0" in form of (op, prob, magnitude), where magnitude <= 1. 127 | [("equalize", 0.8, 0.1), ("shear_y", 0.8, 0.4)], 128 | [("color", 0.4, 0.9), ("equalize", 0.6, 0.3)], 129 | [("color", 0.4, 0.1), ("rotate", 0.6, 0.8)], 130 | [("solarize", 0.8, 0.3), ("equalize", 0.4, 0.7)], 131 | [("solarize", 0.4, 0.2), ("solarize", 0.6, 0.2)], 132 | [("color", 0.2, 0.0), ("equalize", 0.8, 0.8)], 133 | [("equalize", 0.4, 0.8), ("solarize_add", 0.8, 0.3)], 134 | [("shear_x", 0.2, 0.9), ("rotate", 0.6, 0.8)], 135 | [("color", 0.6, 0.1), ("equalize", 1.0, 0.2)], 136 | [("invert", 0.4, 0.9), ("rotate", 0.6, 0.0)], 137 | [("equalize", 1.0, 0.9), ("shear_y", 0.6, 0.3)], 138 | [("color", 0.4, 0.7), ("equalize", 0.6, 0.0)], 139 | [("posterize", 0.4, 0.6), ("auto_contrast", 0.4, 0.7)], 140 | [("solarize", 0.6, 0.8), ("color", 0.6, 0.9)], 141 | [("solarize", 0.2, 0.4), ("rotate", 0.8, 0.9)], 142 | [("rotate", 1.0, 0.7), ("trans_y", 0.8, 0.9)], 143 | [("shear_x", 0.0, 0.0), ("solarize", 0.8, 0.4)], 144 | [("shear_y", 0.8, 0.0), ("color", 0.6, 0.4)], 145 | [("color", 1.0, 0.0), ("rotate", 0.6, 0.2)], 146 | [("equalize", 0.8, 0.4), ("equalize", 0.0, 0.8)], 147 | [("equalize", 1.0, 0.4), ("auto_contrast", 0.6, 0.2)], 148 | [("shear_y", 0.4, 0.7), ("solarize_add", 0.6, 0.7)], 149 | [("posterize", 0.8, 0.2), ("solarize", 0.6, 1.0)], 150 | [("solarize", 0.6, 0.8), ("equalize", 0.6, 0.1)], 151 | [("color", 0.8, 0.6), ("rotate", 0.4, 0.5)], 152 | ] 153 | 154 | 155 | RANDAUG_OPS = [ 156 | # RandAugment list of operations using "increasing" transforms. 157 | "auto_contrast", 158 | "equalize", 159 | "invert", 160 | "rotate", 161 | "posterize_inc", 162 | "solarize_inc", 163 | "solarize_add", 164 | "color_inc", 165 | "contrast_inc", 166 | "brightness_inc", 167 | "sharpness_inc", 168 | "shear_x", 169 | "shear_y", 170 | "trans_x", 171 | "trans_y", 172 | ] 173 | 174 | def check_support(): 175 | mask=np.zeros((100,100)).astype("uint8") 176 | mask=Image.fromarray(mask) 177 | magnitude=1.0 178 | for op in RANDAUG_OPS: 179 | min_v, max_v, negate = OP_RANGES[op] 180 | v = magnitude * (max_v - min_v) + min_v 181 | v = -v if negate and random.random() > 0.5 else v 182 | OP_FUNCTIONS[op](mask, v) 183 | 184 | 185 | def apply_op(im, op, prob, magnitude): 186 | """Apply the selected op to image with given probability and magnitude.""" 187 | # The magnitude is converted to an absolute value v for an op (some ops use -v or v) 188 | assert 0 <= magnitude <= 1 189 | assert op in OP_RANGES and op in OP_FUNCTIONS, "unknown op " + op 190 | if prob < 1 and random.random() > prob: 191 | return im 192 | min_v, max_v, negate = OP_RANGES[op] 193 | v = magnitude * (max_v - min_v) + min_v 194 | v = -v if negate and random.random() > 0.5 else v 195 | return OP_FUNCTIONS[op](im, v) 196 | 197 | def apply_op_both(im,mask, op, prob, magnitude,fill,ignore_value=255): 198 | """Apply the selected op to image with given probability and magnitude.""" 199 | # The magnitude is converted to an absolute value v for an op (some ops use -v or v) 200 | assert 0 <= magnitude <= 1 201 | assert op in OP_RANGES and op in OP_FUNCTIONS, "unknown op " + op 202 | if prob < 1 and random.random() > prob: 203 | return im,mask 204 | min_v, max_v, negate = OP_RANGES[op] 205 | v = magnitude * (max_v - min_v) + min_v 206 | v = -v if negate and random.random() > 0.5 else v 207 | WARP_PARAMS["fillcolor"]=fill 208 | im=OP_FUNCTIONS[op](im, v) 209 | if op in affine_ops: 210 | WARP_PARAMS["fillcolor"]=ignore_value 211 | mask=OP_FUNCTIONS[op](mask, v) 212 | return im,mask 213 | 214 | def rand_augment_both(im, mask,magnitude, ops=None, n_ops=2, prob=1.0,fill=(128,128,128),ignore_value=255): 215 | """Applies random augmentation to an image.""" 216 | ops = ops if ops else RANDAUG_OPS 217 | for op in np.random.choice(ops, int(n_ops)): 218 | im,mask = apply_op_both(im,mask, op, prob, magnitude,fill,ignore_value) 219 | return im,mask 220 | 221 | def rand_augment(im, magnitude, ops=None, n_ops=2, prob=1.0): 222 | """Applies random augmentation to an image.""" 223 | ops = ops if ops else RANDAUG_OPS 224 | for op in np.random.choice(ops, int(n_ops)): 225 | im = apply_op(im, op, prob, magnitude) 226 | return im 227 | 228 | 229 | def auto_augment(im, policy=None): 230 | """Apply auto augmentation to an image.""" 231 | policy = policy if policy else AUTOAUG_POLICY 232 | for op, prob, magnitude in random.choice(policy): 233 | im = apply_op(im, op, prob, magnitude) 234 | return im 235 | 236 | 237 | def make_augment(augment_str): 238 | """Generate augmentation function from separated parameter string. 239 | The parameter string augment_str may be either "AutoAugment" or "RandAugment". 240 | Undocumented use allows for specifying extra params, e.g. "RandAugment_N2_M0.5".""" 241 | params = augment_str.split("_") 242 | names = {"N": "n_ops", "M": "magnitude", "P": "prob"} 243 | assert params[0] in ["RandAugment", "AutoAugment"] 244 | assert all(p[0] in names for p in params[1:]) 245 | keys = [names[p[0]] for p in params[1:]] 246 | vals = [float(p[1:]) for p in params[1:]] 247 | augment = rand_augment if params[0] == "RandAugment" else auto_augment 248 | return lambda im: augment(im, **dict(zip(keys, vals))) 249 | 250 | 251 | def visualize_ops(im, ops=None, num_steps=10): 252 | """Visualize ops by applying each op by varying amounts.""" 253 | ops = ops if ops else RANDAUG_OPS 254 | w, h, magnitudes = im.size[0], im.size[1], np.linspace(0, 1, num_steps) 255 | output = Image.new("RGB", (w * num_steps, h * len(ops))) 256 | for i, op in enumerate(ops): 257 | for j, m in enumerate(magnitudes): 258 | out = apply_op(im, op, prob=1.0, magnitude=m) 259 | output.paste(out, (j * w, i * h)) 260 | return output 261 | 262 | 263 | def visualize_aug(im, augment=rand_augment, num_trials=10, **kwargs): 264 | """Visualize augmentation by applying random augmentations.""" 265 | w, h = im.size[0], im.size[1] 266 | output = Image.new("RGB", (w * num_trials, h * num_trials)) 267 | for i in range(num_trials): 268 | for j in range(num_trials): 269 | output.paste(augment(im, **kwargs), (j * w, i * h)) 270 | return output 271 | 272 | if __name__=="__main__": 273 | im=Image.open("prima_4class/Images/pc-00000085.jpg") 274 | w, h = im.size[0], im.size[1] 275 | im=im.resize((w//4,h//4)) 276 | output=visualize_ops(im,num_steps=10) 277 | #output=visualize_aug(im,magnitude=1/3) 278 | output.show() 279 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | import time 4 | import torch.cuda.amp as amp 5 | import torch.nn.functional 6 | 7 | @torch.no_grad() 8 | def compute_eval_time(model,device,warmup_iter,num_iter,crop_size,batch_size,mixed_precision): 9 | model.eval() 10 | x=torch.randn(batch_size,3,crop_size,crop_size).to(device) 11 | times=[] 12 | for cur_iter in range(warmup_iter+num_iter): 13 | if cur_iter == warmup_iter: 14 | times.clear() 15 | t1=time.time() 16 | with amp.autocast(enabled=mixed_precision): 17 | output = model(x) 18 | torch.cuda.synchronize() 19 | t2=time.time() 20 | times.append(t2-t1) 21 | return average(times) 22 | 23 | def average(v): 24 | return sum(v)/len(v) 25 | def compute_train_time(model,warmup_iter,num_iter,crop_size,batch_size,num_classes,mixed_precision): 26 | model.train() 27 | x=torch.randn(batch_size, 3, crop_size, crop_size).cuda(non_blocking=False) 28 | target=torch.randint(0,num_classes,(batch_size, crop_size, crop_size)).cuda(non_blocking=False) 29 | fw_times=[] 30 | bw_times=[] 31 | scaler = amp.GradScaler(enabled=mixed_precision) 32 | for cur_iter in range(warmup_iter+num_iter): 33 | if cur_iter == warmup_iter: 34 | fw_times.clear() 35 | bw_times.clear() 36 | t1=time.time() 37 | with amp.autocast(enabled=mixed_precision): 38 | output = model(x) 39 | loss = nn.functional.cross_entropy(output,target,ignore_index=255) 40 | torch.cuda.synchronize() 41 | t2=time.time() 42 | scaler.scale(loss).backward() 43 | torch.cuda.synchronize() 44 | t3=time.time() 45 | fw_times.append(t2-t1) 46 | bw_times.append(t3-t2) 47 | return average(fw_times),average(bw_times) 48 | 49 | def compute_loader_time(data_loader,warmup_iter,num_iter): 50 | times=[] 51 | data_loader_iter=iter(data_loader) 52 | for cur_iter in range(warmup_iter+num_iter): 53 | if cur_iter == warmup_iter: 54 | times.clear() 55 | t1=time.time() 56 | next(data_loader_iter) 57 | t2=time.time() 58 | times.append(t2-t1) 59 | return average(times) 60 | 61 | 62 | def memory_used(device): 63 | x=torch.cuda.memory_allocated(device) 64 | return round(x/1024/1024) 65 | def max_memory_used(device): 66 | x=torch.cuda.max_memory_allocated(device) 67 | return round(x/1024/1024) 68 | def memory_test_helper(model,device,crop_size,batch_size,num_classes,mixed_precision): 69 | model.train() 70 | scaler = amp.GradScaler(enabled=mixed_precision) 71 | x=torch.randn(batch_size, 3, crop_size, crop_size).to(device) 72 | target=torch.randint(0,num_classes,(batch_size, crop_size, crop_size)).to(device) 73 | t1=memory_used(device) 74 | with amp.autocast(enabled=mixed_precision): 75 | output = model(x) 76 | loss = nn.functional.cross_entropy(output,target,ignore_index=255) 77 | scaler.scale(loss).backward() 78 | torch.cuda.synchronize() 79 | t2=max_memory_used(device) 80 | torch.cuda.reset_peak_memory_stats(device) 81 | return t2-t1 82 | 83 | def compute_memory_usage(model,device,crop_size,batch_size,num_classes,mixed_precision): 84 | for p in model.parameters(): 85 | p.grad=None 86 | try: 87 | t=memory_test_helper(model,device,crop_size,batch_size,num_classes,mixed_precision) 88 | print() 89 | except: 90 | t=-1 91 | print("out of memory") 92 | for p in model.parameters(): 93 | p.grad=None 94 | return t 95 | 96 | def compute_time_no_loader(model,warmup_iter,num_iter,device,crop_size,batch_size,num_classes,mixed_precision): 97 | model=model.to(device) 98 | print("benchmarking eval time") 99 | eval_time=compute_eval_time(model,device,warmup_iter,num_iter,crop_size,batch_size,mixed_precision) 100 | print("benchmarking train time") 101 | train_fw_time,train_bw_time=compute_train_time(model,warmup_iter,num_iter,crop_size,batch_size,num_classes,mixed_precision) 102 | train_time=train_fw_time+train_bw_time 103 | print("benchmarking memory usage") 104 | memory_usage=compute_memory_usage(model,device,crop_size,batch_size,num_classes,mixed_precision) 105 | print("benchmarking loader time") 106 | dic1={ 107 | "eval_time":eval_time, 108 | "train_time":train_time, 109 | "memory_usage":memory_usage 110 | } 111 | return dic1 112 | 113 | def compute_time_full(model,data_loader,warmup_iter,num_iter,device,crop_size,batch_size,num_classes,mixed_precision): 114 | model=model.to(device) 115 | print("benchmarking eval time") 116 | eval_time=compute_eval_time(model,device,warmup_iter,num_iter,crop_size,batch_size,mixed_precision) 117 | print("benchmarking train time") 118 | train_fw_time,train_bw_time=compute_train_time(model,warmup_iter,num_iter,crop_size,batch_size,num_classes,mixed_precision) 119 | train_time=train_fw_time+train_bw_time 120 | print("benchmarking memory usage") 121 | memory_usage=compute_memory_usage(model,device,crop_size,batch_size,num_classes,mixed_precision) 122 | print("benchmarking loader time") 123 | loader_time=compute_loader_time(data_loader,warmup_iter,num_iter) 124 | loader_overhead=max(0,loader_time-train_time)/train_time 125 | dic1={ 126 | "eval_time":eval_time, 127 | "train_time":train_time, 128 | "memory_usage":memory_usage, 129 | "loader_time":loader_time, 130 | "loader_overhead":loader_overhead 131 | } 132 | dic2={ 133 | "eval_time":eval_time*len(data_loader), 134 | "train_time":train_time*len(data_loader), 135 | "memory_usage":memory_usage, 136 | "loader_time":loader_time, 137 | "loader_overhead":loader_overhead 138 | } 139 | return dic1 140 | -------------------------------------------------------------------------------- /blocks.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | class XBlock(nn.Module): # From figure 4 6 | def __init__(self, in_channels, out_channels, bottleneck_ratio, group_width, stride): 7 | super(XBlock, self).__init__() 8 | inter_channels = out_channels // bottleneck_ratio 9 | groups = inter_channels // group_width 10 | 11 | self.conv_block_1 = nn.Sequential( 12 | nn.Conv2d(in_channels, inter_channels, kernel_size=1, bias=False), 13 | nn.BatchNorm2d(inter_channels), 14 | nn.ReLU(inplace=True) 15 | ) 16 | self.conv_block_2 = nn.Sequential( 17 | nn.Conv2d(inter_channels, inter_channels, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False), 18 | nn.BatchNorm2d(inter_channels), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | self.conv_block_3 = nn.Sequential( 23 | nn.Conv2d(inter_channels, out_channels, kernel_size=1, bias=False), 24 | nn.BatchNorm2d(out_channels) 25 | ) 26 | if stride != 1 or in_channels != out_channels: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), 29 | nn.BatchNorm2d(out_channels) 30 | ) 31 | else: 32 | self.shortcut = None 33 | self.rl = nn.ReLU(inplace=True) 34 | 35 | def forward(self, x): 36 | shortcut=self.shortcut(x) if self.shortcut else x 37 | x = self.conv_block_1(x) 38 | x = self.conv_block_2(x) 39 | x = self.conv_block_3(x) 40 | x = self.rl(x + shortcut) 41 | return x 42 | 43 | 44 | class VBlock(nn.Module): 45 | def __init__(self, in_channels, out_channels, stride): 46 | super(VBlock, self).__init__() 47 | self.conv_block_1 = nn.Sequential( 48 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 49 | nn.BatchNorm2d(out_channels), 50 | nn.ReLU(inplace=True) 51 | ) 52 | self.conv_block_2 = nn.Sequential( 53 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 54 | nn.BatchNorm2d(out_channels), 55 | ) 56 | if stride != 1 or in_channels != out_channels: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(out_channels) 60 | ) 61 | else: 62 | self.shortcut = None 63 | self.rl = nn.ReLU(inplace=True) 64 | 65 | def forward(self, x): 66 | shortcut=self.shortcut(x) if self.shortcut else x 67 | x = self.conv_block_1(x) 68 | x = self.conv_block_2(x) 69 | x = self.rl(x + shortcut) 70 | return x 71 | 72 | class AtrousConcat(nn.Module): 73 | def __init__(self, in_channels, out_channels, group_width,rates): 74 | super(AtrousConcat, self).__init__() 75 | modules = [] 76 | groups = out_channels // group_width 77 | for rate in rates: 78 | modules.append(nn.Sequential( 79 | nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False,groups=groups), 80 | nn.BatchNorm2d(out_channels), 81 | nn.ReLU(inplace=True) 82 | )) 83 | self.convs = nn.ModuleList(modules) 84 | 85 | def forward(self, x): 86 | res = [] 87 | for conv in self.convs: 88 | res.append(conv(x)) 89 | res = torch.cat(res, dim=1) 90 | return res 91 | class SpatialConcat(nn.Module): 92 | def __init__(self, in_channels, out_channels,group_width, bin_sizes): 93 | super(SpatialConcat, self).__init__() 94 | modules = [] 95 | groups = out_channels // group_width 96 | self.reg_conv=nn.Sequential( 97 | nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False,groups=groups), 98 | nn.BatchNorm2d(out_channels), 99 | nn.ReLU(inplace=True), 100 | ) 101 | for size in bin_sizes: 102 | modules.append(nn.Sequential( 103 | nn.AdaptiveAvgPool2d(output_size=(size,size)), 104 | nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False,groups=groups), 105 | nn.BatchNorm2d(out_channels), 106 | nn.ReLU(inplace=True), 107 | )) 108 | self.convs = nn.ModuleList(modules) 109 | 110 | def forward(self, x): 111 | input_shape = x.shape[-2:] 112 | res = [self.reg_conv(x)] 113 | for conv in self.convs: 114 | res.append(F.interpolate(conv(x), size=input_shape, mode='bilinear',align_corners=False)) 115 | res = torch.cat(res, dim=1) 116 | return res 117 | class AtrousSum(nn.Module): 118 | def __init__(self, in_channels, out_channels,group_width, rates): 119 | super(AtrousSum, self).__init__() 120 | modules = [] 121 | groups = out_channels // group_width 122 | for rate in rates: 123 | modules.append(nn.Sequential( 124 | nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False,groups=groups), 125 | nn.BatchNorm2d(out_channels), 126 | nn.ReLU(inplace=True) 127 | )) 128 | self.convs = nn.ModuleList(modules) 129 | 130 | def forward(self, x): 131 | out=self.convs[0](x) 132 | for conv in self.convs[1:]: 133 | out=out+conv(x) 134 | return out 135 | 136 | class ABlock(nn.Module): 137 | def __init__(self, in_channels, conv1_channels,conv2_channels, out_channels,group_width,rates=(1,6,12,18),mode="concat"): 138 | super(ABlock, self).__init__() 139 | self.conv1 = nn.Sequential( 140 | nn.Conv2d(in_channels, conv1_channels, kernel_size=1, bias=False), 141 | nn.BatchNorm2d(conv1_channels), 142 | nn.ReLU(inplace=True) 143 | ) 144 | if mode=="concat": 145 | self.conv2=AtrousConcat(conv1_channels, conv2_channels, group_width,rates) 146 | conv2_out_channels=conv2_channels * len(rates) 147 | elif mode=="sum": 148 | self.conv2=AtrousSum(conv1_channels, conv2_channels, group_width,rates) 149 | conv2_out_channels=conv2_channels 150 | else: 151 | raise NotImplementedError() 152 | self.conv3 = nn.Sequential( 153 | nn.Conv2d(conv2_out_channels, out_channels, 1, bias=False), 154 | nn.BatchNorm2d(out_channels), 155 | ) 156 | self.shortcut=None 157 | if in_channels != out_channels: 158 | self.shortcut = nn.Sequential( 159 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), 160 | nn.BatchNorm2d(out_channels) 161 | ) 162 | self.rl = nn.Sequential( 163 | nn.ReLU(inplace=True) 164 | ) 165 | 166 | def forward(self, x): 167 | shortcut=self.shortcut(x) if self.shortcut else x 168 | x = self.conv1(x) 169 | x = self.conv2(x) 170 | x = self.conv3(x) 171 | x = self.rl(x + shortcut) 172 | return x 173 | 174 | class SBlock(nn.Module): 175 | def __init__(self, in_channels, conv1_channels,conv2_channels, out_channels,group_width,bin_sizes=(1,2,3,6),mode="concat"): 176 | super(SBlock, self).__init__() 177 | self.conv1 = nn.Sequential( 178 | nn.Conv2d(in_channels, conv1_channels, kernel_size=1, bias=False), 179 | nn.BatchNorm2d(conv1_channels), 180 | nn.ReLU(inplace=True) 181 | ) 182 | if mode=="concat": 183 | self.conv2=SpatialConcat(conv1_channels, conv2_channels, group_width,bin_sizes) 184 | conv2_out_channels=conv2_channels * (len(bin_sizes)+1) 185 | else: 186 | raise NotImplementedError() 187 | self.conv3 = nn.Sequential( 188 | nn.Conv2d(conv2_out_channels, out_channels, 1, bias=False), 189 | nn.BatchNorm2d(out_channels), 190 | ) 191 | self.shortcut=None 192 | if in_channels != out_channels: 193 | self.shortcut = nn.Sequential( 194 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), 195 | nn.BatchNorm2d(out_channels) 196 | ) 197 | self.rl = nn.Sequential( 198 | nn.ReLU(inplace=True) 199 | ) 200 | 201 | def forward(self, x): 202 | shortcut=self.shortcut(x) if self.shortcut else x 203 | x = self.conv1(x) 204 | x = self.conv2(x) 205 | x = self.conv3(x) 206 | x = self.rl(x + shortcut) 207 | return x 208 | 209 | def profile(model,x,device,num_iter=15): 210 | import time 211 | model=model.to(device) 212 | model.eval() 213 | t1=time.time() 214 | for i in range(num_iter): 215 | y=model(x) 216 | t2=time.time() 217 | return (t2-t1)/num_iter 218 | 219 | def profiler(models,x): 220 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 221 | print("warming up") 222 | profile(models[0],x, device,num_iter=15) 223 | for model in models: 224 | seconds=profile(model,x,device,num_iter=30) 225 | print(round(seconds,3)) 226 | 227 | def try_models(): 228 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 229 | x=torch.randn(2,128,128,128) 230 | group_width=128 231 | w=128 232 | model1=ABlock(w,w,w,w,group_width,rates=(1,6)) 233 | model2=SBlock(w,w,w,w,group_width,bin_sizes=(4,16)) 234 | model3=XBlock(w,w,1,group_width,1) 235 | models=[model1,model2,model3] 236 | profiler(models,x) 237 | 238 | 239 | if __name__=="__main__": 240 | try_models() 241 | -------------------------------------------------------------------------------- /cityscapes.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import namedtuple 4 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 5 | import torch.utils.data as data 6 | from PIL import Image 7 | 8 | 9 | class Cityscapes(data.Dataset): 10 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', 11 | 'has_instances', 'ignore_in_eval', 'color']) 12 | 13 | classes = [ 14 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 15 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 16 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 17 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 18 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 19 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 20 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 21 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 22 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), 23 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 24 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 25 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), 26 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), 27 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), 28 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 29 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 30 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 31 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), 32 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 33 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), 34 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), 35 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), 36 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), 37 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), 38 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), 39 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), 40 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), 41 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), 42 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), 43 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 44 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 45 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), 46 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), 47 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), 48 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), 49 | ] 50 | 51 | def __init__( 52 | self, 53 | root: str, 54 | split: str = "train", 55 | mode: str = "fine", 56 | target_type: Union[List[str], str] = "semantic", 57 | transforms: Optional[Callable] = None, 58 | ) -> None: 59 | self.root=root 60 | self.transforms=transforms 61 | self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' 62 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split) 63 | self.targets_dir = os.path.join(self.root, self.mode, split) 64 | self.target_type = target_type 65 | self.split = split 66 | self.images = [] 67 | self.targets = [] 68 | 69 | if not isinstance(target_type, list): 70 | self.target_type = [target_type] 71 | 72 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): 73 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' 74 | ' specified "split" and "mode" are inside the "root" directory') 75 | 76 | for city in os.listdir(self.images_dir): 77 | if city[0]==".": 78 | continue 79 | img_dir = os.path.join(self.images_dir, city) 80 | target_dir = os.path.join(self.targets_dir, city) 81 | for file_name in os.listdir(img_dir): 82 | target_types = [] 83 | for t in self.target_type: 84 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 85 | self._get_target_suffix(self.mode, t)) 86 | target_types.append(os.path.join(target_dir, target_name)) 87 | 88 | self.images.append(os.path.join(img_dir, file_name)) 89 | self.targets.append(target_types) 90 | 91 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 92 | """ 93 | Args: 94 | index (int): Index 95 | Returns: 96 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more 97 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. 98 | """ 99 | 100 | image = Image.open(self.images[index]).convert('RGB') 101 | 102 | targets: Any = [] 103 | for i, t in enumerate(self.target_type): 104 | if t == 'polygon': 105 | target = self._load_json(self.targets[index][i]) 106 | else: 107 | target = Image.open(self.targets[index][i]) 108 | 109 | targets.append(target) 110 | 111 | target = tuple(targets) if len(targets) > 1 else targets[0] 112 | 113 | if self.transforms is not None: 114 | image, target = self.transforms(image, target) 115 | 116 | return image, target 117 | 118 | def __len__(self) -> int: 119 | return len(self.images) 120 | 121 | def extra_repr(self) -> str: 122 | lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] 123 | return '\n'.join(lines).format(**self.__dict__) 124 | 125 | def _load_json(self, path: str) -> Dict[str, Any]: 126 | with open(path, 'r') as file: 127 | data = json.load(file) 128 | return data 129 | 130 | def _get_target_suffix(self, mode: str, target_type: str) -> str: 131 | if target_type == 'instance': 132 | return '{}_instanceIds.png'.format(mode) 133 | elif target_type == 'semantic': 134 | return '{}_labelTrainIds.png'.format(mode) 135 | elif target_type == 'color': 136 | return '{}_color.png'.format(mode) 137 | else: 138 | return '{}_polygons.json'.format(mode) 139 | -------------------------------------------------------------------------------- /coco_download.sh: -------------------------------------------------------------------------------- 1 | mkdir coco 2 | cd coco 3 | mkdir images 4 | cd images 5 | 6 | wget http://images.cocodataset.org/zips/train2017.zip 7 | wget http://images.cocodataset.org/zips/val2017.zip 8 | 9 | unzip -q train2017.zip 10 | unzip -q val2017.zip 11 | 12 | rm train2017.zip 13 | rm val2017.zip 14 | 15 | cd .. 16 | 17 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 18 | unzip -q annotations_trainval2017.zip 19 | rm annotations_trainval2017.zip 20 | -------------------------------------------------------------------------------- /coco_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.utils.data 4 | import torchvision 5 | from PIL import Image 6 | 7 | import os 8 | 9 | from pycocotools import mask as coco_mask 10 | 11 | from transforms import Compose 12 | 13 | 14 | class FilterAndRemapCocoCategories(object): 15 | def __init__(self, categories, remap=True): 16 | self.categories = categories 17 | self.remap = remap 18 | 19 | def __call__(self, image, anno): 20 | anno = [obj for obj in anno if obj["category_id"] in self.categories] 21 | if not self.remap: 22 | return image, anno 23 | anno = copy.deepcopy(anno) 24 | for obj in anno: 25 | obj["category_id"] = self.categories.index(obj["category_id"]) 26 | return image, anno 27 | 28 | 29 | def convert_coco_poly_to_mask(segmentations, height, width): 30 | masks = [] 31 | for polygons in segmentations: 32 | rles = coco_mask.frPyObjects(polygons, height, width) 33 | mask = coco_mask.decode(rles) 34 | if len(mask.shape) < 3: 35 | mask = mask[..., None] 36 | mask = torch.as_tensor(mask, dtype=torch.uint8) 37 | mask = mask.any(dim=2) 38 | masks.append(mask) 39 | if masks: 40 | masks = torch.stack(masks, dim=0) 41 | else: 42 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 43 | return masks 44 | 45 | 46 | class ConvertCocoPolysToMask(object): 47 | def __call__(self, image, anno): 48 | w, h = image.size 49 | segmentations = [obj["segmentation"] for obj in anno] 50 | cats = [obj["category_id"] for obj in anno] 51 | if segmentations: 52 | masks = convert_coco_poly_to_mask(segmentations, h, w) 53 | cats = torch.as_tensor(cats, dtype=masks.dtype) 54 | # merge all instance masks into a single segmentation map 55 | # with its corresponding categories 56 | target, _ = (masks * cats[:, None, None]).max(dim=0) 57 | # discard overlapping instances 58 | target[masks.sum(0) > 1] = 255 59 | else: 60 | target = torch.zeros((h, w), dtype=torch.uint8) 61 | target = Image.fromarray(target.numpy()) 62 | return image, target 63 | 64 | 65 | def _coco_remove_images_without_annotations(dataset, cat_list=None): 66 | def _has_valid_annotation(anno): 67 | # if it's empty, there is no annotation 68 | if len(anno) == 0: 69 | return False 70 | # if more than 1k pixels occupied in the image 71 | return sum(obj["area"] for obj in anno) > 1000 72 | 73 | assert isinstance(dataset, torchvision.datasets.CocoDetection) 74 | ids = [] 75 | for ds_idx, img_id in enumerate(dataset.ids): 76 | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) 77 | anno = dataset.coco.loadAnns(ann_ids) 78 | if cat_list: 79 | anno = [obj for obj in anno if obj["category_id"] in cat_list] 80 | if _has_valid_annotation(anno): 81 | ids.append(ds_idx) 82 | 83 | dataset = torch.utils.data.Subset(dataset, ids) 84 | return dataset 85 | 86 | 87 | def get_coco_dataset(root, image_set, transforms): 88 | PATHS = { 89 | "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), 90 | "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), 91 | # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) 92 | } 93 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 94 | 1, 64, 20, 63, 7, 72] 95 | 96 | transforms = Compose([ 97 | FilterAndRemapCocoCategories(CAT_LIST, remap=True), 98 | ConvertCocoPolysToMask(), 99 | transforms 100 | ]) 101 | 102 | img_folder, ann_file = PATHS[image_set] 103 | img_folder = os.path.join(root, img_folder) 104 | ann_file = os.path.join(root, ann_file) 105 | 106 | dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) 107 | 108 | if image_set == "train": 109 | dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) 110 | 111 | return dataset 112 | -------------------------------------------------------------------------------- /configs/cityscapes_regnety40_160epochs_mixed_precision.yaml: -------------------------------------------------------------------------------- 1 | #MODEL: 2 | model_name: regnety_040 3 | num_classes: 19 4 | pretrained_backbone: True 5 | separable_convolution: False 6 | 7 | #OPTIM: 8 | epochs: 160 9 | resume: False 10 | lr: 0.01 11 | momentum: 0.9 12 | weight_decay: 0.0001 13 | class_weight: null 14 | 15 | #TRAIN: 16 | batch_size: 16 17 | train_size: 481 18 | mixed_precision: True 19 | 20 | #TEST: 21 | val_size: 513 22 | 23 | #benchmark 24 | warmup_iter: 3 25 | num_iter: 30 26 | 27 | save_every_k_epochs: 3 28 | save_last_k_epochs: 8 29 | dataset_name: cityscapes 30 | dataset_dir: cityscapes_dataset 31 | aug_mode: baseline 32 | pretrained_path: '' 33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/cityscapes_regnety40_latest 34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/cityscapes_regnety40 35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/cityscapes_regnety40_latest 36 | 37 | #save_best_path: checkpoints/voc_regnety40_mixed_precision 38 | #save_latest_path: checkpoints/voc_regnety40_mixed_precision_latest 39 | -------------------------------------------------------------------------------- /configs/configs_sanity_check.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | def f(): 4 | filename="voc_resnet50d_30epochs.yaml" 5 | with open(filename) as file: 6 | dic=yaml.full_load(file) 7 | print(dic) 8 | 9 | if __name__=="__main__": 10 | f() 11 | -------------------------------------------------------------------------------- /configs/voc_mobilenetv2_30epochs.yaml: -------------------------------------------------------------------------------- 1 | #MODEL: 2 | model_name: mobilenetv2 3 | num_classes: 21 4 | pretrained_backbone: True 5 | separable_convolution: False 6 | 7 | #OPTIM: 8 | epochs: 30 9 | resume: False 10 | lr: 0.01 11 | momentum: 0.9 12 | weight_decay: 0.0001 13 | class_weight: null 14 | 15 | #TRAIN: 16 | batch_size: 16 17 | train_size: 481 18 | mixed_precision: False 19 | 20 | #TEST: 21 | eval_steps: 1000 22 | val_size: 513 23 | 24 | #benchmark 25 | warmup_iter: 3 26 | num_iter: 30 27 | 28 | save_every_k_epochs: 3 29 | dataset_name: pascal_voc 30 | dataset_dir: pascal_voc_dataset 31 | aug_mode: baseline 32 | pretrained_path: '' 33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_mobilenetv2_latest 34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_mobilenetv2 35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_mobilenetv2_latest 36 | -------------------------------------------------------------------------------- /configs/voc_regnetx40_30epochs.yaml: -------------------------------------------------------------------------------- 1 | #MODEL: 2 | model_name: regnetx_040 3 | num_classes: 21 4 | pretrained_backbone: True 5 | separable_convolution: True 6 | 7 | #OPTIM: 8 | epochs: 30 9 | resume: False 10 | lr: 0.01 11 | momentum: 0.9 12 | weight_decay: 0.0001 13 | class_weight: null 14 | 15 | #TRAIN: 16 | batch_size: 16 17 | train_size: 481 18 | mixed_precision: False 19 | 20 | #TEST: 21 | eval_steps: 1000 22 | val_size: 513 23 | 24 | #benchmark 25 | warmup_iter: 3 26 | num_iter: 30 27 | 28 | save_every_k_epochs: 3 29 | dataset_name: pascal_voc 30 | dataset_dir: pascal_voc_dataset 31 | aug_mode: baseline 32 | pretrained_path: '' 33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnetx40_latest 34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnetx40 35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnetx40_latest 36 | -------------------------------------------------------------------------------- /configs/voc_regnety40_30epochs.yaml: -------------------------------------------------------------------------------- 1 | #MODEL: 2 | model_name: regnety_040 3 | num_classes: 21 4 | pretrained_backbone: True 5 | separable_convolution: True 6 | 7 | #OPTIM: 8 | epochs: 30 9 | resume: False 10 | lr: 0.01 11 | momentum: 0.9 12 | weight_decay: 0.0001 13 | class_weight: null 14 | 15 | #TRAIN: 16 | batch_size: 16 17 | train_size: 481 18 | mixed_precision: False 19 | 20 | #TEST: 21 | eval_steps: 1000 22 | val_size: 513 23 | 24 | #benchmark 25 | warmup_iter: 3 26 | num_iter: 30 27 | 28 | save_every_k_epochs: 3 29 | dataset_name: pascal_voc 30 | dataset_dir: pascal_voc_dataset 31 | aug_mode: baseline 32 | pretrained_path: '' 33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_latest 34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40 35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_latest 36 | -------------------------------------------------------------------------------- /configs/voc_regnety40_30epochs_mixed_precision.yaml: -------------------------------------------------------------------------------- 1 | #MODEL: 2 | model_name: regnety_040 3 | num_classes: 21 4 | pretrained_backbone: True 5 | separable_convolution: False 6 | 7 | #OPTIM: 8 | epochs: 30 9 | resume: False 10 | lr: 0.01 11 | momentum: 0.9 12 | weight_decay: 0.0001 13 | class_weight: null 14 | 15 | #TRAIN: 16 | batch_size: 16 17 | train_size: 481 18 | mixed_precision: True 19 | 20 | #TEST: 21 | val_size: 513 22 | 23 | #benchmark 24 | warmup_iter: 3 25 | num_iter: 30 26 | 27 | save_every_k_epochs: 3 28 | save_last_k_epochs: 8 29 | dataset_name: pascal_voc 30 | dataset_dir: pascal_voc_dataset 31 | aug_mode: baseline 32 | pretrained_path: '' 33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_mixed_precision_latest 34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_mixed_precision 35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_mixed_precision_latest 36 | 37 | #save_best_path: checkpoints/voc_regnety40_mixed_precision 38 | #save_latest_path: checkpoints/voc_regnety40_mixed_precision_latest 39 | -------------------------------------------------------------------------------- /configs/voc_resnet50d_30epochs.yaml: -------------------------------------------------------------------------------- 1 | #MODEL: 2 | model_name: resnet50d 3 | num_classes: 21 4 | pretrained_backbone: True 5 | separable_convolution: True 6 | 7 | #OPTIM: 8 | epochs: 30 9 | resume: False 10 | lr: 0.01 11 | momentum: 0.9 12 | weight_decay: 0.0001 13 | class_weight: null 14 | 15 | #TRAIN: 16 | batch_size: 16 17 | train_size: 481 18 | mixed_precision: False 19 | 20 | #TEST: 21 | eval_steps: 1000 22 | val_size: 513 23 | 24 | #benchmark 25 | warmup_iter: 3 26 | num_iter: 30 27 | 28 | save_every_k_epochs: 3 29 | dataset_name: pascal_voc 30 | dataset_dir: pascal_voc_dataset 31 | aug_mode: baseline 32 | pretrained_path: '' 33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_resnet50d_latest 34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_resnet50d 35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_resnet50d_latest 36 | -------------------------------------------------------------------------------- /configs/yoho.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 4, 6, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25] 12 | GROUP_WS: [64, 128, 256, 512] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.2 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 256 23 | TEST: 24 | DATASET: imagenet 25 | IM_SIZE: 256 26 | BATCH_SIZE: 200 27 | NUM_GPUS: 8 28 | OUT_DIR: . 29 | -------------------------------------------------------------------------------- /custom_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch.utils.data as data 3 | import os 4 | import json 5 | 6 | class SegmentationDataset(data.Dataset): 7 | def __init__(self,root,image_set,transforms): 8 | # images, masks, json splits 9 | split_f=os.path.join(root,f"{image_set}.json") 10 | file_names=json.load(open(split_f, 'r')) 11 | root= os.path.expanduser(root) 12 | self.transforms=transforms 13 | image_dir = os.path.join(root, "Images") 14 | mask_dir=os.path.join(root, "Masks") 15 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] 16 | self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] 17 | assert (len(self.images) == len(self.masks)) 18 | 19 | def __getitem__(self, index): 20 | img = Image.open(self.images[index]).convert('RGB') 21 | target = Image.open(self.masks[index]) 22 | if self.transforms is not None: 23 | img, target = self.transforms(img, target) 24 | 25 | return img, target 26 | 27 | def __len__(self): 28 | return len(self.images) 29 | 30 | if __name__=='__main__': 31 | dataset=SegmentationDataset("yoho",image_set="train",transforms=None) 32 | print(dataset[1]) 33 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import transforms as T 2 | from data_utils import * 3 | from cityscapes import Cityscapes 4 | from voc12 import Voc12Segmentation 5 | from coco_utils import get_coco_dataset 6 | 7 | def build_transforms(is_train, size, crop_size,mode="baseline"): 8 | mean = (0.485, 0.456, 0.406) 9 | std = (0.229, 0.224, 0.225) 10 | fill = tuple([int(v * 255) for v in mean]) 11 | ignore_value = 255 12 | transforms=[] 13 | min_scale=1 14 | max_scale=1 15 | if is_train: 16 | min_scale=0.5 17 | max_scale=2 18 | transforms.append(T.RandomResize(int(min_scale*size),int(max_scale*size))) 19 | if is_train: 20 | if mode=="baseline": 21 | pass 22 | elif mode=="randaug": 23 | transforms.append(T.RandAugment(2,1/3,prob=1.0,fill=fill,ignore_value=ignore_value)) 24 | elif mode=="custom1": 25 | transforms.append(T.ColorJitter(0.5,0.5,(0.5,2),0.05)) 26 | transforms.append(T.AddNoise(10)) 27 | transforms.append(T.RandomRotation((-10,10), mean=fill, ignore_value=0)) 28 | else: 29 | raise NotImplementedError() 30 | transforms.append( 31 | T.RandomCrop( 32 | crop_size,crop_size, 33 | fill, 34 | ignore_value, 35 | random_pad=is_train 36 | )) 37 | transforms.append(T.RandomHorizontalFlip(0.5)) 38 | transforms.append(T.ToTensor()) 39 | transforms.append(T.Normalize( 40 | mean, 41 | std 42 | )) 43 | return T.Compose(transforms) 44 | 45 | def get_cityscapes(root,batch_size=16,val_size=513,train_size=481,mode="baseline",num_workers=4): 46 | train=Cityscapes(root, split="train", target_type="semantic", transforms=build_transforms(True, val_size, train_size,mode)) 47 | val=Cityscapes(root, split="val", target_type="semantic", transforms=build_transforms(False, val_size, train_size,mode)) 48 | train_loader = get_dataloader_train(train, batch_size,num_workers) 49 | val_loader = get_dataloader_val(val,num_workers) 50 | return train_loader,val_loader 51 | def get_coco(root,batch_size=16,val_size=513,train_size=481,mode="baseline",num_workers=4): 52 | train=get_coco_dataset(root, "train", build_transforms(True, val_size, train_size,mode)) 53 | val=get_coco_dataset(root, "val", build_transforms(False, val_size, train_size,mode)) 54 | train_loader = get_dataloader_train(train, batch_size,num_workers) 55 | val_loader = get_dataloader_val(val,num_workers) 56 | return train_loader, val_loader 57 | def get_pascal_voc(root,batch_size=16,val_size=513,train_size=481,mode="baseline",num_workers=4): 58 | download=False 59 | train = Voc12Segmentation(root, 'train_aug', build_transforms(True, val_size, train_size,mode), 60 | download) 61 | val = Voc12Segmentation(root, 'val', build_transforms(False, val_size, train_size,mode), 62 | download) 63 | train_loader = get_dataloader_train(train, batch_size,num_workers) 64 | val_loader = get_dataloader_val(val,num_workers) 65 | return train_loader, val_loader 66 | 67 | if __name__ == '__main__': 68 | train_loader, val_loader=get_pascal_voc("pascal_voc_dataset") 69 | print(iter(train_loader).__next__()) 70 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | 5 | def cat_list(images, fill_value=0): 6 | max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) 7 | batch_shape = (len(images),) + max_size 8 | batched_imgs = images[0].new(*batch_shape).fill_(fill_value) 9 | for img, pad_img in zip(images, batched_imgs): 10 | pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) 11 | return batched_imgs 12 | 13 | def collate_fn(batch): 14 | images, targets = list(zip(*batch)) 15 | batched_imgs = cat_list(images, fill_value=0) 16 | batched_targets = cat_list(targets, fill_value=255) 17 | return batched_imgs, batched_targets 18 | 19 | def get_sampler(dataset,dataset_test,distributed=False): 20 | if distributed: 21 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 22 | test_sampler = torch.utils.data.distributed.DistributedSampler( 23 | dataset_test) 24 | else: 25 | train_sampler = torch.utils.data.RandomSampler(dataset) 26 | #train_sampler = torch.utils.data.SequentialSampler(dataset) 27 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 28 | return train_sampler,test_sampler 29 | 30 | def worker_init_fn(worker_id): 31 | from datetime import datetime 32 | np.random.seed(datetime.now().microsecond) 33 | def get_dataloader_train(dataset,batch_size,num_workers=4): 34 | #dataset = get_coco(image_folder, ann_file, "train",get_temp_transform()) 35 | train_sampler = torch.utils.data.RandomSampler(dataset) 36 | data_loader = torch.utils.data.DataLoader( 37 | dataset, batch_size=batch_size, 38 | sampler=train_sampler, num_workers=num_workers, 39 | collate_fn=collate_fn, drop_last=True) 40 | return data_loader 41 | 42 | def get_dataloader_val(dataset_test,num_workers=4): 43 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 44 | data_loader_test = torch.utils.data.DataLoader( 45 | dataset_test, batch_size=1, 46 | sampler=test_sampler, num_workers=num_workers, 47 | collate_fn=collate_fn) 48 | return data_loader_test 49 | 50 | def find_mean_and_std(data_loader): 51 | K = 0.8 52 | n = 0.0 53 | Ex=torch.zeros(3).float() 54 | Ex2 = torch.zeros(3).float() 55 | count=0 56 | for image,_ in data_loader: 57 | count+=1 58 | assert len(image.size())==4 59 | image = image.transpose(0, 1).flatten(1) 60 | Ex += (image-K).sum(dim=1) 61 | Ex2 += ((image - K)**2).sum(dim=1) 62 | n +=image.size()[1] 63 | if count==1000: 64 | break 65 | mean=Ex/n+K 66 | variance=(Ex2 - (Ex * Ex)/n)/(n-1) 67 | std=variance.sqrt() 68 | return mean, std 69 | 70 | def find_class_weights(dataloader,num_classes): 71 | print_every=4 72 | class_weights=torch.zeros(num_classes) 73 | for count,(image,target) in enumerate(dataloader): 74 | class_weights+=torch.bincount(target[target=0) 259 | # b=torch.sum(x<0) 260 | # print(a,b) 261 | # x=m(x) 262 | # print(name,torch.mean(x),torch.std(x)) 263 | # return x 264 | # def gradient2(): 265 | # #model=timm.create_model('mobilenetv2_100',features_only=True,out_indices=(4,),output_stride=16) 266 | # #model=Deeplab3P(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False) 267 | # # pretrained_path='checkpoints/voc_mobilenetv2' 268 | # # model=Deeplab3P(name='mobilenetv2_100',num_classes=num_classes,pretrained=pretrained_path,sc=False).to( 269 | # # device) 270 | # #model=torchvision.models.vgg16() 271 | # model=torchvision.models.resnet50() 272 | # model=LittleNet() 273 | # model.train() 274 | # print(model) 275 | # x=torch.randn((2,10,128,128)) 276 | # y=model(x) 277 | # loss=torch.mean(y**2) 278 | # loss.backward() 279 | # for name,p in model.named_parameters(): 280 | # if p.grad is not None: 281 | # if "conv" in name: 282 | # print(name,float(torch.sum(torch.abs(p.grad)))) 283 | -------------------------------------------------------------------------------- /fov.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | def field_of_vision(fs): 4 | k=1 5 | s=1 6 | for _k,_s in fs: # kernel size and stride 7 | k=k+(_k-1)*s 8 | s=s*_s 9 | return k,s 10 | 11 | def staged_net(ds): 12 | fs=[] 13 | for d in ds: 14 | for i in range(d): 15 | if i==0: 16 | fs.append([3,2]) 17 | else: 18 | fs.append([3,1]) 19 | return fs 20 | 21 | 22 | def test1(): 23 | regnets=[ 24 | [1,1,4,7], 25 | [1,2,7,12], 26 | [1,3,5,7], 27 | [1,3,7,5], 28 | [2,4,10,2], 29 | [2,6,15,2], 30 | [2,5,14,2], 31 | [2,4,10,1], 32 | [2,5,15,1], 33 | [2,5,11,1], 34 | [2,6,13,1], 35 | [2,7,13,1] 36 | ] 37 | for ds in regnets: 38 | fs=staged_net(ds) 39 | fs=[(3,2)]+fs 40 | k,s=field_of_vision(fs) 41 | print(k,s) 42 | def test2(): 43 | resnets=[ 44 | [2, 2, 2, 2], 45 | [3, 4, 6, 3], 46 | [3, 4, 6, 3], 47 | [3, 4, 23, 3], 48 | [3, 8, 36, 3] 49 | ] 50 | for ds in resnets: 51 | ds[0]=ds[0]+1 52 | fs=staged_net(ds) 53 | fs=[(7,2)]+fs 54 | #fs=[(3,2),(3,1),(3,1)]+fs 55 | k,s=field_of_vision(fs) 56 | print(k,s) 57 | def test3(): 58 | inverted_residual_setting = [ 59 | # t, c, n, s 60 | [1, 16, 1, 1], 61 | [6, 24, 2, 2], 62 | [6, 32, 3, 2], 63 | [6, 64, 4, 2], 64 | [6, 96, 3, 1], 65 | [6, 160, 3, 2], 66 | [6, 320, 1, 1], 67 | ] 68 | fs=[] 69 | mobilenetv2=[2,3,7,4] 70 | mobilenetv3_large_fs=[ 71 | (3,2),(3,1),(3,2),(3,1),(5,2),(5,1),(5,1),(3,2), 72 | (3,1),(3,1),(3,1),(3,1),(3,1),(5,2),(5,1),(5,1)] 73 | fs=staged_net(mobilenetv2) 74 | fs=[(3,2),(3,1)]+fs 75 | k,s=field_of_vision(fs) 76 | print(k,s) 77 | k,s=field_of_vision(mobilenetv3_large_fs) 78 | print(k,s) 79 | 80 | 81 | if __name__=="__main__": 82 | test1() 83 | print() 84 | test2() 85 | print() 86 | test3() 87 | #torchvision.models.resnet50() 88 | #model=torchvision.models.mobilenet_v2() 89 | # fs=[(3,2),(1,1),(3,2),(1,1)] 90 | # k,s=field_of_vision(fs) 91 | # print(k,s) 92 | -------------------------------------------------------------------------------- /google0ccf7212e9c814a7.html: -------------------------------------------------------------------------------- 1 | google-site-verification: google0ccf7212e9c814a7.html -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch import nn 3 | import torch 4 | from torch.nn import functional as F 5 | from math import log2 6 | from Deeplabv3 import DeepLabHead,DeepLabHeadNoASSP,get_ASSP,convert_to_separable_conv, ASPP 7 | import timm 8 | 9 | def replace(output_stride=8): 10 | d=[False,False,False] 11 | n=int(log2(32/output_stride)) 12 | assert n<=3,'output_stride too small' 13 | for i in range(n): 14 | d[2-i]=True 15 | return d 16 | 17 | class Deeplab3(nn.Module): 18 | def __init__(self,name="mobilenetv2_100",num_classes=21,pretrained="", 19 | pretrained_backbone=True,aspp=True): 20 | super(Deeplab3,self).__init__() 21 | output_stride = 16 22 | self.backbone=timm.create_model(name, features_only=True, 23 | output_stride=output_stride, out_indices=(4,),pretrained=pretrained_backbone and pretrained =="") 24 | channels=self.backbone.feature_info.channels() 25 | if aspp: 26 | self.head=DeepLabHead(channels[0], num_classes,output_stride) 27 | else: 28 | self.head=DeepLabHeadNoASSP(channels[0], num_classes) 29 | if pretrained != "": 30 | dic = torch.load(pretrained, map_location='cpu') 31 | if type(dic)==dict: 32 | self.load_state_dict(dic['model']) 33 | else: 34 | self.load_state_dict(dic) 35 | def forward(self,x): 36 | input_shape = x.shape[-2:] 37 | x=self.backbone(x) 38 | x=self.head(x[0]) 39 | x = F.interpolate(x, size=input_shape, mode='bilinear',align_corners=False) 40 | return x 41 | 42 | class Deeplab3P(nn.Module): 43 | def __init__(self, name="mobilenetv2_100",num_classes=21,pretrained="", 44 | pretrained_backbone=True,sc=False,filter_multiplier=1.0): 45 | super(Deeplab3P,self).__init__() 46 | output_stride = 16 47 | num_filters = int(256*filter_multiplier) 48 | num_low_filters = int(48*filter_multiplier) 49 | try: 50 | self.backbone=timm.create_model(name, features_only=True, 51 | output_stride=output_stride, out_indices=(1, 4),pretrained=pretrained_backbone and pretrained =="") 52 | except RuntimeError: 53 | print("no model") 54 | print(timm.list_models()) 55 | raise RuntimeError() 56 | channels=self.backbone.feature_info.channels() 57 | self.head16=get_ASSP(channels[1], output_stride,num_filters) 58 | self.head4=torch.nn.Sequential( 59 | nn.Conv2d(channels[0], num_low_filters, 1, bias=False), 60 | nn.BatchNorm2d(num_low_filters), 61 | nn.ReLU(inplace=True)) 62 | self.decoder= nn.Sequential( 63 | nn.Conv2d(num_low_filters+num_filters, num_filters, 3, padding=1, bias=False), 64 | nn.BatchNorm2d(num_filters), 65 | nn.ReLU(inplace=True), 66 | nn.Conv2d(num_filters, num_filters, 3, padding=1, bias=False), 67 | nn.BatchNorm2d(num_filters), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d(num_filters, num_classes, 1) 70 | ) 71 | if sc: 72 | self.decoder = convert_to_separable_conv(self.decoder) 73 | if pretrained != "": 74 | dic = torch.load(pretrained, map_location='cpu') 75 | if type(dic)==dict: 76 | self.load_state_dict(dic['model']) 77 | else: 78 | self.load_state_dict(dic) 79 | 80 | def forward(self, x): 81 | input_shape = x.shape[-2:] 82 | features = self.backbone(x) 83 | x=self.head16(features[1]) 84 | x2=self.head4(features[0]) 85 | intermediate_shape=x2.shape[-2:] 86 | x = F.interpolate(x, size=intermediate_shape, mode='bilinear',align_corners=False) 87 | x=torch.cat((x,x2),dim=1) 88 | x=self.decoder(x) 89 | x = F.interpolate(x, size=input_shape, mode='bilinear',align_corners=False) 90 | return x 91 | 92 | def profile(model,device,num_iter=15): 93 | import time 94 | model=model.to(device) 95 | model.eval() 96 | x=torch.randn(2,3,321,321).to(device) 97 | t1=time.time() 98 | for i in range(num_iter): 99 | y=model(x) 100 | t2=time.time() 101 | return (t2-t1)/num_iter 102 | 103 | def profiler(models): 104 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 105 | model = torchvision.models.resnet50() 106 | print("warming up") 107 | profile(model, device,num_iter=15) 108 | for model in models: 109 | seconds=profile(model,device,num_iter=30) 110 | print(round(seconds,3)) 111 | 112 | def total_params(models): 113 | for model in models: 114 | total = sum(p.numel() for p in model.parameters() if p.requires_grad) 115 | total=round(total/1000000,2) 116 | print(f"{model.__class__.__name__}: {total}M") 117 | 118 | def profiler2(models): 119 | from ptflops import get_model_complexity_info 120 | for model in models: 121 | macs, params = get_model_complexity_info(model, (3, 480, 480), 122 | as_strings=True, 123 | print_per_layer_stat=False, 124 | verbose=False) 125 | print(f"{model.__class__.__name__}: {macs}, {params}") 126 | 127 | def memory_used(device): 128 | x=torch.cuda.memory_allocated(device) 129 | return round(x/1024/1024) 130 | def max_memory_used(device): 131 | x=torch.cuda.max_memory_allocated(device) 132 | return round(x/1024/1024) 133 | def memory_test_helper(model,device): 134 | model=model.to(device) 135 | model.train() 136 | N=16 137 | x=torch.randn(N, 3, 481, 481).to(device) 138 | target=torch.randint(0,21,(N, 481, 481)).to(device) 139 | t1=memory_used(device) 140 | out=model(x) 141 | loss=nn.functional.cross_entropy(out,target,ignore_index=255) 142 | loss.backward() 143 | t2=max_memory_used(device) 144 | print(t2-t1) 145 | torch.cuda.reset_peak_memory_stats(device) 146 | 147 | def memory_test(models,device): 148 | for i in range(len(models)): 149 | try: 150 | memory_test_helper(models[0],device) 151 | print() 152 | except: 153 | print("out of memory") 154 | for p in models[0].parameters(): 155 | p.grad=None 156 | del models[0] 157 | 158 | def test_separable(): 159 | model1=nn.Sequential(nn.Conv2d(3,256, 1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True), 160 | nn.Conv2d(256,256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True)) 161 | model2=nn.Sequential(nn.Conv2d(3,256, 1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True), 162 | nn.Conv2d(256,256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True)) 163 | model2=convert_to_separable_conv(model2) 164 | model3=nn.Sequential(nn.Conv2d(3,256, 1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True), 165 | nn.Conv2d(256,256, 3, padding=2, bias=False,dilation=2),nn.BatchNorm2d(256),nn.ReLU(inplace=True)) 166 | model4=nn.Sequential(nn.Conv2d(3,256, 1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True), 167 | nn.Conv2d(256,256, 3, padding=2, bias=False,dilation=2),nn.BatchNorm2d(256),nn.ReLU(inplace=True)) 168 | model4=convert_to_separable_conv(model4) 169 | models=[ 170 | model1,model2,model3,model4 171 | ] 172 | profiler(models) 173 | def test_fast(): 174 | models=[ 175 | Deeplab3P(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False), 176 | #Deeplab3P(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False,filter_multiplier=0.5), 177 | Deeplab3(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False), 178 | Deeplab3(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False,aspp=False), 179 | Deeplab3P(name='resnet50d', num_classes=21,pretrained_backbone=False), 180 | #Deeplab3P(name='resnet50d', num_classes=21,pretrained_backbone=False,filter_multiplier=0.5), 181 | Deeplab3(name='resnet50d', num_classes=21,pretrained_backbone=False), 182 | Deeplab3(name='resnet50d', num_classes=21,pretrained_backbone=False,aspp=False), 183 | ] 184 | profiler(models) 185 | 186 | def test_models(): 187 | names=[ 188 | 'resnet50d', 189 | 'nf_regnet_b0', 190 | 'gernet_m', 191 | 'efficientnet_lite3', 192 | 'efficientnet_lite2' 193 | ] 194 | models = [] 195 | for name in names: 196 | #models.append(Deeplab3P(name=name, num_classes=21,pretrained_backbone=False)) 197 | models.append(timm.create_model(name, features_only=True, 198 | output_stride=16, out_indices=(4,))) 199 | profiler(models) 200 | 201 | if __name__=='__main__': 202 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 203 | #resnet50d 77.1 mIOU 204 | #regnetx_040 77.0 mIOU 205 | #regnety_040 78.6 mIOU 206 | #mobilenetv2_100 72.8 mIOU 207 | #regnetx_080 77.3 mIOU 208 | 209 | num_classes=21 210 | print(timm.list_models()) 211 | model=timm.create_model('resnest50d') 212 | print(model) 213 | #experiment1() 214 | #test_fast() 215 | #test_models() 216 | 217 | 218 | 219 | 220 | # 0.897 221 | # 0.752 222 | # 0.614 223 | # 0.371 224 | # 0.251 225 | 226 | # regnety4G 227 | # 0.764 228 | # 0.691 229 | # 0.572 230 | # 0.289 231 | # 0.131 232 | 233 | # regnet 234 | # 1.016 235 | # 0.946 236 | # 0.853 237 | # 0.49 238 | # 0.181 239 | 240 | 241 | # 'mobilenetv2_100' 242 | # 0.34 243 | # 0.178 244 | # 0.167 245 | # 0.168 246 | # 0.167 247 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cityscapesscripts 2 | opencv-python 3 | pip-check-reqs 4 | pip-chill 5 | ptflops 6 | pycocotools 7 | pyqt5 8 | pyyaml 9 | timm 10 | typing-extensions 11 | -------------------------------------------------------------------------------- /show.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from PIL import Image 4 | import torch 5 | from model import Deeplab3P 6 | import time 7 | from data import get_cityscapes,get_pascal_voc 8 | from cityscapes import Cityscapes 9 | 10 | mean = np.array([0.485, 0.456, 0.406]) 11 | std = np.array([0.229, 0.224, 0.225]) 12 | 13 | def get_colors(): 14 | palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) 15 | colors = torch.arange(255).view(-1, 1) * palette 16 | colors = (colors % 255).numpy().astype("uint8") 17 | return colors 18 | def get_colors_cityscapes(): 19 | colors=np.zeros((256,3)) 20 | colors[255]=[255,255,255] 21 | for c in Cityscapes.classes: 22 | if 0<=c.train_id<=18: 23 | colors[c.train_id]=c.color 24 | return colors.astype("uint8") 25 | 26 | 27 | def show_image(inp, title=None): 28 | """Imshow for Tensor.""" 29 | inp = inp.numpy().transpose((1, 2, 0)) 30 | inp = std * inp + mean 31 | inp = np.clip(inp, 0, 1) 32 | plt.imshow(inp) 33 | if title is not None: 34 | plt.title(title) 35 | 36 | def show_mask(images): 37 | colors=get_colors() 38 | r = Image.fromarray(images.byte().cpu().numpy()) 39 | r.putpalette(colors) 40 | plt.imshow(r) 41 | def show_cityscapes_mask(images): 42 | colors=get_colors_cityscapes() 43 | r = Image.fromarray(images.byte().cpu().numpy()) 44 | r.putpalette(colors) 45 | plt.imshow(r) 46 | 47 | def display(data_loader,show_mask,num_images=5,skip=4,images_per_line=6): 48 | images_so_far = 0 49 | fig = plt.figure(figsize=(6, 4)) 50 | num_rows=int(np.ceil(num_images/images_per_line)) 51 | data_loader = iter(data_loader) 52 | for _ in range(skip): 53 | next(data_loader) 54 | for images, targets in data_loader: 55 | for image, target in zip(images, targets): 56 | print(image.size(), target.size()) 57 | plt.subplot(num_rows, 2*images_per_line, images_so_far + 1) 58 | plt.axis('off') 59 | show_image(image) 60 | 61 | plt.subplot(num_rows, 2*images_per_line, images_so_far + 2) 62 | plt.axis('off') 63 | show_mask(target) 64 | 65 | images_so_far += 2 66 | if images_so_far == 2 * num_images: 67 | plt.tight_layout() 68 | plt.show() 69 | return 70 | plt.tight_layout() 71 | plt.show() 72 | def show(model,data_loader,device,show_mask,num_images=5,skip=4,images_per_line=2): 73 | images_so_far=0 74 | model.eval() 75 | num_rows = int(np.ceil(num_images / images_per_line)) 76 | fig=plt.figure(figsize=(8,4)) 77 | data_loader=iter(data_loader) 78 | for _ in range(skip): 79 | next(data_loader) 80 | with torch.no_grad(): 81 | for images, targets in data_loader: 82 | images, targets = images.to(device), targets.to(device) 83 | start=time.time() 84 | outputs = model(images) 85 | end=time.time() 86 | print(end-start) 87 | for image,target,output in zip(images,targets,outputs): 88 | output = output.argmax(0) 89 | print(image.size(),target.size(),output.size()) 90 | plt.subplot(num_rows, 3*images_per_line, images_so_far+1) 91 | plt.axis('off') 92 | show_image(image) 93 | 94 | plt.subplot(num_rows, 3*images_per_line, images_so_far+2) 95 | plt.axis('off') 96 | show_mask(target) 97 | 98 | plt.subplot(num_rows,3*images_per_line,images_so_far+3) 99 | plt.axis('off') 100 | show_mask(output) 101 | 102 | images_so_far+=3 103 | if images_so_far==3*num_images: 104 | plt.tight_layout() 105 | plt.show() 106 | return 107 | plt.tight_layout() 108 | plt.show() 109 | 110 | def show_cityscapes(): 111 | num_images=16 112 | images_per_line=4 113 | skip=0 114 | _,data_loader=get_cityscapes("cityscapes_dataset",16,train_size=481,val_size=513) 115 | display(data_loader,show_cityscapes_mask,num_images=num_images,skip=skip,images_per_line=images_per_line) 116 | 117 | if __name__=="__main__": 118 | show_cityscapes() 119 | 120 | #_,data_loader=get_pascal_voc("pascal_voc_dataset",16,train_size=385,val_size=385) 121 | 122 | #pretrained_path='checkpoints/voc_resnet50d_noise' 123 | #model=torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True).to(device) 124 | # model=Deeplab3P(name='resnet50d',num_classes=num_classes,pretrained=pretrained_path,sc=True).to( 125 | # device) 126 | 127 | #show(model,data_loader,device,show_mask,num_images=num_images,skip=skip,images_per_line=images_per_line) 128 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from model import Deeplab3P 2 | from benchmark import compute_time_full, compute_time_no_loader 3 | from data import get_cityscapes,get_pascal_voc,get_coco 4 | import datetime 5 | import time 6 | 7 | import torch 8 | import torch.utils.data 9 | from torch import nn 10 | import torch.nn.functional 11 | import yaml 12 | import torch.cuda.amp as amp 13 | import os 14 | 15 | class ConfusionMatrix(object): 16 | def __init__(self, num_classes): 17 | self.num_classes = num_classes 18 | self.mat = None 19 | 20 | def update(self, a, b): 21 | n = self.num_classes 22 | if self.mat is None: 23 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) 24 | with torch.no_grad(): 25 | k = (a >= 0) & (a < n) 26 | inds = n * a[k].to(torch.int64) + b[k] 27 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 28 | 29 | def reset(self): 30 | self.mat.zero_() 31 | 32 | def compute(self): 33 | h = self.mat.float() 34 | acc_global = torch.diag(h).sum() / h.sum() 35 | acc = torch.diag(h) / h.sum(1) 36 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 37 | return acc_global, acc, iu 38 | def __str__(self): 39 | acc_global, acc, iu = self.compute() 40 | return ( 41 | 'global correct: {:.1f}\n' 42 | 'average row correct: {}\n' 43 | 'IoU: {}\n' 44 | 'mean IoU: {:.1f}').format( 45 | acc_global.item() * 100, 46 | ['{:.1f}'.format(i) for i in (acc * 100).tolist()], 47 | ['{:.1f}'.format(i) for i in (iu * 100).tolist()], 48 | iu.mean().item() * 100) 49 | 50 | 51 | def criterion2(inputs, target, w): 52 | return nn.functional.cross_entropy(inputs,target,ignore_index=255,weight=w) 53 | 54 | def get_loss_fun(weight): 55 | return nn.CrossEntropyLoss(weight=weight,ignore_index=255) 56 | 57 | def evaluate(model, data_loader, device, num_classes,mixed_precision,print_every=100): 58 | model.eval() 59 | confmat = ConfusionMatrix(num_classes) 60 | with torch.no_grad(): 61 | for i,(image, target) in enumerate(data_loader): 62 | if (i+1)%print_every==0: 63 | print(i+1) 64 | image, target = image.to(device), target.to(device) 65 | with amp.autocast(enabled=mixed_precision): 66 | output = model(image) 67 | confmat.update(target.flatten(), output.argmax(1).flatten()) 68 | 69 | return confmat 70 | 71 | def train_one_epoch(model, loss_fun, optimizer, loader, lr_scheduler, device, print_freq,mixed_precision,scaler): 72 | model.train() 73 | losses=0 74 | for t, x in enumerate(loader): 75 | image, target=x 76 | image, target = image.to(device), target.to(device) 77 | with amp.autocast(enabled=mixed_precision): 78 | output = model(image) 79 | loss = loss_fun(output, target) 80 | optimizer.zero_grad() 81 | scaler.scale(loss).backward() 82 | scaler.step(optimizer) 83 | scaler.update() 84 | lr_scheduler.step() 85 | losses+=loss.item() 86 | if t % print_freq==0: 87 | print(t,loss.item()) 88 | num_iter=len(loader) 89 | print(losses/num_iter) 90 | 91 | def save(model,optimizer,scheduler,epoch,path,best_mIU,scaler): 92 | dic={ 93 | 'model': model.state_dict(), 94 | 'optimizer': optimizer.state_dict(), 95 | 'lr_scheduler': scheduler.state_dict(), 96 | 'scaler':scaler.state_dict(), 97 | 'epoch': epoch, 98 | 'best_mIU':best_mIU 99 | } 100 | torch.save(dic,path) 101 | 102 | def train(model, save_best_path,save_latest_path, epochs,optimizer, data_loader, data_loader_test, lr_scheduler, device,num_classes,save_best_on_epochs,loss_fun,mixed_precision,scaler,best_mIU): 103 | start_time = time.time() 104 | for epoch in epochs: 105 | print(f"epoch: {epoch}") 106 | train_one_epoch(model, loss_fun, optimizer, data_loader, lr_scheduler, 107 | device, print_freq=50,mixed_precision=mixed_precision,scaler=scaler) 108 | if epoch in save_best_on_epochs: 109 | confmat = evaluate(model, data_loader_test, device=device, 110 | num_classes=num_classes,mixed_precision=mixed_precision,print_every=100) 111 | print(confmat) 112 | acc_global, acc, iu = confmat.compute() 113 | mIU=iu.mean().item() * 100 114 | if mIU > best_mIU: 115 | best_mIU=mIU 116 | save(model, optimizer, lr_scheduler, epoch, save_best_path,best_mIU,scaler) 117 | if save_latest_path != "": 118 | save(model, optimizer, lr_scheduler, epoch, save_latest_path,best_mIU,scaler) 119 | 120 | total_time = time.time() - start_time 121 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 122 | print('Training time {}'.format(total_time_str)) 123 | 124 | def get_dataset_loaders(config): 125 | name=config["dataset_name"] 126 | if name=="pascal_voc": 127 | f=get_pascal_voc 128 | elif name=="cityscapes": 129 | f=get_cityscapes 130 | elif name=="coco": 131 | f=get_coco 132 | else: 133 | raise NotImplementedError() 134 | mode="baseline" 135 | if "aug_mode" in config: 136 | mode=config["aug_mode"] 137 | data_loader, data_loader_test=f(config["dataset_dir"],config["batch_size"],train_size=config["train_size"],val_size=config["val_size"],mode=mode) 138 | print("train size:", len(data_loader)) 139 | print("val size:", len(data_loader_test)) 140 | return data_loader, data_loader_test 141 | 142 | def get_model(config): 143 | pretrained_backbone=config["pretrained_backbone"] 144 | if config["resume"]: 145 | pretrained_backbone=False 146 | return Deeplab3P(name=config["model_name"], 147 | num_classes=config["num_classes"], 148 | pretrained_backbone=pretrained_backbone, 149 | sc=config["separable_convolution"], 150 | pretrained=config["pretrained_path"]) 151 | 152 | def get_config_and_check_files(config_filename): 153 | with open(config_filename) as file: 154 | config=yaml.full_load(file) 155 | save_best_dir=os.path.dirname(config["save_best_path"]) 156 | save_latest_dir=os.path.dirname(config["save_latest_path"]) 157 | if not os.path.isdir(save_best_dir): 158 | raise FileNotFoundError(f"{save_best_dir} is not a directory") 159 | if not os.path.isdir(save_latest_dir): 160 | raise FileNotFoundError(f"{save_latest_dir} is not a directory") 161 | if not os.path.isdir(config["dataset_dir"]): 162 | raise FileNotFoundError(f"{config['dataset_dir']} is not a directory") 163 | if config["resume"]: 164 | if not os.path.isfile(config["resume_path"]): 165 | raise FileNotFoundError(f"{config['resume_path']} is not a file") 166 | elif not config["pretrained_backbone"]: 167 | if not os.path.isfile(config["pretrained_path"]): 168 | raise FileNotFoundError(f"{config['pretrained_path']} is not a file") 169 | return config 170 | 171 | def get_epochs_to_save(config): 172 | epochs=config["epochs"] 173 | save_every_k_epochs=config["save_every_k_epochs"] 174 | save_best_on_epochs=[i*save_every_k_epochs-1 for i in range(1,epochs//save_every_k_epochs+1)] 175 | if epochs-1 not in save_best_on_epochs: 176 | save_best_on_epochs.append(epochs-1) 177 | if "save_last_k_epochs" in config: 178 | for i in range(epochs-config["save_last_k_epochs"],epochs): 179 | if i not in save_best_on_epochs: 180 | save_best_on_epochs.append(i) 181 | save_best_on_epochs=sorted(save_best_on_epochs) 182 | return save_best_on_epochs 183 | 184 | def main2(config_filename): 185 | config=get_config_and_check_files(config_filename) 186 | torch.backends.cudnn.benchmark=True 187 | save_best_path=config["save_best_path"] 188 | save_latest_path=config["save_latest_path"] 189 | epochs=config["epochs"] 190 | num_classes=config["num_classes"] 191 | class_weight=config["class_weight"] 192 | mixed_precision=config["mixed_precision"] 193 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 194 | data_loader, data_loader_test=get_dataset_loaders(config) 195 | model=get_model(config).to(device) 196 | params_to_optimize=model.parameters() 197 | optimizer = torch.optim.SGD(params_to_optimize, lr=config["lr"], 198 | momentum=config["momentum"], weight_decay=config["weight_decay"]) 199 | scaler = amp.GradScaler(enabled=mixed_precision) 200 | loss_fun=get_loss_fun(class_weight) 201 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 202 | optimizer,lambda x: (1 - x / (len(data_loader) * epochs)) ** 0.9) 203 | 204 | epoch_start=0 205 | best_mIU=0 206 | save_best_on_epochs=get_epochs_to_save(config) 207 | print("save on epochs: ",save_best_on_epochs) 208 | 209 | if config["resume"]: 210 | dic=torch.load(config["resume_path"],map_location='cpu') 211 | model.load_state_dict(dic['model']) 212 | optimizer.load_state_dict(dic['optimizer']) 213 | lr_scheduler.load_state_dict(dic['lr_scheduler']) 214 | epoch_start = dic['epoch'] + 1 215 | if "best_mIU" in dic: 216 | best_mIU=dic["best_mIU"] 217 | if "scaler" in dic: 218 | scaler.load_state_dict(dic["scaler"]) 219 | 220 | train(model, save_best_path,save_latest_path, range(epoch_start,epochs),optimizer, data_loader, 221 | data_loader_test, lr_scheduler, device,num_classes,save_best_on_epochs,loss_fun,mixed_precision,scaler,best_mIU) 222 | def check3(config_filename): 223 | config=get_config_and_check_files(config_filename) 224 | torch.backends.cudnn.benchmark=True 225 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 226 | data_loader, data_loader_test=get_dataset_loaders(config) 227 | model=get_model(config).to(device) 228 | num_classes=config["num_classes"] 229 | mixed_precision=config["mixed_precision"] 230 | print("evaluating") 231 | confmat = evaluate(model, data_loader_test, device=device, 232 | num_classes=num_classes,mixed_precision=mixed_precision) 233 | print(confmat) 234 | 235 | # def check(): 236 | # device = torch.device( 237 | # 'cuda') if torch.cuda.is_available() else torch.device('cpu') 238 | # num_classes = 21 239 | # pretrained_path='/content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_resnet50d_noise2' 240 | # #voc_resnet50d_noise 241 | # data_loader, data_loader_test=get_pascal_voc("pascal_voc_dataset",16,train_size=481,val_size=513) 242 | # eval_steps = len(data_loader_test) 243 | # model=Deeplab3P(name="resnet50d",num_classes=num_classes,pretrained=pretrained_path,sc=True).to( 244 | # device) 245 | # print("evaluating") 246 | # confmat = evaluate(model, data_loader_test, device=device, 247 | # num_classes=num_classes,eval_steps=eval_steps,print_every=100) 248 | # print(confmat) 249 | # def check2(): 250 | # device = torch.device( 251 | # 'cuda') if torch.cuda.is_available() else torch.device('cpu') 252 | # num_classes = 21 253 | # pretrained_path='checkpoints/voc_mobilenetv2' 254 | # #pretrained_path='checkpoints/voc_regnety40' 255 | # data_loader, data_loader_test=get_pascal_voc("pascal_voc_dataset",16,train_size=385,val_size=385) 256 | # eval_steps = 100 257 | # model=Deeplab3P(name='mobilenetv2_100',num_classes=num_classes,pretrained=pretrained_path,sc=False).to( 258 | # device) 259 | # print("evaluating") 260 | # confmat = evaluate(model, data_loader_test, device=device, 261 | # num_classes=num_classes,eval_steps=eval_steps,print_every=5) 262 | # print(confmat) 263 | 264 | def benchmark(config_filename): 265 | config=get_config_and_check_files(config_filename) 266 | torch.backends.cudnn.benchmark=True 267 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 268 | mixed_precision=config["mixed_precision"] 269 | warmup_iter=config["warmup_iter"] 270 | num_iter=config["num_iter"] 271 | crop_size=config["train_size"] 272 | batch_size=config["batch_size"] 273 | num_classes=config["num_classes"] 274 | model=get_model(config).to(device) 275 | #data_loader, data_loader_test=get_dataset_loaders(config) 276 | #dic=compute_time_full(model,data_loader,warmup_iter,num_iter,device,crop_size,batch_size,num_classes,mixed_precision) 277 | dic=compute_time_no_loader(model,warmup_iter,num_iter,device,crop_size,batch_size,num_classes,mixed_precision) 278 | for k,v in dic.items(): 279 | print(f"{k}: {v}") 280 | 281 | if __name__=='__main__': 282 | #validation example nums 283 | #config_filename="PyTorch_DeepLab/configs/voc_regnety40_30epochs_mixed_precision.yaml" 284 | config_filename2="configs/voc_regnety40_30epochs_mixed_precision.yaml" 285 | config=get_config_and_check_files(config_filename2) 286 | print(config) 287 | print(get_epochs_to_save(config)) 288 | #benchmark("PyTorch_DeepLab/configs/voc_regnety40_30epochs_mixed_precision.yaml") 289 | #main2("PyTorch_DeepLab/configs/voc_regnety40_30epochs_mixed_precision.yaml") 290 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from torchvision.transforms import functional as F 4 | from PIL import Image 5 | import torchvision.transforms as T 6 | import torch 7 | from augment import apply_op_both,rand_augment_both 8 | 9 | class Compose(object): 10 | """ 11 | Composes a sequence of transforms. 12 | Arguments: 13 | transforms: A list of transforms. 14 | """ 15 | def __init__(self, transforms): 16 | self.transforms = transforms 17 | 18 | def __call__(self, image, label): 19 | for t in self.transforms: 20 | image, label = t(image, label) 21 | return image, label 22 | 23 | def __repr__(self): 24 | format_string = self.__class__.__name__ + "(" 25 | for t in self.transforms: 26 | format_string += "\n" 27 | format_string += " {0}".format(t) 28 | format_string += "\n)" 29 | return format_string 30 | 31 | 32 | class ToTensor(object): 33 | def __call__(self, image, target): 34 | image = F.to_tensor(image) 35 | target = torch.as_tensor(np.array(target), dtype=torch.int64) 36 | return image, target 37 | 38 | class RandAugment: 39 | def __init__(self,N,M,prob=1.0,fill=(128,128,128),ignore_value=255): 40 | self.N=N 41 | self.M=M 42 | self.prob=prob 43 | self.fill=fill 44 | self.ignore_value=ignore_value 45 | def __call__(self, image, target): 46 | return rand_augment_both(image,target,n_ops=self.N,magnitude=self.M,prob=self.prob,fill=self.fill,ignore_value=self.ignore_value) 47 | 48 | 49 | class Normalize(object): 50 | """ 51 | Normalizes image by mean and std. 52 | """ 53 | def __init__(self, mean, std): 54 | self.mean = mean 55 | self.std = std 56 | 57 | def __call__(self, image, label): 58 | image = F.normalize(image, mean=self.mean, std=self.std) 59 | return image, label 60 | 61 | class RandomResize(object): 62 | def __init__(self, min_size, max_size=None): 63 | self.min_size = min_size 64 | if max_size is None: 65 | max_size = min_size 66 | self.max_size = max_size 67 | 68 | def __call__(self, image, target): 69 | size = random.randint(self.min_size, self.max_size) 70 | image = F.resize(image, size) 71 | target = F.resize(target, size, interpolation=F.InterpolationMode.NEAREST) 72 | return image, target 73 | 74 | class ColorJitter: 75 | def __init__(self,brightness=0.2, contrast=0.2, saturation=(0.5,4), hue=0.2): 76 | self.jitter=T.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) 77 | def __call__(self, image, target): 78 | image=self.jitter(image) 79 | return image,target 80 | 81 | class AddNoise:#additive gaussian noise 82 | def __init__(self,factor): 83 | self.factor=factor 84 | def __call__(self, image, target): 85 | factor = random.uniform(0, self.factor) 86 | image = np.array(image) 87 | assert(image.dtype==np.uint8) 88 | gauss = (np.array(torch.randn(*image.shape)) * factor).astype("uint8") 89 | noisy = (image + gauss).clip(0, 255) 90 | image = Image.fromarray(noisy) 91 | return image, target 92 | 93 | class RandomRotation: 94 | def __init__(self,degrees,mean,ignore_value=255): 95 | self.degrees=degrees 96 | self.mean=mean 97 | self.ignore_value=ignore_value 98 | def __call__(self, image, target): 99 | expand=True 100 | if random.random()<0.5: 101 | angle = random.uniform(*self.degrees) 102 | image=F.rotate(image, angle,fill=self.mean,expand=expand) 103 | target=F.rotate(target,angle,fill=self.ignore_value,expand=expand) 104 | return image,target 105 | 106 | 107 | class RandomScale(object): 108 | """ 109 | Applies random scale augmentation. 110 | Arguments: 111 | min_scale: Minimum scale value. 112 | max_scale: Maximum scale value. 113 | scale_step_size: The step size from minimum to maximum value. 114 | """ 115 | def __init__(self, min_scale, max_scale, scale_step_size): 116 | self.min_scale = min_scale 117 | self.max_scale = max_scale 118 | self.scale_step_size = scale_step_size 119 | 120 | @staticmethod 121 | def get_random_scale(min_scale_factor, max_scale_factor, step_size): 122 | """Gets a random scale value. 123 | Args: 124 | min_scale_factor: Minimum scale value. 125 | max_scale_factor: Maximum scale value. 126 | step_size: The step size from minimum to maximum value. 127 | Returns: 128 | A random scale value selected between minimum and maximum value. 129 | Raises: 130 | ValueError: min_scale_factor has unexpected value. 131 | """ 132 | if min_scale_factor < 0 or min_scale_factor > max_scale_factor: 133 | raise ValueError('Unexpected value of min_scale_factor.') 134 | 135 | if min_scale_factor == max_scale_factor: 136 | return min_scale_factor 137 | 138 | # When step_size = 0, we sample the value uniformly from [min, max). 139 | if step_size == 0: 140 | return random.uniform(min_scale_factor, max_scale_factor) 141 | 142 | # When step_size != 0, we randomly select one discrete value from [min, max]. 143 | num_steps = int((max_scale_factor - min_scale_factor) / step_size + 1) 144 | scale_factors = np.linspace(min_scale_factor, max_scale_factor, num_steps) 145 | np.random.shuffle(scale_factors) 146 | return scale_factors[0] 147 | 148 | def __call__(self, image, label): 149 | scale = self.get_random_scale(self.min_scale, self.max_scale, self.scale_step_size) 150 | img_w, img_h = image.size 151 | img_w,img_h=int(img_w*scale),int(img_h*scale) 152 | image=F.resize(image,[img_h,img_w]) 153 | label=F.resize(label,[img_h,img_w],interpolation=F.InterpolationMode.NEAREST) 154 | return image,label 155 | 156 | 157 | class RandomCrop(object): 158 | def __init__(self, crop_h, crop_w, pad_value, ignore_label, random_pad): 159 | self.crop_h = crop_h 160 | self.crop_w = crop_w 161 | self.pad_value = pad_value 162 | self.ignore_label = ignore_label 163 | self.random_pad = random_pad 164 | 165 | def __call__(self, image, label): 166 | img_w,img_h=image.size 167 | pad_h = max(self.crop_h - img_h, 0) 168 | pad_w = max(self.crop_w - img_w, 0) 169 | if pad_h > 0 or pad_w > 0: 170 | if self.random_pad: 171 | pad_top = random.randint(0, pad_h) 172 | pad_bottom = pad_h - pad_top 173 | pad_left = random.randint(0, pad_w) 174 | pad_right = pad_w - pad_left 175 | else: 176 | pad_top, pad_bottom, pad_left, pad_right = 0, pad_h, 0, pad_w 177 | image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom), fill=self.pad_value) 178 | label= F.pad(label, (pad_left, pad_top, pad_right, pad_bottom), fill=self.ignore_label) 179 | 180 | crop_params = T.RandomCrop.get_params(image, (self.crop_h, self.crop_w)) 181 | image = F.crop(image, *crop_params) 182 | label = F.crop(label, *crop_params) 183 | return image,label 184 | 185 | 186 | class RandomHorizontalFlip(object): 187 | def __init__(self, flip_prob): 188 | self.flip_prob = flip_prob 189 | 190 | def __call__(self, image, target): 191 | if random.random() < self.flip_prob: 192 | image = F.hflip(image) 193 | target = F.hflip(target) 194 | return image, target 195 | 196 | def f(): 197 | image=np.zeros((50,50),dtype=np.uint8) 198 | assert(image.dtype==np.uint8) 199 | gauss=(np.array(torch.randn(50,50)*10)).astype("uint8") 200 | noisy = (image + gauss).clip(0, 255) 201 | # print(noisy.dtype) 202 | # factor = random.uniform(0, 10) 203 | # image = Image.open("cityscapes_dataset/leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png") 204 | # image=np.array(image) 205 | # gauss = np.array(torch.randn(*image.shape)) * factor 206 | # noisy = (image + gauss).clip(0, 255).astype("uint8") 207 | 208 | def g(): 209 | image=np.zeros((50,50)) 210 | gauss=np.array(torch.randn(50,50)*10) 211 | noisy = (image + gauss).clip(0, 255).astype("uint8") 212 | # print(noisy.dtype) 213 | # factor = random.uniform(0, 10) 214 | # image = Image.open("cityscapes_dataset/leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png") 215 | # image=np.array(image) 216 | # gauss = np.array(torch.randn(*image.shape)) * factor 217 | # print(gauss.dtype) 218 | # noisy = (image + gauss).clip(0, 255).astype("uint8") 219 | 220 | 221 | if __name__=='__main__': 222 | f() 223 | g() 224 | #import timeit 225 | #print(timeit.timeit('f()', globals=globals(), number=100)) 226 | #print(timeit.timeit('g()', globals=globals(), number=100)) 227 | -------------------------------------------------------------------------------- /voc12.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | import torch.utils.data as data 4 | 5 | from PIL import Image 6 | from torchvision.datasets.utils import download_url 7 | 8 | class Voc12Segmentation(data.Dataset): 9 | def __init__(self,root,image_set,transforms,download=False): 10 | self.root = os.path.expanduser(root) 11 | self.url='http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar' 12 | self.filename='VOCtrainval_11-May-2012.tar' 13 | self.md5='6cd6e144f989b92b3379bac3b3de84fd' 14 | self.base_dir='VOCdevkit/VOC2012' 15 | self.transforms=transforms 16 | voc_root = os.path.join(self.root, self.base_dir) 17 | image_dir = os.path.join(voc_root, 'JPEGImages') 18 | if download: 19 | download_extract(self.url, self.root, self.filename, self.md5) 20 | if not os.path.isdir(voc_root): 21 | raise RuntimeError(f'{voc_root} not found') 22 | if image_set == 'train_aug': 23 | mask_dir = os.path.join(voc_root, 'SegmentationClassAug') 24 | split_f = os.path.join(voc_root, f'ImageSets/Segmentation/{image_set}.txt') 25 | else: 26 | mask_dir = os.path.join(voc_root, 'SegmentationClass') 27 | split_f = os.path.join(voc_root, f'ImageSets/Segmentation/{image_set}.txt') 28 | if not os.path.exists(split_f): 29 | raise RuntimeError(f'{split_f} not found') 30 | with open(split_f, "r") as f:# os.path.join(split_f) 31 | file_names = [x.strip() for x in f.readlines()] 32 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] 33 | self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] 34 | assert (len(self.images) == len(self.masks)) 35 | 36 | def __getitem__(self, index): 37 | img = Image.open(self.images[index]).convert('RGB') 38 | target = Image.open(self.masks[index]) 39 | if self.transforms is not None: 40 | img, target = self.transforms(img, target) 41 | 42 | return img, target 43 | 44 | def __len__(self): 45 | return len(self.images) 46 | 47 | def download_extract(url, root, filename, md5): 48 | download_url(url, root, filename, md5) 49 | with tarfile.open(os.path.join(root, filename), "r") as tar: 50 | tar.extractall(path=root) 51 | --------------------------------------------------------------------------------