├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── semantic_segmentation.iml
└── vcs.xml
├── Deeplabv3.py
├── README.md
├── accuracy_upperbound.py
├── augment.py
├── benchmark.py
├── blocks.py
├── cityscapes.py
├── coco_download.sh
├── coco_utils.py
├── configs
├── cityscapes_regnety40_160epochs_mixed_precision.yaml
├── configs_sanity_check.py
├── voc_mobilenetv2_30epochs.yaml
├── voc_regnetx40_30epochs.yaml
├── voc_regnety40_30epochs.yaml
├── voc_regnety40_30epochs_mixed_precision.yaml
├── voc_resnet50d_30epochs.yaml
└── yoho.yaml
├── custom_dataset.py
├── data.py
├── data_utils.py
├── experimental_models.py
├── fov.py
├── google0ccf7212e9c814a7.html
├── model.py
├── requirements.txt
├── show.py
├── train.py
├── transforms.py
└── voc12.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project exclude paths
2 | /venv/
3 | cityscapes_dataset.zip
4 | hello.py
5 | pascal_voc_dataset.zip
6 | /cifar/
7 | /cityscapes_dataset/
8 | /pascal_voc_dataset/
9 | checkpoints
10 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/semantic_segmentation.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/Deeplabv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from copy import deepcopy
5 |
6 | class AtrousSeparableConvolution(nn.Sequential):
7 | def __init__(self, in_channels, out_channels, kernel_size,
8 | stride=1, padding=0, dilation=1, bias=True,add_norm=True):
9 | modules=[]
10 | modules.append(nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
11 | stride=stride, padding=padding, dilation=dilation,
12 | bias=(not add_norm), groups=in_channels))
13 | if add_norm:
14 | modules.append(nn.BatchNorm2d(in_channels))
15 | modules.append(nn.ReLU(inplace=True))
16 | modules.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1,
17 | padding=0, bias=bias))
18 | super().__init__(*modules)
19 |
20 |
21 | class DeepLabHead(nn.Sequential):
22 | def __init__(self, in_channels, num_classes,output_stride):
23 | base_rates=[3,6,9]
24 | mul=32//output_stride
25 | rates=[x*mul for x in base_rates]
26 | super().__init__(
27 | ASPP(in_channels, rates),
28 | nn.Conv2d(256, 256, 3, padding=1, bias=False),
29 | nn.BatchNorm2d(256),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(256, num_classes, 1)
32 | )
33 | class DeepLabHeadNoASSP(nn.Sequential):
34 | def __init__(self, in_channels, num_classes):
35 | super().__init__(
36 | ASPP(in_channels, []),
37 | nn.Conv2d(256, num_classes, 1)
38 | )
39 |
40 | def get_ASSP(in_channels,output_stride,output_channels=256):
41 | base_rates = [3, 6, 9]
42 | mul = 32 // output_stride
43 | rates = [x * mul for x in base_rates]
44 | return ASPP(in_channels, rates,output_channels)
45 |
46 | class ASPPConv(nn.Sequential):
47 | def __init__(self, in_channels, out_channels, dilation):
48 | modules = [
49 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
50 | nn.BatchNorm2d(out_channels),
51 | nn.ReLU(inplace=True)
52 | ]
53 | super(ASPPConv, self).__init__(*modules)
54 |
55 | class ASPPPooling(nn.Sequential):
56 | def __init__(self, in_channels, out_channels):
57 | super(ASPPPooling, self).__init__(
58 | nn.AdaptiveAvgPool2d(1),
59 | nn.Conv2d(in_channels, out_channels, 1, bias=False),
60 | nn.BatchNorm2d(out_channels),
61 | nn.ReLU(inplace=True))
62 |
63 | def forward(self, x):
64 | size = x.shape[-2:]
65 | for mod in self:
66 | x = mod(x)
67 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
68 |
69 | class ASPP(nn.Module):
70 | def __init__(self, in_channels, atrous_rates, out_channels=256,intermediate_channels=256,dropout=0.5):
71 | super(ASPP, self).__init__()
72 | modules = []
73 | modules.append(nn.Sequential(
74 | nn.Conv2d(in_channels, intermediate_channels, 1, bias=False),
75 | nn.BatchNorm2d(intermediate_channels),
76 | nn.ReLU(inplace=True)))
77 |
78 | rates = tuple(atrous_rates)
79 | for rate in rates:
80 | modules.append(ASPPConv(in_channels, intermediate_channels, rate))
81 |
82 | modules.append(ASPPPooling(in_channels, intermediate_channels))
83 |
84 | self.convs = nn.ModuleList(modules)
85 | num_branches=len(self.convs)
86 | self.project = nn.Sequential(
87 | nn.Conv2d(num_branches * intermediate_channels, out_channels, 1, bias=False),
88 | nn.BatchNorm2d(out_channels),
89 | nn.ReLU(inplace=True),
90 | nn.Dropout(dropout))
91 |
92 | def forward(self, x):
93 | res = []
94 | for conv in self.convs:
95 | res.append(conv(x))
96 | res = torch.cat(res, dim=1)
97 | return self.project(res)
98 |
99 |
100 | def convert_to_separable_conv(module,deep_copy=True):
101 | new_module=module
102 | if deep_copy:
103 | new_module = deepcopy(module)
104 | if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1 and module.groups == 1:
105 | new_module = AtrousSeparableConvolution(
106 | module.in_channels,
107 | module.out_channels,
108 | module.kernel_size,
109 | module.stride,
110 | module.padding,
111 | module.dilation,
112 | module.bias is not None)
113 | for name, child in new_module.named_children():
114 | new_module.add_module(name, convert_to_separable_conv(child,deep_copy=False))
115 | return new_module
116 |
117 | if __name__=='__main__':
118 | module=nn.Conv2d(5,5,1,bias=False)
119 | print(module.in_channels,
120 | module.out_channels,
121 | module.kernel_size,
122 | module.stride,
123 | module.padding,
124 | module.dilation,
125 | module.bias is not None)
126 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch_DeepLab
2 |
3 | This repo is old. Go check out my new model [RegSeg](https://github.com/RolandGao/RegSeg) that achieved SOTA on real-time semantic segmentation on Cityscapes.
4 |
5 | Currently, the code supports DeepLabv3+ with many common backbones, such as Mobilenetv2, Mobilenetv3, Resnet, Resnetv2, XceptionAligned, Regnet, EfficientNet, and many more, thanks to the package [timm](https://github.com/rwightman/pytorch-image-models). The code supports 3 datasets, namely PascalVoc, Coco, and Cityscapes.
6 |
7 | I trained a few models on Cityscapes and PascalVoc, and will release the weights soon.
8 |
9 | ## Results
10 |
11 | Using separable convolution in the decoder
12 | reduces model size and the number of flops,
13 | but increases the memory requirement by 1 GB during training.
14 |
15 | #### PascalVoc
16 | To use the weights, click the link, and instantiate an object like the line below,
17 | changing the name, sc("separable convolution"), and the path to the pretrained weights that you just downloaded.
18 |
19 | ```
20 | model=Deeplab3P(name='regnetx_040',num_classes=21,
21 | sc=False,pretrained=pretrained_path).to(device)
22 | ```
23 | name | separable convolution | mIOU | weights
24 | --- | --- | --- | ---
25 | resnet50d | yes | 77.1 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_resnet50d)
26 | regnetx_040 | yes | 77.0 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_regnetx40)
27 | regnety_040 | yes | 78.6 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_regnety40)
28 | regnetx_080 | no | 77.3 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_regnetx80)
29 | mobilenetv2 | no | 72.8 | [link](https://github.com/RolandGao/PyTorch_DeepLab/releases/download/v1.0-alpha/voc_mobilenetv2)
30 |
31 | ## Installation
32 | After cloning the repository, run the following command to install all dependencies.
33 | pip install -r requirements.txt
34 |
35 | ## Datasets
36 | #### COCO
37 | run the command
38 | ```shell
39 | sh coco_download.sh
40 | ```
41 | We use the 21 classes that intersect PascalVoc's.
42 |
43 | #### Cityscapes
44 | go to https://www.cityscapes-dataset.com, create an account, and download
45 | gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip.
46 | You can delete the test images to save some space if you don't want to submit to the competition.
47 | Name the directory cityscapes_dataset.
48 | Make sure that you have downloaded the required python packages and run
49 | ```
50 | CITYSCAPES_DATASET=cityscapes_dataset csCreateTrainIdLabelImgs
51 | ```
52 | There are 19 classes.
53 |
54 | #### PascalVoc
55 | Download the original dataset [here](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar).
56 |
57 | Then download the augmented dataset [here](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0),
58 | and create a text file named train_aug.txt with [this content](https://gist.githubusercontent.com/sun11/2dbda6b31acc7c6292d14a872d0c90b7/raw/5f5a5270089239ef2f6b65b1cc55208355b5acca/trainaug.txt).
59 |
60 | Place train_aug.txt in VOCdevkit/VOC2012/ImageSets/Segmentation/train_aug.txt
61 |
62 | Place SegmentationClassAug directory in VOCdevkit/VOC2012/SegmentationClassAug
63 |
64 | There are 21 claases.
65 |
66 | Credits to https://www.sun11.me/blog/2018/how-to-use-10582-trainaug-images-on-DeeplabV3-code/
67 |
68 |
69 | #### Once you have downloaded the dataset
70 | do one of the following three lines in train.py
71 | ```
72 | data_loader, data_loader_test=get_coco(root,batch_size=16)
73 | data_loader, data_loader_test=get_pascal_voc(root,batch_size=16)
74 | data_loader, data_loader_test=get_cityscapes(root,batch_size=16)
75 | ```
76 | where the root is usually "." or the top level directory name of the dataset.
77 |
78 | ## To train a model yourself
79 | Download one of the three datasets, change save_path, and num_classes in train.py if necessary, and run the command
80 | ```
81 | python train.py
82 | ```
83 |
84 | ## To resume training
85 | In train.py, set resume=True, and change the resume_path to the save_path of your last train session.
86 |
--------------------------------------------------------------------------------
/accuracy_upperbound.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torch import nn
3 | import torch
4 | from torch.nn import functional as F
5 | from train import ConfusionMatrix
6 | from data import get_cityscapes,get_pascal_voc
7 |
8 | class DoNothingNet(nn.Module):
9 | def __init__(self,output_stride=16,mode="bilinear"):
10 | super(DoNothingNet,self).__init__()
11 | self.os=output_stride
12 | self.mode=mode
13 | def forward(self,y):
14 | shape = y.shape[-2:]
15 | downsample_shape=((shape[0]-1)//self.os+1,(shape[1]-1)//self.os+1)
16 | if self.mode=="bilinear":
17 | y=F.interpolate(y,size=downsample_shape,mode='bilinear',align_corners=False)
18 | elif self.mode=="adaptive_avg":
19 | y=F.adaptive_avg_pool2d(y,downsample_shape)
20 | elif self.mode=="adaptive_max":
21 | y=F.adaptive_max_pool2d(y,downsample_shape)
22 | elif self.mode=="max3x3":
23 | y=F.max_pool2d(y,kernel_size=3,stride=2,padding=1)
24 | elif self.mode=="avg3x3":
25 | y=F.avg_pool2d(y,kernel_size=3,stride=2,padding=1)
26 | else:
27 | raise NotImplementedError()
28 | y=F.interpolate(y,size=shape,mode='bilinear',align_corners=False)
29 | return y
30 |
31 | def evaluate(model, data_loader, device, num_classes,eval_steps,print_every=100):
32 | model.eval()
33 | confmat = ConfusionMatrix(num_classes)
34 | with torch.no_grad():
35 | for i,(image, target) in enumerate(data_loader):
36 | if (i+1)%print_every==0:
37 | print(i+1)
38 | if i==eval_steps:
39 | break
40 | target = target.to(device)
41 | logits=torch.zeros(target.shape[0],num_classes,target.shape[1],target.shape[2],device=device)
42 | target2=torch.unsqueeze(target,1)
43 | target2[target2==255]=0
44 | logits.scatter_(1,target2,1)
45 | output = model(logits)
46 | confmat.update(target.flatten(), output.argmax(1).flatten())
47 | return confmat
48 |
49 | def f(os,mode,device,num_classes):
50 | net=DoNothingNet(output_stride=os,mode=mode).to(device)
51 | #data_loader, data_loader_test=get_pascal_voc("pascal_voc_dataset",16,train_size=481,val_size=513)
52 | data_loader, data_loader_test=get_cityscapes("cityscapes_dataset",16,train_size=480,val_size=1024,num_workers=0)
53 | confmat=evaluate(net,data_loader_test,device,num_classes,eval_steps=100,print_every=20)
54 | return confmat
55 |
56 | def experiment1():
57 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
58 | num_classes=19
59 | print()
60 | for mode in ["adaptive_avg"]:#"bilinear","adaptive_max","max3x3",
61 | for os in [2,4,8,16,32]:
62 | confmat=f(os,mode,device,num_classes)
63 | # net=DoNothingNet(output_stride=os,mode=mode).to(device)
64 | # #data_loader, data_loader_test=get_pascal_voc("pascal_voc_dataset",16,train_size=481,val_size=513)
65 | # data_loader, data_loader_test=get_cityscapes("cityscapes_dataset",16,train_size=480,val_size=1024)
66 | # confmat=evaluate(net,data_loader_test,device,num_classes,eval_steps=300,print_every=100)
67 | print(mode,os)
68 | print(confmat)
69 |
70 | if __name__=="__main__":
71 | experiment1()
72 |
--------------------------------------------------------------------------------
/augment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """
9 | Lightweight and simple implementation of AutoAugment and RandAugment.
10 |
11 | AutoAugment - https://arxiv.org/abs/1805.09501
12 | RandAugment - https://arxiv.org/abs/1909.13719
13 |
14 | http://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
15 | Note that the official implementation varies substantially from the papers :-(
16 |
17 | Our AutoAugment policy should be fairly identical to the official AutoAugment policy.
18 | The main difference is we set POSTERIZE_MIN = 1, which avoids degenerate (all 0) images.
19 | Our RandAugment policy differs, and uses transforms that increase in intensity with
20 | increasing magnitude. This allows for a more natural control of the magnitude. That is,
21 | setting magnitude = 0 results in ops that leaves the image unchanged, if possible.
22 | We also set the range of the magnitude to be 0 to 1 to avoid setting a "max level".
23 |
24 | Our implementation is inspired by and uses policies that are the similar to those in:
25 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
26 | Specifically our implementation can be *numerically identical* as the implementation in
27 | timm if using timm's "v0" policy for AutoAugment and "inc" transforms for RandAugment
28 | and if we set POSTERIZE_MIN = 0 (although as noted our default is POSTERIZE_MIN = 1).
29 | Note that magnitude in our code ranges from 0 to 1 (compared to 0 to 10 in timm).
30 |
31 | Specifically, given the same seeds, the functions from timm:
32 | out_auto = auto_augment_transform("v0", {"interpolation": 2})(im)
33 | out_rand = rand_augment_transform("rand-inc1-n2-m05", {"interpolation": 2})(im)
34 | Are numerically equivalent to:
35 | POSTERIZE_MIN = 0
36 | out_auto = auto_augment(im)
37 | out_rand = rand_augment(im, prob=0.5, n_ops=2, magnitude=0.5)
38 | Tested as of 10/07/2020. Can alter corresponding params for both and should match.
39 |
40 | Finally, the ops and augmentations can be visualized as follows:
41 | from PIL import Image
42 | import pycls.datasets.augment as augment
43 | im = Image.open("scratch.jpg")
44 | im_ops = augment.visualize_ops(im)
45 | im_rand = augment.visualize_aug(im, augment=augment.rand_augment, magnitude=0.5)
46 | im_auto = augment.visualize_aug(im, augment=augment.auto_augment)
47 | im_ops.show()
48 | im_auto.show()
49 | im_rand.show()
50 | """
51 |
52 | import random
53 |
54 | import numpy as np
55 | from PIL import Image, ImageEnhance, ImageOps
56 |
57 |
58 | # Minimum value for posterize (0 in EfficientNet implementation)
59 | POSTERIZE_MIN = 1
60 |
61 | # Parameters for affine warping and rotation
62 | WARP_PARAMS = {"fillcolor": (128, 128, 128), "resample": Image.BILINEAR}
63 |
64 |
65 | def affine_warp(im, data):
66 | """Applies affine transform to image."""
67 | return im.transform(im.size, Image.AFFINE, data, **WARP_PARAMS)
68 |
69 |
70 | OP_FUNCTIONS = {
71 | # Each op takes an image x and a level v and returns an augmented image.
72 | "auto_contrast": lambda x, _: ImageOps.autocontrast(x),
73 | "equalize": lambda x, _: ImageOps.equalize(x),
74 | "invert": lambda x, _: ImageOps.invert(x),
75 | "rotate": lambda x, v: x.rotate(v, **WARP_PARAMS),
76 | "posterize": lambda x, v: ImageOps.posterize(x, max(POSTERIZE_MIN, int(v))),
77 | "posterize_inc": lambda x, v: ImageOps.posterize(x, max(POSTERIZE_MIN, 4 - int(v))),
78 | "solarize": lambda x, v: x.point(lambda i: i if i < int(v) else 255 - i),
79 | "solarize_inc": lambda x, v: x.point(lambda i: i if i < 256 - v else 255 - i),
80 | "solarize_add": lambda x, v: x.point(lambda i: min(255, v + i) if i < 128 else i),
81 | "color": lambda x, v: ImageEnhance.Color(x).enhance(v),
82 | "contrast": lambda x, v: ImageEnhance.Contrast(x).enhance(v),
83 | "brightness": lambda x, v: ImageEnhance.Brightness(x).enhance(v),
84 | "sharpness": lambda x, v: ImageEnhance.Sharpness(x).enhance(v),
85 | "color_inc": lambda x, v: ImageEnhance.Color(x).enhance(1 + v),
86 | "contrast_inc": lambda x, v: ImageEnhance.Contrast(x).enhance(1 + v),
87 | "brightness_inc": lambda x, v: ImageEnhance.Brightness(x).enhance(1 + v),
88 | "sharpness_inc": lambda x, v: ImageEnhance.Sharpness(x).enhance(1 + v),
89 | "shear_x": lambda x, v: affine_warp(x, (1, v, 0, 0, 1, 0)),
90 | "shear_y": lambda x, v: affine_warp(x, (1, 0, 0, v, 1, 0)),
91 | "trans_x": lambda x, v: affine_warp(x, (1, 0, v * x.size[0], 0, 1, 0)),
92 | "trans_y": lambda x, v: affine_warp(x, (1, 0, 0, 0, 1, v * x.size[1])),
93 | }
94 | affine_ops=[
95 | "rotate","shear_x","shear_y","trans_x","trans_y"
96 | ]
97 |
98 |
99 | OP_RANGES = {
100 | # Ranges for each op in the form of a (min, max, negate).
101 | "auto_contrast": (0, 1, False),
102 | "equalize": (0, 1, False),
103 | "invert": (0, 1, False),
104 | "rotate": (0.0, 30.0, True),
105 | "posterize": (0, 4, False),
106 | "posterize_inc": (0, 4, False),
107 | "solarize": (0, 256, False),
108 | "solarize_inc": (0, 256, False),
109 | "solarize_add": (0, 110, False),
110 | "color": (0.1, 1.9, False),
111 | "contrast": (0.1, 1.9, False),
112 | "brightness": (0.1, 1.9, False),
113 | "sharpness": (0.1, 1.9, False),
114 | "color_inc": (0, 0.9, True),
115 | "contrast_inc": (0, 0.9, True),
116 | "brightness_inc": (0, 0.9, True),
117 | "sharpness_inc": (0, 0.9, True),
118 | "shear_x": (0.0, 0.3, True),
119 | "shear_y": (0.0, 0.3, True),
120 | "trans_x": (0.0, 0.45, True),
121 | "trans_y": (0.0, 0.45, True),
122 | }
123 |
124 |
125 | AUTOAUG_POLICY = [
126 | # AutoAugment "policy_v0" in form of (op, prob, magnitude), where magnitude <= 1.
127 | [("equalize", 0.8, 0.1), ("shear_y", 0.8, 0.4)],
128 | [("color", 0.4, 0.9), ("equalize", 0.6, 0.3)],
129 | [("color", 0.4, 0.1), ("rotate", 0.6, 0.8)],
130 | [("solarize", 0.8, 0.3), ("equalize", 0.4, 0.7)],
131 | [("solarize", 0.4, 0.2), ("solarize", 0.6, 0.2)],
132 | [("color", 0.2, 0.0), ("equalize", 0.8, 0.8)],
133 | [("equalize", 0.4, 0.8), ("solarize_add", 0.8, 0.3)],
134 | [("shear_x", 0.2, 0.9), ("rotate", 0.6, 0.8)],
135 | [("color", 0.6, 0.1), ("equalize", 1.0, 0.2)],
136 | [("invert", 0.4, 0.9), ("rotate", 0.6, 0.0)],
137 | [("equalize", 1.0, 0.9), ("shear_y", 0.6, 0.3)],
138 | [("color", 0.4, 0.7), ("equalize", 0.6, 0.0)],
139 | [("posterize", 0.4, 0.6), ("auto_contrast", 0.4, 0.7)],
140 | [("solarize", 0.6, 0.8), ("color", 0.6, 0.9)],
141 | [("solarize", 0.2, 0.4), ("rotate", 0.8, 0.9)],
142 | [("rotate", 1.0, 0.7), ("trans_y", 0.8, 0.9)],
143 | [("shear_x", 0.0, 0.0), ("solarize", 0.8, 0.4)],
144 | [("shear_y", 0.8, 0.0), ("color", 0.6, 0.4)],
145 | [("color", 1.0, 0.0), ("rotate", 0.6, 0.2)],
146 | [("equalize", 0.8, 0.4), ("equalize", 0.0, 0.8)],
147 | [("equalize", 1.0, 0.4), ("auto_contrast", 0.6, 0.2)],
148 | [("shear_y", 0.4, 0.7), ("solarize_add", 0.6, 0.7)],
149 | [("posterize", 0.8, 0.2), ("solarize", 0.6, 1.0)],
150 | [("solarize", 0.6, 0.8), ("equalize", 0.6, 0.1)],
151 | [("color", 0.8, 0.6), ("rotate", 0.4, 0.5)],
152 | ]
153 |
154 |
155 | RANDAUG_OPS = [
156 | # RandAugment list of operations using "increasing" transforms.
157 | "auto_contrast",
158 | "equalize",
159 | "invert",
160 | "rotate",
161 | "posterize_inc",
162 | "solarize_inc",
163 | "solarize_add",
164 | "color_inc",
165 | "contrast_inc",
166 | "brightness_inc",
167 | "sharpness_inc",
168 | "shear_x",
169 | "shear_y",
170 | "trans_x",
171 | "trans_y",
172 | ]
173 |
174 | def check_support():
175 | mask=np.zeros((100,100)).astype("uint8")
176 | mask=Image.fromarray(mask)
177 | magnitude=1.0
178 | for op in RANDAUG_OPS:
179 | min_v, max_v, negate = OP_RANGES[op]
180 | v = magnitude * (max_v - min_v) + min_v
181 | v = -v if negate and random.random() > 0.5 else v
182 | OP_FUNCTIONS[op](mask, v)
183 |
184 |
185 | def apply_op(im, op, prob, magnitude):
186 | """Apply the selected op to image with given probability and magnitude."""
187 | # The magnitude is converted to an absolute value v for an op (some ops use -v or v)
188 | assert 0 <= magnitude <= 1
189 | assert op in OP_RANGES and op in OP_FUNCTIONS, "unknown op " + op
190 | if prob < 1 and random.random() > prob:
191 | return im
192 | min_v, max_v, negate = OP_RANGES[op]
193 | v = magnitude * (max_v - min_v) + min_v
194 | v = -v if negate and random.random() > 0.5 else v
195 | return OP_FUNCTIONS[op](im, v)
196 |
197 | def apply_op_both(im,mask, op, prob, magnitude,fill,ignore_value=255):
198 | """Apply the selected op to image with given probability and magnitude."""
199 | # The magnitude is converted to an absolute value v for an op (some ops use -v or v)
200 | assert 0 <= magnitude <= 1
201 | assert op in OP_RANGES and op in OP_FUNCTIONS, "unknown op " + op
202 | if prob < 1 and random.random() > prob:
203 | return im,mask
204 | min_v, max_v, negate = OP_RANGES[op]
205 | v = magnitude * (max_v - min_v) + min_v
206 | v = -v if negate and random.random() > 0.5 else v
207 | WARP_PARAMS["fillcolor"]=fill
208 | im=OP_FUNCTIONS[op](im, v)
209 | if op in affine_ops:
210 | WARP_PARAMS["fillcolor"]=ignore_value
211 | mask=OP_FUNCTIONS[op](mask, v)
212 | return im,mask
213 |
214 | def rand_augment_both(im, mask,magnitude, ops=None, n_ops=2, prob=1.0,fill=(128,128,128),ignore_value=255):
215 | """Applies random augmentation to an image."""
216 | ops = ops if ops else RANDAUG_OPS
217 | for op in np.random.choice(ops, int(n_ops)):
218 | im,mask = apply_op_both(im,mask, op, prob, magnitude,fill,ignore_value)
219 | return im,mask
220 |
221 | def rand_augment(im, magnitude, ops=None, n_ops=2, prob=1.0):
222 | """Applies random augmentation to an image."""
223 | ops = ops if ops else RANDAUG_OPS
224 | for op in np.random.choice(ops, int(n_ops)):
225 | im = apply_op(im, op, prob, magnitude)
226 | return im
227 |
228 |
229 | def auto_augment(im, policy=None):
230 | """Apply auto augmentation to an image."""
231 | policy = policy if policy else AUTOAUG_POLICY
232 | for op, prob, magnitude in random.choice(policy):
233 | im = apply_op(im, op, prob, magnitude)
234 | return im
235 |
236 |
237 | def make_augment(augment_str):
238 | """Generate augmentation function from separated parameter string.
239 | The parameter string augment_str may be either "AutoAugment" or "RandAugment".
240 | Undocumented use allows for specifying extra params, e.g. "RandAugment_N2_M0.5"."""
241 | params = augment_str.split("_")
242 | names = {"N": "n_ops", "M": "magnitude", "P": "prob"}
243 | assert params[0] in ["RandAugment", "AutoAugment"]
244 | assert all(p[0] in names for p in params[1:])
245 | keys = [names[p[0]] for p in params[1:]]
246 | vals = [float(p[1:]) for p in params[1:]]
247 | augment = rand_augment if params[0] == "RandAugment" else auto_augment
248 | return lambda im: augment(im, **dict(zip(keys, vals)))
249 |
250 |
251 | def visualize_ops(im, ops=None, num_steps=10):
252 | """Visualize ops by applying each op by varying amounts."""
253 | ops = ops if ops else RANDAUG_OPS
254 | w, h, magnitudes = im.size[0], im.size[1], np.linspace(0, 1, num_steps)
255 | output = Image.new("RGB", (w * num_steps, h * len(ops)))
256 | for i, op in enumerate(ops):
257 | for j, m in enumerate(magnitudes):
258 | out = apply_op(im, op, prob=1.0, magnitude=m)
259 | output.paste(out, (j * w, i * h))
260 | return output
261 |
262 |
263 | def visualize_aug(im, augment=rand_augment, num_trials=10, **kwargs):
264 | """Visualize augmentation by applying random augmentations."""
265 | w, h = im.size[0], im.size[1]
266 | output = Image.new("RGB", (w * num_trials, h * num_trials))
267 | for i in range(num_trials):
268 | for j in range(num_trials):
269 | output.paste(augment(im, **kwargs), (j * w, i * h))
270 | return output
271 |
272 | if __name__=="__main__":
273 | im=Image.open("prima_4class/Images/pc-00000085.jpg")
274 | w, h = im.size[0], im.size[1]
275 | im=im.resize((w//4,h//4))
276 | output=visualize_ops(im,num_steps=10)
277 | #output=visualize_aug(im,magnitude=1/3)
278 | output.show()
279 |
--------------------------------------------------------------------------------
/benchmark.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | import time
4 | import torch.cuda.amp as amp
5 | import torch.nn.functional
6 |
7 | @torch.no_grad()
8 | def compute_eval_time(model,device,warmup_iter,num_iter,crop_size,batch_size,mixed_precision):
9 | model.eval()
10 | x=torch.randn(batch_size,3,crop_size,crop_size).to(device)
11 | times=[]
12 | for cur_iter in range(warmup_iter+num_iter):
13 | if cur_iter == warmup_iter:
14 | times.clear()
15 | t1=time.time()
16 | with amp.autocast(enabled=mixed_precision):
17 | output = model(x)
18 | torch.cuda.synchronize()
19 | t2=time.time()
20 | times.append(t2-t1)
21 | return average(times)
22 |
23 | def average(v):
24 | return sum(v)/len(v)
25 | def compute_train_time(model,warmup_iter,num_iter,crop_size,batch_size,num_classes,mixed_precision):
26 | model.train()
27 | x=torch.randn(batch_size, 3, crop_size, crop_size).cuda(non_blocking=False)
28 | target=torch.randint(0,num_classes,(batch_size, crop_size, crop_size)).cuda(non_blocking=False)
29 | fw_times=[]
30 | bw_times=[]
31 | scaler = amp.GradScaler(enabled=mixed_precision)
32 | for cur_iter in range(warmup_iter+num_iter):
33 | if cur_iter == warmup_iter:
34 | fw_times.clear()
35 | bw_times.clear()
36 | t1=time.time()
37 | with amp.autocast(enabled=mixed_precision):
38 | output = model(x)
39 | loss = nn.functional.cross_entropy(output,target,ignore_index=255)
40 | torch.cuda.synchronize()
41 | t2=time.time()
42 | scaler.scale(loss).backward()
43 | torch.cuda.synchronize()
44 | t3=time.time()
45 | fw_times.append(t2-t1)
46 | bw_times.append(t3-t2)
47 | return average(fw_times),average(bw_times)
48 |
49 | def compute_loader_time(data_loader,warmup_iter,num_iter):
50 | times=[]
51 | data_loader_iter=iter(data_loader)
52 | for cur_iter in range(warmup_iter+num_iter):
53 | if cur_iter == warmup_iter:
54 | times.clear()
55 | t1=time.time()
56 | next(data_loader_iter)
57 | t2=time.time()
58 | times.append(t2-t1)
59 | return average(times)
60 |
61 |
62 | def memory_used(device):
63 | x=torch.cuda.memory_allocated(device)
64 | return round(x/1024/1024)
65 | def max_memory_used(device):
66 | x=torch.cuda.max_memory_allocated(device)
67 | return round(x/1024/1024)
68 | def memory_test_helper(model,device,crop_size,batch_size,num_classes,mixed_precision):
69 | model.train()
70 | scaler = amp.GradScaler(enabled=mixed_precision)
71 | x=torch.randn(batch_size, 3, crop_size, crop_size).to(device)
72 | target=torch.randint(0,num_classes,(batch_size, crop_size, crop_size)).to(device)
73 | t1=memory_used(device)
74 | with amp.autocast(enabled=mixed_precision):
75 | output = model(x)
76 | loss = nn.functional.cross_entropy(output,target,ignore_index=255)
77 | scaler.scale(loss).backward()
78 | torch.cuda.synchronize()
79 | t2=max_memory_used(device)
80 | torch.cuda.reset_peak_memory_stats(device)
81 | return t2-t1
82 |
83 | def compute_memory_usage(model,device,crop_size,batch_size,num_classes,mixed_precision):
84 | for p in model.parameters():
85 | p.grad=None
86 | try:
87 | t=memory_test_helper(model,device,crop_size,batch_size,num_classes,mixed_precision)
88 | print()
89 | except:
90 | t=-1
91 | print("out of memory")
92 | for p in model.parameters():
93 | p.grad=None
94 | return t
95 |
96 | def compute_time_no_loader(model,warmup_iter,num_iter,device,crop_size,batch_size,num_classes,mixed_precision):
97 | model=model.to(device)
98 | print("benchmarking eval time")
99 | eval_time=compute_eval_time(model,device,warmup_iter,num_iter,crop_size,batch_size,mixed_precision)
100 | print("benchmarking train time")
101 | train_fw_time,train_bw_time=compute_train_time(model,warmup_iter,num_iter,crop_size,batch_size,num_classes,mixed_precision)
102 | train_time=train_fw_time+train_bw_time
103 | print("benchmarking memory usage")
104 | memory_usage=compute_memory_usage(model,device,crop_size,batch_size,num_classes,mixed_precision)
105 | print("benchmarking loader time")
106 | dic1={
107 | "eval_time":eval_time,
108 | "train_time":train_time,
109 | "memory_usage":memory_usage
110 | }
111 | return dic1
112 |
113 | def compute_time_full(model,data_loader,warmup_iter,num_iter,device,crop_size,batch_size,num_classes,mixed_precision):
114 | model=model.to(device)
115 | print("benchmarking eval time")
116 | eval_time=compute_eval_time(model,device,warmup_iter,num_iter,crop_size,batch_size,mixed_precision)
117 | print("benchmarking train time")
118 | train_fw_time,train_bw_time=compute_train_time(model,warmup_iter,num_iter,crop_size,batch_size,num_classes,mixed_precision)
119 | train_time=train_fw_time+train_bw_time
120 | print("benchmarking memory usage")
121 | memory_usage=compute_memory_usage(model,device,crop_size,batch_size,num_classes,mixed_precision)
122 | print("benchmarking loader time")
123 | loader_time=compute_loader_time(data_loader,warmup_iter,num_iter)
124 | loader_overhead=max(0,loader_time-train_time)/train_time
125 | dic1={
126 | "eval_time":eval_time,
127 | "train_time":train_time,
128 | "memory_usage":memory_usage,
129 | "loader_time":loader_time,
130 | "loader_overhead":loader_overhead
131 | }
132 | dic2={
133 | "eval_time":eval_time*len(data_loader),
134 | "train_time":train_time*len(data_loader),
135 | "memory_usage":memory_usage,
136 | "loader_time":loader_time,
137 | "loader_overhead":loader_overhead
138 | }
139 | return dic1
140 |
--------------------------------------------------------------------------------
/blocks.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 | from torch.nn import functional as F
4 |
5 | class XBlock(nn.Module): # From figure 4
6 | def __init__(self, in_channels, out_channels, bottleneck_ratio, group_width, stride):
7 | super(XBlock, self).__init__()
8 | inter_channels = out_channels // bottleneck_ratio
9 | groups = inter_channels // group_width
10 |
11 | self.conv_block_1 = nn.Sequential(
12 | nn.Conv2d(in_channels, inter_channels, kernel_size=1, bias=False),
13 | nn.BatchNorm2d(inter_channels),
14 | nn.ReLU(inplace=True)
15 | )
16 | self.conv_block_2 = nn.Sequential(
17 | nn.Conv2d(inter_channels, inter_channels, kernel_size=3, stride=stride, groups=groups, padding=1, bias=False),
18 | nn.BatchNorm2d(inter_channels),
19 | nn.ReLU(inplace=True)
20 | )
21 |
22 | self.conv_block_3 = nn.Sequential(
23 | nn.Conv2d(inter_channels, out_channels, kernel_size=1, bias=False),
24 | nn.BatchNorm2d(out_channels)
25 | )
26 | if stride != 1 or in_channels != out_channels:
27 | self.shortcut = nn.Sequential(
28 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
29 | nn.BatchNorm2d(out_channels)
30 | )
31 | else:
32 | self.shortcut = None
33 | self.rl = nn.ReLU(inplace=True)
34 |
35 | def forward(self, x):
36 | shortcut=self.shortcut(x) if self.shortcut else x
37 | x = self.conv_block_1(x)
38 | x = self.conv_block_2(x)
39 | x = self.conv_block_3(x)
40 | x = self.rl(x + shortcut)
41 | return x
42 |
43 |
44 | class VBlock(nn.Module):
45 | def __init__(self, in_channels, out_channels, stride):
46 | super(VBlock, self).__init__()
47 | self.conv_block_1 = nn.Sequential(
48 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
49 | nn.BatchNorm2d(out_channels),
50 | nn.ReLU(inplace=True)
51 | )
52 | self.conv_block_2 = nn.Sequential(
53 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
54 | nn.BatchNorm2d(out_channels),
55 | )
56 | if stride != 1 or in_channels != out_channels:
57 | self.shortcut = nn.Sequential(
58 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
59 | nn.BatchNorm2d(out_channels)
60 | )
61 | else:
62 | self.shortcut = None
63 | self.rl = nn.ReLU(inplace=True)
64 |
65 | def forward(self, x):
66 | shortcut=self.shortcut(x) if self.shortcut else x
67 | x = self.conv_block_1(x)
68 | x = self.conv_block_2(x)
69 | x = self.rl(x + shortcut)
70 | return x
71 |
72 | class AtrousConcat(nn.Module):
73 | def __init__(self, in_channels, out_channels, group_width,rates):
74 | super(AtrousConcat, self).__init__()
75 | modules = []
76 | groups = out_channels // group_width
77 | for rate in rates:
78 | modules.append(nn.Sequential(
79 | nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False,groups=groups),
80 | nn.BatchNorm2d(out_channels),
81 | nn.ReLU(inplace=True)
82 | ))
83 | self.convs = nn.ModuleList(modules)
84 |
85 | def forward(self, x):
86 | res = []
87 | for conv in self.convs:
88 | res.append(conv(x))
89 | res = torch.cat(res, dim=1)
90 | return res
91 | class SpatialConcat(nn.Module):
92 | def __init__(self, in_channels, out_channels,group_width, bin_sizes):
93 | super(SpatialConcat, self).__init__()
94 | modules = []
95 | groups = out_channels // group_width
96 | self.reg_conv=nn.Sequential(
97 | nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False,groups=groups),
98 | nn.BatchNorm2d(out_channels),
99 | nn.ReLU(inplace=True),
100 | )
101 | for size in bin_sizes:
102 | modules.append(nn.Sequential(
103 | nn.AdaptiveAvgPool2d(output_size=(size,size)),
104 | nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False,groups=groups),
105 | nn.BatchNorm2d(out_channels),
106 | nn.ReLU(inplace=True),
107 | ))
108 | self.convs = nn.ModuleList(modules)
109 |
110 | def forward(self, x):
111 | input_shape = x.shape[-2:]
112 | res = [self.reg_conv(x)]
113 | for conv in self.convs:
114 | res.append(F.interpolate(conv(x), size=input_shape, mode='bilinear',align_corners=False))
115 | res = torch.cat(res, dim=1)
116 | return res
117 | class AtrousSum(nn.Module):
118 | def __init__(self, in_channels, out_channels,group_width, rates):
119 | super(AtrousSum, self).__init__()
120 | modules = []
121 | groups = out_channels // group_width
122 | for rate in rates:
123 | modules.append(nn.Sequential(
124 | nn.Conv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False,groups=groups),
125 | nn.BatchNorm2d(out_channels),
126 | nn.ReLU(inplace=True)
127 | ))
128 | self.convs = nn.ModuleList(modules)
129 |
130 | def forward(self, x):
131 | out=self.convs[0](x)
132 | for conv in self.convs[1:]:
133 | out=out+conv(x)
134 | return out
135 |
136 | class ABlock(nn.Module):
137 | def __init__(self, in_channels, conv1_channels,conv2_channels, out_channels,group_width,rates=(1,6,12,18),mode="concat"):
138 | super(ABlock, self).__init__()
139 | self.conv1 = nn.Sequential(
140 | nn.Conv2d(in_channels, conv1_channels, kernel_size=1, bias=False),
141 | nn.BatchNorm2d(conv1_channels),
142 | nn.ReLU(inplace=True)
143 | )
144 | if mode=="concat":
145 | self.conv2=AtrousConcat(conv1_channels, conv2_channels, group_width,rates)
146 | conv2_out_channels=conv2_channels * len(rates)
147 | elif mode=="sum":
148 | self.conv2=AtrousSum(conv1_channels, conv2_channels, group_width,rates)
149 | conv2_out_channels=conv2_channels
150 | else:
151 | raise NotImplementedError()
152 | self.conv3 = nn.Sequential(
153 | nn.Conv2d(conv2_out_channels, out_channels, 1, bias=False),
154 | nn.BatchNorm2d(out_channels),
155 | )
156 | self.shortcut=None
157 | if in_channels != out_channels:
158 | self.shortcut = nn.Sequential(
159 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
160 | nn.BatchNorm2d(out_channels)
161 | )
162 | self.rl = nn.Sequential(
163 | nn.ReLU(inplace=True)
164 | )
165 |
166 | def forward(self, x):
167 | shortcut=self.shortcut(x) if self.shortcut else x
168 | x = self.conv1(x)
169 | x = self.conv2(x)
170 | x = self.conv3(x)
171 | x = self.rl(x + shortcut)
172 | return x
173 |
174 | class SBlock(nn.Module):
175 | def __init__(self, in_channels, conv1_channels,conv2_channels, out_channels,group_width,bin_sizes=(1,2,3,6),mode="concat"):
176 | super(SBlock, self).__init__()
177 | self.conv1 = nn.Sequential(
178 | nn.Conv2d(in_channels, conv1_channels, kernel_size=1, bias=False),
179 | nn.BatchNorm2d(conv1_channels),
180 | nn.ReLU(inplace=True)
181 | )
182 | if mode=="concat":
183 | self.conv2=SpatialConcat(conv1_channels, conv2_channels, group_width,bin_sizes)
184 | conv2_out_channels=conv2_channels * (len(bin_sizes)+1)
185 | else:
186 | raise NotImplementedError()
187 | self.conv3 = nn.Sequential(
188 | nn.Conv2d(conv2_out_channels, out_channels, 1, bias=False),
189 | nn.BatchNorm2d(out_channels),
190 | )
191 | self.shortcut=None
192 | if in_channels != out_channels:
193 | self.shortcut = nn.Sequential(
194 | nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
195 | nn.BatchNorm2d(out_channels)
196 | )
197 | self.rl = nn.Sequential(
198 | nn.ReLU(inplace=True)
199 | )
200 |
201 | def forward(self, x):
202 | shortcut=self.shortcut(x) if self.shortcut else x
203 | x = self.conv1(x)
204 | x = self.conv2(x)
205 | x = self.conv3(x)
206 | x = self.rl(x + shortcut)
207 | return x
208 |
209 | def profile(model,x,device,num_iter=15):
210 | import time
211 | model=model.to(device)
212 | model.eval()
213 | t1=time.time()
214 | for i in range(num_iter):
215 | y=model(x)
216 | t2=time.time()
217 | return (t2-t1)/num_iter
218 |
219 | def profiler(models,x):
220 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
221 | print("warming up")
222 | profile(models[0],x, device,num_iter=15)
223 | for model in models:
224 | seconds=profile(model,x,device,num_iter=30)
225 | print(round(seconds,3))
226 |
227 | def try_models():
228 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
229 | x=torch.randn(2,128,128,128)
230 | group_width=128
231 | w=128
232 | model1=ABlock(w,w,w,w,group_width,rates=(1,6))
233 | model2=SBlock(w,w,w,w,group_width,bin_sizes=(4,16))
234 | model3=XBlock(w,w,1,group_width,1)
235 | models=[model1,model2,model3]
236 | profiler(models,x)
237 |
238 |
239 | if __name__=="__main__":
240 | try_models()
241 |
--------------------------------------------------------------------------------
/cityscapes.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import namedtuple
4 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple
5 | import torch.utils.data as data
6 | from PIL import Image
7 |
8 |
9 | class Cityscapes(data.Dataset):
10 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
11 | 'has_instances', 'ignore_in_eval', 'color'])
12 |
13 | classes = [
14 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
15 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
16 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
17 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
18 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
19 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
20 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
21 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
22 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
23 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
24 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
25 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
26 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
27 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
28 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
29 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
30 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
31 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
32 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
33 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
34 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
35 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
36 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
37 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
38 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
39 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
40 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
41 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
42 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
43 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
44 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
45 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
46 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
47 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
48 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)),
49 | ]
50 |
51 | def __init__(
52 | self,
53 | root: str,
54 | split: str = "train",
55 | mode: str = "fine",
56 | target_type: Union[List[str], str] = "semantic",
57 | transforms: Optional[Callable] = None,
58 | ) -> None:
59 | self.root=root
60 | self.transforms=transforms
61 | self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
62 | self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
63 | self.targets_dir = os.path.join(self.root, self.mode, split)
64 | self.target_type = target_type
65 | self.split = split
66 | self.images = []
67 | self.targets = []
68 |
69 | if not isinstance(target_type, list):
70 | self.target_type = [target_type]
71 |
72 | if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
73 | raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
74 | ' specified "split" and "mode" are inside the "root" directory')
75 |
76 | for city in os.listdir(self.images_dir):
77 | if city[0]==".":
78 | continue
79 | img_dir = os.path.join(self.images_dir, city)
80 | target_dir = os.path.join(self.targets_dir, city)
81 | for file_name in os.listdir(img_dir):
82 | target_types = []
83 | for t in self.target_type:
84 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
85 | self._get_target_suffix(self.mode, t))
86 | target_types.append(os.path.join(target_dir, target_name))
87 |
88 | self.images.append(os.path.join(img_dir, file_name))
89 | self.targets.append(target_types)
90 |
91 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
92 | """
93 | Args:
94 | index (int): Index
95 | Returns:
96 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
97 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
98 | """
99 |
100 | image = Image.open(self.images[index]).convert('RGB')
101 |
102 | targets: Any = []
103 | for i, t in enumerate(self.target_type):
104 | if t == 'polygon':
105 | target = self._load_json(self.targets[index][i])
106 | else:
107 | target = Image.open(self.targets[index][i])
108 |
109 | targets.append(target)
110 |
111 | target = tuple(targets) if len(targets) > 1 else targets[0]
112 |
113 | if self.transforms is not None:
114 | image, target = self.transforms(image, target)
115 |
116 | return image, target
117 |
118 | def __len__(self) -> int:
119 | return len(self.images)
120 |
121 | def extra_repr(self) -> str:
122 | lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
123 | return '\n'.join(lines).format(**self.__dict__)
124 |
125 | def _load_json(self, path: str) -> Dict[str, Any]:
126 | with open(path, 'r') as file:
127 | data = json.load(file)
128 | return data
129 |
130 | def _get_target_suffix(self, mode: str, target_type: str) -> str:
131 | if target_type == 'instance':
132 | return '{}_instanceIds.png'.format(mode)
133 | elif target_type == 'semantic':
134 | return '{}_labelTrainIds.png'.format(mode)
135 | elif target_type == 'color':
136 | return '{}_color.png'.format(mode)
137 | else:
138 | return '{}_polygons.json'.format(mode)
139 |
--------------------------------------------------------------------------------
/coco_download.sh:
--------------------------------------------------------------------------------
1 | mkdir coco
2 | cd coco
3 | mkdir images
4 | cd images
5 |
6 | wget http://images.cocodataset.org/zips/train2017.zip
7 | wget http://images.cocodataset.org/zips/val2017.zip
8 |
9 | unzip -q train2017.zip
10 | unzip -q val2017.zip
11 |
12 | rm train2017.zip
13 | rm val2017.zip
14 |
15 | cd ..
16 |
17 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
18 | unzip -q annotations_trainval2017.zip
19 | rm annotations_trainval2017.zip
20 |
--------------------------------------------------------------------------------
/coco_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.utils.data
4 | import torchvision
5 | from PIL import Image
6 |
7 | import os
8 |
9 | from pycocotools import mask as coco_mask
10 |
11 | from transforms import Compose
12 |
13 |
14 | class FilterAndRemapCocoCategories(object):
15 | def __init__(self, categories, remap=True):
16 | self.categories = categories
17 | self.remap = remap
18 |
19 | def __call__(self, image, anno):
20 | anno = [obj for obj in anno if obj["category_id"] in self.categories]
21 | if not self.remap:
22 | return image, anno
23 | anno = copy.deepcopy(anno)
24 | for obj in anno:
25 | obj["category_id"] = self.categories.index(obj["category_id"])
26 | return image, anno
27 |
28 |
29 | def convert_coco_poly_to_mask(segmentations, height, width):
30 | masks = []
31 | for polygons in segmentations:
32 | rles = coco_mask.frPyObjects(polygons, height, width)
33 | mask = coco_mask.decode(rles)
34 | if len(mask.shape) < 3:
35 | mask = mask[..., None]
36 | mask = torch.as_tensor(mask, dtype=torch.uint8)
37 | mask = mask.any(dim=2)
38 | masks.append(mask)
39 | if masks:
40 | masks = torch.stack(masks, dim=0)
41 | else:
42 | masks = torch.zeros((0, height, width), dtype=torch.uint8)
43 | return masks
44 |
45 |
46 | class ConvertCocoPolysToMask(object):
47 | def __call__(self, image, anno):
48 | w, h = image.size
49 | segmentations = [obj["segmentation"] for obj in anno]
50 | cats = [obj["category_id"] for obj in anno]
51 | if segmentations:
52 | masks = convert_coco_poly_to_mask(segmentations, h, w)
53 | cats = torch.as_tensor(cats, dtype=masks.dtype)
54 | # merge all instance masks into a single segmentation map
55 | # with its corresponding categories
56 | target, _ = (masks * cats[:, None, None]).max(dim=0)
57 | # discard overlapping instances
58 | target[masks.sum(0) > 1] = 255
59 | else:
60 | target = torch.zeros((h, w), dtype=torch.uint8)
61 | target = Image.fromarray(target.numpy())
62 | return image, target
63 |
64 |
65 | def _coco_remove_images_without_annotations(dataset, cat_list=None):
66 | def _has_valid_annotation(anno):
67 | # if it's empty, there is no annotation
68 | if len(anno) == 0:
69 | return False
70 | # if more than 1k pixels occupied in the image
71 | return sum(obj["area"] for obj in anno) > 1000
72 |
73 | assert isinstance(dataset, torchvision.datasets.CocoDetection)
74 | ids = []
75 | for ds_idx, img_id in enumerate(dataset.ids):
76 | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
77 | anno = dataset.coco.loadAnns(ann_ids)
78 | if cat_list:
79 | anno = [obj for obj in anno if obj["category_id"] in cat_list]
80 | if _has_valid_annotation(anno):
81 | ids.append(ds_idx)
82 |
83 | dataset = torch.utils.data.Subset(dataset, ids)
84 | return dataset
85 |
86 |
87 | def get_coco_dataset(root, image_set, transforms):
88 | PATHS = {
89 | "train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
90 | "val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
91 | # "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
92 | }
93 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
94 | 1, 64, 20, 63, 7, 72]
95 |
96 | transforms = Compose([
97 | FilterAndRemapCocoCategories(CAT_LIST, remap=True),
98 | ConvertCocoPolysToMask(),
99 | transforms
100 | ])
101 |
102 | img_folder, ann_file = PATHS[image_set]
103 | img_folder = os.path.join(root, img_folder)
104 | ann_file = os.path.join(root, ann_file)
105 |
106 | dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
107 |
108 | if image_set == "train":
109 | dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
110 |
111 | return dataset
112 |
--------------------------------------------------------------------------------
/configs/cityscapes_regnety40_160epochs_mixed_precision.yaml:
--------------------------------------------------------------------------------
1 | #MODEL:
2 | model_name: regnety_040
3 | num_classes: 19
4 | pretrained_backbone: True
5 | separable_convolution: False
6 |
7 | #OPTIM:
8 | epochs: 160
9 | resume: False
10 | lr: 0.01
11 | momentum: 0.9
12 | weight_decay: 0.0001
13 | class_weight: null
14 |
15 | #TRAIN:
16 | batch_size: 16
17 | train_size: 481
18 | mixed_precision: True
19 |
20 | #TEST:
21 | val_size: 513
22 |
23 | #benchmark
24 | warmup_iter: 3
25 | num_iter: 30
26 |
27 | save_every_k_epochs: 3
28 | save_last_k_epochs: 8
29 | dataset_name: cityscapes
30 | dataset_dir: cityscapes_dataset
31 | aug_mode: baseline
32 | pretrained_path: ''
33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/cityscapes_regnety40_latest
34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/cityscapes_regnety40
35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/cityscapes_regnety40_latest
36 |
37 | #save_best_path: checkpoints/voc_regnety40_mixed_precision
38 | #save_latest_path: checkpoints/voc_regnety40_mixed_precision_latest
39 |
--------------------------------------------------------------------------------
/configs/configs_sanity_check.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 | def f():
4 | filename="voc_resnet50d_30epochs.yaml"
5 | with open(filename) as file:
6 | dic=yaml.full_load(file)
7 | print(dic)
8 |
9 | if __name__=="__main__":
10 | f()
11 |
--------------------------------------------------------------------------------
/configs/voc_mobilenetv2_30epochs.yaml:
--------------------------------------------------------------------------------
1 | #MODEL:
2 | model_name: mobilenetv2
3 | num_classes: 21
4 | pretrained_backbone: True
5 | separable_convolution: False
6 |
7 | #OPTIM:
8 | epochs: 30
9 | resume: False
10 | lr: 0.01
11 | momentum: 0.9
12 | weight_decay: 0.0001
13 | class_weight: null
14 |
15 | #TRAIN:
16 | batch_size: 16
17 | train_size: 481
18 | mixed_precision: False
19 |
20 | #TEST:
21 | eval_steps: 1000
22 | val_size: 513
23 |
24 | #benchmark
25 | warmup_iter: 3
26 | num_iter: 30
27 |
28 | save_every_k_epochs: 3
29 | dataset_name: pascal_voc
30 | dataset_dir: pascal_voc_dataset
31 | aug_mode: baseline
32 | pretrained_path: ''
33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_mobilenetv2_latest
34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_mobilenetv2
35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_mobilenetv2_latest
36 |
--------------------------------------------------------------------------------
/configs/voc_regnetx40_30epochs.yaml:
--------------------------------------------------------------------------------
1 | #MODEL:
2 | model_name: regnetx_040
3 | num_classes: 21
4 | pretrained_backbone: True
5 | separable_convolution: True
6 |
7 | #OPTIM:
8 | epochs: 30
9 | resume: False
10 | lr: 0.01
11 | momentum: 0.9
12 | weight_decay: 0.0001
13 | class_weight: null
14 |
15 | #TRAIN:
16 | batch_size: 16
17 | train_size: 481
18 | mixed_precision: False
19 |
20 | #TEST:
21 | eval_steps: 1000
22 | val_size: 513
23 |
24 | #benchmark
25 | warmup_iter: 3
26 | num_iter: 30
27 |
28 | save_every_k_epochs: 3
29 | dataset_name: pascal_voc
30 | dataset_dir: pascal_voc_dataset
31 | aug_mode: baseline
32 | pretrained_path: ''
33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnetx40_latest
34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnetx40
35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnetx40_latest
36 |
--------------------------------------------------------------------------------
/configs/voc_regnety40_30epochs.yaml:
--------------------------------------------------------------------------------
1 | #MODEL:
2 | model_name: regnety_040
3 | num_classes: 21
4 | pretrained_backbone: True
5 | separable_convolution: True
6 |
7 | #OPTIM:
8 | epochs: 30
9 | resume: False
10 | lr: 0.01
11 | momentum: 0.9
12 | weight_decay: 0.0001
13 | class_weight: null
14 |
15 | #TRAIN:
16 | batch_size: 16
17 | train_size: 481
18 | mixed_precision: False
19 |
20 | #TEST:
21 | eval_steps: 1000
22 | val_size: 513
23 |
24 | #benchmark
25 | warmup_iter: 3
26 | num_iter: 30
27 |
28 | save_every_k_epochs: 3
29 | dataset_name: pascal_voc
30 | dataset_dir: pascal_voc_dataset
31 | aug_mode: baseline
32 | pretrained_path: ''
33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_latest
34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40
35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_latest
36 |
--------------------------------------------------------------------------------
/configs/voc_regnety40_30epochs_mixed_precision.yaml:
--------------------------------------------------------------------------------
1 | #MODEL:
2 | model_name: regnety_040
3 | num_classes: 21
4 | pretrained_backbone: True
5 | separable_convolution: False
6 |
7 | #OPTIM:
8 | epochs: 30
9 | resume: False
10 | lr: 0.01
11 | momentum: 0.9
12 | weight_decay: 0.0001
13 | class_weight: null
14 |
15 | #TRAIN:
16 | batch_size: 16
17 | train_size: 481
18 | mixed_precision: True
19 |
20 | #TEST:
21 | val_size: 513
22 |
23 | #benchmark
24 | warmup_iter: 3
25 | num_iter: 30
26 |
27 | save_every_k_epochs: 3
28 | save_last_k_epochs: 8
29 | dataset_name: pascal_voc
30 | dataset_dir: pascal_voc_dataset
31 | aug_mode: baseline
32 | pretrained_path: ''
33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_mixed_precision_latest
34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_mixed_precision
35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_regnety40_mixed_precision_latest
36 |
37 | #save_best_path: checkpoints/voc_regnety40_mixed_precision
38 | #save_latest_path: checkpoints/voc_regnety40_mixed_precision_latest
39 |
--------------------------------------------------------------------------------
/configs/voc_resnet50d_30epochs.yaml:
--------------------------------------------------------------------------------
1 | #MODEL:
2 | model_name: resnet50d
3 | num_classes: 21
4 | pretrained_backbone: True
5 | separable_convolution: True
6 |
7 | #OPTIM:
8 | epochs: 30
9 | resume: False
10 | lr: 0.01
11 | momentum: 0.9
12 | weight_decay: 0.0001
13 | class_weight: null
14 |
15 | #TRAIN:
16 | batch_size: 16
17 | train_size: 481
18 | mixed_precision: False
19 |
20 | #TEST:
21 | eval_steps: 1000
22 | val_size: 513
23 |
24 | #benchmark
25 | warmup_iter: 3
26 | num_iter: 30
27 |
28 | save_every_k_epochs: 3
29 | dataset_name: pascal_voc
30 | dataset_dir: pascal_voc_dataset
31 | aug_mode: baseline
32 | pretrained_path: ''
33 | resume_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_resnet50d_latest
34 | save_best_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_resnet50d
35 | save_latest_path: /content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_resnet50d_latest
36 |
--------------------------------------------------------------------------------
/configs/yoho.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 4, 6, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25]
12 | GROUP_WS: [64, 128, 256, 512]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.2
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 256
23 | TEST:
24 | DATASET: imagenet
25 | IM_SIZE: 256
26 | BATCH_SIZE: 200
27 | NUM_GPUS: 8
28 | OUT_DIR: .
29 |
--------------------------------------------------------------------------------
/custom_dataset.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch.utils.data as data
3 | import os
4 | import json
5 |
6 | class SegmentationDataset(data.Dataset):
7 | def __init__(self,root,image_set,transforms):
8 | # images, masks, json splits
9 | split_f=os.path.join(root,f"{image_set}.json")
10 | file_names=json.load(open(split_f, 'r'))
11 | root= os.path.expanduser(root)
12 | self.transforms=transforms
13 | image_dir = os.path.join(root, "Images")
14 | mask_dir=os.path.join(root, "Masks")
15 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
16 | self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
17 | assert (len(self.images) == len(self.masks))
18 |
19 | def __getitem__(self, index):
20 | img = Image.open(self.images[index]).convert('RGB')
21 | target = Image.open(self.masks[index])
22 | if self.transforms is not None:
23 | img, target = self.transforms(img, target)
24 |
25 | return img, target
26 |
27 | def __len__(self):
28 | return len(self.images)
29 |
30 | if __name__=='__main__':
31 | dataset=SegmentationDataset("yoho",image_set="train",transforms=None)
32 | print(dataset[1])
33 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import transforms as T
2 | from data_utils import *
3 | from cityscapes import Cityscapes
4 | from voc12 import Voc12Segmentation
5 | from coco_utils import get_coco_dataset
6 |
7 | def build_transforms(is_train, size, crop_size,mode="baseline"):
8 | mean = (0.485, 0.456, 0.406)
9 | std = (0.229, 0.224, 0.225)
10 | fill = tuple([int(v * 255) for v in mean])
11 | ignore_value = 255
12 | transforms=[]
13 | min_scale=1
14 | max_scale=1
15 | if is_train:
16 | min_scale=0.5
17 | max_scale=2
18 | transforms.append(T.RandomResize(int(min_scale*size),int(max_scale*size)))
19 | if is_train:
20 | if mode=="baseline":
21 | pass
22 | elif mode=="randaug":
23 | transforms.append(T.RandAugment(2,1/3,prob=1.0,fill=fill,ignore_value=ignore_value))
24 | elif mode=="custom1":
25 | transforms.append(T.ColorJitter(0.5,0.5,(0.5,2),0.05))
26 | transforms.append(T.AddNoise(10))
27 | transforms.append(T.RandomRotation((-10,10), mean=fill, ignore_value=0))
28 | else:
29 | raise NotImplementedError()
30 | transforms.append(
31 | T.RandomCrop(
32 | crop_size,crop_size,
33 | fill,
34 | ignore_value,
35 | random_pad=is_train
36 | ))
37 | transforms.append(T.RandomHorizontalFlip(0.5))
38 | transforms.append(T.ToTensor())
39 | transforms.append(T.Normalize(
40 | mean,
41 | std
42 | ))
43 | return T.Compose(transforms)
44 |
45 | def get_cityscapes(root,batch_size=16,val_size=513,train_size=481,mode="baseline",num_workers=4):
46 | train=Cityscapes(root, split="train", target_type="semantic", transforms=build_transforms(True, val_size, train_size,mode))
47 | val=Cityscapes(root, split="val", target_type="semantic", transforms=build_transforms(False, val_size, train_size,mode))
48 | train_loader = get_dataloader_train(train, batch_size,num_workers)
49 | val_loader = get_dataloader_val(val,num_workers)
50 | return train_loader,val_loader
51 | def get_coco(root,batch_size=16,val_size=513,train_size=481,mode="baseline",num_workers=4):
52 | train=get_coco_dataset(root, "train", build_transforms(True, val_size, train_size,mode))
53 | val=get_coco_dataset(root, "val", build_transforms(False, val_size, train_size,mode))
54 | train_loader = get_dataloader_train(train, batch_size,num_workers)
55 | val_loader = get_dataloader_val(val,num_workers)
56 | return train_loader, val_loader
57 | def get_pascal_voc(root,batch_size=16,val_size=513,train_size=481,mode="baseline",num_workers=4):
58 | download=False
59 | train = Voc12Segmentation(root, 'train_aug', build_transforms(True, val_size, train_size,mode),
60 | download)
61 | val = Voc12Segmentation(root, 'val', build_transforms(False, val_size, train_size,mode),
62 | download)
63 | train_loader = get_dataloader_train(train, batch_size,num_workers)
64 | val_loader = get_dataloader_val(val,num_workers)
65 | return train_loader, val_loader
66 |
67 | if __name__ == '__main__':
68 | train_loader, val_loader=get_pascal_voc("pascal_voc_dataset")
69 | print(iter(train_loader).__next__())
70 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | import numpy as np
4 |
5 | def cat_list(images, fill_value=0):
6 | max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
7 | batch_shape = (len(images),) + max_size
8 | batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
9 | for img, pad_img in zip(images, batched_imgs):
10 | pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
11 | return batched_imgs
12 |
13 | def collate_fn(batch):
14 | images, targets = list(zip(*batch))
15 | batched_imgs = cat_list(images, fill_value=0)
16 | batched_targets = cat_list(targets, fill_value=255)
17 | return batched_imgs, batched_targets
18 |
19 | def get_sampler(dataset,dataset_test,distributed=False):
20 | if distributed:
21 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
22 | test_sampler = torch.utils.data.distributed.DistributedSampler(
23 | dataset_test)
24 | else:
25 | train_sampler = torch.utils.data.RandomSampler(dataset)
26 | #train_sampler = torch.utils.data.SequentialSampler(dataset)
27 | test_sampler = torch.utils.data.SequentialSampler(dataset_test)
28 | return train_sampler,test_sampler
29 |
30 | def worker_init_fn(worker_id):
31 | from datetime import datetime
32 | np.random.seed(datetime.now().microsecond)
33 | def get_dataloader_train(dataset,batch_size,num_workers=4):
34 | #dataset = get_coco(image_folder, ann_file, "train",get_temp_transform())
35 | train_sampler = torch.utils.data.RandomSampler(dataset)
36 | data_loader = torch.utils.data.DataLoader(
37 | dataset, batch_size=batch_size,
38 | sampler=train_sampler, num_workers=num_workers,
39 | collate_fn=collate_fn, drop_last=True)
40 | return data_loader
41 |
42 | def get_dataloader_val(dataset_test,num_workers=4):
43 | test_sampler = torch.utils.data.SequentialSampler(dataset_test)
44 | data_loader_test = torch.utils.data.DataLoader(
45 | dataset_test, batch_size=1,
46 | sampler=test_sampler, num_workers=num_workers,
47 | collate_fn=collate_fn)
48 | return data_loader_test
49 |
50 | def find_mean_and_std(data_loader):
51 | K = 0.8
52 | n = 0.0
53 | Ex=torch.zeros(3).float()
54 | Ex2 = torch.zeros(3).float()
55 | count=0
56 | for image,_ in data_loader:
57 | count+=1
58 | assert len(image.size())==4
59 | image = image.transpose(0, 1).flatten(1)
60 | Ex += (image-K).sum(dim=1)
61 | Ex2 += ((image - K)**2).sum(dim=1)
62 | n +=image.size()[1]
63 | if count==1000:
64 | break
65 | mean=Ex/n+K
66 | variance=(Ex2 - (Ex * Ex)/n)/(n-1)
67 | std=variance.sqrt()
68 | return mean, std
69 |
70 | def find_class_weights(dataloader,num_classes):
71 | print_every=4
72 | class_weights=torch.zeros(num_classes)
73 | for count,(image,target) in enumerate(dataloader):
74 | class_weights+=torch.bincount(target[target=0)
259 | # b=torch.sum(x<0)
260 | # print(a,b)
261 | # x=m(x)
262 | # print(name,torch.mean(x),torch.std(x))
263 | # return x
264 | # def gradient2():
265 | # #model=timm.create_model('mobilenetv2_100',features_only=True,out_indices=(4,),output_stride=16)
266 | # #model=Deeplab3P(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False)
267 | # # pretrained_path='checkpoints/voc_mobilenetv2'
268 | # # model=Deeplab3P(name='mobilenetv2_100',num_classes=num_classes,pretrained=pretrained_path,sc=False).to(
269 | # # device)
270 | # #model=torchvision.models.vgg16()
271 | # model=torchvision.models.resnet50()
272 | # model=LittleNet()
273 | # model.train()
274 | # print(model)
275 | # x=torch.randn((2,10,128,128))
276 | # y=model(x)
277 | # loss=torch.mean(y**2)
278 | # loss.backward()
279 | # for name,p in model.named_parameters():
280 | # if p.grad is not None:
281 | # if "conv" in name:
282 | # print(name,float(torch.sum(torch.abs(p.grad))))
283 |
--------------------------------------------------------------------------------
/fov.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 |
3 | def field_of_vision(fs):
4 | k=1
5 | s=1
6 | for _k,_s in fs: # kernel size and stride
7 | k=k+(_k-1)*s
8 | s=s*_s
9 | return k,s
10 |
11 | def staged_net(ds):
12 | fs=[]
13 | for d in ds:
14 | for i in range(d):
15 | if i==0:
16 | fs.append([3,2])
17 | else:
18 | fs.append([3,1])
19 | return fs
20 |
21 |
22 | def test1():
23 | regnets=[
24 | [1,1,4,7],
25 | [1,2,7,12],
26 | [1,3,5,7],
27 | [1,3,7,5],
28 | [2,4,10,2],
29 | [2,6,15,2],
30 | [2,5,14,2],
31 | [2,4,10,1],
32 | [2,5,15,1],
33 | [2,5,11,1],
34 | [2,6,13,1],
35 | [2,7,13,1]
36 | ]
37 | for ds in regnets:
38 | fs=staged_net(ds)
39 | fs=[(3,2)]+fs
40 | k,s=field_of_vision(fs)
41 | print(k,s)
42 | def test2():
43 | resnets=[
44 | [2, 2, 2, 2],
45 | [3, 4, 6, 3],
46 | [3, 4, 6, 3],
47 | [3, 4, 23, 3],
48 | [3, 8, 36, 3]
49 | ]
50 | for ds in resnets:
51 | ds[0]=ds[0]+1
52 | fs=staged_net(ds)
53 | fs=[(7,2)]+fs
54 | #fs=[(3,2),(3,1),(3,1)]+fs
55 | k,s=field_of_vision(fs)
56 | print(k,s)
57 | def test3():
58 | inverted_residual_setting = [
59 | # t, c, n, s
60 | [1, 16, 1, 1],
61 | [6, 24, 2, 2],
62 | [6, 32, 3, 2],
63 | [6, 64, 4, 2],
64 | [6, 96, 3, 1],
65 | [6, 160, 3, 2],
66 | [6, 320, 1, 1],
67 | ]
68 | fs=[]
69 | mobilenetv2=[2,3,7,4]
70 | mobilenetv3_large_fs=[
71 | (3,2),(3,1),(3,2),(3,1),(5,2),(5,1),(5,1),(3,2),
72 | (3,1),(3,1),(3,1),(3,1),(3,1),(5,2),(5,1),(5,1)]
73 | fs=staged_net(mobilenetv2)
74 | fs=[(3,2),(3,1)]+fs
75 | k,s=field_of_vision(fs)
76 | print(k,s)
77 | k,s=field_of_vision(mobilenetv3_large_fs)
78 | print(k,s)
79 |
80 |
81 | if __name__=="__main__":
82 | test1()
83 | print()
84 | test2()
85 | print()
86 | test3()
87 | #torchvision.models.resnet50()
88 | #model=torchvision.models.mobilenet_v2()
89 | # fs=[(3,2),(1,1),(3,2),(1,1)]
90 | # k,s=field_of_vision(fs)
91 | # print(k,s)
92 |
--------------------------------------------------------------------------------
/google0ccf7212e9c814a7.html:
--------------------------------------------------------------------------------
1 | google-site-verification: google0ccf7212e9c814a7.html
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torchvision
2 | from torch import nn
3 | import torch
4 | from torch.nn import functional as F
5 | from math import log2
6 | from Deeplabv3 import DeepLabHead,DeepLabHeadNoASSP,get_ASSP,convert_to_separable_conv, ASPP
7 | import timm
8 |
9 | def replace(output_stride=8):
10 | d=[False,False,False]
11 | n=int(log2(32/output_stride))
12 | assert n<=3,'output_stride too small'
13 | for i in range(n):
14 | d[2-i]=True
15 | return d
16 |
17 | class Deeplab3(nn.Module):
18 | def __init__(self,name="mobilenetv2_100",num_classes=21,pretrained="",
19 | pretrained_backbone=True,aspp=True):
20 | super(Deeplab3,self).__init__()
21 | output_stride = 16
22 | self.backbone=timm.create_model(name, features_only=True,
23 | output_stride=output_stride, out_indices=(4,),pretrained=pretrained_backbone and pretrained =="")
24 | channels=self.backbone.feature_info.channels()
25 | if aspp:
26 | self.head=DeepLabHead(channels[0], num_classes,output_stride)
27 | else:
28 | self.head=DeepLabHeadNoASSP(channels[0], num_classes)
29 | if pretrained != "":
30 | dic = torch.load(pretrained, map_location='cpu')
31 | if type(dic)==dict:
32 | self.load_state_dict(dic['model'])
33 | else:
34 | self.load_state_dict(dic)
35 | def forward(self,x):
36 | input_shape = x.shape[-2:]
37 | x=self.backbone(x)
38 | x=self.head(x[0])
39 | x = F.interpolate(x, size=input_shape, mode='bilinear',align_corners=False)
40 | return x
41 |
42 | class Deeplab3P(nn.Module):
43 | def __init__(self, name="mobilenetv2_100",num_classes=21,pretrained="",
44 | pretrained_backbone=True,sc=False,filter_multiplier=1.0):
45 | super(Deeplab3P,self).__init__()
46 | output_stride = 16
47 | num_filters = int(256*filter_multiplier)
48 | num_low_filters = int(48*filter_multiplier)
49 | try:
50 | self.backbone=timm.create_model(name, features_only=True,
51 | output_stride=output_stride, out_indices=(1, 4),pretrained=pretrained_backbone and pretrained =="")
52 | except RuntimeError:
53 | print("no model")
54 | print(timm.list_models())
55 | raise RuntimeError()
56 | channels=self.backbone.feature_info.channels()
57 | self.head16=get_ASSP(channels[1], output_stride,num_filters)
58 | self.head4=torch.nn.Sequential(
59 | nn.Conv2d(channels[0], num_low_filters, 1, bias=False),
60 | nn.BatchNorm2d(num_low_filters),
61 | nn.ReLU(inplace=True))
62 | self.decoder= nn.Sequential(
63 | nn.Conv2d(num_low_filters+num_filters, num_filters, 3, padding=1, bias=False),
64 | nn.BatchNorm2d(num_filters),
65 | nn.ReLU(inplace=True),
66 | nn.Conv2d(num_filters, num_filters, 3, padding=1, bias=False),
67 | nn.BatchNorm2d(num_filters),
68 | nn.ReLU(inplace=True),
69 | nn.Conv2d(num_filters, num_classes, 1)
70 | )
71 | if sc:
72 | self.decoder = convert_to_separable_conv(self.decoder)
73 | if pretrained != "":
74 | dic = torch.load(pretrained, map_location='cpu')
75 | if type(dic)==dict:
76 | self.load_state_dict(dic['model'])
77 | else:
78 | self.load_state_dict(dic)
79 |
80 | def forward(self, x):
81 | input_shape = x.shape[-2:]
82 | features = self.backbone(x)
83 | x=self.head16(features[1])
84 | x2=self.head4(features[0])
85 | intermediate_shape=x2.shape[-2:]
86 | x = F.interpolate(x, size=intermediate_shape, mode='bilinear',align_corners=False)
87 | x=torch.cat((x,x2),dim=1)
88 | x=self.decoder(x)
89 | x = F.interpolate(x, size=input_shape, mode='bilinear',align_corners=False)
90 | return x
91 |
92 | def profile(model,device,num_iter=15):
93 | import time
94 | model=model.to(device)
95 | model.eval()
96 | x=torch.randn(2,3,321,321).to(device)
97 | t1=time.time()
98 | for i in range(num_iter):
99 | y=model(x)
100 | t2=time.time()
101 | return (t2-t1)/num_iter
102 |
103 | def profiler(models):
104 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
105 | model = torchvision.models.resnet50()
106 | print("warming up")
107 | profile(model, device,num_iter=15)
108 | for model in models:
109 | seconds=profile(model,device,num_iter=30)
110 | print(round(seconds,3))
111 |
112 | def total_params(models):
113 | for model in models:
114 | total = sum(p.numel() for p in model.parameters() if p.requires_grad)
115 | total=round(total/1000000,2)
116 | print(f"{model.__class__.__name__}: {total}M")
117 |
118 | def profiler2(models):
119 | from ptflops import get_model_complexity_info
120 | for model in models:
121 | macs, params = get_model_complexity_info(model, (3, 480, 480),
122 | as_strings=True,
123 | print_per_layer_stat=False,
124 | verbose=False)
125 | print(f"{model.__class__.__name__}: {macs}, {params}")
126 |
127 | def memory_used(device):
128 | x=torch.cuda.memory_allocated(device)
129 | return round(x/1024/1024)
130 | def max_memory_used(device):
131 | x=torch.cuda.max_memory_allocated(device)
132 | return round(x/1024/1024)
133 | def memory_test_helper(model,device):
134 | model=model.to(device)
135 | model.train()
136 | N=16
137 | x=torch.randn(N, 3, 481, 481).to(device)
138 | target=torch.randint(0,21,(N, 481, 481)).to(device)
139 | t1=memory_used(device)
140 | out=model(x)
141 | loss=nn.functional.cross_entropy(out,target,ignore_index=255)
142 | loss.backward()
143 | t2=max_memory_used(device)
144 | print(t2-t1)
145 | torch.cuda.reset_peak_memory_stats(device)
146 |
147 | def memory_test(models,device):
148 | for i in range(len(models)):
149 | try:
150 | memory_test_helper(models[0],device)
151 | print()
152 | except:
153 | print("out of memory")
154 | for p in models[0].parameters():
155 | p.grad=None
156 | del models[0]
157 |
158 | def test_separable():
159 | model1=nn.Sequential(nn.Conv2d(3,256, 1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),
160 | nn.Conv2d(256,256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True))
161 | model2=nn.Sequential(nn.Conv2d(3,256, 1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),
162 | nn.Conv2d(256,256, 3, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True))
163 | model2=convert_to_separable_conv(model2)
164 | model3=nn.Sequential(nn.Conv2d(3,256, 1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),
165 | nn.Conv2d(256,256, 3, padding=2, bias=False,dilation=2),nn.BatchNorm2d(256),nn.ReLU(inplace=True))
166 | model4=nn.Sequential(nn.Conv2d(3,256, 1, padding=0, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),
167 | nn.Conv2d(256,256, 3, padding=2, bias=False,dilation=2),nn.BatchNorm2d(256),nn.ReLU(inplace=True))
168 | model4=convert_to_separable_conv(model4)
169 | models=[
170 | model1,model2,model3,model4
171 | ]
172 | profiler(models)
173 | def test_fast():
174 | models=[
175 | Deeplab3P(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False),
176 | #Deeplab3P(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False,filter_multiplier=0.5),
177 | Deeplab3(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False),
178 | Deeplab3(name='mobilenetv2_100', num_classes=21,pretrained_backbone=False,aspp=False),
179 | Deeplab3P(name='resnet50d', num_classes=21,pretrained_backbone=False),
180 | #Deeplab3P(name='resnet50d', num_classes=21,pretrained_backbone=False,filter_multiplier=0.5),
181 | Deeplab3(name='resnet50d', num_classes=21,pretrained_backbone=False),
182 | Deeplab3(name='resnet50d', num_classes=21,pretrained_backbone=False,aspp=False),
183 | ]
184 | profiler(models)
185 |
186 | def test_models():
187 | names=[
188 | 'resnet50d',
189 | 'nf_regnet_b0',
190 | 'gernet_m',
191 | 'efficientnet_lite3',
192 | 'efficientnet_lite2'
193 | ]
194 | models = []
195 | for name in names:
196 | #models.append(Deeplab3P(name=name, num_classes=21,pretrained_backbone=False))
197 | models.append(timm.create_model(name, features_only=True,
198 | output_stride=16, out_indices=(4,)))
199 | profiler(models)
200 |
201 | if __name__=='__main__':
202 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
203 | #resnet50d 77.1 mIOU
204 | #regnetx_040 77.0 mIOU
205 | #regnety_040 78.6 mIOU
206 | #mobilenetv2_100 72.8 mIOU
207 | #regnetx_080 77.3 mIOU
208 |
209 | num_classes=21
210 | print(timm.list_models())
211 | model=timm.create_model('resnest50d')
212 | print(model)
213 | #experiment1()
214 | #test_fast()
215 | #test_models()
216 |
217 |
218 |
219 |
220 | # 0.897
221 | # 0.752
222 | # 0.614
223 | # 0.371
224 | # 0.251
225 |
226 | # regnety4G
227 | # 0.764
228 | # 0.691
229 | # 0.572
230 | # 0.289
231 | # 0.131
232 |
233 | # regnet
234 | # 1.016
235 | # 0.946
236 | # 0.853
237 | # 0.49
238 | # 0.181
239 |
240 |
241 | # 'mobilenetv2_100'
242 | # 0.34
243 | # 0.178
244 | # 0.167
245 | # 0.168
246 | # 0.167
247 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cityscapesscripts
2 | opencv-python
3 | pip-check-reqs
4 | pip-chill
5 | ptflops
6 | pycocotools
7 | pyqt5
8 | pyyaml
9 | timm
10 | typing-extensions
11 |
--------------------------------------------------------------------------------
/show.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from PIL import Image
4 | import torch
5 | from model import Deeplab3P
6 | import time
7 | from data import get_cityscapes,get_pascal_voc
8 | from cityscapes import Cityscapes
9 |
10 | mean = np.array([0.485, 0.456, 0.406])
11 | std = np.array([0.229, 0.224, 0.225])
12 |
13 | def get_colors():
14 | palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
15 | colors = torch.arange(255).view(-1, 1) * palette
16 | colors = (colors % 255).numpy().astype("uint8")
17 | return colors
18 | def get_colors_cityscapes():
19 | colors=np.zeros((256,3))
20 | colors[255]=[255,255,255]
21 | for c in Cityscapes.classes:
22 | if 0<=c.train_id<=18:
23 | colors[c.train_id]=c.color
24 | return colors.astype("uint8")
25 |
26 |
27 | def show_image(inp, title=None):
28 | """Imshow for Tensor."""
29 | inp = inp.numpy().transpose((1, 2, 0))
30 | inp = std * inp + mean
31 | inp = np.clip(inp, 0, 1)
32 | plt.imshow(inp)
33 | if title is not None:
34 | plt.title(title)
35 |
36 | def show_mask(images):
37 | colors=get_colors()
38 | r = Image.fromarray(images.byte().cpu().numpy())
39 | r.putpalette(colors)
40 | plt.imshow(r)
41 | def show_cityscapes_mask(images):
42 | colors=get_colors_cityscapes()
43 | r = Image.fromarray(images.byte().cpu().numpy())
44 | r.putpalette(colors)
45 | plt.imshow(r)
46 |
47 | def display(data_loader,show_mask,num_images=5,skip=4,images_per_line=6):
48 | images_so_far = 0
49 | fig = plt.figure(figsize=(6, 4))
50 | num_rows=int(np.ceil(num_images/images_per_line))
51 | data_loader = iter(data_loader)
52 | for _ in range(skip):
53 | next(data_loader)
54 | for images, targets in data_loader:
55 | for image, target in zip(images, targets):
56 | print(image.size(), target.size())
57 | plt.subplot(num_rows, 2*images_per_line, images_so_far + 1)
58 | plt.axis('off')
59 | show_image(image)
60 |
61 | plt.subplot(num_rows, 2*images_per_line, images_so_far + 2)
62 | plt.axis('off')
63 | show_mask(target)
64 |
65 | images_so_far += 2
66 | if images_so_far == 2 * num_images:
67 | plt.tight_layout()
68 | plt.show()
69 | return
70 | plt.tight_layout()
71 | plt.show()
72 | def show(model,data_loader,device,show_mask,num_images=5,skip=4,images_per_line=2):
73 | images_so_far=0
74 | model.eval()
75 | num_rows = int(np.ceil(num_images / images_per_line))
76 | fig=plt.figure(figsize=(8,4))
77 | data_loader=iter(data_loader)
78 | for _ in range(skip):
79 | next(data_loader)
80 | with torch.no_grad():
81 | for images, targets in data_loader:
82 | images, targets = images.to(device), targets.to(device)
83 | start=time.time()
84 | outputs = model(images)
85 | end=time.time()
86 | print(end-start)
87 | for image,target,output in zip(images,targets,outputs):
88 | output = output.argmax(0)
89 | print(image.size(),target.size(),output.size())
90 | plt.subplot(num_rows, 3*images_per_line, images_so_far+1)
91 | plt.axis('off')
92 | show_image(image)
93 |
94 | plt.subplot(num_rows, 3*images_per_line, images_so_far+2)
95 | plt.axis('off')
96 | show_mask(target)
97 |
98 | plt.subplot(num_rows,3*images_per_line,images_so_far+3)
99 | plt.axis('off')
100 | show_mask(output)
101 |
102 | images_so_far+=3
103 | if images_so_far==3*num_images:
104 | plt.tight_layout()
105 | plt.show()
106 | return
107 | plt.tight_layout()
108 | plt.show()
109 |
110 | def show_cityscapes():
111 | num_images=16
112 | images_per_line=4
113 | skip=0
114 | _,data_loader=get_cityscapes("cityscapes_dataset",16,train_size=481,val_size=513)
115 | display(data_loader,show_cityscapes_mask,num_images=num_images,skip=skip,images_per_line=images_per_line)
116 |
117 | if __name__=="__main__":
118 | show_cityscapes()
119 |
120 | #_,data_loader=get_pascal_voc("pascal_voc_dataset",16,train_size=385,val_size=385)
121 |
122 | #pretrained_path='checkpoints/voc_resnet50d_noise'
123 | #model=torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True).to(device)
124 | # model=Deeplab3P(name='resnet50d',num_classes=num_classes,pretrained=pretrained_path,sc=True).to(
125 | # device)
126 |
127 | #show(model,data_loader,device,show_mask,num_images=num_images,skip=skip,images_per_line=images_per_line)
128 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from model import Deeplab3P
2 | from benchmark import compute_time_full, compute_time_no_loader
3 | from data import get_cityscapes,get_pascal_voc,get_coco
4 | import datetime
5 | import time
6 |
7 | import torch
8 | import torch.utils.data
9 | from torch import nn
10 | import torch.nn.functional
11 | import yaml
12 | import torch.cuda.amp as amp
13 | import os
14 |
15 | class ConfusionMatrix(object):
16 | def __init__(self, num_classes):
17 | self.num_classes = num_classes
18 | self.mat = None
19 |
20 | def update(self, a, b):
21 | n = self.num_classes
22 | if self.mat is None:
23 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
24 | with torch.no_grad():
25 | k = (a >= 0) & (a < n)
26 | inds = n * a[k].to(torch.int64) + b[k]
27 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
28 |
29 | def reset(self):
30 | self.mat.zero_()
31 |
32 | def compute(self):
33 | h = self.mat.float()
34 | acc_global = torch.diag(h).sum() / h.sum()
35 | acc = torch.diag(h) / h.sum(1)
36 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
37 | return acc_global, acc, iu
38 | def __str__(self):
39 | acc_global, acc, iu = self.compute()
40 | return (
41 | 'global correct: {:.1f}\n'
42 | 'average row correct: {}\n'
43 | 'IoU: {}\n'
44 | 'mean IoU: {:.1f}').format(
45 | acc_global.item() * 100,
46 | ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
47 | ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
48 | iu.mean().item() * 100)
49 |
50 |
51 | def criterion2(inputs, target, w):
52 | return nn.functional.cross_entropy(inputs,target,ignore_index=255,weight=w)
53 |
54 | def get_loss_fun(weight):
55 | return nn.CrossEntropyLoss(weight=weight,ignore_index=255)
56 |
57 | def evaluate(model, data_loader, device, num_classes,mixed_precision,print_every=100):
58 | model.eval()
59 | confmat = ConfusionMatrix(num_classes)
60 | with torch.no_grad():
61 | for i,(image, target) in enumerate(data_loader):
62 | if (i+1)%print_every==0:
63 | print(i+1)
64 | image, target = image.to(device), target.to(device)
65 | with amp.autocast(enabled=mixed_precision):
66 | output = model(image)
67 | confmat.update(target.flatten(), output.argmax(1).flatten())
68 |
69 | return confmat
70 |
71 | def train_one_epoch(model, loss_fun, optimizer, loader, lr_scheduler, device, print_freq,mixed_precision,scaler):
72 | model.train()
73 | losses=0
74 | for t, x in enumerate(loader):
75 | image, target=x
76 | image, target = image.to(device), target.to(device)
77 | with amp.autocast(enabled=mixed_precision):
78 | output = model(image)
79 | loss = loss_fun(output, target)
80 | optimizer.zero_grad()
81 | scaler.scale(loss).backward()
82 | scaler.step(optimizer)
83 | scaler.update()
84 | lr_scheduler.step()
85 | losses+=loss.item()
86 | if t % print_freq==0:
87 | print(t,loss.item())
88 | num_iter=len(loader)
89 | print(losses/num_iter)
90 |
91 | def save(model,optimizer,scheduler,epoch,path,best_mIU,scaler):
92 | dic={
93 | 'model': model.state_dict(),
94 | 'optimizer': optimizer.state_dict(),
95 | 'lr_scheduler': scheduler.state_dict(),
96 | 'scaler':scaler.state_dict(),
97 | 'epoch': epoch,
98 | 'best_mIU':best_mIU
99 | }
100 | torch.save(dic,path)
101 |
102 | def train(model, save_best_path,save_latest_path, epochs,optimizer, data_loader, data_loader_test, lr_scheduler, device,num_classes,save_best_on_epochs,loss_fun,mixed_precision,scaler,best_mIU):
103 | start_time = time.time()
104 | for epoch in epochs:
105 | print(f"epoch: {epoch}")
106 | train_one_epoch(model, loss_fun, optimizer, data_loader, lr_scheduler,
107 | device, print_freq=50,mixed_precision=mixed_precision,scaler=scaler)
108 | if epoch in save_best_on_epochs:
109 | confmat = evaluate(model, data_loader_test, device=device,
110 | num_classes=num_classes,mixed_precision=mixed_precision,print_every=100)
111 | print(confmat)
112 | acc_global, acc, iu = confmat.compute()
113 | mIU=iu.mean().item() * 100
114 | if mIU > best_mIU:
115 | best_mIU=mIU
116 | save(model, optimizer, lr_scheduler, epoch, save_best_path,best_mIU,scaler)
117 | if save_latest_path != "":
118 | save(model, optimizer, lr_scheduler, epoch, save_latest_path,best_mIU,scaler)
119 |
120 | total_time = time.time() - start_time
121 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
122 | print('Training time {}'.format(total_time_str))
123 |
124 | def get_dataset_loaders(config):
125 | name=config["dataset_name"]
126 | if name=="pascal_voc":
127 | f=get_pascal_voc
128 | elif name=="cityscapes":
129 | f=get_cityscapes
130 | elif name=="coco":
131 | f=get_coco
132 | else:
133 | raise NotImplementedError()
134 | mode="baseline"
135 | if "aug_mode" in config:
136 | mode=config["aug_mode"]
137 | data_loader, data_loader_test=f(config["dataset_dir"],config["batch_size"],train_size=config["train_size"],val_size=config["val_size"],mode=mode)
138 | print("train size:", len(data_loader))
139 | print("val size:", len(data_loader_test))
140 | return data_loader, data_loader_test
141 |
142 | def get_model(config):
143 | pretrained_backbone=config["pretrained_backbone"]
144 | if config["resume"]:
145 | pretrained_backbone=False
146 | return Deeplab3P(name=config["model_name"],
147 | num_classes=config["num_classes"],
148 | pretrained_backbone=pretrained_backbone,
149 | sc=config["separable_convolution"],
150 | pretrained=config["pretrained_path"])
151 |
152 | def get_config_and_check_files(config_filename):
153 | with open(config_filename) as file:
154 | config=yaml.full_load(file)
155 | save_best_dir=os.path.dirname(config["save_best_path"])
156 | save_latest_dir=os.path.dirname(config["save_latest_path"])
157 | if not os.path.isdir(save_best_dir):
158 | raise FileNotFoundError(f"{save_best_dir} is not a directory")
159 | if not os.path.isdir(save_latest_dir):
160 | raise FileNotFoundError(f"{save_latest_dir} is not a directory")
161 | if not os.path.isdir(config["dataset_dir"]):
162 | raise FileNotFoundError(f"{config['dataset_dir']} is not a directory")
163 | if config["resume"]:
164 | if not os.path.isfile(config["resume_path"]):
165 | raise FileNotFoundError(f"{config['resume_path']} is not a file")
166 | elif not config["pretrained_backbone"]:
167 | if not os.path.isfile(config["pretrained_path"]):
168 | raise FileNotFoundError(f"{config['pretrained_path']} is not a file")
169 | return config
170 |
171 | def get_epochs_to_save(config):
172 | epochs=config["epochs"]
173 | save_every_k_epochs=config["save_every_k_epochs"]
174 | save_best_on_epochs=[i*save_every_k_epochs-1 for i in range(1,epochs//save_every_k_epochs+1)]
175 | if epochs-1 not in save_best_on_epochs:
176 | save_best_on_epochs.append(epochs-1)
177 | if "save_last_k_epochs" in config:
178 | for i in range(epochs-config["save_last_k_epochs"],epochs):
179 | if i not in save_best_on_epochs:
180 | save_best_on_epochs.append(i)
181 | save_best_on_epochs=sorted(save_best_on_epochs)
182 | return save_best_on_epochs
183 |
184 | def main2(config_filename):
185 | config=get_config_and_check_files(config_filename)
186 | torch.backends.cudnn.benchmark=True
187 | save_best_path=config["save_best_path"]
188 | save_latest_path=config["save_latest_path"]
189 | epochs=config["epochs"]
190 | num_classes=config["num_classes"]
191 | class_weight=config["class_weight"]
192 | mixed_precision=config["mixed_precision"]
193 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
194 | data_loader, data_loader_test=get_dataset_loaders(config)
195 | model=get_model(config).to(device)
196 | params_to_optimize=model.parameters()
197 | optimizer = torch.optim.SGD(params_to_optimize, lr=config["lr"],
198 | momentum=config["momentum"], weight_decay=config["weight_decay"])
199 | scaler = amp.GradScaler(enabled=mixed_precision)
200 | loss_fun=get_loss_fun(class_weight)
201 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
202 | optimizer,lambda x: (1 - x / (len(data_loader) * epochs)) ** 0.9)
203 |
204 | epoch_start=0
205 | best_mIU=0
206 | save_best_on_epochs=get_epochs_to_save(config)
207 | print("save on epochs: ",save_best_on_epochs)
208 |
209 | if config["resume"]:
210 | dic=torch.load(config["resume_path"],map_location='cpu')
211 | model.load_state_dict(dic['model'])
212 | optimizer.load_state_dict(dic['optimizer'])
213 | lr_scheduler.load_state_dict(dic['lr_scheduler'])
214 | epoch_start = dic['epoch'] + 1
215 | if "best_mIU" in dic:
216 | best_mIU=dic["best_mIU"]
217 | if "scaler" in dic:
218 | scaler.load_state_dict(dic["scaler"])
219 |
220 | train(model, save_best_path,save_latest_path, range(epoch_start,epochs),optimizer, data_loader,
221 | data_loader_test, lr_scheduler, device,num_classes,save_best_on_epochs,loss_fun,mixed_precision,scaler,best_mIU)
222 | def check3(config_filename):
223 | config=get_config_and_check_files(config_filename)
224 | torch.backends.cudnn.benchmark=True
225 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
226 | data_loader, data_loader_test=get_dataset_loaders(config)
227 | model=get_model(config).to(device)
228 | num_classes=config["num_classes"]
229 | mixed_precision=config["mixed_precision"]
230 | print("evaluating")
231 | confmat = evaluate(model, data_loader_test, device=device,
232 | num_classes=num_classes,mixed_precision=mixed_precision)
233 | print(confmat)
234 |
235 | # def check():
236 | # device = torch.device(
237 | # 'cuda') if torch.cuda.is_available() else torch.device('cpu')
238 | # num_classes = 21
239 | # pretrained_path='/content/drive/My Drive/Colab Notebooks/SemanticSegmentation/checkpoints/voc_resnet50d_noise2'
240 | # #voc_resnet50d_noise
241 | # data_loader, data_loader_test=get_pascal_voc("pascal_voc_dataset",16,train_size=481,val_size=513)
242 | # eval_steps = len(data_loader_test)
243 | # model=Deeplab3P(name="resnet50d",num_classes=num_classes,pretrained=pretrained_path,sc=True).to(
244 | # device)
245 | # print("evaluating")
246 | # confmat = evaluate(model, data_loader_test, device=device,
247 | # num_classes=num_classes,eval_steps=eval_steps,print_every=100)
248 | # print(confmat)
249 | # def check2():
250 | # device = torch.device(
251 | # 'cuda') if torch.cuda.is_available() else torch.device('cpu')
252 | # num_classes = 21
253 | # pretrained_path='checkpoints/voc_mobilenetv2'
254 | # #pretrained_path='checkpoints/voc_regnety40'
255 | # data_loader, data_loader_test=get_pascal_voc("pascal_voc_dataset",16,train_size=385,val_size=385)
256 | # eval_steps = 100
257 | # model=Deeplab3P(name='mobilenetv2_100',num_classes=num_classes,pretrained=pretrained_path,sc=False).to(
258 | # device)
259 | # print("evaluating")
260 | # confmat = evaluate(model, data_loader_test, device=device,
261 | # num_classes=num_classes,eval_steps=eval_steps,print_every=5)
262 | # print(confmat)
263 |
264 | def benchmark(config_filename):
265 | config=get_config_and_check_files(config_filename)
266 | torch.backends.cudnn.benchmark=True
267 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
268 | mixed_precision=config["mixed_precision"]
269 | warmup_iter=config["warmup_iter"]
270 | num_iter=config["num_iter"]
271 | crop_size=config["train_size"]
272 | batch_size=config["batch_size"]
273 | num_classes=config["num_classes"]
274 | model=get_model(config).to(device)
275 | #data_loader, data_loader_test=get_dataset_loaders(config)
276 | #dic=compute_time_full(model,data_loader,warmup_iter,num_iter,device,crop_size,batch_size,num_classes,mixed_precision)
277 | dic=compute_time_no_loader(model,warmup_iter,num_iter,device,crop_size,batch_size,num_classes,mixed_precision)
278 | for k,v in dic.items():
279 | print(f"{k}: {v}")
280 |
281 | if __name__=='__main__':
282 | #validation example nums
283 | #config_filename="PyTorch_DeepLab/configs/voc_regnety40_30epochs_mixed_precision.yaml"
284 | config_filename2="configs/voc_regnety40_30epochs_mixed_precision.yaml"
285 | config=get_config_and_check_files(config_filename2)
286 | print(config)
287 | print(get_epochs_to_save(config))
288 | #benchmark("PyTorch_DeepLab/configs/voc_regnety40_30epochs_mixed_precision.yaml")
289 | #main2("PyTorch_DeepLab/configs/voc_regnety40_30epochs_mixed_precision.yaml")
290 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from torchvision.transforms import functional as F
4 | from PIL import Image
5 | import torchvision.transforms as T
6 | import torch
7 | from augment import apply_op_both,rand_augment_both
8 |
9 | class Compose(object):
10 | """
11 | Composes a sequence of transforms.
12 | Arguments:
13 | transforms: A list of transforms.
14 | """
15 | def __init__(self, transforms):
16 | self.transforms = transforms
17 |
18 | def __call__(self, image, label):
19 | for t in self.transforms:
20 | image, label = t(image, label)
21 | return image, label
22 |
23 | def __repr__(self):
24 | format_string = self.__class__.__name__ + "("
25 | for t in self.transforms:
26 | format_string += "\n"
27 | format_string += " {0}".format(t)
28 | format_string += "\n)"
29 | return format_string
30 |
31 |
32 | class ToTensor(object):
33 | def __call__(self, image, target):
34 | image = F.to_tensor(image)
35 | target = torch.as_tensor(np.array(target), dtype=torch.int64)
36 | return image, target
37 |
38 | class RandAugment:
39 | def __init__(self,N,M,prob=1.0,fill=(128,128,128),ignore_value=255):
40 | self.N=N
41 | self.M=M
42 | self.prob=prob
43 | self.fill=fill
44 | self.ignore_value=ignore_value
45 | def __call__(self, image, target):
46 | return rand_augment_both(image,target,n_ops=self.N,magnitude=self.M,prob=self.prob,fill=self.fill,ignore_value=self.ignore_value)
47 |
48 |
49 | class Normalize(object):
50 | """
51 | Normalizes image by mean and std.
52 | """
53 | def __init__(self, mean, std):
54 | self.mean = mean
55 | self.std = std
56 |
57 | def __call__(self, image, label):
58 | image = F.normalize(image, mean=self.mean, std=self.std)
59 | return image, label
60 |
61 | class RandomResize(object):
62 | def __init__(self, min_size, max_size=None):
63 | self.min_size = min_size
64 | if max_size is None:
65 | max_size = min_size
66 | self.max_size = max_size
67 |
68 | def __call__(self, image, target):
69 | size = random.randint(self.min_size, self.max_size)
70 | image = F.resize(image, size)
71 | target = F.resize(target, size, interpolation=F.InterpolationMode.NEAREST)
72 | return image, target
73 |
74 | class ColorJitter:
75 | def __init__(self,brightness=0.2, contrast=0.2, saturation=(0.5,4), hue=0.2):
76 | self.jitter=T.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
77 | def __call__(self, image, target):
78 | image=self.jitter(image)
79 | return image,target
80 |
81 | class AddNoise:#additive gaussian noise
82 | def __init__(self,factor):
83 | self.factor=factor
84 | def __call__(self, image, target):
85 | factor = random.uniform(0, self.factor)
86 | image = np.array(image)
87 | assert(image.dtype==np.uint8)
88 | gauss = (np.array(torch.randn(*image.shape)) * factor).astype("uint8")
89 | noisy = (image + gauss).clip(0, 255)
90 | image = Image.fromarray(noisy)
91 | return image, target
92 |
93 | class RandomRotation:
94 | def __init__(self,degrees,mean,ignore_value=255):
95 | self.degrees=degrees
96 | self.mean=mean
97 | self.ignore_value=ignore_value
98 | def __call__(self, image, target):
99 | expand=True
100 | if random.random()<0.5:
101 | angle = random.uniform(*self.degrees)
102 | image=F.rotate(image, angle,fill=self.mean,expand=expand)
103 | target=F.rotate(target,angle,fill=self.ignore_value,expand=expand)
104 | return image,target
105 |
106 |
107 | class RandomScale(object):
108 | """
109 | Applies random scale augmentation.
110 | Arguments:
111 | min_scale: Minimum scale value.
112 | max_scale: Maximum scale value.
113 | scale_step_size: The step size from minimum to maximum value.
114 | """
115 | def __init__(self, min_scale, max_scale, scale_step_size):
116 | self.min_scale = min_scale
117 | self.max_scale = max_scale
118 | self.scale_step_size = scale_step_size
119 |
120 | @staticmethod
121 | def get_random_scale(min_scale_factor, max_scale_factor, step_size):
122 | """Gets a random scale value.
123 | Args:
124 | min_scale_factor: Minimum scale value.
125 | max_scale_factor: Maximum scale value.
126 | step_size: The step size from minimum to maximum value.
127 | Returns:
128 | A random scale value selected between minimum and maximum value.
129 | Raises:
130 | ValueError: min_scale_factor has unexpected value.
131 | """
132 | if min_scale_factor < 0 or min_scale_factor > max_scale_factor:
133 | raise ValueError('Unexpected value of min_scale_factor.')
134 |
135 | if min_scale_factor == max_scale_factor:
136 | return min_scale_factor
137 |
138 | # When step_size = 0, we sample the value uniformly from [min, max).
139 | if step_size == 0:
140 | return random.uniform(min_scale_factor, max_scale_factor)
141 |
142 | # When step_size != 0, we randomly select one discrete value from [min, max].
143 | num_steps = int((max_scale_factor - min_scale_factor) / step_size + 1)
144 | scale_factors = np.linspace(min_scale_factor, max_scale_factor, num_steps)
145 | np.random.shuffle(scale_factors)
146 | return scale_factors[0]
147 |
148 | def __call__(self, image, label):
149 | scale = self.get_random_scale(self.min_scale, self.max_scale, self.scale_step_size)
150 | img_w, img_h = image.size
151 | img_w,img_h=int(img_w*scale),int(img_h*scale)
152 | image=F.resize(image,[img_h,img_w])
153 | label=F.resize(label,[img_h,img_w],interpolation=F.InterpolationMode.NEAREST)
154 | return image,label
155 |
156 |
157 | class RandomCrop(object):
158 | def __init__(self, crop_h, crop_w, pad_value, ignore_label, random_pad):
159 | self.crop_h = crop_h
160 | self.crop_w = crop_w
161 | self.pad_value = pad_value
162 | self.ignore_label = ignore_label
163 | self.random_pad = random_pad
164 |
165 | def __call__(self, image, label):
166 | img_w,img_h=image.size
167 | pad_h = max(self.crop_h - img_h, 0)
168 | pad_w = max(self.crop_w - img_w, 0)
169 | if pad_h > 0 or pad_w > 0:
170 | if self.random_pad:
171 | pad_top = random.randint(0, pad_h)
172 | pad_bottom = pad_h - pad_top
173 | pad_left = random.randint(0, pad_w)
174 | pad_right = pad_w - pad_left
175 | else:
176 | pad_top, pad_bottom, pad_left, pad_right = 0, pad_h, 0, pad_w
177 | image = F.pad(image, (pad_left, pad_top, pad_right, pad_bottom), fill=self.pad_value)
178 | label= F.pad(label, (pad_left, pad_top, pad_right, pad_bottom), fill=self.ignore_label)
179 |
180 | crop_params = T.RandomCrop.get_params(image, (self.crop_h, self.crop_w))
181 | image = F.crop(image, *crop_params)
182 | label = F.crop(label, *crop_params)
183 | return image,label
184 |
185 |
186 | class RandomHorizontalFlip(object):
187 | def __init__(self, flip_prob):
188 | self.flip_prob = flip_prob
189 |
190 | def __call__(self, image, target):
191 | if random.random() < self.flip_prob:
192 | image = F.hflip(image)
193 | target = F.hflip(target)
194 | return image, target
195 |
196 | def f():
197 | image=np.zeros((50,50),dtype=np.uint8)
198 | assert(image.dtype==np.uint8)
199 | gauss=(np.array(torch.randn(50,50)*10)).astype("uint8")
200 | noisy = (image + gauss).clip(0, 255)
201 | # print(noisy.dtype)
202 | # factor = random.uniform(0, 10)
203 | # image = Image.open("cityscapes_dataset/leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png")
204 | # image=np.array(image)
205 | # gauss = np.array(torch.randn(*image.shape)) * factor
206 | # noisy = (image + gauss).clip(0, 255).astype("uint8")
207 |
208 | def g():
209 | image=np.zeros((50,50))
210 | gauss=np.array(torch.randn(50,50)*10)
211 | noisy = (image + gauss).clip(0, 255).astype("uint8")
212 | # print(noisy.dtype)
213 | # factor = random.uniform(0, 10)
214 | # image = Image.open("cityscapes_dataset/leftImg8bit/train/aachen/aachen_000000_000019_leftImg8bit.png")
215 | # image=np.array(image)
216 | # gauss = np.array(torch.randn(*image.shape)) * factor
217 | # print(gauss.dtype)
218 | # noisy = (image + gauss).clip(0, 255).astype("uint8")
219 |
220 |
221 | if __name__=='__main__':
222 | f()
223 | g()
224 | #import timeit
225 | #print(timeit.timeit('f()', globals=globals(), number=100))
226 | #print(timeit.timeit('g()', globals=globals(), number=100))
227 |
--------------------------------------------------------------------------------
/voc12.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tarfile
3 | import torch.utils.data as data
4 |
5 | from PIL import Image
6 | from torchvision.datasets.utils import download_url
7 |
8 | class Voc12Segmentation(data.Dataset):
9 | def __init__(self,root,image_set,transforms,download=False):
10 | self.root = os.path.expanduser(root)
11 | self.url='http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar'
12 | self.filename='VOCtrainval_11-May-2012.tar'
13 | self.md5='6cd6e144f989b92b3379bac3b3de84fd'
14 | self.base_dir='VOCdevkit/VOC2012'
15 | self.transforms=transforms
16 | voc_root = os.path.join(self.root, self.base_dir)
17 | image_dir = os.path.join(voc_root, 'JPEGImages')
18 | if download:
19 | download_extract(self.url, self.root, self.filename, self.md5)
20 | if not os.path.isdir(voc_root):
21 | raise RuntimeError(f'{voc_root} not found')
22 | if image_set == 'train_aug':
23 | mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
24 | split_f = os.path.join(voc_root, f'ImageSets/Segmentation/{image_set}.txt')
25 | else:
26 | mask_dir = os.path.join(voc_root, 'SegmentationClass')
27 | split_f = os.path.join(voc_root, f'ImageSets/Segmentation/{image_set}.txt')
28 | if not os.path.exists(split_f):
29 | raise RuntimeError(f'{split_f} not found')
30 | with open(split_f, "r") as f:# os.path.join(split_f)
31 | file_names = [x.strip() for x in f.readlines()]
32 | self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
33 | self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
34 | assert (len(self.images) == len(self.masks))
35 |
36 | def __getitem__(self, index):
37 | img = Image.open(self.images[index]).convert('RGB')
38 | target = Image.open(self.masks[index])
39 | if self.transforms is not None:
40 | img, target = self.transforms(img, target)
41 |
42 | return img, target
43 |
44 | def __len__(self):
45 | return len(self.images)
46 |
47 | def download_extract(url, root, filename, md5):
48 | download_url(url, root, filename, md5)
49 | with tarfile.open(os.path.join(root, filename), "r") as tar:
50 | tar.extractall(path=root)
51 |
--------------------------------------------------------------------------------