├── utils ├── __init__.py ├── log_utils.py └── pytorch_utils.py ├── datasets ├── __init__.py └── crowd.py ├── losses ├── __init__.py ├── ot_loss.py └── bregman_pytorch.py ├── preprocess ├── __init__.py ├── qnrf_val.txt ├── preprocess_dataset_qnrf.py ├── preprocess_dataset_nwpu.py └── qnrf_train.txt ├── .gitattributes ├── example_images ├── 1.png ├── 2.png └── 3.png ├── vis └── part_A_final │ └── IMG_121 │ ├── image.png │ ├── gt_dmap.png │ └── pred_map.png ├── requirements.txt ├── preprocess_dataset.py ├── README.md ├── .gitignore ├── train.py ├── test_image_patch.py ├── vis_densityMap.py ├── train_helper_ALTGVT.py └── Networks └── ALTGVT.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /example_images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wfs123456/CCTrans/HEAD/example_images/1.png -------------------------------------------------------------------------------- /example_images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wfs123456/CCTrans/HEAD/example_images/2.png -------------------------------------------------------------------------------- /example_images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wfs123456/CCTrans/HEAD/example_images/3.png -------------------------------------------------------------------------------- /vis/part_A_final/IMG_121/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wfs123456/CCTrans/HEAD/vis/part_A_final/IMG_121/image.png -------------------------------------------------------------------------------- /vis/part_A_final/IMG_121/gt_dmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wfs123456/CCTrans/HEAD/vis/part_A_final/IMG_121/gt_dmap.png -------------------------------------------------------------------------------- /vis/part_A_final/IMG_121/pred_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wfs123456/CCTrans/HEAD/vis/part_A_final/IMG_121/pred_map.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/cu113/torch_stable.html 2 | torch==1.10.0+cu113 3 | torchvision==0.11.1+cu113 4 | torchaudio==0.10.0+cu113 5 | numpy>=1.16.5 6 | scipy>=1.3.0 7 | opencv-python 8 | gdown 9 | pillow 10 | gradio 11 | timm==0.4.12 12 | wandb 13 | matplotlib -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def get_logger(log_file): 5 | logger = logging.getLogger(log_file) 6 | logger.setLevel(logging.DEBUG) 7 | fh = logging.FileHandler(log_file) 8 | fh.setLevel(logging.DEBUG) 9 | ch = logging.StreamHandler() 10 | ch.setLevel(logging.INFO) 11 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 12 | ch.setFormatter(formatter) 13 | fh.setFormatter(formatter) 14 | logger.addHandler(ch) 15 | logger.addHandler(fh) 16 | return logger 17 | 18 | 19 | def print_config(config, logger): 20 | """ 21 | Print configuration of the model 22 | """ 23 | for k, v in config.items(): 24 | logger.info("{}:\t{}".format(k.ljust(15), v)) 25 | -------------------------------------------------------------------------------- /preprocess_dataset.py: -------------------------------------------------------------------------------- 1 | # Preprocess images in QNRF and NWPU dataset. 2 | 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='Preprocess') 6 | parser.add_argument('--dataset', default='qnrf', 7 | help='dataset name, only support qnrf and nwpu') 8 | parser.add_argument('--input-dataset-path', default='data/QNRF', 9 | help='original data directory') 10 | parser.add_argument('--output-dataset-path', default='data/QNRF-Train-Val-Test', 11 | help='processed data directory') 12 | args = parser.parse_args() 13 | 14 | if args.dataset.lower() == 'qnrf': 15 | from preprocess.preprocess_dataset_qnrf import main 16 | 17 | main(args.input_dataset_path, args.output_dataset_path, 512, 2048) 18 | elif args.dataset.lower() == 'nwpu': 19 | from preprocess.preprocess_dataset_nwpu import main 20 | 21 | main(args.input_dataset_path, args.output_dataset_path, 384, 1920) 22 | else: 23 | raise NotImplementedError 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CCTrans: Simplifying and Improving Crowd Counting with Transformer(Code reproduction) 2 | * Code reproduction 3 | * Original paper [Link](https://arxiv.org/pdf/2109.14483.pdf) 4 | 5 | ## Overview 6 | * Presentate only the experiment on dataset ShanghaiTech Part A (loss: DM-Count) 7 | * ShanghaiTech Part A 8 | 9 | | Code | MAE | MSE | 10 | |-----------|-------|-------| 11 | | PAPER | 54.8 | 86.6 | 12 | | This code | 54.20 | 88.97 | 13 | 14 | Our code reaches this result with the standard hyperparameter set in code. Trained with batch-size=8 for around 1500 epoch(as said in the paper). Best validation at around epoch 606 15 | # code framework 16 | * adopt code of DM-Count. 17 | * [link](https://github.com/cvlab-stonybrook/DM-Count) 18 | 19 | # Training 20 | Take a look at the arguments accepted by ```train.py``` 21 | * update root "data-dir" in ./train.py. 22 | * load pretrained weights of ImageNet-1k in ./Networks/ALTGVT.py. 23 | * pretrained weights [link](https://drive.google.com/file/d/1um39wxIaicmOquP2fr_SiZdxNCUou8w-/view) 24 | * [new] Added [wandb](https://wandb.ai/) integration. If you want to log with wandb, set ```--wandb 1``` in ```train.py``` after having logged in to wandb (```wandb login``` in console) 25 | * launch with ```python train.py``` 26 | 27 | # Testing 28 | * python test_image_patch.py 29 | * Due to crop training with size of 256x256, the validation image is divided into several patches with size of 256x256, and the overlapping area is averaged. 30 | * Download the pretrained model from Baidu-Disk(提取码: se59) [link](https://pan.baidu.com/s/16qY_cFIUAUaDRsdr5vNsWQ) 31 | 32 | # Visualization 33 | * python vis_densityMap.py 34 | * save to ./vis/part_A_final 35 | 36 | # Environment 37 | See requirements.txt 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10): 4 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 5 | lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6) 6 | for param_group in optimizer.param_groups: 7 | param_group['lr'] = lr 8 | 9 | 10 | class Save_Handle(object): 11 | """handle the number of """ 12 | def __init__(self, max_num): 13 | self.save_list = [] 14 | self.max_num = max_num 15 | 16 | def append(self, save_path): 17 | if len(self.save_list) < self.max_num: 18 | self.save_list.append(save_path) 19 | else: 20 | remove_path = self.save_list[0] 21 | del self.save_list[0] 22 | self.save_list.append(save_path) 23 | if os.path.exists(remove_path): 24 | os.remove(remove_path) 25 | 26 | 27 | class AverageMeter(object): 28 | """Computes and stores the average and current value""" 29 | def __init__(self): 30 | self.reset() 31 | 32 | def reset(self): 33 | self.val = 0 34 | self.avg = 0 35 | self.sum = 0 36 | self.count = 0 37 | 38 | def update(self, val, n=1): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = 1.0 * self.sum / self.count 43 | 44 | def get_avg(self): 45 | return self.avg 46 | 47 | def get_count(self): 48 | return self.count 49 | 50 | 51 | def set_trainable(model, requires_grad): 52 | for param in model.parameters(): 53 | param.requires_grad = requires_grad 54 | 55 | 56 | 57 | def get_num_params(model): 58 | return sum(p.numel() for p in model.parameters() if p.requires_grad) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | 116 | # wandb 117 | wandb/ 118 | 119 | #other project file 120 | *_test.txt -------------------------------------------------------------------------------- /preprocess/qnrf_val.txt: -------------------------------------------------------------------------------- 1 | img_0042.jpg 2 | img_0697.jpg 3 | img_0012.jpg 4 | img_0062.jpg 5 | img_0990.jpg 6 | img_1048.jpg 7 | img_0576.jpg 8 | img_0802.jpg 9 | img_0116.jpg 10 | img_0119.jpg 11 | img_0967.jpg 12 | img_0054.jpg 13 | img_0782.jpg 14 | img_0514.jpg 15 | img_0929.jpg 16 | img_0809.jpg 17 | img_0033.jpg 18 | img_0125.jpg 19 | img_0633.jpg 20 | img_0038.jpg 21 | img_0775.jpg 22 | img_0600.jpg 23 | img_0157.jpg 24 | img_0824.jpg 25 | img_0103.jpg 26 | img_0984.jpg 27 | img_0250.jpg 28 | img_0505.jpg 29 | img_0631.jpg 30 | img_0556.jpg 31 | img_1049.jpg 32 | img_1181.jpg 33 | img_0097.jpg 34 | img_0536.jpg 35 | img_1104.jpg 36 | img_0733.jpg 37 | img_1130.jpg 38 | img_0808.jpg 39 | img_0086.jpg 40 | img_0302.jpg 41 | img_0114.jpg 42 | img_0470.jpg 43 | img_0715.jpg 44 | img_0641.jpg 45 | img_0557.jpg 46 | img_0510.jpg 47 | img_0152.jpg 48 | img_0485.jpg 49 | img_0190.jpg 50 | img_0065.jpg 51 | img_0839.jpg 52 | img_0068.jpg 53 | img_0864.jpg 54 | img_0477.jpg 55 | img_0441.jpg 56 | img_0546.jpg 57 | img_0091.jpg 58 | img_0853.jpg 59 | img_0975.jpg 60 | img_0357.jpg 61 | img_1004.jpg 62 | img_0794.jpg 63 | img_0750.jpg 64 | img_0791.jpg 65 | img_0605.jpg 66 | img_0590.jpg 67 | img_0489.jpg 68 | img_0191.jpg 69 | img_0007.jpg 70 | img_0778.jpg 71 | img_0658.jpg 72 | img_0289.jpg 73 | img_0925.jpg 74 | img_1184.jpg 75 | img_0521.jpg 76 | img_0291.jpg 77 | img_0823.jpg 78 | img_0382.jpg 79 | img_0416.jpg 80 | img_0736.jpg 81 | img_0268.jpg 82 | img_0128.jpg 83 | img_0280.jpg 84 | img_1022.jpg 85 | img_0545.jpg 86 | img_0257.jpg 87 | img_0251.jpg 88 | img_0684.jpg 89 | img_1092.jpg 90 | img_0638.jpg 91 | img_1079.jpg 92 | img_0790.jpg 93 | img_0811.jpg 94 | img_0303.jpg 95 | img_0542.jpg 96 | img_1019.jpg 97 | img_0472.jpg 98 | img_0027.jpg 99 | img_0539.jpg 100 | img_0856.jpg 101 | img_1094.jpg 102 | img_1030.jpg 103 | img_1063.jpg 104 | img_0887.jpg 105 | img_0067.jpg 106 | img_0379.jpg 107 | img_0919.jpg 108 | img_1155.jpg 109 | img_0221.jpg 110 | img_1053.jpg 111 | img_0916.jpg 112 | img_1072.jpg 113 | img_0347.jpg 114 | img_1199.jpg 115 | img_1080.jpg 116 | img_0385.jpg 117 | img_0344.jpg 118 | img_1073.jpg 119 | img_0339.jpg 120 | img_0338.jpg 121 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from train_helper_ALTGVT import Trainer 5 | 6 | def str2bool(v): 7 | return v.lower() in ("yes", "true", "t", "1") 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Train') 12 | parser.add_argument('--data-dir', default='/home/hdd/dataset_xzw/ShanghaiTech/part_A_final', help='data path') 13 | parser.add_argument('--dataset', default='sha', help='dataset name: qnrf, nwpu, sha, shb, custom') 14 | parser.add_argument('--lr', type=float, default=1e-5, 15 | help='the initial learning rate') 16 | parser.add_argument('--weight-decay', type=float, default=1e-4, 17 | help='the weight decay') 18 | parser.add_argument('--resume', default='', type=str, 19 | help='the path of resume training model') 20 | parser.add_argument('--max-epoch', type=int, default=4000, 21 | help='max training epoch') 22 | parser.add_argument('--val-epoch', type=int, default=1, 23 | help='the num of steps to log training information') 24 | parser.add_argument('--val-start', type=int, default=0, 25 | help='the epoch start to val') 26 | parser.add_argument('--batch-size', type=int, default=16, 27 | help='train batch size') 28 | parser.add_argument('--device', default='0', help='assign device') 29 | parser.add_argument('--num-workers', type=int, default=16, 30 | help='the num of training process') 31 | parser.add_argument('--crop-size', type=int, default= 256, 32 | help='the crop size of the train image') 33 | parser.add_argument('--wot', type=float, default=0.1, help='weight on OT loss') 34 | parser.add_argument('--wtv', type=float, default=0.01, help='weight on TV loss') 35 | parser.add_argument('--reg', type=float, default=10.0, 36 | help='entropy regularization in sinkhorn') 37 | parser.add_argument('--num-of-iter-in-ot', type=int, default=100, 38 | help='sinkhorn iterations') 39 | parser.add_argument('--norm-cood', type=int, default=0, help='whether to norm cood when computing distance') 40 | 41 | parser.add_argument('--run-name', default='CCTrans', help='run name for wandb interface/logging') 42 | parser.add_argument('--wandb', default=0, type=int, help='boolean to set wandb logging') 43 | 44 | args = parser.parse_args() 45 | 46 | if args.dataset.lower() == 'qnrf': 47 | args.crop_size = 512 48 | elif args.dataset.lower() == 'nwpu': 49 | args.crop_size = 384 50 | args.val_epoch = 50 51 | elif args.dataset.lower() == 'sha': 52 | args.crop_size = 256 53 | elif args.dataset.lower() == 'shb': 54 | args.crop_size = 512 55 | elif args.dataset.lower() == 'custom': 56 | args.crop_size = 256 57 | else: 58 | raise NotImplementedError 59 | return args 60 | 61 | 62 | if __name__ == '__main__': 63 | args = parse_args() 64 | torch.backends.cudnn.benchmark = True 65 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip() # set vis gpu 66 | trainer = Trainer(args) 67 | trainer.setup() 68 | trainer.train() 69 | -------------------------------------------------------------------------------- /preprocess/preprocess_dataset_qnrf.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | from glob import glob 6 | import cv2 7 | 8 | dir_name = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | def cal_new_size(im_h, im_w, min_size, max_size): 11 | if im_h < im_w: 12 | if im_h < min_size: 13 | ratio = 1.0 * min_size / im_h 14 | im_h = min_size 15 | im_w = round(im_w * ratio) 16 | elif im_h > max_size: 17 | ratio = 1.0 * max_size / im_h 18 | im_h = max_size 19 | im_w = round(im_w * ratio) 20 | else: 21 | ratio = 1.0 22 | else: 23 | if im_w < min_size: 24 | ratio = 1.0 * min_size / im_w 25 | im_w = min_size 26 | im_h = round(im_h * ratio) 27 | elif im_w > max_size: 28 | ratio = 1.0 * max_size / im_w 29 | im_w = max_size 30 | im_h = round(im_h * ratio) 31 | else: 32 | ratio = 1.0 33 | return im_h, im_w, ratio 34 | 35 | 36 | def generate_data(im_path, min_size, max_size): 37 | im = Image.open(im_path) 38 | im_w, im_h = im.size 39 | mat_path = im_path.replace('.jpg', '_ann.mat') 40 | points = loadmat(mat_path)['annPoints'].astype(np.float32) 41 | idx_mask = (points[:, 0] >= 0) * (points[:, 0] <= im_w) * (points[:, 1] >= 0) * (points[:, 1] <= im_h) 42 | points = points[idx_mask] 43 | im_h, im_w, rr = cal_new_size(im_h, im_w, min_size, max_size) 44 | im = np.array(im) 45 | if rr != 1.0: 46 | im = cv2.resize(np.array(im), (im_w, im_h), cv2.INTER_CUBIC) 47 | points = points * rr 48 | return Image.fromarray(im), points 49 | 50 | 51 | def main(input_dataset_path, output_dataset_path, min_size=512, max_size=2048): 52 | for phase in ['Train', 'Test']: 53 | sub_dir = os.path.join(input_dataset_path, phase) 54 | if phase == 'Train': 55 | sub_phase_list = ['train', 'val'] 56 | for sub_phase in sub_phase_list: 57 | sub_save_dir = os.path.join(output_dataset_path, sub_phase) 58 | if not os.path.exists(sub_save_dir): 59 | os.makedirs(sub_save_dir) 60 | with open(os.path.join(dir_name, 'qnrf_{}.txt'.format(sub_phase))) as f: 61 | for i in f: 62 | im_path = os.path.join(sub_dir, i.strip()) 63 | name = os.path.basename(im_path) 64 | print(name) 65 | im, points = generate_data(im_path, min_size, max_size) 66 | im_save_path = os.path.join(sub_save_dir, name) 67 | im.save(im_save_path) 68 | gd_save_path = im_save_path.replace('jpg', 'npy') 69 | np.save(gd_save_path, points) 70 | else: 71 | sub_save_dir = os.path.join(output_dataset_path, 'test') 72 | if not os.path.exists(sub_save_dir): 73 | os.makedirs(sub_save_dir) 74 | im_list = glob(os.path.join(sub_dir, '*jpg')) 75 | for im_path in im_list: 76 | name = os.path.basename(im_path) 77 | print(name) 78 | im, points = generate_data(im_path, min_size, max_size) 79 | im_save_path = os.path.join(sub_save_dir, name) 80 | im.save(im_save_path) 81 | gd_save_path = im_save_path.replace('jpg', 'npy') 82 | np.save(gd_save_path, points) 83 | -------------------------------------------------------------------------------- /losses/ot_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | from .bregman_pytorch import sinkhorn 4 | 5 | class OT_Loss(Module): 6 | def __init__(self, c_size, stride, norm_cood, device, num_of_iter_in_ot=100, reg=10.0): 7 | super(OT_Loss, self).__init__() 8 | assert c_size % stride == 0 9 | 10 | self.c_size = c_size 11 | self.device = device 12 | self.norm_cood = norm_cood 13 | self.num_of_iter_in_ot = num_of_iter_in_ot 14 | self.reg = reg 15 | 16 | # coordinate is same to image space, set to constant since crop size is same 17 | self.cood = torch.arange(0, c_size, step=stride, 18 | dtype=torch.float32, device=device) + stride / 2 19 | self.density_size = self.cood.size(0) 20 | self.cood.unsqueeze_(0) # [1, #cood] 21 | if self.norm_cood: 22 | self.cood = self.cood / c_size * 2 - 1 # map to [-1, 1] 23 | self.output_size = self.cood.size(1) 24 | 25 | 26 | def forward(self, normed_density, unnormed_density, points): 27 | batch_size = normed_density.size(0) 28 | assert len(points) == batch_size 29 | assert self.output_size == normed_density.size(2) 30 | loss = torch.zeros([1]).to(self.device) 31 | ot_obj_values = torch.zeros([1]).to(self.device) 32 | wd = 0 # wasserstain distance 33 | for idx, im_points in enumerate(points): 34 | if len(im_points) > 0: 35 | # compute l2 square distance, it should be source target distance. [#gt, #cood * #cood] 36 | if self.norm_cood: 37 | im_points = im_points / self.c_size * 2 - 1 # map to [-1, 1] 38 | x = im_points[:, 0].unsqueeze_(1) # [#gt, 1] 39 | y = im_points[:, 1].unsqueeze_(1) 40 | x_dis = -2 * torch.matmul(x, self.cood) + x * x + self.cood * self.cood # [#gt, #cood] 41 | y_dis = -2 * torch.matmul(y, self.cood) + y * y + self.cood * self.cood 42 | y_dis.unsqueeze_(2) 43 | x_dis.unsqueeze_(1) 44 | dis = y_dis + x_dis 45 | dis = dis.view((dis.size(0), -1)) # size of [#gt, #cood * #cood] 46 | 47 | source_prob = normed_density[idx][0].view([-1]).detach() 48 | target_prob = (torch.ones([len(im_points)]) / len(im_points)).to(self.device) 49 | # use sinkhorn to solve OT, compute optimal beta. 50 | P, log = sinkhorn(target_prob, source_prob, dis, self.reg, maxIter=self.num_of_iter_in_ot, log=True) 51 | beta = log['beta'] # size is the same as source_prob: [#cood * #cood] 52 | ot_obj_values += torch.sum(normed_density[idx] * beta.view([1, self.output_size, self.output_size])) 53 | # compute the gradient of OT loss to predicted density (unnormed_density). 54 | # im_grad = beta / source_count - < beta, source_density> / (source_count)^2 55 | source_density = unnormed_density[idx][0].view([-1]).detach() 56 | source_count = source_density.sum() 57 | im_grad_1 = (source_count) / (source_count * source_count+1e-8) * beta # size of [#cood * #cood] 58 | im_grad_2 = (source_density * beta).sum() / (source_count * source_count + 1e-8) # size of 1 59 | im_grad = im_grad_1 - im_grad_2 60 | im_grad = im_grad.detach().view([1, self.output_size, self.output_size]) 61 | # Define loss = . The gradient of loss w.r.t prediced density is im_grad. 62 | loss += torch.sum(unnormed_density[idx] * im_grad) 63 | wd += torch.sum(dis * P).item() 64 | 65 | return loss, wd, ot_obj_values 66 | 67 | 68 | -------------------------------------------------------------------------------- /test_image_patch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import datasets.crowd as crowd 6 | from Networks import ALTGVT 7 | import torch.nn.functional as F 8 | 9 | def tensor_divideByfactor(img_tensor, factor=32): 10 | _, _, h, w = img_tensor.size() 11 | h, w = int(h//factor*factor), int(w//factor*factor) 12 | img_tensor = F.interpolate(img_tensor, (h, w), mode='bilinear', align_corners=True) 13 | 14 | return img_tensor 15 | def cal_new_tensor(img_tensor, min_size=256): 16 | _, _, h, w = img_tensor.size() 17 | if min(h, w) < min_size: 18 | ratio_h, ratio_w = min_size / h, min_size / w 19 | if ratio_h >= ratio_w: 20 | img_tensor = F.interpolate(img_tensor, (min_size, int(min_size / h * w)), mode='bilinear', align_corners=True) 21 | else: 22 | img_tensor = F.interpolate(img_tensor, (int(min_size / w * h), min_size), mode='bilinear', align_corners=True) 23 | return img_tensor 24 | 25 | parser = argparse.ArgumentParser(description='Test ') 26 | parser.add_argument('--device', default='0', help='assign device') 27 | parser.add_argument('--batch-size', type=int, default=8, 28 | help='train batch size') 29 | parser.add_argument('--crop-size', type=int, default=256, 30 | help='the crop size of the train image') 31 | parser.add_argument('--model-path', type=str, required=True, 32 | help='saved model path') 33 | parser.add_argument('--data-path', type=str, 34 | help='dataset path') 35 | parser.add_argument('--dataset', type=str, default='sha', 36 | help='dataset name: qnrf, nwpu, sha, shb, custom') 37 | parser.add_argument('--pred-density-map-path', type=str, default='inference_results', 38 | help='save predicted density maps when pred-density-map-path is not empty.') 39 | 40 | def test(args, isSave = True): 41 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu 42 | device = torch.device('cuda') 43 | 44 | model_path = args.model_path 45 | crop_size = args.crop_size 46 | data_path = args.data_path 47 | if args.dataset.lower() == 'qnrf': 48 | dataset = crowd.Crowd_qnrf(os.path.join(data_path, 'test'), crop_size, 8, method='val') 49 | elif args.dataset.lower() == 'nwpu': 50 | dataset = crowd.Crowd_nwpu(os.path.join(data_path, 'val'), crop_size, 8, method='val') 51 | elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb': 52 | dataset = crowd.Crowd_sh(os.path.join(data_path, 'test_data'), crop_size, 8, method='val') 53 | elif args.dataset.lower() == 'custom': 54 | dataset = crowd.CustomDataset(data_path, crop_size, downsample_ratio=8, method='test') 55 | else: 56 | raise NotImplementedError 57 | dataloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, 58 | num_workers=1, pin_memory=True) 59 | 60 | model = ALTGVT.alt_gvt_large(pretrained=True) 61 | model.to(device) 62 | model.load_state_dict(torch.load(model_path, device)) 63 | model.eval() 64 | image_errs = [] 65 | result = [] 66 | for inputs, count, name in dataloader: 67 | with torch.no_grad(): 68 | # nputs = cal_new_tensor(inputs, min_size=args.crop_size) 69 | inputs = inputs.to(device) 70 | crop_imgs, crop_masks = [], [] 71 | b, c, h, w = inputs.size() 72 | rh, rw = args.crop_size, args.crop_size 73 | for i in range(0, h, rh): 74 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 75 | for j in range(0, w, rw): 76 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 77 | crop_imgs.append(inputs[:, :, gis:gie, gjs:gje]) 78 | mask = torch.zeros([b, 1, h, w]).to(device) 79 | mask[:, :, gis:gie, gjs:gje].fill_(1.0) 80 | crop_masks.append(mask) 81 | crop_imgs, crop_masks = map(lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks)) 82 | 83 | crop_preds = [] 84 | nz, bz = crop_imgs.size(0), args.batch_size 85 | for i in range(0, nz, bz): 86 | gs, gt = i, min(nz, i + bz) 87 | crop_pred, _ = model(crop_imgs[gs:gt]) 88 | 89 | _, _, h1, w1 = crop_pred.size() 90 | 91 | crop_pred = F.interpolate(crop_pred, size=(h1 * 8, w1 * 8), mode='bilinear', align_corners=True) / 64 92 | 93 | crop_preds.append(crop_pred) 94 | crop_preds = torch.cat(crop_preds, dim=0) 95 | 96 | # splice them to the original size 97 | idx = 0 98 | pred_map = torch.zeros([b, 1, h, w]).to(device) 99 | for i in range(0, h, rh): 100 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 101 | for j in range(0, w, rw): 102 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 103 | pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx] 104 | idx += 1 105 | # for the overlapping area, compute average value 106 | mask = crop_masks.sum(dim=0).unsqueeze(0) 107 | outputs = pred_map / mask 108 | 109 | img_err = count[0].item() - torch.sum(outputs).item() 110 | print("Img name: ", name, "Error: ", img_err, "GT count: ", count[0].item(), "Model out: ", torch.sum(outputs).item()) 111 | image_errs.append(img_err) 112 | result.append([name, count[0].item(), torch.sum(outputs).item(), img_err]) 113 | 114 | image_errs = np.array(image_errs) 115 | mse = np.sqrt(np.mean(np.square(image_errs))) 116 | mae = np.mean(np.abs(image_errs)) 117 | print('{}: mae {}, mse {}\n'.format(model_path, mae, mse)) 118 | 119 | if isSave: 120 | with open("ALGVT_sha_test.txt","w") as f: 121 | for i in range(len(result)): 122 | f.write(str(result[i]).replace('[','').replace(']','').replace(',', ' ')+"\n") 123 | f.close() 124 | 125 | if __name__ == '__main__': 126 | args = parser.parse_args() 127 | test(args, isSave= True) 128 | 129 | 130 | -------------------------------------------------------------------------------- /preprocess/preprocess_dataset_nwpu.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | import cv2 6 | 7 | 8 | def cal_new_size_v2(im_h, im_w, min_size, max_size): 9 | rate = 1.0 * max_size / im_h 10 | rate_w = im_w * rate 11 | if rate_w > max_size: 12 | rate = 1.0 * max_size / im_w 13 | tmp_h = int(1.0 * im_h * rate / 16) * 16 14 | 15 | if tmp_h < min_size: 16 | rate = 1.0 * min_size / im_h 17 | tmp_w = int(1.0 * im_w * rate / 16) * 16 18 | 19 | if tmp_w < min_size: 20 | rate = 1.0 * min_size / im_w 21 | tmp_h = min(max(int(1.0 * im_h * rate / 16) * 16, min_size), max_size) 22 | tmp_w = min(max(int(1.0 * im_w * rate / 16) * 16, min_size), max_size) 23 | 24 | rate_h = 1.0 * tmp_h / im_h 25 | rate_w = 1.0 * tmp_w / im_w 26 | assert tmp_h >= min_size and tmp_h <= max_size 27 | assert tmp_w >= min_size and tmp_w <= max_size 28 | return tmp_h, tmp_w, rate_h, rate_w 29 | 30 | 31 | def gen_density_map_gaussian(im_height, im_width, points, sigma=4): 32 | """ 33 | func: generate the density map. 34 | points: [num_gt, 2], for each row: [width, height] 35 | """ 36 | density_map = np.zeros([im_height, im_width], dtype=np.float32) 37 | h, w = density_map.shape[:2] 38 | num_gt = np.squeeze(points).shape[0] 39 | if num_gt == 0: 40 | return density_map 41 | for p in points: 42 | p = np.round(p).astype(int) 43 | p[0], p[1] = min(h - 1, p[1]), min(w - 1, p[0]) 44 | gaussian_radius = sigma * 2 - 1 45 | gaussian_map = np.multiply( 46 | cv2.getGaussianKernel(int(gaussian_radius * 2 + 1), sigma), 47 | cv2.getGaussianKernel(int(gaussian_radius * 2 + 1), sigma).T 48 | ) 49 | x_left, x_right, y_up, y_down = 0, gaussian_map.shape[1], 0, gaussian_map.shape[0] 50 | # cut the gaussian kernel 51 | if p[1] < gaussian_radius: 52 | x_left = gaussian_radius - p[1] 53 | if p[0] < gaussian_radius: 54 | y_up = gaussian_radius - p[0] 55 | if p[1] + gaussian_radius >= w: 56 | x_right = gaussian_map.shape[1] - (gaussian_radius + p[1] - w) - 1 57 | if p[0] + gaussian_radius >= h: 58 | y_down = gaussian_map.shape[0] - (gaussian_radius + p[0] - h) - 1 59 | gaussian_map = gaussian_map[y_up:y_down, x_left:x_right] 60 | if np.sum(gaussian_map): 61 | gaussian_map = gaussian_map / np.sum(gaussian_map) 62 | density_map[ 63 | max(0, p[0] - gaussian_radius):min(h, p[0] + gaussian_radius + 1), 64 | max(0, p[1] - gaussian_radius):min(w, p[1] + gaussian_radius + 1) 65 | ] += gaussian_map 66 | density_map = density_map / (np.sum(density_map / num_gt)) 67 | return density_map 68 | 69 | 70 | def generate_data(im_path, mat_path, min_size, max_size): 71 | im = Image.open(im_path).convert('RGB') 72 | im_w, im_h = im.size 73 | points = loadmat(mat_path)['annPoints'].astype(np.float32) 74 | if len(points) > 0: # some image has no crowd 75 | idx_mask = (points[:, 0] >= 0) * (points[:, 0] <= im_w) * (points[:, 1] >= 0) * (points[:, 1] <= im_h) 76 | points = points[idx_mask] 77 | im_h, im_w, rr_h, rr_w = cal_new_size_v2(im_h, im_w, min_size, max_size) 78 | im = np.array(im) 79 | if rr_h != 1.0 or rr_w != 1.0: 80 | im = cv2.resize(np.array(im), (im_w, im_h), cv2.INTER_CUBIC) 81 | if len(points) > 0: # some image has no crowd 82 | points[:, 0] = points[:, 0] * rr_w 83 | points[:, 1] = points[:, 1] * rr_h 84 | 85 | density_map = gen_density_map_gaussian(im_h, im_w, points, sigma=8) 86 | return Image.fromarray(im), points, density_map 87 | 88 | 89 | def generate_image(im_path, min_size, max_size): 90 | im = Image.open(im_path).convert('RGB') 91 | im_w, im_h = im.size 92 | im_h, im_w, rr_h, rr_w = cal_new_size_v2(im_h, im_w, min_size, max_size) 93 | im = np.array(im) 94 | if rr_h != 1.0 or rr_w != 1.0: 95 | im = cv2.resize(np.array(im), (im_w, im_h), cv2.INTER_CUBIC) 96 | return Image.fromarray(im) 97 | 98 | 99 | def main(input_dataset_path, output_dataset_path, min_size=384, max_size=1920): 100 | ori_img_path = os.path.join(input_dataset_path, 'images') 101 | ori_anno_path = os.path.join(input_dataset_path, 'mats') 102 | 103 | for phase in ['train', 'val']: 104 | sub_save_dir = os.path.join(output_dataset_path, phase) 105 | if not os.path.exists(sub_save_dir): 106 | os.makedirs(sub_save_dir) 107 | with open(os.path.join(input_dataset_path, '{}.txt'.format(phase))) as f: 108 | lines = f.readlines() 109 | for i in lines: 110 | i = i.strip().split(' ')[0] 111 | im_path = os.path.join(ori_img_path, i + '.jpg') 112 | mat_path = os.path.join(ori_anno_path, i + '.mat') 113 | name = os.path.basename(im_path) 114 | im_save_path = os.path.join(sub_save_dir, name) 115 | print(name) 116 | # The Gaussian smoothed density map is just for visualization. It's not used in training. 117 | im, points, density_map = generate_data(im_path, mat_path, min_size, max_size) 118 | im.save(im_save_path) 119 | gd_save_path = im_save_path.replace('jpg', 'npy') 120 | np.save(gd_save_path, points) 121 | dm_save_path = im_save_path.replace('.jpg', '_densitymap.npy') 122 | np.save(dm_save_path, density_map) 123 | 124 | for phase in ['test']: 125 | sub_save_dir = os.path.join(output_dataset_path, phase) 126 | if not os.path.exists(sub_save_dir): 127 | os.makedirs(sub_save_dir) 128 | with open(os.path.join(input_dataset_path, '{}.txt'.format(phase))) as f: 129 | lines = f.readlines() 130 | for i in lines: 131 | i = i.strip().split(' ')[0] 132 | im_path = os.path.join(ori_img_path, i + '.jpg') 133 | name = os.path.basename(im_path) 134 | im_save_path = os.path.join(sub_save_dir, name) 135 | print(name) 136 | im = generate_image(im_path, min_size, max_size) 137 | im.save(im_save_path) 138 | -------------------------------------------------------------------------------- /vis_densityMap.py: -------------------------------------------------------------------------------- 1 | # _*_ coding: utf-8 _*_ 2 | # @author : 王福森 3 | # @time : 2021/12/3 20:54 4 | # @File : vis_densityMap.py 5 | # @Software : PyCharm 6 | from Networks import ALTGVT 7 | import numpy as np 8 | from torch.autograd import Variable 9 | import torchvision.transforms as transforms 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | import scipy.io as io 13 | import torch 14 | import os 15 | import argparse 16 | import torch.nn.functional as F 17 | import cv2 18 | 19 | def vis(args): 20 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu 21 | device = torch.device('cuda') 22 | 23 | model_path = args.weight_path 24 | crop_size = args.crop_size 25 | image_path = args.image_path 26 | 27 | model = ALTGVT.alt_gvt_large(pretrained=True) 28 | model.to(device) 29 | model.load_state_dict(torch.load(model_path, device)) 30 | model.eval() 31 | 32 | if not os.path.exists(image_path): 33 | print("not find image path!") 34 | exit(-1) 35 | 36 | if image_path.split("/")[-4] != "QNRF": 37 | dataset = "/".join(image_path.split("/")[-5:-3]) 38 | else: 39 | dataset = "QNRF" 40 | 41 | print("detect image '%s'..." % image_path) 42 | if not os.path.exists(image_path): 43 | print("not find image path!") 44 | exit(-1) 45 | 46 | if dataset == "QNRF": 47 | mat = io.loadmat( 48 | image_path.replace('.jpg', '.mat').replace('images', 'ground_truth').replace('.', '_ann.').replace( 49 | "UCF-QNRF-Nor", "UCF-QNRF")) 50 | points = mat["annPoints"] 51 | else: 52 | mat = io.loadmat(image_path.replace('.jpg', '.mat').replace('images', 'ground_truth').replace("IMG", "GT_IMG")) 53 | points = mat["image_info"][0, 0][0, 0][0] 54 | 55 | gt_count = len(points) 56 | image = Image.open(image_path).convert("RGB") 57 | wd, ht = image.size 58 | st_size = 1.0 * min(wd, ht) 59 | if st_size < crop_size: 60 | rr = 1.0 * crop_size / st_size 61 | wd = round(wd * rr) 62 | ht = round(ht * rr) 63 | st_size = 1.0 * min(wd, ht) 64 | image = image.resize((wd, ht), Image.BICUBIC) 65 | 66 | # image = np.asarray(image, dtype=np.float32) 67 | # if len(image.shape) == 2: # expand grayscale image to three channel. 68 | # image = image[:, :, np.newaxis] 69 | # image = np.concatenate((image, image, image), 2) 70 | # vis_img = image.copy() 71 | 72 | transform = transforms.Compose([ 73 | transforms.ToTensor(), 74 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 75 | ]) 76 | image = transform(image) 77 | gt_dmap_path = image_path.replace('.jpg', '.npy').replace('images', 'density_maps_constant15') 78 | gt_dmap = np.load(gt_dmap_path) 79 | 80 | with torch.no_grad(): 81 | # nputs = cal_new_tensor(inputs, min_size=args.crop_size) 82 | inputs = image.unsqueeze(0).to(device) 83 | crop_imgs, crop_masks = [], [] 84 | b, c, h, w = inputs.size() 85 | rh, rw = args.crop_size, args.crop_size 86 | for i in range(0, h, rh): 87 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 88 | for j in range(0, w, rw): 89 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 90 | crop_imgs.append(inputs[:, :, gis:gie, gjs:gje]) 91 | mask = torch.zeros([b, 1, h, w]).to(device) 92 | mask[:, :, gis:gie, gjs:gje].fill_(1.0) 93 | crop_masks.append(mask) 94 | crop_imgs, crop_masks = map(lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks)) 95 | 96 | crop_preds = [] 97 | nz, bz = crop_imgs.size(0), args.batch_size 98 | for i in range(0, nz, bz): 99 | gs, gt = i, min(nz, i + bz) 100 | crop_pred, _ = model(crop_imgs[gs:gt]) 101 | 102 | _, _, h1, w1 = crop_pred.size() 103 | crop_pred = F.interpolate(crop_pred, size=(h1 * 8, w1 * 8), mode='bilinear', align_corners=True) / 64 104 | 105 | crop_preds.append(crop_pred) 106 | crop_preds = torch.cat(crop_preds, dim=0) 107 | 108 | # splice them to the original size 109 | idx = 0 110 | pred_map = torch.zeros([b, 1, h, w]).to(device) 111 | for i in range(0, h, rh): 112 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 113 | for j in range(0, w, rw): 114 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 115 | pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx] 116 | idx += 1 117 | # for the overlapping area, compute average value 118 | mask = crop_masks.sum(dim=0).unsqueeze(0) 119 | pred_map = pred_map / mask 120 | pred_map = pred_map.squeeze(0).squeeze(0).cpu().data.numpy() 121 | return pred_map, gt_dmap, gt_count 122 | 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--device', default='0', help='assign device') 127 | parser.add_argument("--image_path", type=str, required=True, 128 | help="the image path to be detected.") 129 | parser.add_argument("--weight_path", type=str, required=True, 130 | help="the weight path to be loaded") 131 | parser.add_argument('--crop_size', type=int, default=256, 132 | help='the crop size of the train image') 133 | parser.add_argument('--batch-size', type=int, default=1, help='train batch size') 134 | 135 | args = parser.parse_args() 136 | print(args) 137 | 138 | pred_map, gt_dmap, gt_count = vis(args) 139 | 140 | save_path = "vis/%s"%(args.image_path.split("/")[-4]+"/"+args.image_path.split("/")[-1][:-4]) 141 | if not os.path.exists(save_path): 142 | os.makedirs(save_path) 143 | 144 | print("predmap count is %.2f, gt_dmap count is %.2f, gt count is %d"%(pred_map.sum(),gt_dmap.sum(),gt_count)) 145 | 146 | vis_img = pred_map 147 | # normalize density map values from 0 to 1, then map it to 0-255. 148 | vis_img = (vis_img - vis_img.min()) / (vis_img.max() - vis_img.min() + 1e-5) 149 | vis_img = (vis_img * 255).astype(np.uint8) 150 | vis_img = cv2.applyColorMap(vis_img, cv2.COLORMAP_JET) 151 | cv2.imwrite("%s/pred_map.png" % save_path, vis_img) 152 | 153 | # plt.imsave("%s/pred_map.png" % save_path, pred_map) 154 | plt.imsave("%s/gt_dmap.png" % save_path, gt_dmap, cmap = 'jet') 155 | 156 | print("the visual result saved in %s"%save_path) -------------------------------------------------------------------------------- /datasets/crowd.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch.utils.data as data 3 | import os 4 | from glob import glob 5 | import torch 6 | import torchvision.transforms.functional as F 7 | from torchvision import transforms 8 | import random 9 | import numpy as np 10 | import scipy.io as sio 11 | 12 | 13 | def random_crop(im_h, im_w, crop_h, crop_w): 14 | res_h = im_h - crop_h 15 | res_w = im_w - crop_w 16 | i = random.randint(0, res_h) 17 | j = random.randint(0, res_w) 18 | return i, j, crop_h, crop_w 19 | 20 | 21 | def gen_discrete_map(im_height, im_width, points): 22 | """ 23 | func: generate the discrete map. 24 | points: [num_gt, 2], for each row: [width, height] 25 | """ 26 | discrete_map = np.zeros([im_height, im_width], dtype=np.float32) 27 | h, w = discrete_map.shape[:2] 28 | num_gt = points.shape[0] 29 | if num_gt == 0: 30 | return discrete_map 31 | 32 | # fast create discrete map 33 | points_np = np.array(points).round().astype(int) 34 | p_h = np.minimum(points_np[:, 1], np.array([h-1]*num_gt).astype(int)) 35 | p_w = np.minimum(points_np[:, 0], np.array([w-1]*num_gt).astype(int)) 36 | p_index = torch.from_numpy(p_h* im_width + p_w).to(torch.int64) 37 | discrete_map = torch.zeros(im_width * im_height).scatter_add_(0, index=p_index, src=torch.ones(im_width*im_height)).view(im_height, im_width).numpy() 38 | 39 | ''' slow method 40 | for p in points: 41 | p = np.round(p).astype(int) 42 | p[0], p[1] = min(h - 1, p[1]), min(w - 1, p[0]) 43 | discrete_map[p[0], p[1]] += 1 44 | ''' 45 | assert np.sum(discrete_map) == num_gt 46 | return discrete_map 47 | 48 | 49 | class Base(data.Dataset): 50 | def __init__(self, root_path, crop_size, downsample_ratio=8): 51 | 52 | self.root_path = root_path 53 | self.c_size = crop_size 54 | self.d_ratio = downsample_ratio 55 | assert self.c_size % self.d_ratio == 0 56 | self.dc_size = self.c_size // self.d_ratio 57 | self.trans = transforms.Compose([ 58 | transforms.ToTensor(), 59 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 60 | ]) 61 | 62 | def __len__(self): 63 | pass 64 | 65 | def __getitem__(self, item): 66 | pass 67 | 68 | def train_transform(self, img, keypoints): 69 | wd, ht = img.size 70 | st_size = 1.0 * min(wd, ht) 71 | assert st_size >= self.c_size 72 | assert len(keypoints) >= 0 73 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) 74 | img = F.crop(img, i, j, h, w) 75 | if len(keypoints) > 0: 76 | keypoints = keypoints - [j, i] 77 | idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \ 78 | (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h) 79 | keypoints = keypoints[idx_mask] 80 | else: 81 | keypoints = np.empty([0, 2]) 82 | 83 | gt_discrete = gen_discrete_map(h, w, keypoints) 84 | down_w = w // self.d_ratio 85 | down_h = h // self.d_ratio 86 | gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3)) 87 | assert np.sum(gt_discrete) == len(keypoints) 88 | 89 | if len(keypoints) > 0: 90 | if random.random() > 0.5: 91 | img = F.hflip(img) 92 | gt_discrete = np.fliplr(gt_discrete) 93 | keypoints[:, 0] = w - keypoints[:, 0] 94 | else: 95 | if random.random() > 0.5: 96 | img = F.hflip(img) 97 | gt_discrete = np.fliplr(gt_discrete) 98 | gt_discrete = np.expand_dims(gt_discrete, 0) 99 | 100 | return self.trans(img), torch.from_numpy(keypoints.copy()).float(), torch.from_numpy( 101 | gt_discrete.copy()).float() 102 | 103 | 104 | class Crowd_qnrf(Base): 105 | def __init__(self, root_path, crop_size, 106 | downsample_ratio=8, 107 | method='train'): 108 | super().__init__(root_path, crop_size, downsample_ratio) 109 | self.method = method 110 | self.im_list = sorted(glob(os.path.join(self.root_path, '*.jpg'))) 111 | print('number of img: {}'.format(len(self.im_list))) 112 | if method not in ['train', 'val']: 113 | raise Exception("not implement") 114 | 115 | def __len__(self): 116 | return len(self.im_list) 117 | 118 | def __getitem__(self, item): 119 | img_path = self.im_list[item] 120 | gd_path = img_path.replace('jpg', 'npy') 121 | img = Image.open(img_path).convert('RGB') 122 | if self.method == 'train': 123 | keypoints = np.load(gd_path) 124 | return self.train_transform(img, keypoints) 125 | elif self.method == 'val': 126 | keypoints = np.load(gd_path) 127 | img = self.trans(img) 128 | name = os.path.basename(img_path).split('.')[0] 129 | return img, len(keypoints), name 130 | 131 | 132 | class Crowd_nwpu(Base): 133 | def __init__(self, root_path, crop_size, 134 | downsample_ratio=8, 135 | method='train'): 136 | super().__init__(root_path, crop_size, downsample_ratio) 137 | self.method = method 138 | self.im_list = sorted(glob(os.path.join(self.root_path, '*.jpg'))) 139 | print('number of img: {}'.format(len(self.im_list))) 140 | 141 | if method not in ['train', 'val', 'test']: 142 | raise Exception("not implement") 143 | 144 | def __len__(self): 145 | return len(self.im_list) 146 | 147 | def __getitem__(self, item): 148 | img_path = self.im_list[item] 149 | gd_path = img_path.replace('jpg', 'npy') 150 | img = Image.open(img_path).convert('RGB') 151 | if self.method == 'train': 152 | keypoints = np.load(gd_path) 153 | return self.train_transform(img, keypoints) 154 | elif self.method == 'val': 155 | keypoints = np.load(gd_path) 156 | img = self.trans(img) 157 | name = os.path.basename(img_path).split('.')[0] 158 | return img, len(keypoints), name 159 | elif self.method == 'test': 160 | img = self.trans(img) 161 | name = os.path.basename(img_path).split('.')[0] 162 | return img, name 163 | 164 | 165 | class Crowd_sh(Base): 166 | def __init__(self, root_path, crop_size, 167 | downsample_ratio=8, 168 | method='train'): 169 | super().__init__(root_path, crop_size, downsample_ratio) 170 | self.method = method 171 | if method not in ['train', 'val']: 172 | raise Exception("not implement") 173 | 174 | self.im_list = sorted(glob(os.path.join(self.root_path, 'images', '*.jpg'))) 175 | 176 | print('number of img [{}]: {}'.format(method, len(self.im_list))) 177 | 178 | def __len__(self): 179 | return len(self.im_list) 180 | 181 | def __getitem__(self, item): 182 | img_path = self.im_list[item] 183 | name = os.path.basename(img_path).split('.')[0] 184 | gd_path = os.path.join(self.root_path, 'ground-truth', 'GT_{}.mat'.format(name)) 185 | img = Image.open(img_path).convert('RGB') 186 | keypoints = sio.loadmat(gd_path)['image_info'][0][0][0][0][0] 187 | if self.method == 'train': 188 | return self.train_transform(img, keypoints) 189 | elif self.method == 'val': 190 | wd, ht = img.size 191 | st_size = 1.0 * min(wd, ht) 192 | if st_size < self.c_size: 193 | rr = 1.0 * self.c_size / st_size 194 | wd = round(wd * rr) 195 | ht = round(ht * rr) 196 | st_size = 1.0 * min(wd, ht) 197 | img = img.resize((wd, ht), Image.BICUBIC) 198 | img = self.trans(img) 199 | return img, len(keypoints), name 200 | 201 | def train_transform(self, img, keypoints): 202 | wd, ht = img.size 203 | st_size = 1.0 * min(wd, ht) 204 | # resize the image to fit the crop size 205 | if st_size < self.c_size: 206 | rr = 1.0 * self.c_size / st_size 207 | wd = round(wd * rr) 208 | ht = round(ht * rr) 209 | st_size = 1.0 * min(wd, ht) 210 | img = img.resize((wd, ht), Image.BICUBIC) 211 | keypoints = keypoints * rr 212 | assert st_size >= self.c_size, print(wd, ht) 213 | assert len(keypoints) >= 0 214 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) 215 | img = F.crop(img, i, j, h, w) 216 | if len(keypoints) > 0: 217 | keypoints = keypoints - [j, i] 218 | idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \ 219 | (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h) 220 | keypoints = keypoints[idx_mask] 221 | else: 222 | keypoints = np.empty([0, 2]) 223 | 224 | gt_discrete = gen_discrete_map(h, w, keypoints) 225 | down_w = w // self.d_ratio 226 | down_h = h // self.d_ratio 227 | gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3)) 228 | assert np.sum(gt_discrete) == len(keypoints) 229 | 230 | if len(keypoints) > 0: 231 | if random.random() > 0.5: 232 | img = F.hflip(img) 233 | gt_discrete = np.fliplr(gt_discrete) 234 | keypoints[:, 0] = w - keypoints[:, 0] - 1 235 | else: 236 | if random.random() > 0.5: 237 | img = F.hflip(img) 238 | gt_discrete = np.fliplr(gt_discrete) 239 | gt_discrete = np.expand_dims(gt_discrete, 0) 240 | 241 | return self.trans(img), torch.from_numpy(keypoints.copy()).float(), torch.from_numpy( 242 | gt_discrete.copy()).float() 243 | 244 | 245 | class CustomDataset(Base): 246 | ''' 247 | Class that allows training for a custom dataset. The folder are designed in the following way: 248 | root_dataset_path: 249 | -> images_1 250 | ->another_folder_with_image 251 | ->train.list 252 | ->valid.list 253 | 254 | The content of the lists file (csv with space as separator) are: 255 | img_xx__path label_xx_path 256 | img_xx1__path label_xx1_path 257 | 258 | where label_xx_path contains a list of x,y position of the head. 259 | ''' 260 | def __init__(self, root_path, crop_size, 261 | downsample_ratio=8, 262 | method='train'): 263 | super().__init__(root_path, crop_size, downsample_ratio) 264 | self.method = method 265 | if method not in ['train', 'valid', 'test']: 266 | raise Exception("not implement") 267 | 268 | # read the list file 269 | self.img_to_label = {} 270 | list_file = f'{method}.list' # train.list, valid.list or test.list 271 | with open(os.path.join(self.root_path, list_file)) as fin: 272 | for line in fin: 273 | if len(line) < 2: 274 | continue 275 | line = line.strip().split() 276 | self.img_to_label[os.path.join(self.root_path, line[0].strip())] = \ 277 | os.path.join(self.root_path, line[1].strip()) 278 | self.img_list = sorted(list(self.img_to_label.keys())) 279 | 280 | 281 | print('number of img [{}]: {}'.format(method, len(self.img_list))) 282 | 283 | def __len__(self): 284 | return len(self.img_list) 285 | 286 | def __getitem__(self, item): 287 | img_path = self.img_list[item] 288 | gt_path = self.img_to_label[img_path] 289 | img_name = os.path.basename(img_path).split('.')[0] 290 | 291 | img = Image.open(img_path).convert('RGB') 292 | keypoints = self.load_head_annotation(gt_path) 293 | 294 | if self.method == 'train': 295 | return self.train_transform(img, keypoints) 296 | elif self.method == 'valid' or self.method == 'test': 297 | wd, ht = img.size 298 | st_size = 1.0 * min(wd, ht) 299 | if st_size < self.c_size: 300 | rr = 1.0 * self.c_size / st_size 301 | wd = round(wd * rr) 302 | ht = round(ht * rr) 303 | st_size = 1.0 * min(wd, ht) 304 | img = img.resize((wd, ht), Image.BICUBIC) 305 | img = self.trans(img) 306 | return img, len(keypoints), img_name 307 | 308 | def load_head_annotation(self, gt_path): 309 | annotations = [] 310 | with open(gt_path) as annotation: 311 | for line in annotation: 312 | x = float(line.strip().split(' ')[0]) 313 | y = float(line.strip().split(' ')[1]) 314 | annotations.append([x, y]) 315 | return np.array(annotations) 316 | 317 | def train_transform(self, img, keypoints): 318 | wd, ht = img.size 319 | st_size = 1.0 * min(wd, ht) 320 | # resize the image to fit the crop size 321 | if st_size < self.c_size: 322 | rr = 1.0 * self.c_size / st_size 323 | wd = round(wd * rr) 324 | ht = round(ht * rr) 325 | st_size = 1.0 * min(wd, ht) 326 | img = img.resize((wd, ht), Image.BICUBIC) 327 | keypoints = keypoints * rr 328 | assert st_size >= self.c_size, print(wd, ht) 329 | assert len(keypoints) >= 0 330 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) 331 | img = F.crop(img, i, j, h, w) 332 | if len(keypoints) > 0: 333 | keypoints = keypoints - [j, i] 334 | idx_mask = (keypoints[:, 0] >= 0) * (keypoints[:, 0] <= w) * \ 335 | (keypoints[:, 1] >= 0) * (keypoints[:, 1] <= h) 336 | keypoints = keypoints[idx_mask] 337 | else: 338 | keypoints = np.empty([0, 2]) 339 | 340 | gt_discrete = gen_discrete_map(h, w, keypoints) 341 | down_w = w // self.d_ratio 342 | down_h = h // self.d_ratio 343 | gt_discrete = gt_discrete.reshape([down_h, self.d_ratio, down_w, self.d_ratio]).sum(axis=(1, 3)) 344 | assert np.sum(gt_discrete) == len(keypoints) 345 | 346 | if len(keypoints) > 0: 347 | if random.random() > 0.5: 348 | img = F.hflip(img) 349 | gt_discrete = np.fliplr(gt_discrete) 350 | keypoints[:, 0] = w - keypoints[:, 0] - 1 351 | else: 352 | if random.random() > 0.5: 353 | img = F.hflip(img) 354 | gt_discrete = np.fliplr(gt_discrete) 355 | gt_discrete = np.expand_dims(gt_discrete, 0) 356 | 357 | return self.trans(img), torch.from_numpy(keypoints.copy()).float(), torch.from_numpy( 358 | gt_discrete.copy()).float() -------------------------------------------------------------------------------- /train_helper_ALTGVT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | from torch import optim 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.dataloader import default_collate 8 | import numpy as np 9 | from datetime import datetime 10 | import torch.nn.functional as F 11 | from datasets.crowd import Crowd_qnrf, Crowd_nwpu, Crowd_sh, CustomDataset 12 | 13 | # from models import vgg19 14 | from Networks import ALTGVT 15 | from losses.ot_loss import OT_Loss 16 | from utils.pytorch_utils import Save_Handle, AverageMeter 17 | import utils.log_utils as log_utils 18 | import wandb 19 | 20 | 21 | def train_collate(batch): 22 | transposed_batch = list(zip(*batch)) 23 | images = torch.stack(transposed_batch[0], 0) 24 | points = transposed_batch[ 25 | 1 26 | ] # the number of points is not fixed, keep it as a list of tensor 27 | gt_discretes = torch.stack(transposed_batch[2], 0) 28 | return images, points, gt_discretes 29 | 30 | 31 | class Trainer(object): 32 | def __init__(self, args): 33 | self.args = args 34 | 35 | def setup(self): 36 | args = self.args 37 | sub_dir = ( 38 | "ALTGVT/{}_12-1-input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}".format( 39 | args.run_name, 40 | args.crop_size, 41 | args.wot, 42 | args.wtv, 43 | args.reg, 44 | args.num_of_iter_in_ot, 45 | args.norm_cood, 46 | ) 47 | ) 48 | 49 | self.save_dir = os.path.join("ckpts", sub_dir) 50 | if not os.path.exists(self.save_dir): 51 | os.makedirs(self.save_dir) 52 | 53 | time_str = datetime.strftime(datetime.now(), "%m%d-%H%M%S") 54 | self.logger = log_utils.get_logger( 55 | os.path.join(self.save_dir, "train-{:s}.log".format(time_str)) 56 | ) 57 | log_utils.print_config(vars(args), self.logger) 58 | 59 | if torch.cuda.is_available(): 60 | self.device = torch.device("cuda") 61 | self.device_count = torch.cuda.device_count() 62 | assert self.device_count == 1 63 | self.logger.info("using {} gpus".format(self.device_count)) 64 | else: 65 | raise Exception("gpu is not available") 66 | 67 | downsample_ratio = 8 68 | if args.dataset.lower() == "qnrf": 69 | self.datasets = { 70 | x: Crowd_qnrf( 71 | os.path.join( 72 | args.data_dir, x), args.crop_size, downsample_ratio, x 73 | ) 74 | for x in ["train", "val"] 75 | } 76 | elif args.dataset.lower() == "nwpu": 77 | self.datasets = { 78 | x: Crowd_nwpu( 79 | os.path.join( 80 | args.data_dir, x), args.crop_size, downsample_ratio, x 81 | ) 82 | for x in ["train", "val"] 83 | } 84 | elif args.dataset.lower() == "sha" or args.dataset.lower() == "shb": 85 | self.datasets = { 86 | "train": Crowd_sh( 87 | os.path.join(args.data_dir, "train_data"), 88 | args.crop_size, 89 | downsample_ratio, 90 | "train", 91 | ), 92 | "val": Crowd_sh( 93 | os.path.join(args.data_dir, "test_data"), 94 | args.crop_size, 95 | downsample_ratio, 96 | "val", 97 | ), 98 | } 99 | elif args.dataset.lower() == "custom": 100 | self.datasets = { 101 | "train": CustomDataset( 102 | args.data_dir, args.crop_size, downsample_ratio, method="train" 103 | ), 104 | "val": CustomDataset( 105 | args.data_dir, args.crop_size, downsample_ratio, method="valid" 106 | ), 107 | } 108 | else: 109 | raise NotImplementedError 110 | 111 | self.dataloaders = { 112 | x: DataLoader( 113 | self.datasets[x], 114 | collate_fn=(train_collate if x == 115 | "train" else default_collate), 116 | batch_size=(args.batch_size if x == "train" else 1), 117 | shuffle=(True if x == "train" else False), 118 | num_workers=args.num_workers * self.device_count, 119 | pin_memory=(True if x == "train" else False), 120 | ) 121 | for x in ["train", "val"] 122 | } 123 | self.model = ALTGVT.alt_gvt_large(pretrained=True) 124 | self.model.to(self.device) 125 | self.optimizer = optim.AdamW( 126 | self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay 127 | ) 128 | self.start_epoch = 0 129 | 130 | # check if wandb has to log 131 | if args.wandb: 132 | self.wandb_run = wandb.init( 133 | config=args, project="CTTrans", name=args.run_name 134 | ) 135 | else : 136 | wandb.init(mode="disabled") 137 | 138 | 139 | if args.resume: 140 | self.logger.info("loading pretrained model from " + args.resume) 141 | suf = args.resume.rsplit(".", 1)[-1] 142 | if suf == "tar": 143 | checkpoint = torch.load(args.resume, self.device) 144 | self.model.load_state_dict(checkpoint["model_state_dict"]) 145 | self.optimizer.load_state_dict( 146 | checkpoint["optimizer_state_dict"]) 147 | self.start_epoch = checkpoint["epoch"] + 1 148 | elif suf == "pth": 149 | self.model.load_state_dict( 150 | torch.load(args.resume, self.device)) 151 | else: 152 | self.logger.info("random initialization") 153 | 154 | self.ot_loss = OT_Loss( 155 | args.crop_size, 156 | downsample_ratio, 157 | args.norm_cood, 158 | self.device, 159 | args.num_of_iter_in_ot, 160 | args.reg, 161 | ) 162 | self.tv_loss = nn.L1Loss(reduction="none").to(self.device) 163 | self.mse = nn.MSELoss().to(self.device) 164 | self.mae = nn.L1Loss().to(self.device) 165 | self.save_list = Save_Handle(max_num=1) 166 | self.best_mae = np.inf 167 | self.best_mse = np.inf 168 | # self.best_count = 0 169 | 170 | def train(self): 171 | """training process""" 172 | args = self.args 173 | for epoch in range(self.start_epoch, args.max_epoch + 1): 174 | self.logger.info( 175 | "-" * 5 + "Epoch {}/{}".format(epoch, args.max_epoch) + "-" * 5 176 | ) 177 | self.epoch = epoch 178 | self.train_epoch() 179 | if epoch % args.val_epoch == 0 and epoch >= args.val_start: 180 | self.val_epoch() 181 | 182 | def train_epoch(self): 183 | epoch_ot_loss = AverageMeter() 184 | epoch_ot_obj_value = AverageMeter() 185 | epoch_wd = AverageMeter() 186 | epoch_count_loss = AverageMeter() 187 | epoch_tv_loss = AverageMeter() 188 | epoch_loss = AverageMeter() 189 | epoch_mae = AverageMeter() 190 | epoch_mse = AverageMeter() 191 | epoch_start = time.time() 192 | self.model.train() # Set model to training mode 193 | 194 | for step, (inputs, points, gt_discrete) in enumerate(self.dataloaders["train"]): 195 | inputs = inputs.to(self.device) 196 | gd_count = np.array([len(p) for p in points], dtype=np.float32) 197 | points = [p.to(self.device) for p in points] 198 | gt_discrete = gt_discrete.to(self.device) 199 | N = inputs.size(0) 200 | 201 | with torch.set_grad_enabled(True): 202 | outputs, outputs_normed = self.model(inputs) 203 | # Compute OT loss. 204 | ot_loss, wd, ot_obj_value = self.ot_loss( 205 | outputs_normed, outputs, points 206 | ) 207 | ot_loss = ot_loss * self.args.wot 208 | ot_obj_value = ot_obj_value * self.args.wot 209 | epoch_ot_loss.update(ot_loss.item(), N) 210 | epoch_ot_obj_value.update(ot_obj_value.item(), N) 211 | epoch_wd.update(wd, N) 212 | 213 | # Compute counting loss. 214 | count_loss = self.mae( 215 | outputs.sum(1).sum(1).sum(1), 216 | torch.from_numpy(gd_count).float().to(self.device), 217 | ) 218 | epoch_count_loss.update(count_loss.item(), N) 219 | 220 | # Compute TV loss. 221 | gd_count_tensor = ( 222 | torch.from_numpy(gd_count) 223 | .float() 224 | .to(self.device) 225 | .unsqueeze(1) 226 | .unsqueeze(2) 227 | .unsqueeze(3) 228 | ) 229 | gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6) 230 | tv_loss = ( 231 | self.tv_loss(outputs_normed, gt_discrete_normed) 232 | .sum(1) 233 | .sum(1) 234 | .sum(1) 235 | * torch.from_numpy(gd_count).float().to(self.device) 236 | ).mean(0) * self.args.wtv 237 | epoch_tv_loss.update(tv_loss.item(), N) 238 | 239 | loss = ot_loss + count_loss + tv_loss 240 | 241 | self.optimizer.zero_grad() 242 | loss.backward() 243 | self.optimizer.step() 244 | 245 | pred_count = ( 246 | torch.sum(outputs.view(N, -1), 247 | dim=1).detach().cpu().numpy() 248 | ) 249 | pred_err = pred_count - gd_count 250 | epoch_loss.update(loss.item(), N) 251 | epoch_mse.update(np.mean(pred_err * pred_err), N) 252 | epoch_mae.update(np.mean(abs(pred_err)), N) 253 | 254 | # log wandb 255 | wandb.log( 256 | { 257 | "train/TOTAL_loss": loss, 258 | "train/count_loss": count_loss, 259 | "train/tv_loss": tv_loss, 260 | "train/pred_err": pred_err, 261 | }, 262 | step=self.epoch, 263 | ) 264 | 265 | self.logger.info( 266 | "Epoch {} Train, Loss: {:.2f}, OT Loss: {:.2e}, Wass Distance: {:.2f}, OT obj value: {:.2f}, " 267 | "Count Loss: {:.2f}, TV Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec".format( 268 | self.epoch, 269 | epoch_loss.get_avg(), 270 | epoch_ot_loss.get_avg(), 271 | epoch_wd.get_avg(), 272 | epoch_ot_obj_value.get_avg(), 273 | epoch_count_loss.get_avg(), 274 | epoch_tv_loss.get_avg(), 275 | np.sqrt(epoch_mse.get_avg()), 276 | epoch_mae.get_avg(), 277 | time.time() - epoch_start, 278 | ) 279 | ) 280 | model_state_dic = self.model.state_dict() 281 | save_path = os.path.join( 282 | self.save_dir, "{}_ckpt.tar".format(self.epoch)) 283 | torch.save( 284 | { 285 | "epoch": self.epoch, 286 | "optimizer_state_dict": self.optimizer.state_dict(), 287 | "model_state_dict": model_state_dic, 288 | }, 289 | save_path, 290 | ) 291 | self.save_list.append(save_path) 292 | 293 | def val_epoch(self): 294 | args = self.args 295 | epoch_start = time.time() 296 | self.model.eval() # Set model to evaluate mode 297 | epoch_res = [] 298 | for inputs, count, name in self.dataloaders["val"]: 299 | with torch.no_grad(): 300 | # nputs = cal_new_tensor(inputs, min_size=args.crop_size) 301 | inputs = inputs.to(self.device) 302 | crop_imgs, crop_masks = [], [] 303 | b, c, h, w = inputs.size() 304 | rh, rw = args.crop_size, args.crop_size 305 | for i in range(0, h, rh): 306 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 307 | for j in range(0, w, rw): 308 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 309 | crop_imgs.append(inputs[:, :, gis:gie, gjs:gje]) 310 | mask = torch.zeros([b, 1, h, w]).to(self.device) 311 | mask[:, :, gis:gie, gjs:gje].fill_(1.0) 312 | crop_masks.append(mask) 313 | crop_imgs, crop_masks = map( 314 | lambda x: torch.cat(x, dim=0), (crop_imgs, crop_masks) 315 | ) 316 | 317 | crop_preds = [] 318 | nz, bz = crop_imgs.size(0), args.batch_size 319 | for i in range(0, nz, bz): 320 | gs, gt = i, min(nz, i + bz) 321 | crop_pred, _ = self.model(crop_imgs[gs:gt]) 322 | 323 | _, _, h1, w1 = crop_pred.size() 324 | crop_pred = ( 325 | F.interpolate( 326 | crop_pred, 327 | size=(h1 * 8, w1 * 8), 328 | mode="bilinear", 329 | align_corners=True, 330 | ) 331 | / 64 332 | ) 333 | 334 | crop_preds.append(crop_pred) 335 | crop_preds = torch.cat(crop_preds, dim=0) 336 | 337 | # splice them to the original size 338 | idx = 0 339 | pred_map = torch.zeros([b, 1, h, w]).to(self.device) 340 | for i in range(0, h, rh): 341 | gis, gie = max(min(h - rh, i), 0), min(h, i + rh) 342 | for j in range(0, w, rw): 343 | gjs, gje = max(min(w - rw, j), 0), min(w, j + rw) 344 | pred_map[:, :, gis:gie, gjs:gje] += crop_preds[idx] 345 | idx += 1 346 | # for the overlapping area, compute average value 347 | mask = crop_masks.sum(dim=0).unsqueeze(0) 348 | outputs = pred_map / mask 349 | 350 | res = count[0].item() - torch.sum(outputs).item() 351 | epoch_res.append(res) 352 | epoch_res = np.array(epoch_res) 353 | mse = np.sqrt(np.mean(np.square(epoch_res))) 354 | mae = np.mean(np.abs(epoch_res)) 355 | 356 | self.logger.info( 357 | "Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec".format( 358 | self.epoch, mse, mae, time.time() - epoch_start 359 | ) 360 | ) 361 | 362 | # log wandb 363 | wandb.log({"val/MSE": mse, "val/MAE": mae}, step=self.epoch) 364 | 365 | model_state_dic = self.model.state_dict() 366 | # if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae): 367 | print("Comaprison", mae, self.best_mae) 368 | if mae < self.best_mae: 369 | self.best_mse = mse 370 | self.best_mae = mae 371 | self.logger.info( 372 | "save best mse {:.2f} mae {:.2f} model epoch {}".format( 373 | self.best_mse, self.best_mae, self.epoch 374 | ) 375 | ) 376 | print("Saving best model at {} epoch".format(self.epoch)) 377 | model_path = os.path.join( 378 | self.save_dir, "best_model_mae-{:.2f}_epoch-{}.pth".format( 379 | self.best_mae, self.epoch) 380 | ) 381 | torch.save( 382 | model_state_dic, 383 | model_path, 384 | ) 385 | 386 | if args.wandb: 387 | artifact = wandb.Artifact("model", type="model") 388 | artifact.add_file(model_path) 389 | 390 | self.wandb_run.log_artifact(artifact) 391 | 392 | # torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model_{}.pth'.format(self.best_count))) 393 | # self.best_count += 1 394 | 395 | 396 | def tensor_divideByfactor(img_tensor, factor=32): 397 | _, _, h, w = img_tensor.size() 398 | h, w = int(h // factor * factor), int(w // factor * factor) 399 | img_tensor = F.interpolate( 400 | img_tensor, (h, w), mode="bilinear", align_corners=True) 401 | 402 | return img_tensor 403 | 404 | 405 | def cal_new_tensor(img_tensor, min_size=256): 406 | _, _, h, w = img_tensor.size() 407 | if min(h, w) < min_size: 408 | ratio_h, ratio_w = min_size / h, min_size / w 409 | if ratio_h >= ratio_w: 410 | img_tensor = F.interpolate( 411 | img_tensor, 412 | (min_size, int(min_size / h * w)), 413 | mode="bilinear", 414 | align_corners=True, 415 | ) 416 | else: 417 | img_tensor = F.interpolate( 418 | img_tensor, 419 | (int(min_size / w * h), min_size), 420 | mode="bilinear", 421 | align_corners=True, 422 | ) 423 | return img_tensor 424 | 425 | 426 | if __name__ == "__main__": 427 | import torch 428 | 429 | print(torch.__file__) 430 | x = torch.ones(1, 3, 768, 1152) 431 | y = tensor_spilt(x) 432 | print(y.size()) 433 | -------------------------------------------------------------------------------- /losses/bregman_pytorch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Rewrite ot.bregman.sinkhorn in Python Optimal Transport (https://pythonot.github.io/_modules/ot/bregman.html#sinkhorn) 4 | using pytorch operations. 5 | Bregman projections for regularized OT (Sinkhorn distance). 6 | """ 7 | 8 | import torch 9 | 10 | M_EPS = 1e-16 11 | 12 | 13 | def sinkhorn(a, b, C, reg=1e-1, method='sinkhorn', maxIter=1000, tau=1e3, 14 | stopThr=1e-9, verbose=False, log=True, warm_start=None, eval_freq=10, print_freq=200, **kwargs): 15 | """ 16 | Solve the entropic regularization optimal transport 17 | The input should be PyTorch tensors 18 | The function solves the following optimization problem: 19 | 20 | .. math:: 21 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) 22 | s.t. \gamma 1 = a 23 | \gamma^T 1= b 24 | \gamma\geq 0 25 | where : 26 | - C is the (ns,nt) metric cost matrix 27 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 28 | - a and b are target and source measures (sum to 1) 29 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. 30 | 31 | Parameters 32 | ---------- 33 | a : torch.tensor (na,) 34 | samples measure in the target domain 35 | b : torch.tensor (nb,) 36 | samples in the source domain 37 | C : torch.tensor (na,nb) 38 | loss matrix 39 | reg : float 40 | Regularization term > 0 41 | method : str 42 | method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or 43 | 'sinkhorn_epsilon_scaling', see those function for specific parameters 44 | maxIter : int, optional 45 | Max number of iterations 46 | stopThr : float, optional 47 | Stop threshol on error ( > 0 ) 48 | verbose : bool, optional 49 | Print information along iterations 50 | log : bool, optional 51 | record log if True 52 | 53 | Returns 54 | ------- 55 | gamma : (na x nb) torch.tensor 56 | Optimal transportation matrix for the given parameters 57 | log : dict 58 | log dictionary return only if log==True in parameters 59 | 60 | References 61 | ---------- 62 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 63 | See Also 64 | -------- 65 | 66 | """ 67 | 68 | if method.lower() == 'sinkhorn': 69 | return sinkhorn_knopp(a, b, C, reg, maxIter=maxIter, 70 | stopThr=stopThr, verbose=verbose, log=log, 71 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq, 72 | **kwargs) 73 | elif method.lower() == 'sinkhorn_stabilized': 74 | return sinkhorn_stabilized(a, b, C, reg, maxIter=maxIter, tau=tau, 75 | stopThr=stopThr, verbose=verbose, log=log, 76 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq, 77 | **kwargs) 78 | elif method.lower() == 'sinkhorn_epsilon_scaling': 79 | return sinkhorn_epsilon_scaling(a, b, C, reg, 80 | maxIter=maxIter, maxInnerIter=100, tau=tau, 81 | scaling_base=0.75, scaling_coef=None, stopThr=stopThr, 82 | verbose=False, log=log, warm_start=warm_start, eval_freq=eval_freq, 83 | print_freq=print_freq, **kwargs) 84 | else: 85 | raise ValueError("Unknown method '%s'." % method) 86 | 87 | 88 | def sinkhorn_knopp(a, b, C, reg=1e-1, maxIter=1000, stopThr=1e-9, 89 | verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs): 90 | """ 91 | Solve the entropic regularization optimal transport 92 | The input should be PyTorch tensors 93 | The function solves the following optimization problem: 94 | 95 | .. math:: 96 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) 97 | s.t. \gamma 1 = a 98 | \gamma^T 1= b 99 | \gamma\geq 0 100 | where : 101 | - C is the (ns,nt) metric cost matrix 102 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 103 | - a and b are target and source measures (sum to 1) 104 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1]. 105 | 106 | Parameters 107 | ---------- 108 | a : torch.tensor (na,) 109 | samples measure in the target domain 110 | b : torch.tensor (nb,) 111 | samples in the source domain 112 | C : torch.tensor (na,nb) 113 | loss matrix 114 | reg : float 115 | Regularization term > 0 116 | maxIter : int, optional 117 | Max number of iterations 118 | stopThr : float, optional 119 | Stop threshol on error ( > 0 ) 120 | verbose : bool, optional 121 | Print information along iterations 122 | log : bool, optional 123 | record log if True 124 | 125 | Returns 126 | ------- 127 | gamma : (na x nb) torch.tensor 128 | Optimal transportation matrix for the given parameters 129 | log : dict 130 | log dictionary return only if log==True in parameters 131 | 132 | References 133 | ---------- 134 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 135 | See Also 136 | -------- 137 | 138 | """ 139 | 140 | device = a.device 141 | na, nb = C.shape 142 | 143 | assert na >= 1 and nb >= 1, 'C needs to be 2d' 144 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C" 145 | assert reg > 0, 'reg should be greater than 0' 146 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0' 147 | 148 | if log: 149 | log = {'err': []} 150 | 151 | if warm_start is not None: 152 | u = warm_start['u'] 153 | v = warm_start['v'] 154 | else: 155 | u = torch.ones(na, dtype=a.dtype).to(device) / na 156 | v = torch.ones(nb, dtype=b.dtype).to(device) / nb 157 | 158 | K = torch.empty(C.shape, dtype=C.dtype).to(device) 159 | torch.div(C, -reg, out=K) 160 | torch.exp(K, out=K) 161 | 162 | b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) 163 | 164 | it = 1 165 | err = 1 166 | 167 | # allocate memory beforehand 168 | KTu = torch.empty(v.shape, dtype=v.dtype).to(device) 169 | Kv = torch.empty(u.shape, dtype=u.dtype).to(device) 170 | 171 | while (err > stopThr and it <= maxIter): 172 | upre, vpre = u, v 173 | torch.matmul(u, K, out=KTu) 174 | v = torch.div(b, KTu + M_EPS) 175 | torch.matmul(K, v, out=Kv) 176 | u = torch.div(a, Kv + M_EPS) 177 | 178 | if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or \ 179 | torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)): 180 | print('Warning: numerical errors at iteration', it) 181 | u, v = upre, vpre 182 | break 183 | 184 | if log and it % eval_freq == 0: 185 | # we can speed up the process by checking for the error only all 186 | # the eval_freq iterations 187 | # below is equivalent to: 188 | # b_hat = torch.sum(u.reshape(-1, 1) * K * v.reshape(1, -1), 0) 189 | # but with more memory efficient 190 | b_hat = torch.matmul(u, K) * v 191 | err = (b - b_hat).pow(2).sum().item() 192 | # err = (b - b_hat).abs().sum().item() 193 | log['err'].append(err) 194 | 195 | if verbose and it % print_freq == 0: 196 | print('iteration {:5d}, constraint error {:5e}'.format(it, err)) 197 | 198 | it += 1 199 | 200 | if log: 201 | log['u'] = u 202 | log['v'] = v 203 | log['alpha'] = reg * torch.log(u + M_EPS) 204 | log['beta'] = reg * torch.log(v + M_EPS) 205 | 206 | # transport plan 207 | P = u.reshape(-1, 1) * K * v.reshape(1, -1) 208 | if log: 209 | return P, log 210 | else: 211 | return P 212 | 213 | 214 | def sinkhorn_stabilized(a, b, C, reg=1e-1, maxIter=1000, tau=1e3, stopThr=1e-9, 215 | verbose=False, log=False, warm_start=None, eval_freq=10, print_freq=200, **kwargs): 216 | """ 217 | Solve the entropic regularization OT problem with log stabilization 218 | The function solves the following optimization problem: 219 | 220 | .. math:: 221 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) 222 | s.t. \gamma 1 = a 223 | \gamma^T 1= b 224 | \gamma\geq 0 225 | where : 226 | - C is the (ns,nt) metric cost matrix 227 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 228 | - a and b are target and source measures (sum to 1) 229 | 230 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [1] 231 | but with the log stabilization proposed in [3] an defined in [2] (Algo 3.1) 232 | 233 | Parameters 234 | ---------- 235 | a : torch.tensor (na,) 236 | samples measure in the target domain 237 | b : torch.tensor (nb,) 238 | samples in the source domain 239 | C : torch.tensor (na,nb) 240 | loss matrix 241 | reg : float 242 | Regularization term > 0 243 | tau : float 244 | thershold for max value in u or v for log scaling 245 | maxIter : int, optional 246 | Max number of iterations 247 | stopThr : float, optional 248 | Stop threshol on error ( > 0 ) 249 | verbose : bool, optional 250 | Print information along iterations 251 | log : bool, optional 252 | record log if True 253 | 254 | Returns 255 | ------- 256 | gamma : (na x nb) torch.tensor 257 | Optimal transportation matrix for the given parameters 258 | log : dict 259 | log dictionary return only if log==True in parameters 260 | 261 | References 262 | ---------- 263 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 264 | [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019 265 | [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. 266 | 267 | See Also 268 | -------- 269 | 270 | """ 271 | 272 | device = a.device 273 | na, nb = C.shape 274 | 275 | assert na >= 1 and nb >= 1, 'C needs to be 2d' 276 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C" 277 | assert reg > 0, 'reg should be greater than 0' 278 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0' 279 | 280 | if log: 281 | log = {'err': []} 282 | 283 | if warm_start is not None: 284 | alpha = warm_start['alpha'] 285 | beta = warm_start['beta'] 286 | else: 287 | alpha = torch.zeros(na, dtype=a.dtype).to(device) 288 | beta = torch.zeros(nb, dtype=b.dtype).to(device) 289 | 290 | u = torch.ones(na, dtype=a.dtype).to(device) / na 291 | v = torch.ones(nb, dtype=b.dtype).to(device) / nb 292 | 293 | def update_K(alpha, beta): 294 | """log space computation""" 295 | """memory efficient""" 296 | torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=K) 297 | torch.add(K, -C, out=K) 298 | torch.div(K, reg, out=K) 299 | torch.exp(K, out=K) 300 | 301 | def update_P(alpha, beta, u, v, ab_updated=False): 302 | """log space P (gamma) computation""" 303 | torch.add(alpha.reshape(-1, 1), beta.reshape(1, -1), out=P) 304 | torch.add(P, -C, out=P) 305 | torch.div(P, reg, out=P) 306 | if not ab_updated: 307 | torch.add(P, torch.log(u + M_EPS).reshape(-1, 1), out=P) 308 | torch.add(P, torch.log(v + M_EPS).reshape(1, -1), out=P) 309 | torch.exp(P, out=P) 310 | 311 | K = torch.empty(C.shape, dtype=C.dtype).to(device) 312 | update_K(alpha, beta) 313 | 314 | b_hat = torch.empty(b.shape, dtype=C.dtype).to(device) 315 | 316 | it = 1 317 | err = 1 318 | ab_updated = False 319 | 320 | # allocate memory beforehand 321 | KTu = torch.empty(v.shape, dtype=v.dtype).to(device) 322 | Kv = torch.empty(u.shape, dtype=u.dtype).to(device) 323 | P = torch.empty(C.shape, dtype=C.dtype).to(device) 324 | 325 | while (err > stopThr and it <= maxIter): 326 | upre, vpre = u, v 327 | torch.matmul(u, K, out=KTu) 328 | v = torch.div(b, KTu + M_EPS) 329 | torch.matmul(K, v, out=Kv) 330 | u = torch.div(a, Kv + M_EPS) 331 | 332 | ab_updated = False 333 | # remove numerical problems and store them in K 334 | if u.abs().sum() > tau or v.abs().sum() > tau: 335 | alpha += reg * torch.log(u + M_EPS) 336 | beta += reg * torch.log(v + M_EPS) 337 | u.fill_(1. / na) 338 | v.fill_(1. / nb) 339 | update_K(alpha, beta) 340 | ab_updated = True 341 | 342 | if log and it % eval_freq == 0: 343 | # we can speed up the process by checking for the error only all 344 | # the eval_freq iterations 345 | update_P(alpha, beta, u, v, ab_updated) 346 | b_hat = torch.sum(P, 0) 347 | err = (b - b_hat).pow(2).sum().item() 348 | log['err'].append(err) 349 | 350 | if verbose and it % print_freq == 0: 351 | print('iteration {:5d}, constraint error {:5e}'.format(it, err)) 352 | 353 | it += 1 354 | 355 | if log: 356 | log['u'] = u 357 | log['v'] = v 358 | log['alpha'] = alpha + reg * torch.log(u + M_EPS) 359 | log['beta'] = beta + reg * torch.log(v + M_EPS) 360 | 361 | # transport plan 362 | update_P(alpha, beta, u, v, False) 363 | 364 | if log: 365 | return P, log 366 | else: 367 | return P 368 | 369 | 370 | def sinkhorn_epsilon_scaling(a, b, C, reg=1e-1, maxIter=100, maxInnerIter=100, tau=1e3, scaling_base=0.75, 371 | scaling_coef=None, stopThr=1e-9, verbose=False, log=False, warm_start=None, eval_freq=10, 372 | print_freq=200, **kwargs): 373 | """ 374 | Solve the entropic regularization OT problem with log stabilization 375 | The function solves the following optimization problem: 376 | 377 | .. math:: 378 | \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) 379 | s.t. \gamma 1 = a 380 | \gamma^T 1= b 381 | \gamma\geq 0 382 | where : 383 | - C is the (ns,nt) metric cost matrix 384 | - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` 385 | - a and b are target and source measures (sum to 1) 386 | 387 | The algorithm used for solving the problem is the Sinkhorn-Knopp matrix 388 | scaling algorithm as proposed in [1] but with the log stabilization 389 | proposed in [3] and the log scaling proposed in [2] algorithm 3.2 390 | 391 | Parameters 392 | ---------- 393 | a : torch.tensor (na,) 394 | samples measure in the target domain 395 | b : torch.tensor (nb,) 396 | samples in the source domain 397 | C : torch.tensor (na,nb) 398 | loss matrix 399 | reg : float 400 | Regularization term > 0 401 | tau : float 402 | thershold for max value in u or v for log scaling 403 | maxIter : int, optional 404 | Max number of iterations 405 | stopThr : float, optional 406 | Stop threshol on error ( > 0 ) 407 | verbose : bool, optional 408 | Print information along iterations 409 | log : bool, optional 410 | record log if True 411 | 412 | Returns 413 | ------- 414 | gamma : (na x nb) torch.tensor 415 | Optimal transportation matrix for the given parameters 416 | log : dict 417 | log dictionary return only if log==True in parameters 418 | 419 | References 420 | ---------- 421 | [1] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 422 | [2] Bernhard Schmitzer. Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. SIAM Journal on Scientific Computing, 2019 423 | [3] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. 424 | 425 | See Also 426 | -------- 427 | 428 | """ 429 | 430 | na, nb = C.shape 431 | 432 | assert na >= 1 and nb >= 1, 'C needs to be 2d' 433 | assert na == a.shape[0] and nb == b.shape[0], "Shape of a or b does't match that of C" 434 | assert reg > 0, 'reg should be greater than 0' 435 | assert a.min() >= 0. and b.min() >= 0., 'Elements in a or b less than 0' 436 | 437 | def get_reg(it, reg, pre_reg): 438 | if it == 1: 439 | return scaling_coef 440 | else: 441 | if (pre_reg - reg) * scaling_base < M_EPS: 442 | return reg 443 | else: 444 | return (pre_reg - reg) * scaling_base + reg 445 | 446 | if scaling_coef is None: 447 | scaling_coef = C.max() + reg 448 | 449 | it = 1 450 | err = 1 451 | running_reg = scaling_coef 452 | 453 | if log: 454 | log = {'err': []} 455 | 456 | warm_start = None 457 | 458 | while (err > stopThr and it <= maxIter): 459 | running_reg = get_reg(it, reg, running_reg) 460 | P, _log = sinkhorn_stabilized(a, b, C, running_reg, maxIter=maxInnerIter, tau=tau, 461 | stopThr=stopThr, verbose=False, log=True, 462 | warm_start=warm_start, eval_freq=eval_freq, print_freq=print_freq, 463 | **kwargs) 464 | 465 | warm_start = {} 466 | warm_start['alpha'] = _log['alpha'] 467 | warm_start['beta'] = _log['beta'] 468 | 469 | primal_val = (C * P).sum() + reg * (P * torch.log(P)).sum() - reg * P.sum() 470 | dual_val = (_log['alpha'] * a).sum() + (_log['beta'] * b).sum() - reg * P.sum() 471 | err = primal_val - dual_val 472 | log['err'].append(err) 473 | 474 | if verbose and it % print_freq == 0: 475 | print('iteration {:5d}, constraint error {:5e}'.format(it, err)) 476 | 477 | it += 1 478 | 479 | if log: 480 | log['alpha'] = _log['alpha'] 481 | log['beta'] = _log['beta'] 482 | return P, log 483 | else: 484 | return P 485 | -------------------------------------------------------------------------------- /preprocess/qnrf_train.txt: -------------------------------------------------------------------------------- 1 | img_0526.jpg 2 | img_0639.jpg 3 | img_0826.jpg 4 | img_0415.jpg 5 | img_0720.jpg 6 | img_0123.jpg 7 | img_0529.jpg 8 | img_1071.jpg 9 | img_0501.jpg 10 | img_0804.jpg 11 | img_0873.jpg 12 | img_0601.jpg 13 | img_0177.jpg 14 | img_0173.jpg 15 | img_0675.jpg 16 | img_1001.jpg 17 | img_0096.jpg 18 | img_1139.jpg 19 | img_0001.jpg 20 | img_0084.jpg 21 | img_0395.jpg 22 | img_0166.jpg 23 | img_0368.jpg 24 | img_0093.jpg 25 | img_0004.jpg 26 | img_0572.jpg 27 | img_0956.jpg 28 | img_0721.jpg 29 | img_0120.jpg 30 | img_0554.jpg 31 | img_0308.jpg 32 | img_0131.jpg 33 | img_0992.jpg 34 | img_0156.jpg 35 | img_0532.jpg 36 | img_0476.jpg 37 | img_0427.jpg 38 | img_1162.jpg 39 | img_0660.jpg 40 | img_0538.jpg 41 | img_0298.jpg 42 | img_0306.jpg 43 | img_1173.jpg 44 | img_1157.jpg 45 | img_0777.jpg 46 | img_0859.jpg 47 | img_0537.jpg 48 | img_0236.jpg 49 | img_0986.jpg 50 | img_0370.jpg 51 | img_0491.jpg 52 | img_1150.jpg 53 | img_0719.jpg 54 | img_1083.jpg 55 | img_0107.jpg 56 | img_1029.jpg 57 | img_0927.jpg 58 | img_0893.jpg 59 | img_0286.jpg 60 | img_1135.jpg 61 | img_0640.jpg 62 | img_0530.jpg 63 | img_1115.jpg 64 | img_0533.jpg 65 | img_0105.jpg 66 | img_0945.jpg 67 | img_1035.jpg 68 | img_0484.jpg 69 | img_1168.jpg 70 | img_0760.jpg 71 | img_0939.jpg 72 | img_0907.jpg 73 | img_0401.jpg 74 | img_0429.jpg 75 | img_0828.jpg 76 | img_1167.jpg 77 | img_0144.jpg 78 | img_0553.jpg 79 | img_0421.jpg 80 | img_0560.jpg 81 | img_0743.jpg 82 | img_0817.jpg 83 | img_0657.jpg 84 | img_0106.jpg 85 | img_0079.jpg 86 | img_0473.jpg 87 | img_0865.jpg 88 | img_0730.jpg 89 | img_0989.jpg 90 | img_0243.jpg 91 | img_0182.jpg 92 | img_0252.jpg 93 | img_0812.jpg 94 | img_0508.jpg 95 | img_0744.jpg 96 | img_0439.jpg 97 | img_0181.jpg 98 | img_0965.jpg 99 | img_0487.jpg 100 | img_0710.jpg 101 | img_1054.jpg 102 | img_0947.jpg 103 | img_0321.jpg 104 | img_0758.jpg 105 | img_0014.jpg 106 | img_0504.jpg 107 | img_0674.jpg 108 | img_0991.jpg 109 | img_0358.jpg 110 | img_1138.jpg 111 | img_0019.jpg 112 | img_0677.jpg 113 | img_0336.jpg 114 | img_0070.jpg 115 | img_0766.jpg 116 | img_0612.jpg 117 | img_1109.jpg 118 | img_0840.jpg 119 | img_0616.jpg 120 | img_0926.jpg 121 | img_0376.jpg 122 | img_0761.jpg 123 | img_0020.jpg 124 | img_0795.jpg 125 | img_0046.jpg 126 | img_0459.jpg 127 | img_0267.jpg 128 | img_0428.jpg 129 | img_1122.jpg 130 | img_0247.jpg 131 | img_1143.jpg 132 | img_0290.jpg 133 | img_0524.jpg 134 | img_0275.jpg 135 | img_1120.jpg 136 | img_0115.jpg 137 | img_0698.jpg 138 | img_0092.jpg 139 | img_0922.jpg 140 | img_1052.jpg 141 | img_0297.jpg 142 | img_0112.jpg 143 | img_0180.jpg 144 | img_0520.jpg 145 | img_0351.jpg 146 | img_0478.jpg 147 | img_0588.jpg 148 | img_0109.jpg 149 | img_0738.jpg 150 | img_0592.jpg 151 | img_0752.jpg 152 | img_1028.jpg 153 | img_1164.jpg 154 | img_0450.jpg 155 | img_0168.jpg 156 | img_1108.jpg 157 | img_0799.jpg 158 | img_0649.jpg 159 | img_0272.jpg 160 | img_0902.jpg 161 | img_0874.jpg 162 | img_0870.jpg 163 | img_0821.jpg 164 | img_0153.jpg 165 | img_0426.jpg 166 | img_0949.jpg 167 | img_0527.jpg 168 | img_1198.jpg 169 | img_0443.jpg 170 | img_0063.jpg 171 | img_0013.jpg 172 | img_0564.jpg 173 | img_0040.jpg 174 | img_0764.jpg 175 | img_0411.jpg 176 | img_0118.jpg 177 | img_1172.jpg 178 | img_0196.jpg 179 | img_0879.jpg 180 | img_0985.jpg 181 | img_0437.jpg 182 | img_0918.jpg 183 | img_0493.jpg 184 | img_0271.jpg 185 | img_0860.jpg 186 | img_0059.jpg 187 | img_0645.jpg 188 | img_1126.jpg 189 | img_0911.jpg 190 | img_1082.jpg 191 | img_0383.jpg 192 | img_0422.jpg 193 | img_0139.jpg 194 | img_1192.jpg 195 | img_0904.jpg 196 | img_0503.jpg 197 | img_0512.jpg 198 | img_0541.jpg 199 | img_0330.jpg 200 | img_0348.jpg 201 | img_0425.jpg 202 | img_0673.jpg 203 | img_0210.jpg 204 | img_0950.jpg 205 | img_0151.jpg 206 | img_0792.jpg 207 | img_0469.jpg 208 | img_0661.jpg 209 | img_0003.jpg 210 | img_0089.jpg 211 | img_0312.jpg 212 | img_0555.jpg 213 | img_0215.jpg 214 | img_0023.jpg 215 | img_1129.jpg 216 | img_0249.jpg 217 | img_0451.jpg 218 | img_1032.jpg 219 | img_0689.jpg 220 | img_1189.jpg 221 | img_0391.jpg 222 | img_0146.jpg 223 | img_0653.jpg 224 | img_0248.jpg 225 | img_0695.jpg 226 | img_0402.jpg 227 | img_0075.jpg 228 | img_1018.jpg 229 | img_1020.jpg 230 | img_0163.jpg 231 | img_0440.jpg 232 | img_0756.jpg 233 | img_0253.jpg 234 | img_0712.jpg 235 | img_0962.jpg 236 | img_0471.jpg 237 | img_0842.jpg 238 | img_0525.jpg 239 | img_1176.jpg 240 | img_1021.jpg 241 | img_0127.jpg 242 | img_0295.jpg 243 | img_1045.jpg 244 | img_1088.jpg 245 | img_1090.jpg 246 | img_0622.jpg 247 | img_0650.jpg 248 | img_0518.jpg 249 | img_0854.jpg 250 | img_0262.jpg 251 | img_0323.jpg 252 | img_0522.jpg 253 | img_0933.jpg 254 | img_0951.jpg 255 | img_0366.jpg 256 | img_0325.jpg 257 | img_1034.jpg 258 | img_0827.jpg 259 | img_0194.jpg 260 | img_0636.jpg 261 | img_0051.jpg 262 | img_0683.jpg 263 | img_0558.jpg 264 | img_0309.jpg 265 | img_0345.jpg 266 | img_0438.jpg 267 | img_1091.jpg 268 | img_0577.jpg 269 | img_0500.jpg 270 | img_0279.jpg 271 | img_1145.jpg 272 | img_0886.jpg 273 | img_1161.jpg 274 | img_0617.jpg 275 | img_0726.jpg 276 | img_0620.jpg 277 | img_0444.jpg 278 | img_1118.jpg 279 | img_0506.jpg 280 | img_0164.jpg 281 | img_0507.jpg 282 | img_0614.jpg 283 | img_0769.jpg 284 | img_1131.jpg 285 | img_0185.jpg 286 | img_0694.jpg 287 | img_1055.jpg 288 | img_0754.jpg 289 | img_0569.jpg 290 | img_0317.jpg 291 | img_0228.jpg 292 | img_0492.jpg 293 | img_1190.jpg 294 | img_0566.jpg 295 | img_0921.jpg 296 | img_0818.jpg 297 | img_0204.jpg 298 | img_0974.jpg 299 | img_0866.jpg 300 | img_1039.jpg 301 | img_0101.jpg 302 | img_0169.jpg 303 | img_0375.jpg 304 | img_0334.jpg 305 | img_1078.jpg 306 | img_0061.jpg 307 | img_0113.jpg 308 | img_0981.jpg 309 | img_0080.jpg 310 | img_0324.jpg 311 | img_0316.jpg 312 | img_0643.jpg 313 | img_0408.jpg 314 | img_0890.jpg 315 | img_0363.jpg 316 | img_0765.jpg 317 | img_0822.jpg 318 | img_0430.jpg 319 | img_0245.jpg 320 | img_0671.jpg 321 | img_0486.jpg 322 | img_1201.jpg 323 | img_0129.jpg 324 | img_1142.jpg 325 | img_0843.jpg 326 | img_1133.jpg 327 | img_0238.jpg 328 | img_0955.jpg 329 | img_1017.jpg 330 | img_0858.jpg 331 | img_1154.jpg 332 | img_0559.jpg 333 | img_0002.jpg 334 | img_0407.jpg 335 | img_1146.jpg 336 | img_1086.jpg 337 | img_0495.jpg 338 | img_0857.jpg 339 | img_0133.jpg 340 | img_0121.jpg 341 | img_0973.jpg 342 | img_0830.jpg 343 | img_0165.jpg 344 | img_0278.jpg 345 | img_1012.jpg 346 | img_0393.jpg 347 | img_0202.jpg 348 | img_0700.jpg 349 | img_0313.jpg 350 | img_0024.jpg 351 | img_0055.jpg 352 | img_0979.jpg 353 | img_0162.jpg 354 | img_0135.jpg 355 | img_0098.jpg 356 | img_0727.jpg 357 | img_0969.jpg 358 | img_1137.jpg 359 | img_0932.jpg 360 | img_1102.jpg 361 | img_0301.jpg 362 | img_0047.jpg 363 | img_0595.jpg 364 | img_0805.jpg 365 | img_0801.jpg 366 | img_1151.jpg 367 | img_0387.jpg 368 | img_0999.jpg 369 | img_0136.jpg 370 | img_1037.jpg 371 | img_1087.jpg 372 | img_1186.jpg 373 | img_0032.jpg 374 | img_0195.jpg 375 | img_0360.jpg 376 | img_0276.jpg 377 | img_0642.jpg 378 | img_0913.jpg 379 | img_0231.jpg 380 | img_0670.jpg 381 | img_1123.jpg 382 | img_0517.jpg 383 | img_0707.jpg 384 | img_0088.jpg 385 | img_0594.jpg 386 | img_0838.jpg 387 | img_0848.jpg 388 | img_0354.jpg 389 | img_0936.jpg 390 | img_0876.jpg 391 | img_1081.jpg 392 | img_0322.jpg 393 | img_0637.jpg 394 | img_0739.jpg 395 | img_0917.jpg 396 | img_0244.jpg 397 | img_0591.jpg 398 | img_0628.jpg 399 | img_0964.jpg 400 | img_0691.jpg 401 | img_0609.jpg 402 | img_0342.jpg 403 | img_1097.jpg 404 | img_1077.jpg 405 | img_0502.jpg 406 | img_0423.jpg 407 | img_0561.jpg 408 | img_1059.jpg 409 | img_0568.jpg 410 | img_0920.jpg 411 | img_0389.jpg 412 | img_0940.jpg 413 | img_0787.jpg 414 | img_0634.jpg 415 | img_0516.jpg 416 | img_0900.jpg 417 | img_0463.jpg 418 | img_0942.jpg 419 | img_0796.jpg 420 | img_0835.jpg 421 | img_0789.jpg 422 | img_0184.jpg 423 | img_0397.jpg 424 | img_1195.jpg 425 | img_1089.jpg 426 | img_0319.jpg 427 | img_0328.jpg 428 | img_0724.jpg 429 | img_0852.jpg 430 | img_0662.jpg 431 | img_0225.jpg 432 | img_0479.jpg 433 | img_0266.jpg 434 | img_0499.jpg 435 | img_0134.jpg 436 | img_1023.jpg 437 | img_1064.jpg 438 | img_0400.jpg 439 | img_0226.jpg 440 | img_0015.jpg 441 | img_0203.jpg 442 | img_0548.jpg 443 | img_1084.jpg 444 | img_0970.jpg 445 | img_0718.jpg 446 | img_0138.jpg 447 | img_0095.jpg 448 | img_0831.jpg 449 | img_0482.jpg 450 | img_1000.jpg 451 | img_0234.jpg 452 | img_0183.jpg 453 | img_0687.jpg 454 | img_0923.jpg 455 | img_0197.jpg 456 | img_1016.jpg 457 | img_1100.jpg 458 | img_0034.jpg 459 | img_0587.jpg 460 | img_0229.jpg 461 | img_1178.jpg 462 | img_0124.jpg 463 | img_0424.jpg 464 | img_0496.jpg 465 | img_0179.jpg 466 | img_1110.jpg 467 | img_0998.jpg 468 | img_0742.jpg 469 | img_0578.jpg 470 | img_0207.jpg 471 | img_0305.jpg 472 | img_0373.jpg 473 | img_0971.jpg 474 | img_0292.jpg 475 | img_0861.jpg 476 | img_0621.jpg 477 | img_0414.jpg 478 | img_1140.jpg 479 | img_0737.jpg 480 | img_0176.jpg 481 | img_1057.jpg 482 | img_1095.jpg 483 | img_0667.jpg 484 | img_0755.jpg 485 | img_0318.jpg 486 | img_0170.jpg 487 | img_0418.jpg 488 | img_0178.jpg 489 | img_1200.jpg 490 | img_0021.jpg 491 | img_0652.jpg 492 | img_0327.jpg 493 | img_0627.jpg 494 | img_1051.jpg 495 | img_0837.jpg 496 | img_0352.jpg 497 | img_0029.jpg 498 | img_0833.jpg 499 | img_0952.jpg 500 | img_0488.jpg 501 | img_0474.jpg 502 | img_0702.jpg 503 | img_0819.jpg 504 | img_1188.jpg 505 | img_0261.jpg 506 | img_0685.jpg 507 | img_1024.jpg 508 | img_0008.jpg 509 | img_0734.jpg 510 | img_0509.jpg 511 | img_0888.jpg 512 | img_0676.jpg 513 | img_0404.jpg 514 | img_1046.jpg 515 | img_1127.jpg 516 | img_1008.jpg 517 | img_0161.jpg 518 | img_0699.jpg 519 | img_0085.jpg 520 | img_0703.jpg 521 | img_0083.jpg 522 | img_0934.jpg 523 | img_0626.jpg 524 | img_1170.jpg 525 | img_1065.jpg 526 | img_0664.jpg 527 | img_0883.jpg 528 | img_0655.jpg 529 | img_0263.jpg 530 | img_1005.jpg 531 | img_1061.jpg 532 | img_0333.jpg 533 | img_0881.jpg 534 | img_1041.jpg 535 | img_0540.jpg 536 | img_1185.jpg 537 | img_0953.jpg 538 | img_0586.jpg 539 | img_1011.jpg 540 | img_0846.jpg 541 | img_0149.jpg 542 | img_1075.jpg 543 | img_0894.jpg 544 | img_0759.jpg 545 | img_1177.jpg 546 | img_0258.jpg 547 | img_0171.jpg 548 | img_0740.jpg 549 | img_0006.jpg 550 | img_0353.jpg 551 | img_0615.jpg 552 | img_0810.jpg 553 | img_0142.jpg 554 | img_0958.jpg 555 | img_0584.jpg 556 | img_0390.jpg 557 | img_0585.jpg 558 | img_0365.jpg 559 | img_0026.jpg 560 | img_0458.jpg 561 | img_0143.jpg 562 | img_0575.jpg 563 | img_1027.jpg 564 | img_1183.jpg 565 | img_0535.jpg 566 | img_0891.jpg 567 | img_1085.jpg 568 | img_0757.jpg 569 | img_0549.jpg 570 | img_0436.jpg 571 | img_0815.jpg 572 | img_0635.jpg 573 | img_0954.jpg 574 | img_0367.jpg 575 | img_0064.jpg 576 | img_0410.jpg 577 | img_0277.jpg 578 | img_1111.jpg 579 | img_1025.jpg 580 | img_0434.jpg 581 | img_1175.jpg 582 | img_1171.jpg 583 | img_0610.jpg 584 | img_0618.jpg 585 | img_0208.jpg 586 | img_0281.jpg 587 | img_0058.jpg 588 | img_0851.jpg 589 | img_0300.jpg 590 | img_0017.jpg 591 | img_0110.jpg 592 | img_0265.jpg 593 | img_0362.jpg 594 | img_1038.jpg 595 | img_0580.jpg 596 | img_1096.jpg 597 | img_0972.jpg 598 | img_0666.jpg 599 | img_0090.jpg 600 | img_1007.jpg 601 | img_0982.jpg 602 | img_0287.jpg 603 | img_0714.jpg 604 | img_0218.jpg 605 | img_0832.jpg 606 | img_0145.jpg 607 | img_0072.jpg 608 | img_0222.jpg 609 | img_0137.jpg 610 | img_0741.jpg 611 | img_0028.jpg 612 | img_0413.jpg 613 | img_0232.jpg 614 | img_0573.jpg 615 | img_0849.jpg 616 | img_0855.jpg 617 | img_0770.jpg 618 | img_0283.jpg 619 | img_0914.jpg 620 | img_0611.jpg 621 | img_1047.jpg 622 | img_0596.jpg 623 | img_0706.jpg 624 | img_0847.jpg 625 | img_0868.jpg 626 | img_0193.jpg 627 | img_0780.jpg 628 | img_0100.jpg 629 | img_0786.jpg 630 | img_0337.jpg 631 | img_0728.jpg 632 | img_0656.jpg 633 | img_0602.jpg 634 | img_1015.jpg 635 | img_0273.jpg 636 | img_0797.jpg 637 | img_0398.jpg 638 | img_0693.jpg 639 | img_0944.jpg 640 | img_0593.jpg 641 | img_0768.jpg 642 | img_0995.jpg 643 | img_1125.jpg 644 | img_0078.jpg 645 | img_0543.jpg 646 | img_0167.jpg 647 | img_0420.jpg 648 | img_0264.jpg 649 | img_0016.jpg 650 | img_0599.jpg 651 | img_0417.jpg 652 | img_0448.jpg 653 | img_0748.jpg 654 | img_0311.jpg 655 | img_0071.jpg 656 | img_0749.jpg 657 | img_0941.jpg 658 | img_0237.jpg 659 | img_0214.jpg 660 | img_1149.jpg 661 | img_0241.jpg 662 | img_0461.jpg 663 | img_0018.jpg 664 | img_0356.jpg 665 | img_0483.jpg 666 | img_0099.jpg 667 | img_0130.jpg 668 | img_0372.jpg 669 | img_0800.jpg 670 | img_0654.jpg 671 | img_0544.jpg 672 | img_1099.jpg 673 | img_1068.jpg 674 | img_0326.jpg 675 | img_0374.jpg 676 | img_0074.jpg 677 | img_0938.jpg 678 | img_0117.jpg 679 | img_0456.jpg 680 | img_0901.jpg 681 | img_0713.jpg 682 | img_0788.jpg 683 | img_0665.jpg 684 | img_0294.jpg 685 | img_0841.jpg 686 | img_0269.jpg 687 | img_0579.jpg 688 | img_1098.jpg 689 | img_0466.jpg 690 | img_0480.jpg 691 | img_0709.jpg 692 | img_0672.jpg 693 | img_1010.jpg 694 | img_0314.jpg 695 | img_0043.jpg 696 | img_0349.jpg 697 | img_0172.jpg 698 | img_1187.jpg 699 | img_0371.jpg 700 | img_0320.jpg 701 | img_1103.jpg 702 | img_1159.jpg 703 | img_0629.jpg 704 | img_0399.jpg 705 | img_0663.jpg 706 | img_0335.jpg 707 | img_1148.jpg 708 | img_0108.jpg 709 | img_0254.jpg 710 | img_0432.jpg 711 | img_0915.jpg 712 | img_0624.jpg 713 | img_0997.jpg 714 | img_0711.jpg 715 | img_0704.jpg 716 | img_1147.jpg 717 | img_0036.jpg 718 | img_0519.jpg 719 | img_0680.jpg 720 | img_0498.jpg 721 | img_0651.jpg 722 | img_0230.jpg 723 | img_0198.jpg 724 | img_0905.jpg 725 | img_0751.jpg 726 | img_0928.jpg 727 | img_0630.jpg 728 | img_0140.jpg 729 | img_0644.jpg 730 | img_0776.jpg 731 | img_0057.jpg 732 | img_0361.jpg 733 | img_0209.jpg 734 | img_0158.jpg 735 | img_1160.jpg 736 | img_1169.jpg 737 | img_0735.jpg 738 | img_0551.jpg 739 | img_0681.jpg 740 | img_0515.jpg 741 | img_0077.jpg 742 | img_0968.jpg 743 | img_0240.jpg 744 | img_1166.jpg 745 | img_0937.jpg 746 | img_0877.jpg 747 | img_0513.jpg 748 | img_0528.jpg 749 | img_0150.jpg 750 | img_1165.jpg 751 | img_0200.jpg 752 | img_0246.jpg 753 | img_0869.jpg 754 | img_0011.jpg 755 | img_0160.jpg 756 | img_0464.jpg 757 | img_0285.jpg 758 | img_0132.jpg 759 | img_0701.jpg 760 | img_0082.jpg 761 | img_1182.jpg 762 | img_0030.jpg 763 | img_0126.jpg 764 | img_0632.jpg 765 | img_0731.jpg 766 | img_0875.jpg 767 | img_0978.jpg 768 | img_0717.jpg 769 | img_0460.jpg 770 | img_1044.jpg 771 | img_1194.jpg 772 | img_0910.jpg 773 | img_0049.jpg 774 | img_0331.jpg 775 | img_0213.jpg 776 | img_0885.jpg 777 | img_0468.jpg 778 | img_0419.jpg 779 | img_1158.jpg 780 | img_0022.jpg 781 | img_0174.jpg 782 | img_0747.jpg 783 | img_1006.jpg 784 | img_0381.jpg 785 | img_1036.jpg 786 | img_0863.jpg 787 | img_0994.jpg 788 | img_0783.jpg 789 | img_0346.jpg 790 | img_0233.jpg 791 | img_0820.jpg 792 | img_1107.jpg 793 | img_1193.jpg 794 | img_0943.jpg 795 | img_1191.jpg 796 | img_0005.jpg 797 | img_0087.jpg 798 | img_0039.jpg 799 | img_0813.jpg 800 | img_0239.jpg 801 | img_0206.jpg 802 | img_0256.jpg 803 | img_1070.jpg 804 | img_0409.jpg 805 | img_0377.jpg 806 | img_0446.jpg 807 | img_0216.jpg 808 | img_0189.jpg 809 | img_0785.jpg 810 | img_0041.jpg 811 | img_0598.jpg 812 | img_0310.jpg 813 | img_0307.jpg 814 | img_1093.jpg 815 | img_0465.jpg 816 | img_0746.jpg 817 | img_0380.jpg 818 | img_0732.jpg 819 | img_0781.jpg 820 | img_0906.jpg 821 | img_0619.jpg 822 | img_0604.jpg 823 | img_0983.jpg 824 | img_0753.jpg 825 | img_0211.jpg 826 | img_0552.jpg 827 | img_0892.jpg 828 | img_0767.jpg 829 | img_1180.jpg 830 | img_1069.jpg 831 | img_0154.jpg 832 | img_0899.jpg 833 | img_0343.jpg 834 | img_0025.jpg 835 | img_1196.jpg 836 | img_0155.jpg 837 | img_0433.jpg 838 | img_0597.jpg 839 | img_0570.jpg 840 | img_0867.jpg 841 | img_0223.jpg 842 | img_0581.jpg 843 | img_0186.jpg 844 | img_0122.jpg 845 | img_1134.jpg 846 | img_0340.jpg 847 | img_0957.jpg 848 | img_0364.jpg 849 | img_0069.jpg 850 | img_1114.jpg 851 | img_0646.jpg 852 | img_0679.jpg 853 | img_0623.jpg 854 | img_0392.jpg 855 | img_0814.jpg 856 | img_0589.jpg 857 | img_0299.jpg 858 | img_0931.jpg 859 | img_0836.jpg 860 | img_0963.jpg 861 | img_0094.jpg 862 | img_0987.jpg 863 | img_0930.jpg 864 | img_0976.jpg 865 | img_0924.jpg 866 | img_0384.jpg 867 | img_0035.jpg 868 | img_0076.jpg 869 | img_1101.jpg 870 | img_0405.jpg 871 | img_0350.jpg 872 | img_0147.jpg 873 | img_0659.jpg 874 | img_1013.jpg 875 | img_0948.jpg 876 | img_0066.jpg 877 | img_1132.jpg 878 | img_0829.jpg 879 | img_0690.jpg 880 | img_1060.jpg 881 | img_0457.jpg 882 | img_0897.jpg 883 | img_0825.jpg 884 | img_1163.jpg 885 | img_0803.jpg 886 | img_0563.jpg 887 | img_0574.jpg 888 | img_0175.jpg 889 | img_1112.jpg 890 | img_0668.jpg 891 | img_0045.jpg 892 | img_0259.jpg 893 | img_0341.jpg 894 | img_1067.jpg 895 | img_1040.jpg 896 | img_1106.jpg 897 | img_0205.jpg 898 | img_0296.jpg 899 | img_0255.jpg 900 | img_1152.jpg 901 | img_0772.jpg 902 | img_0613.jpg 903 | img_1121.jpg 904 | img_0834.jpg 905 | img_0406.jpg 906 | img_0762.jpg 907 | img_0442.jpg 908 | img_0192.jpg 909 | img_0044.jpg 910 | img_0774.jpg 911 | img_0606.jpg 912 | img_0359.jpg 913 | img_0467.jpg 914 | img_0779.jpg 915 | img_0060.jpg 916 | img_1074.jpg 917 | img_0494.jpg 918 | img_1153.jpg 919 | img_0102.jpg 920 | img_0582.jpg 921 | img_0386.jpg 922 | img_0212.jpg 923 | img_0625.jpg 924 | img_0844.jpg 925 | img_0872.jpg 926 | img_1105.jpg 927 | img_0396.jpg 928 | img_1119.jpg 929 | img_0052.jpg 930 | img_0454.jpg 931 | img_1179.jpg 932 | img_0862.jpg 933 | img_0481.jpg 934 | img_1026.jpg 935 | img_0511.jpg 936 | img_0912.jpg 937 | img_1124.jpg 938 | img_0148.jpg 939 | img_0960.jpg 940 | img_0523.jpg 941 | img_0531.jpg 942 | img_0729.jpg 943 | img_0571.jpg 944 | img_0908.jpg 945 | img_0889.jpg 946 | img_0188.jpg 947 | img_0037.jpg 948 | img_0716.jpg 949 | img_1014.jpg 950 | img_0394.jpg 951 | img_1056.jpg 952 | img_0462.jpg 953 | img_0850.jpg 954 | img_0784.jpg 955 | img_1002.jpg 956 | img_0763.jpg 957 | img_0159.jpg 958 | img_0009.jpg 959 | img_0708.jpg 960 | img_1050.jpg 961 | img_0678.jpg 962 | img_0648.jpg 963 | img_0010.jpg 964 | img_1031.jpg 965 | img_0445.jpg 966 | img_0355.jpg 967 | img_1117.jpg 968 | img_0378.jpg 969 | img_0550.jpg 970 | img_0217.jpg 971 | img_0260.jpg 972 | img_0816.jpg 973 | img_0996.jpg 974 | img_0081.jpg 975 | img_0878.jpg 976 | img_0199.jpg 977 | img_0431.jpg 978 | img_1144.jpg 979 | img_0688.jpg 980 | img_0745.jpg 981 | img_0686.jpg 982 | img_1042.jpg 983 | img_0187.jpg 984 | img_1066.jpg 985 | img_0682.jpg 986 | img_0048.jpg 987 | img_0896.jpg 988 | img_0608.jpg 989 | img_1003.jpg 990 | img_1156.jpg 991 | img_0723.jpg 992 | img_0692.jpg 993 | img_0220.jpg 994 | img_0993.jpg 995 | img_1197.jpg 996 | img_0447.jpg 997 | img_0369.jpg 998 | img_0056.jpg 999 | img_0807.jpg 1000 | img_0315.jpg 1001 | img_0567.jpg 1002 | img_0452.jpg 1003 | img_1128.jpg 1004 | img_0647.jpg 1005 | img_0242.jpg 1006 | img_0201.jpg 1007 | img_0497.jpg 1008 | img_0031.jpg 1009 | img_0771.jpg 1010 | img_0547.jpg 1011 | img_0705.jpg 1012 | img_0725.jpg 1013 | img_1058.jpg 1014 | img_0053.jpg 1015 | img_1043.jpg 1016 | img_0722.jpg 1017 | img_0435.jpg 1018 | img_0284.jpg 1019 | img_0583.jpg 1020 | img_0882.jpg 1021 | img_0111.jpg 1022 | img_0959.jpg 1023 | img_1076.jpg 1024 | img_0880.jpg 1025 | img_0224.jpg 1026 | img_0977.jpg 1027 | img_0270.jpg 1028 | img_0793.jpg 1029 | img_0603.jpg 1030 | img_1116.jpg 1031 | img_0304.jpg 1032 | img_0884.jpg 1033 | img_1136.jpg 1034 | img_0235.jpg 1035 | img_0412.jpg 1036 | img_0980.jpg 1037 | img_0988.jpg 1038 | img_0773.jpg 1039 | img_1174.jpg 1040 | img_0562.jpg 1041 | img_0871.jpg 1042 | img_0798.jpg 1043 | img_0453.jpg 1044 | img_0696.jpg 1045 | img_0104.jpg 1046 | img_0607.jpg 1047 | img_0669.jpg 1048 | img_0293.jpg 1049 | img_1141.jpg 1050 | img_0329.jpg 1051 | img_0534.jpg 1052 | img_1113.jpg 1053 | img_0288.jpg 1054 | img_0961.jpg 1055 | img_0388.jpg 1056 | img_0073.jpg 1057 | img_0141.jpg 1058 | img_0935.jpg 1059 | img_1062.jpg 1060 | img_0227.jpg 1061 | img_0895.jpg 1062 | img_0449.jpg 1063 | img_0565.jpg 1064 | img_1009.jpg 1065 | img_0282.jpg 1066 | img_0806.jpg 1067 | img_1033.jpg 1068 | img_0332.jpg 1069 | img_0903.jpg 1070 | img_0475.jpg 1071 | img_0050.jpg 1072 | img_0455.jpg 1073 | img_0845.jpg 1074 | img_0946.jpg 1075 | img_0490.jpg 1076 | img_0274.jpg 1077 | img_0909.jpg 1078 | img_0966.jpg 1079 | img_0219.jpg 1080 | img_0898.jpg 1081 | img_0403.jpg 1082 | -------------------------------------------------------------------------------- /Networks/ALTGVT.py: -------------------------------------------------------------------------------- 1 | # _*_ coding: utf-8 _*_ 2 | # @author : 王福森 3 | # @time : 2021/11/12 17:16 4 | # @File : ALTGVT.py 5 | # @Software : PyCharm 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from functools import partial 10 | 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | from timm.models.registry import register_model 13 | from timm.models.vision_transformer import _cfg 14 | from timm.models.vision_transformer import Block as TimmBlock 15 | from timm.models.vision_transformer import Attention as TimmAttention 16 | 17 | class Regression(nn.Module): 18 | def __init__(self): 19 | super(Regression, self).__init__() 20 | 21 | self.v1 = nn.Sequential( 22 | # nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True), 23 | nn.Conv2d(256, 256, 3, padding=1, dilation=1), 24 | nn.BatchNorm2d(256), 25 | nn.ReLU(inplace=True) 26 | ) 27 | 28 | self.v2 = nn.Sequential( 29 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 30 | nn.Conv2d(512, 256, 3, padding=1, dilation=1), 31 | nn.BatchNorm2d(256), 32 | nn.ReLU(inplace=True) 33 | ) 34 | self.v3 = nn.Sequential( 35 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), 36 | nn.Conv2d(1024, 256, 3, padding=1, dilation=1), 37 | nn.BatchNorm2d(256), 38 | nn.ReLU(inplace=True) 39 | ) 40 | self.stage1 = nn.Sequential( 41 | nn.Conv2d(256, 128, 3, padding=1, dilation=1), 42 | nn.BatchNorm2d(128), 43 | nn.ReLU(inplace=True) 44 | ) 45 | self.stage2 = nn.Sequential( 46 | nn.Conv2d(256, 128, 3, padding=2, dilation=2), 47 | nn.BatchNorm2d(128), 48 | nn.ReLU(inplace=True) 49 | ) 50 | self.stage3 = nn.Sequential( 51 | nn.Conv2d(256, 128, 3, padding=3, dilation=3), 52 | nn.BatchNorm2d(128), 53 | nn.ReLU(inplace=True) 54 | ) 55 | self.stage4 = nn.Sequential( 56 | nn.Conv2d(256, 384, 1), 57 | nn.BatchNorm2d(384), 58 | nn.ReLU(inplace=True) 59 | ) 60 | self.res = nn.Sequential( 61 | nn.Conv2d(384, 64, 3, padding=1, dilation=1), 62 | nn.BatchNorm2d(64), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(64, 1, 1), 65 | nn.ReLU() 66 | ) 67 | 68 | self.init_param() 69 | 70 | def forward(self, x1, x2, x3): 71 | x1 = self.v1(x1) 72 | x2 = self.v2(x2) 73 | x3 = self.v3(x3) 74 | x = x1 + x2 + x3 75 | y1 = self.stage1(x) 76 | y2 = self.stage2(x) 77 | y3 = self.stage3(x) 78 | y4 = self.stage4(x) 79 | y = torch.cat((y1,y2,y3), dim=1) + y4 80 | y = self.res(y) 81 | return y 82 | 83 | def init_param(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | nn.init.normal_(m.weight, std=0.01) 87 | if m.bias is not None: 88 | nn.init.constant_(m.bias, 0) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | class Mlp(nn.Module): 94 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 95 | super().__init__() 96 | out_features = out_features or in_features 97 | hidden_features = hidden_features or in_features 98 | self.fc1 = nn.Linear(in_features, hidden_features) 99 | self.act = act_layer() 100 | self.fc2 = nn.Linear(hidden_features, out_features) 101 | self.drop = nn.Dropout(drop) 102 | 103 | def forward(self, x): 104 | x = self.fc1(x) 105 | x = self.act(x) 106 | x = self.drop(x) 107 | x = self.fc2(x) 108 | x = self.drop(x) 109 | return x 110 | 111 | 112 | class GroupAttention(nn.Module): 113 | """ 114 | LSA: self attention within a group 115 | """ 116 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ws=1): 117 | assert ws != 1 118 | super(GroupAttention, self).__init__() 119 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 120 | 121 | self.dim = dim 122 | self.num_heads = num_heads 123 | head_dim = dim // num_heads 124 | self.scale = qk_scale or head_dim ** -0.5 125 | 126 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 127 | self.attn_drop = nn.Dropout(attn_drop) 128 | self.proj = nn.Linear(dim, dim) 129 | self.proj_drop = nn.Dropout(proj_drop) 130 | self.ws = ws 131 | 132 | def forward(self, x, H, W): 133 | B, N, C = x.shape 134 | h_group, w_group = H // self.ws, W // self.ws 135 | 136 | total_groups = h_group * w_group 137 | x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3) 138 | 139 | qkv = self.qkv(x).reshape(B, total_groups, -1, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) 140 | # B, hw, ws*ws, 3, n_head, head_dim -> 3, B, hw, n_head, ws*ws, head_dim 141 | q, k, v = qkv[0], qkv[1], qkv[2] # B, hw, n_head, ws*ws, head_dim 142 | attn = (q @ k.transpose(-2, -1)) * self.scale # B, hw, n_head, ws*ws, ws*ws 143 | attn = attn.softmax(dim=-1) 144 | attn = self.attn_drop( 145 | attn) # attn @ v-> B, hw, n_head, ws*ws, head_dim -> (t(2,3)) B, hw, ws*ws, n_head, head_dim 146 | attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, C) 147 | x = attn.transpose(2, 3).reshape(B, N, C) 148 | x = self.proj(x) 149 | x = self.proj_drop(x) 150 | return x 151 | 152 | 153 | class Attention(nn.Module): 154 | """ 155 | GSA: using a key to summarize the information for a group to be efficient. 156 | """ 157 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 158 | super().__init__() 159 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 160 | 161 | self.dim = dim 162 | self.num_heads = num_heads 163 | head_dim = dim // num_heads 164 | self.scale = qk_scale or head_dim ** -0.5 165 | 166 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 167 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 168 | self.attn_drop = nn.Dropout(attn_drop) 169 | self.proj = nn.Linear(dim, dim) 170 | self.proj_drop = nn.Dropout(proj_drop) 171 | 172 | self.sr_ratio = sr_ratio 173 | if sr_ratio > 1: 174 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 175 | self.norm = nn.LayerNorm(dim) 176 | 177 | def forward(self, x, H, W): 178 | B, N, C = x.shape 179 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 180 | 181 | if self.sr_ratio > 1: 182 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 183 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 184 | x_ = self.norm(x_) 185 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 186 | else: 187 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 188 | k, v = kv[0], kv[1] 189 | 190 | attn = (q @ k.transpose(-2, -1)) * self.scale 191 | attn = attn.softmax(dim=-1) 192 | attn = self.attn_drop(attn) 193 | 194 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 195 | x = self.proj(x) 196 | x = self.proj_drop(x) 197 | 198 | return x 199 | 200 | 201 | class Block(nn.Module): 202 | 203 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 204 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 205 | super().__init__() 206 | self.norm1 = norm_layer(dim) 207 | self.attn = Attention( 208 | dim, 209 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 210 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 211 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 212 | self.norm2 = norm_layer(dim) 213 | mlp_hidden_dim = int(dim * mlp_ratio) 214 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 215 | 216 | def forward(self, x, H, W): 217 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 218 | x = x + self.drop_path(self.mlp(self.norm2(x))) 219 | 220 | return x 221 | 222 | 223 | class SBlock(TimmBlock): 224 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 225 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 226 | super(SBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, 227 | drop_path, act_layer, norm_layer) 228 | 229 | def forward(self, x, H, W): 230 | return super(SBlock, self).forward(x) 231 | 232 | 233 | class GroupBlock(TimmBlock): 234 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 235 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=1): 236 | 237 | #super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, 238 | # drop_path, act_layer, norm_layer) 239 | 240 | #delete the qk_scale 241 | super(GroupBlock, self).__init__(dim, num_heads, mlp_ratio, qkv_bias, drop, attn_drop, 242 | drop_path, act_layer, norm_layer) 243 | del self.attn 244 | if ws == 1: 245 | self.attn = Attention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, sr_ratio) 246 | else: 247 | self.attn = GroupAttention(dim, num_heads, qkv_bias, qk_scale, attn_drop, drop, ws) 248 | 249 | def forward(self, x, H, W): 250 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 251 | x = x + self.drop_path(self.mlp(self.norm2(x))) 252 | 253 | return x 254 | 255 | 256 | class PatchEmbed(nn.Module): 257 | """ Image to Patch Embedding 258 | """ 259 | 260 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 261 | super().__init__() 262 | img_size = to_2tuple(img_size) 263 | patch_size = to_2tuple(patch_size) 264 | 265 | self.img_size = img_size 266 | self.patch_size = patch_size 267 | assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ 268 | f"img_size {img_size} should be divided by patch_size {patch_size}." 269 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 270 | self.num_patches = self.H * self.W 271 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 272 | self.norm = nn.LayerNorm(embed_dim) 273 | 274 | def forward(self, x): 275 | B, C, H, W = x.shape 276 | 277 | x = self.proj(x).flatten(2).transpose(1, 2) 278 | x = self.norm(x) 279 | H, W = H // self.patch_size[0], W // self.patch_size[1] 280 | 281 | return x, (H, W) 282 | 283 | 284 | # borrow from PVT https://github.com/whai362/PVT.git 285 | class PyramidVisionTransformer(nn.Module): 286 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 287 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 288 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 289 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block): 290 | super().__init__() 291 | self.num_classes = num_classes 292 | self.depths = depths 293 | 294 | # patch_embed 295 | self.patch_embeds = nn.ModuleList() 296 | self.pos_embeds = nn.ParameterList() 297 | self.pos_drops = nn.ModuleList() 298 | self.blocks = nn.ModuleList() 299 | 300 | for i in range(len(depths)): 301 | if i == 0: 302 | self.patch_embeds.append(PatchEmbed(img_size, patch_size, in_chans, embed_dims[i])) 303 | else: 304 | self.patch_embeds.append( 305 | PatchEmbed(img_size // patch_size // 2 ** (i - 1), 2, embed_dims[i - 1], embed_dims[i])) 306 | patch_num = self.patch_embeds[-1].num_patches + 1 if i == len(embed_dims) - 1 else self.patch_embeds[ 307 | -1].num_patches 308 | self.pos_embeds.append(nn.Parameter(torch.zeros(1, patch_num, embed_dims[i]))) 309 | self.pos_drops.append(nn.Dropout(p=drop_rate)) 310 | 311 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 312 | cur = 0 313 | for k in range(len(depths)): 314 | _block = nn.ModuleList([ 315 | block_cls( 316 | dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias, 317 | qk_scale=qk_scale, 318 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 319 | sr_ratio=sr_ratios[k] 320 | ) 321 | for i in range(depths[k])]) 322 | self.blocks.append(_block) 323 | cur += depths[k] 324 | 325 | self.norm = norm_layer(embed_dims[-1]) 326 | 327 | # cls_token 328 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1])) 329 | 330 | # classification head 331 | self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() 332 | 333 | # init weights 334 | for pos_emb in self.pos_embeds: 335 | trunc_normal_(pos_emb, std=.02) 336 | self.apply(self._init_weights) 337 | 338 | def reset_drop_path(self, drop_path_rate): 339 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] 340 | cur = 0 341 | for k in range(len(self.depths)): 342 | for i in range(self.depths[k]): 343 | self.blocks[k][i].drop_path.drop_prob = dpr[cur + i] 344 | cur += self.depths[k] 345 | 346 | def _init_weights(self, m): 347 | if isinstance(m, nn.Linear): 348 | trunc_normal_(m.weight, std=.02) 349 | if isinstance(m, nn.Linear) and m.bias is not None: 350 | nn.init.constant_(m.bias, 0) 351 | elif isinstance(m, nn.LayerNorm): 352 | nn.init.constant_(m.bias, 0) 353 | nn.init.constant_(m.weight, 1.0) 354 | 355 | @torch.jit.ignore 356 | def no_weight_decay(self): 357 | return {'cls_token'} 358 | 359 | def get_classifier(self): 360 | return self.head 361 | 362 | def reset_classifier(self, num_classes, global_pool=''): 363 | self.num_classes = num_classes 364 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 365 | 366 | def forward_features(self, x): 367 | B = x.shape[0] 368 | for i in range(len(self.depths)): 369 | x, (H, W) = self.patch_embeds[i](x) 370 | if i == len(self.depths) - 1: 371 | cls_tokens = self.cls_token.expand(B, -1, -1) 372 | x = torch.cat((cls_tokens, x), dim=1) 373 | x = x + self.pos_embeds[i] 374 | x = self.pos_drops[i](x) 375 | for blk in self.blocks[i]: 376 | x = blk(x, H, W) 377 | if i < len(self.depths) - 1: 378 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 379 | x = self.norm(x) 380 | return x[:, 0] 381 | 382 | def forward(self, x): 383 | x = self.forward_features(x) 384 | x = self.head(x) 385 | 386 | return x 387 | 388 | 389 | # PEG from https://arxiv.org/abs/2102.10882 390 | class PosCNN(nn.Module): 391 | def __init__(self, in_chans, embed_dim=768, s=1): 392 | super(PosCNN, self).__init__() 393 | self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, s, 1, bias=True, groups=embed_dim), ) 394 | self.s = s 395 | 396 | def forward(self, x, H, W): 397 | B, N, C = x.shape 398 | feat_token = x 399 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W) 400 | if self.s == 1: 401 | x = self.proj(cnn_feat) + cnn_feat 402 | else: 403 | x = self.proj(cnn_feat) 404 | x = x.flatten(2).transpose(1, 2) 405 | return x 406 | 407 | def no_weight_decay(self): 408 | return ['proj.%d.weight' % i for i in range(4)] 409 | 410 | 411 | class CPVTV2(PyramidVisionTransformer): 412 | """ 413 | Use useful results from CPVT. PEG and GAP. 414 | Therefore, cls token is no longer required. 415 | PEG is used to encode the absolute position on the fly, which greatly affects the performance when input resolution 416 | changes during the training (such as segmentation, detection) 417 | """ 418 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 419 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 420 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 421 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], block_cls=Block): 422 | super(CPVTV2, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, mlp_ratios, 423 | qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, norm_layer, depths, 424 | sr_ratios, block_cls) 425 | del self.pos_embeds 426 | del self.cls_token 427 | self.pos_block = nn.ModuleList( 428 | [PosCNN(embed_dim, embed_dim) for embed_dim in embed_dims] 429 | ) 430 | 431 | self.regression = Regression() 432 | self.apply(self._init_weights) 433 | 434 | 435 | def _init_weights(self, m): 436 | import math 437 | if isinstance(m, nn.Linear): 438 | trunc_normal_(m.weight, std=.02) 439 | if isinstance(m, nn.Linear) and m.bias is not None: 440 | nn.init.constant_(m.bias, 0) 441 | elif isinstance(m, nn.LayerNorm): 442 | nn.init.constant_(m.bias, 0) 443 | nn.init.constant_(m.weight, 1.0) 444 | elif isinstance(m, nn.Conv2d): 445 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 446 | fan_out //= m.groups 447 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 448 | if m.bias is not None: 449 | m.bias.data.zero_() 450 | elif isinstance(m, nn.BatchNorm2d): 451 | m.weight.data.fill_(1.0) 452 | m.bias.data.zero_() 453 | 454 | def no_weight_decay(self): 455 | return set(['cls_token'] + ['pos_block.' + n for n, p in self.pos_block.named_parameters()]) 456 | 457 | def forward_features(self, x): 458 | outputs = list() 459 | 460 | B = x.shape[0] 461 | 462 | for i in range(len(self.depths)): 463 | x, (H, W) = self.patch_embeds[i](x) 464 | x = self.pos_drops[i](x) 465 | for j, blk in enumerate(self.blocks[i]): 466 | x = blk(x, H, W) 467 | if j == 0: 468 | x = self.pos_block[i](x, H, W) 469 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 470 | outputs.append(x) 471 | 472 | return outputs 473 | 474 | def forward(self, x): 475 | x = self.forward_features(x) 476 | mu = self.regression(x[1], x[2], x[3]) 477 | B, C, H, W = mu.size() 478 | mu_sum = mu.view([B, -1]).sum(1).unsqueeze(1).unsqueeze(2).unsqueeze(3) 479 | mu_normed = mu / (mu_sum + 1e-6) 480 | return mu, mu_normed 481 | 482 | 483 | class PCPVT(CPVTV2): 484 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256], 485 | num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 486 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 487 | depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=SBlock): 488 | super(PCPVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, 489 | mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, 490 | norm_layer, depths, sr_ratios, block_cls) 491 | 492 | 493 | class ALTGVT(PCPVT): 494 | """ 495 | alias Twins-SVT 496 | """ 497 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256], 498 | num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 499 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 500 | depths=[4, 4, 4], sr_ratios=[4, 2, 1], block_cls=GroupBlock, wss=[7, 7, 7]): 501 | super(ALTGVT, self).__init__(img_size, patch_size, in_chans, num_classes, embed_dims, num_heads, 502 | mlp_ratios, qkv_bias, qk_scale, drop_rate, attn_drop_rate, drop_path_rate, 503 | norm_layer, depths, sr_ratios, block_cls) 504 | del self.blocks 505 | self.wss = wss 506 | # transformer encoder 507 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 508 | cur = 0 509 | self.blocks = nn.ModuleList() 510 | for k in range(len(depths)): 511 | _block = nn.ModuleList([block_cls( 512 | dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], qkv_bias=qkv_bias, 513 | qk_scale=qk_scale, 514 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, 515 | sr_ratio=sr_ratios[k], ws=1 if i % 2 == 1 else wss[k]) for i in range(depths[k])]) 516 | self.blocks.append(_block) 517 | cur += depths[k] 518 | self.apply(self._init_weights) 519 | 520 | 521 | def _conv_filter(state_dict, patch_size=16): 522 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 523 | out_dict = {} 524 | for k, v in state_dict.items(): 525 | if 'patch_embed.proj.weight' in k: 526 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 527 | out_dict[k] = v 528 | 529 | return out_dict 530 | 531 | @register_model 532 | def alt_gvt_small(pretrained=False, **kwargs): 533 | model = ALTGVT( 534 | patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 535 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], 536 | **kwargs) 537 | model.default_cfg = _cfg() 538 | return model 539 | 540 | 541 | @register_model 542 | def alt_gvt_base(pretrained=False, **kwargs): 543 | model = ALTGVT( 544 | patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 545 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], 546 | **kwargs) 547 | 548 | model.default_cfg = _cfg() 549 | return model 550 | 551 | 552 | @register_model 553 | def alt_gvt_large(pretrained=False, **kwargs): 554 | model = ALTGVT( 555 | patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], 556 | qkv_bias=True, 557 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 18, 2], wss=[8, 8, 8, 8], sr_ratios=[8, 4, 2, 1], 558 | **kwargs) 559 | model.default_cfg = _cfg() 560 | if pretrained: 561 | '''download from https://github.com/Meituan-AutoML/Twins/alt_gvt_large.pth''' 562 | checkpoint = torch.load('/train_folder/head_detection/CCTrans/model_weights/alt_gvt_large.pth') # todo pass path as argument 563 | model.load_state_dict(checkpoint, strict=False) 564 | print("load transformer pretrained") 565 | return model 566 | 567 | if __name__ == '__main__': 568 | model = alt_gvt_large(pretrained=True) 569 | x = torch.ones(1, 3, 256, 256) 570 | mu, mu_norm = model(x) 571 | print(mu.size(), mu_norm.size()) --------------------------------------------------------------------------------