├── script ├── __init__.py ├── batch_train.sh └── command.sh ├── generalframeworks ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── loss.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ └── loss.py ├── meter │ ├── __init__.py │ └── meter.py ├── networks │ ├── __init__.py │ ├── deeplabv3 │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── aspp.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── deeplabv3.cpython-37.pyc │ │ ├── aspp.py │ │ └── deeplabv3.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── ddp_model.cpython-37.pyc │ │ └── uncer_head.cpython-37.pyc │ ├── uncer_head.py │ ├── ddp_model.py │ └── resnet.py ├── util │ ├── __init__.py │ ├── __pycache__ │ │ ├── miou.cpython-37.pyc │ │ ├── meter.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── torch_dist_sum.cpython-37.pyc │ ├── miou.py │ ├── torch_dist_sum.py │ ├── dist_init.py │ └── meter.py ├── augmentation │ ├── __init__.py │ └── transform.py ├── scheduler │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── rampscheduler.cpython-37.pyc │ │ └── my_lr_scheduler.cpython-37.pyc │ ├── my_lr_scheduler.py │ └── rampscheduler.py ├── dataset_helpers │ ├── __init__.py │ ├── __pycache__ │ │ ├── VOC.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── Cityscapes.cpython-37.pyc │ ├── Cityscapes.py │ └── VOC.py ├── __pycache__ │ └── utils.cpython-37.pyc └── utils.py ├── PRCL.gif ├── pretrained └── Download.txt ├── config ├── CityScapes_prcl_config_150.yaml └── VOC_prcl_config_662.yaml ├── requirements.txt ├── visual.py ├── CityScapes_split └── 150 │ └── 3407 │ └── labeled_filename.txt ├── README.md ├── VOC_split └── 662 │ └── 3407 │ ├── labeled_filename.txt │ └── valid_filename.txt ├── prcl_sig.py └── prcl.py /script/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/loss/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/meter/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /PRCL.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/PRCL.gif -------------------------------------------------------------------------------- /pretrained/Download.txt: -------------------------------------------------------------------------------- 1 | Please download the model pretrained on Imagenet following README.md 2 | -------------------------------------------------------------------------------- /generalframeworks/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/loss/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/loss/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/miou.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/util/__pycache__/miou.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/meter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/util/__pycache__/meter.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/VOC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/dataset_helpers/__pycache__/VOC.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/ddp_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/networks/__pycache__/ddp_model.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/uncer_head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/networks/__pycache__/uncer_head.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/scheduler/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/torch_dist_sum.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/util/__pycache__/torch_dist_sum.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/dataset_helpers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/aspp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/networks/deeplabv3/__pycache__/aspp.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/rampscheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/scheduler/__pycache__/rampscheduler.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/Cityscapes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/dataset_helpers/__pycache__/Cityscapes.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/my_lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/scheduler/__pycache__/my_lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/networks/deeplabv3/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/deeplabv3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Haoyu-Xie/PRCL/HEAD/generalframeworks/networks/deeplabv3/__pycache__/deeplabv3.cpython-37.pyc -------------------------------------------------------------------------------- /script/batch_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ./script/command.sh prcl_662 8 1 "python -u prcl.py --config ./config/VOC_prcl_config_662.yaml" 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /generalframeworks/util/miou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mean_intersection_over_union(mat: torch.Tensor): 4 | ''' Compute miou via Confmatrix''' 5 | h = mat.float() 6 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 7 | miou = torch.mean(iu).item() 8 | 9 | return miou -------------------------------------------------------------------------------- /generalframeworks/scheduler/my_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | class PolyLR(_LRScheduler): 5 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6): 6 | self.power = power 7 | self.max_iters = max_iters 8 | self.min_lr = min_lr 9 | super(PolyLR, self).__init__(optimizer, last_epoch) 10 | 11 | def get_lr(self): 12 | return [max(base_lr * (1 - self.last_epoch / self.max_iters) ** self.power, self.min_lr) 13 | for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /script/command.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | job_name=$1 5 | train_gpu=$2 6 | num_node=$3 7 | command=$4 8 | total_process=$((train_gpu*num_node)) 9 | 10 | mkdir -p log 11 | 12 | 13 | port=$(( $RANDOM % 300 + 23450 )) 14 | 15 | 16 | # nohup 17 | GLOG_vmodule=MemcachedClient=-1 \ 18 | srun --partition=VA \ 19 | --mpi=pmi2 -n$total_process \ 20 | --gres=gpu:$train_gpu \ 21 | --ntasks-per-node=$train_gpu \ 22 | --job-name=$job_name \ 23 | --kill-on-bad-exit=1 \ 24 | --cpus-per-task=6 \ 25 | -x "BJ-IDC1-10-10-16-[53,98,115,116,117,119,120,121,122,123,124]" \ 26 | $command --port $port --job_name $job_name 2>&1|tee -a log/$job_name.log & 27 | -------------------------------------------------------------------------------- /generalframeworks/util/torch_dist_sum.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | __all__ = ['torch_dist_sum'] 5 | 6 | def torch_dist_sum(gpu, *args): 7 | process_group = torch.distributed.group.WORLD 8 | tensor_args = [] 9 | pending_res = [] 10 | for arg in args: 11 | # if isinstance(arg, torch.Tensor): 12 | # tensor_arg = arg.clone().reshape(-1).detach().cuda(gpu) 13 | # else: 14 | # tensor_arg = torch.tensor(arg).reshape(-1).cuda(gpu) 15 | tensor_arg = arg.clone().detach().cuda(gpu) 16 | tensor_args.append(tensor_arg) 17 | pending_res.append(torch.distributed.all_reduce(tensor_arg, group=process_group, async_op=True)) 18 | for res in pending_res: 19 | res.wait() 20 | return tensor_args 21 | -------------------------------------------------------------------------------- /config/CityScapes_prcl_config_150.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | name: DeepLabv3Plus 3 | num_class: 19 4 | 5 | EMA: 6 | alpha: 0.99 7 | 8 | Optim: 9 | lr: 3.2e-3 10 | uncer_lr: 5e-5 11 | weight_decay: 5e-4 12 | 13 | Lr_Scheduler: 14 | name: PolyLR 15 | step_size: 90 16 | gamma: 0.1 17 | 18 | Dataset: 19 | name: CityScapes 20 | data_dir: ./cityscapes 21 | txt_dir: ./CityScapes_split 22 | num_labels: 150 23 | batch_size: 6 24 | mix_mode: classmix 25 | crop_size: !!python/tuple [512,512] 26 | scale_size: !!python/tuple [0.5,2.0] 27 | 28 | Training_Setting: 29 | epoch: 200 30 | save_dir: ./checkpoints 31 | 32 | Seed: 3407 33 | 34 | Ramp_Scheduler: 35 | begin_epoch: 0 36 | max_epoch: 200 37 | max_value: 1.0 38 | min_value: 0 39 | ramp_mult: -5.0 40 | 41 | Prcl_Loss: 42 | is_available: True 43 | warm_up: 0 44 | un_threshold: 0.97 45 | strong_threshold: 0.8 46 | weak_threshold: 0.7 47 | temp: 100 48 | num_queries: 256 49 | num_negatives: 512 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | Bottleneck==1.3.4 3 | cachetools==5.0.0 4 | certifi==2021.10.8 5 | charset-normalizer==2.0.12 6 | colorama==0.4.4 7 | google-auth==2.6.6 8 | google-auth-oauthlib==0.4.6 9 | grpcio==1.44.0 10 | idna==3.3 11 | importlib-metadata==4.11.3 12 | Markdown==3.3.6 13 | numexpr==2.8.0 14 | numpy==1.21.6 15 | oauthlib==3.2.0 16 | opencv-python==4.5.5.64 17 | pandas==1.3.4 18 | Pillow==9.1.0 19 | pip==22.0.4 20 | protobuf==3.20.1 21 | pyasn1==0.4.8 22 | pyasn1-modules==0.2.8 23 | python-dateutil==2.8.2 24 | pytz==2022.1 25 | PyYAML==6.0 26 | requests==2.27.1 27 | requests-oauthlib==1.3.1 28 | rsa==4.8 29 | setuptools==62.1.0 30 | shutup==0.2.0 31 | six==1.16.0 32 | tensorboard==2.8.0 33 | tensorboard-data-server==0.6.1 34 | tensorboard-plugin-wit==1.8.1 35 | tensorboardX==2.5 36 | torch==1.7.1+cu110 37 | torchaudio==0.7.2 38 | torchvision==0.8.2+cu110 39 | tqdm==4.64.0 40 | typing_extensions==4.2.0 41 | urllib3==1.26.9 42 | Werkzeug==2.1.2 43 | wheel==0.37.1 44 | zipp==3.8.0 45 | -------------------------------------------------------------------------------- /config/VOC_prcl_config_662.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | name: DeepLabv3Plus 3 | num_class: 21 4 | 5 | EMA: 6 | alpha: 0.99 7 | 8 | Optim: 9 | lr: 3.2e-3 10 | uncer_lr: 5e-5 11 | weight_decay: 5e-4 12 | 13 | Lr_Scheduler: 14 | name: PolyLR 15 | step_size: 90 16 | gamma: 0.1 17 | 18 | Dataset: 19 | name: VOC 20 | data_dir: ./VOC 21 | txt_dir: ./VOC_split 22 | num_labels: 662 23 | batch_size: 2 24 | crop_size: !!python/tuple [321,321] 25 | scale_size: !!python/tuple [0.5,1.5] 26 | mix_mode: classmix 27 | 28 | Training_Setting: 29 | epoch: 200 30 | save_dir: ./checkpoints/ 31 | 32 | Seed: 3407 33 | 34 | Prcl_Loss: 35 | is_available: True 36 | warm_up: 0 37 | un_threshold: 0.97 38 | strong_threshold: 0.8 39 | weak_threshold: 0.7 40 | temp: 100 41 | num_queries: 256 42 | num_negatives: 512 43 | 44 | Ramp_Scheduler: 45 | begin_epoch: 0 46 | max_epoch: 200 47 | max_value: 1.0 48 | min_value: 0 49 | ramp_mult: -5.0 50 | 51 | Distributed: 52 | world_size: 8 53 | gpu_id: 0,1,2,3,4,5,6,7 54 | -------------------------------------------------------------------------------- /generalframeworks/networks/uncer_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | 6 | class Uncertainty_head(nn.Module): # feature -> log(sigma^2) 7 | def __init__(self, in_feat=304, out_feat=256): 8 | super(Uncertainty_head, self).__init__() 9 | self.fc1 = Parameter(torch.Tensor(out_feat, in_feat)) 10 | self.bn1 = nn.BatchNorm2d(out_feat, affine=True) 11 | self.relu = nn.ReLU() 12 | self.fc2 = Parameter(torch.Tensor(out_feat, out_feat)) 13 | self.bn2 = nn.BatchNorm2d(out_feat, affine=False) 14 | self.gamma = Parameter(torch.Tensor([1.0])) 15 | self.beta = Parameter(torch.Tensor([0.0])) 16 | 17 | nn.init.kaiming_normal_(self.fc1) 18 | nn.init.kaiming_normal_(self.fc2) 19 | 20 | def forward(self, x: torch.Tensor): 21 | x = x.permute(0, 2, 3, 1) 22 | x = F.linear(x, F.normalize(self.fc1, dim=-1)) # [B, W, H, D] 23 | x = x.permute(0, 3, 1, 2) # [B, W, H, D] -> [B, D, W, H] 24 | x = self.bn1(x) 25 | x = self.relu(x) 26 | x = x.permute(0, 2, 3, 1) 27 | x = F.linear(x, F.normalize(self.fc2, dim=-1)) 28 | x = x.permute(0, 3, 1, 2) 29 | x = self.bn2(x) 30 | x = self.gamma * x + self.beta 31 | x = torch.log(torch.exp(x) + 1e-6) 32 | x = torch.sigmoid(x) 33 | 34 | return x -------------------------------------------------------------------------------- /generalframeworks/util/dist_init.py: -------------------------------------------------------------------------------- 1 | def dist_init(port): 2 | import torch 3 | import os 4 | 5 | def init(host_addr, rank, local_rank, world_size, port): 6 | host_addr_full = 'tcp://' + host_addr + ':' + str(port) 7 | torch.distributed.init_process_group("nccl", init_method=host_addr_full, 8 | rank=rank, world_size=world_size) 9 | torch.cuda.set_device(local_rank) 10 | assert torch.distributed.is_initialized() 11 | 12 | def parse_host_addr(s): 13 | if '[' in s: 14 | left_bracket = s.index('[') 15 | right_bracket = s.index(']') 16 | prefix = s[:left_bracket] 17 | first_number = s[left_bracket+1:right_bracket].split(',')[0].split('-')[0] 18 | return prefix + first_number 19 | else: 20 | return s 21 | 22 | rank = int(os.environ['SLURM_PROCID']) 23 | local_rank = int(os.environ['SLURM_LOCALID']) 24 | world_size = int(os.environ['SLURM_NTASKS']) 25 | 26 | ip = parse_host_addr(os.environ['SLURM_STEP_NODELIST']) 27 | 28 | init(ip, rank, local_rank, world_size, port) 29 | 30 | return rank, local_rank, world_size 31 | 32 | def local_dist_init(config, port): 33 | import torch 34 | import os 35 | 36 | os.environ['MASTER_ADDR'] = 'localhost' 37 | os.environ['MASTER_PORT'] = str(port) 38 | os.environ['WORLD_SIZE'] = str(config['Distributed']['world_size']) 39 | os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL' 40 | 41 | -------------------------------------------------------------------------------- /generalframeworks/scheduler/rampscheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class RampScheduler(object): 3 | 4 | def __init__(self, begin_epoch, max_epoch, max_value, ramp_mult): 5 | super().__init__() 6 | self.begin_epoch = int(begin_epoch) 7 | self.max_epoch = int(max_epoch) 8 | self.max_value = float(max_value) 9 | self.mult = float(ramp_mult) 10 | self.epoch = 0 11 | 12 | def step(self): 13 | self.epoch += 1 14 | 15 | @property 16 | def value(self): 17 | return self.get_lr(self.epoch, self.begin_epoch, self.max_epoch, self.max_value, self.mult) 18 | 19 | @staticmethod 20 | def get_lr(epoch, begin_epoch, max_epochs, max_val, mult): 21 | if epoch < begin_epoch: 22 | return 0. 23 | elif epoch >= max_epochs: 24 | return max_val 25 | return max_val * np.exp(mult * (1. - float(epoch - begin_epoch) / (max_epochs - begin_epoch)) ** 2) 26 | 27 | class RampdownScheduler(object): 28 | 29 | def __init__(self, begin_epoch, max_epoch, current_epoch, max_value, min_value, ramp_mult): 30 | super().__init__() 31 | self.begin_epoch = int(begin_epoch) 32 | self.max_epoch = int(max_epoch) 33 | self.max_value = float(max_value) 34 | self.mult = float(ramp_mult) 35 | self.epoch = current_epoch 36 | self.min_value = min_value 37 | 38 | def step(self): 39 | self.epoch += 1 40 | 41 | @property 42 | def value(self): 43 | current_value = self.get_lr(self.epoch, self.begin_epoch, self.max_epoch, self.max_value, self.min_value, self.mult) 44 | if current_value < self.min_value: 45 | current_value = self.min_value 46 | return current_value 47 | 48 | @staticmethod 49 | def get_lr(epoch, begin_epoch, max_epochs, max_val, min_value, mult): 50 | if epoch < begin_epoch: 51 | return 0. 52 | elif epoch >= max_epochs: 53 | return min_value 54 | return max_val * np.exp(mult * (float(epoch - begin_epoch) / (max_epochs - begin_epoch)) ** 2) -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DeepLabHead(nn.Sequential): 7 | def __init__(self, in_channels, num_classes): 8 | super(DeepLabHead, self).__init__( 9 | ASPP(in_channels, [12, 24, 36]), 10 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 11 | nn.BatchNorm2d(256), 12 | nn.ReLU(), 13 | nn.Conv2d(256, num_classes, 1) 14 | ) 15 | 16 | 17 | class ASPPConv(nn.Sequential): 18 | def __init__(self, in_channels, out_channels, dilation): 19 | modules = [ 20 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU() 23 | ] 24 | super(ASPPConv, self).__init__(*modules) 25 | 26 | 27 | class ASPPPooling(nn.Sequential): 28 | def __init__(self, in_channels, out_channels): 29 | super(ASPPPooling, self).__init__( 30 | nn.AdaptiveAvgPool2d(1), 31 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 32 | nn.BatchNorm2d(out_channels), 33 | nn.ReLU()) 34 | 35 | def forward(self, x): 36 | size = x.shape[-2:] 37 | x = super(ASPPPooling, self).forward(x) 38 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 39 | 40 | 41 | class ASPP(nn.Module): 42 | def __init__(self, in_channels, atrous_rates): 43 | super(ASPP, self).__init__() 44 | out_channels = 256 45 | #modules = [] 46 | modules = torch.nn.ModuleList() 47 | modules.append(nn.Sequential( 48 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 49 | nn.BatchNorm2d(out_channels), 50 | nn.ReLU())) 51 | 52 | rate1, rate2, rate3 = tuple(atrous_rates) 53 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 54 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 55 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 56 | modules.append(ASPPPooling(in_channels, out_channels)) 57 | # self.convs = nn.ModuleList(modules) 58 | self.convs = modules 59 | 60 | self.project = nn.Sequential( 61 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 62 | nn.BatchNorm2d(out_channels), 63 | nn.ReLU(), 64 | # nn.Dropout(0.5) 65 | ) 66 | 67 | def forward(self, x): 68 | res = [] 69 | for conv in self.convs: 70 | res.append(conv(x)) 71 | res = torch.cat(res, dim=1) 72 | return self.project(res) 73 | -------------------------------------------------------------------------------- /generalframeworks/util/meter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, name, fmt=':f'): 7 | self.name = name 8 | self.fmt = fmt 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | def __str__(self): 24 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 25 | return fmtstr.format(**self.__dict__) 26 | 27 | 28 | class ConfMatrix(object): 29 | def __init__(self, num_classes, fmt, name='miou'): 30 | self.name = name 31 | self.fmt = fmt 32 | self.num_classes = num_classes 33 | self.mat = None 34 | self.temp_mat = None 35 | self.val = 0 36 | self.avg = 0 37 | 38 | 39 | def update(self, pred, target): 40 | n = self.num_classes 41 | self.temp_mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 42 | if self.mat is None: 43 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 44 | with torch.no_grad(): 45 | k = (target >= 0) & (target < n) 46 | inds = n * target[k].to(torch.int64) + pred[k] 47 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 48 | self.temp_mat = torch.bincount(inds, minlength=n**2).reshape(n, n) 49 | 50 | 51 | def __str__(self): 52 | h = self.mat.float() 53 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 54 | self.avg = torch.mean(iu).item() 55 | 56 | h_t = self.temp_mat.float() 57 | iu_a = torch.diag(h_t) / (h_t.sum(1) + h_t.sum(0) - torch.diag(h_t)) 58 | self.val = torch.mean(iu_a).item() 59 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 60 | return fmtstr.format(**self.__dict__) 61 | 62 | 63 | class ProgressMeter(object): 64 | def __init__(self, num_batches, meters, prefix=""): 65 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 66 | self.meters = meters 67 | self.prefix = prefix 68 | 69 | def display(self, batch): 70 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 71 | entries += [str(meter) for meter in self.meters] 72 | print('\t'.join(entries)) 73 | 74 | def _get_batch_fmtstr(self, num_batches): 75 | num_digits = len(str(num_batches // 1)) 76 | fmt = '{:' + str(num_digits) + 'd}' 77 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 78 | 79 | -------------------------------------------------------------------------------- /visual.py: -------------------------------------------------------------------------------- 1 | import shutup 2 | shutup.please() 3 | import torch 4 | import torchvision.models as models 5 | from torch.nn import functional as F 6 | 7 | import numpy as np 8 | from PIL import Image 9 | from generalframeworks.networks.deeplabv3.deeplabv3 import DeepLabv3Plus_with_un 10 | from generalframeworks.augmentation.transform import transform 11 | 12 | # ++++++++++++++++++++ Utils +++++++++++++++++++++++++ 13 | def create_pascal_label_colormap(): 14 | """Creates a label colormap used in Pascal segmentation benchmark. 15 | Returns: 16 | A colormap for visualizing segmentation results. 17 | """ 18 | colormap = 255 * np.ones((256, 3), dtype=np.uint8) 19 | colormap[0] = [0, 0, 0] 20 | colormap[1] = [128, 0, 0] 21 | colormap[2] = [0, 128, 0] 22 | colormap[3] = [128, 128, 0] 23 | colormap[4] = [0, 0, 128] 24 | colormap[5] = [128, 0, 128] 25 | colormap[6] = [0, 128, 128] 26 | colormap[7] = [128, 128, 128] 27 | colormap[8] = [64, 0, 0] 28 | colormap[9] = [192, 0, 0] 29 | colormap[10] = [64, 128, 0] 30 | colormap[11] = [192, 128, 0] 31 | colormap[12] = [64, 0, 128] 32 | colormap[13] = [192, 0, 128] 33 | colormap[14] = [64, 128, 128] 34 | colormap[15] = [192, 128, 128] 35 | colormap[16] = [0, 64, 0] 36 | colormap[17] = [128, 64, 0] 37 | colormap[18] = [0, 192, 0] 38 | colormap[19] = [128, 192, 0] 39 | colormap[20] = [0, 64, 128] 40 | return colormap 41 | 42 | def color_map(mask, colormap): 43 | color_mask = np.zeros([mask.shape[0], mask.shape[1], 3]) 44 | for i in np.unique(mask): 45 | color_mask[mask == i] = colormap[i] 46 | return np.uint8(color_mask) 47 | 48 | 49 | # ++++++++++++++++++++ Pascal VOC Visualisation +++++++++++++++++++++++++ 50 | # Initialization 51 | im_size = [513, 513] 52 | root = './dataset/pascal' 53 | num_segments = 21 54 | device = torch.device("cpu") 55 | model = DeepLabv3Plus_with_un(models.resnet101(), num_classes=num_segments).to(device) 56 | 57 | # Load checkpoint 58 | checkpoint = torch.load('./best_model.pth', map_location='cpu') 59 | model.load_state_dict(checkpoint['model']) 60 | 61 | # Switch to eval mode 62 | model.eval() 63 | 64 | # Generate color map for visualisation 65 | colormap = create_pascal_label_colormap() 66 | 67 | # Visualise image in validation set 68 | 69 | # Load images and pre-process 70 | with open(root + '/test_val.txt') as f: 71 | idx_list = f.read().splitlines() 72 | for id in idx_list: 73 | print('Image {} start!'.format(id)) 74 | im = Image.open('./dataset/VOCdevkit/VOC2012/JPEGImages/{}.jpg'.format(id)) 75 | im.save('./vis/image/{}.png'.format(id)) 76 | gt_label = Image.open('./dataset/VOCdevkit/VOC2012/SegmentationClassAug/{}.png'.format(id)) 77 | im_tensor, label_tensor = transform(im, gt_label, None, crop_size=im_size, scale_size=(1.0, 1.0), augmentation=False) 78 | im_w, im_h = im.size 79 | 80 | # Inference 81 | logits, _, _ = model(im_tensor.unsqueeze(0)) 82 | logits = F.interpolate(logits, size=im_size, mode='bilinear', align_corners=True) 83 | max_logits, label_prcl = torch.max(torch.softmax(logits, dim=1), dim=1) 84 | 85 | # Show the results and save 86 | gt_blend = Image.blend(im, Image.fromarray(color_map(label_tensor[0].numpy(), colormap)[:im_h, :im_w]), alpha=0.7) 87 | prcl_blend = Image.blend(im, Image.fromarray(color_map(label_prcl[0].numpy(), colormap)[:im_h, :im_w]), alpha=0.7) 88 | prcl_blend.save('./vis/prcl/{}.png'.format(id)) 89 | -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/deeplabv3.py: -------------------------------------------------------------------------------- 1 | from .aspp import * 2 | from functools import partial 3 | 4 | ##### For PRCL Loss ##### 5 | class DeepLabv3Plus_with_un(nn.Module): 6 | def __init__(self, orig_resnet, dilate_scale=16, num_classes=21, output_dim=256): 7 | super(DeepLabv3Plus_with_un, self).__init__() 8 | if dilate_scale == 8: 9 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 10 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 11 | aspp_dilate = [12, 24, 36] 12 | 13 | elif dilate_scale == 16: 14 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 15 | aspp_dilate = [6, 12, 18] 16 | 17 | # take pre-defined ResNet, except AvgPool and FC 18 | self.resnet_conv1 = orig_resnet.conv1 19 | #self.resnet_conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # Change the num of input channel 20 | self.resnet_bn1 = orig_resnet.bn1 21 | self.resnet_relu1 = orig_resnet.relu 22 | self.resnet_maxpool = orig_resnet.maxpool 23 | 24 | self.resnet_layer1 = orig_resnet.layer1 25 | self.resnet_layer2 = orig_resnet.layer2 26 | self.resnet_layer3 = orig_resnet.layer3 27 | self.resnet_layer4 = orig_resnet.layer4 28 | 29 | self.ASPP = ASPP(2048, aspp_dilate) 30 | 31 | self.project = nn.Sequential( 32 | nn.Conv2d(256, 48, 1, bias=False), 33 | nn.BatchNorm2d(48), 34 | nn.ReLU(inplace=True), 35 | ) 36 | 37 | self.classifier = nn.Sequential( 38 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 39 | nn.BatchNorm2d(256), 40 | nn.ReLU(), 41 | nn.Conv2d(256, num_classes, 1) 42 | ) 43 | 44 | self.representation = nn.Sequential( 45 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 46 | nn.BatchNorm2d(256), 47 | nn.ReLU(), 48 | nn.Conv2d(256, output_dim, 1) 49 | ) 50 | 51 | def _nostride_dilate(self, m, dilate): 52 | classname = m.__class__.__name__ 53 | if classname.find('Conv') != -1: 54 | # the convolution with stride 55 | if m.stride == (2, 2): 56 | m.stride = (1, 1) 57 | if m.kernel_size == (3, 3): 58 | m.dilation = (dilate // 2, dilate // 2) 59 | m.padding = (dilate // 2, dilate // 2) 60 | 61 | # other convoluions 62 | else: 63 | if m.kernel_size == (3, 3): 64 | m.dilation = (dilate, dilate) 65 | m.padding = (dilate, dilate) 66 | 67 | def forward(self, x): 68 | h_w = x.shape[2:] 69 | # with ResNet-50 Encoder 70 | x = self.resnet_relu1(self.resnet_bn1(self.resnet_conv1(x))) 71 | x = self.resnet_maxpool(x) 72 | 73 | x_low = self.resnet_layer1(x) 74 | x = self.resnet_layer2(x_low) 75 | x = self.resnet_layer3(x) 76 | x = self.resnet_layer4(x) 77 | 78 | feature = self.ASPP(x) 79 | 80 | # Decoder 81 | x_low = self.project(x_low) 82 | output_feature = F.interpolate(feature, size=x_low.shape[2:], mode='bilinear', align_corners=True) 83 | prediction = self.classifier(torch.cat([x_low, output_feature], dim=1)) 84 | representation = self.representation(torch.cat([x_low, output_feature], dim=1)) 85 | 86 | 87 | return prediction, representation, torch.cat([x_low, output_feature], dim=1) -------------------------------------------------------------------------------- /generalframeworks/meter/meter.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.sharedctypes import Value 2 | import torch 3 | from generalframeworks.utils import class2one_hot 4 | import numpy as np 5 | 6 | 7 | class Meter(object): 8 | 9 | def reset(self): 10 | # Reset the Meter to default settings 11 | pass 12 | 13 | def add(self, pred_logits, label): 14 | # Log a new value to the meter 15 | pass 16 | 17 | def value(self): 18 | # Get the value of the meter in the current state 19 | pass 20 | 21 | def summary(self) -> dict: 22 | raise NotImplementedError 23 | 24 | def detailed_summary(self) -> dict: 25 | raise NotImplementedError 26 | 27 | class ConfMatrix(object): 28 | def __init__(self, num_classes): 29 | self.num_classes = num_classes 30 | self.mat = None 31 | 32 | def update(self, pred, target): 33 | n = self.num_classes 34 | if self.mat is None: 35 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 36 | with torch.no_grad(): 37 | k = (target >= 0) & (target < n) 38 | inds = n * target[k].to(torch.int64) + pred[k] 39 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 40 | 41 | 42 | def get_metrics(self): 43 | h = self.mat.float() 44 | acc = torch.diag(h).sum() / h.sum() 45 | up = torch.diag(h) 46 | down = h.sum(1) + h.sum(0) - torch.diag(h) 47 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h) + 1e-6) 48 | return torch.mean(iu).item(), acc.item() 49 | 50 | def get_valid_metrics(self): 51 | h = self.mat.float() 52 | acc = torch.diag(h).sum() / h.sum() 53 | up = torch.diag(h) 54 | down = h.sum(1) + h.sum(0) - torch.diag(h) 55 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h) + 1e-6) 56 | num_no_zero = (iu == 0).sum() 57 | return iu.sum() / (len(iu) - num_no_zero).item(), acc.item() 58 | 59 | 60 | 61 | class My_ConfMatrix(Meter): 62 | def __init__(self, num_classes): 63 | super(ConfMatrix, self).__init__() 64 | self.num_classes = num_classes 65 | self.mat = None 66 | self.reset() 67 | self.mIOU = [] 68 | self.Acc = [] 69 | 70 | def add(self, pred_logits, label): 71 | pred_logits = pred_logits.argmax(1).flatten() 72 | label = label.flatten() 73 | n = self.num_classes 74 | if self.mat is None: 75 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred_logits.device) 76 | with torch.no_grad(): 77 | k = (label >= 0) & (label < n) 78 | inds = n * label[k].to(torch.int64) + pred_logits[k] 79 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) 80 | 81 | def value(self, mode='mean'): 82 | h = self.mat.float() 83 | self.acc = torch.diag(h).sum() / h.sum() 84 | self.iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 85 | if mode == 'mean': 86 | return torch.mean(self.iu).item(), self.acc.item() 87 | else: 88 | raise ValueError("mode must be in (mean)") 89 | 90 | def reset(self): 91 | self.mIOU = [] 92 | self.Acc = [] 93 | 94 | def summary(self) -> dict: 95 | mIOU_dct: dict = {} 96 | Acc_dct: dict = {} 97 | for c in range(self.num_classes): 98 | if c != 0: 99 | mIOU_dct['mIOU_{}'.format(c)] = np.array([self.value(i, mode='all')[0] for i in range(len(self.mIOU))])[ 100 | :, c].mean() 101 | Acc_dct['Acc_{}'.format(c)] = np.array([self.value(i, mode='all')[1] for i in range(len(self.mIOU))])[:, 102 | c].mean() 103 | return mIOU_dct, Acc_dct 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /CityScapes_split/150/3407/labeled_filename.txt: -------------------------------------------------------------------------------- 1 | aachen_000013_000019_leftImg8bit 2 | strasbourg_000000_031067_leftImg8bit 3 | aachen_000172_000019_leftImg8bit 4 | strasbourg_000000_003846_leftImg8bit 5 | stuttgart_000030_000019_leftImg8bit 6 | ulm_000065_000019_leftImg8bit 7 | bochum_000000_033714_leftImg8bit 8 | ulm_000048_000019_leftImg8bit 9 | bremen_000004_000019_leftImg8bit 10 | jena_000022_000019_leftImg8bit 11 | zurich_000044_000019_leftImg8bit 12 | monchengladbach_000000_011383_leftImg8bit 13 | hanover_000000_008200_leftImg8bit 14 | zurich_000003_000019_leftImg8bit 15 | hamburg_000000_057816_leftImg8bit 16 | bremen_000168_000019_leftImg8bit 17 | hamburg_000000_045437_leftImg8bit 18 | darmstadt_000007_000019_leftImg8bit 19 | zurich_000017_000019_leftImg8bit 20 | hamburg_000000_068693_leftImg8bit 21 | tubingen_000001_000019_leftImg8bit 22 | zurich_000118_000019_leftImg8bit 23 | strasbourg_000001_028852_leftImg8bit 24 | hanover_000000_048379_leftImg8bit 25 | bremen_000107_000019_leftImg8bit 26 | stuttgart_000050_000019_leftImg8bit 27 | cologne_000068_000019_leftImg8bit 28 | strasbourg_000000_025907_leftImg8bit 29 | hanover_000000_012675_leftImg8bit 30 | bremen_000003_000019_leftImg8bit 31 | weimar_000096_000019_leftImg8bit 32 | bremen_000311_000019_leftImg8bit 33 | bremen_000097_000019_leftImg8bit 34 | tubingen_000032_000019_leftImg8bit 35 | dusseldorf_000120_000019_leftImg8bit 36 | hanover_000000_004230_leftImg8bit 37 | bremen_000077_000019_leftImg8bit 38 | zurich_000038_000019_leftImg8bit 39 | bremen_000229_000019_leftImg8bit 40 | bremen_000076_000019_leftImg8bit 41 | cologne_000152_000019_leftImg8bit 42 | hanover_000000_009004_leftImg8bit 43 | strasbourg_000000_017283_leftImg8bit 44 | hanover_000000_036562_leftImg8bit 45 | strasbourg_000001_047336_leftImg8bit 46 | tubingen_000143_000019_leftImg8bit 47 | bremen_000033_000019_leftImg8bit 48 | hamburg_000000_053086_leftImg8bit 49 | bremen_000224_000019_leftImg8bit 50 | zurich_000043_000019_leftImg8bit 51 | dusseldorf_000142_000019_leftImg8bit 52 | hanover_000000_049465_leftImg8bit 53 | bremen_000244_000019_leftImg8bit 54 | bremen_000165_000019_leftImg8bit 55 | darmstadt_000081_000019_leftImg8bit 56 | bremen_000269_000019_leftImg8bit 57 | strasbourg_000001_028379_leftImg8bit 58 | aachen_000106_000019_leftImg8bit 59 | hamburg_000000_092476_leftImg8bit 60 | dusseldorf_000137_000019_leftImg8bit 61 | aachen_000011_000019_leftImg8bit 62 | stuttgart_000066_000019_leftImg8bit 63 | strasbourg_000001_024152_leftImg8bit 64 | stuttgart_000063_000019_leftImg8bit 65 | bremen_000158_000019_leftImg8bit 66 | darmstadt_000006_000019_leftImg8bit 67 | bremen_000144_000019_leftImg8bit 68 | zurich_000041_000019_leftImg8bit 69 | hanover_000000_015587_leftImg8bit 70 | hanover_000000_034015_leftImg8bit 71 | dusseldorf_000093_000019_leftImg8bit 72 | zurich_000059_000019_leftImg8bit 73 | bremen_000078_000019_leftImg8bit 74 | dusseldorf_000128_000019_leftImg8bit 75 | hamburg_000000_019760_leftImg8bit 76 | ulm_000052_000019_leftImg8bit 77 | cologne_000063_000019_leftImg8bit 78 | cologne_000105_000019_leftImg8bit 79 | dusseldorf_000020_000019_leftImg8bit 80 | hanover_000000_055937_leftImg8bit 81 | strasbourg_000001_016481_leftImg8bit 82 | weimar_000114_000019_leftImg8bit 83 | hamburg_000000_028439_leftImg8bit 84 | zurich_000010_000019_leftImg8bit 85 | stuttgart_000103_000019_leftImg8bit 86 | ulm_000084_000019_leftImg8bit 87 | bremen_000052_000019_leftImg8bit 88 | hamburg_000000_088939_leftImg8bit 89 | strasbourg_000001_031116_leftImg8bit 90 | hanover_000000_056457_leftImg8bit 91 | stuttgart_000083_000019_leftImg8bit 92 | hanover_000000_049005_leftImg8bit 93 | hamburg_000000_025986_leftImg8bit 94 | weimar_000075_000019_leftImg8bit 95 | bremen_000132_000019_leftImg8bit 96 | aachen_000031_000019_leftImg8bit 97 | dusseldorf_000090_000019_leftImg8bit 98 | darmstadt_000080_000019_leftImg8bit 99 | aachen_000062_000019_leftImg8bit 100 | bremen_000094_000019_leftImg8bit 101 | erfurt_000102_000019_leftImg8bit 102 | darmstadt_000047_000019_leftImg8bit 103 | cologne_000019_000019_leftImg8bit 104 | bremen_000242_000019_leftImg8bit 105 | hanover_000000_029325_leftImg8bit 106 | aachen_000155_000019_leftImg8bit 107 | erfurt_000057_000019_leftImg8bit 108 | bremen_000138_000019_leftImg8bit 109 | hamburg_000000_016691_leftImg8bit 110 | hamburg_000000_069289_leftImg8bit 111 | bremen_000140_000019_leftImg8bit 112 | hamburg_000000_070444_leftImg8bit 113 | dusseldorf_000141_000019_leftImg8bit 114 | stuttgart_000167_000019_leftImg8bit 115 | zurich_000060_000019_leftImg8bit 116 | jena_000007_000019_leftImg8bit 117 | cologne_000064_000019_leftImg8bit 118 | strasbourg_000001_042558_leftImg8bit 119 | hamburg_000000_087822_leftImg8bit 120 | stuttgart_000043_000019_leftImg8bit 121 | bremen_000109_000019_leftImg8bit 122 | cologne_000108_000019_leftImg8bit 123 | strasbourg_000000_013574_leftImg8bit 124 | tubingen_000090_000019_leftImg8bit 125 | stuttgart_000034_000019_leftImg8bit 126 | strasbourg_000001_026856_leftImg8bit 127 | strasbourg_000001_000113_leftImg8bit 128 | bremen_000268_000019_leftImg8bit 129 | strasbourg_000000_031223_leftImg8bit 130 | zurich_000042_000019_leftImg8bit 131 | zurich_000088_000019_leftImg8bit 132 | hamburg_000000_073549_leftImg8bit 133 | bremen_000160_000019_leftImg8bit 134 | cologne_000146_000019_leftImg8bit 135 | hanover_000000_013814_leftImg8bit 136 | bremen_000084_000019_leftImg8bit 137 | erfurt_000072_000019_leftImg8bit 138 | hamburg_000000_039420_leftImg8bit 139 | hanover_000000_009420_leftImg8bit 140 | bremen_000053_000019_leftImg8bit 141 | zurich_000045_000019_leftImg8bit 142 | dusseldorf_000197_000019_leftImg8bit 143 | stuttgart_000090_000019_leftImg8bit 144 | hanover_000000_039021_leftImg8bit 145 | jena_000044_000019_leftImg8bit 146 | dusseldorf_000073_000019_leftImg8bit 147 | darmstadt_000012_000019_leftImg8bit 148 | hanover_000000_048274_leftImg8bit 149 | hamburg_000000_073758_leftImg8bit 150 | hanover_000000_049269_leftImg8bit 151 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Boosting Pixel-Wise Contrastive Learning with Probabilitic Representations (AAAI 2023) 2 | ![https://github.com/Haoyu-Xie/PRCL/blob/main/PRCL.gif](https://github.com/Haoyu-Xie/PRCL/blob/main/PRCL.gif) 3 | This repository contains the source code of **PRCL** from the paper [Boosting Pixel-Wise Contrastive Learning with Probabilitic Representations](https://arxiv.org/abs/2210.14670). 4 | 5 | In this paper, we redefine the representation in pixel-wise contrastive learning from a perspective of probability theory. We consider the probability and model the representation as random variable, namely **Probabilistic Representation**. 6 | ## Updates 7 | **Nov. 2022** -- Upload the sorce code. 8 | 9 | ## Prepare 10 | PRCL is evaluated with two datasets: PASCAL VOC 2012 and CityScapes. 11 | - For PASCAL VOC, please download the original training images from the [official PASCAL site](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar): `VOCtrainval_11-May-2012.tar` and the augmented labels [here](http://vllab1.ucmerced.edu/~whung/adv-semi-seg/SegmentationClassAug.zip): `SegmentationClassAug.zip`. 12 | Extract the folder `JPEGImages` and `SegmentationClassAug` as follows: 13 | ``` 14 | ├── data 15 | │ ├── VOCdevkit 16 | │ │ ├──VOC2012 17 | │ │ | ├──JPEGImages 18 | │ │ | ├──SegmentationClassAug 19 | ``` 20 | - For CityScapes, please download the original images and labels from the [official CityScapes site](https://www.cityscapes-dataset.com/downloads/): `leftImg8bit_trainvaltest.zip` and `gtFine_trainvaltest.zip`. 21 | Extract the folder `leftImg8bit_trainvaltest.zip` and `gtFine_trainvaltest.zip` as follows: 22 | ``` 23 | ├── data 24 | │ ├── cityscapes 25 | │ │ ├──leftImg8bit 26 | │ │ | ├──train 27 | │ │ | ├──val 28 | │ │ ├──train 29 | │ │ ├──val 30 | ``` 31 | Folders `train` and `val` under `leftImg8bit` contains training and validation images while folders `train` and `val` under `leftImg8bit` contains labels. 32 | 33 | The data split folder of VOC and CityScapes is as follows: 34 | ``` 35 | ├── VOC(CityScapes)_split 36 | │ ├── labeled number 37 | │ │ ├──seed 38 | │ │ | ├──labeled_filename.txt 39 | │ │ | ├──unlabeled_filename.txt 40 | │ │ | ├──valid_filename.txt 41 | ``` 42 | You need to change the name of folders (labeled number and seed) according to your actual experiments. 43 | 44 | PRCL uses ResNet-101 pretrained on ImageNet and ResNet-101 with deep stem block, please download from [here](https://download.pytorch.org/models/resnet101-63fe2227.pth) for ResNet-101 and [here](https://drive.google.com/file/d/131dWv_zbr1ADUr_8H6lNyuGWsItHygSb/view?usp=sharing) for ResNet-101 stem. Remember to change the directory in corresponding python file. 45 | 46 | In order to install the correct environment, please run the following script: 47 | ``` 48 | conda create -n prcl python=3.8.5 49 | conda activate prcl 50 | pip install -r requirements.txt 51 | ``` 52 | It may takes a long time, take a break and have a cup of coffee! 53 | It is OK if you want to install environment manually, remember to check CAREFULLY! 54 | 55 | ## Run 56 | You can run our code with a single GPU or multiple GPUs. 57 | - For single GPU users, please run the following script: 58 | ``` 59 | python prcl_sig.py [--config] 60 | ``` 61 | You need to change the file name after --config according to your actual experiments. 62 | - For multiple GPUs users, please run the following script: 63 | ``` 64 | run ./script/batch_train.sh 65 | ``` 66 | We provide 662 labels for VOC and 150 labels for CityScapes, the seed in our experiments is 3407. You can change the label rate and seed as you like, remember to change the corresponding config files and data_split directory. 67 | ## Hyper-parameters 68 | All hyper-parameters used in the code are shown below: 69 | |Name | Discription | Value | 70 | | :-: |:-:| :-:| 71 | | `alpha` | hyper-parameter in EMA model | `0.99` | 72 | | `lr` | learning rate of backbone, prediction head, and project head | `3.2e-3` | 73 | | `uncer_lr` | learning rate of probability head | `5e-5` | 74 | | `un_threshold` | threshold in unsupervised loss | `0.97` | 75 | | `weak_threshold` | weak threshold in PRCL loss | `0.7` | 76 | | `strong_threshold` | strong threshold in PRCL loss | `0.8` | 77 | | `temp` | temperature in PRCL loss | `100` | 78 | | `num_queries` | number of queries in PRCL loss | `256` | 79 | | `num_negatives` | number of negatives in PRCL loss | `512` | 80 | | `begin_epoch` | the begin epoch of scheduler $\lambda_c$ | `0` | 81 | | `max_epoch` | the end epoch of scheduler $\lambda_c$ | `200` | 82 | | `max_value` | the max value of scheduler $\lambda_c$ | `1.0` | 83 | | `min_value` | the min value of scheduler $\lambda_c$ | `0` | 84 | | `ramp_mult` | the $\alpha$ of scheduler $\lambda_c$ | `-5.0` | 85 | 86 | **It is worth noting that uncer_lr is very sensitive and training may crash if uncer_lr is not fine-tuned CAREFULLY.** 87 | 88 | ## Acknowledgement 89 | The data processing and augmentation (CutMix, CutOut, and ClassMix) are borrowed from ReCo. 90 | - ReCo: https://github.com/lorenmt/reco 91 | 92 | Thanks a lot for their splendid work! 93 | 94 | ## Citation 95 | If you think this work is useful for you and your research, please considering citing the following: 96 | ``` 97 | @article{PRCL, 98 | title={Boosting Semi-Supervised Semantic Segmentation with Probabilistic Representations}, 99 | author={Xie, Haoyu and Wang, Changqi and Zheng, Mingkai and Dong, Minjing and You, Shan and Xu, Chang}, 100 | journal={arXiv preprint arXiv:2210.14670}, 101 | year={2022} 102 | } 103 | ``` 104 | More interesting works based on Probabilistic Representations are coming soon. 👣 105 | 106 | ## Contact 107 | If you have any questions or meet any problems, please feel free to contact us. 108 | - Haoyu Xie, [895852154@qq.com](mailto:895852154@qq.com) 109 | - Changqi Wang, [wangchangqi98@gmail.com](mailto:wangchangqi98@gmail.com) 110 | -------------------------------------------------------------------------------- /generalframeworks/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from typing import Iterable, Union 4 | from copy import deepcopy as dcopy 5 | from typing import List, Set 6 | import collections 7 | from functools import partial, reduce 8 | import torch 9 | import numpy as np 10 | import os 11 | import datetime 12 | from tqdm import tqdm 13 | from torch.utils.data import DataLoader 14 | import warnings 15 | import torch.nn as nn 16 | 17 | ##### Hyper Parameters Define ##### 18 | 19 | def _parser_(input_strings: str) -> Union[dict, None]: 20 | if input_strings.__len__() == 0: 21 | return None 22 | assert input_strings.find('=') > 0, f"Input args should include '=' to include value" 23 | keys, value = input_strings.split('=')[:-1][0].replace(' ', ''), input_strings.split('=')[1].replace(' ', '') 24 | keys = keys.split('.') 25 | keys.reverse() 26 | for k in keys: 27 | d = {} 28 | d[k] =value 29 | value = dcopy(d) 30 | return dict(value) 31 | 32 | def _parser(strings: List[str]) -> List[dict]: 33 | assert isinstance(strings, list) 34 | args: List[dict] = [_parser_(s) for s in strings] 35 | args = reduce(lambda x, y: dict_merge(x, y, True), args) 36 | return args 37 | 38 | def yaml_parser() -> dict: 39 | parser = argparse.ArgumentParser('Augmnet oarser for yaml config') 40 | parser.add_argument('strings', nargs='*', type=str, default=['']) 41 | parser.add_argument("--local_rank", type=int) 42 | #parser.add_argument('--var', type=int, default=24) 43 | #add args.variable here 44 | args: argparse.Namespace = parser.parse_args() 45 | args: dict = _parser(args.strings) 46 | return args 47 | 48 | def dict_merge(dct: dict, merge_dct: dict, re=False): 49 | ''' 50 | Recursive dict merge. Instead updating only top-level keys, dict_merge recuses down into dicts nested 51 | to an arbitrary depth, updating keys. The ""merge_dct"" is merged into "dct". 52 | ''' 53 | if merge_dct is None: 54 | if re: 55 | return dct 56 | else: 57 | return 58 | for k, v in merge_dct.items(): 59 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct(k), collections.Mapping)): 60 | dict_merge(dct[k], merge_dct[k]) 61 | else: 62 | try: 63 | dct[k] = type(dct[k])(eval(merge_dct[k])) if type(dct[k]) in (bool, list) else type(dct[k])( 64 | merge_dct[k]) 65 | except: 66 | dct[k] = merge_dct[k] 67 | if re: 68 | return dcopy(dct) 69 | 70 | ##### Timer ###### 71 | def now_time(): 72 | time = datetime.datetime.now() 73 | return str(time)[:19] 74 | 75 | ##### Progress Bar ##### 76 | 77 | tqdm_ = partial(tqdm, ncols=125, leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [' '{rate_fmt}{postfix}]') 78 | 79 | ##### Coding ##### 80 | def class2one_hot(seg: torch.Tensor, num_class: int) -> torch.Tensor: 81 | ''' 82 | [b, w, h] containing (0, 1, ..., c) -> [b, c, w, h] containing (0, 1) 83 | ''' 84 | if len(seg.shape) == 2: 85 | seg = seg.unsqueeze(dim=0) # Must 3 dim 86 | if len(seg.shape) == 4: 87 | seg = seg.squeeze(dim=1) 88 | assert sset(seg, list(range(num_class))), 'The value of segmentation outside the num_class!' 89 | b, w, h = seg.shape # Tuple [int, int, int] 90 | res = torch.stack([seg == c for c in range(num_class)], dim=1).type(torch.int32) 91 | assert res.shape == (b, num_class, w, h) 92 | assert one_hot(res) 93 | 94 | return res 95 | 96 | def probs2class(probs: torch.Tensor) -> torch.Tensor: 97 | ''' 98 | [b, c, w, h] containing(float in range(0, 1)) -> [b, w, h] containing ([0, 1, ..., c]) 99 | ''' 100 | b, _, w, h = probs.shape 101 | assert simplex(probs), '{} is not a probability'.format(probs) 102 | res = probs.argmax(dim=1) 103 | assert res.shape == (b, w, h) 104 | 105 | return res 106 | 107 | def probs2one_hot(probs: torch.Tensor) -> torch.Tensor: 108 | _, num_class, _, _ = probs.shape 109 | assert simplex(probs), '{} is not a probability'.format(probs) 110 | res = class2one_hot(probs2class(probs), num_class) 111 | assert res.shape == probs.shape 112 | assert one_hot(res) 113 | return res 114 | 115 | def label_onehot(inputs, num_class): 116 | ''' 117 | inputs is class label 118 | return one_hot label 119 | dim will be increasee 120 | ''' 121 | batch_size, image_h, image_w = inputs.shape 122 | inputs = torch.relu(inputs) 123 | outputs = torch.zeros([batch_size, num_class, image_h, image_w]).to(inputs.device) 124 | return outputs.scatter_(1, inputs.unsqueeze(1), 1.0) 125 | 126 | 127 | def uniq(a: torch.Tensor) -> Set: 128 | return set(torch.unique(a.cpu()).numpy()) 129 | 130 | def sset(a: torch.Tensor, sub: Iterable) -> bool: 131 | return uniq(a).issubset(sub) 132 | 133 | def simplex(t: torch.Tensor, axis=1) -> bool: 134 | ''' 135 | Check if the maticx is the probability in axis dimension. 136 | ''' 137 | _sum = t.sum(axis).type(torch.float32) 138 | _ones = torch.ones_like(_sum, dtype=torch.float32) 139 | return torch.allclose(_sum, _ones) 140 | 141 | def one_hot(t: torch.Tensor, axis=1) -> bool: 142 | ''' 143 | Check if the Tensor is One-hot coding 144 | ''' 145 | return simplex(t, axis) and sset(t, [0, 1]) 146 | 147 | def intersection(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 148 | ''' 149 | a and b must only contain 0 or 1, the function compute the intersection of two tensor. 150 | a & b 151 | ''' 152 | assert a.shape == b.shape, '{}.shape must be the same as {}'.format(a, b) 153 | assert sset(a, [0, 1]), '{} must only contain 0, 1'.format(a) 154 | assert sset(b, [0, 1]), '{} must only contain 0, 1'.format(b) 155 | return a & b 156 | 157 | class iterator_(object): 158 | def __init__(self, dataloader: DataLoader) -> None: 159 | super().__init__() 160 | self.dataloader = dcopy(dataloader) 161 | self.iter_dataloader = iter(dataloader) 162 | self.cache = None 163 | 164 | def __next__(self): 165 | try: 166 | self.cache = self.iter_dataloader.__next__() 167 | return self.cache 168 | except StopIteration: 169 | self.iter_dataloader = iter(self.dataloader) 170 | self.cache = self.iter_dataloader.__next__() 171 | return self.cache 172 | def __cache__(self): 173 | if self.cache is not None: 174 | return self.cache 175 | else: 176 | warnings.warn('No cache found ,iterator forward') 177 | return self.__next__() 178 | 179 | def apply_dropout(m): 180 | if type(m) == nn.Dropout2d: 181 | m.train() 182 | 183 | ##### Scheduler ##### 184 | class RampUpScheduler(): 185 | def __init__(self, begin_epoch, max_epoch, max_value, ramp_mult): 186 | super().__init__() 187 | self.begin_epoch = begin_epoch 188 | self.max_epoch = max_epoch 189 | self.ramp_mult = ramp_mult 190 | self.max_value = max_value 191 | self.epoch = 0 192 | 193 | def step(self): 194 | self.epoch += 1 195 | 196 | @property 197 | def value(self): 198 | return self.get_lr(self.epoch, self.begin_epoch, self.max_epoch, self.max_value,self.ramp_mult) 199 | 200 | def get_lr(self, epoch, begin_epoch, max_epochs, max_val, mult): 201 | if epoch < begin_epoch: 202 | return 0. 203 | elif epoch >= max_epochs: 204 | return max_val 205 | return max_val * np.exp(mult * (1 - float(epoch - begin_epoch) / (max_epochs - begin_epoch)) ** 2 ) 206 | 207 | 208 | ##### Compute mIoU ##### 209 | def mask_label(label, mask): 210 | ''' 211 | label is the original label (contains -1), mask is the valid region in pseudo label (type=long) 212 | return a label with invalid region = -1 213 | ''' 214 | label_tmp = label.clone() 215 | mask_ = (1 - mask.float()).bool() 216 | label_tmp[mask_] = -1 217 | return label_tmp.long() -------------------------------------------------------------------------------- /generalframeworks/networks/ddp_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | from generalframeworks.networks.deeplabv3.deeplabv3 import DeepLabv3Plus_with_un 5 | import torch.nn.functional as F 6 | from generalframeworks.dataset_helpers.VOC import batch_transform, generate_cut_gather, generate_cut 7 | from generalframeworks.networks.uncer_head import Uncertainty_head 8 | 9 | class Model_with_un(nn.Module): 10 | ''' 11 | Build a model for DDP with: a DeepLabV3_Plus, a ema, and a mlp 12 | ''' 13 | 14 | def __init__(self, base_encoder, num_classes=21, output_dim=256, ema_alpha=0.99, config=None) -> None: 15 | super(Model_with_un, self).__init__() 16 | self.model = DeepLabv3Plus_with_un(base_encoder, num_classes=num_classes, output_dim=output_dim) 17 | ##### Init EMA ##### 18 | self.step = 0 19 | self.ema_model = copy.deepcopy(self.model) 20 | for p in self.ema_model.parameters(): 21 | p.requires_grad = False 22 | self.alpha = ema_alpha 23 | print('EMA model has been prepared. Alpha = {}'.format(self.alpha)) 24 | 25 | ##### Init Uncertainty Head ##### 26 | self.uncer_head = Uncertainty_head() 27 | 28 | self.config = config 29 | 30 | def ema_update(self): 31 | decay = min(1 - 1 / (self.step + 1), self.alpha) 32 | for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()): 33 | ema_param.data = decay * ema_param.data + (1 - decay) * param.data 34 | self.step += 1 35 | 36 | def forward(self, train_l_image, train_u_image): 37 | ##### generate pseudo label ##### 38 | with torch.no_grad(): 39 | pred_u, _, _ = self.ema_model(train_u_image) 40 | pred_u_large_raw = F.interpolate(pred_u, size=train_u_image.shape[2:], mode='bilinear', align_corners=True) 41 | pseudo_logits, pseudo_labels = torch.max(torch.softmax(pred_u_large_raw, dim=1), dim=1) 42 | 43 | # Randomly scale images 44 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = batch_transform(train_u_image, pseudo_labels, 45 | pseudo_logits, 46 | crop_size=self.config['Dataset']['crop_size'], 47 | scale_size=self.config['Dataset']['scale_size'], 48 | augmentation=False) 49 | # Apply mixing strategy, we gather all images cross mutiple GPUs during this progress 50 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = generate_cut_gather(train_u_aug_image, 51 | train_u_aug_label, 52 | train_u_aug_logits, 53 | mode=self.config['Dataset'][ 54 | 'mix_mode']) 55 | # Apply augmnetation : color jitter + flip + gaussian blur 56 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = batch_transform(train_u_aug_image, 57 | train_u_aug_label, 58 | train_u_aug_logits, 59 | crop_size=self.config['Dataset']['crop_size'], 60 | scale_size=(1.0, 1.0), 61 | augmentation=True) 62 | 63 | 64 | pred_l, rep_l, raw_feat_l = self.model(train_l_image) 65 | pred_l_large = F.interpolate(pred_l, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 66 | 67 | pred_u, rep_u, raw_feat_u = self.model(train_u_aug_image) 68 | pred_u_large = F.interpolate(pred_u, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 69 | 70 | rep_all = torch.cat((rep_l, rep_u)) 71 | pred_all = torch.cat((pred_l, pred_u)) 72 | 73 | uncer_all = self.uncer_head(torch.cat((raw_feat_l, raw_feat_u), dim=0)) 74 | 75 | return pred_l_large, pred_u_large, train_u_aug_label, train_u_aug_logits, rep_all, pred_all, pred_u_large_raw, uncer_all 76 | 77 | class Model_with_un_single(nn.Module): 78 | ''' 79 | Build a model for DDP with: a DeepLabV3_Plus, a ema, and a mlp 80 | This model is for single GPU user! 81 | ''' 82 | 83 | def __init__(self, base_encoder, num_classes=21, output_dim=256, ema_alpha=0.99, config=None) -> None: 84 | super(Model_with_un_single, self).__init__() 85 | self.model = DeepLabv3Plus_with_un(base_encoder, num_classes=num_classes, output_dim=output_dim) 86 | ##### Init EMA ##### 87 | self.step = 0 88 | self.ema_model = copy.deepcopy(self.model) 89 | for p in self.ema_model.parameters(): 90 | p.requires_grad = False 91 | self.alpha = ema_alpha 92 | print('EMA model has been prepared. Alpha = {}'.format(self.alpha)) 93 | 94 | ##### Init Uncertainty Head ##### 95 | self.uncer_head = Uncertainty_head() 96 | 97 | self.config = config 98 | 99 | def ema_update(self): 100 | decay = min(1 - 1 / (self.step + 1), self.alpha) 101 | for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()): 102 | ema_param.data = decay * ema_param.data + (1 - decay) * param.data 103 | self.step += 1 104 | 105 | def forward(self, train_l_image, train_u_image): 106 | ##### generate pseudo label ##### 107 | with torch.no_grad(): 108 | pred_u, _, _ = self.ema_model(train_u_image) 109 | pred_u_large_raw = F.interpolate(pred_u, size=train_u_image.shape[2:], mode='bilinear', align_corners=True) 110 | pseudo_logits, pseudo_labels = torch.max(torch.softmax(pred_u_large_raw, dim=1), dim=1) 111 | 112 | # Randomly scale images 113 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = batch_transform(train_u_image, pseudo_labels, 114 | pseudo_logits, 115 | crop_size=self.config['Dataset']['crop_size'], 116 | scale_size=self.config['Dataset']['scale_size'], 117 | augmentation=False) 118 | # Apply mixing strategy with single GPU 119 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = generate_cut(train_u_aug_image, 120 | train_u_aug_label, 121 | train_u_aug_logits, 122 | mode=self.config['Dataset'][ 123 | 'mix_mode']) 124 | # Apply augmnetation : color jitter + flip + gaussian blur 125 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = batch_transform(train_u_aug_image, 126 | train_u_aug_label, 127 | train_u_aug_logits, 128 | crop_size=self.config['Dataset']['crop_size'], 129 | scale_size=(1.0, 1.0), 130 | augmentation=True) 131 | 132 | 133 | pred_l, rep_l, raw_feat_l = self.model(train_l_image) 134 | pred_l_large = F.interpolate(pred_l, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 135 | 136 | pred_u, rep_u, raw_feat_u = self.model(train_u_aug_image) 137 | pred_u_large = F.interpolate(pred_u, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 138 | 139 | rep_all = torch.cat((rep_l, rep_u)) 140 | pred_all = torch.cat((pred_l, pred_u)) 141 | 142 | log_uncer_all = self.uncer_head(torch.cat((raw_feat_l, raw_feat_u), dim=0)) 143 | # uncer_all = torch.exp(log_uncer_all) 144 | uncer_all = log_uncer_all 145 | 146 | return pred_l_large, pred_u_large, train_u_aug_label, train_u_aug_logits, rep_all, pred_all, pred_u_large_raw, uncer_all 147 | # utils 148 | @torch.no_grad() 149 | def concat_all_gather(tensor): 150 | """ 151 | Performs all_gather operation on the provided tensors. 152 | Warning: torch.distributed.all_ather has no gradient. 153 | """ 154 | tensor_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] 155 | torch.distributed.all_gather(tensor_gather, tensor, async_op=False) 156 | output = torch.cat(tensor_gather, dim=0) 157 | 158 | return output -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/Cityscapes.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import os 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as transforms_f 6 | import random 7 | from PIL import Image, ImageFilter 8 | import numpy as np 9 | 10 | class Cityscapes_Dataset_cache(data.Dataset): 11 | def __init__(self, root, idx_list, crop_size=(512, 512), scale_size=(0.5, 2.0), augmentation=True, train=True, 12 | apply_partial=None, partial_seed=None): 13 | self.root = os.path.expanduser(root) 14 | self.train = train 15 | self.crop_size = crop_size 16 | self.augmentation = augmentation 17 | self.scale_size = scale_size 18 | self.idx_list = idx_list 19 | self.apply_partial = apply_partial 20 | self.partial_seed = partial_seed 21 | 22 | 23 | def __getitem__(self, index): 24 | if self.train: 25 | image_root, city_name = image_root_transform(self.idx_list[index], mode='train') 26 | image = Image.open(self.root + image_root) 27 | label_root = label_root_transform(self.idx_list[index], city_name, mode='train') 28 | label = Image.open(self.root + label_root) 29 | label = Image.fromarray(cityscapes_class_map(np.array(label))) 30 | else: 31 | image_root, city_name = image_root_transform(self.idx_list[index], mode='val') 32 | image = Image.open(self.root + image_root) 33 | label_root = label_root_transform(self.idx_list[index], city_name, mode='val') 34 | label = Image.open(self.root + label_root) 35 | label = Image.fromarray(cityscapes_class_map(np.array(label))) 36 | image, label = transform(image, label, None, self.crop_size, self.scale_size, self.augmentation) 37 | return image, label.squeeze(0) 38 | 39 | def __len__(self): 40 | return len(self.idx_list) 41 | 42 | class Cityscapes_Dataset(data.Dataset): 43 | def __init__(self, root, idx_list, crop_size=(512, 512), scale_size=(0.5, 2.0), augmentation=True, train=True): 44 | self.root = os.path.expanduser(root) 45 | self.train = train 46 | self.crop_size = crop_size 47 | self.augmentation = augmentation 48 | self.scale_size = scale_size 49 | self.idx_list = idx_list 50 | 51 | def __getitem__(self, index): 52 | if self.train: 53 | image_root, city_name = image_root_transform(self.idx_list[index], mode='train') 54 | image = Image.open(self.root + image_root) 55 | label_root = label_root_transform(self.idx_list[index], city_name, mode='train') 56 | label = Image.open(self.root + label_root) 57 | label = Image.fromarray(cityscapes_class_map(np.array(label))) 58 | else: 59 | image_root, city_name = image_root_transform(self.idx_list[index], mode='val') 60 | image = Image.open(self.root + image_root) 61 | label_root = label_root_transform(self.idx_list[index], city_name, mode='val') 62 | label = Image.open(self.root + label_root) 63 | label = Image.fromarray(cityscapes_class_map(np.array(label))) 64 | image, label = transform(image, label, None, self.crop_size, self.scale_size, self.augmentation) 65 | return image, label.squeeze(0) 66 | 67 | def __len__(self): 68 | return len(self.idx_list) 69 | 70 | class City_BuildData(): 71 | def __init__(self, data_path, txt_path, label_num, seed): 72 | self.data_path = data_path 73 | self.txt_path = txt_path 74 | self.label_num = label_num 75 | self.seed = seed 76 | self.im_size = [512, 1024] 77 | self.crop_size = [512, 512] 78 | self.num_segments = 19 79 | self.scale_size = (1.0, 1.0) 80 | self.train_l_idx, self.train_u_idx, self.test_idx= get_cityscapes_idx_via_txt(self.txt_path, self.label_num, self.seed) 81 | 82 | def build(self): 83 | train_l_dataset = Cityscapes_Dataset(self.data_path, self.train_l_idx, self.crop_size, self.scale_size, 84 | augmentation=True, train=True) 85 | train_u_dataset = Cityscapes_Dataset(self.data_path, self.train_u_idx, self.crop_size, scale_size=(1.0, 1.0), 86 | augmentation=False, train=True) 87 | test_dataset = Cityscapes_Dataset(self.data_path, self.test_idx, self.crop_size, scale_size=(1.0, 1.0),augmentation=False, 88 | train=False) 89 | return train_l_dataset, train_u_dataset, test_dataset 90 | 91 | def get_cityscapes_idx_via_txt(root, label_num, seed): 92 | ''' 93 | Read idx list via generated txt, pre-perform make_list.py 94 | ''' 95 | root = root + '/' + str(label_num) + '/' + str(seed) 96 | with open(root + '/labeled_filename.txt') as f: 97 | labeled_list = f.read().splitlines() 98 | f.close() 99 | with open(root + '/unlabeled_filename.txt') as f: 100 | unlabeled_list = f.read().splitlines() 101 | f.close() 102 | with open(root + '/valid_filename.txt') as f: 103 | test_list = f.read().splitlines() 104 | f.close() 105 | return labeled_list, unlabeled_list, test_list 106 | 107 | def transform(image, label, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 108 | # Randomly rescale images 109 | raw_w, raw_h = image.size 110 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 111 | 112 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 113 | image = transforms_f.resize(image, resized_size, Image.BILINEAR) 114 | label = transforms_f.resize(label, resized_size, Image.NEAREST) 115 | if logits is not None: 116 | logits = transforms_f.resize(logits, resized_size, Image.NEAREST) 117 | 118 | # Add padding if rescaled image is smaller than crop size 119 | if crop_size == -1: # Use original image size 120 | crop_size = (raw_w, raw_h) 121 | 122 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 123 | right_pad = max(crop_size[1] - resized_size[1], 0) 124 | bottom_pad = max(crop_size[0] - resized_size[0], 0) 125 | image = transforms_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 126 | label = transforms_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 127 | if logits is not None: 128 | logits = transforms_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 129 | 130 | # Randomly crop images 131 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 132 | image = transforms_f.crop(image, i, j, h, w) 133 | label = transforms_f.crop(label, i, j, h, w) 134 | if logits is not None: 135 | logits = transforms_f.crop(logits, i, j, h, w) 136 | 137 | if augmentation: 138 | # Random color jittering 139 | if torch.rand(1) > 0.2: 140 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 141 | image = color_transform(image) 142 | 143 | # Random Gaussian filtering 144 | if torch.rand(1) > 0.5: 145 | sigma = random.uniform(0.15, 1.15) 146 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 147 | 148 | # Random horizontal flipping 149 | if torch.rand(1) > 0.5: 150 | image = transforms_f.hflip(image) 151 | label = transforms_f.hflip(label) 152 | if logits is not None: 153 | logits = transforms_f.hflip(logits) 154 | 155 | # Transform to Tensor 156 | image = transforms_f.to_tensor(image) 157 | label = (transforms_f.to_tensor(label) * 255).long() 158 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 159 | if logits is not None: 160 | logits = transforms_f.to_tensor(logits) 161 | 162 | # Apply ImageNet normalization 163 | image = transforms_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 164 | if logits is not None: 165 | return image, label, logits 166 | else: 167 | return image, label 168 | 169 | def tensor_to_pil(image, label, logits): 170 | image = denormalise(image) 171 | image = transforms_f.to_pil_image(image.cpu()) 172 | label = label.float() / 255. 173 | label = transforms_f.to_pil_image(label.unsqueeze(0).cpu()) 174 | logits = transforms_f.to_pil_image(logits.unsqueeze(0).cpu()) 175 | return image, label, logits 176 | 177 | def denormalise(x, imagenet=True): 178 | if imagenet: 179 | x = transforms_f.normalize(x, mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]) 180 | x = transforms_f.normalize(x, mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]) 181 | return x 182 | else: 183 | return (x + 1) / 2 184 | 185 | def batch_transform(images, labels, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 186 | image_list, label_list, logits_list = [], [], [] 187 | device = images.device 188 | for k in range(images.shape[0]): 189 | image_pil, label_pil, logits_pil = tensor_to_pil(images[k], labels[k], logits[k]) 190 | aug_image, aug_label, aug_logits = transform(image_pil, label_pil, logits_pil, crop_size, scale_size, augmentation) 191 | image_list.append(aug_image.unsqueeze(0)) 192 | label_list.append(aug_label) 193 | logits_list.append(aug_logits) 194 | 195 | image_trans, label_trans, logits_trans = torch.cat(image_list).to(device), torch.cat(label_list).to(device), torch.cat(logits_list).to(device) 196 | return image_trans, label_trans, logits_trans 197 | 198 | def cityscapes_class_map(mask): 199 | # source: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 200 | mask_map = np.zeros_like(mask) 201 | mask_map[np.isin(mask, [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30])] = 255 202 | mask_map[np.isin(mask, [7])] = 0 203 | mask_map[np.isin(mask, [8])] = 1 204 | mask_map[np.isin(mask, [11])] = 2 205 | mask_map[np.isin(mask, [12])] = 3 206 | mask_map[np.isin(mask, [13])] = 4 207 | mask_map[np.isin(mask, [17])] = 5 208 | mask_map[np.isin(mask, [19])] = 6 209 | mask_map[np.isin(mask, [20])] = 7 210 | mask_map[np.isin(mask, [21])] = 8 211 | mask_map[np.isin(mask, [22])] = 9 212 | mask_map[np.isin(mask, [23])] = 10 213 | mask_map[np.isin(mask, [24])] = 11 214 | mask_map[np.isin(mask, [25])] = 12 215 | mask_map[np.isin(mask, [26])] = 13 216 | mask_map[np.isin(mask, [27])] = 14 217 | mask_map[np.isin(mask, [28])] = 15 218 | mask_map[np.isin(mask, [31])] = 16 219 | mask_map[np.isin(mask, [32])] = 17 220 | mask_map[np.isin(mask, [33])] = 18 221 | return mask_map 222 | 223 | def label_root_transform(root: str, name: str, mode: str): 224 | label_root = root.strip()[0: -12] + '_gtFine_labelIds' 225 | return '/{}/{}/{}.png'.format(mode, name, label_root) 226 | 227 | def image_root_transform(root: str, mode: str): 228 | name = root[0: root.find('_')] 229 | return '/leftImg8bit/{}/{}/{}.png'.format(mode, name, root), name -------------------------------------------------------------------------------- /VOC_split/662/3407/labeled_filename.txt: -------------------------------------------------------------------------------- 1 | 2011_000068 2 | 2008_004588 3 | 2009_001664 4 | 2011_002350 5 | 2011_000895 6 | 2010_002733 7 | 2007_002055 8 | 2010_000043 9 | 2008_004080 10 | 2011_000790 11 | 2007_003431 12 | 2008_000162 13 | 2007_007591 14 | 2009_002083 15 | 2010_000132 16 | 2008_003947 17 | 2009_002422 18 | 2010_003174 19 | 2007_006832 20 | 2010_004478 21 | 2011_003078 22 | 2007_000768 23 | 2008_002255 24 | 2009_003200 25 | 2008_006032 26 | 2008_000860 27 | 2009_002713 28 | 2008_000217 29 | 2007_004951 30 | 2007_009436 31 | 2010_005932 32 | 2010_005927 33 | 2009_004643 34 | 2007_008526 35 | 2007_008948 36 | 2007_000068 37 | 2007_000333 38 | 2008_002073 39 | 2008_002247 40 | 2008_000348 41 | 2008_000361 42 | 2008_006655 43 | 2009_000100 44 | 2008_001592 45 | 2010_003062 46 | 2011_001922 47 | 2011_000145 48 | 2008_001876 49 | 2011_002752 50 | 2011_003216 51 | 2008_001498 52 | 2009_002472 53 | 2008_007357 54 | 2007_001602 55 | 2009_001744 56 | 2009_003539 57 | 2011_002503 58 | 2007_009665 59 | 2008_005214 60 | 2007_004476 61 | 2009_000887 62 | 2008_007691 63 | 2010_005627 64 | 2010_004948 65 | 2007_000241 66 | 2010_005318 67 | 2011_002398 68 | 2009_001095 69 | 2007_004500 70 | 2008_008550 71 | 2009_001516 72 | 2008_008343 73 | 2010_000746 74 | 2009_000176 75 | 2008_001106 76 | 2010_000748 77 | 2010_000371 78 | 2008_000645 79 | 2007_003715 80 | 2008_004607 81 | 2007_001698 82 | 2009_001640 83 | 2009_003087 84 | 2007_001340 85 | 2010_003384 86 | 2009_003711 87 | 2009_001894 88 | 2009_002281 89 | 2010_004620 90 | 2007_000549 91 | 2009_004990 92 | 2007_003118 93 | 2008_000832 94 | 2011_000105 95 | 2008_001479 96 | 2007_007908 97 | 2008_003939 98 | 2009_005194 99 | 2009_001203 100 | 2010_002039 101 | 2009_001514 102 | 2010_002499 103 | 2008_004259 104 | 2009_005234 105 | 2007_007387 106 | 2010_005098 107 | 2009_001937 108 | 2008_007201 109 | 2007_002368 110 | 2009_004249 111 | 2007_006254 112 | 2008_002066 113 | 2009_001124 114 | 2010_002047 115 | 2010_005775 116 | 2009_001783 117 | 2007_009788 118 | 2009_001117 119 | 2008_005839 120 | 2010_003250 121 | 2007_006303 122 | 2009_000028 123 | 2009_000250 124 | 2008_008462 125 | 2008_000711 126 | 2007_005262 127 | 2011_000152 128 | 2007_009807 129 | 2008_008545 130 | 2010_000847 131 | 2008_000259 132 | 2007_003191 133 | 2007_005360 134 | 2010_005506 135 | 2010_002838 136 | 2009_003146 137 | 2007_002142 138 | 2008_003415 139 | 2010_005223 140 | 2009_001197 141 | 2008_004112 142 | 2009_005056 143 | 2009_003799 144 | 2011_002224 145 | 2007_003604 146 | 2011_000457 147 | 2007_002212 148 | 2010_001282 149 | 2009_000987 150 | 2008_002032 151 | 2007_008927 152 | 2010_004109 153 | 2007_008994 154 | 2011_000228 155 | 2009_003519 156 | 2007_007230 157 | 2008_000422 158 | 2009_004301 159 | 2009_002153 160 | 2007_009832 161 | 2007_003876 162 | 2009_003768 163 | 2008_007142 164 | 2007_009554 165 | 2009_005000 166 | 2010_006009 167 | 2008_000436 168 | 2010_002413 169 | 2011_000345 170 | 2010_001154 171 | 2009_000347 172 | 2010_000661 173 | 2008_005716 174 | 2008_006873 175 | 2010_000187 176 | 2009_003075 177 | 2007_009724 178 | 2007_007585 179 | 2009_002314 180 | 2010_002794 181 | 2008_000764 182 | 2010_001561 183 | 2009_001096 184 | 2008_000778 185 | 2007_006136 186 | 2010_001261 187 | 2008_005698 188 | 2010_002107 189 | 2010_003380 190 | 2010_001630 191 | 2011_002291 192 | 2007_004291 193 | 2009_001544 194 | 2010_005016 195 | 2007_002234 196 | 2011_001967 197 | 2009_001828 198 | 2010_003954 199 | 2008_006140 200 | 2010_000978 201 | 2008_005266 202 | 2008_007012 203 | 2011_001133 204 | 2008_000733 205 | 2010_004704 206 | 2007_000720 207 | 2010_000437 208 | 2010_003342 209 | 2011_001463 210 | 2009_000626 211 | 2008_005714 212 | 2009_005236 213 | 2008_006558 214 | 2010_001576 215 | 2009_001625 216 | 2009_001283 217 | 2010_005800 218 | 2009_001251 219 | 2009_003249 220 | 2009_001359 221 | 2008_000544 222 | 2009_004890 223 | 2009_001871 224 | 2011_000573 225 | 2010_002236 226 | 2007_009580 227 | 2010_002720 228 | 2011_002488 229 | 2008_001566 230 | 2009_001782 231 | 2010_000519 232 | 2008_006221 233 | 2007_005264 234 | 2009_003646 235 | 2008_000515 236 | 2010_005836 237 | 2010_002154 238 | 2011_000382 239 | 2008_007858 240 | 2010_003680 241 | 2008_005668 242 | 2007_001917 243 | 2008_001462 244 | 2008_000226 245 | 2010_004960 246 | 2009_004446 247 | 2011_001652 248 | 2008_002425 249 | 2008_000785 250 | 2007_002760 251 | 2010_000269 252 | 2008_002885 253 | 2011_002227 254 | 2007_003541 255 | 2007_009052 256 | 2007_007480 257 | 2009_000285 258 | 2008_000841 259 | 2010_000810 260 | 2008_001399 261 | 2011_003255 262 | 2007_002293 263 | 2009_003555 264 | 2010_000002 265 | 2009_001390 266 | 2009_002419 267 | 2011_002709 268 | 2008_001829 269 | 2008_003429 270 | 2009_002216 271 | 2007_000039 272 | 2009_001385 273 | 2009_003142 274 | 2008_005945 275 | 2009_003345 276 | 2007_007726 277 | 2010_005820 278 | 2010_004805 279 | 2010_004963 280 | 2010_003534 281 | 2011_002134 282 | 2007_006483 283 | 2008_007165 284 | 2011_001015 285 | 2011_001519 286 | 2009_001104 287 | 2008_007011 288 | 2009_001403 289 | 2007_008778 290 | 2009_003734 291 | 2009_003340 292 | 2010_004074 293 | 2010_003010 294 | 2009_002010 295 | 2010_000114 296 | 2007_009594 297 | 2011_003038 298 | 2009_002460 299 | 2010_004370 300 | 2010_000503 301 | 2011_002447 302 | 2007_009030 303 | 2011_001259 304 | 2010_004938 305 | 2010_004773 306 | 2009_002448 307 | 2011_001412 308 | 2010_005805 309 | 2011_002410 310 | 2007_009550 311 | 2009_000603 312 | 2009_002264 313 | 2009_004178 314 | 2009_004213 315 | 2010_005129 316 | 2011_000577 317 | 2011_000222 318 | 2010_000675 319 | 2009_002416 320 | 2009_002530 321 | 2008_000207 322 | 2008_004892 323 | 2008_007581 324 | 2010_001413 325 | 2008_004321 326 | 2010_005317 327 | 2007_001420 328 | 2009_000405 329 | 2008_008511 330 | 2011_000108 331 | 2008_006434 332 | 2007_009605 333 | 2010_004258 334 | 2007_001709 335 | 2007_008085 336 | 2009_004561 337 | 2008_003913 338 | 2010_004493 339 | 2007_002611 340 | 2007_009322 341 | 2009_002844 342 | 2007_000793 343 | 2010_000685 344 | 2010_002811 345 | 2010_005746 346 | 2008_003769 347 | 2010_005830 348 | 2011_001336 349 | 2011_001571 350 | 2010_001184 351 | 2008_005367 352 | 2008_000089 353 | 2008_006389 354 | 2011_001730 355 | 2010_001273 356 | 2010_000466 357 | 2010_005669 358 | 2007_007902 359 | 2007_001073 360 | 2007_008468 361 | 2008_007472 362 | 2008_006215 363 | 2011_003066 364 | 2009_001027 365 | 2008_007242 366 | 2011_003151 367 | 2007_001872 368 | 2009_005269 369 | 2010_005725 370 | 2011_000025 371 | 2011_000542 372 | 2007_002361 373 | 2007_000170 374 | 2008_004365 375 | 2011_000646 376 | 2007_007355 377 | 2008_001263 378 | 2007_000836 379 | 2010_004069 380 | 2009_004095 381 | 2010_000772 382 | 2009_002586 383 | 2010_000492 384 | 2009_002820 385 | 2011_001622 386 | 2010_002938 387 | 2011_002114 388 | 2011_001991 389 | 2008_008106 390 | 2011_002111 391 | 2011_001270 392 | 2010_003696 393 | 2011_001904 394 | 2008_008773 395 | 2007_009597 396 | 2008_000336 397 | 2010_000588 398 | 2010_001457 399 | 2009_000774 400 | 2008_000033 401 | 2009_003933 402 | 2011_000840 403 | 2007_004707 404 | 2008_003068 405 | 2008_005345 406 | 2007_005878 407 | 2007_004289 408 | 2007_001834 409 | 2009_000906 410 | 2010_002937 411 | 2010_003097 412 | 2009_003090 413 | 2010_001043 414 | 2007_003205 415 | 2007_008203 416 | 2007_008140 417 | 2011_000713 418 | 2007_007481 419 | 2007_009139 420 | 2008_006920 421 | 2009_004368 422 | 2007_009889 423 | 2008_000870 424 | 2008_001112 425 | 2011_000469 426 | 2011_000882 427 | 2007_004627 428 | 2008_002200 429 | 2007_002896 430 | 2010_004144 431 | 2009_002409 432 | 2010_004916 433 | 2008_006065 434 | 2007_003778 435 | 2010_003093 436 | 2009_002988 437 | 2010_001944 438 | 2007_000645 439 | 2010_004060 440 | 2007_006281 441 | 2009_000133 442 | 2009_001146 443 | 2008_007433 444 | 2008_003814 445 | 2010_001514 446 | 2008_000676 447 | 2010_000148 448 | 2009_000684 449 | 2008_001159 450 | 2009_005085 451 | 2007_001960 452 | 2007_004768 453 | 2009_001311 454 | 2007_009759 455 | 2007_005130 456 | 2008_006345 457 | 2011_001928 458 | 2009_001100 459 | 2008_004416 460 | 2007_009649 461 | 2009_005055 462 | 2010_000195 463 | 2008_005926 464 | 2008_005770 465 | 2008_008770 466 | 2007_007447 467 | 2007_008945 468 | 2007_002120 469 | 2011_001027 470 | 2007_007891 471 | 2007_008571 472 | 2009_003636 473 | 2009_002262 474 | 2010_005663 475 | 2008_004838 476 | 2007_004537 477 | 2011_001810 478 | 2007_007890 479 | 2010_001590 480 | 2008_001118 481 | 2008_002894 482 | 2009_001964 483 | 2008_001408 484 | 2007_006151 485 | 2007_004481 486 | 2010_005758 487 | 2007_000584 488 | 2009_000662 489 | 2008_000495 490 | 2007_002845 491 | 2010_002455 492 | 2009_000938 493 | 2007_001225 494 | 2007_002967 495 | 2011_002050 496 | 2007_004166 497 | 2008_004430 498 | 2008_007245 499 | 2010_003345 500 | 2007_007930 501 | 2008_001413 502 | 2009_004374 503 | 2008_003087 504 | 2008_003562 505 | 2008_000273 506 | 2010_005119 507 | 2009_002984 508 | 2008_005294 509 | 2007_000876 510 | 2008_005678 511 | 2011_001139 512 | 2008_001358 513 | 2010_002935 514 | 2009_003865 515 | 2010_001279 516 | 2008_002218 517 | 2011_001198 518 | 2009_002932 519 | 2007_007523 520 | 2009_002245 521 | 2007_003788 522 | 2010_005202 523 | 2011_000182 524 | 2009_000895 525 | 2010_003230 526 | 2009_003353 527 | 2007_007098 528 | 2007_004810 529 | 2009_001388 530 | 2011_000999 531 | 2008_001169 532 | 2008_000696 533 | 2009_003690 534 | 2010_002697 535 | 2007_009630 536 | 2008_000365 537 | 2008_008476 538 | 2009_001443 539 | 2008_000588 540 | 2010_004361 541 | 2008_001387 542 | 2008_001208 543 | 2007_004705 544 | 2009_001755 545 | 2009_003613 546 | 2008_000760 547 | 2009_000544 548 | 2011_001135 549 | 2009_003039 550 | 2010_003717 551 | 2011_002834 552 | 2009_003961 553 | 2010_005232 554 | 2010_000815 555 | 2011_001790 556 | 2008_007239 557 | 2008_006481 558 | 2010_002786 559 | 2009_002019 560 | 2007_006212 561 | 2008_001523 562 | 2010_002054 563 | 2010_002254 564 | 2009_004620 565 | 2010_003088 566 | 2007_006400 567 | 2007_006803 568 | 2007_006615 569 | 2010_003799 570 | 2010_002379 571 | 2008_003585 572 | 2010_002363 573 | 2010_004933 574 | 2008_000584 575 | 2007_000032 576 | 2008_002210 577 | 2008_003252 578 | 2007_009899 579 | 2011_001974 580 | 2008_002215 581 | 2010_000131 582 | 2008_003168 583 | 2008_003083 584 | 2007_003251 585 | 2007_006530 586 | 2007_005227 587 | 2009_000655 588 | 2007_005266 589 | 2007_002281 590 | 2010_001399 591 | 2008_001375 592 | 2008_003208 593 | 2007_005144 594 | 2008_005843 595 | 2010_000887 596 | 2007_003580 597 | 2011_002656 598 | 2010_004429 599 | 2011_000834 600 | 2010_000855 601 | 2010_005457 602 | 2010_002962 603 | 2009_005031 604 | 2009_004417 605 | 2008_003665 606 | 2007_001027 607 | 2008_000540 608 | 2011_001753 609 | 2008_001056 610 | 2008_000188 611 | 2007_002088 612 | 2011_000652 613 | 2007_000713 614 | 2009_001036 615 | 2008_002972 616 | 2008_003362 617 | 2007_005248 618 | 2011_000400 619 | 2009_002845 620 | 2010_005678 621 | 2008_000284 622 | 2011_000641 623 | 2010_005198 624 | 2009_001070 625 | 2007_007947 626 | 2010_005700 627 | 2007_005430 628 | 2007_001149 629 | 2007_006699 630 | 2009_001636 631 | 2007_000738 632 | 2008_006482 633 | 2007_007772 634 | 2009_000690 635 | 2007_002545 636 | 2008_008541 637 | 2008_001632 638 | 2009_001444 639 | 2011_002381 640 | 2010_005891 641 | 2008_000131 642 | 2011_001004 643 | 2008_003200 644 | 2011_002585 645 | 2011_000359 646 | 2007_002216 647 | 2011_001754 648 | 2009_001690 649 | 2009_003034 650 | 2009_003317 651 | 2009_002628 652 | 2007_007948 653 | 2007_006004 654 | 2009_003783 655 | 2007_007003 656 | 2011_002300 657 | 2009_000532 658 | 2008_008263 659 | 2008_002221 660 | 2009_001802 661 | 2007_009216 662 | 2011_001475 -------------------------------------------------------------------------------- /prcl_sig.py: -------------------------------------------------------------------------------- 1 | import shutup 2 | shutup.please() 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from generalframeworks.dataset_helpers.VOC import VOC_BuildData 7 | from generalframeworks.dataset_helpers.Cityscapes import City_BuildData 8 | from generalframeworks.networks.ddp_model import Model_with_un_single 9 | from generalframeworks.loss.loss import Attention_Threshold_Loss, Prcl_Loss_single 10 | from generalframeworks.scheduler.my_lr_scheduler import PolyLR 11 | from generalframeworks.scheduler.rampscheduler import RampdownScheduler 12 | from generalframeworks.utils import iterator_ 13 | from generalframeworks.util.meter import * 14 | from generalframeworks.utils import label_onehot 15 | from generalframeworks.util.torch_dist_sum import * 16 | from generalframeworks.util.miou import * 17 | import yaml 18 | import os 19 | import time 20 | import torchvision.models as models 21 | from generalframeworks.networks import resnet 22 | import argparse 23 | import random 24 | 25 | def main(args): 26 | ##### Config init ##### 27 | with open(args.config, 'r') as f: 28 | config = yaml.load(f.read(), Loader=yaml.FullLoader) 29 | save_dir = './checkpoints/' + str(args.job_name) 30 | if not os.path.exists(save_dir): 31 | os.makedirs(save_dir) 32 | with open(save_dir + '/config.yaml', 'w') as f: 33 | yaml.dump(config, f, default_flow_style=False) 34 | print(config) 35 | 36 | ##### Init Seed ##### 37 | random.seed(config['Seed']) 38 | torch.manual_seed(config['Seed']) 39 | torch.backends.cudnn.deterministic = True 40 | 41 | ##### Load the dataset ##### 42 | if config['Dataset']['name'] == 'VOC': 43 | data = VOC_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 44 | label_num=config['Dataset']['num_labels'], seed=config['Seed']) 45 | if config['Dataset']['name'] == 'CityScapes': 46 | data = City_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 47 | label_num=config['Dataset']['num_labels'], seed=config['Seed']) 48 | train_l_dataset, train_u_dataset, test_dataset = data.build() 49 | train_l_sampler = torch.utils.data.RandomSampler(train_l_dataset) 50 | train_l_loader = torch.utils.data.DataLoader(train_l_dataset, 51 | batch_size=config['Dataset']['batch_size'], 52 | pin_memory=True, 53 | sampler=train_l_sampler, 54 | num_workers=4) 55 | train_u_sampler = torch.utils.data.RandomSampler(train_u_dataset) 56 | train_u_loader = torch.utils.data.DataLoader(train_u_dataset, 57 | batch_size=config['Dataset']['batch_size'], 58 | pin_memory=True, 59 | sampler=train_u_sampler, 60 | num_workers=4) 61 | test_loader = torch.utils.data.DataLoader(test_dataset, 62 | batch_size=config['Dataset']['batch_size'], 63 | pin_memory=True, 64 | num_workers=4) 65 | 66 | ##### Model init ##### 67 | backbone = models.resnet101() 68 | ckpt = torch.load('./pretrained/resnet101.pth', map_location='cpu') 69 | backbone.load_state_dict(ckpt) 70 | 71 | # for Resnet-101 stem users 72 | #backbone = resnet.resnet101(pretrained=True) 73 | 74 | model = Model_with_un_single(backbone, num_classes=config['Network']['num_class'], output_dim=256, ema_alpha=config['EMA']['alpha'], config=config).cuda() 75 | 76 | ##### Loss init ##### 77 | criterion = {'ce_loss': nn.CrossEntropyLoss(ignore_index=-1).cuda(), 78 | 'unsup_loss': Attention_Threshold_Loss(strong_threshold=config['Prcl_Loss']['un_threshold']).cuda(), 79 | 'prcl_loss': Prcl_Loss_single(strong_threshold=config['Prcl_Loss']['strong_threshold'], 80 | num_queries=config['Prcl_Loss']['num_queries'], 81 | num_negatives=config['Prcl_Loss']['num_negatives'], 82 | temp=config['Prcl_Loss']['temp']).cuda() 83 | } 84 | 85 | ##### Other init ##### 86 | optimizer = torch.optim.SGD(model.model.parameters(), lr=float(config['Optim']['lr']), weight_decay=float(config['Optim']['weight_decay']), 87 | momentum=0.9, nesterov=True) 88 | optimizer_uncer = torch.optim.SGD(model.uncer_head.parameters(), lr=float(config['Optim']['uncer_lr']), weight_decay=float(config['Optim']['weight_decay']), 89 | momentum=0.9, nesterov=True) 90 | total_epoch = config['Training_Setting']['epoch'] 91 | lr_scheduler = PolyLR(optimizer, total_epoch) 92 | lr_scheduler_uncer = PolyLR(optimizer_uncer, total_epoch) 93 | 94 | if os.path.exists(args.resume): 95 | print('resume from', args.resume) 96 | checkpoint = torch.load(args.resume, map_location='cpu') 97 | model.model.load_state_dict(checkpoint['model']) 98 | model.ema_model.load_state_dict(checkpoint['ema']) 99 | model.uncer_head.load_state_dict(checkpoint['uncer_head']) 100 | optimizer.load_state_dict(checkpoint['optimizer']) 101 | optimizer_uncer.load_state_dict(checkpoint['optimizer_uncer']) 102 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 103 | lr_scheduler_uncer.load_state_dict(checkpoint['lr_scheduler_uncer']) 104 | start_epoch = checkpoint['epoch'] 105 | else: 106 | start_epoch = 0 107 | sche_d = RampdownScheduler(begin_epoch=config['Ramp_Scheduler']['begin_epoch'], 108 | max_epoch=config['Ramp_Scheduler']['max_epoch'], 109 | current_epoch=start_epoch, 110 | max_value=config['Ramp_Scheduler']['max_value'], 111 | min_value=config['Ramp_Scheduler']['min_value'], 112 | ramp_mult=config['Ramp_Scheduler']['ramp_mult']) 113 | 114 | 115 | best_miou = 0 116 | 117 | model.model.train() 118 | model.ema_model.train() 119 | model.uncer_head.train() 120 | for epoch in range(start_epoch, total_epoch): 121 | train(train_l_loader, train_u_loader, model, optimizer, optimizer_uncer, criterion, epoch, lr_scheduler, lr_scheduler_uncer, sche_d, config) 122 | miou = test(test_loader, model.ema_model, config) 123 | best_miou = max(best_miou, miou) 124 | print('Epoch:{} * mIoU {:.4f} Best_mIoU {:.4f} Time {}'.format(epoch, miou, best_miou, time.asctime( time.localtime(time.time()) ))) 125 | # Save model 126 | if miou == best_miou: 127 | save_dir = './checkpoints/' + str(args.job_name) 128 | torch.save( 129 | { 130 | 'epoch': epoch+1, 131 | 'ema': model.ema_model.state_dict(), 132 | 'model': model.model.state_dict(), 133 | 'uncer_head': model.uncer_head.state_dict(), 134 | 'optimizer': optimizer.state_dict(), 135 | 'optimizer_uncer': optimizer_uncer.state_dict(), 136 | 'lr_scheduler': lr_scheduler.state_dict(), 137 | 'lr_scheduler_uncer': lr_scheduler_uncer.state_dict(), 138 | }, os.path.join(save_dir, 'best_model.pth')) 139 | 140 | 141 | 142 | def train(train_l_loader, train_u_loader, model, optimizer, optimizer_uncer, criterion, epoch, scheduler, scheduler_uncer, sche_d, config): 143 | batch_time = AverageMeter('Time', ':6.3f') 144 | data_time = AverageMeter('Data', ':6.3f') 145 | sup_loss_meter = AverageMeter('Sup_loss', ':6.3f') 146 | unsup_loss_meter = AverageMeter('Unsup_loss', ':6.3f') 147 | contr_loss_meter = AverageMeter('Contr_loss', ':6.3f') 148 | num_class = config['Network']['num_class'] 149 | mious_conf_l = ConfMatrix(num_classes=num_class, fmt=':6.3f', name='l_miou') 150 | mious_conf_u = ConfMatrix(num_classes=num_class, fmt=':6.3f', name='u_miou') 151 | iter_num = int(2000 / config['Dataset']['batch_size']/ len(train_l_loader)) #2000 img in a epoch 152 | progress = ProgressMeter( 153 | iter_num, 154 | [batch_time, data_time, sup_loss_meter, unsup_loss_meter, contr_loss_meter, mious_conf_l, mious_conf_u], 155 | prefix='Epoch: [{}]'.format(epoch) 156 | ) 157 | # switch to train mode 158 | model.model.train() 159 | model.ema_model.train() 160 | model.uncer_head.train() 161 | 162 | end = time.time() 163 | for iter_i in range(iter_num): 164 | training_u_iter = iterator_(train_u_loader) 165 | for i, (train_l_image, train_l_label) in enumerate(train_l_loader): 166 | data_time.update(time.time() - end) 167 | train_l_image, train_l_label = train_l_image.cuda(), train_l_label.cuda() 168 | train_u_image, train_u_label = training_u_iter.__next__() 169 | train_u_image, train_u_label = train_u_image.cuda(), train_u_label.cuda() 170 | pred_l_large, pred_u_large, train_u_aug_label, train_u_aug_logits, rep_all, pred_all, pred_u_large_raw, uncer_all = model(train_l_image, train_u_image) 171 | 172 | sup_loss = criterion['ce_loss'](pred_l_large, train_l_label) 173 | unsup_loss = criterion['unsup_loss'](pred_u_large, train_u_aug_label, train_u_aug_logits) 174 | 175 | ##### Contrastive learning ##### 176 | with torch.no_grad(): 177 | train_u_aug_mask = train_u_aug_logits.ge(config['Prcl_Loss']['weak_threshold']).float() 178 | mask_all = torch.cat(((train_l_label.unsqueeze(1) >= 0).float(), train_u_aug_mask.unsqueeze(1))) 179 | mask_all = F.interpolate(mask_all, size=pred_all.shape[2:], mode='nearest') 180 | 181 | label_l = F.interpolate(label_onehot(train_l_label, num_class), size=pred_all.shape[2:], mode='nearest') 182 | label_u = F.interpolate(label_onehot(train_u_aug_label, num_class), size=pred_all.shape[2:], mode='nearest') 183 | label_all = torch.cat((label_l, label_u)) 184 | 185 | prob_all = torch.softmax(pred_all, dim=1) 186 | 187 | prcl_loss = criterion['prcl_loss'](rep_all, uncer_all, label_all, mask_all, prob_all) 188 | total_loss = sup_loss + unsup_loss + prcl_loss * sche_d.value 189 | 190 | # Update Meter 191 | sup_loss_meter.update(sup_loss.item(), pred_all.shape[0]) 192 | unsup_loss_meter.update(unsup_loss.item(), pred_all.shape[0]) 193 | mious_conf_l.update(pred_l_large.argmax(1).flatten(), train_l_label.flatten()) 194 | mious_conf_u.update(pred_u_large_raw.argmax(1).flatten(), train_u_label.flatten()) 195 | contr_loss_meter.update(prcl_loss.item(), pred_all.shape[0]) 196 | optimizer.zero_grad() 197 | optimizer_uncer.zero_grad() 198 | total_loss.backward() 199 | optimizer.step() 200 | optimizer_uncer.step() 201 | model.ema_update() 202 | batch_time.update(time.time() - end) 203 | end = time.time() 204 | scheduler.step() 205 | scheduler_uncer.step() 206 | sche_d.step() 207 | 208 | @torch.no_grad() 209 | def test(test_loader, model, config): 210 | batch_time = AverageMeter('Time', ':6.3f') 211 | data_time = AverageMeter('Data', ':6.3f') 212 | miou_meter = ConfMatrix(num_classes=config['Network']['num_class'], fmt=':6.4f', name='test_miou') 213 | 214 | # switch to eval mode 215 | model.eval() 216 | 217 | end = time.time() 218 | test_iter = iter(test_loader) 219 | for _ in range(len(test_loader)): 220 | data_time.update(time.time() - end) 221 | test_image, test_label = test_iter.next() 222 | test_image, test_label = test_image.cuda(), test_label.cuda() 223 | 224 | pred, _, _ = model(test_image) 225 | pred = F.interpolate(pred, size=test_label.shape[1:], mode='bilinear', align_corners=True) 226 | 227 | miou_meter.update(pred.argmax(1).flatten(), test_label.flatten()) 228 | batch_time.update(time.time() - end) 229 | end = time.time() 230 | 231 | miou = mean_intersection_over_union(miou_meter.mat) 232 | 233 | return miou 234 | 235 | 236 | if __name__ == '__main__': 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument('--config', type=str, default='') 239 | parser.add_argument('--resume', type=str, default='') 240 | parser.add_argument('--job_name', type=str, default='') 241 | 242 | 243 | args = parser.parse_args() 244 | main(args) 245 | -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/VOC.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import os 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as transforms_f 6 | import random 7 | from PIL import Image, ImageFilter 8 | import numpy as np 9 | import torch.distributed as dist 10 | 11 | class Pascal_VOC_Dataset(data.Dataset): 12 | def __init__(self, root, idx_list, crop_size=(512, 512), scale_size=(0.5, 2.0), augmentation=True, train=True): 13 | self.root = os.path.expanduser(root) 14 | self.train = train 15 | self.crop_size = crop_size 16 | self.augmentation = augmentation 17 | self.scale_size = scale_size 18 | self.idx_list = idx_list 19 | 20 | def __getitem__(self, index): 21 | image_root = Image.open(self.root + '/JPEGImages/{}.jpg'.format(self.idx_list[index])) 22 | label_root = Image.open(self.root + '/SegmentationClassAug/{}.png'.format(self.idx_list[index])) 23 | image, label = transform(image_root, label_root, None, crop_size=self.crop_size, scale_size=self.scale_size, augmentation=self.augmentation) 24 | return image, label.squeeze(0) 25 | 26 | def __len__(self): 27 | return len(self.idx_list) 28 | 29 | class VOC_BuildData(): 30 | def __init__(self, data_path, txt_path, label_num, seed): 31 | self.data_path = data_path 32 | self.txt_path = txt_path 33 | self.image_size = [513, 513] 34 | self.crop_size = [321, 321] 35 | self.num_segments = 21 36 | self.scale_size = (0.5, 1.5) 37 | self.train_l_idx, self.train_u_idx, self.test_idx= get_pascal_idx_via_txt(self.txt_path, label_num=label_num, seed=seed) 38 | 39 | def build(self): 40 | train_l_dataset = Pascal_VOC_Dataset(self.data_path, self.train_l_idx, self.crop_size, self.scale_size, 41 | augmentation=True, train=True) 42 | train_u_dataset = Pascal_VOC_Dataset(self.data_path, self.train_u_idx, self.crop_size, scale_size=(1.0, 1.0), 43 | augmentation=False, train=True) 44 | test_dataset = Pascal_VOC_Dataset(self.data_path, self.test_idx, self.crop_size, scale_size=(1.0, 1.0),augmentation=False, 45 | train=False) 46 | return train_l_dataset, train_u_dataset, test_dataset 47 | 48 | def get_pascal_idx_via_txt(root, label_num, seed): 49 | ''' 50 | Read idx list via generated txt, pre-perform make_list.py 51 | ''' 52 | root = root + '/' + str(label_num) + '/' + str(seed) 53 | with open(root + '/labeled_filename.txt') as f: 54 | labeled_list = f.read().splitlines() 55 | f.close() 56 | with open(root + '/unlabeled_filename.txt') as f: 57 | unlabeled_list = f.read().splitlines() 58 | f.close() 59 | with open(root + '/valid_filename.txt') as f: 60 | test_list = f.read().splitlines() 61 | f.close() 62 | return labeled_list, unlabeled_list, test_list 63 | 64 | def transform(image, label, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 65 | # Randomly rescale images 66 | raw_w, raw_h = image.size 67 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 68 | 69 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 70 | image = transforms_f.resize(image, resized_size, Image.BILINEAR) 71 | label = transforms_f.resize(label, resized_size, Image.NEAREST) 72 | if logits is not None: 73 | logits = transforms_f.resize(logits, resized_size, Image.NEAREST) 74 | 75 | # Add padding if rescaled image is smaller than crop size 76 | if crop_size == -1: # Use original image size 77 | crop_size = (raw_w, raw_h) 78 | 79 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 80 | right_pad = max(crop_size[1] - resized_size[1], 0) 81 | bottom_pad = max(crop_size[0] - resized_size[0], 0) 82 | image = transforms_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 83 | label = transforms_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 84 | if logits is not None: 85 | logits = transforms_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 86 | 87 | # Randomly crop images 88 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 89 | image = transforms_f.crop(image, i, j, h, w) 90 | label = transforms_f.crop(label, i, j, h, w) 91 | if logits is not None: 92 | logits = transforms_f.crop(logits, i, j, h, w) 93 | 94 | if augmentation: 95 | # Random color jittering 96 | if torch.rand(1) > 0.2: 97 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 98 | image = color_transform(image) 99 | 100 | # Random Gaussian filtering 101 | if torch.rand(1) > 0.5: 102 | sigma = random.uniform(0.15, 1.15) 103 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 104 | 105 | # Random horizontal flipping 106 | if torch.rand(1) > 0.5: 107 | image = transforms_f.hflip(image) 108 | label = transforms_f.hflip(label) 109 | if logits is not None: 110 | logits = transforms_f.hflip(logits) 111 | 112 | # Transform to Tensor 113 | image = transforms_f.to_tensor(image) 114 | label = (transforms_f.to_tensor(label) * 255).long() 115 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 116 | if logits is not None: 117 | logits = transforms_f.to_tensor(logits) 118 | 119 | # Apply ImageNet normalization 120 | image = transforms_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 121 | if logits is not None: 122 | return image, label, logits 123 | else: 124 | return image, label 125 | 126 | def tensor_to_pil(image, label, logits): 127 | image = denormalise(image) 128 | image = transforms_f.to_pil_image(image.cpu()) 129 | label = label.float() / 255. 130 | label = transforms_f.to_pil_image(label.unsqueeze(0).cpu()) 131 | logits = transforms_f.to_pil_image(logits.unsqueeze(0).cpu()) 132 | return image, label, logits 133 | 134 | def denormalise(x, imagenet=True): 135 | if imagenet: 136 | x = transforms_f.normalize(x, mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]) 137 | x = transforms_f.normalize(x, mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]) 138 | return x 139 | else: 140 | return (x + 1) / 2 141 | 142 | def batch_transform(images, labels, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 143 | image_list, label_list, logits_list = [], [], [] 144 | device = images.device 145 | for k in range(images.shape[0]): 146 | image_pil, label_pil, logits_pil = tensor_to_pil(images[k], labels[k], logits[k]) 147 | aug_image, aug_label, aug_logits = transform(image_pil, label_pil, logits_pil, crop_size, scale_size, augmentation) 148 | image_list.append(aug_image.unsqueeze(0)) 149 | label_list.append(aug_label) 150 | logits_list.append(aug_logits) 151 | 152 | image_trans, label_trans, logits_trans = torch.cat(image_list).to(device), torch.cat(label_list).to(device), torch.cat(logits_list).to(device) 153 | return image_trans, label_trans, logits_trans 154 | 155 | def generate_cut_gather(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, mode='cutout'): 156 | 157 | batch_size, _, image_h, image_w = image.shape 158 | image = concat_all_gather(image) 159 | label = concat_all_gather(label) 160 | logits = concat_all_gather(logits) 161 | total_size = image.shape[0] 162 | device = image.device 163 | rank = dist.get_rank() 164 | 165 | if mode == 'none': 166 | return image[rank * batch_size: (rank + 1) * batch_size], label[rank * batch_size: (rank + 1) * batch_size].long(), logits[rank * batch_size: (rank + 1) * batch_size] 167 | 168 | new_image = [] 169 | new_label = [] 170 | new_logits = [] 171 | for i in range(total_size): 172 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 173 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 174 | label[i][(1 - mix_mask).bool()] = -1 175 | 176 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 177 | new_label.append(label[i].unsqueeze(0)) 178 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 179 | continue 180 | elif mode == 'cutmix': 181 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 182 | elif mode == 'classmix': 183 | mix_mask = generate_class_mask(label[i]).to(device) 184 | else: 185 | raise ValueError('mode must be in cutout, cutmix, or classmix') 186 | 187 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 188 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 189 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 190 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 191 | 192 | return new_image[rank * batch_size: (rank + 1) * batch_size], new_label[rank * batch_size: (rank + 1) * batch_size].long(), new_logits[rank * batch_size: (rank + 1) * batch_size] 193 | 194 | def generate_cut(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, mode='cutout'): 195 | if mode == 'none': 196 | return image, label.long(), logits 197 | batch_size, _, image_h, image_w = image.shape 198 | device = image.device 199 | 200 | new_image = [] 201 | new_label = [] 202 | new_logits = [] 203 | for i in range(batch_size): 204 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 205 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 206 | label[i][(1 - mix_mask).bool()] = -1 207 | 208 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 209 | new_label.append(label[i].unsqueeze(0)) 210 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 211 | continue 212 | elif mode == 'cutmix': 213 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 214 | elif mode == 'classmix': 215 | mix_mask = generate_class_mask(label[i]).to(device) 216 | else: 217 | raise ValueError('mode must be in cutout, cutmix, or classmix') 218 | 219 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 220 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 221 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 222 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 223 | 224 | return new_image, new_label.long(), new_logits 225 | 226 | def generate_class_mask(pseudo_labels: torch.Tensor): 227 | # select the half classes and cover up them 228 | labels = torch.unique(pseudo_labels) # all unique labels 229 | labels_select: torch.Tensor = labels[torch.randperm(len(labels))][:len(labels) // 2] # Randmoly select half of labels 230 | mask = (pseudo_labels.unsqueeze(-1) == labels_select).any(dim=-1) 231 | return mask.float() 232 | 233 | def generate_cutout_mask(image_size, ratio=2): 234 | # Cutout: random generate mask where the region inside is 0, one ouside is 1 235 | cutout_area = image_size[0] * image_size[1] / ratio 236 | 237 | w = np.random.randint(image_size[1] / ratio + 1, image_size[1]) 238 | h = np.round(cutout_area / w) 239 | 240 | x_start = np.random.randint(0, image_size[1] - w + 1) 241 | y_start = np.random.randint(0, image_size[0] - h + 1) 242 | 243 | x_end = int(x_start + w) 244 | y_end = int(y_start + h) 245 | 246 | mask = torch.ones(image_size) 247 | mask[y_start: y_end, x_start: x_end] = 0 248 | 249 | return mask.float() 250 | 251 | @torch.no_grad() 252 | def concat_all_gather(tensor): 253 | """ 254 | Performs all_gather operation on the provided tensors. 255 | Warning: torch.distributed.all_ather has no gradient. 256 | """ 257 | tensor_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] 258 | torch.distributed.all_gather(tensor_gather, tensor, async_op=False) 259 | output = torch.cat(tensor_gather, dim=0) 260 | 261 | return output -------------------------------------------------------------------------------- /generalframeworks/networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = [ 6 | "ResNet", 7 | "resnet18", 8 | "resnet34", 9 | "resnet50", 10 | "resnet101", 11 | "resnet152", 12 | ] 13 | 14 | 15 | model_urls = { 16 | "resnet18": "/path/to/resnet18.pth", 17 | "resnet34": "/path/to/resnet34.pth", 18 | "resnet50": "/path/to/resnet50.pth", 19 | "resnet101": "path/to/resnet101.pth", 20 | "resnet152": "/path/to/resnet152.pth", 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d( 27 | in_planes, 28 | out_planes, 29 | kernel_size=3, 30 | stride=stride, 31 | padding=dilation, 32 | groups=groups, 33 | bias=False, 34 | dilation=dilation, 35 | ) 36 | 37 | 38 | def conv1x1(in_planes, out_planes, stride=1): 39 | """1x1 convolution""" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__( 47 | self, 48 | inplanes, 49 | planes, 50 | stride=1, 51 | downsample=None, 52 | groups=1, 53 | base_width=64, 54 | dilation=1, 55 | norm_layer=None, 56 | ): 57 | super(BasicBlock, self).__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | if groups != 1 or base_width != 64: 61 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 62 | if dilation > 1: 63 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 64 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 65 | self.conv1 = conv3x3(inplanes, planes, stride) 66 | self.bn1 = norm_layer(planes) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.conv2 = conv3x3(planes, planes) 69 | self.bn2 = norm_layer(planes) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | identity = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | out += identity 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class Bottleneck(nn.Module): 93 | expansion = 4 94 | 95 | def __init__( 96 | self, 97 | inplanes, 98 | planes, 99 | stride=1, 100 | downsample=None, 101 | groups=1, 102 | base_width=64, 103 | dilation=1, 104 | norm_layer=nn.BatchNorm2d, 105 | ): 106 | super(Bottleneck, self).__init__() 107 | width = int(planes * (base_width / 64.0)) * groups 108 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 109 | self.conv1 = conv1x1(inplanes, width) 110 | self.bn1 = norm_layer(width) 111 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 112 | self.bn2 = norm_layer(width) 113 | self.conv3 = conv1x1(width, planes * self.expansion) 114 | self.bn3 = norm_layer(planes * self.expansion) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.downsample = downsample 117 | self.stride = stride 118 | 119 | def forward(self, x): 120 | identity = x 121 | 122 | out = self.conv1(x) 123 | out = self.bn1(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv2(out) 127 | out = self.bn2(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv3(out) 131 | out = self.bn3(out) 132 | 133 | if self.downsample is not None: 134 | identity = self.downsample(x) 135 | 136 | out += identity 137 | out = self.relu(out) 138 | 139 | return out 140 | 141 | 142 | class ResNet_Stem(nn.Module): 143 | def __init__( 144 | self, 145 | block, 146 | layers, 147 | zero_init_residual=True, 148 | groups=1, 149 | width_per_group=64, 150 | replace_stride_with_dilation=[False, True, True], 151 | multi_grid=True, 152 | fpn=True, 153 | ): 154 | super(ResNet_Stem, self).__init__() 155 | 156 | # norm_layer = 157 | norm_layer = nn.BatchNorm2d 158 | self._norm_layer = norm_layer 159 | 160 | self.inplanes = 128 161 | self.dilation = 1 162 | 163 | if replace_stride_with_dilation is None: 164 | # each element in the tuple indicates if we should replace 165 | # the 2x2 stride with a dilated convolution instead 166 | replace_stride_with_dilation = [False, False, False] 167 | 168 | if len(replace_stride_with_dilation) != 3: 169 | raise ValueError( 170 | "replace_stride_with_dilation should be None " 171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 172 | ) 173 | 174 | self.groups = groups 175 | self.base_width = width_per_group 176 | self.fpn = fpn 177 | self.conv1 = nn.Sequential( 178 | conv3x3(3, 64, stride=2), 179 | norm_layer(64), 180 | nn.ReLU(inplace=True), 181 | conv3x3(64, 64), 182 | norm_layer(64), 183 | nn.ReLU(inplace=True), 184 | conv3x3(64, self.inplanes), 185 | ) 186 | self.bn1 = norm_layer(self.inplanes) 187 | self.relu = nn.ReLU(inplace=True) 188 | self.maxpool = nn.MaxPool2d( 189 | kernel_size=3, stride=2, padding=1, ceil_mode=True 190 | ) # change 191 | 192 | self.layer1 = self._make_layer(block, 64, layers[0]) 193 | self.layer2 = self._make_layer( 194 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 195 | ) 196 | self.layer3 = self._make_layer( 197 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 198 | ) 199 | self.layer4 = self._make_layer( 200 | block, 201 | 512, 202 | layers[3], 203 | stride=2, 204 | dilate=replace_stride_with_dilation[2], 205 | multi_grid=multi_grid, 206 | ) 207 | 208 | for m in self.modules(): 209 | if isinstance(m, nn.Conv2d): 210 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 211 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): 212 | nn.init.constant_(m.weight, 1) 213 | nn.init.constant_(m.bias, 0) 214 | 215 | # Zero-initialize the last BN in each residual branch, 216 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 217 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 218 | if zero_init_residual: 219 | for m in self.modules(): 220 | if isinstance(m, Bottleneck): 221 | nn.init.constant_(m.bn3.weight, 0) 222 | elif isinstance(m, BasicBlock): 223 | nn.init.constant_(m.bn2.weight, 0) 224 | 225 | def get_outplanes(self): 226 | return self.inplanes 227 | 228 | def get_auxplanes(self): 229 | return self.inplanes // 2 230 | 231 | def _make_layer( 232 | self, block, planes, blocks, stride=1, dilate=False, multi_grid=False 233 | ): 234 | norm_layer = self._norm_layer 235 | downsample = None 236 | previous_dilation = self.dilation 237 | if dilate: 238 | self.dilation *= stride 239 | stride = 1 240 | if stride != 1 or self.inplanes != planes * block.expansion: 241 | downsample = nn.Sequential( 242 | conv1x1(self.inplanes, planes * block.expansion, stride), 243 | norm_layer(planes * block.expansion), 244 | ) 245 | 246 | grids = [1] * blocks 247 | if multi_grid: 248 | grids = [2, 2, 4] 249 | 250 | layers = [] 251 | layers.append( 252 | block( 253 | self.inplanes, 254 | planes, 255 | stride, 256 | downsample, 257 | self.groups, 258 | self.base_width, 259 | previous_dilation * grids[0], 260 | norm_layer, 261 | ) 262 | ) 263 | self.inplanes = planes * block.expansion 264 | for i in range(1, blocks): 265 | layers.append( 266 | block( 267 | self.inplanes, 268 | planes, 269 | groups=self.groups, 270 | base_width=self.base_width, 271 | dilation=self.dilation * grids[i], 272 | norm_layer=norm_layer, 273 | ) 274 | ) 275 | 276 | return nn.Sequential(*layers) 277 | 278 | def forward(self, x): 279 | x = self.relu(self.bn1(self.conv1(x))) 280 | x = self.maxpool(x) 281 | 282 | x = self.layer1(x) 283 | x1 = x 284 | x = self.layer2(x) 285 | x2 = x 286 | x3 = self.layer3(x) 287 | x4 = self.layer4(x3) 288 | if self.fpn: 289 | return [x1, x2, x3, x4] 290 | else: 291 | return [x3, x4] 292 | 293 | 294 | 295 | def resnet18(pretrained=False, **kwargs): 296 | """Constructs a ResNet-18 model. 297 | 298 | Args: 299 | pretrained (bool): If True, returns a model pre-trained on ImageNet 300 | """ 301 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 302 | if pretrained: 303 | model_url = model_urls["resnet18"] 304 | state_dict = torch.load(model_url) 305 | 306 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 307 | print( 308 | f"[Info] Load ImageNet pretrain from '{model_url}'", 309 | "\nmissing_keys: ", 310 | missing_keys, 311 | "\nunexpected_keys: ", 312 | unexpected_keys, 313 | ) 314 | return model 315 | 316 | 317 | def resnet34(pretrained=False, **kwargs): 318 | """Constructs a ResNet-34 model. 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | """ 323 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 324 | if pretrained: 325 | model_url = model_urls["resnet34"] 326 | state_dict = torch.load(model_url) 327 | 328 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 329 | print( 330 | f"[Info] Load ImageNet pretrain from '{model_url}'", 331 | "\nmissing_keys: ", 332 | missing_keys, 333 | "\nunexpected_keys: ", 334 | unexpected_keys, 335 | ) 336 | return model 337 | 338 | 339 | def resnet50(pretrained=True, **kwargs): 340 | """Constructs a ResNet-50 model. 341 | 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | """ 345 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 346 | if pretrained: 347 | model_url = model_urls["resnet50"] 348 | state_dict = torch.load(model_url) 349 | 350 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 351 | print( 352 | f"[Info] Load ImageNet pretrain from '{model_url}'", 353 | "\nmissing_keys: ", 354 | missing_keys, 355 | "\nunexpected_keys: ", 356 | unexpected_keys, 357 | ) 358 | return model 359 | 360 | 361 | def resnet101(pretrained=True, **kwargs): 362 | """Constructs a ResNet-101 model. 363 | 364 | Args: 365 | pretrained (bool): If True, returns a model pre-trained on ImageNet 366 | """ 367 | model = ResNet_Stem(Bottleneck, [3, 4, 23, 3], **kwargs) 368 | if pretrained: 369 | model_url = model_urls["resnet101"] 370 | state_dict = torch.load(model_url) 371 | 372 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 373 | print( 374 | f"[Info] Load ImageNet pretrain from '{model_url}'", 375 | "\nmissing_keys: ", 376 | missing_keys, 377 | "\nunexpected_keys: ", 378 | unexpected_keys, 379 | ) 380 | return model 381 | 382 | 383 | def resnet152(pretrained=True, **kwargs): 384 | """Constructs a ResNet-152 model. 385 | 386 | Args: 387 | pretrained (bool): If True, returns a model pre-trained on ImageNet 388 | """ 389 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 390 | if pretrained: 391 | model_url = model_urls["resnet152"] 392 | state_dict = torch.load(model_url) 393 | 394 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 395 | print( 396 | f"[Info] Load ImageNet pretrain from '{model_url}'", 397 | "\nmissing_keys: ", 398 | missing_keys, 399 | "\nunexpected_keys: ", 400 | unexpected_keys, 401 | ) 402 | return model 403 | -------------------------------------------------------------------------------- /generalframeworks/loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from generalframeworks.utils import simplex 6 | from generalframeworks.networks.ddp_model import concat_all_gather 7 | 8 | class Attention_Threshold_Loss(nn.Module): 9 | def __init__(self, strong_threshold): 10 | super(Attention_Threshold_Loss, self).__init__() 11 | self.strong_threshold = strong_threshold 12 | 13 | def forward(self, pred: torch.Tensor, pseudo_label: torch.Tensor, logits: torch.Tensor): 14 | batch_size = pred.shape[0] 15 | valid_mask = (pseudo_label >= 0).float() # only count valid pixels (class) 16 | weighting = logits.view(batch_size, -1).ge(self.strong_threshold).sum(-1) / (valid_mask.view(batch_size, -1).sum(-1)) # May be nan if the whole target is masked in cutout 17 | #self.tmp_valid_num = logits.ge(self.strong_threshold).view(logits.shape[0], -1).float().sum(-1).mean(0) 18 | # weight represent the proportion of valid pixels in this batch 19 | loss = F.cross_entropy(pred, pseudo_label, reduction='none', ignore_index=-1) # pixel-wise 20 | weighted_loss = torch.mean(torch.masked_select(weighting[:, None, None] * loss, loss > 0)) 21 | # weight torch.size([4]) -> weight[:, None, None] torch.size([4, 1, 1]) for broadcast to multiply the weight to the corresponding class 22 | # torch.masked_select to select loss > 0 only leaved 23 | 24 | return weighted_loss 25 | 26 | class Prcl_Loss(nn.Module): 27 | def __init__(self, num_queries, num_negatives, temp=0.5, mean=False, strong_threshold=0.97): 28 | super(Prcl_Loss, self).__init__() 29 | self.temp = temp 30 | self.mean = mean 31 | self.num_queries = num_queries # anchor 32 | self.num_negatives = num_negatives 33 | self.strong_threshold = strong_threshold 34 | def forward(self, mu, sigma, label, mask, prob): 35 | # We gather all representations (mu and sigma) cross mutiple GPUs during this progress 36 | mu_prt = concat_all_gather(mu) # For protoype computing on all cards (w/o gradients) 37 | sigma_prt = concat_all_gather(sigma) 38 | batch_size, num_feat, mu_w, mu_h = mu.shape 39 | num_segments = label.shape[1] #21 40 | valid_pixel_all = label * mask # Valid rep (sampling strategy a) 41 | valid_pixel_all_prt = concat_all_gather(valid_pixel_all) # For protoype computing on all cards 42 | 43 | # Permute representation for indexing" [batch, rep_h, rep_w, feat_num] 44 | 45 | mu = mu.permute(0, 2, 3, 1) 46 | sigma = sigma.permute(0, 2, 3, 1) 47 | mu_prt = mu_prt.permute(0, 2, 3, 1) 48 | sigma_prt = sigma_prt.permute(0, 2, 3, 1) 49 | 50 | mu_all_list = [] # all valid rep pool 51 | sigma_all_list = [] # all valid sigma pool 52 | mu_hard_list = [] # anchor pool 53 | sigma_hard_list = [] 54 | num_list = [] # Valid num of each class 55 | proto_mu_list = [] # Prototype 56 | proto_sigma_list = [] 57 | 58 | for i in range(num_segments): #21 59 | valid_pixel = valid_pixel_all[:, i] # on single card 60 | valid_pixel_gather = valid_pixel_all_prt[:, i] # on multi card 61 | if valid_pixel.sum() == 0: 62 | continue 63 | prob_seg = prob[:, i, :, :] 64 | rep_mask_hard = (prob_seg < self.strong_threshold) * valid_pixel.bool() # Anchor sampling (strategy b) 65 | # Prototype calculation 66 | with torch.no_grad(): 67 | proto_sigma_ = 1 / torch.sum((1 / sigma_prt[valid_pixel_gather.bool()]), dim=0, keepdim=True) # Equation 8 68 | proto_mu_ = torch.sum((proto_sigma_ / sigma_prt[valid_pixel_gather.bool()]) \ 69 | * mu_prt[valid_pixel_gather.bool()], dim=0, keepdim=True) # Equation 7 70 | proto_mu_list.append(proto_mu_) 71 | proto_sigma_list.append(proto_sigma_) 72 | 73 | mu_all_list.append(mu[valid_pixel.bool()]) 74 | sigma_all_list.append(sigma[valid_pixel.bool()]) 75 | mu_hard_list.append(mu[rep_mask_hard]) 76 | sigma_hard_list.append(sigma[rep_mask_hard]) 77 | num_list.append(int(valid_pixel.sum().item())) 78 | 79 | # Compute Probabilistic Representation Contrastive Loss 80 | if (len(num_list) <= 1) : # in some rare cases, a small mini-batch only contain 1 or no semantic class 81 | return torch.tensor(0.0) #+ 0 * mu.sum() + 0 * sigma.sum() # A trick for avoiding data leakage in DDP training, if you have the find unused gradient warning. 82 | else: 83 | prcl_loss = torch.tensor(0.0) 84 | proto_mu = torch.cat(proto_mu_list) # [c] 85 | proto_sigma = torch.cat(proto_sigma_list) 86 | valid_num = len(num_list) 87 | seg_len = torch.arange(valid_num) 88 | 89 | for i in range(valid_num): 90 | if len(mu_hard_list[i]) > 0: 91 | # Random Sampling anchor representations 92 | sample_idx = torch.randint(len(mu_hard_list[i]), size=(self.num_queries, )) 93 | anchor_mu = mu_hard_list[i][sample_idx] 94 | anchor_sigma = sigma_hard_list[i][sample_idx] 95 | else: 96 | continue 97 | with torch.no_grad(): 98 | # Sampling negatives 99 | id_mask = torch.cat(([seg_len[i: ], seg_len[: i]])) 100 | proto_sim = mutual_likelihood_score(proto_mu[id_mask[0].unsqueeze(0)], 101 | proto_mu[id_mask[1: ]], 102 | proto_sigma[id_mask[0].unsqueeze(0)], 103 | proto_sigma[id_mask[1: ]]) # Calculate the similarity among prototypes 104 | proto_prob = torch.softmax(proto_sim / self.temp, dim=0) # The distribution of sampling (strategy c) 105 | negative_dist = torch.distributions.categorical.Categorical(probs=proto_prob) 106 | samp_class = negative_dist.sample(sample_shape=[self.num_queries, self.num_negatives]) # Sampling negatives according the similarity 107 | samp_num = torch.stack([(samp_class == c).sum(1) for c in range(len(proto_prob))], dim=1) 108 | negative_num_list = num_list[i+1: ] + num_list[: i] 109 | negative_index = negative_index_sampler(samp_num, negative_num_list) 110 | negative_mu_all = torch.cat(mu_all_list[i+1: ] + mu_all_list[: i]) 111 | negative_sigma_all = torch.cat(sigma_all_list[i+1: ] + sigma_all_list[: i]) 112 | negative_mu = negative_mu_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 113 | negative_sigma = negative_sigma_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 114 | positive_mu = proto_mu[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 115 | positive_sigma = proto_sigma[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 116 | all_mu = torch.cat((positive_mu, negative_mu), dim=1) 117 | all_sigma = torch.cat((positive_sigma, negative_sigma), dim=1) 118 | 119 | logits = mutual_likelihood_score(anchor_mu.unsqueeze(1), all_mu, anchor_sigma.unsqueeze(1), all_sigma) 120 | prcl_loss = prcl_loss + F.cross_entropy(logits / self.temp, torch.zeros(self.num_queries).long().cuda()) 121 | 122 | return prcl_loss / valid_num 123 | 124 | class Prcl_Loss_single(nn.Module): 125 | # For single GPU users 126 | def __init__(self, num_queries, num_negatives, temp=0.5, mean=False, strong_threshold=0.97): 127 | super(Prcl_Loss_single, self).__init__() 128 | self.temp = temp 129 | self.mean = mean 130 | self.num_queries = num_queries 131 | self.num_negatives = num_negatives 132 | self.strong_threshold = strong_threshold 133 | def forward(self, mu, sigma, label, mask, prob): 134 | batch_size, num_feat, mu_w, mu_h = mu.shape 135 | num_segments = label.shape[1] #21 136 | valid_pixel_all = label * mask 137 | # Permute representation for indexing" [batch, rep_h, rep_w, feat_num] 138 | 139 | mu = mu.permute(0, 2, 3, 1) 140 | sigma = sigma.permute(0, 2, 3, 1) 141 | 142 | mu_all_list = [] 143 | sigma_all_list = [] 144 | mu_hard_list = [] 145 | sigma_hard_list = [] 146 | num_list = [] 147 | proto_mu_list = [] 148 | proto_sigma_list = [] 149 | 150 | for i in range(num_segments): #21 151 | valid_pixel = valid_pixel_all[:, i] 152 | if valid_pixel.sum() == 0: 153 | continue 154 | prob_seg = prob[:, i, :, :] 155 | rep_mask_hard = (prob_seg < self.strong_threshold) * valid_pixel.bool() # Only on single card 156 | # Prototype computing 157 | with torch.no_grad(): 158 | proto_sigma_ = 1 / torch.sum((1 / sigma[valid_pixel.bool()]), dim=0, keepdim=True) 159 | proto_mu_ = torch.sum((proto_sigma_ / sigma[valid_pixel.bool()]) \ 160 | * mu[valid_pixel.bool()], dim=0, keepdim=True) 161 | proto_mu_list.append(proto_mu_) 162 | proto_sigma_list.append(proto_sigma_) 163 | 164 | mu_all_list.append(mu[valid_pixel.bool()]) 165 | sigma_all_list.append(sigma[valid_pixel.bool()]) 166 | mu_hard_list.append(mu[rep_mask_hard]) 167 | sigma_hard_list.append(sigma[rep_mask_hard]) 168 | num_list.append(int(valid_pixel.sum().item())) 169 | 170 | # Compute Probabilistic Representation Contrastive Loss 171 | if (len(num_list) <= 1) : # in some rare cases, a small mini-batch only contain 1 or no semantic class 172 | return torch.tensor(0.0) 173 | else: 174 | prcl_loss = torch.tensor(0.0) 175 | proto_mu = torch.cat(proto_mu_list) # [c] 176 | proto_sigma = torch.cat(proto_sigma_list) 177 | valid_num = len(num_list) 178 | seg_len = torch.arange(valid_num) 179 | 180 | for i in range(valid_num): 181 | if len(mu_hard_list[i]) > 0: 182 | # Random Sampling anchor representations 183 | sample_idx = torch.randint(len(mu_hard_list[i]), size=(self.num_queries, )) 184 | anchor_mu = mu_hard_list[i][sample_idx] 185 | anchor_sigma = sigma_hard_list[i][sample_idx] 186 | else: 187 | continue 188 | with torch.no_grad(): 189 | # Select negatives 190 | id_mask = torch.cat(([seg_len[i: ], seg_len[: i]])) 191 | proto_sim = mutual_likelihood_score(proto_mu[id_mask[0].unsqueeze(0)], 192 | proto_mu[id_mask[1: ]], 193 | proto_sigma[id_mask[0].unsqueeze(0)], 194 | proto_sigma[id_mask[1: ]]) 195 | proto_prob = torch.softmax(proto_sim / self.temp, dim=0) 196 | negative_dist = torch.distributions.categorical.Categorical(probs=proto_prob) 197 | samp_class = negative_dist.sample(sample_shape=[self.num_queries, self.num_negatives]) 198 | samp_num = torch.stack([(samp_class == c).sum(1) for c in range(len(proto_prob))], dim=1) 199 | negative_num_list = num_list[i+1: ] + num_list[: i] 200 | negative_index = negative_index_sampler(samp_num, negative_num_list) 201 | negative_mu_all = torch.cat(mu_all_list[i+1: ] + mu_all_list[: i]) 202 | negative_sigma_all = torch.cat(sigma_all_list[i+1: ] + sigma_all_list[: i]) 203 | negative_mu = negative_mu_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 204 | negative_sigma = negative_sigma_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 205 | positive_mu = proto_mu[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 206 | positive_sigma = proto_sigma[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 207 | all_mu = torch.cat((positive_mu, negative_mu), dim=1) 208 | all_sigma = torch.cat((positive_sigma, negative_sigma), dim=1) 209 | 210 | logits = mutual_likelihood_score(anchor_mu.unsqueeze(1), all_mu, anchor_sigma.unsqueeze(1), all_sigma) 211 | prcl_loss = prcl_loss + F.cross_entropy(logits / self.temp, torch.zeros(self.num_queries).long().cuda()) 212 | 213 | return prcl_loss / valid_num 214 | 215 | #### Utils #### 216 | def negative_index_sampler(samp_num, seg_num_list): 217 | negative_index = [] 218 | for i in range(samp_num.shape[0]): 219 | for j in range(samp_num.shape[1]): 220 | negative_index += np.random.randint(low=sum(seg_num_list[: j]), 221 | high=sum(seg_num_list[: j+1]), 222 | size=int(samp_num[i, j])).tolist() 223 | 224 | return negative_index 225 | 226 | #### MLS #### 227 | def mutual_likelihood_score(mu_0, mu_1, sigma_0, sigma_1): 228 | ''' 229 | Compute the MLS 230 | param: mu_0, mu_1 [256, 513, 256] [256, 1, 256] 231 | sigma_0, sigma_1 [256, 513, 256] [256, 1, 256] 232 | ''' 233 | mu_0 = F.normalize(mu_0, dim=-1) 234 | mu_1 = F.normalize(mu_1, dim=-1) 235 | up = (mu_0 - mu_1) ** 2 236 | down = sigma_0 + sigma_1 237 | mls = -0.5 * (up / down + torch.log(down)).mean(-1) 238 | 239 | 240 | return mls 241 | -------------------------------------------------------------------------------- /prcl.py: -------------------------------------------------------------------------------- 1 | import shutup 2 | shutup.please() 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from generalframeworks.dataset_helpers.VOC import VOC_BuildData 7 | from generalframeworks.dataset_helpers.Cityscapes import City_BuildData 8 | from generalframeworks.networks.ddp_model import Model_with_un 9 | from generalframeworks.loss.loss import Attention_Threshold_Loss, Prcl_Loss 10 | from generalframeworks.scheduler.my_lr_scheduler import PolyLR 11 | from generalframeworks.scheduler.rampscheduler import RampdownScheduler 12 | from generalframeworks.utils import iterator_ 13 | from generalframeworks.util.dist_init import dist_init 14 | from generalframeworks.util.meter import * 15 | from generalframeworks.utils import label_onehot 16 | from generalframeworks.util.torch_dist_sum import * 17 | from generalframeworks.util.miou import * 18 | import yaml 19 | import os 20 | import time 21 | import torchvision.models as models 22 | from generalframeworks.networks import resnet 23 | import argparse 24 | import random 25 | 26 | def main(): 27 | args = parser.parse_args() 28 | from torch.nn.parallel import DistributedDataParallel 29 | ##### Distribution init ##### 30 | rank, local_rank, world_size = dist_init(args.port) 31 | print('Hello from rank {}\n'.format(rank)) 32 | 33 | ##### Config init ##### 34 | with open(args.config, 'r') as f: 35 | config = yaml.load(f.read(), Loader=yaml.FullLoader) 36 | save_dir = './checkpoints/' + str(args.job_name) 37 | if rank == 0: 38 | if not os.path.exists(save_dir): 39 | os.makedirs(save_dir) 40 | with open(save_dir + '/config.yaml', 'w') as f: 41 | yaml.dump(config, f, default_flow_style=False) 42 | print(config) 43 | 44 | ##### Init Seed ##### 45 | random.seed(config['Seed']) 46 | torch.manual_seed(config['Seed']) 47 | torch.backends.cudnn.deterministic = True 48 | 49 | ##### Load the dataset ##### 50 | if config['Dataset']['name'] == 'VOC': 51 | data = VOC_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 52 | label_num=config['Dataset']['num_labels'], seed=config['Seed']) 53 | if config['Dataset']['name'] == 'CityScapes': 54 | data = City_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 55 | label_num=config['Dataset']['num_labels'], seed=config['Seed']) 56 | train_l_dataset, train_u_dataset, test_dataset = data.build() 57 | train_l_sampler = torch.utils.data.distributed.DistributedSampler(train_l_dataset) 58 | train_l_loader = torch.utils.data.DataLoader(train_l_dataset, 59 | batch_size=config['Dataset']['batch_size'], 60 | num_workers=4, 61 | pin_memory=True, 62 | sampler=train_l_sampler, 63 | persistent_workers=True) 64 | train_u_sampler = torch.utils.data.distributed.DistributedSampler(train_u_dataset) 65 | train_u_loader = torch.utils.data.DataLoader(train_u_dataset, 66 | batch_size=config['Dataset']['batch_size'], 67 | num_workers=4, 68 | pin_memory=True, 69 | sampler=train_u_sampler, 70 | persistent_workers=True) 71 | test_loader = torch.utils.data.DataLoader(test_dataset, 72 | batch_size=config['Dataset']['batch_size'], 73 | num_workers=4, 74 | pin_memory=True, 75 | persistent_workers=True) 76 | 77 | ##### Model init ##### 78 | backbone = models.resnet101() 79 | ckpt = torch.load('./pretrained/resnet101.pth', map_location='cpu') 80 | backbone.load_state_dict(ckpt) 81 | 82 | # for Resnet-101 stem users 83 | #backbone = resnet.resnet101(pretrained=True) 84 | 85 | model = Model_with_un(backbone, num_classes=config['Network']['num_class'], output_dim=256, ema_alpha=config['EMA']['alpha'], config=config) 86 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda(local_rank) # Added 87 | model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) 88 | 89 | ##### Loss init ##### 90 | criterion = {'ce_loss': nn.CrossEntropyLoss(ignore_index=-1).cuda(local_rank), 91 | 'unsup_loss': Attention_Threshold_Loss(strong_threshold=config['Prcl_Loss']['un_threshold']).cuda(local_rank), 92 | 'prcl_loss': Prcl_Loss(strong_threshold=config['Prcl_Loss']['strong_threshold'], 93 | num_queries=config['Prcl_Loss']['num_queries'], 94 | num_negatives=config['Prcl_Loss']['num_negatives'], 95 | temp=config['Prcl_Loss']['temp']).cuda(local_rank) 96 | } 97 | 98 | 99 | ##### Other init ##### 100 | optimizer = torch.optim.SGD(model.module.model.parameters(), lr=float(config['Optim']['lr']), weight_decay=float(config['Optim']['weight_decay']), 101 | momentum=0.9, nesterov=True) 102 | optimizer_uncer = torch.optim.SGD(model.module.uncer_head.parameters(), lr=float(config['Optim']['uncer_lr']), weight_decay=float(config['Optim']['weight_decay']), 103 | momentum=0.9, nesterov=True) 104 | total_epoch = config['Training_Setting']['epoch'] 105 | lr_scheduler = PolyLR(optimizer, total_epoch) 106 | lr_scheduler_uncer = PolyLR(optimizer_uncer, total_epoch) 107 | 108 | if os.path.exists(args.resume): 109 | print('resume from', args.resume) 110 | checkpoint = torch.load(args.resume, map_location='cpu') 111 | model.module.model.load_state_dict(checkpoint['model']) 112 | model.module.ema_model.load_state_dict(checkpoint['ema']) 113 | model.module.uncer_head.load_state_dict(checkpoint['uncer_head']) 114 | optimizer.load_state_dict(checkpoint['optimizer']) 115 | optimizer_uncer.load_state_dict(checkpoint['optimizer_uncer']) 116 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 117 | start_epoch = checkpoint['epoch'] 118 | else: 119 | start_epoch = 0 120 | 121 | sche_d = RampdownScheduler(begin_epoch=config['Ramp_Scheduler']['begin_epoch'], 122 | max_epoch=config['Ramp_Scheduler']['max_epoch'], 123 | current_epoch=start_epoch, 124 | max_value=config['Ramp_Scheduler']['max_value'], 125 | min_value=config['Ramp_Scheduler']['min_value'], 126 | ramp_mult=config['Ramp_Scheduler']['ramp_mult']) 127 | 128 | best_miou = 0 129 | 130 | model.module.model.train() 131 | model.module.ema_model.train() 132 | model.module.uncer_head.train() 133 | for epoch in range(start_epoch, total_epoch): 134 | train(train_l_loader, train_u_loader, model, rank, local_rank, world_size, optimizer, optimizer_uncer, criterion, epoch, lr_scheduler, lr_scheduler_uncer, sche_d, config, args) 135 | miou = test(test_loader, model.module.ema_model, rank) 136 | best_miou = max(best_miou, miou) 137 | if rank == 0: 138 | print('Epoch:{} * mIoU {:.4f} Best_mIoU {:.4f} Time {}'.format(epoch, miou, best_miou, time.asctime( time.localtime(time.time()) ))) 139 | # Save model 140 | if miou == best_miou: 141 | save_dir = './checkpoints/' + str(args.job_name) 142 | torch.save( 143 | { 144 | 'epoch': epoch+1, 145 | 'ema': model.module.ema_model.state_dict(), 146 | 'model': model.module.model.state_dict(), 147 | 'uncer_head': model.module.uncer_head.state_dict(), 148 | 'optimizer': optimizer.state_dict(), 149 | 'optimizer_uncer': optimizer_uncer.state_dict(), 150 | 'lr_scheduler': lr_scheduler.state_dict(), 151 | 'lr_scheduler_uncer': lr_scheduler_uncer.state_dict() 152 | }, os.path.join(save_dir, 'best_model.pth')) 153 | 154 | 155 | 156 | def train(train_l_loader, train_u_loader, model, rank, local_rank, world_size, optimizer, optimizer_uncer, criterion, epoch, scheduler, scheduler_uncer, sche_d, config, args): 157 | batch_time = AverageMeter('Time', ':6.3f') 158 | data_time = AverageMeter('Data', ':6.3f') 159 | sup_loss_meter = AverageMeter('Sup_loss', ':6.3f') 160 | unsup_loss_meter = AverageMeter('Unsup_loss', ':6.3f') 161 | contr_loss_meter = AverageMeter('Contr_loss', ':6.3f') 162 | num_class = config['Network']['num_class'] # VOC=21 163 | mious_conf_l = ConfMatrix(num_classes=num_class, fmt=':6.3f', name='l_miou') 164 | mious_conf_u = ConfMatrix(num_classes=num_class, fmt=':6.3f', name='u_miou') 165 | iter_num = int(2000 / config['Dataset']['batch_size'] / world_size / len(train_l_loader)) #2000 img in a epoch 166 | progress = ProgressMeter( 167 | iter_num, 168 | [batch_time, data_time, sup_loss_meter, unsup_loss_meter, contr_loss_meter, mious_conf_l, mious_conf_u], 169 | prefix='Epoch: [{}]'.format(epoch) 170 | ) 171 | # switch to train mode 172 | model.module.model.train() 173 | model.module.ema_model.train() 174 | model.module.uncer_head.train() 175 | 176 | end = time.time() 177 | train_u_loader.sampler.set_epoch(epoch) 178 | training_u_iter = iterator_(train_u_loader) 179 | for iter_i in range(iter_num): 180 | train_l_loader.sampler.set_epoch(epoch + iter_i * 200) 181 | for i, (train_l_image, train_l_label) in enumerate(train_l_loader): 182 | data_time.update(time.time() - end) 183 | train_l_image, train_l_label = train_l_image.cuda(local_rank), train_l_label.cuda(local_rank) 184 | train_u_image, train_u_label = training_u_iter.__next__() 185 | train_u_image, train_u_label = train_u_image.cuda(local_rank), train_u_label.cuda(local_rank) 186 | pred_l_large, pred_u_large, train_u_aug_label, train_u_aug_logits, rep_all, pred_all, pred_u_large_raw, uncer_all = model(train_l_image, train_u_image) 187 | 188 | sup_loss = criterion['ce_loss'](pred_l_large, train_l_label) 189 | unsup_loss = criterion['unsup_loss'](pred_u_large, train_u_aug_label, train_u_aug_logits) 190 | 191 | ##### Contrastive learning ##### 192 | with torch.no_grad(): 193 | train_u_aug_mask = train_u_aug_logits.ge(config['Prcl_Loss']['weak_threshold']).float() 194 | mask_all = torch.cat(((train_l_label.unsqueeze(1) >= 0).float(), train_u_aug_mask.unsqueeze(1))) 195 | mask_all = F.interpolate(mask_all, size=pred_all.shape[2:], mode='nearest') 196 | 197 | label_l = F.interpolate(label_onehot(train_l_label, num_class), size=pred_all.shape[2:], mode='nearest') 198 | label_u = F.interpolate(label_onehot(train_u_aug_label, num_class), size=pred_all.shape[2:], mode='nearest') 199 | label_all = torch.cat((label_l, label_u)) 200 | 201 | prob_all = torch.softmax(pred_all, dim=1) 202 | 203 | prcl_loss = criterion['prcl_loss'](rep_all, uncer_all, label_all, mask_all, prob_all) 204 | total_loss = sup_loss + unsup_loss + prcl_loss * sche_d.value 205 | 206 | # Update Meter 207 | sup_loss_meter.update(sup_loss.item(), pred_all.shape[0]) 208 | unsup_loss_meter.update(unsup_loss.item(), pred_all.shape[0]) 209 | mious_conf_l.update(pred_l_large.argmax(1).flatten(), train_l_label.flatten()) 210 | mious_conf_u.update(pred_u_large_raw.argmax(1).flatten(), train_u_label.flatten()) 211 | contr_loss_meter.update(prcl_loss.item(), pred_all.shape[0]) 212 | optimizer.zero_grad() 213 | optimizer_uncer.zero_grad() 214 | total_loss.backward() 215 | optimizer.step() 216 | optimizer_uncer.step() 217 | model.module.ema_update() 218 | batch_time.update(time.time() - end) 219 | end = time.time() 220 | # if i % 20 ==0 and rank == 0: 221 | # progress.display(iter_i) 222 | scheduler.step() 223 | scheduler_uncer.step() 224 | sche_d.step() 225 | 226 | 227 | @torch.no_grad() 228 | def test(test_loader, model, rank): 229 | batch_time = AverageMeter('Time', ':6.3f') 230 | data_time = AverageMeter('Data', ':6.3f') 231 | miou_meter = ConfMatrix(num_classes=21, fmt=':6.4f', name='test_miou') 232 | 233 | # switch to eval mode 234 | model.eval() 235 | 236 | end = time.time() 237 | test_iter = iter(test_loader) 238 | for _ in range(len(test_loader)): 239 | data_time.update(time.time() - end) 240 | test_image, test_label = test_iter.next() 241 | test_image, test_label = test_image.cuda(), test_label.cuda() 242 | 243 | pred, _, _ = model(test_image) 244 | pred = F.interpolate(pred, size=test_label.shape[1:], mode='bilinear', align_corners=True) 245 | 246 | miou_meter.update(pred.argmax(1).flatten(), test_label.flatten()) 247 | batch_time.update(time.time() - end) 248 | end = time.time() 249 | 250 | mat = torch_dist_sum(rank, miou_meter.mat) # We refine the func without reshape 251 | miou = mean_intersection_over_union(mat[0]) 252 | 253 | return miou 254 | 255 | 256 | 257 | if __name__ == '__main__': 258 | parser = argparse.ArgumentParser() 259 | parser.add_argument('--port', type=int, default=23456) 260 | parser.add_argument('--config', type=str, default='') 261 | parser.add_argument('--resume', type=str, default='') 262 | parser.add_argument('--job_name', type=str, default='') 263 | 264 | 265 | main() 266 | 267 | -------------------------------------------------------------------------------- /generalframeworks/augmentation/transform.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image, ImageFilter 3 | 4 | import torch 5 | from typing import Tuple 6 | from torchvision import transforms 7 | import torchvision.transforms.functional as transform_f 8 | import random 9 | import numpy as np 10 | 11 | def batch_transform(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, crop_size: Tuple['h', 'w'], scale_size, 12 | apply_augmentation=False): 13 | image_list, label_list, logits_list = [], [], [] 14 | device = image.device 15 | 16 | for k in range(image.shape[0]): 17 | image_pil, label_pil, logits_pil = tensor_to_pil(image[k], label[k], logits[k]) 18 | aug_image, aug_label, aug_logits = transform(image_pil, label_pil, logits_pil, 19 | crop_size=crop_size, 20 | scale_size=scale_size, 21 | augmentation=apply_augmentation) 22 | image_list.append(aug_image.unsqueeze(0)) 23 | label_list.append(aug_label) 24 | logits_list.append(aug_logits) 25 | 26 | image_trans, label_trans, logits_trans = torch.cat(image_list).to(device), torch.cat(label_list).to(device), \ 27 | torch.cat(logits_list).to(device) 28 | return image_trans, label_trans, logits_trans 29 | 30 | def tensor_to_pil(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor): 31 | image = denormalise(image) 32 | image = transform_f.to_pil_image(image.cpu()) 33 | 34 | label = label.float() / 255. 35 | label = transform_f.to_pil_image(label.unsqueeze(0).cpu()) 36 | 37 | logits = transform_f.to_pil_image(logits.unsqueeze(0).cpu()) 38 | 39 | 40 | return image, label, logits 41 | 42 | def tensor_to_pil_1(image: torch.Tensor, label: torch.Tensor, uncertainty_u:torch.Tensor, logits: torch.Tensor, logits_all: torch.Tensor): 43 | image = denormalise(image) 44 | image = transform_f.to_pil_image(image.cpu()) 45 | 46 | label = label.float() / 255. 47 | label = transform_f.to_pil_image(label.unsqueeze(0).cpu()) 48 | uncertainty_u = uncertainty_u.float() / 255. 49 | uncertainty_u = transform_f.to_pil_image(uncertainty_u.unsqueeze(0).cpu()) 50 | logits_all_l = [] 51 | for i in range(logits_all.shape[0]): 52 | logits_all_l.append(transform_f.to_pil_image(logits_all[i].float().unsqueeze(0).cpu(), mode='F')) 53 | 54 | logits = transform_f.to_pil_image(logits.unsqueeze(0).cpu(), 'F') 55 | 56 | return image, label, uncertainty_u, logits, logits_all_l 57 | 58 | 59 | def denormalise(x, imagenet=True): 60 | if imagenet: 61 | x = transform_f.normalize(x, mean=[0., 0., 0.], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) 62 | x = transform_f.normalize(x, mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]) 63 | return x 64 | else: 65 | return (x + 1) / 2 66 | 67 | def transform(image, label, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), label_fill=255, augmentation=False): 68 | ''' 69 | Only apply on the 3d image (one image not batch) 70 | ''' 71 | # Random Rescale image 72 | raw_w, raw_h = image.size 73 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 74 | 75 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 76 | image = transform_f.resize(image, resized_size, Image.NEAREST) 77 | label = transform_f.resize(label, resized_size, Image.NEAREST) 78 | if logits is not None: 79 | logits = transform_f.resize(logits, resized_size, Image.NEAREST) 80 | 81 | # Adding padding if rescaled image size is less than crop size 82 | if crop_size == -1: # Use original image size without rop or padding 83 | crop_size = (raw_h, raw_w) 84 | 85 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 86 | right_pad, bottom_pad = max(crop_size[1] - resized_size[1], 0), max(crop_size[0] - resized_size[0], 0) 87 | image = transform_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 88 | label = transform_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=label_fill, padding_mode='constant') 89 | if logits is not None: 90 | logits = transform_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 91 | 92 | # Random Cropping 93 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 94 | image = transform_f.crop(image, i, j, h, w) 95 | label = transform_f.crop(label, i, j, h, w) 96 | if logits is not None: 97 | logits = transform_f.crop(logits, i, j, h, w) 98 | 99 | if augmentation: 100 | # Random Color jitter 101 | if torch.rand(1) > 0.2: 102 | color_transform = transforms.ColorJitter.get_params((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 103 | image = color_transform(image) 104 | 105 | # Rnadmom Gaussian filter 106 | if torch.rand(1) > 0.5: 107 | sigma = random.uniform(0.15, 1.15) 108 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 109 | 110 | # Random horizontal filpping 111 | if torch.rand(1) > 0.5: 112 | image = transform_f.hflip(image) 113 | label = transform_f.hflip(label) 114 | if logits is not None: 115 | logits = transform_f.hflip(logits) 116 | 117 | # Transform to Tensor 118 | image = transform_f.to_tensor(image) 119 | label = (transform_f.to_tensor(label) * 255).long() 120 | label[label == 255] = -1 # incalid pixels are re-mapping to index -1 121 | if logits is not None: 122 | logits = transform_f.to_tensor(logits) 123 | 124 | # Apply (ImageNet) normalization 125 | #image = transform_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 126 | image = transform_f.normalize(image, mean=[0.5], std=[0.299]) 127 | if logits is not None: 128 | return image, label, logits 129 | else: 130 | return image, label 131 | 132 | def generate_cut(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, mode='cutout'): 133 | batch_size, _, image_h, image_w = image.shape 134 | device = image.device 135 | 136 | new_image = [] 137 | new_label = [] 138 | new_logits = [] 139 | for i in range(batch_size): 140 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 141 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 142 | label[i][(1 - mix_mask).bool()] = -1 143 | 144 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 145 | new_label.append(label[i].unsqueeze(0)) 146 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 147 | continue 148 | elif mode == 'cutmix': 149 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 150 | elif mode == 'classmix': 151 | mix_mask = generate_class_mask(label[i]).to(device) 152 | else: 153 | raise ValueError('mode must be in cutout, cutmix, or classmix') 154 | 155 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 156 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 157 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 158 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 159 | 160 | return new_image, new_label.long(), new_logits 161 | 162 | 163 | 164 | def generate_cutout_mask(image_size, ratio=2): 165 | # Cutout: random generate mask where the region inside is 0, one ouside is 1 166 | cutout_area = image_size[0] * image_size[1] / ratio 167 | 168 | w = np.random.randint(image_size[1] / ratio + 1, image_size[1]) 169 | h = np.round(cutout_area / w) 170 | 171 | x_start = np.random.randint(0, image_size[1] - w + 1) 172 | y_start = np.random.randint(0, image_size[0] - h + 1) 173 | 174 | x_end = int(x_start + w) 175 | y_end = int(y_start + h) 176 | 177 | mask = torch.ones(image_size) 178 | mask[y_start: y_end, x_start: x_end] = 0 179 | 180 | return mask.float() 181 | 182 | def generate_class_mask(pseudo_labels: torch.Tensor): 183 | # select the half classes and cover up them 184 | labels = torch.unique(pseudo_labels) # all unique labels 185 | labels_select: torch.Tensor = labels[torch.randperm(len(labels))][:len(labels) // 2] # Randmoly select half of labels 186 | mask = (pseudo_labels.unsqueeze(-1) == labels_select).any(dim=-1) 187 | return mask.float() 188 | 189 | def batch_transform_1(data, label, uncertainty_u, logits, logits_all, crop_size, scale_size, apply_augmentation): 190 | data_list, label_list, uncertainty_u_list, logits_list, logits_all_list = [], [], [], [], [] 191 | device = data.device 192 | 193 | for k in range(data.shape[0]): 194 | data_pil, label_pil, uncertainty_u_pil, logits_pil, logits_all_pil = tensor_to_pil_1(data[k], label[k], uncertainty_u[k], logits[k], logits_all[k])##ok 195 | aug_data, aug_label, aug_uncertainty_u, aug_logits, aug_logits_all = transform_1(data_pil, label_pil, uncertainty_u_pil, logits_pil, logits_all_pil, 196 | crop_size=crop_size, 197 | scale_size=scale_size, 198 | augmentation=apply_augmentation) 199 | 200 | 201 | tmp = aug_label.squeeze(0).cuda().eq(aug_logits_all.cuda().argmax(0)) 202 | all = tmp.cuda().sum() + (aug_label.cuda() == -1).sum() 203 | data_list.append(aug_data.unsqueeze(0)) 204 | label_list.append(aug_label) 205 | uncertainty_u_list.append(aug_uncertainty_u) 206 | logits_list.append(aug_logits) 207 | logits_all_list.append(aug_logits_all.unsqueeze(0)) 208 | #ok 209 | 210 | data_trans, label_trans, uncertainty_u_trans, logits_trans, logits_all_trans = \ 211 | torch.cat(data_list).to(device), torch.cat(label_list).to(device), torch.cat(uncertainty_u_list).to(device), torch.cat(logits_list).to(device), torch.cat(logits_all_list).to(device) 212 | return data_trans, label_trans, uncertainty_u_trans, logits_trans, logits_all_trans 213 | 214 | def transform_1(image, label, uncertainty_u=None, logits=None, logits_all=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 215 | # Random rescale image 216 | 217 | raw_w, raw_h = image.size 218 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 219 | 220 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 221 | image = transform_f.resize(image, resized_size, Image.BILINEAR) 222 | label = transform_f.resize(label, resized_size, Image.NEAREST) 223 | if uncertainty_u is not None: 224 | uncertainty_u = transform_f.resize(uncertainty_u, resized_size, Image.NEAREST) 225 | if logits is not None: 226 | logits = transform_f.resize(logits, resized_size, Image.NEAREST) 227 | logits_all_l = [] 228 | if logits_all is not None: 229 | for logits_item in logits_all: 230 | logits_all_l.append(transform_f.resize(logits_item, resized_size, Image.NEAREST)) 231 | logits_all = logits_all_l 232 | 233 | # Add padding if rescaled image size is less than crop size 234 | if crop_size == -1: # use original im size without crop or padding 235 | crop_size = (raw_h, raw_w) 236 | 237 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 238 | ##ok 239 | right_pad, bottom_pad = max(crop_size[1] - resized_size[1], 0), max(crop_size[0] - resized_size[0], 0) 240 | image = transform_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 241 | label = transform_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 242 | if uncertainty_u is not None: 243 | uncertainty_u = transform_f.pad(uncertainty_u, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 244 | if logits is not None: 245 | logits = transform_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 246 | if logits_all is not None: 247 | logits_all_l_tmp = [] 248 | for logits_item in logits_all: 249 | logits_all_l_tmp.append(transform_f.pad(logits_item, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant')) 250 | logits_all = logits_all_l_tmp 251 | # ok 252 | 253 | 254 | # Random Cropping 255 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 256 | image = transform_f.crop(image, i, j, h, w) 257 | label = transform_f.crop(label, i, j, h, w) 258 | if uncertainty_u is not None: 259 | uncertainty_u = transform_f.crop(uncertainty_u, i, j, h, w) 260 | if logits is not None: 261 | logits = transform_f.crop(logits, i, j, h, w) 262 | if logits_all is not None: 263 | logits_all_l_tmp = [] 264 | for logits_item in logits_all: 265 | logits_all_l_tmp.append(transform_f.crop(logits_item, i, j, h, w)) 266 | logits_all = logits_all_l_tmp 267 | 268 | if augmentation: 269 | # Random color jitter 270 | if torch.rand(1) > 0.2: 271 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) # For PyTorch 1.9/TorchVision 0.10 users 272 | # color_transform = transforms.ColorJitter.get_params((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 273 | image = color_transform(image) 274 | 275 | # Random Gaussian filter 276 | if torch.rand(1) > 0.5: 277 | sigma = random.uniform(0.15, 1.15) 278 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 279 | 280 | # Random horizontal flipping 281 | if torch.rand(1) > 0.5: 282 | image = transform_f.hflip(image) 283 | label = transform_f.hflip(label) 284 | if uncertainty_u is not None: 285 | uncertainty_u = transform_f.hflip(uncertainty_u) 286 | if logits is not None: 287 | logits = transform_f.hflip(logits) 288 | if logits_all is not None: 289 | logits_all_l_tmp = [] 290 | for logits_item in logits_all: 291 | logits_all_l_tmp.append(transform_f.hflip(logits_item)) 292 | logits_all = logits_all_l_tmp 293 | 294 | # Transform to tensor 295 | image = transform_f.to_tensor(image) 296 | label = (transform_f.to_tensor(label) * 255).long() 297 | uncertainty_u = (transform_f.to_tensor(uncertainty_u) * 255).long() 298 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 299 | if logits is not None: 300 | logits = transform_f.to_tensor(logits) 301 | if uncertainty_u is not None: 302 | uncertainty_u[uncertainty_u == 255] = -1 303 | if logits_all is not None: 304 | logits_all_l_tmp = [] 305 | for logits_item in logits_all: 306 | logits_all_l_tmp.append(transform_f.to_tensor(logits_item)) 307 | logits_all = torch.cat(logits_all_l_tmp) 308 | 309 | # Apply (ImageNet) normalisation 310 | # image = transform_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 311 | if logits is not None and uncertainty_u is not None and logits_all is not None: 312 | return image, label, uncertainty_u, logits, logits_all 313 | elif logits is not None and uncertainty_u is None: 314 | return image, label, logits 315 | elif logits is None and uncertainty_u is not None: 316 | return image, label, uncertainty_u 317 | else: 318 | return image, label 319 | 320 | def generate_cut_1(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, uncertainty_u: torch.Tensor=None, logits_all=None, mode='cutout'): 321 | batch_size, _, image_h, image_w = image.shape 322 | device = image.device 323 | 324 | new_image = [] 325 | new_label = [] 326 | new_uncertainty_u = [] 327 | new_logits = [] 328 | new_logits_all = [] 329 | for i in range(batch_size): 330 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 331 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 332 | label[i][(1 - mix_mask).bool()] = -1 333 | if uncertainty_u is not None: 334 | uncertainty_u[i][(1 - mix_mask).bool()] = 0 335 | 336 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 337 | new_label.append(label[i].unsqueeze(0)) 338 | if uncertainty_u is not None: 339 | new_uncertainty_u.append(uncertainty_u[i].unsqueeze(0)) 340 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 341 | continue 342 | elif mode == 'cutmix': 343 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 344 | elif mode == 'classmix': 345 | mix_mask = generate_class_mask(label[i]).to(device) 346 | else: 347 | raise ValueError('mode must be in cutout, cutmix, or classmix') 348 | 349 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 350 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 351 | if uncertainty_u is not None: 352 | new_uncertainty_u.append((uncertainty_u[i] * mix_mask + uncertainty_u[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 353 | if logits_all is not None: 354 | new_logits_all.append((logits_all[i] * mix_mask + logits_all[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 355 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 356 | 357 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 358 | 359 | if uncertainty_u is not None and logits_all is not None: 360 | new_uncertainty_u = torch.cat(new_uncertainty_u) 361 | new_logits_all = torch.cat(new_logits_all) 362 | 363 | return new_image, new_label.long(), new_uncertainty_u.long(), new_logits, new_logits_all 364 | else: 365 | return new_image, new_label.long(), new_logits 366 | 367 | 368 | def batch_transform_2(data, label, uncertainty_u, logits, crop_size, scale_size, apply_augmentation): 369 | data_list, label_list, uncertainty_u_list, logits_list = [], [], [], [] 370 | device = data.device 371 | 372 | for k in range(data.shape[0]): 373 | data_pil, label_pil, logits_pil = tensor_to_pil(data[k], label[k], logits[k]) 374 | aug_data, aug_label, aug_uncertainty_u, aug_logits = transform_2(data_pil, label_pil, uncertainty_u[k].unsqueeze(0), logits_pil, 375 | crop_size=crop_size, 376 | scale_size=scale_size, 377 | augmentation=apply_augmentation) 378 | data_list.append(aug_data.unsqueeze(0)) 379 | label_list.append(aug_label) 380 | # uncertainty_u_list.append(aug_uncertainty_u.unsqueeze(0)) 381 | uncertainty_u_list.append(aug_uncertainty_u) 382 | logits_list.append(aug_logits) 383 | 384 | data_trans, label_trans, uncertainty_u_trans, logits_trans = \ 385 | torch.cat(data_list).to(device), torch.cat(label_list).to(device), torch.cat(uncertainty_u_list).to(device), torch.cat(logits_list).to(device) 386 | return data_trans, label_trans, uncertainty_u_trans, logits_trans 387 | 388 | def transform_2(image, label, uncertainty_u=None, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 389 | # Random rescale image 390 | raw_w, raw_h = image.size 391 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 392 | 393 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 394 | image = transform_f.resize(image, resized_size, Image.BILINEAR) 395 | label = transform_f.resize(label, resized_size, Image.NEAREST) 396 | if uncertainty_u is not None: 397 | uncertainty_u = transform_f.resize(uncertainty_u, resized_size, Image.NEAREST) 398 | if logits is not None: 399 | logits = transform_f.resize(logits, resized_size, Image.NEAREST) 400 | 401 | # Add padding if rescaled image size is less than crop size 402 | if crop_size == -1: # use original im size without crop or padding 403 | crop_size = (raw_h, raw_w) 404 | 405 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 406 | right_pad, bottom_pad = max(crop_size[1] - resized_size[1], 0), max(crop_size[0] - resized_size[0], 0) 407 | image = transform_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 408 | label = transform_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 409 | if uncertainty_u is not None: 410 | uncertainty_u = transform_f.pad(uncertainty_u, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 411 | if logits is not None: 412 | logits = transform_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 413 | 414 | # Random Cropping 415 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 416 | image = transform_f.crop(image, i, j, h, w) 417 | label = transform_f.crop(label, i, j, h, w) 418 | if uncertainty_u is not None: 419 | uncertainty_u = transform_f.crop(uncertainty_u, i, j, h, w) 420 | if logits is not None: 421 | logits = transform_f.crop(logits, i, j, h, w) 422 | 423 | if augmentation: 424 | # Random color jitter 425 | if torch.rand(1) > 0.2: 426 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) # For PyTorch 1.9/TorchVision 0.10 users 427 | # color_transform = transforms.ColorJitter.get_params((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 428 | image = color_transform(image) 429 | 430 | # Random Gaussian filter 431 | if torch.rand(1) > 0.5: 432 | sigma = random.uniform(0.15, 1.15) 433 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 434 | 435 | # Random horizontal flipping 436 | if torch.rand(1) > 0.5: 437 | image = transform_f.hflip(image) 438 | label = transform_f.hflip(label) 439 | if uncertainty_u is not None: 440 | uncertainty_u = transform_f.hflip(uncertainty_u) 441 | if logits is not None: 442 | logits = transform_f.hflip(logits) 443 | 444 | # Transform to tensor 445 | image = transform_f.to_tensor(image) 446 | label = (transform_f.to_tensor(label) * 255).long() 447 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 448 | if logits is not None: 449 | logits = transform_f.to_tensor(logits) 450 | 451 | # Apply (ImageNet) normalisation 452 | image = transform_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 453 | if logits is not None and uncertainty_u is not None: 454 | return image, label, uncertainty_u, logits 455 | elif logits is not None and uncertainty_u is None: 456 | return image, label, logits 457 | elif logits is None and uncertainty_u is not None: 458 | return image, label, uncertainty_u 459 | else: 460 | return image, label 461 | 462 | def generate_cut_2(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, uncertainty_u: torch.Tensor=None, mode='cutout'): 463 | batch_size, _, image_h, image_w = image.shape 464 | device = image.device 465 | 466 | new_image = [] 467 | new_label = [] 468 | new_uncertainty_u = [] 469 | new_logits = [] 470 | for i in range(batch_size): 471 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 472 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 473 | label[i][(1 - mix_mask).bool()] = -1 474 | if uncertainty_u is not None: 475 | uncertainty_u[i][(1 - mix_mask).bool()] = 0 476 | 477 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 478 | new_label.append(label[i].unsqueeze(0)) 479 | if uncertainty_u is not None: 480 | new_uncertainty_u.append(uncertainty_u[i].unsqueeze(0)) 481 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 482 | continue 483 | elif mode == 'cutmix': 484 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 485 | elif mode == 'classmix': 486 | mix_mask = generate_class_mask(label[i]).to(device) 487 | else: 488 | raise ValueError('mode must be in cutout, cutmix, or classmix') 489 | 490 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 491 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 492 | if uncertainty_u is not None: 493 | new_uncertainty_u.append((uncertainty_u[i] * mix_mask + uncertainty_u[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 494 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 495 | 496 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 497 | if uncertainty_u is not None: 498 | new_uncertainty_u = torch.cat(new_uncertainty_u) 499 | return new_image, new_label.long(), new_uncertainty_u, new_logits 500 | else: 501 | return new_image, new_label.long(), new_logits 502 | -------------------------------------------------------------------------------- /VOC_split/662/3407/valid_filename.txt: -------------------------------------------------------------------------------- 1 | 2007_000033 2 | 2007_000042 3 | 2007_000061 4 | 2007_000123 5 | 2007_000129 6 | 2007_000175 7 | 2007_000187 8 | 2007_000323 9 | 2007_000332 10 | 2007_000346 11 | 2007_000452 12 | 2007_000464 13 | 2007_000491 14 | 2007_000529 15 | 2007_000559 16 | 2007_000572 17 | 2007_000629 18 | 2007_000636 19 | 2007_000661 20 | 2007_000663 21 | 2007_000676 22 | 2007_000727 23 | 2007_000762 24 | 2007_000783 25 | 2007_000799 26 | 2007_000804 27 | 2007_000830 28 | 2007_000837 29 | 2007_000847 30 | 2007_000862 31 | 2007_000925 32 | 2007_000999 33 | 2007_001154 34 | 2007_001175 35 | 2007_001239 36 | 2007_001284 37 | 2007_001288 38 | 2007_001289 39 | 2007_001299 40 | 2007_001311 41 | 2007_001321 42 | 2007_001377 43 | 2007_001408 44 | 2007_001423 45 | 2007_001430 46 | 2007_001457 47 | 2007_001458 48 | 2007_001526 49 | 2007_001568 50 | 2007_001585 51 | 2007_001586 52 | 2007_001587 53 | 2007_001594 54 | 2007_001630 55 | 2007_001677 56 | 2007_001678 57 | 2007_001717 58 | 2007_001733 59 | 2007_001761 60 | 2007_001763 61 | 2007_001774 62 | 2007_001884 63 | 2007_001955 64 | 2007_002046 65 | 2007_002094 66 | 2007_002119 67 | 2007_002132 68 | 2007_002260 69 | 2007_002266 70 | 2007_002268 71 | 2007_002284 72 | 2007_002376 73 | 2007_002378 74 | 2007_002387 75 | 2007_002400 76 | 2007_002412 77 | 2007_002426 78 | 2007_002427 79 | 2007_002445 80 | 2007_002470 81 | 2007_002539 82 | 2007_002565 83 | 2007_002597 84 | 2007_002618 85 | 2007_002619 86 | 2007_002624 87 | 2007_002643 88 | 2007_002648 89 | 2007_002719 90 | 2007_002728 91 | 2007_002823 92 | 2007_002824 93 | 2007_002852 94 | 2007_002903 95 | 2007_003011 96 | 2007_003020 97 | 2007_003022 98 | 2007_003051 99 | 2007_003088 100 | 2007_003101 101 | 2007_003106 102 | 2007_003110 103 | 2007_003131 104 | 2007_003134 105 | 2007_003137 106 | 2007_003143 107 | 2007_003169 108 | 2007_003188 109 | 2007_003194 110 | 2007_003195 111 | 2007_003201 112 | 2007_003349 113 | 2007_003367 114 | 2007_003373 115 | 2007_003499 116 | 2007_003503 117 | 2007_003506 118 | 2007_003530 119 | 2007_003571 120 | 2007_003587 121 | 2007_003611 122 | 2007_003621 123 | 2007_003682 124 | 2007_003711 125 | 2007_003714 126 | 2007_003742 127 | 2007_003786 128 | 2007_003841 129 | 2007_003848 130 | 2007_003861 131 | 2007_003872 132 | 2007_003917 133 | 2007_003957 134 | 2007_003991 135 | 2007_004033 136 | 2007_004052 137 | 2007_004112 138 | 2007_004121 139 | 2007_004143 140 | 2007_004189 141 | 2007_004190 142 | 2007_004193 143 | 2007_004241 144 | 2007_004275 145 | 2007_004281 146 | 2007_004380 147 | 2007_004392 148 | 2007_004405 149 | 2007_004468 150 | 2007_004483 151 | 2007_004510 152 | 2007_004538 153 | 2007_004558 154 | 2007_004644 155 | 2007_004649 156 | 2007_004712 157 | 2007_004722 158 | 2007_004856 159 | 2007_004866 160 | 2007_004902 161 | 2007_004969 162 | 2007_005058 163 | 2007_005074 164 | 2007_005107 165 | 2007_005114 166 | 2007_005149 167 | 2007_005173 168 | 2007_005281 169 | 2007_005294 170 | 2007_005296 171 | 2007_005304 172 | 2007_005331 173 | 2007_005354 174 | 2007_005358 175 | 2007_005428 176 | 2007_005460 177 | 2007_005469 178 | 2007_005509 179 | 2007_005547 180 | 2007_005600 181 | 2007_005608 182 | 2007_005626 183 | 2007_005689 184 | 2007_005696 185 | 2007_005705 186 | 2007_005759 187 | 2007_005803 188 | 2007_005813 189 | 2007_005828 190 | 2007_005844 191 | 2007_005845 192 | 2007_005857 193 | 2007_005911 194 | 2007_005915 195 | 2007_005978 196 | 2007_006028 197 | 2007_006035 198 | 2007_006046 199 | 2007_006076 200 | 2007_006086 201 | 2007_006117 202 | 2007_006171 203 | 2007_006241 204 | 2007_006260 205 | 2007_006277 206 | 2007_006348 207 | 2007_006364 208 | 2007_006373 209 | 2007_006444 210 | 2007_006449 211 | 2007_006549 212 | 2007_006553 213 | 2007_006560 214 | 2007_006647 215 | 2007_006678 216 | 2007_006680 217 | 2007_006698 218 | 2007_006761 219 | 2007_006802 220 | 2007_006837 221 | 2007_006841 222 | 2007_006864 223 | 2007_006866 224 | 2007_006946 225 | 2007_007007 226 | 2007_007084 227 | 2007_007109 228 | 2007_007130 229 | 2007_007165 230 | 2007_007168 231 | 2007_007195 232 | 2007_007196 233 | 2007_007203 234 | 2007_007211 235 | 2007_007235 236 | 2007_007341 237 | 2007_007414 238 | 2007_007417 239 | 2007_007470 240 | 2007_007477 241 | 2007_007493 242 | 2007_007498 243 | 2007_007524 244 | 2007_007534 245 | 2007_007624 246 | 2007_007651 247 | 2007_007688 248 | 2007_007748 249 | 2007_007795 250 | 2007_007810 251 | 2007_007815 252 | 2007_007818 253 | 2007_007836 254 | 2007_007849 255 | 2007_007881 256 | 2007_007996 257 | 2007_008051 258 | 2007_008084 259 | 2007_008106 260 | 2007_008110 261 | 2007_008204 262 | 2007_008222 263 | 2007_008256 264 | 2007_008260 265 | 2007_008339 266 | 2007_008374 267 | 2007_008415 268 | 2007_008430 269 | 2007_008543 270 | 2007_008547 271 | 2007_008596 272 | 2007_008645 273 | 2007_008670 274 | 2007_008708 275 | 2007_008722 276 | 2007_008747 277 | 2007_008802 278 | 2007_008815 279 | 2007_008897 280 | 2007_008944 281 | 2007_008964 282 | 2007_008973 283 | 2007_008980 284 | 2007_009015 285 | 2007_009068 286 | 2007_009084 287 | 2007_009088 288 | 2007_009096 289 | 2007_009221 290 | 2007_009245 291 | 2007_009251 292 | 2007_009252 293 | 2007_009258 294 | 2007_009320 295 | 2007_009323 296 | 2007_009331 297 | 2007_009346 298 | 2007_009392 299 | 2007_009413 300 | 2007_009419 301 | 2007_009446 302 | 2007_009458 303 | 2007_009521 304 | 2007_009562 305 | 2007_009592 306 | 2007_009654 307 | 2007_009655 308 | 2007_009684 309 | 2007_009687 310 | 2007_009691 311 | 2007_009706 312 | 2007_009750 313 | 2007_009756 314 | 2007_009764 315 | 2007_009794 316 | 2007_009817 317 | 2007_009841 318 | 2007_009897 319 | 2007_009911 320 | 2007_009923 321 | 2007_009938 322 | 2008_000009 323 | 2008_000016 324 | 2008_000073 325 | 2008_000075 326 | 2008_000080 327 | 2008_000107 328 | 2008_000120 329 | 2008_000123 330 | 2008_000149 331 | 2008_000182 332 | 2008_000213 333 | 2008_000215 334 | 2008_000223 335 | 2008_000233 336 | 2008_000234 337 | 2008_000239 338 | 2008_000254 339 | 2008_000270 340 | 2008_000271 341 | 2008_000345 342 | 2008_000359 343 | 2008_000391 344 | 2008_000401 345 | 2008_000464 346 | 2008_000469 347 | 2008_000474 348 | 2008_000501 349 | 2008_000510 350 | 2008_000533 351 | 2008_000573 352 | 2008_000589 353 | 2008_000602 354 | 2008_000630 355 | 2008_000657 356 | 2008_000661 357 | 2008_000662 358 | 2008_000666 359 | 2008_000673 360 | 2008_000700 361 | 2008_000725 362 | 2008_000731 363 | 2008_000763 364 | 2008_000765 365 | 2008_000782 366 | 2008_000795 367 | 2008_000811 368 | 2008_000848 369 | 2008_000853 370 | 2008_000863 371 | 2008_000911 372 | 2008_000919 373 | 2008_000943 374 | 2008_000992 375 | 2008_001013 376 | 2008_001028 377 | 2008_001040 378 | 2008_001070 379 | 2008_001074 380 | 2008_001076 381 | 2008_001078 382 | 2008_001135 383 | 2008_001150 384 | 2008_001170 385 | 2008_001231 386 | 2008_001249 387 | 2008_001260 388 | 2008_001283 389 | 2008_001308 390 | 2008_001379 391 | 2008_001404 392 | 2008_001433 393 | 2008_001439 394 | 2008_001478 395 | 2008_001491 396 | 2008_001504 397 | 2008_001513 398 | 2008_001514 399 | 2008_001531 400 | 2008_001546 401 | 2008_001547 402 | 2008_001580 403 | 2008_001629 404 | 2008_001640 405 | 2008_001682 406 | 2008_001688 407 | 2008_001715 408 | 2008_001821 409 | 2008_001874 410 | 2008_001885 411 | 2008_001895 412 | 2008_001966 413 | 2008_001971 414 | 2008_001992 415 | 2008_002043 416 | 2008_002152 417 | 2008_002205 418 | 2008_002212 419 | 2008_002239 420 | 2008_002240 421 | 2008_002241 422 | 2008_002269 423 | 2008_002273 424 | 2008_002358 425 | 2008_002379 426 | 2008_002383 427 | 2008_002429 428 | 2008_002464 429 | 2008_002467 430 | 2008_002492 431 | 2008_002495 432 | 2008_002504 433 | 2008_002521 434 | 2008_002536 435 | 2008_002588 436 | 2008_002623 437 | 2008_002680 438 | 2008_002681 439 | 2008_002775 440 | 2008_002778 441 | 2008_002835 442 | 2008_002859 443 | 2008_002864 444 | 2008_002900 445 | 2008_002904 446 | 2008_002929 447 | 2008_002936 448 | 2008_002942 449 | 2008_002958 450 | 2008_003003 451 | 2008_003026 452 | 2008_003034 453 | 2008_003076 454 | 2008_003105 455 | 2008_003108 456 | 2008_003110 457 | 2008_003135 458 | 2008_003141 459 | 2008_003155 460 | 2008_003210 461 | 2008_003238 462 | 2008_003270 463 | 2008_003330 464 | 2008_003333 465 | 2008_003369 466 | 2008_003379 467 | 2008_003451 468 | 2008_003461 469 | 2008_003477 470 | 2008_003492 471 | 2008_003499 472 | 2008_003511 473 | 2008_003546 474 | 2008_003576 475 | 2008_003577 476 | 2008_003676 477 | 2008_003709 478 | 2008_003733 479 | 2008_003777 480 | 2008_003782 481 | 2008_003821 482 | 2008_003846 483 | 2008_003856 484 | 2008_003858 485 | 2008_003874 486 | 2008_003876 487 | 2008_003885 488 | 2008_003886 489 | 2008_003926 490 | 2008_003976 491 | 2008_004069 492 | 2008_004101 493 | 2008_004140 494 | 2008_004172 495 | 2008_004175 496 | 2008_004212 497 | 2008_004279 498 | 2008_004339 499 | 2008_004345 500 | 2008_004363 501 | 2008_004367 502 | 2008_004396 503 | 2008_004399 504 | 2008_004453 505 | 2008_004477 506 | 2008_004552 507 | 2008_004562 508 | 2008_004575 509 | 2008_004610 510 | 2008_004612 511 | 2008_004621 512 | 2008_004624 513 | 2008_004654 514 | 2008_004659 515 | 2008_004687 516 | 2008_004701 517 | 2008_004704 518 | 2008_004705 519 | 2008_004754 520 | 2008_004758 521 | 2008_004854 522 | 2008_004910 523 | 2008_004995 524 | 2008_005049 525 | 2008_005089 526 | 2008_005097 527 | 2008_005105 528 | 2008_005145 529 | 2008_005197 530 | 2008_005217 531 | 2008_005242 532 | 2008_005245 533 | 2008_005254 534 | 2008_005262 535 | 2008_005338 536 | 2008_005398 537 | 2008_005399 538 | 2008_005422 539 | 2008_005439 540 | 2008_005445 541 | 2008_005525 542 | 2008_005544 543 | 2008_005628 544 | 2008_005633 545 | 2008_005637 546 | 2008_005642 547 | 2008_005676 548 | 2008_005680 549 | 2008_005691 550 | 2008_005727 551 | 2008_005738 552 | 2008_005812 553 | 2008_005904 554 | 2008_005915 555 | 2008_006008 556 | 2008_006036 557 | 2008_006055 558 | 2008_006063 559 | 2008_006108 560 | 2008_006130 561 | 2008_006143 562 | 2008_006159 563 | 2008_006216 564 | 2008_006219 565 | 2008_006229 566 | 2008_006254 567 | 2008_006275 568 | 2008_006325 569 | 2008_006327 570 | 2008_006341 571 | 2008_006408 572 | 2008_006480 573 | 2008_006523 574 | 2008_006526 575 | 2008_006528 576 | 2008_006553 577 | 2008_006554 578 | 2008_006703 579 | 2008_006722 580 | 2008_006752 581 | 2008_006784 582 | 2008_006835 583 | 2008_006874 584 | 2008_006981 585 | 2008_006986 586 | 2008_007025 587 | 2008_007031 588 | 2008_007048 589 | 2008_007120 590 | 2008_007123 591 | 2008_007143 592 | 2008_007194 593 | 2008_007219 594 | 2008_007273 595 | 2008_007350 596 | 2008_007378 597 | 2008_007392 598 | 2008_007402 599 | 2008_007497 600 | 2008_007498 601 | 2008_007507 602 | 2008_007513 603 | 2008_007527 604 | 2008_007548 605 | 2008_007596 606 | 2008_007677 607 | 2008_007737 608 | 2008_007797 609 | 2008_007804 610 | 2008_007811 611 | 2008_007814 612 | 2008_007828 613 | 2008_007836 614 | 2008_007945 615 | 2008_007994 616 | 2008_008051 617 | 2008_008103 618 | 2008_008127 619 | 2008_008221 620 | 2008_008252 621 | 2008_008268 622 | 2008_008296 623 | 2008_008301 624 | 2008_008335 625 | 2008_008362 626 | 2008_008392 627 | 2008_008393 628 | 2008_008421 629 | 2008_008434 630 | 2008_008469 631 | 2008_008629 632 | 2008_008682 633 | 2008_008711 634 | 2008_008746 635 | 2009_000012 636 | 2009_000013 637 | 2009_000022 638 | 2009_000032 639 | 2009_000037 640 | 2009_000039 641 | 2009_000074 642 | 2009_000080 643 | 2009_000087 644 | 2009_000096 645 | 2009_000121 646 | 2009_000136 647 | 2009_000149 648 | 2009_000156 649 | 2009_000201 650 | 2009_000205 651 | 2009_000219 652 | 2009_000242 653 | 2009_000309 654 | 2009_000318 655 | 2009_000335 656 | 2009_000351 657 | 2009_000354 658 | 2009_000387 659 | 2009_000391 660 | 2009_000412 661 | 2009_000418 662 | 2009_000421 663 | 2009_000426 664 | 2009_000440 665 | 2009_000446 666 | 2009_000455 667 | 2009_000457 668 | 2009_000469 669 | 2009_000487 670 | 2009_000488 671 | 2009_000523 672 | 2009_000573 673 | 2009_000619 674 | 2009_000628 675 | 2009_000641 676 | 2009_000664 677 | 2009_000675 678 | 2009_000704 679 | 2009_000705 680 | 2009_000712 681 | 2009_000716 682 | 2009_000723 683 | 2009_000727 684 | 2009_000730 685 | 2009_000731 686 | 2009_000732 687 | 2009_000771 688 | 2009_000825 689 | 2009_000828 690 | 2009_000839 691 | 2009_000840 692 | 2009_000845 693 | 2009_000879 694 | 2009_000892 695 | 2009_000919 696 | 2009_000924 697 | 2009_000931 698 | 2009_000935 699 | 2009_000964 700 | 2009_000989 701 | 2009_000991 702 | 2009_000998 703 | 2009_001008 704 | 2009_001082 705 | 2009_001108 706 | 2009_001160 707 | 2009_001215 708 | 2009_001240 709 | 2009_001255 710 | 2009_001278 711 | 2009_001299 712 | 2009_001300 713 | 2009_001314 714 | 2009_001332 715 | 2009_001333 716 | 2009_001363 717 | 2009_001391 718 | 2009_001411 719 | 2009_001433 720 | 2009_001505 721 | 2009_001535 722 | 2009_001536 723 | 2009_001565 724 | 2009_001607 725 | 2009_001644 726 | 2009_001663 727 | 2009_001683 728 | 2009_001684 729 | 2009_001687 730 | 2009_001718 731 | 2009_001731 732 | 2009_001765 733 | 2009_001768 734 | 2009_001775 735 | 2009_001804 736 | 2009_001816 737 | 2009_001818 738 | 2009_001850 739 | 2009_001851 740 | 2009_001854 741 | 2009_001941 742 | 2009_001991 743 | 2009_002012 744 | 2009_002035 745 | 2009_002042 746 | 2009_002082 747 | 2009_002094 748 | 2009_002097 749 | 2009_002122 750 | 2009_002150 751 | 2009_002155 752 | 2009_002164 753 | 2009_002165 754 | 2009_002171 755 | 2009_002185 756 | 2009_002202 757 | 2009_002221 758 | 2009_002238 759 | 2009_002239 760 | 2009_002265 761 | 2009_002268 762 | 2009_002291 763 | 2009_002295 764 | 2009_002317 765 | 2009_002320 766 | 2009_002346 767 | 2009_002366 768 | 2009_002372 769 | 2009_002382 770 | 2009_002390 771 | 2009_002415 772 | 2009_002445 773 | 2009_002487 774 | 2009_002521 775 | 2009_002527 776 | 2009_002535 777 | 2009_002539 778 | 2009_002549 779 | 2009_002562 780 | 2009_002568 781 | 2009_002571 782 | 2009_002573 783 | 2009_002584 784 | 2009_002591 785 | 2009_002594 786 | 2009_002604 787 | 2009_002618 788 | 2009_002635 789 | 2009_002638 790 | 2009_002649 791 | 2009_002651 792 | 2009_002727 793 | 2009_002732 794 | 2009_002749 795 | 2009_002753 796 | 2009_002771 797 | 2009_002808 798 | 2009_002856 799 | 2009_002887 800 | 2009_002888 801 | 2009_002928 802 | 2009_002936 803 | 2009_002975 804 | 2009_002982 805 | 2009_002990 806 | 2009_003003 807 | 2009_003005 808 | 2009_003043 809 | 2009_003059 810 | 2009_003063 811 | 2009_003065 812 | 2009_003071 813 | 2009_003080 814 | 2009_003105 815 | 2009_003123 816 | 2009_003193 817 | 2009_003196 818 | 2009_003217 819 | 2009_003224 820 | 2009_003241 821 | 2009_003269 822 | 2009_003273 823 | 2009_003299 824 | 2009_003304 825 | 2009_003311 826 | 2009_003323 827 | 2009_003343 828 | 2009_003378 829 | 2009_003387 830 | 2009_003406 831 | 2009_003433 832 | 2009_003450 833 | 2009_003466 834 | 2009_003481 835 | 2009_003494 836 | 2009_003498 837 | 2009_003504 838 | 2009_003507 839 | 2009_003517 840 | 2009_003523 841 | 2009_003542 842 | 2009_003549 843 | 2009_003551 844 | 2009_003564 845 | 2009_003569 846 | 2009_003576 847 | 2009_003589 848 | 2009_003607 849 | 2009_003640 850 | 2009_003666 851 | 2009_003696 852 | 2009_003703 853 | 2009_003707 854 | 2009_003756 855 | 2009_003771 856 | 2009_003773 857 | 2009_003804 858 | 2009_003806 859 | 2009_003810 860 | 2009_003849 861 | 2009_003857 862 | 2009_003858 863 | 2009_003895 864 | 2009_003903 865 | 2009_003904 866 | 2009_003928 867 | 2009_003938 868 | 2009_003971 869 | 2009_003991 870 | 2009_004021 871 | 2009_004033 872 | 2009_004043 873 | 2009_004070 874 | 2009_004072 875 | 2009_004084 876 | 2009_004099 877 | 2009_004125 878 | 2009_004140 879 | 2009_004217 880 | 2009_004221 881 | 2009_004247 882 | 2009_004248 883 | 2009_004255 884 | 2009_004298 885 | 2009_004324 886 | 2009_004455 887 | 2009_004494 888 | 2009_004497 889 | 2009_004504 890 | 2009_004507 891 | 2009_004509 892 | 2009_004540 893 | 2009_004568 894 | 2009_004579 895 | 2009_004581 896 | 2009_004590 897 | 2009_004592 898 | 2009_004594 899 | 2009_004635 900 | 2009_004653 901 | 2009_004687 902 | 2009_004721 903 | 2009_004730 904 | 2009_004732 905 | 2009_004738 906 | 2009_004748 907 | 2009_004789 908 | 2009_004799 909 | 2009_004801 910 | 2009_004848 911 | 2009_004859 912 | 2009_004867 913 | 2009_004882 914 | 2009_004886 915 | 2009_004895 916 | 2009_004942 917 | 2009_004969 918 | 2009_004987 919 | 2009_004993 920 | 2009_004994 921 | 2009_005038 922 | 2009_005078 923 | 2009_005087 924 | 2009_005089 925 | 2009_005137 926 | 2009_005148 927 | 2009_005156 928 | 2009_005158 929 | 2009_005189 930 | 2009_005190 931 | 2009_005217 932 | 2009_005219 933 | 2009_005220 934 | 2009_005231 935 | 2009_005260 936 | 2009_005262 937 | 2009_005302 938 | 2010_000003 939 | 2010_000038 940 | 2010_000065 941 | 2010_000083 942 | 2010_000084 943 | 2010_000087 944 | 2010_000110 945 | 2010_000159 946 | 2010_000160 947 | 2010_000163 948 | 2010_000174 949 | 2010_000216 950 | 2010_000238 951 | 2010_000241 952 | 2010_000256 953 | 2010_000272 954 | 2010_000284 955 | 2010_000309 956 | 2010_000318 957 | 2010_000330 958 | 2010_000335 959 | 2010_000342 960 | 2010_000372 961 | 2010_000422 962 | 2010_000426 963 | 2010_000427 964 | 2010_000502 965 | 2010_000530 966 | 2010_000552 967 | 2010_000559 968 | 2010_000572 969 | 2010_000573 970 | 2010_000622 971 | 2010_000628 972 | 2010_000639 973 | 2010_000666 974 | 2010_000679 975 | 2010_000682 976 | 2010_000683 977 | 2010_000724 978 | 2010_000738 979 | 2010_000764 980 | 2010_000788 981 | 2010_000814 982 | 2010_000836 983 | 2010_000874 984 | 2010_000904 985 | 2010_000906 986 | 2010_000907 987 | 2010_000918 988 | 2010_000929 989 | 2010_000941 990 | 2010_000952 991 | 2010_000961 992 | 2010_001000 993 | 2010_001010 994 | 2010_001011 995 | 2010_001016 996 | 2010_001017 997 | 2010_001024 998 | 2010_001036 999 | 2010_001061 1000 | 2010_001069 1001 | 2010_001070 1002 | 2010_001079 1003 | 2010_001104 1004 | 2010_001124 1005 | 2010_001149 1006 | 2010_001151 1007 | 2010_001174 1008 | 2010_001206 1009 | 2010_001246 1010 | 2010_001251 1011 | 2010_001256 1012 | 2010_001264 1013 | 2010_001292 1014 | 2010_001313 1015 | 2010_001327 1016 | 2010_001331 1017 | 2010_001351 1018 | 2010_001367 1019 | 2010_001376 1020 | 2010_001403 1021 | 2010_001448 1022 | 2010_001451 1023 | 2010_001522 1024 | 2010_001534 1025 | 2010_001553 1026 | 2010_001557 1027 | 2010_001563 1028 | 2010_001577 1029 | 2010_001579 1030 | 2010_001646 1031 | 2010_001656 1032 | 2010_001692 1033 | 2010_001699 1034 | 2010_001734 1035 | 2010_001752 1036 | 2010_001767 1037 | 2010_001768 1038 | 2010_001773 1039 | 2010_001820 1040 | 2010_001830 1041 | 2010_001851 1042 | 2010_001908 1043 | 2010_001913 1044 | 2010_001951 1045 | 2010_001956 1046 | 2010_001962 1047 | 2010_001966 1048 | 2010_001995 1049 | 2010_002017 1050 | 2010_002025 1051 | 2010_002030 1052 | 2010_002106 1053 | 2010_002137 1054 | 2010_002142 1055 | 2010_002146 1056 | 2010_002147 1057 | 2010_002150 1058 | 2010_002161 1059 | 2010_002200 1060 | 2010_002228 1061 | 2010_002232 1062 | 2010_002251 1063 | 2010_002271 1064 | 2010_002305 1065 | 2010_002310 1066 | 2010_002336 1067 | 2010_002348 1068 | 2010_002361 1069 | 2010_002390 1070 | 2010_002396 1071 | 2010_002422 1072 | 2010_002450 1073 | 2010_002480 1074 | 2010_002512 1075 | 2010_002531 1076 | 2010_002536 1077 | 2010_002538 1078 | 2010_002546 1079 | 2010_002623 1080 | 2010_002682 1081 | 2010_002691 1082 | 2010_002693 1083 | 2010_002701 1084 | 2010_002763 1085 | 2010_002792 1086 | 2010_002868 1087 | 2010_002900 1088 | 2010_002902 1089 | 2010_002921 1090 | 2010_002929 1091 | 2010_002939 1092 | 2010_002988 1093 | 2010_003014 1094 | 2010_003060 1095 | 2010_003123 1096 | 2010_003127 1097 | 2010_003132 1098 | 2010_003168 1099 | 2010_003183 1100 | 2010_003187 1101 | 2010_003207 1102 | 2010_003231 1103 | 2010_003239 1104 | 2010_003275 1105 | 2010_003276 1106 | 2010_003293 1107 | 2010_003302 1108 | 2010_003325 1109 | 2010_003362 1110 | 2010_003365 1111 | 2010_003381 1112 | 2010_003402 1113 | 2010_003409 1114 | 2010_003418 1115 | 2010_003446 1116 | 2010_003453 1117 | 2010_003468 1118 | 2010_003473 1119 | 2010_003495 1120 | 2010_003506 1121 | 2010_003514 1122 | 2010_003531 1123 | 2010_003532 1124 | 2010_003541 1125 | 2010_003547 1126 | 2010_003597 1127 | 2010_003675 1128 | 2010_003708 1129 | 2010_003716 1130 | 2010_003746 1131 | 2010_003758 1132 | 2010_003764 1133 | 2010_003768 1134 | 2010_003771 1135 | 2010_003772 1136 | 2010_003781 1137 | 2010_003813 1138 | 2010_003820 1139 | 2010_003854 1140 | 2010_003912 1141 | 2010_003915 1142 | 2010_003947 1143 | 2010_003956 1144 | 2010_003971 1145 | 2010_004041 1146 | 2010_004042 1147 | 2010_004056 1148 | 2010_004063 1149 | 2010_004104 1150 | 2010_004120 1151 | 2010_004149 1152 | 2010_004165 1153 | 2010_004208 1154 | 2010_004219 1155 | 2010_004226 1156 | 2010_004314 1157 | 2010_004320 1158 | 2010_004322 1159 | 2010_004337 1160 | 2010_004348 1161 | 2010_004355 1162 | 2010_004369 1163 | 2010_004382 1164 | 2010_004419 1165 | 2010_004432 1166 | 2010_004472 1167 | 2010_004479 1168 | 2010_004519 1169 | 2010_004520 1170 | 2010_004529 1171 | 2010_004543 1172 | 2010_004550 1173 | 2010_004551 1174 | 2010_004556 1175 | 2010_004559 1176 | 2010_004628 1177 | 2010_004635 1178 | 2010_004662 1179 | 2010_004697 1180 | 2010_004757 1181 | 2010_004763 1182 | 2010_004772 1183 | 2010_004783 1184 | 2010_004789 1185 | 2010_004795 1186 | 2010_004815 1187 | 2010_004825 1188 | 2010_004828 1189 | 2010_004856 1190 | 2010_004857 1191 | 2010_004861 1192 | 2010_004941 1193 | 2010_004946 1194 | 2010_004951 1195 | 2010_004980 1196 | 2010_004994 1197 | 2010_005013 1198 | 2010_005021 1199 | 2010_005046 1200 | 2010_005063 1201 | 2010_005108 1202 | 2010_005118 1203 | 2010_005159 1204 | 2010_005160 1205 | 2010_005166 1206 | 2010_005174 1207 | 2010_005180 1208 | 2010_005187 1209 | 2010_005206 1210 | 2010_005245 1211 | 2010_005252 1212 | 2010_005284 1213 | 2010_005305 1214 | 2010_005344 1215 | 2010_005353 1216 | 2010_005366 1217 | 2010_005401 1218 | 2010_005421 1219 | 2010_005428 1220 | 2010_005432 1221 | 2010_005433 1222 | 2010_005496 1223 | 2010_005501 1224 | 2010_005508 1225 | 2010_005531 1226 | 2010_005534 1227 | 2010_005575 1228 | 2010_005582 1229 | 2010_005606 1230 | 2010_005626 1231 | 2010_005644 1232 | 2010_005664 1233 | 2010_005705 1234 | 2010_005706 1235 | 2010_005709 1236 | 2010_005718 1237 | 2010_005719 1238 | 2010_005727 1239 | 2010_005762 1240 | 2010_005788 1241 | 2010_005860 1242 | 2010_005871 1243 | 2010_005877 1244 | 2010_005888 1245 | 2010_005899 1246 | 2010_005922 1247 | 2010_005991 1248 | 2010_005992 1249 | 2010_006026 1250 | 2010_006034 1251 | 2010_006054 1252 | 2010_006070 1253 | 2011_000045 1254 | 2011_000051 1255 | 2011_000054 1256 | 2011_000066 1257 | 2011_000070 1258 | 2011_000112 1259 | 2011_000173 1260 | 2011_000178 1261 | 2011_000185 1262 | 2011_000226 1263 | 2011_000234 1264 | 2011_000238 1265 | 2011_000239 1266 | 2011_000248 1267 | 2011_000283 1268 | 2011_000291 1269 | 2011_000310 1270 | 2011_000312 1271 | 2011_000338 1272 | 2011_000396 1273 | 2011_000412 1274 | 2011_000419 1275 | 2011_000435 1276 | 2011_000436 1277 | 2011_000438 1278 | 2011_000455 1279 | 2011_000456 1280 | 2011_000479 1281 | 2011_000481 1282 | 2011_000482 1283 | 2011_000503 1284 | 2011_000512 1285 | 2011_000521 1286 | 2011_000526 1287 | 2011_000536 1288 | 2011_000548 1289 | 2011_000566 1290 | 2011_000585 1291 | 2011_000598 1292 | 2011_000607 1293 | 2011_000618 1294 | 2011_000638 1295 | 2011_000658 1296 | 2011_000661 1297 | 2011_000669 1298 | 2011_000747 1299 | 2011_000780 1300 | 2011_000789 1301 | 2011_000807 1302 | 2011_000809 1303 | 2011_000813 1304 | 2011_000830 1305 | 2011_000843 1306 | 2011_000874 1307 | 2011_000888 1308 | 2011_000900 1309 | 2011_000912 1310 | 2011_000953 1311 | 2011_000969 1312 | 2011_001005 1313 | 2011_001014 1314 | 2011_001020 1315 | 2011_001047 1316 | 2011_001060 1317 | 2011_001064 1318 | 2011_001069 1319 | 2011_001071 1320 | 2011_001082 1321 | 2011_001110 1322 | 2011_001114 1323 | 2011_001159 1324 | 2011_001161 1325 | 2011_001190 1326 | 2011_001232 1327 | 2011_001263 1328 | 2011_001276 1329 | 2011_001281 1330 | 2011_001287 1331 | 2011_001292 1332 | 2011_001313 1333 | 2011_001341 1334 | 2011_001346 1335 | 2011_001350 1336 | 2011_001407 1337 | 2011_001416 1338 | 2011_001421 1339 | 2011_001434 1340 | 2011_001447 1341 | 2011_001489 1342 | 2011_001529 1343 | 2011_001530 1344 | 2011_001534 1345 | 2011_001546 1346 | 2011_001567 1347 | 2011_001589 1348 | 2011_001597 1349 | 2011_001601 1350 | 2011_001607 1351 | 2011_001613 1352 | 2011_001614 1353 | 2011_001619 1354 | 2011_001624 1355 | 2011_001642 1356 | 2011_001665 1357 | 2011_001669 1358 | 2011_001674 1359 | 2011_001708 1360 | 2011_001713 1361 | 2011_001714 1362 | 2011_001722 1363 | 2011_001726 1364 | 2011_001745 1365 | 2011_001748 1366 | 2011_001775 1367 | 2011_001782 1368 | 2011_001793 1369 | 2011_001794 1370 | 2011_001812 1371 | 2011_001862 1372 | 2011_001863 1373 | 2011_001868 1374 | 2011_001880 1375 | 2011_001910 1376 | 2011_001984 1377 | 2011_001988 1378 | 2011_002002 1379 | 2011_002040 1380 | 2011_002041 1381 | 2011_002064 1382 | 2011_002075 1383 | 2011_002098 1384 | 2011_002110 1385 | 2011_002121 1386 | 2011_002124 1387 | 2011_002150 1388 | 2011_002156 1389 | 2011_002178 1390 | 2011_002200 1391 | 2011_002223 1392 | 2011_002244 1393 | 2011_002247 1394 | 2011_002279 1395 | 2011_002295 1396 | 2011_002298 1397 | 2011_002308 1398 | 2011_002317 1399 | 2011_002322 1400 | 2011_002327 1401 | 2011_002343 1402 | 2011_002358 1403 | 2011_002371 1404 | 2011_002379 1405 | 2011_002391 1406 | 2011_002498 1407 | 2011_002509 1408 | 2011_002515 1409 | 2011_002532 1410 | 2011_002535 1411 | 2011_002548 1412 | 2011_002575 1413 | 2011_002578 1414 | 2011_002589 1415 | 2011_002592 1416 | 2011_002623 1417 | 2011_002641 1418 | 2011_002644 1419 | 2011_002662 1420 | 2011_002675 1421 | 2011_002685 1422 | 2011_002713 1423 | 2011_002730 1424 | 2011_002754 1425 | 2011_002812 1426 | 2011_002863 1427 | 2011_002879 1428 | 2011_002885 1429 | 2011_002929 1430 | 2011_002951 1431 | 2011_002975 1432 | 2011_002993 1433 | 2011_002997 1434 | 2011_003003 1435 | 2011_003011 1436 | 2011_003019 1437 | 2011_003030 1438 | 2011_003055 1439 | 2011_003085 1440 | 2011_003103 1441 | 2011_003114 1442 | 2011_003145 1443 | 2011_003146 1444 | 2011_003182 1445 | 2011_003197 1446 | 2011_003205 1447 | 2011_003240 1448 | 2011_003256 1449 | 2011_003271 --------------------------------------------------------------------------------