├── model ├── __init__.py ├── discriminator.py ├── feature_extractor.py ├── classifier.py └── resnet.py ├── utils ├── __init__.py ├── dropout.py ├── loss.py ├── util.py └── transform.py ├── figs └── fig.png ├── .gitignore ├── LICENSE ├── datasets └── get_thresholds.py ├── README.md ├── data ├── __init__.py ├── base_dataset.py ├── gta5_dataset.py ├── cityscapes_val_dataset.py ├── cityscapes_train_dataset.py ├── randaugment.py └── augmentations.py ├── generate_soft_label.py ├── train_phase2.py └── train_phase1.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figs/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/DecoupleNet/HEAD/figs/fig.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | snapshots/ 2 | *.pth 3 | *__pycache__* 4 | debug/ 5 | class_balance_ids_*.p 6 | core.* 7 | datasets/pseudo_labels*/ 8 | datasets/soft_labels*/ 9 | output/ 10 | check.py 11 | data/class_balance_ids_*.p 12 | data/cityscapes_class_balance_ids_*.p 13 | *.pickle 14 | */*.zip 15 | */*.npy 16 | slurm_cmd/ 17 | pretrained/ 18 | *_soft_labels/ 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 DV Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class FCDiscriminator(nn.Module): 6 | 7 | def __init__(self, num_classes, ndf = 64): 8 | super(FCDiscriminator, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 11 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 12 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 13 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 14 | self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1) 15 | 16 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 17 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 18 | #self.sigmoid = nn.Sigmoid() 19 | 20 | 21 | def forward(self, x): 22 | x = self.conv1(x) 23 | x = self.leaky_relu(x) 24 | x = self.conv2(x) 25 | x = self.leaky_relu(x) 26 | x = self.conv3(x) 27 | x = self.leaky_relu(x) 28 | x = self.conv4(x) 29 | x = self.leaky_relu(x) 30 | x = self.classifier(x) 31 | #x = self.up_sample(x) 32 | #x = self.sigmoid(x) 33 | 34 | return x 35 | -------------------------------------------------------------------------------- /model/feature_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torchvision.models._utils import IntermediateLayerGetter 5 | from . import resnet 6 | 7 | class FrozenBatchNorm2d(nn.Module): 8 | """ 9 | BatchNorm2d where the batch statistics and the affine parameters 10 | are fixed 11 | """ 12 | 13 | def __init__(self, n): 14 | super(FrozenBatchNorm2d_v2, self).__init__() 15 | self.register_buffer("weight", torch.ones(n)) 16 | self.register_buffer("bias", torch.zeros(n)) 17 | self.register_buffer("running_mean", torch.zeros(n)) 18 | self.register_buffer("running_var", torch.ones(n)) 19 | 20 | def forward(self, x): 21 | output = F.batch_norm(x, self.running_mean, self.running_var, weight=self.weight, bias=self.bias, training=False) 22 | return output 23 | 24 | class resnet_feature_extractor(nn.Module): 25 | def __init__(self, backbone_name, pretrained_weights=None, aux=False, pretrained_backbone=True, freeze_bn=False): 26 | super(resnet_feature_extractor, self).__init__() 27 | bn_layer = nn.BatchNorm2d 28 | if freeze_bn: 29 | bn_layer = FrozenBatchNorm2d 30 | backbone = resnet.__dict__[backbone_name]( 31 | pretrained=pretrained_backbone, 32 | replace_stride_with_dilation=[False, True, True], pretrained_weights=pretrained_weights, norm_layer=bn_layer) 33 | return_layers = {'layer4': 'out'} 34 | if aux: 35 | return_layers['layer3'] = 'aux' 36 | self.aux = aux 37 | self.backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) 38 | 39 | def forward(self, x): 40 | if self.aux == True: 41 | output = self.backbone(x) 42 | aux, out = output['aux'], output['out'] 43 | return aux, out 44 | else: 45 | out = self.backbone(x)['out'] 46 | return out 47 | -------------------------------------------------------------------------------- /datasets/get_thresholds.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import pickle 5 | import matplotlib.pyplot as plt 6 | from scipy.special import softmax 7 | import sys 8 | 9 | # python3 get_thresholds.py 0.8 gta2city_soft_labels 10 | 11 | p = float(sys.argv[1]) 12 | npy_dir = sys.argv[2] 13 | save_path = "./{}_cls2prob.pickle".format(sys.argv[2]) 14 | output_path = "./{}_thresholds_p{}.npy".format(sys.argv[2], p) 15 | ignore_label = 250 16 | 17 | if not os.path.exists(save_path): 18 | cls2prob = {} 19 | files = glob.glob(os.path.join(npy_dir, "*.npy")) 20 | for i, npy_file in enumerate(files): 21 | if i % 100 == 0: 22 | print("i: {}/ {}".format(i, len(files))) 23 | f = np.load(npy_file) #[c, h, w] 24 | f = softmax(f, axis=0) 25 | classes = f.argmax(0) #[h, w] 26 | prob = f.max(0) #[h, w] 27 | for c in np.unique(classes): 28 | if c not in cls2prob: 29 | cls2prob[c] = [] 30 | cls2prob[c].extend(prob[classes == c]) 31 | for c in cls2prob: 32 | cls2prob[c].sort(reverse=True) 33 | # with open(save_path, "wb+") as f: 34 | # pickle.dump(cls2prob, f) 35 | else: 36 | with open(save_path, "rb") as f: 37 | cls2prob = pickle.load(f) 38 | 39 | class_list = ["road","sidewalk","building","wall", 40 | "fence","pole","traffic_light","traffic_sign","vegetation", 41 | "terrain","sky","person","rider","car", 42 | "truck","bus","train","motorcycle","bicycle"] 43 | 44 | # print("p: {}".format(p)) 45 | 46 | thresholds = [] 47 | for c in range(len(cls2prob.keys())): 48 | prob_c = cls2prob[c] 49 | rank = int(p * len(prob_c)) 50 | thresh = prob_c[rank] 51 | thresholds.append(thresh) 52 | thresholds = np.array(thresholds) 53 | 54 | for i in range(len(thresholds)): 55 | print("i: {}, class i: {}, thresh_i: {}".format(i, class_list[i], thresholds[i])) 56 | 57 | np.save(output_path, thresholds) 58 | -------------------------------------------------------------------------------- /model/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import random 4 | import numpy as np 5 | from torch import nn 6 | from torchvision.models._utils import IntermediateLayerGetter 7 | 8 | class ASPP_Classifier(nn.Module): 9 | def __init__(self, in_channels, dilation_series, padding_series, num_classes): 10 | super(ASPP_Classifier, self).__init__() 11 | self.conv2d_list = nn.ModuleList() 12 | for dilation, padding in zip(dilation_series, padding_series): 13 | self.conv2d_list.append( 14 | nn.Conv2d( 15 | in_channels, 16 | num_classes, 17 | kernel_size=3, 18 | stride=1, 19 | padding=padding, 20 | dilation=dilation, 21 | bias=True, 22 | ) 23 | ) 24 | 25 | for m in self.conv2d_list: 26 | m.weight.data.normal_(0, 0.01) 27 | 28 | def forward(self, x, size=None): 29 | out = self.conv2d_list[0](x) 30 | for i in range(len(self.conv2d_list) - 1): 31 | out += self.conv2d_list[i + 1](x) 32 | if size is not None: 33 | out = F.interpolate(out, size=size, mode='bilinear', align_corners=True) 34 | return out 35 | 36 | 37 | class ASPP_Classifier_Gen(nn.Module): 38 | '''Generalized version of ASPP head''' 39 | def __init__(self, in_channels, dilation_series, padding_series, num_classes, hidden_dim=128): 40 | super(ASPP_Classifier_Gen, self).__init__() 41 | self.head = ASPP_Classifier(in_channels, dilation_series, padding_series, hidden_dim) 42 | self.classifier = nn.Conv2d(hidden_dim, num_classes, kernel_size=1, stride=1) # Generalize DeepLabv2 to backbone + classifier structure (make classifier independent) 43 | 44 | def forward(self, x, size=None): 45 | out = self.head(x) 46 | out = self.classifier(out) 47 | if size is not None: 48 | out = F.interpolate(out, size=size, mode='bilinear', align_corners=True) 49 | return out 50 | -------------------------------------------------------------------------------- /utils/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from copy import deepcopy 3 | 4 | 5 | def create_adversarial_dropout_mask(mask, jacobian, delta): 6 | """ 7 | 8 | :param mask: shape [batch_size, ...] 9 | :param jacobian: shape [batch_size, ...] 10 | :param delta: 11 | :return: 12 | """ 13 | num_of_units = int(torch.prod(torch.tensor(mask.size()[1:])).to(torch.float)) 14 | change_limit = int(num_of_units * delta) 15 | mask = (mask > 0).to(torch.float) 16 | 17 | if change_limit == 0: 18 | return deepcopy(mask).detach(), torch.Tensor([]).type(torch.int64) 19 | 20 | # mask (mask=1 -> m = 1), (mask=0 -> m=-1) 21 | m = 2 * mask - torch.ones_like(mask) 22 | 23 | # sign of Jacobian (J>0 -> s=1), (J<0 -> s=-1) 24 | s = torch.sign(jacobian) 25 | 26 | # remain (J>0, m=-1) and (J<0, m=1), which are candidates to be changed 27 | change_candidates = ((m * s) < 0).to(torch.float) 28 | 29 | # print("change_candidates: ", change_candidates.sum()) 30 | 31 | # ordering abs_jacobian for candidates 32 | # the maximum number of the changes is "change_limit" 33 | # draw top_k elements ( if the top k element is 0, the number of the changes is less than "change_limit" ) 34 | abs_jacobian = torch.abs(jacobian) 35 | candidate_abs_jacobian = (change_candidates * abs_jacobian).view(-1, num_of_units) 36 | topk_values, topk_indices = torch.topk(candidate_abs_jacobian, change_limit + 1) 37 | min_values = topk_values[:, -1] 38 | change_target_marker = (candidate_abs_jacobian > min_values.unsqueeze(-1)).view(mask.size()).to(torch.float) 39 | 40 | # changed mask with change_target_marker 41 | adv_mask = torch.abs(mask - change_target_marker) 42 | 43 | # normalization 44 | adv_mask = adv_mask.view(-1, num_of_units) 45 | num_of_undropped_units = torch.sum(adv_mask, dim=1).unsqueeze(-1) 46 | adv_mask = ((adv_mask / num_of_undropped_units) * num_of_units).view(mask.size()) 47 | 48 | # return adv_mask.clone().detach(), (adv_mask == 0).nonzero()[:, 1] 49 | return adv_mask.clone().detach(), None 50 | 51 | 52 | def calculate_jacobians(h, clean_logits, head, classifier, consistency_criterion): 53 | cnn_mask = torch.ones((*h.size()[:2], 1, 1)).to(h.device) 54 | # fc_mask = torch.ones(cnn_mask.size(0), fc_mask_size).to(cnn_mask.device) 55 | cnn_mask.requires_grad = True 56 | # fc_mask.requires_grad = True 57 | 58 | # h_logits = classifier(cnn_mask * h, fc_mask) 59 | h_logits = classifier(head(cnn_mask * h)) 60 | discrepancy = consistency_criterion(h_logits, clean_logits) 61 | 62 | # print("discrepancy: ", discrepancy) 63 | 64 | discrepancy.backward() 65 | 66 | # reset_grad_fn() 67 | # return cnn_mask.grad.clone(), fc_mask.grad.clone(), h_logits 68 | head.zero_grad() 69 | classifier.zero_grad() 70 | 71 | # print("cnn_mask.grad.max(): {}, cnn_mask.grad.min(): {}, cnn_mask.grad.mean(): {}".format(cnn_mask.grad.max(), cnn_mask.grad.min(), cnn_mask.grad.mean())) 72 | 73 | return cnn_mask.grad.clone(), h_logits 74 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class CrossEntropy2d(nn.Module): 8 | 9 | def __init__(self, size_average=True, ignore_label=255): 10 | super(CrossEntropy2d, self).__init__() 11 | self.size_average = size_average 12 | self.ignore_label = ignore_label 13 | 14 | def forward(self, predict, target, weight=None): 15 | """ 16 | Args: 17 | predict:(n, c, h, w) 18 | target:(n, h, w) 19 | weight (Tensor, optional): a manual rescaling weight given to each class. 20 | If given, has to be a Tensor of size "nclasses" 21 | """ 22 | assert not target.requires_grad 23 | assert predict.dim() == 4 24 | assert target.dim() == 3 25 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) 26 | assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) 27 | assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3)) 28 | n, c, h, w = predict.size() 29 | target_mask = (target >= 0) * (target != self.ignore_label) 30 | target = target[target_mask] 31 | if not target.data.dim(): 32 | return Variable(torch.zeros(1)) 33 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 34 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 35 | loss = F.cross_entropy(predict, target, weight=weight, size_average=self.size_average) 36 | return loss 37 | 38 | 39 | class EntropyLoss(nn.Module): 40 | def __init__(self, reduction='mean'): 41 | super().__init__() 42 | self.reduction = reduction 43 | 44 | def forward(self, logits): 45 | p = F.softmax(logits, dim=1) 46 | elementwise_entropy = -p * F.log_softmax(logits, dim=1) 47 | if self.reduction == 'none': 48 | return elementwise_entropy 49 | 50 | # print("elementwise_entropy.shape: ", elementwise_entropy.shape) 51 | 52 | sum_entropy = torch.sum(elementwise_entropy, dim=1) 53 | if self.reduction == 'sum': 54 | return sum_entropy 55 | 56 | # print("sum_entropy.shape: ", sum_entropy.shape) 57 | 58 | return torch.mean(sum_entropy) 59 | 60 | 61 | class AbstractConsistencyLoss(nn.Module): 62 | def __init__(self, reduction='mean'): 63 | super().__init__() 64 | self.reduction = reduction 65 | 66 | def forward(self, logits1, logits2): 67 | raise NotImplementedError 68 | 69 | 70 | class LossWithLogits(AbstractConsistencyLoss): 71 | def __init__(self, reduction='mean', loss_cls=nn.L1Loss): 72 | super().__init__(reduction) 73 | self.loss_with_softmax = loss_cls(reduction=reduction) 74 | 75 | def forward(self, logits1, logits2): 76 | loss = self.loss_with_softmax(F.softmax(logits1, dim=1), F.softmax(logits2, dim=1)) 77 | return loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DecoupleNet 2 | Official implementation for our ECCV 2022 paper "DecoupleNet: Decoupled Network for Domain Adaptive Semantic Segmentation" [[arXiv](https://arxiv.org/pdf/2207.09988.pdf)] [[Paper](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136930362.pdf)] 3 | 4 |
5 | 6 |
7 | 8 | # Get Started 9 | 10 | ## Datasets Preparation 11 | 12 | ### GTA5 13 | First, download GTA5 from the [website](https://download.visinf.tu-darmstadt.de/data/from_games/). Then, extract them and organize as follows. 14 | ``` 15 | images/ 16 | |---00000.png 17 | |---00001.png 18 | |---... 19 | labels/ 20 | |---00000.png 21 | |---00001.png 22 | |---... 23 | split.mat 24 | gtav_label_info.p 25 | ``` 26 | 27 | ### Cityscapes 28 | 29 | Download Cityscapes dataset from the [website](https://www.cityscapes-dataset.com/). And organize them as 30 | ``` 31 | leftImg8bit/ 32 | |---train/ 33 | |---val/ 34 | |---test/ 35 | gtFine 36 | |---train/ 37 | |---val/ 38 | |---test/ 39 | ``` 40 | 41 | ## Training 42 | 43 | ### GTA5 -> Cityspcaes 44 | First, download the pretrained ResNet101 (PyTorch) and sourceonly model from [here](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155154502_link_cuhk_edu_hk/EVowKrywcUVJhK0tbO_ebxQBv83FCISbGW_2fTeCWiFvGA), and put them into the directory `./pretrained`. 45 | ``` 46 | mkdir pretrained && cd pretrained 47 | wget https://download.pytorch.org/models/resnet101-5d3b4d8f.pth 48 | # Also put the sourceonly.pth into ./pretrained/ 49 | ``` 50 | 51 | First-phase training: 52 | ``` 53 | python3 train_phase1.py --snapshot-dir ./snapshots/GTA2Cityscapes_phase1 --batch-size 8 --gpus 0,1,2,3 --dist --tensorboard --batch_size_val 4 --src_rootpath [YOUR_SOURCE_DATA_ROOT] --tgt_rootpath [YOUR_TARGET_DATA_ROOT] 54 | ``` 55 | 56 | Second-phase training (The trained phase1 model can also be downloaded from [here](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155154502_link_cuhk_edu_hk/EmhCkQ_lJ1FLr9Dj2QopYHkB4gyXPOC2BUzjmw4jGq6FSQ?e=m8XPfC)): 57 | ``` 58 | # First generate the soft pesudo labels from the trained phase1 model 59 | python3 generate_soft_label.py --snapshot-dir ./snapshots/GTA2Cityscapes_generate_soft_labels --batch-size 8 --gpus 0,1,2,3 --dist --tensorboard --batch_size_val 4 --resume [PATH_OF_PHASE1_MODEL] --output_folder ./datasets/gta2city_soft_labels --no_droplast --src_rootpath [YOUR_SOURCE_DATA_ROOT] --tgt_rootpath [YOUR_TARGET_DATA_ROOT] 60 | 61 | # Then, get the thresholds from the generated soft labels: 62 | cd datasets/ && python3 get_thresholds.py 0.8 gta2city_soft_labels 63 | 64 | # Training with soft pseudo labels: 65 | python3 train_phase2.py --snapshot-dir ./snapshots/GTA2Cityscapes_phase2 --batch-size 8 --gpus 0,1,2,3 --dist --tensorboard --learning-rate 5e-4 --batch_size_val 4 --soft_labels_folder ./datasets/gta2city_soft_labels --resume [PATH_OF_PHASE1_MODEL] --thresholds_path ./datasets/gta2city_soft_labels_thresholds_p0.8.npy --src_rootpath [YOUR_SOURCE_DATA_ROOT] --tgt_rootpath [YOUR_TARGET_DATA_ROOT] 66 | ``` 67 | 68 | # Acknowledgement 69 | This repository borrows codes from the following repos. Many thanks to the authors for their great work. 70 | 71 | ProDA: https://github.com/microsoft/ProDA 72 | 73 | FADA: https://github.com/JDAI-CV/FADA 74 | 75 | semseg: https://github.com/hszhao/semseg 76 | 77 | # Citation 78 | If you find this project useful, please consider citing: 79 | 80 | ``` 81 | @inproceedings{lai2022decouplenet, 82 | title={Decouplenet: Decoupled network for domain adaptive semantic segmentation}, 83 | author={Lai, Xin and Tian, Zhuotao and Xu, Xiaogang and Chen, Yingcong and Liu, Shu and Zhao, Hengshuang and Wang, Liwei and Jia, Jiaya}, 84 | booktitle={European Conference on Computer Vision}, 85 | pages={369--387}, 86 | year={2022}, 87 | organization={Springer} 88 | } 89 | ``` -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import importlib 5 | import numpy as np 6 | import torch.utils.data 7 | from data.base_dataset import BaseDataset 8 | from data.augmentations import * 9 | 10 | def find_dataset_using_name(name): 11 | """Import the module "data/[dataset_name]_dataset.py". 12 | 13 | In the file, the class called DatasetNameDataset() will 14 | be instantiated. It has to be a subclass of BaseDataset, 15 | and it is case-insensitive. 16 | """ 17 | dataset_filename = "data." + name + "_dataset" 18 | datasetlib = importlib.import_module(dataset_filename) 19 | 20 | dataset = None 21 | target_dataset_name = name + '_loader' 22 | for _name, cls in datasetlib.__dict__.items(): 23 | if _name.lower() == target_dataset_name.lower() \ 24 | and issubclass(cls, BaseDataset): 25 | dataset = cls 26 | 27 | if dataset is None: 28 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 29 | 30 | return dataset 31 | 32 | def get_option_setter(dataset_name): 33 | """Return the static method of the dataset class.""" 34 | dataset_class = find_dataset_using_name(dataset_name) 35 | return dataset_class.modify_commandline_options 36 | 37 | def create_dataset(opt, logger): 38 | """Create a dataset given the option. 39 | 40 | This function wraps the class CustomDatasetDataLoader. 41 | This is the main interface between this package and 'train.py'/'test.py' 42 | 43 | Example: 44 | >>> from data import create_dataset 45 | >>> dataset = create_dataset(opt) 46 | """ 47 | data_loader = CustomDatasetDataLoader(opt, logger) 48 | dataset = data_loader.load_data() 49 | return dataset 50 | 51 | def get_composed_augmentations(opt): 52 | return Compose([RandomSized(opt.resize), 53 | RandomCrop(opt.rcrop), 54 | RandomHorizontallyFlip(opt.hflip)]) 55 | 56 | class CustomDatasetDataLoader(): 57 | def __init__(self, opt, logger): 58 | self.opt = opt 59 | self.logger = logger 60 | 61 | # status == 'train': 62 | source_train = find_dataset_using_name(opt.src_dataset) 63 | data_aug = None if opt.noaug else get_composed_augmentations(opt) 64 | self.source_train = source_train(opt, logger, augmentations=data_aug) 65 | if logger is not None: 66 | logger.info("{} source dataset has been created".format(self.source_train.__class__.__name__)) 67 | print("dataset {} for source was created".format(self.source_train.__class__.__name__)) 68 | self.source_train[0] 69 | 70 | data_aug = None if opt.noaug else get_composed_augmentations(opt) 71 | target_train = find_dataset_using_name(opt.tgt_dataset) 72 | self.target_train = target_train(opt, logger, augmentations=data_aug, split='train') 73 | if logger is not None: 74 | logger.info("{} target dataset has been created".format(self.target_train.__class__.__name__)) 75 | print("dataset {} for target was created".format(self.target_train.__class__.__name__)) 76 | self.target_train[0] 77 | 78 | ## create train loader 79 | self.source_train_sampler = torch.utils.data.distributed.DistributedSampler(self.source_train, shuffle=not opt.noshuffle) 80 | self.source_train_loader = torch.utils.data.DataLoader( 81 | self.source_train, 82 | batch_size=opt.batch_size, 83 | shuffle=False, 84 | sampler=self.source_train_sampler, 85 | num_workers=int(opt.num_workers), 86 | drop_last=True, 87 | pin_memory=True, 88 | ) 89 | self.target_train_sampler = torch.utils.data.distributed.DistributedSampler(self.target_train, shuffle=not opt.noshuffle) 90 | self.target_train_loader = torch.utils.data.DataLoader( 91 | self.target_train, 92 | batch_size=opt.batch_size, 93 | shuffle=False, 94 | sampler=self.target_train_sampler, 95 | num_workers=int(opt.num_workers), 96 | drop_last=not opt.no_droplast, 97 | pin_memory=True, 98 | ) 99 | 100 | # status == valid 101 | self.source_valid = None 102 | self.source_valid_loader = None 103 | 104 | self.target_valid = None 105 | self.target_valid_loader = None 106 | 107 | target_valid = find_dataset_using_name(opt.tgt_val_dataset) 108 | self.target_valid = target_valid(opt, logger, augmentations=None, split='val') 109 | if logger is not None: 110 | logger.info("{} target_valid dataset has been created".format(self.target_valid.__class__.__name__)) 111 | print("dataset {} for target_valid was created".format(self.target_valid.__class__.__name__)) 112 | 113 | self.target_valid_sampler = torch.utils.data.distributed.DistributedSampler(self.target_valid, shuffle=False) 114 | self.target_valid_loader = torch.utils.data.DataLoader( 115 | self.target_valid, 116 | batch_size=opt.batch_size_val, 117 | shuffle=False, 118 | sampler=self.target_valid_sampler, 119 | num_workers=int(opt.num_workers), 120 | drop_last=False, 121 | pin_memory=True, 122 | ) 123 | 124 | def load_data(self): 125 | return self 126 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.init as initer 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1): 29 | """Sets the learning rate to the base LR decayed by 10 every step epochs""" 30 | lr = base_lr * (multiplier ** (epoch // step_epoch)) 31 | return lr 32 | 33 | 34 | def poly_learning_rate(base_lr, curr_iter, max_iter, power=0.9): 35 | """poly learning rate policy""" 36 | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power 37 | return lr 38 | 39 | 40 | def intersectionAndUnion(output, target, K, ignore_index=255): 41 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 42 | assert (output.ndim in [1, 2, 3]) 43 | assert output.shape == target.shape 44 | output = output.reshape(output.size).copy() 45 | target = target.reshape(target.size) 46 | output[np.where(target == ignore_index)[0]] = ignore_index 47 | intersection = output[np.where(output == target)[0]] 48 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K+1)) 49 | area_output, _ = np.histogram(output, bins=np.arange(K+1)) 50 | area_target, _ = np.histogram(target, bins=np.arange(K+1)) 51 | area_union = area_output + area_target - area_intersection 52 | return area_intersection, area_union, area_target 53 | 54 | 55 | def intersectionAndUnionGPU(output, target, K, ignore_index=255): 56 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 57 | assert (output.dim() in [1, 2, 3]) 58 | assert output.shape == target.shape 59 | output = output.view(-1) 60 | target = target.view(-1) 61 | output[target == ignore_index] = ignore_index 62 | intersection = output[output == target] 63 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K-1) 64 | area_output = torch.histc(output, bins=K, min=0, max=K-1) 65 | area_target = torch.histc(target, bins=K, min=0, max=K-1) 66 | area_union = area_output + area_target - area_intersection 67 | return area_intersection, area_union, area_target 68 | 69 | 70 | def check_mkdir(dir_name): 71 | if not os.path.exists(dir_name): 72 | os.mkdir(dir_name) 73 | 74 | 75 | def check_makedirs(dir_name): 76 | if not os.path.exists(dir_name): 77 | os.makedirs(dir_name) 78 | 79 | 80 | def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'): 81 | """ 82 | :param model: Pytorch Model which is nn.Module 83 | :param conv: 'kaiming' or 'xavier' 84 | :param batchnorm: 'normal' or 'constant' 85 | :param linear: 'kaiming' or 'xavier' 86 | :param lstm: 'kaiming' or 'xavier' 87 | """ 88 | for m in model.modules(): 89 | if isinstance(m, (nn.modules.conv._ConvNd)): 90 | if conv == 'kaiming': 91 | initer.kaiming_normal_(m.weight) 92 | elif conv == 'xavier': 93 | initer.xavier_normal_(m.weight) 94 | else: 95 | raise ValueError("init type of conv error.\n") 96 | if m.bias is not None: 97 | initer.constant_(m.bias, 0) 98 | 99 | elif isinstance(m, (nn.modules.batchnorm._BatchNorm)): 100 | if batchnorm == 'normal': 101 | initer.normal_(m.weight, 1.0, 0.02) 102 | elif batchnorm == 'constant': 103 | initer.constant_(m.weight, 1.0) 104 | else: 105 | raise ValueError("init type of batchnorm error.\n") 106 | initer.constant_(m.bias, 0.0) 107 | 108 | elif isinstance(m, nn.Linear): 109 | if linear == 'kaiming': 110 | initer.kaiming_normal_(m.weight) 111 | elif linear == 'xavier': 112 | initer.xavier_normal_(m.weight) 113 | else: 114 | raise ValueError("init type of linear error.\n") 115 | if m.bias is not None: 116 | initer.constant_(m.bias, 0) 117 | 118 | elif isinstance(m, nn.LSTM): 119 | for name, param in m.named_parameters(): 120 | if 'weight' in name: 121 | if lstm == 'kaiming': 122 | initer.kaiming_normal_(param) 123 | elif lstm == 'xavier': 124 | initer.xavier_normal_(param) 125 | else: 126 | raise ValueError("init type of lstm error.\n") 127 | elif 'bias' in name: 128 | initer.constant_(param, 0) 129 | 130 | 131 | def group_weight(weight_group, module, lr): 132 | group_decay = [] 133 | group_no_decay = [] 134 | for m in module.modules(): 135 | if isinstance(m, nn.Linear): 136 | group_decay.append(m.weight) 137 | if m.bias is not None: 138 | group_no_decay.append(m.bias) 139 | elif isinstance(m, nn.modules.conv._ConvNd): 140 | group_decay.append(m.weight) 141 | if m.bias is not None: 142 | group_no_decay.append(m.bias) 143 | elif isinstance(m, nn.modules.batchnorm._BatchNorm): 144 | if m.weight is not None: 145 | group_no_decay.append(m.weight) 146 | if m.bias is not None: 147 | group_no_decay.append(m.bias) 148 | assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) 149 | weight_group.append(dict(params=group_decay, lr=lr)) 150 | weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) 151 | return weight_group 152 | 153 | 154 | def colorize(gray, palette): 155 | # gray: numpy array of the label and 1*3N size list palette 156 | color = Image.fromarray(gray.astype(np.uint8)).convert('P') 157 | color.putpalette(palette) 158 | return color 159 | 160 | 161 | def find_free_port(): 162 | import socket 163 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 164 | # Binding to port 0 will cause the OS to find an available port for us 165 | sock.bind(("", 0)) 166 | port = sock.getsockname()[1] 167 | sock.close() 168 | # NOTE: there is still a chance the port could be taken by other processes. 169 | return port 170 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 5 | 6 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 7 | """ 8 | import torch.utils.data as data 9 | from PIL import Image 10 | import torchvision.transforms as transforms 11 | from abc import ABC, abstractmethod 12 | 13 | 14 | class BaseDataset(data.Dataset, ABC): 15 | """This class is an abstract base class (ABC) for datasets. 16 | 17 | To create a subclass, you need to implement the following four functions: 18 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 19 | -- <__len__>: return the size of dataset. 20 | -- <__getitem__>: get a data point. 21 | -- : (optionally) add dataset-specific options and set default options. 22 | """ 23 | 24 | def __init__(self, opt): 25 | """Initialize the class; save the options in the class 26 | 27 | Parameters: 28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 29 | """ 30 | self.opt = opt 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_transform(opt, grayscale=False, convert=True, crop=True, flip=True): 64 | """Create a torchvision transformation function 65 | 66 | The type of transformation is defined by option (e.g., [opt.preprocess], [opt.load_size], [opt.crop_size]) 67 | and can be overwritten by arguments such as [convert], [crop], and [flip] 68 | 69 | Parameters: 70 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 71 | grayscale (bool) -- if convert input RGB image to a grayscale image 72 | convert (bool) -- if convert an image to a tensor array betwen [-1, 1] 73 | crop (bool) -- if apply cropping 74 | flip (bool) -- if apply horizontal flippling 75 | """ 76 | transform_list = [] 77 | if grayscale: 78 | transform_list.append(transforms.Grayscale(1)) 79 | if opt.preprocess == 'resize_and_crop': 80 | osize = [opt.load_size, opt.load_size] 81 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 82 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 83 | elif opt.preprocess == 'crop' and crop: 84 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 85 | elif opt.preprocess == 'scale_width': 86 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.crop_size))) 87 | elif opt.preprocess == 'scale_width_and_crop': 88 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size))) 89 | if crop: 90 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 91 | elif opt.preprocess == 'none': 92 | transform_list.append(transforms.Lambda(lambda img: __adjust(img))) 93 | else: 94 | raise ValueError('--preprocess %s is not a valid option.' % opt.preprocess) 95 | 96 | if not opt.no_flip and flip: 97 | transform_list.append(transforms.RandomHorizontalFlip()) 98 | if convert: 99 | transform_list += [transforms.ToTensor(), 100 | transforms.Normalize((0.5, 0.5, 0.5), 101 | (0.5, 0.5, 0.5))] 102 | return transforms.Compose(transform_list) 103 | 104 | 105 | def __adjust(img): 106 | """Modify the width and height to be multiple of 4. 107 | 108 | Parameters: 109 | img (PIL image) -- input image 110 | 111 | Returns a modified image whose width and height are mulitple of 4. 112 | 113 | the size needs to be a multiple of 4, 114 | because going through generator network may change img size 115 | and eventually cause size mismatch error 116 | """ 117 | ow, oh = img.size 118 | mult = 4 119 | if ow % mult == 0 and oh % mult == 0: 120 | return img 121 | w = (ow - 1) // mult 122 | w = (w + 1) * mult 123 | h = (oh - 1) // mult 124 | h = (h + 1) * mult 125 | 126 | if ow != w or oh != h: 127 | __print_size_warning(ow, oh, w, h) 128 | 129 | return img.resize((w, h), Image.BICUBIC) 130 | 131 | 132 | def __scale_width(img, target_width): 133 | """Resize images so that the width of the output image is the same as a target width 134 | 135 | Parameters: 136 | img (PIL image) -- input image 137 | target_width (int) -- target image width 138 | 139 | Returns a modified image whose width matches the target image width; 140 | 141 | the size needs to be a multiple of 4, 142 | because going through generator network may change img size 143 | and eventually cause size mismatch error 144 | """ 145 | ow, oh = img.size 146 | 147 | mult = 4 148 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 149 | if (ow == target_width and oh % mult == 0): 150 | return img 151 | w = target_width 152 | target_height = int(target_width * oh / ow) 153 | m = (target_height - 1) // mult 154 | h = (m + 1) * mult 155 | 156 | if target_height != h: 157 | __print_size_warning(target_width, target_height, w, h) 158 | 159 | return img.resize((w, h), Image.BICUBIC) 160 | 161 | 162 | def __print_size_warning(ow, oh, w, h): 163 | """Print warning information about image size(only print once)""" 164 | if not hasattr(__print_size_warning, 'has_printed'): 165 | print("The image size needs to be a multiple of 4. " 166 | "The loaded image size was (%d, %d), so it was adjusted to " 167 | "(%d, %d). This adjustment will be done to all images " 168 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 169 | __print_size_warning.has_printed = True 170 | -------------------------------------------------------------------------------- /data/gta5_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import torch 7 | import numpy as np 8 | import scipy.misc as m 9 | import matplotlib.pyplot as plt 10 | import matplotlib.image as imgs 11 | from PIL import Image 12 | import random 13 | import scipy.io as io 14 | from tqdm import tqdm 15 | from scipy import stats 16 | 17 | from torch.utils import data 18 | 19 | from data import BaseDataset 20 | from data.randaugment import RandAugmentMC 21 | 22 | import pickle 23 | from torchvision import transforms 24 | 25 | 26 | class GTA5_loader(BaseDataset): 27 | """ 28 | GTA5 synthetic dataset 29 | for domain adaptation to Cityscapes 30 | """ 31 | 32 | colors = [ # [ 0, 0, 0], 33 | [128, 64, 128], 34 | [244, 35, 232], 35 | [70, 70, 70], 36 | [102, 102, 156], 37 | [190, 153, 153], 38 | [153, 153, 153], 39 | [250, 170, 30], 40 | [220, 220, 0], 41 | [107, 142, 35], 42 | [152, 251, 152], 43 | [0, 130, 180], 44 | [220, 20, 60], 45 | [255, 0, 0], 46 | [0, 0, 142], 47 | [0, 0, 70], 48 | [0, 60, 100], 49 | [0, 80, 100], 50 | [0, 0, 230], 51 | [119, 11, 32], 52 | ] 53 | 54 | label_colours = dict(zip(range(19), colors)) 55 | def __init__(self, opt, logger, augmentations=None): 56 | self.opt = opt 57 | self.root = opt.src_rootpath 58 | self.split = 'all' 59 | self.augmentations = augmentations 60 | self.randaug = RandAugmentMC(2, 10) 61 | self.n_classes = 19 62 | self.img_size = (1914, 1052) 63 | 64 | self.mean = [0.0, 0.0, 0.0] #TODO: calculating the mean value of rgb channels on GTA5 65 | self.image_base_path = os.path.join(self.root, 'images') 66 | self.label_base_path = os.path.join(self.root, 'labels') 67 | splits = io.loadmat(os.path.join(self.root, 'split.mat')) 68 | if self.split == 'all': 69 | ids = np.concatenate((splits['trainIds'][:,0], splits['valIds'][:,0], splits['testIds'][:,0])) 70 | elif self.split == 'train': 71 | ids = splits['trainIds'][:,0] 72 | elif self.split == 'val': 73 | ids = splits['valIds'][:200,0] 74 | elif self.split == 'test': 75 | ids = splits['testIds'][:,0] 76 | 77 | max_iters = opt.num_steps * opt.batch_size * opt.world_size 78 | if max_iters is not None: 79 | if not os.path.exists("data/class_balance_ids_{}.p".format(max_iters)): 80 | self.label_to_file, self.file_to_label = pickle.load(open(os.path.join(self.root, "gtav_label_info.p"), "rb")) 81 | self.ids = [] 82 | SUB_EPOCH_SIZE = 3000 83 | tmp_list = [] 84 | ind = dict() 85 | for i in range(self.n_classes): 86 | ind[i] = 0 87 | for e in range(int(max_iters/SUB_EPOCH_SIZE)+1): 88 | cur_class_dist = np.zeros(self.n_classes) 89 | for i in range(SUB_EPOCH_SIZE): 90 | if cur_class_dist.sum() == 0: 91 | dist1 = cur_class_dist.copy() 92 | else: 93 | dist1 = cur_class_dist/cur_class_dist.sum() 94 | w = 1/np.log(1+1e-2 + dist1) 95 | w = w/w.sum() 96 | c = np.random.choice(self.n_classes, p=w) 97 | 98 | if ind[c] > (len(self.label_to_file[c])-1): 99 | np.random.shuffle(self.label_to_file[c]) 100 | ind[c] = ind[c]%(len(self.label_to_file[c])-1) 101 | 102 | c_file = self.label_to_file[c][ind[c]] 103 | tmp_list.append(c_file) 104 | ind[c] = ind[c]+1 105 | cur_class_dist[self.file_to_label[c_file]] += 1 106 | 107 | self.ids = [os.path.join(self.label_base_path, x) for x in tmp_list] 108 | with open("data/class_balance_ids_{}.p".format(max_iters), 'wb') as f: 109 | pickle.dump(self.ids, f) 110 | else: 111 | with open("data/class_balance_ids_{}.p".format(max_iters), 'rb') as f: 112 | self.ids = pickle.load(f) 113 | 114 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, 34, -1] 115 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33,] 116 | self.class_names = ["unlabelled","road","sidewalk","building","wall","fence","pole","traffic_light", 117 | "traffic_sign","vegetation","terrain","sky","person","rider","car","truck","bus","train", 118 | "motorcycle","bicycle",] 119 | 120 | self.ignore_index = 250 121 | self.class_map = dict(zip(self.valid_classes, range(19))) 122 | 123 | clrjit_params = getattr(opt, "clrjit_params", [0.5, 0.5, 0.5, 0.2]) 124 | self.train_transform = transforms.Compose([ 125 | transforms.ToPILImage(), 126 | transforms.ColorJitter(*clrjit_params), 127 | ]) 128 | 129 | if len(self.ids) == 0: 130 | raise Exception( 131 | "No files for style=[%s] found in %s" % (self.split, self.image_base_path) 132 | ) 133 | 134 | print("Found {} {} images".format(len(self.ids), self.split)) 135 | 136 | def __len__(self): 137 | return len(self.ids) 138 | 139 | def __getitem__(self, index): 140 | """__getitem__ 141 | 142 | param: index 143 | """ 144 | id = self.ids[index] 145 | if self.split != 'all' and self.split != 'val': 146 | filename = '{:05d}.png'.format(id) 147 | img_path = os.path.join(self.image_base_path, filename) 148 | lbl_path = os.path.join(self.label_base_path, filename) 149 | else: 150 | img_path = os.path.join(self.image_base_path, id.split('/')[-1]) 151 | lbl_path = id 152 | 153 | img = Image.open(img_path) 154 | lbl = Image.open(lbl_path) 155 | 156 | img = img.resize(self.img_size, Image.BILINEAR) 157 | lbl = lbl.resize(self.img_size, Image.NEAREST) 158 | img = np.asarray(img, dtype=np.uint8) 159 | lbl = np.asarray(lbl, dtype=np.uint8) 160 | 161 | lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8)) 162 | 163 | input_dict = {} 164 | if self.augmentations!=None: 165 | img, lbl, _, _, _ = self.augmentations(img, lbl) 166 | img_strong, params = self.randaug(Image.fromarray(img)) 167 | img_strong, _ = self.transform(img_strong, lbl) 168 | input_dict['img_strong'] = img_strong 169 | input_dict['params'] = params 170 | 171 | img = self.train_transform(img) 172 | 173 | img, lbl = self.transform(img, lbl) 174 | 175 | input_dict['img'] = img 176 | input_dict['label'] = lbl 177 | input_dict['img_path'] = self.ids[index] 178 | return input_dict 179 | 180 | 181 | def encode_segmap(self, lbl): 182 | for _i in self.void_classes: 183 | lbl[lbl == _i] = self.ignore_index 184 | for _i in self.valid_classes: 185 | lbl[lbl == _i] = self.class_map[_i] 186 | return lbl 187 | 188 | def decode_segmap(self, temp): 189 | r = temp.copy() 190 | g = temp.copy() 191 | b = temp.copy() 192 | for l in range(0, self.n_classes): 193 | r[temp == l] = self.label_colours[l][0] 194 | g[temp == l] = self.label_colours[l][1] 195 | b[temp == l] = self.label_colours[l][2] 196 | 197 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 198 | rgb[:, :, 0] = r / 255.0 199 | rgb[:, :, 1] = g / 255.0 200 | rgb[:, :, 2] = b / 255.0 201 | return rgb 202 | 203 | def transform(self, img, lbl): 204 | """transform 205 | 206 | img, lbl 207 | """ 208 | img = np.array(img) 209 | # img = img[:, :, ::-1] # RGB -> BGR 210 | img = img.astype(np.float64) 211 | img -= self.mean 212 | img = img.astype(float) / 255.0 213 | img = img.transpose(2, 0, 1) 214 | 215 | classes = np.unique(lbl) 216 | lbl = np.array(lbl) 217 | lbl = lbl.astype(float) 218 | # lbl = m.imresize(lbl, self.img_size, "nearest", mode='F') 219 | lbl = lbl.astype(int) 220 | 221 | if not np.all(classes == np.unique(lbl)): 222 | print("WARN: resizing labels yielded fewer classes") #TODO: compare the original and processed ones 223 | 224 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): 225 | print("after det", classes, np.unique(lbl)) 226 | raise ValueError("Segmentation map contained invalid class values") 227 | 228 | img = torch.from_numpy(img).float() 229 | lbl = torch.from_numpy(lbl).long() 230 | 231 | return img, lbl 232 | 233 | def get_cls_num_list(self): 234 | cls_num_list = np.array([16139327127, 4158369631, 8495419275, 927064742, 318109335, 235 | 532432540, 67453231, 40526481, 3818867486, 1081467674, 236 | 6800402117, 182228033, 15360044, 1265024472, 567736474, 237 | 184854135, 32542442, 15832619, 2721193]) 238 | # cls_num_list = np.zeros(self.n_classes, dtype=np.int64) 239 | # for n in range(len(self.ids)): 240 | # lbl = Image.open(self.ids[n]) 241 | # lbl = lbl.resize(self.img_size, Image.NEAREST) 242 | # lbl = np.asarray(lbl, dtype=np.uint8) 243 | # lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8)) 244 | # for i in range(self.n_classes): 245 | # cls_num_list[i] += (lbl == i).sum() 246 | return cls_num_list 247 | -------------------------------------------------------------------------------- /data/cityscapes_val_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import torch 6 | import numpy as np 7 | import scipy.misc as m 8 | from tqdm import tqdm 9 | 10 | from torch.utils import data 11 | from PIL import Image 12 | 13 | from data.augmentations import * 14 | from data.base_dataset import BaseDataset 15 | from data.randaugment import RandAugmentMC 16 | 17 | import random 18 | 19 | def recursive_glob(rootdir=".", suffix=""): 20 | """Performs recursive glob with given suffix and rootdir 21 | :param rootdir is the root directory 22 | :param suffix is the suffix to be searched 23 | """ 24 | return [ 25 | os.path.join(looproot, filename) 26 | for looproot, _, filenames in os.walk(rootdir) #os.walk: traversal all files in rootdir and its subfolders 27 | for filename in filenames 28 | if filename.endswith(suffix) 29 | ] 30 | 31 | class Cityscapes_val_loader(BaseDataset): 32 | """cityscapesLoader 33 | 34 | https://www.cityscapes-dataset.com 35 | 36 | Data is derived from CityScapes, and can be downloaded from here: 37 | https://www.cityscapes-dataset.com/downloads/ 38 | 39 | Many Thanks to @fvisin for the loader repo: 40 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py 41 | """ 42 | 43 | colors = [ # [ 0, 0, 0], 44 | [128, 64, 128], 45 | [244, 35, 232], 46 | [70, 70, 70], 47 | [102, 102, 156], 48 | [190, 153, 153], 49 | [153, 153, 153], 50 | [250, 170, 30], 51 | [220, 220, 0], 52 | [107, 142, 35], 53 | [152, 251, 152], 54 | [0, 130, 180], 55 | [220, 20, 60], 56 | [255, 0, 0], 57 | [0, 0, 142], 58 | [0, 0, 70], 59 | [0, 60, 100], 60 | [0, 80, 100], 61 | [0, 0, 230], 62 | [119, 11, 32], 63 | ] 64 | 65 | label_colours = dict(zip(range(19), colors)) 66 | 67 | mean_rgb = { 68 | "pascal": [103.939, 116.779, 123.68], 69 | "cityscapes": [0.0, 0.0, 0.0], 70 | } # pascal mean for PSPNet and ICNet pre-trained model 71 | 72 | def __init__(self, opt, logger, augmentations = None, split='train'): 73 | """__init__ 74 | 75 | :param opt: parameters of dataset 76 | :param writer: save the result of experiment 77 | :param logger: logging file 78 | :param augmentations: 79 | """ 80 | 81 | self.opt = opt 82 | self.root = opt.tgt_rootpath 83 | self.split = split 84 | self.augmentations = augmentations 85 | self.randaug = RandAugmentMC(2, 10) 86 | self.n_classes = opt.num_classes 87 | self.img_size = (2048, 1024) 88 | self.mean = np.array(self.mean_rgb['cityscapes']) 89 | self.files = {} 90 | self.paired_files = {} 91 | 92 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split) 93 | self.annotations_base = os.path.join( 94 | self.root, "gtFine", self.split 95 | ) 96 | 97 | self.files = sorted(recursive_glob(rootdir=self.images_base, suffix=".png")) #find all files from rootdir and subfolders with suffix = ".png" 98 | 99 | #self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 100 | if self.n_classes == 19: 101 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33,] 102 | self.class_names = ["unlabelled","road","sidewalk","building","wall", 103 | "fence","pole","traffic_light","traffic_sign","vegetation", 104 | "terrain","sky","person","rider","car", 105 | "truck","bus","train","motorcycle","bicycle", 106 | ] 107 | self.to19 = dict(zip(range(19), range(19))) 108 | elif self.n_classes == 16: 109 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 23, 24, 25, 26, 28, 32, 33,] 110 | self.class_names = ["unlabelled","road","sidewalk","building","wall", 111 | "fence","pole","traffic_light","traffic_sign","vegetation", 112 | "sky","person","rider","car","bus", 113 | "motorcycle","bicycle", 114 | ] 115 | self.to19 = dict(zip(range(16), [0,1,2,3,4,5,6,7,8,10,11,12,13,15,17,18])) 116 | elif self.n_classes == 13: 117 | self.valid_classes = [7, 8, 11, 19, 20, 21, 23, 24, 25, 26, 28, 32, 33,] 118 | self.class_names = ["unlabelled","road","sidewalk","building","traffic_light", 119 | "traffic_sign","vegetation","sky","person","rider", 120 | "car","bus","motorcycle","bicycle", 121 | ] 122 | self.to19 = dict(zip(range(13), [0,1,2,6,7,8,10,11,12,13,15,17,18])) 123 | 124 | self.ignore_index = 250 125 | self.class_map = dict(zip(self.valid_classes, range(self.n_classes))) #zip: return tuples 126 | 127 | if not self.files: 128 | raise Exception( 129 | "No files for split=[%s] found in %s" % (self.split, self.images_base) 130 | ) 131 | 132 | print("Found %d %s images" % (len(self.files), self.split)) 133 | 134 | def __len__(self): 135 | """__len__""" 136 | return len(self.files) 137 | 138 | def __getitem__(self, index): 139 | """__getitem__ 140 | 141 | :param index: 142 | """ 143 | img_path = self.files[index].rstrip() 144 | lbl_path = os.path.join( 145 | self.annotations_base, 146 | img_path.split(os.sep)[-2], 147 | os.path.basename(img_path)[:-15] + "gtFine_labelIds.png", 148 | ) 149 | 150 | img = Image.open(img_path) 151 | lbl = Image.open(lbl_path) 152 | img = img.resize(self.img_size, Image.BILINEAR) 153 | lbl = lbl.resize(self.img_size, Image.NEAREST) 154 | 155 | img = np.array(img, dtype=np.uint8) 156 | lbl = np.array(lbl, dtype=np.uint8) 157 | lbl = self.encode_segmap(np.array(lbl, dtype=np.uint8)) 158 | 159 | img_full = img.copy().astype(np.float64) 160 | img_full -= self.mean 161 | img_full = img_full.astype(float) / 255.0 162 | img_full = img_full.transpose(2, 0, 1) 163 | lbl_full = lbl.copy() 164 | 165 | lp, lpsoft, weak_params = None, None, None 166 | input_dict = {} 167 | if self.augmentations!=None: 168 | img, lbl, lp, lpsoft, weak_params = self.augmentations(img, lbl, lp, lpsoft) 169 | img_strong, params = self.randaug(Image.fromarray(img)) 170 | img_strong, _, _ = self.transform(img_strong, lbl) 171 | input_dict['img_strong'] = img_strong 172 | input_dict['params'] = params 173 | 174 | img, lbl_, lp = self.transform(img, lbl, lp) 175 | 176 | input_dict['img'] = img 177 | input_dict['img_full'] = torch.from_numpy(img_full).float() 178 | input_dict['label'] = lbl_ 179 | input_dict['lp'] = lp 180 | input_dict['lpsoft'] = lpsoft 181 | input_dict['weak_params'] = weak_params #full2weak 182 | input_dict['img_path'] = self.files[index] 183 | input_dict['lbl_full'] = torch.from_numpy(lbl_full).long() 184 | 185 | input_dict = {k:v for k, v in input_dict.items() if v is not None} 186 | return input_dict 187 | 188 | def transform(self, img, lbl, lp=None, check=True): 189 | """transform 190 | 191 | :param img: 192 | :param lbl: 193 | """ 194 | # img = m.imresize( 195 | # img, (self.img_size[0], self.img_size[1]) 196 | # ) # uint8 with RGB mode 197 | img = np.array(img) 198 | # img = img[:, :, ::-1] # RGB -> BGR 199 | img = img.astype(np.float64) 200 | img -= self.mean 201 | img = img.astype(float) / 255.0 202 | # NHWC -> NCHW 203 | img = img.transpose(2, 0, 1) 204 | 205 | classes = np.unique(lbl) 206 | lbl = np.array(lbl) 207 | lbl = lbl.astype(float) 208 | # lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 209 | lbl = lbl.astype(int) 210 | 211 | if not np.all(classes == np.unique(lbl)): 212 | print("WARN: resizing labels yielded fewer classes") #TODO: compare the original and processed ones 213 | 214 | if check and not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): #todo: understanding the meaning 215 | print("after det", classes, np.unique(lbl)) 216 | raise ValueError("Segmentation map contained invalid class values") 217 | 218 | img = torch.from_numpy(img).float() 219 | lbl = torch.from_numpy(lbl).long() 220 | 221 | if lp is not None: 222 | classes = np.unique(lp) 223 | lp = np.array(lp) 224 | # if not np.all(np.unique(lp[lp != self.ignore_index]) < self.n_classes): 225 | # raise ValueError("lp Segmentation map contained invalid class values") 226 | 227 | lp = torch.from_numpy(lp).long() 228 | 229 | return img, lbl, lp 230 | 231 | def decode_segmap(self, temp): 232 | r = temp.copy() 233 | g = temp.copy() 234 | b = temp.copy() 235 | for l in range(0, self.n_classes): 236 | r[temp == l] = self.label_colours[self.to19[l]][0] 237 | g[temp == l] = self.label_colours[self.to19[l]][1] 238 | b[temp == l] = self.label_colours[self.to19[l]][2] 239 | 240 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 241 | rgb[:, :, 0] = r / 255.0 242 | rgb[:, :, 1] = g / 255.0 243 | rgb[:, :, 2] = b / 255.0 244 | return rgb 245 | 246 | def encode_segmap(self, mask): 247 | # Put all void classes to zero 248 | label_copy = 250 * np.ones(mask.shape, dtype=np.uint8) 249 | for k, v in list(self.class_map.items()): 250 | label_copy[mask == k] = v 251 | return label_copy 252 | 253 | def get_cls_num_list(self): 254 | cls_num_list = np.array([1557726944, 254364912, 673500400, 18431664, 14431392, 255 | 29361440, 7038112, 7352368, 477239920, 40134240, 256 | 211669120, 36057968, 865184, 264786464, 17128544, 257 | 2385680, 943312, 504112, 2174560]) 258 | return cls_num_list 259 | -------------------------------------------------------------------------------- /data/cityscapes_train_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import torch 6 | import numpy as np 7 | import scipy.misc as m 8 | from tqdm import tqdm 9 | 10 | from torch.utils import data 11 | from PIL import Image 12 | 13 | from data.augmentations import * 14 | from data.base_dataset import BaseDataset 15 | from data.randaugment import RandAugmentMC 16 | 17 | import random 18 | from torchvision import transforms 19 | 20 | def recursive_glob(rootdir=".", suffix=""): 21 | """Performs recursive glob with given suffix and rootdir 22 | :param rootdir is the root directory 23 | :param suffix is the suffix to be searched 24 | """ 25 | return [ 26 | os.path.join(looproot, filename) 27 | for looproot, _, filenames in os.walk(rootdir) #os.walk: traversal all files in rootdir and its subfolders 28 | for filename in filenames 29 | if filename.endswith(suffix) 30 | ] 31 | 32 | class Cityscapes_train_loader(BaseDataset): 33 | """cityscapesLoader 34 | 35 | https://www.cityscapes-dataset.com 36 | 37 | Data is derived from CityScapes, and can be downloaded from here: 38 | https://www.cityscapes-dataset.com/downloads/ 39 | 40 | Many Thanks to @fvisin for the loader repo: 41 | https://github.com/fvisin/dataset_loaders/blob/master/dataset_loaders/images/cityscapes.py 42 | """ 43 | 44 | colors = [ # [ 0, 0, 0], 45 | [128, 64, 128], 46 | [244, 35, 232], 47 | [70, 70, 70], 48 | [102, 102, 156], 49 | [190, 153, 153], 50 | [153, 153, 153], 51 | [250, 170, 30], 52 | [220, 220, 0], 53 | [107, 142, 35], 54 | [152, 251, 152], 55 | [0, 130, 180], 56 | [220, 20, 60], 57 | [255, 0, 0], 58 | [0, 0, 142], 59 | [0, 0, 70], 60 | [0, 60, 100], 61 | [0, 80, 100], 62 | [0, 0, 230], 63 | [119, 11, 32], 64 | ] 65 | 66 | label_colours = dict(zip(range(19), colors)) 67 | 68 | mean_rgb = { 69 | "pascal": [103.939, 116.779, 123.68], 70 | "cityscapes": [0.0, 0.0, 0.0], 71 | } # pascal mean for PSPNet and ICNet pre-trained model 72 | 73 | def __init__(self, opt, logger, augmentations = None, split='train'): 74 | """__init__ 75 | 76 | :param opt: parameters of dataset 77 | :param writer: save the result of experiment 78 | :param logger: logging file 79 | :param augmentations: 80 | """ 81 | 82 | self.opt = opt 83 | self.root = opt.tgt_rootpath 84 | self.split = split 85 | self.augmentations = augmentations 86 | self.randaug = RandAugmentMC(2, 10) 87 | self.n_classes = opt.num_classes 88 | self.img_size = (2048, 1024) 89 | self.mean = np.array(self.mean_rgb['cityscapes']) 90 | self.files = {} 91 | self.paired_files = {} 92 | 93 | if logger is not None: 94 | logger.info("pseudo_labels_folder set to {}".format(opt.pseudo_labels_folder)) 95 | 96 | self.images_base = os.path.join(self.root, "leftImg8bit", self.split) 97 | self.annotations_base = os.path.join(opt.pseudo_labels_folder) 98 | 99 | self.files = sorted(recursive_glob(rootdir=self.images_base, suffix=".png")) #find all files from rootdir and subfolders with suffix = ".png" 100 | 101 | #self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 102 | if self.n_classes == 19: 103 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33,] 104 | self.class_names = ["unlabelled","road","sidewalk","building","wall", 105 | "fence","pole","traffic_light","traffic_sign","vegetation", 106 | "terrain","sky","person","rider","car", 107 | "truck","bus","train","motorcycle","bicycle", 108 | ] 109 | self.to19 = dict(zip(range(19), range(19))) 110 | elif self.n_classes == 16: 111 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 23, 24, 25, 26, 28, 32, 33,] 112 | self.class_names = ["unlabelled","road","sidewalk","building","wall", 113 | "fence","pole","traffic_light","traffic_sign","vegetation", 114 | "sky","person","rider","car","bus", 115 | "motorcycle","bicycle", 116 | ] 117 | self.to19 = dict(zip(range(16), [0,1,2,3,4,5,6,7,8,10,11,12,13,15,17,18])) 118 | elif self.n_classes == 13: 119 | self.valid_classes = [7, 8, 11, 19, 20, 21, 23, 24, 25, 26, 28, 32, 33,] 120 | self.class_names = ["unlabelled","road","sidewalk","building","traffic_light", 121 | "traffic_sign","vegetation","sky","person","rider", 122 | "car","bus","motorcycle","bicycle", 123 | ] 124 | self.to19 = dict(zip(range(13), [0,1,2,6,7,8,10,11,12,13,15,17,18])) 125 | 126 | self.ignore_index = 250 127 | self.class_map = dict(zip(self.valid_classes, range(self.n_classes))) #zip: return tuples 128 | 129 | if not self.files: 130 | raise Exception( 131 | "No files for split=[%s] found in %s" % (self.split, self.images_base) 132 | ) 133 | 134 | clrjit_params = getattr(opt, "clrjit_params", [0.5, 0.5, 0.5, 0.2]) 135 | self.train_transform = transforms.Compose([ 136 | transforms.ToPILImage(), 137 | transforms.ColorJitter(*clrjit_params), 138 | ]) 139 | 140 | print("Found %d %s images" % (len(self.files), self.split)) 141 | 142 | def __len__(self): 143 | """__len__""" 144 | return len(self.files) 145 | 146 | def __getitem__(self, index): 147 | """__getitem__ 148 | 149 | :param index: 150 | """ 151 | img_path = self.files[index].rstrip() 152 | lbl_path = os.path.join(self.annotations_base, img_path.split("/")[-1]) 153 | 154 | img = Image.open(img_path) 155 | lbl = Image.open(lbl_path) if os.path.exists(lbl_path) else Image.fromarray(np.zeros(img.size[:2])) 156 | img = img.resize(self.img_size, Image.BILINEAR) 157 | lbl = lbl.resize(self.img_size, Image.NEAREST) 158 | 159 | img = np.array(img, dtype=np.uint8) 160 | lbl = np.array(lbl, dtype=np.uint8) 161 | 162 | img_full = img.copy().astype(np.float64) 163 | img_full -= self.mean 164 | img_full = img_full.astype(float) / 255.0 165 | img_full = img_full.transpose(2, 0, 1) 166 | lbl_full = lbl.copy() 167 | 168 | lp, lpsoft, weak_params = None, None, None 169 | if self.split == 'train' and hasattr(self.opt, "soft_labels_folder"): 170 | lpsoft = np.load(os.path.join(self.opt.soft_labels_folder, os.path.basename(img_path).replace('.png', '.npy'))) 171 | 172 | input_dict = {} 173 | if self.augmentations!=None: 174 | img, lbl, lp, lpsoft, weak_params = self.augmentations(img, lbl, lp, lpsoft) 175 | img_strong, params = self.randaug(Image.fromarray(img)) 176 | img_strong, _, _ = self.transform(img_strong, lbl) 177 | input_dict['img_strong'] = img_strong 178 | input_dict['params'] = params 179 | 180 | img = self.train_transform(img) 181 | 182 | img, lbl_, lp = self.transform(img, lbl, lp) 183 | 184 | input_dict['img'] = img 185 | input_dict['img_full'] = torch.from_numpy(img_full).float() 186 | input_dict['label'] = lbl_ 187 | input_dict['lp'] = lp 188 | input_dict['lpsoft'] = lpsoft 189 | input_dict['weak_params'] = weak_params #full2weak 190 | input_dict['img_path'] = self.files[index] 191 | input_dict['lbl_full'] = torch.from_numpy(lbl_full).long() 192 | 193 | input_dict = {k:v for k, v in input_dict.items() if v is not None} 194 | return input_dict 195 | 196 | def transform(self, img, lbl, lp=None, check=True): 197 | """transform 198 | 199 | :param img: 200 | :param lbl: 201 | """ 202 | # img = m.imresize( 203 | # img, (self.img_size[0], self.img_size[1]) 204 | # ) # uint8 with RGB mode 205 | img = np.array(img) 206 | # img = img[:, :, ::-1] # RGB -> BGR 207 | img = img.astype(np.float64) 208 | img -= self.mean 209 | img = img.astype(float) / 255.0 210 | # NHWC -> NCHW 211 | img = img.transpose(2, 0, 1) 212 | 213 | classes = np.unique(lbl) 214 | lbl = np.array(lbl) 215 | lbl = lbl.astype(float) 216 | # lbl = m.imresize(lbl, (self.img_size[0], self.img_size[1]), "nearest", mode="F") 217 | lbl = lbl.astype(int) 218 | 219 | if not np.all(classes == np.unique(lbl)): 220 | print("WARN: resizing labels yielded fewer classes") #TODO: compare the original and processed ones 221 | 222 | if check and not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): #todo: understanding the meaning 223 | print("after det", classes, np.unique(lbl)) 224 | raise ValueError("Segmentation map contained invalid class values") 225 | 226 | img = torch.from_numpy(img).float() 227 | lbl = torch.from_numpy(lbl).long() 228 | 229 | if lp is not None: 230 | classes = np.unique(lp) 231 | lp = np.array(lp) 232 | # if not np.all(np.unique(lp[lp != self.ignore_index]) < self.n_classes): 233 | # raise ValueError("lp Segmentation map contained invalid class values") 234 | 235 | lp = torch.from_numpy(lp).long() 236 | 237 | return img, lbl, lp 238 | 239 | def decode_segmap(self, temp): 240 | r = temp.copy() 241 | g = temp.copy() 242 | b = temp.copy() 243 | for l in range(0, self.n_classes): 244 | r[temp == l] = self.label_colours[self.to19[l]][0] 245 | g[temp == l] = self.label_colours[self.to19[l]][1] 246 | b[temp == l] = self.label_colours[self.to19[l]][2] 247 | 248 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 249 | rgb[:, :, 0] = r / 255.0 250 | rgb[:, :, 1] = g / 255.0 251 | rgb[:, :, 2] = b / 255.0 252 | return rgb 253 | 254 | def encode_segmap(self, mask): 255 | # Put all void classes to zero 256 | label_copy = 250 * np.ones(mask.shape, dtype=np.uint8) 257 | for k, v in list(self.class_map.items()): 258 | label_copy[mask == k] = v 259 | return label_copy 260 | 261 | def get_cls_num_list(self): 262 | cls_num_list = np.array([1557726944, 254364912, 673500400, 18431664, 14431392, 263 | 29361440, 7038112, 7352368, 477239920, 40134240, 264 | 211669120, 36057968, 865184, 264786464, 17128544, 265 | 2385680, 943312, 504112, 2174560]) 266 | return cls_num_list 267 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | import numbers 5 | import collections 6 | import cv2 7 | 8 | import torch 9 | 10 | 11 | class Compose(object): 12 | # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()]) 13 | def __init__(self, segtransform): 14 | self.segtransform = segtransform 15 | 16 | def __call__(self, image, label): 17 | for t in self.segtransform: 18 | image, label = t(image, label) 19 | return image, label 20 | 21 | 22 | class ToTensor(object): 23 | # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 24 | def __call__(self, image, label): 25 | if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray): 26 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray" 27 | "[eg: data readed by cv2.imread()].\n")) 28 | if len(image.shape) > 3 or len(image.shape) < 2: 29 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n")) 30 | if len(image.shape) == 2: 31 | image = np.expand_dims(image, axis=2) 32 | if not len(label.shape) == 2: 33 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n")) 34 | 35 | image = torch.from_numpy(image.transpose((2, 0, 1))) 36 | if not isinstance(image, torch.FloatTensor): 37 | image = image.float() 38 | label = torch.from_numpy(label) 39 | if not isinstance(label, torch.LongTensor): 40 | label = label.long() 41 | return image, label 42 | 43 | 44 | class Normalize(object): 45 | # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std 46 | def __init__(self, mean, std=None): 47 | if std is None: 48 | assert len(mean) > 0 49 | else: 50 | assert len(mean) == len(std) 51 | self.mean = mean 52 | self.std = std 53 | 54 | def __call__(self, image, label): 55 | if self.std is None: 56 | for t, m in zip(image, self.mean): 57 | t.sub_(m) 58 | else: 59 | for t, m, s in zip(image, self.mean, self.std): 60 | t.sub_(m).div_(s) 61 | return image, label 62 | 63 | 64 | class Resize(object): 65 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). 66 | def __init__(self, size): 67 | assert (isinstance(size, collections.Iterable) and len(size) == 2) 68 | self.size = size 69 | 70 | def __call__(self, image, label): 71 | size = self.size[::-1] 72 | image = cv2.resize(image, (self.size[1], self.size[0]), interpolation=cv2.INTER_LINEAR) 73 | label = cv2.resize(label, (self.size[1], self.size[0]), interpolation=cv2.INTER_NEAREST) 74 | return image, label 75 | 76 | 77 | class RandScale(object): 78 | # Randomly resize image & label with scale factor in [scale_min, scale_max] 79 | def __init__(self, scale, aspect_ratio=None): 80 | assert (isinstance(scale, collections.Iterable) and len(scale) == 2) 81 | if isinstance(scale, collections.Iterable) and len(scale) == 2 \ 82 | and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \ 83 | and 0 < scale[0] <= scale[1]: 84 | self.scale = scale 85 | else: 86 | raise (RuntimeError("segtransform.RandScale() scale param error.\n")) 87 | if aspect_ratio is None: 88 | self.aspect_ratio = aspect_ratio 89 | elif isinstance(aspect_ratio, collections.Iterable) and len(aspect_ratio) == 2 \ 90 | and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \ 91 | and 0 < aspect_ratio[0] < aspect_ratio[1]: 92 | self.aspect_ratio = aspect_ratio 93 | else: 94 | raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n")) 95 | 96 | def __call__(self, image, label): 97 | temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random() 98 | temp_aspect_ratio = 1.0 99 | if self.aspect_ratio is not None: 100 | temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random() 101 | temp_aspect_ratio = math.sqrt(temp_aspect_ratio) 102 | scale_factor_x = temp_scale * temp_aspect_ratio 103 | scale_factor_y = temp_scale / temp_aspect_ratio 104 | image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR) 105 | label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST) 106 | return image, label 107 | 108 | 109 | class Crop(object): 110 | """Crops the given ndarray image (H*W*C or H*W). 111 | Args: 112 | size (sequence or int): Desired output size of the crop. If size is an 113 | int instead of sequence like (h, w), a square crop (size, size) is made. 114 | """ 115 | def __init__(self, size, crop_type='center', padding=None, ignore_label=255): 116 | if isinstance(size, int): 117 | self.crop_h = size 118 | self.crop_w = size 119 | elif isinstance(size, collections.Iterable) and len(size) == 2 \ 120 | and isinstance(size[0], int) and isinstance(size[1], int) \ 121 | and size[0] > 0 and size[1] > 0: 122 | self.crop_h = size[0] 123 | self.crop_w = size[1] 124 | else: 125 | raise (RuntimeError("crop size error.\n")) 126 | if crop_type == 'center' or crop_type == 'rand': 127 | self.crop_type = crop_type 128 | else: 129 | raise (RuntimeError("crop type error: rand | center\n")) 130 | if padding is None: 131 | self.padding = padding 132 | elif isinstance(padding, list): 133 | if all(isinstance(i, numbers.Number) for i in padding): 134 | self.padding = padding 135 | else: 136 | raise (RuntimeError("padding in Crop() should be a number list\n")) 137 | if len(padding) != 3: 138 | raise (RuntimeError("padding channel is not equal with 3\n")) 139 | else: 140 | raise (RuntimeError("padding in Crop() should be a number list\n")) 141 | if isinstance(ignore_label, int): 142 | self.ignore_label = ignore_label 143 | else: 144 | raise (RuntimeError("ignore_label should be an integer number\n")) 145 | 146 | def __call__(self, image, label): 147 | h, w = label.shape 148 | pad_h = max(self.crop_h - h, 0) 149 | pad_w = max(self.crop_w - w, 0) 150 | pad_h_half = int(pad_h / 2) 151 | pad_w_half = int(pad_w / 2) 152 | if pad_h > 0 or pad_w > 0: 153 | if self.padding is None: 154 | raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n")) 155 | image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.padding) 156 | label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label) 157 | h, w = label.shape 158 | if self.crop_type == 'rand': 159 | h_off = random.randint(0, h - self.crop_h) 160 | w_off = random.randint(0, w - self.crop_w) 161 | else: 162 | h_off = int((h - self.crop_h) / 2) 163 | w_off = int((w - self.crop_w) / 2) 164 | image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] 165 | label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w] 166 | return image, label 167 | 168 | 169 | class RandRotate(object): 170 | # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max] 171 | def __init__(self, rotate, padding, ignore_label=255, p=0.5): 172 | assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2) 173 | if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]: 174 | self.rotate = rotate 175 | else: 176 | raise (RuntimeError("segtransform.RandRotate() scale param error.\n")) 177 | assert padding is not None 178 | assert isinstance(padding, list) and len(padding) == 3 179 | if all(isinstance(i, numbers.Number) for i in padding): 180 | self.padding = padding 181 | else: 182 | raise (RuntimeError("padding in RandRotate() should be a number list\n")) 183 | assert isinstance(ignore_label, int) 184 | self.ignore_label = ignore_label 185 | self.p = p 186 | 187 | def __call__(self, image, label): 188 | if random.random() < self.p: 189 | angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random() 190 | h, w = label.shape 191 | matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 192 | image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=self.padding) 193 | label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=self.ignore_label) 194 | return image, label 195 | 196 | 197 | class RandomHorizontalFlip(object): 198 | def __init__(self, p=0.5): 199 | self.p = p 200 | 201 | def __call__(self, image, label): 202 | if random.random() < self.p: 203 | image = cv2.flip(image, 1) 204 | label = cv2.flip(label, 1) 205 | return image, label 206 | 207 | 208 | class RandomVerticalFlip(object): 209 | def __init__(self, p=0.5): 210 | self.p = p 211 | 212 | def __call__(self, image, label): 213 | if random.random() < self.p: 214 | image = cv2.flip(image, 0) 215 | label = cv2.flip(label, 0) 216 | return image, label 217 | 218 | 219 | class RandomGaussianBlur(object): 220 | def __init__(self, radius=5): 221 | self.radius = radius 222 | 223 | def __call__(self, image, label): 224 | if random.random() < 0.5: 225 | image = cv2.GaussianBlur(image, (self.radius, self.radius), 0) 226 | return image, label 227 | 228 | 229 | class RGB2BGR(object): 230 | # Converts image from RGB order to BGR order, for model initialized from Caffe 231 | def __call__(self, image, label): 232 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 233 | return image, label 234 | 235 | 236 | class BGR2RGB(object): 237 | # Converts image from BGR order to RGB order, for model initialized from Pytorch 238 | def __call__(self, image, label): 239 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 240 | return image, label 241 | -------------------------------------------------------------------------------- /data/randaugment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import random 5 | 6 | import numpy as np 7 | import PIL 8 | import PIL.ImageOps 9 | import PIL.ImageEnhance 10 | import PIL.ImageDraw 11 | from PIL import Image 12 | import torch 13 | import torch.nn.functional as F 14 | import torchvision.transforms as transforms 15 | 16 | PARAMETER_MAX = 10 17 | 18 | def AutoContrast(img, **kwarg): 19 | return PIL.ImageOps.autocontrast(img), None 20 | 21 | 22 | def Brightness(img, v, max_v, bias=0): 23 | v = _float_parameter(v, max_v) + bias 24 | return PIL.ImageEnhance.Brightness(img).enhance(v), v 25 | 26 | 27 | def Color(img, v, max_v, bias=0): 28 | v = _float_parameter(v, max_v) + bias 29 | return PIL.ImageEnhance.Color(img).enhance(v), v 30 | 31 | 32 | def Contrast(img, v, max_v, bias=0): 33 | v = _float_parameter(v, max_v) + bias 34 | return PIL.ImageEnhance.Contrast(img).enhance(v), v 35 | 36 | 37 | def Cutout(img, v, max_v, bias=0): 38 | if v == 0: 39 | return img 40 | v = _float_parameter(v, max_v) + bias 41 | v = int(v * min(img.size)) 42 | return CutoutAbs(img, v) 43 | 44 | 45 | def CutoutAbs(img, v, **kwarg): 46 | w, h = img.size 47 | x0 = np.random.uniform(0, w) 48 | y0 = np.random.uniform(0, h) 49 | x0 = int(max(0, x0 - v / 2.)) 50 | y0 = int(max(0, y0 - v / 2.)) 51 | x1 = int(min(w, x0 + v)) 52 | y1 = int(min(h, y0 + v)) 53 | xy = (x0, y0, x1, y1) 54 | # gray 55 | color = (127, 127, 127) 56 | img = img.copy() 57 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 58 | return img, xy 59 | 60 | 61 | def Equalize(img, **kwarg): 62 | return PIL.ImageOps.equalize(img), None 63 | 64 | 65 | def Identity(img, **kwarg): 66 | return img, None 67 | 68 | 69 | def Invert(img, **kwarg): 70 | return PIL.ImageOps.invert(img), None 71 | 72 | 73 | def Posterize(img, v, max_v, bias=0): 74 | v = _int_parameter(v, max_v) + bias 75 | return PIL.ImageOps.posterize(img, v), v 76 | 77 | 78 | # def Rotate(img, v, max_v, bias=0): 79 | # v = _int_parameter(v, max_v) + bias 80 | # if random.random() < 0.5: 81 | # v = -v 82 | # #return img.rotate(v), v 83 | # img_t = transforms.ToTensor()(img) 84 | # H = img_t.shape[1] 85 | # W = img_t.shape[2] 86 | # theta = np.array([[np.cos(v/180*np.pi), -np.sin(v/180*np.pi), 0], [np.sin(v/180*np.pi), np.cos(v/180*np.pi), 0]]).astype(np.float) 87 | # theta[0,1] = theta[0,1]*H/W 88 | # theta[1,0] = theta[1,0]*W/H 89 | # #theta = np.array([[np.cos(v/180*np.pi), -np.sin(v/180*np.pi)], [np.sin(v/180*np.pi), np.cos(v/180*np.pi)]]).astype(np.float) 90 | # theta = torch.Tensor(theta).unsqueeze(0) 91 | 92 | # # meshgrid_x, meshgrid_y = torch.meshgrid(torch.arange(W, dtype=torch.float), torch.arange(H, dtype=torch.float)) 93 | # # meshgrid = torch.stack((meshgrid_x.t()*2/W - 1, meshgrid_y.t()*2/H - 1), dim=-1).unsqueeze(0) 94 | # # grid = torch.matmul(meshgrid, theta) 95 | 96 | # # s_h = int(abs(H - W) // 2) 97 | # # dim_last = s_h if H > W else 0 98 | # # img_t = F.pad(img_t.unsqueeze(0), (dim_last, dim_last, s_h - dim_last, s_h - dim_last)).squeeze(0) 99 | # grid = F.affine_grid(theta, img_t.unsqueeze(0).size()) 100 | # img_t = F.grid_sample(img_t.unsqueeze(0), grid, mode='bilinear').squeeze(0) 101 | # # img_t = img_t[:,:,s_h:-s_h] if H > W else img_t[:,s_h:-s_h,:] 102 | # img_t = transforms.ToPILImage()(img_t) 103 | # return img_t, v 104 | 105 | def Rotate(img, v, max_v, bias=0): 106 | v = _int_parameter(v, max_v) + bias 107 | if random.random() < 0.5: 108 | v = -v 109 | return img.rotate(v, resample=Image.BILINEAR, fillcolor=(127,127,127)), v 110 | 111 | def Sharpness(img, v, max_v, bias=0): 112 | v = _float_parameter(v, max_v) + bias 113 | return PIL.ImageEnhance.Sharpness(img).enhance(v), v 114 | 115 | 116 | def ShearX(img, v, max_v, bias=0): 117 | v = _float_parameter(v, max_v) + bias 118 | if random.random() < 0.5: 119 | v = -v 120 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0), resample=Image.BILINEAR, fillcolor=(127,127,127)), v 121 | 122 | 123 | def ShearY(img, v, max_v, bias=0): 124 | v = _float_parameter(v, max_v) + bias 125 | if random.random() < 0.5: 126 | v = -v 127 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0), resample=Image.BILINEAR, fillcolor=(127,127,127)), v 128 | 129 | 130 | def Solarize(img, v, max_v, bias=0): 131 | v = _int_parameter(v, max_v) + bias 132 | return PIL.ImageOps.solarize(img, 256 - v), 256 - v 133 | 134 | 135 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 136 | v = _int_parameter(v, max_v) + bias 137 | if random.random() < 0.5: 138 | v = -v 139 | img_np = np.array(img).astype(np.int) 140 | img_np = img_np + v 141 | img_np = np.clip(img_np, 0, 255) 142 | img_np = img_np.astype(np.uint8) 143 | img = Image.fromarray(img_np) 144 | return PIL.ImageOps.solarize(img, threshold), threshold 145 | 146 | 147 | def TranslateX(img, v, max_v, bias=0): 148 | v = _float_parameter(v, max_v) + bias 149 | if random.random() < 0.5: 150 | v = -v 151 | v = int(v * img.size[0]) 152 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0), resample=Image.BILINEAR, fillcolor=(127,127,127)), v 153 | 154 | 155 | def TranslateY(img, v, max_v, bias=0): 156 | v = _float_parameter(v, max_v) + bias 157 | if random.random() < 0.5: 158 | v = -v 159 | v = int(v * img.size[1]) 160 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v), resample=Image.BILINEAR, fillcolor=(127,127,127)), v 161 | 162 | 163 | def _float_parameter(v, max_v): 164 | return float(v) * max_v / PARAMETER_MAX 165 | 166 | 167 | def _int_parameter(v, max_v): 168 | return int(v * max_v / PARAMETER_MAX) 169 | 170 | 171 | def fixmatch_augment_pool(): 172 | # FixMatch paper 173 | augs = [(AutoContrast, None, None), 174 | (Brightness, 0.9, 0.05), 175 | (Color, 0.9, 0.05), 176 | (Contrast, 0.9, 0.05), 177 | (Equalize, None, None), 178 | (Identity, None, None), 179 | (Posterize, 4, 4), 180 | (Rotate, 30, 0), 181 | (Sharpness, 0.9, 0.05), 182 | (ShearX, 0.3, 0), 183 | (ShearY, 0.3, 0), 184 | (Solarize, 256, 0), 185 | (TranslateX, 0.3, 0), 186 | (TranslateY, 0.3, 0)] 187 | return augs 188 | 189 | 190 | def my_augment_pool(): 191 | # Test 192 | augs = [(AutoContrast, None, None), 193 | (Brightness, 1.8, 0.1), 194 | (Color, 1.8, 0.1), 195 | (Contrast, 1.8, 0.1), 196 | (Cutout, 0.2, 0), 197 | (Equalize, None, None), 198 | (Invert, None, None), 199 | (Posterize, 4, 4), 200 | (Rotate, 30, 0), 201 | (Sharpness, 1.8, 0.1), 202 | (ShearX, 0.3, 0), 203 | (ShearY, 0.3, 0), 204 | (Solarize, 256, 0), 205 | (SolarizeAdd, 110, 0), 206 | (TranslateX, 0.45, 0), 207 | (TranslateY, 0.45, 0)] 208 | return augs 209 | 210 | 211 | class RandAugmentPC(object): 212 | def __init__(self, n, m): 213 | assert n >= 1 214 | assert 1 <= m <= 10 215 | self.n = n 216 | self.m = m 217 | self.augment_pool = my_augment_pool() 218 | 219 | def __call__(self, img): 220 | ops = random.choices(self.augment_pool, k=self.n) 221 | for op, max_v, bias in ops: 222 | prob = np.random.uniform(0.2, 0.8) 223 | if random.random() + prob >= 1: 224 | img = op(img, v=self.m, max_v=max_v, bias=bias) 225 | img = CutoutAbs(img, 16) 226 | return img 227 | 228 | 229 | class RandAugmentMC(object): 230 | def __init__(self, n, m): 231 | assert n >= 1 232 | assert 1 <= m <= 10 233 | self.n = n 234 | self.m = m 235 | self.augment_pool = fixmatch_augment_pool() 236 | 237 | def __call__(self, img, type='crc'): 238 | aug_type = {'Hflip':False, 'ShearX':1e4, 'ShearY':1e4, 'TranslateX':1e4, 'TranslateY':1e4, 'Rotate':1e4, 'CutoutAbs':1e4} 239 | if random.random() < 0.5: 240 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 241 | #aug_type.append(['Hflip', True]) 242 | aug_type['Hflip'] = True 243 | if type == 'cr' or type == 'crc': 244 | ops = random.choices(self.augment_pool, k=self.n) 245 | for op, max_v, bias in ops: 246 | v = np.random.randint(1, self.m) 247 | if random.random() < 0.5: 248 | img, params = op(img, v=v, max_v=max_v, bias=bias) 249 | if op.__name__ in ['ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']: 250 | #aug_type.append([op.__name__, params]) 251 | aug_type[op.__name__] = params 252 | if type == 'cc' or type == 'crc': 253 | img, params = CutoutAbs(img, min(img.size[0], img.size[1]) // 3) 254 | #aug_type.append([CutoutAbs.__name__, params]) 255 | aug_type['CutoutAbs'] = params 256 | return img, aug_type 257 | 258 | def affine_sample(tensor, v, type): 259 | # tensor: B*C*H*W 260 | # v: scalar, translation param 261 | if type == 'Rotate': 262 | theta = np.array([[np.cos(v/180*np.pi), -np.sin(v/180*np.pi), 0], [np.sin(v/180*np.pi), np.cos(v/180*np.pi), 0]]).astype(np.float) 263 | elif type == 'ShearX': 264 | theta = np.array([[1, v, 0], [0, 1, 0]]).astype(np.float) 265 | elif type == 'ShearY': 266 | theta = np.array([[1, 0, 0], [v, 1, 0]]).astype(np.float) 267 | elif type == 'TranslateX': 268 | theta = np.array([[1, 0, v], [0, 1, 0]]).astype(np.float) 269 | elif type == 'TranslateY': 270 | theta = np.array([[1, 0, 0], [0, 1, v]]).astype(np.float) 271 | 272 | H = tensor.shape[2] 273 | W = tensor.shape[3] 274 | theta[0,1] = theta[0,1]*H/W 275 | theta[1,0] = theta[1,0]*W/H 276 | if type != 'Rotate': 277 | theta[0,2] = theta[0,2]*2/H + theta[0,0] + theta[0,1] - 1 278 | theta[1,2] = theta[1,2]*2/H + theta[1,0] + theta[1,1] - 1 279 | 280 | theta = torch.Tensor(theta).unsqueeze(0) 281 | grid = F.affine_grid(theta, tensor.size()).to(tensor.device) 282 | tensor_t = F.grid_sample(tensor, grid, mode='nearest') 283 | return tensor_t 284 | 285 | if __name__ == '__main__': 286 | randaug = RandAugmentMC(2, 10) 287 | #path = r'E:\WorkHome\IMG_20190131_142431.jpg' 288 | path = r'E:\WorkHome\0.png' 289 | img = Image.open(path) 290 | img_t = transforms.ToTensor()(img).unsqueeze(0) 291 | #img_aug, aug_type = randaug(img) 292 | #img_aug.show() 293 | 294 | # v = 20 295 | # img_pil = img.rotate(v) 296 | # img_T = affine_sample(img_t, v, 'Rotate') 297 | 298 | v = 0.12 299 | img_pil = img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 300 | img_T = affine_sample(img_t, v, 'ShearY') 301 | 302 | img_ten = transforms.ToPILImage()(img_T.squeeze(0)) 303 | img_pil.show() 304 | img_ten.show() -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from mmcv.runner import load_checkpoint 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | BatchNorm = nn.BatchNorm2d 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=dilation, groups=groups, bias=False, dilation=dilation) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 37 | base_width=64, dilation=1, norm_layer=None): 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 43 | if dilation > 1: 44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | identity = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | expansion = 4 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 77 | base_width=64, dilation=1, norm_layer=None): 78 | super(Bottleneck, self).__init__() 79 | if norm_layer is None: 80 | norm_layer = nn.BatchNorm2d 81 | width = int(planes * (base_width / 64.)) * groups 82 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 83 | self.conv1 = conv1x1(inplanes, width) 84 | self.bn1 = norm_layer(width) 85 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 86 | self.bn2 = norm_layer(width) 87 | self.conv3 = conv1x1(width, planes * self.expansion) 88 | self.bn3 = norm_layer(planes * self.expansion) 89 | self.relu = nn.ReLU(inplace=True) 90 | self.downsample = downsample 91 | self.stride = stride 92 | 93 | def forward(self, x): 94 | identity = x 95 | 96 | out = self.conv1(x) 97 | out = self.bn1(out) 98 | out = self.relu(out) 99 | 100 | out = self.conv2(out) 101 | out = self.bn2(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv3(out) 105 | out = self.bn3(out) 106 | 107 | if self.downsample is not None: 108 | identity = self.downsample(x) 109 | 110 | out += identity 111 | out = self.relu(out) 112 | 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | 118 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 119 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 120 | norm_layer=None): 121 | super(ResNet, self).__init__() 122 | if norm_layer is None: 123 | norm_layer = nn.BatchNorm2d 124 | self._norm_layer = norm_layer 125 | 126 | self.inplanes = 64 127 | self.dilation = 1 128 | if replace_stride_with_dilation is None: 129 | # each element in the tuple indicates if we should replace 130 | # the 2x2 stride with a dilated convolution instead 131 | replace_stride_with_dilation = [False, False, False] 132 | if len(replace_stride_with_dilation) != 3: 133 | raise ValueError("replace_stride_with_dilation should be None " 134 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 135 | self.groups = groups 136 | self.base_width = width_per_group 137 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 138 | bias=False) 139 | self.bn1 = norm_layer(self.inplanes) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 144 | dilate=replace_stride_with_dilation[0]) 145 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 146 | dilate=replace_stride_with_dilation[1]) 147 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 148 | dilate=replace_stride_with_dilation[2]) 149 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 150 | self.fc = nn.Linear(512 * block.expansion, num_classes) 151 | 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 155 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 156 | nn.init.constant_(m.weight, 1) 157 | nn.init.constant_(m.bias, 0) 158 | 159 | # Zero-initialize the last BN in each residual branch, 160 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 161 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 162 | if zero_init_residual: 163 | for m in self.modules(): 164 | if isinstance(m, Bottleneck): 165 | nn.init.constant_(m.bn3.weight, 0) 166 | elif isinstance(m, BasicBlock): 167 | nn.init.constant_(m.bn2.weight, 0) 168 | 169 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 170 | norm_layer = self._norm_layer 171 | downsample = None 172 | previous_dilation = self.dilation 173 | if dilate: 174 | self.dilation *= stride 175 | stride = 1 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | conv1x1(self.inplanes, planes * block.expansion, stride), 179 | norm_layer(planes * block.expansion), 180 | ) 181 | 182 | layers = [] 183 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 184 | self.base_width, previous_dilation, norm_layer)) 185 | self.inplanes = planes * block.expansion 186 | for _ in range(1, blocks): 187 | layers.append(block(self.inplanes, planes, groups=self.groups, 188 | base_width=self.base_width, dilation=self.dilation, 189 | norm_layer=norm_layer)) 190 | 191 | return nn.Sequential(*layers) 192 | 193 | def forward(self, x): 194 | 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | x = self.maxpool(x) 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | 204 | x = self.avgpool(x) 205 | x = x.reshape(x.size(0), -1) 206 | x = self.fc(x) 207 | 208 | return x 209 | 210 | 211 | def _resnet(arch, block, layers, pretrained, progress, pretrained_weights, **kwargs): 212 | model = ResNet(block, layers, **kwargs) 213 | if pretrained: 214 | # load_checkpoint(model, pretrained_weights, map_location='cpu') 215 | import torch 216 | import os 217 | if os.path.exists('./pretrained/resnet101-5d3b4d8f.pth'): 218 | saved_state_dict = torch.load('./pretrained/resnet101-5d3b4d8f.pth', map_location='cpu') 219 | print("load weight from ./pretrained/resnet101-5d3b4d8f.pth") 220 | else: 221 | raise ValueError("No saved_state_dict loaded") 222 | model.load_state_dict(saved_state_dict) 223 | return model 224 | 225 | 226 | def resnet18(pretrained=False, progress=True, **kwargs): 227 | """Constructs a ResNet-18 model. 228 | 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | progress (bool): If True, displays a progress bar of the download to stderr 232 | """ 233 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 234 | **kwargs) 235 | 236 | 237 | def resnet34(pretrained=False, progress=True, **kwargs): 238 | """Constructs a ResNet-34 model. 239 | 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet50(pretrained=False, progress=True, **kwargs): 249 | """Constructs a ResNet-50 model. 250 | 251 | Args: 252 | pretrained (bool): If True, returns a model pre-trained on ImageNet 253 | progress (bool): If True, displays a progress bar of the download to stderr 254 | """ 255 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 256 | **kwargs) 257 | 258 | 259 | def resnet101(pretrained=False, progress=True, **kwargs): 260 | """Constructs a ResNet-101 model. 261 | 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | progress (bool): If True, displays a progress bar of the download to stderr 265 | """ 266 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 267 | **kwargs) 268 | 269 | 270 | def resnet152(pretrained=False, progress=True, **kwargs): 271 | """Constructs a ResNet-152 model. 272 | 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 282 | """Constructs a ResNeXt-50 32x4d model. 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | kwargs['groups'] = 32 289 | kwargs['width_per_group'] = 4 290 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 291 | pretrained, progress, **kwargs) 292 | 293 | 294 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 295 | """Constructs a ResNeXt-101 32x8d model. 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | kwargs['groups'] = 32 302 | kwargs['width_per_group'] = 8 303 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 304 | pretrained, progress, **kwargs) 305 | -------------------------------------------------------------------------------- /data/augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | # Adapted from https://github.com/ZijunDeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py 4 | 5 | import math 6 | import numbers 7 | import random 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import torchvision.transforms.functional as tf 12 | 13 | from PIL import Image, ImageOps 14 | 15 | 16 | class Compose(object): 17 | def __init__(self, augmentations): 18 | self.augmentations = augmentations 19 | self.PIL2Numpy = False 20 | 21 | def __call__(self, img, mask, mask1=None, lpsoft=None): 22 | params = {} 23 | if isinstance(img, np.ndarray): 24 | img = Image.fromarray(img, mode="RGB") 25 | mask = Image.fromarray(mask, mode="L") 26 | if mask1 is not None: 27 | mask1 = Image.fromarray(mask1, mode="L") 28 | if lpsoft is not None: 29 | lpsoft = torch.from_numpy(lpsoft) 30 | lpsoft = F.interpolate(lpsoft.unsqueeze(0), size=[img.size[1], img.size[0]], mode='bilinear', align_corners=True)[0] 31 | self.PIL2Numpy = True 32 | 33 | if img.size != mask.size: 34 | print (img.size, mask.size) 35 | assert img.size == mask.size 36 | if mask1 is not None: 37 | assert (img.size == mask1.size) 38 | for a in self.augmentations: 39 | img, mask, mask1, lpsoft, params = a(img, mask, mask1, lpsoft, params) 40 | # print(img.size) 41 | 42 | if self.PIL2Numpy: 43 | img, mask = np.array(img), np.array(mask, dtype=np.uint8) 44 | if mask1 is not None: 45 | mask1 = np.array(mask1, dtype=np.uint8) 46 | return img, mask, mask1, lpsoft, params 47 | 48 | 49 | class RandomCrop(object): 50 | def __init__(self, size, padding=0): 51 | if isinstance(size, numbers.Number): 52 | self.size = (int(size), int(size)) 53 | else: 54 | self.size = size 55 | self.padding = padding 56 | 57 | def __call__(self, img, mask, mask1=None, lpsoft=None, params=None): 58 | if self.padding > 0: 59 | img = ImageOps.expand(img, border=self.padding, fill=0) 60 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 61 | if mask1 is not None: 62 | mask1 = ImageOps.expand(mask1, border=self.padding, fill=0) 63 | 64 | assert img.size == mask.size 65 | if mask1 is not None: 66 | assert (img.size == mask1.size) 67 | w, h = img.size 68 | 69 | # print("self.size: ", self.size) 70 | 71 | tw, th = self.size 72 | # if w == tw and h == th: 73 | # return img, mask 74 | if w < tw or h < th: 75 | if lpsoft is not None: 76 | lpsoft = F.interpolate(lpsoft.unsqueeze(0), size=[th, tw], mode='bolinear', align_corners=True)[0] 77 | if mask1 is not None: 78 | return ( 79 | img.resize((tw, th), Image.BILINEAR), 80 | mask.resize((tw, th), Image.NEAREST), 81 | mask1.resize((tw, th), Image.NEAREST), 82 | lpsoft 83 | ) 84 | else: 85 | return ( 86 | img.resize((tw, th), Image.BILINEAR), 87 | mask.resize((tw, th), Image.NEAREST), 88 | None, 89 | lpsoft 90 | ) 91 | 92 | x1 = random.randint(0, w - tw) 93 | y1 = random.randint(0, h - th) 94 | params['RandomCrop'] = (y1, y1 + th, x1, x1 + tw) 95 | if lpsoft is not None: 96 | lpsoft = lpsoft[:, y1:y1 + th, x1:x1 + tw] 97 | if mask1 is not None: 98 | return ( 99 | img.crop((x1, y1, x1 + tw, y1 + th)), 100 | mask.crop((x1, y1, x1 + tw, y1 + th)), 101 | mask1.crop((x1, y1, x1 + tw, y1 + th)), 102 | lpsoft, 103 | params 104 | ) 105 | else: 106 | return ( 107 | img.crop((x1, y1, x1 + tw, y1 + th)), 108 | mask.crop((x1, y1, x1 + tw, y1 + th)), 109 | None, 110 | lpsoft, 111 | params 112 | ) 113 | 114 | 115 | class AdjustGamma(object): 116 | def __init__(self, gamma): 117 | self.gamma = gamma 118 | 119 | def __call__(self, img, mask): 120 | assert img.size == mask.size 121 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask 122 | 123 | 124 | class AdjustSaturation(object): 125 | def __init__(self, saturation): 126 | self.saturation = saturation 127 | 128 | def __call__(self, img, mask): 129 | assert img.size == mask.size 130 | return tf.adjust_saturation(img, 131 | random.uniform(1 - self.saturation, 132 | 1 + self.saturation)), mask 133 | 134 | 135 | class AdjustHue(object): 136 | def __init__(self, hue): 137 | self.hue = hue 138 | 139 | def __call__(self, img, mask): 140 | assert img.size == mask.size 141 | return tf.adjust_hue(img, random.uniform(-self.hue, 142 | self.hue)), mask 143 | 144 | 145 | class AdjustBrightness(object): 146 | def __init__(self, bf): 147 | self.bf = bf 148 | 149 | def __call__(self, img, mask): 150 | assert img.size == mask.size 151 | return tf.adjust_brightness(img, 152 | random.uniform(1 - self.bf, 153 | 1 + self.bf)), mask 154 | 155 | class AdjustContrast(object): 156 | def __init__(self, cf): 157 | self.cf = cf 158 | 159 | def __call__(self, img, mask): 160 | assert img.size == mask.size 161 | return tf.adjust_contrast(img, 162 | random.uniform(1 - self.cf, 163 | 1 + self.cf)), mask 164 | 165 | class CenterCrop(object): 166 | def __init__(self, size): 167 | if isinstance(size, numbers.Number): 168 | self.size = (int(size), int(size)) 169 | else: 170 | self.size = size 171 | 172 | def __call__(self, img, mask): 173 | assert img.size == mask.size 174 | w, h = img.size 175 | th, tw = self.size 176 | x1 = int(round((w - tw) / 2.)) 177 | y1 = int(round((h - th) / 2.)) 178 | return ( 179 | img.crop((x1, y1, x1 + tw, y1 + th)), 180 | mask.crop((x1, y1, x1 + tw, y1 + th)), 181 | ) 182 | 183 | 184 | class RandomHorizontallyFlip(object): 185 | def __init__(self, p): 186 | self.p = p 187 | 188 | def __call__(self, img, mask, mask1=None, lpsoft=None, params=None): 189 | if random.random() < self.p: 190 | params['RandomHorizontallyFlip'] = True 191 | if lpsoft is not None: 192 | inv_idx = torch.arange(lpsoft.size(2)-1,-1,-1).long() # C x H x W 193 | lpsoft = lpsoft.index_select(2,inv_idx) 194 | if mask1 is not None: 195 | return ( 196 | img.transpose(Image.FLIP_LEFT_RIGHT), 197 | mask.transpose(Image.FLIP_LEFT_RIGHT), 198 | mask1.transpose(Image.FLIP_LEFT_RIGHT), 199 | lpsoft, 200 | params 201 | ) 202 | else: 203 | return ( 204 | img.transpose(Image.FLIP_LEFT_RIGHT), 205 | mask.transpose(Image.FLIP_LEFT_RIGHT), 206 | None, 207 | lpsoft, 208 | params 209 | ) 210 | else: 211 | params['RandomHorizontallyFlip'] = False 212 | return img, mask, mask1, lpsoft, params 213 | 214 | 215 | class RandomVerticallyFlip(object): 216 | def __init__(self, p): 217 | self.p = p 218 | 219 | def __call__(self, img, mask): 220 | if random.random() < self.p: 221 | return ( 222 | img.transpose(Image.FLIP_TOP_BOTTOM), 223 | mask.transpose(Image.FLIP_TOP_BOTTOM), 224 | ) 225 | return img, mask 226 | 227 | 228 | class FreeScale(object): 229 | def __init__(self, size): 230 | self.size = tuple(reversed(size)) # size: (h, w) 231 | 232 | def __call__(self, img, mask): 233 | assert img.size == mask.size 234 | return ( 235 | img.resize(self.size, Image.BILINEAR), 236 | mask.resize(self.size, Image.NEAREST), 237 | ) 238 | 239 | 240 | class RandomTranslate(object): 241 | def __init__(self, offset): 242 | self.offset = offset # tuple (delta_x, delta_y) 243 | 244 | def __call__(self, img, mask): 245 | assert img.size == mask.size 246 | x_offset = int(2 * (random.random() - 0.5) * self.offset[0]) 247 | y_offset = int(2 * (random.random() - 0.5) * self.offset[1]) 248 | 249 | x_crop_offset = x_offset 250 | y_crop_offset = y_offset 251 | if x_offset < 0: 252 | x_crop_offset = 0 253 | if y_offset < 0: 254 | y_crop_offset = 0 255 | 256 | cropped_img = tf.crop(img, 257 | y_crop_offset, 258 | x_crop_offset, 259 | img.size[1]-abs(y_offset), 260 | img.size[0]-abs(x_offset)) 261 | 262 | if x_offset >= 0 and y_offset >= 0: 263 | padding_tuple = (0, 0, x_offset, y_offset) 264 | 265 | elif x_offset >= 0 and y_offset < 0: 266 | padding_tuple = (0, abs(y_offset), x_offset, 0) 267 | 268 | elif x_offset < 0 and y_offset >= 0: 269 | padding_tuple = (abs(x_offset), 0, 0, y_offset) 270 | 271 | elif x_offset < 0 and y_offset < 0: 272 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0) 273 | 274 | return ( 275 | tf.pad(cropped_img, 276 | padding_tuple, 277 | padding_mode='reflect'), 278 | tf.affine(mask, 279 | translate=(-x_offset, -y_offset), 280 | scale=1.0, 281 | angle=0.0, 282 | shear=0.0, 283 | fillcolor=250)) 284 | 285 | 286 | class RandomRotate(object): 287 | def __init__(self, degree): 288 | self.degree = degree 289 | 290 | def __call__(self, img, mask): 291 | rotate_degree = random.random() * 2 * self.degree - self.degree 292 | return ( 293 | tf.affine(img, 294 | translate=(0, 0), 295 | scale=1.0, 296 | angle=rotate_degree, 297 | resample=Image.BILINEAR, 298 | fillcolor=(0, 0, 0), 299 | shear=0.0), 300 | tf.affine(mask, 301 | translate=(0, 0), 302 | scale=1.0, 303 | angle=rotate_degree, 304 | resample=Image.NEAREST, 305 | fillcolor=250, 306 | shear=0.0)) 307 | 308 | 309 | 310 | class Scale(object): 311 | def __init__(self, size): 312 | self.size = size 313 | 314 | def __call__(self, img, mask): 315 | assert img.size == mask.size 316 | w, h = img.size 317 | if (w >= h and w == self.size) or (h >= w and h == self.size): 318 | return img, mask 319 | if w > h: 320 | ow = self.size 321 | oh = int(self.size * h / w) 322 | return ( 323 | img.resize((ow, oh), Image.BILINEAR), 324 | mask.resize((ow, oh), Image.NEAREST), 325 | ) 326 | else: 327 | oh = self.size 328 | ow = int(self.size * w / h) 329 | return ( 330 | img.resize((ow, oh), Image.BILINEAR), 331 | mask.resize((ow, oh), Image.NEAREST), 332 | ) 333 | 334 | def MyScale(img, lbl, size): 335 | """scale 336 | 337 | img, lbl, longer size 338 | """ 339 | if isinstance(img, np.ndarray): 340 | _img = Image.fromarray(img) 341 | _lbl = Image.fromarray(lbl) 342 | else: 343 | _img = img 344 | _lbl = lbl 345 | assert _img.size == _lbl.size 346 | # prop = 1.0 * _img.size[0]/_img.size[1] 347 | w, h = size 348 | # h = int(size / prop) 349 | _img = _img.resize((w, h), Image.BILINEAR) 350 | _lbl = _lbl.resize((w, h), Image.NEAREST) 351 | return np.array(_img), np.array(_lbl) 352 | 353 | def Flip(img, lbl, prop): 354 | """ 355 | flip img and lbl with probablity prop 356 | """ 357 | if isinstance(img, np.ndarray): 358 | _img = Image.fromarray(img) 359 | _lbl = Image.fromarray(lbl) 360 | else: 361 | _img = img 362 | _lbl = lbl 363 | if random.random() < prop: 364 | _img.transpose(Image.FLIP_LEFT_RIGHT), 365 | _lbl.transpose(Image.FLIP_LEFT_RIGHT), 366 | return np.array(_img), np.array(_lbl) 367 | 368 | def MyRotate(img, lbl, degree): 369 | """ 370 | img, lbl, degree 371 | randomly rotate clockwise or anti-clockwise 372 | """ 373 | if isinstance(img, np.ndarray): 374 | _img = Image.fromarray(img) 375 | _lbl = Image.fromarray(lbl) 376 | else: 377 | _img = img 378 | _lbl = lbl 379 | _degree = random.random()*degree 380 | 381 | flags = -1 382 | if random.random() < 0.5: 383 | flags = 1 384 | _img = _img.rotate(_degree * flags) 385 | _lbl = _lbl.rotate(_degree * flags) 386 | return np.array(_img), np.array(_lbl) 387 | 388 | class RandomSizedCrop(object): 389 | def __init__(self, size): 390 | self.size = size 391 | 392 | def __call__(self, img, mask): 393 | assert img.size == mask.size 394 | for attempt in range(10): 395 | area = img.size[0] * img.size[1] 396 | target_area = random.uniform(0.45, 1.0) * area 397 | aspect_ratio = random.uniform(0.5, 2) 398 | 399 | w = int(round(math.sqrt(target_area * aspect_ratio))) 400 | h = int(round(math.sqrt(target_area / aspect_ratio))) 401 | 402 | if random.random() < 0.5: 403 | w, h = h, w 404 | 405 | if w <= img.size[0] and h <= img.size[1]: 406 | x1 = random.randint(0, img.size[0] - w) 407 | y1 = random.randint(0, img.size[1] - h) 408 | 409 | img = img.crop((x1, y1, x1 + w, y1 + h)) 410 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 411 | assert img.size == (w, h) 412 | 413 | return ( 414 | img.resize((self.size, self.size), Image.BILINEAR), 415 | mask.resize((self.size, self.size), Image.NEAREST), 416 | ) 417 | 418 | # Fallback 419 | scale = Scale(self.size) 420 | crop = CenterCrop(self.size) 421 | return crop(*scale(img, mask)) 422 | 423 | 424 | class RandomSized(object): 425 | def __init__(self, size): 426 | self.size = size 427 | self.scale = Scale(self.size) 428 | self.crop = RandomCrop(self.size) 429 | 430 | def __call__(self, img, mask, mask1=None, lpsoft=None, params=None): 431 | assert img.size == mask.size 432 | if mask1 is not None: 433 | assert (img.size == mask1.size) 434 | 435 | prop = 1.0 * img.size[0] / img.size[1] 436 | w = int(random.uniform(0.5, 1.5) * self.size) 437 | #w = self.size 438 | h = int(w/prop) 439 | params['RandomSized'] = (h, w) 440 | # h = int(random.uniform(0.5, 2) * self.size[1]) 441 | 442 | img, mask = ( 443 | img.resize((w, h), Image.BILINEAR), 444 | mask.resize((w, h), Image.NEAREST), 445 | ) 446 | if mask1 is not None: 447 | mask1 = mask1.resize((w, h), Image.NEAREST) 448 | if lpsoft is not None: 449 | lpsoft = F.interpolate(lpsoft.unsqueeze(0), size=[h, w], mode='bilinear', align_corners=True)[0] 450 | 451 | return img, mask, mask1, lpsoft, params 452 | # return self.crop(*self.scale(img, mask)) 453 | -------------------------------------------------------------------------------- /generate_soft_label.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import pickle 6 | import torch.optim as optim 7 | import scipy.misc 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn.functional as F 10 | import sys 11 | import os 12 | import os.path as osp 13 | import random 14 | import logging 15 | import time 16 | import torch.distributed as dist 17 | import torch.multiprocessing as mp 18 | from tensorboardX import SummaryWriter 19 | 20 | from model.feature_extractor import resnet_feature_extractor 21 | from model.classifier import ASPP_Classifier_Gen 22 | from model.discriminator import FCDiscriminator 23 | 24 | from utils.util import * 25 | from data import create_dataset 26 | import cv2 27 | 28 | IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32) 29 | IMG_STD = np.array((0.229, 0.224, 0.225), dtype=np.float32) 30 | 31 | MODEL = 'DeepLab' 32 | BATCH_SIZE = 1 33 | ITER_SIZE = 1 34 | NUM_WORKERS = 16 35 | IGNORE_LABEL = 250 36 | LEARNING_RATE = 2.5e-4 37 | MOMENTUM = 0.9 38 | NUM_CLASSES = 19 39 | NUM_STEPS = 62500 40 | NUM_STEPS_STOP = 40000 # early stopping 41 | POWER = 0.9 42 | RANDOM_SEED = 1234 43 | RESUME = './pretrained/model_phase1.pth' 44 | SAVE_NUM_IMAGES = 2 45 | SAVE_PRED_EVERY = 1000 46 | SNAPSHOT_DIR = './snapshots/' 47 | WEIGHT_DECAY = 0.0005 48 | LOG_DIR = './log' 49 | 50 | LEARNING_RATE_D = 1e-4 51 | LAMBDA_SEG = 0.1 52 | LAMBDA_ADV_TARGET1 = 0.0002 53 | LAMBDA_ADV_TARGET2 = 0.001 54 | 55 | SET = 'train' 56 | 57 | def get_arguments(): 58 | """Parse all the arguments provided from the CLI. 59 | 60 | Returns: 61 | A list of parsed arguments. 62 | """ 63 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 64 | parser.add_argument("--model", type=str, default=MODEL, 65 | help="available options : DeepLab") 66 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 67 | help="Number of images sent to the network in one step.") 68 | parser.add_argument("--iter-size", type=int, default=ITER_SIZE, 69 | help="Accumulate gradients for ITER_SIZE iterations.") 70 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, 71 | help="number of workers for multithread dataloading.") 72 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 73 | help="The index of the label to ignore during the training.") 74 | parser.add_argument("--is-training", action="store_true", 75 | help="Whether to updates the running means and variances during the training.") 76 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 77 | help="Base learning rate for training with polynomial decay.") 78 | parser.add_argument("--learning-rate-D", type=float, default=LEARNING_RATE_D, 79 | help="Base learning rate for discriminator.") 80 | parser.add_argument("--lambda-seg", type=float, default=LAMBDA_SEG, 81 | help="lambda_seg.") 82 | parser.add_argument("--lambda-adv-target1", type=float, default=LAMBDA_ADV_TARGET1, 83 | help="lambda_adv for adversarial training.") 84 | parser.add_argument("--lambda-adv-target2", type=float, default=LAMBDA_ADV_TARGET2, 85 | help="lambda_adv for adversarial training.") 86 | parser.add_argument("--momentum", type=float, default=MOMENTUM, 87 | help="Momentum component of the optimiser.") 88 | parser.add_argument("--not-restore-last", action="store_true", 89 | help="Whether to not restore last (FC) layers.") 90 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 91 | help="Number of classes to predict (including background).") 92 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS, 93 | help="Number of training steps.") 94 | parser.add_argument("--num-steps-stop", type=int, default=NUM_STEPS_STOP, 95 | help="Number of training steps for early stopping.") 96 | parser.add_argument("--power", type=float, default=POWER, 97 | help="Decay parameter to compute the learning rate.") 98 | parser.add_argument("--random-mirror", action="store_true", 99 | help="Whether to randomly mirror the inputs during the training.") 100 | parser.add_argument("--random-scale", action="store_true", 101 | help="Whether to randomly scale the inputs during the training.") 102 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 103 | help="Random seed to have reproducible results.") 104 | parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, 105 | help="How many images to save.") 106 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, 107 | help="Save summaries and checkpoint every often.") 108 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, 109 | help="Where to save snapshots of the model.") 110 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, 111 | help="Regularisation parameter for L2-loss.") 112 | parser.add_argument("--cpu", action='store_true', help="choose to use cpu device.") 113 | parser.add_argument("--tensorboard", action='store_true', help="choose whether to use tensorboard.") 114 | parser.add_argument("--log-dir", type=str, default=LOG_DIR, 115 | help="Path to the directory of log.") 116 | parser.add_argument("--set", type=str, default=SET, 117 | help="choose adaptation set.") 118 | parser.add_argument("--gpus", type=str, default="0,1", help="selected gpus") 119 | parser.add_argument("--dist", action="store_true", help="DDP") 120 | parser.add_argument("--ngpus_per_node", type=int, default=1, help='number of gpus in each node') 121 | parser.add_argument("--print-every", type=int, default=20, help='output message every n iterations') 122 | 123 | parser.add_argument("--src_dataset", type=str, default="gta5", help='training source dataset') 124 | parser.add_argument("--tgt_dataset", type=str, default="cityscapes_train", help='training target dataset') 125 | parser.add_argument("--tgt_val_dataset", type=str, default="cityscapes_val", help='training target dataset') 126 | parser.add_argument("--noaug", action="store_true", help="augmentation") 127 | parser.add_argument('--resize', type=int, default=2200, help='resize long size') 128 | parser.add_argument("--clrjit_params", type=str, default="0.5,0.5,0.5,0.2", help='brightness,contrast,saturation,hue') 129 | parser.add_argument('--rcrop', type=str, default='896,512', help='rondom crop size') 130 | parser.add_argument('--hflip', type=float, default=0.5, help='random flip probility') 131 | parser.add_argument('--src_rootpath', type=str, default='datasets/gta5') 132 | parser.add_argument('--tgt_rootpath', type=str, default='datasets/cityscapes') 133 | parser.add_argument('--noshuffle', action='store_true', help='do not use shuffle') 134 | parser.add_argument('--no_droplast', action='store_true') 135 | parser.add_argument('--pseudo_labels_folder', type=str, default='') 136 | parser.add_argument("--batch_size_val", type=int, default=4, help='batch_size for validation') 137 | parser.add_argument("--resume", type=str, default=RESUME, help='resume weight') 138 | parser.add_argument("--freeze_bn", action="store_true", help="augmentation") 139 | parser.add_argument("--hidden_dim", type=int, default=128, help='number of selected negative samples') 140 | parser.add_argument("--layer", type=int, default=1, help='separate from which layer') 141 | parser.add_argument("--output_folder", type=str, default="", help='output folder') 142 | return parser.parse_args() 143 | 144 | 145 | args = get_arguments() 146 | 147 | def main_worker(gpu, world_size, dist_url): 148 | """Create the model and start the training.""" 149 | if gpu == 0: 150 | if not os.path.exists(args.snapshot_dir): 151 | os.makedirs(args.snapshot_dir) 152 | logFilename = os.path.join(args.snapshot_dir, str(time.time())) 153 | logging.basicConfig( 154 | level = logging.INFO, 155 | format ='%(asctime)s-%(levelname)s-%(message)s', 156 | datefmt = '%y-%m-%d %H:%M', 157 | filename = logFilename, 158 | filemode = 'w+') 159 | filehandler = logging.FileHandler(logFilename, encoding='utf-8') 160 | logger = logging.getLogger() 161 | logger.addHandler(filehandler) 162 | handler = logging.StreamHandler() 163 | logger.addHandler(handler) 164 | logger.info(args) 165 | 166 | np.random.seed(args.random_seed) 167 | random.seed(args.random_seed) 168 | torch.manual_seed(args.random_seed) 169 | torch.cuda.manual_seed(args.random_seed) 170 | # torch.backends.cudnn.deterministic = True 171 | torch.cuda.manual_seed_all(args.random_seed) # if you are using multi-GPU. 172 | # torch.backends.cudnn.enabled = False 173 | 174 | print("gpu: {}, world_size: {}".format(gpu, world_size)) 175 | print("dist_url: ", dist_url) 176 | 177 | torch.cuda.set_device(gpu) 178 | args.batch_size = args.batch_size // world_size 179 | args.batch_size_val = args.batch_size_val // world_size 180 | args.num_workers = args.num_workers // world_size 181 | dist.init_process_group(backend='nccl', init_method=dist_url, world_size=world_size, rank=gpu) 182 | 183 | if gpu == 0: 184 | logger.info("args.batch_size: {}, args.batch_size_val: {}".format(args.batch_size, args.batch_size_val)) 185 | 186 | device = torch.device("cuda" if not args.cpu else "cpu") 187 | args.world_size = world_size 188 | 189 | if gpu == 0: 190 | logger.info("args: {}".format(args)) 191 | 192 | # cudnn.enabled = True 193 | 194 | # Create network 195 | if args.model == 'DeepLab': 196 | if args.resume: 197 | resume_weight = torch.load(args.resume, map_location='cpu') 198 | print("args.resume: ", args.resume) 199 | # feature_extractor_weights = resume_weight['model_state_dict'] 200 | model_B2_weights = resume_weight['model_B2_state_dict'] 201 | model_B_weights = resume_weight['model_B_state_dict'] 202 | head_weights = resume_weight['head_state_dict'] 203 | classifier_weights = resume_weight['classifier_state_dict'] 204 | # feature_extractor_weights = {k.replace("module.", ""):v for k,v in feature_extractor_weights.items()} 205 | model_B2_weights = {k.replace("module.", ""):v for k,v in model_B2_weights.items()} 206 | model_B_weights = {k.replace("module.", ""):v for k,v in model_B_weights.items()} 207 | head_weights = {k.replace("module.", ""):v for k,v in head_weights.items()} 208 | classifier_weights = {k.replace("module.", ""):v for k,v in classifier_weights.items()} 209 | 210 | if gpu == 0: 211 | logger.info("freeze_bn: {}".format(args.freeze_bn)) 212 | model = resnet_feature_extractor('resnet101', 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', freeze_bn=args.freeze_bn) 213 | 214 | if args.layer == 0: 215 | ndf = 64 216 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool) 217 | model_B = nn.Sequential(model.backbone.layer1, model.backbone.layer2, model.backbone.layer3, model.backbone.layer4) 218 | elif args.layer == 1: 219 | ndf = 256 220 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1) 221 | model_B = nn.Sequential(model.backbone.layer2, model.backbone.layer3, model.backbone.layer4) 222 | elif args.layer == 2: 223 | ndf = 512 224 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1, model.backbone.layer2) 225 | model_B = nn.Sequential(model.backbone.layer3, model.backbone.layer4) 226 | 227 | if args.resume: 228 | model_B2.load_state_dict(model_B2_weights) 229 | model_B.load_state_dict(model_B_weights) 230 | 231 | classifier = ASPP_Classifier_Gen(2048, [6, 12, 18, 24], [6, 12, 18, 24], args.num_classes, hidden_dim=args.hidden_dim) 232 | head, classifier = classifier.head, classifier.classifier 233 | if args.resume: 234 | head.load_state_dict(head_weights) 235 | classifier.load_state_dict(classifier_weights) 236 | 237 | model_B2.train() 238 | model_B.train() 239 | head.train() 240 | classifier.train() 241 | 242 | if gpu == 0: 243 | logger.info(model_B2) 244 | logger.info(model_B) 245 | logger.info(head) 246 | logger.info(classifier) 247 | else: 248 | logger = None 249 | 250 | if gpu == 0: 251 | logger.info("args.noaug: {}, args.resize: {}, args.rcrop: {}, args.hflip: {}, args.noshuffle: {}, args.no_droplast: {}".format(args.noaug, args.resize, args.rcrop, args.hflip, args.noshuffle, args.no_droplast)) 252 | args.rcrop = [int(x.strip()) for x in args.rcrop.split(",")] 253 | args.clrjit_params = [float(x) for x in args.clrjit_params.split(',')] 254 | 255 | datasets = create_dataset(args, logger) 256 | sourceloader_iter = enumerate(datasets.source_train_loader) 257 | targetloader_iter = enumerate(datasets.target_train_loader) 258 | 259 | # define model 260 | model_B2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B2) 261 | model_B2 = torch.nn.parallel.DistributedDataParallel(model_B2.cuda(), device_ids=[gpu], find_unused_parameters=True) 262 | 263 | model_B = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B) 264 | model_B = torch.nn.parallel.DistributedDataParallel(model_B.cuda(), device_ids=[gpu], find_unused_parameters=True) 265 | 266 | head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(head) 267 | head = torch.nn.parallel.DistributedDataParallel(head.cuda(), device_ids=[gpu], find_unused_parameters=True) 268 | 269 | classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 270 | classifier = torch.nn.parallel.DistributedDataParallel(classifier.cuda(), device_ids=[gpu], find_unused_parameters=True) 271 | seg_loss = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label) 272 | interp = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True) 273 | interp_target = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True) 274 | 275 | # labels for adversarial training 276 | source_label = 0 277 | target_label = 1 278 | 279 | # set up tensor board 280 | if args.tensorboard and gpu == 0: 281 | writer = SummaryWriter(args.snapshot_dir) 282 | 283 | validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_train_loader, args.output_folder) 284 | # exit() 285 | 286 | def validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger, testloader, output_folder): 287 | if gpu == 0: 288 | logger.info("Start Evaluation") 289 | # evaluate 290 | loss_meter = AverageMeter() 291 | intersection_meter = AverageMeter() 292 | union_meter = AverageMeter() 293 | 294 | model_B2.eval() 295 | model_B.eval() 296 | head.eval() 297 | classifier.eval() 298 | 299 | with torch.no_grad(): 300 | for i, batch in enumerate(testloader): 301 | images = batch["img_full"].cuda() 302 | labels = batch["lbl_full"].cuda() 303 | img_paths = batch['img_path'] 304 | 305 | pred = model_B(model_B2(images)) 306 | pred = classifier(head(pred)) 307 | output = F.interpolate(pred, size=labels.size()[-2:], mode='bilinear', align_corners=True) 308 | loss = seg_loss(output, labels) 309 | 310 | output = F.softmax(output, 1) 311 | 312 | output_np = pred.detach().cpu().numpy().squeeze() 313 | 314 | logits, output = output.max(1) 315 | 316 | for b in range(output_np.shape[0]): 317 | mask_filename = img_paths[b].split("/")[-1].split(".")[0] 318 | np.save(os.path.join(output_folder, mask_filename+".npy"), output_np[b]) 319 | 320 | intersection, union, _ = intersectionAndUnionGPU(output, labels, args.num_classes, args.ignore_label) 321 | dist.all_reduce(intersection), dist.all_reduce(union) 322 | intersection, union = intersection.cpu().numpy(), union.cpu().numpy() 323 | intersection_meter.update(intersection), union_meter.update(union) 324 | loss_meter.update(loss.item(), images.size(0)) 325 | if gpu == 0 and i % 50 == 0 and i != 0: 326 | logger.info("Evaluation iter = {0:5d}/{1:5d}, loss_eval = {2:.3f}".format( 327 | i, len(testloader), loss_meter.val 328 | )) 329 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 330 | miou = np.mean(iou_class) 331 | if gpu == 0: 332 | logger.info("Val result: mIoU = {:.3f}".format(miou)) 333 | for i in range(args.num_classes): 334 | logger.info("Class_{} Result: iou = {:.3f}".format(i, iou_class[i])) 335 | logger.info("End Evaluation") 336 | 337 | return miou, loss_meter.avg 338 | 339 | def find_free_port(): 340 | import socket 341 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 342 | # Binding to port 0 will cause the OS to find an available port for us 343 | sock.bind(("", 0)) 344 | port = sock.getsockname()[1] 345 | sock.close() 346 | # NOTE: there is still a chance the port could be taken by other processes. 347 | return port 348 | 349 | if __name__ == '__main__': 350 | args.gpus = [int(x) for x in args.gpus.split(",")] 351 | args.world_size = len(args.gpus) 352 | 353 | os.makedirs(args.output_folder, exist_ok=True) 354 | 355 | if args.dist: 356 | port = find_free_port() 357 | args.dist_url = f"tcp://127.0.0.1:{port}" 358 | mp.spawn(main_worker, nprocs=args.world_size, args=(args.world_size, args.dist_url)) 359 | else: 360 | main_worker(args.train_gpu, args.world_size, args) 361 | 362 | -------------------------------------------------------------------------------- /train_phase2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import pickle 6 | import torch.optim as optim 7 | import scipy.misc 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn.functional as F 10 | import sys 11 | import os 12 | import os.path as osp 13 | import random 14 | import logging 15 | import time 16 | import torch.distributed as dist 17 | import torch.multiprocessing as mp 18 | from tensorboardX import SummaryWriter 19 | 20 | from model.feature_extractor import resnet_feature_extractor 21 | from model.classifier import ASPP_Classifier_Gen 22 | from model.discriminator import FCDiscriminator 23 | 24 | from utils.util import * 25 | from data import create_dataset 26 | import cv2 27 | 28 | IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32) 29 | IMG_STD = np.array((0.229, 0.224, 0.225), dtype=np.float32) 30 | 31 | MODEL = 'DeepLab' 32 | BATCH_SIZE = 1 33 | ITER_SIZE = 1 34 | NUM_WORKERS = 16 35 | IGNORE_LABEL = 250 36 | LEARNING_RATE = 2.5e-4 37 | MOMENTUM = 0.9 38 | NUM_CLASSES = 19 39 | NUM_STEPS = 62500 40 | NUM_STEPS_STOP = 40000 # early stopping 41 | POWER = 0.9 42 | RANDOM_SEED = 1234 43 | RESUME = './pretrained/model_phase1.pth' 44 | SAVE_NUM_IMAGES = 2 45 | SAVE_PRED_EVERY = 1000 46 | SNAPSHOT_DIR = './snapshots/' 47 | WEIGHT_DECAY = 0.0005 48 | LOG_DIR = './log' 49 | 50 | LEARNING_RATE_D = 1e-4 51 | LAMBDA_SEG = 0.1 52 | LAMBDA_ADV_TARGET1 = 0.0002 53 | LAMBDA_ADV_TARGET2 = 0.001 54 | 55 | SET = 'train' 56 | 57 | def get_arguments(): 58 | """Parse all the arguments provided from the CLI. 59 | 60 | Returns: 61 | A list of parsed arguments. 62 | """ 63 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 64 | parser.add_argument("--model", type=str, default=MODEL, 65 | help="available options : DeepLab") 66 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 67 | help="Number of images sent to the network in one step.") 68 | parser.add_argument("--iter-size", type=int, default=ITER_SIZE, 69 | help="Accumulate gradients for ITER_SIZE iterations.") 70 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, 71 | help="number of workers for multithread dataloading.") 72 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 73 | help="The index of the label to ignore during the training.") 74 | parser.add_argument("--is-training", action="store_true", 75 | help="Whether to updates the running means and variances during the training.") 76 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 77 | help="Base learning rate for training with polynomial decay.") 78 | parser.add_argument("--learning-rate-D", type=float, default=LEARNING_RATE_D, 79 | help="Base learning rate for discriminator.") 80 | parser.add_argument("--lambda-seg", type=float, default=LAMBDA_SEG, 81 | help="lambda_seg.") 82 | parser.add_argument("--lambda-adv-target1", type=float, default=LAMBDA_ADV_TARGET1, 83 | help="lambda_adv for adversarial training.") 84 | parser.add_argument("--lambda-adv-target2", type=float, default=LAMBDA_ADV_TARGET2, 85 | help="lambda_adv for adversarial training.") 86 | parser.add_argument("--momentum", type=float, default=MOMENTUM, 87 | help="Momentum component of the optimiser.") 88 | parser.add_argument("--not-restore-last", action="store_true", 89 | help="Whether to not restore last (FC) layers.") 90 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 91 | help="Number of classes to predict (including background).") 92 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS, 93 | help="Number of training steps.") 94 | parser.add_argument("--num-steps-stop", type=int, default=NUM_STEPS_STOP, 95 | help="Number of training steps for early stopping.") 96 | parser.add_argument("--power", type=float, default=POWER, 97 | help="Decay parameter to compute the learning rate.") 98 | parser.add_argument("--random-mirror", action="store_true", 99 | help="Whether to randomly mirror the inputs during the training.") 100 | parser.add_argument("--random-scale", action="store_true", 101 | help="Whether to randomly scale the inputs during the training.") 102 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 103 | help="Random seed to have reproducible results.") 104 | parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, 105 | help="How many images to save.") 106 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, 107 | help="Save summaries and checkpoint every often.") 108 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, 109 | help="Where to save snapshots of the model.") 110 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, 111 | help="Regularisation parameter for L2-loss.") 112 | parser.add_argument("--cpu", action='store_true', help="choose to use cpu device.") 113 | parser.add_argument("--tensorboard", action='store_true', help="choose whether to use tensorboard.") 114 | parser.add_argument("--log-dir", type=str, default=LOG_DIR, 115 | help="Path to the directory of log.") 116 | parser.add_argument("--set", type=str, default=SET, 117 | help="choose adaptation set.") 118 | parser.add_argument("--gpus", type=str, default="0,1", help="selected gpus") 119 | parser.add_argument("--dist", action="store_true", help="DDP") 120 | parser.add_argument("--ngpus_per_node", type=int, default=1, help='number of gpus in each node') 121 | parser.add_argument("--print-every", type=int, default=20, help='output message every n iterations') 122 | 123 | parser.add_argument("--src_dataset", type=str, default="gta5", help='training source dataset') 124 | parser.add_argument("--tgt_dataset", type=str, default="cityscapes_train", help='training target dataset') 125 | parser.add_argument("--tgt_val_dataset", type=str, default="cityscapes_val", help='training target dataset') 126 | parser.add_argument("--noaug", action="store_true", help="augmentation") 127 | parser.add_argument('--resize', type=int, default=2200, help='resize long size') 128 | parser.add_argument("--clrjit_params", type=str, default="0.5,0.5,0.5,0.2", help='brightness,contrast,saturation,hue') 129 | parser.add_argument('--rcrop', type=str, default='896,512', help='rondom crop size') 130 | parser.add_argument('--hflip', type=float, default=0.5, help='random flip probility') 131 | parser.add_argument('--src_rootpath', type=str, default='datasets/gta5') 132 | parser.add_argument('--tgt_rootpath', type=str, default='datasets/cityscapes') 133 | parser.add_argument('--noshuffle', action='store_true', help='do not use shuffle') 134 | parser.add_argument('--no_droplast', action='store_true') 135 | parser.add_argument('--pseudo_labels_folder', type=str, default='') 136 | parser.add_argument('--soft_labels_folder', type=str, default='') 137 | parser.add_argument('--src_loss_weight', type=float, default=1.0, help='loss weight for source domain loss') 138 | parser.add_argument('--thresholds_path', type=str, default="avg", help='avg | pred_only | fix_only') 139 | 140 | parser.add_argument("--batch_size_val", type=int, default=4, help='batch_size for validation') 141 | parser.add_argument("--resume", type=str, default=RESUME, help='resume weight') 142 | parser.add_argument("--freeze_bn", action="store_true", help="augmentation") 143 | parser.add_argument("--hidden_dim", type=int, default=128, help='number of selected negative samples') 144 | parser.add_argument("--layer", type=int, default=1, help='separate from which layer') 145 | return parser.parse_args() 146 | 147 | 148 | args = get_arguments() 149 | 150 | 151 | def soft_label_cross_entropy(pred, soft_label, pixel_weights=None): 152 | N, C, H, W = pred.shape 153 | loss = -soft_label.float()*F.log_softmax(pred, dim=1) 154 | if pixel_weights is None: 155 | return torch.mean(torch.sum(loss, dim=1)) 156 | return torch.mean(pixel_weights*torch.sum(loss, dim=1)) 157 | 158 | 159 | def lr_poly(base_lr, iter, max_iter, power): 160 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 161 | 162 | 163 | def adjust_learning_rate(optimizer, i_iter): 164 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power) 165 | optimizer.param_groups[0]['lr'] = lr 166 | if len(optimizer.param_groups) > 1: 167 | optimizer.param_groups[1]['lr'] = lr * 10 168 | 169 | 170 | def adjust_learning_rate_D(optimizer, i_iter): 171 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power) 172 | optimizer.param_groups[0]['lr'] = lr 173 | if len(optimizer.param_groups) > 1: 174 | optimizer.param_groups[1]['lr'] = lr * 10 175 | 176 | 177 | def main_worker(gpu, world_size, dist_url): 178 | """Create the model and start the training.""" 179 | if gpu == 0: 180 | if not os.path.exists(args.snapshot_dir): 181 | os.makedirs(args.snapshot_dir) 182 | logFilename = os.path.join(args.snapshot_dir, str(time.time())) 183 | logging.basicConfig( 184 | level = logging.INFO, 185 | format ='%(asctime)s-%(levelname)s-%(message)s', 186 | datefmt = '%y-%m-%d %H:%M', 187 | filename = logFilename, 188 | filemode = 'w+') 189 | filehandler = logging.FileHandler(logFilename, encoding='utf-8') 190 | logger = logging.getLogger() 191 | logger.addHandler(filehandler) 192 | handler = logging.StreamHandler() 193 | logger.addHandler(handler) 194 | logger.info(args) 195 | 196 | np.random.seed(args.random_seed) 197 | random.seed(args.random_seed) 198 | torch.manual_seed(args.random_seed) 199 | torch.cuda.manual_seed(args.random_seed) 200 | # torch.backends.cudnn.deterministic = True 201 | torch.cuda.manual_seed_all(args.random_seed) # if you are using multi-GPU. 202 | # torch.backends.cudnn.enabled = False 203 | 204 | print("gpu: {}, world_size: {}".format(gpu, world_size)) 205 | print("dist_url: ", dist_url) 206 | 207 | torch.cuda.set_device(gpu) 208 | args.batch_size = args.batch_size // world_size 209 | args.batch_size_val = args.batch_size_val // world_size 210 | args.num_workers = args.num_workers // world_size 211 | dist.init_process_group(backend='nccl', init_method=dist_url, world_size=world_size, rank=gpu) 212 | 213 | if gpu == 0: 214 | logger.info("args.batch_size: {}, args.batch_size_val: {}".format(args.batch_size, args.batch_size_val)) 215 | 216 | device = torch.device("cuda" if not args.cpu else "cpu") 217 | 218 | args.world_size = world_size 219 | 220 | if gpu == 0: 221 | logger.info("args: {}".format(args)) 222 | 223 | # cudnn.enabled = True 224 | 225 | # Create network 226 | if args.model == 'DeepLab': 227 | if args.resume: 228 | resume_weight = torch.load(args.resume, map_location='cpu') 229 | print("args.resume: ", args.resume) 230 | # feature_extractor_weights = resume_weight['model_state_dict'] 231 | model_B2_weights = resume_weight['model_B2_state_dict'] 232 | model_B_weights = resume_weight['model_B_state_dict'] 233 | head_weights = resume_weight['head_state_dict'] 234 | classifier_weights = resume_weight['classifier_state_dict'] 235 | model_B2_weights = {k.replace("module.", ""):v for k,v in model_B2_weights.items()} 236 | model_B_weights = {k.replace("module.", ""):v for k,v in model_B_weights.items()} 237 | head_weights = {k.replace("module.", ""):v for k,v in head_weights.items()} 238 | classifier_weights = {k.replace("module.", ""):v for k,v in classifier_weights.items()} 239 | 240 | if gpu == 0: 241 | logger.info("freeze_bn: {}".format(args.freeze_bn)) 242 | 243 | model = resnet_feature_extractor('resnet101', 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', freeze_bn=args.freeze_bn) 244 | 245 | if args.layer == 0: 246 | ndf = 64 247 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool) 248 | model_B = nn.Sequential(model.backbone.layer1, model.backbone.layer2, model.backbone.layer3, model.backbone.layer4) 249 | elif args.layer == 1: 250 | ndf = 256 251 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1) 252 | model_B = nn.Sequential(model.backbone.layer2, model.backbone.layer3, model.backbone.layer4) 253 | elif args.layer == 2: 254 | ndf = 512 255 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1, model.backbone.layer2) 256 | model_B = nn.Sequential(model.backbone.layer3, model.backbone.layer4) 257 | 258 | if args.resume: 259 | model_B2.load_state_dict(model_B2_weights) 260 | model_B.load_state_dict(model_B_weights) 261 | 262 | classifier = ASPP_Classifier_Gen(2048, [6, 12, 18, 24], [6, 12, 18, 24], args.num_classes, hidden_dim=args.hidden_dim) 263 | head, classifier = classifier.head, classifier.classifier 264 | if args.resume: 265 | head.load_state_dict(head_weights) 266 | classifier.load_state_dict(classifier_weights) 267 | 268 | model_B2.train() 269 | model_B.train() 270 | head.train() 271 | classifier.train() 272 | 273 | if gpu == 0: 274 | logger.info(model_B2) 275 | logger.info(model_B) 276 | logger.info(head) 277 | logger.info(classifier) 278 | else: 279 | logger = None 280 | 281 | if gpu == 0: 282 | logger.info("args.noaug: {}, args.resize: {}, args.rcrop: {}, args.hflip: {}, args.noshuffle: {}, args.no_droplast: {}".format(args.noaug, args.resize, args.rcrop, args.hflip, args.noshuffle, args.no_droplast)) 283 | args.rcrop = [int(x.strip()) for x in args.rcrop.split(",")] 284 | args.clrjit_params = [float(x) for x in args.clrjit_params.split(',')] 285 | 286 | datasets = create_dataset(args, logger) 287 | sourceloader_iter = enumerate(datasets.source_train_loader) 288 | targetloader_iter = enumerate(datasets.target_train_loader) 289 | 290 | # define optimizer 291 | model_params = [{'params': list(model_B2.parameters()) + list(model_B.parameters())}, 292 | {'params': list(head.parameters()) + list(classifier.parameters()), 'lr': args.learning_rate * 10}] 293 | optimizer = optim.SGD(model_params, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 294 | assert len(optimizer.param_groups) == 2 295 | optimizer.zero_grad() 296 | 297 | model_B2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B2) 298 | model_B2 = torch.nn.parallel.DistributedDataParallel(model_B2.cuda(), device_ids=[gpu], find_unused_parameters=True) 299 | 300 | model_B = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B) 301 | model_B = torch.nn.parallel.DistributedDataParallel(model_B.cuda(), device_ids=[gpu], find_unused_parameters=True) 302 | 303 | head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(head) 304 | head = torch.nn.parallel.DistributedDataParallel(head.cuda(), device_ids=[gpu], find_unused_parameters=True) 305 | 306 | classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 307 | classifier = torch.nn.parallel.DistributedDataParallel(classifier.cuda(), device_ids=[gpu], find_unused_parameters=True) 308 | seg_loss = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label) 309 | 310 | interp = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True) 311 | interp_target = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True) 312 | 313 | # labels for adversarial training 314 | source_label = 0 315 | target_label = 1 316 | 317 | # set up tensor board 318 | if args.tensorboard and gpu == 0: 319 | writer = SummaryWriter(args.snapshot_dir) 320 | 321 | # # Uncomment the following two lines for testing 322 | # validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_valid_loader) 323 | # exit() 324 | 325 | thresholds = np.load(args.thresholds_path) 326 | class_list = ["road","sidewalk","building","wall", 327 | "fence","pole","traffic_light","traffic_sign","vegetation", 328 | "terrain","sky","person","rider","car", 329 | "truck","bus","train","motorcycle","bicycle"] 330 | if gpu == 0: 331 | logger.info('successfully load class-wise thresholds from {}'.format(args.thresholds_path)) 332 | for c in range(len(class_list)): 333 | logger.info("class {}: {}, threshold: {}".format(c, class_list[c], thresholds[c])) 334 | thresholds = torch.from_numpy(thresholds).cuda() 335 | 336 | scaler = torch.cuda.amp.GradScaler() 337 | best_miou = 0.0 338 | filename = None 339 | epoch_s, epoch_t = 0, 0 340 | for i_iter in range(args.num_steps): 341 | 342 | model_B2.train() 343 | model_B.train() 344 | head.train() 345 | classifier.train() 346 | 347 | loss_seg_value = 0 348 | loss_src_seg_value = 0 349 | 350 | optimizer.zero_grad() 351 | adjust_learning_rate(optimizer, i_iter) 352 | 353 | for sub_i in range(args.iter_size): 354 | 355 | # train with source 356 | try: 357 | _, batch = sourceloader_iter.__next__() 358 | except StopIteration: 359 | epoch_s += 1 360 | datasets.source_train_sampler.set_epoch(epoch_s) 361 | sourceloader_iter = enumerate(datasets.source_train_loader) 362 | _, batch = sourceloader_iter.__next__() 363 | 364 | images = batch['img'].cuda() 365 | labels = batch['label'].cuda() 366 | src_size = images.shape[-2:] 367 | 368 | with torch.cuda.amp.autocast(): 369 | 370 | feat_src = model_B2(images) 371 | feat_B_src = model_B(feat_src) 372 | pred = classifier(head(feat_B_src)) 373 | pred = interp(pred) #[b, num_classes, h, w] 374 | 375 | loss_seg = seg_loss(pred, labels) 376 | 377 | loss = loss_seg 378 | 379 | # proper normalization 380 | loss = args.src_loss_weight * loss / args.iter_size 381 | loss_src_seg_value += loss_seg / args.iter_size 382 | 383 | scaler.scale(loss).backward() 384 | 385 | # train with target 386 | try: 387 | _, batch = targetloader_iter.__next__() 388 | except StopIteration: 389 | epoch_t += 1 390 | datasets.target_train_sampler.set_epoch(epoch_t) 391 | targetloader_iter = enumerate(datasets.target_train_loader) 392 | _, batch = targetloader_iter.__next__() 393 | 394 | images = batch['img'].cuda() 395 | soft_labels = batch['lpsoft'].cuda() 396 | tgt_size = images.shape[-2:] 397 | 398 | with torch.no_grad(): 399 | 400 | 401 | soft_labels = F.softmax(soft_labels, 1) 402 | 403 | 404 | images_full = batch['img_full'].cuda() 405 | weak_params = batch['weak_params'] 406 | resize_params = weak_params['RandomSized'] 407 | crop_params = weak_params['RandomCrop'] 408 | flip_params = weak_params['RandomHorizontallyFlip'] 409 | # print("resize_params: ", resize_params) 410 | # print("crop_params: ", crop_params) 411 | # print("flip_params: ", flip_params) 412 | with torch.cuda.amp.autocast(): 413 | with torch.no_grad(): 414 | pred_full = F.softmax(interp(classifier(head(model_B(model_B2(images_full))))), 1) 415 | 416 | # print("v1 pred_full.min(): {}, pred_full.max(): {}, pred_full.mean(): {}".format(pred_full.min(), pred_full.max(), pred_full.mean())) 417 | 418 | pred_labels = [] 419 | for b in range(pred_full.shape[0]): 420 | # restore pred_full to crop 421 | # 1.Resize 422 | h, w = resize_params[0][b], resize_params[1][b] 423 | pred_resize_b = F.interpolate(pred_full[b].unsqueeze(0), size=(h, w), mode='bilinear', align_corners=True)[0] 424 | # 2.Crop 425 | ys, ye, xs, xe = crop_params[0][b], crop_params[1][b], crop_params[2][b], crop_params[3][b] 426 | pred_crop_b = pred_resize_b[:, ys:ye, xs:xe] 427 | # 3.Flip 428 | if flip_params[b]: 429 | pred_crop_b = torch.flip(pred_crop_b, dims=(2,)) #[c, h, w] 430 | pred_labels.append(pred_crop_b) 431 | pred_labels = torch.stack(pred_labels, 0) 432 | assert pred_labels.shape[-2:] == tgt_size 433 | pseudo_labels = (pred_labels + soft_labels) / 2.0 434 | 435 | 436 | with torch.cuda.amp.autocast(): 437 | feat_tgt = model_B2(images) 438 | feat_B_tgt = model_B(feat_tgt) 439 | pred = classifier(head(feat_B_tgt)) 440 | pred = interp(pred) #[b, num_classes, h, w] 441 | 442 | conf, pseudo_labels = pseudo_labels.max(1) #[b, h, w] 443 | 444 | pseudo_labels[conf < thresholds[pseudo_labels]] = args.ignore_label 445 | pseudo_labels = pseudo_labels.detach() 446 | loss_seg = seg_loss(pred, pseudo_labels) 447 | 448 | loss = loss_seg 449 | 450 | # proper normalization 451 | loss = loss / args.iter_size 452 | loss_seg_value += loss_seg / args.iter_size 453 | 454 | scaler.scale(loss).backward() 455 | 456 | n = torch.tensor(1.0).cuda() 457 | 458 | dist.all_reduce(n), dist.all_reduce(loss_seg_value), dist.all_reduce(loss_src_seg_value) 459 | 460 | loss_seg_value = loss_seg_value.item() / n.item() 461 | loss_src_seg_value = loss_src_seg_value.item() / n.item() 462 | 463 | scaler.step(optimizer) 464 | scaler.update() 465 | 466 | if args.tensorboard and gpu == 0: 467 | scalar_info = { 468 | 'loss_seg': loss_seg_value, 469 | 'loss_src_seg': loss_src_seg_value, 470 | } 471 | 472 | if i_iter % 10 == 0: 473 | for key, val in scalar_info.items(): 474 | writer.add_scalar(key, val, i_iter) 475 | 476 | if gpu == 0 and i_iter % args.print_every == 0: 477 | logger.info('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_src_seg = {3:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_src_seg_value)) 478 | 479 | if gpu == 0 and i_iter >= args.num_steps_stop - 1: 480 | logger.info('save model ...') 481 | filename = osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth') 482 | save_file = {'model_B2_state_dict': model_B2.state_dict(), 'model_B_state_dict': model_B.state_dict(), \ 483 | 'head_state_dict': head.state_dict(), 'classifier_state_dict': classifier.state_dict()} 484 | torch.save(save_file, filename) 485 | logger.info("saving checkpoint model to {}".format(filename)) 486 | break 487 | 488 | if i_iter % args.save_pred_every == 0 and i_iter != 0: 489 | miou, loss_val = validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_valid_loader) 490 | if args.tensorboard and gpu == 0: 491 | scalar_info = { 492 | 'miou_val': miou, 493 | 'loss_val': loss_val 494 | } 495 | for k, v in scalar_info.items(): 496 | writer.add_scalar(k, v, i_iter) 497 | 498 | if gpu == 0 and miou > best_miou: 499 | best_miou = miou 500 | logger.info('taking snapshot ...') 501 | if filename is not None and os.path.exists(filename): 502 | os.remove(filename) 503 | filename = osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + "_{}".format(miou) + '.pth') 504 | save_file = {'model_B2_state_dict': model_B2.state_dict(), 'model_B_state_dict': model_B.state_dict(), \ 505 | 'head_state_dict': head.state_dict(), 'classifier_state_dict': classifier.state_dict()} 506 | torch.save(save_file, filename) 507 | logger.info("saving checkpoint model to {}".format(filename)) 508 | 509 | if args.tensorboard and gpu == 0: 510 | writer.close() 511 | 512 | def validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger, testloader): 513 | if gpu == 0: 514 | logger.info("Start Evaluation") 515 | # evaluate 516 | loss_meter = AverageMeter() 517 | intersection_meter = AverageMeter() 518 | union_meter = AverageMeter() 519 | 520 | model_B2.eval() 521 | model_B.eval() 522 | head.eval() 523 | classifier.eval() 524 | 525 | with torch.no_grad(): 526 | for i, batch in enumerate(testloader): 527 | images = batch["img"].cuda() 528 | labels = batch["label"].cuda() 529 | 530 | pred = model_B(model_B2(images)) 531 | pred = classifier(head(pred)) 532 | output = F.interpolate(pred, size=labels.size()[-2:], mode='bilinear', align_corners=True) 533 | loss = seg_loss(output, labels) 534 | 535 | output = output.max(1)[1] 536 | intersection, union, _ = intersectionAndUnionGPU(output, labels, args.num_classes, args.ignore_label) 537 | dist.all_reduce(intersection), dist.all_reduce(union) 538 | intersection, union = intersection.cpu().numpy(), union.cpu().numpy() 539 | intersection_meter.update(intersection), union_meter.update(union) 540 | loss_meter.update(loss.item(), images.size(0)) 541 | if gpu == 0 and i % 50 == 0 and i != 0: 542 | logger.info("Evaluation iter = {0:5d}/{1:5d}, loss_eval = {2:.3f}".format( 543 | i, len(testloader), loss_meter.val 544 | )) 545 | 546 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 547 | miou = np.mean(iou_class) 548 | if gpu == 0: 549 | logger.info("Val result: mIoU = {:.3f}".format(miou)) 550 | for i in range(args.num_classes): 551 | logger.info("Class_{} Result: iou = {:.3f}".format(i, iou_class[i])) 552 | logger.info("End Evaluation") 553 | 554 | torch.cuda.empty_cache() 555 | 556 | return miou, loss_meter.avg 557 | 558 | def find_free_port(): 559 | import socket 560 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 561 | # Binding to port 0 will cause the OS to find an available port for us 562 | sock.bind(("", 0)) 563 | port = sock.getsockname()[1] 564 | sock.close() 565 | # NOTE: there is still a chance the port could be taken by other processes. 566 | return port 567 | 568 | if __name__ == '__main__': 569 | args.gpus = [int(x) for x in args.gpus.split(",")] 570 | args.world_size = len(args.gpus) 571 | if args.dist: 572 | port = find_free_port() 573 | args.dist_url = f"tcp://127.0.0.1:{port}" 574 | mp.spawn(main_worker, nprocs=args.world_size, args=(args.world_size, args.dist_url)) 575 | else: 576 | main_worker(args.train_gpu, args.world_size, args) 577 | 578 | -------------------------------------------------------------------------------- /train_phase1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import pickle 6 | import torch.optim as optim 7 | import scipy.misc 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn.functional as F 10 | import sys 11 | import os 12 | import os.path as osp 13 | import random 14 | import logging 15 | import time 16 | import torch.distributed as dist 17 | import torch.multiprocessing as mp 18 | from tensorboardX import SummaryWriter 19 | 20 | from model.feature_extractor import resnet_feature_extractor 21 | from model.classifier import ASPP_Classifier_Gen 22 | from model.discriminator import FCDiscriminator 23 | 24 | from utils.util import * 25 | from data import create_dataset 26 | import cv2 27 | 28 | IMG_MEAN = np.array((0.485, 0.456, 0.406), dtype=np.float32) 29 | IMG_STD = np.array((0.229, 0.224, 0.225), dtype=np.float32) 30 | 31 | MODEL = 'DeepLab' 32 | BATCH_SIZE = 1 33 | ITER_SIZE = 1 34 | NUM_WORKERS = 16 35 | IGNORE_LABEL = 250 36 | LEARNING_RATE = 2.5e-4 37 | MOMENTUM = 0.9 38 | NUM_CLASSES = 19 39 | NUM_STEPS = 93750 40 | NUM_STEPS_STOP = 60000 # early stopping 41 | POWER = 0.9 42 | RANDOM_SEED = 1234 43 | RESUME = './pretrained/sourceonly.pth' 44 | SAVE_NUM_IMAGES = 2 45 | SAVE_PRED_EVERY = 1000 46 | SNAPSHOT_DIR = './snapshots/' 47 | WEIGHT_DECAY = 0.0005 48 | LOG_DIR = './log' 49 | 50 | LEARNING_RATE_D = 1e-4 51 | LAMBDA_SEG = 0.1 52 | LAMBDA_ADV_TARGET1 = 0.0002 53 | LAMBDA_ADV_TARGET2 = 0.001 54 | GAN = 'LS' #'Vanilla' 55 | 56 | SET = 'train' 57 | 58 | def get_arguments(): 59 | """Parse all the arguments provided from the CLI. 60 | 61 | Returns: 62 | A list of parsed arguments. 63 | """ 64 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 65 | parser.add_argument("--model", type=str, default=MODEL, 66 | help="available options : DeepLab") 67 | parser.add_argument("--batch-size", type=int, default=BATCH_SIZE, 68 | help="Number of images sent to the network in one step.") 69 | parser.add_argument("--iter-size", type=int, default=ITER_SIZE, 70 | help="Accumulate gradients for ITER_SIZE iterations.") 71 | parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, 72 | help="number of workers for multithread dataloading.") 73 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 74 | help="The index of the label to ignore during the training.") 75 | parser.add_argument("--is-training", action="store_true", 76 | help="Whether to updates the running means and variances during the training.") 77 | parser.add_argument("--learning-rate", type=float, default=LEARNING_RATE, 78 | help="Base learning rate for training with polynomial decay.") 79 | parser.add_argument("--learning-rate-D", type=float, default=LEARNING_RATE_D, 80 | help="Base learning rate for discriminator.") 81 | parser.add_argument("--lambda-seg", type=float, default=LAMBDA_SEG, 82 | help="lambda_seg.") 83 | parser.add_argument("--lambda-adv-target1", type=float, default=LAMBDA_ADV_TARGET1, 84 | help="lambda_adv for adversarial training.") 85 | parser.add_argument("--lambda-adv-target2", type=float, default=LAMBDA_ADV_TARGET2, 86 | help="lambda_adv for adversarial training.") 87 | parser.add_argument("--momentum", type=float, default=MOMENTUM, 88 | help="Momentum component of the optimiser.") 89 | parser.add_argument("--not-restore-last", action="store_true", 90 | help="Whether to not restore last (FC) layers.") 91 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 92 | help="Number of classes to predict (including background).") 93 | parser.add_argument("--num-steps", type=int, default=NUM_STEPS, 94 | help="Number of training steps.") 95 | parser.add_argument("--num-steps-stop", type=int, default=NUM_STEPS_STOP, 96 | help="Number of training steps for early stopping.") 97 | parser.add_argument("--power", type=float, default=POWER, 98 | help="Decay parameter to compute the learning rate.") 99 | parser.add_argument("--random-mirror", action="store_true", 100 | help="Whether to randomly mirror the inputs during the training.") 101 | parser.add_argument("--random-scale", action="store_true", 102 | help="Whether to randomly scale the inputs during the training.") 103 | parser.add_argument("--random-seed", type=int, default=RANDOM_SEED, 104 | help="Random seed to have reproducible results.") 105 | parser.add_argument("--save-num-images", type=int, default=SAVE_NUM_IMAGES, 106 | help="How many images to save.") 107 | parser.add_argument("--save-pred-every", type=int, default=SAVE_PRED_EVERY, 108 | help="Save summaries and checkpoint every often.") 109 | parser.add_argument("--snapshot-dir", type=str, default=SNAPSHOT_DIR, 110 | help="Where to save snapshots of the model.") 111 | parser.add_argument("--weight-decay", type=float, default=WEIGHT_DECAY, 112 | help="Regularisation parameter for L2-loss.") 113 | parser.add_argument("--cpu", action='store_true', help="choose to use cpu device.") 114 | parser.add_argument("--tensorboard", action='store_true', help="choose whether to use tensorboard.") 115 | parser.add_argument("--log-dir", type=str, default=LOG_DIR, 116 | help="Path to the directory of log.") 117 | parser.add_argument("--set", type=str, default=SET, 118 | help="choose adaptation set.") 119 | parser.add_argument("--gan", type=str, default=GAN, 120 | help="choose the GAN objective.") 121 | parser.add_argument("--gpus", type=str, default="0,1", help="selected gpus") 122 | parser.add_argument("--dist", action="store_true", help="DDP") 123 | parser.add_argument("--ngpus_per_node", type=int, default=1, help='number of gpus in each node') 124 | parser.add_argument("--print-every", type=int, default=20, help='output message every n iterations') 125 | 126 | parser.add_argument("--src_dataset", type=str, default="gta5", help='training source dataset') 127 | parser.add_argument("--tgt_dataset", type=str, default="cityscapes_train", help='training target dataset') 128 | parser.add_argument("--tgt_val_dataset", type=str, default="cityscapes_val", help='training target dataset') 129 | parser.add_argument("--noaug", action="store_true", help="augmentation") 130 | parser.add_argument('--resize', type=int, default=2200, help='resize long size') 131 | parser.add_argument("--clrjit_params", type=str, default="0.0,0.0,0.0,0.0", help='brightness,contrast,saturation,hue') 132 | parser.add_argument('--rcrop', type=str, default='896,512', help='rondom crop size') 133 | parser.add_argument('--hflip', type=float, default=0.5, help='random flip probility') 134 | parser.add_argument('--src_rootpath', type=str, default='datasets/gta5') 135 | parser.add_argument('--tgt_rootpath', type=str, default='datasets/cityscapes') 136 | parser.add_argument('--noshuffle', action='store_true', help='do not use shuffle') 137 | parser.add_argument('--no_droplast', action='store_true') 138 | parser.add_argument('--pseudo_labels_folder', type=str, default='') 139 | parser.add_argument('--conf_bank_length', type=int, default=100000) 140 | parser.add_argument('--conf_p', type=float, default=0.8) 141 | 142 | parser.add_argument("--batch_size_val", type=int, default=4, help='batch_size for validation') 143 | parser.add_argument("--resume", type=str, default=RESUME, help='resume weight') 144 | parser.add_argument("--freeze_bn", action="store_true", help="augmentation") 145 | parser.add_argument("--lambda_adv_src", type=float, default=0.1, help='weight for loss_adv_src') 146 | parser.add_argument("--lambda_adv_tgt", type=float, default=0.01, help='weight for loss_adv_tgt') 147 | parser.add_argument("--hidden_dim", type=int, default=128, help='number of selected negative samples') 148 | parser.add_argument("--layer", type=int, default=1, help='separate from which layer') 149 | parser.add_argument("--lambda_st", type=float, default=0.1, help='weight for loss_st') 150 | return parser.parse_args() 151 | 152 | 153 | args = get_arguments() 154 | 155 | 156 | def soft_label_cross_entropy(pred, soft_label, pixel_weights=None): 157 | N, C, H, W = pred.shape 158 | loss = -soft_label.float()*F.log_softmax(pred, dim=1) 159 | if pixel_weights is None: 160 | return torch.mean(torch.sum(loss, dim=1)) 161 | return torch.mean(pixel_weights*torch.sum(loss, dim=1)) 162 | 163 | 164 | def lr_poly(base_lr, iter, max_iter, power): 165 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 166 | 167 | 168 | def adjust_learning_rate(optimizer, i_iter): 169 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power) 170 | optimizer.param_groups[0]['lr'] = lr 171 | if len(optimizer.param_groups) > 1: 172 | optimizer.param_groups[1]['lr'] = lr * 10 173 | 174 | 175 | def adjust_learning_rate_D(optimizer, i_iter): 176 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power) 177 | optimizer.param_groups[0]['lr'] = lr 178 | if len(optimizer.param_groups) > 1: 179 | optimizer.param_groups[1]['lr'] = lr * 10 180 | 181 | 182 | def main_worker(gpu, world_size, dist_url): 183 | """Create the model and start the training.""" 184 | if gpu == 0: 185 | if not os.path.exists(args.snapshot_dir): 186 | os.makedirs(args.snapshot_dir) 187 | logFilename = os.path.join(args.snapshot_dir, str(time.time())) 188 | logging.basicConfig( 189 | level = logging.INFO, 190 | format ='%(asctime)s-%(levelname)s-%(message)s', 191 | datefmt = '%y-%m-%d %H:%M', 192 | filename = logFilename, 193 | filemode = 'w+') 194 | filehandler = logging.FileHandler(logFilename, encoding='utf-8') 195 | logger = logging.getLogger() 196 | logger.addHandler(filehandler) 197 | handler = logging.StreamHandler() 198 | logger.addHandler(handler) 199 | logger.info(args) 200 | 201 | np.random.seed(args.random_seed) 202 | random.seed(args.random_seed) 203 | torch.manual_seed(args.random_seed) 204 | torch.cuda.manual_seed(args.random_seed) 205 | # torch.backends.cudnn.deterministic = True 206 | torch.cuda.manual_seed_all(args.random_seed) # if you are using multi-GPU. 207 | # torch.backends.cudnn.enabled = False 208 | 209 | print("gpu: {}, world_size: {}".format(gpu, world_size)) 210 | print("dist_url: ", dist_url) 211 | 212 | torch.cuda.set_device(gpu) 213 | args.batch_size = args.batch_size // world_size 214 | args.batch_size_val = args.batch_size_val // world_size 215 | args.num_workers = args.num_workers // world_size 216 | dist.init_process_group(backend='nccl', init_method=dist_url, world_size=world_size, rank=gpu) 217 | 218 | if gpu == 0: 219 | logger.info("args.batch_size: {}, args.batch_size_val: {}".format(args.batch_size, args.batch_size_val)) 220 | 221 | device = torch.device("cuda" if not args.cpu else "cpu") 222 | 223 | args.world_size = world_size 224 | 225 | if gpu == 0: 226 | logger.info("args: {}".format(args)) 227 | 228 | # cudnn.enabled = True 229 | 230 | # Create network 231 | if args.model == 'DeepLab': 232 | 233 | if args.resume: 234 | resume_weight = torch.load(args.resume, map_location='cpu') 235 | print("args.resume: ", args.resume) 236 | feature_extractor_weights = resume_weight['model_state_dict'] 237 | head_weights = resume_weight['head_state_dict'] 238 | classifier_weights = resume_weight['classifier_state_dict'] 239 | feature_extractor_weights = {k.replace("module.", ""):v for k,v in feature_extractor_weights.items()} 240 | head_weights = {k.replace("module.", ""):v for k,v in head_weights.items()} 241 | classifier_weights = {k.replace("module.", ""):v for k,v in classifier_weights.items()} 242 | 243 | if gpu == 0: 244 | logger.info("freeze_bn: {}".format(args.freeze_bn)) 245 | model = resnet_feature_extractor('resnet101', 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', freeze_bn=args.freeze_bn) 246 | if args.resume: 247 | model.load_state_dict(feature_extractor_weights) 248 | 249 | if args.layer == 0: 250 | model_B1 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool) 251 | elif args.layer == 1: 252 | model_B1 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1) 253 | elif args.layer == 2: 254 | model_B1 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1, model.backbone.layer2) 255 | 256 | model = resnet_feature_extractor('resnet101', 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', freeze_bn=args.freeze_bn) 257 | if args.resume: 258 | model.load_state_dict(feature_extractor_weights) 259 | 260 | if args.layer == 0: 261 | ndf = 64 262 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool) 263 | model_B = nn.Sequential(model.backbone.layer1, model.backbone.layer2, model.backbone.layer3, model.backbone.layer4) 264 | elif args.layer == 1: 265 | ndf = 256 266 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1) 267 | model_B = nn.Sequential(model.backbone.layer2, model.backbone.layer3, model.backbone.layer4) 268 | elif args.layer == 2: 269 | ndf = 512 270 | model_B2 = nn.Sequential(model.backbone.conv1, model.backbone.bn1, model.backbone.relu, model.backbone.maxpool, model.backbone.layer1, model.backbone.layer2) 271 | model_B = nn.Sequential(model.backbone.layer3, model.backbone.layer4) 272 | 273 | model_D1 = FCDiscriminator(ndf, ndf=32) 274 | model_D2 = FCDiscriminator(args.num_classes, ndf=64) 275 | 276 | classifier = ASPP_Classifier_Gen(2048, [6, 12, 18, 24], [6, 12, 18, 24], args.num_classes, hidden_dim=args.hidden_dim) 277 | head, classifier = classifier.head, classifier.classifier 278 | if args.resume: 279 | head.load_state_dict(head_weights) 280 | classifier.load_state_dict(classifier_weights) 281 | 282 | aux_classifier = ASPP_Classifier_Gen(2048, [6, 12, 18, 24], [6, 12, 18, 24], args.num_classes, hidden_dim=args.hidden_dim) 283 | _, aux_classifier = aux_classifier.head, aux_classifier.classifier 284 | if args.resume: 285 | aux_classifier.load_state_dict(classifier_weights) 286 | 287 | model_B1.train() 288 | model_B2.train() 289 | model_B.train() 290 | model_D1.train() 291 | model_D2.train() 292 | head.train() 293 | classifier.train() 294 | aux_classifier.train() 295 | 296 | # cudnn.benchmark = True 297 | if gpu == 0: 298 | logger.info(model_B1) 299 | logger.info(model_B2) 300 | logger.info(model_B) 301 | logger.info(model_D1) 302 | logger.info(model_D2) 303 | logger.info(head) 304 | logger.info(classifier) 305 | logger.info(aux_classifier) 306 | else: 307 | logger = None 308 | 309 | if gpu == 0: 310 | logger.info("args.noaug: {}, args.resize: {}, args.rcrop: {}, args.hflip: {}, args.noshuffle: {}, args.no_droplast: {}".format(args.noaug, args.resize, args.rcrop, args.hflip, args.noshuffle, args.no_droplast)) 311 | args.rcrop = [int(x.strip()) for x in args.rcrop.split(",")] 312 | args.clrjit_params = [float(x) for x in args.clrjit_params.split(',')] 313 | 314 | datasets = create_dataset(args, logger) 315 | 316 | # define optimizer 317 | model_params = [{'params': list(model_B1.parameters()) + list(model_B2.parameters()) + list(model_B.parameters())}, 318 | {'params': list(head.parameters()) + list(classifier.parameters()) + \ 319 | list(aux_classifier.parameters()), 'lr': args.learning_rate * 10}] 320 | optimizer = optim.SGD(model_params, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 321 | assert len(optimizer.param_groups) == 2 322 | optimizer.zero_grad() 323 | 324 | optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) 325 | optimizer_D1.zero_grad() 326 | 327 | optimizer_D2 = optim.Adam(model_D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) 328 | optimizer_D2.zero_grad() 329 | 330 | # define model 331 | model_B1 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B1) 332 | model_B1 = torch.nn.parallel.DistributedDataParallel(model_B1.cuda(), device_ids=[gpu], find_unused_parameters=True) 333 | 334 | model_B2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B2) 335 | model_B2 = torch.nn.parallel.DistributedDataParallel(model_B2.cuda(), device_ids=[gpu], find_unused_parameters=True) 336 | 337 | model_B = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_B) 338 | model_B = torch.nn.parallel.DistributedDataParallel(model_B.cuda(), device_ids=[gpu], find_unused_parameters=True) 339 | 340 | head = torch.nn.SyncBatchNorm.convert_sync_batchnorm(head) 341 | head = torch.nn.parallel.DistributedDataParallel(head.cuda(), device_ids=[gpu], find_unused_parameters=True) 342 | 343 | classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(classifier) 344 | classifier = torch.nn.parallel.DistributedDataParallel(classifier.cuda(), device_ids=[gpu], find_unused_parameters=True) 345 | 346 | model_D1 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_D1) 347 | model_D1 = torch.nn.parallel.DistributedDataParallel(model_D1.cuda(), device_ids=[gpu], find_unused_parameters=True) 348 | 349 | model_D2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_D2) 350 | model_D2 = torch.nn.parallel.DistributedDataParallel(model_D2.cuda(), device_ids=[gpu], find_unused_parameters=True) 351 | 352 | aux_classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(aux_classifier) 353 | aux_classifier = torch.nn.parallel.DistributedDataParallel(aux_classifier.cuda(), device_ids=[gpu], find_unused_parameters=True) 354 | 355 | if args.gan == 'Vanilla': 356 | bce_loss = torch.nn.BCEWithLogitsLoss() 357 | elif args.gan == 'LS': 358 | bce_loss = torch.nn.MSELoss() 359 | if gpu == 0: 360 | logger.info("use LS-GAN") 361 | seg_loss = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label) 362 | 363 | interp = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True) 364 | interp_target = nn.Upsample(size=(args.rcrop[1], args.rcrop[0]), mode='bilinear', align_corners=True) 365 | 366 | # labels for adversarial training 367 | source_label = 0 368 | target_label = 1 369 | 370 | # set up tensor board 371 | if args.tensorboard and gpu == 0: 372 | writer = SummaryWriter(args.snapshot_dir) 373 | 374 | if gpu == 0: 375 | logger.info("args.lambda_adv_src: {}, args.lambda_adv_tgt: {}".format(args.lambda_adv_src, args.lambda_adv_tgt)) 376 | 377 | # validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_valid_loader) 378 | # exit() 379 | 380 | trainloader_iter = enumerate(datasets.source_train_loader) 381 | targetloader_iter = enumerate(datasets.target_train_loader) 382 | 383 | conf_bank = {i: [] for i in range(args.num_classes)} 384 | thresholds = torch.zeros(args.num_classes).float().cuda() 385 | class_list = ["road","sidewalk","building","wall", 386 | "fence","pole","traffic_light","traffic_sign","vegetation", 387 | "terrain","sky","person","rider","car", 388 | "truck","bus","train","motorcycle","bicycle"] 389 | 390 | scaler = torch.cuda.amp.GradScaler() 391 | best_miou = 0.0 392 | filename = None 393 | epoch_s, epoch_t = 0, 0 394 | for i_iter in range(args.num_steps): 395 | 396 | # model.train() 397 | model_B1.train() 398 | model_B2.train() 399 | model_B.train() 400 | model_D1.train() 401 | model_D2.train() 402 | head.train() 403 | classifier.train() 404 | aux_classifier.train() 405 | 406 | loss_seg_value = 0 407 | loss_adv_src_value = 0 408 | loss_adv_tgt_value = 0 409 | loss_D1_value = 0 410 | loss_D2_value = 0 411 | loss_st_value = 0 412 | 413 | optimizer.zero_grad() 414 | adjust_learning_rate(optimizer, i_iter) 415 | optimizer_D1.zero_grad() 416 | adjust_learning_rate_D(optimizer_D1, i_iter) 417 | optimizer_D2.zero_grad() 418 | adjust_learning_rate_D(optimizer_D2, i_iter) 419 | 420 | for sub_i in range(args.iter_size): 421 | 422 | # train G 423 | for param in model_D1.parameters(): 424 | param.requires_grad = False 425 | for param in model_D2.parameters(): 426 | param.requires_grad = False 427 | 428 | # train with source 429 | try: 430 | _, batch = trainloader_iter.__next__() 431 | except StopIteration: 432 | epoch_s += 1 433 | datasets.source_train_sampler.set_epoch(epoch_s) 434 | trainloader_iter = enumerate(datasets.source_train_loader) 435 | _, batch = trainloader_iter.__next__() 436 | 437 | images = batch['img'].cuda() 438 | labels = batch['label'].cuda() 439 | 440 | src_size = images.shape[-2:] 441 | with torch.cuda.amp.autocast(): 442 | feat_src = model_B1(images) 443 | 444 | feat_B_src = model_B(feat_src) 445 | pred = classifier(head(feat_B_src)) 446 | pred = interp(pred) #[b, num_classes, h, w] 447 | 448 | temperature = 1.8 449 | pred = pred.div(temperature) 450 | loss_seg = seg_loss(pred, labels) 451 | 452 | D_out = model_D1(F.interpolate(feat_src, size=src_size, mode='bilinear', align_corners=True)) 453 | 454 | loss_adv_src = args.lambda_adv_src * bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(target_label).cuda()) 455 | loss = loss_seg + loss_adv_src 456 | 457 | # proper normalization 458 | loss = loss / args.iter_size 459 | loss_seg_value += loss_seg / args.iter_size 460 | loss_adv_src_value += loss_adv_src / args.iter_size 461 | 462 | scaler.scale(loss).backward() 463 | 464 | # train with target 465 | try: 466 | _, batch = targetloader_iter.__next__() 467 | except StopIteration: 468 | epoch_t += 1 469 | datasets.target_train_sampler.set_epoch(epoch_t) 470 | targetloader_iter = enumerate(datasets.target_train_loader) 471 | _, batch = targetloader_iter.__next__() 472 | 473 | images = batch['img'].cuda() 474 | 475 | tgt_size = images.shape[-2:] 476 | with torch.cuda.amp.autocast(): 477 | feat_tgt = model_B2(images) 478 | feat_B_tgt = model_B(feat_tgt) 479 | 480 | feat_B_tgt_head = head(feat_B_tgt) 481 | pred_tgt = classifier(feat_B_tgt_head) 482 | 483 | with torch.no_grad(): 484 | pred_logits, pred_idx = F.softmax(pred_tgt.detach(), 1).max(1) #[b, h, w] 485 | assert pred_logits.shape[-2:] == pred_tgt.shape[-2:] 486 | 487 | # update_thresholds 488 | for c in range(args.num_classes): 489 | prob_c = pred_logits[pred_idx == c].cpu().numpy().tolist() 490 | if len(prob_c) == 0: 491 | continue 492 | conf_bank[c].extend(prob_c) 493 | rank = int(len(conf_bank[c]) * args.conf_p) 494 | thresholds[c] = sorted(conf_bank[c], reverse=True)[rank] 495 | if len(conf_bank[c]) > args.conf_bank_length: 496 | conf_bank[c] = conf_bank[c][-args.conf_bank_length:] 497 | 498 | n = torch.tensor(1.0).cuda() 499 | dist.all_reduce(thresholds) 500 | dist.all_reduce(n) 501 | thresholds = thresholds / n 502 | 503 | if i_iter % 500 == 0 and gpu == 0: 504 | for c in range(args.num_classes): 505 | print("c: {}, class_i: {} threshold: {}, len(conf_bank[c]): {}".format(c, class_list[c], thresholds[c], len(conf_bank[c]))) 506 | 507 | # if i_iter % 100 == 0 and gpu == 0: 508 | # num_pos = (pred_logits > thresholds[pred_idx]).float().sum() 509 | # num_all = np.prod(pred_logits.shape) 510 | # ratio = num_pos / (num_all+1e-8) 511 | # logger.info("num_pos: {}, num_all: {}, ratio: {}".format(num_pos, num_all, ratio)) 512 | 513 | pred_idx[pred_logits < thresholds[pred_idx]] = args.ignore_label 514 | 515 | pred_tgt = interp_target(pred_tgt) 516 | pred_tgt = pred_tgt.div(temperature) 517 | 518 | pred_tgt_aux = aux_classifier(feat_B_tgt_head) 519 | loss_st = args.lambda_st * seg_loss(pred_tgt_aux, pred_idx) 520 | 521 | D_out = model_D2(F.softmax(pred_tgt, 1)) 522 | 523 | loss_adv_tgt = args.lambda_adv_tgt * bce_loss(D_out, torch.FloatTensor(D_out.data.size()).fill_(source_label).cuda()) 524 | loss = loss_adv_tgt + loss_st 525 | 526 | loss = loss / args.iter_size 527 | loss_adv_tgt_value += loss_adv_tgt / args.iter_size 528 | loss_st_value += loss_st / args.iter_size 529 | 530 | scaler.scale(loss).backward() 531 | 532 | # train D 533 | # bring back requires_grad 534 | for param in model_D1.parameters(): 535 | param.requires_grad = True 536 | 537 | optimizer_D1.zero_grad() 538 | with torch.cuda.amp.autocast(): 539 | src_D1_pred = model_D1(F.interpolate(feat_src.detach(), size=src_size, mode='bilinear', align_corners=True)) 540 | loss_D1_src = 0.5 * bce_loss(src_D1_pred, torch.FloatTensor(src_D1_pred.data.size()).fill_(source_label).cuda()) / args.iter_size 541 | 542 | scaler.scale(loss_D1_src).backward() 543 | 544 | with torch.cuda.amp.autocast(): 545 | 546 | tgt_D1_pred = model_D1(F.interpolate(feat_tgt.detach(), size=tgt_size, mode='bilinear', align_corners=True)) 547 | loss_D1_tgt = 0.5 * bce_loss(tgt_D1_pred, torch.FloatTensor(tgt_D1_pred.data.size()).fill_(target_label).cuda()) / args.iter_size 548 | 549 | loss_D1_value += loss_D1_src + loss_D1_tgt 550 | 551 | scaler.scale(loss_D1_tgt).backward() 552 | 553 | for param in model_D2.parameters(): 554 | param.requires_grad = True 555 | optimizer_D2.zero_grad() 556 | 557 | with torch.cuda.amp.autocast(): 558 | src_D2_pred = model_D2(F.softmax(pred.detach(), 1)) 559 | loss_D2_src = 0.5 * bce_loss(src_D2_pred, torch.FloatTensor(src_D2_pred.data.size()).fill_(source_label).cuda()) / args.iter_size 560 | 561 | scaler.scale(loss_D2_src).backward() 562 | 563 | with torch.cuda.amp.autocast(): 564 | 565 | tgt_D2_pred = model_D2(F.softmax(pred_tgt.detach(), 1)) 566 | loss_D2_tgt = 0.5 * bce_loss(tgt_D2_pred, torch.FloatTensor(tgt_D2_pred.data.size()).fill_(target_label).cuda()) / args.iter_size 567 | 568 | loss_D2_value += loss_D2_src + loss_D2_tgt 569 | 570 | scaler.scale(loss_D2_tgt).backward() 571 | 572 | n = torch.tensor(1.0).cuda() 573 | 574 | dist.all_reduce(n), dist.all_reduce(loss_seg_value), dist.all_reduce(loss_adv_src_value), dist.all_reduce(loss_adv_tgt_value) 575 | dist.all_reduce(loss_D1_value), dist.all_reduce(loss_D2_value), dist.all_reduce(loss_st_value) 576 | 577 | loss_seg_value = loss_seg_value.item() / n.item() 578 | loss_adv_src_value = loss_adv_src_value.item() / n.item() 579 | loss_adv_tgt_value = loss_adv_tgt_value.item() / n.item() 580 | loss_D1_value = loss_D1_value.item() / n.item() 581 | loss_D2_value = loss_D2_value.item() / n.item() 582 | loss_st_value = loss_st_value.item() / n.item() 583 | 584 | scaler.step(optimizer) 585 | scaler.step(optimizer_D1) 586 | scaler.step(optimizer_D2) 587 | scaler.update() 588 | 589 | if args.tensorboard and gpu == 0: 590 | scalar_info = { 591 | 'loss_seg': loss_seg_value, 592 | 'loss_adv_src': loss_adv_src_value, 593 | 'loss_adv_tgt': loss_adv_tgt_value, 594 | 'loss_D1': loss_D1_value, 595 | 'loss_D2': loss_D2_value, 596 | "loss_st": loss_st_value, 597 | } 598 | 599 | if i_iter % 10 == 0: 600 | for key, val in scalar_info.items(): 601 | writer.add_scalar(key, val, i_iter) 602 | 603 | if gpu == 0 and i_iter % args.print_every == 0: 604 | logger.info('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_src = {3:.5f}, loss_adv_tgt = {4:.5f}, loss_D1 = {5:.3f}, ' 605 | 'loss_D2 = {6:.3f}, loss_st = {7:.5f}, epoch_s = {8:3d}, epoch_t = {9:3d}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_src_value, \ 606 | loss_adv_tgt_value, loss_D1_value, loss_D2_value, loss_st_value, epoch_s, epoch_t)) 607 | 608 | if gpu == 0 and i_iter >= args.num_steps_stop - 1: 609 | logger.info('save model ...') 610 | filename = osp.join(args.snapshot_dir, 'GTA5_' + str(args.num_steps_stop) + '.pth') 611 | save_file = {'model_B1_state_dict': model_B1.state_dict(), 'model_B2_state_dict': model_B2.state_dict(), \ 612 | 'model_B_state_dict': model_B.state_dict(), 'head_state_dict': head.state_dict(), 'classifier_state_dict': classifier.state_dict()} 613 | torch.save(save_file, filename) 614 | logger.info("saving checkpoint model to {}".format(filename)) 615 | break 616 | 617 | if i_iter % args.save_pred_every == 0 and i_iter != 0: 618 | miou, loss_val = validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger if gpu == 0 else None, datasets.target_valid_loader) 619 | if args.tensorboard and gpu == 0: 620 | scalar_info = { 621 | 'miou_val': miou, 622 | 'loss_val': loss_val 623 | } 624 | for k, v in scalar_info.items(): 625 | writer.add_scalar(k, v, i_iter) 626 | 627 | if gpu == 0 and miou > best_miou: 628 | best_miou = miou 629 | logger.info('taking snapshot ...') 630 | if filename is not None and os.path.exists(filename): 631 | os.remove(filename) 632 | filename = osp.join(args.snapshot_dir, 'GTA5_' + str(i_iter) + "_{}".format(miou) + '.pth') 633 | save_file = {'model_B1_state_dict': model_B1.state_dict(), 'model_B2_state_dict': model_B2.state_dict(), \ 634 | 'model_B_state_dict': model_B.state_dict(), 'head_state_dict': head.state_dict(), 'classifier_state_dict': classifier.state_dict()} 635 | torch.save(save_file, filename) 636 | logger.info("saving checkpoint model to {}".format(filename)) 637 | 638 | if args.tensorboard and gpu == 0: 639 | writer.close() 640 | 641 | def validate(model_B2, model_B, head, classifier, seg_loss, gpu, logger, testloader): 642 | if gpu == 0: 643 | logger.info("Start Evaluation") 644 | # evaluate 645 | loss_meter = AverageMeter() 646 | intersection_meter = AverageMeter() 647 | union_meter = AverageMeter() 648 | 649 | model_B2.eval() 650 | model_B.eval() 651 | head.eval() 652 | classifier.eval() 653 | 654 | with torch.no_grad(): 655 | for i, batch in enumerate(testloader): 656 | images = batch['img'].cuda() 657 | labels = batch['label'].cuda() 658 | 659 | pred = model_B(model_B2(images)) 660 | pred = classifier(head(pred)) 661 | output = F.interpolate(pred, size=labels.size()[-2:], mode='bilinear', align_corners=True) 662 | loss = seg_loss(output, labels) 663 | 664 | output = output.max(1)[1] 665 | intersection, union, _ = intersectionAndUnionGPU(output, labels, args.num_classes, args.ignore_label) 666 | dist.all_reduce(intersection), dist.all_reduce(union) 667 | intersection, union = intersection.cpu().numpy(), union.cpu().numpy() 668 | intersection_meter.update(intersection), union_meter.update(union) 669 | loss_meter.update(loss.item(), images.size(0)) 670 | if gpu == 0 and i % 50 == 0 and i != 0: 671 | logger.info("Evaluation iter = {0:5d}/{1:5d}, loss_eval = {2:.3f}".format( 672 | i, len(testloader), loss_meter.val 673 | )) 674 | 675 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 676 | miou = np.mean(iou_class) 677 | if gpu == 0: 678 | logger.info("Val result: mIoU = {:.3f}".format(miou)) 679 | for i in range(args.num_classes): 680 | logger.info("Class_{} Result: iou = {:.3f}".format(i, iou_class[i])) 681 | logger.info("End Evaluation") 682 | 683 | torch.cuda.empty_cache() 684 | 685 | return miou, loss_meter.avg 686 | 687 | def find_free_port(): 688 | import socket 689 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 690 | # Binding to port 0 will cause the OS to find an available port for us 691 | sock.bind(("", 0)) 692 | port = sock.getsockname()[1] 693 | sock.close() 694 | # NOTE: there is still a chance the port could be taken by other processes. 695 | return port 696 | 697 | if __name__ == '__main__': 698 | args.gpus = [int(x) for x in args.gpus.split(",")] 699 | args.world_size = len(args.gpus) 700 | if args.dist: 701 | port = find_free_port() 702 | args.dist_url = f"tcp://127.0.0.1:{port}" 703 | mp.spawn(main_worker, nprocs=args.world_size, args=(args.world_size, args.dist_url)) 704 | else: 705 | main_worker(args.train_gpu, args.world_size, args) 706 | 707 | --------------------------------------------------------------------------------