├── .gitignore ├── LICENSE ├── README.md ├── data_preparation ├── create_gt_test_set_shtech.m ├── create_training_set_shtech.m └── get_density_map_gaussian.m ├── src ├── crowd_count.py ├── data_loader.py ├── evaluate_model.py ├── models │ ├── base.py │ ├── deep.py │ └── wide.py ├── network.py ├── timer.py └── utils.py ├── test.py ├── thumbnails └── stackpool.jpg └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 siyuhuang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Stacked Pooling for Boosting Scale Invariance of Crowd Counting 2 | 3 | PyTorch implementation of "**Stacked Pooling for Boosting Scale Invariance of Crowd Counting**" [\[ICASSP 2020\]](https://siyuhuang.github.io/papers/ICASSP-2020-STACKED%20POOLING%20FOR%20BOOSTING%20SCALE%20INVARIANCE%20OF%20CROWD%20COUNTING.pdf). 4 | 5 | ``` 6 | @inproceedings{huang2020stacked, 7 | title={Stacked Pooling for Boosting Scale Invariance of Crowd Counting}, 8 | author={Huang, Siyu and Li, Xi and Cheng, Zhi-Qi and Zhang, Zhongfei and Hauptmann, Alexander}, 9 | booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing}, 10 | pages={2578--2582}, 11 | year={2020}, 12 | } 13 | ``` 14 | 15 | This code is implemented based on [https://github.com/svishwa/crowdcount-mcnn](https://github.com/svishwa/crowdcount-mcnn) 16 | 17 |

18 | 19 |

20 | 21 | | | ShanghaiTech-A | ShanghaiTech-B | WorldExpo'10| 22 | | -------- | :-----: | :----: | :----: | 23 | | Vanilla Pooling | 97.63 | 21.17 | 14.74 | 24 | | Stacked Pooling | **93.98** | **18.73** | **12.92**| 25 | 26 | 27 | ## Dependency 28 | 1. Python 2.7 29 | 2. PyTorch 0.4.0 30 | 31 | ## Data Setup 32 | 1. Download ShanghaiTech Dataset from 33 | Dropbox: https://www.dropbox.com/s/fipgjqxl7uj8hd5/ShanghaiTech.zip?dl=0 34 | Baidu Disk: http://pan.baidu.com/s/1nuAYslz 35 | 2. Create Directory `mkdir ./data/original/shanghaitech/` 36 | 3. Save "part_A_final" under ./data/original/shanghaitech/ 37 | Save "part_B_final" under ./data/original/shanghaitech/ 38 | 4. `cd ./data_preparation/` 39 | Run `create_gt_test_set_shtech.m` in matlab to create ground truth files for test data 40 | Run `create_training_set_shtech.m` in matlab to create training and validataion set along with ground truth files 41 | 42 | ## Train 43 | 1. To train **Deep Net**+**vanilla pooling** on **ShanghaiTechA**, edit configurations in `train.py` 44 | ```bash 45 | pool = pools[0] 46 | ``` 47 | 48 | To train **Deep Net**+**stacked pooling** on **ShanghaiTechA**, edit configurations in `train.py` 49 | ```bash 50 | pool = pools[1] 51 | ``` 52 | 2. Run `python train.py` respectively to start training 53 | 54 | ## Test 55 | 1. Follow step 1 of **Train** to edit corresponding `pool` in `test.py` 56 | 2. Edit `model_path` in `test.py` using the best checkpoint on validation set (output by training process) 57 | 3. Run `python test.py` respectively to compare them! 58 | 59 | ## Note 60 | 1. To try pooling methods (**vanilla pooling**, **stacked pooling**, and **multi-kernel pooling**) described in our paper: 61 | 62 | Edit `pool` in `train.py` and `test.py` 63 | 64 | 2. To evaluate on datasets (**ShanghaiTechA**, **ShanghaiTechB**) or backbone models (**Base Net**, **Wide-Net**, **Deep-Net**) described in our paper: 65 | 66 | Edit `dataset_name` or `model` in `train.py` and `test.py` 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /data_preparation/create_gt_test_set_shtech.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | % File to create grount truth density map for test set% 3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 4 | 5 | 6 | clc; clear all; 7 | dataset = 'A'; 8 | dataset_name = ['shanghaitech_part_' dataset ]; 9 | path = ['../data/original/shanghaitech/part_' dataset '_final/test_data/images/']; 10 | gt_path = ['../data/original/shanghaitech/part_' dataset '_final/test_data/ground_truth/']; 11 | gt_path_csv = ['../data/original/shanghaitech/part_' dataset '_final/test_data/ground_truth_csv/']; 12 | 13 | mkdir(gt_path_csv ) 14 | if (dataset == 'A') 15 | num_images = 182; 16 | else 17 | num_images = 316; 18 | end 19 | 20 | for i = 1:num_images 21 | if (mod(i,10)==0) 22 | fprintf(1,'Processing %3d/%d files\n', i, num_images); 23 | end 24 | load(strcat(gt_path, 'GT_IMG_',num2str(i),'.mat')) ; 25 | input_img_name = strcat(path,'IMG_',num2str(i),'.jpg'); 26 | im = imread(input_img_name); 27 | [h, w, c] = size(im); 28 | if (c == 3) 29 | im = rgb2gray(im); 30 | end 31 | annPoints = image_info{1}.location; 32 | [h, w, c] = size(im); 33 | im_density = get_density_map_gaussian(im,annPoints); 34 | csvwrite([gt_path_csv ,'IMG_',num2str(i) '.csv'], im_density); 35 | end 36 | 37 | -------------------------------------------------------------------------------- /data_preparation/create_training_set_shtech.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | % File to create training and validation set % 3 | % for ShanghaiTech Dataset Part A and B. 10% of % 4 | % the training set is set aside for validation % 5 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 6 | 7 | 8 | clc; clear all; 9 | seed = 95461354; 10 | rng(seed) 11 | N = 9; 12 | dataset = 'A'; 13 | dataset_name = ['shanghaitech_part_' dataset '_patches_' num2str(N)]; 14 | path = ['../data/original/shanghaitech/part_' dataset '_final/train_data/images/']; 15 | output_path = '../data/formatted_trainval/'; 16 | train_path_img = strcat(output_path, dataset_name,'/train/'); 17 | train_path_den = strcat(output_path, dataset_name,'/train_den/'); 18 | val_path_img = strcat(output_path, dataset_name,'/val/'); 19 | val_path_den = strcat(output_path, dataset_name,'/val_den/'); 20 | gt_path = ['../data/original/shanghaitech/part_' dataset '_final/train_data/ground_truth/']; 21 | 22 | mkdir(output_path) 23 | mkdir(train_path_img); 24 | mkdir(train_path_den); 25 | mkdir(val_path_img); 26 | mkdir(val_path_den); 27 | 28 | if (dataset == 'A') 29 | num_images = 300; 30 | else 31 | num_images = 400; 32 | end 33 | num_val = ceil(num_images*0.1); 34 | indices = randperm(num_images); 35 | 36 | for idx = 1:num_images 37 | i = indices(idx); 38 | if (mod(idx,10)==0) 39 | fprintf(1,'Processing %3d/%d files\n', idx, num_images); 40 | end 41 | load(strcat(gt_path, 'GT_IMG_',num2str(i),'.mat')) ; 42 | input_img_name = strcat(path,'IMG_',num2str(i),'.jpg'); 43 | im = imread(input_img_name); 44 | [h, w, c] = size(im); 45 | if (c == 3) 46 | im = rgb2gray(im); 47 | end 48 | 49 | wn2 = w/8; hn2 = h/8; 50 | wn2 =8 * floor(wn2/8); 51 | hn2 =8 * floor(hn2/8); 52 | 53 | annPoints = image_info{1}.location; 54 | if( w <= 2*wn2 ) 55 | im = imresize(im,[ h,2*wn2+1]); 56 | annPoints(:,1) = annPoints(:,1)*2*wn2/w; 57 | end 58 | if( h <= 2*hn2) 59 | im = imresize(im,[2*hn2+1,w]); 60 | annPoints(:,2) = annPoints(:,2)*2*hn2/h; 61 | end 62 | [h, w, c] = size(im); 63 | a_w = wn2+1; b_w = w - wn2; 64 | a_h = hn2+1; b_h = h - hn2; 65 | 66 | im_density = get_density_map_gaussian(im,annPoints); 67 | for j = 1:N 68 | 69 | x = floor((b_w - a_w) * rand + a_w); 70 | y = floor((b_h - a_h) * rand + a_h); 71 | x1 = x - wn2; y1 = y - hn2; 72 | x2 = x + wn2-1; y2 = y + hn2-1; 73 | 74 | 75 | im_sampled = im(y1:y2, x1:x2,:); 76 | im_density_sampled = im_density(y1:y2,x1:x2); 77 | 78 | annPoints_sampled = annPoints(annPoints(:,1)>x1 & ... 79 | annPoints(:,1) < x2 & ... 80 | annPoints(:,2) > y1 & ... 81 | annPoints(:,2) < y2,:); 82 | annPoints_sampled(:,1) = annPoints_sampled(:,1) - x1; 83 | annPoints_sampled(:,2) = annPoints_sampled(:,2) - y1; 84 | img_idx = strcat(num2str(i), '_',num2str(j)); 85 | 86 | if(idx < num_val) 87 | imwrite(im_sampled, [val_path_img num2str(img_idx) '.jpg']); 88 | csvwrite([val_path_den num2str(img_idx) '.csv'], im_density_sampled); 89 | else 90 | imwrite(im_sampled, [train_path_img num2str(img_idx) '.jpg']); 91 | csvwrite([train_path_den num2str(img_idx) '.csv'], im_density_sampled); 92 | end 93 | 94 | end 95 | 96 | end 97 | 98 | -------------------------------------------------------------------------------- /data_preparation/get_density_map_gaussian.m: -------------------------------------------------------------------------------- 1 | function im_density = get_density_map_gaussian(im,points) 2 | 3 | 4 | im_density = zeros(size(im)); 5 | [h,w] = size(im_density); 6 | 7 | if(length(points)==0) 8 | return; 9 | end 10 | 11 | if(length(points(:,1))==1) 12 | x1 = max(1,min(w,round(points(1,1)))); 13 | y1 = max(1,min(h,round(points(1,2)))); 14 | im_density(y1,x1) = 255; 15 | return; 16 | end 17 | for j = 1:length(points) 18 | f_sz = 15; 19 | sigma = 4.0; 20 | H = fspecial('Gaussian',[f_sz, f_sz],sigma); 21 | x = min(w,max(1,abs(int32(floor(points(j,1)))))); 22 | y = min(h,max(1,abs(int32(floor(points(j,2)))))); 23 | if(x > w || y > h) 24 | continue; 25 | end 26 | x1 = x - int32(floor(f_sz/2)); y1 = y - int32(floor(f_sz/2)); 27 | x2 = x + int32(floor(f_sz/2)); y2 = y + int32(floor(f_sz/2)); 28 | dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0; 29 | change_H = false; 30 | if(x1 < 1) 31 | dfx1 = abs(x1)+1; 32 | x1 = 1; 33 | change_H = true; 34 | end 35 | if(y1 < 1) 36 | dfy1 = abs(y1)+1; 37 | y1 = 1; 38 | change_H = true; 39 | end 40 | if(x2 > w) 41 | dfx2 = x2 - w; 42 | x2 = w; 43 | change_H = true; 44 | end 45 | if(y2 > h) 46 | dfy2 = y2 - h; 47 | y2 = h; 48 | change_H = true; 49 | end 50 | x1h = 1+dfx1; y1h = 1+dfy1; x2h = f_sz - dfx2; y2h = f_sz - dfy2; 51 | if (change_H == true) 52 | H = fspecial('Gaussian',[double(y2h-y1h+1), double(x2h-x1h+1)],sigma); 53 | end 54 | im_density(y1:y2,x1:x2) = im_density(y1:y2,x1:x2) + H; 55 | 56 | end 57 | 58 | end -------------------------------------------------------------------------------- /src/crowd_count.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import network 3 | 4 | class CrowdCounter(nn.Module): 5 | def __init__(self,model,pool): 6 | super(CrowdCounter, self).__init__() 7 | if model=='base': 8 | from models.base import base 9 | self.DME = base(pool) 10 | if model=='wide': 11 | from models.wide import wide 12 | self.DME = wide(pool) 13 | if model=='deep': 14 | from models.deep import deep 15 | self.DME = deep(pool) 16 | 17 | self.loss_fn = nn.MSELoss() 18 | 19 | @property 20 | def loss(self): 21 | return self.loss_mse 22 | 23 | def forward(self, im_data, gt_data=None): 24 | im_data = network.np_to_variable(im_data, is_cuda=True, is_training=self.training) 25 | density_map = self.DME(im_data) 26 | if self.training: 27 | gt_data = network.np_to_variable(gt_data, is_cuda=True, is_training=self.training) 28 | self.loss_mse = self.build_loss(density_map, gt_data) 29 | 30 | return density_map 31 | 32 | def build_loss(self, density_map, gt_data): 33 | loss = self.loss_fn(density_map, gt_data) 34 | return loss 35 | -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import random 5 | import pandas as pd 6 | 7 | class ImageDataLoader(): 8 | def __init__(self, data_path, gt_path, shuffle=False, gt_downsample=False, pre_load=False, 9 | batch_size=1, scaling=4, re_scale=1.0, re_size=None): 10 | #pre_load: if true, all training and validation images are loaded into CPU RAM for faster processing. 11 | # This avoids frequent file reads. Use this only for small datasets. 12 | self.data_path = data_path 13 | self.gt_path = gt_path 14 | self.gt_downsample = gt_downsample 15 | self.pre_load = pre_load 16 | self.data_files = [filename for filename in os.listdir(data_path) \ 17 | if os.path.isfile(os.path.join(data_path,filename))] 18 | self.data_files.sort() 19 | self.shuffle = shuffle 20 | self.scaling = scaling 21 | self.re_scale = re_scale 22 | self.re_size = re_size 23 | if shuffle: 24 | random.seed(2468) 25 | self.num_samples = len(self.data_files) 26 | self.blob_list = {} 27 | self.id_list = range(0,self.num_samples/batch_size) 28 | 29 | batch = -1 30 | batch_full=False 31 | if self.pre_load: 32 | print 'Pre-loading the data. This may take a while...' 33 | idx = 0 34 | for fname in self.data_files: 35 | 36 | img = cv2.imread(os.path.join(self.data_path,fname),0) 37 | img = img.astype(np.float32, copy=False) 38 | if self.re_size is None: 39 | ht = img.shape[0] 40 | wd = img.shape[1] 41 | else: 42 | ht = self.re_size[0] 43 | wd = self.re_size[1] 44 | ht_1 = (ht/self.scaling)*self.scaling 45 | wd_1 = (wd/self.scaling)*self.scaling 46 | img = cv2.resize(img,(wd_1,ht_1)) 47 | img = img.reshape((1,1,img.shape[0],img.shape[1])) 48 | img = img/self.re_scale 49 | den = pd.read_csv(os.path.join(self.gt_path,os.path.splitext(fname)[0] + '.csv'), sep=',',header=None).as_matrix() 50 | den = den.astype(np.float32, copy=False) 51 | if self.gt_downsample: 52 | wd_1 = wd_1/self.scaling 53 | ht_1 = ht_1/self.scaling 54 | den = cv2.resize(den,(wd_1,ht_1)) 55 | den = den * ((wd*ht)/(wd_1*ht_1)) 56 | else: 57 | den = cv2.resize(den,(wd_1,ht_1)) 58 | den = den * ((wd*ht)/(wd_1*ht_1)) 59 | 60 | den = den.reshape((1,1,den.shape[0],den.shape[1])) 61 | if idx==0: 62 | blob = {} 63 | blob['data']=img 64 | blob['gt_density']=den 65 | blob['fname'] = [fname] 66 | idx+=1 67 | batch_full=False 68 | if idx==batch_size: 69 | idx = 0 70 | batch_full=True 71 | else: 72 | blob['data']=np.concatenate((blob['data'],img)) 73 | blob['gt_density']=np.concatenate((blob['gt_density'],den)) 74 | blob['fname'].append(fname) 75 | idx+=1 76 | batch_full=False 77 | if idx==batch_size: 78 | idx = 0 79 | batch_full=True 80 | 81 | if batch_full: 82 | batch+=1 83 | self.blob_list[batch] = blob 84 | if batch % 200 == 0: 85 | print 'Loaded', batch, 'batch', batch*batch_size, '/', self.num_samples, 'files' 86 | 87 | print 'Completed Loading ', batch+1, 'batches' 88 | 89 | 90 | def __iter__(self): 91 | if self.shuffle: 92 | if self.pre_load: 93 | random.shuffle(self.id_list) 94 | else: 95 | random.shuffle(self.data_files) 96 | files = self.data_files 97 | id_list = self.id_list 98 | 99 | for idx in id_list: 100 | if self.pre_load: 101 | blob = self.blob_list[idx] 102 | blob['idx'] = idx 103 | else: 104 | fname = files[idx] 105 | img = cv2.imread(os.path.join(self.data_path,fname),0) 106 | img = img.astype(np.float32, copy=False) 107 | if self.re_size is None: 108 | ht = img.shape[0] 109 | wd = img.shape[1] 110 | else: 111 | ht = self.re_size[0] 112 | wd = self.re_size[1] 113 | ht_1 = (ht/self.scaling)*self.scaling 114 | wd_1 = (wd/self.scaling)*self.scaling 115 | img = cv2.resize(img,(wd_1,ht_1)) 116 | img = img.reshape((1,1,img.shape[0],img.shape[1])) 117 | img = img/self.re_scale 118 | den = pd.read_csv(os.path.join(self.gt_path,os.path.splitext(fname)[0] + '.csv'), sep=',',header=None).as_matrix() 119 | den = den.astype(np.float32, copy=False) 120 | if self.gt_downsample: 121 | wd_1 = wd_1/self.scaling 122 | ht_1 = ht_1/self.scaling 123 | den = cv2.resize(den,(wd_1,ht_1)) 124 | den = den * ((wd*ht)/(wd_1*ht_1)) 125 | else: 126 | den = cv2.resize(den,(wd_1,ht_1)) 127 | den = den * ((wd*ht)/(wd_1*ht_1)) 128 | 129 | den = den.reshape((1,1,den.shape[0],den.shape[1])) 130 | blob = {} 131 | blob['data']=img 132 | blob['gt_density']=den 133 | blob['fname'] = fname 134 | 135 | yield blob 136 | 137 | def get_num_samples(self): 138 | return self.num_samples 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /src/evaluate_model.py: -------------------------------------------------------------------------------- 1 | from crowd_count import CrowdCounter 2 | import network 3 | import numpy as np 4 | 5 | 6 | def evaluate_model(trained_model, data_loader, model, pool): 7 | net = CrowdCounter(model=model,pool=pool) 8 | network.load_net(trained_model, net) 9 | net.cuda() 10 | net.eval() 11 | mae = 0.0 12 | mse = 0.0 13 | for blob in data_loader: 14 | im_data = blob['data'] 15 | gt_data = blob['gt_density'] 16 | density_map = net(im_data, gt_data) 17 | density_map = density_map.data.cpu().numpy() 18 | gt_count = np.sum(gt_data) 19 | et_count = np.sum(density_map) 20 | mae += abs(gt_count-et_count) 21 | mse += ((gt_count-et_count)*(gt_count-et_count)) 22 | mae = mae/data_loader.get_num_samples() 23 | mse = np.sqrt(mse/data_loader.get_num_samples()) 24 | return mae,mse -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | The Base-Net described in our paper. 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | from src.network import Conv2d 9 | import time 10 | 11 | class base(nn.Module): 12 | 13 | def __init__(self, pool, bn=False): 14 | super(base, self).__init__() 15 | 16 | kernel_size = 5 17 | self.pool = pool 18 | if kernel_size==7: 19 | self.c1 = Conv2d( 1, 16, 9, same_padding=True, bn=bn) 20 | self.c2 = Conv2d(16, 32, 7, same_padding=True, bn=bn) 21 | self.c3_5 = nn.Sequential(Conv2d(32, 16, 7, same_padding=True, bn=bn), 22 | Conv2d(16, 8, 7, same_padding=True, bn=bn), 23 | Conv2d( 8, 1, 1, same_padding=True, bn=bn)) 24 | if kernel_size==5: 25 | self.c1 = Conv2d( 1, 20, 7, same_padding=True, bn=bn) 26 | self.c2 = Conv2d(20, 40, 5, same_padding=True, bn=bn) 27 | self.c3_5 = nn.Sequential(Conv2d(40, 20, 5, same_padding=True, bn=bn), 28 | Conv2d(20, 10, 5, same_padding=True, bn=bn), 29 | Conv2d( 10, 1, 1, same_padding=True, bn=bn)) 30 | if kernel_size==3: 31 | self.c1 = Conv2d( 1, 24, 5, same_padding=True, bn=bn) 32 | self.c2 = Conv2d(24, 48, 3, same_padding=True, bn=bn) 33 | self.c3_5 = nn.Sequential(Conv2d(48, 24, 3, same_padding=True, bn=bn), 34 | Conv2d(24, 12, 3, same_padding=True, bn=bn), 35 | Conv2d( 12, 1, 1, same_padding=True, bn=bn)) 36 | 37 | self.pool2 = nn.MaxPool2d(2, stride=2) 38 | self.pool2s1 = nn.MaxPool2d(2, stride=1) 39 | self.pool3s1 = nn.MaxPool2d(3, stride=1, padding=1) 40 | self.pool4 = nn.MaxPool2d(4, stride=2, padding=1) 41 | self.pool8 = nn.MaxPool2d(8, stride=2, padding=3) 42 | 43 | self.padding = nn.ReplicationPad2d((0, 1, 0, 1)) 44 | 45 | def multi_pool(self, x): 46 | x1 = self.pool2(x) 47 | x2 = self.pool4(x) 48 | x3 = self.pool8(x) 49 | y = (x1+x2+x3)/3.0 50 | return y 51 | 52 | def stack_pool(self, x): 53 | x1 = self.pool2(x) 54 | x2 = self.pool2s1(self.padding(x1)) 55 | x3 = self.pool3s1(x2) 56 | y = (x1+x2+x3)/3.0 57 | return y 58 | 59 | def forward(self, im_data): 60 | x = self.c1(im_data) 61 | 62 | if self.pool=='mpool': 63 | x = self.multi_pool(x) 64 | if self.pool=='stackpool': 65 | x = self.stack_pool(x) 66 | if self.pool=='vpool': 67 | x = self.pool2(x) 68 | 69 | x = self.c2(x) 70 | 71 | if self.pool=='mpool': 72 | x = self.multi_pool(x) 73 | if self.pool=='stackpool': 74 | x = self.stack_pool(x) 75 | if self.pool=='vpool': 76 | x = self.pool2(x) 77 | 78 | x = self.c3_5(x) 79 | 80 | return x -------------------------------------------------------------------------------- /src/models/deep.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | The Deep-Net described in our paper. 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | import numpy as np 10 | 11 | cfg = { 12 | # 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 512, 512], 13 | # 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 512, 512], 14 | 'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 128, 64, 32, 16], 15 | # 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 512, 512, 512], 16 | # 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 512, 512, 512, 512], 17 | } 18 | 19 | class multi_pool(nn.Module): 20 | def __init__(self): 21 | super(multi_pool, self).__init__() 22 | self.pool2 = nn.MaxPool2d(2, stride=2) 23 | self.pool4 = nn.MaxPool2d(4, stride=2, padding=1) 24 | self.pool8 = nn.MaxPool2d(8, stride=2, padding=3) 25 | def forward(self, x): 26 | x1 = self.pool2(x) 27 | x2 = self.pool4(x) 28 | x3 = self.pool8(x) 29 | y = (x1+x2+x3)/3.0 30 | return y 31 | 32 | class stack_pool(nn.Module): 33 | def __init__(self): 34 | super(stack_pool, self).__init__() 35 | self.pool2 = nn.MaxPool2d(2, stride=2) 36 | self.pool2s1 = nn.MaxPool2d(2, stride=1) 37 | self.pool3s1 = nn.MaxPool2d(3, stride=1, padding=1) 38 | self.padding = nn.ReplicationPad2d((0, 1, 0, 1)) 39 | def forward(self, x): 40 | x1 = self.pool2(x) 41 | x2 = self.pool2s1(self.padding(x1)) 42 | x3 = self.pool3s1(x2) 43 | y = (x1+x2+x3)/3.0 44 | return y 45 | 46 | class feature_net(nn.Module): 47 | def __init__(self,pool): 48 | super(feature_net, self).__init__() 49 | self.pool = pool 50 | self.features = self.make_layers(cfg = cfg['C'], batch_norm = False) 51 | def forward(self, x): 52 | feature = self.features(x) 53 | return feature 54 | def make_layers(self, cfg, batch_norm = False): 55 | layers = [] 56 | in_channels = 1 57 | idx_M = 0 58 | conv_size = 5 59 | for v in cfg: 60 | if v == 'M': 61 | idx_M += 1 62 | if idx_M >= 2: 63 | conv_size = 3 64 | if self.pool == 'mpool': 65 | layers += [multi_pool()] 66 | if self.pool == 'stackpool': 67 | layers += [stack_pool()] 68 | if self.pool == 'vpool': 69 | layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)] 70 | else: 71 | conv2d = nn.Conv2d(in_channels, v, kernel_size = conv_size, padding = (conv_size-1)/2 ) 72 | if batch_norm: 73 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)] 74 | else: 75 | layers += [conv2d, nn.ReLU(inplace = True)] 76 | in_channels = v 77 | return nn.Sequential(*layers) 78 | 79 | class deep(nn.Module): 80 | def __init__(self,pool): 81 | super(deep, self).__init__() 82 | self.conv2d = nn.Conv2d(16, 1, kernel_size = 1) 83 | self.feature_net = feature_net(pool) 84 | #self._initialize_weights() 85 | def forward(self, x): 86 | x = self.feature_net.forward(x) 87 | heat_map = self.conv2d(x) 88 | return heat_map 89 | def _initialize_weights(self): 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv2d): 92 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 93 | if m.bias is not None: 94 | nn.init.constant_(m.bias, 0) 95 | -------------------------------------------------------------------------------- /src/models/wide.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | The Wide-Net described in our paper. 4 | ''' 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | import numpy as np 10 | 11 | cfg = { 12 | 'wide': [128, 'M', 256, 'M',128, 64], 13 | } 14 | 15 | class multi_pool(nn.Module): 16 | def __init__(self): 17 | super(multi_pool, self).__init__() 18 | self.pool2 = nn.MaxPool2d(2, stride=2) 19 | self.pool4 = nn.MaxPool2d(4, stride=2, padding=1) 20 | self.pool8 = nn.MaxPool2d(8, stride=2, padding=3) 21 | def forward(self, x): 22 | x1 = self.pool2(x) 23 | x2 = self.pool4(x) 24 | x3 = self.pool8(x) 25 | y = (x1+x2+x3)/3.0 26 | return y 27 | 28 | class stack_pool(nn.Module): 29 | def __init__(self): 30 | super(stack_pool, self).__init__() 31 | self.pool2 = nn.MaxPool2d(2, stride=2) 32 | self.pool2s1 = nn.MaxPool2d(2, stride=1) 33 | self.pool3s1 = nn.MaxPool2d(3, stride=1, padding=1) 34 | self.padding = nn.ReplicationPad2d((0, 1, 0, 1)) 35 | def forward(self, x): 36 | x1 = self.pool2(x) 37 | x2 = self.pool2s1(self.padding(x1)) 38 | x3 = self.pool3s1(x2) 39 | y = (x1+x2+x3)/3.0 40 | return y 41 | 42 | class feature_net(nn.Module): 43 | def __init__(self,pool): 44 | super(feature_net, self).__init__() 45 | self.pool = pool 46 | self.features = self.make_layers(cfg = cfg['wide'], batch_norm = False) 47 | def forward(self, x): 48 | feature = self.features(x) 49 | return feature 50 | def make_layers(self, cfg, batch_norm = False): 51 | layers = [] 52 | in_channels = 1 53 | idx_M = 0 54 | conv_size = 7 55 | for v in cfg: 56 | if v == 'M': 57 | idx_M += 1 58 | if idx_M >= 1: 59 | conv_size = 5 60 | if idx_M >= 2: 61 | conv_size = 3 62 | if self.pool == 'mpool': 63 | layers += [multi_pool()] 64 | if self.pool == 'stackpool': 65 | layers += [stack_pool()] 66 | if self.pool == 'vpool': 67 | layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)] 68 | else: 69 | conv2d = nn.Conv2d(in_channels, v, kernel_size = conv_size, padding = (conv_size-1)/2 ) 70 | if batch_norm: 71 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)] 72 | else: 73 | layers += [conv2d, nn.ReLU(inplace = True)] 74 | in_channels = v 75 | return nn.Sequential(*layers) 76 | 77 | class wide(nn.Module): 78 | def __init__(self,pool): 79 | super(wide, self).__init__() 80 | self.conv2d = nn.Conv2d(64, 1, kernel_size = 1) 81 | self.feature_net = feature_net(pool) 82 | #self._initialize_weights() 83 | def forward(self, x): 84 | x = self.feature_net.forward(x) 85 | heat_map = self.conv2d(x) 86 | return heat_map 87 | def _initialize_weights(self): 88 | for m in self.modules(): 89 | if isinstance(m, nn.Conv2d): 90 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 91 | if m.bias is not None: 92 | nn.init.constant_(m.bias, 0) 93 | -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | class Conv2d(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, same_padding=False, bn=False): 8 | super(Conv2d, self).__init__() 9 | 10 | if isinstance(kernel_size, tuple)==True: 11 | padding = (int((kernel_size[0] - 1) / 2), int((kernel_size[1] - 1) / 2)) if same_padding else 0 12 | else: 13 | padding = int((kernel_size - 1) / 2) if same_padding else 0 14 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding) 15 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None 16 | self.relu = nn.ReLU(inplace=True) if relu else None 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | if self.bn is not None: 21 | x = self.bn(x) 22 | if self.relu is not None: 23 | x = self.relu(x) 24 | return x 25 | 26 | 27 | class FC(nn.Module): 28 | def __init__(self, in_features, out_features, relu=True): 29 | super(FC, self).__init__() 30 | self.fc = nn.Linear(in_features, out_features) 31 | self.relu = nn.ReLU(inplace=True) if relu else None 32 | 33 | def forward(self, x): 34 | x = self.fc(x) 35 | if self.relu is not None: 36 | x = self.relu(x) 37 | return x 38 | 39 | class MaxPoolSame(nn.Module): 40 | def __init__(self): 41 | super(MaxPoolSame, self).__init__() 42 | self.padding = nn.ReplicationPad2d((0, 1, 0, 1)) 43 | self.pooling = nn.MaxPool2d(2, 1) 44 | 45 | def forward(self, x): 46 | out = self.pooling(self.padding(x)) 47 | assert(x.shape == out.shape) 48 | return out 49 | 50 | 51 | def save_net(fname, net): 52 | import h5py 53 | h5f = h5py.File(fname, mode='w') 54 | for k, v in net.state_dict().items(): 55 | h5f.create_dataset(k, data=v.cpu().numpy()) 56 | 57 | 58 | def load_net(fname, net): 59 | import h5py 60 | h5f = h5py.File(fname, mode='r') 61 | for k, v in net.state_dict().items(): 62 | param = torch.from_numpy(np.asarray(h5f[k])) 63 | v.copy_(param) 64 | 65 | 66 | def np_to_variable(x, is_cuda=True, is_training=False, dtype=torch.FloatTensor): 67 | if is_training: 68 | v = Variable(torch.from_numpy(x).type(dtype)) 69 | else: 70 | v = Variable(torch.from_numpy(x).type(dtype), requires_grad = False)#, volatile = True) 71 | if is_cuda: 72 | v = v.cuda() 73 | return v 74 | 75 | 76 | def set_trainable(model, requires_grad): 77 | for param in model.parameters(): 78 | param.requires_grad = requires_grad 79 | 80 | 81 | def weights_normal_init(model, dev=0.01): 82 | if isinstance(model, list): 83 | for m in model: 84 | weights_normal_init(m, dev) 85 | else: 86 | for m in model.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | #print torch.sum(m.weight) 89 | m.weight.data.normal_(0.0, dev) 90 | if m.bias is not None: 91 | m.bias.data.fill_(0.0) 92 | elif isinstance(m, nn.Linear): 93 | m.weight.data.normal_(0.0, dev) 94 | -------------------------------------------------------------------------------- /src/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer(object): 4 | def __init__(self): 5 | self.tot_time = 0. 6 | self.calls = 0 7 | self.start_time = 0. 8 | self.diff = 0. 9 | self.average_time = 0. 10 | 11 | def tic(self): 12 | # using time.time instead of time.clock because time time.clock 13 | # does not normalize for multithreading 14 | self.start_time = time.time() 15 | 16 | def toc(self, average=True): 17 | self.diff = time.time() - self.start_time 18 | self.tot_time += self.diff 19 | self.calls += 1 20 | self.average_time = self.tot_time / self.calls 21 | if average: 22 | return self.average_time 23 | else: 24 | return self.diff 25 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | 5 | def save_results(input_img, gt_data,density_map,output_dir, fname='results.png'): 6 | input_img = input_img[0][0] 7 | gt_data = 255*gt_data/np.max(gt_data) 8 | density_map = 255*density_map/np.max(density_map) 9 | gt_data = gt_data[0][0] 10 | density_map= density_map[0][0] 11 | if density_map.shape[1] != input_img.shape[1]: 12 | density_map = cv2.resize(density_map, (input_img.shape[1],input_img.shape[0])) 13 | gt_data = cv2.resize(gt_data, (input_img.shape[1],input_img.shape[0])) 14 | result_img = np.hstack((input_img,gt_data,density_map)) 15 | cv2.imwrite(os.path.join(output_dir,fname),result_img) 16 | 17 | 18 | def save_density_map(density_map,output_dir, fname='results.png'): 19 | density_map = 255*density_map/np.max(density_map) 20 | density_map= density_map[0][0] 21 | cv2.imwrite(os.path.join(output_dir,fname),density_map) 22 | 23 | def display_results(input_img, gt_data,density_map): 24 | input_img = input_img[0][0] 25 | gt_data = 255*gt_data/np.max(gt_data) 26 | density_map = 255*density_map/np.max(density_map) 27 | gt_data = gt_data[0][0] 28 | density_map= density_map[0][0] 29 | if density_map.shape[1] != input_img.shape[1]: 30 | input_img = cv2.resize(input_img, (density_map.shape[1],density_map.shape[0])) 31 | result_img = np.hstack((input_img,gt_data,density_map)) 32 | result_img = result_img.astype(np.uint8, copy=False) 33 | cv2.imshow('Result', result_img) 34 | cv2.waitKey(0) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | np.set_printoptions(threshold=np.nan) 5 | 6 | from src.crowd_count import CrowdCounter 7 | from src import network 8 | from src.data_loader import ImageDataLoader 9 | from src import utils 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 12 | np.warnings.filterwarnings('ignore') 13 | # dataset, model, and pooling method 14 | datasets = ['shtechA', 'shtechB'] # datasets 15 | models = ['base', 'wide', 'deep'] # backbone network architecture 16 | pools = ['vpool','stackpool','mpool'] # vpool is vanilla pooling; stackpool is stacked pooling; mpool is multi-kernel pooling 17 | 18 | ### 19 | dataset_name = datasets[0] # choose the dataset 20 | model = models[2] # choose the backbone network architecture 21 | pool = pools[0] # choose the pooling method 22 | method=model+'_'+pool 23 | 24 | name = dataset_name[-1] 25 | data_path = './data/original/shanghaitech/part_'+name+'_final/test_data/images/' 26 | gt_path = './data/original/shanghaitech/part_'+name+'_final/test_data/ground_truth_csv/' 27 | model_path = './saved_models/'+method+'_shtech'+name+'_0.h5' 28 | print 'Testing %s' % (model_path) 29 | 30 | torch.backends.cudnn.enabled = True 31 | torch.backends.cudnn.benchmark = True 32 | vis = False 33 | save_output = True 34 | 35 | output_dir = './output/' 36 | model_name = os.path.basename(model_path).split('.')[0] 37 | file_results = os.path.join(output_dir,'results_' + model_name + '_.txt') 38 | if not os.path.exists(output_dir): 39 | os.mkdir(output_dir) 40 | output_dir = os.path.join(output_dir, 'density_maps_' + model_name) 41 | if not os.path.exists(output_dir): 42 | os.mkdir(output_dir) 43 | 44 | net = CrowdCounter(model,pool) 45 | trained_model = os.path.join(model_path) 46 | network.load_net(trained_model, net) 47 | net.cuda() 48 | net.eval() 49 | 50 | if model in ['base','wide']: 51 | scaling = 4 52 | if model=='deep': 53 | scaling = 8 54 | 55 | #load test data 56 | data_loader = ImageDataLoader(data_path, gt_path, shuffle=False, gt_downsample=True, pre_load=False, batch_size=1, scaling=scaling) 57 | 58 | mae = 0.0 59 | mse = 0.0 60 | num = 0 61 | for blob in data_loader: 62 | num+=1 63 | im_data = blob['data'] 64 | gt_data = blob['gt_density'] 65 | density_map = net(im_data) 66 | density_map = density_map.data.cpu().numpy() 67 | gt_count = np.sum(gt_data) 68 | et_count = np.sum(density_map) 69 | mae += abs(gt_count-et_count) 70 | mse += ((gt_count-et_count)*(gt_count-et_count)) 71 | if vis: 72 | utils.display_results(im_data, gt_data, density_map) 73 | if save_output: 74 | utils.save_density_map(density_map, output_dir, 'output_' + blob['fname'].split('.')[0] + '.png') 75 | if num%100==0: 76 | print '%d/%d' % (num,data_loader.get_num_samples()) 77 | 78 | mae = mae/data_loader.get_num_samples() 79 | mse = np.sqrt(mse/data_loader.get_num_samples()) 80 | print 'MAE: %0.2f, MSE: %0.2f' % (mae,mse) 81 | 82 | f = open(file_results, 'w') 83 | f.write('MAE: %0.2f, MSE: %0.2f' % (mae,mse)) 84 | f.close() 85 | -------------------------------------------------------------------------------- /thumbnails/stackpool.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/siyuhuang/crowdcount-stackpool/bbba3d9e91a5a89642b4bd3638ae8e68801ea7bf/thumbnails/stackpool.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import sys 5 | 6 | from src.crowd_count import CrowdCounter 7 | from src import network 8 | from src.data_loader import ImageDataLoader 9 | from src.timer import Timer 10 | from src import utils 11 | from src.evaluate_model import evaluate_model 12 | import time 13 | 14 | np.warnings.filterwarnings('ignore') 15 | ### assign dataset, model, and pooling method 16 | datasets = ['shtechA', 'shtechB'] # datasets 17 | models = ['base', 'wide', 'deep'] # backbone network architecture 18 | pools = ['vpool','stackpool','mpool'] # vpool is vanilla pooling; stackpool is stacked pooling; mpool is multi-kernel pooling; 19 | 20 | dataset_name = datasets[0] # choose the dataset 21 | model = models[2] # choose the backbone network architecture 22 | pool = pools[0] # choose the pooling method 23 | method=model+'_'+pool 24 | print 'Training %s on %s' % (method, dataset_name) 25 | 26 | ### assign GPU 27 | if pool=='vpool': 28 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 29 | if pool=='stackpool': 30 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 31 | if pool=='mpool': 32 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 33 | 34 | ### PyTorch configuration 35 | torch.backends.cudnn.enabled = True 36 | torch.backends.cudnn.benchmark = True 37 | 38 | ### model saving folder 39 | output_dir = './saved_models/' 40 | if not os.path.exists(output_dir): 41 | os.mkdir(output_dir) 42 | 43 | ### data folder 44 | name = dataset_name[-1] 45 | train_path = './data/formatted_trainval/shanghaitech_part_'+name+'_patches_9/train' 46 | train_gt_path = './data/formatted_trainval/shanghaitech_part_'+name+'_patches_9/train_den' 47 | val_path = './data/formatted_trainval/shanghaitech_part_'+name+'_patches_9/val' 48 | val_gt_path = './data/formatted_trainval/shanghaitech_part_'+name+'_patches_9/val_den' 49 | 50 | ### training configuration 51 | start_step = 0 52 | end_step = 500 53 | batch_size=1 54 | disp_interval = 1500 55 | if model=='base': 56 | if dataset_name == 'shtechA': 57 | lr = 5*1e-5 58 | if dataset_name == 'shtechB': 59 | lr = 2*1e-5 60 | scaling=4 # output density map is 1/4 size of input image 61 | if model=='wide': 62 | if dataset_name == 'shtechA': 63 | lr = 1e-5 64 | if dataset_name == 'shtechB': 65 | lr = 1e-5 66 | scaling=4 # output density map is 1/4 size of input image 67 | if model=='deep': 68 | if dataset_name == 'shtechA': 69 | lr = 1e-5 70 | if dataset_name == 'shtechB': 71 | lr = 5*1e-6 72 | scaling=8 # output density map is 1/8 size of input image 73 | print 'learning rate %f' % (lr) 74 | 75 | ### random seed 76 | rand_seed = 64678 77 | if rand_seed is not None: 78 | np.random.seed(rand_seed) 79 | torch.manual_seed(rand_seed) 80 | torch.cuda.manual_seed_all(rand_seed) 81 | 82 | ### initialize network 83 | net = CrowdCounter(model=model,pool=pool) 84 | network.weights_normal_init(net, dev=0.01) 85 | net.cuda() 86 | net.train() 87 | 88 | ### optimizer 89 | params = list(net.parameters()) 90 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr) 91 | 92 | ### load data 93 | pre_load=True 94 | data_loader = ImageDataLoader(train_path, train_gt_path, shuffle=True, gt_downsample=True, pre_load=pre_load, 95 | batch_size=batch_size,scaling=scaling) 96 | data_loader_val = ImageDataLoader(val_path, val_gt_path, shuffle=False, gt_downsample=True, pre_load=pre_load, 97 | batch_size=1,scaling=scaling) 98 | 99 | ### training 100 | train_loss = 0 101 | t = Timer() 102 | t.tic() 103 | best_mae = sys.maxint 104 | 105 | for epoch in range(start_step, end_step+1): 106 | step = 0 107 | train_loss = 0 108 | for blob in data_loader: 109 | step = step + 1 110 | im_data = blob['data'] 111 | gt_data = blob['gt_density'] 112 | density_map = net(im_data, gt_data) 113 | loss = net.loss 114 | train_loss += loss.item() 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | 119 | if step % disp_interval == 0: 120 | duration = t.toc(average=False) 121 | density_map = density_map.data.cpu().numpy() 122 | utils.save_results(im_data,gt_data,density_map, output_dir) 123 | print 'epoch: %4d, step %4d, Time: %.4fs, loss: %4.10f' % (epoch, step, duration, train_loss/disp_interval) 124 | train_loss = 0 125 | t.tic() 126 | 127 | if (epoch % 2 == 0): 128 | # save model checkpoint 129 | save_name = os.path.join(output_dir, '{}_{}_{}.h5'.format(method,dataset_name,epoch)) 130 | network.save_net(save_name, net) 131 | # calculate error on the validation dataset 132 | mae,mse = evaluate_model(save_name, data_loader_val, model, pool) 133 | if mae < best_mae: 134 | best_mae = mae 135 | best_mse = mse 136 | best_model = '{}_{}_{}.h5'.format(method,dataset_name,epoch) 137 | print 'EPOCH: %d, MAE: %0.2f, MSE: %0.2f' % (epoch,mae,mse) 138 | print 'BEST MAE: %0.2f, BEST MSE: %0.2f, BEST MODEL: %s' % (best_mae,best_mse, best_model) 139 | 140 | 141 | t.tic() 142 | 143 | 144 | 145 | --------------------------------------------------------------------------------