├── .gitignore
├── LICENSE
├── README.md
├── data_preparation
├── create_gt_test_set_shtech.m
├── create_training_set_shtech.m
└── get_density_map_gaussian.m
├── src
├── crowd_count.py
├── data_loader.py
├── evaluate_model.py
├── models
│ ├── base.py
│ ├── deep.py
│ └── wide.py
├── network.py
├── timer.py
└── utils.py
├── test.py
├── thumbnails
└── stackpool.jpg
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 siyuhuang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Stacked Pooling for Boosting Scale Invariance of Crowd Counting
2 |
3 | PyTorch implementation of "**Stacked Pooling for Boosting Scale Invariance of Crowd Counting**" [\[ICASSP 2020\]](https://siyuhuang.github.io/papers/ICASSP-2020-STACKED%20POOLING%20FOR%20BOOSTING%20SCALE%20INVARIANCE%20OF%20CROWD%20COUNTING.pdf).
4 |
5 | ```
6 | @inproceedings{huang2020stacked,
7 | title={Stacked Pooling for Boosting Scale Invariance of Crowd Counting},
8 | author={Huang, Siyu and Li, Xi and Cheng, Zhi-Qi and Zhang, Zhongfei and Hauptmann, Alexander},
9 | booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing},
10 | pages={2578--2582},
11 | year={2020},
12 | }
13 | ```
14 |
15 | This code is implemented based on [https://github.com/svishwa/crowdcount-mcnn](https://github.com/svishwa/crowdcount-mcnn)
16 |
17 |
18 |
19 |
20 |
21 | | | ShanghaiTech-A | ShanghaiTech-B | WorldExpo'10|
22 | | -------- | :-----: | :----: | :----: |
23 | | Vanilla Pooling | 97.63 | 21.17 | 14.74 |
24 | | Stacked Pooling | **93.98** | **18.73** | **12.92**|
25 |
26 |
27 | ## Dependency
28 | 1. Python 2.7
29 | 2. PyTorch 0.4.0
30 |
31 | ## Data Setup
32 | 1. Download ShanghaiTech Dataset from
33 | Dropbox: https://www.dropbox.com/s/fipgjqxl7uj8hd5/ShanghaiTech.zip?dl=0
34 | Baidu Disk: http://pan.baidu.com/s/1nuAYslz
35 | 2. Create Directory `mkdir ./data/original/shanghaitech/`
36 | 3. Save "part_A_final" under ./data/original/shanghaitech/
37 | Save "part_B_final" under ./data/original/shanghaitech/
38 | 4. `cd ./data_preparation/`
39 | Run `create_gt_test_set_shtech.m` in matlab to create ground truth files for test data
40 | Run `create_training_set_shtech.m` in matlab to create training and validataion set along with ground truth files
41 |
42 | ## Train
43 | 1. To train **Deep Net**+**vanilla pooling** on **ShanghaiTechA**, edit configurations in `train.py`
44 | ```bash
45 | pool = pools[0]
46 | ```
47 |
48 | To train **Deep Net**+**stacked pooling** on **ShanghaiTechA**, edit configurations in `train.py`
49 | ```bash
50 | pool = pools[1]
51 | ```
52 | 2. Run `python train.py` respectively to start training
53 |
54 | ## Test
55 | 1. Follow step 1 of **Train** to edit corresponding `pool` in `test.py`
56 | 2. Edit `model_path` in `test.py` using the best checkpoint on validation set (output by training process)
57 | 3. Run `python test.py` respectively to compare them!
58 |
59 | ## Note
60 | 1. To try pooling methods (**vanilla pooling**, **stacked pooling**, and **multi-kernel pooling**) described in our paper:
61 |
62 | Edit `pool` in `train.py` and `test.py`
63 |
64 | 2. To evaluate on datasets (**ShanghaiTechA**, **ShanghaiTechB**) or backbone models (**Base Net**, **Wide-Net**, **Deep-Net**) described in our paper:
65 |
66 | Edit `dataset_name` or `model` in `train.py` and `test.py`
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/data_preparation/create_gt_test_set_shtech.m:
--------------------------------------------------------------------------------
1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2 | % File to create grount truth density map for test set%
3 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
4 |
5 |
6 | clc; clear all;
7 | dataset = 'A';
8 | dataset_name = ['shanghaitech_part_' dataset ];
9 | path = ['../data/original/shanghaitech/part_' dataset '_final/test_data/images/'];
10 | gt_path = ['../data/original/shanghaitech/part_' dataset '_final/test_data/ground_truth/'];
11 | gt_path_csv = ['../data/original/shanghaitech/part_' dataset '_final/test_data/ground_truth_csv/'];
12 |
13 | mkdir(gt_path_csv )
14 | if (dataset == 'A')
15 | num_images = 182;
16 | else
17 | num_images = 316;
18 | end
19 |
20 | for i = 1:num_images
21 | if (mod(i,10)==0)
22 | fprintf(1,'Processing %3d/%d files\n', i, num_images);
23 | end
24 | load(strcat(gt_path, 'GT_IMG_',num2str(i),'.mat')) ;
25 | input_img_name = strcat(path,'IMG_',num2str(i),'.jpg');
26 | im = imread(input_img_name);
27 | [h, w, c] = size(im);
28 | if (c == 3)
29 | im = rgb2gray(im);
30 | end
31 | annPoints = image_info{1}.location;
32 | [h, w, c] = size(im);
33 | im_density = get_density_map_gaussian(im,annPoints);
34 | csvwrite([gt_path_csv ,'IMG_',num2str(i) '.csv'], im_density);
35 | end
36 |
37 |
--------------------------------------------------------------------------------
/data_preparation/create_training_set_shtech.m:
--------------------------------------------------------------------------------
1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2 | % File to create training and validation set %
3 | % for ShanghaiTech Dataset Part A and B. 10% of %
4 | % the training set is set aside for validation %
5 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
6 |
7 |
8 | clc; clear all;
9 | seed = 95461354;
10 | rng(seed)
11 | N = 9;
12 | dataset = 'A';
13 | dataset_name = ['shanghaitech_part_' dataset '_patches_' num2str(N)];
14 | path = ['../data/original/shanghaitech/part_' dataset '_final/train_data/images/'];
15 | output_path = '../data/formatted_trainval/';
16 | train_path_img = strcat(output_path, dataset_name,'/train/');
17 | train_path_den = strcat(output_path, dataset_name,'/train_den/');
18 | val_path_img = strcat(output_path, dataset_name,'/val/');
19 | val_path_den = strcat(output_path, dataset_name,'/val_den/');
20 | gt_path = ['../data/original/shanghaitech/part_' dataset '_final/train_data/ground_truth/'];
21 |
22 | mkdir(output_path)
23 | mkdir(train_path_img);
24 | mkdir(train_path_den);
25 | mkdir(val_path_img);
26 | mkdir(val_path_den);
27 |
28 | if (dataset == 'A')
29 | num_images = 300;
30 | else
31 | num_images = 400;
32 | end
33 | num_val = ceil(num_images*0.1);
34 | indices = randperm(num_images);
35 |
36 | for idx = 1:num_images
37 | i = indices(idx);
38 | if (mod(idx,10)==0)
39 | fprintf(1,'Processing %3d/%d files\n', idx, num_images);
40 | end
41 | load(strcat(gt_path, 'GT_IMG_',num2str(i),'.mat')) ;
42 | input_img_name = strcat(path,'IMG_',num2str(i),'.jpg');
43 | im = imread(input_img_name);
44 | [h, w, c] = size(im);
45 | if (c == 3)
46 | im = rgb2gray(im);
47 | end
48 |
49 | wn2 = w/8; hn2 = h/8;
50 | wn2 =8 * floor(wn2/8);
51 | hn2 =8 * floor(hn2/8);
52 |
53 | annPoints = image_info{1}.location;
54 | if( w <= 2*wn2 )
55 | im = imresize(im,[ h,2*wn2+1]);
56 | annPoints(:,1) = annPoints(:,1)*2*wn2/w;
57 | end
58 | if( h <= 2*hn2)
59 | im = imresize(im,[2*hn2+1,w]);
60 | annPoints(:,2) = annPoints(:,2)*2*hn2/h;
61 | end
62 | [h, w, c] = size(im);
63 | a_w = wn2+1; b_w = w - wn2;
64 | a_h = hn2+1; b_h = h - hn2;
65 |
66 | im_density = get_density_map_gaussian(im,annPoints);
67 | for j = 1:N
68 |
69 | x = floor((b_w - a_w) * rand + a_w);
70 | y = floor((b_h - a_h) * rand + a_h);
71 | x1 = x - wn2; y1 = y - hn2;
72 | x2 = x + wn2-1; y2 = y + hn2-1;
73 |
74 |
75 | im_sampled = im(y1:y2, x1:x2,:);
76 | im_density_sampled = im_density(y1:y2,x1:x2);
77 |
78 | annPoints_sampled = annPoints(annPoints(:,1)>x1 & ...
79 | annPoints(:,1) < x2 & ...
80 | annPoints(:,2) > y1 & ...
81 | annPoints(:,2) < y2,:);
82 | annPoints_sampled(:,1) = annPoints_sampled(:,1) - x1;
83 | annPoints_sampled(:,2) = annPoints_sampled(:,2) - y1;
84 | img_idx = strcat(num2str(i), '_',num2str(j));
85 |
86 | if(idx < num_val)
87 | imwrite(im_sampled, [val_path_img num2str(img_idx) '.jpg']);
88 | csvwrite([val_path_den num2str(img_idx) '.csv'], im_density_sampled);
89 | else
90 | imwrite(im_sampled, [train_path_img num2str(img_idx) '.jpg']);
91 | csvwrite([train_path_den num2str(img_idx) '.csv'], im_density_sampled);
92 | end
93 |
94 | end
95 |
96 | end
97 |
98 |
--------------------------------------------------------------------------------
/data_preparation/get_density_map_gaussian.m:
--------------------------------------------------------------------------------
1 | function im_density = get_density_map_gaussian(im,points)
2 |
3 |
4 | im_density = zeros(size(im));
5 | [h,w] = size(im_density);
6 |
7 | if(length(points)==0)
8 | return;
9 | end
10 |
11 | if(length(points(:,1))==1)
12 | x1 = max(1,min(w,round(points(1,1))));
13 | y1 = max(1,min(h,round(points(1,2))));
14 | im_density(y1,x1) = 255;
15 | return;
16 | end
17 | for j = 1:length(points)
18 | f_sz = 15;
19 | sigma = 4.0;
20 | H = fspecial('Gaussian',[f_sz, f_sz],sigma);
21 | x = min(w,max(1,abs(int32(floor(points(j,1))))));
22 | y = min(h,max(1,abs(int32(floor(points(j,2))))));
23 | if(x > w || y > h)
24 | continue;
25 | end
26 | x1 = x - int32(floor(f_sz/2)); y1 = y - int32(floor(f_sz/2));
27 | x2 = x + int32(floor(f_sz/2)); y2 = y + int32(floor(f_sz/2));
28 | dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0;
29 | change_H = false;
30 | if(x1 < 1)
31 | dfx1 = abs(x1)+1;
32 | x1 = 1;
33 | change_H = true;
34 | end
35 | if(y1 < 1)
36 | dfy1 = abs(y1)+1;
37 | y1 = 1;
38 | change_H = true;
39 | end
40 | if(x2 > w)
41 | dfx2 = x2 - w;
42 | x2 = w;
43 | change_H = true;
44 | end
45 | if(y2 > h)
46 | dfy2 = y2 - h;
47 | y2 = h;
48 | change_H = true;
49 | end
50 | x1h = 1+dfx1; y1h = 1+dfy1; x2h = f_sz - dfx2; y2h = f_sz - dfy2;
51 | if (change_H == true)
52 | H = fspecial('Gaussian',[double(y2h-y1h+1), double(x2h-x1h+1)],sigma);
53 | end
54 | im_density(y1:y2,x1:x2) = im_density(y1:y2,x1:x2) + H;
55 |
56 | end
57 |
58 | end
--------------------------------------------------------------------------------
/src/crowd_count.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import network
3 |
4 | class CrowdCounter(nn.Module):
5 | def __init__(self,model,pool):
6 | super(CrowdCounter, self).__init__()
7 | if model=='base':
8 | from models.base import base
9 | self.DME = base(pool)
10 | if model=='wide':
11 | from models.wide import wide
12 | self.DME = wide(pool)
13 | if model=='deep':
14 | from models.deep import deep
15 | self.DME = deep(pool)
16 |
17 | self.loss_fn = nn.MSELoss()
18 |
19 | @property
20 | def loss(self):
21 | return self.loss_mse
22 |
23 | def forward(self, im_data, gt_data=None):
24 | im_data = network.np_to_variable(im_data, is_cuda=True, is_training=self.training)
25 | density_map = self.DME(im_data)
26 | if self.training:
27 | gt_data = network.np_to_variable(gt_data, is_cuda=True, is_training=self.training)
28 | self.loss_mse = self.build_loss(density_map, gt_data)
29 |
30 | return density_map
31 |
32 | def build_loss(self, density_map, gt_data):
33 | loss = self.loss_fn(density_map, gt_data)
34 | return loss
35 |
--------------------------------------------------------------------------------
/src/data_loader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import os
4 | import random
5 | import pandas as pd
6 |
7 | class ImageDataLoader():
8 | def __init__(self, data_path, gt_path, shuffle=False, gt_downsample=False, pre_load=False,
9 | batch_size=1, scaling=4, re_scale=1.0, re_size=None):
10 | #pre_load: if true, all training and validation images are loaded into CPU RAM for faster processing.
11 | # This avoids frequent file reads. Use this only for small datasets.
12 | self.data_path = data_path
13 | self.gt_path = gt_path
14 | self.gt_downsample = gt_downsample
15 | self.pre_load = pre_load
16 | self.data_files = [filename for filename in os.listdir(data_path) \
17 | if os.path.isfile(os.path.join(data_path,filename))]
18 | self.data_files.sort()
19 | self.shuffle = shuffle
20 | self.scaling = scaling
21 | self.re_scale = re_scale
22 | self.re_size = re_size
23 | if shuffle:
24 | random.seed(2468)
25 | self.num_samples = len(self.data_files)
26 | self.blob_list = {}
27 | self.id_list = range(0,self.num_samples/batch_size)
28 |
29 | batch = -1
30 | batch_full=False
31 | if self.pre_load:
32 | print 'Pre-loading the data. This may take a while...'
33 | idx = 0
34 | for fname in self.data_files:
35 |
36 | img = cv2.imread(os.path.join(self.data_path,fname),0)
37 | img = img.astype(np.float32, copy=False)
38 | if self.re_size is None:
39 | ht = img.shape[0]
40 | wd = img.shape[1]
41 | else:
42 | ht = self.re_size[0]
43 | wd = self.re_size[1]
44 | ht_1 = (ht/self.scaling)*self.scaling
45 | wd_1 = (wd/self.scaling)*self.scaling
46 | img = cv2.resize(img,(wd_1,ht_1))
47 | img = img.reshape((1,1,img.shape[0],img.shape[1]))
48 | img = img/self.re_scale
49 | den = pd.read_csv(os.path.join(self.gt_path,os.path.splitext(fname)[0] + '.csv'), sep=',',header=None).as_matrix()
50 | den = den.astype(np.float32, copy=False)
51 | if self.gt_downsample:
52 | wd_1 = wd_1/self.scaling
53 | ht_1 = ht_1/self.scaling
54 | den = cv2.resize(den,(wd_1,ht_1))
55 | den = den * ((wd*ht)/(wd_1*ht_1))
56 | else:
57 | den = cv2.resize(den,(wd_1,ht_1))
58 | den = den * ((wd*ht)/(wd_1*ht_1))
59 |
60 | den = den.reshape((1,1,den.shape[0],den.shape[1]))
61 | if idx==0:
62 | blob = {}
63 | blob['data']=img
64 | blob['gt_density']=den
65 | blob['fname'] = [fname]
66 | idx+=1
67 | batch_full=False
68 | if idx==batch_size:
69 | idx = 0
70 | batch_full=True
71 | else:
72 | blob['data']=np.concatenate((blob['data'],img))
73 | blob['gt_density']=np.concatenate((blob['gt_density'],den))
74 | blob['fname'].append(fname)
75 | idx+=1
76 | batch_full=False
77 | if idx==batch_size:
78 | idx = 0
79 | batch_full=True
80 |
81 | if batch_full:
82 | batch+=1
83 | self.blob_list[batch] = blob
84 | if batch % 200 == 0:
85 | print 'Loaded', batch, 'batch', batch*batch_size, '/', self.num_samples, 'files'
86 |
87 | print 'Completed Loading ', batch+1, 'batches'
88 |
89 |
90 | def __iter__(self):
91 | if self.shuffle:
92 | if self.pre_load:
93 | random.shuffle(self.id_list)
94 | else:
95 | random.shuffle(self.data_files)
96 | files = self.data_files
97 | id_list = self.id_list
98 |
99 | for idx in id_list:
100 | if self.pre_load:
101 | blob = self.blob_list[idx]
102 | blob['idx'] = idx
103 | else:
104 | fname = files[idx]
105 | img = cv2.imread(os.path.join(self.data_path,fname),0)
106 | img = img.astype(np.float32, copy=False)
107 | if self.re_size is None:
108 | ht = img.shape[0]
109 | wd = img.shape[1]
110 | else:
111 | ht = self.re_size[0]
112 | wd = self.re_size[1]
113 | ht_1 = (ht/self.scaling)*self.scaling
114 | wd_1 = (wd/self.scaling)*self.scaling
115 | img = cv2.resize(img,(wd_1,ht_1))
116 | img = img.reshape((1,1,img.shape[0],img.shape[1]))
117 | img = img/self.re_scale
118 | den = pd.read_csv(os.path.join(self.gt_path,os.path.splitext(fname)[0] + '.csv'), sep=',',header=None).as_matrix()
119 | den = den.astype(np.float32, copy=False)
120 | if self.gt_downsample:
121 | wd_1 = wd_1/self.scaling
122 | ht_1 = ht_1/self.scaling
123 | den = cv2.resize(den,(wd_1,ht_1))
124 | den = den * ((wd*ht)/(wd_1*ht_1))
125 | else:
126 | den = cv2.resize(den,(wd_1,ht_1))
127 | den = den * ((wd*ht)/(wd_1*ht_1))
128 |
129 | den = den.reshape((1,1,den.shape[0],den.shape[1]))
130 | blob = {}
131 | blob['data']=img
132 | blob['gt_density']=den
133 | blob['fname'] = fname
134 |
135 | yield blob
136 |
137 | def get_num_samples(self):
138 | return self.num_samples
139 |
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/src/evaluate_model.py:
--------------------------------------------------------------------------------
1 | from crowd_count import CrowdCounter
2 | import network
3 | import numpy as np
4 |
5 |
6 | def evaluate_model(trained_model, data_loader, model, pool):
7 | net = CrowdCounter(model=model,pool=pool)
8 | network.load_net(trained_model, net)
9 | net.cuda()
10 | net.eval()
11 | mae = 0.0
12 | mse = 0.0
13 | for blob in data_loader:
14 | im_data = blob['data']
15 | gt_data = blob['gt_density']
16 | density_map = net(im_data, gt_data)
17 | density_map = density_map.data.cpu().numpy()
18 | gt_count = np.sum(gt_data)
19 | et_count = np.sum(density_map)
20 | mae += abs(gt_count-et_count)
21 | mse += ((gt_count-et_count)*(gt_count-et_count))
22 | mae = mae/data_loader.get_num_samples()
23 | mse = np.sqrt(mse/data_loader.get_num_samples())
24 | return mae,mse
--------------------------------------------------------------------------------
/src/models/base.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | The Base-Net described in our paper.
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | from src.network import Conv2d
9 | import time
10 |
11 | class base(nn.Module):
12 |
13 | def __init__(self, pool, bn=False):
14 | super(base, self).__init__()
15 |
16 | kernel_size = 5
17 | self.pool = pool
18 | if kernel_size==7:
19 | self.c1 = Conv2d( 1, 16, 9, same_padding=True, bn=bn)
20 | self.c2 = Conv2d(16, 32, 7, same_padding=True, bn=bn)
21 | self.c3_5 = nn.Sequential(Conv2d(32, 16, 7, same_padding=True, bn=bn),
22 | Conv2d(16, 8, 7, same_padding=True, bn=bn),
23 | Conv2d( 8, 1, 1, same_padding=True, bn=bn))
24 | if kernel_size==5:
25 | self.c1 = Conv2d( 1, 20, 7, same_padding=True, bn=bn)
26 | self.c2 = Conv2d(20, 40, 5, same_padding=True, bn=bn)
27 | self.c3_5 = nn.Sequential(Conv2d(40, 20, 5, same_padding=True, bn=bn),
28 | Conv2d(20, 10, 5, same_padding=True, bn=bn),
29 | Conv2d( 10, 1, 1, same_padding=True, bn=bn))
30 | if kernel_size==3:
31 | self.c1 = Conv2d( 1, 24, 5, same_padding=True, bn=bn)
32 | self.c2 = Conv2d(24, 48, 3, same_padding=True, bn=bn)
33 | self.c3_5 = nn.Sequential(Conv2d(48, 24, 3, same_padding=True, bn=bn),
34 | Conv2d(24, 12, 3, same_padding=True, bn=bn),
35 | Conv2d( 12, 1, 1, same_padding=True, bn=bn))
36 |
37 | self.pool2 = nn.MaxPool2d(2, stride=2)
38 | self.pool2s1 = nn.MaxPool2d(2, stride=1)
39 | self.pool3s1 = nn.MaxPool2d(3, stride=1, padding=1)
40 | self.pool4 = nn.MaxPool2d(4, stride=2, padding=1)
41 | self.pool8 = nn.MaxPool2d(8, stride=2, padding=3)
42 |
43 | self.padding = nn.ReplicationPad2d((0, 1, 0, 1))
44 |
45 | def multi_pool(self, x):
46 | x1 = self.pool2(x)
47 | x2 = self.pool4(x)
48 | x3 = self.pool8(x)
49 | y = (x1+x2+x3)/3.0
50 | return y
51 |
52 | def stack_pool(self, x):
53 | x1 = self.pool2(x)
54 | x2 = self.pool2s1(self.padding(x1))
55 | x3 = self.pool3s1(x2)
56 | y = (x1+x2+x3)/3.0
57 | return y
58 |
59 | def forward(self, im_data):
60 | x = self.c1(im_data)
61 |
62 | if self.pool=='mpool':
63 | x = self.multi_pool(x)
64 | if self.pool=='stackpool':
65 | x = self.stack_pool(x)
66 | if self.pool=='vpool':
67 | x = self.pool2(x)
68 |
69 | x = self.c2(x)
70 |
71 | if self.pool=='mpool':
72 | x = self.multi_pool(x)
73 | if self.pool=='stackpool':
74 | x = self.stack_pool(x)
75 | if self.pool=='vpool':
76 | x = self.pool2(x)
77 |
78 | x = self.c3_5(x)
79 |
80 | return x
--------------------------------------------------------------------------------
/src/models/deep.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | The Deep-Net described in our paper.
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | import math
9 | import numpy as np
10 |
11 | cfg = {
12 | # 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 512, 512],
13 | # 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 512, 512],
14 | 'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 128, 64, 32, 16],
15 | # 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 512, 512, 512],
16 | # 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 512, 512, 512, 512],
17 | }
18 |
19 | class multi_pool(nn.Module):
20 | def __init__(self):
21 | super(multi_pool, self).__init__()
22 | self.pool2 = nn.MaxPool2d(2, stride=2)
23 | self.pool4 = nn.MaxPool2d(4, stride=2, padding=1)
24 | self.pool8 = nn.MaxPool2d(8, stride=2, padding=3)
25 | def forward(self, x):
26 | x1 = self.pool2(x)
27 | x2 = self.pool4(x)
28 | x3 = self.pool8(x)
29 | y = (x1+x2+x3)/3.0
30 | return y
31 |
32 | class stack_pool(nn.Module):
33 | def __init__(self):
34 | super(stack_pool, self).__init__()
35 | self.pool2 = nn.MaxPool2d(2, stride=2)
36 | self.pool2s1 = nn.MaxPool2d(2, stride=1)
37 | self.pool3s1 = nn.MaxPool2d(3, stride=1, padding=1)
38 | self.padding = nn.ReplicationPad2d((0, 1, 0, 1))
39 | def forward(self, x):
40 | x1 = self.pool2(x)
41 | x2 = self.pool2s1(self.padding(x1))
42 | x3 = self.pool3s1(x2)
43 | y = (x1+x2+x3)/3.0
44 | return y
45 |
46 | class feature_net(nn.Module):
47 | def __init__(self,pool):
48 | super(feature_net, self).__init__()
49 | self.pool = pool
50 | self.features = self.make_layers(cfg = cfg['C'], batch_norm = False)
51 | def forward(self, x):
52 | feature = self.features(x)
53 | return feature
54 | def make_layers(self, cfg, batch_norm = False):
55 | layers = []
56 | in_channels = 1
57 | idx_M = 0
58 | conv_size = 5
59 | for v in cfg:
60 | if v == 'M':
61 | idx_M += 1
62 | if idx_M >= 2:
63 | conv_size = 3
64 | if self.pool == 'mpool':
65 | layers += [multi_pool()]
66 | if self.pool == 'stackpool':
67 | layers += [stack_pool()]
68 | if self.pool == 'vpool':
69 | layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)]
70 | else:
71 | conv2d = nn.Conv2d(in_channels, v, kernel_size = conv_size, padding = (conv_size-1)/2 )
72 | if batch_norm:
73 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)]
74 | else:
75 | layers += [conv2d, nn.ReLU(inplace = True)]
76 | in_channels = v
77 | return nn.Sequential(*layers)
78 |
79 | class deep(nn.Module):
80 | def __init__(self,pool):
81 | super(deep, self).__init__()
82 | self.conv2d = nn.Conv2d(16, 1, kernel_size = 1)
83 | self.feature_net = feature_net(pool)
84 | #self._initialize_weights()
85 | def forward(self, x):
86 | x = self.feature_net.forward(x)
87 | heat_map = self.conv2d(x)
88 | return heat_map
89 | def _initialize_weights(self):
90 | for m in self.modules():
91 | if isinstance(m, nn.Conv2d):
92 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
93 | if m.bias is not None:
94 | nn.init.constant_(m.bias, 0)
95 |
--------------------------------------------------------------------------------
/src/models/wide.py:
--------------------------------------------------------------------------------
1 |
2 | '''
3 | The Wide-Net described in our paper.
4 | '''
5 |
6 | import torch
7 | import torch.nn as nn
8 | import math
9 | import numpy as np
10 |
11 | cfg = {
12 | 'wide': [128, 'M', 256, 'M',128, 64],
13 | }
14 |
15 | class multi_pool(nn.Module):
16 | def __init__(self):
17 | super(multi_pool, self).__init__()
18 | self.pool2 = nn.MaxPool2d(2, stride=2)
19 | self.pool4 = nn.MaxPool2d(4, stride=2, padding=1)
20 | self.pool8 = nn.MaxPool2d(8, stride=2, padding=3)
21 | def forward(self, x):
22 | x1 = self.pool2(x)
23 | x2 = self.pool4(x)
24 | x3 = self.pool8(x)
25 | y = (x1+x2+x3)/3.0
26 | return y
27 |
28 | class stack_pool(nn.Module):
29 | def __init__(self):
30 | super(stack_pool, self).__init__()
31 | self.pool2 = nn.MaxPool2d(2, stride=2)
32 | self.pool2s1 = nn.MaxPool2d(2, stride=1)
33 | self.pool3s1 = nn.MaxPool2d(3, stride=1, padding=1)
34 | self.padding = nn.ReplicationPad2d((0, 1, 0, 1))
35 | def forward(self, x):
36 | x1 = self.pool2(x)
37 | x2 = self.pool2s1(self.padding(x1))
38 | x3 = self.pool3s1(x2)
39 | y = (x1+x2+x3)/3.0
40 | return y
41 |
42 | class feature_net(nn.Module):
43 | def __init__(self,pool):
44 | super(feature_net, self).__init__()
45 | self.pool = pool
46 | self.features = self.make_layers(cfg = cfg['wide'], batch_norm = False)
47 | def forward(self, x):
48 | feature = self.features(x)
49 | return feature
50 | def make_layers(self, cfg, batch_norm = False):
51 | layers = []
52 | in_channels = 1
53 | idx_M = 0
54 | conv_size = 7
55 | for v in cfg:
56 | if v == 'M':
57 | idx_M += 1
58 | if idx_M >= 1:
59 | conv_size = 5
60 | if idx_M >= 2:
61 | conv_size = 3
62 | if self.pool == 'mpool':
63 | layers += [multi_pool()]
64 | if self.pool == 'stackpool':
65 | layers += [stack_pool()]
66 | if self.pool == 'vpool':
67 | layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)]
68 | else:
69 | conv2d = nn.Conv2d(in_channels, v, kernel_size = conv_size, padding = (conv_size-1)/2 )
70 | if batch_norm:
71 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace = True)]
72 | else:
73 | layers += [conv2d, nn.ReLU(inplace = True)]
74 | in_channels = v
75 | return nn.Sequential(*layers)
76 |
77 | class wide(nn.Module):
78 | def __init__(self,pool):
79 | super(wide, self).__init__()
80 | self.conv2d = nn.Conv2d(64, 1, kernel_size = 1)
81 | self.feature_net = feature_net(pool)
82 | #self._initialize_weights()
83 | def forward(self, x):
84 | x = self.feature_net.forward(x)
85 | heat_map = self.conv2d(x)
86 | return heat_map
87 | def _initialize_weights(self):
88 | for m in self.modules():
89 | if isinstance(m, nn.Conv2d):
90 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
91 | if m.bias is not None:
92 | nn.init.constant_(m.bias, 0)
93 |
--------------------------------------------------------------------------------
/src/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import numpy as np
5 |
6 | class Conv2d(nn.Module):
7 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, same_padding=False, bn=False):
8 | super(Conv2d, self).__init__()
9 |
10 | if isinstance(kernel_size, tuple)==True:
11 | padding = (int((kernel_size[0] - 1) / 2), int((kernel_size[1] - 1) / 2)) if same_padding else 0
12 | else:
13 | padding = int((kernel_size - 1) / 2) if same_padding else 0
14 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding)
15 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None
16 | self.relu = nn.ReLU(inplace=True) if relu else None
17 |
18 | def forward(self, x):
19 | x = self.conv(x)
20 | if self.bn is not None:
21 | x = self.bn(x)
22 | if self.relu is not None:
23 | x = self.relu(x)
24 | return x
25 |
26 |
27 | class FC(nn.Module):
28 | def __init__(self, in_features, out_features, relu=True):
29 | super(FC, self).__init__()
30 | self.fc = nn.Linear(in_features, out_features)
31 | self.relu = nn.ReLU(inplace=True) if relu else None
32 |
33 | def forward(self, x):
34 | x = self.fc(x)
35 | if self.relu is not None:
36 | x = self.relu(x)
37 | return x
38 |
39 | class MaxPoolSame(nn.Module):
40 | def __init__(self):
41 | super(MaxPoolSame, self).__init__()
42 | self.padding = nn.ReplicationPad2d((0, 1, 0, 1))
43 | self.pooling = nn.MaxPool2d(2, 1)
44 |
45 | def forward(self, x):
46 | out = self.pooling(self.padding(x))
47 | assert(x.shape == out.shape)
48 | return out
49 |
50 |
51 | def save_net(fname, net):
52 | import h5py
53 | h5f = h5py.File(fname, mode='w')
54 | for k, v in net.state_dict().items():
55 | h5f.create_dataset(k, data=v.cpu().numpy())
56 |
57 |
58 | def load_net(fname, net):
59 | import h5py
60 | h5f = h5py.File(fname, mode='r')
61 | for k, v in net.state_dict().items():
62 | param = torch.from_numpy(np.asarray(h5f[k]))
63 | v.copy_(param)
64 |
65 |
66 | def np_to_variable(x, is_cuda=True, is_training=False, dtype=torch.FloatTensor):
67 | if is_training:
68 | v = Variable(torch.from_numpy(x).type(dtype))
69 | else:
70 | v = Variable(torch.from_numpy(x).type(dtype), requires_grad = False)#, volatile = True)
71 | if is_cuda:
72 | v = v.cuda()
73 | return v
74 |
75 |
76 | def set_trainable(model, requires_grad):
77 | for param in model.parameters():
78 | param.requires_grad = requires_grad
79 |
80 |
81 | def weights_normal_init(model, dev=0.01):
82 | if isinstance(model, list):
83 | for m in model:
84 | weights_normal_init(m, dev)
85 | else:
86 | for m in model.modules():
87 | if isinstance(m, nn.Conv2d):
88 | #print torch.sum(m.weight)
89 | m.weight.data.normal_(0.0, dev)
90 | if m.bias is not None:
91 | m.bias.data.fill_(0.0)
92 | elif isinstance(m, nn.Linear):
93 | m.weight.data.normal_(0.0, dev)
94 |
--------------------------------------------------------------------------------
/src/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | class Timer(object):
4 | def __init__(self):
5 | self.tot_time = 0.
6 | self.calls = 0
7 | self.start_time = 0.
8 | self.diff = 0.
9 | self.average_time = 0.
10 |
11 | def tic(self):
12 | # using time.time instead of time.clock because time time.clock
13 | # does not normalize for multithreading
14 | self.start_time = time.time()
15 |
16 | def toc(self, average=True):
17 | self.diff = time.time() - self.start_time
18 | self.tot_time += self.diff
19 | self.calls += 1
20 | self.average_time = self.tot_time / self.calls
21 | if average:
22 | return self.average_time
23 | else:
24 | return self.diff
25 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 |
5 | def save_results(input_img, gt_data,density_map,output_dir, fname='results.png'):
6 | input_img = input_img[0][0]
7 | gt_data = 255*gt_data/np.max(gt_data)
8 | density_map = 255*density_map/np.max(density_map)
9 | gt_data = gt_data[0][0]
10 | density_map= density_map[0][0]
11 | if density_map.shape[1] != input_img.shape[1]:
12 | density_map = cv2.resize(density_map, (input_img.shape[1],input_img.shape[0]))
13 | gt_data = cv2.resize(gt_data, (input_img.shape[1],input_img.shape[0]))
14 | result_img = np.hstack((input_img,gt_data,density_map))
15 | cv2.imwrite(os.path.join(output_dir,fname),result_img)
16 |
17 |
18 | def save_density_map(density_map,output_dir, fname='results.png'):
19 | density_map = 255*density_map/np.max(density_map)
20 | density_map= density_map[0][0]
21 | cv2.imwrite(os.path.join(output_dir,fname),density_map)
22 |
23 | def display_results(input_img, gt_data,density_map):
24 | input_img = input_img[0][0]
25 | gt_data = 255*gt_data/np.max(gt_data)
26 | density_map = 255*density_map/np.max(density_map)
27 | gt_data = gt_data[0][0]
28 | density_map= density_map[0][0]
29 | if density_map.shape[1] != input_img.shape[1]:
30 | input_img = cv2.resize(input_img, (density_map.shape[1],density_map.shape[0]))
31 | result_img = np.hstack((input_img,gt_data,density_map))
32 | result_img = result_img.astype(np.uint8, copy=False)
33 | cv2.imshow('Result', result_img)
34 | cv2.waitKey(0)
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | np.set_printoptions(threshold=np.nan)
5 |
6 | from src.crowd_count import CrowdCounter
7 | from src import network
8 | from src.data_loader import ImageDataLoader
9 | from src import utils
10 |
11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
12 | np.warnings.filterwarnings('ignore')
13 | # dataset, model, and pooling method
14 | datasets = ['shtechA', 'shtechB'] # datasets
15 | models = ['base', 'wide', 'deep'] # backbone network architecture
16 | pools = ['vpool','stackpool','mpool'] # vpool is vanilla pooling; stackpool is stacked pooling; mpool is multi-kernel pooling
17 |
18 | ###
19 | dataset_name = datasets[0] # choose the dataset
20 | model = models[2] # choose the backbone network architecture
21 | pool = pools[0] # choose the pooling method
22 | method=model+'_'+pool
23 |
24 | name = dataset_name[-1]
25 | data_path = './data/original/shanghaitech/part_'+name+'_final/test_data/images/'
26 | gt_path = './data/original/shanghaitech/part_'+name+'_final/test_data/ground_truth_csv/'
27 | model_path = './saved_models/'+method+'_shtech'+name+'_0.h5'
28 | print 'Testing %s' % (model_path)
29 |
30 | torch.backends.cudnn.enabled = True
31 | torch.backends.cudnn.benchmark = True
32 | vis = False
33 | save_output = True
34 |
35 | output_dir = './output/'
36 | model_name = os.path.basename(model_path).split('.')[0]
37 | file_results = os.path.join(output_dir,'results_' + model_name + '_.txt')
38 | if not os.path.exists(output_dir):
39 | os.mkdir(output_dir)
40 | output_dir = os.path.join(output_dir, 'density_maps_' + model_name)
41 | if not os.path.exists(output_dir):
42 | os.mkdir(output_dir)
43 |
44 | net = CrowdCounter(model,pool)
45 | trained_model = os.path.join(model_path)
46 | network.load_net(trained_model, net)
47 | net.cuda()
48 | net.eval()
49 |
50 | if model in ['base','wide']:
51 | scaling = 4
52 | if model=='deep':
53 | scaling = 8
54 |
55 | #load test data
56 | data_loader = ImageDataLoader(data_path, gt_path, shuffle=False, gt_downsample=True, pre_load=False, batch_size=1, scaling=scaling)
57 |
58 | mae = 0.0
59 | mse = 0.0
60 | num = 0
61 | for blob in data_loader:
62 | num+=1
63 | im_data = blob['data']
64 | gt_data = blob['gt_density']
65 | density_map = net(im_data)
66 | density_map = density_map.data.cpu().numpy()
67 | gt_count = np.sum(gt_data)
68 | et_count = np.sum(density_map)
69 | mae += abs(gt_count-et_count)
70 | mse += ((gt_count-et_count)*(gt_count-et_count))
71 | if vis:
72 | utils.display_results(im_data, gt_data, density_map)
73 | if save_output:
74 | utils.save_density_map(density_map, output_dir, 'output_' + blob['fname'].split('.')[0] + '.png')
75 | if num%100==0:
76 | print '%d/%d' % (num,data_loader.get_num_samples())
77 |
78 | mae = mae/data_loader.get_num_samples()
79 | mse = np.sqrt(mse/data_loader.get_num_samples())
80 | print 'MAE: %0.2f, MSE: %0.2f' % (mae,mse)
81 |
82 | f = open(file_results, 'w')
83 | f.write('MAE: %0.2f, MSE: %0.2f' % (mae,mse))
84 | f.close()
85 |
--------------------------------------------------------------------------------
/thumbnails/stackpool.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/siyuhuang/crowdcount-stackpool/bbba3d9e91a5a89642b4bd3638ae8e68801ea7bf/thumbnails/stackpool.jpg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import sys
5 |
6 | from src.crowd_count import CrowdCounter
7 | from src import network
8 | from src.data_loader import ImageDataLoader
9 | from src.timer import Timer
10 | from src import utils
11 | from src.evaluate_model import evaluate_model
12 | import time
13 |
14 | np.warnings.filterwarnings('ignore')
15 | ### assign dataset, model, and pooling method
16 | datasets = ['shtechA', 'shtechB'] # datasets
17 | models = ['base', 'wide', 'deep'] # backbone network architecture
18 | pools = ['vpool','stackpool','mpool'] # vpool is vanilla pooling; stackpool is stacked pooling; mpool is multi-kernel pooling;
19 |
20 | dataset_name = datasets[0] # choose the dataset
21 | model = models[2] # choose the backbone network architecture
22 | pool = pools[0] # choose the pooling method
23 | method=model+'_'+pool
24 | print 'Training %s on %s' % (method, dataset_name)
25 |
26 | ### assign GPU
27 | if pool=='vpool':
28 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
29 | if pool=='stackpool':
30 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
31 | if pool=='mpool':
32 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
33 |
34 | ### PyTorch configuration
35 | torch.backends.cudnn.enabled = True
36 | torch.backends.cudnn.benchmark = True
37 |
38 | ### model saving folder
39 | output_dir = './saved_models/'
40 | if not os.path.exists(output_dir):
41 | os.mkdir(output_dir)
42 |
43 | ### data folder
44 | name = dataset_name[-1]
45 | train_path = './data/formatted_trainval/shanghaitech_part_'+name+'_patches_9/train'
46 | train_gt_path = './data/formatted_trainval/shanghaitech_part_'+name+'_patches_9/train_den'
47 | val_path = './data/formatted_trainval/shanghaitech_part_'+name+'_patches_9/val'
48 | val_gt_path = './data/formatted_trainval/shanghaitech_part_'+name+'_patches_9/val_den'
49 |
50 | ### training configuration
51 | start_step = 0
52 | end_step = 500
53 | batch_size=1
54 | disp_interval = 1500
55 | if model=='base':
56 | if dataset_name == 'shtechA':
57 | lr = 5*1e-5
58 | if dataset_name == 'shtechB':
59 | lr = 2*1e-5
60 | scaling=4 # output density map is 1/4 size of input image
61 | if model=='wide':
62 | if dataset_name == 'shtechA':
63 | lr = 1e-5
64 | if dataset_name == 'shtechB':
65 | lr = 1e-5
66 | scaling=4 # output density map is 1/4 size of input image
67 | if model=='deep':
68 | if dataset_name == 'shtechA':
69 | lr = 1e-5
70 | if dataset_name == 'shtechB':
71 | lr = 5*1e-6
72 | scaling=8 # output density map is 1/8 size of input image
73 | print 'learning rate %f' % (lr)
74 |
75 | ### random seed
76 | rand_seed = 64678
77 | if rand_seed is not None:
78 | np.random.seed(rand_seed)
79 | torch.manual_seed(rand_seed)
80 | torch.cuda.manual_seed_all(rand_seed)
81 |
82 | ### initialize network
83 | net = CrowdCounter(model=model,pool=pool)
84 | network.weights_normal_init(net, dev=0.01)
85 | net.cuda()
86 | net.train()
87 |
88 | ### optimizer
89 | params = list(net.parameters())
90 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)
91 |
92 | ### load data
93 | pre_load=True
94 | data_loader = ImageDataLoader(train_path, train_gt_path, shuffle=True, gt_downsample=True, pre_load=pre_load,
95 | batch_size=batch_size,scaling=scaling)
96 | data_loader_val = ImageDataLoader(val_path, val_gt_path, shuffle=False, gt_downsample=True, pre_load=pre_load,
97 | batch_size=1,scaling=scaling)
98 |
99 | ### training
100 | train_loss = 0
101 | t = Timer()
102 | t.tic()
103 | best_mae = sys.maxint
104 |
105 | for epoch in range(start_step, end_step+1):
106 | step = 0
107 | train_loss = 0
108 | for blob in data_loader:
109 | step = step + 1
110 | im_data = blob['data']
111 | gt_data = blob['gt_density']
112 | density_map = net(im_data, gt_data)
113 | loss = net.loss
114 | train_loss += loss.item()
115 | optimizer.zero_grad()
116 | loss.backward()
117 | optimizer.step()
118 |
119 | if step % disp_interval == 0:
120 | duration = t.toc(average=False)
121 | density_map = density_map.data.cpu().numpy()
122 | utils.save_results(im_data,gt_data,density_map, output_dir)
123 | print 'epoch: %4d, step %4d, Time: %.4fs, loss: %4.10f' % (epoch, step, duration, train_loss/disp_interval)
124 | train_loss = 0
125 | t.tic()
126 |
127 | if (epoch % 2 == 0):
128 | # save model checkpoint
129 | save_name = os.path.join(output_dir, '{}_{}_{}.h5'.format(method,dataset_name,epoch))
130 | network.save_net(save_name, net)
131 | # calculate error on the validation dataset
132 | mae,mse = evaluate_model(save_name, data_loader_val, model, pool)
133 | if mae < best_mae:
134 | best_mae = mae
135 | best_mse = mse
136 | best_model = '{}_{}_{}.h5'.format(method,dataset_name,epoch)
137 | print 'EPOCH: %d, MAE: %0.2f, MSE: %0.2f' % (epoch,mae,mse)
138 | print 'BEST MAE: %0.2f, BEST MSE: %0.2f, BEST MODEL: %s' % (best_mae,best_mse, best_model)
139 |
140 |
141 | t.tic()
142 |
143 |
144 |
145 |
--------------------------------------------------------------------------------