├── .idea
├── SANet_pytoch.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── .ipynb_checkpoints
├── main-checkpoint.ipynb
├── predict-checkpoint.ipynb
└── second_phase-checkpoint.ipynb
├── DataConstructor.py
├── FPN_SAN_Net.py
├── README.md
├── __pycache__
├── DataConstructor.cpython-36.pyc
├── FPN_SAN_Net.cpython-36.pyc
├── metrics.cpython-36.pyc
├── net.cpython-36.pyc
├── ssim_loss.cpython-36.pyc
└── utils.cpython-36.pyc
├── checkpoints
└── model_1_rate_b_0315_19:35.pkl
├── generate_density_map.py
├── main.py
├── metrics.py
├── net.py
├── predict.py
├── ssim_loss.py
└── utils.py
/.idea/SANet_pytoch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
168 |
169 |
170 |
171 | _assert_no_grad
172 | gaussian_kernel
173 | size_average
174 | int(img_shape[1] / 4)
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 | true
202 | DEFINITION_ORDER
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 | 1551863282698
329 |
330 |
331 | 1551863282698
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
--------------------------------------------------------------------------------
/DataConstructor.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import torchvision.transforms as transforms
4 | import torchvision.transforms.functional as F
5 | import torch.utils.data as data
6 | import random
7 | import time
8 | from utils import HSI_Calculator
9 | import torch
10 |
11 |
12 | class DatasetConstructor(data.Dataset):
13 | def __init__(self,
14 | data_dir_path,
15 | gt_dir_path,
16 | train_num,
17 | validate_num,
18 | if_train=True
19 | ):
20 | self.train_num = train_num
21 | self.validate_num = validate_num
22 | self.imgs = []
23 | self.data_root = data_dir_path
24 | self.gt_root = gt_dir_path
25 | self.train = if_train
26 | self.train_permulation = np.random.permutation(self.train_num)
27 | self.eval_permulation = random.sample(range(0, self.train_num), self.validate_num)
28 | self.calcu = HSI_Calculator()
29 | for i in range(self.train_num):
30 | img_name = '/IMG_' + str(i + 1) + ".jpg"
31 | gt_map_name = '/GT_IMG_' + str(i + 1) + ".npy"
32 | img = Image.open(self.data_root + img_name).convert("RGB")
33 | gt_map = Image.fromarray(np.squeeze(np.load(self.gt_root + gt_map_name)))
34 | self.imgs.append([img, gt_map])
35 |
36 | def __getitem__(self, index):
37 |
38 | start = time.time()
39 | if self.train:
40 | img, gt_map = self.imgs[self.train_permulation[index]]
41 | img = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1)(img)
42 | flip_random = random.random()
43 | if flip_random > 0.5:
44 | img = F.hflip(img)
45 | gt_map = F.hflip(gt_map)
46 | img = transforms.ToTensor()(img)
47 | gt_map = transforms.ToTensor()(gt_map)
48 | img_shape = img.shape # C, H, W
49 | random_h = random.randint(0, (3 / 4) * img_shape[1] - 1)
50 | random_w = random.randint(0, (3 / 4) * img_shape[2] - 1)
51 | patch_height = int(img_shape[1] / 4)
52 | patch_width = int(img_shape[2] / 4)
53 | img = img[:, random_h:random_h + patch_height, random_w:random_w + patch_width]
54 | gt_map = gt_map[:, random_h:random_h + patch_height, random_w:random_w + patch_width]
55 | end = time.time()
56 | img = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(img)
57 | return self.train_permulation[index] + 1, img, gt_map, (end - start)
58 |
59 | else:
60 | img, gt_map = self.imgs[self.eval_permulation[index]]
61 | img = transforms.ToTensor()(img)
62 | gt_map = transforms.ToTensor()(gt_map)
63 | img_shape = img.shape # C, H, W
64 | patch_height = int(img_shape[1] / 4)
65 | patch_width = int(img_shape[2] / 4)
66 | img = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(img)
67 | imgs = []
68 | for i in range(7):
69 | for j in range(7):
70 | start_h = int(patch_height / 2) * i
71 | start_w = int(patch_width / 2) * j
72 | # print(img.shape, start_h, start_w, patch_height, patch_width)
73 | imgs.append(img[:, start_h:start_h + patch_height, start_w:start_w + patch_width])
74 | imgs = torch.stack(imgs)
75 | end = time.time()
76 | return self.eval_permulation[index] + 1, imgs, gt_map, (end - start)
77 |
78 | def __len__(self):
79 | if self.train:
80 | return self.train_num
81 | else:
82 | return self.validate_num
83 |
84 | def shuffle(self):
85 | if self.train:
86 | self.train_permulation = np.random.permutation(self.train_num)
87 | else:
88 | self.eval_permulation = random.sample(range(0, self.train_num), self.validate_num)
89 | return self
90 |
91 | def eval_model(self):
92 | self.train = False
93 | return self
94 |
95 | def train_model(self):
96 | self.train = True
97 | return self
98 |
--------------------------------------------------------------------------------
/FPN_SAN_Net.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIGKnight/SANet_pytorch/594bb050f5ff72e5fb0b54d17cefa809868e59a7/FPN_SAN_Net.py
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #SANet_CROWD_COUNTINg
2 | implement by using pytorch
3 | And the result are virtually same with my tensorflow implementation [BIGKnight/SANet_implementation](https://github.com/BIGKnight/SANet_implementation).
4 |
However, this version are more close to the paper's net structure.(except I put the BN or IN behind the relu instead of before it. and it seemed in fact the BN layer did not ameliorate the net ability)
5 |
I still can not reach the result as the paper said(in fact, far worse than that). To be honest, I somewhat doubt the results showed in the paper.
--------------------------------------------------------------------------------
/__pycache__/DataConstructor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIGKnight/SANet_pytorch/594bb050f5ff72e5fb0b54d17cefa809868e59a7/__pycache__/DataConstructor.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/FPN_SAN_Net.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIGKnight/SANet_pytorch/594bb050f5ff72e5fb0b54d17cefa809868e59a7/__pycache__/FPN_SAN_Net.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIGKnight/SANet_pytorch/594bb050f5ff72e5fb0b54d17cefa809868e59a7/__pycache__/metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/net.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIGKnight/SANet_pytorch/594bb050f5ff72e5fb0b54d17cefa809868e59a7/__pycache__/net.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/ssim_loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIGKnight/SANet_pytorch/594bb050f5ff72e5fb0b54d17cefa809868e59a7/__pycache__/ssim_loss.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIGKnight/SANet_pytorch/594bb050f5ff72e5fb0b54d17cefa809868e59a7/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/checkpoints/model_1_rate_b_0315_19:35.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BIGKnight/SANet_pytorch/594bb050f5ff72e5fb0b54d17cefa809868e59a7/checkpoints/model_1_rate_b_0315_19:35.pkl
--------------------------------------------------------------------------------
/generate_density_map.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import scipy
4 | import scipy.io as scio
5 | from PIL import Image
6 |
7 |
8 | def get_density_map_gaussian(N, M, points, adaptive_kernel=False, fixed_value=15):
9 | density_map = np.zeros([N, M], dtype=np.float32)
10 | h, w = density_map.shape[:2]
11 | num_gt = np.squeeze(points).shape[0]
12 | if num_gt == 0:
13 | return density_map
14 |
15 | if adaptive_kernel:
16 | # referred from https://github.com/vlad3996/computing-density-maps/blob/master/make_ShanghaiTech.ipynb
17 | leafsize = 2048
18 | tree = scipy.spatial.KDTree(points.copy(), leafsize=leafsize)
19 | distances = tree.query(points, k=4)[0]
20 |
21 | for idx, p in enumerate(points):
22 | p = np.round(p).astype(int)
23 | p[0], p[1] = min(h-1, p[1]), min(w-1, p[0])
24 | if num_gt > 1:
25 | if adaptive_kernel:
26 | sigma = int(np.sum(distances[idx][1:4]) // 3 * 0.3)
27 | else:
28 | sigma = fixed_value
29 | else:
30 | sigma = fixed_value # np.average([h, w]) / 2. / 2.
31 | sigma = max(1, sigma)
32 |
33 | gaussian_radius = sigma * 3
34 | gaussian_map = np.multiply(
35 | cv2.getGaussianKernel(gaussian_radius*2+1, sigma),
36 | cv2.getGaussianKernel(gaussian_radius*2+1, sigma).T
37 | )
38 | x_left, x_right, y_up, y_down = 0, gaussian_map.shape[1], 0, gaussian_map.shape[0]
39 | # cut the gaussian kernel
40 | if p[1] < 0 or p[0] < 0:
41 | continue
42 | if p[1] < gaussian_radius:
43 | x_left = gaussian_radius - p[1]
44 | if p[0] < gaussian_radius:
45 | y_up = gaussian_radius - p[0]
46 | if p[1] + gaussian_radius >= w:
47 | x_right = gaussian_map.shape[1] - (gaussian_radius + p[1] - w) - 1
48 | if p[0] + gaussian_radius >= h:
49 | y_down = gaussian_map.shape[0] - (gaussian_radius + p[0] - h) - 1
50 | density_map[
51 | max(0, p[0]-gaussian_radius):min(density_map.shape[0], p[0]+gaussian_radius+1),
52 | max(0, p[1]-gaussian_radius):min(density_map.shape[1], p[1]+gaussian_radius+1)
53 | ] += gaussian_map[y_up:y_down, x_left:x_right]
54 | return density_map
55 |
56 |
57 | # 22, 37
58 | if __name__ == "__main__":
59 | image_dir_path = "/home/zzn/part_B_final/test_data/images"
60 | ground_truth_dir_path = "/home/zzn/part_B_final/test_data/ground_truth"
61 | output_gt_dir = "/home/zzn/part_B_final/test_data/gt_map"
62 | for i in range(316):
63 | img_path = image_dir_path + "/IMG_" + str(i + 1) + ".jpg"
64 | gt_path = ground_truth_dir_path + "/GT_IMG_" + str(i + 1) + ".mat"
65 | img = Image.open(img_path)
66 | height = img.size[1]
67 | weight = img.size[0]
68 | points = scio.loadmat(gt_path)['image_info'][0][0][0][0][0]
69 | gt = get_density_map_gaussian(height, weight, points, False, 5)
70 | gt = np.reshape(gt, [height, weight]) # transpose into w, h
71 | np.save(output_gt_dir + "/GT_IMG_" + str(i + 1), gt)
72 | print("complete!")
73 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from net import *
3 | from PIL import Image
4 | import time
5 | import numpy as np
6 | import random
7 | from ssim_loss import *
8 | from metrics import *
9 | from DataConstructor import *
10 | from metrics import *
11 | from utils import show
12 | from FPN_SAN_Net import *
13 | import sys
14 | import torchvision.transforms as transforms
15 | MAE = 10240000
16 | MSE = 10240000
17 | RATE = 10000000
18 | SHANGHAITECH = "B"
19 | # %matplotlib inline
20 | # data_load
21 | img_dir = "/home/zzn/part_" + SHANGHAITECH + "_final/train_data/images"
22 | gt_dir = "/home/zzn/part_" + SHANGHAITECH + "_final/train_data/gt_map"
23 |
24 | img_dir_t = "/home/zzn/part_" + SHANGHAITECH + "_final/test_data/images"
25 | gt_dir_t = "/home/zzn/part_" + SHANGHAITECH + "_final/test_data/gt_map"
26 |
27 | dataset = DatasetConstructor(img_dir, gt_dir, 400, 50)
28 | test_data_set = DatasetConstructor(img_dir_t, gt_dir_t, 316, 50, False)
29 |
30 | train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4)
31 | eval_loader = torch.utils.data.DataLoader(dataset=test_data_set, batch_size=1)
32 | # obtain the gpu device
33 | assert torch.cuda.is_available()
34 | cuda_device = torch.device("cuda")
35 |
36 | # model construct
37 | # net = FPN_SA_Net().to(cuda_device)
38 | net = SANet().to(cuda_device)
39 | # net = torch.load("/home/zzn/PycharmProjects/SANet_pytoch/checkpoints/model_1_in_time_0316_14:18.pkl").to(cuda_device)
40 | # set optimizer and estimator
41 | criterion = SANetLoss(1).to(cuda_device)
42 | optimizer = torch.optim.Adam(net.parameters(), 1e-6)
43 | ae_batch = AEBatch().to(cuda_device)
44 | se_batch = SEBatch().to(cuda_device)
45 | step = 0
46 | for epoch_index in range(10000):
47 | # dataset = dataset.train_model().shuffle()
48 | dataset = dataset.shuffle()
49 | for train_img_index, train_img, train_gt, data_ptc in train_loader:
50 | # eval per 100 batch
51 | if step % 100 == 0:
52 | net.eval()
53 | test_data_set = test_data_set.shuffle()
54 | loss_ = []
55 | MAE_ = []
56 | MSE_ = []
57 | difference_rates = []
58 |
59 | rand_number = random.randint(0, 19)
60 | counter = 0
61 |
62 | for eval_img_index, eval_img, eval_gt, eval_data_ptc in eval_loader:
63 |
64 | image_shape = eval_img.shape
65 | patch_height = int(image_shape[3])
66 | patch_width = int(image_shape[4])
67 | # B
68 | eval_x = eval_img.view(49, 3, patch_height, patch_width)
69 | eval_y = eval_gt.view(1, 1, patch_height * 4, patch_width * 4).cuda()
70 | prediction_map = torch.zeros(1, 1, patch_height * 4, patch_width * 4).cuda()
71 | for i in range(7):
72 | for j in range(7):
73 | eval_x_sample = eval_x[i * 7 + j:i * 7 + j + 1].cuda()
74 | eval_prediction = net(eval_x_sample)
75 | start_h = int(patch_height / 4)
76 | start_w = int(patch_width / 4)
77 | valid_h = int(patch_height / 2)
78 | valid_w = int(patch_width / 2)
79 | h_pred = 3 * int(patch_height / 4) + 2 * int(patch_height / 4) * (i - 1)
80 | w_pred = 3 * int(patch_width / 4) + 2 * int(patch_width / 4) * (j - 1)
81 | if i == 0:
82 | valid_h = int((3 * patch_height) / 4)
83 | start_h = 0
84 | h_pred = 0
85 | elif i == 6:
86 | valid_h = int((3 * patch_height) / 4)
87 |
88 | if j == 0:
89 | valid_w = int((3 * patch_width) / 4)
90 | start_w = 0
91 | w_pred = 0
92 | elif j == 6:
93 | valid_w = int((3 * patch_width) / 4)
94 |
95 | prediction_map[:, :, h_pred:h_pred + valid_h, w_pred:w_pred + valid_w] += eval_prediction[:, :,
96 | start_h:start_h + valid_h,
97 | start_w:start_w + valid_w]
98 | # That’s because numpy doesn’t support CUDA,
99 | # so there’s no way to make it use GPU memory without a copy to CPU first.
100 | # Remember that .numpy() doesn’t do any copy,
101 | # but returns an array that uses the same memory as the tensor
102 | eval_loss = criterion(prediction_map, eval_y).data.cpu().numpy()
103 | batch_ae = ae_batch(prediction_map, eval_y).data.cpu().numpy()
104 | batch_se = se_batch(prediction_map, eval_y).data.cpu().numpy()
105 |
106 | validate_pred_map = np.squeeze(prediction_map.permute(0, 2, 3, 1).data.cpu().numpy())
107 | validate_gt_map = np.squeeze(eval_y.permute(0, 2, 3, 1).data.cpu().numpy())
108 | gt_counts = np.sum(validate_gt_map)
109 | pred_counts = np.sum(validate_pred_map)
110 | # random show 1 sample
111 | if rand_number == counter and step % 2000 == 0:
112 | origin_image = Image.open("/home/zzn/part_" + SHANGHAITECH + "_final/test_data/images/IMG_" + str(
113 | eval_img_index.numpy()[0]) + ".jpg")
114 | show(origin_image, validate_gt_map, validate_pred_map, eval_img_index.numpy()[0])
115 | sys.stdout.write(
116 | 'The gt counts of the above sample:{}, and the pred counts:{}\n'.format(gt_counts, pred_counts))
117 |
118 | difference_rates.append(np.abs(gt_counts - pred_counts) / gt_counts)
119 | loss_.append(eval_loss)
120 | MAE_.append(batch_ae)
121 | MSE_.append(batch_se)
122 | counter += 1
123 |
124 | # calculate the validate loss, validate MAE and validate RMSE
125 | loss_ = np.reshape(loss_, [-1])
126 | MAE_ = np.reshape(MAE_, [-1])
127 | MSE_ = np.reshape(MSE_, [-1])
128 |
129 | validate_loss = np.mean(loss_)
130 | validate_MAE = np.mean(MAE_)
131 | validate_RMSE = np.sqrt(np.mean(MSE_))
132 | validate_rate = np.mean(difference_rates)
133 |
134 | sys.stdout.write(
135 | 'In step {}, epoch {}, with loss {}, rate = {}, MAE = {}, MSE = {}\n'.format(step, epoch_index + 1,
136 | validate_loss,
137 | validate_rate,
138 | validate_MAE,
139 | validate_RMSE))
140 | sys.stdout.flush()
141 |
142 | if RATE > validate_rate:
143 | RATE = validate_rate
144 | torch.save(net, "/home/zzn/PycharmProjects/SANet_pytoch/checkpoints/model_1_rate_b.pkl")
145 |
146 | # save model
147 | if MAE > validate_MAE:
148 | MAE = validate_MAE
149 | torch.save(net, "/home/zzn/PycharmProjects/SANet_pytoch/checkpoints/model_1_mae_b.pkl")
150 |
151 | # save model
152 | if MSE > validate_RMSE:
153 | MSE = validate_RMSE
154 | torch.save(net, "/home/zzn/PycharmProjects/SANet_pytoch/checkpoints/model_1_mse_b.pkl")
155 |
156 | torch.save(net, "/home/zzn/PycharmProjects/SANet_pytoch/checkpoints/model_1_in_time.pkl")
157 |
158 | # return train model
159 |
160 | net.train()
161 | # dataset = dataset.train_model()
162 | optimizer.zero_grad()
163 | # B
164 | x = train_img.cuda()
165 | y = train_gt.cuda()
166 |
167 | prediction = net(x)
168 | loss = criterion(prediction, y)
169 | loss.backward()
170 | optimizer.step()
171 | step += 1
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from ssim_loss import *
4 |
5 |
6 | class SANetLoss(nn.Module):
7 | def __init__(self, in_channels, size=5, sigma=1.5, size_average=True):
8 | super(SANetLoss, self).__init__()
9 | self.ssim_loss = SSIM_Loss(in_channels, size, sigma, size_average)
10 |
11 | def forward(self, estimated_density_map, gt_map):
12 | height = estimated_density_map.shape[2]
13 | width = estimated_density_map.shape[3]
14 | loss_c = self.ssim_loss(estimated_density_map, gt_map)
15 | loss_e = torch.mean((estimated_density_map - gt_map) ** 2, dim=(0, 1, 2, 3))
16 | return torch.mul(torch.add(torch.mul(loss_c, 0.001), loss_e), height * width)
17 |
18 |
19 | class ScalingLoss(nn.Module):
20 | def __init__(self):
21 | super(ScalingLoss, self).__init__()
22 |
23 | def forward(self, x_map, gt_map):
24 | gt_counts = torch.sum(gt_map, dim=(1, 2, 3))
25 | x_counts = torch.sum(x_map, dim=(1, 2, 3))
26 | return torch.mean(gt_counts.sub(x_counts).div(torch.add(gt_counts, 1)).pow(2))
27 |
28 |
29 | class AEBatch(nn.Module):
30 | def __init__(self):
31 | super(AEBatch, self).__init__()
32 |
33 | def forward(self, estimated_density_map, gt_map):
34 | return torch.abs(torch.sum(estimated_density_map - gt_map, dim=(1, 2, 3)))
35 |
36 |
37 | class SEBatch(nn.Module):
38 | def __init__(self):
39 | super(SEBatch, self).__init__()
40 |
41 | def forward(self, estimated_density_map, gt_map):
42 | return torch.pow(torch.sum(estimated_density_map - gt_map, dim=(1, 2, 3)), 2)
43 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class BasicConv(nn.Module):
9 | def __init__(self, in_channels, out_channels, use_bn=False, **kwargs):
10 | super(BasicConv, self).__init__()
11 | self.use_bn = use_bn
12 | self.conv = nn.Conv2d(in_channels, out_channels, bias=not self.use_bn, **kwargs)
13 | self.bn = nn.InstanceNorm2d(out_channels, affine=True) if self.use_bn else None
14 |
15 | def forward(self, x):
16 | x = self.conv(x)
17 | x = F.relu(x, inplace=True)
18 | if self.use_bn:
19 | x = self.bn(x)
20 | return x
21 |
22 |
23 | class BasicDeconv(nn.Module):
24 | def __init__(self, in_channels, out_channels, use_bn=False, **kwargs):
25 | super(BasicDeconv, self).__init__()
26 | self.use_bn = use_bn
27 | self.tconv = nn.ConvTranspose2d(in_channels, out_channels, bias=not self.use_bn, **kwargs)
28 | self.bn = nn.InstanceNorm2d(out_channels, affine=True) if self.use_bn else None
29 |
30 | def forward(self, x):
31 | x = self.tconv(x)
32 | # if self.use_bn:
33 | # x = self.bn(x)
34 | # return F.relu(x, inplace=True)
35 | x = F.relu(x, inplace=True)
36 | if self.use_bn:
37 | x = self.bn(x)
38 | return x
39 |
40 |
41 | class SAModule_Head(nn.Module):
42 | def __init__(self, in_channels, out_channels, use_bn):
43 | super(SAModule_Head, self).__init__()
44 | branch_out = out_channels // 4
45 | self.branch1x1 = BasicConv(in_channels, branch_out, use_bn=use_bn,
46 | kernel_size=1)
47 | self.branch3x3 = BasicConv(in_channels, branch_out, use_bn=use_bn,
48 | kernel_size=3, padding=1)
49 | self.branch5x5 = BasicConv(in_channels, branch_out, use_bn=use_bn,
50 | kernel_size=5, padding=2)
51 | self.branch7x7 = BasicConv(in_channels, branch_out, use_bn=use_bn,
52 | kernel_size=7, padding=3)
53 |
54 | def forward(self, x):
55 | branch1x1 = self.branch1x1(x)
56 | branch3x3 = self.branch3x3(x)
57 | branch5x5 = self.branch5x5(x)
58 | branch7x7 = self.branch7x7(x)
59 | out = torch.cat([branch1x1, branch3x3, branch5x5, branch7x7], 1)
60 | return out
61 |
62 |
63 | class SAModule(nn.Module):
64 | def __init__(self, in_channels, out_channels, use_bn):
65 | super(SAModule, self).__init__()
66 | branch_out = out_channels // 4
67 | self.branch1x1 = BasicConv(in_channels, branch_out, use_bn=use_bn,
68 | kernel_size=1)
69 | self.branch3x3 = nn.Sequential(
70 | BasicConv(in_channels, 2 * branch_out, use_bn=use_bn,
71 | kernel_size=1),
72 | BasicConv(2 * branch_out, branch_out, use_bn=use_bn,
73 | kernel_size=3, padding=1),
74 | )
75 | self.branch5x5 = nn.Sequential(
76 | BasicConv(in_channels, 2 * branch_out, use_bn=use_bn,
77 | kernel_size=1),
78 | BasicConv(2 * branch_out, branch_out, use_bn=use_bn,
79 | kernel_size=5, padding=2),
80 | )
81 | self.branch7x7 = nn.Sequential(
82 | BasicConv(in_channels, 2 * branch_out, use_bn=use_bn,
83 | kernel_size=1),
84 | BasicConv(2 * branch_out, branch_out, use_bn=use_bn,
85 | kernel_size=7, padding=3),
86 | )
87 |
88 | def forward(self, x):
89 | branch1x1 = self.branch1x1(x)
90 | branch3x3 = self.branch3x3(x)
91 | branch5x5 = self.branch5x5(x)
92 | branch7x7 = self.branch7x7(x)
93 | out = torch.cat([branch1x1, branch3x3, branch5x5, branch7x7], 1)
94 | return out
95 |
96 |
97 | class SANet(nn.Module):
98 | def __init__(self, gray_input=False, use_bn=True):
99 | super(SANet, self).__init__()
100 | if gray_input:
101 | in_channels = 1
102 | else:
103 | in_channels = 3
104 |
105 | self.encoder = nn.Sequential(
106 | SAModule_Head(in_channels, 64, use_bn),
107 | nn.MaxPool2d(2, 2),
108 | SAModule(64, 128, use_bn),
109 | nn.MaxPool2d(2, 2),
110 | SAModule(128, 128, use_bn),
111 | nn.MaxPool2d(2, 2),
112 | SAModule(128, 128, use_bn),
113 | )
114 |
115 | self.decoder = nn.Sequential(
116 | BasicConv(128, 64, use_bn=use_bn, kernel_size=9, padding=4),
117 | BasicDeconv(64, 64, use_bn=use_bn, kernel_size=2, stride=2),
118 | BasicConv(64, 32, use_bn=use_bn, kernel_size=7, padding=3),
119 | BasicDeconv(32, 32, use_bn=use_bn, kernel_size=2, stride=2),
120 | BasicConv(32, 16, use_bn=use_bn, kernel_size=5, padding=2),
121 | BasicDeconv(16, 16, use_bn=use_bn, kernel_size=2, stride=2),
122 | BasicConv(16, 16, use_bn=use_bn, kernel_size=3, padding=1),
123 | BasicConv(16, 1, use_bn=False, kernel_size=1),
124 | )
125 |
126 | self._initialize_weights()
127 |
128 | def _initialize_weights(self):
129 | for m in self.modules():
130 | if isinstance(m, nn.InstanceNorm2d):
131 | nn.init.constant_(m.weight, 1)
132 | nn.init.constant_(m.bias, 0)
133 | elif isinstance(m, nn.Conv2d):
134 | nn.init.normal_(m.weight, std=0.01)
135 | if m.bias is not None:
136 | nn.init.constant_(m.bias, 0)
137 | elif isinstance(m, nn.ConvTranspose2d):
138 | nn.init.normal_(m.weight, std=0.01)
139 | if m.bias is not None:
140 | nn.init.constant_(m.bias, 0)
141 |
142 | def forward(self, x):
143 | features = self.encoder(x)
144 | out = self.decoder(features)
145 | return out
146 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import matplotlib.pyplot as plt
4 | import sys
5 | import numpy as np
6 | import random
7 | from utils import *
8 | import torchvision.transforms as transforms
9 | from DataConstructor import DatasetConstructor
10 | import metrics
11 | from PIL import Image
12 | import time
13 | SHANGHAITECH = "B"
14 | # %matplotlib inline
15 | # obtain the gpu device
16 | assert torch.cuda.is_available()
17 | cuda_device = torch.device("cuda") # device object representing GPU
18 | # data_load
19 | img_dir = "/home/zzn/part_" + SHANGHAITECH + "_final/test_data/images"
20 | gt_dir = "/home/zzn/part_" + SHANGHAITECH + "_final/test_data/gt_map"
21 | dataset = DatasetConstructor(img_dir, gt_dir, 316, 316, False)
22 | test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1)
23 | mae_metrics = []
24 | mse_metrics = []
25 | net = torch.load("/home/zzn/PycharmProjects/SANet_pytoch/checkpoints/model_1_rate_b_0315_19:35.pkl").to(cuda_device)
26 | net.eval()
27 |
28 |
29 | ae_batch = metrics.AEBatch().to(cuda_device)
30 | se_batch = metrics.SEBatch().to(cuda_device)
31 |
32 | for real_index, test_img, test_gt, test_time_cost in test_loader:
33 | image_shape = test_img.shape
34 | patch_height = int(image_shape[3])
35 | patch_width = int(image_shape[4])
36 | # B
37 | eval_x = test_img.view(49, 3, patch_height, patch_width)
38 | eval_y = test_gt.view(1, 1, patch_height * 4, patch_width * 4).cuda()
39 | prediction_map = torch.zeros(1, 1, patch_height * 4, patch_width * 4).cuda()
40 | for i in range(7):
41 | for j in range(7):
42 | eval_x_sample = eval_x[i * 7 + j:i * 7 + j + 1].cuda()
43 | eval_y_sample = eval_y[i * 7 + j:i * 7 + j + 1].cuda()
44 | eval_prediction = net(eval_x_sample)
45 | start_h = int(patch_height / 4)
46 | start_w = int(patch_width / 4)
47 | valid_h = int(patch_height / 2)
48 | valid_w = int(patch_width / 2)
49 | h_pred = 3 * int(patch_height / 4) + 2 * int(patch_height / 4) * (i - 1)
50 | w_pred = 3 * int(patch_width / 4) + 2 * int(patch_width / 4) * (j - 1)
51 | if i == 0:
52 | valid_h = int((3 * patch_height) / 4)
53 | start_h = 0
54 | h_pred = 0
55 | elif i == 6:
56 | valid_h = int((3 * patch_height) / 4)
57 |
58 | if j == 0:
59 | valid_w = int((3 * patch_width) / 4)
60 | start_w = 0
61 | w_pred = 0
62 | elif j == 6:
63 | valid_w = int((3 * patch_width) / 4)
64 |
65 | prediction_map[:, :, h_pred:h_pred + valid_h, w_pred:w_pred + valid_w] += eval_prediction[:, :,
66 | start_h:start_h + valid_h,
67 | start_w:start_w + valid_w]
68 |
69 | batch_ae = ae_batch(prediction_map, eval_y).data.cpu().numpy()
70 | batch_se = se_batch(prediction_map, eval_y).data.cpu().numpy()
71 | mae_metrics.append(batch_ae)
72 | mse_metrics.append(batch_se)
73 | # to numpy
74 | numpy_predict_map = prediction_map.permute(0, 2, 3, 1).data.cpu().numpy()
75 | numpy_gt_map = eval_y.permute(0, 2, 3, 1).data.cpu().numpy()
76 |
77 | # show current prediction
78 | figure, (origin, dm_gt, dm_pred) = plt.subplots(1, 3, figsize=(20, 4))
79 | origin.imshow(Image.open("/home/zzn/part_B_final/test_data/images/IMG_" + str(real_index.numpy()[0]) + ".jpg"))
80 | origin.set_title('Origin Image')
81 | dm_gt.imshow(np.squeeze(numpy_gt_map), cmap=plt.cm.jet)
82 | dm_gt.set_title('ground_truth_1')
83 |
84 | dm_pred.imshow(np.squeeze(numpy_predict_map), cmap=plt.cm.jet)
85 | dm_pred.set_title('prediction')
86 |
87 | plt.suptitle('The ' + str(real_index.numpy()[0]) + 'th images\'prediction')
88 | plt.show()
89 | sys.stdout.write('The grount truth crowd number is:{}, and the predicting number is:{}'.format(np.sum(numpy_gt_map),
90 | np.sum(
91 | numpy_predict_map)))
92 | sys.stdout.flush()
93 |
94 | mae_metrics = np.reshape(mae_metrics, [-1])
95 | mse_metrics = np.reshape(mse_metrics, [-1])
96 | MAE = np.mean(mae_metrics)
97 | MSE = np.sqrt(np.mean(mse_metrics))
98 | print('MAE:', MAE, 'MSE:', MSE)
--------------------------------------------------------------------------------
/ssim_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import torch
4 | from torch.nn.parameter import Parameter
5 | from torch.nn import functional as F
6 | from torch.autograd import Variable
7 | from torch.nn.modules.loss import _Loss
8 | import numpy as np
9 |
10 |
11 | def gaussian_kernel(size, sigma):
12 | x, y = np.mgrid[-size:size + 1, -size:size + 1]
13 | kernel = np.exp(-0.5 * (x * x + y * y) / (sigma * sigma))
14 | kernel /= kernel.sum()
15 | return kernel
16 |
17 |
18 | class SSIM_Loss(_Loss):
19 | def __init__(self, in_channels, size=5, sigma=1.5, size_average=True):
20 | super(SSIM_Loss, self).__init__(size_average)
21 | # assert in_channels == 1, 'Only support single-channel input'
22 | self.in_channels = in_channels
23 | self.size = int(size)
24 | self.sigma = sigma
25 | self.size_average = size_average
26 |
27 | kernel = gaussian_kernel(self.size, self.sigma)
28 | self.kernel_size = kernel.shape
29 | weight = np.tile(kernel, (in_channels, 1, 1, 1))
30 | self.weight = Parameter(torch.from_numpy(weight).float(), requires_grad=False)
31 |
32 | def forward(self, input, target, mask=None):
33 | mean1 = F.conv2d(input, self.weight, padding=self.size, groups=self.in_channels)
34 | mean2 = F.conv2d(target, self.weight, padding=self.size, groups=self.in_channels)
35 | mean1_sq = mean1 * mean1
36 | mean2_sq = mean2 * mean2
37 | mean_12 = mean1 * mean2
38 |
39 | sigma1_sq = F.conv2d(input * input, self.weight, padding=self.size, groups=self.in_channels) - mean1_sq
40 | sigma2_sq = F.conv2d(target * target, self.weight, padding=self.size, groups=self.in_channels) - mean2_sq
41 | sigma_12 = F.conv2d(input * target, self.weight, padding=self.size, groups=self.in_channels) - mean_12
42 |
43 | C1 = 0.01 ** 2
44 | C2 = 0.03 ** 2
45 |
46 | ssim = ((2 * mean_12 + C1) * (2 * sigma_12 + C2)) / ((mean1_sq + mean2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
47 | if self.size_average:
48 | out = 1 - ssim.mean()
49 | else:
50 | out = 1 - ssim.view(ssim.size(0), -1).mean(1)
51 | return out
52 |
53 |
54 | if __name__ == '__main__':
55 | data = torch.zeros(1, 1, 1, 1)
56 | data += 0.001
57 | target = torch.zeros(1, 1, 1, 1)
58 | data = Variable(data, requires_grad=True)
59 | target = Variable(target)
60 |
61 | model = SSIM_Loss(1)
62 | loss = model(data, target)
63 | loss.backward()
64 | print(loss)
65 | print(data.grad)
66 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | from PIL import Image
6 | import torchvision.transforms as transforms
7 | import torchvision.transforms.functional as F
8 |
9 |
10 | def show(origin_map, gt_map, predict, index):
11 | figure, (origin, gt, pred) = plt.subplots(1, 3, figsize=(20, 4))
12 | origin.imshow(origin_map)
13 | origin.set_title("origin picture")
14 | gt.imshow(gt_map, cmap=plt.cm.jet)
15 | gt.set_title("gt map")
16 | pred.imshow(predict, cmap=plt.cm.jet)
17 | pred.set_title("prediction")
18 | plt.suptitle(str(index) + "th sample")
19 | plt.show()
20 | plt.close()
21 |
22 |
23 | def show_phase2(origin_map, gt_map, predict_1, predict_2, index):
24 | figure, (origin, gt, pred_1, pred_2) = plt.subplots(1, 4, figsize=(20, 4))
25 | origin.imshow(origin_map)
26 | origin.set_title("origin picture")
27 | gt.imshow(gt_map, cmap=plt.cm.jet)
28 | gt.set_title("gt map")
29 | pred_1.imshow(predict_1, cmap=plt.cm.jet)
30 | pred_1.set_title("prediction_phase_1")
31 | pred_2.imshow(predict_2, cmap=plt.cm.jet)
32 | pred_2.set_title("prediction_phase_2")
33 | plt.suptitle(str(index) + "th sample")
34 | plt.show()
35 | plt.close()
36 |
37 |
38 | class HSI_Calculator(nn.Module):
39 | def __init__(self):
40 | super(HSI_Calculator, self).__init__()
41 |
42 | def forward(self, image):
43 | image = transforms.ToTensor()(image)
44 | I = torch.mean(image)
45 | Sum = image.sum(0)
46 | Min = 3 * image.min(0)[0]
47 | S = (1 - Min.div(Sum.clamp(1e-6))).mean()
48 | numerator = (2 * image[0] - image[1] - image[2]) / 2
49 | denominator = ((image[0] - image[1]) ** 2 + (image[0] - image[2]) * (image[1] - image[2])).sqrt()
50 | theta = (numerator.div(denominator.clamp(1e-6))).clamp(-1 + 1e-6, 1 - 1e-6).acos()
51 | logistic_matrix = (image[1] - image[2]).ceil()
52 | H = (theta * logistic_matrix + (1 - logistic_matrix) * (360 - theta)).mean() / 360
53 | return H, S, I
54 |
55 | #
56 | # test = Image.open("/home/zzn/part_B_final/test_data/images/IMG_100.jpg")
57 | # new = F.adjust_brightness(test, 0.43 / 0.376)
58 | # figure, (origin, new_fig) = plt.subplots(1, 2, figsize=(40, 4))
59 | # origin.imshow(test)
60 | # new_fig.imshow(new)
61 | # plt.show()
62 | # calcu = HSI_Calculator()
63 | # H, S, I = calcu(test)
64 | # print(H, S, I)
--------------------------------------------------------------------------------