├── model
├── __init__.py
├── discriminator.py
├── feature_extractor.py
├── classifier.py
└── resnet.py
├── utils
├── __init__.py
├── dropout.py
├── loss.py
├── util.py
└── transform.py
├── figs
└── fig.png
├── .gitignore
├── LICENSE
├── datasets
└── get_thresholds.py
├── README.md
├── data
├── __init__.py
├── base_dataset.py
├── gta5_dataset.py
├── cityscapes_val_dataset.py
├── cityscapes_train_dataset.py
├── randaugment.py
└── augmentations.py
├── generate_soft_label.py
├── train_phase2.py
└── train_phase1.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figs/fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/DecoupleNet/HEAD/figs/fig.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | snapshots/
2 | *.pth
3 | *__pycache__*
4 | debug/
5 | class_balance_ids_*.p
6 | core.*
7 | datasets/pseudo_labels*/
8 | datasets/soft_labels*/
9 | output/
10 | check.py
11 | data/class_balance_ids_*.p
12 | data/cityscapes_class_balance_ids_*.p
13 | *.pickle
14 | */*.zip
15 | */*.npy
16 | slurm_cmd/
17 | pretrained/
18 | *_soft_labels/
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 DV Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class FCDiscriminator(nn.Module):
6 |
7 | def __init__(self, num_classes, ndf = 64):
8 | super(FCDiscriminator, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
11 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1)
12 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1)
13 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1)
14 | self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1)
15 |
16 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
17 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear')
18 | #self.sigmoid = nn.Sigmoid()
19 |
20 |
21 | def forward(self, x):
22 | x = self.conv1(x)
23 | x = self.leaky_relu(x)
24 | x = self.conv2(x)
25 | x = self.leaky_relu(x)
26 | x = self.conv3(x)
27 | x = self.leaky_relu(x)
28 | x = self.conv4(x)
29 | x = self.leaky_relu(x)
30 | x = self.classifier(x)
31 | #x = self.up_sample(x)
32 | #x = self.sigmoid(x)
33 |
34 | return x
35 |
--------------------------------------------------------------------------------
/model/feature_extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torchvision.models._utils import IntermediateLayerGetter
5 | from . import resnet
6 |
7 | class FrozenBatchNorm2d(nn.Module):
8 | """
9 | BatchNorm2d where the batch statistics and the affine parameters
10 | are fixed
11 | """
12 |
13 | def __init__(self, n):
14 | super(FrozenBatchNorm2d_v2, self).__init__()
15 | self.register_buffer("weight", torch.ones(n))
16 | self.register_buffer("bias", torch.zeros(n))
17 | self.register_buffer("running_mean", torch.zeros(n))
18 | self.register_buffer("running_var", torch.ones(n))
19 |
20 | def forward(self, x):
21 | output = F.batch_norm(x, self.running_mean, self.running_var, weight=self.weight, bias=self.bias, training=False)
22 | return output
23 |
24 | class resnet_feature_extractor(nn.Module):
25 | def __init__(self, backbone_name, pretrained_weights=None, aux=False, pretrained_backbone=True, freeze_bn=False):
26 | super(resnet_feature_extractor, self).__init__()
27 | bn_layer = nn.BatchNorm2d
28 | if freeze_bn:
29 | bn_layer = FrozenBatchNorm2d
30 | backbone = resnet.__dict__[backbone_name](
31 | pretrained=pretrained_backbone,
32 | replace_stride_with_dilation=[False, True, True], pretrained_weights=pretrained_weights, norm_layer=bn_layer)
33 | return_layers = {'layer4': 'out'}
34 | if aux:
35 | return_layers['layer3'] = 'aux'
36 | self.aux = aux
37 | self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
38 |
39 | def forward(self, x):
40 | if self.aux == True:
41 | output = self.backbone(x)
42 | aux, out = output['aux'], output['out']
43 | return aux, out
44 | else:
45 | out = self.backbone(x)['out']
46 | return out
47 |
--------------------------------------------------------------------------------
/datasets/get_thresholds.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import numpy as np
4 | import pickle
5 | import matplotlib.pyplot as plt
6 | from scipy.special import softmax
7 | import sys
8 |
9 | # python3 get_thresholds.py 0.8 gta2city_soft_labels
10 |
11 | p = float(sys.argv[1])
12 | npy_dir = sys.argv[2]
13 | save_path = "./{}_cls2prob.pickle".format(sys.argv[2])
14 | output_path = "./{}_thresholds_p{}.npy".format(sys.argv[2], p)
15 | ignore_label = 250
16 |
17 | if not os.path.exists(save_path):
18 | cls2prob = {}
19 | files = glob.glob(os.path.join(npy_dir, "*.npy"))
20 | for i, npy_file in enumerate(files):
21 | if i % 100 == 0:
22 | print("i: {}/ {}".format(i, len(files)))
23 | f = np.load(npy_file) #[c, h, w]
24 | f = softmax(f, axis=0)
25 | classes = f.argmax(0) #[h, w]
26 | prob = f.max(0) #[h, w]
27 | for c in np.unique(classes):
28 | if c not in cls2prob:
29 | cls2prob[c] = []
30 | cls2prob[c].extend(prob[classes == c])
31 | for c in cls2prob:
32 | cls2prob[c].sort(reverse=True)
33 | # with open(save_path, "wb+") as f:
34 | # pickle.dump(cls2prob, f)
35 | else:
36 | with open(save_path, "rb") as f:
37 | cls2prob = pickle.load(f)
38 |
39 | class_list = ["road","sidewalk","building","wall",
40 | "fence","pole","traffic_light","traffic_sign","vegetation",
41 | "terrain","sky","person","rider","car",
42 | "truck","bus","train","motorcycle","bicycle"]
43 |
44 | # print("p: {}".format(p))
45 |
46 | thresholds = []
47 | for c in range(len(cls2prob.keys())):
48 | prob_c = cls2prob[c]
49 | rank = int(p * len(prob_c))
50 | thresh = prob_c[rank]
51 | thresholds.append(thresh)
52 | thresholds = np.array(thresholds)
53 |
54 | for i in range(len(thresholds)):
55 | print("i: {}, class i: {}, thresh_i: {}".format(i, class_list[i], thresholds[i]))
56 |
57 | np.save(output_path, thresholds)
58 |
--------------------------------------------------------------------------------
/model/classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import random
4 | import numpy as np
5 | from torch import nn
6 | from torchvision.models._utils import IntermediateLayerGetter
7 |
8 | class ASPP_Classifier(nn.Module):
9 | def __init__(self, in_channels, dilation_series, padding_series, num_classes):
10 | super(ASPP_Classifier, self).__init__()
11 | self.conv2d_list = nn.ModuleList()
12 | for dilation, padding in zip(dilation_series, padding_series):
13 | self.conv2d_list.append(
14 | nn.Conv2d(
15 | in_channels,
16 | num_classes,
17 | kernel_size=3,
18 | stride=1,
19 | padding=padding,
20 | dilation=dilation,
21 | bias=True,
22 | )
23 | )
24 |
25 | for m in self.conv2d_list:
26 | m.weight.data.normal_(0, 0.01)
27 |
28 | def forward(self, x, size=None):
29 | out = self.conv2d_list[0](x)
30 | for i in range(len(self.conv2d_list) - 1):
31 | out += self.conv2d_list[i + 1](x)
32 | if size is not None:
33 | out = F.interpolate(out, size=size, mode='bilinear', align_corners=True)
34 | return out
35 |
36 |
37 | class ASPP_Classifier_Gen(nn.Module):
38 | '''Generalized version of ASPP head'''
39 | def __init__(self, in_channels, dilation_series, padding_series, num_classes, hidden_dim=128):
40 | super(ASPP_Classifier_Gen, self).__init__()
41 | self.head = ASPP_Classifier(in_channels, dilation_series, padding_series, hidden_dim)
42 | self.classifier = nn.Conv2d(hidden_dim, num_classes, kernel_size=1, stride=1) # Generalize DeepLabv2 to backbone + classifier structure (make classifier independent)
43 |
44 | def forward(self, x, size=None):
45 | out = self.head(x)
46 | out = self.classifier(out)
47 | if size is not None:
48 | out = F.interpolate(out, size=size, mode='bilinear', align_corners=True)
49 | return out
50 |
--------------------------------------------------------------------------------
/utils/dropout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from copy import deepcopy
3 |
4 |
5 | def create_adversarial_dropout_mask(mask, jacobian, delta):
6 | """
7 |
8 | :param mask: shape [batch_size, ...]
9 | :param jacobian: shape [batch_size, ...]
10 | :param delta:
11 | :return:
12 | """
13 | num_of_units = int(torch.prod(torch.tensor(mask.size()[1:])).to(torch.float))
14 | change_limit = int(num_of_units * delta)
15 | mask = (mask > 0).to(torch.float)
16 |
17 | if change_limit == 0:
18 | return deepcopy(mask).detach(), torch.Tensor([]).type(torch.int64)
19 |
20 | # mask (mask=1 -> m = 1), (mask=0 -> m=-1)
21 | m = 2 * mask - torch.ones_like(mask)
22 |
23 | # sign of Jacobian (J>0 -> s=1), (J<0 -> s=-1)
24 | s = torch.sign(jacobian)
25 |
26 | # remain (J>0, m=-1) and (J<0, m=1), which are candidates to be changed
27 | change_candidates = ((m * s) < 0).to(torch.float)
28 |
29 | # print("change_candidates: ", change_candidates.sum())
30 |
31 | # ordering abs_jacobian for candidates
32 | # the maximum number of the changes is "change_limit"
33 | # draw top_k elements ( if the top k element is 0, the number of the changes is less than "change_limit" )
34 | abs_jacobian = torch.abs(jacobian)
35 | candidate_abs_jacobian = (change_candidates * abs_jacobian).view(-1, num_of_units)
36 | topk_values, topk_indices = torch.topk(candidate_abs_jacobian, change_limit + 1)
37 | min_values = topk_values[:, -1]
38 | change_target_marker = (candidate_abs_jacobian > min_values.unsqueeze(-1)).view(mask.size()).to(torch.float)
39 |
40 | # changed mask with change_target_marker
41 | adv_mask = torch.abs(mask - change_target_marker)
42 |
43 | # normalization
44 | adv_mask = adv_mask.view(-1, num_of_units)
45 | num_of_undropped_units = torch.sum(adv_mask, dim=1).unsqueeze(-1)
46 | adv_mask = ((adv_mask / num_of_undropped_units) * num_of_units).view(mask.size())
47 |
48 | # return adv_mask.clone().detach(), (adv_mask == 0).nonzero()[:, 1]
49 | return adv_mask.clone().detach(), None
50 |
51 |
52 | def calculate_jacobians(h, clean_logits, head, classifier, consistency_criterion):
53 | cnn_mask = torch.ones((*h.size()[:2], 1, 1)).to(h.device)
54 | # fc_mask = torch.ones(cnn_mask.size(0), fc_mask_size).to(cnn_mask.device)
55 | cnn_mask.requires_grad = True
56 | # fc_mask.requires_grad = True
57 |
58 | # h_logits = classifier(cnn_mask * h, fc_mask)
59 | h_logits = classifier(head(cnn_mask * h))
60 | discrepancy = consistency_criterion(h_logits, clean_logits)
61 |
62 | # print("discrepancy: ", discrepancy)
63 |
64 | discrepancy.backward()
65 |
66 | # reset_grad_fn()
67 | # return cnn_mask.grad.clone(), fc_mask.grad.clone(), h_logits
68 | head.zero_grad()
69 | classifier.zero_grad()
70 |
71 | # print("cnn_mask.grad.max(): {}, cnn_mask.grad.min(): {}, cnn_mask.grad.mean(): {}".format(cnn_mask.grad.max(), cnn_mask.grad.min(), cnn_mask.grad.mean()))
72 |
73 | return cnn_mask.grad.clone(), h_logits
74 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 |
6 |
7 | class CrossEntropy2d(nn.Module):
8 |
9 | def __init__(self, size_average=True, ignore_label=255):
10 | super(CrossEntropy2d, self).__init__()
11 | self.size_average = size_average
12 | self.ignore_label = ignore_label
13 |
14 | def forward(self, predict, target, weight=None):
15 | """
16 | Args:
17 | predict:(n, c, h, w)
18 | target:(n, h, w)
19 | weight (Tensor, optional): a manual rescaling weight given to each class.
20 | If given, has to be a Tensor of size "nclasses"
21 | """
22 | assert not target.requires_grad
23 | assert predict.dim() == 4
24 | assert target.dim() == 3
25 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0))
26 | assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1))
27 | assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3))
28 | n, c, h, w = predict.size()
29 | target_mask = (target >= 0) * (target != self.ignore_label)
30 | target = target[target_mask]
31 | if not target.data.dim():
32 | return Variable(torch.zeros(1))
33 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous()
34 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c)
35 | loss = F.cross_entropy(predict, target, weight=weight, size_average=self.size_average)
36 | return loss
37 |
38 |
39 | class EntropyLoss(nn.Module):
40 | def __init__(self, reduction='mean'):
41 | super().__init__()
42 | self.reduction = reduction
43 |
44 | def forward(self, logits):
45 | p = F.softmax(logits, dim=1)
46 | elementwise_entropy = -p * F.log_softmax(logits, dim=1)
47 | if self.reduction == 'none':
48 | return elementwise_entropy
49 |
50 | # print("elementwise_entropy.shape: ", elementwise_entropy.shape)
51 |
52 | sum_entropy = torch.sum(elementwise_entropy, dim=1)
53 | if self.reduction == 'sum':
54 | return sum_entropy
55 |
56 | # print("sum_entropy.shape: ", sum_entropy.shape)
57 |
58 | return torch.mean(sum_entropy)
59 |
60 |
61 | class AbstractConsistencyLoss(nn.Module):
62 | def __init__(self, reduction='mean'):
63 | super().__init__()
64 | self.reduction = reduction
65 |
66 | def forward(self, logits1, logits2):
67 | raise NotImplementedError
68 |
69 |
70 | class LossWithLogits(AbstractConsistencyLoss):
71 | def __init__(self, reduction='mean', loss_cls=nn.L1Loss):
72 | super().__init__(reduction)
73 | self.loss_with_softmax = loss_cls(reduction=reduction)
74 |
75 | def forward(self, logits1, logits2):
76 | loss = self.loss_with_softmax(F.softmax(logits1, dim=1), F.softmax(logits2, dim=1))
77 | return loss
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DecoupleNet
2 | Official implementation for our ECCV 2022 paper "DecoupleNet: Decoupled Network for Domain Adaptive Semantic Segmentation" [[arXiv](https://arxiv.org/pdf/2207.09988.pdf)] [[Paper](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136930362.pdf)]
3 |
4 |
5 |

6 |
7 |
8 | # Get Started
9 |
10 | ## Datasets Preparation
11 |
12 | ### GTA5
13 | First, download GTA5 from the [website](https://download.visinf.tu-darmstadt.de/data/from_games/). Then, extract them and organize as follows.
14 | ```
15 | images/
16 | |---00000.png
17 | |---00001.png
18 | |---...
19 | labels/
20 | |---00000.png
21 | |---00001.png
22 | |---...
23 | split.mat
24 | gtav_label_info.p
25 | ```
26 |
27 | ### Cityscapes
28 |
29 | Download Cityscapes dataset from the [website](https://www.cityscapes-dataset.com/). And organize them as
30 | ```
31 | leftImg8bit/
32 | |---train/
33 | |---val/
34 | |---test/
35 | gtFine
36 | |---train/
37 | |---val/
38 | |---test/
39 | ```
40 |
41 | ## Training
42 |
43 | ### GTA5 -> Cityspcaes
44 | First, download the pretrained ResNet101 (PyTorch) and sourceonly model from [here](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155154502_link_cuhk_edu_hk/EVowKrywcUVJhK0tbO_ebxQBv83FCISbGW_2fTeCWiFvGA), and put them into the directory `./pretrained`.
45 | ```
46 | mkdir pretrained && cd pretrained
47 | wget https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
48 | # Also put the sourceonly.pth into ./pretrained/
49 | ```
50 |
51 | First-phase training:
52 | ```
53 | python3 train_phase1.py --snapshot-dir ./snapshots/GTA2Cityscapes_phase1 --batch-size 8 --gpus 0,1,2,3 --dist --tensorboard --batch_size_val 4 --src_rootpath [YOUR_SOURCE_DATA_ROOT] --tgt_rootpath [YOUR_TARGET_DATA_ROOT]
54 | ```
55 |
56 | Second-phase training (The trained phase1 model can also be downloaded from [here](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155154502_link_cuhk_edu_hk/EmhCkQ_lJ1FLr9Dj2QopYHkB4gyXPOC2BUzjmw4jGq6FSQ?e=m8XPfC)):
57 | ```
58 | # First generate the soft pesudo labels from the trained phase1 model
59 | python3 generate_soft_label.py --snapshot-dir ./snapshots/GTA2Cityscapes_generate_soft_labels --batch-size 8 --gpus 0,1,2,3 --dist --tensorboard --batch_size_val 4 --resume [PATH_OF_PHASE1_MODEL] --output_folder ./datasets/gta2city_soft_labels --no_droplast --src_rootpath [YOUR_SOURCE_DATA_ROOT] --tgt_rootpath [YOUR_TARGET_DATA_ROOT]
60 |
61 | # Then, get the thresholds from the generated soft labels:
62 | cd datasets/ && python3 get_thresholds.py 0.8 gta2city_soft_labels
63 |
64 | # Training with soft pseudo labels:
65 | python3 train_phase2.py --snapshot-dir ./snapshots/GTA2Cityscapes_phase2 --batch-size 8 --gpus 0,1,2,3 --dist --tensorboard --learning-rate 5e-4 --batch_size_val 4 --soft_labels_folder ./datasets/gta2city_soft_labels --resume [PATH_OF_PHASE1_MODEL] --thresholds_path ./datasets/gta2city_soft_labels_thresholds_p0.8.npy --src_rootpath [YOUR_SOURCE_DATA_ROOT] --tgt_rootpath [YOUR_TARGET_DATA_ROOT]
66 | ```
67 |
68 | # Acknowledgement
69 | This repository borrows codes from the following repos. Many thanks to the authors for their great work.
70 |
71 | ProDA: https://github.com/microsoft/ProDA
72 |
73 | FADA: https://github.com/JDAI-CV/FADA
74 |
75 | semseg: https://github.com/hszhao/semseg
76 |
77 | # Citation
78 | If you find this project useful, please consider citing:
79 |
80 | ```
81 | @inproceedings{lai2022decouplenet,
82 | title={Decouplenet: Decoupled network for domain adaptive semantic segmentation},
83 | author={Lai, Xin and Tian, Zhuotao and Xu, Xiaogang and Chen, Yingcong and Liu, Shu and Zhao, Hengshuang and Wang, Liwei and Jia, Jiaya},
84 | booktitle={European Conference on Computer Vision},
85 | pages={369--387},
86 | year={2022},
87 | organization={Springer}
88 | }
89 | ```
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import importlib
5 | import numpy as np
6 | import torch.utils.data
7 | from data.base_dataset import BaseDataset
8 | from data.augmentations import *
9 |
10 | def find_dataset_using_name(name):
11 | """Import the module "data/[dataset_name]_dataset.py".
12 |
13 | In the file, the class called DatasetNameDataset() will
14 | be instantiated. It has to be a subclass of BaseDataset,
15 | and it is case-insensitive.
16 | """
17 | dataset_filename = "data." + name + "_dataset"
18 | datasetlib = importlib.import_module(dataset_filename)
19 |
20 | dataset = None
21 | target_dataset_name = name + '_loader'
22 | for _name, cls in datasetlib.__dict__.items():
23 | if _name.lower() == target_dataset_name.lower() \
24 | and issubclass(cls, BaseDataset):
25 | dataset = cls
26 |
27 | if dataset is None:
28 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
29 |
30 | return dataset
31 |
32 | def get_option_setter(dataset_name):
33 | """Return the static method of the dataset class."""
34 | dataset_class = find_dataset_using_name(dataset_name)
35 | return dataset_class.modify_commandline_options
36 |
37 | def create_dataset(opt, logger):
38 | """Create a dataset given the option.
39 |
40 | This function wraps the class CustomDatasetDataLoader.
41 | This is the main interface between this package and 'train.py'/'test.py'
42 |
43 | Example:
44 | >>> from data import create_dataset
45 | >>> dataset = create_dataset(opt)
46 | """
47 | data_loader = CustomDatasetDataLoader(opt, logger)
48 | dataset = data_loader.load_data()
49 | return dataset
50 |
51 | def get_composed_augmentations(opt):
52 | return Compose([RandomSized(opt.resize),
53 | RandomCrop(opt.rcrop),
54 | RandomHorizontallyFlip(opt.hflip)])
55 |
56 | class CustomDatasetDataLoader():
57 | def __init__(self, opt, logger):
58 | self.opt = opt
59 | self.logger = logger
60 |
61 | # status == 'train':
62 | source_train = find_dataset_using_name(opt.src_dataset)
63 | data_aug = None if opt.noaug else get_composed_augmentations(opt)
64 | self.source_train = source_train(opt, logger, augmentations=data_aug)
65 | if logger is not None:
66 | logger.info("{} source dataset has been created".format(self.source_train.__class__.__name__))
67 | print("dataset {} for source was created".format(self.source_train.__class__.__name__))
68 | self.source_train[0]
69 |
70 | data_aug = None if opt.noaug else get_composed_augmentations(opt)
71 | target_train = find_dataset_using_name(opt.tgt_dataset)
72 | self.target_train = target_train(opt, logger, augmentations=data_aug, split='train')
73 | if logger is not None:
74 | logger.info("{} target dataset has been created".format(self.target_train.__class__.__name__))
75 | print("dataset {} for target was created".format(self.target_train.__class__.__name__))
76 | self.target_train[0]
77 |
78 | ## create train loader
79 | self.source_train_sampler = torch.utils.data.distributed.DistributedSampler(self.source_train, shuffle=not opt.noshuffle)
80 | self.source_train_loader = torch.utils.data.DataLoader(
81 | self.source_train,
82 | batch_size=opt.batch_size,
83 | shuffle=False,
84 | sampler=self.source_train_sampler,
85 | num_workers=int(opt.num_workers),
86 | drop_last=True,
87 | pin_memory=True,
88 | )
89 | self.target_train_sampler = torch.utils.data.distributed.DistributedSampler(self.target_train, shuffle=not opt.noshuffle)
90 | self.target_train_loader = torch.utils.data.DataLoader(
91 | self.target_train,
92 | batch_size=opt.batch_size,
93 | shuffle=False,
94 | sampler=self.target_train_sampler,
95 | num_workers=int(opt.num_workers),
96 | drop_last=not opt.no_droplast,
97 | pin_memory=True,
98 | )
99 |
100 | # status == valid
101 | self.source_valid = None
102 | self.source_valid_loader = None
103 |
104 | self.target_valid = None
105 | self.target_valid_loader = None
106 |
107 | target_valid = find_dataset_using_name(opt.tgt_val_dataset)
108 | self.target_valid = target_valid(opt, logger, augmentations=None, split='val')
109 | if logger is not None:
110 | logger.info("{} target_valid dataset has been created".format(self.target_valid.__class__.__name__))
111 | print("dataset {} for target_valid was created".format(self.target_valid.__class__.__name__))
112 |
113 | self.target_valid_sampler = torch.utils.data.distributed.DistributedSampler(self.target_valid, shuffle=False)
114 | self.target_valid_loader = torch.utils.data.DataLoader(
115 | self.target_valid,
116 | batch_size=opt.batch_size_val,
117 | shuffle=False,
118 | sampler=self.target_valid_sampler,
119 | num_workers=int(opt.num_workers),
120 | drop_last=False,
121 | pin_memory=True,
122 | )
123 |
124 | def load_data(self):
125 | return self
126 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.init as initer
8 |
9 |
10 | class AverageMeter(object):
11 | """Computes and stores the average and current value"""
12 | def __init__(self):
13 | self.reset()
14 |
15 | def reset(self):
16 | self.val = 0
17 | self.avg = 0
18 | self.sum = 0
19 | self.count = 0
20 |
21 | def update(self, val, n=1):
22 | self.val = val
23 | self.sum += val * n
24 | self.count += n
25 | self.avg = self.sum / self.count
26 |
27 |
28 | def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1):
29 | """Sets the learning rate to the base LR decayed by 10 every step epochs"""
30 | lr = base_lr * (multiplier ** (epoch // step_epoch))
31 | return lr
32 |
33 |
34 | def poly_learning_rate(base_lr, curr_iter, max_iter, power=0.9):
35 | """poly learning rate policy"""
36 | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power
37 | return lr
38 |
39 |
40 | def intersectionAndUnion(output, target, K, ignore_index=255):
41 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
42 | assert (output.ndim in [1, 2, 3])
43 | assert output.shape == target.shape
44 | output = output.reshape(output.size).copy()
45 | target = target.reshape(target.size)
46 | output[np.where(target == ignore_index)[0]] = ignore_index
47 | intersection = output[np.where(output == target)[0]]
48 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K+1))
49 | area_output, _ = np.histogram(output, bins=np.arange(K+1))
50 | area_target, _ = np.histogram(target, bins=np.arange(K+1))
51 | area_union = area_output + area_target - area_intersection
52 | return area_intersection, area_union, area_target
53 |
54 |
55 | def intersectionAndUnionGPU(output, target, K, ignore_index=255):
56 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
57 | assert (output.dim() in [1, 2, 3])
58 | assert output.shape == target.shape
59 | output = output.view(-1)
60 | target = target.view(-1)
61 | output[target == ignore_index] = ignore_index
62 | intersection = output[output == target]
63 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1)
64 | area_output = torch.histc(output, bins=K, min=0, max=K-1)
65 | area_target = torch.histc(target, bins=K, min=0, max=K-1)
66 | area_union = area_output + area_target - area_intersection
67 | return area_intersection, area_union, area_target
68 |
69 |
70 | def check_mkdir(dir_name):
71 | if not os.path.exists(dir_name):
72 | os.mkdir(dir_name)
73 |
74 |
75 | def check_makedirs(dir_name):
76 | if not os.path.exists(dir_name):
77 | os.makedirs(dir_name)
78 |
79 |
80 | def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'):
81 | """
82 | :param model: Pytorch Model which is nn.Module
83 | :param conv: 'kaiming' or 'xavier'
84 | :param batchnorm: 'normal' or 'constant'
85 | :param linear: 'kaiming' or 'xavier'
86 | :param lstm: 'kaiming' or 'xavier'
87 | """
88 | for m in model.modules():
89 | if isinstance(m, (nn.modules.conv._ConvNd)):
90 | if conv == 'kaiming':
91 | initer.kaiming_normal_(m.weight)
92 | elif conv == 'xavier':
93 | initer.xavier_normal_(m.weight)
94 | else:
95 | raise ValueError("init type of conv error.\n")
96 | if m.bias is not None:
97 | initer.constant_(m.bias, 0)
98 |
99 | elif isinstance(m, (nn.modules.batchnorm._BatchNorm)):
100 | if batchnorm == 'normal':
101 | initer.normal_(m.weight, 1.0, 0.02)
102 | elif batchnorm == 'constant':
103 | initer.constant_(m.weight, 1.0)
104 | else:
105 | raise ValueError("init type of batchnorm error.\n")
106 | initer.constant_(m.bias, 0.0)
107 |
108 | elif isinstance(m, nn.Linear):
109 | if linear == 'kaiming':
110 | initer.kaiming_normal_(m.weight)
111 | elif linear == 'xavier':
112 | initer.xavier_normal_(m.weight)
113 | else:
114 | raise ValueError("init type of linear error.\n")
115 | if m.bias is not None:
116 | initer.constant_(m.bias, 0)
117 |
118 | elif isinstance(m, nn.LSTM):
119 | for name, param in m.named_parameters():
120 | if 'weight' in name:
121 | if lstm == 'kaiming':
122 | initer.kaiming_normal_(param)
123 | elif lstm == 'xavier':
124 | initer.xavier_normal_(param)
125 | else:
126 | raise ValueError("init type of lstm error.\n")
127 | elif 'bias' in name:
128 | initer.constant_(param, 0)
129 |
130 |
131 | def group_weight(weight_group, module, lr):
132 | group_decay = []
133 | group_no_decay = []
134 | for m in module.modules():
135 | if isinstance(m, nn.Linear):
136 | group_decay.append(m.weight)
137 | if m.bias is not None:
138 | group_no_decay.append(m.bias)
139 | elif isinstance(m, nn.modules.conv._ConvNd):
140 | group_decay.append(m.weight)
141 | if m.bias is not None:
142 | group_no_decay.append(m.bias)
143 | elif isinstance(m, nn.modules.batchnorm._BatchNorm):
144 | if m.weight is not None:
145 | group_no_decay.append(m.weight)
146 | if m.bias is not None:
147 | group_no_decay.append(m.bias)
148 | assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
149 | weight_group.append(dict(params=group_decay, lr=lr))
150 | weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
151 | return weight_group
152 |
153 |
154 | def colorize(gray, palette):
155 | # gray: numpy array of the label and 1*3N size list palette
156 | color = Image.fromarray(gray.astype(np.uint8)).convert('P')
157 | color.putpalette(palette)
158 | return color
159 |
160 |
161 | def find_free_port():
162 | import socket
163 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
164 | # Binding to port 0 will cause the OS to find an available port for us
165 | sock.bind(("", 0))
166 | port = sock.getsockname()[1]
167 | sock.close()
168 | # NOTE: there is still a chance the port could be taken by other processes.
169 | return port
170 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
5 |
6 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
7 | """
8 | import torch.utils.data as data
9 | from PIL import Image
10 | import torchvision.transforms as transforms
11 | from abc import ABC, abstractmethod
12 |
13 |
14 | class BaseDataset(data.Dataset, ABC):
15 | """This class is an abstract base class (ABC) for datasets.
16 |
17 | To create a subclass, you need to implement the following four functions:
18 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
19 | -- <__len__>: return the size of dataset.
20 | -- <__getitem__>: get a data point.
21 | -- : (optionally) add dataset-specific options and set default options.
22 | """
23 |
24 | def __init__(self, opt):
25 | """Initialize the class; save the options in the class
26 |
27 | Parameters:
28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
29 | """
30 | self.opt = opt
31 |
32 | @staticmethod
33 | def modify_commandline_options(parser, is_train):
34 | """Add new dataset-specific options, and rewrite default values for existing options.
35 |
36 | Parameters:
37 | parser -- original option parser
38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39 |
40 | Returns:
41 | the modified parser.
42 | """
43 | return parser
44 |
45 | @abstractmethod
46 | def __len__(self):
47 | """Return the total number of images in the dataset."""
48 | return 0
49 |
50 | @abstractmethod
51 | def __getitem__(self, index):
52 | """Return a data point and its metadata information.
53 |
54 | Parameters:
55 | index - - a random integer for data indexing
56 |
57 | Returns:
58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59 | """
60 | pass
61 |
62 |
63 | def get_transform(opt, grayscale=False, convert=True, crop=True, flip=True):
64 | """Create a torchvision transformation function
65 |
66 | The type of transformation is defined by option (e.g., [opt.preprocess], [opt.load_size], [opt.crop_size])
67 | and can be overwritten by arguments such as [convert], [crop], and [flip]
68 |
69 | Parameters:
70 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
71 | grayscale (bool) -- if convert input RGB image to a grayscale image
72 | convert (bool) -- if convert an image to a tensor array betwen [-1, 1]
73 | crop (bool) -- if apply cropping
74 | flip (bool) -- if apply horizontal flippling
75 | """
76 | transform_list = []
77 | if grayscale:
78 | transform_list.append(transforms.Grayscale(1))
79 | if opt.preprocess == 'resize_and_crop':
80 | osize = [opt.load_size, opt.load_size]
81 | transform_list.append(transforms.Resize(osize, Image.BICUBIC))
82 | transform_list.append(transforms.RandomCrop(opt.crop_size))
83 | elif opt.preprocess == 'crop' and crop:
84 | transform_list.append(transforms.RandomCrop(opt.crop_size))
85 | elif opt.preprocess == 'scale_width':
86 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.crop_size)))
87 | elif opt.preprocess == 'scale_width_and_crop':
88 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size)))
89 | if crop:
90 | transform_list.append(transforms.RandomCrop(opt.crop_size))
91 | elif opt.preprocess == 'none':
92 | transform_list.append(transforms.Lambda(lambda img: __adjust(img)))
93 | else:
94 | raise ValueError('--preprocess %s is not a valid option.' % opt.preprocess)
95 |
96 | if not opt.no_flip and flip:
97 | transform_list.append(transforms.RandomHorizontalFlip())
98 | if convert:
99 | transform_list += [transforms.ToTensor(),
100 | transforms.Normalize((0.5, 0.5, 0.5),
101 | (0.5, 0.5, 0.5))]
102 | return transforms.Compose(transform_list)
103 |
104 |
105 | def __adjust(img):
106 | """Modify the width and height to be multiple of 4.
107 |
108 | Parameters:
109 | img (PIL image) -- input image
110 |
111 | Returns a modified image whose width and height are mulitple of 4.
112 |
113 | the size needs to be a multiple of 4,
114 | because going through generator network may change img size
115 | and eventually cause size mismatch error
116 | """
117 | ow, oh = img.size
118 | mult = 4
119 | if ow % mult == 0 and oh % mult == 0:
120 | return img
121 | w = (ow - 1) // mult
122 | w = (w + 1) * mult
123 | h = (oh - 1) // mult
124 | h = (h + 1) * mult
125 |
126 | if ow != w or oh != h:
127 | __print_size_warning(ow, oh, w, h)
128 |
129 | return img.resize((w, h), Image.BICUBIC)
130 |
131 |
132 | def __scale_width(img, target_width):
133 | """Resize images so that the width of the output image is the same as a target width
134 |
135 | Parameters:
136 | img (PIL image) -- input image
137 | target_width (int) -- target image width
138 |
139 | Returns a modified image whose width matches the target image width;
140 |
141 | the size needs to be a multiple of 4,
142 | because going through generator network may change img size
143 | and eventually cause size mismatch error
144 | """
145 | ow, oh = img.size
146 |
147 | mult = 4
148 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
149 | if (ow == target_width and oh % mult == 0):
150 | return img
151 | w = target_width
152 | target_height = int(target_width * oh / ow)
153 | m = (target_height - 1) // mult
154 | h = (m + 1) * mult
155 |
156 | if target_height != h:
157 | __print_size_warning(target_width, target_height, w, h)
158 |
159 | return img.resize((w, h), Image.BICUBIC)
160 |
161 |
162 | def __print_size_warning(ow, oh, w, h):
163 | """Print warning information about image size(only print once)"""
164 | if not hasattr(__print_size_warning, 'has_printed'):
165 | print("The image size needs to be a multiple of 4. "
166 | "The loaded image size was (%d, %d), so it was adjusted to "
167 | "(%d, %d). This adjustment will be done to all images "
168 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
169 | __print_size_warning.has_printed = True
170 |
--------------------------------------------------------------------------------
/data/gta5_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import sys
6 | import torch
7 | import numpy as np
8 | import scipy.misc as m
9 | import matplotlib.pyplot as plt
10 | import matplotlib.image as imgs
11 | from PIL import Image
12 | import random
13 | import scipy.io as io
14 | from tqdm import tqdm
15 | from scipy import stats
16 |
17 | from torch.utils import data
18 |
19 | from data import BaseDataset
20 | from data.randaugment import RandAugmentMC
21 |
22 | import pickle
23 | from torchvision import transforms
24 |
25 |
26 | class GTA5_loader(BaseDataset):
27 | """
28 | GTA5 synthetic dataset
29 | for domain adaptation to Cityscapes
30 | """
31 |
32 | colors = [ # [ 0, 0, 0],
33 | [128, 64, 128],
34 | [244, 35, 232],
35 | [70, 70, 70],
36 | [102, 102, 156],
37 | [190, 153, 153],
38 | [153, 153, 153],
39 | [250, 170, 30],
40 | [220, 220, 0],
41 | [107, 142, 35],
42 | [152, 251, 152],
43 | [0, 130, 180],
44 | [220, 20, 60],
45 | [255, 0, 0],
46 | [0, 0, 142],
47 | [0, 0, 70],
48 | [0, 60, 100],
49 | [0, 80, 100],
50 | [0, 0, 230],
51 | [119, 11, 32],
52 | ]
53 |
54 | label_colours = dict(zip(range(19), colors))
55 | def __init__(self, opt, logger, augmentations=None):
56 | self.opt = opt
57 | self.root = opt.src_rootpath
58 | self.split = 'all'
59 | self.augmentations = augmentations
60 | self.randaug = RandAugmentMC(2, 10)
61 | self.n_classes = 19
62 | self.img_size = (1914, 1052)
63 |
64 | self.mean = [0.0, 0.0, 0.0] #TODO: calculating the mean value of rgb channels on GTA5
65 | self.image_base_path = os.path.join(self.root, 'images')
66 | self.label_base_path = os.path.join(self.root, 'labels')
67 | splits = io.loadmat(os.path.join(self.root, 'split.mat'))
68 | if self.split == 'all':
69 | ids = np.concatenate((splits['trainIds'][:,0], splits['valIds'][:,0], splits['testIds'][:,0]))
70 | elif self.split == 'train':
71 | ids = splits['trainIds'][:,0]
72 | elif self.split == 'val':
73 | ids = splits['valIds'][:200,0]
74 | elif self.split == 'test':
75 | ids = splits['testIds'][:,0]
76 |
77 | max_iters = opt.num_steps * opt.batch_size * opt.world_size
78 | if max_iters is not None:
79 | if not os.path.exists("data/class_balance_ids_{}.p".format(max_iters)):
80 | self.label_to_file, self.file_to_label = pickle.load(open(os.path.join(self.root, "gtav_label_info.p"), "rb"))
81 | self.ids = []
82 | SUB_EPOCH_SIZE = 3000
83 | tmp_list = []
84 | ind = dict()
85 | for i in range(self.n_classes):
86 | ind[i] = 0
87 | for e in range(int(max_iters/SUB_EPOCH_SIZE)+1):
88 | cur_class_dist = np.zeros(self.n_classes)
89 | for i in range(SUB_EPOCH_SIZE):
90 | if cur_class_dist.sum() == 0:
91 | dist1 = cur_class_dist.copy()
92 | else:
93 | dist1 = cur_class_dist/cur_class_dist.sum()
94 | w = 1/np.log(1+1e-2 + dist1)
95 | w = w/w.sum()
96 | c = np.random.choice(self.n_classes, p=w)
97 |
98 | if ind[c] > (len(self.label_to_file[c])-1):
99 | np.random.shuffle(self.label_to_file[c])
100 | ind[c] = ind[c]%(len(self.label_to_file[c])-1)
101 |
102 | c_file = self.label_to_file[c][ind[c]]
103 | tmp_list.append(c_file)
104 | ind[c] = ind[c]+1
105 | cur_class_dist[self.file_to_label[c_file]] += 1
106 |
107 | self.ids = [os.path.join(self.label_base_path, x) for x in tmp_list]
108 | with open("data/class_balance_ids_{}.p".format(max_iters), 'wb') as f:
109 | pickle.dump(self.ids, f)
110 | else:
111 | with open("data/class_balance_ids_{}.p".format(max_iters), 'rb') as f:
112 | self.ids = pickle.load(f)
113 |
114 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, 34, -1]
115 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33,]
116 | self.class_names = ["unlabelled","road","sidewalk","building","wall","fence","pole","traffic_light",
117 | "traffic_sign","vegetation","terrain","sky","person","rider","car","truck","bus","train",
118 | "motorcycle","bicycle",]
119 |
120 | self.ignore_index = 250
121 | self.class_map = dict(zip(self.valid_classes, range(19)))
122 |
123 | clrjit_params = getattr(opt, "clrjit_params", [0.5, 0.5, 0.5, 0.2])
124 | self.train_transform = transforms.Compose([
125 | transforms.ToPILImage(),
126 | transforms.ColorJitter(*clrjit_params),
127 | ])
128 |
129 | if len(self.ids) == 0:
130 | raise Exception(
131 | "No files for style=[%s] found in %s" % (self.split, self.image_base_path)
132 | )
133 |
134 | print("Found {} {} images".format(len(self.ids), self.split))
135 |
136 | def __len__(self):
137 | return len(self.ids)
138 |
139 | def __getitem__(self, index):
140 | """__getitem__
141 |
142 | param: index
143 | """
144 | id = self.ids[index]
145 | if self.split != 'all' and self.split != 'val':
146 | filename = '{:05d}.png'.format(id)
147 | img_path = os.path.join(self.image_base_path, filename)
148 | lbl_path = os.path.join(self.label_base_path, filename)
149 | else:
150 | img_path = os.path.join(self.image_base_path, id.split('/')[-1])
151 | lbl_path = id
152 |
153 | img = Image.open(img_path)
154 | lbl = Image.open(lbl_path)
155 |
156 | img = img.resize(self.img_size, Image.BILINEAR)
157 | lbl = lbl.resize(self.img_size, Image.NEAREST)
158 | img = np.asarray(img, dtype=np.uint8)
159 | lbl = np.asarray(lbl, dtype=np.uint8)
160 |
161 | lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8))
162 |
163 | input_dict = {}
164 | if self.augmentations!=None:
165 | img, lbl, _, _, _ = self.augmentations(img, lbl)
166 | img_strong, params = self.randaug(Image.fromarray(img))
167 | img_strong, _ = self.transform(img_strong, lbl)
168 | input_dict['img_strong'] = img_strong
169 | input_dict['params'] = params
170 |
171 | img = self.train_transform(img)
172 |
173 | img, lbl = self.transform(img, lbl)
174 |
175 | input_dict['img'] = img
176 | input_dict['label'] = lbl
177 | input_dict['img_path'] = self.ids[index]
178 | return input_dict
179 |
180 |
181 | def encode_segmap(self, lbl):
182 | for _i in self.void_classes:
183 | lbl[lbl == _i] = self.ignore_index
184 | for _i in self.valid_classes:
185 | lbl[lbl == _i] = self.class_map[_i]
186 | return lbl
187 |
188 | def decode_segmap(self, temp):
189 | r = temp.copy()
190 | g = temp.copy()
191 | b = temp.copy()
192 | for l in range(0, self.n_classes):
193 | r[temp == l] = self.label_colours[l][0]
194 | g[temp == l] = self.label_colours[l][1]
195 | b[temp == l] = self.label_colours[l][2]
196 |
197 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
198 | rgb[:, :, 0] = r / 255.0
199 | rgb[:, :, 1] = g / 255.0
200 | rgb[:, :, 2] = b / 255.0
201 | return rgb
202 |
203 | def transform(self, img, lbl):
204 | """transform
205 |
206 | img, lbl
207 | """
208 | img = np.array(img)
209 | # img = img[:, :, ::-1] # RGB -> BGR
210 | img = img.astype(np.float64)
211 | img -= self.mean
212 | img = img.astype(float) / 255.0
213 | img = img.transpose(2, 0, 1)
214 |
215 | classes = np.unique(lbl)
216 | lbl = np.array(lbl)
217 | lbl = lbl.astype(float)
218 | # lbl = m.imresize(lbl, self.img_size, "nearest", mode='F')
219 | lbl = lbl.astype(int)
220 |
221 | if not np.all(classes == np.unique(lbl)):
222 | print("WARN: resizing labels yielded fewer classes") #TODO: compare the original and processed ones
223 |
224 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes):
225 | print("after det", classes, np.unique(lbl))
226 | raise ValueError("Segmentation map contained invalid class values")
227 |
228 | img = torch.from_numpy(img).float()
229 | lbl = torch.from_numpy(lbl).long()
230 |
231 | return img, lbl
232 |
233 | def get_cls_num_list(self):
234 | cls_num_list = np.array([16139327127, 4158369631, 8495419275, 927064742, 318109335,
235 | 532432540, 67453231, 40526481, 3818867486, 1081467674,
236 | 6800402117, 182228033, 15360044, 1265024472, 567736474,
237 | 184854135, 32542442, 15832619, 2721193])
238 | # cls_num_list = np.zeros(self.n_classes, dtype=np.int64)
239 | # for n in range(len(self.ids)):
240 | # lbl = Image.open(self.ids[n])
241 | # lbl = lbl.resize(self.img_size, Image.NEAREST)
242 | # lbl = np.asarray(lbl, dtype=np.uint8)
243 | # lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8))
244 | # for i in range(self.n_classes):
245 | # cls_num_list[i] += (lbl == i).sum()
246 | return cls_num_list
247 |
--------------------------------------------------------------------------------
/data/cityscapes_val_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import torch
6 | import numpy as np
7 | import scipy.misc as m
8 | from tqdm import tqdm
9 |
10 | from torch.utils import data
11 | from PIL import Image
12 |
13 | from data.augmentations import *
14 | from data.base_dataset import BaseDataset
15 | from data.randaugment import RandAugmentMC
16 |
17 | import random
18 |
19 | def recursive_glob(rootdir=".", suffix=""):
20 | """Performs recursive glob with given suffix and rootdir
21 | :param rootdir is the root directory
22 | :param suffix is the suffix to be searched
23 | """
24 | return [
25 | os.path.join(looproot, filename)
26 | for looproot, _, filenames in os.walk(rootdir) #os.walk: traversal all files in rootdir and its subfolders
27 | for filename in filenames
28 | if filename.endswith(suffix)
29 | ]
30 |
31 | class Cityscapes_val_loader(BaseDataset):
32 | """cityscapesLoader
33 |
34 | https://www.cityscapes-dataset.com
35 |
36 | Data is derived from CityScapes, and can be downloaded from here:
37 | https://www.cityscapes-dataset.com/downloads/
38 |
39 | Many Thanks to @fvisin for the loader repo:
40 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py
41 | """
42 |
43 | colors = [ # [ 0, 0, 0],
44 | [128, 64, 128],
45 | [244, 35, 232],
46 | [70, 70, 70],
47 | [102, 102, 156],
48 | [190, 153, 153],
49 | [153, 153, 153],
50 | [250, 170, 30],
51 | [220, 220, 0],
52 | [107, 142, 35],
53 | [152, 251, 152],
54 | [0, 130, 180],
55 | [220, 20, 60],
56 | [255, 0, 0],
57 | [0, 0, 142],
58 | [0, 0, 70],
59 | [0, 60, 100],
60 | [0, 80, 100],
61 | [0, 0, 230],
62 | [119, 11, 32],
63 | ]
64 |
65 | label_colours = dict(zip(range(19), colors))
66 |
67 | mean_rgb = {
68 | "pascal": [103.939, 116.779, 123.68],
69 | "cityscapes": [0.0, 0.0, 0.0],
70 | } # pascal mean for PSPNet and ICNet pre-trained model
71 |
72 | def __init__(self, opt, logger, augmentations = None, split='train'):
73 | """__init__
74 |
75 | :param opt: parameters of dataset
76 | :param writer: save the result of experiment
77 | :param logger: logging file
78 | :param augmentations:
79 | """
80 |
81 | self.opt = opt
82 | self.root = opt.tgt_rootpath
83 | self.split = split
84 | self.augmentations = augmentations
85 | self.randaug = RandAugmentMC(2, 10)
86 | self.n_classes = opt.num_classes
87 | self.img_size = (2048, 1024)
88 | self.mean = np.array(self.mean_rgb['cityscapes'])
89 | self.files = {}
90 | self.paired_files = {}
91 |
92 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split)
93 | self.annotations_base = os.path.join(
94 | self.root, "gtFine", self.split
95 | )
96 |
97 | self.files = sorted(recursive_glob(rootdir=self.images_base, suffix=".png")) #find all files from rootdir and subfolders with suffix = ".png"
98 |
99 | #self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
100 | if self.n_classes == 19:
101 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33,]
102 | self.class_names = ["unlabelled","road","sidewalk","building","wall",
103 | "fence","pole","traffic_light","traffic_sign","vegetation",
104 | "terrain","sky","person","rider","car",
105 | "truck","bus","train","motorcycle","bicycle",
106 | ]
107 | self.to19 = dict(zip(range(19), range(19)))
108 | elif self.n_classes == 16:
109 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 23, 24, 25, 26, 28, 32, 33,]
110 | self.class_names = ["unlabelled","road","sidewalk","building","wall",
111 | "fence","pole","traffic_light","traffic_sign","vegetation",
112 | "sky","person","rider","car","bus",
113 | "motorcycle","bicycle",
114 | ]
115 | self.to19 = dict(zip(range(16), [0,1,2,3,4,5,6,7,8,10,11,12,13,15,17,18]))
116 | elif self.n_classes == 13:
117 | self.valid_classes = [7, 8, 11, 19, 20, 21, 23, 24, 25, 26, 28, 32, 33,]
118 | self.class_names = ["unlabelled","road","sidewalk","building","traffic_light",
119 | "traffic_sign","vegetation","sky","person","rider",
120 | "car","bus","motorcycle","bicycle",
121 | ]
122 | self.to19 = dict(zip(range(13), [0,1,2,6,7,8,10,11,12,13,15,17,18]))
123 |
124 | self.ignore_index = 250
125 | self.class_map = dict(zip(self.valid_classes, range(self.n_classes))) #zip: return tuples
126 |
127 | if not self.files:
128 | raise Exception(
129 | "No files for split=[%s] found in %s" % (self.split, self.images_base)
130 | )
131 |
132 | print("Found %d %s images" % (len(self.files), self.split))
133 |
134 | def __len__(self):
135 | """__len__"""
136 | return len(self.files)
137 |
138 | def __getitem__(self, index):
139 | """__getitem__
140 |
141 | :param index:
142 | """
143 | img_path = self.files[index].rstrip()
144 | lbl_path = os.path.join(
145 | self.annotations_base,
146 | img_path.split(os.sep)[-2],
147 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png",
148 | )
149 |
150 | img = Image.open(img_path)
151 | lbl = Image.open(lbl_path)
152 | img = img.resize(self.img_size, Image.BILINEAR)
153 | lbl = lbl.resize(self.img_size, Image.NEAREST)
154 |
155 | img = np.array(img, dtype=np.uint8)
156 | lbl = np.array(lbl, dtype=np.uint8)
157 | lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8))
158 |
159 | img_full = img.copy().astype(np.float64)
160 | img_full -= self.mean
161 | img_full = img_full.astype(float) / 255.0
162 | img_full = img_full.transpose(2, 0, 1)
163 | lbl_full = lbl.copy()
164 |
165 | lp, lpsoft, weak_params = None, None, None
166 | input_dict = {}
167 | if self.augmentations!=None:
168 | img, lbl, lp, lpsoft, weak_params = self.augmentations(img, lbl, lp, lpsoft)
169 | img_strong, params = self.randaug(Image.fromarray(img))
170 | img_strong, _, _ = self.transform(img_strong, lbl)
171 | input_dict['img_strong'] = img_strong
172 | input_dict['params'] = params
173 |
174 | img, lbl_, lp = self.transform(img, lbl, lp)
175 |
176 | input_dict['img'] = img
177 | input_dict['img_full'] = torch.from_numpy(img_full).float()
178 | input_dict['label'] = lbl_
179 | input_dict['lp'] = lp
180 | input_dict['lpsoft'] = lpsoft
181 | input_dict['weak_params'] = weak_params #full2weak
182 | input_dict['img_path'] = self.files[index]
183 | input_dict['lbl_full'] = torch.from_numpy(lbl_full).long()
184 |
185 | input_dict = {k:v for k, v in input_dict.items() if v is not None}
186 | return input_dict
187 |
188 | def transform(self, img, lbl, lp=None, check=True):
189 | """transform
190 |
191 | :param img:
192 | :param lbl:
193 | """
194 | # img = m.imresize(
195 | # img, (self.img_size[0], self.img_size[1])
196 | # ) # uint8 with RGB mode
197 | img = np.array(img)
198 | # img = img[:, :, ::-1] # RGB -> BGR
199 | img = img.astype(np.float64)
200 | img -= self.mean
201 | img = img.astype(float) / 255.0
202 | # NHWC -> NCHW
203 | img = img.transpose(2, 0, 1)
204 |
205 | classes = np.unique(lbl)
206 | lbl = np.array(lbl)
207 | lbl = lbl.astype(float)
208 | # lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
209 | lbl = lbl.astype(int)
210 |
211 | if not np.all(classes == np.unique(lbl)):
212 | print("WARN: resizing labels yielded fewer classes") #TODO: compare the original and processed ones
213 |
214 | if check and not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): #todo: understanding the meaning
215 | print("after det", classes, np.unique(lbl))
216 | raise ValueError("Segmentation map contained invalid class values")
217 |
218 | img = torch.from_numpy(img).float()
219 | lbl = torch.from_numpy(lbl).long()
220 |
221 | if lp is not None:
222 | classes = np.unique(lp)
223 | lp = np.array(lp)
224 | # if not np.all(np.unique(lp[lp != self.ignore_index]) < self.n_classes):
225 | # raise ValueError("lp Segmentation map contained invalid class values")
226 |
227 | lp = torch.from_numpy(lp).long()
228 |
229 | return img, lbl, lp
230 |
231 | def decode_segmap(self, temp):
232 | r = temp.copy()
233 | g = temp.copy()
234 | b = temp.copy()
235 | for l in range(0, self.n_classes):
236 | r[temp == l] = self.label_colours[self.to19[l]][0]
237 | g[temp == l] = self.label_colours[self.to19[l]][1]
238 | b[temp == l] = self.label_colours[self.to19[l]][2]
239 |
240 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
241 | rgb[:, :, 0] = r / 255.0
242 | rgb[:, :, 1] = g / 255.0
243 | rgb[:, :, 2] = b / 255.0
244 | return rgb
245 |
246 | def encode_segmap(self, mask):
247 | # Put all void classes to zero
248 | label_copy = 250 * np.ones(mask.shape, dtype=np.uint8)
249 | for k, v in list(self.class_map.items()):
250 | label_copy[mask == k] = v
251 | return label_copy
252 |
253 | def get_cls_num_list(self):
254 | cls_num_list = np.array([1557726944, 254364912, 673500400, 18431664, 14431392,
255 | 29361440, 7038112, 7352368, 477239920, 40134240,
256 | 211669120, 36057968, 865184, 264786464, 17128544,
257 | 2385680, 943312, 504112, 2174560])
258 | return cls_num_list
259 |
--------------------------------------------------------------------------------
/data/cityscapes_train_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import torch
6 | import numpy as np
7 | import scipy.misc as m
8 | from tqdm import tqdm
9 |
10 | from torch.utils import data
11 | from PIL import Image
12 |
13 | from data.augmentations import *
14 | from data.base_dataset import BaseDataset
15 | from data.randaugment import RandAugmentMC
16 |
17 | import random
18 | from torchvision import transforms
19 |
20 | def recursive_glob(rootdir=".", suffix=""):
21 | """Performs recursive glob with given suffix and rootdir
22 | :param rootdir is the root directory
23 | :param suffix is the suffix to be searched
24 | """
25 | return [
26 | os.path.join(looproot, filename)
27 | for looproot, _, filenames in os.walk(rootdir) #os.walk: traversal all files in rootdir and its subfolders
28 | for filename in filenames
29 | if filename.endswith(suffix)
30 | ]
31 |
32 | class Cityscapes_train_loader(BaseDataset):
33 | """cityscapesLoader
34 |
35 | https://www.cityscapes-dataset.com
36 |
37 | Data is derived from CityScapes, and can be downloaded from here:
38 | https://www.cityscapes-dataset.com/downloads/
39 |
40 | Many Thanks to @fvisin for the loader repo:
41 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py
42 | """
43 |
44 | colors = [ # [ 0, 0, 0],
45 | [128, 64, 128],
46 | [244, 35, 232],
47 | [70, 70, 70],
48 | [102, 102, 156],
49 | [190, 153, 153],
50 | [153, 153, 153],
51 | [250, 170, 30],
52 | [220, 220, 0],
53 | [107, 142, 35],
54 | [152, 251, 152],
55 | [0, 130, 180],
56 | [220, 20, 60],
57 | [255, 0, 0],
58 | [0, 0, 142],
59 | [0, 0, 70],
60 | [0, 60, 100],
61 | [0, 80, 100],
62 | [0, 0, 230],
63 | [119, 11, 32],
64 | ]
65 |
66 | label_colours = dict(zip(range(19), colors))
67 |
68 | mean_rgb = {
69 | "pascal": [103.939, 116.779, 123.68],
70 | "cityscapes": [0.0, 0.0, 0.0],
71 | } # pascal mean for PSPNet and ICNet pre-trained model
72 |
73 | def __init__(self, opt, logger, augmentations = None, split='train'):
74 | """__init__
75 |
76 | :param opt: parameters of dataset
77 | :param writer: save the result of experiment
78 | :param logger: logging file
79 | :param augmentations:
80 | """
81 |
82 | self.opt = opt
83 | self.root = opt.tgt_rootpath
84 | self.split = split
85 | self.augmentations = augmentations
86 | self.randaug = RandAugmentMC(2, 10)
87 | self.n_classes = opt.num_classes
88 | self.img_size = (2048, 1024)
89 | self.mean = np.array(self.mean_rgb['cityscapes'])
90 | self.files = {}
91 | self.paired_files = {}
92 |
93 | if logger is not None:
94 | logger.info("pseudo_labels_folder set to {}".format(opt.pseudo_labels_folder))
95 |
96 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split)
97 | self.annotations_base = os.path.join(opt.pseudo_labels_folder)
98 |
99 | self.files = sorted(recursive_glob(rootdir=self.images_base, suffix=".png")) #find all files from rootdir and subfolders with suffix = ".png"
100 |
101 | #self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
102 | if self.n_classes == 19:
103 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33,]
104 | self.class_names = ["unlabelled","road","sidewalk","building","wall",
105 | "fence","pole","traffic_light","traffic_sign","vegetation",
106 | "terrain","sky","person","rider","car",
107 | "truck","bus","train","motorcycle","bicycle",
108 | ]
109 | self.to19 = dict(zip(range(19), range(19)))
110 | elif self.n_classes == 16:
111 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 23, 24, 25, 26, 28, 32, 33,]
112 | self.class_names = ["unlabelled","road","sidewalk","building","wall",
113 | "fence","pole","traffic_light","traffic_sign","vegetation",
114 | "sky","person","rider","car","bus",
115 | "motorcycle","bicycle",
116 | ]
117 | self.to19 = dict(zip(range(16), [0,1,2,3,4,5,6,7,8,10,11,12,13,15,17,18]))
118 | elif self.n_classes == 13:
119 | self.valid_classes = [7, 8, 11, 19, 20, 21, 23, 24, 25, 26, 28, 32, 33,]
120 | self.class_names = ["unlabelled","road","sidewalk","building","traffic_light",
121 | "traffic_sign","vegetation","sky","person","rider",
122 | "car","bus","motorcycle","bicycle",
123 | ]
124 | self.to19 = dict(zip(range(13), [0,1,2,6,7,8,10,11,12,13,15,17,18]))
125 |
126 | self.ignore_index = 250
127 | self.class_map = dict(zip(self.valid_classes, range(self.n_classes))) #zip: return tuples
128 |
129 | if not self.files:
130 | raise Exception(
131 | "No files for split=[%s] found in %s" % (self.split, self.images_base)
132 | )
133 |
134 | clrjit_params = getattr(opt, "clrjit_params", [0.5, 0.5, 0.5, 0.2])
135 | self.train_transform = transforms.Compose([
136 | transforms.ToPILImage(),
137 | transforms.ColorJitter(*clrjit_params),
138 | ])
139 |
140 | print("Found %d %s images" % (len(self.files), self.split))
141 |
142 | def __len__(self):
143 | """__len__"""
144 | return len(self.files)
145 |
146 | def __getitem__(self, index):
147 | """__getitem__
148 |
149 | :param index:
150 | """
151 | img_path = self.files[index].rstrip()
152 | lbl_path = os.path.join(self.annotations_base, img_path.split("/")[-1])
153 |
154 | img = Image.open(img_path)
155 | lbl = Image.open(lbl_path) if os.path.exists(lbl_path) else Image.fromarray(np.zeros(img.size[:2]))
156 | img = img.resize(self.img_size, Image.BILINEAR)
157 | lbl = lbl.resize(self.img_size, Image.NEAREST)
158 |
159 | img = np.array(img, dtype=np.uint8)
160 | lbl = np.array(lbl, dtype=np.uint8)
161 |
162 | img_full = img.copy().astype(np.float64)
163 | img_full -= self.mean
164 | img_full = img_full.astype(float) / 255.0
165 | img_full = img_full.transpose(2, 0, 1)
166 | lbl_full = lbl.copy()
167 |
168 | lp, lpsoft, weak_params = None, None, None
169 | if self.split == 'train' and hasattr(self.opt, "soft_labels_folder"):
170 | lpsoft = np.load(os.path.join(self.opt.soft_labels_folder, os.path.basename(img_path).replace('.png', '.npy')))
171 |
172 | input_dict = {}
173 | if self.augmentations!=None:
174 | img, lbl, lp, lpsoft, weak_params = self.augmentations(img, lbl, lp, lpsoft)
175 | img_strong, params = self.randaug(Image.fromarray(img))
176 | img_strong, _, _ = self.transform(img_strong, lbl)
177 | input_dict['img_strong'] = img_strong
178 | input_dict['params'] = params
179 |
180 | img = self.train_transform(img)
181 |
182 | img, lbl_, lp = self.transform(img, lbl, lp)
183 |
184 | input_dict['img'] = img
185 | input_dict['img_full'] = torch.from_numpy(img_full).float()
186 | input_dict['label'] = lbl_
187 | input_dict['lp'] = lp
188 | input_dict['lpsoft'] = lpsoft
189 | input_dict['weak_params'] = weak_params #full2weak
190 | input_dict['img_path'] = self.files[index]
191 | input_dict['lbl_full'] = torch.from_numpy(lbl_full).long()
192 |
193 | input_dict = {k:v for k, v in input_dict.items() if v is not None}
194 | return input_dict
195 |
196 | def transform(self, img, lbl, lp=None, check=True):
197 | """transform
198 |
199 | :param img:
200 | :param lbl:
201 | """
202 | # img = m.imresize(
203 | # img, (self.img_size[0], self.img_size[1])
204 | # ) # uint8 with RGB mode
205 | img = np.array(img)
206 | # img = img[:, :, ::-1] # RGB -> BGR
207 | img = img.astype(np.float64)
208 | img -= self.mean
209 | img = img.astype(float) / 255.0
210 | # NHWC -> NCHW
211 | img = img.transpose(2, 0, 1)
212 |
213 | classes = np.unique(lbl)
214 | lbl = np.array(lbl)
215 | lbl = lbl.astype(float)
216 | # lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F")
217 | lbl = lbl.astype(int)
218 |
219 | if not np.all(classes == np.unique(lbl)):
220 | print("WARN: resizing labels yielded fewer classes") #TODO: compare the original and processed ones
221 |
222 | if check and not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): #todo: understanding the meaning
223 | print("after det", classes, np.unique(lbl))
224 | raise ValueError("Segmentation map contained invalid class values")
225 |
226 | img = torch.from_numpy(img).float()
227 | lbl = torch.from_numpy(lbl).long()
228 |
229 | if lp is not None:
230 | classes = np.unique(lp)
231 | lp = np.array(lp)
232 | # if not np.all(np.unique(lp[lp != self.ignore_index]) < self.n_classes):
233 | # raise ValueError("lp Segmentation map contained invalid class values")
234 |
235 | lp = torch.from_numpy(lp).long()
236 |
237 | return img, lbl, lp
238 |
239 | def decode_segmap(self, temp):
240 | r = temp.copy()
241 | g = temp.copy()
242 | b = temp.copy()
243 | for l in range(0, self.n_classes):
244 | r[temp == l] = self.label_colours[self.to19[l]][0]
245 | g[temp == l] = self.label_colours[self.to19[l]][1]
246 | b[temp == l] = self.label_colours[self.to19[l]][2]
247 |
248 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
249 | rgb[:, :, 0] = r / 255.0
250 | rgb[:, :, 1] = g / 255.0
251 | rgb[:, :, 2] = b / 255.0
252 | return rgb
253 |
254 | def encode_segmap(self, mask):
255 | # Put all void classes to zero
256 | label_copy = 250 * np.ones(mask.shape, dtype=np.uint8)
257 | for k, v in list(self.class_map.items()):
258 | label_copy[mask == k] = v
259 | return label_copy
260 |
261 | def get_cls_num_list(self):
262 | cls_num_list = np.array([1557726944, 254364912, 673500400, 18431664, 14431392,
263 | 29361440, 7038112, 7352368, 477239920, 40134240,
264 | 211669120, 36057968, 865184, 264786464, 17128544,
265 | 2385680, 943312, 504112, 2174560])
266 | return cls_num_list
267 |
--------------------------------------------------------------------------------
/utils/transform.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | import numpy as np
4 | import numbers
5 | import collections
6 | import cv2
7 |
8 | import torch
9 |
10 |
11 | class Compose(object):
12 | # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()])
13 | def __init__(self, segtransform):
14 | self.segtransform = segtransform
15 |
16 | def __call__(self, image, label):
17 | for t in self.segtransform:
18 | image, label = t(image, label)
19 | return image, label
20 |
21 |
22 | class ToTensor(object):
23 | # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
24 | def __call__(self, image, label):
25 | if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray):
26 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray"
27 | "[eg: data readed by cv2.imread()].\n"))
28 | if len(image.shape) > 3 or len(image.shape) < 2:
29 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n"))
30 | if len(image.shape) == 2:
31 | image = np.expand_dims(image, axis=2)
32 | if not len(label.shape) == 2:
33 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n"))
34 |
35 | image = torch.from_numpy(image.transpose((2, 0, 1)))
36 | if not isinstance(image, torch.FloatTensor):
37 | image = image.float()
38 | label = torch.from_numpy(label)
39 | if not isinstance(label, torch.LongTensor):
40 | label = label.long()
41 | return image, label
42 |
43 |
44 | class Normalize(object):
45 | # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std
46 | def __init__(self, mean, std=None):
47 | if std is None:
48 | assert len(mean) > 0
49 | else:
50 | assert len(mean) == len(std)
51 | self.mean = mean
52 | self.std = std
53 |
54 | def __call__(self, image, label):
55 | if self.std is None:
56 | for t, m in zip(image, self.mean):
57 | t.sub_(m)
58 | else:
59 | for t, m, s in zip(image, self.mean, self.std):
60 | t.sub_(m).div_(s)
61 | return image, label
62 |
63 |
64 | class Resize(object):
65 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w).
66 | def __init__(self, size):
67 | assert (isinstance(size, collections.Iterable) and len(size) == 2)
68 | self.size = size
69 |
70 | def __call__(self, image, label):
71 | size = self.size[::-1]
72 | image = cv2.resize(image, (self.size[1], self.size[0]), interpolation=cv2.INTER_LINEAR)
73 | label = cv2.resize(label, (self.size[1], self.size[0]), interpolation=cv2.INTER_NEAREST)
74 | return image, label
75 |
76 |
77 | class RandScale(object):
78 | # Randomly resize image & label with scale factor in [scale_min, scale_max]
79 | def __init__(self, scale, aspect_ratio=None):
80 | assert (isinstance(scale, collections.Iterable) and len(scale) == 2)
81 | if isinstance(scale, collections.Iterable) and len(scale) == 2 \
82 | and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \
83 | and 0 < scale[0] <= scale[1]:
84 | self.scale = scale
85 | else:
86 | raise (RuntimeError("segtransform.RandScale() scale param error.\n"))
87 | if aspect_ratio is None:
88 | self.aspect_ratio = aspect_ratio
89 | elif isinstance(aspect_ratio, collections.Iterable) and len(aspect_ratio) == 2 \
90 | and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \
91 | and 0 < aspect_ratio[0] < aspect_ratio[1]:
92 | self.aspect_ratio = aspect_ratio
93 | else:
94 | raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n"))
95 |
96 | def __call__(self, image, label):
97 | temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random()
98 | temp_aspect_ratio = 1.0
99 | if self.aspect_ratio is not None:
100 | temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random()
101 | temp_aspect_ratio = math.sqrt(temp_aspect_ratio)
102 | scale_factor_x = temp_scale * temp_aspect_ratio
103 | scale_factor_y = temp_scale / temp_aspect_ratio
104 | image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR)
105 | label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST)
106 | return image, label
107 |
108 |
109 | class Crop(object):
110 | """Crops the given ndarray image (H*W*C or H*W).
111 | Args:
112 | size (sequence or int): Desired output size of the crop. If size is an
113 | int instead of sequence like (h, w), a square crop (size, size) is made.
114 | """
115 | def __init__(self, size, crop_type='center', padding=None, ignore_label=255):
116 | if isinstance(size, int):
117 | self.crop_h = size
118 | self.crop_w = size
119 | elif isinstance(size, collections.Iterable) and len(size) == 2 \
120 | and isinstance(size[0], int) and isinstance(size[1], int) \
121 | and size[0] > 0 and size[1] > 0:
122 | self.crop_h = size[0]
123 | self.crop_w = size[1]
124 | else:
125 | raise (RuntimeError("crop size error.\n"))
126 | if crop_type == 'center' or crop_type == 'rand':
127 | self.crop_type = crop_type
128 | else:
129 | raise (RuntimeError("crop type error: rand | center\n"))
130 | if padding is None:
131 | self.padding = padding
132 | elif isinstance(padding, list):
133 | if all(isinstance(i, numbers.Number) for i in padding):
134 | self.padding = padding
135 | else:
136 | raise (RuntimeError("padding in Crop() should be a number list\n"))
137 | if len(padding) != 3:
138 | raise (RuntimeError("padding channel is not equal with 3\n"))
139 | else:
140 | raise (RuntimeError("padding in Crop() should be a number list\n"))
141 | if isinstance(ignore_label, int):
142 | self.ignore_label = ignore_label
143 | else:
144 | raise (RuntimeError("ignore_label should be an integer number\n"))
145 |
146 | def __call__(self, image, label):
147 | h, w = label.shape
148 | pad_h = max(self.crop_h - h, 0)
149 | pad_w = max(self.crop_w - w, 0)
150 | pad_h_half = int(pad_h / 2)
151 | pad_w_half = int(pad_w / 2)
152 | if pad_h > 0 or pad_w > 0:
153 | if self.padding is None:
154 | raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n"))
155 | image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.padding)
156 | label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label)
157 | h, w = label.shape
158 | if self.crop_type == 'rand':
159 | h_off = random.randint(0, h - self.crop_h)
160 | w_off = random.randint(0, w - self.crop_w)
161 | else:
162 | h_off = int((h - self.crop_h) / 2)
163 | w_off = int((w - self.crop_w) / 2)
164 | image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
165 | label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
166 | return image, label
167 |
168 |
169 | class RandRotate(object):
170 | # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max]
171 | def __init__(self, rotate, padding, ignore_label=255, p=0.5):
172 | assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2)
173 | if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]:
174 | self.rotate = rotate
175 | else:
176 | raise (RuntimeError("segtransform.RandRotate() scale param error.\n"))
177 | assert padding is not None
178 | assert isinstance(padding, list) and len(padding) == 3
179 | if all(isinstance(i, numbers.Number) for i in padding):
180 | self.padding = padding
181 | else:
182 | raise (RuntimeError("padding in RandRotate() should be a number list\n"))
183 | assert isinstance(ignore_label, int)
184 | self.ignore_label = ignore_label
185 | self.p = p
186 |
187 | def __call__(self, image, label):
188 | if random.random() < self.p:
189 | angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random()
190 | h, w = label.shape
191 | matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
192 | image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=self.padding)
193 | label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=self.ignore_label)
194 | return image, label
195 |
196 |
197 | class RandomHorizontalFlip(object):
198 | def __init__(self, p=0.5):
199 | self.p = p
200 |
201 | def __call__(self, image, label):
202 | if random.random() < self.p:
203 | image = cv2.flip(image, 1)
204 | label = cv2.flip(label, 1)
205 | return image, label
206 |
207 |
208 | class RandomVerticalFlip(object):
209 | def __init__(self, p=0.5):
210 | self.p = p
211 |
212 | def __call__(self, image, label):
213 | if random.random() < self.p:
214 | image = cv2.flip(image, 0)
215 | label = cv2.flip(label, 0)
216 | return image, label
217 |
218 |
219 | class RandomGaussianBlur(object):
220 | def __init__(self, radius=5):
221 | self.radius = radius
222 |
223 | def __call__(self, image, label):
224 | if random.random() < 0.5:
225 | image = cv2.GaussianBlur(image, (self.radius, self.radius), 0)
226 | return image, label
227 |
228 |
229 | class RGB2BGR(object):
230 | # Converts image from RGB order to BGR order, for model initialized from Caffe
231 | def __call__(self, image, label):
232 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
233 | return image, label
234 |
235 |
236 | class BGR2RGB(object):
237 | # Converts image from BGR order to RGB order, for model initialized from Pytorch
238 | def __call__(self, image, label):
239 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
240 | return image, label
241 |
--------------------------------------------------------------------------------
/data/randaugment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import random
5 |
6 | import numpy as np
7 | import PIL
8 | import PIL.ImageOps
9 | import PIL.ImageEnhance
10 | import PIL.ImageDraw
11 | from PIL import Image
12 | import torch
13 | import torch.nn.functional as F
14 | import torchvision.transforms as transforms
15 |
16 | PARAMETER_MAX = 10
17 |
18 | def AutoContrast(img, **kwarg):
19 | return PIL.ImageOps.autocontrast(img), None
20 |
21 |
22 | def Brightness(img, v, max_v, bias=0):
23 | v = _float_parameter(v, max_v) + bias
24 | return PIL.ImageEnhance.Brightness(img).enhance(v), v
25 |
26 |
27 | def Color(img, v, max_v, bias=0):
28 | v = _float_parameter(v, max_v) + bias
29 | return PIL.ImageEnhance.Color(img).enhance(v), v
30 |
31 |
32 | def Contrast(img, v, max_v, bias=0):
33 | v = _float_parameter(v, max_v) + bias
34 | return PIL.ImageEnhance.Contrast(img).enhance(v), v
35 |
36 |
37 | def Cutout(img, v, max_v, bias=0):
38 | if v == 0:
39 | return img
40 | v = _float_parameter(v, max_v) + bias
41 | v = int(v * min(img.size))
42 | return CutoutAbs(img, v)
43 |
44 |
45 | def CutoutAbs(img, v, **kwarg):
46 | w, h = img.size
47 | x0 = np.random.uniform(0, w)
48 | y0 = np.random.uniform(0, h)
49 | x0 = int(max(0, x0 - v / 2.))
50 | y0 = int(max(0, y0 - v / 2.))
51 | x1 = int(min(w, x0 + v))
52 | y1 = int(min(h, y0 + v))
53 | xy = (x0, y0, x1, y1)
54 | # gray
55 | color = (127, 127, 127)
56 | img = img.copy()
57 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
58 | return img, xy
59 |
60 |
61 | def Equalize(img, **kwarg):
62 | return PIL.ImageOps.equalize(img), None
63 |
64 |
65 | def Identity(img, **kwarg):
66 | return img, None
67 |
68 |
69 | def Invert(img, **kwarg):
70 | return PIL.ImageOps.invert(img), None
71 |
72 |
73 | def Posterize(img, v, max_v, bias=0):
74 | v = _int_parameter(v, max_v) + bias
75 | return PIL.ImageOps.posterize(img, v), v
76 |
77 |
78 | # def Rotate(img, v, max_v, bias=0):
79 | # v = _int_parameter(v, max_v) + bias
80 | # if random.random() < 0.5:
81 | # v = -v
82 | # #return img.rotate(v), v
83 | # img_t = transforms.ToTensor()(img)
84 | # H = img_t.shape[1]
85 | # W = img_t.shape[2]
86 | # theta = np.array([[np.cos(v/180*np.pi), -np.sin(v/180*np.pi), 0], [np.sin(v/180*np.pi), np.cos(v/180*np.pi), 0]]).astype(np.float)
87 | # theta[0,1] = theta[0,1]*H/W
88 | # theta[1,0] = theta[1,0]*W/H
89 | # #theta = np.array([[np.cos(v/180*np.pi), -np.sin(v/180*np.pi)], [np.sin(v/180*np.pi), np.cos(v/180*np.pi)]]).astype(np.float)
90 | # theta = torch.Tensor(theta).unsqueeze(0)
91 |
92 | # # meshgrid_x, meshgrid_y = torch.meshgrid(torch.arange(W, dtype=torch.float), torch.arange(H, dtype=torch.float))
93 | # # meshgrid = torch.stack((meshgrid_x.t()*2/W - 1, meshgrid_y.t()*2/H - 1), dim=-1).unsqueeze(0)
94 | # # grid = torch.matmul(meshgrid, theta)
95 |
96 | # # s_h = int(abs(H - W) // 2)
97 | # # dim_last = s_h if H > W else 0
98 | # # img_t = F.pad(img_t.unsqueeze(0), (dim_last, dim_last, s_h - dim_last, s_h - dim_last)).squeeze(0)
99 | # grid = F.affine_grid(theta, img_t.unsqueeze(0).size())
100 | # img_t = F.grid_sample(img_t.unsqueeze(0), grid, mode='bilinear').squeeze(0)
101 | # # img_t = img_t[:,:,s_h:-s_h] if H > W else img_t[:,s_h:-s_h,:]
102 | # img_t = transforms.ToPILImage()(img_t)
103 | # return img_t, v
104 |
105 | def Rotate(img, v, max_v, bias=0):
106 | v = _int_parameter(v, max_v) + bias
107 | if random.random() < 0.5:
108 | v = -v
109 | return img.rotate(v, resample=Image.BILINEAR, fillcolor=(127,127,127)), v
110 |
111 | def Sharpness(img, v, max_v, bias=0):
112 | v = _float_parameter(v, max_v) + bias
113 | return PIL.ImageEnhance.Sharpness(img).enhance(v), v
114 |
115 |
116 | def ShearX(img, v, max_v, bias=0):
117 | v = _float_parameter(v, max_v) + bias
118 | if random.random() < 0.5:
119 | v = -v
120 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0), resample=Image.BILINEAR, fillcolor=(127,127,127)), v
121 |
122 |
123 | def ShearY(img, v, max_v, bias=0):
124 | v = _float_parameter(v, max_v) + bias
125 | if random.random() < 0.5:
126 | v = -v
127 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0), resample=Image.BILINEAR, fillcolor=(127,127,127)), v
128 |
129 |
130 | def Solarize(img, v, max_v, bias=0):
131 | v = _int_parameter(v, max_v) + bias
132 | return PIL.ImageOps.solarize(img, 256 - v), 256 - v
133 |
134 |
135 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128):
136 | v = _int_parameter(v, max_v) + bias
137 | if random.random() < 0.5:
138 | v = -v
139 | img_np = np.array(img).astype(np.int)
140 | img_np = img_np + v
141 | img_np = np.clip(img_np, 0, 255)
142 | img_np = img_np.astype(np.uint8)
143 | img = Image.fromarray(img_np)
144 | return PIL.ImageOps.solarize(img, threshold), threshold
145 |
146 |
147 | def TranslateX(img, v, max_v, bias=0):
148 | v = _float_parameter(v, max_v) + bias
149 | if random.random() < 0.5:
150 | v = -v
151 | v = int(v * img.size[0])
152 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0), resample=Image.BILINEAR, fillcolor=(127,127,127)), v
153 |
154 |
155 | def TranslateY(img, v, max_v, bias=0):
156 | v = _float_parameter(v, max_v) + bias
157 | if random.random() < 0.5:
158 | v = -v
159 | v = int(v * img.size[1])
160 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v), resample=Image.BILINEAR, fillcolor=(127,127,127)), v
161 |
162 |
163 | def _float_parameter(v, max_v):
164 | return float(v) * max_v / PARAMETER_MAX
165 |
166 |
167 | def _int_parameter(v, max_v):
168 | return int(v * max_v / PARAMETER_MAX)
169 |
170 |
171 | def fixmatch_augment_pool():
172 | # FixMatch paper
173 | augs = [(AutoContrast, None, None),
174 | (Brightness, 0.9, 0.05),
175 | (Color, 0.9, 0.05),
176 | (Contrast, 0.9, 0.05),
177 | (Equalize, None, None),
178 | (Identity, None, None),
179 | (Posterize, 4, 4),
180 | (Rotate, 30, 0),
181 | (Sharpness, 0.9, 0.05),
182 | (ShearX, 0.3, 0),
183 | (ShearY, 0.3, 0),
184 | (Solarize, 256, 0),
185 | (TranslateX, 0.3, 0),
186 | (TranslateY, 0.3, 0)]
187 | return augs
188 |
189 |
190 | def my_augment_pool():
191 | # Test
192 | augs = [(AutoContrast, None, None),
193 | (Brightness, 1.8, 0.1),
194 | (Color, 1.8, 0.1),
195 | (Contrast, 1.8, 0.1),
196 | (Cutout, 0.2, 0),
197 | (Equalize, None, None),
198 | (Invert, None, None),
199 | (Posterize, 4, 4),
200 | (Rotate, 30, 0),
201 | (Sharpness, 1.8, 0.1),
202 | (ShearX, 0.3, 0),
203 | (ShearY, 0.3, 0),
204 | (Solarize, 256, 0),
205 | (SolarizeAdd, 110, 0),
206 | (TranslateX, 0.45, 0),
207 | (TranslateY, 0.45, 0)]
208 | return augs
209 |
210 |
211 | class RandAugmentPC(object):
212 | def __init__(self, n, m):
213 | assert n >= 1
214 | assert 1 <= m <= 10
215 | self.n = n
216 | self.m = m
217 | self.augment_pool = my_augment_pool()
218 |
219 | def __call__(self, img):
220 | ops = random.choices(self.augment_pool, k=self.n)
221 | for op, max_v, bias in ops:
222 | prob = np.random.uniform(0.2, 0.8)
223 | if random.random() + prob >= 1:
224 | img = op(img, v=self.m, max_v=max_v, bias=bias)
225 | img = CutoutAbs(img, 16)
226 | return img
227 |
228 |
229 | class RandAugmentMC(object):
230 | def __init__(self, n, m):
231 | assert n >= 1
232 | assert 1 <= m <= 10
233 | self.n = n
234 | self.m = m
235 | self.augment_pool = fixmatch_augment_pool()
236 |
237 | def __call__(self, img, type='crc'):
238 | aug_type = {'Hflip':False, 'ShearX':1e4, 'ShearY':1e4, 'TranslateX':1e4, 'TranslateY':1e4, 'Rotate':1e4, 'CutoutAbs':1e4}
239 | if random.random() < 0.5:
240 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
241 | #aug_type.append(['Hflip', True])
242 | aug_type['Hflip'] = True
243 | if type == 'cr' or type == 'crc':
244 | ops = random.choices(self.augment_pool, k=self.n)
245 | for op, max_v, bias in ops:
246 | v = np.random.randint(1, self.m)
247 | if random.random() < 0.5:
248 | img, params = op(img, v=v, max_v=max_v, bias=bias)
249 | if op.__name__ in ['ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']:
250 | #aug_type.append([op.__name__, params])
251 | aug_type[op.__name__] = params
252 | if type == 'cc' or type == 'crc':
253 | img, params = CutoutAbs(img, min(img.size[0], img.size[1]) // 3)
254 | #aug_type.append([CutoutAbs.__name__, params])
255 | aug_type['CutoutAbs'] = params
256 | return img, aug_type
257 |
258 | def affine_sample(tensor, v, type):
259 | # tensor: B*C*H*W
260 | # v: scalar, translation param
261 | if type == 'Rotate':
262 | theta = np.array([[np.cos(v/180*np.pi), -np.sin(v/180*np.pi), 0], [np.sin(v/180*np.pi), np.cos(v/180*np.pi), 0]]).astype(np.float)
263 | elif type == 'ShearX':
264 | theta = np.array([[1, v, 0], [0, 1, 0]]).astype(np.float)
265 | elif type == 'ShearY':
266 | theta = np.array([[1, 0, 0], [v, 1, 0]]).astype(np.float)
267 | elif type == 'TranslateX':
268 | theta = np.array([[1, 0, v], [0, 1, 0]]).astype(np.float)
269 | elif type == 'TranslateY':
270 | theta = np.array([[1, 0, 0], [0, 1, v]]).astype(np.float)
271 |
272 | H = tensor.shape[2]
273 | W = tensor.shape[3]
274 | theta[0,1] = theta[0,1]*H/W
275 | theta[1,0] = theta[1,0]*W/H
276 | if type != 'Rotate':
277 | theta[0,2] = theta[0,2]*2/H + theta[0,0] + theta[0,1] - 1
278 | theta[1,2] = theta[1,2]*2/H + theta[1,0] + theta[1,1] - 1
279 |
280 | theta = torch.Tensor(theta).unsqueeze(0)
281 | grid = F.affine_grid(theta, tensor.size()).to(tensor.device)
282 | tensor_t = F.grid_sample(tensor, grid, mode='nearest')
283 | return tensor_t
284 |
285 | if __name__ == '__main__':
286 | randaug = RandAugmentMC(2, 10)
287 | #path = r'E:\WorkHome\IMG_20190131_142431.jpg'
288 | path = r'E:\WorkHome\0.png'
289 | img = Image.open(path)
290 | img_t = transforms.ToTensor()(img).unsqueeze(0)
291 | #img_aug, aug_type = randaug(img)
292 | #img_aug.show()
293 |
294 | # v = 20
295 | # img_pil = img.rotate(v)
296 | # img_T = affine_sample(img_t, v, 'Rotate')
297 |
298 | v = 0.12
299 | img_pil = img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
300 | img_T = affine_sample(img_t, v, 'ShearY')
301 |
302 | img_ten = transforms.ToPILImage()(img_T.squeeze(0))
303 | img_pil.show()
304 | img_ten.show()
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from mmcv.runner import load_checkpoint
3 | from torchvision.models.utils import load_state_dict_from_url
4 |
5 | BatchNorm = nn.BatchNorm2d
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
9 |
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
23 | """3x3 convolution with padding"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=dilation, groups=groups, bias=False, dilation=dilation)
26 |
27 |
28 | def conv1x1(in_planes, out_planes, stride=1):
29 | """1x1 convolution"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 |
36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
37 | base_width=64, dilation=1, norm_layer=None):
38 | super(BasicBlock, self).__init__()
39 | if norm_layer is None:
40 | norm_layer = nn.BatchNorm2d
41 | if groups != 1 or base_width != 64:
42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
43 | if dilation > 1:
44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
46 | self.conv1 = conv3x3(inplanes, planes, stride)
47 | self.bn1 = norm_layer(planes)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.conv2 = conv3x3(planes, planes)
50 | self.bn2 = norm_layer(planes)
51 | self.downsample = downsample
52 | self.stride = stride
53 |
54 | def forward(self, x):
55 | identity = x
56 |
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 |
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 |
64 | if self.downsample is not None:
65 | identity = self.downsample(x)
66 |
67 | out += identity
68 | out = self.relu(out)
69 |
70 | return out
71 |
72 |
73 | class Bottleneck(nn.Module):
74 | expansion = 4
75 |
76 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
77 | base_width=64, dilation=1, norm_layer=None):
78 | super(Bottleneck, self).__init__()
79 | if norm_layer is None:
80 | norm_layer = nn.BatchNorm2d
81 | width = int(planes * (base_width / 64.)) * groups
82 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
83 | self.conv1 = conv1x1(inplanes, width)
84 | self.bn1 = norm_layer(width)
85 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
86 | self.bn2 = norm_layer(width)
87 | self.conv3 = conv1x1(width, planes * self.expansion)
88 | self.bn3 = norm_layer(planes * self.expansion)
89 | self.relu = nn.ReLU(inplace=True)
90 | self.downsample = downsample
91 | self.stride = stride
92 |
93 | def forward(self, x):
94 | identity = x
95 |
96 | out = self.conv1(x)
97 | out = self.bn1(out)
98 | out = self.relu(out)
99 |
100 | out = self.conv2(out)
101 | out = self.bn2(out)
102 | out = self.relu(out)
103 |
104 | out = self.conv3(out)
105 | out = self.bn3(out)
106 |
107 | if self.downsample is not None:
108 | identity = self.downsample(x)
109 |
110 | out += identity
111 | out = self.relu(out)
112 |
113 | return out
114 |
115 |
116 | class ResNet(nn.Module):
117 |
118 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
119 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
120 | norm_layer=None):
121 | super(ResNet, self).__init__()
122 | if norm_layer is None:
123 | norm_layer = nn.BatchNorm2d
124 | self._norm_layer = norm_layer
125 |
126 | self.inplanes = 64
127 | self.dilation = 1
128 | if replace_stride_with_dilation is None:
129 | # each element in the tuple indicates if we should replace
130 | # the 2x2 stride with a dilated convolution instead
131 | replace_stride_with_dilation = [False, False, False]
132 | if len(replace_stride_with_dilation) != 3:
133 | raise ValueError("replace_stride_with_dilation should be None "
134 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
135 | self.groups = groups
136 | self.base_width = width_per_group
137 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
138 | bias=False)
139 | self.bn1 = norm_layer(self.inplanes)
140 | self.relu = nn.ReLU(inplace=True)
141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
142 | self.layer1 = self._make_layer(block, 64, layers[0])
143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
144 | dilate=replace_stride_with_dilation[0])
145 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
146 | dilate=replace_stride_with_dilation[1])
147 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
148 | dilate=replace_stride_with_dilation[2])
149 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
150 | self.fc = nn.Linear(512 * block.expansion, num_classes)
151 |
152 | for m in self.modules():
153 | if isinstance(m, nn.Conv2d):
154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
155 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
156 | nn.init.constant_(m.weight, 1)
157 | nn.init.constant_(m.bias, 0)
158 |
159 | # Zero-initialize the last BN in each residual branch,
160 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
161 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
162 | if zero_init_residual:
163 | for m in self.modules():
164 | if isinstance(m, Bottleneck):
165 | nn.init.constant_(m.bn3.weight, 0)
166 | elif isinstance(m, BasicBlock):
167 | nn.init.constant_(m.bn2.weight, 0)
168 |
169 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
170 | norm_layer = self._norm_layer
171 | downsample = None
172 | previous_dilation = self.dilation
173 | if dilate:
174 | self.dilation *= stride
175 | stride = 1
176 | if stride != 1 or self.inplanes != planes * block.expansion:
177 | downsample = nn.Sequential(
178 | conv1x1(self.inplanes, planes * block.expansion, stride),
179 | norm_layer(planes * block.expansion),
180 | )
181 |
182 | layers = []
183 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
184 | self.base_width, previous_dilation, norm_layer))
185 | self.inplanes = planes * block.expansion
186 | for _ in range(1, blocks):
187 | layers.append(block(self.inplanes, planes, groups=self.groups,
188 | base_width=self.base_width, dilation=self.dilation,
189 | norm_layer=norm_layer))
190 |
191 | return nn.Sequential(*layers)
192 |
193 | def forward(self, x):
194 |
195 | x = self.conv1(x)
196 | x = self.bn1(x)
197 | x = self.relu(x)
198 | x = self.maxpool(x)
199 | x = self.layer1(x)
200 | x = self.layer2(x)
201 | x = self.layer3(x)
202 | x = self.layer4(x)
203 |
204 | x = self.avgpool(x)
205 | x = x.reshape(x.size(0), -1)
206 | x = self.fc(x)
207 |
208 | return x
209 |
210 |
211 | def _resnet(arch, block, layers, pretrained, progress, pretrained_weights, **kwargs):
212 | model = ResNet(block, layers, **kwargs)
213 | if pretrained:
214 | # load_checkpoint(model, pretrained_weights, map_location='cpu')
215 | import torch
216 | import os
217 | if os.path.exists('./pretrained/resnet101-5d3b4d8f.pth'):
218 | saved_state_dict = torch.load('./pretrained/resnet101-5d3b4d8f.pth', map_location='cpu')
219 | print("load weight from ./pretrained/resnet101-5d3b4d8f.pth")
220 | else:
221 | raise ValueError("No saved_state_dict loaded")
222 | model.load_state_dict(saved_state_dict)
223 | return model
224 |
225 |
226 | def resnet18(pretrained=False, progress=True, **kwargs):
227 | """Constructs a ResNet-18 model.
228 |
229 | Args:
230 | pretrained (bool): If True, returns a model pre-trained on ImageNet
231 | progress (bool): If True, displays a progress bar of the download to stderr
232 | """
233 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
234 | **kwargs)
235 |
236 |
237 | def resnet34(pretrained=False, progress=True, **kwargs):
238 | """Constructs a ResNet-34 model.
239 |
240 | Args:
241 | pretrained (bool): If True, returns a model pre-trained on ImageNet
242 | progress (bool): If True, displays a progress bar of the download to stderr
243 | """
244 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
245 | **kwargs)
246 |
247 |
248 | def resnet50(pretrained=False, progress=True, **kwargs):
249 | """Constructs a ResNet-50 model.
250 |
251 | Args:
252 | pretrained (bool): If True, returns a model pre-trained on ImageNet
253 | progress (bool): If True, displays a progress bar of the download to stderr
254 | """
255 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
256 | **kwargs)
257 |
258 |
259 | def resnet101(pretrained=False, progress=True, **kwargs):
260 | """Constructs a ResNet-101 model.
261 |
262 | Args:
263 | pretrained (bool): If True, returns a model pre-trained on ImageNet
264 | progress (bool): If True, displays a progress bar of the download to stderr
265 | """
266 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
267 | **kwargs)
268 |
269 |
270 | def resnet152(pretrained=False, progress=True, **kwargs):
271 | """Constructs a ResNet-152 model.
272 |
273 | Args:
274 | pretrained (bool): If True, returns a model pre-trained on ImageNet
275 | progress (bool): If True, displays a progress bar of the download to stderr
276 | """
277 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
278 | **kwargs)
279 |
280 |
281 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
282 | """Constructs a ResNeXt-50 32x4d model.
283 |
284 | Args:
285 | pretrained (bool): If True, returns a model pre-trained on ImageNet
286 | progress (bool): If True, displays a progress bar of the download to stderr
287 | """
288 | kwargs['groups'] = 32
289 | kwargs['width_per_group'] = 4
290 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
291 | pretrained, progress, **kwargs)
292 |
293 |
294 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
295 | """Constructs a ResNeXt-101 32x8d model.
296 |
297 | Args:
298 | pretrained (bool): If True, returns a model pre-trained on ImageNet
299 | progress (bool): If True, displays a progress bar of the download to stderr
300 | """
301 | kwargs['groups'] = 32
302 | kwargs['width_per_group'] = 8
303 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
304 | pretrained, progress, **kwargs)
305 |
--------------------------------------------------------------------------------
/data/augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 | # Adapted from https://github.com/ZijunDeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py
4 |
5 | import math
6 | import numbers
7 | import random
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import torchvision.transforms.functional as tf
12 |
13 | from PIL import Image, ImageOps
14 |
15 |
16 | class Compose(object):
17 | def __init__(self, augmentations):
18 | self.augmentations = augmentations
19 | self.PIL2Numpy = False
20 |
21 | def __call__(self, img, mask, mask1=None, lpsoft=None):
22 | params = {}
23 | if isinstance(img, np.ndarray):
24 | img = Image.fromarray(img, mode="RGB")
25 | mask = Image.fromarray(mask, mode="L")
26 | if mask1 is not None:
27 | mask1 = Image.fromarray(mask1, mode="L")
28 | if lpsoft is not None:
29 | lpsoft = torch.from_numpy(lpsoft)
30 | lpsoft = F.interpolate(lpsoft.unsqueeze(0), size=[img.size[1], img.size[0]], mode='bilinear', align_corners=True)[0]
31 | self.PIL2Numpy = True
32 |
33 | if img.size != mask.size:
34 | print (img.size, mask.size)
35 | assert img.size == mask.size
36 | if mask1 is not None:
37 | assert (img.size == mask1.size)
38 | for a in self.augmentations:
39 | img, mask, mask1, lpsoft, params = a(img, mask, mask1, lpsoft, params)
40 | # print(img.size)
41 |
42 | if self.PIL2Numpy:
43 | img, mask = np.array(img), np.array(mask, dtype=np.uint8)
44 | if mask1 is not None:
45 | mask1 = np.array(mask1, dtype=np.uint8)
46 | return img, mask, mask1, lpsoft, params
47 |
48 |
49 | class RandomCrop(object):
50 | def __init__(self, size, padding=0):
51 | if isinstance(size, numbers.Number):
52 | self.size = (int(size), int(size))
53 | else:
54 | self.size = size
55 | self.padding = padding
56 |
57 | def __call__(self, img, mask, mask1=None, lpsoft=None, params=None):
58 | if self.padding > 0:
59 | img = ImageOps.expand(img, border=self.padding, fill=0)
60 | mask = ImageOps.expand(mask, border=self.padding, fill=0)
61 | if mask1 is not None:
62 | mask1 = ImageOps.expand(mask1, border=self.padding, fill=0)
63 |
64 | assert img.size == mask.size
65 | if mask1 is not None:
66 | assert (img.size == mask1.size)
67 | w, h = img.size
68 |
69 | # print("self.size: ", self.size)
70 |
71 | tw, th = self.size
72 | # if w == tw and h == th:
73 | # return img, mask
74 | if w < tw or h < th:
75 | if lpsoft is not None:
76 | lpsoft = F.interpolate(lpsoft.unsqueeze(0), size=[th, tw], mode='bolinear', align_corners=True)[0]
77 | if mask1 is not None:
78 | return (
79 | img.resize((tw, th), Image.BILINEAR),
80 | mask.resize((tw, th), Image.NEAREST),
81 | mask1.resize((tw, th), Image.NEAREST),
82 | lpsoft
83 | )
84 | else:
85 | return (
86 | img.resize((tw, th), Image.BILINEAR),
87 | mask.resize((tw, th), Image.NEAREST),
88 | None,
89 | lpsoft
90 | )
91 |
92 | x1 = random.randint(0, w - tw)
93 | y1 = random.randint(0, h - th)
94 | params['RandomCrop'] = (y1, y1 + th, x1, x1 + tw)
95 | if lpsoft is not None:
96 | lpsoft = lpsoft[:, y1:y1 + th, x1:x1 + tw]
97 | if mask1 is not None:
98 | return (
99 | img.crop((x1, y1, x1 + tw, y1 + th)),
100 | mask.crop((x1, y1, x1 + tw, y1 + th)),
101 | mask1.crop((x1, y1, x1 + tw, y1 + th)),
102 | lpsoft,
103 | params
104 | )
105 | else:
106 | return (
107 | img.crop((x1, y1, x1 + tw, y1 + th)),
108 | mask.crop((x1, y1, x1 + tw, y1 + th)),
109 | None,
110 | lpsoft,
111 | params
112 | )
113 |
114 |
115 | class AdjustGamma(object):
116 | def __init__(self, gamma):
117 | self.gamma = gamma
118 |
119 | def __call__(self, img, mask):
120 | assert img.size == mask.size
121 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask
122 |
123 |
124 | class AdjustSaturation(object):
125 | def __init__(self, saturation):
126 | self.saturation = saturation
127 |
128 | def __call__(self, img, mask):
129 | assert img.size == mask.size
130 | return tf.adjust_saturation(img,
131 | random.uniform(1 - self.saturation,
132 | 1 + self.saturation)), mask
133 |
134 |
135 | class AdjustHue(object):
136 | def __init__(self, hue):
137 | self.hue = hue
138 |
139 | def __call__(self, img, mask):
140 | assert img.size == mask.size
141 | return tf.adjust_hue(img, random.uniform(-self.hue,
142 | self.hue)), mask
143 |
144 |
145 | class AdjustBrightness(object):
146 | def __init__(self, bf):
147 | self.bf = bf
148 |
149 | def __call__(self, img, mask):
150 | assert img.size == mask.size
151 | return tf.adjust_brightness(img,
152 | random.uniform(1 - self.bf,
153 | 1 + self.bf)), mask
154 |
155 | class AdjustContrast(object):
156 | def __init__(self, cf):
157 | self.cf = cf
158 |
159 | def __call__(self, img, mask):
160 | assert img.size == mask.size
161 | return tf.adjust_contrast(img,
162 | random.uniform(1 - self.cf,
163 | 1 + self.cf)), mask
164 |
165 | class CenterCrop(object):
166 | def __init__(self, size):
167 | if isinstance(size, numbers.Number):
168 | self.size = (int(size), int(size))
169 | else:
170 | self.size = size
171 |
172 | def __call__(self, img, mask):
173 | assert img.size == mask.size
174 | w, h = img.size
175 | th, tw = self.size
176 | x1 = int(round((w - tw) / 2.))
177 | y1 = int(round((h - th) / 2.))
178 | return (
179 | img.crop((x1, y1, x1 + tw, y1 + th)),
180 | mask.crop((x1, y1, x1 + tw, y1 + th)),
181 | )
182 |
183 |
184 | class RandomHorizontallyFlip(object):
185 | def __init__(self, p):
186 | self.p = p
187 |
188 | def __call__(self, img, mask, mask1=None, lpsoft=None, params=None):
189 | if random.random() < self.p:
190 | params['RandomHorizontallyFlip'] = True
191 | if lpsoft is not None:
192 | inv_idx = torch.arange(lpsoft.size(2)-1,-1,-1).long() # C x H x W
193 | lpsoft = lpsoft.index_select(2,inv_idx)
194 | if mask1 is not None:
195 | return (
196 | img.transpose(Image.FLIP_LEFT_RIGHT),
197 | mask.transpose(Image.FLIP_LEFT_RIGHT),
198 | mask1.transpose(Image.FLIP_LEFT_RIGHT),
199 | lpsoft,
200 | params
201 | )
202 | else:
203 | return (
204 | img.transpose(Image.FLIP_LEFT_RIGHT),
205 | mask.transpose(Image.FLIP_LEFT_RIGHT),
206 | None,
207 | lpsoft,
208 | params
209 | )
210 | else:
211 | params['RandomHorizontallyFlip'] = False
212 | return img, mask, mask1, lpsoft, params
213 |
214 |
215 | class RandomVerticallyFlip(object):
216 | def __init__(self, p):
217 | self.p = p
218 |
219 | def __call__(self, img, mask):
220 | if random.random() < self.p:
221 | return (
222 | img.transpose(Image.FLIP_TOP_BOTTOM),
223 | mask.transpose(Image.FLIP_TOP_BOTTOM),
224 | )
225 | return img, mask
226 |
227 |
228 | class FreeScale(object):
229 | def __init__(self, size):
230 | self.size = tuple(reversed(size)) # size: (h, w)
231 |
232 | def __call__(self, img, mask):
233 | assert img.size == mask.size
234 | return (
235 | img.resize(self.size, Image.BILINEAR),
236 | mask.resize(self.size, Image.NEAREST),
237 | )
238 |
239 |
240 | class RandomTranslate(object):
241 | def __init__(self, offset):
242 | self.offset = offset # tuple (delta_x, delta_y)
243 |
244 | def __call__(self, img, mask):
245 | assert img.size == mask.size
246 | x_offset = int(2 * (random.random() - 0.5) * self.offset[0])
247 | y_offset = int(2 * (random.random() - 0.5) * self.offset[1])
248 |
249 | x_crop_offset = x_offset
250 | y_crop_offset = y_offset
251 | if x_offset < 0:
252 | x_crop_offset = 0
253 | if y_offset < 0:
254 | y_crop_offset = 0
255 |
256 | cropped_img = tf.crop(img,
257 | y_crop_offset,
258 | x_crop_offset,
259 | img.size[1]-abs(y_offset),
260 | img.size[0]-abs(x_offset))
261 |
262 | if x_offset >= 0 and y_offset >= 0:
263 | padding_tuple = (0, 0, x_offset, y_offset)
264 |
265 | elif x_offset >= 0 and y_offset < 0:
266 | padding_tuple = (0, abs(y_offset), x_offset, 0)
267 |
268 | elif x_offset < 0 and y_offset >= 0:
269 | padding_tuple = (abs(x_offset), 0, 0, y_offset)
270 |
271 | elif x_offset < 0 and y_offset < 0:
272 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0)
273 |
274 | return (
275 | tf.pad(cropped_img,
276 | padding_tuple,
277 | padding_mode='reflect'),
278 | tf.affine(mask,
279 | translate=(-x_offset, -y_offset),
280 | scale=1.0,
281 | angle=0.0,
282 | shear=0.0,
283 | fillcolor=250))
284 |
285 |
286 | class RandomRotate(object):
287 | def __init__(self, degree):
288 | self.degree = degree
289 |
290 | def __call__(self, img, mask):
291 | rotate_degree = random.random() * 2 * self.degree - self.degree
292 | return (
293 | tf.affine(img,
294 | translate=(0, 0),
295 | scale=1.0,
296 | angle=rotate_degree,
297 | resample=Image.BILINEAR,
298 | fillcolor=(0, 0, 0),
299 | shear=0.0),
300 | tf.affine(mask,
301 | translate=(0, 0),
302 | scale=1.0,
303 | angle=rotate_degree,
304 | resample=Image.NEAREST,
305 | fillcolor=250,
306 | shear=0.0))
307 |
308 |
309 |
310 | class Scale(object):
311 | def __init__(self, size):
312 | self.size = size
313 |
314 | def __call__(self, img, mask):
315 | assert img.size == mask.size
316 | w, h = img.size
317 | if (w >= h and w == self.size) or (h >= w and h == self.size):
318 | return img, mask
319 | if w > h:
320 | ow = self.size
321 | oh = int(self.size * h / w)
322 | return (
323 | img.resize((ow, oh), Image.BILINEAR),
324 | mask.resize((ow, oh), Image.NEAREST),
325 | )
326 | else:
327 | oh = self.size
328 | ow = int(self.size * w / h)
329 | return (
330 | img.resize((ow, oh), Image.BILINEAR),
331 | mask.resize((ow, oh), Image.NEAREST),
332 | )
333 |
334 | def MyScale(img, lbl, size):
335 | """scale
336 |
337 | img, lbl, longer size
338 | """
339 | if isinstance(img, np.ndarray):
340 | _img = Image.fromarray(img)
341 | _lbl = Image.fromarray(lbl)
342 | else:
343 | _img = img
344 | _lbl = lbl
345 | assert _img.size == _lbl.size
346 | # prop = 1.0 * _img.size[0]/_img.size[1]
347 | w, h = size
348 | # h = int(size / prop)
349 | _img = _img.resize((w, h), Image.BILINEAR)
350 | _lbl = _lbl.resize((w, h), Image.NEAREST)
351 | return np.array(_img), np.array(_lbl)
352 |
353 | def Flip(img, lbl, prop):
354 | """
355 | flip img and lbl with probablity prop
356 | """
357 | if isinstance(img, np.ndarray):
358 | _img = Image.fromarray(img)
359 | _lbl = Image.fromarray(lbl)
360 | else:
361 | _img = img
362 | _lbl = lbl
363 | if random.random() < prop:
364 | _img.transpose(Image.FLIP_LEFT_RIGHT),
365 | _lbl.transpose(Image.FLIP_LEFT_RIGHT),
366 | return np.array(_img), np.array(_lbl)
367 |
368 | def MyRotate(img, lbl, degree):
369 | """
370 | img, lbl, degree
371 | randomly rotate clockwise or anti-clockwise
372 | """
373 | if isinstance(img, np.ndarray):
374 | _img = Image.fromarray(img)
375 | _lbl = Image.fromarray(lbl)
376 | else:
377 | _img = img
378 | _lbl = lbl
379 | _degree = random.random()*degree
380 |
381 | flags = -1
382 | if random.random() < 0.5:
383 | flags = 1
384 | _img = _img.rotate(_degree * flags)
385 | _lbl = _lbl.rotate(_degree * flags)
386 | return np.array(_img), np.array(_lbl)
387 |
388 | class RandomSizedCrop(object):
389 | def __init__(self, size):
390 | self.size = size
391 |
392 | def __call__(self, img, mask):
393 | assert img.size == mask.size
394 | for attempt in range(10):
395 | area = img.size[0] * img.size[1]
396 | target_area = random.uniform(0.45, 1.0) * area
397 | aspect_ratio = random.uniform(0.5, 2)
398 |
399 | w = int(round(math.sqrt(target_area * aspect_ratio)))
400 | h = int(round(math.sqrt(target_area / aspect_ratio)))
401 |
402 | if random.random() < 0.5:
403 | w, h = h, w
404 |
405 | if w <= img.size[0] and h <= img.size[1]:
406 | x1 = random.randint(0, img.size[0] - w)
407 | y1 = random.randint(0, img.size[1] - h)
408 |
409 | img = img.crop((x1, y1, x1 + w, y1 + h))
410 | mask = mask.crop((x1, y1, x1 + w, y1 + h))
411 | assert img.size == (w, h)
412 |
413 | return (
414 | img.resize((self.size, self.size), Image.BILINEAR),
415 | mask.resize((self.size, self.size), Image.NEAREST),
416 | )
417 |
418 | # Fallback
419 | scale = Scale(self.size)
420 | crop = CenterCrop(self.size)
421 | return crop(*scale(img, mask))
422 |
423 |
424 | class RandomSized(object):
425 | def __init__(self, size):
426 | self.size = size
427 | self.scale = Scale(self.size)
428 | self.crop = RandomCrop(self.size)
429 |
430 | def __call__(self, img, mask, mask1=None, lpsoft=None, params=None):
431 | assert img.size == mask.size
432 | if mask1 is not None:
433 | assert (img.size == mask1.size)
434 |
435 | prop = 1.0 * img.size[0] / img.size[1]
436 | w = int(random.uniform(0.5, 1.5) * self.size)
437 | #w = self.size
438 | h = int(w/prop)
439 | params['RandomSized'] = (h, w)
440 | # h = int(random.uniform(0.5, 2) * self.size[1])
441 |
442 | img, mask = (
443 | img.resize((w, h), Image.BILINEAR),
444 | mask.resize((w, h), Image.NEAREST),
445 | )
446 | if mask1 is not None:
447 | mask1 = mask1.resize((w, h), Image.NEAREST)
448 | if lpsoft is not None:
449 | lpsoft = F.interpolate(lpsoft.unsqueeze(0), size=[h, w], mode='bilinear', align_corners=True)[0]
450 |
451 | return img, mask, mask1, lpsoft, params
452 | # return self.crop(*self.scale(img, mask))
453 |
--------------------------------------------------------------------------------
/generate_soft_label.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import pickle
6 | import torch.optim as optim
7 | import scipy.misc
8 | import torch.backends.cudnn as cudnn
9 | import torch.nn.functional as F
10 | import sys
11 | import os
12 | import os.path as osp
13 | import random
14 | import logging
15 | import time
16 | import torch.distributed as dist
17 | import torch.multiprocessing as mp
18 | from tensorboardX import SummaryWriter
19 |
20 | from model.feature_extractor import resnet_feature_extractor
21 | from model.classifier import ASPP_Classifier_Gen
22 | from model.discriminator import FCDiscriminator
23 |
24 | from utils.util import *
25 | from data import create_dataset
26 | import cv2
27 |
28 | IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32)
29 | IMG_STD = np.array((0.229, 0.224, 0.225), dtype=np.float32)
30 |
31 | MODEL = 'DeepLab'
32 | BATCH_SIZE = 1
33 | ITER_SIZE = 1
34 | NUM_WORKERS = 16
35 | IGNORE_LABEL = 250
36 | LEARNING_RATE = 2.5e-4
37 | MOMENTUM = 0.9
38 | NUM_CLASSES = 19
39 | NUM_STEPS = 62500
40 | NUM_STEPS_STOP = 40000 # early stopping
41 | POWER = 0.9
42 | RANDOM_SEED = 1234
43 | RESUME = './pretrained/model_phase1.pth'
44 | SAVE_NUM_IMAGES = 2
45 | SAVE_PRED_EVERY = 1000
46 | SNAPSHOT_DIR = './snapshots/'
47 | WEIGHT_DECAY = 0.0005
48 | LOG_DIR = './log'
49 |
50 | LEARNING_RATE_D = 1e-4
51 | LAMBDA_SEG = 0.1
52 | LAMBDA_ADV_TARGET1 = 0.0002
53 | LAMBDA_ADV_TARGET2 = 0.001
54 |
55 | SET = 'train'
56 |
57 | def get_arguments():
58 | """Parse all the arguments provided from the CLI.
59 |
60 | Returns:
61 | A list of parsed arguments.
62 | """
63 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network")
64 | parser.add_argument("--model", type=str, default=MODEL,
65 | help="available options : DeepLab")
66 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
67 | help="Number of images sent to the network in one step.")
68 | parser.add_argument("--iter-size", type=int, default=ITER_SIZE,
69 | help="Accumulate gradients for ITER_SIZE iterations.")
70 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS,
71 | help="number of workers for multithread dataloading.")
72 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
73 | help="The index of the label to ignore during the training.")
74 | parser.add_argument("--is-training", action="store_true",
75 | help="Whether to updates the running means and variances during the training.")
76 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,
77 | help="Base learning rate for training with polynomial decay.")
78 | parser.add_argument("--learning-rate-D", type=float, default=LEARNING_RATE_D,
79 | help="Base learning rate for discriminator.")
80 | parser.add_argument("--lambda-seg", type=float, default=LAMBDA_SEG,
81 | help="lambda_seg.")
82 | parser.add_argument("--lambda-adv-target1", type=float, default=LAMBDA_ADV_TARGET1,
83 | help="lambda_adv for adversarial training.")
84 | parser.add_argument("--lambda-adv-target2", type=float, default=LAMBDA_ADV_TARGET2,
85 | help="lambda_adv for adversarial training.")
86 | parser.add_argument("--momentum", type=float, default=MOMENTUM,
87 | help="Momentum component of the optimiser.")
88 | parser.add_argument("--not-restore-last", action="store_true",
89 | help="Whether to not restore last (FC) layers.")
90 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
91 | help="Number of classes to predict (including background).")
92 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS,
93 | help="Number of training steps.")
94 | parser.add_argument("--num-steps-stop", type=int, default=NUM_STEPS_STOP,
95 | help="Number of training steps for early stopping.")
96 | parser.add_argument("--power", type=float, default=POWER,
97 | help="Decay parameter to compute the learning rate.")
98 | parser.add_argument("--random-mirror", action="store_true",
99 | help="Whether to randomly mirror the inputs during the training.")
100 | parser.add_argument("--random-scale", action="store_true",
101 | help="Whether to randomly scale the inputs during the training.")
102 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,
103 | help="Random seed to have reproducible results.")
104 | parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES,
105 | help="How many images to save.")
106 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY,
107 | help="Save summaries and checkpoint every often.")
108 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,
109 | help="Where to save snapshots of the model.")
110 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,
111 | help="Regularisation parameter for L2-loss.")
112 | parser.add_argument("--cpu", action='store_true', help="choose to use cpu device.")
113 | parser.add_argument("--tensorboard", action='store_true', help="choose whether to use tensorboard.")
114 | parser.add_argument("--log-dir", type=str, default=LOG_DIR,
115 | help="Path to the directory of log.")
116 | parser.add_argument("--set", type=str, default=SET,
117 | help="choose adaptation set.")
118 | parser.add_argument("--gpus", type=str, default="0,1", help="selected gpus")
119 | parser.add_argument("--dist", action="store_true", help="DDP")
120 | parser.add_argument("--ngpus_per_node", type=int, default=1, help='number of gpus in each node')
121 | parser.add_argument("--print-every", type=int, default=20, help='output message every n iterations')
122 |
123 | parser.add_argument("--src_dataset", type=str, default="gta5", help='training source dataset')
124 | parser.add_argument("--tgt_dataset", type=str, default="cityscapes_train", help='training target dataset')
125 | parser.add_argument("--tgt_val_dataset", type=str, default="cityscapes_val", help='training target dataset')
126 | parser.add_argument("--noaug", action="store_true", help="augmentation")
127 | parser.add_argument('--resize', type=int, default=2200, help='resize long size')
128 | parser.add_argument("--clrjit_params", type=str, default="0.5,0.5,0.5,0.2", help='brightness,contrast,saturation,hue')
129 | parser.add_argument('--rcrop', type=str, default='896,512', help='rondom crop size')
130 | parser.add_argument('--hflip', type=float, default=0.5, help='random flip probility')
131 | parser.add_argument('--src_rootpath', type=str, default='datasets/gta5')
132 | parser.add_argument('--tgt_rootpath', type=str, default='datasets/cityscapes')
133 | parser.add_argument('--noshuffle', action='store_true', help='do not use shuffle')
134 | parser.add_argument('--no_droplast', action='store_true')
135 | parser.add_argument('--pseudo_labels_folder', type=str, default='')
136 | parser.add_argument("--batch_size_val", type=int, default=4, help='batch_size for validation')
137 | parser.add_argument("--resume", type=str, default=RESUME, help='resume weight')
138 | parser.add_argument("--freeze_bn", action="store_true", help="augmentation")
139 | parser.add_argument("--hidden_dim", type=int, default=128, help='number of selected negative samples')
140 | parser.add_argument("--layer", type=int, default=1, help='separate from which layer')
141 | parser.add_argument("--output_folder", type=str, default="", help='output folder')
142 | return parser.parse_args()
143 |
144 |
145 | args = get_arguments()
146 |
147 | def main_worker(gpu, world_size, dist_url):
148 | """Create the model and start the training."""
149 | if gpu == 0:
150 | if not os.path.exists(args.snapshot_dir):
151 | os.makedirs(args.snapshot_dir)
152 | logFilename = os.path.join(args.snapshot_dir, str(time.time()))
153 | logging.basicConfig(
154 | level = logging.INFO,
155 | format ='%(asctime)s-%(levelname)s-%(message)s',
156 | datefmt = '%y-%m-%d %H:%M',
157 | filename = logFilename,
158 | filemode = 'w+')
159 | filehandler = logging.FileHandler(logFilename, encoding='utf-8')
160 | logger = logging.getLogger()
161 | logger.addHandler(filehandler)
162 | handler = logging.StreamHandler()
163 | logger.addHandler(handler)
164 | logger.info(args)
165 |
166 | np.random.seed(args.random_seed)
167 | random.seed(args.random_seed)
168 | torch.manual_seed(args.random_seed)
169 | torch.cuda.manual_seed(args.random_seed)
170 | # torch.backends.cudnn.deterministic = True
171 | torch.cuda.manual_seed_all(args.random_seed) # if you are using multi-GPU.
172 | # torch.backends.cudnn.enabled = False
173 |
174 | print("gpu: {}, world_size: {}".format(gpu, world_size))
175 | print("dist_url: ", dist_url)
176 |
177 | torch.cuda.set_device(gpu)
178 | args.batch_size = args.batch_size // world_size
179 | args.batch_size_val = args.batch_size_val // world_size
180 | args.num_workers = args.num_workers // world_size
181 | dist.init_process_group(backend='nccl', init_method=dist_url, world_size=world_size, rank=gpu)
182 |
183 | if gpu == 0:
184 | logger.info("args.batch_size: {}, args.batch_size_val: {}".format(args.batch_size, args.batch_size_val))
185 |
186 | device = torch.device("cuda" if not args.cpu else "cpu")
187 | args.world_size = world_size
188 |
189 | if gpu == 0:
190 | logger.info("args: {}".format(args))
191 |
192 | # cudnn.enabled = True
193 |
194 | # Create network
195 | if args.model == 'DeepLab':
196 | if args.resume:
197 | resume_weight = torch.load(args.resume, map_location='cpu')
198 | print("args.resume: ", args.resume)
199 | # feature_extractor_weights = resume_weight['model_state_dict']
200 | model_B2_weights = resume_weight['model_B2_state_dict']
201 | model_B_weights = resume_weight['model_B_state_dict']
202 | head_weights = resume_weight['head_state_dict']
203 | classifier_weights = resume_weight['classifier_state_dict']
204 | # feature_extractor_weights = {k.replace("module.", ""):v for k,v in feature_extractor_weights.items()}
205 | model_B2_weights = {k.replace("module.", ""):v for k,v in model_B2_weights.items()}
206 | model_B_weights = {k.replace("module.", ""):v for k,v in model_B_weights.items()}
207 | head_weights = {k.replace("module.", ""):v for k,v in head_weights.items()}
208 | classifier_weights = {k.replace("module.", ""):v for k,v in classifier_weights.items()}
209 |
210 | if gpu == 0:
211 | logger.info("freeze_bn: {}".format(args.freeze_bn))
212 | model = resnet_feature_extractor('resnet101', 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', freeze_bn=args.freeze_bn)
213 |
214 | if args.layer == 0:
215 | ndf = 64
216 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool)
217 | model_B = nn.Sequential(model.backbone.layer1, model.backbone.layer2, model.backbone.layer3, model.backbone.layer4)
218 | elif args.layer == 1:
219 | ndf = 256
220 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1)
221 | model_B = nn.Sequential(model.backbone.layer2, model.backbone.layer3, model.backbone.layer4)
222 | elif args.layer == 2:
223 | ndf = 512
224 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1, model.backbone.layer2)
225 | model_B = nn.Sequential(model.backbone.layer3, model.backbone.layer4)
226 |
227 | if args.resume:
228 | model_B2.load_state_dict(model_B2_weights)
229 | model_B.load_state_dict(model_B_weights)
230 |
231 | classifier = ASPP_Classifier_Gen(2048, [6, 12, 18, 24], [6, 12, 18, 24], args.num_classes, hidden_dim=args.hidden_dim)
232 | head, classifier = classifier.head, classifier.classifier
233 | if args.resume:
234 | head.load_state_dict(head_weights)
235 | classifier.load_state_dict(classifier_weights)
236 |
237 | model_B2.train()
238 | model_B.train()
239 | head.train()
240 | classifier.train()
241 |
242 | if gpu == 0:
243 | logger.info(model_B2)
244 | logger.info(model_B)
245 | logger.info(head)
246 | logger.info(classifier)
247 | else:
248 | logger = None
249 |
250 | if gpu == 0:
251 | logger.info("args.noaug: {}, args.resize: {}, args.rcrop: {}, args.hflip: {}, args.noshuffle: {}, args.no_droplast: {}".format(args.noaug, args.resize, args.rcrop, args.hflip, args.noshuffle, args.no_droplast))
252 | args.rcrop = [int(x.strip()) for x in args.rcrop.split(",")]
253 | args.clrjit_params = [float(x) for x in args.clrjit_params.split(',')]
254 |
255 | datasets = create_dataset(args, logger)
256 | sourceloader_iter = enumerate(datasets.source_train_loader)
257 | targetloader_iter = enumerate(datasets.target_train_loader)
258 |
259 | # define model
260 | model_B2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B2)
261 | model_B2 = torch.nn.parallel.DistributedDataParallel(model_B2.cuda(), device_ids=[gpu], find_unused_parameters=True)
262 |
263 | model_B = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B)
264 | model_B = torch.nn.parallel.DistributedDataParallel(model_B.cuda(), device_ids=[gpu], find_unused_parameters=True)
265 |
266 | head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(head)
267 | head = torch.nn.parallel.DistributedDataParallel(head.cuda(), device_ids=[gpu], find_unused_parameters=True)
268 |
269 | classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
270 | classifier = torch.nn.parallel.DistributedDataParallel(classifier.cuda(), device_ids=[gpu], find_unused_parameters=True)
271 | seg_loss = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label)
272 | interp = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True)
273 | interp_target = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True)
274 |
275 | # labels for adversarial training
276 | source_label = 0
277 | target_label = 1
278 |
279 | # set up tensor board
280 | if args.tensorboard and gpu == 0:
281 | writer = SummaryWriter(args.snapshot_dir)
282 |
283 | validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_train_loader, args.output_folder)
284 | # exit()
285 |
286 | def validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger, testloader, output_folder):
287 | if gpu == 0:
288 | logger.info("Start Evaluation")
289 | # evaluate
290 | loss_meter = AverageMeter()
291 | intersection_meter = AverageMeter()
292 | union_meter = AverageMeter()
293 |
294 | model_B2.eval()
295 | model_B.eval()
296 | head.eval()
297 | classifier.eval()
298 |
299 | with torch.no_grad():
300 | for i, batch in enumerate(testloader):
301 | images = batch["img_full"].cuda()
302 | labels = batch["lbl_full"].cuda()
303 | img_paths = batch['img_path']
304 |
305 | pred = model_B(model_B2(images))
306 | pred = classifier(head(pred))
307 | output = F.interpolate(pred, size=labels.size()[-2:], mode='bilinear', align_corners=True)
308 | loss = seg_loss(output, labels)
309 |
310 | output = F.softmax(output, 1)
311 |
312 | output_np = pred.detach().cpu().numpy().squeeze()
313 |
314 | logits, output = output.max(1)
315 |
316 | for b in range(output_np.shape[0]):
317 | mask_filename = img_paths[b].split("/")[-1].split(".")[0]
318 | np.save(os.path.join(output_folder, mask_filename+".npy"), output_np[b])
319 |
320 | intersection, union, _ = intersectionAndUnionGPU(output, labels, args.num_classes, args.ignore_label)
321 | dist.all_reduce(intersection), dist.all_reduce(union)
322 | intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
323 | intersection_meter.update(intersection), union_meter.update(union)
324 | loss_meter.update(loss.item(), images.size(0))
325 | if gpu == 0 and i % 50 == 0 and i != 0:
326 | logger.info("Evaluation iter = {0:5d}/{1:5d}, loss_eval = {2:.3f}".format(
327 | i, len(testloader), loss_meter.val
328 | ))
329 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
330 | miou = np.mean(iou_class)
331 | if gpu == 0:
332 | logger.info("Val result: mIoU = {:.3f}".format(miou))
333 | for i in range(args.num_classes):
334 | logger.info("Class_{} Result: iou = {:.3f}".format(i, iou_class[i]))
335 | logger.info("End Evaluation")
336 |
337 | return miou, loss_meter.avg
338 |
339 | def find_free_port():
340 | import socket
341 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
342 | # Binding to port 0 will cause the OS to find an available port for us
343 | sock.bind(("", 0))
344 | port = sock.getsockname()[1]
345 | sock.close()
346 | # NOTE: there is still a chance the port could be taken by other processes.
347 | return port
348 |
349 | if __name__ == '__main__':
350 | args.gpus = [int(x) for x in args.gpus.split(",")]
351 | args.world_size = len(args.gpus)
352 |
353 | os.makedirs(args.output_folder, exist_ok=True)
354 |
355 | if args.dist:
356 | port = find_free_port()
357 | args.dist_url = f"tcp://127.0.0.1:{port}"
358 | mp.spawn(main_worker, nprocs=args.world_size, args=(args.world_size, args.dist_url))
359 | else:
360 | main_worker(args.train_gpu, args.world_size, args)
361 |
362 |
--------------------------------------------------------------------------------
/train_phase2.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import pickle
6 | import torch.optim as optim
7 | import scipy.misc
8 | import torch.backends.cudnn as cudnn
9 | import torch.nn.functional as F
10 | import sys
11 | import os
12 | import os.path as osp
13 | import random
14 | import logging
15 | import time
16 | import torch.distributed as dist
17 | import torch.multiprocessing as mp
18 | from tensorboardX import SummaryWriter
19 |
20 | from model.feature_extractor import resnet_feature_extractor
21 | from model.classifier import ASPP_Classifier_Gen
22 | from model.discriminator import FCDiscriminator
23 |
24 | from utils.util import *
25 | from data import create_dataset
26 | import cv2
27 |
28 | IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32)
29 | IMG_STD = np.array((0.229, 0.224, 0.225), dtype=np.float32)
30 |
31 | MODEL = 'DeepLab'
32 | BATCH_SIZE = 1
33 | ITER_SIZE = 1
34 | NUM_WORKERS = 16
35 | IGNORE_LABEL = 250
36 | LEARNING_RATE = 2.5e-4
37 | MOMENTUM = 0.9
38 | NUM_CLASSES = 19
39 | NUM_STEPS = 62500
40 | NUM_STEPS_STOP = 40000 # early stopping
41 | POWER = 0.9
42 | RANDOM_SEED = 1234
43 | RESUME = './pretrained/model_phase1.pth'
44 | SAVE_NUM_IMAGES = 2
45 | SAVE_PRED_EVERY = 1000
46 | SNAPSHOT_DIR = './snapshots/'
47 | WEIGHT_DECAY = 0.0005
48 | LOG_DIR = './log'
49 |
50 | LEARNING_RATE_D = 1e-4
51 | LAMBDA_SEG = 0.1
52 | LAMBDA_ADV_TARGET1 = 0.0002
53 | LAMBDA_ADV_TARGET2 = 0.001
54 |
55 | SET = 'train'
56 |
57 | def get_arguments():
58 | """Parse all the arguments provided from the CLI.
59 |
60 | Returns:
61 | A list of parsed arguments.
62 | """
63 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network")
64 | parser.add_argument("--model", type=str, default=MODEL,
65 | help="available options : DeepLab")
66 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
67 | help="Number of images sent to the network in one step.")
68 | parser.add_argument("--iter-size", type=int, default=ITER_SIZE,
69 | help="Accumulate gradients for ITER_SIZE iterations.")
70 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS,
71 | help="number of workers for multithread dataloading.")
72 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
73 | help="The index of the label to ignore during the training.")
74 | parser.add_argument("--is-training", action="store_true",
75 | help="Whether to updates the running means and variances during the training.")
76 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,
77 | help="Base learning rate for training with polynomial decay.")
78 | parser.add_argument("--learning-rate-D", type=float, default=LEARNING_RATE_D,
79 | help="Base learning rate for discriminator.")
80 | parser.add_argument("--lambda-seg", type=float, default=LAMBDA_SEG,
81 | help="lambda_seg.")
82 | parser.add_argument("--lambda-adv-target1", type=float, default=LAMBDA_ADV_TARGET1,
83 | help="lambda_adv for adversarial training.")
84 | parser.add_argument("--lambda-adv-target2", type=float, default=LAMBDA_ADV_TARGET2,
85 | help="lambda_adv for adversarial training.")
86 | parser.add_argument("--momentum", type=float, default=MOMENTUM,
87 | help="Momentum component of the optimiser.")
88 | parser.add_argument("--not-restore-last", action="store_true",
89 | help="Whether to not restore last (FC) layers.")
90 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
91 | help="Number of classes to predict (including background).")
92 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS,
93 | help="Number of training steps.")
94 | parser.add_argument("--num-steps-stop", type=int, default=NUM_STEPS_STOP,
95 | help="Number of training steps for early stopping.")
96 | parser.add_argument("--power", type=float, default=POWER,
97 | help="Decay parameter to compute the learning rate.")
98 | parser.add_argument("--random-mirror", action="store_true",
99 | help="Whether to randomly mirror the inputs during the training.")
100 | parser.add_argument("--random-scale", action="store_true",
101 | help="Whether to randomly scale the inputs during the training.")
102 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,
103 | help="Random seed to have reproducible results.")
104 | parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES,
105 | help="How many images to save.")
106 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY,
107 | help="Save summaries and checkpoint every often.")
108 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,
109 | help="Where to save snapshots of the model.")
110 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,
111 | help="Regularisation parameter for L2-loss.")
112 | parser.add_argument("--cpu", action='store_true', help="choose to use cpu device.")
113 | parser.add_argument("--tensorboard", action='store_true', help="choose whether to use tensorboard.")
114 | parser.add_argument("--log-dir", type=str, default=LOG_DIR,
115 | help="Path to the directory of log.")
116 | parser.add_argument("--set", type=str, default=SET,
117 | help="choose adaptation set.")
118 | parser.add_argument("--gpus", type=str, default="0,1", help="selected gpus")
119 | parser.add_argument("--dist", action="store_true", help="DDP")
120 | parser.add_argument("--ngpus_per_node", type=int, default=1, help='number of gpus in each node')
121 | parser.add_argument("--print-every", type=int, default=20, help='output message every n iterations')
122 |
123 | parser.add_argument("--src_dataset", type=str, default="gta5", help='training source dataset')
124 | parser.add_argument("--tgt_dataset", type=str, default="cityscapes_train", help='training target dataset')
125 | parser.add_argument("--tgt_val_dataset", type=str, default="cityscapes_val", help='training target dataset')
126 | parser.add_argument("--noaug", action="store_true", help="augmentation")
127 | parser.add_argument('--resize', type=int, default=2200, help='resize long size')
128 | parser.add_argument("--clrjit_params", type=str, default="0.5,0.5,0.5,0.2", help='brightness,contrast,saturation,hue')
129 | parser.add_argument('--rcrop', type=str, default='896,512', help='rondom crop size')
130 | parser.add_argument('--hflip', type=float, default=0.5, help='random flip probility')
131 | parser.add_argument('--src_rootpath', type=str, default='datasets/gta5')
132 | parser.add_argument('--tgt_rootpath', type=str, default='datasets/cityscapes')
133 | parser.add_argument('--noshuffle', action='store_true', help='do not use shuffle')
134 | parser.add_argument('--no_droplast', action='store_true')
135 | parser.add_argument('--pseudo_labels_folder', type=str, default='')
136 | parser.add_argument('--soft_labels_folder', type=str, default='')
137 | parser.add_argument('--src_loss_weight', type=float, default=1.0, help='loss weight for source domain loss')
138 | parser.add_argument('--thresholds_path', type=str, default="avg", help='avg | pred_only | fix_only')
139 |
140 | parser.add_argument("--batch_size_val", type=int, default=4, help='batch_size for validation')
141 | parser.add_argument("--resume", type=str, default=RESUME, help='resume weight')
142 | parser.add_argument("--freeze_bn", action="store_true", help="augmentation")
143 | parser.add_argument("--hidden_dim", type=int, default=128, help='number of selected negative samples')
144 | parser.add_argument("--layer", type=int, default=1, help='separate from which layer')
145 | return parser.parse_args()
146 |
147 |
148 | args = get_arguments()
149 |
150 |
151 | def soft_label_cross_entropy(pred, soft_label, pixel_weights=None):
152 | N, C, H, W = pred.shape
153 | loss = -soft_label.float()*F.log_softmax(pred, dim=1)
154 | if pixel_weights is None:
155 | return torch.mean(torch.sum(loss, dim=1))
156 | return torch.mean(pixel_weights*torch.sum(loss, dim=1))
157 |
158 |
159 | def lr_poly(base_lr, iter, max_iter, power):
160 | return base_lr * ((1 - float(iter) / max_iter) ** (power))
161 |
162 |
163 | def adjust_learning_rate(optimizer, i_iter):
164 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power)
165 | optimizer.param_groups[0]['lr'] = lr
166 | if len(optimizer.param_groups) > 1:
167 | optimizer.param_groups[1]['lr'] = lr * 10
168 |
169 |
170 | def adjust_learning_rate_D(optimizer, i_iter):
171 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power)
172 | optimizer.param_groups[0]['lr'] = lr
173 | if len(optimizer.param_groups) > 1:
174 | optimizer.param_groups[1]['lr'] = lr * 10
175 |
176 |
177 | def main_worker(gpu, world_size, dist_url):
178 | """Create the model and start the training."""
179 | if gpu == 0:
180 | if not os.path.exists(args.snapshot_dir):
181 | os.makedirs(args.snapshot_dir)
182 | logFilename = os.path.join(args.snapshot_dir, str(time.time()))
183 | logging.basicConfig(
184 | level = logging.INFO,
185 | format ='%(asctime)s-%(levelname)s-%(message)s',
186 | datefmt = '%y-%m-%d %H:%M',
187 | filename = logFilename,
188 | filemode = 'w+')
189 | filehandler = logging.FileHandler(logFilename, encoding='utf-8')
190 | logger = logging.getLogger()
191 | logger.addHandler(filehandler)
192 | handler = logging.StreamHandler()
193 | logger.addHandler(handler)
194 | logger.info(args)
195 |
196 | np.random.seed(args.random_seed)
197 | random.seed(args.random_seed)
198 | torch.manual_seed(args.random_seed)
199 | torch.cuda.manual_seed(args.random_seed)
200 | # torch.backends.cudnn.deterministic = True
201 | torch.cuda.manual_seed_all(args.random_seed) # if you are using multi-GPU.
202 | # torch.backends.cudnn.enabled = False
203 |
204 | print("gpu: {}, world_size: {}".format(gpu, world_size))
205 | print("dist_url: ", dist_url)
206 |
207 | torch.cuda.set_device(gpu)
208 | args.batch_size = args.batch_size // world_size
209 | args.batch_size_val = args.batch_size_val // world_size
210 | args.num_workers = args.num_workers // world_size
211 | dist.init_process_group(backend='nccl', init_method=dist_url, world_size=world_size, rank=gpu)
212 |
213 | if gpu == 0:
214 | logger.info("args.batch_size: {}, args.batch_size_val: {}".format(args.batch_size, args.batch_size_val))
215 |
216 | device = torch.device("cuda" if not args.cpu else "cpu")
217 |
218 | args.world_size = world_size
219 |
220 | if gpu == 0:
221 | logger.info("args: {}".format(args))
222 |
223 | # cudnn.enabled = True
224 |
225 | # Create network
226 | if args.model == 'DeepLab':
227 | if args.resume:
228 | resume_weight = torch.load(args.resume, map_location='cpu')
229 | print("args.resume: ", args.resume)
230 | # feature_extractor_weights = resume_weight['model_state_dict']
231 | model_B2_weights = resume_weight['model_B2_state_dict']
232 | model_B_weights = resume_weight['model_B_state_dict']
233 | head_weights = resume_weight['head_state_dict']
234 | classifier_weights = resume_weight['classifier_state_dict']
235 | model_B2_weights = {k.replace("module.", ""):v for k,v in model_B2_weights.items()}
236 | model_B_weights = {k.replace("module.", ""):v for k,v in model_B_weights.items()}
237 | head_weights = {k.replace("module.", ""):v for k,v in head_weights.items()}
238 | classifier_weights = {k.replace("module.", ""):v for k,v in classifier_weights.items()}
239 |
240 | if gpu == 0:
241 | logger.info("freeze_bn: {}".format(args.freeze_bn))
242 |
243 | model = resnet_feature_extractor('resnet101', 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', freeze_bn=args.freeze_bn)
244 |
245 | if args.layer == 0:
246 | ndf = 64
247 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool)
248 | model_B = nn.Sequential(model.backbone.layer1, model.backbone.layer2, model.backbone.layer3, model.backbone.layer4)
249 | elif args.layer == 1:
250 | ndf = 256
251 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1)
252 | model_B = nn.Sequential(model.backbone.layer2, model.backbone.layer3, model.backbone.layer4)
253 | elif args.layer == 2:
254 | ndf = 512
255 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1, model.backbone.layer2)
256 | model_B = nn.Sequential(model.backbone.layer3, model.backbone.layer4)
257 |
258 | if args.resume:
259 | model_B2.load_state_dict(model_B2_weights)
260 | model_B.load_state_dict(model_B_weights)
261 |
262 | classifier = ASPP_Classifier_Gen(2048, [6, 12, 18, 24], [6, 12, 18, 24], args.num_classes, hidden_dim=args.hidden_dim)
263 | head, classifier = classifier.head, classifier.classifier
264 | if args.resume:
265 | head.load_state_dict(head_weights)
266 | classifier.load_state_dict(classifier_weights)
267 |
268 | model_B2.train()
269 | model_B.train()
270 | head.train()
271 | classifier.train()
272 |
273 | if gpu == 0:
274 | logger.info(model_B2)
275 | logger.info(model_B)
276 | logger.info(head)
277 | logger.info(classifier)
278 | else:
279 | logger = None
280 |
281 | if gpu == 0:
282 | logger.info("args.noaug: {}, args.resize: {}, args.rcrop: {}, args.hflip: {}, args.noshuffle: {}, args.no_droplast: {}".format(args.noaug, args.resize, args.rcrop, args.hflip, args.noshuffle, args.no_droplast))
283 | args.rcrop = [int(x.strip()) for x in args.rcrop.split(",")]
284 | args.clrjit_params = [float(x) for x in args.clrjit_params.split(',')]
285 |
286 | datasets = create_dataset(args, logger)
287 | sourceloader_iter = enumerate(datasets.source_train_loader)
288 | targetloader_iter = enumerate(datasets.target_train_loader)
289 |
290 | # define optimizer
291 | model_params = [{'params': list(model_B2.parameters()) + list(model_B.parameters())},
292 | {'params': list(head.parameters()) + list(classifier.parameters()), 'lr': args.learning_rate * 10}]
293 | optimizer = optim.SGD(model_params, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
294 | assert len(optimizer.param_groups) == 2
295 | optimizer.zero_grad()
296 |
297 | model_B2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B2)
298 | model_B2 = torch.nn.parallel.DistributedDataParallel(model_B2.cuda(), device_ids=[gpu], find_unused_parameters=True)
299 |
300 | model_B = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B)
301 | model_B = torch.nn.parallel.DistributedDataParallel(model_B.cuda(), device_ids=[gpu], find_unused_parameters=True)
302 |
303 | head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(head)
304 | head = torch.nn.parallel.DistributedDataParallel(head.cuda(), device_ids=[gpu], find_unused_parameters=True)
305 |
306 | classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
307 | classifier = torch.nn.parallel.DistributedDataParallel(classifier.cuda(), device_ids=[gpu], find_unused_parameters=True)
308 | seg_loss = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label)
309 |
310 | interp = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True)
311 | interp_target = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True)
312 |
313 | # labels for adversarial training
314 | source_label = 0
315 | target_label = 1
316 |
317 | # set up tensor board
318 | if args.tensorboard and gpu == 0:
319 | writer = SummaryWriter(args.snapshot_dir)
320 |
321 | # # Uncomment the following two lines for testing
322 | # validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_valid_loader)
323 | # exit()
324 |
325 | thresholds = np.load(args.thresholds_path)
326 | class_list = ["road","sidewalk","building","wall",
327 | "fence","pole","traffic_light","traffic_sign","vegetation",
328 | "terrain","sky","person","rider","car",
329 | "truck","bus","train","motorcycle","bicycle"]
330 | if gpu == 0:
331 | logger.info('successfully load class-wise thresholds from {}'.format(args.thresholds_path))
332 | for c in range(len(class_list)):
333 | logger.info("class {}: {}, threshold: {}".format(c, class_list[c], thresholds[c]))
334 | thresholds = torch.from_numpy(thresholds).cuda()
335 |
336 | scaler = torch.cuda.amp.GradScaler()
337 | best_miou = 0.0
338 | filename = None
339 | epoch_s, epoch_t = 0, 0
340 | for i_iter in range(args.num_steps):
341 |
342 | model_B2.train()
343 | model_B.train()
344 | head.train()
345 | classifier.train()
346 |
347 | loss_seg_value = 0
348 | loss_src_seg_value = 0
349 |
350 | optimizer.zero_grad()
351 | adjust_learning_rate(optimizer, i_iter)
352 |
353 | for sub_i in range(args.iter_size):
354 |
355 | # train with source
356 | try:
357 | _, batch = sourceloader_iter.__next__()
358 | except StopIteration:
359 | epoch_s += 1
360 | datasets.source_train_sampler.set_epoch(epoch_s)
361 | sourceloader_iter = enumerate(datasets.source_train_loader)
362 | _, batch = sourceloader_iter.__next__()
363 |
364 | images = batch['img'].cuda()
365 | labels = batch['label'].cuda()
366 | src_size = images.shape[-2:]
367 |
368 | with torch.cuda.amp.autocast():
369 |
370 | feat_src = model_B2(images)
371 | feat_B_src = model_B(feat_src)
372 | pred = classifier(head(feat_B_src))
373 | pred = interp(pred) #[b, num_classes, h, w]
374 |
375 | loss_seg = seg_loss(pred, labels)
376 |
377 | loss = loss_seg
378 |
379 | # proper normalization
380 | loss = args.src_loss_weight * loss / args.iter_size
381 | loss_src_seg_value += loss_seg / args.iter_size
382 |
383 | scaler.scale(loss).backward()
384 |
385 | # train with target
386 | try:
387 | _, batch = targetloader_iter.__next__()
388 | except StopIteration:
389 | epoch_t += 1
390 | datasets.target_train_sampler.set_epoch(epoch_t)
391 | targetloader_iter = enumerate(datasets.target_train_loader)
392 | _, batch = targetloader_iter.__next__()
393 |
394 | images = batch['img'].cuda()
395 | soft_labels = batch['lpsoft'].cuda()
396 | tgt_size = images.shape[-2:]
397 |
398 | with torch.no_grad():
399 |
400 |
401 | soft_labels = F.softmax(soft_labels, 1)
402 |
403 |
404 | images_full = batch['img_full'].cuda()
405 | weak_params = batch['weak_params']
406 | resize_params = weak_params['RandomSized']
407 | crop_params = weak_params['RandomCrop']
408 | flip_params = weak_params['RandomHorizontallyFlip']
409 | # print("resize_params: ", resize_params)
410 | # print("crop_params: ", crop_params)
411 | # print("flip_params: ", flip_params)
412 | with torch.cuda.amp.autocast():
413 | with torch.no_grad():
414 | pred_full = F.softmax(interp(classifier(head(model_B(model_B2(images_full))))), 1)
415 |
416 | # print("v1 pred_full.min(): {}, pred_full.max(): {}, pred_full.mean(): {}".format(pred_full.min(), pred_full.max(), pred_full.mean()))
417 |
418 | pred_labels = []
419 | for b in range(pred_full.shape[0]):
420 | # restore pred_full to crop
421 | # 1.Resize
422 | h, w = resize_params[0][b], resize_params[1][b]
423 | pred_resize_b = F.interpolate(pred_full[b].unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)[0]
424 | # 2.Crop
425 | ys, ye, xs, xe = crop_params[0][b], crop_params[1][b], crop_params[2][b], crop_params[3][b]
426 | pred_crop_b = pred_resize_b[:, ys:ye, xs:xe]
427 | # 3.Flip
428 | if flip_params[b]:
429 | pred_crop_b = torch.flip(pred_crop_b, dims=(2,)) #[c, h, w]
430 | pred_labels.append(pred_crop_b)
431 | pred_labels = torch.stack(pred_labels, 0)
432 | assert pred_labels.shape[-2:] == tgt_size
433 | pseudo_labels = (pred_labels + soft_labels) / 2.0
434 |
435 |
436 | with torch.cuda.amp.autocast():
437 | feat_tgt = model_B2(images)
438 | feat_B_tgt = model_B(feat_tgt)
439 | pred = classifier(head(feat_B_tgt))
440 | pred = interp(pred) #[b, num_classes, h, w]
441 |
442 | conf, pseudo_labels = pseudo_labels.max(1) #[b, h, w]
443 |
444 | pseudo_labels[conf < thresholds[pseudo_labels]] = args.ignore_label
445 | pseudo_labels = pseudo_labels.detach()
446 | loss_seg = seg_loss(pred, pseudo_labels)
447 |
448 | loss = loss_seg
449 |
450 | # proper normalization
451 | loss = loss / args.iter_size
452 | loss_seg_value += loss_seg / args.iter_size
453 |
454 | scaler.scale(loss).backward()
455 |
456 | n = torch.tensor(1.0).cuda()
457 |
458 | dist.all_reduce(n), dist.all_reduce(loss_seg_value), dist.all_reduce(loss_src_seg_value)
459 |
460 | loss_seg_value = loss_seg_value.item() / n.item()
461 | loss_src_seg_value = loss_src_seg_value.item() / n.item()
462 |
463 | scaler.step(optimizer)
464 | scaler.update()
465 |
466 | if args.tensorboard and gpu == 0:
467 | scalar_info = {
468 | 'loss_seg': loss_seg_value,
469 | 'loss_src_seg': loss_src_seg_value,
470 | }
471 |
472 | if i_iter % 10 == 0:
473 | for key, val in scalar_info.items():
474 | writer.add_scalar(key, val, i_iter)
475 |
476 | if gpu == 0 and i_iter % args.print_every == 0:
477 | logger.info('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_src_seg = {3:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_src_seg_value))
478 |
479 | if gpu == 0 and i_iter >= args.num_steps_stop - 1:
480 | logger.info('save model ...')
481 | filename = osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')
482 | save_file = {'model_B2_state_dict': model_B2.state_dict(), 'model_B_state_dict': model_B.state_dict(), \
483 | 'head_state_dict': head.state_dict(), 'classifier_state_dict': classifier.state_dict()}
484 | torch.save(save_file, filename)
485 | logger.info("saving checkpoint model to {}".format(filename))
486 | break
487 |
488 | if i_iter % args.save_pred_every == 0 and i_iter != 0:
489 | miou, loss_val = validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_valid_loader)
490 | if args.tensorboard and gpu == 0:
491 | scalar_info = {
492 | 'miou_val': miou,
493 | 'loss_val': loss_val
494 | }
495 | for k, v in scalar_info.items():
496 | writer.add_scalar(k, v, i_iter)
497 |
498 | if gpu == 0 and miou > best_miou:
499 | best_miou = miou
500 | logger.info('taking snapshot ...')
501 | if filename is not None and os.path.exists(filename):
502 | os.remove(filename)
503 | filename = osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + "_{}".format(miou) + '.pth')
504 | save_file = {'model_B2_state_dict': model_B2.state_dict(), 'model_B_state_dict': model_B.state_dict(), \
505 | 'head_state_dict': head.state_dict(), 'classifier_state_dict': classifier.state_dict()}
506 | torch.save(save_file, filename)
507 | logger.info("saving checkpoint model to {}".format(filename))
508 |
509 | if args.tensorboard and gpu == 0:
510 | writer.close()
511 |
512 | def validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger, testloader):
513 | if gpu == 0:
514 | logger.info("Start Evaluation")
515 | # evaluate
516 | loss_meter = AverageMeter()
517 | intersection_meter = AverageMeter()
518 | union_meter = AverageMeter()
519 |
520 | model_B2.eval()
521 | model_B.eval()
522 | head.eval()
523 | classifier.eval()
524 |
525 | with torch.no_grad():
526 | for i, batch in enumerate(testloader):
527 | images = batch["img"].cuda()
528 | labels = batch["label"].cuda()
529 |
530 | pred = model_B(model_B2(images))
531 | pred = classifier(head(pred))
532 | output = F.interpolate(pred, size=labels.size()[-2:], mode='bilinear', align_corners=True)
533 | loss = seg_loss(output, labels)
534 |
535 | output = output.max(1)[1]
536 | intersection, union, _ = intersectionAndUnionGPU(output, labels, args.num_classes, args.ignore_label)
537 | dist.all_reduce(intersection), dist.all_reduce(union)
538 | intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
539 | intersection_meter.update(intersection), union_meter.update(union)
540 | loss_meter.update(loss.item(), images.size(0))
541 | if gpu == 0 and i % 50 == 0 and i != 0:
542 | logger.info("Evaluation iter = {0:5d}/{1:5d}, loss_eval = {2:.3f}".format(
543 | i, len(testloader), loss_meter.val
544 | ))
545 |
546 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
547 | miou = np.mean(iou_class)
548 | if gpu == 0:
549 | logger.info("Val result: mIoU = {:.3f}".format(miou))
550 | for i in range(args.num_classes):
551 | logger.info("Class_{} Result: iou = {:.3f}".format(i, iou_class[i]))
552 | logger.info("End Evaluation")
553 |
554 | torch.cuda.empty_cache()
555 |
556 | return miou, loss_meter.avg
557 |
558 | def find_free_port():
559 | import socket
560 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
561 | # Binding to port 0 will cause the OS to find an available port for us
562 | sock.bind(("", 0))
563 | port = sock.getsockname()[1]
564 | sock.close()
565 | # NOTE: there is still a chance the port could be taken by other processes.
566 | return port
567 |
568 | if __name__ == '__main__':
569 | args.gpus = [int(x) for x in args.gpus.split(",")]
570 | args.world_size = len(args.gpus)
571 | if args.dist:
572 | port = find_free_port()
573 | args.dist_url = f"tcp://127.0.0.1:{port}"
574 | mp.spawn(main_worker, nprocs=args.world_size, args=(args.world_size, args.dist_url))
575 | else:
576 | main_worker(args.train_gpu, args.world_size, args)
577 |
578 |
--------------------------------------------------------------------------------
/train_phase1.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import pickle
6 | import torch.optim as optim
7 | import scipy.misc
8 | import torch.backends.cudnn as cudnn
9 | import torch.nn.functional as F
10 | import sys
11 | import os
12 | import os.path as osp
13 | import random
14 | import logging
15 | import time
16 | import torch.distributed as dist
17 | import torch.multiprocessing as mp
18 | from tensorboardX import SummaryWriter
19 |
20 | from model.feature_extractor import resnet_feature_extractor
21 | from model.classifier import ASPP_Classifier_Gen
22 | from model.discriminator import FCDiscriminator
23 |
24 | from utils.util import *
25 | from data import create_dataset
26 | import cv2
27 |
28 | IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32)
29 | IMG_STD = np.array((0.229, 0.224, 0.225), dtype=np.float32)
30 |
31 | MODEL = 'DeepLab'
32 | BATCH_SIZE = 1
33 | ITER_SIZE = 1
34 | NUM_WORKERS = 16
35 | IGNORE_LABEL = 250
36 | LEARNING_RATE = 2.5e-4
37 | MOMENTUM = 0.9
38 | NUM_CLASSES = 19
39 | NUM_STEPS = 93750
40 | NUM_STEPS_STOP = 60000 # early stopping
41 | POWER = 0.9
42 | RANDOM_SEED = 1234
43 | RESUME = './pretrained/sourceonly.pth'
44 | SAVE_NUM_IMAGES = 2
45 | SAVE_PRED_EVERY = 1000
46 | SNAPSHOT_DIR = './snapshots/'
47 | WEIGHT_DECAY = 0.0005
48 | LOG_DIR = './log'
49 |
50 | LEARNING_RATE_D = 1e-4
51 | LAMBDA_SEG = 0.1
52 | LAMBDA_ADV_TARGET1 = 0.0002
53 | LAMBDA_ADV_TARGET2 = 0.001
54 | GAN = 'LS' #'Vanilla'
55 |
56 | SET = 'train'
57 |
58 | def get_arguments():
59 | """Parse all the arguments provided from the CLI.
60 |
61 | Returns:
62 | A list of parsed arguments.
63 | """
64 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network")
65 | parser.add_argument("--model", type=str, default=MODEL,
66 | help="available options : DeepLab")
67 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
68 | help="Number of images sent to the network in one step.")
69 | parser.add_argument("--iter-size", type=int, default=ITER_SIZE,
70 | help="Accumulate gradients for ITER_SIZE iterations.")
71 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS,
72 | help="number of workers for multithread dataloading.")
73 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL,
74 | help="The index of the label to ignore during the training.")
75 | parser.add_argument("--is-training", action="store_true",
76 | help="Whether to updates the running means and variances during the training.")
77 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE,
78 | help="Base learning rate for training with polynomial decay.")
79 | parser.add_argument("--learning-rate-D", type=float, default=LEARNING_RATE_D,
80 | help="Base learning rate for discriminator.")
81 | parser.add_argument("--lambda-seg", type=float, default=LAMBDA_SEG,
82 | help="lambda_seg.")
83 | parser.add_argument("--lambda-adv-target1", type=float, default=LAMBDA_ADV_TARGET1,
84 | help="lambda_adv for adversarial training.")
85 | parser.add_argument("--lambda-adv-target2", type=float, default=LAMBDA_ADV_TARGET2,
86 | help="lambda_adv for adversarial training.")
87 | parser.add_argument("--momentum", type=float, default=MOMENTUM,
88 | help="Momentum component of the optimiser.")
89 | parser.add_argument("--not-restore-last", action="store_true",
90 | help="Whether to not restore last (FC) layers.")
91 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES,
92 | help="Number of classes to predict (including background).")
93 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS,
94 | help="Number of training steps.")
95 | parser.add_argument("--num-steps-stop", type=int, default=NUM_STEPS_STOP,
96 | help="Number of training steps for early stopping.")
97 | parser.add_argument("--power", type=float, default=POWER,
98 | help="Decay parameter to compute the learning rate.")
99 | parser.add_argument("--random-mirror", action="store_true",
100 | help="Whether to randomly mirror the inputs during the training.")
101 | parser.add_argument("--random-scale", action="store_true",
102 | help="Whether to randomly scale the inputs during the training.")
103 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED,
104 | help="Random seed to have reproducible results.")
105 | parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES,
106 | help="How many images to save.")
107 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY,
108 | help="Save summaries and checkpoint every often.")
109 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR,
110 | help="Where to save snapshots of the model.")
111 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY,
112 | help="Regularisation parameter for L2-loss.")
113 | parser.add_argument("--cpu", action='store_true', help="choose to use cpu device.")
114 | parser.add_argument("--tensorboard", action='store_true', help="choose whether to use tensorboard.")
115 | parser.add_argument("--log-dir", type=str, default=LOG_DIR,
116 | help="Path to the directory of log.")
117 | parser.add_argument("--set", type=str, default=SET,
118 | help="choose adaptation set.")
119 | parser.add_argument("--gan", type=str, default=GAN,
120 | help="choose the GAN objective.")
121 | parser.add_argument("--gpus", type=str, default="0,1", help="selected gpus")
122 | parser.add_argument("--dist", action="store_true", help="DDP")
123 | parser.add_argument("--ngpus_per_node", type=int, default=1, help='number of gpus in each node')
124 | parser.add_argument("--print-every", type=int, default=20, help='output message every n iterations')
125 |
126 | parser.add_argument("--src_dataset", type=str, default="gta5", help='training source dataset')
127 | parser.add_argument("--tgt_dataset", type=str, default="cityscapes_train", help='training target dataset')
128 | parser.add_argument("--tgt_val_dataset", type=str, default="cityscapes_val", help='training target dataset')
129 | parser.add_argument("--noaug", action="store_true", help="augmentation")
130 | parser.add_argument('--resize', type=int, default=2200, help='resize long size')
131 | parser.add_argument("--clrjit_params", type=str, default="0.0,0.0,0.0,0.0", help='brightness,contrast,saturation,hue')
132 | parser.add_argument('--rcrop', type=str, default='896,512', help='rondom crop size')
133 | parser.add_argument('--hflip', type=float, default=0.5, help='random flip probility')
134 | parser.add_argument('--src_rootpath', type=str, default='datasets/gta5')
135 | parser.add_argument('--tgt_rootpath', type=str, default='datasets/cityscapes')
136 | parser.add_argument('--noshuffle', action='store_true', help='do not use shuffle')
137 | parser.add_argument('--no_droplast', action='store_true')
138 | parser.add_argument('--pseudo_labels_folder', type=str, default='')
139 | parser.add_argument('--conf_bank_length', type=int, default=100000)
140 | parser.add_argument('--conf_p', type=float, default=0.8)
141 |
142 | parser.add_argument("--batch_size_val", type=int, default=4, help='batch_size for validation')
143 | parser.add_argument("--resume", type=str, default=RESUME, help='resume weight')
144 | parser.add_argument("--freeze_bn", action="store_true", help="augmentation")
145 | parser.add_argument("--lambda_adv_src", type=float, default=0.1, help='weight for loss_adv_src')
146 | parser.add_argument("--lambda_adv_tgt", type=float, default=0.01, help='weight for loss_adv_tgt')
147 | parser.add_argument("--hidden_dim", type=int, default=128, help='number of selected negative samples')
148 | parser.add_argument("--layer", type=int, default=1, help='separate from which layer')
149 | parser.add_argument("--lambda_st", type=float, default=0.1, help='weight for loss_st')
150 | return parser.parse_args()
151 |
152 |
153 | args = get_arguments()
154 |
155 |
156 | def soft_label_cross_entropy(pred, soft_label, pixel_weights=None):
157 | N, C, H, W = pred.shape
158 | loss = -soft_label.float()*F.log_softmax(pred, dim=1)
159 | if pixel_weights is None:
160 | return torch.mean(torch.sum(loss, dim=1))
161 | return torch.mean(pixel_weights*torch.sum(loss, dim=1))
162 |
163 |
164 | def lr_poly(base_lr, iter, max_iter, power):
165 | return base_lr * ((1 - float(iter) / max_iter) ** (power))
166 |
167 |
168 | def adjust_learning_rate(optimizer, i_iter):
169 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power)
170 | optimizer.param_groups[0]['lr'] = lr
171 | if len(optimizer.param_groups) > 1:
172 | optimizer.param_groups[1]['lr'] = lr * 10
173 |
174 |
175 | def adjust_learning_rate_D(optimizer, i_iter):
176 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power)
177 | optimizer.param_groups[0]['lr'] = lr
178 | if len(optimizer.param_groups) > 1:
179 | optimizer.param_groups[1]['lr'] = lr * 10
180 |
181 |
182 | def main_worker(gpu, world_size, dist_url):
183 | """Create the model and start the training."""
184 | if gpu == 0:
185 | if not os.path.exists(args.snapshot_dir):
186 | os.makedirs(args.snapshot_dir)
187 | logFilename = os.path.join(args.snapshot_dir, str(time.time()))
188 | logging.basicConfig(
189 | level = logging.INFO,
190 | format ='%(asctime)s-%(levelname)s-%(message)s',
191 | datefmt = '%y-%m-%d %H:%M',
192 | filename = logFilename,
193 | filemode = 'w+')
194 | filehandler = logging.FileHandler(logFilename, encoding='utf-8')
195 | logger = logging.getLogger()
196 | logger.addHandler(filehandler)
197 | handler = logging.StreamHandler()
198 | logger.addHandler(handler)
199 | logger.info(args)
200 |
201 | np.random.seed(args.random_seed)
202 | random.seed(args.random_seed)
203 | torch.manual_seed(args.random_seed)
204 | torch.cuda.manual_seed(args.random_seed)
205 | # torch.backends.cudnn.deterministic = True
206 | torch.cuda.manual_seed_all(args.random_seed) # if you are using multi-GPU.
207 | # torch.backends.cudnn.enabled = False
208 |
209 | print("gpu: {}, world_size: {}".format(gpu, world_size))
210 | print("dist_url: ", dist_url)
211 |
212 | torch.cuda.set_device(gpu)
213 | args.batch_size = args.batch_size // world_size
214 | args.batch_size_val = args.batch_size_val // world_size
215 | args.num_workers = args.num_workers // world_size
216 | dist.init_process_group(backend='nccl', init_method=dist_url, world_size=world_size, rank=gpu)
217 |
218 | if gpu == 0:
219 | logger.info("args.batch_size: {}, args.batch_size_val: {}".format(args.batch_size, args.batch_size_val))
220 |
221 | device = torch.device("cuda" if not args.cpu else "cpu")
222 |
223 | args.world_size = world_size
224 |
225 | if gpu == 0:
226 | logger.info("args: {}".format(args))
227 |
228 | # cudnn.enabled = True
229 |
230 | # Create network
231 | if args.model == 'DeepLab':
232 |
233 | if args.resume:
234 | resume_weight = torch.load(args.resume, map_location='cpu')
235 | print("args.resume: ", args.resume)
236 | feature_extractor_weights = resume_weight['model_state_dict']
237 | head_weights = resume_weight['head_state_dict']
238 | classifier_weights = resume_weight['classifier_state_dict']
239 | feature_extractor_weights = {k.replace("module.", ""):v for k,v in feature_extractor_weights.items()}
240 | head_weights = {k.replace("module.", ""):v for k,v in head_weights.items()}
241 | classifier_weights = {k.replace("module.", ""):v for k,v in classifier_weights.items()}
242 |
243 | if gpu == 0:
244 | logger.info("freeze_bn: {}".format(args.freeze_bn))
245 | model = resnet_feature_extractor('resnet101', 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', freeze_bn=args.freeze_bn)
246 | if args.resume:
247 | model.load_state_dict(feature_extractor_weights)
248 |
249 | if args.layer == 0:
250 | model_B1 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool)
251 | elif args.layer == 1:
252 | model_B1 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1)
253 | elif args.layer == 2:
254 | model_B1 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1, model.backbone.layer2)
255 |
256 | model = resnet_feature_extractor('resnet101', 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', freeze_bn=args.freeze_bn)
257 | if args.resume:
258 | model.load_state_dict(feature_extractor_weights)
259 |
260 | if args.layer == 0:
261 | ndf = 64
262 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool)
263 | model_B = nn.Sequential(model.backbone.layer1, model.backbone.layer2, model.backbone.layer3, model.backbone.layer4)
264 | elif args.layer == 1:
265 | ndf = 256
266 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1)
267 | model_B = nn.Sequential(model.backbone.layer2, model.backbone.layer3, model.backbone.layer4)
268 | elif args.layer == 2:
269 | ndf = 512
270 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1, model.backbone.layer2)
271 | model_B = nn.Sequential(model.backbone.layer3, model.backbone.layer4)
272 |
273 | model_D1 = FCDiscriminator(ndf, ndf=32)
274 | model_D2 = FCDiscriminator(args.num_classes, ndf=64)
275 |
276 | classifier = ASPP_Classifier_Gen(2048, [6, 12, 18, 24], [6, 12, 18, 24], args.num_classes, hidden_dim=args.hidden_dim)
277 | head, classifier = classifier.head, classifier.classifier
278 | if args.resume:
279 | head.load_state_dict(head_weights)
280 | classifier.load_state_dict(classifier_weights)
281 |
282 | aux_classifier = ASPP_Classifier_Gen(2048, [6, 12, 18, 24], [6, 12, 18, 24], args.num_classes, hidden_dim=args.hidden_dim)
283 | _, aux_classifier = aux_classifier.head, aux_classifier.classifier
284 | if args.resume:
285 | aux_classifier.load_state_dict(classifier_weights)
286 |
287 | model_B1.train()
288 | model_B2.train()
289 | model_B.train()
290 | model_D1.train()
291 | model_D2.train()
292 | head.train()
293 | classifier.train()
294 | aux_classifier.train()
295 |
296 | # cudnn.benchmark = True
297 | if gpu == 0:
298 | logger.info(model_B1)
299 | logger.info(model_B2)
300 | logger.info(model_B)
301 | logger.info(model_D1)
302 | logger.info(model_D2)
303 | logger.info(head)
304 | logger.info(classifier)
305 | logger.info(aux_classifier)
306 | else:
307 | logger = None
308 |
309 | if gpu == 0:
310 | logger.info("args.noaug: {}, args.resize: {}, args.rcrop: {}, args.hflip: {}, args.noshuffle: {}, args.no_droplast: {}".format(args.noaug, args.resize, args.rcrop, args.hflip, args.noshuffle, args.no_droplast))
311 | args.rcrop = [int(x.strip()) for x in args.rcrop.split(",")]
312 | args.clrjit_params = [float(x) for x in args.clrjit_params.split(',')]
313 |
314 | datasets = create_dataset(args, logger)
315 |
316 | # define optimizer
317 | model_params = [{'params': list(model_B1.parameters()) + list(model_B2.parameters()) + list(model_B.parameters())},
318 | {'params': list(head.parameters()) + list(classifier.parameters()) + \
319 | list(aux_classifier.parameters()), 'lr': args.learning_rate * 10}]
320 | optimizer = optim.SGD(model_params, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
321 | assert len(optimizer.param_groups) == 2
322 | optimizer.zero_grad()
323 |
324 | optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
325 | optimizer_D1.zero_grad()
326 |
327 | optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99))
328 | optimizer_D2.zero_grad()
329 |
330 | # define model
331 | model_B1 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B1)
332 | model_B1 = torch.nn.parallel.DistributedDataParallel(model_B1.cuda(), device_ids=[gpu], find_unused_parameters=True)
333 |
334 | model_B2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B2)
335 | model_B2 = torch.nn.parallel.DistributedDataParallel(model_B2.cuda(), device_ids=[gpu], find_unused_parameters=True)
336 |
337 | model_B = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B)
338 | model_B = torch.nn.parallel.DistributedDataParallel(model_B.cuda(), device_ids=[gpu], find_unused_parameters=True)
339 |
340 | head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(head)
341 | head = torch.nn.parallel.DistributedDataParallel(head.cuda(), device_ids=[gpu], find_unused_parameters=True)
342 |
343 | classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier)
344 | classifier = torch.nn.parallel.DistributedDataParallel(classifier.cuda(), device_ids=[gpu], find_unused_parameters=True)
345 |
346 | model_D1 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_D1)
347 | model_D1 = torch.nn.parallel.DistributedDataParallel(model_D1.cuda(), device_ids=[gpu], find_unused_parameters=True)
348 |
349 | model_D2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_D2)
350 | model_D2 = torch.nn.parallel.DistributedDataParallel(model_D2.cuda(), device_ids=[gpu], find_unused_parameters=True)
351 |
352 | aux_classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(aux_classifier)
353 | aux_classifier = torch.nn.parallel.DistributedDataParallel(aux_classifier.cuda(), device_ids=[gpu], find_unused_parameters=True)
354 |
355 | if args.gan == 'Vanilla':
356 | bce_loss = torch.nn.BCEWithLogitsLoss()
357 | elif args.gan == 'LS':
358 | bce_loss = torch.nn.MSELoss()
359 | if gpu == 0:
360 | logger.info("use LS-GAN")
361 | seg_loss = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label)
362 |
363 | interp = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True)
364 | interp_target = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True)
365 |
366 | # labels for adversarial training
367 | source_label = 0
368 | target_label = 1
369 |
370 | # set up tensor board
371 | if args.tensorboard and gpu == 0:
372 | writer = SummaryWriter(args.snapshot_dir)
373 |
374 | if gpu == 0:
375 | logger.info("args.lambda_adv_src: {}, args.lambda_adv_tgt: {}".format(args.lambda_adv_src, args.lambda_adv_tgt))
376 |
377 | # validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_valid_loader)
378 | # exit()
379 |
380 | trainloader_iter = enumerate(datasets.source_train_loader)
381 | targetloader_iter = enumerate(datasets.target_train_loader)
382 |
383 | conf_bank = {i: [] for i in range(args.num_classes)}
384 | thresholds = torch.zeros(args.num_classes).float().cuda()
385 | class_list = ["road","sidewalk","building","wall",
386 | "fence","pole","traffic_light","traffic_sign","vegetation",
387 | "terrain","sky","person","rider","car",
388 | "truck","bus","train","motorcycle","bicycle"]
389 |
390 | scaler = torch.cuda.amp.GradScaler()
391 | best_miou = 0.0
392 | filename = None
393 | epoch_s, epoch_t = 0, 0
394 | for i_iter in range(args.num_steps):
395 |
396 | # model.train()
397 | model_B1.train()
398 | model_B2.train()
399 | model_B.train()
400 | model_D1.train()
401 | model_D2.train()
402 | head.train()
403 | classifier.train()
404 | aux_classifier.train()
405 |
406 | loss_seg_value = 0
407 | loss_adv_src_value = 0
408 | loss_adv_tgt_value = 0
409 | loss_D1_value = 0
410 | loss_D2_value = 0
411 | loss_st_value = 0
412 |
413 | optimizer.zero_grad()
414 | adjust_learning_rate(optimizer, i_iter)
415 | optimizer_D1.zero_grad()
416 | adjust_learning_rate_D(optimizer_D1, i_iter)
417 | optimizer_D2.zero_grad()
418 | adjust_learning_rate_D(optimizer_D2, i_iter)
419 |
420 | for sub_i in range(args.iter_size):
421 |
422 | # train G
423 | for param in model_D1.parameters():
424 | param.requires_grad = False
425 | for param in model_D2.parameters():
426 | param.requires_grad = False
427 |
428 | # train with source
429 | try:
430 | _, batch = trainloader_iter.__next__()
431 | except StopIteration:
432 | epoch_s += 1
433 | datasets.source_train_sampler.set_epoch(epoch_s)
434 | trainloader_iter = enumerate(datasets.source_train_loader)
435 | _, batch = trainloader_iter.__next__()
436 |
437 | images = batch['img'].cuda()
438 | labels = batch['label'].cuda()
439 |
440 | src_size = images.shape[-2:]
441 | with torch.cuda.amp.autocast():
442 | feat_src = model_B1(images)
443 |
444 | feat_B_src = model_B(feat_src)
445 | pred = classifier(head(feat_B_src))
446 | pred = interp(pred) #[b, num_classes, h, w]
447 |
448 | temperature = 1.8
449 | pred = pred.div(temperature)
450 | loss_seg = seg_loss(pred, labels)
451 |
452 | D_out = model_D1(F.interpolate(feat_src, size=src_size, mode='bilinear', align_corners=True))
453 |
454 | loss_adv_src = args.lambda_adv_src * bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(target_label).cuda())
455 | loss = loss_seg + loss_adv_src
456 |
457 | # proper normalization
458 | loss = loss / args.iter_size
459 | loss_seg_value += loss_seg / args.iter_size
460 | loss_adv_src_value += loss_adv_src / args.iter_size
461 |
462 | scaler.scale(loss).backward()
463 |
464 | # train with target
465 | try:
466 | _, batch = targetloader_iter.__next__()
467 | except StopIteration:
468 | epoch_t += 1
469 | datasets.target_train_sampler.set_epoch(epoch_t)
470 | targetloader_iter = enumerate(datasets.target_train_loader)
471 | _, batch = targetloader_iter.__next__()
472 |
473 | images = batch['img'].cuda()
474 |
475 | tgt_size = images.shape[-2:]
476 | with torch.cuda.amp.autocast():
477 | feat_tgt = model_B2(images)
478 | feat_B_tgt = model_B(feat_tgt)
479 |
480 | feat_B_tgt_head = head(feat_B_tgt)
481 | pred_tgt = classifier(feat_B_tgt_head)
482 |
483 | with torch.no_grad():
484 | pred_logits, pred_idx = F.softmax(pred_tgt.detach(), 1).max(1) #[b, h, w]
485 | assert pred_logits.shape[-2:] == pred_tgt.shape[-2:]
486 |
487 | # update_thresholds
488 | for c in range(args.num_classes):
489 | prob_c = pred_logits[pred_idx == c].cpu().numpy().tolist()
490 | if len(prob_c) == 0:
491 | continue
492 | conf_bank[c].extend(prob_c)
493 | rank = int(len(conf_bank[c]) * args.conf_p)
494 | thresholds[c] = sorted(conf_bank[c], reverse=True)[rank]
495 | if len(conf_bank[c]) > args.conf_bank_length:
496 | conf_bank[c] = conf_bank[c][-args.conf_bank_length:]
497 |
498 | n = torch.tensor(1.0).cuda()
499 | dist.all_reduce(thresholds)
500 | dist.all_reduce(n)
501 | thresholds = thresholds / n
502 |
503 | if i_iter % 500 == 0 and gpu == 0:
504 | for c in range(args.num_classes):
505 | print("c: {}, class_i: {} threshold: {}, len(conf_bank[c]): {}".format(c, class_list[c], thresholds[c], len(conf_bank[c])))
506 |
507 | # if i_iter % 100 == 0 and gpu == 0:
508 | # num_pos = (pred_logits > thresholds[pred_idx]).float().sum()
509 | # num_all = np.prod(pred_logits.shape)
510 | # ratio = num_pos / (num_all+1e-8)
511 | # logger.info("num_pos: {}, num_all: {}, ratio: {}".format(num_pos, num_all, ratio))
512 |
513 | pred_idx[pred_logits < thresholds[pred_idx]] = args.ignore_label
514 |
515 | pred_tgt = interp_target(pred_tgt)
516 | pred_tgt = pred_tgt.div(temperature)
517 |
518 | pred_tgt_aux = aux_classifier(feat_B_tgt_head)
519 | loss_st = args.lambda_st * seg_loss(pred_tgt_aux, pred_idx)
520 |
521 | D_out = model_D2(F.softmax(pred_tgt, 1))
522 |
523 | loss_adv_tgt = args.lambda_adv_tgt * bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda())
524 | loss = loss_adv_tgt + loss_st
525 |
526 | loss = loss / args.iter_size
527 | loss_adv_tgt_value += loss_adv_tgt / args.iter_size
528 | loss_st_value += loss_st / args.iter_size
529 |
530 | scaler.scale(loss).backward()
531 |
532 | # train D
533 | # bring back requires_grad
534 | for param in model_D1.parameters():
535 | param.requires_grad = True
536 |
537 | optimizer_D1.zero_grad()
538 | with torch.cuda.amp.autocast():
539 | src_D1_pred = model_D1(F.interpolate(feat_src.detach(), size=src_size, mode='bilinear', align_corners=True))
540 | loss_D1_src = 0.5 * bce_loss(src_D1_pred, torch.FloatTensor(src_D1_pred.data.size()).fill_(source_label).cuda()) / args.iter_size
541 |
542 | scaler.scale(loss_D1_src).backward()
543 |
544 | with torch.cuda.amp.autocast():
545 |
546 | tgt_D1_pred = model_D1(F.interpolate(feat_tgt.detach(), size=tgt_size, mode='bilinear', align_corners=True))
547 | loss_D1_tgt = 0.5 * bce_loss(tgt_D1_pred, torch.FloatTensor(tgt_D1_pred.data.size()).fill_(target_label).cuda()) / args.iter_size
548 |
549 | loss_D1_value += loss_D1_src + loss_D1_tgt
550 |
551 | scaler.scale(loss_D1_tgt).backward()
552 |
553 | for param in model_D2.parameters():
554 | param.requires_grad = True
555 | optimizer_D2.zero_grad()
556 |
557 | with torch.cuda.amp.autocast():
558 | src_D2_pred = model_D2(F.softmax(pred.detach(), 1))
559 | loss_D2_src = 0.5 * bce_loss(src_D2_pred, torch.FloatTensor(src_D2_pred.data.size()).fill_(source_label).cuda()) / args.iter_size
560 |
561 | scaler.scale(loss_D2_src).backward()
562 |
563 | with torch.cuda.amp.autocast():
564 |
565 | tgt_D2_pred = model_D2(F.softmax(pred_tgt.detach(), 1))
566 | loss_D2_tgt = 0.5 * bce_loss(tgt_D2_pred, torch.FloatTensor(tgt_D2_pred.data.size()).fill_(target_label).cuda()) / args.iter_size
567 |
568 | loss_D2_value += loss_D2_src + loss_D2_tgt
569 |
570 | scaler.scale(loss_D2_tgt).backward()
571 |
572 | n = torch.tensor(1.0).cuda()
573 |
574 | dist.all_reduce(n), dist.all_reduce(loss_seg_value), dist.all_reduce(loss_adv_src_value), dist.all_reduce(loss_adv_tgt_value)
575 | dist.all_reduce(loss_D1_value), dist.all_reduce(loss_D2_value), dist.all_reduce(loss_st_value)
576 |
577 | loss_seg_value = loss_seg_value.item() / n.item()
578 | loss_adv_src_value = loss_adv_src_value.item() / n.item()
579 | loss_adv_tgt_value = loss_adv_tgt_value.item() / n.item()
580 | loss_D1_value = loss_D1_value.item() / n.item()
581 | loss_D2_value = loss_D2_value.item() / n.item()
582 | loss_st_value = loss_st_value.item() / n.item()
583 |
584 | scaler.step(optimizer)
585 | scaler.step(optimizer_D1)
586 | scaler.step(optimizer_D2)
587 | scaler.update()
588 |
589 | if args.tensorboard and gpu == 0:
590 | scalar_info = {
591 | 'loss_seg': loss_seg_value,
592 | 'loss_adv_src': loss_adv_src_value,
593 | 'loss_adv_tgt': loss_adv_tgt_value,
594 | 'loss_D1': loss_D1_value,
595 | 'loss_D2': loss_D2_value,
596 | "loss_st": loss_st_value,
597 | }
598 |
599 | if i_iter % 10 == 0:
600 | for key, val in scalar_info.items():
601 | writer.add_scalar(key, val, i_iter)
602 |
603 | if gpu == 0 and i_iter % args.print_every == 0:
604 | logger.info('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_src = {3:.5f}, loss_adv_tgt = {4:.5f}, loss_D1 = {5:.3f}, '
605 | 'loss_D2 = {6:.3f}, loss_st = {7:.5f}, epoch_s = {8:3d}, epoch_t = {9:3d}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_src_value, \
606 | loss_adv_tgt_value, loss_D1_value, loss_D2_value, loss_st_value, epoch_s, epoch_t))
607 |
608 | if gpu == 0 and i_iter >= args.num_steps_stop - 1:
609 | logger.info('save model ...')
610 | filename = osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth')
611 | save_file = {'model_B1_state_dict': model_B1.state_dict(), 'model_B2_state_dict': model_B2.state_dict(), \
612 | 'model_B_state_dict': model_B.state_dict(), 'head_state_dict': head.state_dict(), 'classifier_state_dict': classifier.state_dict()}
613 | torch.save(save_file, filename)
614 | logger.info("saving checkpoint model to {}".format(filename))
615 | break
616 |
617 | if i_iter % args.save_pred_every == 0 and i_iter != 0:
618 | miou, loss_val = validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_valid_loader)
619 | if args.tensorboard and gpu == 0:
620 | scalar_info = {
621 | 'miou_val': miou,
622 | 'loss_val': loss_val
623 | }
624 | for k, v in scalar_info.items():
625 | writer.add_scalar(k, v, i_iter)
626 |
627 | if gpu == 0 and miou > best_miou:
628 | best_miou = miou
629 | logger.info('taking snapshot ...')
630 | if filename is not None and os.path.exists(filename):
631 | os.remove(filename)
632 | filename = osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + "_{}".format(miou) + '.pth')
633 | save_file = {'model_B1_state_dict': model_B1.state_dict(), 'model_B2_state_dict': model_B2.state_dict(), \
634 | 'model_B_state_dict': model_B.state_dict(), 'head_state_dict': head.state_dict(), 'classifier_state_dict': classifier.state_dict()}
635 | torch.save(save_file, filename)
636 | logger.info("saving checkpoint model to {}".format(filename))
637 |
638 | if args.tensorboard and gpu == 0:
639 | writer.close()
640 |
641 | def validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger, testloader):
642 | if gpu == 0:
643 | logger.info("Start Evaluation")
644 | # evaluate
645 | loss_meter = AverageMeter()
646 | intersection_meter = AverageMeter()
647 | union_meter = AverageMeter()
648 |
649 | model_B2.eval()
650 | model_B.eval()
651 | head.eval()
652 | classifier.eval()
653 |
654 | with torch.no_grad():
655 | for i, batch in enumerate(testloader):
656 | images = batch['img'].cuda()
657 | labels = batch['label'].cuda()
658 |
659 | pred = model_B(model_B2(images))
660 | pred = classifier(head(pred))
661 | output = F.interpolate(pred, size=labels.size()[-2:], mode='bilinear', align_corners=True)
662 | loss = seg_loss(output, labels)
663 |
664 | output = output.max(1)[1]
665 | intersection, union, _ = intersectionAndUnionGPU(output, labels, args.num_classes, args.ignore_label)
666 | dist.all_reduce(intersection), dist.all_reduce(union)
667 | intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
668 | intersection_meter.update(intersection), union_meter.update(union)
669 | loss_meter.update(loss.item(), images.size(0))
670 | if gpu == 0 and i % 50 == 0 and i != 0:
671 | logger.info("Evaluation iter = {0:5d}/{1:5d}, loss_eval = {2:.3f}".format(
672 | i, len(testloader), loss_meter.val
673 | ))
674 |
675 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
676 | miou = np.mean(iou_class)
677 | if gpu == 0:
678 | logger.info("Val result: mIoU = {:.3f}".format(miou))
679 | for i in range(args.num_classes):
680 | logger.info("Class_{} Result: iou = {:.3f}".format(i, iou_class[i]))
681 | logger.info("End Evaluation")
682 |
683 | torch.cuda.empty_cache()
684 |
685 | return miou, loss_meter.avg
686 |
687 | def find_free_port():
688 | import socket
689 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
690 | # Binding to port 0 will cause the OS to find an available port for us
691 | sock.bind(("", 0))
692 | port = sock.getsockname()[1]
693 | sock.close()
694 | # NOTE: there is still a chance the port could be taken by other processes.
695 | return port
696 |
697 | if __name__ == '__main__':
698 | args.gpus = [int(x) for x in args.gpus.split(",")]
699 | args.world_size = len(args.gpus)
700 | if args.dist:
701 | port = find_free_port()
702 | args.dist_url = f"tcp://127.0.0.1:{port}"
703 | mp.spawn(main_worker, nprocs=args.world_size, args=(args.world_size, args.dist_url))
704 | else:
705 | main_worker(args.train_gpu, args.world_size, args)
706 |
707 |
--------------------------------------------------------------------------------