├── README.md ├── config ├── CityScapes_config_baseline.yaml └── VOC_config_baseline.yaml ├── cross_label.py ├── generalframeworks ├── augmentation │ ├── __init__.py │ └── transform.py ├── dataset_helpers │ ├── Cityscapes.py │ ├── VOC.py │ ├── __init__.py │ └── __pycache__ │ │ ├── Cityscapes.cpython-37.pyc │ │ ├── Cityscapes.cpython-38.pyc │ │ ├── VOC.cpython-36.pyc │ │ ├── VOC.cpython-37.pyc │ │ ├── VOC.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── __init__.cpython-38.pyc ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── loss.cpython-37.pyc │ │ └── loss.cpython-38.pyc │ └── loss.py ├── meter │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── meter.cpython-38.pyc │ ├── mIOU_metrics.py │ └── meter.py ├── networks │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── ddp_model.cpython-37.pyc │ │ ├── ddp_model.cpython-38.pyc │ │ ├── module.cpython-38.pyc │ │ ├── uncer_head.cpython-37.pyc │ │ └── uncer_head.cpython-38.pyc │ ├── ddp_model.py │ ├── deeplabv3 │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── aspp.cpython-37.pyc │ │ │ ├── aspp.cpython-38.pyc │ │ │ ├── deeplabv3.cpython-37.pyc │ │ │ └── deeplabv3.cpython-38.pyc │ │ ├── aspp.py │ │ └── deeplabv3.py │ ├── module.py │ └── resnet.py ├── scheduler │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── my_lr_scheduler.cpython-37.pyc │ │ ├── my_lr_scheduler.cpython-38.pyc │ │ ├── rampscheduler.cpython-37.pyc │ │ └── rampscheduler.cpython-38.pyc │ ├── my_lr_scheduler.py │ └── rampscheduler.py ├── util │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── dist_init.cpython-38.pyc │ │ ├── meter.cpython-37.pyc │ │ ├── meter.cpython-38.pyc │ │ ├── miou.cpython-37.pyc │ │ ├── miou.cpython-38.pyc │ │ ├── torch_dist_sum.cpython-37.pyc │ │ └── torch_dist_sum.cpython-38.pyc │ ├── dist_init.py │ ├── meter.py │ ├── miou.py │ └── torch_dist_sum.py └── utils.py ├── mix_label.py ├── ori_pseudo.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Space Engage: Collaborative Space Supervision for Contrastive-based Semi-Supervised Semantic Segmentation (ICCV 2023) 2 | ![cover figure.pdf](https://github.com/WangChangqi98/CSS/files/12594221/cover.figure.pdf) 3 | 4 | This repository contains the code of **CSS** from the paper: [Space Engage: Collaborative Space Supervision for Contrastive-based Semi-Supervised Semantic Segmentation](https://arxiv.org/pdf/2307.09755.pdf) 5 | 6 | In this paper, we propose a novel apporach to use the pseudo-labels from the logit and representation space in a collabrative way. Meanwhile, we use the softmax similarity as the indicator to tilt training in representation space. 7 | ## Updates 8 | **Sep. 2023** -- Upload the code. 9 | 10 | ## Prepare 11 | CSS is evaluated with two datasets: PASCAL VOC 2012 and CityScapes. 12 | - For PASCAL VOC, please download the original training images from the [official PASCAL site](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar): `VOCtrainval_11-May-2012.tar` and the augmented labels [here](http://vllab1.ucmerced.edu/~whung/adv-semi-seg/SegmentationClassAug.zip): `SegmentationClassAug.zip`. 13 | Extract the folder `JPEGImages` and `SegmentationClassAug` as follows: 14 | ``` 15 | ├── data 16 | │ ├── VOCdevkit 17 | │ │ ├──VOC2012 18 | │ │ | ├──JPEGImages 19 | │ │ | ├──SegmentationClassAug 20 | ``` 21 | - For CityScapes, please download the original images and labels from the [official CityScapes site](https://www.cityscapes-dataset.com/downloads/): `leftImg8bit_trainvaltest.zip` and `gtFine_trainvaltest.zip`. 22 | Extract the folder `leftImg8bit_trainvaltest.zip` and `gtFine_trainvaltest.zip` as follows: 23 | ``` 24 | ├── data 25 | │ ├── cityscapes 26 | │ │ ├──leftImg8bit 27 | │ │ | ├──train 28 | │ │ | ├──val 29 | │ │ ├──train 30 | │ │ ├──val 31 | ``` 32 | Folders `train` and `val` under `leftImg8bit` contains training and validation images while folders `train` and `val` under `leftImg8bit` contains labels. 33 | 34 | The data split folder of VOC and CityScapes is as follows: 35 | ``` 36 | ├── VOC(CityScapes)_split 37 | │ ├── labeled number 38 | │ │ ├──seed 39 | │ │ | ├──labeled_filename.txt 40 | │ │ | ├──unlabeled_filename.txt 41 | │ │ | ├──valid_filename.txt 42 | ``` 43 | You need to change the name of folders (labeled number and seed) according to your actual experiments. 44 | 45 | CSS uses ResNet-101 pretrained on ImageNet and ResNet-101 with deep stem block, please download from [here](https://download.pytorch.org/models/resnet101-63fe2227.pth) for ResNet-101 and [here](https://drive.google.com/file/d/131dWv_zbr1ADUr_8H6lNyuGWsItHygSb/view?usp=sharing) for ResNet-101 stem. Remember to change the directory in corresponding python file. 46 | 47 | In order to install the correct environment, please run the following script: 48 | ``` 49 | conda create -n css_env python=3.8.5 50 | conda activate css_env 51 | pip install -r requirements.txt 52 | ``` 53 | It may takes a long time, take a break and have a cup of coffee! 54 | It is OK if you want to install environment manually, remember to check CAREFULLY! 55 | 56 | ## Run 57 | You can run our code with multiple GPUs. 58 | - For our baseline, please run the following script: 59 | ``` 60 | python ori_pseudo.py 61 | ``` 62 | - For the mix label strategy, please run the following script: 63 | ``` 64 | python mix_label.py 65 | ``` 66 | - For the cross label strategy, please run the following script: 67 | ``` 68 | python cross_label.py 69 | ``` 70 | The seed in our experiments is 3407. You can change the label rate and seed as you like, remember to change the corresponding config files and data_split directory. 71 | ## Hyper-parameters 72 | Some critical hyper-parameters used in the code are shown below: 73 | |Name | Discription | Value | 74 | | :-: |:-:| :-:| 75 | | `alpha_t` | update speed of teacher model | `0.99` | 76 | | `alpha_p` | update speed of prototypes | `0.99` | 77 | | `un_threshold` | threshold in unsupervised loss | `0.97` | 78 | | `weak_threshold` | weak threshold in contrastive loss | `0.7` | 79 | | `strong_threshold` | strong threshold in contrastive loss | `0.8` | 80 | | `temp` | temperature in contrastive loss | `0.5` | 81 | | `num_queries` | number of queries in contrastive loss | `256` | 82 | | `num_negatives` | number of negatives in contrastive loss | `512` | 83 | | `warm_up` | warm up epochs if is needed | `20` | 84 | 85 | 86 | ## Acknowledgement 87 | The data processing and augmentation (CutMix, CutOut, and ClassMix) are borrowed from ReCo. 88 | - ReCo: https://github.com/lorenmt/reco 89 | 90 | Thanks a lot for their splendid work! 91 | 92 | ## Citation 93 | If you think this work is useful for you and your research, please considering citing the following: 94 | ``` 95 | @article{wang2023space, 96 | title={Space Engage: Collaborative Space Supervision for Contrastive-based Semi-Supervised Semantic Segmentation}, 97 | author={Wang, Changqi and Xie, Haoyu and Yuan, Yuhui and Fu, Chong and Yue, Xiangyu}, 98 | journal={arXiv preprint arXiv:2307.09755}, 99 | year={2023} 100 | } 101 | ``` 102 | 103 | ## Contact 104 | If you have any questions or meet any problems, please feel free to contact us. 105 | - Changqi Wang, [wangchangqi98@gmail.com](mailto:wangchangqi98@gmail.com) 106 | -------------------------------------------------------------------------------- /config/CityScapes_config_baseline.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | name: DeepLabv3Plus 3 | num_class: 19 4 | 5 | EMA: 6 | alpha: 0.99 7 | 8 | Optim: 9 | lr: 6.4e-3 10 | weight_decay: 5e-4 11 | 12 | Lr_Scheduler: 13 | name: PolyLR 14 | step_size: 90 15 | gamma: 0.1 16 | 17 | Dataset: 18 | name: CityScapes 19 | data_dir: / 20 | txt_dir: / 21 | num_labels: 372 22 | batch_size: 4 23 | mix_mode: cutmix 24 | crop_size: !!python/tuple [769,769] 25 | scale_size: !!python/tuple [0.5,2.0] 26 | 27 | Training_Setting: 28 | epoch: 200 29 | save_dir: ./checkpoints 30 | 31 | Seed: 3407 32 | 33 | Ramp_Scheduler: 34 | begin_epoch: 0 35 | max_epoch: 200 36 | max_value: 1.0 37 | min_value: 0 38 | ramp_mult: -5.0 39 | 40 | Loss: 41 | is_available: True 42 | warm_up: 0 43 | un_threshold: 0.97 44 | strong_threshold: 0.97 45 | weak_threshold: 0.7 46 | temp: 0.5 47 | num_queries: 256 48 | num_negatives: 512 49 | alpha: 0.99 50 | 51 | Distributed: 52 | world_size: 4 53 | gpu_id: 0,1,2,3 -------------------------------------------------------------------------------- /config/VOC_config_baseline.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | name: DeepLabv3Plus 3 | num_class: 21 4 | 5 | EMA: 6 | alpha: 0.99 7 | 8 | Optim: 9 | lr: 6.4e-3 10 | weight_decay: 5e-4 11 | 12 | Lr_Scheduler: 13 | name: PolyLR 14 | step_size: 90 15 | gamma: 0.1 16 | 17 | Dataset: 18 | name: VOC 19 | data_dir: / 20 | txt_dir: / 21 | num_labels: 331 22 | batch_size: 8 23 | crop_size: !!python/tuple [512,512] 24 | scale_size: !!python/tuple [0.5,1.5] 25 | mix_mode: cutmix 26 | 27 | Training_Setting: 28 | epoch: 200 29 | save_dir: ./checkpoints/ 30 | 31 | Seed: 3407 32 | 33 | Loss: 34 | is_available: True 35 | warm_up: 0 36 | un_threshold: 0.97 37 | strong_threshold: 0.97 38 | weak_threshold: 0.7 39 | temp: 0.5 40 | num_queries: 256 41 | num_negatives: 512 42 | alpha: 0.99 43 | 44 | Ramp_Scheduler: 45 | begin_epoch: 0 46 | max_epoch: 200 47 | max_value: 1.0 48 | min_value: 0 49 | ramp_mult: -5.0 50 | 51 | Distributed: 52 | world_size: 2 53 | gpu_id: 0,1 54 | -------------------------------------------------------------------------------- /cross_label.py: -------------------------------------------------------------------------------- 1 | import shutup 2 | shutup.please() 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parallel import DistributedDataParallel 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | from generalframeworks.dataset_helpers.VOC import VOC_BuildData 10 | from generalframeworks.dataset_helpers.Cityscapes import City_BuildData 11 | from generalframeworks.networks.ddp_model import Model_cross 12 | from generalframeworks.scheduler.my_lr_scheduler import PolyLR 13 | from generalframeworks.scheduler.rampscheduler import RampdownScheduler 14 | from generalframeworks.utils import iterator_, Logger 15 | from generalframeworks.util.meter import * 16 | from generalframeworks.utils import label_onehot 17 | from generalframeworks.util.torch_dist_sum import * 18 | from generalframeworks.util.miou import * 19 | from generalframeworks.util.dist_init import local_dist_init 20 | from generalframeworks.loss.loss import ProbOhemCrossEntropy2d, Attention_Threshold_Loss, Contrast_Loss 21 | import yaml 22 | import os 23 | import time 24 | import torchvision.models as models 25 | import argparse 26 | import random 27 | 28 | def main(rank, config, args): 29 | ##### Distribution init ##### 30 | local_dist_init(args, rank) 31 | print('Hello from rank {}\n'.format(rank)) 32 | 33 | ##### Load the dataset ##### 34 | if config['Dataset']['name'] == 'VOC': 35 | data = VOC_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 36 | label_num=args.num_labels, seed=config['Seed'], crop_size=config['Dataset']['crop_size']) 37 | if config['Dataset']['name'] == 'CityScapes': 38 | data = City_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 39 | label_num=args.num_labels, seed=config['Seed'], crop_size=config['Dataset']['crop_size']) 40 | train_l_dataset, train_u_dataset, test_dataset = data.build() 41 | train_l_sampler = torch.utils.data.distributed.DistributedSampler(train_l_dataset) 42 | train_l_loader = torch.utils.data.DataLoader(train_l_dataset, 43 | batch_size=config['Dataset']['batch_size'], 44 | pin_memory=True, 45 | sampler=train_l_sampler, 46 | num_workers=4, 47 | drop_last=True) 48 | train_u_sampler = torch.utils.data.distributed.DistributedSampler(train_u_dataset) 49 | train_u_loader = torch.utils.data.DataLoader(train_u_dataset, 50 | batch_size=config['Dataset']['batch_size'], 51 | pin_memory=True, 52 | sampler=train_u_sampler, 53 | num_workers=4, 54 | drop_last=True) 55 | test_loader = torch.utils.data.DataLoader(test_dataset, 56 | batch_size=config['Dataset']['batch_size'], 57 | pin_memory=True, 58 | num_workers=4) 59 | 60 | ##### Load the weight for each class ##### 61 | weight = torch.FloatTensor( 62 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 63 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 64 | 1.0865, 1.1529, 1.0507]).cuda() 65 | 66 | ##### Model init ##### 67 | backbone = models.resnet101() 68 | ckpt = torch.load('./pretrained/resnet101.pth', map_location='cpu') 69 | backbone.load_state_dict(ckpt) 70 | 71 | # for Resnet-101 stem users 72 | #backbone = resnet.resnet101(pretrained=True) 73 | 74 | model = Model_cross(backbone, num_classes=config['Network']['num_class'], output_dim=256, config=config, temp=args.temp).cuda() 75 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda() 76 | model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) 77 | 78 | ##### Loss init ##### 79 | criterion = {'sup_loss': ProbOhemCrossEntropy2d(ignore_label=-1, thresh=0.7, min_kept=50000 * config['Dataset']['batch_size']).cuda(), 80 | 'ce_loss': nn.CrossEntropyLoss(ignore_index=-1).cuda(), 81 | 'unsup_loss': Attention_Threshold_Loss(strong_threshold=args.un_threshold).cuda(), 82 | 'contrast_loss': Contrast_Loss(strong_threshold=args.strong_threshold, 83 | num_queries=config['Loss']['num_queries'], 84 | num_negatives=config['Loss']['num_negatives'], 85 | temp=config['Loss']['temp'], 86 | alpha=config['Loss']['alpha']).cuda(), 87 | } 88 | 89 | ##### Prototype init ##### 90 | global prototypes 91 | 92 | prototypes = torch.zeros(config['Network']['num_class'], 256).cuda() 93 | if os.path.exists(args.prototypes_resume): 94 | print('prototypes resume from', args.prototypes_resume) 95 | checkpoint = torch.load(args.prototypes_resume, map_location='cpu') 96 | prototypes = torch.tensor(checkpoint['prototypes']).cuda() 97 | 98 | ##### Other init ##### 99 | optimizer = torch.optim.SGD(model.module.model.parameters(), 100 | lr=float(config['Optim']['lr']), weight_decay=float(config['Optim']['weight_decay']), momentum=0.9, nesterov=True) 101 | total_iter = args.total_iter 102 | total_epoch = int(total_iter / len(train_l_loader)) 103 | if dist.get_rank() == 0: 104 | print('total epoch is {}'.format(total_epoch)) 105 | lr_scheduler = PolyLR(optimizer, total_iter, min_lr=1e-4) 106 | 107 | if os.path.exists(args.resume): 108 | print('resume from', args.resume) 109 | checkpoint = torch.load(args.resume, map_location='cpu') 110 | model.module.model.load_state_dict(checkpoint['model']) 111 | model.module.ema_model.load_state_dict(checkpoint['ema_model']) 112 | optimizer.load_state_dict(checkpoint['optimizer']) 113 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 114 | start_epoch = checkpoint['epoch'] 115 | prototypes = torch.tensor(checkpoint['prototypes']).cuda() 116 | else: 117 | start_epoch = 0 118 | sche_d = RampdownScheduler(begin_epoch=config['Ramp_Scheduler']['begin_epoch'], 119 | max_epoch=config['Ramp_Scheduler']['max_epoch'], 120 | current_epoch=start_epoch, 121 | max_value=config['Ramp_Scheduler']['max_value'], 122 | min_value=config['Ramp_Scheduler']['min_value'], 123 | ramp_mult=config['Ramp_Scheduler']['ramp_mult']) 124 | 125 | # if dist.get_rank() == 0: 126 | # log = Logger(logFile='./log/' + str(args.job_name) + '.log') 127 | best_miou = 0 128 | 129 | model.module.model.train() 130 | model.module.ema_model.train() 131 | for epoch in range(start_epoch, total_epoch): 132 | train(train_l_loader, train_u_loader, model, optimizer, criterion, epoch, lr_scheduler, sche_d, config, args) 133 | miou = test(test_loader, model.module.ema_model, config) 134 | best_miou = max(best_miou, miou) 135 | if dist.get_rank() == 0: 136 | print('Epoch:{} * mIoU {:.4f} Best_mIoU {:.4f} Time {}'.format(epoch, miou, best_miou, time.asctime(time.localtime(time.time())))) 137 | # log.write('Epoch:{} * mIoU {:.4f} Best_mIoU {:.4f} Time {}\n'.format(epoch, miou, best_miou, time.asctime( time.localtime(time.time()) ))) 138 | # Save model 139 | if miou == best_miou: 140 | save_dir = './checkpoints/' + str(args.job_name) 141 | torch.save( 142 | { 143 | 'epoch': epoch+1, 144 | 'model': model.module.model.state_dict(), 145 | 'ema_model': model.module.ema_model.state_dict(), 146 | 'optimizer': optimizer.state_dict(), 147 | 'lr_scheduler': lr_scheduler.state_dict(), 148 | 'prototypes': prototypes.data.cpu().numpy(), 149 | }, os.path.join(save_dir, 'best_model.pth')) 150 | 151 | 152 | 153 | def train(train_l_loader, train_u_loader, model, optimizer, criterion, epoch, scheduler, sche_d, config, args): 154 | num_class = config['Network']['num_class'] 155 | # switch to train mode 156 | model.module.model.train() 157 | model.module.ema_model.train() 158 | 159 | train_u_loader.sampler.set_epoch(epoch) 160 | training_u_iter = iterator_(train_u_loader) 161 | train_l_loader.sampler.set_epoch(epoch) 162 | for i, (train_l_image, train_l_label) in enumerate(train_l_loader): 163 | train_l_image, train_l_label = train_l_image.cuda(), train_l_label.cuda() 164 | train_u_image, train_u_label = training_u_iter.__next__() 165 | train_u_image, train_u_label = train_u_image.cuda(), train_u_label.cuda() 166 | pred_l_large, pred_u_large, train_u_aug_label_cls, train_u_aug_label_rep, train_u_aug_logits_cls, train_u_aug_logits_rep, rep_all, pred_all = model(train_l_image, train_u_image, prototypes) 167 | 168 | if config['Dataset']['name'] == 'VOC': 169 | sup_loss = criterion['ce_loss'](pred_l_large, train_l_label) 170 | else: 171 | sup_loss = criterion['sup_loss'](pred_l_large, train_l_label) 172 | if epoch < args.warmup: 173 | unsup_loss = criterion['unsup_loss'](pred_u_large, train_u_aug_label_cls, train_u_aug_logits_cls) 174 | else: 175 | unsup_loss = criterion['unsup_loss'](pred_u_large, train_u_aug_label_rep, train_u_aug_logits_rep) 176 | 177 | ##### Contrastive learning ##### 178 | with torch.no_grad(): 179 | train_u_aug_mask = train_u_aug_logits_cls.ge(args.weak_threshold).float() 180 | mask_all = torch.cat(((train_l_label.unsqueeze(1) >= 0).float(), train_u_aug_mask.unsqueeze(1))) 181 | mask_all = F.interpolate(mask_all, size=pred_all.shape[2:], mode='nearest') 182 | 183 | label_l = F.interpolate(label_onehot(train_l_label, num_class), size=pred_all.shape[2:], mode='nearest') 184 | label_u = F.interpolate(label_onehot(train_u_aug_label_cls, num_class), size=pred_all.shape[2:], mode='nearest') 185 | label_all = torch.cat((label_l, label_u)) 186 | 187 | contrast_loss = criterion['contrast_loss'](rep_all, label_all, mask_all, pred_all, prototypes) 188 | 189 | if args.sche: 190 | total_loss = sup_loss + unsup_loss + contrast_loss * sche_d.value 191 | else: 192 | total_loss = sup_loss + unsup_loss + contrast_loss 193 | 194 | # Update Meter 195 | optimizer.zero_grad() 196 | total_loss.backward() 197 | optimizer.step() 198 | model.module.ema_update() 199 | scheduler.step() 200 | sche_d.step() 201 | 202 | @torch.no_grad() 203 | def test(test_loader, model, config): 204 | batch_time = AverageMeter('Time', ':6.3f') 205 | data_time = AverageMeter('Data', ':6.3f') 206 | miou_meter = ConfMatrix(num_classes=config['Network']['num_class'], fmt=':6.4f', name='test_miou') 207 | 208 | # switch to eval mode 209 | model.eval() 210 | 211 | end = time.time() 212 | test_iter = iter(test_loader) 213 | for _ in range(len(test_loader)): 214 | data_time.update(time.time() - end) 215 | test_image, test_label = test_iter.next() 216 | test_image, test_label = test_image.cuda(), test_label.cuda() 217 | 218 | pred, _ = model(test_image) 219 | pred = F.interpolate(pred, size=test_label.shape[1:], mode='bilinear', align_corners=True) 220 | 221 | miou_meter.update(pred.argmax(1).flatten(), test_label.flatten()) 222 | batch_time.update(time.time() - end) 223 | end = time.time() 224 | 225 | mat = torch_dist_sum(dist.get_rank(), miou_meter.mat) 226 | miou = mean_intersection_over_union(mat[0]) 227 | 228 | return miou 229 | 230 | 231 | if __name__ == '__main__': 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument('--config', type=str, default='./config/VOC_config_baseline.yaml') 234 | parser.add_argument('--resume', type=str, default='') 235 | parser.add_argument('--prototypes_resume', type=str, default='') 236 | parser.add_argument('--num_labels', type=int, default=92) 237 | parser.add_argument('--job_name', type=str, default='VOC_92_cross_label') 238 | 239 | # Distributed 240 | parser.add_argument('--gpu_id', type=str, default='0,1,2,3') 241 | parser.add_argument('--world_size', type=str, default='4') 242 | parser.add_argument('--port', type=str, default='12301') 243 | 244 | # Hyperparameter 245 | parser.add_argument('--strong_threshold', type=float, default=0.8) 246 | parser.add_argument('--weak_threshold', type=float, default=0.7) 247 | parser.add_argument('--un_threshold', type=float, default=0.97) 248 | parser.add_argument('--temp', type=float, default=0.5) 249 | parser.add_argument('--warmup', type=int, default=0) 250 | parser.add_argument('--sche', type=bool, default=True) 251 | 252 | args = parser.parse_args() 253 | 254 | ##### Config init ##### 255 | with open(args.config, 'r') as f: 256 | config = yaml.load(f.read(), Loader=yaml.FullLoader) 257 | save_dir = './checkpoints/' + str(args.job_name) 258 | if not os.path.exists(save_dir): 259 | os.makedirs(save_dir) 260 | with open(save_dir + '/config.yaml', 'w') as f: 261 | yaml.dump(config, f, default_flow_style=False) 262 | print(config) 263 | 264 | ##### Init Seed ##### 265 | random.seed(config['Seed']) 266 | torch.manual_seed(config['Seed']) 267 | torch.backends.cudnn.deterministic = True 268 | 269 | mp.spawn(main, nprocs=int(args.world_size), args=(config, args)) 270 | -------------------------------------------------------------------------------- /generalframeworks/augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/augmentation/transform.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image, ImageFilter 3 | 4 | import torch 5 | from typing import Tuple 6 | from torchvision import transforms 7 | import torchvision.transforms.functional as transform_f 8 | import random 9 | import numpy as np 10 | 11 | def batch_transform(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, crop_size: Tuple['h', 'w'], scale_size, 12 | apply_augmentation=False): 13 | image_list, label_list, logits_list = [], [], [] 14 | device = image.device 15 | 16 | for k in range(image.shape[0]): 17 | image_pil, label_pil, logits_pil = tensor_to_pil(image[k], label[k], logits[k]) 18 | aug_image, aug_label, aug_logits = transform(image_pil, label_pil, logits_pil, 19 | crop_size=crop_size, 20 | scale_size=scale_size, 21 | augmentation=apply_augmentation) 22 | image_list.append(aug_image.unsqueeze(0)) 23 | label_list.append(aug_label) 24 | logits_list.append(aug_logits) 25 | 26 | image_trans, label_trans, logits_trans = torch.cat(image_list).to(device), torch.cat(label_list).to(device), \ 27 | torch.cat(logits_list).to(device) 28 | return image_trans, label_trans, logits_trans 29 | 30 | def tensor_to_pil(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor): 31 | image = denormalise(image) 32 | image = transform_f.to_pil_image(image.cpu()) 33 | 34 | label = label.float() / 255. 35 | label = transform_f.to_pil_image(label.unsqueeze(0).cpu()) 36 | 37 | logits = transform_f.to_pil_image(logits.unsqueeze(0).cpu()) 38 | 39 | 40 | return image, label, logits 41 | 42 | def tensor_to_pil_1(image: torch.Tensor, label: torch.Tensor, uncertainty_u:torch.Tensor, logits: torch.Tensor, logits_all: torch.Tensor): 43 | image = denormalise(image) 44 | image = transform_f.to_pil_image(image.cpu()) 45 | 46 | label = label.float() / 255. 47 | label = transform_f.to_pil_image(label.unsqueeze(0).cpu()) 48 | uncertainty_u = uncertainty_u.float() / 255. 49 | uncertainty_u = transform_f.to_pil_image(uncertainty_u.unsqueeze(0).cpu()) 50 | logits_all_l = [] 51 | for i in range(logits_all.shape[0]): 52 | logits_all_l.append(transform_f.to_pil_image(logits_all[i].float().unsqueeze(0).cpu(), mode='F')) 53 | 54 | logits = transform_f.to_pil_image(logits.unsqueeze(0).cpu(), 'F') 55 | 56 | return image, label, uncertainty_u, logits, logits_all_l 57 | 58 | 59 | def denormalise(x, imagenet=True): 60 | if imagenet: 61 | x = transform_f.normalize(x, mean=[0., 0., 0.], std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) 62 | x = transform_f.normalize(x, mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]) 63 | return x 64 | else: 65 | return (x + 1) / 2 66 | 67 | def transform(image, label, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), label_fill=255, augmentation=False): 68 | ''' 69 | Only apply on the 3d image (one image not batch) 70 | ''' 71 | # Random Rescale image 72 | raw_w, raw_h = image.size 73 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 74 | 75 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 76 | image = transform_f.resize(image, resized_size, Image.NEAREST) 77 | label = transform_f.resize(label, resized_size, Image.NEAREST) 78 | if logits is not None: 79 | logits = transform_f.resize(logits, resized_size, Image.NEAREST) 80 | 81 | # Adding padding if rescaled image size is less than crop size 82 | if crop_size == -1: # Use original image size without rop or padding 83 | crop_size = (raw_h, raw_w) 84 | 85 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 86 | right_pad, bottom_pad = max(crop_size[1] - resized_size[1], 0), max(crop_size[0] - resized_size[0], 0) 87 | image = transform_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 88 | label = transform_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=label_fill, padding_mode='constant') 89 | if logits is not None: 90 | logits = transform_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 91 | 92 | # Random Cropping 93 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 94 | image = transform_f.crop(image, i, j, h, w) 95 | label = transform_f.crop(label, i, j, h, w) 96 | if logits is not None: 97 | logits = transform_f.crop(logits, i, j, h, w) 98 | 99 | if augmentation: 100 | # Random Color jitter 101 | if torch.rand(1) > 0.2: 102 | color_transform = transforms.ColorJitter.get_params((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 103 | image = color_transform(image) 104 | 105 | # Rnadmom Gaussian filter 106 | if torch.rand(1) > 0.5: 107 | sigma = random.uniform(0.15, 1.15) 108 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 109 | 110 | # Random horizontal filpping 111 | if torch.rand(1) > 0.5: 112 | image = transform_f.hflip(image) 113 | label = transform_f.hflip(label) 114 | if logits is not None: 115 | logits = transform_f.hflip(logits) 116 | 117 | # Transform to Tensor 118 | image = transform_f.to_tensor(image) 119 | label = (transform_f.to_tensor(label) * 255).long() 120 | label[label == 255] = -1 # incalid pixels are re-mapping to index -1 121 | if logits is not None: 122 | logits = transform_f.to_tensor(logits) 123 | 124 | # Apply (ImageNet) normalization 125 | #image = transform_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 126 | image = transform_f.normalize(image, mean=[0.5], std=[0.299]) 127 | if logits is not None: 128 | return image, label, logits 129 | else: 130 | return image, label 131 | 132 | def generate_cut(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, mode='cutout'): 133 | batch_size, _, image_h, image_w = image.shape 134 | device = image.device 135 | 136 | new_image = [] 137 | new_label = [] 138 | new_logits = [] 139 | for i in range(batch_size): 140 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 141 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 142 | label[i][(1 - mix_mask).bool()] = -1 143 | 144 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 145 | new_label.append(label[i].unsqueeze(0)) 146 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 147 | continue 148 | elif mode == 'cutmix': 149 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 150 | elif mode == 'classmix': 151 | mix_mask = generate_class_mask(label[i]).to(device) 152 | else: 153 | raise ValueError('mode must be in cutout, cutmix, or classmix') 154 | 155 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 156 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 157 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 158 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 159 | 160 | return new_image, new_label.long(), new_logits 161 | 162 | 163 | 164 | def generate_cutout_mask(image_size, ratio=2): 165 | # Cutout: random generate mask where the region inside is 0, one ouside is 1 166 | cutout_area = image_size[0] * image_size[1] / ratio 167 | 168 | w = np.random.randint(image_size[1] / ratio + 1, image_size[1]) 169 | h = np.round(cutout_area / w) 170 | 171 | x_start = np.random.randint(0, image_size[1] - w + 1) 172 | y_start = np.random.randint(0, image_size[0] - h + 1) 173 | 174 | x_end = int(x_start + w) 175 | y_end = int(y_start + h) 176 | 177 | mask = torch.ones(image_size) 178 | mask[y_start: y_end, x_start: x_end] = 0 179 | 180 | return mask.float() 181 | 182 | def generate_class_mask(pseudo_labels: torch.Tensor): 183 | # select the half classes and cover up them 184 | labels = torch.unique(pseudo_labels) # all unique labels 185 | labels_select: torch.Tensor = labels[torch.randperm(len(labels))][:len(labels) // 2] # Randmoly select half of labels 186 | mask = (pseudo_labels.unsqueeze(-1) == labels_select).any(dim=-1) 187 | return mask.float() 188 | 189 | def batch_transform_1(data, label, uncertainty_u, logits, logits_all, crop_size, scale_size, apply_augmentation): 190 | data_list, label_list, uncertainty_u_list, logits_list, logits_all_list = [], [], [], [], [] 191 | device = data.device 192 | 193 | for k in range(data.shape[0]): 194 | data_pil, label_pil, uncertainty_u_pil, logits_pil, logits_all_pil = tensor_to_pil_1(data[k], label[k], uncertainty_u[k], logits[k], logits_all[k])##ok 195 | aug_data, aug_label, aug_uncertainty_u, aug_logits, aug_logits_all = transform_1(data_pil, label_pil, uncertainty_u_pil, logits_pil, logits_all_pil, 196 | crop_size=crop_size, 197 | scale_size=scale_size, 198 | augmentation=apply_augmentation) 199 | 200 | 201 | tmp = aug_label.squeeze(0).cuda().eq(aug_logits_all.cuda().argmax(0)) 202 | all = tmp.cuda().sum() + (aug_label.cuda() == -1).sum() 203 | data_list.append(aug_data.unsqueeze(0)) 204 | label_list.append(aug_label) 205 | uncertainty_u_list.append(aug_uncertainty_u) 206 | logits_list.append(aug_logits) 207 | logits_all_list.append(aug_logits_all.unsqueeze(0)) 208 | #ok 209 | 210 | data_trans, label_trans, uncertainty_u_trans, logits_trans, logits_all_trans = \ 211 | torch.cat(data_list).to(device), torch.cat(label_list).to(device), torch.cat(uncertainty_u_list).to(device), torch.cat(logits_list).to(device), torch.cat(logits_all_list).to(device) 212 | return data_trans, label_trans, uncertainty_u_trans, logits_trans, logits_all_trans 213 | 214 | def transform_1(image, label, uncertainty_u=None, logits=None, logits_all=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 215 | # Random rescale image 216 | 217 | raw_w, raw_h = image.size 218 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 219 | 220 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 221 | image = transform_f.resize(image, resized_size, Image.BILINEAR) 222 | label = transform_f.resize(label, resized_size, Image.NEAREST) 223 | if uncertainty_u is not None: 224 | uncertainty_u = transform_f.resize(uncertainty_u, resized_size, Image.NEAREST) 225 | if logits is not None: 226 | logits = transform_f.resize(logits, resized_size, Image.NEAREST) 227 | logits_all_l = [] 228 | if logits_all is not None: 229 | for logits_item in logits_all: 230 | logits_all_l.append(transform_f.resize(logits_item, resized_size, Image.NEAREST)) 231 | logits_all = logits_all_l 232 | 233 | # Add padding if rescaled image size is less than crop size 234 | if crop_size == -1: # use original im size without crop or padding 235 | crop_size = (raw_h, raw_w) 236 | 237 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 238 | ##ok 239 | right_pad, bottom_pad = max(crop_size[1] - resized_size[1], 0), max(crop_size[0] - resized_size[0], 0) 240 | image = transform_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 241 | label = transform_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 242 | if uncertainty_u is not None: 243 | uncertainty_u = transform_f.pad(uncertainty_u, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 244 | if logits is not None: 245 | logits = transform_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 246 | if logits_all is not None: 247 | logits_all_l_tmp = [] 248 | for logits_item in logits_all: 249 | logits_all_l_tmp.append(transform_f.pad(logits_item, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant')) 250 | logits_all = logits_all_l_tmp 251 | # ok 252 | 253 | 254 | # Random Cropping 255 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 256 | image = transform_f.crop(image, i, j, h, w) 257 | label = transform_f.crop(label, i, j, h, w) 258 | if uncertainty_u is not None: 259 | uncertainty_u = transform_f.crop(uncertainty_u, i, j, h, w) 260 | if logits is not None: 261 | logits = transform_f.crop(logits, i, j, h, w) 262 | if logits_all is not None: 263 | logits_all_l_tmp = [] 264 | for logits_item in logits_all: 265 | logits_all_l_tmp.append(transform_f.crop(logits_item, i, j, h, w)) 266 | logits_all = logits_all_l_tmp 267 | 268 | if augmentation: 269 | # Random color jitter 270 | if torch.rand(1) > 0.2: 271 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) # For PyTorch 1.9/TorchVision 0.10 users 272 | # color_transform = transforms.ColorJitter.get_params((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 273 | image = color_transform(image) 274 | 275 | # Random Gaussian filter 276 | if torch.rand(1) > 0.5: 277 | sigma = random.uniform(0.15, 1.15) 278 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 279 | 280 | # Random horizontal flipping 281 | if torch.rand(1) > 0.5: 282 | image = transform_f.hflip(image) 283 | label = transform_f.hflip(label) 284 | if uncertainty_u is not None: 285 | uncertainty_u = transform_f.hflip(uncertainty_u) 286 | if logits is not None: 287 | logits = transform_f.hflip(logits) 288 | if logits_all is not None: 289 | logits_all_l_tmp = [] 290 | for logits_item in logits_all: 291 | logits_all_l_tmp.append(transform_f.hflip(logits_item)) 292 | logits_all = logits_all_l_tmp 293 | 294 | # Transform to tensor 295 | image = transform_f.to_tensor(image) 296 | label = (transform_f.to_tensor(label) * 255).long() 297 | uncertainty_u = (transform_f.to_tensor(uncertainty_u) * 255).long() 298 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 299 | if logits is not None: 300 | logits = transform_f.to_tensor(logits) 301 | if uncertainty_u is not None: 302 | uncertainty_u[uncertainty_u == 255] = -1 303 | if logits_all is not None: 304 | logits_all_l_tmp = [] 305 | for logits_item in logits_all: 306 | logits_all_l_tmp.append(transform_f.to_tensor(logits_item)) 307 | logits_all = torch.cat(logits_all_l_tmp) 308 | 309 | # Apply (ImageNet) normalisation 310 | # image = transform_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 311 | if logits is not None and uncertainty_u is not None and logits_all is not None: 312 | return image, label, uncertainty_u, logits, logits_all 313 | elif logits is not None and uncertainty_u is None: 314 | return image, label, logits 315 | elif logits is None and uncertainty_u is not None: 316 | return image, label, uncertainty_u 317 | else: 318 | return image, label 319 | 320 | def generate_cut_1(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, uncertainty_u: torch.Tensor=None, logits_all=None, mode='cutout'): 321 | batch_size, _, image_h, image_w = image.shape 322 | device = image.device 323 | 324 | new_image = [] 325 | new_label = [] 326 | new_uncertainty_u = [] 327 | new_logits = [] 328 | new_logits_all = [] 329 | for i in range(batch_size): 330 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 331 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 332 | label[i][(1 - mix_mask).bool()] = -1 333 | if uncertainty_u is not None: 334 | uncertainty_u[i][(1 - mix_mask).bool()] = 0 335 | 336 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 337 | new_label.append(label[i].unsqueeze(0)) 338 | if uncertainty_u is not None: 339 | new_uncertainty_u.append(uncertainty_u[i].unsqueeze(0)) 340 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 341 | continue 342 | elif mode == 'cutmix': 343 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 344 | elif mode == 'classmix': 345 | mix_mask = generate_class_mask(label[i]).to(device) 346 | else: 347 | raise ValueError('mode must be in cutout, cutmix, or classmix') 348 | 349 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 350 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 351 | if uncertainty_u is not None: 352 | new_uncertainty_u.append((uncertainty_u[i] * mix_mask + uncertainty_u[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 353 | if logits_all is not None: 354 | new_logits_all.append((logits_all[i] * mix_mask + logits_all[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 355 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 356 | 357 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 358 | 359 | if uncertainty_u is not None and logits_all is not None: 360 | new_uncertainty_u = torch.cat(new_uncertainty_u) 361 | new_logits_all = torch.cat(new_logits_all) 362 | 363 | return new_image, new_label.long(), new_uncertainty_u.long(), new_logits, new_logits_all 364 | else: 365 | return new_image, new_label.long(), new_logits 366 | 367 | 368 | def batch_transform_2(data, label, uncertainty_u, logits, crop_size, scale_size, apply_augmentation): 369 | data_list, label_list, uncertainty_u_list, logits_list = [], [], [], [] 370 | device = data.device 371 | 372 | for k in range(data.shape[0]): 373 | data_pil, label_pil, logits_pil = tensor_to_pil(data[k], label[k], logits[k]) 374 | aug_data, aug_label, aug_uncertainty_u, aug_logits = transform_2(data_pil, label_pil, uncertainty_u[k].unsqueeze(0), logits_pil, 375 | crop_size=crop_size, 376 | scale_size=scale_size, 377 | augmentation=apply_augmentation) 378 | data_list.append(aug_data.unsqueeze(0)) 379 | label_list.append(aug_label) 380 | # uncertainty_u_list.append(aug_uncertainty_u.unsqueeze(0)) 381 | uncertainty_u_list.append(aug_uncertainty_u) 382 | logits_list.append(aug_logits) 383 | 384 | data_trans, label_trans, uncertainty_u_trans, logits_trans = \ 385 | torch.cat(data_list).to(device), torch.cat(label_list).to(device), torch.cat(uncertainty_u_list).to(device), torch.cat(logits_list).to(device) 386 | return data_trans, label_trans, uncertainty_u_trans, logits_trans 387 | 388 | def transform_2(image, label, uncertainty_u=None, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 389 | # Random rescale image 390 | raw_w, raw_h = image.size 391 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 392 | 393 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 394 | image = transform_f.resize(image, resized_size, Image.BILINEAR) 395 | label = transform_f.resize(label, resized_size, Image.NEAREST) 396 | if uncertainty_u is not None: 397 | uncertainty_u = transform_f.resize(uncertainty_u, resized_size, Image.NEAREST) 398 | if logits is not None: 399 | logits = transform_f.resize(logits, resized_size, Image.NEAREST) 400 | 401 | # Add padding if rescaled image size is less than crop size 402 | if crop_size == -1: # use original im size without crop or padding 403 | crop_size = (raw_h, raw_w) 404 | 405 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 406 | right_pad, bottom_pad = max(crop_size[1] - resized_size[1], 0), max(crop_size[0] - resized_size[0], 0) 407 | image = transform_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 408 | label = transform_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 409 | if uncertainty_u is not None: 410 | uncertainty_u = transform_f.pad(uncertainty_u, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 411 | if logits is not None: 412 | logits = transform_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 413 | 414 | # Random Cropping 415 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 416 | image = transform_f.crop(image, i, j, h, w) 417 | label = transform_f.crop(label, i, j, h, w) 418 | if uncertainty_u is not None: 419 | uncertainty_u = transform_f.crop(uncertainty_u, i, j, h, w) 420 | if logits is not None: 421 | logits = transform_f.crop(logits, i, j, h, w) 422 | 423 | if augmentation: 424 | # Random color jitter 425 | if torch.rand(1) > 0.2: 426 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) # For PyTorch 1.9/TorchVision 0.10 users 427 | # color_transform = transforms.ColorJitter.get_params((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 428 | image = color_transform(image) 429 | 430 | # Random Gaussian filter 431 | if torch.rand(1) > 0.5: 432 | sigma = random.uniform(0.15, 1.15) 433 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 434 | 435 | # Random horizontal flipping 436 | if torch.rand(1) > 0.5: 437 | image = transform_f.hflip(image) 438 | label = transform_f.hflip(label) 439 | if uncertainty_u is not None: 440 | uncertainty_u = transform_f.hflip(uncertainty_u) 441 | if logits is not None: 442 | logits = transform_f.hflip(logits) 443 | 444 | # Transform to tensor 445 | image = transform_f.to_tensor(image) 446 | label = (transform_f.to_tensor(label) * 255).long() 447 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 448 | if logits is not None: 449 | logits = transform_f.to_tensor(logits) 450 | 451 | # Apply (ImageNet) normalisation 452 | image = transform_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 453 | if logits is not None and uncertainty_u is not None: 454 | return image, label, uncertainty_u, logits 455 | elif logits is not None and uncertainty_u is None: 456 | return image, label, logits 457 | elif logits is None and uncertainty_u is not None: 458 | return image, label, uncertainty_u 459 | else: 460 | return image, label 461 | 462 | def generate_cut_2(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, uncertainty_u: torch.Tensor=None, mode='cutout'): 463 | batch_size, _, image_h, image_w = image.shape 464 | device = image.device 465 | 466 | new_image = [] 467 | new_label = [] 468 | new_uncertainty_u = [] 469 | new_logits = [] 470 | for i in range(batch_size): 471 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 472 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 473 | label[i][(1 - mix_mask).bool()] = -1 474 | if uncertainty_u is not None: 475 | uncertainty_u[i][(1 - mix_mask).bool()] = 0 476 | 477 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 478 | new_label.append(label[i].unsqueeze(0)) 479 | if uncertainty_u is not None: 480 | new_uncertainty_u.append(uncertainty_u[i].unsqueeze(0)) 481 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 482 | continue 483 | elif mode == 'cutmix': 484 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 485 | elif mode == 'classmix': 486 | mix_mask = generate_class_mask(label[i]).to(device) 487 | else: 488 | raise ValueError('mode must be in cutout, cutmix, or classmix') 489 | 490 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 491 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 492 | if uncertainty_u is not None: 493 | new_uncertainty_u.append((uncertainty_u[i] * mix_mask + uncertainty_u[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 494 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 495 | 496 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 497 | if uncertainty_u is not None: 498 | new_uncertainty_u = torch.cat(new_uncertainty_u) 499 | return new_image, new_label.long(), new_uncertainty_u, new_logits 500 | else: 501 | return new_image, new_label.long(), new_logits 502 | -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/Cityscapes.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import os 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as transforms_f 6 | import random 7 | from PIL import Image, ImageFilter 8 | import numpy as np 9 | 10 | class Cityscapes_Dataset_cache(data.Dataset): 11 | def __init__(self, root, idx_list, crop_size=(512, 512), scale_size=(0.5, 2.0), augmentation=True, train=True, 12 | apply_partial=None, partial_seed=None): 13 | self.root = os.path.expanduser(root) 14 | self.train = train 15 | self.crop_size = crop_size 16 | self.augmentation = augmentation 17 | self.scale_size = scale_size 18 | self.idx_list = idx_list 19 | self.apply_partial = apply_partial 20 | self.partial_seed = partial_seed 21 | 22 | 23 | def __getitem__(self, index): 24 | if self.train: 25 | image_root, city_name = image_root_transform(self.idx_list[index], mode='train') 26 | image = Image.open(self.root + image_root) 27 | label_root = label_root_transform(self.idx_list[index], city_name, mode='train') 28 | label = Image.open(self.root + label_root) 29 | else: 30 | image_root, city_name = image_root_transform(self.idx_list[index], mode='val') 31 | image = Image.open(self.root + image_root) 32 | label_root = label_root_transform(self.idx_list[index], city_name, mode='val') 33 | label = Image.open(self.root + label_root) 34 | image, label = transform(image, label, None, self.crop_size, self.scale_size, self.augmentation) 35 | return image, label.squeeze(0) 36 | 37 | def __len__(self): 38 | return len(self.idx_list) 39 | 40 | class Cityscapes_Dataset(data.Dataset): 41 | def __init__(self, root, idx_list, crop_size=(512, 512), scale_size=(0.5, 2.0), augmentation=True, train=True): 42 | self.root = os.path.expanduser(root) 43 | self.train = train 44 | self.crop_size = crop_size 45 | self.augmentation = augmentation 46 | self.scale_size = scale_size 47 | self.idx_list = idx_list 48 | 49 | def __getitem__(self, index): 50 | if self.train: 51 | image_root, city_name = image_root_transform(self.idx_list[index], mode='train') 52 | image = Image.open(self.root + image_root) 53 | label_root = label_root_transform(self.idx_list[index], city_name, mode='train') 54 | label = Image.open(self.root + label_root) 55 | else: 56 | image_root, city_name = image_root_transform(self.idx_list[index], mode='val') 57 | image = Image.open(self.root + image_root) 58 | label_root = label_root_transform(self.idx_list[index], city_name, mode='val') 59 | label = Image.open(self.root + label_root) 60 | image, label = transform(image, label, None, self.crop_size, self.scale_size, self.augmentation) 61 | return image, label.squeeze(0) 62 | 63 | def __len__(self): 64 | return len(self.idx_list) 65 | 66 | class City_BuildData(): 67 | def __init__(self, data_path, txt_path, label_num, seed, crop_size=[512,512]): 68 | self.data_path = data_path 69 | self.txt_path = txt_path 70 | self.label_num = label_num 71 | self.seed = seed 72 | self.im_size = [512, 1024] 73 | self.crop_size = crop_size 74 | self.num_segments = 19 75 | self.scale_size = (1.0, 1.0) 76 | self.train_l_idx, self.train_u_idx, self.test_idx= get_cityscapes_idx_via_txt(self.txt_path, self.label_num, self.seed) 77 | 78 | def build(self): 79 | train_l_dataset = Cityscapes_Dataset(self.data_path, self.train_l_idx, self.crop_size, self.scale_size, 80 | augmentation=True, train=True) 81 | train_u_dataset = Cityscapes_Dataset(self.data_path, self.train_u_idx, self.crop_size, scale_size=(1.0, 1.0), 82 | augmentation=False, train=True) 83 | test_dataset = Cityscapes_Dataset(self.data_path, self.test_idx, self.crop_size, scale_size=(1.0, 1.0),augmentation=False, 84 | train=False) 85 | return train_l_dataset, train_u_dataset, test_dataset 86 | 87 | def get_cityscapes_idx_via_txt(root, label_num, seed): 88 | ''' 89 | Read idx list via generated txt, pre-perform make_list.py 90 | ''' 91 | root = root + '/' + str(label_num) + '/' + str(seed) 92 | with open(root + '/labeled_filename.txt') as f: 93 | labeled_list = f.read().splitlines() 94 | f.close() 95 | with open(root + '/unlabeled_filename.txt') as f: 96 | unlabeled_list = f.read().splitlines() 97 | f.close() 98 | with open(root + '/valid_filename.txt') as f: 99 | test_list = f.read().splitlines() 100 | f.close() 101 | return labeled_list, unlabeled_list, test_list 102 | 103 | def transform(image, label, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 104 | # Randomly rescale images 105 | raw_w, raw_h = image.size 106 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 107 | 108 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 109 | image = transforms_f.resize(image, resized_size, Image.BILINEAR) 110 | label = transforms_f.resize(label, resized_size, Image.NEAREST) 111 | if logits is not None: 112 | logits = transforms_f.resize(logits, resized_size, Image.NEAREST) 113 | 114 | # Add padding if rescaled image is smaller than crop size 115 | if crop_size == -1: # Use original image size 116 | crop_size = (raw_w, raw_h) 117 | 118 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 119 | right_pad = max(crop_size[1] - resized_size[1], 0) 120 | bottom_pad = max(crop_size[0] - resized_size[0], 0) 121 | image = transforms_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 122 | label = transforms_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 123 | if logits is not None: 124 | logits = transforms_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 125 | 126 | # Randomly crop images 127 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 128 | image = transforms_f.crop(image, i, j, h, w) 129 | label = transforms_f.crop(label, i, j, h, w) 130 | if logits is not None: 131 | logits = transforms_f.crop(logits, i, j, h, w) 132 | 133 | if augmentation: 134 | # Random color jittering 135 | if torch.rand(1) > 0.2: 136 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 137 | image = color_transform(image) 138 | 139 | # Random Gaussian filtering 140 | if torch.rand(1) > 0.5: 141 | sigma = random.uniform(0.15, 1.15) 142 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 143 | 144 | # Random horizontal flipping 145 | if torch.rand(1) > 0.5: 146 | image = transforms_f.hflip(image) 147 | label = transforms_f.hflip(label) 148 | if logits is not None: 149 | logits = transforms_f.hflip(logits) 150 | 151 | # Transform to Tensor 152 | image = transforms_f.to_tensor(image) 153 | label = (transforms_f.to_tensor(label) * 255).long() 154 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 155 | if logits is not None: 156 | logits = transforms_f.to_tensor(logits) 157 | 158 | # Apply ImageNet normalization 159 | image = transforms_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 160 | if logits is not None: 161 | return image, label, logits 162 | else: 163 | return image, label 164 | 165 | def tensor_to_pil(image, label, logits): 166 | image = denormalise(image) 167 | image = transforms_f.to_pil_image(image.cpu()) 168 | label = label.float() / 255. 169 | label = transforms_f.to_pil_image(label.unsqueeze(0).cpu()) 170 | logits = transforms_f.to_pil_image(logits.unsqueeze(0).cpu()) 171 | return image, label, logits 172 | 173 | def denormalise(x, imagenet=True): 174 | if imagenet: 175 | x = transforms_f.normalize(x, mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]) 176 | x = transforms_f.normalize(x, mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]) 177 | return x 178 | else: 179 | return (x + 1) / 2 180 | 181 | def batch_transform(images, labels, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 182 | image_list, label_list, logits_list = [], [], [] 183 | device = images.device 184 | for k in range(images.shape[0]): 185 | image_pil, label_pil, logits_pil = tensor_to_pil(images[k], labels[k], logits[k]) 186 | aug_image, aug_label, aug_logits = transform(image_pil, label_pil, logits_pil, crop_size, scale_size, augmentation) 187 | image_list.append(aug_image.unsqueeze(0)) 188 | label_list.append(aug_label) 189 | logits_list.append(aug_logits) 190 | 191 | image_trans, label_trans, logits_trans = torch.cat(image_list).to(device), torch.cat(label_list).to(device), torch.cat(logits_list).to(device) 192 | return image_trans, label_trans, logits_trans 193 | 194 | def cityscapes_class_map(mask): 195 | # source: https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 196 | mask_map = np.zeros_like(mask) 197 | mask_map[np.isin(mask, [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30])] = 255 198 | mask_map[np.isin(mask, [7])] = 0 199 | mask_map[np.isin(mask, [8])] = 1 200 | mask_map[np.isin(mask, [11])] = 2 201 | mask_map[np.isin(mask, [12])] = 3 202 | mask_map[np.isin(mask, [13])] = 4 203 | mask_map[np.isin(mask, [17])] = 5 204 | mask_map[np.isin(mask, [19])] = 6 205 | mask_map[np.isin(mask, [20])] = 7 206 | mask_map[np.isin(mask, [21])] = 8 207 | mask_map[np.isin(mask, [22])] = 9 208 | mask_map[np.isin(mask, [23])] = 10 209 | mask_map[np.isin(mask, [24])] = 11 210 | mask_map[np.isin(mask, [25])] = 12 211 | mask_map[np.isin(mask, [26])] = 13 212 | mask_map[np.isin(mask, [27])] = 14 213 | mask_map[np.isin(mask, [28])] = 15 214 | mask_map[np.isin(mask, [31])] = 16 215 | mask_map[np.isin(mask, [32])] = 17 216 | mask_map[np.isin(mask, [33])] = 18 217 | return mask_map 218 | 219 | def label_root_transform(root: str, name: str, mode: str): 220 | label_root = root.strip()[0: -12] + '_gtFine_trainIds' 221 | return '/gtFine/{}/{}/{}.png'.format(mode, name, label_root) 222 | 223 | def image_root_transform(root: str, mode: str): 224 | name = root[0: root.find('_')] 225 | return '/leftImg8bit/{}/{}/{}.png'.format(mode, name, root), name -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/VOC.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import os 4 | import torchvision.transforms as transforms 5 | import torchvision.transforms.functional as transforms_f 6 | import random 7 | from PIL import Image, ImageFilter 8 | import numpy as np 9 | import torch.distributed as dist 10 | 11 | class Pascal_VOC_Dataset(data.Dataset): 12 | def __init__(self, root, idx_list, crop_size=(512, 512), scale_size=(0.5, 2.0), augmentation=True, train=True): 13 | self.root = os.path.expanduser(root) 14 | self.train = train 15 | self.crop_size = crop_size 16 | self.augmentation = augmentation 17 | self.scale_size = scale_size 18 | self.idx_list = idx_list 19 | 20 | def __getitem__(self, index): 21 | image_root = Image.open(self.root + '/JPEGImages/{}.jpg'.format(self.idx_list[index])) 22 | label_root = Image.open(self.root + '/SegmentationClassAug/{}.png'.format(self.idx_list[index])) 23 | image, label = transform(image_root, label_root, None, crop_size=self.crop_size, scale_size=self.scale_size, augmentation=self.augmentation) 24 | return image, label.squeeze(0) 25 | 26 | def __len__(self): 27 | return len(self.idx_list) 28 | 29 | class VOC_BuildData(): 30 | def __init__(self, data_path, txt_path, label_num, seed, crop_size=[512,512]): 31 | self.data_path = data_path 32 | self.txt_path = txt_path 33 | self.image_size = [513, 513] 34 | self.crop_size = crop_size 35 | self.num_segments = 21 36 | self.scale_size = (0.5, 1.5) 37 | self.train_l_idx, self.train_u_idx, self.test_idx= get_pascal_idx_via_txt(self.txt_path, label_num=label_num, seed=seed) 38 | 39 | def build(self): 40 | train_l_dataset = Pascal_VOC_Dataset(self.data_path, self.train_l_idx, self.crop_size, self.scale_size, 41 | augmentation=True, train=True) 42 | train_u_dataset = Pascal_VOC_Dataset(self.data_path, self.train_u_idx, self.crop_size, scale_size=(1.0, 1.0), 43 | augmentation=False, train=True) 44 | test_dataset = Pascal_VOC_Dataset(self.data_path, self.test_idx, self.crop_size, scale_size=(1.0, 1.0),augmentation=False, 45 | train=False) 46 | return train_l_dataset, train_u_dataset, test_dataset 47 | 48 | def get_pascal_idx_via_txt(root, label_num, seed): 49 | ''' 50 | Read idx list via generated txt, pre-perform make_list.py 51 | ''' 52 | root = root + '/' + str(label_num) + '/' + str(seed) 53 | with open(root + '/labeled_filename.txt') as f: 54 | labeled_list = f.read().splitlines() 55 | f.close() 56 | with open(root + '/unlabeled_filename.txt') as f: 57 | unlabeled_list = f.read().splitlines() 58 | f.close() 59 | with open(root + '/valid_filename.txt') as f: 60 | test_list = f.read().splitlines() 61 | f.close() 62 | return labeled_list, unlabeled_list, test_list 63 | 64 | def transform(image, label, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 65 | # Randomly rescale images 66 | raw_w, raw_h = image.size 67 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 68 | 69 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 70 | image = transforms_f.resize(image, resized_size, Image.BILINEAR) 71 | label = transforms_f.resize(label, resized_size, Image.NEAREST) 72 | if logits is not None: 73 | logits = transforms_f.resize(logits, resized_size, Image.NEAREST) 74 | 75 | # Add padding if rescaled image is smaller than crop size 76 | if crop_size == -1: # Use original image size 77 | crop_size = (raw_w, raw_h) 78 | 79 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 80 | right_pad = max(crop_size[1] - resized_size[1], 0) 81 | bottom_pad = max(crop_size[0] - resized_size[0], 0) 82 | image = transforms_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 83 | label = transforms_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 84 | if logits is not None: 85 | logits = transforms_f.pad(logits, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 86 | 87 | # Randomly crop images 88 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 89 | image = transforms_f.crop(image, i, j, h, w) 90 | label = transforms_f.crop(label, i, j, h, w) 91 | if logits is not None: 92 | logits = transforms_f.crop(logits, i, j, h, w) 93 | 94 | if augmentation: 95 | # Random color jittering 96 | if torch.rand(1) > 0.2: 97 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 98 | image = color_transform(image) 99 | 100 | # Random Gaussian filtering 101 | if torch.rand(1) > 0.5: 102 | sigma = random.uniform(0.15, 1.15) 103 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 104 | 105 | # Random horizontal flipping 106 | if torch.rand(1) > 0.5: 107 | image = transforms_f.hflip(image) 108 | label = transforms_f.hflip(label) 109 | if logits is not None: 110 | logits = transforms_f.hflip(logits) 111 | 112 | # Transform to Tensor 113 | image = transforms_f.to_tensor(image) 114 | label = (transforms_f.to_tensor(label) * 255).long() 115 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 116 | if logits is not None: 117 | logits = transforms_f.to_tensor(logits) 118 | 119 | # Apply ImageNet normalization 120 | image = transforms_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 121 | if logits is not None: 122 | return image, label, logits 123 | else: 124 | return image, label 125 | 126 | def transform_2(image, label, logits1=None, logits2=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 127 | # Randomly rescale images 128 | raw_w, raw_h = image.size 129 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 130 | 131 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 132 | image = transforms_f.resize(image, resized_size, Image.BILINEAR) 133 | label = transforms_f.resize(label, resized_size, Image.NEAREST) 134 | if logits1 is not None: 135 | logits1 = transforms_f.resize(logits1, resized_size, Image.NEAREST) 136 | if logits2 is not None: 137 | logits2 = transforms_f.resize(logits2, resized_size, Image.NEAREST) 138 | 139 | # Add padding if rescaled image is smaller than crop size 140 | if crop_size == -1: # Use original image size 141 | crop_size = (raw_w, raw_h) 142 | 143 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 144 | right_pad = max(crop_size[1] - resized_size[1], 0) 145 | bottom_pad = max(crop_size[0] - resized_size[0], 0) 146 | image = transforms_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 147 | label = transforms_f.pad(label, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 148 | if logits1 is not None: 149 | logits1 = transforms_f.pad(logits1, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 150 | if logits2 is not None: 151 | logits2 = transforms_f.pad(logits2, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 152 | 153 | # Randomly crop images 154 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 155 | image = transforms_f.crop(image, i, j, h, w) 156 | label = transforms_f.crop(label, i, j, h, w) 157 | if logits1 is not None: 158 | logits1 = transforms_f.crop(logits1, i, j, h, w) 159 | if logits2 is not None: 160 | logits2 = transforms_f.crop(logits2, i, j, h, w) 161 | 162 | if augmentation: 163 | # Random color jittering 164 | if torch.rand(1) > 0.2: 165 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 166 | image = color_transform(image) 167 | 168 | # Random Gaussian filtering 169 | if torch.rand(1) > 0.5: 170 | sigma = random.uniform(0.15, 1.15) 171 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 172 | 173 | # Random horizontal flipping 174 | if torch.rand(1) > 0.5: 175 | image = transforms_f.hflip(image) 176 | label = transforms_f.hflip(label) 177 | if logits1 is not None: 178 | logits1 = transforms_f.hflip(logits1) 179 | if logits2 is not None: 180 | logits2 = transforms_f.hflip(logits2) 181 | 182 | # Transform to Tensor 183 | image = transforms_f.to_tensor(image) 184 | label = (transforms_f.to_tensor(label) * 255).long() 185 | label[label == 255] = -1 # invalid pixels are re-mapped to index -1 186 | if logits1 is not None: 187 | logits1 = transforms_f.to_tensor(logits1) 188 | if logits2 is not None: 189 | logits2 = transforms_f.to_tensor(logits2) 190 | 191 | # Apply ImageNet normalization 192 | image = transforms_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 193 | if logits1 is not None: 194 | return image, label, logits1, logits2 195 | else: 196 | return image, label 197 | 198 | def transform_3(image, label1, label2, logits1=None, logits2=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 199 | # Randomly rescale images 200 | raw_w, raw_h = image.size 201 | scale_ratio = random.uniform(scale_size[0], scale_size[1]) 202 | 203 | resized_size = (int(raw_h * scale_ratio), int(raw_w * scale_ratio)) 204 | image = transforms_f.resize(image, resized_size, Image.BILINEAR) 205 | label1 = transforms_f.resize(label1, resized_size, Image.NEAREST) 206 | label2 = transforms_f.resize(label2, resized_size, Image.NEAREST) 207 | if logits1 is not None: 208 | logits1 = transforms_f.resize(logits1, resized_size, Image.NEAREST) 209 | if logits2 is not None: 210 | logits2 = transforms_f.resize(logits2, resized_size, Image.NEAREST) 211 | 212 | # Add padding if rescaled image is smaller than crop size 213 | if crop_size == -1: # Use original image size 214 | crop_size = (raw_w, raw_h) 215 | 216 | if crop_size[0] > resized_size[0] or crop_size[1] > resized_size[1]: 217 | right_pad = max(crop_size[1] - resized_size[1], 0) 218 | bottom_pad = max(crop_size[0] - resized_size[0], 0) 219 | image = transforms_f.pad(image, padding=(0, 0, right_pad, bottom_pad), padding_mode='reflect') 220 | label1 = transforms_f.pad(label1, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 221 | label2 = transforms_f.pad(label2, padding=(0, 0, right_pad, bottom_pad), fill=255, padding_mode='constant') 222 | if logits1 is not None: 223 | logits1 = transforms_f.pad(logits1, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 224 | if logits2 is not None: 225 | logits2 = transforms_f.pad(logits2, padding=(0, 0, right_pad, bottom_pad), fill=0, padding_mode='constant') 226 | 227 | # Randomly crop images 228 | i, j, h, w = transforms.RandomCrop.get_params(image, output_size=crop_size) 229 | image = transforms_f.crop(image, i, j, h, w) 230 | label1 = transforms_f.crop(label1, i, j, h, w) 231 | label2 = transforms_f.crop(label2, i, j, h, w) 232 | if logits1 is not None: 233 | logits1 = transforms_f.crop(logits1, i, j, h, w) 234 | if logits2 is not None: 235 | logits2 = transforms_f.crop(logits2, i, j, h, w) 236 | 237 | if augmentation: 238 | # Random color jittering 239 | if torch.rand(1) > 0.2: 240 | color_transform = transforms.ColorJitter((0.75, 1.25), (0.75, 1.25), (0.75, 1.25), (-0.25, 0.25)) 241 | image = color_transform(image) 242 | 243 | # Random Gaussian filtering 244 | if torch.rand(1) > 0.5: 245 | sigma = random.uniform(0.15, 1.15) 246 | image = image.filter(ImageFilter.GaussianBlur(radius=sigma)) 247 | 248 | # Random horizontal flipping 249 | if torch.rand(1) > 0.5: 250 | image = transforms_f.hflip(image) 251 | label1 = transforms_f.hflip(label1) 252 | label2 = transforms_f.hflip(label2) 253 | if logits1 is not None: 254 | logits1 = transforms_f.hflip(logits1) 255 | if logits2 is not None: 256 | logits2 = transforms_f.hflip(logits2) 257 | 258 | # Transform to Tensor 259 | image = transforms_f.to_tensor(image) 260 | label1 = (transforms_f.to_tensor(label1) * 255).long() 261 | label2 = (transforms_f.to_tensor(label2) * 255).long() 262 | label1[label1 == 255] = -1 # invalid pixels are re-mapped to index -1 263 | label2[label2 == 255] = -1 # invalid pixels are re-mapped to index -1 264 | if logits1 is not None: 265 | logits1 = transforms_f.to_tensor(logits1) 266 | if logits2 is not None: 267 | logits2 = transforms_f.to_tensor(logits2) 268 | 269 | # Apply ImageNet normalization 270 | image = transforms_f.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 271 | if logits1 is not None: 272 | return image, label1, label2, logits1, logits2 273 | else: 274 | return image, label1 275 | 276 | def tensor_to_pil(image, label, logits): 277 | image = denormalise(image) 278 | image = transforms_f.to_pil_image(image.cpu()) 279 | label = label.float() / 255. 280 | label = transforms_f.to_pil_image(label.unsqueeze(0).cpu()) 281 | logits = transforms_f.to_pil_image(logits.unsqueeze(0).cpu()) 282 | return image, label, logits 283 | 284 | def tensor_to_pil_2(image, label, logits1, logits2): 285 | image = denormalise(image) 286 | image = transforms_f.to_pil_image(image.cpu()) 287 | label = label.float() / 255. 288 | label = transforms_f.to_pil_image(label.unsqueeze(0).cpu()) 289 | logits1 = transforms_f.to_pil_image(logits1.unsqueeze(0).cpu()) 290 | logits2 = transforms_f.to_pil_image(logits2.unsqueeze(0).cpu()) 291 | return image, label, logits1, logits2 292 | 293 | def tensor_to_pil_3(image, label1, label2, logits1, logits2): 294 | image = denormalise(image) 295 | image = transforms_f.to_pil_image(image.cpu()) 296 | label1 = label1.float() / 255. 297 | label1 = transforms_f.to_pil_image(label1.unsqueeze(0).cpu()) 298 | label2 = label2.float() / 255. 299 | label2 = transforms_f.to_pil_image(label2.unsqueeze(0).cpu()) 300 | logits1 = transforms_f.to_pil_image(logits1.unsqueeze(0).cpu()) 301 | logits2 = transforms_f.to_pil_image(logits2.unsqueeze(0).cpu()) 302 | return image, label1, label2, logits1, logits2 303 | 304 | def denormalise(x, imagenet=True): 305 | if imagenet: 306 | x = transforms_f.normalize(x, mean=[0., 0., 0.], std=[1/0.229, 1/0.224, 1/0.225]) 307 | x = transforms_f.normalize(x, mean=[-0.485, -0.456, -0.406], std=[1., 1., 1.]) 308 | return x 309 | else: 310 | return (x + 1) / 2 311 | 312 | def batch_transform(images, labels, logits=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 313 | image_list, label_list, logits_list = [], [], [] 314 | device = images.device 315 | for k in range(images.shape[0]): 316 | image_pil, label_pil, logits_pil = tensor_to_pil(images[k], labels[k], logits[k]) 317 | aug_image, aug_label, aug_logits = transform(image_pil, label_pil, logits_pil, crop_size, scale_size, augmentation) 318 | image_list.append(aug_image.unsqueeze(0)) 319 | label_list.append(aug_label) 320 | logits_list.append(aug_logits) 321 | 322 | image_trans, label_trans, logits_trans = torch.cat(image_list).to(device), torch.cat(label_list).to(device), torch.cat(logits_list).to(device) 323 | return image_trans, label_trans, logits_trans 324 | 325 | def batch_transform_2(images, labels, logits_1=None, logits_2=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 326 | image_list, label_list, logits_1_list, logits_2_list = [], [], [], [] 327 | device = images.device 328 | for k in range(images.shape[0]): 329 | image_pil, label_pil, logits_pil_1, logits_pil_2 = tensor_to_pil_2(images[k], labels[k], logits_1[k], logits_2[k]) 330 | aug_image, aug_label, aug_logits_1, aug_logits_2 = transform_2(image_pil, label_pil, logits_pil_1, logits_pil_2, crop_size, scale_size, augmentation) 331 | image_list.append(aug_image.unsqueeze(0)) 332 | label_list.append(aug_label) 333 | logits_1_list.append(aug_logits_1) 334 | logits_2_list.append(aug_logits_2) 335 | 336 | image_trans, label_trans, logits_1_trans, logits_2_trans = torch.cat(image_list).to(device), torch.cat(label_list).to(device), torch.cat(logits_1_list).to(device), torch.cat(logits_2_list).to(device) 337 | return image_trans, label_trans, logits_1_trans, logits_2_trans 338 | 339 | def batch_transform_3(images, labels1, labels2, logits_1=None, logits_2=None, crop_size=(512, 512), scale_size=(0.8, 1.0), augmentation=True): 340 | image_list, label1_list, label2_list, logits_1_list, logits_2_list = [], [], [], [], [] 341 | device = images.device 342 | for k in range(images.shape[0]): 343 | image_pil, label1_pil, label2_pil, logits_pil_1, logits_pil_2 = tensor_to_pil_3(images[k], labels1[k], labels2[k], logits_1[k], logits_2[k]) 344 | aug_image, aug_label1, aug_label2, aug_logits_1, aug_logits_2 = transform_3(image_pil, label1_pil, label2_pil, logits_pil_1, logits_pil_2, crop_size, scale_size, augmentation) 345 | image_list.append(aug_image.unsqueeze(0)) 346 | label1_list.append(aug_label1) 347 | label2_list.append(aug_label2) 348 | logits_1_list.append(aug_logits_1) 349 | logits_2_list.append(aug_logits_2) 350 | 351 | image_trans, label1_trans, label2_trans, logits_1_trans, logits_2_trans = torch.cat(image_list).to(device), torch.cat(label1_list).to(device), torch.cat(label2_list).to(device), torch.cat(logits_1_list).to(device), torch.cat(logits_2_list).to(device) 352 | return image_trans, label1_trans, label2_trans, logits_1_trans, logits_2_trans 353 | 354 | def generate_cut_gather(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, mode='cutout'): 355 | 356 | batch_size, _, image_h, image_w = image.shape 357 | image = concat_all_gather(image) 358 | label = concat_all_gather(label) 359 | logits = concat_all_gather(logits) 360 | total_size = image.shape[0] 361 | device = image.device 362 | rank = dist.get_rank() 363 | 364 | if mode == 'none': 365 | return image[rank * batch_size: (rank + 1) * batch_size], label[rank * batch_size: (rank + 1) * batch_size].long(), logits[rank * batch_size: (rank + 1) * batch_size] 366 | 367 | new_image = [] 368 | new_label = [] 369 | new_logits = [] 370 | for i in range(total_size): 371 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 372 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 373 | label[i][(1 - mix_mask).bool()] = -1 374 | 375 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 376 | new_label.append(label[i].unsqueeze(0)) 377 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 378 | continue 379 | elif mode == 'cutmix': 380 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 381 | elif mode == 'classmix': 382 | mix_mask = generate_class_mask(label[i]).to(device) 383 | else: 384 | raise ValueError('mode must be in cutout, cutmix, or classmix') 385 | 386 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 387 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 388 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 389 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 390 | 391 | return new_image[rank * batch_size: (rank + 1) * batch_size], new_label[rank * batch_size: (rank + 1) * batch_size].long(), new_logits[rank * batch_size: (rank + 1) * batch_size] 392 | 393 | def generate_cut_gather_2(image: torch.Tensor, label: torch.Tensor, logits1: torch.Tensor, logits2: torch.Tensor, mode='cutout'): 394 | 395 | batch_size, _, image_h, image_w = image.shape 396 | image = concat_all_gather(image) 397 | label = concat_all_gather(label) 398 | logits1 = concat_all_gather(logits1) 399 | logits2 = concat_all_gather(logits2) 400 | total_size = image.shape[0] 401 | device = image.device 402 | rank = dist.get_rank() 403 | 404 | if mode == 'none': 405 | return image[rank * batch_size: (rank + 1) * batch_size], label[rank * batch_size: (rank + 1) * batch_size].long(), logits1[rank * batch_size: (rank + 1) * batch_size], logits2[rank * batch_size: (rank + 1) * batch_size] 406 | 407 | new_image = [] 408 | new_label = [] 409 | new_logits1 = [] 410 | new_logits2 = [] 411 | for i in range(total_size): 412 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 413 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 414 | label[i][(1 - mix_mask).bool()] = -1 415 | 416 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 417 | new_label.append(label[i].unsqueeze(0)) 418 | new_logits1.append((logits1[i] * mix_mask).unsqueeze(0)) 419 | new_logits2.append((logits2[i] * mix_mask).unsqueeze(0)) 420 | continue 421 | elif mode == 'cutmix': 422 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 423 | elif mode == 'classmix': 424 | mix_mask = generate_class_mask(label[i]).to(device) 425 | else: 426 | raise ValueError('mode must be in cutout, cutmix, or classmix') 427 | 428 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 429 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 430 | new_logits1.append((logits1[i] * mix_mask + logits1[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 431 | new_logits2.append((logits2[i] * mix_mask + logits2[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 432 | new_image, new_label, new_logits1, new_logits2 = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits1), torch.cat(new_logits2) 433 | 434 | return new_image[rank * batch_size: (rank + 1) * batch_size], new_label[rank * batch_size: (rank + 1) * batch_size].long(), new_logits1[rank * batch_size: (rank + 1) * batch_size], new_logits2[rank * batch_size: (rank + 1) * batch_size] 435 | 436 | def generate_cut_gather_3(image: torch.Tensor, label1: torch.Tensor, label2: torch.Tensor, logits1: torch.Tensor, logits2: torch.Tensor, mode='cutout'): 437 | 438 | batch_size, _, image_h, image_w = image.shape 439 | image = concat_all_gather(image) 440 | label1 = concat_all_gather(label1) 441 | label2 = concat_all_gather(label2) 442 | logits1 = concat_all_gather(logits1) 443 | logits2 = concat_all_gather(logits2) 444 | total_size = image.shape[0] 445 | device = image.device 446 | rank = dist.get_rank() 447 | 448 | new_image = [] 449 | new_label1 = [] 450 | new_label2 = [] 451 | new_logits1 = [] 452 | new_logits2 = [] 453 | for i in range(total_size): 454 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 455 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 456 | label1[i][(1 - mix_mask).bool()] = -1 457 | 458 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 459 | new_label1.append(label1[i].unsqueeze(0)) 460 | new_logits1.append((logits1[i] * mix_mask).unsqueeze(0)) 461 | new_logits2.append((logits2[i] * mix_mask).unsqueeze(0)) 462 | continue 463 | elif mode == 'cutmix': 464 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 465 | elif mode == 'classmix': 466 | mix_mask = generate_class_mask(label1[i]).to(device) 467 | else: 468 | raise ValueError('mode must be in cutout, cutmix, or classmix') 469 | 470 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 471 | new_label1.append((label1[i] * mix_mask + label1[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 472 | new_label2.append((label2[i] * mix_mask + label2[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 473 | new_logits1.append((logits1[i] * mix_mask + logits1[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 474 | new_logits2.append((logits2[i] * mix_mask + logits2[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 475 | new_image, new_label1, new_label2, new_logits1, new_logits2 = torch.cat(new_image), torch.cat(new_label1), torch.cat(new_label2), torch.cat(new_logits1), torch.cat(new_logits2) 476 | 477 | return new_image[rank * batch_size: (rank + 1) * batch_size], new_label1[rank * batch_size: (rank + 1) * batch_size].long(), new_label2[rank * batch_size: (rank + 1) * batch_size].long(), new_logits1[rank * batch_size: (rank + 1) * batch_size], new_logits2[rank * batch_size: (rank + 1) * batch_size] 478 | 479 | def generate_cut(image: torch.Tensor, label: torch.Tensor, logits: torch.Tensor, mode='cutout'): 480 | if mode == 'none': 481 | return image, label.long(), logits 482 | batch_size, _, image_h, image_w = image.shape 483 | device = image.device 484 | 485 | new_image = [] 486 | new_label = [] 487 | new_logits = [] 488 | for i in range(batch_size): 489 | if mode == 'cutout': # label: generated region is masked by -1, image: generated region is masked by 0 490 | mix_mask: torch.Tensor = generate_cutout_mask([image_h, image_w], ratio=2).to(device) 491 | label[i][(1 - mix_mask).bool()] = -1 492 | 493 | new_image.append((image[i] * mix_mask).unsqueeze(0)) 494 | new_label.append(label[i].unsqueeze(0)) 495 | new_logits.append((logits[i] * mix_mask).unsqueeze(0)) 496 | continue 497 | elif mode == 'cutmix': 498 | mix_mask = generate_cutout_mask([image_h, image_w]).to(device) 499 | elif mode == 'classmix': 500 | mix_mask = generate_class_mask(label[i]).to(device) 501 | else: 502 | raise ValueError('mode must be in cutout, cutmix, or classmix') 503 | 504 | new_image.append((image[i] * mix_mask + image[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 505 | new_label.append((label[i] * mix_mask + label[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 506 | new_logits.append((logits[i] * mix_mask + logits[(i + 1) % batch_size] * (1 - mix_mask)).unsqueeze(0)) 507 | new_image, new_label, new_logits = torch.cat(new_image), torch.cat(new_label), torch.cat(new_logits) 508 | 509 | return new_image, new_label.long(), new_logits 510 | 511 | def generate_class_mask(pseudo_labels: torch.Tensor): 512 | # select the half classes and cover up them 513 | labels = torch.unique(pseudo_labels) # all unique labels 514 | labels_select: torch.Tensor = labels[torch.randperm(len(labels))][:len(labels) // 2] # Randmoly select half of labels 515 | mask = (pseudo_labels.unsqueeze(-1) == labels_select).any(dim=-1) 516 | return mask.float() 517 | 518 | def generate_cutout_mask(image_size, ratio=2): 519 | # Cutout: random generate mask where the region inside is 0, one ouside is 1 520 | cutout_area = image_size[0] * image_size[1] / ratio 521 | 522 | w = np.random.randint(image_size[1] / ratio + 1, image_size[1]) 523 | h = np.round(cutout_area / w) 524 | 525 | x_start = np.random.randint(0, image_size[1] - w + 1) 526 | y_start = np.random.randint(0, image_size[0] - h + 1) 527 | 528 | x_end = int(x_start + w) 529 | y_end = int(y_start + h) 530 | 531 | mask = torch.ones(image_size) 532 | mask[y_start: y_end, x_start: x_end] = 0 533 | 534 | return mask.float() 535 | 536 | @torch.no_grad() 537 | def concat_all_gather(tensor): 538 | """ 539 | Performs all_gather operation on the provided tensors. 540 | Warning: torch.distributed.all_ather has no gradient. 541 | """ 542 | tensor_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] 543 | torch.distributed.all_gather(tensor_gather, tensor, async_op=False) 544 | output = torch.cat(tensor_gather, dim=0) 545 | 546 | return output -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/Cityscapes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/dataset_helpers/__pycache__/Cityscapes.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/Cityscapes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/dataset_helpers/__pycache__/Cityscapes.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/VOC.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/dataset_helpers/__pycache__/VOC.cpython-36.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/VOC.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/dataset_helpers/__pycache__/VOC.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/VOC.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/dataset_helpers/__pycache__/VOC.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/dataset_helpers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/dataset_helpers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/dataset_helpers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/dataset_helpers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/loss/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/loss/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/loss/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/loss/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/loss/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from generalframeworks.utils import simplex 6 | from generalframeworks.networks.ddp_model import concat_all_gather 7 | 8 | class ProbOhemCrossEntropy2d(nn.Module): 9 | def __init__(self, ignore_label, reduction='mean', thresh=0.6, min_kept=256, 10 | down_ratio=1, use_weight=False): 11 | super(ProbOhemCrossEntropy2d, self).__init__() 12 | self.ignore_label = ignore_label 13 | self.thresh = float(thresh) 14 | self.min_kept = int(min_kept) 15 | self.down_ratio = down_ratio 16 | self.criterion = torch.nn.CrossEntropyLoss(reduction=reduction, 17 | ignore_index=ignore_label) 18 | 19 | def forward(self, pred, target): 20 | b, c, h, w = pred.size() 21 | target = target.view(-1) 22 | valid_mask = target.ne(self.ignore_label) 23 | target = target * valid_mask.long() 24 | num_valid = valid_mask.sum() 25 | prob = F.softmax(pred, dim=1) 26 | prob = (prob.transpose(0, 1)).reshape(c, -1) 27 | 28 | if self.min_kept > num_valid: 29 | print('Labels: {}'.format(num_valid)) 30 | elif num_valid > 0: 31 | prob = prob.masked_fill_(~valid_mask, 1) 32 | mask_prob = prob[target, torch.arange(len(target), dtype=torch.long)] 33 | threshold = self.thresh 34 | if self.min_kept > 0: 35 | index = mask_prob.argsort() 36 | threshold_index = index[min(len(index), self.min_kept) - 1] 37 | if mask_prob[threshold_index] > self.thresh: 38 | threshold = mask_prob[threshold_index] 39 | kept_mask = mask_prob.le(threshold) 40 | target = target * kept_mask.long() 41 | valid_mask = valid_mask * kept_mask 42 | 43 | target = target.masked_fill_(~valid_mask, self.ignore_label) 44 | target = target.view(b, h, w) 45 | 46 | return self.criterion(pred, target) 47 | 48 | class Attention_Threshold_Loss(nn.Module): 49 | def __init__(self, strong_threshold): 50 | super(Attention_Threshold_Loss, self).__init__() 51 | self.strong_threshold = strong_threshold 52 | 53 | def forward(self, pred: torch.Tensor, pseudo_label: torch.Tensor, logits: torch.Tensor): 54 | batch_size = pred.shape[0] 55 | valid_mask = (pseudo_label >= 0).float() # only count valid pixels (class) 56 | weighting = logits.view(batch_size, -1).ge(self.strong_threshold).sum(-1) / (valid_mask.view(batch_size, -1).sum(-1)) # May be nan if the whole target is masked in cutout 57 | #self.tmp_valid_num = logits.ge(self.strong_threshold).view(logits.shape[0], -1).float().sum(-1).mean(0) 58 | # weight represent the proportion of valid pixels in this batch 59 | loss = F.cross_entropy(pred, pseudo_label, reduction='none', ignore_index=-1) # pixel-wise 60 | weighted_loss = torch.mean(torch.masked_select(weighting[:, None, None] * loss, loss > 0)) 61 | # weight torch.size([4]) -> weight[:, None, None] torch.size([4, 1, 1]) for broadcast to multiply the weight to the corresponding class 62 | # torch.masked_select to select loss > 0 only leaved 63 | 64 | return weighted_loss 65 | 66 | class Contrast_Loss(nn.Module): 67 | def __init__(self, num_queries, num_negatives, temp=0.5, mean=False, strong_threshold=0.97, alpha=0.99): 68 | super(Contrast_Loss, self).__init__() 69 | self.temp = temp 70 | self.mean = mean 71 | self.num_queries = num_queries 72 | self.num_negatives = num_negatives 73 | self.strong_threshold = strong_threshold 74 | self.alpha = alpha 75 | def forward(self, rep, label, mask, prob, prototypes): 76 | # we gather all representations (mu and sigma) cross mutiple GPUs during this progress 77 | rep_prt = concat_all_gather(rep) # For protoype computing on all cards (w/o gradients) 78 | batch_size, num_feat, rep_w, rep_h = rep.shape 79 | num_segments = label.shape[1] #21 80 | valid_pixel_all = label * mask 81 | valid_pixel_all_prt = concat_all_gather(valid_pixel_all) # For protoype computing on all cards 82 | 83 | # Permute representation for indexing" [batch, rep_h, rep_w, feat_num] 84 | 85 | rep = rep.permute(0, 2, 3, 1) 86 | rep_prt = rep_prt.permute(0, 2, 3, 1) 87 | 88 | rep_all_list = [] 89 | rep_hard_list = [] 90 | num_list = [] 91 | proto_rep_list = [] 92 | 93 | for i in range(num_segments): #21 94 | valid_pixel = valid_pixel_all[:, i] 95 | valid_pixel_gather = valid_pixel_all_prt[:, i] 96 | if valid_pixel.sum() == 0: 97 | continue 98 | prob_seg = prob[:, i, :, :] 99 | rep_mask_hard = (prob_seg < self.strong_threshold) * valid_pixel.bool() # Only on single card 100 | # Prototype computing on all cards 101 | with torch.no_grad(): 102 | proto_rep_ = torch.mean((rep_prt[valid_pixel_gather.bool()]), dim=0, keepdim=True) 103 | if (prototypes[i].sum() == torch.tensor(0.0)): 104 | proto_rep_list.append(proto_rep_) 105 | prototypes[i] = proto_rep_ 106 | else: 107 | # Update gloal prototype 108 | prototypes[i] = self.alpha * prototypes[i] + (1 - self.alpha) * proto_rep_ 109 | proto_rep_list.append(prototypes[i].unsqueeze(0)) 110 | 111 | rep_all_list.append(rep[valid_pixel.bool()]) 112 | rep_hard_list.append(rep[rep_mask_hard]) 113 | num_list.append(int(valid_pixel.sum().item())) 114 | 115 | # Compute Probabilistic Representation Contrastive Loss 116 | if (len(num_list) <= 1) : # in some rare cases, a small mini-batch only contain 1 or no semantic class 117 | return torch.tensor(0.0) + 0 * rep.sum() # A trick for avoiding data leakage in DDP training 118 | else: 119 | contrast_loss = torch.tensor(0.0) 120 | proto_rep = torch.cat(proto_rep_list) # [c] 121 | valid_num = len(num_list) 122 | seg_len = torch.arange(valid_num) 123 | 124 | for i in range(valid_num): 125 | if len(rep_hard_list[i]) > 0: 126 | # Random Sampling anchor representations 127 | sample_idx = torch.randint(len(rep_hard_list[i]), size=(self.num_queries, )) 128 | anchor_rep = rep_hard_list[i][sample_idx] 129 | else: 130 | continue 131 | with torch.no_grad(): 132 | # Select negatives 133 | id_mask = torch.cat(([seg_len[i: ], seg_len[: i]])) 134 | proto_sim = torch.cosine_similarity(proto_rep[id_mask[0]].unsqueeze(0), proto_rep[id_mask[1:]], dim=1) 135 | proto_prob = torch.softmax(proto_sim / self.temp, dim=0) 136 | negative_dist = torch.distributions.categorical.Categorical(probs=proto_prob) 137 | samp_class = negative_dist.sample(sample_shape=[self.num_queries, self.num_negatives]) 138 | samp_num = torch.stack([(samp_class == c).sum(1) for c in range(len(proto_prob))], dim=1) 139 | negative_num_list = num_list[i+1: ] + num_list[: i] 140 | negative_index = negative_index_sampler(samp_num, negative_num_list) 141 | negative_rep_all = torch.cat(rep_all_list[i+1: ] + rep_all_list[: i]) 142 | negative_rep = negative_rep_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 143 | positive_rep = proto_rep[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 144 | all_rep = torch.cat((positive_rep, negative_rep), dim=1) 145 | 146 | logits = torch.cosine_similarity(anchor_rep.unsqueeze(1), all_rep, dim=2) 147 | contrast_loss = contrast_loss + F.cross_entropy(logits / self.temp, torch.zeros(self.num_queries).long().cuda()) 148 | 149 | return contrast_loss / valid_num 150 | 151 | class Contrast_Loss_ds(nn.Module): 152 | def __init__(self, num_queries, num_negatives, temp=0.5, mean=False, strong_threshold=0.97, alpha=0.99): 153 | super(Contrast_Loss_ds, self).__init__() 154 | self.temp = temp 155 | self.mean = mean 156 | self.num_queries = num_queries 157 | self.num_negatives = num_negatives 158 | self.alpha = alpha 159 | def forward(self, rep, label, mask, prob, prototypes, strong_threshold): 160 | # we gather all representations (mu and sigma) cross mutiple GPUs during this progress 161 | rep_prt = concat_all_gather(rep) # For protoype computing on all cards (w/o gradients) 162 | batch_size, num_feat, rep_w, rep_h = rep.shape 163 | num_segments = label.shape[1] #21 164 | valid_pixel_all = label * mask 165 | valid_pixel_all_prt = concat_all_gather(valid_pixel_all) # For protoype computing on all cards 166 | 167 | # Permute representation for indexing" [batch, rep_h, rep_w, feat_num] 168 | 169 | rep = rep.permute(0, 2, 3, 1) 170 | rep_prt = rep_prt.permute(0, 2, 3, 1) 171 | 172 | rep_all_list = [] 173 | rep_hard_list = [] 174 | num_list = [] 175 | proto_rep_list = [] 176 | 177 | for i in range(num_segments): #21 178 | valid_pixel = valid_pixel_all[:, i] 179 | valid_pixel_gather = valid_pixel_all_prt[:, i] 180 | if valid_pixel.sum() == 0: 181 | continue 182 | prob_seg = prob[:, i, :, :] 183 | rep_mask_hard = (prob_seg < strong_threshold) * valid_pixel.bool() # Only on single card 184 | # Prototype computing on all cards 185 | with torch.no_grad(): 186 | proto_rep_ = torch.mean((rep_prt[valid_pixel_gather.bool()]), dim=0, keepdim=True) 187 | if (prototypes[i].sum() == torch.tensor(0.0)): 188 | proto_rep_list.append(proto_rep_) 189 | prototypes[i] = proto_rep_ 190 | else: 191 | # Update gloal prototype 192 | prototypes[i] = self.alpha * prototypes[i] + (1 - self.alpha) * proto_rep_ 193 | proto_rep_list.append(prototypes[i].unsqueeze(0)) 194 | 195 | rep_all_list.append(rep[valid_pixel.bool()]) 196 | rep_hard_list.append(rep[rep_mask_hard]) 197 | num_list.append(int(valid_pixel.sum().item())) 198 | 199 | # Compute Probabilistic Representation Contrastive Loss 200 | if (len(num_list) <= 1) : # in some rare cases, a small mini-batch only contain 1 or no semantic class 201 | return torch.tensor(0.0) + 0 * rep.sum() # A trick for avoiding data leakage in DDP training 202 | else: 203 | contrast_loss = torch.tensor(0.0) 204 | proto_rep = torch.cat(proto_rep_list) # [c] 205 | valid_num = len(num_list) 206 | seg_len = torch.arange(valid_num) 207 | 208 | for i in range(valid_num): 209 | if len(rep_hard_list[i]) > 0: 210 | # Random Sampling anchor representations 211 | sample_idx = torch.randint(len(rep_hard_list[i]), size=(self.num_queries, )) 212 | anchor_rep = rep_hard_list[i][sample_idx] 213 | else: 214 | continue 215 | with torch.no_grad(): 216 | # Select negatives 217 | id_mask = torch.cat(([seg_len[i: ], seg_len[: i]])) 218 | proto_sim = torch.cosine_similarity(proto_rep[id_mask[0]].unsqueeze(0), proto_rep[id_mask[1:]], dim=1) 219 | proto_prob = torch.softmax(proto_sim / self.temp, dim=0) 220 | negative_dist = torch.distributions.categorical.Categorical(probs=proto_prob) 221 | samp_class = negative_dist.sample(sample_shape=[self.num_queries, self.num_negatives]) 222 | samp_num = torch.stack([(samp_class == c).sum(1) for c in range(len(proto_prob))], dim=1) 223 | negative_num_list = num_list[i+1: ] + num_list[: i] 224 | negative_index = negative_index_sampler(samp_num, negative_num_list) 225 | negative_rep_all = torch.cat(rep_all_list[i+1: ] + rep_all_list[: i]) 226 | negative_rep = negative_rep_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 227 | positive_rep = proto_rep[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 228 | all_rep = torch.cat((positive_rep, negative_rep), dim=1) 229 | 230 | logits = torch.cosine_similarity(anchor_rep.unsqueeze(1), all_rep, dim=2) 231 | contrast_loss = contrast_loss + F.cross_entropy(logits / self.temp, torch.zeros(self.num_queries).long().cuda()) 232 | 233 | return contrast_loss / valid_num 234 | 235 | class Contrast_Loss_sig(nn.Module): 236 | def __init__(self, num_queries, num_negatives, temp=0.5, mean=False, strong_threshold=0.97, alpha=0.99): 237 | super(Contrast_Loss_sig, self).__init__() 238 | self.temp = temp 239 | self.mean = mean 240 | self.num_queries = num_queries 241 | self.num_negatives = num_negatives 242 | self.strong_threshold = strong_threshold 243 | self.alpha = alpha 244 | def forward(self, rep, label, mask, prob, prototypes): 245 | # we gather all representations (mu and sigma) cross mutiple GPUs during this progress 246 | batch_size, num_feat, rep_w, rep_h = rep.shape 247 | num_segments = label.shape[1] #21 248 | valid_pixel_all = label * mask 249 | 250 | # Permute representation for indexing" [batch, rep_h, rep_w, feat_num] 251 | 252 | rep = rep.permute(0, 2, 3, 1) 253 | rep_prt = rep_prt.permute(0, 2, 3, 1) 254 | 255 | rep_all_list = [] 256 | rep_hard_list = [] 257 | num_list = [] 258 | proto_rep_list = [] 259 | 260 | for i in range(num_segments): #21 261 | valid_pixel = valid_pixel_all[:, i] 262 | if valid_pixel.sum() == 0: 263 | continue 264 | prob_seg = prob[:, i, :, :] 265 | rep_mask_hard = (prob_seg < self.strong_threshold) * valid_pixel.bool() # Only on single card 266 | # Prototype computing on all cards 267 | with torch.no_grad(): 268 | proto_rep_ = torch.mean((rep_prt[valid_pixel.bool()]), dim=0, keepdim=True) 269 | if (prototypes[i].sum() == torch.tensor(0.0)): 270 | proto_rep_list.append(proto_rep_) 271 | prototypes[i] = proto_rep_ 272 | else: 273 | # Update gloal prototype 274 | prototypes[i] = self.alpha * prototypes[i] + (1 - self.alpha) * proto_rep_ 275 | proto_rep_list.append(prototypes[i].unsqueeze(0)) 276 | 277 | rep_all_list.append(rep[valid_pixel.bool()]) 278 | rep_hard_list.append(rep[rep_mask_hard]) 279 | num_list.append(int(valid_pixel.sum().item())) 280 | 281 | # Compute Probabilistic Representation Contrastive Loss 282 | if (len(num_list) <= 1) : # in some rare cases, a small mini-batch only contain 1 or no semantic class 283 | return torch.tensor(0.0) + 0 * rep.sum() # A trick for avoiding data leakage in DDP training 284 | else: 285 | contrast_loss = torch.tensor(0.0) 286 | proto_rep = torch.cat(proto_rep_list) # [c] 287 | valid_num = len(num_list) 288 | seg_len = torch.arange(valid_num) 289 | 290 | for i in range(valid_num): 291 | if len(rep_hard_list[i]) > 0: 292 | # Random Sampling anchor representations 293 | sample_idx = torch.randint(len(rep_hard_list[i]), size=(self.num_queries, )) 294 | anchor_rep = rep_hard_list[i][sample_idx] 295 | else: 296 | continue 297 | with torch.no_grad(): 298 | # Select negatives 299 | id_mask = torch.cat(([seg_len[i: ], seg_len[: i]])) 300 | proto_sim = torch.cosine_similarity(proto_rep[id_mask[0]].unsqueeze(0), proto_rep[id_mask[1:]], dim=1) 301 | proto_prob = torch.softmax(proto_sim / self.temp, dim=0) 302 | negative_dist = torch.distributions.categorical.Categorical(probs=proto_prob) 303 | samp_class = negative_dist.sample(sample_shape=[self.num_queries, self.num_negatives]) 304 | samp_num = torch.stack([(samp_class == c).sum(1) for c in range(len(proto_prob))], dim=1) 305 | negative_num_list = num_list[i+1: ] + num_list[: i] 306 | negative_index = negative_index_sampler(samp_num, negative_num_list) 307 | negative_rep_all = torch.cat(rep_all_list[i+1: ] + rep_all_list[: i]) 308 | negative_rep = negative_rep_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 309 | positive_rep = proto_rep[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 310 | all_rep = torch.cat((positive_rep, negative_rep), dim=1) 311 | 312 | logits = torch.cosine_similarity(anchor_rep.unsqueeze(1), all_rep, dim=2) 313 | contrast_loss = contrast_loss + F.cross_entropy(logits / self.temp, torch.zeros(self.num_queries).long().cuda()) 314 | 315 | return contrast_loss / valid_num 316 | 317 | class Prcl_Loss_single(nn.Module): 318 | # For single GPU users 319 | def __init__(self, num_queries, num_negatives, temp=0.5, mean=False, strong_threshold=0.97): 320 | super(Prcl_Loss_single, self).__init__() 321 | self.temp = temp 322 | self.mean = mean 323 | self.num_queries = num_queries 324 | self.num_negatives = num_negatives 325 | self.strong_threshold = strong_threshold 326 | def forward(self, mu, sigma, label, mask, prob): 327 | batch_size, num_feat, mu_w, mu_h = mu.shape 328 | num_segments = label.shape[1] #21 329 | valid_pixel_all = label * mask 330 | # Permute representation for indexing" [batch, rep_h, rep_w, feat_num] 331 | 332 | mu = mu.permute(0, 2, 3, 1) 333 | sigma = sigma.permute(0, 2, 3, 1) 334 | 335 | mu_all_list = [] 336 | sigma_all_list = [] 337 | mu_hard_list = [] 338 | sigma_hard_list = [] 339 | num_list = [] 340 | proto_mu_list = [] 341 | proto_sigma_list = [] 342 | 343 | for i in range(num_segments): #21 344 | valid_pixel = valid_pixel_all[:, i] 345 | if valid_pixel.sum() == 0: 346 | continue 347 | prob_seg = prob[:, i, :, :] 348 | rep_mask_hard = (prob_seg < self.strong_threshold) * valid_pixel.bool() # Only on single card 349 | # Prototype computing 350 | with torch.no_grad(): 351 | proto_sigma_ = 1 / torch.sum((1 / sigma[valid_pixel.bool()]), dim=0, keepdim=True) 352 | proto_mu_ = torch.sum((proto_sigma_ / sigma[valid_pixel.bool()]) \ 353 | * mu[valid_pixel.bool()], dim=0, keepdim=True) 354 | proto_mu_list.append(proto_mu_) 355 | proto_sigma_list.append(proto_sigma_) 356 | 357 | mu_all_list.append(mu[valid_pixel.bool()]) 358 | sigma_all_list.append(sigma[valid_pixel.bool()]) 359 | mu_hard_list.append(mu[rep_mask_hard]) 360 | sigma_hard_list.append(sigma[rep_mask_hard]) 361 | num_list.append(int(valid_pixel.sum().item())) 362 | 363 | # Compute Probabilistic Representation Contrastive Loss 364 | if (len(num_list) <= 1) : # in some rare cases, a small mini-batch only contain 1 or no semantic class 365 | return torch.tensor(0.0) #+ 0 * mu.sum() + 0 * sigma.sum() # A trick for avoiding data leakage in DDP training 366 | else: 367 | prcl_loss = torch.tensor(0.0) 368 | proto_mu = torch.cat(proto_mu_list) # [c] 369 | proto_sigma = torch.cat(proto_sigma_list) 370 | valid_num = len(num_list) 371 | seg_len = torch.arange(valid_num) 372 | 373 | for i in range(valid_num): 374 | if len(mu_hard_list[i]) > 0: 375 | # Random Sampling anchor representations 376 | sample_idx = torch.randint(len(mu_hard_list[i]), size=(self.num_queries, )) 377 | anchor_mu = mu_hard_list[i][sample_idx] 378 | anchor_sigma = sigma_hard_list[i][sample_idx] 379 | else: 380 | continue 381 | with torch.no_grad(): 382 | # Select negatives 383 | id_mask = torch.cat(([seg_len[i: ], seg_len[: i]])) 384 | proto_sim = mutual_likelihood_score(proto_mu[id_mask[0].unsqueeze(0)], 385 | proto_mu[id_mask[1: ]], 386 | proto_sigma[id_mask[0].unsqueeze(0)], 387 | proto_sigma[id_mask[1: ]]) 388 | proto_prob = torch.softmax(proto_sim / self.temp, dim=0) 389 | negative_dist = torch.distributions.categorical.Categorical(probs=proto_prob) 390 | samp_class = negative_dist.sample(sample_shape=[self.num_queries, self.num_negatives]) 391 | samp_num = torch.stack([(samp_class == c).sum(1) for c in range(len(proto_prob))], dim=1) 392 | negative_num_list = num_list[i+1: ] + num_list[: i] 393 | negative_index = negative_index_sampler(samp_num, negative_num_list) 394 | negative_mu_all = torch.cat(mu_all_list[i+1: ] + mu_all_list[: i]) 395 | negative_sigma_all = torch.cat(sigma_all_list[i+1: ] + sigma_all_list[: i]) 396 | negative_mu = negative_mu_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 397 | negative_sigma = negative_sigma_all[negative_index].reshape(self.num_queries, self.num_negatives, num_feat) 398 | positive_mu = proto_mu[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 399 | positive_sigma = proto_sigma[i].unsqueeze(0).unsqueeze(0).repeat(self.num_queries, 1, 1) 400 | all_mu = torch.cat((positive_mu, negative_mu), dim=1) 401 | all_sigma = torch.cat((positive_sigma, negative_sigma), dim=1) 402 | 403 | logits = mutual_likelihood_score(anchor_mu.unsqueeze(1), all_mu, anchor_sigma.unsqueeze(1), all_sigma) 404 | prcl_loss = prcl_loss + F.cross_entropy(logits / self.temp, torch.zeros(self.num_queries).long().cuda()) 405 | 406 | return prcl_loss / valid_num 407 | 408 | #### Utils #### 409 | 410 | def negative_index_sampler(samp_num, seg_num_list): 411 | negative_index = [] 412 | for i in range(samp_num.shape[0]): 413 | for j in range(samp_num.shape[1]): 414 | negative_index += np.random.randint(low=sum(seg_num_list[: j]), 415 | high=sum(seg_num_list[: j+1]), 416 | size=int(samp_num[i, j])).tolist() 417 | 418 | return negative_index 419 | 420 | def mutual_likelihood_score(mu_0, mu_1, sigma_0, sigma_1): 421 | ''' 422 | Compute the MLS 423 | param: mu_0, mu_1 [256, 513, 256] [256, 1, 256] 424 | sigma_0, sigma_1 [256, 513, 256] [256, 1, 256] 425 | ''' 426 | mu_0 = F.normalize(mu_0, dim=-1) 427 | mu_1 = F.normalize(mu_1, dim=-1) 428 | up = (mu_0 - mu_1) ** 2 429 | down = sigma_0 + sigma_1 430 | mls = -0.5 * (up / down + torch.log(down)).mean(-1) 431 | 432 | 433 | return mls -------------------------------------------------------------------------------- /generalframeworks/meter/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/meter/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/meter/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/meter/__pycache__/meter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/meter/__pycache__/meter.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/meter/mIOU_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | class mIOUMetrics: 7 | def __init__(self, num_classes, ignore_index, device='cpu'): 8 | self.num_classes = num_classes 9 | self.ignore_index = ignore_index 10 | self.total_area_inter = torch.zeros(size=(num_classes,),dtype=torch.float64).to(device) 11 | self.total_area_union = torch.zeros(size=(num_classes,),dtype=torch.float64).to(device) 12 | self.device = device 13 | 14 | def update(self, predict, target): 15 | # 预处理 将ignore label对应的像素点筛除 16 | # target = target.squeeze(1) 17 | # print('t', target.shape, 'p', predict.shape) 18 | target_mask = (target != self.ignore_index) # [batch, height, width]筛选出所有需要训练的像素点标签 19 | target = target[target_mask] # [num_pixels] 20 | 21 | # _, predict = torch.max(predict, dim=1) 22 | # print('unique', torch.unique(predict), 'shape:', predict.shape) 23 | # predict = predict.permute(0,2,3,1) 24 | # predict = torch.nn.functional.one_hot(predict, 19) 25 | # print('unique1', torch.unique(predict), 'shape:', predict.shape) 26 | 27 | batch, num_class, height, width = predict.size() # 28 | predict = predict.permute(0, 2, 3, 1) # [batch, height, width, num_class] 29 | 30 | # 计算pixel accuracy 31 | predict = predict[target_mask.unsqueeze(-1).repeat(1, 1, 1, num_class)].view(-1, num_class) 32 | predict = predict.argmax(dim=1) 33 | num_pixels = target.numel() 34 | correct = (predict == target).sum() 35 | pixel_acc = correct / num_pixels 36 | 37 | # 计算所有类别的mIoU 38 | predict = predict + 1 39 | target = target + 1 40 | intersection = predict * (predict == target).long() 41 | area_inter = torch.histc(intersection.float(), bins=num_class, max=num_class, min=1) 42 | area_pred = torch.histc(predict.float(), bins=num_class, max=num_class, min=1) 43 | area_label = torch.histc(target.float(), bins=num_class, max=num_class, min=1) 44 | 45 | self.total_area_inter += area_inter 46 | area_union = (area_pred + area_label - area_inter) 47 | self.total_area_union += area_union 48 | 49 | 50 | def reset(self): 51 | self.total_area_inter = torch.zeros(size=(self.num_classes,),dtype=torch.float64).to(self.device) 52 | self.total_area_union = torch.zeros(size=(self.num_classes,),dtype=torch.float64).to(self.device) 53 | 54 | def get_mIOU(self): 55 | iou = self.total_area_inter / self.total_area_union 56 | #print(iou) 57 | miou = torch.mean(iou[~iou.isnan()]) 58 | if torch.isnan(miou).any(): 59 | print('get_mIOU somthing wrong! nan detects!') 60 | return miou 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /generalframeworks/meter/meter.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.sharedctypes import Value 2 | import torch 3 | from generalframeworks.utils import class2one_hot 4 | import numpy as np 5 | 6 | 7 | class Meter(object): 8 | 9 | def reset(self): 10 | # Reset the Meter to default settings 11 | pass 12 | 13 | def add(self, pred_logits, label): 14 | # Log a new value to the meter 15 | pass 16 | 17 | def value(self): 18 | # Get the value of the meter in the current state 19 | pass 20 | 21 | def summary(self) -> dict: 22 | raise NotImplementedError 23 | 24 | def detailed_summary(self) -> dict: 25 | raise NotImplementedError 26 | 27 | class ConfMatrix(object): 28 | def __init__(self, num_classes): 29 | self.num_classes = num_classes 30 | self.mat = None 31 | 32 | def update(self, pred, target): 33 | n = self.num_classes 34 | if self.mat is None: 35 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 36 | with torch.no_grad(): 37 | k = (target >= 0) & (target < n) 38 | inds = n * target[k].to(torch.int64) + pred[k] 39 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 40 | 41 | 42 | def get_metrics(self): 43 | h = self.mat.float() 44 | acc = torch.diag(h).sum() / h.sum() 45 | up = torch.diag(h) 46 | down = h.sum(1) + h.sum(0) - torch.diag(h) 47 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h) + 1e-6) 48 | return torch.mean(iu).item(), acc.item() 49 | 50 | def get_valid_metrics(self): 51 | h = self.mat.float() 52 | acc = torch.diag(h).sum() / h.sum() 53 | up = torch.diag(h) 54 | down = h.sum(1) + h.sum(0) - torch.diag(h) 55 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h) + 1e-6) 56 | num_no_zero = (iu == 0).sum() 57 | return iu.sum() / (len(iu) - num_no_zero).item(), acc.item() 58 | 59 | 60 | class My_ConfMatrix(Meter): 61 | def __init__(self, num_classes): 62 | super(ConfMatrix, self).__init__() 63 | self.num_classes = num_classes 64 | self.mat = None 65 | self.reset() 66 | self.mIOU = [] 67 | self.Acc = [] 68 | 69 | def add(self, pred_logits, label): 70 | pred_logits = pred_logits.argmax(1).flatten() 71 | label = label.flatten() 72 | n = self.num_classes 73 | if self.mat is None: 74 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred_logits.device) 75 | with torch.no_grad(): 76 | k = (label >= 0) & (label < n) 77 | inds = n * label[k].to(torch.int64) + pred_logits[k] 78 | self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) 79 | 80 | def value(self, mode='mean'): 81 | h = self.mat.float() 82 | self.acc = torch.diag(h).sum() / h.sum() 83 | self.iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 84 | if mode == 'mean': 85 | return torch.mean(self.iu).item(), self.acc.item() 86 | else: 87 | raise ValueError("mode must be in (mean)") 88 | 89 | def reset(self): 90 | self.mIOU = [] 91 | self.Acc = [] 92 | 93 | def summary(self) -> dict: 94 | mIOU_dct: dict = {} 95 | Acc_dct: dict = {} 96 | for c in range(self.num_classes): 97 | if c != 0: 98 | mIOU_dct['mIOU_{}'.format(c)] = np.array([self.value(i, mode='all')[0] for i in range(len(self.mIOU))])[ 99 | :, c].mean() 100 | Acc_dct['Acc_{}'.format(c)] = np.array([self.value(i, mode='all')[1] for i in range(len(self.mIOU))])[:, 101 | c].mean() 102 | return mIOU_dct, Acc_dct 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /generalframeworks/networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/ddp_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/__pycache__/ddp_model.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/ddp_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/__pycache__/ddp_model.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/__pycache__/module.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/uncer_head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/__pycache__/uncer_head.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/__pycache__/uncer_head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/__pycache__/uncer_head.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/ddp_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | from generalframeworks.networks.deeplabv3.deeplabv3 import DeepLabv3Plus_with_rep 5 | import torch.nn.functional as F 6 | from generalframeworks.dataset_helpers.VOC import batch_transform, generate_cut_gather, batch_transform_2, batch_transform_3, generate_cut_gather_2, generate_cut_gather_3 7 | 8 | class Model_ori_pseudo(nn.Module): 9 | ''' 10 | Build a model for DDP with: a DeepLabV3_Plus, a ema, and a mlp 11 | ''' 12 | 13 | def __init__(self, base_encoder, num_classes=21, output_dim=256, ema_alpha=0.99, config=None) -> None: 14 | super(Model_ori_pseudo, self).__init__() 15 | self.model = DeepLabv3Plus_with_rep(base_encoder, num_classes=num_classes, output_dim=output_dim, dilate_scale=8) 16 | ##### Init EMA ##### 17 | self.step = 0 18 | self.ema_model = copy.deepcopy(self.model) 19 | for p in self.ema_model.parameters(): 20 | p.requires_grad = False 21 | self.alpha = ema_alpha 22 | print('EMA model has been prepared. Alpha = {}'.format(self.alpha)) 23 | 24 | self.config = config 25 | 26 | def ema_update(self): 27 | decay = min(1 - 1 / (self.step + 1), self.alpha) 28 | for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()): 29 | ema_param.data = decay * ema_param.data + (1 - decay) * param.data 30 | self.step += 1 31 | 32 | def forward(self, train_l_image, train_u_image): 33 | ##### generate pseudo label ##### 34 | with torch.no_grad(): 35 | pred_u, _ = self.ema_model(train_u_image) 36 | pred_u_large_raw = F.interpolate(pred_u, size=train_u_image.shape[2:], mode='bilinear', align_corners=True) 37 | pseudo_logits, pseudo_labels = torch.max(torch.softmax(pred_u_large_raw, dim=1), dim=1) 38 | 39 | # Randomly scale images 40 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = batch_transform(train_u_image, pseudo_labels, 41 | pseudo_logits, 42 | crop_size=self.config['Dataset']['crop_size'], 43 | scale_size=self.config['Dataset']['scale_size'], 44 | augmentation=False) 45 | # Apply mixing strategy, we gather all images cross mutiple GPUs during this progress 46 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = generate_cut_gather(train_u_aug_image, 47 | train_u_aug_label, 48 | train_u_aug_logits, 49 | mode=self.config['Dataset'][ 50 | 'mix_mode']) 51 | # Apply augmnetation : color jitter + flip + gaussian blur 52 | train_u_aug_image, train_u_aug_label, train_u_aug_logits = batch_transform(train_u_aug_image, 53 | train_u_aug_label, 54 | train_u_aug_logits, 55 | crop_size=self.config['Dataset']['crop_size'], 56 | scale_size=(1.0, 1.0), 57 | augmentation=True) 58 | 59 | 60 | pred_l, rep_l = self.model(train_l_image) 61 | pred_l_large = F.interpolate(pred_l, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 62 | 63 | pred_u, rep_u = self.model(train_u_aug_image) 64 | pred_u_large = F.interpolate(pred_u, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 65 | 66 | 67 | rep_all = torch.cat((rep_l, rep_u)) 68 | pred_all = torch.cat((pred_l, pred_u)) 69 | 70 | return pred_l_large, pred_u_large, train_u_aug_label, train_u_aug_logits, rep_all, pred_all, pred_u_large_raw 71 | 72 | 73 | class Model_mix(nn.Module): 74 | ''' 75 | Build a model for DDP with: a DeepLabV3_Plus, a ema, and a mlp 76 | ''' 77 | 78 | def __init__(self, base_encoder, num_classes=21, output_dim=256, ema_alpha=0.99, config=None, temp=0.25) -> None: 79 | super(Model_mix, self).__init__() 80 | self.model = DeepLabv3Plus_with_rep(base_encoder, num_classes=num_classes, output_dim=output_dim, dilate_scale=8) 81 | self.temp = temp 82 | self.num_classes = num_classes 83 | ##### Init EMA ##### 84 | self.step = 0 85 | self.ema_model = copy.deepcopy(self.model) 86 | for p in self.ema_model.parameters(): 87 | p.requires_grad = False 88 | self.alpha = ema_alpha 89 | print('EMA model has been prepared. Alpha = {}'.format(self.alpha)) 90 | 91 | self.config = config 92 | 93 | def ema_update(self): 94 | decay = min(1 - 1 / (self.step + 1), self.alpha) 95 | for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()): 96 | ema_param.data = decay * ema_param.data + (1 - decay) * param.data 97 | self.step += 1 98 | 99 | def forward(self, train_l_image, train_u_image, prototypes): 100 | ##### generate pseudo label from class predictor and indicator from representation predictor ##### 101 | with torch.no_grad(): 102 | pred_l, rep_l = self.ema_model(train_l_image) 103 | pred_u, rep_u = self.ema_model(train_u_image) 104 | rep_b, rep_dim, rep_w, rep_h = rep_u.shape 105 | norm_rep_u = rep_u.permute(0, 2, 3, 1) 106 | norm_rep_u = F.normalize(norm_rep_u, dim=-1) 107 | norm_proto = F.normalize(prototypes, dim=-1).permute(1, 0) 108 | norm_rep_u = norm_rep_u.reshape(rep_b * rep_w * rep_h, rep_dim) 109 | sim_mat = torch.mm(norm_rep_u, norm_proto) 110 | sim_mat = sim_mat.reshape(rep_b, rep_w, rep_h, self.num_classes).permute(0, 3, 1, 2) 111 | sim_mat_large_raw = F.interpolate(sim_mat, size=train_u_image.shape[2:], mode='bilinear', align_corners=True) 112 | pseudo_logits_rep, pseudo_labels_rep = torch.max(F.softmax(sim_mat_large_raw / self.temp, dim=1), dim=1) 113 | pred_u_large_raw = F.interpolate(pred_u, size=train_u_image.shape[2:], mode='bilinear', align_corners=True) 114 | pseudo_logits_cls, pseudo_labels_cls = torch.max(torch.softmax(pred_u_large_raw, dim=1), dim=1) 115 | label_mask = pseudo_labels_cls.eq(pseudo_labels_rep) 116 | label_mask = (~label_mask).float() 117 | pseudo_labels = pseudo_labels_cls - label_mask * self.num_classes 118 | pseudo_labels[pseudo_labels < 0] = 255 119 | 120 | # Randomly scale images 121 | train_u_aug_image, train_u_aug_label, train_u_aug_logits_cls, train_u_aug_logits_rep = batch_transform_2(train_u_image, pseudo_labels, 122 | pseudo_logits_cls, pseudo_logits_rep, 123 | crop_size=self.config['Dataset']['crop_size'], 124 | scale_size=self.config['Dataset']['scale_size'], 125 | augmentation=False) 126 | # Apply mixing strategy, we gather all images cross mutiple GPUs during this progress 127 | train_u_aug_image, train_u_aug_label, train_u_aug_logits_cls, train_u_aug_logits_rep = generate_cut_gather_2(train_u_aug_image, 128 | train_u_aug_label, 129 | train_u_aug_logits_cls, train_u_aug_logits_rep, 130 | mode=self.config['Dataset']['mix_mode']) 131 | # Apply augmnetation : color jitter + flip + gaussian blur 132 | train_u_aug_image, train_u_aug_label, train_u_aug_logits_cls, train_u_aug_logits_rep = batch_transform_2(train_u_aug_image, 133 | train_u_aug_label, 134 | train_u_aug_logits_cls, train_u_aug_logits_rep, 135 | crop_size=self.config['Dataset']['crop_size'], 136 | scale_size=(1.0, 1.0), 137 | augmentation=True) 138 | 139 | 140 | pred_l, rep_l = self.model(train_l_image) 141 | pred_l_large = F.interpolate(pred_l, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 142 | 143 | pred_u, rep_u = self.model(train_u_aug_image) 144 | pred_u_large = F.interpolate(pred_u, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 145 | 146 | 147 | rep_all = torch.cat((rep_l, rep_u)) 148 | rep_all_b = rep_all.shape[0] 149 | norm_rep_all = rep_all.permute(0, 2, 3, 1) 150 | norm_rep_all = F.normalize(norm_rep_all, dim=-1) 151 | norm_rep_all = norm_rep_all.reshape(rep_all_b * rep_w * rep_h, rep_dim) 152 | prob_all = torch.mm(norm_rep_all, norm_proto) 153 | prob_all = prob_all.reshape(rep_all_b, rep_w, rep_h, self.num_classes).permute(0, 3, 1, 2) 154 | prob_all = F.softmax(prob_all / self.temp, dim=1) 155 | 156 | return pred_l_large, pred_u_large, train_u_aug_label, train_u_aug_logits_cls, train_u_aug_logits_rep, rep_all, prob_all 157 | 158 | class Model_cross(nn.Module): 159 | ''' 160 | Build a model for DDP with: a DeepLabV3_Plus, a ema, and a mlp 161 | ''' 162 | 163 | def __init__(self, base_encoder, num_classes=21, output_dim=256, ema_alpha=0.99, config=None, temp=0.1) -> None: 164 | super(Model_cross, self).__init__() 165 | self.model = DeepLabv3Plus_with_rep(base_encoder, num_classes=num_classes, output_dim=output_dim, dilate_scale=8) 166 | self.temp = temp 167 | self.num_classes = num_classes 168 | ##### Init EMA ##### 169 | self.step = 0 170 | self.ema_model = copy.deepcopy(self.model) 171 | for p in self.ema_model.parameters(): 172 | p.requires_grad = False 173 | self.alpha = ema_alpha 174 | print('EMA model has been prepared. Alpha = {}'.format(self.alpha)) 175 | 176 | self.config = config 177 | 178 | def ema_update(self): 179 | decay = min(1 - 1 / (self.step + 1), self.alpha) 180 | for ema_param, param in zip(self.ema_model.parameters(), self.model.parameters()): 181 | ema_param.data = decay * ema_param.data + (1 - decay) * param.data 182 | self.step += 1 183 | 184 | def forward(self, train_l_image, train_u_image, prototypes): 185 | ##### generate pseudo label from class predictor and indicator from representation predictor ##### 186 | with torch.no_grad(): 187 | pred_l, rep_l = self.ema_model(train_l_image) 188 | pred_u, rep_u = self.ema_model(train_u_image) 189 | rep_b, rep_dim, rep_w, rep_h = rep_u.shape 190 | norm_rep_u = rep_u.permute(0, 2, 3, 1) 191 | norm_rep_u = F.normalize(norm_rep_u, dim=-1) 192 | norm_proto = F.normalize(prototypes, dim=-1).permute(1, 0) 193 | norm_rep_u = norm_rep_u.reshape(rep_b * rep_w * rep_h, rep_dim) 194 | sim_mat = torch.mm(norm_rep_u, norm_proto) 195 | sim_mat = sim_mat.reshape(rep_b, rep_w, rep_h, self.num_classes).permute(0, 3, 1, 2) 196 | sim_mat_large_raw = F.interpolate(sim_mat, size=train_u_image.shape[2:], mode='bilinear', align_corners=True) 197 | pseudo_logits_rep, pseudo_labels_rep = torch.max(F.softmax(sim_mat_large_raw / self.temp, dim=1), dim=1) 198 | pred_u_large_raw = F.interpolate(pred_u, size=train_u_image.shape[2:], mode='bilinear', align_corners=True) 199 | pseudo_logits_cls, pseudo_labels_cls = torch.max(torch.softmax(pred_u_large_raw, dim=1), dim=1) 200 | 201 | # Randomly scale images 202 | train_u_aug_image, train_u_aug_label_cls, train_u_aug_label_rep, train_u_aug_logits_cls, train_u_aug_logits_rep = batch_transform_3(train_u_image, pseudo_labels_cls, pseudo_labels_rep, 203 | pseudo_logits_cls, pseudo_logits_rep, 204 | crop_size=self.config['Dataset']['crop_size'], 205 | scale_size=self.config['Dataset']['scale_size'], 206 | augmentation=False) 207 | # Apply mixing strategy, we gather all images cross mutiple GPUs during this progress 208 | train_u_aug_image, train_u_aug_label_cls, train_u_aug_label_rep, train_u_aug_logits_cls, train_u_aug_logits_rep = generate_cut_gather_3(train_u_aug_image, 209 | train_u_aug_label_cls, 210 | train_u_aug_label_rep, 211 | train_u_aug_logits_cls, train_u_aug_logits_rep, 212 | mode=self.config['Dataset']['mix_mode']) 213 | # Apply augmnetation : color jitter + flip + gaussian blur 214 | train_u_aug_image, train_u_aug_label_cls, train_u_aug_label_rep, train_u_aug_logits_cls, train_u_aug_logits_rep = batch_transform_3(train_u_aug_image, 215 | train_u_aug_label_cls, 216 | train_u_aug_label_rep, 217 | train_u_aug_logits_cls, train_u_aug_logits_rep, 218 | crop_size=self.config['Dataset']['crop_size'], 219 | scale_size=(1.0, 1.0), 220 | augmentation=True) 221 | 222 | 223 | pred_l, rep_l = self.model(train_l_image) 224 | pred_l_large = F.interpolate(pred_l, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 225 | 226 | pred_u, rep_u = self.model(train_u_aug_image) 227 | pred_u_large = F.interpolate(pred_u, size=train_l_image.shape[2:], mode='bilinear', align_corners=True) 228 | 229 | 230 | rep_all = torch.cat((rep_l, rep_u)) 231 | rep_all_b = rep_all.shape[0] 232 | norm_rep_all = rep_all.permute(0, 2, 3, 1) 233 | norm_rep_all = F.normalize(norm_rep_all, dim=-1) 234 | norm_rep_all = norm_rep_all.reshape(rep_all_b * rep_w * rep_h, rep_dim) 235 | prob_all = torch.mm(norm_rep_all, norm_proto) 236 | prob_all = prob_all.reshape(rep_all_b, rep_w, rep_h, self.num_classes).permute(0, 3, 1, 2) 237 | prob_all = F.softmax(prob_all / self.temp, dim=1) 238 | 239 | return pred_l_large, pred_u_large, train_u_aug_label_cls, train_u_aug_label_rep, train_u_aug_logits_cls, train_u_aug_logits_rep, rep_all, prob_all 240 | 241 | @torch.no_grad() 242 | def concat_all_gather(tensor): 243 | """ 244 | Performs all_gather operation on the provided tensors. 245 | Warning: torch.distributed.all_ather has no gradient. 246 | """ 247 | tensor_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] 248 | torch.distributed.all_gather(tensor_gather, tensor, async_op=False) 249 | output = torch.cat(tensor_gather, dim=0) 250 | 251 | return output -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/deeplabv3/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/deeplabv3/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/aspp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/deeplabv3/__pycache__/aspp.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/aspp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/deeplabv3/__pycache__/aspp.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/deeplabv3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/deeplabv3/__pycache__/deeplabv3.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/__pycache__/deeplabv3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/networks/deeplabv3/__pycache__/deeplabv3.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DeepLabHead(nn.Sequential): 7 | def __init__(self, in_channels, num_classes): 8 | super(DeepLabHead, self).__init__( 9 | ASPP(in_channels, [12, 24, 36]), 10 | nn.Conv2d(256, 256, 3, padding=1, bias=False), 11 | nn.BatchNorm2d(256), 12 | nn.ReLU(), 13 | nn.Conv2d(256, num_classes, 1) 14 | ) 15 | 16 | 17 | class ASPPConv(nn.Sequential): 18 | def __init__(self, in_channels, out_channels, dilation): 19 | modules = [ 20 | nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU() 23 | ] 24 | super(ASPPConv, self).__init__(*modules) 25 | 26 | 27 | class ASPPPooling(nn.Sequential): 28 | def __init__(self, in_channels, out_channels): 29 | super(ASPPPooling, self).__init__( 30 | nn.AdaptiveAvgPool2d(1), 31 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 32 | nn.BatchNorm2d(out_channels), 33 | nn.ReLU()) 34 | 35 | def forward(self, x): 36 | size = x.shape[-2:] 37 | x = super(ASPPPooling, self).forward(x) 38 | return F.interpolate(x, size=size, mode='bilinear', align_corners=False) 39 | 40 | 41 | class ASPP(nn.Module): 42 | def __init__(self, in_channels, atrous_rates): 43 | super(ASPP, self).__init__() 44 | out_channels = 256 45 | #modules = [] 46 | modules = torch.nn.ModuleList() 47 | modules.append(nn.Sequential( 48 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 49 | nn.BatchNorm2d(out_channels), 50 | nn.ReLU())) 51 | 52 | rate1, rate2, rate3 = tuple(atrous_rates) 53 | modules.append(ASPPConv(in_channels, out_channels, rate1)) 54 | modules.append(ASPPConv(in_channels, out_channels, rate2)) 55 | modules.append(ASPPConv(in_channels, out_channels, rate3)) 56 | modules.append(ASPPPooling(in_channels, out_channels)) 57 | # self.convs = nn.ModuleList(modules) 58 | self.convs = modules 59 | 60 | self.project = nn.Sequential( 61 | nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), 62 | nn.BatchNorm2d(out_channels), 63 | nn.ReLU(), 64 | # nn.Dropout(0.5) 65 | ) 66 | 67 | def forward(self, x): 68 | res = [] 69 | for conv in self.convs: 70 | res.append(conv(x)) 71 | res = torch.cat(res, dim=1) 72 | return self.project(res) 73 | -------------------------------------------------------------------------------- /generalframeworks/networks/deeplabv3/deeplabv3.py: -------------------------------------------------------------------------------- 1 | from .aspp import * 2 | from functools import partial 3 | 4 | ##### For PRCL Loss ##### 5 | class DeepLabv3Plus_with_un(nn.Module): 6 | def __init__(self, orig_resnet, dilate_scale=16, num_classes=21, output_dim=256): 7 | super(DeepLabv3Plus_with_un, self).__init__() 8 | if dilate_scale == 8: 9 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 10 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 11 | aspp_dilate = [12, 24, 36] 12 | 13 | elif dilate_scale == 16: 14 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 15 | aspp_dilate = [6, 12, 18] 16 | 17 | # take pre-defined ResNet, except AvgPool and FC 18 | self.resnet_conv1 = orig_resnet.conv1 19 | #self.resnet_conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) # Change the num of input channel 20 | self.resnet_bn1 = orig_resnet.bn1 21 | self.resnet_relu1 = orig_resnet.relu 22 | self.resnet_maxpool = orig_resnet.maxpool 23 | 24 | self.resnet_layer1 = orig_resnet.layer1 25 | self.resnet_layer2 = orig_resnet.layer2 26 | self.resnet_layer3 = orig_resnet.layer3 27 | self.resnet_layer4 = orig_resnet.layer4 28 | 29 | self.ASPP = ASPP(2048, aspp_dilate) 30 | 31 | self.project = nn.Sequential( 32 | nn.Conv2d(256, 48, 1, bias=False), 33 | nn.BatchNorm2d(48), 34 | nn.ReLU(inplace=True), 35 | ) 36 | 37 | self.classifier = nn.Sequential( 38 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 39 | nn.BatchNorm2d(256), 40 | nn.ReLU(), 41 | nn.Conv2d(256, num_classes, 1) 42 | ) 43 | 44 | self.representation = nn.Sequential( 45 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 46 | nn.BatchNorm2d(256), 47 | nn.ReLU(), 48 | nn.Conv2d(256, output_dim, 1) 49 | ) 50 | 51 | def _nostride_dilate(self, m, dilate): 52 | classname = m.__class__.__name__ 53 | if classname.find('Conv') != -1: 54 | # the convolution with stride 55 | if m.stride == (2, 2): 56 | m.stride = (1, 1) 57 | if m.kernel_size == (3, 3): 58 | m.dilation = (dilate // 2, dilate // 2) 59 | m.padding = (dilate // 2, dilate // 2) 60 | 61 | # other convoluions 62 | else: 63 | if m.kernel_size == (3, 3): 64 | m.dilation = (dilate, dilate) 65 | m.padding = (dilate, dilate) 66 | 67 | def forward(self, x): 68 | h_w = x.shape[2:] 69 | # with ResNet-50 Encoder 70 | x = self.resnet_relu1(self.resnet_bn1(self.resnet_conv1(x))) 71 | x = self.resnet_maxpool(x) 72 | 73 | x_low = self.resnet_layer1(x) 74 | x = self.resnet_layer2(x_low) 75 | x = self.resnet_layer3(x) 76 | x = self.resnet_layer4(x) 77 | 78 | feature = self.ASPP(x) 79 | 80 | # Decoder 81 | x_low = self.project(x_low) 82 | output_feature = F.interpolate(feature, size=x_low.shape[2:], mode='bilinear', align_corners=True) 83 | prediction = self.classifier(torch.cat([x_low, output_feature], dim=1)) 84 | representation = self.representation(torch.cat([x_low, output_feature], dim=1)) 85 | 86 | 87 | return prediction, representation, torch.cat([x_low, output_feature], dim=1) 88 | 89 | 90 | class DeepLabv3Plus_with_rep(nn.Module): 91 | def __init__(self, orig_resnet, dilate_scale=16, num_classes=21, output_dim=256): 92 | super(DeepLabv3Plus_with_rep, self).__init__() 93 | if dilate_scale == 8: 94 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 95 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 96 | aspp_dilate = [12, 24, 36] 97 | 98 | elif dilate_scale == 16: 99 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 100 | aspp_dilate = [6, 12, 18] 101 | 102 | # take pre-defined ResNet, except AvgPool and FC 103 | self.resnet_conv1 = orig_resnet.conv1 104 | self.resnet_bn1 = orig_resnet.bn1 105 | self.resnet_relu1 = orig_resnet.relu 106 | self.resnet_maxpool = orig_resnet.maxpool 107 | 108 | self.resnet_layer1 = orig_resnet.layer1 109 | self.resnet_layer2 = orig_resnet.layer2 110 | self.resnet_layer3 = orig_resnet.layer3 111 | self.resnet_layer4 = orig_resnet.layer4 112 | 113 | self.ASPP = ASPP(2048, aspp_dilate) 114 | 115 | self.project = nn.Sequential( 116 | nn.Conv2d(256, 48, 1, bias=False), 117 | nn.BatchNorm2d(48), 118 | nn.ReLU(inplace=True), 119 | ) 120 | 121 | self.classifier = nn.Sequential( 122 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 123 | nn.BatchNorm2d(256), 124 | nn.ReLU(), 125 | nn.Conv2d(256, num_classes, 1) 126 | ) 127 | 128 | self.representation = nn.Sequential( 129 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 130 | nn.BatchNorm2d(256), 131 | nn.ReLU(), 132 | nn.Conv2d(256, output_dim, 1) 133 | ) 134 | 135 | def _nostride_dilate(self, m, dilate): 136 | classname = m.__class__.__name__ 137 | if classname.find('Conv') != -1: 138 | # the convolution with stride 139 | if m.stride == (2, 2): 140 | m.stride = (1, 1) 141 | if m.kernel_size == (3, 3): 142 | m.dilation = (dilate // 2, dilate // 2) 143 | m.padding = (dilate // 2, dilate // 2) 144 | 145 | # other convoluions 146 | else: 147 | if m.kernel_size == (3, 3): 148 | m.dilation = (dilate, dilate) 149 | m.padding = (dilate, dilate) 150 | 151 | def forward(self, x): 152 | x = self.resnet_relu1(self.resnet_bn1(self.resnet_conv1(x))) 153 | x = self.resnet_maxpool(x) 154 | 155 | x_low = self.resnet_layer1(x) 156 | x = self.resnet_layer2(x_low) 157 | x = self.resnet_layer3(x) 158 | x = self.resnet_layer4(x) 159 | 160 | feature = self.ASPP(x) 161 | 162 | # Decoder 163 | x_low = self.project(x_low) 164 | output_feature = F.interpolate(feature, size=x_low.shape[2:], mode='bilinear', align_corners=True) 165 | prediction = self.classifier(torch.cat([x_low, output_feature], dim=1)) 166 | representation = self.representation(torch.cat([x_low, output_feature], dim=1)) 167 | 168 | 169 | return prediction, representation 170 | 171 | class DeepLabv3Plus(nn.Module): 172 | def __init__(self, orig_resnet, dilate_scale=16, num_classes=21, output_dim=256): 173 | super(DeepLabv3Plus, self).__init__() 174 | if dilate_scale == 8: 175 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 176 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 177 | aspp_dilate = [12, 24, 36] 178 | 179 | elif dilate_scale == 16: 180 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 181 | aspp_dilate = [6, 12, 18] 182 | 183 | # take pre-defined ResNet, except AvgPool and FC 184 | self.resnet_conv1 = orig_resnet.conv1 185 | self.resnet_bn1 = orig_resnet.bn1 186 | self.resnet_relu1 = orig_resnet.relu 187 | self.resnet_maxpool = orig_resnet.maxpool 188 | 189 | self.resnet_layer1 = orig_resnet.layer1 190 | self.resnet_layer2 = orig_resnet.layer2 191 | self.resnet_layer3 = orig_resnet.layer3 192 | self.resnet_layer4 = orig_resnet.layer4 193 | 194 | self.ASPP = ASPP(2048, aspp_dilate) 195 | 196 | self.project = nn.Sequential( 197 | nn.Conv2d(256, 48, 1, bias=False), 198 | nn.BatchNorm2d(48), 199 | nn.ReLU(inplace=True), 200 | ) 201 | 202 | self.classifier = nn.Sequential( 203 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 204 | nn.BatchNorm2d(256), 205 | nn.ReLU(), 206 | nn.Conv2d(256, num_classes, 1) 207 | ) 208 | 209 | def _nostride_dilate(self, m, dilate): 210 | classname = m.__class__.__name__ 211 | if classname.find('Conv') != -1: 212 | # the convolution with stride 213 | if m.stride == (2, 2): 214 | m.stride = (1, 1) 215 | if m.kernel_size == (3, 3): 216 | m.dilation = (dilate // 2, dilate // 2) 217 | m.padding = (dilate // 2, dilate // 2) 218 | 219 | # other convoluions 220 | else: 221 | if m.kernel_size == (3, 3): 222 | m.dilation = (dilate, dilate) 223 | m.padding = (dilate, dilate) 224 | 225 | def forward(self, x): 226 | x = self.resnet_relu1(self.resnet_bn1(self.resnet_conv1(x))) 227 | x = self.resnet_maxpool(x) 228 | 229 | x_low = self.resnet_layer1(x) 230 | x = self.resnet_layer2(x_low) 231 | x = self.resnet_layer3(x) 232 | x = self.resnet_layer4(x) 233 | 234 | feature = self.ASPP(x) 235 | 236 | # Decoder 237 | x_low = self.project(x_low) 238 | output_feature = F.interpolate(feature, size=x_low.shape[2:], mode='bilinear', align_corners=True) 239 | prediction = self.classifier(torch.cat([x_low, output_feature], dim=1)) 240 | 241 | 242 | return prediction 243 | 244 | class DeepLabv3Plus_E(nn.Module): 245 | def __init__(self, orig_resnet, dilate_scale=16, num_classes=21, output_dim=256): 246 | super(DeepLabv3Plus_E, self).__init__() 247 | if dilate_scale == 8: 248 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 249 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 250 | aspp_dilate = [12, 24, 36] 251 | 252 | elif dilate_scale == 16: 253 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 254 | aspp_dilate = [6, 12, 18] 255 | 256 | # take pre-defined ResNet, except AvgPool and FC 257 | self.resnet_conv1 = orig_resnet.conv1 258 | self.resnet_bn1 = orig_resnet.bn1 259 | self.resnet_relu1 = orig_resnet.relu 260 | self.resnet_maxpool = orig_resnet.maxpool 261 | 262 | self.resnet_layer1 = orig_resnet.layer1 263 | self.resnet_layer2 = orig_resnet.layer2 264 | self.resnet_layer3 = orig_resnet.layer3 265 | self.resnet_layer4 = orig_resnet.layer4 266 | 267 | self.ASPP = ASPP(2048, aspp_dilate) 268 | 269 | def _nostride_dilate(self, m, dilate): 270 | classname = m.__class__.__name__ 271 | if classname.find('Conv') != -1: 272 | # the convolution with stride 273 | if m.stride == (2, 2): 274 | m.stride = (1, 1) 275 | if m.kernel_size == (3, 3): 276 | m.dilation = (dilate // 2, dilate // 2) 277 | m.padding = (dilate // 2, dilate // 2) 278 | 279 | # other convoluions 280 | else: 281 | if m.kernel_size == (3, 3): 282 | m.dilation = (dilate, dilate) 283 | m.padding = (dilate, dilate) 284 | 285 | def forward(self, x): 286 | x = self.resnet_relu1(self.resnet_bn1(self.resnet_conv1(x))) 287 | x = self.resnet_maxpool(x) 288 | 289 | x_low = self.resnet_layer1(x) 290 | x = self.resnet_layer2(x_low) 291 | x = self.resnet_layer3(x) 292 | x = self.resnet_layer4(x) 293 | 294 | feature = self.ASPP(x) 295 | 296 | return x_low, feature 297 | 298 | class DeepLabv3Plus_r(nn.Module): 299 | def __init__(self, orig_resnet, dilate_scale=16, num_classes=21, output_dim=256): 300 | super(DeepLabv3Plus_r, self).__init__() 301 | if dilate_scale == 8: 302 | orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) 303 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) 304 | aspp_dilate = [12, 24, 36] 305 | 306 | elif dilate_scale == 16: 307 | orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) 308 | aspp_dilate = [6, 12, 18] 309 | 310 | # take pre-defined ResNet, except AvgPool and FC 311 | self.resnet_conv1 = orig_resnet.conv1 312 | self.resnet_bn1 = orig_resnet.bn1 313 | self.resnet_relu1 = orig_resnet.relu 314 | self.resnet_maxpool = orig_resnet.maxpool 315 | 316 | self.resnet_layer1 = orig_resnet.layer1 317 | self.resnet_layer2 = orig_resnet.layer2 318 | self.resnet_layer3 = orig_resnet.layer3 319 | self.resnet_layer4 = orig_resnet.layer4 320 | 321 | self.ASPP = ASPP(2048, aspp_dilate) 322 | 323 | self.project = nn.Sequential( 324 | nn.Conv2d(256, 48, 1, bias=False), 325 | nn.BatchNorm2d(48), 326 | nn.ReLU(inplace=True), 327 | ) 328 | 329 | self.representation = nn.Sequential( 330 | nn.Conv2d(304, 256, 3, padding=1, bias=False), 331 | nn.BatchNorm2d(256), 332 | nn.ReLU(), 333 | nn.Conv2d(256, output_dim, 1) 334 | ) 335 | 336 | def _nostride_dilate(self, m, dilate): 337 | classname = m.__class__.__name__ 338 | if classname.find('Conv') != -1: 339 | # the convolution with stride 340 | if m.stride == (2, 2): 341 | m.stride = (1, 1) 342 | if m.kernel_size == (3, 3): 343 | m.dilation = (dilate // 2, dilate // 2) 344 | m.padding = (dilate // 2, dilate // 2) 345 | 346 | # other convoluions 347 | else: 348 | if m.kernel_size == (3, 3): 349 | m.dilation = (dilate, dilate) 350 | m.padding = (dilate, dilate) 351 | 352 | def forward(self, x): 353 | x = self.resnet_relu1(self.resnet_bn1(self.resnet_conv1(x))) 354 | x = self.resnet_maxpool(x) 355 | 356 | x_low = self.resnet_layer1(x) 357 | x = self.resnet_layer2(x_low) 358 | x = self.resnet_layer3(x) 359 | x = self.resnet_layer4(x) 360 | 361 | feature = self.ASPP(x) 362 | 363 | # Decoder 364 | x_low = self.project(x_low) 365 | output_feature = F.interpolate(feature, size=x_low.shape[2:], mode='bilinear', align_corners=True) 366 | representation = self.representation(torch.cat([x_low, output_feature], dim=1)) 367 | 368 | 369 | return representation -------------------------------------------------------------------------------- /generalframeworks/networks/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | 6 | class Uncertainty_head(nn.Module): # feature -> log(sigma^2) 7 | def __init__(self, in_feat=304, out_feat=256): 8 | super(Uncertainty_head, self).__init__() 9 | self.fc1 = Parameter(torch.Tensor(out_feat, in_feat)) 10 | self.bn1 = nn.BatchNorm2d(out_feat, affine=True) 11 | self.relu = nn.ReLU() 12 | self.fc2 = Parameter(torch.Tensor(out_feat, out_feat)) 13 | self.bn2 = nn.BatchNorm2d(out_feat, affine=False) 14 | self.gamma = Parameter(torch.Tensor([1.0])) 15 | self.beta = Parameter(torch.Tensor([0.0])) 16 | 17 | nn.init.kaiming_normal_(self.fc1) 18 | nn.init.kaiming_normal_(self.fc2) 19 | 20 | def forward(self, x: torch.Tensor): 21 | x = x.permute(0, 2, 3, 1) 22 | x = F.linear(x, F.normalize(self.fc1, dim=-1)) # [B, W, H, D] 23 | x = x.permute(0, 3, 1, 2) # [B, W, H, D] -> [B, D, W, H] 24 | x = self.bn1(x) 25 | x = self.relu(x) 26 | x = x.permute(0, 2, 3, 1) 27 | x = F.linear(x, F.normalize(self.fc2, dim=-1)) 28 | x = x.permute(0, 3, 1, 2) 29 | x = self.bn2(x) 30 | x = self.gamma * x + self.beta 31 | x = torch.log(torch.exp(x) + 1e-6) 32 | x = torch.sigmoid(x) 33 | 34 | return x 35 | 36 | class Classifier(nn.Module): 37 | def __init__(self, in_feat=304, num_classes=21): 38 | super(Classifier, self).__init__() 39 | self.conv1 = nn.Conv2d(in_feat, 256, 3, padding=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(256) 41 | self.relu = nn.ReLU() 42 | self.conv2 = nn.Conv2d(256, num_classes, 1) 43 | 44 | def forward(self, x: torch.Tensor): 45 | x = self.conv1(x) 46 | x = self.bn1(x) 47 | x = self.relu(x) 48 | x = self.conv2(x) 49 | 50 | return x 51 | 52 | class Decoder(nn.Module): 53 | def __init__(self, in_feat=256, num_classes=19): 54 | super(Decoder, self).__init__() 55 | self.conv1 = nn.Conv2d(in_feat, 48, 1, bias=False) 56 | self.bn1 = nn.BatchNorm2d(48) 57 | self.relu1 = nn.ReLU(inplace=True) 58 | self.conv2 = nn.Conv2d(304, 256, 3, padding=1, bias=False) 59 | self.bn2 = nn.BatchNorm2d(256) 60 | self.relu2 = nn.ReLU() 61 | self.conv3 = nn.Conv2d(256, num_classes, 1) 62 | 63 | def forward(self, x_low: torch.Tensor, x: torch.Tensor): 64 | x_low = self.conv1(x_low) 65 | x_low = self.bn1(x_low) 66 | x_low = self.relu1(x_low) 67 | x = F.interpolate(x, size=x_low.shape[2:], mode='bilinear', align_corners=True) 68 | x = torch.cat([x_low, x], dim=1) 69 | x = self.conv2(x) 70 | x = self.bn2(x) 71 | x = self.relu2(x) 72 | x = self.conv3(x) 73 | 74 | return x -------------------------------------------------------------------------------- /generalframeworks/networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = [ 6 | "ResNet", 7 | "resnet18", 8 | "resnet34", 9 | "resnet50", 10 | "resnet101", 11 | "resnet152", 12 | ] 13 | 14 | 15 | model_urls = { 16 | "resnet18": "/path/to/resnet18.pth", 17 | "resnet34": "/path/to/resnet34.pth", 18 | "resnet50": "/path/to/resnet50.pth", 19 | "resnet101": "path/to/resnet101.pth", 20 | "resnet152": "/path/to/resnet152.pth", 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d( 27 | in_planes, 28 | out_planes, 29 | kernel_size=3, 30 | stride=stride, 31 | padding=dilation, 32 | groups=groups, 33 | bias=False, 34 | dilation=dilation, 35 | ) 36 | 37 | 38 | def conv1x1(in_planes, out_planes, stride=1): 39 | """1x1 convolution""" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__( 47 | self, 48 | inplanes, 49 | planes, 50 | stride=1, 51 | downsample=None, 52 | groups=1, 53 | base_width=64, 54 | dilation=1, 55 | norm_layer=None, 56 | ): 57 | super(BasicBlock, self).__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | if groups != 1 or base_width != 64: 61 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 62 | if dilation > 1: 63 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 64 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 65 | self.conv1 = conv3x3(inplanes, planes, stride) 66 | self.bn1 = norm_layer(planes) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.conv2 = conv3x3(planes, planes) 69 | self.bn2 = norm_layer(planes) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | identity = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | out += identity 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class Bottleneck(nn.Module): 93 | expansion = 4 94 | 95 | def __init__( 96 | self, 97 | inplanes, 98 | planes, 99 | stride=1, 100 | downsample=None, 101 | groups=1, 102 | base_width=64, 103 | dilation=1, 104 | norm_layer=nn.BatchNorm2d, 105 | ): 106 | super(Bottleneck, self).__init__() 107 | width = int(planes * (base_width / 64.0)) * groups 108 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 109 | self.conv1 = conv1x1(inplanes, width) 110 | self.bn1 = norm_layer(width) 111 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 112 | self.bn2 = norm_layer(width) 113 | self.conv3 = conv1x1(width, planes * self.expansion) 114 | self.bn3 = norm_layer(planes * self.expansion) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.downsample = downsample 117 | self.stride = stride 118 | 119 | def forward(self, x): 120 | identity = x 121 | 122 | out = self.conv1(x) 123 | out = self.bn1(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv2(out) 127 | out = self.bn2(out) 128 | out = self.relu(out) 129 | 130 | out = self.conv3(out) 131 | out = self.bn3(out) 132 | 133 | if self.downsample is not None: 134 | identity = self.downsample(x) 135 | 136 | out += identity 137 | out = self.relu(out) 138 | 139 | return out 140 | 141 | 142 | class ResNet_Stem(nn.Module): 143 | def __init__( 144 | self, 145 | block, 146 | layers, 147 | zero_init_residual=True, 148 | groups=1, 149 | width_per_group=64, 150 | replace_stride_with_dilation=[False, True, True], 151 | multi_grid=True, 152 | fpn=True, 153 | ): 154 | super(ResNet_Stem, self).__init__() 155 | 156 | # norm_layer = 157 | norm_layer = nn.BatchNorm2d 158 | self._norm_layer = norm_layer 159 | 160 | self.inplanes = 128 161 | self.dilation = 1 162 | 163 | if replace_stride_with_dilation is None: 164 | # each element in the tuple indicates if we should replace 165 | # the 2x2 stride with a dilated convolution instead 166 | replace_stride_with_dilation = [False, False, False] 167 | 168 | if len(replace_stride_with_dilation) != 3: 169 | raise ValueError( 170 | "replace_stride_with_dilation should be None " 171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 172 | ) 173 | 174 | self.groups = groups 175 | self.base_width = width_per_group 176 | self.fpn = fpn 177 | self.conv1 = nn.Sequential( 178 | conv3x3(3, 64, stride=2), 179 | norm_layer(64), 180 | nn.ReLU(inplace=True), 181 | conv3x3(64, 64), 182 | norm_layer(64), 183 | nn.ReLU(inplace=True), 184 | conv3x3(64, self.inplanes), 185 | ) 186 | self.bn1 = norm_layer(self.inplanes) 187 | self.relu = nn.ReLU(inplace=True) 188 | self.maxpool = nn.MaxPool2d( 189 | kernel_size=3, stride=2, padding=1, ceil_mode=True 190 | ) # change 191 | 192 | self.layer1 = self._make_layer(block, 64, layers[0]) 193 | self.layer2 = self._make_layer( 194 | block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 195 | ) 196 | self.layer3 = self._make_layer( 197 | block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 198 | ) 199 | self.layer4 = self._make_layer( 200 | block, 201 | 512, 202 | layers[3], 203 | stride=2, 204 | dilate=replace_stride_with_dilation[2], 205 | multi_grid=multi_grid, 206 | ) 207 | 208 | for m in self.modules(): 209 | if isinstance(m, nn.Conv2d): 210 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 211 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm)): 212 | nn.init.constant_(m.weight, 1) 213 | nn.init.constant_(m.bias, 0) 214 | 215 | # Zero-initialize the last BN in each residual branch, 216 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 217 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 218 | if zero_init_residual: 219 | for m in self.modules(): 220 | if isinstance(m, Bottleneck): 221 | nn.init.constant_(m.bn3.weight, 0) 222 | elif isinstance(m, BasicBlock): 223 | nn.init.constant_(m.bn2.weight, 0) 224 | 225 | def get_outplanes(self): 226 | return self.inplanes 227 | 228 | def get_auxplanes(self): 229 | return self.inplanes // 2 230 | 231 | def _make_layer( 232 | self, block, planes, blocks, stride=1, dilate=False, multi_grid=False 233 | ): 234 | norm_layer = self._norm_layer 235 | downsample = None 236 | previous_dilation = self.dilation 237 | if dilate: 238 | self.dilation *= stride 239 | stride = 1 240 | if stride != 1 or self.inplanes != planes * block.expansion: 241 | downsample = nn.Sequential( 242 | conv1x1(self.inplanes, planes * block.expansion, stride), 243 | norm_layer(planes * block.expansion), 244 | ) 245 | 246 | grids = [1] * blocks 247 | if multi_grid: 248 | grids = [2, 2, 4] 249 | 250 | layers = [] 251 | layers.append( 252 | block( 253 | self.inplanes, 254 | planes, 255 | stride, 256 | downsample, 257 | self.groups, 258 | self.base_width, 259 | previous_dilation * grids[0], 260 | norm_layer, 261 | ) 262 | ) 263 | self.inplanes = planes * block.expansion 264 | for i in range(1, blocks): 265 | layers.append( 266 | block( 267 | self.inplanes, 268 | planes, 269 | groups=self.groups, 270 | base_width=self.base_width, 271 | dilation=self.dilation * grids[i], 272 | norm_layer=norm_layer, 273 | ) 274 | ) 275 | 276 | return nn.Sequential(*layers) 277 | 278 | def forward(self, x): 279 | x = self.relu(self.bn1(self.conv1(x))) 280 | x = self.maxpool(x) 281 | 282 | x = self.layer1(x) 283 | x1 = x 284 | x = self.layer2(x) 285 | x2 = x 286 | x3 = self.layer3(x) 287 | x4 = self.layer4(x3) 288 | if self.fpn: 289 | return [x1, x2, x3, x4] 290 | else: 291 | return [x3, x4] 292 | 293 | 294 | 295 | def resnet18(pretrained=False, **kwargs): 296 | """Constructs a ResNet-18 model. 297 | 298 | Args: 299 | pretrained (bool): If True, returns a model pre-trained on ImageNet 300 | """ 301 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 302 | if pretrained: 303 | model_url = model_urls["resnet18"] 304 | state_dict = torch.load(model_url) 305 | 306 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 307 | print( 308 | f"[Info] Load ImageNet pretrain from '{model_url}'", 309 | "\nmissing_keys: ", 310 | missing_keys, 311 | "\nunexpected_keys: ", 312 | unexpected_keys, 313 | ) 314 | return model 315 | 316 | 317 | def resnet34(pretrained=False, **kwargs): 318 | """Constructs a ResNet-34 model. 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | """ 323 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 324 | if pretrained: 325 | model_url = model_urls["resnet34"] 326 | state_dict = torch.load(model_url) 327 | 328 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 329 | print( 330 | f"[Info] Load ImageNet pretrain from '{model_url}'", 331 | "\nmissing_keys: ", 332 | missing_keys, 333 | "\nunexpected_keys: ", 334 | unexpected_keys, 335 | ) 336 | return model 337 | 338 | 339 | def resnet50(pretrained=True, **kwargs): 340 | """Constructs a ResNet-50 model. 341 | 342 | Args: 343 | pretrained (bool): If True, returns a model pre-trained on ImageNet 344 | """ 345 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 346 | if pretrained: 347 | model_url = model_urls["resnet50"] 348 | state_dict = torch.load(model_url) 349 | 350 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 351 | print( 352 | f"[Info] Load ImageNet pretrain from '{model_url}'", 353 | "\nmissing_keys: ", 354 | missing_keys, 355 | "\nunexpected_keys: ", 356 | unexpected_keys, 357 | ) 358 | return model 359 | 360 | 361 | def resnet101(pretrained=True, **kwargs): 362 | """Constructs a ResNet-101 model. 363 | 364 | Args: 365 | pretrained (bool): If True, returns a model pre-trained on ImageNet 366 | """ 367 | model = ResNet_Stem(Bottleneck, [3, 4, 23, 3], **kwargs) 368 | if pretrained: 369 | model_url = model_urls["resnet101"] 370 | state_dict = torch.load(model_url) 371 | 372 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 373 | print( 374 | f"[Info] Load ImageNet pretrain from '{model_url}'", 375 | "\nmissing_keys: ", 376 | missing_keys, 377 | "\nunexpected_keys: ", 378 | unexpected_keys, 379 | ) 380 | return model 381 | 382 | 383 | def resnet152(pretrained=True, **kwargs): 384 | """Constructs a ResNet-152 model. 385 | 386 | Args: 387 | pretrained (bool): If True, returns a model pre-trained on ImageNet 388 | """ 389 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 390 | if pretrained: 391 | model_url = model_urls["resnet152"] 392 | state_dict = torch.load(model_url) 393 | 394 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 395 | print( 396 | f"[Info] Load ImageNet pretrain from '{model_url}'", 397 | "\nmissing_keys: ", 398 | missing_keys, 399 | "\nunexpected_keys: ", 400 | unexpected_keys, 401 | ) 402 | return model 403 | -------------------------------------------------------------------------------- /generalframeworks/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/scheduler/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/scheduler/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/my_lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/scheduler/__pycache__/my_lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/my_lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/scheduler/__pycache__/my_lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/rampscheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/scheduler/__pycache__/rampscheduler.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/__pycache__/rampscheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/scheduler/__pycache__/rampscheduler.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/scheduler/my_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | class PolyLR(_LRScheduler): 5 | def __init__(self, optimizer, max_iters, power=0.9, last_epoch=-1, min_lr=1e-6): 6 | self.power = power 7 | self.max_iters = max_iters 8 | self.min_lr = min_lr 9 | super(PolyLR, self).__init__(optimizer, last_epoch) 10 | 11 | def get_lr(self): 12 | return [max(base_lr * (1 - self.last_epoch / self.max_iters) ** self.power, self.min_lr) 13 | for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /generalframeworks/scheduler/rampscheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class RampScheduler(object): 3 | 4 | def __init__(self, begin_epoch, max_epoch, max_value, ramp_mult): 5 | super().__init__() 6 | self.begin_epoch = int(begin_epoch) 7 | self.max_epoch = int(max_epoch) 8 | self.max_value = float(max_value) 9 | self.mult = float(ramp_mult) 10 | self.epoch = 0 11 | 12 | def step(self): 13 | self.epoch += 1 14 | 15 | @property 16 | def value(self): 17 | return self.get_lr(self.epoch, self.begin_epoch, self.max_epoch, self.max_value, self.mult) 18 | 19 | @staticmethod 20 | def get_lr(epoch, begin_epoch, max_epochs, max_val, mult): 21 | if epoch < begin_epoch: 22 | return 0. 23 | elif epoch >= max_epochs: 24 | return max_val 25 | return max_val * np.exp(mult * (1. - float(epoch - begin_epoch) / (max_epochs - begin_epoch)) ** 2) 26 | 27 | class RampdownScheduler(object): 28 | 29 | def __init__(self, begin_epoch, max_epoch, current_epoch, max_value, min_value, ramp_mult): 30 | super().__init__() 31 | self.begin_epoch = int(begin_epoch) 32 | self.max_epoch = int(max_epoch) 33 | self.max_value = float(max_value) 34 | self.mult = float(ramp_mult) 35 | self.epoch = current_epoch 36 | self.min_value = min_value 37 | 38 | def step(self): 39 | self.epoch += 1 40 | 41 | @property 42 | def value(self): 43 | current_value = self.get_lr(self.epoch, self.begin_epoch, self.max_epoch, self.max_value, self.min_value, self.mult) 44 | if current_value < self.min_value: 45 | current_value = self.min_value 46 | return current_value 47 | 48 | @staticmethod 49 | def get_lr(epoch, begin_epoch, max_epochs, max_val, min_value, mult): 50 | if epoch < begin_epoch: 51 | return 0. 52 | elif epoch >= max_epochs: 53 | return min_value 54 | return max_val * np.exp(mult * (float(epoch - begin_epoch) / (max_epochs - begin_epoch)) ** 2) -------------------------------------------------------------------------------- /generalframeworks/util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/dist_init.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/dist_init.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/meter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/meter.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/meter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/meter.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/miou.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/miou.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/miou.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/miou.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/torch_dist_sum.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/torch_dist_sum.cpython-37.pyc -------------------------------------------------------------------------------- /generalframeworks/util/__pycache__/torch_dist_sum.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WangChangqi98/CSS/8aec4a874c94dfba76c5150a89882ebe055e6552/generalframeworks/util/__pycache__/torch_dist_sum.cpython-38.pyc -------------------------------------------------------------------------------- /generalframeworks/util/dist_init.py: -------------------------------------------------------------------------------- 1 | def dist_init(port): 2 | import torch 3 | import os 4 | 5 | def init(host_addr, rank, local_rank, world_size, port): 6 | host_addr_full = 'tcp://' + host_addr + ':' + str(port) 7 | torch.distributed.init_process_group("nccl", init_method=host_addr_full, 8 | rank=rank, world_size=world_size) 9 | torch.cuda.set_device(local_rank) 10 | assert torch.distributed.is_initialized() 11 | 12 | def parse_host_addr(s): 13 | if '[' in s: 14 | left_bracket = s.index('[') 15 | right_bracket = s.index(']') 16 | prefix = s[:left_bracket] 17 | first_number = s[left_bracket+1:right_bracket].split(',')[0].split('-')[0] 18 | return prefix + first_number 19 | else: 20 | return s 21 | 22 | rank = int(os.environ['SLURM_PROCID']) 23 | local_rank = int(os.environ['SLURM_LOCALID']) 24 | world_size = int(os.environ['SLURM_NTASKS']) 25 | 26 | ip = parse_host_addr(os.environ['SLURM_STEP_NODELIST']) 27 | 28 | init(ip, rank, local_rank, world_size, port) 29 | 30 | return rank, local_rank, world_size 31 | 32 | def local_dist_init(args, rank): 33 | import torch 34 | import os 35 | import torch.distributed as dist 36 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 37 | os.environ['MASTER_ADDR'] = 'localhost' 38 | os.environ['MASTER_PORT'] = args.port 39 | os.environ['WORLD_SIZE'] = args.world_size 40 | os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL' 41 | dist.init_process_group(backend='nccl', world_size=int(args.world_size), rank=rank) 42 | torch.cuda.set_device(rank) 43 | torch.autograd.set_detect_anomaly(True) 44 | 45 | -------------------------------------------------------------------------------- /generalframeworks/util/meter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, name, fmt=':f'): 7 | self.name = name 8 | self.fmt = fmt 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | def __str__(self): 24 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 25 | return fmtstr.format(**self.__dict__) 26 | 27 | 28 | class ConfMatrix(object): 29 | def __init__(self, num_classes, fmt, name='miou'): 30 | self.name = name 31 | self.fmt = fmt 32 | self.num_classes = num_classes 33 | self.mat = None 34 | self.temp_mat = None 35 | self.val = 0 36 | self.avg = 0 37 | 38 | 39 | def update(self, pred, target): 40 | n = self.num_classes 41 | self.temp_mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 42 | if self.mat is None: 43 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=pred.device) 44 | with torch.no_grad(): 45 | k = (target >= 0) & (target < n) 46 | inds = n * target[k].to(torch.int64) + pred[k] 47 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 48 | self.temp_mat = torch.bincount(inds, minlength=n**2).reshape(n, n) 49 | 50 | 51 | def __str__(self): 52 | h = self.mat.float() 53 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 54 | self.avg = torch.mean(iu).item() 55 | 56 | h_t = self.temp_mat.float() 57 | iu_a = torch.diag(h_t) / (h_t.sum(1) + h_t.sum(0) - torch.diag(h_t)) 58 | self.val = torch.mean(iu_a).item() 59 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 60 | return fmtstr.format(**self.__dict__) 61 | 62 | 63 | class ProgressMeter(object): 64 | def __init__(self, num_batches, meters, prefix=""): 65 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 66 | self.meters = meters 67 | self.prefix = prefix 68 | 69 | def display(self, batch): 70 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 71 | entries += [str(meter) for meter in self.meters] 72 | print('\t'.join(entries)) 73 | 74 | def _get_batch_fmtstr(self, num_batches): 75 | num_digits = len(str(num_batches // 1)) 76 | fmt = '{:' + str(num_digits) + 'd}' 77 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 78 | 79 | -------------------------------------------------------------------------------- /generalframeworks/util/miou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mean_intersection_over_union(mat: torch.Tensor): 4 | ''' Compute miou via Confmatrix''' 5 | h = mat.float() 6 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 7 | miou = torch.mean(iu).item() 8 | 9 | return miou -------------------------------------------------------------------------------- /generalframeworks/util/torch_dist_sum.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | __all__ = ['torch_dist_sum'] 5 | 6 | def torch_dist_sum(gpu, *args): 7 | process_group = torch.distributed.group.WORLD 8 | tensor_args = [] 9 | pending_res = [] 10 | for arg in args: 11 | # if isinstance(arg, torch.Tensor): 12 | # tensor_arg = arg.clone().reshape(-1).detach().cuda(gpu) 13 | # else: 14 | # tensor_arg = torch.tensor(arg).reshape(-1).cuda(gpu) 15 | tensor_arg = arg.clone().detach().cuda(gpu) 16 | tensor_args.append(tensor_arg) 17 | pending_res.append(torch.distributed.all_reduce(tensor_arg, group=process_group, async_op=True)) 18 | for res in pending_res: 19 | res.wait() 20 | return tensor_args 21 | -------------------------------------------------------------------------------- /generalframeworks/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from typing import Iterable, Union 4 | from copy import deepcopy as dcopy 5 | from typing import List, Set 6 | import collections 7 | from functools import partial, reduce 8 | import torch 9 | import numpy as np 10 | import os 11 | import datetime 12 | # from tqdm import tqdm 13 | from torch.utils.data import DataLoader 14 | import warnings 15 | import torch.nn as nn 16 | import sys 17 | 18 | ##### Hyper Parameters Define ##### 19 | 20 | def _parser_(input_strings: str) -> Union[dict, None]: 21 | if input_strings.__len__() == 0: 22 | return None 23 | assert input_strings.find('=') > 0, f"Input args should include '=' to include value" 24 | keys, value = input_strings.split('=')[:-1][0].replace(' ', ''), input_strings.split('=')[1].replace(' ', '') 25 | keys = keys.split('.') 26 | keys.reverse() 27 | for k in keys: 28 | d = {} 29 | d[k] =value 30 | value = dcopy(d) 31 | return dict(value) 32 | 33 | def _parser(strings: List[str]) -> List[dict]: 34 | assert isinstance(strings, list) 35 | args: List[dict] = [_parser_(s) for s in strings] 36 | args = reduce(lambda x, y: dict_merge(x, y, True), args) 37 | return args 38 | 39 | def yaml_parser() -> dict: 40 | parser = argparse.ArgumentParser('Augmnet oarser for yaml config') 41 | parser.add_argument('strings', nargs='*', type=str, default=['']) 42 | parser.add_argument("--local_rank", type=int) 43 | #parser.add_argument('--var', type=int, default=24) 44 | #add args.variable here 45 | args: argparse.Namespace = parser.parse_args() 46 | args: dict = _parser(args.strings) 47 | return args 48 | 49 | def dict_merge(dct: dict, merge_dct: dict, re=False): 50 | ''' 51 | Recursive dict merge. Instead updating only top-level keys, dict_merge recuses down into dicts nested 52 | to an arbitrary depth, updating keys. The ""merge_dct"" is merged into "dct". 53 | ''' 54 | if merge_dct is None: 55 | if re: 56 | return dct 57 | else: 58 | return 59 | for k, v in merge_dct.items(): 60 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct(k), collections.Mapping)): 61 | dict_merge(dct[k], merge_dct[k]) 62 | else: 63 | try: 64 | dct[k] = type(dct[k])(eval(merge_dct[k])) if type(dct[k]) in (bool, list) else type(dct[k])( 65 | merge_dct[k]) 66 | except: 67 | dct[k] = merge_dct[k] 68 | if re: 69 | return dcopy(dct) 70 | 71 | ##### Timer ###### 72 | def now_time(): 73 | time = datetime.datetime.now() 74 | return str(time)[:19] 75 | 76 | # ##### Progress Bar ##### 77 | 78 | # tqdm_ = partial(tqdm, ncols=125, leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [' '{rate_fmt}{postfix}]') 79 | 80 | ##### Coding ##### 81 | def class2one_hot(seg: torch.Tensor, num_class: int) -> torch.Tensor: 82 | ''' 83 | [b, w, h] containing (0, 1, ..., c) -> [b, c, w, h] containing (0, 1) 84 | ''' 85 | if len(seg.shape) == 2: 86 | seg = seg.unsqueeze(dim=0) # Must 3 dim 87 | if len(seg.shape) == 4: 88 | seg = seg.squeeze(dim=1) 89 | assert sset(seg, list(range(num_class))), 'The value of segmentation outside the num_class!' 90 | b, w, h = seg.shape # Tuple [int, int, int] 91 | res = torch.stack([seg == c for c in range(num_class)], dim=1).type(torch.int32) 92 | assert res.shape == (b, num_class, w, h) 93 | assert one_hot(res) 94 | 95 | return res 96 | 97 | def probs2class(probs: torch.Tensor) -> torch.Tensor: 98 | ''' 99 | [b, c, w, h] containing(float in range(0, 1)) -> [b, w, h] containing ([0, 1, ..., c]) 100 | ''' 101 | b, _, w, h = probs.shape 102 | assert simplex(probs), '{} is not a probability'.format(probs) 103 | res = probs.argmax(dim=1) 104 | assert res.shape == (b, w, h) 105 | 106 | return res 107 | 108 | def probs2one_hot(probs: torch.Tensor) -> torch.Tensor: 109 | _, num_class, _, _ = probs.shape 110 | assert simplex(probs), '{} is not a probability'.format(probs) 111 | res = class2one_hot(probs2class(probs), num_class) 112 | assert res.shape == probs.shape 113 | assert one_hot(res) 114 | return res 115 | 116 | def label_onehot(inputs, num_class): 117 | ''' 118 | inputs is class label 119 | return one_hot label 120 | dim will be increasee 121 | ''' 122 | batch_size, image_h, image_w = inputs.shape 123 | inputs = torch.relu(inputs) 124 | outputs = torch.zeros([batch_size, num_class, image_h, image_w]).to(inputs.device) 125 | return outputs.scatter_(1, inputs.unsqueeze(1), 1.0) 126 | 127 | def label_onehot_2(inputs, num_class): 128 | ''' 129 | inputs is class label 130 | return one_hot label 131 | dim will be increasee 132 | ''' 133 | batch_size, image_h, image_w = inputs.shape 134 | inputs = inputs + 1 135 | outputs = torch.zeros([batch_size, (num_class + 1), image_h, image_w]).to(inputs.device) 136 | return outputs.scatter_(1, inputs.unsqueeze(1), 1.0) 137 | 138 | def uniq(a: torch.Tensor) -> Set: 139 | return set(torch.unique(a.cpu()).numpy()) 140 | 141 | def sset(a: torch.Tensor, sub: Iterable) -> bool: 142 | return uniq(a).issubset(sub) 143 | 144 | def simplex(t: torch.Tensor, axis=1) -> bool: 145 | ''' 146 | Check if the maticx is the probability in axis dimension. 147 | ''' 148 | _sum = t.sum(axis).type(torch.float32) 149 | _ones = torch.ones_like(_sum, dtype=torch.float32) 150 | return torch.allclose(_sum, _ones) 151 | 152 | def one_hot(t: torch.Tensor, axis=1) -> bool: 153 | ''' 154 | Check if the Tensor is One-hot coding 155 | ''' 156 | return simplex(t, axis) and sset(t, [0, 1]) 157 | 158 | def intersection(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 159 | ''' 160 | a and b must only contain 0 or 1, the function compute the intersection of two tensor. 161 | a & b 162 | ''' 163 | assert a.shape == b.shape, '{}.shape must be the same as {}'.format(a, b) 164 | assert sset(a, [0, 1]), '{} must only contain 0, 1'.format(a) 165 | assert sset(b, [0, 1]), '{} must only contain 0, 1'.format(b) 166 | return a & b 167 | 168 | class iterator_(object): 169 | def __init__(self, dataloader: DataLoader) -> None: 170 | super().__init__() 171 | self.dataloader = dcopy(dataloader) 172 | self.iter_dataloader = iter(dataloader) 173 | self.cache = None 174 | 175 | def __next__(self): 176 | try: 177 | self.cache = self.iter_dataloader.__next__() 178 | return self.cache 179 | except StopIteration: 180 | self.iter_dataloader = iter(self.dataloader) 181 | self.cache = self.iter_dataloader.__next__() 182 | return self.cache 183 | def __cache__(self): 184 | if self.cache is not None: 185 | return self.cache 186 | else: 187 | warnings.warn('No cache found ,iterator forward') 188 | return self.__next__() 189 | 190 | def apply_dropout(m): 191 | if type(m) == nn.Dropout2d: 192 | m.train() 193 | 194 | ##### Scheduler ##### 195 | class RampUpScheduler(): 196 | def __init__(self, begin_epoch, max_epoch, max_value, ramp_mult): 197 | super().__init__() 198 | self.begin_epoch = begin_epoch 199 | self.max_epoch = max_epoch 200 | self.ramp_mult = ramp_mult 201 | self.max_value = max_value 202 | self.epoch = 0 203 | 204 | def step(self): 205 | self.epoch += 1 206 | 207 | @property 208 | def value(self): 209 | return self.get_lr(self.epoch, self.begin_epoch, self.max_epoch, self.max_value,self.ramp_mult) 210 | 211 | def get_lr(self, epoch, begin_epoch, max_epochs, max_val, mult): 212 | if epoch < begin_epoch: 213 | return 0. 214 | elif epoch >= max_epochs: 215 | return max_val 216 | return max_val * np.exp(mult * (1 - float(epoch - begin_epoch) / (max_epochs - begin_epoch)) ** 2 ) 217 | 218 | 219 | ##### Compute mIoU ##### 220 | def mask_label(label, mask): 221 | ''' 222 | label is the original label (contains -1), mask is the valid region in pseudo label (type=long) 223 | return a label with invalid region = -1 224 | ''' 225 | label_tmp = label.clone() 226 | mask_ = (1 - mask.float()).bool() 227 | label_tmp[mask_] = -1 228 | return label_tmp.long() 229 | 230 | ##### Logger ##### 231 | class Logger(object): 232 | def __init__(self, logFile ="Default.log"): 233 | self.terminal = sys.stdout 234 | self.log = open(logFile,'a') 235 | 236 | def write(self,message): 237 | self.terminal.write(message) 238 | self.log.write(message) 239 | 240 | def flush(self): 241 | pass -------------------------------------------------------------------------------- /mix_label.py: -------------------------------------------------------------------------------- 1 | import shutup 2 | shutup.please() 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parallel import DistributedDataParallel 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | from generalframeworks.dataset_helpers.VOC import VOC_BuildData 10 | from generalframeworks.dataset_helpers.Cityscapes import City_BuildData 11 | from generalframeworks.networks import resnet 12 | from generalframeworks.networks.ddp_model import Model_mix 13 | from generalframeworks.scheduler.my_lr_scheduler import PolyLR 14 | from generalframeworks.scheduler.rampscheduler import RampdownScheduler 15 | from generalframeworks.utils import iterator_, Logger 16 | from generalframeworks.util.meter import * 17 | from generalframeworks.utils import label_onehot, label_onehot_2 18 | from generalframeworks.util.torch_dist_sum import * 19 | from generalframeworks.util.miou import * 20 | from generalframeworks.util.dist_init import local_dist_init 21 | from generalframeworks.loss.loss import ProbOhemCrossEntropy2d, Attention_Threshold_Loss, Contrast_Loss 22 | import yaml 23 | import os 24 | import time 25 | import torchvision.models as models 26 | import argparse 27 | import random 28 | 29 | def main(rank, config, args): 30 | ##### Distribution init ##### 31 | local_dist_init(args, rank) 32 | print('Hello from rank {}\n'.format(rank)) 33 | 34 | ##### Load the dataset ##### 35 | if config['Dataset']['name'] == 'VOC': 36 | data = VOC_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 37 | label_num=args.num_labels, seed=config['Seed'], crop_size=config['Dataset']['crop_size']) 38 | if config['Dataset']['name'] == 'CityScapes': 39 | data = City_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 40 | label_num=args.num_labels, seed=config['Seed'], crop_size=config['Dataset']['crop_size']) 41 | train_l_dataset, train_u_dataset, test_dataset = data.build() 42 | train_l_sampler = torch.utils.data.distributed.DistributedSampler(train_l_dataset) 43 | train_l_loader = torch.utils.data.DataLoader(train_l_dataset, 44 | batch_size=config['Dataset']['batch_size'], 45 | pin_memory=True, 46 | sampler=train_l_sampler, 47 | num_workers=4, 48 | drop_last=True) 49 | train_u_sampler = torch.utils.data.distributed.DistributedSampler(train_u_dataset) 50 | train_u_loader = torch.utils.data.DataLoader(train_u_dataset, 51 | batch_size=config['Dataset']['batch_size'], 52 | pin_memory=True, 53 | sampler=train_u_sampler, 54 | num_workers=4, 55 | drop_last=True) 56 | test_loader = torch.utils.data.DataLoader(test_dataset, 57 | batch_size=config['Dataset']['batch_size'], 58 | pin_memory=True, 59 | num_workers=4) 60 | 61 | ##### Load the weight for each class ##### 62 | weight = torch.FloatTensor( 63 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 64 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 65 | 1.0865, 1.1529, 1.0507]).cuda() 66 | 67 | ##### Model init ##### 68 | backbone = models.resnet101() 69 | ckpt = torch.load('./pretrained/resnet101.pth', map_location='cpu') 70 | backbone.load_state_dict(ckpt) 71 | 72 | # for Resnet-101 stem users 73 | #backbone = resnet.resnet101(pretrained=True) 74 | 75 | model = Model_mix(backbone, num_classes=config['Network']['num_class'], output_dim=256, config=config, temp=args.temp).cuda() 76 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda() 77 | model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) 78 | 79 | ##### Loss init ##### 80 | criterion = {'sup_loss': ProbOhemCrossEntropy2d(ignore_label=-1, thresh=0.7, min_kept=50000 * config['Dataset']['batch_size']).cuda(), 81 | 'ce_loss': nn.CrossEntropyLoss(ignore_index=-1).cuda(), 82 | 'unsup_loss': Attention_Threshold_Loss(strong_threshold=args.un_threshold).cuda(), 83 | 'contrast_loss': Contrast_Loss(strong_threshold=args.strong_threshold, 84 | num_queries=config['Loss']['num_queries'], 85 | num_negatives=config['Loss']['num_negatives'], 86 | temp=config['Loss']['temp'], 87 | alpha=config['Loss']['alpha']).cuda(), 88 | } 89 | 90 | ##### Prototype init ##### 91 | global prototypes 92 | 93 | prototypes = torch.zeros(config['Network']['num_class'], 256).cuda() 94 | 95 | ##### Other init ##### 96 | optimizer = torch.optim.SGD(model.module.model.parameters(), 97 | lr=float(config['Optim']['lr']), weight_decay=float(config['Optim']['weight_decay']), momentum=0.9, nesterov=True) 98 | total_iter = args.total_iter 99 | total_epoch = int(total_iter / len(train_l_loader)) 100 | if dist.get_rank() == 0: 101 | print('total epoch is {}'.format(total_epoch)) 102 | lr_scheduler = PolyLR(optimizer, total_iter, min_lr=1e-4) 103 | 104 | if os.path.exists(args.resume): 105 | print('resume from', args.resume) 106 | checkpoint = torch.load(args.resume, map_location='cpu') 107 | model.module.model.load_state_dict(checkpoint['model']) 108 | model.module.ema_model.load_state_dict(checkpoint['ema_model']) 109 | optimizer.load_state_dict(checkpoint['optimizer']) 110 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 111 | start_epoch = checkpoint['epoch'] 112 | prototypes = torch.tensor(checkpoint['prototypes']).cuda() 113 | else: 114 | start_epoch = 0 115 | sche_d = RampdownScheduler(begin_epoch=config['Ramp_Scheduler']['begin_epoch'], 116 | max_epoch=config['Ramp_Scheduler']['max_epoch'], 117 | current_epoch=start_epoch, 118 | max_value=config['Ramp_Scheduler']['max_value'], 119 | min_value=config['Ramp_Scheduler']['min_value'], 120 | ramp_mult=config['Ramp_Scheduler']['ramp_mult']) 121 | 122 | # if dist.get_rank() == 0: 123 | # log = Logger(logFile='./log/' + str(args.job_name) + '.log') 124 | best_miou = 0 125 | 126 | model.module.model.train() 127 | model.module.ema_model.train() 128 | for epoch in range(start_epoch, total_epoch): 129 | train(train_l_loader, train_u_loader, model, optimizer, criterion, epoch, lr_scheduler, sche_d, config, args) 130 | if epoch % 20 == 0 or epoch > total_epoch - 50: # to save eval time, move it is ok 131 | miou = test(test_loader, model.module.ema_model, config) 132 | best_miou = max(best_miou, miou) 133 | if dist.get_rank() == 0: 134 | print('Epoch:{} * mIoU {:.4f} Best_mIoU {:.4f} Time {}'.format(epoch, miou, best_miou, time.asctime(time.localtime(time.time())))) 135 | # log.write('Epoch:{} * mIoU {:.4f} Best_mIoU {:.4f} Time {}\n'.format(epoch, miou, best_miou, time.asctime( time.localtime(time.time()) ))) 136 | # Save model 137 | if miou == best_miou: 138 | save_dir = './checkpoints/' + str(args.job_name) 139 | torch.save( 140 | { 141 | 'epoch': epoch+1, 142 | 'model': model.module.model.state_dict(), 143 | 'ema_model': model.module.ema_model.state_dict(), 144 | 'optimizer': optimizer.state_dict(), 145 | 'lr_scheduler': lr_scheduler.state_dict(), 146 | 'prototypes': prototypes.data.cpu().numpy(), 147 | }, os.path.join(save_dir, 'best_model.pth')) 148 | else: 149 | if dist.get_rank() == 0: 150 | print('Epoch:{} * Time {}'.format(epoch, time.asctime(time.localtime(time.time())))) 151 | 152 | 153 | 154 | def train(train_l_loader, train_u_loader, model, optimizer, criterion, epoch, scheduler, sche_d, config, args): 155 | num_class = config['Network']['num_class'] 156 | model.module.model.train() 157 | model.module.ema_model.train() 158 | 159 | train_u_loader.sampler.set_epoch(epoch) 160 | training_u_iter = iterator_(train_u_loader) 161 | train_l_loader.sampler.set_epoch(epoch) 162 | for i, (train_l_image, train_l_label) in enumerate(train_l_loader): 163 | train_l_image, train_l_label = train_l_image.cuda(), train_l_label.cuda() 164 | train_u_image, train_u_label = training_u_iter.__next__() 165 | train_u_image, train_u_label = train_u_image.cuda(), train_u_label.cuda() 166 | pred_l_large, pred_u_large, train_u_aug_label, train_u_aug_logits_cls, train_u_aug_logits_rep, rep_all, pred_all = model(train_l_image, train_u_image, prototypes) 167 | 168 | if config['Dataset']['name'] == 'VOC': 169 | sup_loss = criterion['ce_loss'](pred_l_large, train_l_label) 170 | else: 171 | sup_loss = criterion['sup_loss'](pred_l_large, train_l_label) 172 | unsup_loss = criterion['unsup_loss'](pred_u_large, train_u_aug_label, train_u_aug_logits_cls) 173 | 174 | ##### Contrastive learning ##### 175 | with torch.no_grad(): 176 | train_u_aug_mask = train_u_aug_logits_cls.ge(args.weak_threshold).float() 177 | mask_all = torch.cat(((train_l_label.unsqueeze(1) >= 0).float(), train_u_aug_mask.unsqueeze(1))) 178 | mask_all = F.interpolate(mask_all, size=pred_all.shape[2:], mode='nearest') 179 | 180 | label_l = F.interpolate(label_onehot(train_l_label, num_class), size=pred_all.shape[2:], mode='nearest') 181 | label_u = F.interpolate(label_onehot_2(train_u_aug_label, num_class), size=pred_all.shape[2:], mode='nearest') 182 | label_u = label_u[:, 1:, :, :] 183 | label_all = torch.cat((label_l, label_u)) 184 | 185 | contrast_loss = criterion['contrast_loss'](rep_all, label_all, mask_all, pred_all, prototypes) 186 | 187 | if args.sche: 188 | total_loss = sup_loss + unsup_loss + contrast_loss * sche_d.value 189 | else: 190 | total_loss = sup_loss + unsup_loss + contrast_loss 191 | 192 | optimizer.zero_grad() 193 | total_loss.backward() 194 | optimizer.step() 195 | model.module.ema_update() 196 | scheduler.step() 197 | sche_d.step() 198 | 199 | @torch.no_grad() 200 | def test(test_loader, model, config): 201 | batch_time = AverageMeter('Time', ':6.3f') 202 | data_time = AverageMeter('Data', ':6.3f') 203 | miou_meter = ConfMatrix(num_classes=config['Network']['num_class'], fmt=':6.4f', name='test_miou') 204 | 205 | # switch to eval mode 206 | model.eval() 207 | 208 | end = time.time() 209 | test_iter = iter(test_loader) 210 | for _ in range(len(test_loader)): 211 | data_time.update(time.time() - end) 212 | test_image, test_label = test_iter.next() 213 | test_image, test_label = test_image.cuda(), test_label.cuda() 214 | 215 | pred, _ = model(test_image) 216 | pred = F.interpolate(pred, size=test_label.shape[1:], mode='bilinear', align_corners=True) 217 | 218 | miou_meter.update(pred.argmax(1).flatten(), test_label.flatten()) 219 | batch_time.update(time.time() - end) 220 | end = time.time() 221 | 222 | mat = torch_dist_sum(dist.get_rank(), miou_meter.mat) 223 | miou = mean_intersection_over_union(mat[0]) 224 | 225 | return miou 226 | 227 | 228 | if __name__ == '__main__': 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument('--config', type=str, default='./config/VOC_config_baseline.yaml') 231 | parser.add_argument('--resume', type=str, default='') 232 | parser.add_argument('--num_labels', type=int, default=92) 233 | parser.add_argument('--total_iter', type=int, default=80000) 234 | parser.add_argument('--job_name', type=str, default='VOC_92_mix_label') 235 | 236 | # Distributed 237 | parser.add_argument('--gpu_id', type=str, default='0,1,2,3') 238 | parser.add_argument('--world_size', type=str, default='4') 239 | parser.add_argument('--port', type=str, default='12301') 240 | 241 | # Hyperparameter 242 | parser.add_argument('--strong_threshold', type=float, default=0.8) 243 | parser.add_argument('--weak_threshold', type=float, default=0.7) 244 | parser.add_argument('--un_threshold', type=float, default=0.97) 245 | parser.add_argument('--temp', type=float, default=0.5) 246 | parser.add_argument('--sche', type=bool, default=True) 247 | 248 | args = parser.parse_args() 249 | 250 | ##### Config init ##### 251 | with open(args.config, 'r') as f: 252 | config = yaml.load(f.read(), Loader=yaml.FullLoader) 253 | save_dir = './checkpoints/' + str(args.job_name) 254 | if not os.path.exists(save_dir): 255 | os.makedirs(save_dir) 256 | with open(save_dir + '/config.yaml', 'w') as f: 257 | yaml.dump(config, f, default_flow_style=False) 258 | print(config) 259 | 260 | ##### Init Seed ##### 261 | random.seed(config['Seed']) 262 | torch.manual_seed(config['Seed']) 263 | torch.backends.cudnn.deterministic = True 264 | 265 | mp.spawn(main, nprocs=int(args.world_size), args=(config, args)) -------------------------------------------------------------------------------- /ori_pseudo.py: -------------------------------------------------------------------------------- 1 | import shutup 2 | shutup.please() 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parallel import DistributedDataParallel 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | from generalframeworks.dataset_helpers.VOC import VOC_BuildData 10 | from generalframeworks.dataset_helpers.Cityscapes import City_BuildData 11 | from generalframeworks.networks.ddp_model import Model_ori_pseudo 12 | from generalframeworks.scheduler.my_lr_scheduler import PolyLR 13 | from generalframeworks.scheduler.rampscheduler import RampdownScheduler 14 | from generalframeworks.utils import iterator_, Logger 15 | from generalframeworks.util.meter import * 16 | from generalframeworks.utils import label_onehot 17 | from generalframeworks.util.torch_dist_sum import * 18 | from generalframeworks.util.miou import * 19 | from generalframeworks.util.dist_init import local_dist_init 20 | from generalframeworks.loss.loss import ProbOhemCrossEntropy2d, Attention_Threshold_Loss, Contrast_Loss 21 | import yaml 22 | import os 23 | import time 24 | import torchvision.models as models 25 | import argparse 26 | import random 27 | 28 | def main(rank, config, args): 29 | ##### Distribution init ##### 30 | local_dist_init(args, rank) 31 | print('Hello from rank {}\n'.format(rank)) 32 | 33 | ##### Load the dataset ##### 34 | if config['Dataset']['name'] == 'VOC': 35 | data = VOC_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 36 | label_num=args.num_labels, seed=config['Seed'], crop_size=config['Dataset']['crop_size']) 37 | if config['Dataset']['name'] == 'CityScapes': 38 | data = City_BuildData(data_path=config['Dataset']['data_dir'], txt_path=config['Dataset']['txt_dir'], 39 | label_num=args.num_labels, seed=config['Seed'], crop_size=config['Dataset']['crop_size']) 40 | train_l_dataset, train_u_dataset, test_dataset = data.build() 41 | train_l_sampler = torch.utils.data.distributed.DistributedSampler(train_l_dataset) 42 | train_l_loader = torch.utils.data.DataLoader(train_l_dataset, 43 | batch_size=config['Dataset']['batch_size'], 44 | pin_memory=True, 45 | sampler=train_l_sampler, 46 | num_workers=4, 47 | drop_last=True) 48 | train_u_sampler = torch.utils.data.distributed.DistributedSampler(train_u_dataset) 49 | train_u_loader = torch.utils.data.DataLoader(train_u_dataset, 50 | batch_size=config['Dataset']['batch_size'], 51 | pin_memory=True, 52 | sampler=train_u_sampler, 53 | num_workers=4, 54 | drop_last=True) 55 | test_loader = torch.utils.data.DataLoader(test_dataset, 56 | batch_size=config['Dataset']['batch_size'], 57 | pin_memory=True, 58 | num_workers=4) 59 | 60 | ##### Load the weight for each class ##### 61 | weight = torch.FloatTensor( 62 | [0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 63 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 64 | 1.0865, 1.1529, 1.0507]).cuda() 65 | 66 | ##### Model init ##### 67 | backbone = models.resnet101() 68 | ckpt = torch.load('./pretrained/resnet101.pth', map_location='cpu') 69 | backbone.load_state_dict(ckpt) 70 | 71 | # for Resnet-101 stem users 72 | #backbone = resnet.resnet101(pretrained=True) 73 | 74 | model = Model_ori_pseudo(backbone, num_classes=config['Network']['num_class'], output_dim=256, config=config).cuda() 75 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model).cuda() 76 | model = DistributedDataParallel(model, device_ids=[torch.cuda.current_device()], find_unused_parameters=True) 77 | 78 | ##### Loss init ##### 79 | criterion = {'sup_loss': ProbOhemCrossEntropy2d(ignore_label=-1, thresh=0.7, min_kept=50000 * config['Dataset']['batch_size']).cuda(), 80 | 'ce_loss': nn.CrossEntropyLoss(ignore_index=-1).cuda(), 81 | 'unsup_loss': Attention_Threshold_Loss(strong_threshold=config['Loss']['un_threshold']).cuda(), 82 | 'contrast_loss': Contrast_Loss(strong_threshold=config['Loss']['strong_threshold'], 83 | num_queries=config['Loss']['num_queries'], 84 | num_negatives=config['Loss']['num_negatives'], 85 | temp=config['Loss']['temp'], 86 | alpha=config['Loss']['alpha']).cuda(), 87 | } 88 | 89 | ##### Prototype init ##### 90 | global prototypes 91 | 92 | prototypes = torch.zeros(config['Network']['num_class'], 256).cuda() 93 | 94 | ##### Other init ##### 95 | optimizer = torch.optim.SGD(model.module.model.parameters(), 96 | lr=float(config['Optim']['lr']), weight_decay=float(config['Optim']['weight_decay']), momentum=0.9, nesterov=True) 97 | total_iter = args.total_iter 98 | total_epoch = int(total_iter / len(train_l_loader)) 99 | if dist.get_rank() == 0: 100 | print('total epoch is {}'.format(total_epoch)) 101 | lr_scheduler = PolyLR(optimizer, total_iter, min_lr=1e-4) 102 | 103 | if os.path.exists(args.resume): 104 | print('resume from', args.resume) 105 | checkpoint = torch.load(args.resume, map_location='cpu') 106 | model.module.model.load_state_dict(checkpoint['model']) 107 | model.module.ema_model.load_state_dict(checkpoint['ema_model']) 108 | optimizer.load_state_dict(checkpoint['optimizer']) 109 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 110 | start_epoch = checkpoint['epoch'] 111 | prototypes = torch.tensor(checkpoint['prototypes']).cuda() 112 | else: 113 | start_epoch = 0 114 | sche_d = RampdownScheduler(begin_epoch=config['Ramp_Scheduler']['begin_epoch'], 115 | max_epoch=config['Ramp_Scheduler']['max_epoch'], 116 | current_epoch=start_epoch, 117 | max_value=config['Ramp_Scheduler']['max_value'], 118 | min_value=config['Ramp_Scheduler']['min_value'], 119 | ramp_mult=config['Ramp_Scheduler']['ramp_mult']) 120 | 121 | # if dist.get_rank() == 0: 122 | # log = Logger(logFile='./log/' + str(args.job_name) + '.log') 123 | best_miou = 0 124 | 125 | model.module.model.train() 126 | model.module.ema_model.train() 127 | for epoch in range(start_epoch, total_epoch): 128 | train(train_l_loader, train_u_loader, model, optimizer, criterion, epoch, lr_scheduler, sche_d, config, args) 129 | miou = test(test_loader, model.module.ema_model, config) 130 | best_miou = max(best_miou, miou) 131 | if dist.get_rank() == 0: 132 | print('Epoch:{} * mIoU {:.4f} Best_mIoU {:.4f} Time {}'.format(epoch, miou, best_miou, time.asctime(time.localtime(time.time())))) 133 | # log.write('Epoch:{} * mIoU {:.4f} Best_mIoU {:.4f} Time {}\n'.format(epoch, miou, best_miou, time.asctime( time.localtime(time.time()) ))) 134 | # Save model 135 | if miou == best_miou: 136 | save_dir = './checkpoints/' + str(args.job_name) 137 | torch.save( 138 | { 139 | 'epoch': epoch+1, 140 | 'model': model.module.model.state_dict(), 141 | 'ema_model': model.module.ema_model.state_dict(), 142 | 'optimizer': optimizer.state_dict(), 143 | 'lr_scheduler': lr_scheduler.state_dict(), 144 | 'prototypes': prototypes.data.cpu().numpy(), 145 | }, os.path.join(save_dir, 'best_model.pth')) 146 | 147 | 148 | 149 | def train(train_l_loader, train_u_loader, model, optimizer, criterion, epoch, scheduler, sche_d, config, args): 150 | num_class = config['Network']['num_class'] 151 | # switch to train mode 152 | model.module.model.train() 153 | model.module.ema_model.train() 154 | 155 | train_u_loader.sampler.set_epoch(epoch) 156 | training_u_iter = iterator_(train_u_loader) 157 | train_l_loader.sampler.set_epoch(epoch) 158 | for i, (train_l_image, train_l_label) in enumerate(train_l_loader): 159 | train_l_image, train_l_label = train_l_image.cuda(), train_l_label.cuda() 160 | train_u_image, train_u_label = training_u_iter.__next__() 161 | train_u_image, train_u_label = train_u_image.cuda(), train_u_label.cuda() 162 | pred_l_large, pred_u_large, train_u_aug_label, train_u_aug_logits, rep_all, pred_all, pred_u_large_raw = model(train_l_image, train_u_image) 163 | 164 | if config['Dataset']['name'] == 'VOC': 165 | sup_loss = criterion['ce_loss'](pred_l_large, train_l_label) 166 | else: 167 | sup_loss = criterion['sup_loss'](pred_l_large, train_l_label) 168 | unsup_loss = criterion['unsup_loss'](pred_u_large, train_u_aug_label, train_u_aug_logits) 169 | 170 | ##### Contrastive learning ##### 171 | with torch.no_grad(): 172 | train_u_aug_mask = train_u_aug_logits.ge(config['Loss']['weak_threshold']).float() 173 | mask_all = torch.cat(((train_l_label.unsqueeze(1) >= 0).float(), train_u_aug_mask.unsqueeze(1))) 174 | mask_all = F.interpolate(mask_all, size=pred_all.shape[2:], mode='nearest') 175 | 176 | label_l = F.interpolate(label_onehot(train_l_label, num_class), size=pred_all.shape[2:], mode='nearest') 177 | label_u = F.interpolate(label_onehot(train_u_aug_label, num_class), size=pred_all.shape[2:], mode='nearest') 178 | label_all = torch.cat((label_l, label_u)) 179 | 180 | prob_all = torch.softmax(pred_all, dim=1) 181 | contrast_loss = criterion['contrast_loss'](rep_all, label_all, mask_all, prob_all, prototypes) 182 | 183 | total_loss = sup_loss + unsup_loss + contrast_loss 184 | 185 | optimizer.zero_grad() 186 | total_loss.backward() 187 | optimizer.step() 188 | model.module.ema_update() 189 | scheduler.step() 190 | 191 | @torch.no_grad() 192 | def test(test_loader, model, config): 193 | miou_meter = ConfMatrix(num_classes=config['Network']['num_class'], fmt=':6.4f', name='test_miou') 194 | 195 | # switch to eval mode 196 | model.eval() 197 | 198 | test_iter = iter(test_loader) 199 | for _ in range(len(test_loader)): 200 | test_image, test_label = test_iter.next() 201 | test_image, test_label = test_image.cuda(), test_label.cuda() 202 | 203 | pred, _ = model(test_image) 204 | pred = F.interpolate(pred, size=test_label.shape[1:], mode='bilinear', align_corners=True) 205 | 206 | miou_meter.update(pred.argmax(1).flatten(), test_label.flatten()) 207 | 208 | mat = torch_dist_sum(dist.get_rank(), miou_meter.mat) 209 | miou = mean_intersection_over_union(mat[0]) 210 | 211 | return miou 212 | 213 | 214 | if __name__ == '__main__': 215 | parser = argparse.ArgumentParser() 216 | parser.add_argument('--config', type=str, default='./config/VOC_config_baseline.yaml') 217 | parser.add_argument('--resume', type=str, default='') 218 | parser.add_argument('--num_labels', type=int, default=92) 219 | parser.add_argument('--total_iter', type=int, default=80000) 220 | parser.add_argument('--job_name', type=str, default='VOC_92_baseline') 221 | 222 | # Distributed 223 | parser.add_argument('--gpu_id', type=str, default='0,1,2,3') 224 | parser.add_argument('--world_size', type=str, default='4') 225 | parser.add_argument('--port', type=str, default='12301') 226 | 227 | args = parser.parse_args() 228 | 229 | ##### Config init ##### 230 | with open(args.config, 'r') as f: 231 | config = yaml.load(f.read(), Loader=yaml.FullLoader) 232 | save_dir = './checkpoints/' + str(args.job_name) 233 | if not os.path.exists(save_dir): 234 | os.makedirs(save_dir) 235 | with open(save_dir + '/config.yaml', 'w') as f: 236 | yaml.dump(config, f, default_flow_style=False) 237 | print(config) 238 | 239 | ##### Init Seed ##### 240 | random.seed(config['Seed']) 241 | torch.manual_seed(config['Seed']) 242 | torch.backends.cudnn.deterministic = True 243 | 244 | mp.spawn(main, nprocs=int(args.world_size), args=(config, args)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | Bottleneck==1.3.4 3 | cachetools==5.0.0 4 | certifi==2021.10.8 5 | charset-normalizer==2.0.12 6 | colorama==0.4.4 7 | google-auth==2.6.6 8 | google-auth-oauthlib==0.4.6 9 | grpcio==1.44.0 10 | idna==3.3 11 | importlib-metadata==4.11.3 12 | Markdown==3.3.6 13 | numexpr==2.8.0 14 | numpy==1.21.6 15 | oauthlib==3.2.0 16 | opencv-python==4.5.5.64 17 | pandas==1.3.4 18 | Pillow==9.1.0 19 | pip==22.0.4 20 | protobuf==3.20.1 21 | pyasn1==0.4.8 22 | pyasn1-modules==0.2.8 23 | python-dateutil==2.8.2 24 | pytz==2022.1 25 | PyYAML==6.0 26 | requests==2.27.1 27 | requests-oauthlib==1.3.1 28 | rsa==4.8 29 | setuptools==62.1.0 30 | shutup==0.2.0 31 | six==1.16.0 32 | tensorboard==2.8.0 33 | tensorboard-data-server==0.6.1 34 | tensorboard-plugin-wit==1.8.1 35 | tensorboardX==2.5 36 | torch==1.7.1+cu110 37 | torchaudio==0.7.2 38 | torchvision==0.8.2+cu110 39 | tqdm==4.64.0 40 | typing_extensions==4.2.0 41 | urllib3==1.26.9 42 | Werkzeug==2.1.2 43 | wheel==0.37.1 44 | zipp==3.8.0 45 | --------------------------------------------------------------------------------