├── utils ├── __init__.py ├── log_utils.py └── pytorch_utils.py ├── datasets ├── __init__.py └── crowd.py ├── losses ├── __init__.py ├── ramps.py ├── multi_con_loss.py ├── rank_loss.py ├── dm_loss.py ├── ot_loss.py ├── consistency_loss.py └── bregman_pytorch.py ├── sample_imgs └── overview.png ├── requirements.txt ├── README.md ├── test.py ├── train.py └── network └── pvt_cls.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /sample_imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HAAClassic/TreeFormer/HEAD/sample_imgs/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.5 2 | Pillow==9.4.0 3 | scikit_learn==1.2.2 4 | scipy==1.7.3 5 | timm==0.4.12 6 | torch==1.12.1 7 | torchvision==0.13.1 8 | -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(log_file): 5 | logger = logging.getLogger(log_file) 6 | logger.setLevel(logging.DEBUG) 7 | fh = logging.FileHandler(log_file) 8 | fh.setLevel(logging.DEBUG) 9 | ch = logging.StreamHandler() 10 | ch.setLevel(logging.INFO) 11 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 12 | ch.setFormatter(formatter) 13 | fh.setFormatter(formatter) 14 | logger.addHandler(ch) 15 | logger.addHandler(fh) 16 | return logger 17 | 18 | 19 | def print_config(config, logger): 20 | """ 21 | Print configuration of the model 22 | """ 23 | for k, v in config.items(): 24 | logger.info("{}:\t{}".format(k.ljust(15), v)) 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # TreeFormer 3 | 4 | This is the code base for IEEE TRANSACTIONS ON GEOSCIENCE AND REMOTE SENSING (TGRS 2023) paper ['TreeFormer: a Semi-Supervised Transformer-based Framework for Tree Counting from a Single High Resolution Image'](https://arxiv.org/abs/2307.06118) 5 | 6 | 7 | 8 | ## Installation 9 | 10 | Python ≥ 3.7. 11 | 12 | To install the required packages, please run: 13 | 14 | 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Dataset 20 | Download the dataset from [google drive](https://drive.google.com/file/d/1xcjv8967VvvzcDM4aqAi7Corkb11T0i2/view?usp=drive_link). 21 | ## Evaluation 22 | Download our trained model on [London](https://drive.google.com/file/d/14uuOF5758sxtM5EgeGcRtSln5lUXAHge/view?usp=sharing) dataset. 23 | 24 | Modify the path to the dataset and model for evaluation in 'test.py'. 25 | 26 | Run 'test.py' 27 | ## Acknowledgements 28 | 29 | - Part of codes are borrowed from [PVT](https://github.com/whai362/PVT) and [DM Count](https://github.com/cvlab-stonybrook/DM-Count). Thanks for their great work! 30 | 31 | 32 | -------------------------------------------------------------------------------- /losses/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /losses/multi_con_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from losses.consistency_loss import * 5 | 6 | 7 | class MultiConLoss(nn.Module): 8 | def __init__(self): 9 | super(MultiConLoss, self).__init__() 10 | self.countloss_criterion = nn.MSELoss(reduction='sum') 11 | self.multiconloss = 0.0 12 | self.losses = {} 13 | 14 | def forward(self, unlabeled_results): 15 | self.multiconloss = 0.0 16 | self.losses = {} 17 | 18 | if unlabeled_results is None: 19 | self.multiconloss = 0.0 20 | elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0: 21 | count = 0 22 | for i in range(len(unlabeled_results[0])): 23 | with torch.set_grad_enabled(False): 24 | preds_mean = (unlabeled_results[0][i] + unlabeled_results[1][i] + unlabeled_results[2][i])/len(unlabeled_results) 25 | for j in range(len(unlabeled_results)): 26 | 27 | var_sel = softmax_kl_loss(unlabeled_results[j][i], preds_mean) 28 | exp_var = torch.exp(-var_sel) 29 | consistency_dist = (preds_mean - unlabeled_results[j][i]) ** 2 30 | temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel) 31 | 32 | self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss}) 33 | self.multiconloss += temploss 34 | 35 | count += 1 36 | if count > 0: 37 | self.multiconloss = self.multiconloss / count 38 | 39 | 40 | return self.multiconloss 41 | 42 | -------------------------------------------------------------------------------- /utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10): 4 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 5 | lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6) 6 | for param_group in optimizer.param_groups: 7 | param_group['lr'] = lr 8 | 9 | 10 | class Save_Handle(object): 11 | """handle the number of """ 12 | def __init__(self, max_num): 13 | self.save_list = [] 14 | self.max_num = max_num 15 | 16 | def append(self, save_path): 17 | if len(self.save_list) < self.max_num: 18 | self.save_list.append(save_path) 19 | else: 20 | remove_path = self.save_list[0] 21 | del self.save_list[0] 22 | self.save_list.append(save_path) 23 | if os.path.exists(remove_path): 24 | os.remove(remove_path) 25 | 26 | 27 | class AverageMeter(object): 28 | """Computes and stores the average and current value""" 29 | def __init__(self): 30 | self.reset() 31 | 32 | def reset(self): 33 | self.val = 0 34 | self.avg = 0 35 | self.sum = 0 36 | self.count = 0 37 | 38 | def update(self, val, n=1): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = 1.0 * self.sum / self.count 43 | 44 | def get_avg(self): 45 | return self.avg 46 | 47 | def get_count(self): 48 | return self.count 49 | 50 | 51 | def set_trainable(model, requires_grad): 52 | for param in model.parameters(): 53 | param.requires_grad = requires_grad 54 | 55 | 56 | 57 | def get_num_params(model): 58 | return sum(p.numel() for p in model.parameters() if p.requires_grad) -------------------------------------------------------------------------------- /losses/rank_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MarginRankLoss(nn.Module): 7 | def __init__(self): 8 | super(MarginRankLoss, self).__init__() 9 | self.loss = 0.0 10 | 11 | def forward(self, img_list, margin=0): 12 | length = len(img_list) 13 | self.loss = 0.0 14 | B, C, H, W = img_list[0].shape 15 | for i in range(length - 1): 16 | for j in range(i + 1, length): 17 | self.loss = self.loss + torch.sum(F.relu(img_list[j].sum(-1).sum(-1).sum(-1) - img_list[i].sum(-1).sum(-1).sum(-1) + margin)) 18 | 19 | self.loss = self.loss / (B*length*(length-1)/2) 20 | return self.loss 21 | 22 | 23 | class RankLoss(nn.Module): 24 | def __init__(self): 25 | super(RankLoss, self).__init__() 26 | self.countloss_criterion = nn.MSELoss(reduction='sum') 27 | self.rankloss_criterion = MarginRankLoss() 28 | self.rankloss = 0.0 29 | self.losses = {} 30 | 31 | def forward(self, unlabeled_results): 32 | self.rankloss = 0.0 33 | self.losses = {} 34 | 35 | 36 | if unlabeled_results is None: 37 | self.rankloss = 0.0 38 | elif isinstance(unlabeled_results, tuple) and len(unlabeled_results) > 0: 39 | self.rankloss = self.rankloss_criterion(unlabeled_results) 40 | elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0: 41 | count = 0 42 | for i in range(len(unlabeled_results)): 43 | if isinstance(unlabeled_results[i], tuple) and len(unlabeled_results[i]) > 0: 44 | temploss = self.rankloss_criterion(unlabeled_results[i]) 45 | self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss}) 46 | self.rankloss += temploss 47 | 48 | count += 1 49 | if count > 0: 50 | self.rankloss = self.rankloss / count 51 | 52 | return self.rankloss 53 | 54 | -------------------------------------------------------------------------------- /losses/dm_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from losses.consistency_loss import * 5 | from losses.ot_loss import OT_Loss 6 | 7 | class DMLoss(nn.Module): 8 | def __init__(self): 9 | super(DMLoss, self).__init__() 10 | self.DMLoss = 0.0 11 | self.losses = {} 12 | 13 | def forward(self, results, points, gt_discrete): 14 | self.DMLoss = 0.0 15 | self.losses = {} 16 | 17 | if results is None: 18 | self.DMLoss = 0.0 19 | elif isinstance(results, list) and len(results) > 0: 20 | count = 0 21 | for i in range(len(results[0])): 22 | with torch.set_grad_enabled(False): 23 | preds_mean = (results[0][i])/len(results[0][0][0]) 24 | 25 | for j in range(len(results)): 26 | var_sel = softmax_kl_loss(results[j][i], preds_mean) 27 | exp_var = torch.exp(-var_sel) 28 | consistency_dist = (preds_mean - results[j][i]) ** 2 29 | temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel) 30 | 31 | self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss}) 32 | self.DMLoss += temploss 33 | 34 | # Compute counting loss. 35 | count_loss = self.mae(outputs_L[0].sum(1).sum(1).sum(1), 36 | torch.from_numpy(gd_count).float().to(self.device))*self.args.reg 37 | epoch_count_loss.update(count_loss.item(), N) 38 | 39 | # Compute OT loss. 40 | ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L[0], points) 41 | 42 | ot_loss = ot_loss * self.args.ot 43 | ot_obj_value = ot_obj_value * self.args.ot 44 | epoch_ot_loss.update(ot_loss.item(), N) 45 | epoch_ot_obj_value.update(ot_obj_value.item(), N) 46 | epoch_wd.update(wd, N) 47 | 48 | gd_count_tensor = (torch.from_numpy(gd_count).float() 49 | .to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3)) 50 | 51 | gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6) 52 | tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)* 53 | torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv 54 | epoch_tv_loss.update(tv_loss.item(), N) 55 | 56 | count += 1 57 | if count > 0: 58 | self.multiconloss = self.multiconloss / count 59 | 60 | 61 | return self.multiconloss 62 | 63 | -------------------------------------------------------------------------------- /losses/ot_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | from .bregman_pytorch import sinkhorn 4 | 5 | class OT_Loss(Module): 6 | def __init__(self, c_size, stride, norm_cood, device, num_of_iter_in_ot=100, reg=10.0): 7 | super(OT_Loss, self).__init__() 8 | assert c_size % stride == 0 9 | 10 | self.c_size = c_size 11 | self.device = device 12 | self.norm_cood = norm_cood 13 | self.num_of_iter_in_ot = num_of_iter_in_ot 14 | self.reg = reg 15 | 16 | # coordinate is same to image space, set to constant since crop size is same 17 | self.cood = torch.arange(0, c_size, step=stride, 18 | dtype=torch.float32, device=device) + stride / 2 19 | self.density_size = self.cood.size(0) 20 | self.cood.unsqueeze_(0) # [1, #cood] 21 | if self.norm_cood: 22 | self.cood = self.cood / c_size * 2 - 1 # map to [-1, 1] 23 | self.output_size = self.cood.size(1) 24 | 25 | 26 | def forward(self, normed_density, unnormed_density, points): 27 | batch_size = normed_density.size(0) 28 | assert len(points) == batch_size 29 | assert self.output_size == normed_density.size(2) 30 | loss = torch.zeros([1]).to(self.device) 31 | ot_obj_values = torch.zeros([1]).to(self.device) 32 | wd = 0 # wasserstain distance 33 | for idx, im_points in enumerate(points): 34 | if len(im_points) > 0: 35 | # compute l2 square distance, it should be source target distance. [#gt, #cood * #cood] 36 | if self.norm_cood: 37 | im_points = im_points / self.c_size * 2 - 1 # map to [-1, 1] 38 | x = im_points[:, 0].unsqueeze_(1) # [#gt, 1] 39 | y = im_points[:, 1].unsqueeze_(1) 40 | x_dis = -2 * torch.matmul(x, self.cood) + x * x + self.cood * self.cood # [#gt, #cood] 41 | y_dis = -2 * torch.matmul(y, self.cood) + y * y + self.cood * self.cood 42 | y_dis.unsqueeze_(2) 43 | x_dis.unsqueeze_(1) 44 | dis = y_dis + x_dis 45 | dis = dis.view((dis.size(0), -1)) # size of [#gt, #cood * #cood] 46 | 47 | source_prob = normed_density[idx][0].view([-1]).detach() 48 | target_prob = (torch.ones([len(im_points)]) / len(im_points)).to(self.device) 49 | 50 | # use sinkhorn to solve OT, compute optimal beta. 51 | P, log = sinkhorn(target_prob, source_prob, dis, self.reg, maxIter=self.num_of_iter_in_ot, log=True) 52 | beta = log['beta'] # size is the same as source_prob: [#cood * #cood] 53 | ot_obj_values += torch.sum(normed_density[idx] * beta.view([1, self.output_size, self.output_size])) 54 | # compute the gradient of OT loss to predicted density (unnormed_density). 55 | # im_grad = beta / source_count - < beta, source_density> / (source_count)^2 56 | source_density = unnormed_density[idx][0].view([-1]).detach() 57 | source_count = source_density.sum() 58 | im_grad_1 = (source_count) / (source_count * source_count+1e-8) * beta # size of [#cood * #cood] 59 | im_grad_2 = (source_density * beta).sum() / (source_count * source_count + 1e-8) # size of 1 60 | im_grad = im_grad_1 - im_grad_2 61 | im_grad = im_grad.detach().view([1, self.output_size, self.output_size]) 62 | # Define loss = . The gradient of loss w.r.t prediced density is im_grad. 63 | loss += torch.sum(unnormed_density[idx] * im_grad) 64 | wd += torch.sum(dis * P).item() 65 | 66 | return loss, wd, ot_obj_values 67 | 68 | 69 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import datasets.crowd_gauss as crowd 6 | from network import pvt_cls as TCN 7 | import torch.nn.functional as F 8 | from scipy.io import savemat 9 | from sklearn.metrics import r2_score 10 | 11 | parser = argparse.ArgumentParser(description='Test ') 12 | parser.add_argument('--device', default='0', help='assign device') 13 | parser.add_argument('--batch-size', type=int, default=8, help='train batch size') 14 | parser.add_argument('--crop-size', type=int, default=256, help='the crop size of the train image') 15 | parser.add_argument('--model-path', type=str, default='/scratch/users/k2254235/ckpts/SEMI/Treeformer/best_model_mae-21.49_epoch-1759.pth', help='saved model path') 16 | parser.add_argument('--data-path', type=str, default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='dataset path') 17 | parser.add_argument('--dataset', type=str, default='TC') 18 | 19 | def test(args, isSave = True): 20 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu 21 | device = torch.device('cuda') 22 | 23 | model_path = args.model_path 24 | crop_size = args.crop_size 25 | data_path = args.data_path 26 | 27 | dataset = crowd.Crowd_TC(os.path.join(data_path, 'test_data'), crop_size, 1, method='val') 28 | dataloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, num_workers=1, pin_memory=True) 29 | 30 | model = TCN.pvt_treeformer(pretrained=False) 31 | model.to(device) 32 | model.load_state_dict(torch.load(model_path, device)) 33 | model.eval() 34 | image_errs = [] 35 | result = [] 36 | R2_es = [] 37 | R2_gt = [] 38 | l=0; 39 | for inputs, count, name, imgauss in dataloader: 40 | with torch.no_grad(): 41 | inputs = inputs.to(device) 42 | crop_imgs, crop_masks = [], [] 43 | b, c, h, w = inputs.size() 44 | rh, rw = args.crop_size, args.crop_size 45 | 46 | for i in range(0, h, rh): 47 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 48 | 49 | for j in range(0, w, rw): 50 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 51 | crop_imgs.append(inputs[:, :, gis:gie, gjs:gje]) 52 | mask = torch.zeros([b, 1, h, w]).to(device) 53 | mask[:, :, gis:gie, gjs:gje].fill_(1.0) 54 | crop_masks.append(mask) 55 | crop_imgs, crop_masks = map(lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks)) 56 | 57 | crop_preds = [] 58 | nz, bz = crop_imgs.size(0), args.batch_size 59 | for i in range(0, nz, bz): 60 | 61 | gs, gt = i, min(nz, i + bz) 62 | crop_pred, _ = model(crop_imgs[gs:gt]) 63 | crop_pred = crop_pred[0] 64 | 65 | _, _, h1, w1 = crop_pred.size() 66 | crop_pred = F.interpolate(crop_pred, size=(h1 * 4, w1 * 4), mode='bilinear', align_corners=True) / 16 67 | crop_preds.append(crop_pred) 68 | crop_preds = torch.cat(crop_preds, dim=0) 69 | #import pdb;pdb.set_trace() 70 | 71 | # splice them to the original size 72 | idx = 0 73 | pred_map = torch.zeros([b, 1, h, w]).to(device) 74 | for i in range(0, h, rh): 75 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 76 | for j in range(0, w, rw): 77 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 78 | pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx] 79 | idx += 1 80 | # for the overlapping area, compute average value 81 | mask = crop_masks.sum(dim=0).unsqueeze(0) 82 | outputs = pred_map / mask 83 | 84 | outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)/4 85 | outputs = pred_map / mask 86 | 87 | img_err = count[0].item() - torch.sum(outputs).item() 88 | R2_gt.append(count[0].item()) 89 | R2_es.append(torch.sum(outputs).item()) 90 | 91 | print("Img name: ", name, "Error: ", img_err, "GT count: ", count[0].item(), "Model out: ", torch.sum(outputs).item()) 92 | image_errs.append(img_err) 93 | result.append([name, count[0].item(), torch.sum(outputs).item(), img_err]) 94 | 95 | savemat('predictions/'+name[0]+'.mat', {'estimation':np.squeeze(outputs.cpu().data.numpy()), 96 | 'image': np.squeeze(inputs.cpu().data.numpy()), 'gt': np.squeeze(imgauss.cpu().data.numpy())}) 97 | l=l+1 98 | 99 | 100 | image_errs = np.array(image_errs) 101 | 102 | mse = np.sqrt(np.mean(np.square(image_errs))) 103 | mae = np.mean(np.abs(image_errs)) 104 | R_2 = r2_score(R2_gt,R2_es) 105 | 106 | print('{}: mae {}, mse {}, R2 {}\n'.format(model_path, mae, mse,R_2)) 107 | 108 | if isSave: 109 | with open("test.txt","w") as f: 110 | for i in range(len(result)): 111 | f.write(str(result[i]).replace('[','').replace(']','').replace(',', ' ')+"\n") 112 | f.close() 113 | 114 | if __name__ == '__main__': 115 | args = parser.parse_args() 116 | test(args, isSave= True) 117 | 118 | -------------------------------------------------------------------------------- /datasets/crowd.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch.utils.data as data 3 | import os 4 | from glob import glob 5 | import torch 6 | import torchvision.transforms.functional as F 7 | from torchvision import transforms 8 | import random 9 | import numpy as np 10 | import scipy.io as sio 11 | 12 | def random_crop(im_h, im_w, crop_h, crop_w): 13 | res_h = im_h - crop_h 14 | res_w = im_w - crop_w 15 | i = random.randint(0, res_h) 16 | j = random.randint(0, res_w) 17 | return i, j, crop_h, crop_w 18 | 19 | 20 | def gen_discrete_map(im_height, im_width, points): 21 | """ 22 | func: generate the discrete map. 23 | points: [num_gt, 2], for each row: [width, height] 24 | """ 25 | discrete_map = np.zeros([im_height, im_width], dtype=np.float32) 26 | h, w = discrete_map.shape[:2] 27 | num_gt = points.shape[0] 28 | if num_gt == 0: 29 | return discrete_map 30 | 31 | # fast create discrete map 32 | points_np = np.array(points).round().astype(int) 33 | p_h = np.minimum(points_np[:, 1], np.array([h-1]*num_gt).astype(int)) 34 | p_w = np.minimum(points_np[:, 0], np.array([w-1]*num_gt).astype(int)) 35 | p_index = torch.from_numpy(p_h* im_width + p_w).to(torch.int64) 36 | discrete_map = torch.zeros(im_width * im_height).scatter_add_(0, index=p_index, src=torch.ones(im_width*im_height)).view(im_height, im_width).numpy() 37 | 38 | ''' slow method 39 | for p in points: 40 | p = np.round(p).astype(int) 41 | p[0], p[1] = min(h - 1, p[1]), min(w - 1, p[0]) 42 | discrete_map[p[0], p[1]] += 1 43 | ''' 44 | assert np.sum(discrete_map) == num_gt 45 | return discrete_map 46 | 47 | 48 | class Base(data.Dataset): 49 | def __init__(self, root_path, crop_size, downsample_ratio=8): 50 | 51 | self.root_path = root_path 52 | self.c_size = crop_size 53 | self.d_ratio = downsample_ratio 54 | assert self.c_size % self.d_ratio == 0 55 | self.dc_size = self.c_size // self.d_ratio 56 | self.trans = transforms.Compose([ 57 | transforms.ToTensor(), 58 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 59 | ]) 60 | 61 | def __len__(self): 62 | pass 63 | 64 | def __getitem__(self, item): 65 | pass 66 | 67 | def train_transform(self, img, keypoints, gauss_im): 68 | wd, ht = img.size 69 | st_size = 1.0 * min(wd, ht) 70 | assert st_size >= self.c_size 71 | assert len(keypoints) >= 0 72 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) 73 | img = F.crop(img, i, j, h, w) 74 | gauss_im = F.crop(img, i, j, h, w) 75 | if len(keypoints) > 0: 76 | keypoints = keypoints - [j, i] 77 | idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \ 78 | (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h) 79 | keypoints = keypoints[idx_mask] 80 | else: 81 | keypoints = np.empty([0, 2]) 82 | 83 | gt_discrete = gen_discrete_map(h, w, keypoints) 84 | down_w = w // self.d_ratio 85 | down_h = h // self.d_ratio 86 | gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3)) 87 | assert np.sum(gt_discrete) == len(keypoints) 88 | 89 | if len(keypoints) > 0: 90 | if random.random() > 0.5: 91 | img = F.hflip(img) 92 | gauss_im = F.hflip(gauss_im) 93 | gt_discrete = np.fliplr(gt_discrete) 94 | keypoints[:, 0] = w - keypoints[:, 0] 95 | else: 96 | if random.random() > 0.5: 97 | img = F.hflip(img) 98 | gauss_im = F.hflip(gauss_im) 99 | gt_discrete = np.fliplr(gt_discrete) 100 | gt_discrete = np.expand_dims(gt_discrete, 0) 101 | 102 | return self.trans(img), gauss_im, torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(gt_discrete.copy()).float() 103 | 104 | 105 | 106 | class Crowd_TC(Base): 107 | def __init__(self, root_path, crop_size, downsample_ratio=8, method='train'): 108 | super().__init__(root_path, crop_size, downsample_ratio) 109 | self.method = method 110 | if method not in ['train', 'val']: 111 | raise Exception("not implement") 112 | 113 | self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg'))) 114 | 115 | print('number of img [{}]: {}'.format(method, len(self.im_list))) 116 | 117 | def __len__(self): 118 | return len(self.im_list) 119 | 120 | def __getitem__(self, item): 121 | img_path = self.im_list[item] 122 | name = os.path.basename(img_path).split('.')[0] 123 | gd_path = os.path.join(self.root_path, 'ground_truth', 'GT_{}.mat'.format(name)) 124 | img = Image.open(img_path).convert('RGB') 125 | keypoints = sio.loadmat(gd_path)['image_info'][0][0][0][0][0] 126 | gauss_path = os.path.join(self.root_path, 'ground_truth', '{}_densitymap.npy'.format(name)) 127 | gauss_im = torch.from_numpy(np.load(gauss_path)).float() 128 | #import pdb;pdb.set_trace() 129 | #print("label {}", item) 130 | 131 | if self.method == 'train': 132 | return self.train_transform(img, keypoints, gauss_im) 133 | elif self.method == 'val': 134 | wd, ht = img.size 135 | st_size = 1.0 * min(wd, ht) 136 | if st_size < self.c_size: 137 | rr = 1.0 * self.c_size / st_size 138 | wd = round(wd * rr) 139 | ht = round(ht * rr) 140 | st_size = 1.0 * min(wd, ht) 141 | img = img.resize((wd, ht), Image.BICUBIC) 142 | img = self.trans(img) 143 | #import pdb;pdb.set_trace() 144 | 145 | return img, len(keypoints), name, gauss_im 146 | 147 | def train_transform(self, img, keypoints, gauss_im): 148 | wd, ht = img.size 149 | st_size = 1.0 * min(wd, ht) 150 | # resize the image to fit the crop size 151 | if st_size < self.c_size: 152 | rr = 1.0 * self.c_size / st_size 153 | wd = round(wd * rr) 154 | ht = round(ht * rr) 155 | st_size = 1.0 * min(wd, ht) 156 | img = img.resize((wd, ht), Image.BICUBIC) 157 | #gauss_im = gauss_im.resize((wd, ht), Image.BICUBIC) 158 | keypoints = keypoints * rr 159 | assert st_size >= self.c_size, print(wd, ht) 160 | assert len(keypoints) >= 0 161 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) 162 | img = F.crop(img, i, j, h, w) 163 | gauss_im = F.crop(gauss_im, i, j, h, w) 164 | if len(keypoints) > 0: 165 | keypoints = keypoints - [j, i] 166 | idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \ 167 | (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h) 168 | keypoints = keypoints[idx_mask] 169 | else: 170 | keypoints = np.empty([0, 2]) 171 | 172 | gt_discrete = gen_discrete_map(h, w, keypoints) 173 | down_w = w // self.d_ratio 174 | down_h = h // self.d_ratio 175 | gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3)) 176 | assert np.sum(gt_discrete) == len(keypoints) 177 | 178 | 179 | if len(keypoints) > 0: 180 | if random.random() > 0.5: 181 | img = F.hflip(img) 182 | gauss_im = F.hflip(gauss_im) 183 | gt_discrete = np.fliplr(gt_discrete) 184 | keypoints[:, 0] = w - keypoints[:, 0] - 1 185 | else: 186 | if random.random() > 0.5: 187 | img = F.hflip(img) 188 | gauss_im = F.hflip(gauss_im) 189 | gt_discrete = np.fliplr(gt_discrete) 190 | gt_discrete = np.expand_dims(gt_discrete, 0) 191 | #import pdb;pdb.set_trace() 192 | 193 | return self.trans(img), gauss_im, torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(gt_discrete.copy()).float() 194 | 195 | 196 | class Base_UL(data.Dataset): 197 | def __init__(self, root_path, crop_size, downsample_ratio=8): 198 | self.root_path = root_path 199 | self.c_size = crop_size 200 | self.d_ratio = downsample_ratio 201 | assert self.c_size % self.d_ratio == 0 202 | self.dc_size = self.c_size // self.d_ratio 203 | self.trans = transforms.Compose([ 204 | transforms.ToTensor(), 205 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 206 | ]) 207 | 208 | def __len__(self): 209 | pass 210 | 211 | def __getitem__(self, item): 212 | pass 213 | 214 | def train_transform_ul(self, img): 215 | wd, ht = img.size 216 | st_size = 1.0 * min(wd, ht) 217 | assert st_size >= self.c_size 218 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) 219 | img = F.crop(img, i, j, h, w) 220 | 221 | if random.random() > 0.5: 222 | img = F.hflip(img) 223 | 224 | return self.trans(img) 225 | 226 | 227 | class Crowd_UL_TC(Base_UL): 228 | def __init__(self, root_path, crop_size, downsample_ratio=8, method='train_ul'): 229 | super().__init__(root_path, crop_size, downsample_ratio) 230 | self.method = method 231 | if method not in ['train_ul']: 232 | raise Exception("not implement") 233 | 234 | self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg'))) 235 | print('number of img [{}]: {}'.format(method, len(self.im_list))) 236 | 237 | def __len__(self): 238 | return len(self.im_list) 239 | 240 | def __getitem__(self, item): 241 | img_path = self.im_list[item] 242 | name = os.path.basename(img_path).split('.')[0] 243 | img = Image.open(img_path).convert('RGB') 244 | #print("un_label {}", item) 245 | 246 | return self.train_transform_ul(img) 247 | 248 | 249 | def train_transform_ul(self, img): 250 | wd, ht = img.size 251 | st_size = 1.0 * min(wd, ht) 252 | # resize the image to fit the crop size 253 | if st_size < self.c_size: 254 | rr = 1.0 * self.c_size / st_size 255 | wd = round(wd * rr) 256 | ht = round(ht * rr) 257 | st_size = 1.0 * min(wd, ht) 258 | img = img.resize((wd, ht), Image.BICUBIC) 259 | 260 | assert st_size >= self.c_size, print(wd, ht) 261 | 262 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) 263 | img = F.crop(img, i, j, h, w) 264 | if random.random() > 0.5: 265 | img = F.hflip(img) 266 | 267 | return self.trans(img),1 268 | 269 | -------------------------------------------------------------------------------- /losses/consistency_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from losses import ramps 6 | 7 | 8 | 9 | class consistency_weight(object): 10 | """ 11 | ramp_types = ['sigmoid_rampup', 'linear_rampup', 'cosine_rampup', 'log_rampup', 'exp_rampup'] 12 | """ 13 | def __init__(self, final_w, iters_per_epoch, rampup_starts=0, rampup_ends=7, ramp_type='sigmoid_rampup'): 14 | self.final_w = final_w 15 | self.iters_per_epoch = iters_per_epoch 16 | self.rampup_starts = rampup_starts * iters_per_epoch 17 | self.rampup_ends = rampup_ends * iters_per_epoch 18 | self.rampup_length = (self.rampup_ends - self.rampup_starts) 19 | self.rampup_func = getattr(ramps, ramp_type) 20 | self.current_rampup = 0 21 | 22 | def __call__(self, epoch, curr_iter): 23 | cur_total_iter = self.iters_per_epoch * epoch + curr_iter 24 | if cur_total_iter < self.rampup_starts: 25 | return 0 26 | self.current_rampup = self.rampup_func(cur_total_iter - self.rampup_starts, self.rampup_length) 27 | return self.final_w * self.current_rampup 28 | 29 | 30 | def CE_loss(input_logits, target_targets, ignore_index, temperature=1): 31 | return F.cross_entropy(input_logits/temperature, target_targets, ignore_index=ignore_index) 32 | 33 | # for FocalLoss 34 | def softmax_helper(x): 35 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py 36 | rpt = [1 for _ in range(len(x.size()))] 37 | rpt[1] = x.size(1) 38 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt) 39 | e_x = torch.exp(x - x_max) 40 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) 41 | 42 | def get_alpha(supervised_loader): 43 | # get number of classes 44 | num_labels = 0 45 | for image_batch, label_batch in supervised_loader: 46 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background 47 | l_unique = torch.unique(label_batch.data) 48 | list_unique = [element.item() for element in l_unique.flatten()] 49 | num_labels = max(max(list_unique),num_labels) 50 | num_classes = num_labels + 1 51 | # count class occurrences 52 | alpha = [0 for i in range(num_classes)] 53 | for image_batch, label_batch in supervised_loader: 54 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background 55 | l_unique = torch.unique(label_batch.data) 56 | list_unique = [element.item() for element in l_unique.flatten()] 57 | l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480]) 58 | list_count = [count.item() for count in l_unique_count.flatten()] 59 | for index in list_unique: 60 | alpha[index] += list_count[list_unique.index(index)] 61 | return alpha 62 | 63 | # for FocalLoss 64 | def softmax_helper(x): 65 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py 66 | rpt = [1 for _ in range(len(x.size()))] 67 | rpt[1] = x.size(1) 68 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt) 69 | e_x = torch.exp(x - x_max) 70 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) 71 | 72 | 73 | class FocalLoss(nn.Module): 74 | """ 75 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 76 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 77 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 78 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 79 | :param num_class: 80 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 81 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 82 | focus on hard misclassified example 83 | :param smooth: (float,double) smooth value when cross entropy 84 | :param balance_index: (int) balance class index, should be specific when alpha is float 85 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 86 | """ 87 | 88 | def __init__(self, apply_nonlin=None, ignore_index = None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True): 89 | super(FocalLoss, self).__init__() 90 | self.apply_nonlin = apply_nonlin 91 | self.alpha = alpha 92 | self.gamma = gamma 93 | self.balance_index = balance_index 94 | self.smooth = smooth 95 | self.size_average = size_average 96 | 97 | if self.smooth is not None: 98 | if self.smooth < 0 or self.smooth > 1.0: 99 | raise ValueError('smooth value should be in [0,1]') 100 | 101 | def forward(self, logit, target): 102 | if self.apply_nonlin is not None: 103 | logit = self.apply_nonlin(logit) 104 | num_class = logit.shape[1] 105 | 106 | if logit.dim() > 2: 107 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 108 | logit = logit.view(logit.size(0), logit.size(1), -1) 109 | logit = logit.permute(0, 2, 1).contiguous() 110 | logit = logit.view(-1, logit.size(-1)) 111 | target = torch.squeeze(target, 1) 112 | target = target.view(-1, 1) 113 | 114 | valid_mask = None 115 | if self.ignore_index is not None: 116 | valid_mask = target != self.ignore_index 117 | target = target * valid_mask 118 | 119 | alpha = self.alpha 120 | 121 | if alpha is None: 122 | alpha = torch.ones(num_class, 1) 123 | elif isinstance(alpha, (list, np.ndarray)): 124 | assert len(alpha) == num_class 125 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 126 | alpha = alpha / alpha.sum() 127 | alpha = 1/alpha # inverse of class frequency 128 | elif isinstance(alpha, float): 129 | alpha = torch.ones(num_class, 1) 130 | alpha = alpha * (1 - self.alpha) 131 | alpha[self.balance_index] = self.alpha 132 | 133 | else: 134 | raise TypeError('Not support alpha type') 135 | 136 | if alpha.device != logit.device: 137 | alpha = alpha.to(logit.device) 138 | 139 | idx = target.cpu().long() 140 | 141 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 142 | 143 | # to resolve error in idx in scatter_ 144 | idx[idx==225]=0 145 | 146 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 147 | if one_hot_key.device != logit.device: 148 | one_hot_key = one_hot_key.to(logit.device) 149 | 150 | if self.smooth: 151 | one_hot_key = torch.clamp( 152 | one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth) 153 | pt = (one_hot_key * logit).sum(1) + self.smooth 154 | logpt = pt.log() 155 | 156 | gamma = self.gamma 157 | 158 | alpha = alpha[idx] 159 | alpha = torch.squeeze(alpha) 160 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 161 | 162 | if valid_mask is not None: 163 | loss = loss * valid_mask.squeeze() 164 | 165 | if self.size_average: 166 | loss = loss.mean() 167 | else: 168 | loss = loss.sum() 169 | return loss 170 | 171 | 172 | class abCE_loss(nn.Module): 173 | """ 174 | Annealed-Bootstrapped cross-entropy loss 175 | """ 176 | def __init__(self, iters_per_epoch, epochs, num_classes, weight=None, 177 | reduction='mean', thresh=0.7, min_kept=1, ramp_type='log_rampup'): 178 | super(abCE_loss, self).__init__() 179 | self.weight = torch.FloatTensor(weight) if weight is not None else weight 180 | self.reduction = reduction 181 | self.thresh = thresh 182 | self.min_kept = min_kept 183 | self.ramp_type = ramp_type 184 | 185 | if ramp_type is not None: 186 | self.rampup_func = getattr(ramps, ramp_type) 187 | self.iters_per_epoch = iters_per_epoch 188 | self.num_classes = num_classes 189 | self.start = 1/num_classes 190 | self.end = 0.9 191 | self.total_num_iters = (epochs - (0.6 * epochs)) * iters_per_epoch 192 | 193 | def threshold(self, curr_iter, epoch): 194 | cur_total_iter = self.iters_per_epoch * epoch + curr_iter 195 | current_rampup = self.rampup_func(cur_total_iter, self.total_num_iters) 196 | return current_rampup * (self.end - self.start) + self.start 197 | 198 | def forward(self, predict, target, ignore_index, curr_iter, epoch): 199 | batch_kept = self.min_kept * target.size(0) 200 | prob_out = F.softmax(predict, dim=1) 201 | tmp_target = target.clone() 202 | tmp_target[tmp_target == ignore_index] = 0 203 | prob = prob_out.gather(1, tmp_target.unsqueeze(1)) 204 | mask = target.contiguous().view(-1, ) != ignore_index 205 | sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort() 206 | 207 | if self.ramp_type is not None: 208 | thresh = self.threshold(curr_iter=curr_iter, epoch=epoch) 209 | else: 210 | thresh = self.thresh 211 | 212 | min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0 213 | threshold = max(min_threshold, thresh) 214 | loss_matrix = F.cross_entropy(predict, target, 215 | weight=self.weight.to(predict.device) if self.weight is not None else None, 216 | ignore_index=ignore_index, reduction='none') 217 | loss_matirx = loss_matrix.contiguous().view(-1, ) 218 | sort_loss_matirx = loss_matirx[mask][sort_indices] 219 | select_loss_matrix = sort_loss_matirx[sort_prob < threshold] 220 | if self.reduction == 'sum' or select_loss_matrix.numel() == 0: 221 | return select_loss_matrix.sum() 222 | elif self.reduction == 'mean': 223 | return select_loss_matrix.mean() 224 | else: 225 | raise NotImplementedError('Reduction Error!') 226 | 227 | 228 | 229 | def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): 230 | assert inputs.requires_grad == True and targets.requires_grad == False 231 | assert inputs.size() == targets.size() # (batch_size * num_classes * H * W) 232 | inputs = F.softmax(inputs, dim=1) 233 | if use_softmax: 234 | targets = F.softmax(targets, dim=1) 235 | 236 | if conf_mask: 237 | loss_mat = F.mse_loss(inputs, targets, reduction='none') 238 | mask = (targets.max(1)[0] > threshold) 239 | loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] 240 | if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) 241 | return loss_mat.mean() 242 | else: 243 | return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size 244 | 245 | 246 | def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False): 247 | assert inputs.requires_grad == True and targets.requires_grad == False 248 | assert inputs.size() == targets.size() 249 | 250 | if use_softmax: 251 | targets = F.softmax(targets, dim=1) 252 | if conf_mask: 253 | loss_mat = F.kl_div(input_log_softmax, targets, reduction='none') 254 | mask = (targets.max(1)[0] > threshold) 255 | loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)] 256 | if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device) 257 | return loss_mat.sum() / mask.shape.numel() 258 | else: 259 | return F.kl_div(inputs, targets, reduction='mean') 260 | 261 | 262 | def softmax_js_loss(inputs, targets, **_): 263 | assert inputs.requires_grad == True and targets.requires_grad == False 264 | assert inputs.size() == targets.size() 265 | epsilon = 1e-5 266 | 267 | M = (F.softmax(inputs, dim=1) + targets) * 0.5 268 | kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean') 269 | kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean') 270 | return (kl1 + kl2) * 0.5 271 | 272 | 273 | 274 | def pair_wise_loss(unsup_outputs, size_average=True, nbr_of_pairs=8): 275 | """ 276 | Pair-wise loss in the sup. mat. 277 | """ 278 | if isinstance(unsup_outputs, list): 279 | unsup_outputs = torch.stack(unsup_outputs) 280 | 281 | # Only for a subset of the aux outputs to reduce computation and memory 282 | unsup_outputs = unsup_outputs[torch.randperm(unsup_outputs.size(0))] 283 | unsup_outputs = unsup_outputs[:nbr_of_pairs] 284 | 285 | temp = torch.zeros_like(unsup_outputs) # For grad purposes 286 | for i, u in enumerate(unsup_outputs): 287 | temp[i] = F.softmax(u, dim=1) 288 | mean_prediction = temp.mean(0).unsqueeze(0) # Mean over the auxiliary outputs 289 | pw_loss = ((temp - mean_prediction)**2).mean(0) # Variance 290 | pw_loss = pw_loss.sum(1) # Sum over classes 291 | if size_average: 292 | return pw_loss.mean() 293 | return pw_loss.sum() 294 | 295 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | from torch import optim 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.dataloader import default_collate 8 | import numpy as np 9 | from datetime import datetime 10 | import torch.nn.functional as F 11 | from datasets.crowd import Crowd_TC, Crowd_UL_TC 12 | 13 | from network import pvt_cls as TCN 14 | from losses.multi_con_loss import MultiConLoss 15 | 16 | from utils.pytorch_utils import Save_Handle, AverageMeter 17 | import utils.log_utils as log_utils 18 | import argparse 19 | from losses.rank_loss import RankLoss 20 | 21 | from losses import ramps 22 | from losses.ot_loss import OT_Loss 23 | from losses.consistency_loss import * 24 | 25 | parser = argparse.ArgumentParser(description='Train') 26 | parser.add_argument('--data-dir', default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='data path') 27 | 28 | parser.add_argument('--dataset', default='TC') 29 | parser.add_argument('--lr', type=float, default=1e-5, help='the initial learning rate') 30 | parser.add_argument('--weight-decay', type=float, default=1e-4, help='the weight decay') 31 | parser.add_argument('--resume', default='', type=str, help='the path of resume training model') 32 | parser.add_argument('--max-epoch', type=int, default=4000, help='max training epoch') 33 | parser.add_argument('--val-epoch', type=int, default=1, help='the num of steps to log training information') 34 | parser.add_argument('--val-start', type=int, default=0, help='the epoch start to val') 35 | parser.add_argument('--batch-size', type=int, default=16, help='train batch size') 36 | parser.add_argument('--batch-size-ul', type=int, default=16, help='train batch size') 37 | parser.add_argument('--device', default='0', help='assign device') 38 | parser.add_argument('--num-workers', type=int, default=0, help='the num of training process') 39 | parser.add_argument('--crop-size', type=int, default= 256, help='the crop size of the train image') 40 | parser.add_argument('--rl', type=float, default=1, help='entropy regularization in sinkhorn') 41 | parser.add_argument('--reg', type=float, default=1, help='entropy regularization in sinkhorn') 42 | parser.add_argument('--ot', type=float, default=0.1, help='entropy regularization in sinkhorn') 43 | parser.add_argument('--tv', type=float, default=0.01, help='entropy regularization in sinkhorn') 44 | parser.add_argument('--num-of-iter-in-ot', type=int, default=100, help='sinkhorn iterations') 45 | parser.add_argument('--norm-cood', type=int, default=0, help='whether to norm cood when computing distance') 46 | parser.add_argument('--run-name', default='Treeformer_test', help='run name for wandb interface/logging') 47 | parser.add_argument('--consistency', type=int, default=1, help='whether to norm cood when computing distance') 48 | args = parser.parse_args() 49 | 50 | 51 | def train_collate(batch): 52 | transposed_batch = list(zip(*batch)) 53 | images = torch.stack(transposed_batch[0], 0) 54 | gauss = torch.stack(transposed_batch[1], 0) 55 | points = transposed_batch[2] 56 | gt_discretes = torch.stack(transposed_batch[3], 0) 57 | return images, gauss, points, gt_discretes 58 | 59 | 60 | def train_collate_UL(batch): 61 | transposed_batch = list(zip(*batch)) 62 | images = torch.stack(transposed_batch[0], 0) 63 | 64 | return images 65 | 66 | def get_current_consistency_weight(epoch): 67 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242 68 | return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_ramp) 69 | 70 | 71 | class Trainer(object): 72 | def __init__(self, args): 73 | self.args = args 74 | 75 | def setup(self): 76 | args = self.args 77 | sub_dir = ( 78 | "SEMI/{}_12-1-input-{}_reg-{}_nIter-{}_normCood-{}".format( 79 | args.run_name,args.crop_size,args.reg, 80 | args.num_of_iter_in_ot,args.norm_cood)) 81 | 82 | self.save_dir = os.path.join("/scratch/users/k2254235","ckpts", sub_dir) 83 | if not os.path.exists(self.save_dir): 84 | os.makedirs(self.save_dir) 85 | 86 | time_str = datetime.strftime(datetime.now(), "%m%d-%H%M%S") 87 | self.logger = log_utils.get_logger( 88 | os.path.join(self.save_dir, "train-{:s}.log".format(time_str))) 89 | 90 | log_utils.print_config(vars(args), self.logger) 91 | 92 | if torch.cuda.is_available(): 93 | self.device = torch.device("cuda") 94 | self.device_count = torch.cuda.device_count() 95 | self.logger.info("using {} gpus".format(self.device_count)) 96 | else: 97 | raise Exception("gpu is not available") 98 | 99 | 100 | downsample_ratio = 4 101 | self.datasets = {"train": Crowd_TC(os.path.join(args.data_dir, "train_data"), args.crop_size, 102 | downsample_ratio, "train"), "val": Crowd_TC(os.path.join(args.data_dir, "valid_data"), 103 | args.crop_size, downsample_ratio, "val")} 104 | 105 | self.datasets_ul = { "train_ul": Crowd_UL_TC(os.path.join(args.data_dir, "train_data_ul"), 106 | args.crop_size, downsample_ratio, "train_ul")} 107 | 108 | 109 | self.dataloaders = { 110 | x: DataLoader(self.datasets[x], 111 | collate_fn=(train_collate if x == "train" else default_collate), 112 | batch_size=(args.batch_size if x == "train" else 1), 113 | shuffle=(True if x == "train" else False), 114 | num_workers=args.num_workers * self.device_count, 115 | pin_memory=(True if x == "train" else False)) 116 | for x in ["train", "val"]} 117 | 118 | self.dataloaders_ul = { 119 | x: DataLoader(self.datasets_ul[x], 120 | collate_fn=(train_collate_UL ), 121 | batch_size=(args.batch_size_ul), 122 | shuffle=(True), 123 | num_workers=args.num_workers * self.device_count, 124 | pin_memory=(True if x == "train" else False)) 125 | for x in ["train_ul"]} 126 | 127 | 128 | self.model = TCN.pvt_treeformer(pretrained=False) 129 | 130 | self.model.to(self.device) 131 | self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 132 | self.start_epoch = 0 133 | 134 | if args.resume: 135 | self.logger.info("loading pretrained model from " + args.resume) 136 | suf = args.resume.rsplit(".", 1)[-1] 137 | if suf == "tar": 138 | checkpoint = torch.load(args.resume, self.device) 139 | self.model.load_state_dict(checkpoint["model_state_dict"]) 140 | self.optimizer.load_state_dict( 141 | checkpoint["optimizer_state_dict"]) 142 | self.start_epoch = checkpoint["epoch"] + 1 143 | elif suf == "pth": 144 | self.model.load_state_dict( 145 | torch.load(args.resume, self.device)) 146 | else: 147 | self.logger.info("random initialization") 148 | 149 | self.ot_loss = OT_Loss(args.crop_size, downsample_ratio, args.norm_cood, 150 | self.device, args.num_of_iter_in_ot, args.reg) 151 | 152 | self.tvloss = nn.L1Loss(reduction="none").to(self.device) 153 | self.mse = nn.MSELoss().to(self.device) 154 | self.mae = nn.L1Loss().to(self.device) 155 | self.save_list = Save_Handle(max_num=1) 156 | self.best_mae = np.inf 157 | self.best_mse = np.inf 158 | self.rankloss = RankLoss().to(self.device) 159 | self.kl_distance = nn.KLDivLoss(reduction='none') 160 | self.multiconloss = MultiConLoss().to(self.device) 161 | 162 | 163 | def train(self): 164 | """training process""" 165 | args = self.args 166 | for epoch in range(self.start_epoch, args.max_epoch + 1): 167 | self.logger.info("-" * 5 + "Epoch {}/{}".format(epoch, args.max_epoch) + "-" * 5) 168 | self.epoch = epoch 169 | self.train_epoch() 170 | if epoch % args.val_epoch == 0 and epoch >= args.val_start: 171 | self.val_epoch() 172 | 173 | def train_epoch(self): 174 | epoch_ot_loss = AverageMeter() 175 | epoch_ot_obj_value = AverageMeter() 176 | epoch_wd = AverageMeter() 177 | epoch_tv_loss = AverageMeter() 178 | epoch_count_loss = AverageMeter() 179 | epoch_count_consistency_l = AverageMeter() 180 | epoch_count_consistency_ul = AverageMeter() 181 | epoch_loss = AverageMeter() 182 | epoch_mae = AverageMeter() 183 | epoch_mse = AverageMeter() 184 | epoch_start = time.time() 185 | epoch_rank_loss = AverageMeter() 186 | epoch_consistensy_loss = AverageMeter() 187 | 188 | self.model.train() # Set model to training mode 189 | 190 | for step, (inputs, gausss, points, gt_discrete) in enumerate(self.dataloaders["train"]): 191 | inputs = inputs.to(self.device) 192 | gausss = gausss.to(self.device) 193 | gd_count = np.array([len(p) for p in points], dtype=np.float32) 194 | 195 | points = [p.to(self.device) for p in points] 196 | gt_discrete = gt_discrete.to(self.device) 197 | N = inputs.size(0) 198 | 199 | for st, unlabel_data in enumerate(self.dataloaders_ul["train_ul"]): 200 | inputs_ul = unlabel_data.to(self.device) 201 | break 202 | 203 | 204 | with torch.set_grad_enabled(True): 205 | outputs_L, outputs_UL, outputs_normed, CLS_L, CLS_UL = self.model(inputs, inputs_ul) 206 | outputs_L = outputs_L[0] 207 | 208 | with torch.set_grad_enabled(False): 209 | preds_UL = (outputs_UL[0][0] + outputs_UL[1][0] + outputs_UL[2][0])/3 210 | 211 | # Compute counting loss. 212 | count_loss = self.mae(outputs_L.sum(1).sum(1).sum(1),torch.from_numpy(gd_count).float().to(self.device))*self.args.reg 213 | 214 | # Compute OT loss. 215 | ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L, points) 216 | ot_loss = ot_loss* self.args.ot 217 | ot_obj_value = ot_obj_value* self.args.ot 218 | 219 | gd_count_tensor = (torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3)) 220 | gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6) 221 | tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)* 222 | torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv 223 | 224 | epoch_ot_loss.update(ot_loss.item(), N) 225 | epoch_ot_obj_value.update(ot_obj_value.item(), N) 226 | epoch_wd.update(wd, N) 227 | epoch_count_loss.update(count_loss.item(), N) 228 | epoch_tv_loss.update(tv_loss.item(), N) 229 | 230 | # Compute ranking loss. 231 | rank_loss = self.rankloss(outputs_UL)*self.args.rl 232 | epoch_rank_loss.update(rank_loss.item(), N) 233 | 234 | # Compute multi level consistancy loss 235 | consistency_loss = args.consistency * self.multiconloss(outputs_UL) 236 | epoch_consistensy_loss.update(consistency_loss.item(), N) 237 | 238 | 239 | # Compute consistency count 240 | Con_cls_UL = (CLS_UL[0] + CLS_UL[1] + CLS_UL[2])/3 241 | Con_cls_L = torch.from_numpy(gd_count).float().to(self.device) 242 | 243 | count_loss_l = self.mae(torch.stack((CLS_L[0],CLS_L[1],CLS_L[2])), torch.stack((Con_cls_L, Con_cls_L, Con_cls_L))) 244 | count_loss_ul = self.mae(torch.stack((CLS_UL[0],CLS_UL[1],CLS_UL[2])), torch.stack((Con_cls_UL, Con_cls_UL, Con_cls_UL))) 245 | epoch_count_consistency_l.update(count_loss_l.item(), N) 246 | epoch_count_consistency_ul.update(count_loss_ul.item(), N) 247 | 248 | 249 | loss = count_loss + ot_loss + tv_loss + rank_loss + count_loss_l + count_loss_ul + consistency_loss 250 | 251 | 252 | self.optimizer.zero_grad() 253 | loss.backward() 254 | self.optimizer.step() 255 | 256 | pred_count = (torch.sum(outputs_L.view(N, -1), 257 | dim=1).detach().cpu().numpy()) 258 | 259 | pred_err = pred_count - gd_count 260 | epoch_loss.update(loss.item(), N) 261 | epoch_mse.update(np.mean(pred_err * pred_err), N) 262 | epoch_mae.update(np.mean(abs(pred_err)), N) 263 | 264 | 265 | self.logger.info( 266 | "Epoch {} Train, Loss: {:.2f}, Count Loss: {:.2f}, OT Loss: {:.2e}, TV Loss: {:.2e}, Rank Loss: {:.2f}," 267 | "Consistensy Loss: {:.2f}, MSE: {:.2f}, MAE: {:.2f},LC Loss: {:.2f}, ULC Loss: {:.2f}, Cost {:.1f} sec".format( 268 | self.epoch, epoch_loss.get_avg(), epoch_count_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_tv_loss.get_avg(), epoch_rank_loss.get_avg(), 269 | epoch_consistensy_loss.get_avg(), np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(), epoch_count_consistency_l.get_avg(), 270 | epoch_count_consistency_ul.get_avg(), time.time() - epoch_start)) 271 | 272 | 273 | 274 | model_state_dic = self.model.state_dict() 275 | save_path = os.path.join(self.save_dir, "{}_ckpt.tar".format(self.epoch)) 276 | 277 | torch.save({"epoch": self.epoch, "optimizer_state_dict": self.optimizer.state_dict(), 278 | "model_state_dict": model_state_dic}, save_path) 279 | self.save_list.append(save_path) 280 | 281 | def val_epoch(self): 282 | args = self.args 283 | epoch_start = time.time() 284 | self.model.eval() # Set model to evaluate mode 285 | epoch_res = [] 286 | for inputs, count, name, gauss_im in self.dataloaders["val"]: 287 | with torch.no_grad(): 288 | inputs = inputs.to(self.device) 289 | crop_imgs, crop_masks = [], [] 290 | b, c, h, w = inputs.size() 291 | rh, rw = args.crop_size, args.crop_size 292 | for i in range(0, h, rh): 293 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 294 | for j in range(0, w, rw): 295 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 296 | crop_imgs.append(inputs[:, :, gis:gie, gjs:gje]) 297 | mask = torch.zeros([b, 1, h, w]).to(self.device) 298 | mask[:, :, gis:gie, gjs:gje].fill_(1.0) 299 | crop_masks.append(mask) 300 | crop_imgs, crop_masks = map( 301 | lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks)) 302 | 303 | crop_preds = [] 304 | nz, bz = crop_imgs.size(0), args.batch_size 305 | for i in range(0, nz, bz): 306 | gs, gt = i, min(nz, i + bz) 307 | 308 | crop_pred, _ = self.model(crop_imgs[gs:gt]) 309 | crop_pred = crop_pred[0] 310 | _, _, h1, w1 = crop_pred.size() 311 | crop_pred = (F.interpolate(crop_pred, size=(h1 * 4, w1 * 4), 312 | mode="bilinear", align_corners=True) / 16 ) 313 | 314 | crop_preds.append(crop_pred) 315 | crop_preds = torch.cat(crop_preds, dim=0) 316 | 317 | # splice them to the original size 318 | idx = 0 319 | pred_map = torch.zeros([b, 1, h, w]).to(self.device) 320 | for i in range(0, h, rh): 321 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 322 | for j in range(0, w, rw): 323 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 324 | pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx] 325 | idx += 1 326 | # for the overlapping area, compute average value 327 | mask = crop_masks.sum(dim=0).unsqueeze(0) 328 | outputs = pred_map / mask 329 | 330 | res = count[0].item() - torch.sum(outputs).item() 331 | epoch_res.append(res) 332 | epoch_res = np.array(epoch_res) 333 | mse = np.sqrt(np.mean(np.square(epoch_res))) 334 | mae = np.mean(np.abs(epoch_res)) 335 | 336 | self.logger.info("Epoch {} Val, MSE: {:.2f}, MAE: {:.2f}, Cost {:.1f} sec".format( 337 | self.epoch, mse, mae, time.time() - epoch_start )) 338 | 339 | 340 | model_state_dic = self.model.state_dict() 341 | print("Comaprison", mae, self.best_mae) 342 | if mae < self.best_mae: 343 | self.best_mse = mse 344 | self.best_mae = mae 345 | self.logger.info( 346 | "save best mse {:.2f} mae {:.2f} model epoch {}".format( 347 | self.best_mse, self.best_mae, self.epoch)) 348 | 349 | print("Saving best model at {} epoch".format(self.epoch)) 350 | model_path = os.path.join( 351 | self.save_dir, "best_model_mae-{:.2f}_epoch-{}.pth".format( 352 | self.best_mae, self.epoch)) 353 | 354 | torch.save(model_state_dic, model_path) 355 | 356 | 357 | if __name__ == "__main__": 358 | import torch 359 | torch.backends.cudnn.benchmark = True 360 | trainer = Trainer(args) 361 | trainer.setup() 362 | trainer.train() 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | -------------------------------------------------------------------------------- /losses/bregman_pytorch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Rewrite ot.bregman.sinkhorn in Python Optimal Transport (https://pythonot.github.io/_modules/ot/bregman.html#sinkhorn) 4 | using pytorch operations. 5 | Bregman projections for regularized OT (Sinkhorn distance). 6 | """ 7 | 8 | import torch 9 | 10 | M_EPS = 1e-16 11 | 12 | 13 | def sinkhorn(a, b, C, reg=1e-1, method='sinkhorn', maxIter=1000, tau=1e3, 14 | stopThr=1e-9, verbose=False, log=True, warm_start=None, eval_freq=10, print_freq=200, **kwargs): 15 | """ 16 | Solve the entropic regularization optimal transport 17 | The input should be PyTorch tensors 18 | The function solves the following optimization problem: 19 | 20 | .. math:: 21 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) 22 | s.t. \gamma 1 = a 23 | \gamma^T 1= b 24 | \gamma\geq 0 25 | where : 26 | - C is the (ns,nt) metric cost matrix 27 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 28 | - a and b are target and source measures (sum to 1) 29 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. 30 | 31 | Parameters 32 | ---------- 33 | a : torch.tensor (na,) 34 | samples measure in the target domain 35 | b : torch.tensor (nb,) 36 | samples in the source domain 37 | C : torch.tensor (na,nb) 38 | loss matrix 39 | reg : float 40 | Regularization term > 0 41 | method : str 42 | method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or 43 | 'sinkhorn_epsilon_scaling', see those function for specific parameters 44 | maxIter : int, optional 45 | Max number of iterations 46 | stopThr : float, optional 47 | Stop threshol on error ( > 0 ) 48 | verbose : bool, optional 49 | Print information along iterations 50 | log : bool, optional 51 | record log if True 52 | 53 | Returns 54 | ------- 55 | gamma : (na x nb) torch.tensor 56 | Optimal transportation matrix for the given parameters 57 | log : dict 58 | log dictionary return only if log==True in parameters 59 | 60 | References 61 | ---------- 62 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 63 | See Also 64 | -------- 65 | 66 | """ 67 | 68 | if method.lower() == 'sinkhorn': 69 | return sinkhorn_knopp(a, b, C, reg, maxIter=maxIter, 70 | stopThr=stopThr, verbose=verbose, log=log, 71 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq, 72 | **kwargs) 73 | elif method.lower() == 'sinkhorn_stabilized': 74 | return sinkhorn_stabilized(a, b, C, reg, maxIter=maxIter, tau=tau, 75 | stopThr=stopThr, verbose=verbose, log=log, 76 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq, 77 | **kwargs) 78 | elif method.lower() == 'sinkhorn_epsilon_scaling': 79 | return sinkhorn_epsilon_scaling(a, b, C, reg, 80 | maxIter=maxIter, maxInnerIter=100, tau=tau, 81 | scaling_base=0.75, scaling_coef=None, stopThr=stopThr, 82 | verbose=False, log=log, warm_start=warm_start, eval_freq=eval_freq, 83 | print_freq=print_freq, **kwargs) 84 | else: 85 | raise ValueError("Unknown method '%s'." % method) 86 | 87 | 88 | def sinkhorn_knopp(a, b, C, reg=1e-1, maxIter=1000, stopThr=1e-9, 89 | verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs): 90 | """ 91 | Solve the entropic regularization optimal transport 92 | The input should be PyTorch tensors 93 | The function solves the following optimization problem: 94 | 95 | .. math:: 96 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) 97 | s.t. \gamma 1 = a 98 | \gamma^T 1= b 99 | \gamma\geq 0 100 | where : 101 | - C is the (ns,nt) metric cost matrix 102 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 103 | - a and b are target and source measures (sum to 1) 104 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. 105 | 106 | Parameters 107 | ---------- 108 | a : torch.tensor (na,) 109 | samples measure in the target domain 110 | b : torch.tensor (nb,) 111 | samples in the source domain 112 | C : torch.tensor (na,nb) 113 | loss matrix 114 | reg : float 115 | Regularization term > 0 116 | maxIter : int, optional 117 | Max number of iterations 118 | stopThr : float, optional 119 | Stop threshol on error ( > 0 ) 120 | verbose : bool, optional 121 | Print information along iterations 122 | log : bool, optional 123 | record log if True 124 | 125 | Returns 126 | ------- 127 | gamma : (na x nb) torch.tensor 128 | Optimal transportation matrix for the given parameters 129 | log : dict 130 | log dictionary return only if log==True in parameters 131 | 132 | References 133 | ---------- 134 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 135 | See Also 136 | -------- 137 | 138 | """ 139 | 140 | device = a.device 141 | na, nb = C.shape 142 | 143 | assert na >= 1 and nb >= 1, 'C needs to be 2d' 144 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C" 145 | assert reg > 0, 'reg should be greater than 0' 146 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0' 147 | 148 | if log: 149 | log = {'err': []} 150 | 151 | if warm_start is not None: 152 | u = warm_start['u'] 153 | v = warm_start['v'] 154 | else: 155 | u = torch.ones(na, dtype=a.dtype).to(device) / na 156 | v = torch.ones(nb, dtype=b.dtype).to(device) / nb 157 | 158 | K = torch.empty(C.shape, dtype=C.dtype).to(device) 159 | torch.div(C, -reg, out=K) 160 | torch.exp(K, out=K) 161 | 162 | b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) 163 | 164 | it = 1 165 | err = 1 166 | 167 | # allocate memory beforehand 168 | KTu = torch.empty(v.shape, dtype=v.dtype).to(device) 169 | Kv = torch.empty(u.shape, dtype=u.dtype).to(device) 170 | 171 | while (err > stopThr and it <= maxIter): 172 | upre, vpre = u, v 173 | torch.matmul(u, K, out=KTu) 174 | v = torch.div(b, KTu + M_EPS) 175 | torch.matmul(K, v, out=Kv) 176 | u = torch.div(a, Kv + M_EPS) 177 | 178 | if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \ 179 | torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)): 180 | print('Warning: numerical errors at iteration', it) 181 | u, v = upre, vpre 182 | break 183 | 184 | if log and it % eval_freq == 0: 185 | # we can speed up the process by checking for the error only all 186 | # the eval_freq iterations 187 | # below is equivalent to: 188 | # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0) 189 | # but with more memory efficient 190 | b_hat = torch.matmul(u, K) * v 191 | err = (b - b_hat).pow(2).sum().item() 192 | # err = (b - b_hat).abs().sum().item() 193 | log['err'].append(err) 194 | 195 | if verbose and it % print_freq == 0: 196 | print('iteration {:5d}, constraint error {:5e}'.format(it, err)) 197 | 198 | it += 1 199 | 200 | if log: 201 | log['u'] = u 202 | log['v'] = v 203 | log['alpha'] = reg * torch.log(u + M_EPS) 204 | log['beta'] = reg * torch.log(v + M_EPS) 205 | 206 | # transport plan 207 | P = u.reshape(-1, 1) * K * v.reshape(1, -1) 208 | if log: 209 | return P, log 210 | else: 211 | return P 212 | 213 | 214 | def sinkhorn_stabilized(a, b, C, reg=1e-1, maxIter=1000, tau=1e3, stopThr=1e-9, 215 | verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs): 216 | """ 217 | Solve the entropic regularization OT problem with log stabilization 218 | The function solves the following optimization problem: 219 | 220 | .. math:: 221 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) 222 | s.t. \gamma 1 = a 223 | \gamma^T 1= b 224 | \gamma\geq 0 225 | where : 226 | - C is the (ns,nt) metric cost matrix 227 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 228 | - a and b are target and source measures (sum to 1) 229 | 230 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1] 231 | but with the log stabilization proposed in [3] an defined in [2] (Algo 3.1) 232 | 233 | Parameters 234 | ---------- 235 | a : torch.tensor (na,) 236 | samples measure in the target domain 237 | b : torch.tensor (nb,) 238 | samples in the source domain 239 | C : torch.tensor (na,nb) 240 | loss matrix 241 | reg : float 242 | Regularization term > 0 243 | tau : float 244 | thershold for max value in u or v for log scaling 245 | maxIter : int, optional 246 | Max number of iterations 247 | stopThr : float, optional 248 | Stop threshol on error ( > 0 ) 249 | verbose : bool, optional 250 | Print information along iterations 251 | log : bool, optional 252 | record log if True 253 | 254 | Returns 255 | ------- 256 | gamma : (na x nb) torch.tensor 257 | Optimal transportation matrix for the given parameters 258 | log : dict 259 | log dictionary return only if log==True in parameters 260 | 261 | References 262 | ---------- 263 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 264 | [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019 265 | [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. 266 | 267 | See Also 268 | -------- 269 | 270 | """ 271 | 272 | device = a.device 273 | na, nb = C.shape 274 | 275 | assert na >= 1 and nb >= 1, 'C needs to be 2d' 276 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C" 277 | assert reg > 0, 'reg should be greater than 0' 278 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0' 279 | 280 | if log: 281 | log = {'err': []} 282 | 283 | if warm_start is not None: 284 | alpha = warm_start['alpha'] 285 | beta = warm_start['beta'] 286 | else: 287 | alpha = torch.zeros(na, dtype=a.dtype).to(device) 288 | beta = torch.zeros(nb, dtype=b.dtype).to(device) 289 | 290 | u = torch.ones(na, dtype=a.dtype).to(device) / na 291 | v = torch.ones(nb, dtype=b.dtype).to(device) / nb 292 | 293 | def update_K(alpha, beta): 294 | """log space computation""" 295 | """memory efficient""" 296 | torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=K) 297 | torch.add(K, -C, out=K) 298 | torch.div(K, reg, out=K) 299 | torch.exp(K, out=K) 300 | 301 | def update_P(alpha, beta, u, v, ab_updated=False): 302 | """log space P (gamma) computation""" 303 | torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=P) 304 | torch.add(P, -C, out=P) 305 | torch.div(P, reg, out=P) 306 | if not ab_updated: 307 | torch.add(P, torch.log(u + M_EPS).reshape(-1, 1), out=P) 308 | torch.add(P, torch.log(v + M_EPS).reshape(1, -1), out=P) 309 | torch.exp(P, out=P) 310 | 311 | K = torch.empty(C.shape, dtype=C.dtype).to(device) 312 | update_K(alpha, beta) 313 | 314 | b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) 315 | 316 | it = 1 317 | err = 1 318 | ab_updated = False 319 | 320 | # allocate memory beforehand 321 | KTu = torch.empty(v.shape, dtype=v.dtype).to(device) 322 | Kv = torch.empty(u.shape, dtype=u.dtype).to(device) 323 | P = torch.empty(C.shape, dtype=C.dtype).to(device) 324 | 325 | while (err > stopThr and it <= maxIter): 326 | upre, vpre = u, v 327 | torch.matmul(u, K, out=KTu) 328 | v = torch.div(b, KTu + M_EPS) 329 | torch.matmul(K, v, out=Kv) 330 | u = torch.div(a, Kv + M_EPS) 331 | 332 | ab_updated = False 333 | # remove numerical problems and store them in K 334 | if u.abs().sum() > tau or v.abs().sum() > tau: 335 | alpha += reg * torch.log(u + M_EPS) 336 | beta += reg * torch.log(v + M_EPS) 337 | u.fill_(1. / na) 338 | v.fill_(1. / nb) 339 | update_K(alpha, beta) 340 | ab_updated = True 341 | 342 | if log and it % eval_freq == 0: 343 | # we can speed up the process by checking for the error only all 344 | # the eval_freq iterations 345 | update_P(alpha, beta, u, v, ab_updated) 346 | b_hat = torch.sum(P, 0) 347 | err = (b - b_hat).pow(2).sum().item() 348 | log['err'].append(err) 349 | 350 | if verbose and it % print_freq == 0: 351 | print('iteration {:5d}, constraint error {:5e}'.format(it, err)) 352 | 353 | it += 1 354 | 355 | if log: 356 | log['u'] = u 357 | log['v'] = v 358 | log['alpha'] = alpha + reg * torch.log(u + M_EPS) 359 | log['beta'] = beta + reg * torch.log(v + M_EPS) 360 | 361 | # transport plan 362 | update_P(alpha, beta, u, v, False) 363 | 364 | if log: 365 | return P, log 366 | else: 367 | return P 368 | 369 | 370 | def sinkhorn_epsilon_scaling(a, b, C, reg=1e-1, maxIter=100, maxInnerIter=100, tau=1e3, scaling_base=0.75, 371 | scaling_coef=None, stopThr=1e-9, verbose=False, log=False, warm_start=None, eval_freq=10, 372 | print_freq=200, **kwargs): 373 | """ 374 | Solve the entropic regularization OT problem with log stabilization 375 | The function solves the following optimization problem: 376 | 377 | .. math:: 378 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) 379 | s.t. \gamma 1 = a 380 | \gamma^T 1= b 381 | \gamma\geq 0 382 | where : 383 | - C is the (ns,nt) metric cost matrix 384 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 385 | - a and b are target and source measures (sum to 1) 386 | 387 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix 388 | scaling algorithm as proposed in [1] but with the log stabilization 389 | proposed in [3] and the log scaling proposed in [2] algorithm 3.2 390 | 391 | Parameters 392 | ---------- 393 | a : torch.tensor (na,) 394 | samples measure in the target domain 395 | b : torch.tensor (nb,) 396 | samples in the source domain 397 | C : torch.tensor (na,nb) 398 | loss matrix 399 | reg : float 400 | Regularization term > 0 401 | tau : float 402 | thershold for max value in u or v for log scaling 403 | maxIter : int, optional 404 | Max number of iterations 405 | stopThr : float, optional 406 | Stop threshol on error ( > 0 ) 407 | verbose : bool, optional 408 | Print information along iterations 409 | log : bool, optional 410 | record log if True 411 | 412 | Returns 413 | ------- 414 | gamma : (na x nb) torch.tensor 415 | Optimal transportation matrix for the given parameters 416 | log : dict 417 | log dictionary return only if log==True in parameters 418 | 419 | References 420 | ---------- 421 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 422 | [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019 423 | [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. 424 | 425 | See Also 426 | -------- 427 | 428 | """ 429 | 430 | na, nb = C.shape 431 | 432 | assert na >= 1 and nb >= 1, 'C needs to be 2d' 433 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C" 434 | assert reg > 0, 'reg should be greater than 0' 435 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0' 436 | 437 | def get_reg(it, reg, pre_reg): 438 | if it == 1: 439 | return scaling_coef 440 | else: 441 | if (pre_reg - reg) * scaling_base < M_EPS: 442 | return reg 443 | else: 444 | return (pre_reg - reg) * scaling_base + reg 445 | 446 | if scaling_coef is None: 447 | scaling_coef = C.max() + reg 448 | 449 | it = 1 450 | err = 1 451 | running_reg = scaling_coef 452 | 453 | if log: 454 | log = {'err': []} 455 | 456 | warm_start = None 457 | 458 | while (err > stopThr and it <= maxIter): 459 | running_reg = get_reg(it, reg, running_reg) 460 | P, _log = sinkhorn_stabilized(a, b, C, running_reg, maxIter=maxInnerIter, tau=tau, 461 | stopThr=stopThr, verbose=False, log=True, 462 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq, 463 | **kwargs) 464 | 465 | warm_start = {} 466 | warm_start['alpha'] = _log['alpha'] 467 | warm_start['beta'] = _log['beta'] 468 | 469 | primal_val = (C * P).sum() + reg * (P * torch.log(P)).sum() - reg * P.sum() 470 | dual_val = (_log['alpha'] * a).sum() + (_log['beta'] * b).sum() - reg * P.sum() 471 | err = primal_val - dual_val 472 | log['err'].append(err) 473 | 474 | if verbose and it % print_freq == 0: 475 | print('iteration {:5d}, constraint error {:5e}'.format(it, err)) 476 | 477 | it += 1 478 | 479 | if log: 480 | log['alpha'] = _log['alpha'] 481 | log['beta'] = _log['beta'] 482 | return P, log 483 | else: 484 | return P 485 | -------------------------------------------------------------------------------- /network/pvt_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | 10 | import math 11 | from torch.distributions.uniform import Uniform 12 | import numpy as np 13 | import random 14 | 15 | __all__ = [ 16 | 'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large' 17 | ] 18 | 19 | 20 | class SELayer(nn.Module): 21 | def __init__(self, channel, reduction=16): 22 | super(SELayer, self).__init__() 23 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 24 | self.fc = nn.Sequential( 25 | nn.Linear(channel, channel // reduction, bias=False), 26 | nn.ReLU(inplace=True), 27 | nn.Linear(channel // reduction, channel, bias=False), 28 | nn.Sigmoid() 29 | ) 30 | 31 | def forward(self, x): 32 | b, c, _, _ = x.size() 33 | y = self.avg_pool(x).view(b, c) 34 | y = self.fc(y).view(b, c, 1, 1) 35 | return x * y.expand_as(x) 36 | 37 | 38 | class Regression(nn.Module): 39 | def __init__(self): 40 | super(Regression, self).__init__() 41 | 42 | self.v1 = nn.Sequential( 43 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 44 | nn.Conv2d(256, 128, 3, padding=1, dilation=1), 45 | nn.BatchNorm2d(128), nn.ReLU(inplace=True)) 46 | 47 | self.v2 = nn.Sequential( 48 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 49 | nn.Conv2d(512, 256, 3, padding=1, dilation=1), 50 | nn.BatchNorm2d(256), nn.ReLU(inplace=True)) 51 | 52 | self.v3 = nn.Sequential( 53 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 54 | nn.Conv2d(1024, 512, 3, padding=1, dilation=1), nn.BatchNorm2d(512), 55 | nn.ReLU(inplace=True)) 56 | 57 | self.ca2 = nn.Sequential(ChannelAttention(512), 58 | nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1 ), 59 | nn.BatchNorm2d(512), nn.ReLU(inplace=True)) 60 | 61 | self.ca1 = nn.Sequential(ChannelAttention(256), 62 | nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1 ), 63 | nn.BatchNorm2d(256), nn.ReLU(inplace=True)) 64 | 65 | self.ca0 = nn.Sequential(ChannelAttention(128), 66 | nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1 ), 67 | nn.BatchNorm2d(128), nn.ReLU(inplace=True)) 68 | 69 | self.res2 = nn.Sequential( 70 | nn.Conv2d(512, 256, 3, padding=1, dilation=1), nn.BatchNorm2d(256), 71 | nn.ReLU(inplace=True), 72 | nn.Conv2d(256, 128, 3, padding=1, dilation=1), nn.BatchNorm2d(128), 73 | nn.ReLU(inplace=True), 74 | nn.Conv2d(128, 1, 3, padding=1, dilation=1), 75 | nn.ReLU(inplace=True)) 76 | 77 | self.res1 = nn.Sequential( 78 | nn.Conv2d(256, 128, 3, padding=1, dilation=1), nn.BatchNorm2d(128), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(128, 64, 3, padding=1, dilation=1), nn.BatchNorm2d(64), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(64, 1, 3, padding=1, dilation=1), 83 | nn.ReLU(inplace=True)) 84 | 85 | self.res0 = nn.Sequential( 86 | nn.Conv2d(128, 64, 3, padding=1, dilation=1), nn.BatchNorm2d(64), 87 | nn.ReLU(inplace=True), 88 | nn.Conv2d(64, 1, 3, padding=1, dilation=1), 89 | nn.ReLU(inplace=True)) 90 | 91 | self.noise2 = DropOutDecoder(1, 512, 512) 92 | self.noise1 = FeatureDropDecoder(1, 256, 256) 93 | self.noise0 = FeatureNoiseDecoder(1, 128, 128) 94 | 95 | self.upsam2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 96 | self.upsam4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 97 | 98 | self.conv1 = nn.Conv2d(1024, 512, kernel_size=1, bias=False) 99 | self.conv2 = nn.Conv2d(512, 256, kernel_size=1, bias=False) 100 | self.conv3 = nn.Conv2d(256, 128, kernel_size=1, bias=False) 101 | self.conv4 = nn.Conv2d(128, 1, kernel_size=1, bias=False) 102 | 103 | #cls2.view(8, 1024, 1, 1)) 104 | 105 | self.init_param() 106 | 107 | def forward(self, x, cls): 108 | x0 = x[0]; x1 = x[1]; x2 = x[2]; x3 = x[3] 109 | cls0 = cls[0].view(cls[0].shape[0], cls[0].shape[1], 1, 1) 110 | cls1 = cls[1].view(cls[1].shape[0], cls[1].shape[1], 1, 1) 111 | cls2 = cls[2].view(cls[2].shape[0], cls[2].shape[1], 1, 1) 112 | 113 | x2_1 = self.ca2(x2)+self.v3(x3) 114 | x1_1 = self.ca1(x1)+self.v2(x2_1) 115 | x0_1 = self.ca0(x0)+self.v1(x1_1) 116 | 117 | if self.training: 118 | yc2 = self.conv4(self.conv3(self.conv2(self.noise2(self.conv1(cls2))))).squeeze() 119 | yc1 = self.conv4(self.conv3(self.noise1(self.conv2(cls1)))).squeeze() 120 | yc0 = self.conv4(self.noise0(self.conv3(cls0))).squeeze() 121 | 122 | y2 = self.res2(self.upsam4(self.noise2(x2_1))) 123 | y1 = self.res1(self.upsam2(self.noise1(x1_1))) 124 | y0 = self.res0(self.noise0(x0_1)) 125 | 126 | else: 127 | yc2 = self.conv4(self.conv3(self.conv2(self.conv1(cls2)))).squeeze() 128 | yc1 = self.conv4(self.conv3(self.conv2(cls1))).squeeze() 129 | yc0 = self.conv4(self.conv3(cls0)).squeeze() 130 | 131 | y2 = self.res2(self.upsam4(x2_1)) 132 | y1 = self.res1(self.upsam2(x1_1)) 133 | y0 = self.res0(x0_1) 134 | 135 | return [y0, y1, y2], [yc0, yc1, yc2] 136 | 137 | def init_param(self): 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | nn.init.normal_(m.weight, std=0.01) 141 | if m.bias is not None: 142 | nn.init.constant_(m.bias, 0) 143 | elif isinstance(m, nn.BatchNorm2d): 144 | nn.init.constant_(m.weight, 1) 145 | nn.init.constant_(m.bias, 0) 146 | 147 | 148 | class Mlp(nn.Module): 149 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 150 | super().__init__() 151 | out_features = out_features or in_features 152 | hidden_features = hidden_features or in_features 153 | self.fc1 = nn.Linear(in_features, hidden_features) 154 | self.act = act_layer() 155 | self.fc2 = nn.Linear(hidden_features, out_features) 156 | self.drop = nn.Dropout(drop) 157 | 158 | def forward(self, x): 159 | x = self.fc1(x) 160 | x = self.act(x) 161 | x = self.drop(x) 162 | x = self.fc2(x) 163 | x = self.drop(x) 164 | return x 165 | 166 | 167 | 168 | def upsample(in_channels, out_channels, upscale, kernel_size=3): 169 | # A series of x 2 upsamling until we get to the upscale we want 170 | layers = [] 171 | conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 172 | nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu') 173 | layers.append(conv1x1) 174 | for i in range(int(math.log(upscale, 2))): 175 | layers.append(PixelShuffle(out_channels, scale=2)) 176 | return nn.Sequential(*layers) 177 | 178 | 179 | 180 | class FeatureDropDecoder(nn.Module): 181 | def __init__(self, upscale, conv_in_ch, num_classes): 182 | super(FeatureDropDecoder, self).__init__() 183 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) 184 | 185 | def feature_dropout(self, x): 186 | attention = torch.mean(x, dim=1, keepdim=True) 187 | max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True) 188 | threshold = max_val * np.random.uniform(0.7, 0.9) 189 | threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention) 190 | drop_mask = (attention < threshold).float() 191 | return x.mul(drop_mask) 192 | 193 | def forward(self, x): 194 | x = self.feature_dropout(x) 195 | return x 196 | 197 | 198 | class FeatureNoiseDecoder(nn.Module): 199 | def __init__(self, upscale, conv_in_ch, num_classes, uniform_range=0.3): 200 | super(FeatureNoiseDecoder, self).__init__() 201 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) 202 | self.uni_dist = Uniform(-uniform_range, uniform_range) 203 | 204 | def feature_based_noise(self, x): 205 | noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0) 206 | x_noise = x.mul(noise_vector) + x 207 | return x_noise 208 | 209 | def forward(self, x): 210 | x = self.feature_based_noise(x) 211 | return x 212 | 213 | class DropOutDecoder(nn.Module): 214 | def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True): 215 | super(DropOutDecoder, self).__init__() 216 | self.dropout = nn.Dropout2d(p=drop_rate) if spatial_dropout else nn.Dropout(drop_rate) 217 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale) 218 | 219 | def forward(self, x): 220 | x = self.dropout(x) 221 | return x 222 | 223 | 224 | ## ChannelAttetion 225 | class ChannelAttention(nn.Module): 226 | def __init__(self, in_planes, ratio=16): 227 | super(ChannelAttention, self).__init__() 228 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 229 | 230 | self.fc = nn.Sequential( 231 | nn.Linear(in_planes,in_planes // ratio, bias = False), 232 | nn.ReLU(inplace = True), 233 | nn.Linear(in_planes // ratio, in_planes, bias = False) 234 | ) 235 | self.sigmoid = nn.Sigmoid() 236 | for m in self.modules(): 237 | if isinstance(m, nn.Conv2d): 238 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 239 | 240 | def forward(self, in_feature): 241 | x = in_feature 242 | b, c, _, _ = in_feature.size() 243 | avg_out = self.fc(self.avg_pool(x).view(b,c)).view(b, c, 1, 1) 244 | out = avg_out 245 | return self.sigmoid(out).expand_as(in_feature) * in_feature 246 | 247 | 248 | class Attention(nn.Module): 249 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 250 | super().__init__() 251 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 252 | 253 | self.dim = dim 254 | self.num_heads = num_heads 255 | head_dim = dim // num_heads 256 | self.scale = qk_scale or head_dim ** -0.5 257 | 258 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 259 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 260 | self.attn_drop = nn.Dropout(attn_drop) 261 | self.proj = nn.Linear(dim, dim) 262 | self.proj_drop = nn.Dropout(proj_drop) 263 | 264 | self.sr_ratio = sr_ratio 265 | if sr_ratio > 1: 266 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 267 | self.norm = nn.LayerNorm(dim) 268 | 269 | def forward(self, x, H, W): 270 | B, N, C = x.shape 271 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 272 | 273 | if self.sr_ratio > 4: 274 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 275 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 276 | x_ = self.norm(x_) 277 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 278 | else: 279 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 280 | k, v = kv[0], kv[1] 281 | 282 | attn = (q @ k.transpose(-2, -1)) * self.scale 283 | attn = attn.softmax(dim=-1) 284 | attn = self.attn_drop(attn) 285 | 286 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 287 | x = self.proj(x) 288 | x = self.proj_drop(x) 289 | 290 | return x 291 | 292 | 293 | class Block(nn.Module): 294 | 295 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 296 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 297 | super().__init__() 298 | self.norm1 = norm_layer(dim) 299 | self.attn = Attention( 300 | dim, 301 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 302 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 303 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 304 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 305 | self.norm2 = norm_layer(dim) 306 | mlp_hidden_dim = int(dim * mlp_ratio) 307 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 308 | 309 | def forward(self, x, H, W): 310 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 311 | x = x + self.drop_path(self.mlp(self.norm2(x))) 312 | 313 | return x 314 | 315 | 316 | class PatchEmbed(nn.Module): 317 | """ Image to Patch Embedding 318 | """ 319 | 320 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 321 | super().__init__() 322 | img_size = to_2tuple(img_size) 323 | patch_size = to_2tuple(patch_size) 324 | 325 | self.img_size = img_size 326 | self.patch_size = patch_size 327 | # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ 328 | # f"img_size {img_size} should be divided by patch_size {patch_size}." 329 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 330 | self.num_patches = self.H * self.W 331 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 332 | self.norm = nn.LayerNorm(embed_dim) 333 | 334 | def forward(self, x): 335 | B, C, H, W = x.shape 336 | 337 | x = self.proj(x).flatten(2).transpose(1, 2) 338 | x = self.norm(x) 339 | H, W = H // self.patch_size[0], W // self.patch_size[1] 340 | 341 | return x, (H, W) 342 | 343 | 344 | class PyramidVisionTransformer(nn.Module): 345 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 346 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 347 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 348 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4): 349 | super().__init__() 350 | self.num_classes = num_classes 351 | self.depths = depths 352 | self.num_stages = num_stages 353 | 354 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 355 | cur = 0 356 | 357 | for i in range(num_stages): 358 | patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 359 | patch_size=patch_size if i == 0 else 2, 360 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 361 | embed_dim=embed_dims[i]) 362 | num_patches = patch_embed.num_patches if i == 0 else patch_embed.num_patches + 1 363 | pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i])) 364 | pos_drop = nn.Dropout(p=drop_rate) 365 | 366 | block = nn.ModuleList([Block( 367 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, 368 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], 369 | norm_layer=norm_layer, sr_ratio=sr_ratios[i]) 370 | for j in range(depths[i])]) 371 | cur += depths[i] 372 | 373 | setattr(self, f"patch_embed{i + 1}", patch_embed) 374 | setattr(self, f"pos_embed{i + 1}", pos_embed) 375 | setattr(self, f"pos_drop{i + 1}", pos_drop) 376 | setattr(self, f"block{i + 1}", block) 377 | 378 | self.norm = norm_layer(embed_dims[3]) 379 | 380 | # cls_token 381 | self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, embed_dims[1])) 382 | self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, embed_dims[2])) 383 | self.cls_token_3 = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) 384 | 385 | # classification head 386 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 387 | 388 | 389 | self.regression = Regression() 390 | 391 | # init weights 392 | for i in range(num_stages): 393 | pos_embed = getattr(self, f"pos_embed{i + 1}") 394 | trunc_normal_(pos_embed, std=.02) 395 | trunc_normal_(self.cls_token_1, std=.02) 396 | trunc_normal_(self.cls_token_2, std=.02) 397 | trunc_normal_(self.cls_token_3, std=.02) 398 | self.apply(self._init_weights) 399 | 400 | 401 | def _init_weights(self, m): 402 | if isinstance(m, nn.Linear): 403 | trunc_normal_(m.weight, std=.02) 404 | if isinstance(m, nn.Linear) and m.bias is not None: 405 | nn.init.constant_(m.bias, 0) 406 | elif isinstance(m, nn.LayerNorm): 407 | nn.init.constant_(m.bias, 0) 408 | nn.init.constant_(m.weight, 1.0) 409 | 410 | @torch.jit.ignore 411 | def no_weight_decay(self): 412 | # return {'pos_embed', 'cls_token'} # has pos_embed may be better 413 | return {'cls_token'} 414 | 415 | def get_classifier(self): 416 | return self.head 417 | 418 | def reset_classifier(self, num_classes, global_pool=''): 419 | self.num_classes = num_classes 420 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 421 | 422 | def _get_pos_embed(self, pos_embed, patch_embed, H, W): 423 | if H * W == self.patch_embed1.num_patches: 424 | return pos_embed 425 | else: 426 | return F.interpolate( 427 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 428 | size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) 429 | 430 | def forward_features(self, x): 431 | B = x.shape[0] 432 | outputs = list() 433 | cls_output = list() 434 | 435 | for i in range(self.num_stages): 436 | patch_embed = getattr(self, f"patch_embed{i + 1}") 437 | pos_embed = getattr(self, f"pos_embed{i + 1}") 438 | pos_drop = getattr(self, f"pos_drop{i + 1}") 439 | block = getattr(self, f"block{i + 1}") 440 | x, (H, W) = patch_embed(x) 441 | 442 | if i == 0: 443 | pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W) 444 | elif i == 1: 445 | cls_tokens = self.cls_token_1.expand(B, -1, -1) 446 | x = torch.cat((cls_tokens, x), dim=1) 447 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) 448 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1) 449 | elif i == 2: 450 | cls_tokens = self.cls_token_2.expand(B, -1, -1) 451 | x = torch.cat((cls_tokens, x), dim=1) 452 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) 453 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1) 454 | 455 | elif i == 3: 456 | cls_tokens = self.cls_token_3.expand(B, -1, -1) 457 | x = torch.cat((cls_tokens, x), dim=1) 458 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) 459 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1) 460 | 461 | 462 | x = pos_drop(x + pos_embed) 463 | for blk in block: 464 | x = blk(x, H, W) 465 | 466 | if i == 0: 467 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 468 | else: 469 | 470 | x_cls = x[:,1,:] 471 | x = x[:,1:,:].reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 472 | cls_output.append(x_cls) 473 | 474 | 475 | outputs.append(x) 476 | return outputs, cls_output 477 | 478 | 479 | def forward(self, label_x, unlabel_x=None): 480 | 481 | if self.training: 482 | # labeled image processing 483 | label_x, l_cls = self.forward_features(label_x) 484 | out_label_x, out_cls_l = self.regression(label_x, l_cls) 485 | label_x_1, label_x_2, label_x_3 = out_label_x 486 | 487 | B,C,H,W = label_x_1.size() 488 | label_sum = label_x_1.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3) 489 | label_normed = label_x_1 / (label_sum + 1e-6) 490 | 491 | # unlabeled image processing 492 | B,C,H,W = unlabel_x.shape 493 | unlabel_x, ul_cls = self.forward_features(unlabel_x) 494 | out_unlabel_x, out_cls_ul = self.regression(unlabel_x, ul_cls) 495 | y0, y1, y2 = out_unlabel_x 496 | 497 | unlabel_x_1 = self.generate_feature_patches(y0) 498 | unlabel_x_2 = self.generate_feature_patches(y1) 499 | unlabel_x_3 = self.generate_feature_patches(y2) 500 | 501 | assert unlabel_x_1.shape[0] == B * 5 502 | assert unlabel_x_2.shape[0] == B * 5 503 | assert unlabel_x_3.shape[0] == B * 5 504 | 505 | unlabel_x_1 = torch.split(unlabel_x_1, split_size_or_sections=B, dim=0) 506 | unlabel_x_2 = torch.split(unlabel_x_2, split_size_or_sections=B, dim=0) 507 | unlabel_x_3 = torch.split(unlabel_x_3, split_size_or_sections=B, dim=0) 508 | 509 | return [label_x_1, label_x_2, label_x_3], [unlabel_x_1, unlabel_x_2, unlabel_x_3], label_normed, out_cls_l, out_cls_ul 510 | 511 | 512 | else: 513 | 514 | label_x, l_cls = self.forward_features(label_x) 515 | out_label_x, out_cls_l = self.regression(label_x, l_cls) 516 | label_x_1, label_x_2, label_x_3 = out_label_x 517 | B,C,H,W = label_x_1.size() 518 | label_sum = label_x_1.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3) 519 | label_normed = label_x_1 / (label_sum + 1e-6) 520 | 521 | return [label_x_1, label_x_2, label_x_3], label_normed 522 | 523 | 524 | def generate_feature_patches(self, unlabel_x, ratio=0.75): 525 | # unlabeled image processing 526 | 527 | unlabel_x_1 = unlabel_x 528 | b, c, h, w = unlabel_x.shape 529 | 530 | center_x = random.randint(h // 2 - (h - h * ratio) // 2, h // 2 + (h - h * ratio) // 2) 531 | center_y = random.randint(w // 2 - (w - w * ratio) // 2, w // 2 + (w - w * ratio) // 2) 532 | 533 | new_h2 = int(h * ratio) 534 | new_w2 = int(w * ratio) # 48*48 535 | unlabel_x_2 = unlabel_x[:, :, center_x - new_h2 // 2:center_x + new_h2 // 2, 536 | center_y - new_w2 // 2:center_y + new_w2 // 2] 537 | 538 | new_h3 = int(new_h2 * ratio) 539 | new_w3 = int(new_w2 * ratio) 540 | unlabel_x_3 = unlabel_x[:, :, center_x - new_h3 // 2:center_x + new_h3 // 2, 541 | center_y - new_w3 // 2:center_y + new_w3 // 2] 542 | 543 | new_h4 = int(new_h3 * ratio) 544 | new_w4 = int(new_w3 * ratio) 545 | unlabel_x_4 = unlabel_x[:, :, center_x - new_h4 // 2:center_x + new_h4 // 2, 546 | center_y - new_w4 // 2:center_y + new_w4 // 2] 547 | 548 | new_h5 = int(new_h4 * ratio) 549 | new_w5 = int(new_w4 * ratio) 550 | unlabel_x_5 = unlabel_x[:, :, center_x - new_h5 // 2:center_x + new_h5 // 2, 551 | center_y - new_w5 // 2:center_y + new_w5 // 2] 552 | 553 | unlabel_x_2 = nn.functional.interpolate(unlabel_x_2, size=(h, w), mode='bilinear') 554 | unlabel_x_3 = nn.functional.interpolate(unlabel_x_3, size=(h, w), mode='bilinear') 555 | unlabel_x_4 = nn.functional.interpolate(unlabel_x_4, size=(h, w), mode='bilinear') 556 | unlabel_x_5 = nn.functional.interpolate(unlabel_x_5, size=(h, w), mode='bilinear') 557 | 558 | unlabel_x = torch.cat([unlabel_x_1, unlabel_x_2, unlabel_x_3, unlabel_x_4, unlabel_x_5], dim=0) 559 | 560 | return unlabel_x 561 | 562 | def _conv_filter(state_dict, patch_size=16): 563 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 564 | out_dict = {} 565 | for k, v in state_dict.items(): 566 | if 'patch_embed.proj.weight' in k: 567 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 568 | out_dict[k] = v 569 | 570 | return out_dict 571 | 572 | 573 | @register_model 574 | def pvt_tiny(pretrained=False, **kwargs): 575 | model = PyramidVisionTransformer( 576 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 577 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 578 | **kwargs) 579 | model.default_cfg = _cfg() 580 | 581 | return model 582 | 583 | 584 | @register_model 585 | def pvt_small(pretrained=False, **kwargs): 586 | model = PyramidVisionTransformer( 587 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 588 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) 589 | model.default_cfg = _cfg() 590 | 591 | return model 592 | 593 | 594 | @register_model 595 | def pvt_medium(pretrained=False, **kwargs): 596 | model = PyramidVisionTransformer( 597 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 598 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 599 | **kwargs) 600 | model.default_cfg = _cfg() 601 | 602 | return model 603 | 604 | 605 | @register_model 606 | def pvt_large(pretrained=False, **kwargs): 607 | model = PyramidVisionTransformer( 608 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 609 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], 610 | **kwargs) 611 | model.default_cfg = _cfg() 612 | 613 | return model 614 | 615 | @register_model 616 | def pvt_treeformer(pretrained=False, **kwargs): 617 | model = PyramidVisionTransformer( 618 | patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 619 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], 620 | **kwargs) 621 | model.default_cfg = _cfg() 622 | 623 | return model --------------------------------------------------------------------------------