├── README.md ├── SKT_distill.py ├── SKT_distill.sh ├── dataset.py ├── image.py ├── models ├── __init__.py ├── distillation.py ├── model_student_vgg.py ├── model_teacher_vgg.py └── model_vgg.py ├── preprocess ├── ShanghaiTech_GT_generation.py ├── UCF_GT_generation.py ├── make_json.py └── part_A_val.json ├── test.py ├── test.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # [Efficient Crowd Counting via Structured Knowledge Transfer](https://arxiv.org/abs/2003.10120) (ACM MM 2020) 2 | Crowd counting is an application-oriented task and its inference efficiency is crucial for real-world applications. However, most previous works relied on heavy backbone networks and required prohibitive run-time consumption, which would seriously restrict their deployment scopes and cause poor scalability. To liberate these crowd counting models, we propose a novel Structured Knowledge Transfer (SKT) framework, which fully exploits the structured knowledge of a well-trained teacher network to generate a lightweight but still highly effective student network. 3 | 4 | Extensive evaluations on three benchmarks well demonstrate the effectiveness of our SKT for extensive crowd counting models. In this project, the well-trained teacher networks and the distilled student networks have been released at [GoogleDrive](https://drive.google.com/drive/folders/17oxen8sNHtumcFL8hu9Z0Owuc6dWD8zV?usp=sharing) and [BaiduYun](https://pan.baidu.com/s/10_SLXF_FID9huRbzMHFT4A) (extract code: srpl). If you use this code and the released models for your research, please cite our paper: 5 | ``` 6 | @inproceedings{liu2020efficient, 7 | title={Efficient Crowd Counting via Structured Knowledge Transfer}, 8 | author={Liu, Lingbo and Chen, Jiaqi and Wu, Hefeng and Chen, Tianshui and Li, Guanbin and Lin, Liang}, 9 | booktitle={ACM International Conference on Multimedia}, 10 | year={2020} 11 | } 12 | ``` 13 | 14 | ## Datasets 15 | ShanghaiTech: [Google Drive](https://drive.google.com/open?id=16dhJn7k4FWVwByRsQAEpl9lwjuV03jVI) 16 | 17 | UCF-QNRF: [Link](https://www.crcv.ucf.edu/data/ucf-qnrf/) 18 | 19 | ## Prerequisites 20 | We strongly recommend Anaconda as the environment. 21 | 22 | Python: 2.7 23 | 24 | PyTorch: 0.4.0 25 | 26 | ## Preprocessing 27 | 28 | 1. Generation the ground-truth density maps for training 29 | ``` 30 | # ShanghaiTech 31 | python preprocess/ShanghaiTech_GT_generation.py 32 | 33 | # UCF-QNRF 34 | python preprocess/UCF_GT_generation.py --mode train 35 | python preprocess/UCF_GT_generation.py --mode test 36 | ``` 37 | 38 | 2. Make data path files and edit this file to change the path to your original datasets. 39 | ``` 40 | python preprocess/make_json.py 41 | ``` 42 | 43 | 44 | ## Training 45 | Edit this file for distillation training 46 | ``` 47 | bash SKT_distill.sh 48 | ``` 49 | 50 | ## Testing 51 | Edit this file for testing models 52 | ``` 53 | bash test.sh 54 | ``` 55 | 56 | ## Models 57 | The well-trained teacher networks and the distilled student networks are released at have been released at [GoogleDrive](https://drive.google.com/drive/folders/17oxen8sNHtumcFL8hu9Z0Owuc6dWD8zV?usp=sharing) and [BaiduYun](https://pan.baidu.com/s/10_SLXF_FID9huRbzMHFT4A) (extract code: srpl 58 | ). In particular, only using around 6% of the parameters and computation cost of original models, our distilled VGG-based models obtain at least 6.5× speed-up on an Nvidia 1080 GPU and even achieve state-of-the-art performance. 59 | 60 | #### Shanghaitech A (576×864) 61 | | Method | MAE | RMSE | #Param (M) | FLOPs (G) | GPU (ms) | CPU (s) | Comment | 62 | | --- | --- | --- | --- |--- | --- | --- | --- | 63 | | CSRNet | 68.43 | 105.99 | 16.26 | 205.88 | 66.58 | 7.85 | teacher model, trained with [CSRNet-pytorch](https://github.com/leeyeehoo/CSRNet-pytorch) | 64 | | 1/4-CSRNet + SKT | 71.55 | 114.40 | 1.02 | 13.09 | 8.88 | 0.87 | -- | 65 | | BL | 61.46 | 103.17 | 21.50 | 205.32 | 47.89 | 8.84 | teacher model | 66 | | 1/4-BL + SKT | 62.73 | 102.33 | 1.35 | 13.06 | 7.40 | 0.88 | -- | 67 | 68 | #### UCF-QNRF (2032×2912) 69 | | Method | MAE | RMSE | #Param (M) | FLOPs (G) | GPU (ms) | CPU (s) | Comment | 70 | | --- | --- | --- | --- |--- | --- | --- | --- | 71 | | CSRNet | 145.54 | 233.32 | 16.26 | 2447.91 | 823.84 | 119.67 | teacher model, trained with [CSRNet-pytorch](https://github.com/leeyeehoo/CSRNet-pytorch) | 72 | | 1/4-CSRNet + SKT | 144.36 | 234.64 | 1.02 | 155.69 | 106.08 | 9.71 | -- | 73 | | BL | 87.70 | 158.09 | 21.50 | 2441.23 | 595.72 | 130.76 | teacher model | 74 | | 1/4-BL + SKT | 96.24 | 156.82 | 1.35 | 155.30 | 90.96 | 9.78 | The released model is much better. | 75 | -------------------------------------------------------------------------------- /SKT_distill.py: -------------------------------------------------------------------------------- 1 | """ 2 | SKT distillation 3 | """ 4 | import sys 5 | import os 6 | 7 | import warnings 8 | 9 | from models.model_teacher_vgg import CSRNet as CSRNet_teacher 10 | from models.model_student_vgg import CSRNet as CSRNet_student 11 | 12 | from utils import save_checkpoint, cal_para 13 | from models.distillation import cosine_similarity, scale_process, cal_dense_fsp 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.functional as F 18 | from torch.autograd import Variable 19 | from torchvision import datasets, transforms 20 | 21 | import numpy as np 22 | import argparse 23 | import json 24 | import dataset 25 | import time 26 | 27 | parser = argparse.ArgumentParser(description='CSRNet-SKT distillation') 28 | parser.add_argument('train_json', metavar='TRAIN', 29 | help='path to train json') 30 | parser.add_argument('val_json', metavar='VAL', 31 | help='path to val json') 32 | parser.add_argument('test_json', metavar='TEST', 33 | help='path to test json') 34 | parser.add_argument('--lr', default=None, type=float, 35 | help='learning rate') 36 | # parser.add_argument('--teacher', '-t', default=None, type=str, 37 | # help='teacher net version') 38 | parser.add_argument('--teacher_ckpt', '-tc', default=None, type=str, 39 | help='teacher checkpoint') 40 | # parser.add_argument('--student', '-s', default=None, type=str, 41 | # help='student net version') 42 | parser.add_argument('--student_ckpt', '-sc', default=None, type=str, 43 | help='student checkpoint') 44 | parser.add_argument('--lamb_fsp', '-laf', type=float, default=None, 45 | help='weight of dense fsp loss') 46 | parser.add_argument('--lamb_cos', '-lac', type=float, default=None, 47 | help='weight of cos loss') 48 | parser.add_argument('--gpu', metavar='GPU', type=str, default='0', 49 | help='GPU id to use') 50 | parser.add_argument('--out', metavar='OUTPUT', type=str, 51 | help='path to output') 52 | 53 | 54 | global args 55 | args = parser.parse_args() 56 | 57 | 58 | def main(): 59 | global args, mae_best_prec1, mse_best_prec1 60 | 61 | mae_best_prec1 = 1e6 62 | mse_best_prec1 = 1e6 63 | 64 | args.batch_size = 1 # args.batch 65 | args.momentum = 0.95 66 | args.decay = 5 * 1e-4 67 | args.start_epoch = 0 68 | args.epochs = 1000 69 | args.workers = 6 70 | args.seed = time.time() 71 | args.print_freq = 400 72 | with open(args.train_json, 'r') as outfile: 73 | train_list = json.load(outfile) 74 | with open(args.val_json, 'r') as outfile: 75 | val_list = json.load(outfile) 76 | with open(args.test_json, 'r') as outfile: 77 | test_list = json.load(outfile) 78 | 79 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 80 | torch.cuda.manual_seed(args.seed) 81 | 82 | teacher = CSRNet_teacher() 83 | student = CSRNet_student(ratio=4) 84 | cal_para(student) # include 1x1 conv transform parameters 85 | 86 | teacher.regist_hook() # use hook to get teacher's features 87 | teacher = teacher.cuda() 88 | student = student.cuda() 89 | 90 | criterion = nn.MSELoss(size_average=False).cuda() 91 | 92 | optimizer = torch.optim.Adam(student.parameters(), args.lr, weight_decay=args.decay) 93 | 94 | if os.path.isdir(args.out) is False: 95 | os.makedirs(args.out.decode('utf-8')) 96 | 97 | if args.teacher_ckpt: 98 | if os.path.isfile(args.teacher_ckpt): 99 | print("=> loading checkpoint '{}'".format(args.teacher_ckpt)) 100 | checkpoint = torch.load(args.teacher_ckpt) 101 | teacher.load_state_dict(checkpoint['state_dict']) 102 | print("=> loaded checkpoint '{}' (epoch {})" 103 | .format(args.teacher_ckpt, checkpoint['epoch'])) 104 | else: 105 | print("=> no checkpoint found at '{}'".format(args.teacher_ckpt)) 106 | 107 | if args.student_ckpt: 108 | if os.path.isfile(args.student_ckpt): 109 | print("=> loading checkpoint '{}'".format(args.student_ckpt)) 110 | checkpoint = torch.load(args.student_ckpt) 111 | args.start_epoch = checkpoint['epoch'] 112 | if 'best_prec1' in checkpoint.keys(): 113 | mae_best_prec1 = checkpoint['best_prec1'] 114 | else: 115 | mae_best_prec1 = checkpoint['mae_best_prec1'] 116 | if 'mse_best_prec1' in checkpoint.keys(): 117 | mse_best_prec1 = checkpoint['mse_best_prec1'] 118 | student.load_state_dict(checkpoint['state_dict']) 119 | optimizer.load_state_dict(checkpoint['optimizer']) 120 | print("=> loaded checkpoint '{}' (epoch {})" 121 | .format(args.student_ckpt, checkpoint['epoch'])) 122 | else: 123 | print("=> no checkpoint found at '{}'".format(args.student_ckpt)) 124 | 125 | for epoch in range(args.start_epoch, args.epochs): 126 | 127 | train(train_list, teacher, student, criterion, optimizer, epoch) 128 | mae_prec1, mse_prec1 = val(val_list, student) 129 | 130 | mae_is_best = mae_prec1 < mae_best_prec1 131 | mae_best_prec1 = min(mae_prec1, mae_best_prec1) 132 | mse_is_best = mse_prec1 < mse_best_prec1 133 | mse_best_prec1 = min(mse_prec1, mse_best_prec1) 134 | print('Best val * MAE {mae:.3f} * MSE {mse:.3f}' 135 | .format(mae=mae_best_prec1, mse=mse_best_prec1)) 136 | save_checkpoint({ 137 | 'epoch': epoch + 1, 138 | 'arch': args.student_ckpt, 139 | 'state_dict': student.state_dict(), 140 | 'mae_best_prec1': mae_best_prec1, 141 | 'mse_best_prec1': mse_best_prec1, 142 | 'optimizer': optimizer.state_dict(), 143 | }, mae_is_best, mse_is_best, args.out) 144 | 145 | if mae_is_best or mse_is_best: 146 | test(test_list, student) 147 | 148 | 149 | def train(train_list, teacher, student, criterion, optimizer, epoch): 150 | losses_h = AverageMeter() 151 | losses_s = AverageMeter() 152 | losses_fsp = AverageMeter() 153 | losses_cos = AverageMeter() 154 | batch_time = AverageMeter() 155 | data_time = AverageMeter() 156 | train_loader = torch.utils.data.DataLoader( 157 | dataset.listDataset(train_list, 158 | transform=transforms.Compose([ 159 | transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], 160 | std=[0.229, 0.224, 0.225]), 161 | ]), 162 | train=True, 163 | seen=student.seen, 164 | ), 165 | num_workers=args.workers, 166 | shuffle=True, 167 | batch_size=args.batch_size) 168 | print('epoch %d, lr %.10f %s' % (epoch, args.lr, args.out)) 169 | teacher.eval() 170 | student.train() 171 | end = time.time() 172 | 173 | for i, (img, target) in enumerate(train_loader): 174 | data_time.update(time.time() - end) 175 | img = img.cuda() 176 | img = Variable(img) 177 | target = target.type(torch.FloatTensor).cuda() 178 | target = Variable(target) 179 | 180 | with torch.no_grad(): 181 | teacher_output = teacher(img) 182 | teacher.features.append(teacher_output) 183 | teacher_fsp_features = [scale_process(teacher.features)] 184 | teacher_fsp = cal_dense_fsp(teacher_fsp_features) 185 | 186 | student_features = student(img) 187 | student_output = student_features[-1] 188 | student_fsp_features = [scale_process(student_features)] 189 | student_fsp = cal_dense_fsp(student_fsp_features) 190 | 191 | loss_h = criterion(student_output, target) 192 | loss_s = criterion(student_output, teacher_output) 193 | 194 | loss_fsp = torch.tensor([0.], dtype=torch.float).cuda() 195 | if args.lamb_fsp: 196 | loss_f = [] 197 | assert len(teacher_fsp) == len(student_fsp) 198 | for t in range(len(teacher_fsp)): 199 | loss_f.append(criterion(student_fsp[t], teacher_fsp[t])) 200 | loss_fsp = sum(loss_f) * args.lamb_fsp 201 | 202 | loss_cos = torch.tensor([0.], dtype=torch.float).cuda() 203 | if args.lamb_cos: 204 | loss_c = [] 205 | for t in range(len(student_features) - 1): 206 | loss_c.append(cosine_similarity(student_features[t], teacher.features[t])) 207 | loss_cos = sum(loss_c) * args.lamb_cos 208 | 209 | loss = loss_h + loss_s + loss_fsp + loss_cos 210 | 211 | losses_h.update(loss_h.item(), img.size(0)) 212 | losses_s.update(loss_s.item(), img.size(0)) 213 | losses_fsp.update(loss_fsp.item(), img.size(0)) 214 | losses_cos.update(loss_cos.item(), img.size(0)) 215 | optimizer.zero_grad() 216 | loss.backward() 217 | optimizer.step() 218 | batch_time.update(time.time() - end) 219 | end = time.time() 220 | if i % args.print_freq == (args.print_freq - 1): 221 | print('Epoch: [{0}][{1}/{2}]\t' 222 | 'Time {batch_time.avg:.3f} ' 223 | 'Data {data_time.avg:.3f} ' 224 | 'Loss_h {loss_h.avg:.4f} ' 225 | 'Loss_s {loss_s.avg:.4f} ' 226 | 'Loss_fsp {loss_fsp.avg:.4f} ' 227 | 'Loss_cos {loss_kl.avg:.4f} ' 228 | .format( 229 | epoch, i, len(train_loader), batch_time=batch_time, 230 | data_time=data_time, loss_h=losses_h, loss_s=losses_s, 231 | loss_fsp=losses_fsp, loss_kl=losses_cos)) 232 | 233 | 234 | def val(val_list, model): 235 | print('begin val') 236 | val_loader = torch.utils.data.DataLoader( 237 | dataset.listDataset(val_list, 238 | transform=transforms.Compose([ 239 | transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], 240 | std=[0.229, 0.224, 0.225]), 241 | ]), train=False), 242 | num_workers=args.workers, 243 | shuffle=False, 244 | batch_size=args.batch_size) 245 | 246 | model.eval() 247 | 248 | mae = 0 249 | mse = 0 250 | 251 | for i, (img, target) in enumerate(val_loader): 252 | img = img.cuda() 253 | img = Variable(img) 254 | 255 | with torch.no_grad(): 256 | output = model(img) 257 | 258 | mae += abs(output.data.sum() - target.sum().type(torch.FloatTensor).cuda()) 259 | mse += (output.data.sum() - target.sum().type(torch.FloatTensor).cuda()).pow(2) 260 | 261 | N = len(val_loader) 262 | mae = mae / N 263 | mse = torch.sqrt(mse / N) 264 | print('Val * MAE {mae:.3f} * MSE {mse:.3f}' 265 | .format(mae=mae, mse=mse)) 266 | 267 | return mae, mse 268 | 269 | 270 | def test(test_list, model): 271 | print('testing current model...') 272 | test_loader = torch.utils.data.DataLoader( 273 | dataset.listDataset(test_list, 274 | transform=transforms.Compose([ 275 | transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], 276 | std=[0.229, 0.224, 0.225]), 277 | ]), train=False), 278 | num_workers=args.workers, 279 | shuffle=False, 280 | batch_size=args.batch_size) 281 | 282 | model.eval() 283 | 284 | mae = 0 285 | mse = 0 286 | 287 | for i, (img, target) in enumerate(test_loader): 288 | img = img.cuda() 289 | img = Variable(img) 290 | 291 | with torch.no_grad(): 292 | output = model(img) 293 | 294 | mae += abs(output.data.sum() - target.sum().type(torch.FloatTensor).cuda()) 295 | mse += (output.data.sum() - target.sum().type(torch.FloatTensor).cuda()).pow(2) 296 | 297 | N = len(test_loader) 298 | mae = mae / N 299 | mse = torch.sqrt(mse / N) 300 | print('Test * MAE {mae:.3f} * MSE {mse:.3f} ' 301 | .format(mae=mae, mse=mse)) 302 | 303 | 304 | class AverageMeter(object): 305 | """Computes and stores the average and current value""" 306 | 307 | def __init__(self): 308 | self.reset() 309 | 310 | def reset(self): 311 | self.val = 0 312 | self.avg = 0 313 | self.sum = 0 314 | self.count = 0 315 | 316 | def update(self, val, n=1): 317 | self.val = val 318 | self.sum += val * n 319 | self.count += n 320 | self.avg = self.sum / self.count 321 | 322 | 323 | if __name__ == '__main__': 324 | main() 325 | -------------------------------------------------------------------------------- /SKT_distill.sh: -------------------------------------------------------------------------------- 1 | python SKT_distill.py A_train.json A_val.json A_test.json \ 2 | --lr 1e-4 \ 3 | -tc '/teacher/model/ckeckpoint' \ 4 | -laf 0.5 \ 5 | -lac 0.5 \ 6 | --out /save/path/to/output -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | from image import * 8 | import torchvision.transforms.functional as F 9 | 10 | 11 | class listDataset(Dataset): 12 | def __init__(self, root, shape=None, transform=None, train=False, seen=0, 13 | batch_size=1, num_workers=20, dataset='shanghai'): 14 | if train and dataset == 'shanghai': 15 | root = root*4 16 | random.shuffle(root) 17 | 18 | self.nSamples = len(root) 19 | self.lines = root 20 | self.transform = transform 21 | self.train = train 22 | self.shape = shape 23 | self.seen = seen 24 | self.batch_size = batch_size 25 | self.num_workers = num_workers 26 | 27 | self.dataset = dataset 28 | 29 | def __len__(self): 30 | return self.nSamples 31 | 32 | def __getitem__(self, index): 33 | assert index <= len(self), 'index range error' 34 | 35 | img_path = self.lines[index] 36 | 37 | if self.dataset == 'ucf_test': 38 | # test in UCF 39 | img, target = load_ucf_ori_data(img_path) 40 | else: 41 | img, target = load_data(img_path, self.train, self.dataset) 42 | 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | return img, target -------------------------------------------------------------------------------- /image.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | from PIL import Image,ImageFilter,ImageDraw 4 | import numpy as np 5 | import h5py 6 | from PIL import ImageStat 7 | import cv2 8 | import time 9 | 10 | 11 | def load_data(img_path,train=True, dataset='shanghai'): 12 | """ Load data 13 | 14 | Use crop_ratio between 0.5 and 1.0 for random crop 15 | """ 16 | gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground_truth') 17 | img = Image.open(img_path).convert('RGB') 18 | gt_file = h5py.File(gt_path) 19 | target = np.asarray(gt_file['density']) 20 | if train: 21 | if dataset == 'shanghai': 22 | crop_ratio = random.uniform(0.5, 1.0) 23 | crop_size = (int(crop_ratio*img.size[0]), int(crop_ratio*img.size[1])) 24 | dx = int(random.random() * (img.size[0]-crop_size[0])) 25 | dy = int(random.random() * (img.size[1]-crop_size[1])) 26 | 27 | img = img.crop((dx,dy,crop_size[0]+dx,crop_size[1]+dy)) 28 | target = target[dy:crop_size[1]+dy,dx:crop_size[0]+dx] 29 | 30 | if random.random() > 0.8: 31 | target = np.fliplr(target) 32 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 33 | 34 | target = reshape_target(target, 3) 35 | target = np.expand_dims(target, axis=0) 36 | 37 | img = img.copy() 38 | target = target.copy() 39 | return img, target 40 | 41 | 42 | def load_ucf_ori_data(img_path): 43 | """ Load original UCF-QNRF data for testing 44 | 45 | """ 46 | gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground_truth') 47 | img = Image.open(img_path).convert('RGB') 48 | gt_file = h5py.File(gt_path) 49 | target = np.asarray(gt_file['density']) 50 | return img, target 51 | 52 | 53 | def reshape_target(target, down_sample=3): 54 | """ Down sample GT to 1/8 55 | 56 | """ 57 | height = target.shape[0] 58 | width = target.shape[1] 59 | 60 | # ceil_mode=True for nn.MaxPool2d in model 61 | for i in range(down_sample): 62 | height = int((height+1)/2) 63 | width = int((width+1)/2) 64 | # height = int(height/2) 65 | # width = int(width/2) 66 | 67 | target = cv2.resize(target, (width, height), interpolation=cv2.INTER_CUBIC) * (2**(down_sample*2)) 68 | return target 69 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chen-judge/SKT/d5af30f5e6b6957e2e07acce31465f5c76134fd3/models/__init__.py -------------------------------------------------------------------------------- /models/distillation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def cosine_similarity(stu_map, tea_map): 7 | similiar = 1-F.cosine_similarity(stu_map, tea_map, dim=1) 8 | loss = similiar.sum() 9 | return loss 10 | 11 | 12 | def cal_dense_fsp(features): 13 | fsp = [] 14 | for groups in features: 15 | for i in range(len(groups)): 16 | for j in range(i+1, len(groups)): 17 | x = groups[i] 18 | y = groups[j] 19 | 20 | norm1 = nn.InstanceNorm2d(x.shape[1]) 21 | norm2 = nn.InstanceNorm2d(y.shape[1]) 22 | x = norm1(x) 23 | y = norm2(y) 24 | res = gram(x, y) 25 | fsp.append(res) 26 | return fsp 27 | 28 | 29 | def gram(x, y): 30 | n = x.shape[0] 31 | c1 = x.shape[1] 32 | c2 = y.shape[1] 33 | h = x.shape[2] 34 | w = x.shape[3] 35 | x = x.view(n, c1, -1, 1)[0, :, :, 0] 36 | y = y.view(n, c2, -1, 1)[0, :, :, 0] 37 | y = y.transpose(0, 1) 38 | # print x.shape 39 | # print y.shape 40 | z = torch.mm(x, y) / (w*h) 41 | return z 42 | 43 | 44 | def scale_process(features, scale=[3, 2, 1], ceil_mode=True): 45 | # process features for multi-scale dense fsp 46 | new_features = [] 47 | for i in range(len(features)): 48 | if i >= len(scale): 49 | new_features.append(features[i]) 50 | continue 51 | down_ratio = pow(2, scale[i]) 52 | pool = nn.MaxPool2d(kernel_size=down_ratio, stride=down_ratio, ceil_mode=ceil_mode) 53 | new_features.append(pool(features[i])) 54 | return new_features -------------------------------------------------------------------------------- /models/model_student_vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Student model (1/n-CSRNet) in SKT. 3 | """ 4 | import torch.nn as nn 5 | import torch 6 | from torchvision import models 7 | 8 | channel_nums = [[32, 64, 128, 256], # half 9 | [21, 43, 85, 171], # third 10 | [16, 32, 64, 128], # quarter 11 | [13, 26, 51, 102], # fifth 12 | ] 13 | 14 | 15 | class CSRNet(nn.Module): 16 | def __init__(self, ratio=4, transform=True): 17 | super(CSRNet, self).__init__() 18 | self.seen = 0 19 | self.transform = transform 20 | channel = channel_nums[ratio-2] 21 | self.conv0_0 = conv_layers(3, channel[0]) 22 | if self.transform: 23 | self.transform0_0 = feature_transform(channel[0], 64) 24 | self.conv0_1 = conv_layers(channel[0], channel[0]) 25 | 26 | self.pool0 = pool_layers() 27 | if transform: 28 | self.transform1_0 = feature_transform(channel[0], 64) 29 | self.conv1_0 = conv_layers(channel[0], channel[1]) 30 | self.conv1_1 = conv_layers(channel[1], channel[1]) 31 | 32 | self.pool1 = pool_layers() 33 | if transform: 34 | self.transform2_0 = feature_transform(channel[1], 128) 35 | self.conv2_0 = conv_layers(channel[1], channel[2]) 36 | self.conv2_1 = conv_layers(channel[2], channel[2]) 37 | self.conv2_2 = conv_layers(channel[2], channel[2]) 38 | 39 | self.pool2 = pool_layers() 40 | if transform: 41 | self.transform3_0 = feature_transform(channel[2], 256) 42 | self.conv3_0 = conv_layers(channel[2], channel[3]) 43 | self.conv3_1 = conv_layers(channel[3], channel[3]) 44 | self.conv3_2 = conv_layers(channel[3], channel[3]) 45 | 46 | self.conv4_0 = conv_layers(channel[3], channel[3], dilation=2) 47 | if transform: 48 | self.transform4_0 = feature_transform(channel[3], 512) 49 | self.conv4_1 = conv_layers(channel[3], channel[3], dilation=2) 50 | self.conv4_2 = conv_layers(channel[3], channel[3], dilation=2) 51 | self.conv4_3 = conv_layers(channel[3], channel[2], dilation=2) 52 | if transform: 53 | self.transform4_3 = feature_transform(channel[2], 256) 54 | self.conv4_4 = conv_layers(channel[2], channel[1], dilation=2) 55 | self.conv4_5 = conv_layers(channel[1], channel[0], dilation=2) 56 | 57 | self.conv5_0 = nn.Conv2d(channel[0], 1, kernel_size=1) 58 | 59 | self._initialize_weights() 60 | self.features = [] 61 | 62 | def forward(self, x): 63 | self.features = [] 64 | 65 | x = self.conv0_0(x) 66 | if self.transform: 67 | self.features.append(self.transform0_0(x)) 68 | x = self.conv0_1(x) 69 | 70 | x = self.pool0(x) 71 | if self.transform: 72 | self.features.append(self.transform1_0(x)) 73 | x = self.conv1_0(x) 74 | x = self.conv1_1(x) 75 | 76 | x = self.pool1(x) 77 | if self.transform: 78 | self.features.append(self.transform2_0(x)) 79 | x = self.conv2_0(x) 80 | x = self.conv2_1(x) 81 | x = self.conv2_2(x) 82 | 83 | x = self.pool2(x) 84 | if self.transform: 85 | self.features.append(self.transform3_0(x)) 86 | x = self.conv3_0(x) 87 | x = self.conv3_1(x) 88 | x = self.conv3_2(x) 89 | 90 | x = self.conv4_0(x) 91 | if self.transform: 92 | self.features.append(self.transform4_0(x)) 93 | x = self.conv4_1(x) 94 | x = self.conv4_2(x) 95 | x = self.conv4_3(x) 96 | if self.transform: 97 | self.features.append(self.transform4_3(x)) 98 | x = self.conv4_4(x) 99 | x = self.conv4_5(x) 100 | 101 | x = self.conv5_0(x) 102 | 103 | self.features.append(x) 104 | 105 | if self.training is True: 106 | return self.features 107 | return x 108 | 109 | def _initialize_weights(self): 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | # nn.init.xavier_normal_(m.weight) 113 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 114 | # nn.init.normal_(m.weight, std=0.01) 115 | if m.bias is not None: 116 | nn.init.constant_(m.bias, 0) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | 122 | def conv_layers(inp, oup, dilation=False): 123 | if dilation: 124 | d_rate = 2 125 | else: 126 | d_rate = 1 127 | return nn.Sequential( 128 | nn.Conv2d(inp, oup, kernel_size=3, padding=d_rate, dilation=d_rate), 129 | nn.ReLU(inplace=True) 130 | ) 131 | 132 | 133 | def feature_transform(inp, oup): 134 | conv2d = nn.Conv2d(inp, oup, kernel_size=1) # no padding 135 | relu = nn.ReLU(inplace=True) 136 | layers = [] 137 | layers += [conv2d, relu] 138 | return nn.Sequential(*layers) 139 | 140 | 141 | def pool_layers(ceil_mode=True): 142 | return nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode) 143 | -------------------------------------------------------------------------------- /models/model_teacher_vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Teacher model in SKT 3 | """ 4 | import torch.nn as nn 5 | import torch 6 | from torchvision import models 7 | from utils import save_net, load_net, cal_para 8 | 9 | 10 | class CSRNet(nn.Module): 11 | def __init__(self, pretrained=False): 12 | super(CSRNet, self).__init__() 13 | self.seen = 0 14 | self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] 15 | self.backend_feat = [512, 512, 512, 256, 128, 64] 16 | self.frontend = make_layers(self.frontend_feat) 17 | self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True) 18 | self.output_layer = nn.Conv2d(64, 1, kernel_size=1) 19 | self._initialize_weights() 20 | self.features = [] 21 | if pretrained: 22 | print 'load vgg pretrained model' 23 | mod = models.vgg16(pretrained=True) 24 | for i in xrange(len(self.frontend.state_dict().items())): 25 | self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:] 26 | 27 | def forward(self, x): 28 | self.features = [] 29 | # frontend: VGG 30 | x = self.frontend(x) 31 | # backend: dilated convolution 32 | x = self.backend(x) 33 | x = self.output_layer(x) 34 | return x 35 | 36 | def _initialize_weights(self): 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.normal_(m.weight, std=0.01) 40 | if m.bias is not None: 41 | nn.init.constant_(m.bias, 0) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | nn.init.constant_(m.weight, 1) 44 | nn.init.constant_(m.bias, 0) 45 | 46 | def regist_hook(self): 47 | self.features = [] 48 | 49 | def get(model, input, output): 50 | # function will be automatically called each time, since the hook is injected 51 | self.features.append(output.detach()) 52 | 53 | for name, module in self._modules['frontend']._modules.items(): 54 | if name in ['1', '4', '9', '16']: 55 | self._modules['frontend']._modules[name].register_forward_hook(get) 56 | for name, module in self._modules['backend']._modules.items(): 57 | if name in ['1', '7']: 58 | self._modules['backend']._modules[name].register_forward_hook(get) 59 | 60 | 61 | def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False): 62 | if dilation: 63 | d_rate = 2 64 | else: 65 | d_rate = 1 66 | layers = [] 67 | for v in cfg: 68 | if v == 'M': 69 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 70 | else: 71 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate) 72 | if batch_norm: 73 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 74 | else: 75 | layers += [conv2d, nn.ReLU(inplace=True)] 76 | in_channels = v 77 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /models/model_vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | For training CSRNet teacher 3 | """ 4 | import torch.nn as nn 5 | import torch 6 | from torchvision import models 7 | # from utils import save_net,load_net 8 | import time 9 | 10 | 11 | class CSRNet(nn.Module): 12 | def __init__(self, pretrained=True): 13 | super(CSRNet, self).__init__() 14 | self.seen = 0 15 | self.frontend_feat = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] 16 | self.backend_feat = [512, 512, 512, 256, 128, 64] 17 | self.frontend = make_layers(self.frontend_feat) 18 | # cal_para(self.frontend) 19 | self.backend = make_layers(self.backend_feat, in_channels=512, dilation=True) 20 | self.output_layer = nn.Conv2d(64, 1, kernel_size=1) 21 | if pretrained: 22 | self._initialize_weights(mode='normal') 23 | mod = models.vgg16(pretrained=True) 24 | state_keys = list(self.frontend.state_dict().keys()) 25 | pretrain_keys = list(mod.state_dict().keys()) 26 | for i in range(len(self.frontend.state_dict().items())): 27 | # self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:] 28 | # print(mod.state_dict()[pretrain_keys[i]]) 29 | self.frontend.state_dict()[state_keys[i]].data = mod.state_dict()[pretrain_keys[i]].data 30 | else: 31 | self._initialize_weights(mode='kaiming') 32 | 33 | def forward(self, x): 34 | # front relates to VGG 35 | x = self.frontend(x) 36 | # backend relates to dilated convolution 37 | x = self.backend(x) 38 | x = self.output_layer(x) 39 | return x 40 | 41 | def _initialize_weights(self, mode): 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | if mode == 'normal': 45 | nn.init.normal_(m.weight, std=0.01) 46 | elif mode == 'kaiming': 47 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 48 | if m.bias is not None: 49 | nn.init.constant_(m.bias, 0) 50 | elif isinstance(m, nn.BatchNorm2d): 51 | nn.init.constant_(m.weight, 1) 52 | nn.init.constant_(m.bias, 0) 53 | 54 | 55 | def make_layers(cfg, in_channels=3, batch_norm=False, dilation=False): 56 | if dilation: 57 | d_rate = 2 58 | else: 59 | d_rate = 1 60 | layers = [] 61 | for v in cfg: 62 | if v == 'M': 63 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 64 | else: 65 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate, dilation=d_rate) 66 | if batch_norm: 67 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 68 | else: 69 | layers += [conv2d, nn.ReLU(inplace=True)] 70 | in_channels = v 71 | return nn.Sequential(*layers) 72 | 73 | -------------------------------------------------------------------------------- /preprocess/ShanghaiTech_GT_generation.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import scipy.io as io 3 | import numpy as np 4 | import os 5 | import glob 6 | from matplotlib import pyplot as plt 7 | from scipy.ndimage.filters import gaussian_filter 8 | import scipy 9 | 10 | 11 | def gaussian_filter_density(gt): 12 | density = np.zeros(gt.shape, dtype=np.float32) 13 | gt_count = np.count_nonzero(gt) # nonzero value represent people in labels 14 | if gt_count == 0: # gt_count is the amount of people 15 | return density 16 | 17 | pts = np.array(zip(np.nonzero(gt)[1], np.nonzero(gt)[0])) # human label position 18 | leafsize = 2048 19 | # build kdtree 20 | tree = scipy.spatial.KDTree(pts.copy(), leafsize=leafsize) 21 | # query kdtree 22 | distances, locations = tree.query(pts, k=4) 23 | 24 | print 'generate density...' 25 | for i, pt in enumerate(pts): 26 | pt2d = np.zeros(gt.shape, dtype=np.float32) 27 | pt2d[pt[1],pt[0]] = 1. 28 | if gt_count > 1: 29 | sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1 30 | else: 31 | sigma = np.average(np.array(gt.shape))/2./2. # case: 1 point 32 | density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 33 | print 'done.' 34 | return density 35 | 36 | 37 | root = '/media/firstPartition/cjq/ShanghaiTech-test' 38 | 39 | # now generate the ShanghaiA's ground truth 40 | part_A_train = os.path.join(root,'part_A_final/train_data','images') 41 | part_A_test = os.path.join(root,'part_A_final/test_data','images') 42 | part_B_train = os.path.join(root,'part_B_final/train_data','images') 43 | part_B_test = os.path.join(root,'part_B_final/test_data','images') 44 | path_sets = [part_A_train,part_A_test] 45 | 46 | img_paths = [] 47 | for path in path_sets: 48 | for img_path in glob.glob(os.path.join(path, '*.jpg')): 49 | img_paths.append(img_path) 50 | 51 | for img_path in img_paths: 52 | # for every image 53 | print 'image path: ', img_path 54 | mat = io.loadmat(img_path.replace('.jpg','.mat').replace('images','ground_truth').replace('IMG_','GT_IMG_')) 55 | img = plt.imread(img_path) 56 | k = np.zeros((img.shape[0],img.shape[1])) 57 | gt = mat["image_info"][0,0][0,0][0] 58 | for i in range(0,len(gt)): 59 | if int(gt[i][1]) 1: 41 | sigma = distances[i][1] 42 | sigma = min(sigma, threshold) # nearest 43 | else: 44 | sigma = np.average(np.array(gt.shape))/2./2. # case: 1 point 45 | density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 46 | return density 47 | 48 | 49 | root = '/media/firstPartition/cjq/UCF-QNRF-test' 50 | 51 | train_path = os.path.join(root, 'Train') 52 | test_path = os.path.join(root, 'Test') 53 | 54 | if args.mode == 'train': 55 | path = train_path 56 | else: 57 | path = test_path 58 | 59 | paths = glob.glob(os.path.join(path, '*.jpg')) 60 | paths.sort() 61 | if args.start and args.end: 62 | processed_imgs = paths[args.start:args.end] # It will take a long time and can be processed in parts 63 | else: 64 | processed_imgs = paths 65 | # print processed_imgs 66 | 67 | for img_path in processed_imgs: 68 | start = time.time() 69 | img = plt.imread(img_path) 70 | (name, _) = os.path.splitext(img_path) 71 | mat = io.loadmat(name+'_ann.mat') 72 | gt = mat['annPoints'] 73 | 74 | k = np.zeros((img.shape[0], img.shape[1])) 75 | print 'GT len & shape: ', len(gt), img.shape, ' img path: ', img_path 76 | for i in range(0, len(gt)): 77 | if int(gt[i][1]) < img.shape[0] and int(gt[i][0]) < img.shape[1]: 78 | k[int(gt[i][1]), int(gt[i][0])] = 1 79 | k = gaussian_filter_density(k) 80 | # save the Density Maps GT as h5 format 81 | with h5py.File(img_path.replace('.jpg', '.h5'), 'w') as hf: 82 | hf['density'] = k 83 | print 'time: ', time.time() - start 84 | 85 | -------------------------------------------------------------------------------- /preprocess/make_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make json files for dataset 3 | """ 4 | import json 5 | import os 6 | 7 | 8 | def get_val(root): 9 | """ 10 | Validation set follows part_A_val.json in CSRNet 11 | https://github.com/leeyeehoo/CSRNet-pytorch 12 | """ 13 | with open("preprocess/part_A_val.json") as f: 14 | val_list = json.load(f) 15 | new_val = [] 16 | for item in val_list: 17 | new_item = item.replace('/home/leeyh/Downloads/Shanghai/', root) 18 | new_val.append(new_item) 19 | with open('A_val.json', 'w') as f: 20 | json.dump(new_val, f) 21 | 22 | 23 | def get_train(root): 24 | path = os.path.join(root, 'part_A_final', 'train_data', 'images') 25 | filenames = os.listdir(path) 26 | pathname = [os.path.join(path, filename) for filename in filenames] 27 | with open('A_train.json', 'w') as f: 28 | json.dump(pathname, f) 29 | 30 | 31 | def get_test(root): 32 | path = os.path.join(root, 'part_A_final', 'test_data', 'images') 33 | filenames = os.listdir(path) 34 | pathname = [os.path.join(path, filename) for filename in filenames] 35 | with open('A_test.json', 'w') as f: 36 | json.dump(pathname, f) 37 | 38 | 39 | if __name__ == '__main__': 40 | root = '/media/firstPartition/cjq/ShanghaiTech/' # Dataset path 41 | get_train(root) 42 | get_val(root) 43 | get_test(root) 44 | print 'Finish!' 45 | 46 | -------------------------------------------------------------------------------- /preprocess/part_A_val.json: -------------------------------------------------------------------------------- 1 | ["/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_129.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_3.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_100.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_289.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_66.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_221.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_61.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_24.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_9.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_95.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_233.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_151.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_122.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_188.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_187.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_121.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_58.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_34.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_246.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_81.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_179.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_166.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_67.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_259.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_28.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_245.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_203.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_97.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_152.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_124.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_137.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_183.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_39.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_35.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_109.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_182.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_143.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_125.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_176.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_76.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_273.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_299.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_78.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_276.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_89.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_63.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_238.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_199.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_132.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_56.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_20.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_243.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_258.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_253.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_8.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_31.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_44.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_282.jpg", "/home/leeyh/Downloads/Shanghai/part_A_final/train_data/images/IMG_208.jpg"] -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import warnings 5 | 6 | from models.model_vgg import CSRNet as CSRNet_vgg 7 | from models.model_student_vgg import CSRNet as CSRNet_student 8 | 9 | from utils import save_checkpoint 10 | from utils import cal_para, crop_img_patches 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | from torchvision import datasets, transforms 16 | import json 17 | 18 | import numpy as np 19 | import argparse 20 | import json 21 | import dataset 22 | import time 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch CSRNet') 25 | 26 | parser.add_argument('test_json', metavar='TEST', 27 | help='path to test json') 28 | parser.add_argument('--dataset', '-d', default='Shanghai', type=str, 29 | help='Shanghai/UCF') 30 | parser.add_argument('--checkpoint', '-c', metavar='CHECKPOINT', default=None, type=str, 31 | help='path to the checkpoint') 32 | parser.add_argument('--version', '-v', default=None, type=str, 33 | help='vgg/quarter_vgg') 34 | parser.add_argument('--transform', '-t', default=True, type=str, 35 | help='1x1 conv transform') 36 | parser.add_argument('--batch', default=1, type=int, 37 | help='batch size') 38 | parser.add_argument('--gpu', metavar='GPU', default='0', type=str, 39 | help='GPU id to use.') 40 | 41 | args = parser.parse_args() 42 | 43 | 44 | def main(): 45 | global args, best_prec1 46 | 47 | args.batch_size = 1 48 | args.workers = 4 49 | args.seed = time.time() 50 | if args.transform == 'false': 51 | args.transform = False 52 | 53 | with open(args.test_json, 'r') as outfile: 54 | test_list = json.load(outfile) 55 | 56 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 57 | torch.cuda.manual_seed(args.seed) 58 | 59 | if args.version == 'vgg': 60 | print 'VGG' 61 | model = CSRNet_vgg(pretrained=False) 62 | print model 63 | cal_para(model) 64 | 65 | elif args.version == 'quarter_vgg': 66 | print 'quarter_VGG' 67 | model = CSRNet_student(ratio=4, transform=args.transform) 68 | print model 69 | cal_para(model) # including 1x1conv transform layer that can be removed 70 | else: 71 | raise NotImplementedError() 72 | 73 | model = model.cuda() 74 | 75 | if args.checkpoint: 76 | if os.path.isfile(args.checkpoint): 77 | print("=> loading checkpoint '{}'".format(args.checkpoint)) 78 | checkpoint = torch.load(args.checkpoint) 79 | 80 | if args.transform is False: 81 | # remove 1x1 conv para 82 | for k in checkpoint['state_dict'].keys(): 83 | if k[:9] == 'transform': 84 | del checkpoint['state_dict'][k] 85 | 86 | model.load_state_dict(checkpoint['state_dict']) 87 | print("=> loaded checkpoint '{}' (epoch {})" 88 | .format(args.checkpoint, checkpoint['epoch'])) 89 | else: 90 | print("=> no checkpoint found at '{}'".format(args.checkpoint)) 91 | 92 | if args.dataset == 'UCF': 93 | test_ucf(test_list, model) 94 | else: 95 | test(test_list, model) 96 | 97 | 98 | def test(test_list, model): 99 | print('begin test') 100 | test_loader = torch.utils.data.DataLoader( 101 | dataset.listDataset(test_list, 102 | transform=transforms.Compose([ 103 | transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], 104 | std=[0.229, 0.224, 0.225]), 105 | ]), 106 | train=False), 107 | shuffle=False, 108 | batch_size=args.batch_size) 109 | 110 | model.eval() 111 | 112 | mae = 0 113 | mse = 0 114 | 115 | for i, (img, target) in enumerate(test_loader): 116 | img = img.cuda() 117 | img = Variable(img) 118 | with torch.no_grad(): 119 | output = model(img) 120 | 121 | mae += abs(output.data.sum() - target.sum().type(torch.FloatTensor).cuda()) 122 | mse += (output.data.sum() - target.sum().type(torch.FloatTensor).cuda()).pow(2) 123 | 124 | N = len(test_loader) 125 | mae = mae / N 126 | mse = torch.sqrt(mse / N) 127 | print(' * MAE {mae:.3f} \t * MSE {mse:.3f}' 128 | .format(mae=mae, mse=mse)) 129 | 130 | 131 | def test_ucf(test_list, model): 132 | print 'begin test' 133 | test_loader = torch.utils.data.DataLoader( 134 | dataset.listDataset(test_list, 135 | transform=transforms.Compose([ 136 | transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], 137 | std=[0.229, 0.224, 0.225]), 138 | ]), 139 | train=False, 140 | dataset='ucf_test', 141 | ), 142 | shuffle=False, 143 | batch_size=1) 144 | 145 | model.eval() 146 | 147 | mae = 0 148 | mse = 0 149 | 150 | for i, (img, target) in enumerate(test_loader): 151 | img = img.cuda() 152 | img = Variable(img) 153 | 154 | people = 0 155 | img_patches = crop_img_patches(img, size=512) 156 | for patch in img_patches: 157 | with torch.no_grad(): 158 | sub_output = model(patch) 159 | people += sub_output.data.sum() 160 | 161 | error = people - target.sum().type(torch.FloatTensor).cuda() 162 | mae += abs(error) 163 | mse += error.pow(2) 164 | 165 | N = len(test_loader) 166 | mae = mae / N 167 | mse = torch.sqrt(mse / N) 168 | print(' * MAE {mae:.3f} \t * MSE {mse:.3f}' 169 | .format(mae=mae, mse=mse)) 170 | 171 | return mae, mse 172 | 173 | 174 | if __name__ == '__main__': 175 | main() -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py A_test.json \ 2 | -c '/model/file' \ 3 | -v quarter_vgg 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import shutil 6 | import os 7 | import time 8 | 9 | 10 | def save_net(fname, net): 11 | with h5py.File(fname, 'w') as h5f: 12 | for k, v in net.state_dict().items(): 13 | h5f.create_dataset(k, data=v.cpu().numpy()) 14 | 15 | 16 | def load_net(fname, net): 17 | with h5py.File(fname, 'r') as h5f: 18 | for k, v in net.state_dict().items(): 19 | param = torch.from_numpy(np.asarray(h5f[k])) 20 | v.copy_(param) 21 | 22 | 23 | def save_checkpoint(state, mae_is_best, mse_is_best, path, filename='checkpoint.pth.tar'): 24 | torch.save(state, os.path.join(path, filename)) 25 | epoch = state['epoch'] 26 | if mae_is_best: 27 | shutil.copyfile(os.path.join(path, filename), os.path.join(path, 'epoch'+str(epoch)+'_best_mae.pth.tar')) 28 | if mse_is_best: 29 | shutil.copyfile(os.path.join(path, filename), os.path.join(path, 'epoch'+str(epoch)+'_best_mse.pth.tar')) 30 | 31 | 32 | def cal_para(net): 33 | params = list(net.parameters()) 34 | k = 0 35 | for i in params: 36 | l = 1 37 | # print "stucture of layer: " + str(list(i.size())) 38 | for j in i.size(): 39 | l *= j 40 | # print "para in this layer: " + str(l) 41 | k = k + l 42 | print("the amount of para: " + str(k)) 43 | 44 | 45 | def crop_img_patches(img, size=512): 46 | """ crop the test images to patches 47 | 48 | while testing UCF data, we load original images, then use crop_img_patches to crop the test images to patches, 49 | calculate the crowd count respectively and sum them together finally 50 | """ 51 | w = img.shape[3] 52 | h = img.shape[2] 53 | x = int(w/size)+1 54 | y = int(h/size)+1 55 | crop_w = int(w/x) 56 | crop_h = int(h/y) 57 | patches = [] 58 | for i in range(x): 59 | for j in range(y): 60 | start_x = crop_w*i 61 | if i == x-1: 62 | end_x = w 63 | else: 64 | end_x = crop_w*(i+1) 65 | 66 | start_y = crop_h*j 67 | if j == y - 1: 68 | end_y = h 69 | else: 70 | end_y = crop_h*(j+1) 71 | 72 | sub_img = img[:, :, start_y:end_y, start_x:end_x] 73 | patches.append(sub_img) 74 | return patches 75 | --------------------------------------------------------------------------------