├── utils
├── __init__.py
├── log_utils.py
└── pytorch_utils.py
├── datasets
├── __init__.py
└── crowd.py
├── losses
├── __init__.py
├── ramps.py
├── multi_con_loss.py
├── rank_loss.py
├── dm_loss.py
├── ot_loss.py
├── consistency_loss.py
└── bregman_pytorch.py
├── sample_imgs
└── overview.png
├── requirements.txt
├── README.md
├── test.py
├── train.py
└── network
└── pvt_cls.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/sample_imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HAAClassic/TreeFormer/HEAD/sample_imgs/overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.21.5
2 | Pillow==9.4.0
3 | scikit_learn==1.2.2
4 | scipy==1.7.3
5 | timm==0.4.12
6 | torch==1.12.1
7 | torchvision==0.13.1
8 |
--------------------------------------------------------------------------------
/utils/log_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def get_logger(log_file):
5 | logger = logging.getLogger(log_file)
6 | logger.setLevel(logging.DEBUG)
7 | fh = logging.FileHandler(log_file)
8 | fh.setLevel(logging.DEBUG)
9 | ch = logging.StreamHandler()
10 | ch.setLevel(logging.INFO)
11 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
12 | ch.setFormatter(formatter)
13 | fh.setFormatter(formatter)
14 | logger.addHandler(ch)
15 | logger.addHandler(fh)
16 | return logger
17 |
18 |
19 | def print_config(config, logger):
20 | """
21 | Print configuration of the model
22 | """
23 | for k, v in config.items():
24 | logger.info("{}:\t{}".format(k.ljust(15), v))
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # TreeFormer
3 |
4 | This is the code base for IEEE TRANSACTIONS ON GEOSCIENCE AND REMOTE SENSING (TGRS 2023) paper ['TreeFormer: a Semi-Supervised Transformer-based Framework for Tree Counting from a Single High Resolution Image'](https://arxiv.org/abs/2307.06118)
5 |
6 |
7 |
8 | ## Installation
9 |
10 | Python ≥ 3.7.
11 |
12 | To install the required packages, please run:
13 |
14 |
15 | ```bash
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ## Dataset
20 | Download the dataset from [google drive](https://drive.google.com/file/d/1xcjv8967VvvzcDM4aqAi7Corkb11T0i2/view?usp=drive_link).
21 | ## Evaluation
22 | Download our trained model on [London](https://drive.google.com/file/d/14uuOF5758sxtM5EgeGcRtSln5lUXAHge/view?usp=sharing) dataset.
23 |
24 | Modify the path to the dataset and model for evaluation in 'test.py'.
25 |
26 | Run 'test.py'
27 | ## Acknowledgements
28 |
29 | - Part of codes are borrowed from [PVT](https://github.com/whai362/PVT) and [DM Count](https://github.com/cvlab-stonybrook/DM-Count). Thanks for their great work!
30 |
31 |
32 |
--------------------------------------------------------------------------------
/losses/ramps.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Functions for ramping hyperparameters up or down
9 |
10 | Each function takes the current training step or epoch, and the
11 | ramp length in the same format, and returns a multiplier between
12 | 0 and 1.
13 | """
14 |
15 |
16 | import numpy as np
17 |
18 |
19 | def sigmoid_rampup(current, rampup_length):
20 | """Exponential rampup from https://arxiv.org/abs/1610.02242"""
21 | if rampup_length == 0:
22 | return 1.0
23 | else:
24 | current = np.clip(current, 0.0, rampup_length)
25 | phase = 1.0 - current / rampup_length
26 | return float(np.exp(-5.0 * phase * phase))
27 |
28 |
29 | def linear_rampup(current, rampup_length):
30 | """Linear rampup"""
31 | assert current >= 0 and rampup_length >= 0
32 | if current >= rampup_length:
33 | return 1.0
34 | else:
35 | return current / rampup_length
36 |
37 |
38 | def cosine_rampdown(current, rampdown_length):
39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
40 | assert 0 <= current <= rampdown_length
41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
42 |
--------------------------------------------------------------------------------
/losses/multi_con_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from losses.consistency_loss import *
5 |
6 |
7 | class MultiConLoss(nn.Module):
8 | def __init__(self):
9 | super(MultiConLoss, self).__init__()
10 | self.countloss_criterion = nn.MSELoss(reduction='sum')
11 | self.multiconloss = 0.0
12 | self.losses = {}
13 |
14 | def forward(self, unlabeled_results):
15 | self.multiconloss = 0.0
16 | self.losses = {}
17 |
18 | if unlabeled_results is None:
19 | self.multiconloss = 0.0
20 | elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0:
21 | count = 0
22 | for i in range(len(unlabeled_results[0])):
23 | with torch.set_grad_enabled(False):
24 | preds_mean = (unlabeled_results[0][i] + unlabeled_results[1][i] + unlabeled_results[2][i])/len(unlabeled_results)
25 | for j in range(len(unlabeled_results)):
26 |
27 | var_sel = softmax_kl_loss(unlabeled_results[j][i], preds_mean)
28 | exp_var = torch.exp(-var_sel)
29 | consistency_dist = (preds_mean - unlabeled_results[j][i]) ** 2
30 | temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
31 |
32 | self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
33 | self.multiconloss += temploss
34 |
35 | count += 1
36 | if count > 0:
37 | self.multiconloss = self.multiconloss / count
38 |
39 |
40 | return self.multiconloss
41 |
42 |
--------------------------------------------------------------------------------
/utils/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10):
4 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
5 | lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6)
6 | for param_group in optimizer.param_groups:
7 | param_group['lr'] = lr
8 |
9 |
10 | class Save_Handle(object):
11 | """handle the number of """
12 | def __init__(self, max_num):
13 | self.save_list = []
14 | self.max_num = max_num
15 |
16 | def append(self, save_path):
17 | if len(self.save_list) < self.max_num:
18 | self.save_list.append(save_path)
19 | else:
20 | remove_path = self.save_list[0]
21 | del self.save_list[0]
22 | self.save_list.append(save_path)
23 | if os.path.exists(remove_path):
24 | os.remove(remove_path)
25 |
26 |
27 | class AverageMeter(object):
28 | """Computes and stores the average and current value"""
29 | def __init__(self):
30 | self.reset()
31 |
32 | def reset(self):
33 | self.val = 0
34 | self.avg = 0
35 | self.sum = 0
36 | self.count = 0
37 |
38 | def update(self, val, n=1):
39 | self.val = val
40 | self.sum += val * n
41 | self.count += n
42 | self.avg = 1.0 * self.sum / self.count
43 |
44 | def get_avg(self):
45 | return self.avg
46 |
47 | def get_count(self):
48 | return self.count
49 |
50 |
51 | def set_trainable(model, requires_grad):
52 | for param in model.parameters():
53 | param.requires_grad = requires_grad
54 |
55 |
56 |
57 | def get_num_params(model):
58 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
--------------------------------------------------------------------------------
/losses/rank_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MarginRankLoss(nn.Module):
7 | def __init__(self):
8 | super(MarginRankLoss, self).__init__()
9 | self.loss = 0.0
10 |
11 | def forward(self, img_list, margin=0):
12 | length = len(img_list)
13 | self.loss = 0.0
14 | B, C, H, W = img_list[0].shape
15 | for i in range(length - 1):
16 | for j in range(i + 1, length):
17 | self.loss = self.loss + torch.sum(F.relu(img_list[j].sum(-1).sum(-1).sum(-1) - img_list[i].sum(-1).sum(-1).sum(-1) + margin))
18 |
19 | self.loss = self.loss / (B*length*(length-1)/2)
20 | return self.loss
21 |
22 |
23 | class RankLoss(nn.Module):
24 | def __init__(self):
25 | super(RankLoss, self).__init__()
26 | self.countloss_criterion = nn.MSELoss(reduction='sum')
27 | self.rankloss_criterion = MarginRankLoss()
28 | self.rankloss = 0.0
29 | self.losses = {}
30 |
31 | def forward(self, unlabeled_results):
32 | self.rankloss = 0.0
33 | self.losses = {}
34 |
35 |
36 | if unlabeled_results is None:
37 | self.rankloss = 0.0
38 | elif isinstance(unlabeled_results, tuple) and len(unlabeled_results) > 0:
39 | self.rankloss = self.rankloss_criterion(unlabeled_results)
40 | elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0:
41 | count = 0
42 | for i in range(len(unlabeled_results)):
43 | if isinstance(unlabeled_results[i], tuple) and len(unlabeled_results[i]) > 0:
44 | temploss = self.rankloss_criterion(unlabeled_results[i])
45 | self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
46 | self.rankloss += temploss
47 |
48 | count += 1
49 | if count > 0:
50 | self.rankloss = self.rankloss / count
51 |
52 | return self.rankloss
53 |
54 |
--------------------------------------------------------------------------------
/losses/dm_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from losses.consistency_loss import *
5 | from losses.ot_loss import OT_Loss
6 |
7 | class DMLoss(nn.Module):
8 | def __init__(self):
9 | super(DMLoss, self).__init__()
10 | self.DMLoss = 0.0
11 | self.losses = {}
12 |
13 | def forward(self, results, points, gt_discrete):
14 | self.DMLoss = 0.0
15 | self.losses = {}
16 |
17 | if results is None:
18 | self.DMLoss = 0.0
19 | elif isinstance(results, list) and len(results) > 0:
20 | count = 0
21 | for i in range(len(results[0])):
22 | with torch.set_grad_enabled(False):
23 | preds_mean = (results[0][i])/len(results[0][0][0])
24 |
25 | for j in range(len(results)):
26 | var_sel = softmax_kl_loss(results[j][i], preds_mean)
27 | exp_var = torch.exp(-var_sel)
28 | consistency_dist = (preds_mean - results[j][i]) ** 2
29 | temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
30 |
31 | self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
32 | self.DMLoss += temploss
33 |
34 | # Compute counting loss.
35 | count_loss = self.mae(outputs_L[0].sum(1).sum(1).sum(1),
36 | torch.from_numpy(gd_count).float().to(self.device))*self.args.reg
37 | epoch_count_loss.update(count_loss.item(), N)
38 |
39 | # Compute OT loss.
40 | ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L[0], points)
41 |
42 | ot_loss = ot_loss * self.args.ot
43 | ot_obj_value = ot_obj_value * self.args.ot
44 | epoch_ot_loss.update(ot_loss.item(), N)
45 | epoch_ot_obj_value.update(ot_obj_value.item(), N)
46 | epoch_wd.update(wd, N)
47 |
48 | gd_count_tensor = (torch.from_numpy(gd_count).float()
49 | .to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))
50 |
51 | gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
52 | tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)*
53 | torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv
54 | epoch_tv_loss.update(tv_loss.item(), N)
55 |
56 | count += 1
57 | if count > 0:
58 | self.multiconloss = self.multiconloss / count
59 |
60 |
61 | return self.multiconloss
62 |
63 |
--------------------------------------------------------------------------------
/losses/ot_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Module
3 | from .bregman_pytorch import sinkhorn
4 |
5 | class OT_Loss(Module):
6 | def __init__(self, c_size, stride, norm_cood, device, num_of_iter_in_ot=100, reg=10.0):
7 | super(OT_Loss, self).__init__()
8 | assert c_size % stride == 0
9 |
10 | self.c_size = c_size
11 | self.device = device
12 | self.norm_cood = norm_cood
13 | self.num_of_iter_in_ot = num_of_iter_in_ot
14 | self.reg = reg
15 |
16 | # coordinate is same to image space, set to constant since crop size is same
17 | self.cood = torch.arange(0, c_size, step=stride,
18 | dtype=torch.float32, device=device) + stride / 2
19 | self.density_size = self.cood.size(0)
20 | self.cood.unsqueeze_(0) # [1, #cood]
21 | if self.norm_cood:
22 | self.cood = self.cood / c_size * 2 - 1 # map to [-1, 1]
23 | self.output_size = self.cood.size(1)
24 |
25 |
26 | def forward(self, normed_density, unnormed_density, points):
27 | batch_size = normed_density.size(0)
28 | assert len(points) == batch_size
29 | assert self.output_size == normed_density.size(2)
30 | loss = torch.zeros([1]).to(self.device)
31 | ot_obj_values = torch.zeros([1]).to(self.device)
32 | wd = 0 # wasserstain distance
33 | for idx, im_points in enumerate(points):
34 | if len(im_points) > 0:
35 | # compute l2 square distance, it should be source target distance. [#gt, #cood * #cood]
36 | if self.norm_cood:
37 | im_points = im_points / self.c_size * 2 - 1 # map to [-1, 1]
38 | x = im_points[:, 0].unsqueeze_(1) # [#gt, 1]
39 | y = im_points[:, 1].unsqueeze_(1)
40 | x_dis = -2 * torch.matmul(x, self.cood) + x * x + self.cood * self.cood # [#gt, #cood]
41 | y_dis = -2 * torch.matmul(y, self.cood) + y * y + self.cood * self.cood
42 | y_dis.unsqueeze_(2)
43 | x_dis.unsqueeze_(1)
44 | dis = y_dis + x_dis
45 | dis = dis.view((dis.size(0), -1)) # size of [#gt, #cood * #cood]
46 |
47 | source_prob = normed_density[idx][0].view([-1]).detach()
48 | target_prob = (torch.ones([len(im_points)]) / len(im_points)).to(self.device)
49 |
50 | # use sinkhorn to solve OT, compute optimal beta.
51 | P, log = sinkhorn(target_prob, source_prob, dis, self.reg, maxIter=self.num_of_iter_in_ot, log=True)
52 | beta = log['beta'] # size is the same as source_prob: [#cood * #cood]
53 | ot_obj_values += torch.sum(normed_density[idx] * beta.view([1, self.output_size, self.output_size]))
54 | # compute the gradient of OT loss to predicted density (unnormed_density).
55 | # im_grad = beta / source_count - < beta, source_density> / (source_count)^2
56 | source_density = unnormed_density[idx][0].view([-1]).detach()
57 | source_count = source_density.sum()
58 | im_grad_1 = (source_count) / (source_count * source_count+1e-8) * beta # size of [#cood * #cood]
59 | im_grad_2 = (source_density * beta).sum() / (source_count * source_count + 1e-8) # size of 1
60 | im_grad = im_grad_1 - im_grad_2
61 | im_grad = im_grad.detach().view([1, self.output_size, self.output_size])
62 | # Define loss = . The gradient of loss w.r.t prediced density is im_grad.
63 | loss += torch.sum(unnormed_density[idx] * im_grad)
64 | wd += torch.sum(dis * P).item()
65 |
66 | return loss, wd, ot_obj_values
67 |
68 |
69 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import numpy as np
5 | import datasets.crowd_gauss as crowd
6 | from network import pvt_cls as TCN
7 | import torch.nn.functional as F
8 | from scipy.io import savemat
9 | from sklearn.metrics import r2_score
10 |
11 | parser = argparse.ArgumentParser(description='Test ')
12 | parser.add_argument('--device', default='0', help='assign device')
13 | parser.add_argument('--batch-size', type=int, default=8, help='train batch size')
14 | parser.add_argument('--crop-size', type=int, default=256, help='the crop size of the train image')
15 | parser.add_argument('--model-path', type=str, default='/scratch/users/k2254235/ckpts/SEMI/Treeformer/best_model_mae-21.49_epoch-1759.pth', help='saved model path')
16 | parser.add_argument('--data-path', type=str, default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='dataset path')
17 | parser.add_argument('--dataset', type=str, default='TC')
18 |
19 | def test(args, isSave = True):
20 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu
21 | device = torch.device('cuda')
22 |
23 | model_path = args.model_path
24 | crop_size = args.crop_size
25 | data_path = args.data_path
26 |
27 | dataset = crowd.Crowd_TC(os.path.join(data_path, 'test_data'), crop_size, 1, method='val')
28 | dataloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, num_workers=1, pin_memory=True)
29 |
30 | model = TCN.pvt_treeformer(pretrained=False)
31 | model.to(device)
32 | model.load_state_dict(torch.load(model_path, device))
33 | model.eval()
34 | image_errs = []
35 | result = []
36 | R2_es = []
37 | R2_gt = []
38 | l=0;
39 | for inputs, count, name, imgauss in dataloader:
40 | with torch.no_grad():
41 | inputs = inputs.to(device)
42 | crop_imgs, crop_masks = [], []
43 | b, c, h, w = inputs.size()
44 | rh, rw = args.crop_size, args.crop_size
45 |
46 | for i in range(0, h, rh):
47 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
48 |
49 | for j in range(0, w, rw):
50 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
51 | crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
52 | mask = torch.zeros([b, 1, h, w]).to(device)
53 | mask[:, :, gis:gie, gjs:gje].fill_(1.0)
54 | crop_masks.append(mask)
55 | crop_imgs, crop_masks = map(lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks))
56 |
57 | crop_preds = []
58 | nz, bz = crop_imgs.size(0), args.batch_size
59 | for i in range(0, nz, bz):
60 |
61 | gs, gt = i, min(nz, i + bz)
62 | crop_pred, _ = model(crop_imgs[gs:gt])
63 | crop_pred = crop_pred[0]
64 |
65 | _, _, h1, w1 = crop_pred.size()
66 | crop_pred = F.interpolate(crop_pred, size=(h1 * 4, w1 * 4), mode='bilinear', align_corners=True) / 16
67 | crop_preds.append(crop_pred)
68 | crop_preds = torch.cat(crop_preds, dim=0)
69 | #import pdb;pdb.set_trace()
70 |
71 | # splice them to the original size
72 | idx = 0
73 | pred_map = torch.zeros([b, 1, h, w]).to(device)
74 | for i in range(0, h, rh):
75 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
76 | for j in range(0, w, rw):
77 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
78 | pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
79 | idx += 1
80 | # for the overlapping area, compute average value
81 | mask = crop_masks.sum(dim=0).unsqueeze(0)
82 | outputs = pred_map / mask
83 |
84 | outputs = F.interpolate(outputs, size=(h, w), mode='bilinear', align_corners=True)/4
85 | outputs = pred_map / mask
86 |
87 | img_err = count[0].item() - torch.sum(outputs).item()
88 | R2_gt.append(count[0].item())
89 | R2_es.append(torch.sum(outputs).item())
90 |
91 | print("Img name: ", name, "Error: ", img_err, "GT count: ", count[0].item(), "Model out: ", torch.sum(outputs).item())
92 | image_errs.append(img_err)
93 | result.append([name, count[0].item(), torch.sum(outputs).item(), img_err])
94 |
95 | savemat('predictions/'+name[0]+'.mat', {'estimation':np.squeeze(outputs.cpu().data.numpy()),
96 | 'image': np.squeeze(inputs.cpu().data.numpy()), 'gt': np.squeeze(imgauss.cpu().data.numpy())})
97 | l=l+1
98 |
99 |
100 | image_errs = np.array(image_errs)
101 |
102 | mse = np.sqrt(np.mean(np.square(image_errs)))
103 | mae = np.mean(np.abs(image_errs))
104 | R_2 = r2_score(R2_gt,R2_es)
105 |
106 | print('{}: mae {}, mse {}, R2 {}\n'.format(model_path, mae, mse,R_2))
107 |
108 | if isSave:
109 | with open("test.txt","w") as f:
110 | for i in range(len(result)):
111 | f.write(str(result[i]).replace('[','').replace(']','').replace(',', ' ')+"\n")
112 | f.close()
113 |
114 | if __name__ == '__main__':
115 | args = parser.parse_args()
116 | test(args, isSave= True)
117 |
118 |
--------------------------------------------------------------------------------
/datasets/crowd.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch.utils.data as data
3 | import os
4 | from glob import glob
5 | import torch
6 | import torchvision.transforms.functional as F
7 | from torchvision import transforms
8 | import random
9 | import numpy as np
10 | import scipy.io as sio
11 |
12 | def random_crop(im_h, im_w, crop_h, crop_w):
13 | res_h = im_h - crop_h
14 | res_w = im_w - crop_w
15 | i = random.randint(0, res_h)
16 | j = random.randint(0, res_w)
17 | return i, j, crop_h, crop_w
18 |
19 |
20 | def gen_discrete_map(im_height, im_width, points):
21 | """
22 | func: generate the discrete map.
23 | points: [num_gt, 2], for each row: [width, height]
24 | """
25 | discrete_map = np.zeros([im_height, im_width], dtype=np.float32)
26 | h, w = discrete_map.shape[:2]
27 | num_gt = points.shape[0]
28 | if num_gt == 0:
29 | return discrete_map
30 |
31 | # fast create discrete map
32 | points_np = np.array(points).round().astype(int)
33 | p_h = np.minimum(points_np[:, 1], np.array([h-1]*num_gt).astype(int))
34 | p_w = np.minimum(points_np[:, 0], np.array([w-1]*num_gt).astype(int))
35 | p_index = torch.from_numpy(p_h* im_width + p_w).to(torch.int64)
36 | discrete_map = torch.zeros(im_width * im_height).scatter_add_(0, index=p_index, src=torch.ones(im_width*im_height)).view(im_height, im_width).numpy()
37 |
38 | ''' slow method
39 | for p in points:
40 | p = np.round(p).astype(int)
41 | p[0], p[1] = min(h - 1, p[1]), min(w - 1, p[0])
42 | discrete_map[p[0], p[1]] += 1
43 | '''
44 | assert np.sum(discrete_map) == num_gt
45 | return discrete_map
46 |
47 |
48 | class Base(data.Dataset):
49 | def __init__(self, root_path, crop_size, downsample_ratio=8):
50 |
51 | self.root_path = root_path
52 | self.c_size = crop_size
53 | self.d_ratio = downsample_ratio
54 | assert self.c_size % self.d_ratio == 0
55 | self.dc_size = self.c_size // self.d_ratio
56 | self.trans = transforms.Compose([
57 | transforms.ToTensor(),
58 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59 | ])
60 |
61 | def __len__(self):
62 | pass
63 |
64 | def __getitem__(self, item):
65 | pass
66 |
67 | def train_transform(self, img, keypoints, gauss_im):
68 | wd, ht = img.size
69 | st_size = 1.0 * min(wd, ht)
70 | assert st_size >= self.c_size
71 | assert len(keypoints) >= 0
72 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
73 | img = F.crop(img, i, j, h, w)
74 | gauss_im = F.crop(img, i, j, h, w)
75 | if len(keypoints) > 0:
76 | keypoints = keypoints - [j, i]
77 | idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
78 | (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
79 | keypoints = keypoints[idx_mask]
80 | else:
81 | keypoints = np.empty([0, 2])
82 |
83 | gt_discrete = gen_discrete_map(h, w, keypoints)
84 | down_w = w // self.d_ratio
85 | down_h = h // self.d_ratio
86 | gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
87 | assert np.sum(gt_discrete) == len(keypoints)
88 |
89 | if len(keypoints) > 0:
90 | if random.random() > 0.5:
91 | img = F.hflip(img)
92 | gauss_im = F.hflip(gauss_im)
93 | gt_discrete = np.fliplr(gt_discrete)
94 | keypoints[:, 0] = w - keypoints[:, 0]
95 | else:
96 | if random.random() > 0.5:
97 | img = F.hflip(img)
98 | gauss_im = F.hflip(gauss_im)
99 | gt_discrete = np.fliplr(gt_discrete)
100 | gt_discrete = np.expand_dims(gt_discrete, 0)
101 |
102 | return self.trans(img), gauss_im, torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(gt_discrete.copy()).float()
103 |
104 |
105 |
106 | class Crowd_TC(Base):
107 | def __init__(self, root_path, crop_size, downsample_ratio=8, method='train'):
108 | super().__init__(root_path, crop_size, downsample_ratio)
109 | self.method = method
110 | if method not in ['train', 'val']:
111 | raise Exception("not implement")
112 |
113 | self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg')))
114 |
115 | print('number of img [{}]: {}'.format(method, len(self.im_list)))
116 |
117 | def __len__(self):
118 | return len(self.im_list)
119 |
120 | def __getitem__(self, item):
121 | img_path = self.im_list[item]
122 | name = os.path.basename(img_path).split('.')[0]
123 | gd_path = os.path.join(self.root_path, 'ground_truth', 'GT_{}.mat'.format(name))
124 | img = Image.open(img_path).convert('RGB')
125 | keypoints = sio.loadmat(gd_path)['image_info'][0][0][0][0][0]
126 | gauss_path = os.path.join(self.root_path, 'ground_truth', '{}_densitymap.npy'.format(name))
127 | gauss_im = torch.from_numpy(np.load(gauss_path)).float()
128 | #import pdb;pdb.set_trace()
129 | #print("label {}", item)
130 |
131 | if self.method == 'train':
132 | return self.train_transform(img, keypoints, gauss_im)
133 | elif self.method == 'val':
134 | wd, ht = img.size
135 | st_size = 1.0 * min(wd, ht)
136 | if st_size < self.c_size:
137 | rr = 1.0 * self.c_size / st_size
138 | wd = round(wd * rr)
139 | ht = round(ht * rr)
140 | st_size = 1.0 * min(wd, ht)
141 | img = img.resize((wd, ht), Image.BICUBIC)
142 | img = self.trans(img)
143 | #import pdb;pdb.set_trace()
144 |
145 | return img, len(keypoints), name, gauss_im
146 |
147 | def train_transform(self, img, keypoints, gauss_im):
148 | wd, ht = img.size
149 | st_size = 1.0 * min(wd, ht)
150 | # resize the image to fit the crop size
151 | if st_size < self.c_size:
152 | rr = 1.0 * self.c_size / st_size
153 | wd = round(wd * rr)
154 | ht = round(ht * rr)
155 | st_size = 1.0 * min(wd, ht)
156 | img = img.resize((wd, ht), Image.BICUBIC)
157 | #gauss_im = gauss_im.resize((wd, ht), Image.BICUBIC)
158 | keypoints = keypoints * rr
159 | assert st_size >= self.c_size, print(wd, ht)
160 | assert len(keypoints) >= 0
161 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
162 | img = F.crop(img, i, j, h, w)
163 | gauss_im = F.crop(gauss_im, i, j, h, w)
164 | if len(keypoints) > 0:
165 | keypoints = keypoints - [j, i]
166 | idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \
167 | (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h)
168 | keypoints = keypoints[idx_mask]
169 | else:
170 | keypoints = np.empty([0, 2])
171 |
172 | gt_discrete = gen_discrete_map(h, w, keypoints)
173 | down_w = w // self.d_ratio
174 | down_h = h // self.d_ratio
175 | gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3))
176 | assert np.sum(gt_discrete) == len(keypoints)
177 |
178 |
179 | if len(keypoints) > 0:
180 | if random.random() > 0.5:
181 | img = F.hflip(img)
182 | gauss_im = F.hflip(gauss_im)
183 | gt_discrete = np.fliplr(gt_discrete)
184 | keypoints[:, 0] = w - keypoints[:, 0] - 1
185 | else:
186 | if random.random() > 0.5:
187 | img = F.hflip(img)
188 | gauss_im = F.hflip(gauss_im)
189 | gt_discrete = np.fliplr(gt_discrete)
190 | gt_discrete = np.expand_dims(gt_discrete, 0)
191 | #import pdb;pdb.set_trace()
192 |
193 | return self.trans(img), gauss_im, torch.from_numpy(keypoints.copy()).float(), torch.from_numpy(gt_discrete.copy()).float()
194 |
195 |
196 | class Base_UL(data.Dataset):
197 | def __init__(self, root_path, crop_size, downsample_ratio=8):
198 | self.root_path = root_path
199 | self.c_size = crop_size
200 | self.d_ratio = downsample_ratio
201 | assert self.c_size % self.d_ratio == 0
202 | self.dc_size = self.c_size // self.d_ratio
203 | self.trans = transforms.Compose([
204 | transforms.ToTensor(),
205 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
206 | ])
207 |
208 | def __len__(self):
209 | pass
210 |
211 | def __getitem__(self, item):
212 | pass
213 |
214 | def train_transform_ul(self, img):
215 | wd, ht = img.size
216 | st_size = 1.0 * min(wd, ht)
217 | assert st_size >= self.c_size
218 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
219 | img = F.crop(img, i, j, h, w)
220 |
221 | if random.random() > 0.5:
222 | img = F.hflip(img)
223 |
224 | return self.trans(img)
225 |
226 |
227 | class Crowd_UL_TC(Base_UL):
228 | def __init__(self, root_path, crop_size, downsample_ratio=8, method='train_ul'):
229 | super().__init__(root_path, crop_size, downsample_ratio)
230 | self.method = method
231 | if method not in ['train_ul']:
232 | raise Exception("not implement")
233 |
234 | self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg')))
235 | print('number of img [{}]: {}'.format(method, len(self.im_list)))
236 |
237 | def __len__(self):
238 | return len(self.im_list)
239 |
240 | def __getitem__(self, item):
241 | img_path = self.im_list[item]
242 | name = os.path.basename(img_path).split('.')[0]
243 | img = Image.open(img_path).convert('RGB')
244 | #print("un_label {}", item)
245 |
246 | return self.train_transform_ul(img)
247 |
248 |
249 | def train_transform_ul(self, img):
250 | wd, ht = img.size
251 | st_size = 1.0 * min(wd, ht)
252 | # resize the image to fit the crop size
253 | if st_size < self.c_size:
254 | rr = 1.0 * self.c_size / st_size
255 | wd = round(wd * rr)
256 | ht = round(ht * rr)
257 | st_size = 1.0 * min(wd, ht)
258 | img = img.resize((wd, ht), Image.BICUBIC)
259 |
260 | assert st_size >= self.c_size, print(wd, ht)
261 |
262 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size)
263 | img = F.crop(img, i, j, h, w)
264 | if random.random() > 0.5:
265 | img = F.hflip(img)
266 |
267 | return self.trans(img),1
268 |
269 |
--------------------------------------------------------------------------------
/losses/consistency_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 | from losses import ramps
6 |
7 |
8 |
9 | class consistency_weight(object):
10 | """
11 | ramp_types = ['sigmoid_rampup', 'linear_rampup', 'cosine_rampup', 'log_rampup', 'exp_rampup']
12 | """
13 | def __init__(self, final_w, iters_per_epoch, rampup_starts=0, rampup_ends=7, ramp_type='sigmoid_rampup'):
14 | self.final_w = final_w
15 | self.iters_per_epoch = iters_per_epoch
16 | self.rampup_starts = rampup_starts * iters_per_epoch
17 | self.rampup_ends = rampup_ends * iters_per_epoch
18 | self.rampup_length = (self.rampup_ends - self.rampup_starts)
19 | self.rampup_func = getattr(ramps, ramp_type)
20 | self.current_rampup = 0
21 |
22 | def __call__(self, epoch, curr_iter):
23 | cur_total_iter = self.iters_per_epoch * epoch + curr_iter
24 | if cur_total_iter < self.rampup_starts:
25 | return 0
26 | self.current_rampup = self.rampup_func(cur_total_iter - self.rampup_starts, self.rampup_length)
27 | return self.final_w * self.current_rampup
28 |
29 |
30 | def CE_loss(input_logits, target_targets, ignore_index, temperature=1):
31 | return F.cross_entropy(input_logits/temperature, target_targets, ignore_index=ignore_index)
32 |
33 | # for FocalLoss
34 | def softmax_helper(x):
35 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
36 | rpt = [1 for _ in range(len(x.size()))]
37 | rpt[1] = x.size(1)
38 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
39 | e_x = torch.exp(x - x_max)
40 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
41 |
42 | def get_alpha(supervised_loader):
43 | # get number of classes
44 | num_labels = 0
45 | for image_batch, label_batch in supervised_loader:
46 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
47 | l_unique = torch.unique(label_batch.data)
48 | list_unique = [element.item() for element in l_unique.flatten()]
49 | num_labels = max(max(list_unique),num_labels)
50 | num_classes = num_labels + 1
51 | # count class occurrences
52 | alpha = [0 for i in range(num_classes)]
53 | for image_batch, label_batch in supervised_loader:
54 | label_batch.data[label_batch.data==255] = 0 # pixels of ignore class added to background
55 | l_unique = torch.unique(label_batch.data)
56 | list_unique = [element.item() for element in l_unique.flatten()]
57 | l_unique_count = torch.stack([(label_batch.data==x_u).sum() for x_u in l_unique]) # tensor([65920, 36480])
58 | list_count = [count.item() for count in l_unique_count.flatten()]
59 | for index in list_unique:
60 | alpha[index] += list_count[list_unique.index(index)]
61 | return alpha
62 |
63 | # for FocalLoss
64 | def softmax_helper(x):
65 | # copy from: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/utilities/nd_softmax.py
66 | rpt = [1 for _ in range(len(x.size()))]
67 | rpt[1] = x.size(1)
68 | x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
69 | e_x = torch.exp(x - x_max)
70 | return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
71 |
72 |
73 | class FocalLoss(nn.Module):
74 | """
75 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
76 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
77 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
78 | Focal_Loss= -1*alpha*(1-pt)*log(pt)
79 | :param num_class:
80 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
81 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
82 | focus on hard misclassified example
83 | :param smooth: (float,double) smooth value when cross entropy
84 | :param balance_index: (int) balance class index, should be specific when alpha is float
85 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
86 | """
87 |
88 | def __init__(self, apply_nonlin=None, ignore_index = None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
89 | super(FocalLoss, self).__init__()
90 | self.apply_nonlin = apply_nonlin
91 | self.alpha = alpha
92 | self.gamma = gamma
93 | self.balance_index = balance_index
94 | self.smooth = smooth
95 | self.size_average = size_average
96 |
97 | if self.smooth is not None:
98 | if self.smooth < 0 or self.smooth > 1.0:
99 | raise ValueError('smooth value should be in [0,1]')
100 |
101 | def forward(self, logit, target):
102 | if self.apply_nonlin is not None:
103 | logit = self.apply_nonlin(logit)
104 | num_class = logit.shape[1]
105 |
106 | if logit.dim() > 2:
107 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
108 | logit = logit.view(logit.size(0), logit.size(1), -1)
109 | logit = logit.permute(0, 2, 1).contiguous()
110 | logit = logit.view(-1, logit.size(-1))
111 | target = torch.squeeze(target, 1)
112 | target = target.view(-1, 1)
113 |
114 | valid_mask = None
115 | if self.ignore_index is not None:
116 | valid_mask = target != self.ignore_index
117 | target = target * valid_mask
118 |
119 | alpha = self.alpha
120 |
121 | if alpha is None:
122 | alpha = torch.ones(num_class, 1)
123 | elif isinstance(alpha, (list, np.ndarray)):
124 | assert len(alpha) == num_class
125 | alpha = torch.FloatTensor(alpha).view(num_class, 1)
126 | alpha = alpha / alpha.sum()
127 | alpha = 1/alpha # inverse of class frequency
128 | elif isinstance(alpha, float):
129 | alpha = torch.ones(num_class, 1)
130 | alpha = alpha * (1 - self.alpha)
131 | alpha[self.balance_index] = self.alpha
132 |
133 | else:
134 | raise TypeError('Not support alpha type')
135 |
136 | if alpha.device != logit.device:
137 | alpha = alpha.to(logit.device)
138 |
139 | idx = target.cpu().long()
140 |
141 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
142 |
143 | # to resolve error in idx in scatter_
144 | idx[idx==225]=0
145 |
146 | one_hot_key = one_hot_key.scatter_(1, idx, 1)
147 | if one_hot_key.device != logit.device:
148 | one_hot_key = one_hot_key.to(logit.device)
149 |
150 | if self.smooth:
151 | one_hot_key = torch.clamp(
152 | one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
153 | pt = (one_hot_key * logit).sum(1) + self.smooth
154 | logpt = pt.log()
155 |
156 | gamma = self.gamma
157 |
158 | alpha = alpha[idx]
159 | alpha = torch.squeeze(alpha)
160 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
161 |
162 | if valid_mask is not None:
163 | loss = loss * valid_mask.squeeze()
164 |
165 | if self.size_average:
166 | loss = loss.mean()
167 | else:
168 | loss = loss.sum()
169 | return loss
170 |
171 |
172 | class abCE_loss(nn.Module):
173 | """
174 | Annealed-Bootstrapped cross-entropy loss
175 | """
176 | def __init__(self, iters_per_epoch, epochs, num_classes, weight=None,
177 | reduction='mean', thresh=0.7, min_kept=1, ramp_type='log_rampup'):
178 | super(abCE_loss, self).__init__()
179 | self.weight = torch.FloatTensor(weight) if weight is not None else weight
180 | self.reduction = reduction
181 | self.thresh = thresh
182 | self.min_kept = min_kept
183 | self.ramp_type = ramp_type
184 |
185 | if ramp_type is not None:
186 | self.rampup_func = getattr(ramps, ramp_type)
187 | self.iters_per_epoch = iters_per_epoch
188 | self.num_classes = num_classes
189 | self.start = 1/num_classes
190 | self.end = 0.9
191 | self.total_num_iters = (epochs - (0.6 * epochs)) * iters_per_epoch
192 |
193 | def threshold(self, curr_iter, epoch):
194 | cur_total_iter = self.iters_per_epoch * epoch + curr_iter
195 | current_rampup = self.rampup_func(cur_total_iter, self.total_num_iters)
196 | return current_rampup * (self.end - self.start) + self.start
197 |
198 | def forward(self, predict, target, ignore_index, curr_iter, epoch):
199 | batch_kept = self.min_kept * target.size(0)
200 | prob_out = F.softmax(predict, dim=1)
201 | tmp_target = target.clone()
202 | tmp_target[tmp_target == ignore_index] = 0
203 | prob = prob_out.gather(1, tmp_target.unsqueeze(1))
204 | mask = target.contiguous().view(-1, ) != ignore_index
205 | sort_prob, sort_indices = prob.contiguous().view(-1, )[mask].contiguous().sort()
206 |
207 | if self.ramp_type is not None:
208 | thresh = self.threshold(curr_iter=curr_iter, epoch=epoch)
209 | else:
210 | thresh = self.thresh
211 |
212 | min_threshold = sort_prob[min(batch_kept, sort_prob.numel() - 1)] if sort_prob.numel() > 0 else 0.0
213 | threshold = max(min_threshold, thresh)
214 | loss_matrix = F.cross_entropy(predict, target,
215 | weight=self.weight.to(predict.device) if self.weight is not None else None,
216 | ignore_index=ignore_index, reduction='none')
217 | loss_matirx = loss_matrix.contiguous().view(-1, )
218 | sort_loss_matirx = loss_matirx[mask][sort_indices]
219 | select_loss_matrix = sort_loss_matirx[sort_prob < threshold]
220 | if self.reduction == 'sum' or select_loss_matrix.numel() == 0:
221 | return select_loss_matrix.sum()
222 | elif self.reduction == 'mean':
223 | return select_loss_matrix.mean()
224 | else:
225 | raise NotImplementedError('Reduction Error!')
226 |
227 |
228 |
229 | def softmax_mse_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
230 | assert inputs.requires_grad == True and targets.requires_grad == False
231 | assert inputs.size() == targets.size() # (batch_size * num_classes * H * W)
232 | inputs = F.softmax(inputs, dim=1)
233 | if use_softmax:
234 | targets = F.softmax(targets, dim=1)
235 |
236 | if conf_mask:
237 | loss_mat = F.mse_loss(inputs, targets, reduction='none')
238 | mask = (targets.max(1)[0] > threshold)
239 | loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
240 | if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
241 | return loss_mat.mean()
242 | else:
243 | return F.mse_loss(inputs, targets, reduction='mean') # take the mean over the batch_size
244 |
245 |
246 | def softmax_kl_loss(inputs, targets, conf_mask=False, threshold=None, use_softmax=False):
247 | assert inputs.requires_grad == True and targets.requires_grad == False
248 | assert inputs.size() == targets.size()
249 |
250 | if use_softmax:
251 | targets = F.softmax(targets, dim=1)
252 | if conf_mask:
253 | loss_mat = F.kl_div(input_log_softmax, targets, reduction='none')
254 | mask = (targets.max(1)[0] > threshold)
255 | loss_mat = loss_mat[mask.unsqueeze(1).expand_as(loss_mat)]
256 | if loss_mat.shape.numel() == 0: loss_mat = torch.tensor([0.]).to(inputs.device)
257 | return loss_mat.sum() / mask.shape.numel()
258 | else:
259 | return F.kl_div(inputs, targets, reduction='mean')
260 |
261 |
262 | def softmax_js_loss(inputs, targets, **_):
263 | assert inputs.requires_grad == True and targets.requires_grad == False
264 | assert inputs.size() == targets.size()
265 | epsilon = 1e-5
266 |
267 | M = (F.softmax(inputs, dim=1) + targets) * 0.5
268 | kl1 = F.kl_div(F.log_softmax(inputs, dim=1), M, reduction='mean')
269 | kl2 = F.kl_div(torch.log(targets+epsilon), M, reduction='mean')
270 | return (kl1 + kl2) * 0.5
271 |
272 |
273 |
274 | def pair_wise_loss(unsup_outputs, size_average=True, nbr_of_pairs=8):
275 | """
276 | Pair-wise loss in the sup. mat.
277 | """
278 | if isinstance(unsup_outputs, list):
279 | unsup_outputs = torch.stack(unsup_outputs)
280 |
281 | # Only for a subset of the aux outputs to reduce computation and memory
282 | unsup_outputs = unsup_outputs[torch.randperm(unsup_outputs.size(0))]
283 | unsup_outputs = unsup_outputs[:nbr_of_pairs]
284 |
285 | temp = torch.zeros_like(unsup_outputs) # For grad purposes
286 | for i, u in enumerate(unsup_outputs):
287 | temp[i] = F.softmax(u, dim=1)
288 | mean_prediction = temp.mean(0).unsqueeze(0) # Mean over the auxiliary outputs
289 | pw_loss = ((temp - mean_prediction)**2).mean(0) # Variance
290 | pw_loss = pw_loss.sum(1) # Sum over classes
291 | if size_average:
292 | return pw_loss.mean()
293 | return pw_loss.sum()
294 |
295 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | from torch import optim
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data.dataloader import default_collate
8 | import numpy as np
9 | from datetime import datetime
10 | import torch.nn.functional as F
11 | from datasets.crowd import Crowd_TC, Crowd_UL_TC
12 |
13 | from network import pvt_cls as TCN
14 | from losses.multi_con_loss import MultiConLoss
15 |
16 | from utils.pytorch_utils import Save_Handle, AverageMeter
17 | import utils.log_utils as log_utils
18 | import argparse
19 | from losses.rank_loss import RankLoss
20 |
21 | from losses import ramps
22 | from losses.ot_loss import OT_Loss
23 | from losses.consistency_loss import *
24 |
25 | parser = argparse.ArgumentParser(description='Train')
26 | parser.add_argument('--data-dir', default='/users/k2254235/Lab/TCT/Dataset/London_103050/', help='data path')
27 |
28 | parser.add_argument('--dataset', default='TC')
29 | parser.add_argument('--lr', type=float, default=1e-5, help='the initial learning rate')
30 | parser.add_argument('--weight-decay', type=float, default=1e-4, help='the weight decay')
31 | parser.add_argument('--resume', default='', type=str, help='the path of resume training model')
32 | parser.add_argument('--max-epoch', type=int, default=4000, help='max training epoch')
33 | parser.add_argument('--val-epoch', type=int, default=1, help='the num of steps to log training information')
34 | parser.add_argument('--val-start', type=int, default=0, help='the epoch start to val')
35 | parser.add_argument('--batch-size', type=int, default=16, help='train batch size')
36 | parser.add_argument('--batch-size-ul', type=int, default=16, help='train batch size')
37 | parser.add_argument('--device', default='0', help='assign device')
38 | parser.add_argument('--num-workers', type=int, default=0, help='the num of training process')
39 | parser.add_argument('--crop-size', type=int, default= 256, help='the crop size of the train image')
40 | parser.add_argument('--rl', type=float, default=1, help='entropy regularization in sinkhorn')
41 | parser.add_argument('--reg', type=float, default=1, help='entropy regularization in sinkhorn')
42 | parser.add_argument('--ot', type=float, default=0.1, help='entropy regularization in sinkhorn')
43 | parser.add_argument('--tv', type=float, default=0.01, help='entropy regularization in sinkhorn')
44 | parser.add_argument('--num-of-iter-in-ot', type=int, default=100, help='sinkhorn iterations')
45 | parser.add_argument('--norm-cood', type=int, default=0, help='whether to norm cood when computing distance')
46 | parser.add_argument('--run-name', default='Treeformer_test', help='run name for wandb interface/logging')
47 | parser.add_argument('--consistency', type=int, default=1, help='whether to norm cood when computing distance')
48 | args = parser.parse_args()
49 |
50 |
51 | def train_collate(batch):
52 | transposed_batch = list(zip(*batch))
53 | images = torch.stack(transposed_batch[0], 0)
54 | gauss = torch.stack(transposed_batch[1], 0)
55 | points = transposed_batch[2]
56 | gt_discretes = torch.stack(transposed_batch[3], 0)
57 | return images, gauss, points, gt_discretes
58 |
59 |
60 | def train_collate_UL(batch):
61 | transposed_batch = list(zip(*batch))
62 | images = torch.stack(transposed_batch[0], 0)
63 |
64 | return images
65 |
66 | def get_current_consistency_weight(epoch):
67 | # Consistency ramp-up from https://arxiv.org/abs/1610.02242
68 | return args.consistency * ramps.sigmoid_rampup(epoch, args.consistency_ramp)
69 |
70 |
71 | class Trainer(object):
72 | def __init__(self, args):
73 | self.args = args
74 |
75 | def setup(self):
76 | args = self.args
77 | sub_dir = (
78 | "SEMI/{}_12-1-input-{}_reg-{}_nIter-{}_normCood-{}".format(
79 | args.run_name,args.crop_size,args.reg,
80 | args.num_of_iter_in_ot,args.norm_cood))
81 |
82 | self.save_dir = os.path.join("/scratch/users/k2254235","ckpts", sub_dir)
83 | if not os.path.exists(self.save_dir):
84 | os.makedirs(self.save_dir)
85 |
86 | time_str = datetime.strftime(datetime.now(), "%m%d-%H%M%S")
87 | self.logger = log_utils.get_logger(
88 | os.path.join(self.save_dir, "train-{:s}.log".format(time_str)))
89 |
90 | log_utils.print_config(vars(args), self.logger)
91 |
92 | if torch.cuda.is_available():
93 | self.device = torch.device("cuda")
94 | self.device_count = torch.cuda.device_count()
95 | self.logger.info("using {} gpus".format(self.device_count))
96 | else:
97 | raise Exception("gpu is not available")
98 |
99 |
100 | downsample_ratio = 4
101 | self.datasets = {"train": Crowd_TC(os.path.join(args.data_dir, "train_data"), args.crop_size,
102 | downsample_ratio, "train"), "val": Crowd_TC(os.path.join(args.data_dir, "valid_data"),
103 | args.crop_size, downsample_ratio, "val")}
104 |
105 | self.datasets_ul = { "train_ul": Crowd_UL_TC(os.path.join(args.data_dir, "train_data_ul"),
106 | args.crop_size, downsample_ratio, "train_ul")}
107 |
108 |
109 | self.dataloaders = {
110 | x: DataLoader(self.datasets[x],
111 | collate_fn=(train_collate if x == "train" else default_collate),
112 | batch_size=(args.batch_size if x == "train" else 1),
113 | shuffle=(True if x == "train" else False),
114 | num_workers=args.num_workers * self.device_count,
115 | pin_memory=(True if x == "train" else False))
116 | for x in ["train", "val"]}
117 |
118 | self.dataloaders_ul = {
119 | x: DataLoader(self.datasets_ul[x],
120 | collate_fn=(train_collate_UL ),
121 | batch_size=(args.batch_size_ul),
122 | shuffle=(True),
123 | num_workers=args.num_workers * self.device_count,
124 | pin_memory=(True if x == "train" else False))
125 | for x in ["train_ul"]}
126 |
127 |
128 | self.model = TCN.pvt_treeformer(pretrained=False)
129 |
130 | self.model.to(self.device)
131 | self.optimizer = optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
132 | self.start_epoch = 0
133 |
134 | if args.resume:
135 | self.logger.info("loading pretrained model from " + args.resume)
136 | suf = args.resume.rsplit(".", 1)[-1]
137 | if suf == "tar":
138 | checkpoint = torch.load(args.resume, self.device)
139 | self.model.load_state_dict(checkpoint["model_state_dict"])
140 | self.optimizer.load_state_dict(
141 | checkpoint["optimizer_state_dict"])
142 | self.start_epoch = checkpoint["epoch"] + 1
143 | elif suf == "pth":
144 | self.model.load_state_dict(
145 | torch.load(args.resume, self.device))
146 | else:
147 | self.logger.info("random initialization")
148 |
149 | self.ot_loss = OT_Loss(args.crop_size, downsample_ratio, args.norm_cood,
150 | self.device, args.num_of_iter_in_ot, args.reg)
151 |
152 | self.tvloss = nn.L1Loss(reduction="none").to(self.device)
153 | self.mse = nn.MSELoss().to(self.device)
154 | self.mae = nn.L1Loss().to(self.device)
155 | self.save_list = Save_Handle(max_num=1)
156 | self.best_mae = np.inf
157 | self.best_mse = np.inf
158 | self.rankloss = RankLoss().to(self.device)
159 | self.kl_distance = nn.KLDivLoss(reduction='none')
160 | self.multiconloss = MultiConLoss().to(self.device)
161 |
162 |
163 | def train(self):
164 | """training process"""
165 | args = self.args
166 | for epoch in range(self.start_epoch, args.max_epoch + 1):
167 | self.logger.info("-" * 5 + "Epoch {}/{}".format(epoch, args.max_epoch) + "-" * 5)
168 | self.epoch = epoch
169 | self.train_epoch()
170 | if epoch % args.val_epoch == 0 and epoch >= args.val_start:
171 | self.val_epoch()
172 |
173 | def train_epoch(self):
174 | epoch_ot_loss = AverageMeter()
175 | epoch_ot_obj_value = AverageMeter()
176 | epoch_wd = AverageMeter()
177 | epoch_tv_loss = AverageMeter()
178 | epoch_count_loss = AverageMeter()
179 | epoch_count_consistency_l = AverageMeter()
180 | epoch_count_consistency_ul = AverageMeter()
181 | epoch_loss = AverageMeter()
182 | epoch_mae = AverageMeter()
183 | epoch_mse = AverageMeter()
184 | epoch_start = time.time()
185 | epoch_rank_loss = AverageMeter()
186 | epoch_consistensy_loss = AverageMeter()
187 |
188 | self.model.train() # Set model to training mode
189 |
190 | for step, (inputs, gausss, points, gt_discrete) in enumerate(self.dataloaders["train"]):
191 | inputs = inputs.to(self.device)
192 | gausss = gausss.to(self.device)
193 | gd_count = np.array([len(p) for p in points], dtype=np.float32)
194 |
195 | points = [p.to(self.device) for p in points]
196 | gt_discrete = gt_discrete.to(self.device)
197 | N = inputs.size(0)
198 |
199 | for st, unlabel_data in enumerate(self.dataloaders_ul["train_ul"]):
200 | inputs_ul = unlabel_data.to(self.device)
201 | break
202 |
203 |
204 | with torch.set_grad_enabled(True):
205 | outputs_L, outputs_UL, outputs_normed, CLS_L, CLS_UL = self.model(inputs, inputs_ul)
206 | outputs_L = outputs_L[0]
207 |
208 | with torch.set_grad_enabled(False):
209 | preds_UL = (outputs_UL[0][0] + outputs_UL[1][0] + outputs_UL[2][0])/3
210 |
211 | # Compute counting loss.
212 | count_loss = self.mae(outputs_L.sum(1).sum(1).sum(1),torch.from_numpy(gd_count).float().to(self.device))*self.args.reg
213 |
214 | # Compute OT loss.
215 | ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L, points)
216 | ot_loss = ot_loss* self.args.ot
217 | ot_obj_value = ot_obj_value* self.args.ot
218 |
219 | gd_count_tensor = (torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))
220 | gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
221 | tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)*
222 | torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv
223 |
224 | epoch_ot_loss.update(ot_loss.item(), N)
225 | epoch_ot_obj_value.update(ot_obj_value.item(), N)
226 | epoch_wd.update(wd, N)
227 | epoch_count_loss.update(count_loss.item(), N)
228 | epoch_tv_loss.update(tv_loss.item(), N)
229 |
230 | # Compute ranking loss.
231 | rank_loss = self.rankloss(outputs_UL)*self.args.rl
232 | epoch_rank_loss.update(rank_loss.item(), N)
233 |
234 | # Compute multi level consistancy loss
235 | consistency_loss = args.consistency * self.multiconloss(outputs_UL)
236 | epoch_consistensy_loss.update(consistency_loss.item(), N)
237 |
238 |
239 | # Compute consistency count
240 | Con_cls_UL = (CLS_UL[0] + CLS_UL[1] + CLS_UL[2])/3
241 | Con_cls_L = torch.from_numpy(gd_count).float().to(self.device)
242 |
243 | count_loss_l = self.mae(torch.stack((CLS_L[0],CLS_L[1],CLS_L[2])), torch.stack((Con_cls_L, Con_cls_L, Con_cls_L)))
244 | count_loss_ul = self.mae(torch.stack((CLS_UL[0],CLS_UL[1],CLS_UL[2])), torch.stack((Con_cls_UL, Con_cls_UL, Con_cls_UL)))
245 | epoch_count_consistency_l.update(count_loss_l.item(), N)
246 | epoch_count_consistency_ul.update(count_loss_ul.item(), N)
247 |
248 |
249 | loss = count_loss + ot_loss + tv_loss + rank_loss + count_loss_l + count_loss_ul + consistency_loss
250 |
251 |
252 | self.optimizer.zero_grad()
253 | loss.backward()
254 | self.optimizer.step()
255 |
256 | pred_count = (torch.sum(outputs_L.view(N, -1),
257 | dim=1).detach().cpu().numpy())
258 |
259 | pred_err = pred_count - gd_count
260 | epoch_loss.update(loss.item(), N)
261 | epoch_mse.update(np.mean(pred_err * pred_err), N)
262 | epoch_mae.update(np.mean(abs(pred_err)), N)
263 |
264 |
265 | self.logger.info(
266 | "Epoch {} Train, Loss: {:.2f}, Count Loss: {:.2f}, OT Loss: {:.2e}, TV Loss: {:.2e}, Rank Loss: {:.2f},"
267 | "Consistensy Loss: {:.2f}, MSE: {:.2f}, MAE: {:.2f},LC Loss: {:.2f}, ULC Loss: {:.2f}, Cost {:.1f} sec".format(
268 | self.epoch, epoch_loss.get_avg(), epoch_count_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_tv_loss.get_avg(), epoch_rank_loss.get_avg(),
269 | epoch_consistensy_loss.get_avg(), np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(), epoch_count_consistency_l.get_avg(),
270 | epoch_count_consistency_ul.get_avg(), time.time() - epoch_start))
271 |
272 |
273 |
274 | model_state_dic = self.model.state_dict()
275 | save_path = os.path.join(self.save_dir, "{}_ckpt.tar".format(self.epoch))
276 |
277 | torch.save({"epoch": self.epoch, "optimizer_state_dict": self.optimizer.state_dict(),
278 | "model_state_dict": model_state_dic}, save_path)
279 | self.save_list.append(save_path)
280 |
281 | def val_epoch(self):
282 | args = self.args
283 | epoch_start = time.time()
284 | self.model.eval() # Set model to evaluate mode
285 | epoch_res = []
286 | for inputs, count, name, gauss_im in self.dataloaders["val"]:
287 | with torch.no_grad():
288 | inputs = inputs.to(self.device)
289 | crop_imgs, crop_masks = [], []
290 | b, c, h, w = inputs.size()
291 | rh, rw = args.crop_size, args.crop_size
292 | for i in range(0, h, rh):
293 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
294 | for j in range(0, w, rw):
295 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
296 | crop_imgs.append(inputs[:, :, gis:gie, gjs:gje])
297 | mask = torch.zeros([b, 1, h, w]).to(self.device)
298 | mask[:, :, gis:gie, gjs:gje].fill_(1.0)
299 | crop_masks.append(mask)
300 | crop_imgs, crop_masks = map(
301 | lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks))
302 |
303 | crop_preds = []
304 | nz, bz = crop_imgs.size(0), args.batch_size
305 | for i in range(0, nz, bz):
306 | gs, gt = i, min(nz, i + bz)
307 |
308 | crop_pred, _ = self.model(crop_imgs[gs:gt])
309 | crop_pred = crop_pred[0]
310 | _, _, h1, w1 = crop_pred.size()
311 | crop_pred = (F.interpolate(crop_pred, size=(h1 * 4, w1 * 4),
312 | mode="bilinear", align_corners=True) / 16 )
313 |
314 | crop_preds.append(crop_pred)
315 | crop_preds = torch.cat(crop_preds, dim=0)
316 |
317 | # splice them to the original size
318 | idx = 0
319 | pred_map = torch.zeros([b, 1, h, w]).to(self.device)
320 | for i in range(0, h, rh):
321 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh)
322 | for j in range(0, w, rw):
323 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw)
324 | pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx]
325 | idx += 1
326 | # for the overlapping area, compute average value
327 | mask = crop_masks.sum(dim=0).unsqueeze(0)
328 | outputs = pred_map / mask
329 |
330 | res = count[0].item() - torch.sum(outputs).item()
331 | epoch_res.append(res)
332 | epoch_res = np.array(epoch_res)
333 | mse = np.sqrt(np.mean(np.square(epoch_res)))
334 | mae = np.mean(np.abs(epoch_res))
335 |
336 | self.logger.info("Epoch {} Val, MSE: {:.2f}, MAE: {:.2f}, Cost {:.1f} sec".format(
337 | self.epoch, mse, mae, time.time() - epoch_start ))
338 |
339 |
340 | model_state_dic = self.model.state_dict()
341 | print("Comaprison", mae, self.best_mae)
342 | if mae < self.best_mae:
343 | self.best_mse = mse
344 | self.best_mae = mae
345 | self.logger.info(
346 | "save best mse {:.2f} mae {:.2f} model epoch {}".format(
347 | self.best_mse, self.best_mae, self.epoch))
348 |
349 | print("Saving best model at {} epoch".format(self.epoch))
350 | model_path = os.path.join(
351 | self.save_dir, "best_model_mae-{:.2f}_epoch-{}.pth".format(
352 | self.best_mae, self.epoch))
353 |
354 | torch.save(model_state_dic, model_path)
355 |
356 |
357 | if __name__ == "__main__":
358 | import torch
359 | torch.backends.cudnn.benchmark = True
360 | trainer = Trainer(args)
361 | trainer.setup()
362 | trainer.train()
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
--------------------------------------------------------------------------------
/losses/bregman_pytorch.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Rewrite ot.bregman.sinkhorn in Python Optimal Transport (https://pythonot.github.io/_modules/ot/bregman.html#sinkhorn)
4 | using pytorch operations.
5 | Bregman projections for regularized OT (Sinkhorn distance).
6 | """
7 |
8 | import torch
9 |
10 | M_EPS = 1e-16
11 |
12 |
13 | def sinkhorn(a, b, C, reg=1e-1, method='sinkhorn', maxIter=1000, tau=1e3,
14 | stopThr=1e-9, verbose=False, log=True, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
15 | """
16 | Solve the entropic regularization optimal transport
17 | The input should be PyTorch tensors
18 | The function solves the following optimization problem:
19 |
20 | .. math::
21 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
22 | s.t. \gamma 1 = a
23 | \gamma^T 1= b
24 | \gamma\geq 0
25 | where :
26 | - C is the (ns,nt) metric cost matrix
27 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
28 | - a and b are target and source measures (sum to 1)
29 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1].
30 |
31 | Parameters
32 | ----------
33 | a : torch.tensor (na,)
34 | samples measure in the target domain
35 | b : torch.tensor (nb,)
36 | samples in the source domain
37 | C : torch.tensor (na,nb)
38 | loss matrix
39 | reg : float
40 | Regularization term > 0
41 | method : str
42 | method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
43 | 'sinkhorn_epsilon_scaling', see those function for specific parameters
44 | maxIter : int, optional
45 | Max number of iterations
46 | stopThr : float, optional
47 | Stop threshol on error ( > 0 )
48 | verbose : bool, optional
49 | Print information along iterations
50 | log : bool, optional
51 | record log if True
52 |
53 | Returns
54 | -------
55 | gamma : (na x nb) torch.tensor
56 | Optimal transportation matrix for the given parameters
57 | log : dict
58 | log dictionary return only if log==True in parameters
59 |
60 | References
61 | ----------
62 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
63 | See Also
64 | --------
65 |
66 | """
67 |
68 | if method.lower() == 'sinkhorn':
69 | return sinkhorn_knopp(a, b, C, reg, maxIter=maxIter,
70 | stopThr=stopThr, verbose=verbose, log=log,
71 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
72 | **kwargs)
73 | elif method.lower() == 'sinkhorn_stabilized':
74 | return sinkhorn_stabilized(a, b, C, reg, maxIter=maxIter, tau=tau,
75 | stopThr=stopThr, verbose=verbose, log=log,
76 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
77 | **kwargs)
78 | elif method.lower() == 'sinkhorn_epsilon_scaling':
79 | return sinkhorn_epsilon_scaling(a, b, C, reg,
80 | maxIter=maxIter, maxInnerIter=100, tau=tau,
81 | scaling_base=0.75, scaling_coef=None, stopThr=stopThr,
82 | verbose=False, log=log, warm_start=warm_start, eval_freq=eval_freq,
83 | print_freq=print_freq, **kwargs)
84 | else:
85 | raise ValueError("Unknown method '%s'." % method)
86 |
87 |
88 | def sinkhorn_knopp(a, b, C, reg=1e-1, maxIter=1000, stopThr=1e-9,
89 | verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
90 | """
91 | Solve the entropic regularization optimal transport
92 | The input should be PyTorch tensors
93 | The function solves the following optimization problem:
94 |
95 | .. math::
96 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
97 | s.t. \gamma 1 = a
98 | \gamma^T 1= b
99 | \gamma\geq 0
100 | where :
101 | - C is the (ns,nt) metric cost matrix
102 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
103 | - a and b are target and source measures (sum to 1)
104 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1].
105 |
106 | Parameters
107 | ----------
108 | a : torch.tensor (na,)
109 | samples measure in the target domain
110 | b : torch.tensor (nb,)
111 | samples in the source domain
112 | C : torch.tensor (na,nb)
113 | loss matrix
114 | reg : float
115 | Regularization term > 0
116 | maxIter : int, optional
117 | Max number of iterations
118 | stopThr : float, optional
119 | Stop threshol on error ( > 0 )
120 | verbose : bool, optional
121 | Print information along iterations
122 | log : bool, optional
123 | record log if True
124 |
125 | Returns
126 | -------
127 | gamma : (na x nb) torch.tensor
128 | Optimal transportation matrix for the given parameters
129 | log : dict
130 | log dictionary return only if log==True in parameters
131 |
132 | References
133 | ----------
134 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
135 | See Also
136 | --------
137 |
138 | """
139 |
140 | device = a.device
141 | na, nb = C.shape
142 |
143 | assert na >= 1 and nb >= 1, 'C needs to be 2d'
144 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
145 | assert reg > 0, 'reg should be greater than 0'
146 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
147 |
148 | if log:
149 | log = {'err': []}
150 |
151 | if warm_start is not None:
152 | u = warm_start['u']
153 | v = warm_start['v']
154 | else:
155 | u = torch.ones(na, dtype=a.dtype).to(device) / na
156 | v = torch.ones(nb, dtype=b.dtype).to(device) / nb
157 |
158 | K = torch.empty(C.shape, dtype=C.dtype).to(device)
159 | torch.div(C, -reg, out=K)
160 | torch.exp(K, out=K)
161 |
162 | b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)
163 |
164 | it = 1
165 | err = 1
166 |
167 | # allocate memory beforehand
168 | KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
169 | Kv = torch.empty(u.shape, dtype=u.dtype).to(device)
170 |
171 | while (err > stopThr and it <= maxIter):
172 | upre, vpre = u, v
173 | torch.matmul(u, K, out=KTu)
174 | v = torch.div(b, KTu + M_EPS)
175 | torch.matmul(K, v, out=Kv)
176 | u = torch.div(a, Kv + M_EPS)
177 |
178 | if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \
179 | torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)):
180 | print('Warning: numerical errors at iteration', it)
181 | u, v = upre, vpre
182 | break
183 |
184 | if log and it % eval_freq == 0:
185 | # we can speed up the process by checking for the error only all
186 | # the eval_freq iterations
187 | # below is equivalent to:
188 | # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0)
189 | # but with more memory efficient
190 | b_hat = torch.matmul(u, K) * v
191 | err = (b - b_hat).pow(2).sum().item()
192 | # err = (b - b_hat).abs().sum().item()
193 | log['err'].append(err)
194 |
195 | if verbose and it % print_freq == 0:
196 | print('iteration {:5d}, constraint error {:5e}'.format(it, err))
197 |
198 | it += 1
199 |
200 | if log:
201 | log['u'] = u
202 | log['v'] = v
203 | log['alpha'] = reg * torch.log(u + M_EPS)
204 | log['beta'] = reg * torch.log(v + M_EPS)
205 |
206 | # transport plan
207 | P = u.reshape(-1, 1) * K * v.reshape(1, -1)
208 | if log:
209 | return P, log
210 | else:
211 | return P
212 |
213 |
214 | def sinkhorn_stabilized(a, b, C, reg=1e-1, maxIter=1000, tau=1e3, stopThr=1e-9,
215 | verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs):
216 | """
217 | Solve the entropic regularization OT problem with log stabilization
218 | The function solves the following optimization problem:
219 |
220 | .. math::
221 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
222 | s.t. \gamma 1 = a
223 | \gamma^T 1= b
224 | \gamma\geq 0
225 | where :
226 | - C is the (ns,nt) metric cost matrix
227 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
228 | - a and b are target and source measures (sum to 1)
229 |
230 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]
231 | but with the log stabilization proposed in [3] an defined in [2] (Algo 3.1)
232 |
233 | Parameters
234 | ----------
235 | a : torch.tensor (na,)
236 | samples measure in the target domain
237 | b : torch.tensor (nb,)
238 | samples in the source domain
239 | C : torch.tensor (na,nb)
240 | loss matrix
241 | reg : float
242 | Regularization term > 0
243 | tau : float
244 | thershold for max value in u or v for log scaling
245 | maxIter : int, optional
246 | Max number of iterations
247 | stopThr : float, optional
248 | Stop threshol on error ( > 0 )
249 | verbose : bool, optional
250 | Print information along iterations
251 | log : bool, optional
252 | record log if True
253 |
254 | Returns
255 | -------
256 | gamma : (na x nb) torch.tensor
257 | Optimal transportation matrix for the given parameters
258 | log : dict
259 | log dictionary return only if log==True in parameters
260 |
261 | References
262 | ----------
263 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
264 | [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019
265 | [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
266 |
267 | See Also
268 | --------
269 |
270 | """
271 |
272 | device = a.device
273 | na, nb = C.shape
274 |
275 | assert na >= 1 and nb >= 1, 'C needs to be 2d'
276 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
277 | assert reg > 0, 'reg should be greater than 0'
278 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
279 |
280 | if log:
281 | log = {'err': []}
282 |
283 | if warm_start is not None:
284 | alpha = warm_start['alpha']
285 | beta = warm_start['beta']
286 | else:
287 | alpha = torch.zeros(na, dtype=a.dtype).to(device)
288 | beta = torch.zeros(nb, dtype=b.dtype).to(device)
289 |
290 | u = torch.ones(na, dtype=a.dtype).to(device) / na
291 | v = torch.ones(nb, dtype=b.dtype).to(device) / nb
292 |
293 | def update_K(alpha, beta):
294 | """log space computation"""
295 | """memory efficient"""
296 | torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=K)
297 | torch.add(K, -C, out=K)
298 | torch.div(K, reg, out=K)
299 | torch.exp(K, out=K)
300 |
301 | def update_P(alpha, beta, u, v, ab_updated=False):
302 | """log space P (gamma) computation"""
303 | torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=P)
304 | torch.add(P, -C, out=P)
305 | torch.div(P, reg, out=P)
306 | if not ab_updated:
307 | torch.add(P, torch.log(u + M_EPS).reshape(-1, 1), out=P)
308 | torch.add(P, torch.log(v + M_EPS).reshape(1, -1), out=P)
309 | torch.exp(P, out=P)
310 |
311 | K = torch.empty(C.shape, dtype=C.dtype).to(device)
312 | update_K(alpha, beta)
313 |
314 | b_hat = torch.empty(b.shape, dtype=C.dtype).to(device)
315 |
316 | it = 1
317 | err = 1
318 | ab_updated = False
319 |
320 | # allocate memory beforehand
321 | KTu = torch.empty(v.shape, dtype=v.dtype).to(device)
322 | Kv = torch.empty(u.shape, dtype=u.dtype).to(device)
323 | P = torch.empty(C.shape, dtype=C.dtype).to(device)
324 |
325 | while (err > stopThr and it <= maxIter):
326 | upre, vpre = u, v
327 | torch.matmul(u, K, out=KTu)
328 | v = torch.div(b, KTu + M_EPS)
329 | torch.matmul(K, v, out=Kv)
330 | u = torch.div(a, Kv + M_EPS)
331 |
332 | ab_updated = False
333 | # remove numerical problems and store them in K
334 | if u.abs().sum() > tau or v.abs().sum() > tau:
335 | alpha += reg * torch.log(u + M_EPS)
336 | beta += reg * torch.log(v + M_EPS)
337 | u.fill_(1. / na)
338 | v.fill_(1. / nb)
339 | update_K(alpha, beta)
340 | ab_updated = True
341 |
342 | if log and it % eval_freq == 0:
343 | # we can speed up the process by checking for the error only all
344 | # the eval_freq iterations
345 | update_P(alpha, beta, u, v, ab_updated)
346 | b_hat = torch.sum(P, 0)
347 | err = (b - b_hat).pow(2).sum().item()
348 | log['err'].append(err)
349 |
350 | if verbose and it % print_freq == 0:
351 | print('iteration {:5d}, constraint error {:5e}'.format(it, err))
352 |
353 | it += 1
354 |
355 | if log:
356 | log['u'] = u
357 | log['v'] = v
358 | log['alpha'] = alpha + reg * torch.log(u + M_EPS)
359 | log['beta'] = beta + reg * torch.log(v + M_EPS)
360 |
361 | # transport plan
362 | update_P(alpha, beta, u, v, False)
363 |
364 | if log:
365 | return P, log
366 | else:
367 | return P
368 |
369 |
370 | def sinkhorn_epsilon_scaling(a, b, C, reg=1e-1, maxIter=100, maxInnerIter=100, tau=1e3, scaling_base=0.75,
371 | scaling_coef=None, stopThr=1e-9, verbose=False, log=False, warm_start=None, eval_freq=10,
372 | print_freq=200, **kwargs):
373 | """
374 | Solve the entropic regularization OT problem with log stabilization
375 | The function solves the following optimization problem:
376 |
377 | .. math::
378 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma)
379 | s.t. \gamma 1 = a
380 | \gamma^T 1= b
381 | \gamma\geq 0
382 | where :
383 | - C is the (ns,nt) metric cost matrix
384 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
385 | - a and b are target and source measures (sum to 1)
386 |
387 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
388 | scaling algorithm as proposed in [1] but with the log stabilization
389 | proposed in [3] and the log scaling proposed in [2] algorithm 3.2
390 |
391 | Parameters
392 | ----------
393 | a : torch.tensor (na,)
394 | samples measure in the target domain
395 | b : torch.tensor (nb,)
396 | samples in the source domain
397 | C : torch.tensor (na,nb)
398 | loss matrix
399 | reg : float
400 | Regularization term > 0
401 | tau : float
402 | thershold for max value in u or v for log scaling
403 | maxIter : int, optional
404 | Max number of iterations
405 | stopThr : float, optional
406 | Stop threshol on error ( > 0 )
407 | verbose : bool, optional
408 | Print information along iterations
409 | log : bool, optional
410 | record log if True
411 |
412 | Returns
413 | -------
414 | gamma : (na x nb) torch.tensor
415 | Optimal transportation matrix for the given parameters
416 | log : dict
417 | log dictionary return only if log==True in parameters
418 |
419 | References
420 | ----------
421 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
422 | [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019
423 | [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
424 |
425 | See Also
426 | --------
427 |
428 | """
429 |
430 | na, nb = C.shape
431 |
432 | assert na >= 1 and nb >= 1, 'C needs to be 2d'
433 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C"
434 | assert reg > 0, 'reg should be greater than 0'
435 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0'
436 |
437 | def get_reg(it, reg, pre_reg):
438 | if it == 1:
439 | return scaling_coef
440 | else:
441 | if (pre_reg - reg) * scaling_base < M_EPS:
442 | return reg
443 | else:
444 | return (pre_reg - reg) * scaling_base + reg
445 |
446 | if scaling_coef is None:
447 | scaling_coef = C.max() + reg
448 |
449 | it = 1
450 | err = 1
451 | running_reg = scaling_coef
452 |
453 | if log:
454 | log = {'err': []}
455 |
456 | warm_start = None
457 |
458 | while (err > stopThr and it <= maxIter):
459 | running_reg = get_reg(it, reg, running_reg)
460 | P, _log = sinkhorn_stabilized(a, b, C, running_reg, maxIter=maxInnerIter, tau=tau,
461 | stopThr=stopThr, verbose=False, log=True,
462 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq,
463 | **kwargs)
464 |
465 | warm_start = {}
466 | warm_start['alpha'] = _log['alpha']
467 | warm_start['beta'] = _log['beta']
468 |
469 | primal_val = (C * P).sum() + reg * (P * torch.log(P)).sum() - reg * P.sum()
470 | dual_val = (_log['alpha'] * a).sum() + (_log['beta'] * b).sum() - reg * P.sum()
471 | err = primal_val - dual_val
472 | log['err'].append(err)
473 |
474 | if verbose and it % print_freq == 0:
475 | print('iteration {:5d}, constraint error {:5e}'.format(it, err))
476 |
477 | it += 1
478 |
479 | if log:
480 | log['alpha'] = _log['alpha']
481 | log['beta'] = _log['beta']
482 | return P, log
483 | else:
484 | return P
485 |
--------------------------------------------------------------------------------
/network/pvt_cls.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from functools import partial
5 |
6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7 | from timm.models.registry import register_model
8 | from timm.models.vision_transformer import _cfg
9 |
10 | import math
11 | from torch.distributions.uniform import Uniform
12 | import numpy as np
13 | import random
14 |
15 | __all__ = [
16 | 'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large'
17 | ]
18 |
19 |
20 | class SELayer(nn.Module):
21 | def __init__(self, channel, reduction=16):
22 | super(SELayer, self).__init__()
23 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
24 | self.fc = nn.Sequential(
25 | nn.Linear(channel, channel // reduction, bias=False),
26 | nn.ReLU(inplace=True),
27 | nn.Linear(channel // reduction, channel, bias=False),
28 | nn.Sigmoid()
29 | )
30 |
31 | def forward(self, x):
32 | b, c, _, _ = x.size()
33 | y = self.avg_pool(x).view(b, c)
34 | y = self.fc(y).view(b, c, 1, 1)
35 | return x * y.expand_as(x)
36 |
37 |
38 | class Regression(nn.Module):
39 | def __init__(self):
40 | super(Regression, self).__init__()
41 |
42 | self.v1 = nn.Sequential(
43 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
44 | nn.Conv2d(256, 128, 3, padding=1, dilation=1),
45 | nn.BatchNorm2d(128), nn.ReLU(inplace=True))
46 |
47 | self.v2 = nn.Sequential(
48 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
49 | nn.Conv2d(512, 256, 3, padding=1, dilation=1),
50 | nn.BatchNorm2d(256), nn.ReLU(inplace=True))
51 |
52 | self.v3 = nn.Sequential(
53 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
54 | nn.Conv2d(1024, 512, 3, padding=1, dilation=1), nn.BatchNorm2d(512),
55 | nn.ReLU(inplace=True))
56 |
57 | self.ca2 = nn.Sequential(ChannelAttention(512),
58 | nn.Conv2d(512, 512, kernel_size = 3, stride = 1, padding = 1 ),
59 | nn.BatchNorm2d(512), nn.ReLU(inplace=True))
60 |
61 | self.ca1 = nn.Sequential(ChannelAttention(256),
62 | nn.Conv2d(256, 256, kernel_size = 3, stride = 1, padding = 1 ),
63 | nn.BatchNorm2d(256), nn.ReLU(inplace=True))
64 |
65 | self.ca0 = nn.Sequential(ChannelAttention(128),
66 | nn.Conv2d(128, 128, kernel_size = 3, stride = 1, padding = 1 ),
67 | nn.BatchNorm2d(128), nn.ReLU(inplace=True))
68 |
69 | self.res2 = nn.Sequential(
70 | nn.Conv2d(512, 256, 3, padding=1, dilation=1), nn.BatchNorm2d(256),
71 | nn.ReLU(inplace=True),
72 | nn.Conv2d(256, 128, 3, padding=1, dilation=1), nn.BatchNorm2d(128),
73 | nn.ReLU(inplace=True),
74 | nn.Conv2d(128, 1, 3, padding=1, dilation=1),
75 | nn.ReLU(inplace=True))
76 |
77 | self.res1 = nn.Sequential(
78 | nn.Conv2d(256, 128, 3, padding=1, dilation=1), nn.BatchNorm2d(128),
79 | nn.ReLU(inplace=True),
80 | nn.Conv2d(128, 64, 3, padding=1, dilation=1), nn.BatchNorm2d(64),
81 | nn.ReLU(inplace=True),
82 | nn.Conv2d(64, 1, 3, padding=1, dilation=1),
83 | nn.ReLU(inplace=True))
84 |
85 | self.res0 = nn.Sequential(
86 | nn.Conv2d(128, 64, 3, padding=1, dilation=1), nn.BatchNorm2d(64),
87 | nn.ReLU(inplace=True),
88 | nn.Conv2d(64, 1, 3, padding=1, dilation=1),
89 | nn.ReLU(inplace=True))
90 |
91 | self.noise2 = DropOutDecoder(1, 512, 512)
92 | self.noise1 = FeatureDropDecoder(1, 256, 256)
93 | self.noise0 = FeatureNoiseDecoder(1, 128, 128)
94 |
95 | self.upsam2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
96 | self.upsam4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
97 |
98 | self.conv1 = nn.Conv2d(1024, 512, kernel_size=1, bias=False)
99 | self.conv2 = nn.Conv2d(512, 256, kernel_size=1, bias=False)
100 | self.conv3 = nn.Conv2d(256, 128, kernel_size=1, bias=False)
101 | self.conv4 = nn.Conv2d(128, 1, kernel_size=1, bias=False)
102 |
103 | #cls2.view(8, 1024, 1, 1))
104 |
105 | self.init_param()
106 |
107 | def forward(self, x, cls):
108 | x0 = x[0]; x1 = x[1]; x2 = x[2]; x3 = x[3]
109 | cls0 = cls[0].view(cls[0].shape[0], cls[0].shape[1], 1, 1)
110 | cls1 = cls[1].view(cls[1].shape[0], cls[1].shape[1], 1, 1)
111 | cls2 = cls[2].view(cls[2].shape[0], cls[2].shape[1], 1, 1)
112 |
113 | x2_1 = self.ca2(x2)+self.v3(x3)
114 | x1_1 = self.ca1(x1)+self.v2(x2_1)
115 | x0_1 = self.ca0(x0)+self.v1(x1_1)
116 |
117 | if self.training:
118 | yc2 = self.conv4(self.conv3(self.conv2(self.noise2(self.conv1(cls2))))).squeeze()
119 | yc1 = self.conv4(self.conv3(self.noise1(self.conv2(cls1)))).squeeze()
120 | yc0 = self.conv4(self.noise0(self.conv3(cls0))).squeeze()
121 |
122 | y2 = self.res2(self.upsam4(self.noise2(x2_1)))
123 | y1 = self.res1(self.upsam2(self.noise1(x1_1)))
124 | y0 = self.res0(self.noise0(x0_1))
125 |
126 | else:
127 | yc2 = self.conv4(self.conv3(self.conv2(self.conv1(cls2)))).squeeze()
128 | yc1 = self.conv4(self.conv3(self.conv2(cls1))).squeeze()
129 | yc0 = self.conv4(self.conv3(cls0)).squeeze()
130 |
131 | y2 = self.res2(self.upsam4(x2_1))
132 | y1 = self.res1(self.upsam2(x1_1))
133 | y0 = self.res0(x0_1)
134 |
135 | return [y0, y1, y2], [yc0, yc1, yc2]
136 |
137 | def init_param(self):
138 | for m in self.modules():
139 | if isinstance(m, nn.Conv2d):
140 | nn.init.normal_(m.weight, std=0.01)
141 | if m.bias is not None:
142 | nn.init.constant_(m.bias, 0)
143 | elif isinstance(m, nn.BatchNorm2d):
144 | nn.init.constant_(m.weight, 1)
145 | nn.init.constant_(m.bias, 0)
146 |
147 |
148 | class Mlp(nn.Module):
149 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
150 | super().__init__()
151 | out_features = out_features or in_features
152 | hidden_features = hidden_features or in_features
153 | self.fc1 = nn.Linear(in_features, hidden_features)
154 | self.act = act_layer()
155 | self.fc2 = nn.Linear(hidden_features, out_features)
156 | self.drop = nn.Dropout(drop)
157 |
158 | def forward(self, x):
159 | x = self.fc1(x)
160 | x = self.act(x)
161 | x = self.drop(x)
162 | x = self.fc2(x)
163 | x = self.drop(x)
164 | return x
165 |
166 |
167 |
168 | def upsample(in_channels, out_channels, upscale, kernel_size=3):
169 | # A series of x 2 upsamling until we get to the upscale we want
170 | layers = []
171 | conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
172 | nn.init.kaiming_normal_(conv1x1.weight.data, nonlinearity='relu')
173 | layers.append(conv1x1)
174 | for i in range(int(math.log(upscale, 2))):
175 | layers.append(PixelShuffle(out_channels, scale=2))
176 | return nn.Sequential(*layers)
177 |
178 |
179 |
180 | class FeatureDropDecoder(nn.Module):
181 | def __init__(self, upscale, conv_in_ch, num_classes):
182 | super(FeatureDropDecoder, self).__init__()
183 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
184 |
185 | def feature_dropout(self, x):
186 | attention = torch.mean(x, dim=1, keepdim=True)
187 | max_val, _ = torch.max(attention.view(x.size(0), -1), dim=1, keepdim=True)
188 | threshold = max_val * np.random.uniform(0.7, 0.9)
189 | threshold = threshold.view(x.size(0), 1, 1, 1).expand_as(attention)
190 | drop_mask = (attention < threshold).float()
191 | return x.mul(drop_mask)
192 |
193 | def forward(self, x):
194 | x = self.feature_dropout(x)
195 | return x
196 |
197 |
198 | class FeatureNoiseDecoder(nn.Module):
199 | def __init__(self, upscale, conv_in_ch, num_classes, uniform_range=0.3):
200 | super(FeatureNoiseDecoder, self).__init__()
201 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
202 | self.uni_dist = Uniform(-uniform_range, uniform_range)
203 |
204 | def feature_based_noise(self, x):
205 | noise_vector = self.uni_dist.sample(x.shape[1:]).to(x.device).unsqueeze(0)
206 | x_noise = x.mul(noise_vector) + x
207 | return x_noise
208 |
209 | def forward(self, x):
210 | x = self.feature_based_noise(x)
211 | return x
212 |
213 | class DropOutDecoder(nn.Module):
214 | def __init__(self, upscale, conv_in_ch, num_classes, drop_rate=0.3, spatial_dropout=True):
215 | super(DropOutDecoder, self).__init__()
216 | self.dropout = nn.Dropout2d(p=drop_rate) if spatial_dropout else nn.Dropout(drop_rate)
217 | self.upsample = upsample(conv_in_ch, num_classes, upscale=upscale)
218 |
219 | def forward(self, x):
220 | x = self.dropout(x)
221 | return x
222 |
223 |
224 | ## ChannelAttetion
225 | class ChannelAttention(nn.Module):
226 | def __init__(self, in_planes, ratio=16):
227 | super(ChannelAttention, self).__init__()
228 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
229 |
230 | self.fc = nn.Sequential(
231 | nn.Linear(in_planes,in_planes // ratio, bias = False),
232 | nn.ReLU(inplace = True),
233 | nn.Linear(in_planes // ratio, in_planes, bias = False)
234 | )
235 | self.sigmoid = nn.Sigmoid()
236 | for m in self.modules():
237 | if isinstance(m, nn.Conv2d):
238 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
239 |
240 | def forward(self, in_feature):
241 | x = in_feature
242 | b, c, _, _ = in_feature.size()
243 | avg_out = self.fc(self.avg_pool(x).view(b,c)).view(b, c, 1, 1)
244 | out = avg_out
245 | return self.sigmoid(out).expand_as(in_feature) * in_feature
246 |
247 |
248 | class Attention(nn.Module):
249 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
250 | super().__init__()
251 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
252 |
253 | self.dim = dim
254 | self.num_heads = num_heads
255 | head_dim = dim // num_heads
256 | self.scale = qk_scale or head_dim ** -0.5
257 |
258 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
259 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
260 | self.attn_drop = nn.Dropout(attn_drop)
261 | self.proj = nn.Linear(dim, dim)
262 | self.proj_drop = nn.Dropout(proj_drop)
263 |
264 | self.sr_ratio = sr_ratio
265 | if sr_ratio > 1:
266 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
267 | self.norm = nn.LayerNorm(dim)
268 |
269 | def forward(self, x, H, W):
270 | B, N, C = x.shape
271 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
272 |
273 | if self.sr_ratio > 4:
274 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
275 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
276 | x_ = self.norm(x_)
277 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
278 | else:
279 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
280 | k, v = kv[0], kv[1]
281 |
282 | attn = (q @ k.transpose(-2, -1)) * self.scale
283 | attn = attn.softmax(dim=-1)
284 | attn = self.attn_drop(attn)
285 |
286 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
287 | x = self.proj(x)
288 | x = self.proj_drop(x)
289 |
290 | return x
291 |
292 |
293 | class Block(nn.Module):
294 |
295 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
296 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
297 | super().__init__()
298 | self.norm1 = norm_layer(dim)
299 | self.attn = Attention(
300 | dim,
301 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
302 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
303 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
304 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
305 | self.norm2 = norm_layer(dim)
306 | mlp_hidden_dim = int(dim * mlp_ratio)
307 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
308 |
309 | def forward(self, x, H, W):
310 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
311 | x = x + self.drop_path(self.mlp(self.norm2(x)))
312 |
313 | return x
314 |
315 |
316 | class PatchEmbed(nn.Module):
317 | """ Image to Patch Embedding
318 | """
319 |
320 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
321 | super().__init__()
322 | img_size = to_2tuple(img_size)
323 | patch_size = to_2tuple(patch_size)
324 |
325 | self.img_size = img_size
326 | self.patch_size = patch_size
327 | # assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
328 | # f"img_size {img_size} should be divided by patch_size {patch_size}."
329 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
330 | self.num_patches = self.H * self.W
331 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
332 | self.norm = nn.LayerNorm(embed_dim)
333 |
334 | def forward(self, x):
335 | B, C, H, W = x.shape
336 |
337 | x = self.proj(x).flatten(2).transpose(1, 2)
338 | x = self.norm(x)
339 | H, W = H // self.patch_size[0], W // self.patch_size[1]
340 |
341 | return x, (H, W)
342 |
343 |
344 | class PyramidVisionTransformer(nn.Module):
345 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
346 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
347 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
348 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4):
349 | super().__init__()
350 | self.num_classes = num_classes
351 | self.depths = depths
352 | self.num_stages = num_stages
353 |
354 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
355 | cur = 0
356 |
357 | for i in range(num_stages):
358 | patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
359 | patch_size=patch_size if i == 0 else 2,
360 | in_chans=in_chans if i == 0 else embed_dims[i - 1],
361 | embed_dim=embed_dims[i])
362 | num_patches = patch_embed.num_patches if i == 0 else patch_embed.num_patches + 1
363 | pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i]))
364 | pos_drop = nn.Dropout(p=drop_rate)
365 |
366 | block = nn.ModuleList([Block(
367 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
368 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j],
369 | norm_layer=norm_layer, sr_ratio=sr_ratios[i])
370 | for j in range(depths[i])])
371 | cur += depths[i]
372 |
373 | setattr(self, f"patch_embed{i + 1}", patch_embed)
374 | setattr(self, f"pos_embed{i + 1}", pos_embed)
375 | setattr(self, f"pos_drop{i + 1}", pos_drop)
376 | setattr(self, f"block{i + 1}", block)
377 |
378 | self.norm = norm_layer(embed_dims[3])
379 |
380 | # cls_token
381 | self.cls_token_1 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
382 | self.cls_token_2 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
383 | self.cls_token_3 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
384 |
385 | # classification head
386 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
387 |
388 |
389 | self.regression = Regression()
390 |
391 | # init weights
392 | for i in range(num_stages):
393 | pos_embed = getattr(self, f"pos_embed{i + 1}")
394 | trunc_normal_(pos_embed, std=.02)
395 | trunc_normal_(self.cls_token_1, std=.02)
396 | trunc_normal_(self.cls_token_2, std=.02)
397 | trunc_normal_(self.cls_token_3, std=.02)
398 | self.apply(self._init_weights)
399 |
400 |
401 | def _init_weights(self, m):
402 | if isinstance(m, nn.Linear):
403 | trunc_normal_(m.weight, std=.02)
404 | if isinstance(m, nn.Linear) and m.bias is not None:
405 | nn.init.constant_(m.bias, 0)
406 | elif isinstance(m, nn.LayerNorm):
407 | nn.init.constant_(m.bias, 0)
408 | nn.init.constant_(m.weight, 1.0)
409 |
410 | @torch.jit.ignore
411 | def no_weight_decay(self):
412 | # return {'pos_embed', 'cls_token'} # has pos_embed may be better
413 | return {'cls_token'}
414 |
415 | def get_classifier(self):
416 | return self.head
417 |
418 | def reset_classifier(self, num_classes, global_pool=''):
419 | self.num_classes = num_classes
420 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
421 |
422 | def _get_pos_embed(self, pos_embed, patch_embed, H, W):
423 | if H * W == self.patch_embed1.num_patches:
424 | return pos_embed
425 | else:
426 | return F.interpolate(
427 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
428 | size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)
429 |
430 | def forward_features(self, x):
431 | B = x.shape[0]
432 | outputs = list()
433 | cls_output = list()
434 |
435 | for i in range(self.num_stages):
436 | patch_embed = getattr(self, f"patch_embed{i + 1}")
437 | pos_embed = getattr(self, f"pos_embed{i + 1}")
438 | pos_drop = getattr(self, f"pos_drop{i + 1}")
439 | block = getattr(self, f"block{i + 1}")
440 | x, (H, W) = patch_embed(x)
441 |
442 | if i == 0:
443 | pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W)
444 | elif i == 1:
445 | cls_tokens = self.cls_token_1.expand(B, -1, -1)
446 | x = torch.cat((cls_tokens, x), dim=1)
447 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
448 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
449 | elif i == 2:
450 | cls_tokens = self.cls_token_2.expand(B, -1, -1)
451 | x = torch.cat((cls_tokens, x), dim=1)
452 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
453 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
454 |
455 | elif i == 3:
456 | cls_tokens = self.cls_token_3.expand(B, -1, -1)
457 | x = torch.cat((cls_tokens, x), dim=1)
458 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
459 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
460 |
461 |
462 | x = pos_drop(x + pos_embed)
463 | for blk in block:
464 | x = blk(x, H, W)
465 |
466 | if i == 0:
467 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
468 | else:
469 |
470 | x_cls = x[:,1,:]
471 | x = x[:,1:,:].reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
472 | cls_output.append(x_cls)
473 |
474 |
475 | outputs.append(x)
476 | return outputs, cls_output
477 |
478 |
479 | def forward(self, label_x, unlabel_x=None):
480 |
481 | if self.training:
482 | # labeled image processing
483 | label_x, l_cls = self.forward_features(label_x)
484 | out_label_x, out_cls_l = self.regression(label_x, l_cls)
485 | label_x_1, label_x_2, label_x_3 = out_label_x
486 |
487 | B,C,H,W = label_x_1.size()
488 | label_sum = label_x_1.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
489 | label_normed = label_x_1 / (label_sum + 1e-6)
490 |
491 | # unlabeled image processing
492 | B,C,H,W = unlabel_x.shape
493 | unlabel_x, ul_cls = self.forward_features(unlabel_x)
494 | out_unlabel_x, out_cls_ul = self.regression(unlabel_x, ul_cls)
495 | y0, y1, y2 = out_unlabel_x
496 |
497 | unlabel_x_1 = self.generate_feature_patches(y0)
498 | unlabel_x_2 = self.generate_feature_patches(y1)
499 | unlabel_x_3 = self.generate_feature_patches(y2)
500 |
501 | assert unlabel_x_1.shape[0] == B * 5
502 | assert unlabel_x_2.shape[0] == B * 5
503 | assert unlabel_x_3.shape[0] == B * 5
504 |
505 | unlabel_x_1 = torch.split(unlabel_x_1, split_size_or_sections=B, dim=0)
506 | unlabel_x_2 = torch.split(unlabel_x_2, split_size_or_sections=B, dim=0)
507 | unlabel_x_3 = torch.split(unlabel_x_3, split_size_or_sections=B, dim=0)
508 |
509 | return [label_x_1, label_x_2, label_x_3], [unlabel_x_1, unlabel_x_2, unlabel_x_3], label_normed, out_cls_l, out_cls_ul
510 |
511 |
512 | else:
513 |
514 | label_x, l_cls = self.forward_features(label_x)
515 | out_label_x, out_cls_l = self.regression(label_x, l_cls)
516 | label_x_1, label_x_2, label_x_3 = out_label_x
517 | B,C,H,W = label_x_1.size()
518 | label_sum = label_x_1.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3)
519 | label_normed = label_x_1 / (label_sum + 1e-6)
520 |
521 | return [label_x_1, label_x_2, label_x_3], label_normed
522 |
523 |
524 | def generate_feature_patches(self, unlabel_x, ratio=0.75):
525 | # unlabeled image processing
526 |
527 | unlabel_x_1 = unlabel_x
528 | b, c, h, w = unlabel_x.shape
529 |
530 | center_x = random.randint(h // 2 - (h - h * ratio) // 2, h // 2 + (h - h * ratio) // 2)
531 | center_y = random.randint(w // 2 - (w - w * ratio) // 2, w // 2 + (w - w * ratio) // 2)
532 |
533 | new_h2 = int(h * ratio)
534 | new_w2 = int(w * ratio) # 48*48
535 | unlabel_x_2 = unlabel_x[:, :, center_x - new_h2 // 2:center_x + new_h2 // 2,
536 | center_y - new_w2 // 2:center_y + new_w2 // 2]
537 |
538 | new_h3 = int(new_h2 * ratio)
539 | new_w3 = int(new_w2 * ratio)
540 | unlabel_x_3 = unlabel_x[:, :, center_x - new_h3 // 2:center_x + new_h3 // 2,
541 | center_y - new_w3 // 2:center_y + new_w3 // 2]
542 |
543 | new_h4 = int(new_h3 * ratio)
544 | new_w4 = int(new_w3 * ratio)
545 | unlabel_x_4 = unlabel_x[:, :, center_x - new_h4 // 2:center_x + new_h4 // 2,
546 | center_y - new_w4 // 2:center_y + new_w4 // 2]
547 |
548 | new_h5 = int(new_h4 * ratio)
549 | new_w5 = int(new_w4 * ratio)
550 | unlabel_x_5 = unlabel_x[:, :, center_x - new_h5 // 2:center_x + new_h5 // 2,
551 | center_y - new_w5 // 2:center_y + new_w5 // 2]
552 |
553 | unlabel_x_2 = nn.functional.interpolate(unlabel_x_2, size=(h, w), mode='bilinear')
554 | unlabel_x_3 = nn.functional.interpolate(unlabel_x_3, size=(h, w), mode='bilinear')
555 | unlabel_x_4 = nn.functional.interpolate(unlabel_x_4, size=(h, w), mode='bilinear')
556 | unlabel_x_5 = nn.functional.interpolate(unlabel_x_5, size=(h, w), mode='bilinear')
557 |
558 | unlabel_x = torch.cat([unlabel_x_1, unlabel_x_2, unlabel_x_3, unlabel_x_4, unlabel_x_5], dim=0)
559 |
560 | return unlabel_x
561 |
562 | def _conv_filter(state_dict, patch_size=16):
563 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
564 | out_dict = {}
565 | for k, v in state_dict.items():
566 | if 'patch_embed.proj.weight' in k:
567 | v = v.reshape((v.shape[0], 3, patch_size, patch_size))
568 | out_dict[k] = v
569 |
570 | return out_dict
571 |
572 |
573 | @register_model
574 | def pvt_tiny(pretrained=False, **kwargs):
575 | model = PyramidVisionTransformer(
576 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
577 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
578 | **kwargs)
579 | model.default_cfg = _cfg()
580 |
581 | return model
582 |
583 |
584 | @register_model
585 | def pvt_small(pretrained=False, **kwargs):
586 | model = PyramidVisionTransformer(
587 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
588 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
589 | model.default_cfg = _cfg()
590 |
591 | return model
592 |
593 |
594 | @register_model
595 | def pvt_medium(pretrained=False, **kwargs):
596 | model = PyramidVisionTransformer(
597 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
598 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
599 | **kwargs)
600 | model.default_cfg = _cfg()
601 |
602 | return model
603 |
604 |
605 | @register_model
606 | def pvt_large(pretrained=False, **kwargs):
607 | model = PyramidVisionTransformer(
608 | patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
609 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
610 | **kwargs)
611 | model.default_cfg = _cfg()
612 |
613 | return model
614 |
615 | @register_model
616 | def pvt_treeformer(pretrained=False, **kwargs):
617 | model = PyramidVisionTransformer(
618 | patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
619 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
620 | **kwargs)
621 | model.default_cfg = _cfg()
622 |
623 | return model
--------------------------------------------------------------------------------