├── README.md ├── data_preparation ├── create_nine_patches.m ├── create_one_patch.m ├── create_same_size_patches.m ├── get_circumscribed_rectangle_roi.m ├── get_density_map_gaussian.m ├── get_density_map_gaussian_different_kernel.m ├── get_density_map_gaussian_perspective.m ├── get_inscribed_rectangle_roi.m ├── get_mask_map.m ├── get_mask_map_airport.m ├── get_mask_map_ucsd.m ├── get_mask_map_worldexpo.m ├── match_perspective_map.m └── view_density_map.m ├── src ├── __init__.py ├── computing_bar.py ├── crowd_count.py ├── data_multithread.py ├── data_multithread_preload.py ├── data_path.py ├── evaluate_model.py ├── models.py ├── network.py ├── psnr.py ├── ssim.py └── utils.py └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # ASNet 2 | 3 | This is the repo for [Attention scaling for crowd counting](http://openaccess.thecvf.com/content_CVPR_2020/html/Jiang_Attention_Scaling_for_Crowd_Counting_CVPR_2020_paper.html) in CVPR 2020, which proposed a simple but effective architecture for crowd counting. 4 | 5 | ## Requirements 6 | 7 | In an Anaconda environment, we need the following packages. 8 | 9 | Python: 3.7.5 10 | 11 | PyTorch: 1.3.1 12 | 13 | ## Results 14 | 15 | For ShanghaiTech Part A dataset 16 | 17 | Model: [Link](https://pan.baidu.com/s/1jQgBsDy90UfzlLafXgTcXQ), Password: 585s 18 | 19 | ## References 20 | 21 | If you find the ASNet useful, please cite our paper. Thank you! 22 | 23 | ``` 24 | @inproceedings{jiang2020attention, 25 | title={Attention scaling for crowd counting}, 26 | author={Jiang, Xiaoheng and Zhang, Li and Xu, Mingliang and Zhang, Tianzhu and Lv, Pei and Zhou, Bing and Yang, Xin and Pang, Yanwei}, 27 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 28 | pages={4706--4715}, 29 | month={June}, 30 | year={2020} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /data_preparation/create_nine_patches.m: -------------------------------------------------------------------------------- 1 | clc; clear all; 2 | seed = 95461354; 3 | rng(seed); 4 | 5 | % size of patchs will be number times TIMES 6 | % at least 4 7 | % default 16 8 | TIMES = 16; 9 | 10 | %if true, will get gray images(one channel) 11 | %if false, will get RGB images(three channel) 12 | is_gray_image = false; 13 | 14 | is_random_position = true;% default true 15 | 16 | % if true, will get four patches which are fixed position out of nine 17 | is_fixed_position = false; 18 | 19 | %if true, will get a smallest patch outside ROI of every image 20 | %if true, N must be 1 21 | %if true, must give a ROI(roi.mask) 22 | %if true, is_inscribed_rectangle must be false 23 | %if true, is_half_bottom must be false 24 | %if true, is_use_roi must be true 25 | is_circumscribed_rectangle = false; 26 | 27 | %if true, will get original patchs and overturned patchs 28 | is_overturn = true; 29 | 30 | %if true, will apply ROI to every image 31 | %if true, must give a ROI(roi.mask) 32 | is_use_roi = false; 33 | 34 | %if true, will save roi info for every image 35 | %not support N=16 now 36 | %if true, must give a ROI(roi.mask) 37 | %if true, is_use_roi must be true 38 | is_save_roi = false; 39 | 40 | % try to get patches in ROI 41 | is_all_in_roi = false; 42 | 43 | % try to get patches which have more than one pedestrain in each of them 44 | is_more_than_one_pedestrain = false; 45 | 46 | % resize image and annotation to fit max_height_width 47 | is_resize_to_max_height_width = false; 48 | max_height_width = 1024; 49 | 50 | is_shanghaitech = false; 51 | is_ucf_cc_50 = false; 52 | is_ucsd = false; 53 | is_worldexpo = false; 54 | is_airport = false; 55 | is_ucf_qnrf = false; 56 | is_gcc = true; 57 | if is_shanghaitech 58 | dataset = 'B'; 59 | dataset_name = ['shanghaitech_part_' dataset '_patches_9']; 60 | output_path = ['D:\Dataset\ShanghaiTech\formatted_trainval\']; 61 | img_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\train_data\images\']; 62 | gt_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\train_data\ground_truth\']; 63 | % img_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\test_data\images\']; 64 | % gt_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\test_data\ground_truth\']; 65 | elseif is_ucf_cc_50 66 | dataset_name = ['ucf_cc_50_patches_9']; 67 | img_path = 'D:\Dataset\UCF_CC_50\original\image\'; 68 | gt_path = 'D:\Dataset\UCF_CC_50\original\gt\'; 69 | output_path = 'D:\Dataset\UCF_CC_50\formatted_trainval\'; 70 | elseif is_worldexpo 71 | dataset_name = ['worldexpo_patches_9']; 72 | output_path = 'D:\Dataset\WorldExpo10\formatted_trainval\'; 73 | img_path = 'D:\Dataset\WorldExpo10\train_frame\'; 74 | gt_path = 'D:\Dataset\WorldExpo10\train_label\'; 75 | % img_path = 'D:\Dataset\WorldExpo10\test_frame\'; 76 | % gt_path = 'D:\Dataset\WorldExpo10\test_label\'; 77 | elseif is_airport 78 | dataset_name = ['airport_patches_9']; 79 | output_path = 'D:\Dataset\airport\formatted_trainval\'; 80 | img_path = 'D:\Dataset\airport\img\'; 81 | gt_path = 'D:\Dataset\airport\gt\'; 82 | roi_root_path = 'D:\Dataset\airport\roi\'; 83 | elseif is_ucf_qnrf 84 | dataset_name = ['ucf_qnrf_patches_1']; 85 | output_path = 'D:\Dataset\UCF-QNRF\kernel\'; 86 | img_path = 'D:\Dataset\UCF-QNRF\original\Train\'; 87 | gt_path = 'D:\Dataset\UCF-QNRF\original\Train\'; 88 | % img_path = 'D:\Dataset\UCF-QNRF\original\Test\'; 89 | % gt_path = 'D:\Dataset\UCF-QNRF\original\Test\'; 90 | elseif is_gcc 91 | dataset_name = 'gcc_patches_9'; 92 | output_path = ['D:\Dataset\GCC\ours\formatted_trainval\']; 93 | img_path = ['D:\Dataset\GCC\ours\original\train\jpgs\']; 94 | gt_path = ['D:\Dataset\GCC\ours\original\train\mats\']; 95 | end 96 | 97 | output_path_img = strcat(output_path, dataset_name,'\img\'); 98 | output_path_den = strcat(output_path, dataset_name,'\den\'); 99 | if is_save_roi 100 | output_path_roi = strcat(output_path, dataset_name,'\roi\'); 101 | end 102 | 103 | mkdir(output_path); 104 | mkdir(output_path_img); 105 | mkdir(output_path_den); 106 | if is_save_roi 107 | mkdir(output_path_roi); 108 | end 109 | 110 | dir_output = dir(fullfile(img_path,'*.jpg')); 111 | img_name_list = {dir_output.name}; 112 | if is_use_roi 113 | roi_container = containers.Map; 114 | roi_fourth_size_container = containers.Map; 115 | roi_eighth_size_container = containers.Map; 116 | end 117 | 118 | num_images = numel(img_name_list); 119 | for index = 1:num_images 120 | [~, img_name, ~] = fileparts(img_name_list{index}); 121 | 122 | if is_worldexpo 123 | scene_number = img_name(1:6); 124 | end 125 | 126 | if is_airport 127 | scene_number = img_name(1:2); 128 | end 129 | 130 | if (mod(index, 10)==0) 131 | fprintf(1,'Processing %3d/%d files\n', index, num_images); 132 | end 133 | 134 | if is_shanghaitech 135 | load(strcat(gt_path, 'GT_', img_name, '.mat')); 136 | im = imread(strcat(img_path, img_name, '.jpg')); 137 | elseif is_ucf_cc_50 138 | load(strcat(gt_path, img_name, '_ann.mat')) ; 139 | im = imread(strcat(img_path, img_name, '.jpg')); 140 | elseif is_worldexpo 141 | load(strcat(gt_path, scene_number, '\', img_name, '.mat')); 142 | im = imread(strcat(img_path, img_name, '.jpg')); 143 | if is_use_roi 144 | roi_path = strcat(gt_path, scene_number, '\roi.mat'); 145 | end 146 | elseif is_airport 147 | load(strcat(gt_path, img_name, '.mat')); 148 | im = imread(strcat(img_path, img_name, '.jpg')); 149 | if is_use_roi 150 | roi_path = strcat(roi_root_path, 'roi-', scene_number, '.mat'); 151 | end 152 | elseif is_ucf_qnrf 153 | load(strcat(gt_path, img_name, '_ann.mat')) ; 154 | im = imread(strcat(img_path, img_name, '.jpg')); 155 | elseif is_gcc 156 | load(strcat(gt_path, img_name, '.mat')) ; 157 | im = imread(strcat(img_path, img_name, '.jpg')); 158 | end 159 | 160 | [height, width, channel] = size(im); 161 | if is_gray_image 162 | if (channel == 3) 163 | im = rgb2gray(im); 164 | elseif (channel == 1) 165 | im = im; 166 | end 167 | else 168 | if (channel == 3) 169 | im = im; 170 | elseif (channel == 1) 171 | im_original = im; 172 | im = uint8(zeros(height, width, 3)); 173 | im(:, :, 1) = im_original; 174 | im(:, :, 2) = im_original; 175 | im(:, :, 3) = im_original; 176 | end 177 | end 178 | 179 | if is_shanghaitech 180 | annPoints = image_info{1}.location; 181 | elseif is_ucf_cc_50 182 | % nothing need to do here 183 | elseif is_worldexpo 184 | annPoints = point_position; 185 | elseif is_airport 186 | annPoints = image_info.location; 187 | elseif is_gcc 188 | annPoints = image_info.location; 189 | end 190 | 191 | if is_resize_to_max_height_width 192 | [height, width, ~] = size(im); 193 | if height > width 194 | resized_height = max_height_width; 195 | resized_width = round(width / height * resized_height); 196 | else 197 | resized_width = max_height_width; 198 | resized_height = round(height / width * resized_width); 199 | end 200 | im = imresize(im, [resized_height, resized_width], 'bilinear'); 201 | annPoints(:, 1) = annPoints(:, 1) / width * resized_width; 202 | annPoints(:, 2) = annPoints(:, 2) / height * resized_height; 203 | if is_use_roi 204 | error('roi is not supported'); 205 | end 206 | end 207 | 208 | im_density = get_density_map_gaussian(im, annPoints, 15, 4); 209 | 210 | % view_density_map(im_density); 211 | 212 | % fid=fopen('D:\count.txt','a'); 213 | % fprintf(fid, '%s\t%f\n', img_name, sum(im_density(:))); 214 | % fclose(fid); 215 | 216 | if is_worldexpo 217 | if is_use_roi 218 | if roi_container.isKey(scene_number) 219 | roi_original = roi_container(scene_number); 220 | roi_fourth_size = roi_fourth_size_container(scene_number); 221 | roi_eighth_size = roi_eighth_size_container(scene_number); 222 | else 223 | roi_original = get_mask_map_worldexpo(roi_path); 224 | roi_container(scene_number) = roi_original; 225 | roi_fourth_size = get_mask_map_worldexpo(roi_path, 'fourth'); 226 | roi_fourth_size_container(scene_number) = roi_fourth_size; 227 | roi_eighth_size = get_mask_map_worldexpo(roi_path, 'eighth'); 228 | roi_eighth_size_container(scene_number) = roi_eighth_size; 229 | end 230 | im_density = im_density .* roi_original.mask; 231 | end 232 | end 233 | 234 | if is_airport 235 | if is_use_roi 236 | if roi_container.isKey(scene_number) 237 | roi_original = roi_container(scene_number); 238 | roi_fourth_size = roi_fourth_size_container(scene_number); 239 | roi_eighth_size = roi_eighth_size_container(scene_number); 240 | else 241 | roi_original = get_mask_map_airport(roi_path, TIMES); 242 | roi_container(scene_number) = roi_original; 243 | roi_fourth_size = get_mask_map_airport(roi_path, TIMES, 'fourth'); 244 | roi_fourth_size_container(scene_number) = roi_fourth_size; 245 | roi_eighth_size = get_mask_map_airport(roi_path, TIMES, 'eighth'); 246 | roi_eighth_size_container(scene_number) = roi_eighth_size; 247 | end 248 | im_density = im_density .* roi_original.matrix; 249 | end 250 | end 251 | 252 | if is_circumscribed_rectangle 253 | [top, bottom, left, right] = get_circumscribed_rectangle_roi(roi_original.mask(:,:)); 254 | if is_offset 255 | offset_height = min(top, height - bottom) * (rand - 0.5) * 2; 256 | offset_width = min(left, width - right) * (rand - 0.5) * 2; 257 | top = max(floor(top + offset_height), 1); 258 | bottom = min(floor(bottom + offset_height), height); 259 | left = max(floor(left + offset_width), 1); 260 | right = min(floor(right + offset_width), width); 261 | end 262 | rectangle_height = TIMES * floor((bottom - top) / TIMES); 263 | rectangle_width = TIMES * floor((right - left) / TIMES); 264 | bottom = top + rectangle_height - 1; 265 | right = left + rectangle_width - 1; 266 | top_fourth_size = max(round(top / 4), 1); 267 | left_fourth_size = max(round(left / 4), 1); 268 | bottom_fourth_size = top_fourth_size + rectangle_height / 4 - 1; 269 | right_fourth_size = left_fourth_size + rectangle_width / 4 - 1; 270 | top_eighth_size = max(round(top / 8), 1); 271 | left_eighth_size = max(round(left / 8), 1); 272 | bottom_eighth_size = top_eighth_size + rectangle_height / 8 - 1; 273 | right_eighth_size = left_eighth_size + rectangle_width / 8 - 1; 274 | 275 | if is_gray_image 276 | im = im(top:bottom, left:right); 277 | else 278 | im = im(top:bottom, left:right, :); 279 | end 280 | im_density = im_density(top:bottom, left:right); 281 | roi_original.mask = roi_original.mask(top:bottom, left:right); 282 | roi_fourth_size.mask = roi_fourth_size.mask(top_fourth_size:bottom_fourth_size, left_fourth_size:right_fourth_size); 283 | roi_eighth_size.mask = roi_eighth_size.mask(top_eighth_size:bottom_eighth_size, left_eighth_size:right_eighth_size); 284 | roi_original.matrix = roi_original.matrix(top:bottom, left:right); 285 | roi_fourth_size.matrix = roi_fourth_size.matrix(top_fourth_size:bottom_fourth_size, left_fourth_size:right_fourth_size); 286 | roi_eighth_size.matrix = roi_eighth_size.matrix(top_eighth_size:bottom_eighth_size, left_eighth_size:right_eighth_size); 287 | end 288 | 289 | [height, width, ~] = size(im); 290 | half_width_of_patches = floor(width/4); 291 | half_height_of_patches = floor(height/4); 292 | half_width_of_patches = TIMES * floor(half_width_of_patches / TIMES); 293 | half_height_of_patches = TIMES * floor(half_height_of_patches / TIMES); 294 | half_width_of_patches_fourth_size = half_width_of_patches / 4; 295 | half_height_of_patches_fourth_size = half_height_of_patches / 4; 296 | half_width_of_patches_eighth_size = half_width_of_patches / 8; 297 | half_height_of_patches_eighth_size = half_height_of_patches / 8; 298 | 299 | a_width = half_width_of_patches+1; 300 | b_width = width - half_width_of_patches; 301 | a_height = half_height_of_patches+1; 302 | b_height = height - half_height_of_patches; 303 | 304 | j = 0; 305 | patience = 0; 306 | while(j < 9) 307 | x_position = -1; 308 | y_position = -1; 309 | 310 | if is_fixed_position && j < 4 311 | position = [0 1]; 312 | x_position = position(mod(j, 2) + 1); 313 | y_position = position(mod(j, 2) + 1); 314 | else 315 | if is_random_position 316 | x_position = rand; 317 | y_position = rand; 318 | else 319 | position = [0 0.5 1]; 320 | x_position = position(mod(j, 3) + 1); 321 | y_position = position(ceil(j / 3)); 322 | end 323 | end 324 | 325 | x = floor((b_width - a_width) * x_position + a_width); 326 | y = floor((b_height - a_height) * y_position + a_height); 327 | x1 = x - half_width_of_patches; 328 | y1 = y - half_height_of_patches; 329 | x2 = x + half_width_of_patches-1; 330 | y2 = y + half_height_of_patches-1; 331 | 332 | x_fourth_size = max(floor(x / 4), 1) + 1; 333 | y_fourth_size = max(floor(y / 4), 1) + 1; 334 | x1_fourth_size = x_fourth_size - half_width_of_patches_fourth_size; 335 | y1_fourth_size = y_fourth_size - half_height_of_patches_fourth_size; 336 | x2_fourth_size = x_fourth_size + half_width_of_patches_fourth_size - 1; 337 | y2_fourth_size = y_fourth_size + half_height_of_patches_fourth_size - 1; 338 | 339 | x_eighth_size = max(floor(x / 8), 1) + 1; 340 | y_eighth_size = max(floor(y / 8), 1) + 1; 341 | x1_eighth_size = x_eighth_size - half_width_of_patches_eighth_size; 342 | y1_eighth_size = y_eighth_size - half_height_of_patches_eighth_size; 343 | x2_eighth_size = x_eighth_size + half_width_of_patches_eighth_size - 1; 344 | y2_eighth_size = y_eighth_size + half_height_of_patches_eighth_size - 1; 345 | 346 | if is_gray_image 347 | im_sampled = im(y1:y2, x1:x2); 348 | else 349 | im_sampled = im(y1:y2, x1:x2, :); 350 | end 351 | im_density_sampled = im_density(y1:y2, x1:x2); 352 | if is_use_roi 353 | roi_sampled.mask = roi_original.mask(y1:y2,x1:x2); 354 | roi_fourth_size_sampled.mask = roi_fourth_size.mask(y1_fourth_size:y2_fourth_size, x1_fourth_size:x2_fourth_size); 355 | roi_eighth_size_sampled.mask = roi_eighth_size.mask(y1_eighth_size:y2_eighth_size, x1_eighth_size:x2_eighth_size); 356 | roi_sampled.matrix = roi_original.matrix(y1:y2,x1:x2); 357 | roi_fourth_size_sampled.matrix = roi_fourth_size.matrix(y1_fourth_size:y2_fourth_size, x1_fourth_size:x2_fourth_size); 358 | roi_eighth_size_sampled.matrix = roi_eighth_size.matrix(y1_eighth_size:y2_eighth_size, x1_eighth_size:x2_eighth_size); 359 | end 360 | 361 | if is_fixed_position && j < 4 362 | % nothing to do here 363 | else 364 | if is_all_in_roi && sum(roi_sampled.matrix(:)) ~= (y2 - y1 + 1) * (x2 - x1 + 1) 365 | patience = patience + 1; 366 | if patience < 81 367 | continue 368 | end 369 | elseif is_more_than_one_pedestrain && sum(im_density_sampled(:)) < 1.0 370 | patience = patience + 1; 371 | if patience < 81 372 | continue 373 | end 374 | end 375 | end 376 | patience = 0; 377 | j = j + 1; 378 | 379 | if is_worldexpo || is_airport 380 | img_idx = strcat(scene_number, '_', num2str(index), '_',num2str(j)); 381 | else 382 | img_idx = strcat(num2str(index), '_',num2str(j)); 383 | end 384 | 385 | imwrite(im_sampled, [output_path_img img_name '_' num2str(j) '.jpg']); 386 | csvwrite([output_path_den img_name '_' num2str(j) '.csv'], im_density_sampled); 387 | if is_save_roi 388 | roi = roi_sampled; 389 | save([output_path_roi img_name '_' num2str(j) '_roi.mat'], 'roi'); 390 | roi = roi_fourth_size_sampled; 391 | save([output_path_roi img_name '_' num2str(j) '_roi_fourth_size.mat'], 'roi'); 392 | roi = roi_eighth_size_sampled; 393 | save([output_path_roi img_name '_' num2str(j) '_roi_eighth_size.mat'], 'roi'); 394 | end 395 | if is_overturn 396 | im_sampled_overturn = fliplr(im_sampled); 397 | im_density_sampled_overturn = fliplr(im_density_sampled); 398 | if is_use_roi 399 | roi_sampled_overturn.mask = fliplr(roi_sampled.mask); 400 | roi_fourth_size_sampled_overturn.mask = fliplr(roi_fourth_size_sampled.mask); 401 | roi_eighth_size_sampled_overturn.mask = fliplr(roi_eighth_size_sampled.mask); 402 | roi_sampled_overturn.matrix = fliplr(roi_sampled.matrix); 403 | roi_fourth_size_sampled_overturn.matrix = fliplr(roi_fourth_size_sampled.matrix); 404 | roi_eighth_size_sampled_overturn.matrix = fliplr(roi_eighth_size_sampled.matrix); 405 | end 406 | imwrite(im_sampled_overturn, [output_path_img img_name '_' num2str(j) '_overturn.jpg']); 407 | csvwrite([output_path_den img_name '_' num2str(j) '_overturn.csv'], im_density_sampled_overturn); 408 | if is_save_roi 409 | roi = roi_sampled_overturn; 410 | save([output_path_roi img_name '_' num2str(j) '_overturn_roi.mat'], 'roi'); 411 | roi = roi_fourth_size_sampled_overturn; 412 | save([output_path_roi img_name '_' num2str(j) '_overturn_roi_fourth_size.mat'], 'roi'); 413 | roi = roi_eighth_size_sampled_overturn; 414 | save([output_path_roi img_name '_' num2str(j) '_overturn_roi_eighth_size.mat'], 'roi'); 415 | end 416 | end 417 | end 418 | end 419 | -------------------------------------------------------------------------------- /data_preparation/create_one_patch.m: -------------------------------------------------------------------------------- 1 | % create one patch from every image 2 | clc; clear all; 3 | seed = 95461354; 4 | rng(seed); 5 | 6 | % default 16 7 | TIMES = 16; 8 | 9 | %if true, will get gray images(one channel) 10 | %if false, will get RGB images(three channel) 11 | is_gray_image = false; 12 | 13 | %if true, will apply ROI to every image 14 | %if true, must give a ROI(roi.mask) 15 | is_use_roi = false; 16 | 17 | %if true, will save roi info for every image 18 | %if true, must give a ROI(roi.mask) 19 | %if true, is_use_roi must be true 20 | is_save_roi = false; 21 | 22 | %if true, is_use_roi must be true 23 | is_apply_roi_to_image = false; 24 | 25 | %if true, will get a smallest patch outside ROI of every image. then get patches from this patch 26 | %if true, must give a ROI(roi.mask) 27 | %if true, is_use_roi must be true 28 | is_circumscribed_rectangle = false; 29 | 30 | %if true, will get original patchs and overturned patchs 31 | is_overturn = false; 32 | 33 | %if true, will resize image and density map and roi to fit [TIMES] 34 | %if false, will crop image and density map and roi to fit [TIMES] 35 | %default true 36 | is_resize = false; 37 | 38 | %this has nothing to do with [is_resize] 39 | %use this to resize image and density map and roi 40 | %use for ucsd 41 | %default 1.0 42 | resize_level = 1.0; 43 | 44 | % resize image and annotation to fit max_height_width 45 | is_resize_to_max_height_width = false; 46 | max_height_width = 1024; 47 | 48 | is_shanghaitech = false; 49 | is_ucf_cc_50 = false; 50 | is_worldexpo = false; 51 | is_airport = false; 52 | is_ucsd = false; 53 | is_trancos = false; 54 | is_mall = false; 55 | is_ucf_qnrf = false; 56 | is_gcc = true; 57 | if is_shanghaitech 58 | dataset = 'B'; 59 | dataset_name = ['shanghaitech_part_' dataset '_patches_1']; 60 | root_save_path = ['D:\Dataset\ShanghaiTech\formatted_trainval\']; 61 | % img_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\train_data\images\']; 62 | % gt_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\train_data\ground_truth\']; 63 | img_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\test_data\images\']; 64 | gt_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\test_data\ground_truth\']; 65 | elseif is_ucf_cc_50 66 | dataset_name = 'ucf_cc_50_patches_1'; 67 | img_path = 'D:\Dataset\UCF_CC_50\original\image\'; 68 | gt_path = 'D:\Dataset\UCF_CC_50\original\gt\'; 69 | root_save_path = 'D:\Dataset\UCF_CC_50\formatted_trainval\'; 70 | elseif is_worldexpo 71 | dataset_name = 'worldexpo_patches_1'; 72 | root_save_path = 'D:\Dataset\WorldExpo10\formatted_trainval\'; 73 | img_path = 'D:\Dataset\WorldExpo10\original\train_frame\'; 74 | gt_path = 'D:\Dataset\WorldExpo10\original\train_label\'; 75 | roi_save_path = strcat(root_save_path, dataset_name,'\roi\'); 76 | % img_path = 'D:\Dataset\WorldExpo10\original\test_frame\'; 77 | % gt_path = 'D:\Dataset\WorldExpo10\original\test_label\'; 78 | % roi_save_path = strcat(root_save_path, dataset_name,'\roi\'); 79 | elseif is_airport 80 | dataset_name = 'airport_patches_1'; 81 | root_save_path = 'D:\Dataset\airport\formatted_trainval\'; 82 | img_path = 'D:\Dataset\airport\img\'; 83 | gt_path = 'D:\Dataset\airport\gt\'; 84 | roi_root_path = 'D:\Dataset\airport\roi\'; 85 | roi_save_path = strcat(root_save_path, dataset_name,'\roi\'); 86 | elseif is_ucsd 87 | dataset_name = 'ucsd_patches_1'; 88 | root_save_path = 'D:\Dataset\UCSD\formatted_trainval\'; 89 | img_path = 'D:\Dataset\UCSD\original\img\'; 90 | gt_path = 'D:\Dataset\UCSD\original\gt\'; 91 | roi_root_path = 'D:\Dataset\UCSD\original\roi\'; 92 | roi_save_path = strcat(root_save_path, dataset_name,'\roi\'); 93 | elseif is_trancos 94 | dataset_name = 'trancos_patches_1'; 95 | root_save_path = 'D:\Dataset\TRANCOS\original\formatted_trainval\'; 96 | img_path = 'D:\Dataset\TRANCOS\original\images\'; 97 | gt_path = 'D:\Dataset\TRANCOS\original\images\'; 98 | roi_root_path = 'D:\Dataset\TRANCOS\original\images\'; 99 | roi_save_path = strcat(root_save_path, dataset_name,'\roi\'); 100 | elseif is_mall 101 | dataset_name = 'mall_patches_1'; 102 | root_save_path = 'D:\Dataset\mall\formatted_trainval\'; 103 | img_path = 'D:\Dataset\mall\original\frames\'; 104 | gt_path = 'D:\Dataset\mall\original\'; 105 | roi_root_path = 'D:\Dataset\mall\original\'; 106 | roi_save_path = strcat(root_save_path, dataset_name,'\roi\'); 107 | elseif is_ucf_qnrf 108 | dataset_name = ['ucf_qnrf_patches_1']; 109 | root_save_path = 'D:\Dataset\UCF-QNRF\kernel\'; 110 | % img_path = 'D:\Dataset\UCF-QNRF\original\Train\'; 111 | % gt_path = 'D:\Dataset\UCF-QNRF\original\Train\'; 112 | img_path = 'D:\Dataset\UCF-QNRF\original\Test\'; 113 | gt_path = 'D:\Dataset\UCF-QNRF\original\Test\'; 114 | elseif is_gcc 115 | dataset_name = 'gcc_patches_1'; 116 | root_save_path = ['D:\Dataset\GCC\ours\formatted_trainval\']; 117 | % img_path = ['D:\Dataset\GCC\ours\original\train\jpgs\']; 118 | % gt_path = ['D:\Dataset\GCC\ours\original\train\mats\']; 119 | img_path = ['D:\Dataset\GCC\ours\original\val\jpgs\']; 120 | gt_path = ['D:\Dataset\GCC\ours\original\val\mats\']; 121 | end 122 | img_save_path = strcat(root_save_path, dataset_name,'\img\'); 123 | den_save_path = strcat(root_save_path, dataset_name,'\den\'); 124 | 125 | mkdir(root_save_path); 126 | mkdir(img_save_path); 127 | mkdir(den_save_path); 128 | if is_save_roi 129 | mkdir(roi_save_path); 130 | end 131 | 132 | if is_ucsd 133 | dir_output = dir(fullfile(img_path,'*.png')); 134 | else 135 | dir_output = dir(fullfile(img_path,'*.jpg')); 136 | end 137 | img_name_list = {dir_output.name}; 138 | 139 | if is_use_roi 140 | roi_container = containers.Map; 141 | roi_fourth_size_container = containers.Map; 142 | roi_eighth_size_container = containers.Map; 143 | end 144 | 145 | num_images = numel(img_name_list); 146 | for idx = 1:num_images 147 | [~, img_name, ~] = fileparts(img_name_list{idx}); 148 | 149 | if is_worldexpo 150 | scene_number = img_name(1:6); 151 | end 152 | if is_airport 153 | scene_number = img_name(1:2); 154 | end 155 | if is_ucsd 156 | scene_number = img_name(1:12); 157 | end 158 | if is_mall 159 | scene_number = img_name(5:10); 160 | end 161 | 162 | if (mod(idx, 10)==0) 163 | fprintf(1,'Processing %3d/%d files\n', idx, num_images); 164 | end 165 | 166 | if is_shanghaitech 167 | load(strcat(gt_path, 'GT_', img_name, '.mat')); 168 | annPoints = image_info{1}.location; 169 | im = imread(strcat(img_path, img_name, '.jpg')); 170 | elseif is_ucf_cc_50 171 | load(strcat(gt_path, num2str(idx), '_ann.mat')) ; 172 | im = imread(strcat(img_path, img_name, '.jpg')); 173 | elseif is_worldexpo 174 | load(strcat(gt_path, scene_number, '\', img_name, '.mat')); 175 | annPoints = point_position; 176 | im = imread(strcat(img_path, img_name, '.jpg')); 177 | if is_use_roi 178 | load(strcat(gt_path, scene_number, '\roi.mat')); 179 | roi_raw.x = maskVerticesXCoordinates; 180 | roi_raw.y = maskVerticesYCoordinates; 181 | end 182 | elseif is_airport 183 | load(strcat(gt_path, img_name, '.mat')); 184 | annPoints = image_info.location; 185 | im = imread(strcat(img_path, img_name, '.jpg')); 186 | if is_use_roi 187 | load(strcat(roi_root_path, 'roi-', scene_number, '.mat')); 188 | roi_raw.x = roi_coordinate(:, 1); 189 | roi_raw.y = roi_coordinate(:, 2); 190 | end 191 | elseif is_ucsd 192 | load(strcat(gt_path, scene_number, '_frame_full', '.mat')); 193 | annPoints = fgt.frame{1, str2num(img_name(15:17))}.loc(:, 1:2); 194 | im = imread(strcat(img_path, img_name, '.png')); 195 | if is_use_roi 196 | load(strcat(roi_root_path, 'vidf1_33_roi_mainwalkway.mat')); 197 | roi_raw.x = roi.xi; 198 | roi_raw.y = roi.yi; 199 | end 200 | elseif is_trancos 201 | annPoints = load(strcat(gt_path, img_name, '.txt')); 202 | im = imread(strcat(img_path, img_name, '.jpg')); 203 | if is_use_roi 204 | load(strcat(roi_root_path, img_name, 'mask.mat')); 205 | roi_raw_map = BW * 1.0; 206 | end 207 | elseif is_mall 208 | load(strcat(gt_path, 'mall_gt.mat')); 209 | annPoints = frame{1, str2double(scene_number)}.loc; 210 | im = imread(strcat(img_path, img_name, '.jpg')); 211 | if is_use_roi 212 | load(strcat(roi_root_path, 'perspective_roi.mat')); 213 | roi_raw_map = roi.mask * 1.0; 214 | end 215 | elseif is_ucf_qnrf 216 | load(strcat(gt_path, img_name, '_ann.mat')) ; 217 | im = imread(strcat(img_path, img_name, '.jpg')); 218 | elseif is_gcc 219 | load(strcat(gt_path, img_name, '.mat')) ; 220 | annPoints = image_info.location; 221 | im = imread(strcat(img_path, img_name, '.jpg')); 222 | end 223 | 224 | [height, width, channel] = size(im); 225 | if is_gray_image 226 | if (channel == 3) 227 | im = rgb2gray(im); 228 | elseif (channel == 1) 229 | im = im; 230 | end 231 | else 232 | if (channel == 3) 233 | im = im; 234 | elseif (channel == 1) 235 | im_original = im; 236 | im = uint8(zeros(height, width, 3)); 237 | im(:, :, 1) = im_original; 238 | im(:, :, 2) = im_original; 239 | im(:, :, 3) = im_original; 240 | end 241 | end 242 | 243 | if resize_level ~= 1.0 244 | im = imresize(im, resize_level, 'bilinear'); 245 | annPoints = annPoints * resize_level; 246 | if is_use_roi 247 | if is_worldexpo || is_airport || is_ucsd 248 | roi_raw.x = roi_raw.x * resize_level; 249 | roi_raw.y = roi_raw.y * resize_level; 250 | elseif is_trancos || is_mall 251 | roi_raw_map = round(imresize(roi_raw_map, resize_level, 'bilinear')); 252 | end 253 | end 254 | end 255 | 256 | if is_resize 257 | [height, width, ~] = size(im); 258 | height_1 = TIMES * round(height / TIMES); 259 | width_1 = TIMES * round(width / TIMES); 260 | im = imresize(im, [height_1, width_1], 'bilinear'); 261 | annPoints(:, 1) = annPoints(:, 1) / width * width_1; 262 | annPoints(:, 2) = annPoints(:, 2) / height * height_1; 263 | if is_use_roi 264 | if is_worldexpo || is_airport || is_ucsd 265 | roi_raw.x = roi_raw.x / width * width_1; 266 | roi_raw.y = roi_raw.y / height * height_1; 267 | elseif is_trancos || is_mall 268 | roi_raw_map = round(imresize(roi_raw_map, [height_1, width_1], 'bilinear')); 269 | end 270 | end 271 | end 272 | if is_use_roi 273 | if is_worldexpo || is_airport || is_ucsd 274 | [height, width, ~] = size(im); 275 | roi_raw.width = width; 276 | roi_raw.height = height; 277 | elseif is_trancos || is_mall 278 | % nothing need to do here 279 | end 280 | end 281 | 282 | if is_resize_to_max_height_width 283 | [height, width, ~] = size(im); 284 | if height > width 285 | resized_height = max_height_width; 286 | resized_width = round(width / height * resized_height); 287 | else 288 | resized_width = max_height_width; 289 | resized_height = round(height / width * resized_width); 290 | end 291 | im = imresize(im, [resized_height, resized_width], 'bilinear'); 292 | annPoints(:, 1) = annPoints(:, 1) / width * resized_width; 293 | annPoints(:, 2) = annPoints(:, 2) / height * resized_height; 294 | if is_use_roi 295 | error('roi is not supported'); 296 | end 297 | end 298 | 299 | im_density = get_density_map_gaussian(im, annPoints, 15, 15); 300 | 301 | % im_gray = rgb2gray(im); 302 | % max_density = max(im_density(:)); 303 | % imshow(imlincomb(0.5, im_gray, 0.5, uint8(im_density / max_density * 256))); 304 | 305 | if is_use_roi 306 | if is_worldexpo || is_airport || is_ucsd 307 | if is_worldexpo || is_airport 308 | roi_key = scene_number; 309 | elseif is_ucsd 310 | roi_key = 'all'; 311 | end 312 | if roi_container.isKey(roi_key) 313 | roi_original = roi_container(roi_key); 314 | roi_fourth_size = roi_fourth_size_container(roi_key); 315 | roi_eighth_size = roi_eighth_size_container(roi_key); 316 | else 317 | roi_original = get_mask_map(roi_raw); 318 | roi_fourth_size = get_mask_map(roi_raw, 'fourth'); 319 | roi_eighth_size = get_mask_map(roi_raw, 'eighth'); 320 | roi_container(roi_key) = roi_original; 321 | roi_fourth_size_container(roi_key) = roi_fourth_size; 322 | roi_eighth_size_container(roi_key) = roi_eighth_size; 323 | end 324 | elseif is_trancos || is_mall 325 | roi_original.matrix = roi_raw_map; 326 | roi_original.mask = roi_raw_map == 1; 327 | roi_fourth_size.matrix = round(imresize(roi_raw_map, 0.25, 'bilinear')); 328 | roi_fourth_size.mask = roi_fourth_size.matrix == 1; 329 | roi_eighth_size.matrix = round(imresize(roi_raw_map, 0.125, 'bilinear')); 330 | roi_eighth_size.mask = roi_eighth_size.matrix == 1; 331 | end 332 | im_density = im_density .* roi_original.matrix; 333 | if is_apply_roi_to_image 334 | if is_gray_image 335 | im = im .* uint8(roi_original.matrix); 336 | else 337 | im(:, :, 1) = im(:, :, 1) .* uint8(roi_original.matrix); 338 | im(:, :, 2) = im(:, :, 2) .* uint8(roi_original.matrix); 339 | im(:, :, 3) = im(:, :, 3) .* uint8(roi_original.matrix); 340 | end 341 | end 342 | end 343 | 344 | if is_circumscribed_rectangle 345 | [top, bottom, left, right] = get_circumscribed_rectangle_roi(roi_original.mask(:,:)); 346 | rectangle_height = 8 * floor((bottom - top) / 8); 347 | rectangle_width = 8 * floor((right - left) / 8); 348 | bottom = top + rectangle_height - 1; 349 | right = left + rectangle_width - 1; 350 | top_fourth_size = max(round(top / 4), 1); 351 | left_fourth_size = max(round(left / 4), 1); 352 | bottom_fourth_size = top_fourth_size + rectangle_height / 4 - 1; 353 | right_fourth_size = left_fourth_size + rectangle_width / 4 - 1; 354 | top_eighth_size = max(round(top / 8), 1); 355 | left_eighth_size = max(round(left / 8), 1); 356 | bottom_eighth_size = top_eighth_size + rectangle_height / 8 - 1; 357 | right_eighth_size = left_eighth_size + rectangle_width / 8 - 1; 358 | 359 | if is_gray_image 360 | im = im(top:bottom, left:right); 361 | else 362 | im = im(top:bottom, left:right, :); 363 | end 364 | im_density = im_density(top:bottom, left:right); 365 | roi_original.mask = roi_original.mask(top:bottom, left:right); 366 | roi_fourth_size.mask = roi_fourth_size.mask(top_fourth_size:bottom_fourth_size, left_fourth_size:right_fourth_size); 367 | roi_eighth_size.mask = roi_eighth_size.mask(top_eighth_size:bottom_eighth_size, left_eighth_size:right_eighth_size); 368 | roi_original.matrix = roi_original.matrix(top:bottom, left:right); 369 | roi_fourth_size.matrix = roi_fourth_size.matrix(top_fourth_size:bottom_fourth_size, left_fourth_size:right_fourth_size); 370 | roi_eighth_size.matrix = roi_eighth_size.matrix(top_eighth_size:bottom_eighth_size, left_eighth_size:right_eighth_size); 371 | end 372 | 373 | if is_resize 374 | % nothing need to do here 375 | else 376 | [height, width, ~] = size(im); 377 | height_1 = TIMES * floor(height / TIMES); 378 | width_1 = TIMES * floor(width / TIMES); 379 | height_1_fourth_size = height_1 / 4; 380 | width_1_fourth_size = width_1 / 4; 381 | height_1_eighth_size = height_1 / 8; 382 | width_1_eighth_size = width_1 / 8; 383 | if is_gray_image 384 | im = im(1:height_1, 1:width_1); 385 | else 386 | im = im(1:height_1, 1:width_1, :); 387 | end 388 | im_density = im_density(1:height_1, 1:width_1); 389 | if is_use_roi 390 | roi_original.mask = roi_original.mask(1:height_1, 1:width_1); 391 | roi_fourth_size.mask = roi_fourth_size.mask(1:height_1_fourth_size, 1:width_1_fourth_size); 392 | roi_eighth_size.mask = roi_eighth_size.mask(1:height_1_eighth_size, 1:width_1_eighth_size); 393 | roi_original.matrix = roi_original.matrix(1:height_1, 1:width_1); 394 | roi_fourth_size.matrix = roi_fourth_size.matrix(1:height_1_fourth_size, 1:width_1_fourth_size); 395 | roi_eighth_size.matrix = roi_eighth_size.matrix(1:height_1_eighth_size, 1:width_1_eighth_size); 396 | end 397 | end 398 | 399 | imwrite(im, [img_save_path img_name '.jpg']); 400 | csvwrite([den_save_path img_name '.csv'], im_density); 401 | if is_save_roi 402 | roi = roi_original; 403 | save([roi_save_path img_name '_roi.mat'], 'roi'); 404 | roi = roi_fourth_size; 405 | save([roi_save_path img_name '_roi_fourth_size.mat'], 'roi'); 406 | roi = roi_eighth_size; 407 | save([roi_save_path img_name '_roi_eighth_size.mat'], 'roi'); 408 | end 409 | if is_overturn 410 | im_overturn = fliplr(im); 411 | im_density_overturn = fliplr(im_density); 412 | if is_use_roi 413 | roi_original_overturn.mask = fliplr(roi_original.mask); 414 | roi_fourth_size_overturn.mask = fliplr(roi_fourth_size.mask); 415 | roi_eighth_size_overturn.mask = fliplr(roi_eighth_size.mask); 416 | roi_original_overturn.matrix = fliplr(roi_original.matrix); 417 | roi_fourth_size_overturn.matrix = fliplr(roi_fourth_size.matrix); 418 | roi_eighth_size_overturn.matrix = fliplr(roi_eighth_size.matrix); 419 | end 420 | imwrite(im_overturn, [img_save_path img_name '_overturn.jpg']); 421 | csvwrite([den_save_path img_name '_overturn.csv'], im_density_overturn); 422 | if is_save_roi 423 | roi = roi_original_overturn; 424 | save([roi_save_path img_name '_overturn_roi.mat'], 'roi'); 425 | roi = roi_fourth_size_overturn; 426 | save([roi_save_path img_name '_overturn_roi_fourth_size.mat'], 'roi'); 427 | roi = roi_eighth_size_overturn; 428 | save([roi_save_path img_name '_overturn_roi_eighth_size.mat'], 'roi'); 429 | end 430 | end 431 | end 432 | 433 | -------------------------------------------------------------------------------- /data_preparation/create_same_size_patches.m: -------------------------------------------------------------------------------- 1 | clc; clear all; 2 | % seed = 95461354; 3 | % rng(seed); 4 | 5 | % number of patches 6 | % set to 0 to get adaptive number of patches 7 | number_of_patches = 0; 8 | 9 | % size of patches 10 | height_of_patches = 128; 11 | width_of_patches = 128; 12 | 13 | % resize image and annotation to fit max_height_width 14 | is_resize = false; 15 | max_height_width = 1024; 16 | 17 | %if true, will get gray images(one channel) 18 | %if false, will get RGB images(three channel) 19 | is_gray_image = false; 20 | 21 | %if true, will apply ROI to every image 22 | %if true, must give a ROI(roi.mask) 23 | is_use_roi = false; 24 | 25 | %if true, will save roi info for every image 26 | %not support N=16 now 27 | %if true, must give a ROI(roi.mask) 28 | %if true, is_use_roi must be true 29 | is_save_roi = false; 30 | 31 | %if true, will get a smallest patch outside ROI of every image. then get patches from this patch 32 | %if true, must give a ROI(roi.mask) 33 | %if true, is_use_roi must be true 34 | is_circumscribed_rectangle = false; 35 | 36 | % try to get patches in ROI 37 | is_all_in_roi = false; 38 | 39 | % try to get patches which have more than some pedestrians in each of them 40 | is_more_than_pedestrian = true; 41 | pedstrian_number = 1; 42 | 43 | % try to get patches which have zero pedestrians and more than one pedestrian in each of them 44 | is_forbid_zero_to_one_pedestrian = false; 45 | 46 | %if true, will get original patchs and overturned patchs 47 | is_overturn = true; 48 | 49 | % get fewer patches when number_of_patches == 0 50 | is_fewer_samples = false; 51 | 52 | is_shanghaitech = false; 53 | is_ucf_cc_50 = false; 54 | is_worldexpo = false; 55 | is_airport = false; 56 | is_ucf_qnrf = false; 57 | is_gcc = true; 58 | if is_shanghaitech 59 | dataset = 'B'; 60 | dataset_name = ['shanghaitech_part_' dataset '_patches_' num2str(number_of_patches)]; 61 | output_path = ['D:\Dataset\ShanghaiTech\formatted_trainval\']; 62 | img_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\train_data\images\']; 63 | gt_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\train_data\ground_truth\']; 64 | train_path_img = strcat(output_path, dataset_name,'\train\'); 65 | train_path_den = strcat(output_path, dataset_name,'\train_den\'); 66 | % img_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\test_data\images\']; 67 | % gt_path = ['D:\Dataset\ShanghaiTech\original\shanghaitech\part_' dataset '_final\test_data\ground_truth\']; 68 | % train_path_img = strcat(output_path, dataset_name,'\test\'); 69 | % train_path_den = strcat(output_path, dataset_name,'\test_den\'); 70 | elseif is_ucf_cc_50 71 | dataset_name = ['ucf_cc_50_patches_' num2str(number_of_patches)]; 72 | img_path = 'D:\Dataset\UCF_CC_50\original\image\'; 73 | gt_path = 'D:\Dataset\UCF_CC_50\original\gt\'; 74 | output_path = 'D:\Dataset\UCF_CC_50\formatted_trainval\'; 75 | train_path_img = strcat(output_path, dataset_name,'\train\'); 76 | train_path_den = strcat(output_path, dataset_name,'\train_den\'); 77 | elseif is_worldexpo 78 | dataset_name = ['worldexpo_patches_' num2str(number_of_patches)]; 79 | output_path = 'D:\Dataset\WorldExpo10\formatted_trainval\'; 80 | img_path = 'D:\Dataset\WorldExpo10\train_frame\'; 81 | gt_path = 'D:\Dataset\WorldExpo10\train_label\'; 82 | train_path_img = strcat(output_path, dataset_name,'\train\'); 83 | train_path_den = strcat(output_path, dataset_name,'\train_den\'); 84 | train_path_roi = strcat(output_path, dataset_name,'\train_roi\'); 85 | % img_path = 'D:\Dataset\WorldExpo10\test_frame\'; 86 | % gt_path = 'D:\Dataset\WorldExpo10\test_label\'; 87 | % train_path_img = strcat(output_path, dataset_name,'\test\'); 88 | % train_path_den = strcat(output_path, dataset_name,'\test_den\'); 89 | % train_path_roi = strcat(output_path, dataset_name,'\test_roi\'); 90 | elseif is_airport 91 | dataset_name = ['airport_patches_' num2str(number_of_patches)]; 92 | output_path = 'D:\Dataset\airport\formatted_trainval\'; 93 | img_path = 'D:\Dataset\airport\img\'; 94 | gt_path = 'D:\Dataset\airport\gt\'; 95 | roi_root_path = 'D:\Dataset\airport\roi\'; 96 | train_path_img = strcat(output_path, dataset_name,'\train\'); 97 | train_path_den = strcat(output_path, dataset_name,'\train_den\'); 98 | train_path_roi = strcat(output_path, dataset_name,'\train_roi\'); 99 | elseif is_ucf_qnrf 100 | dataset_name = ['ucf_qnrf_patches_' num2str(number_of_patches)]; 101 | img_path = 'D:\Dataset\UCF-QNRF\original\Train\'; 102 | gt_path = 'D:\Dataset\UCF-QNRF\original\Train\'; 103 | % img_path = 'D:\Dataset\UCF-QNRF\original\Test\'; 104 | % gt_path = 'D:\Dataset\UCF-QNRF\original\Test\'; 105 | output_path = 'D:\Dataset\UCF-QNRF\kernel\'; 106 | train_path_img = strcat(output_path, dataset_name,'\train\'); 107 | train_path_den = strcat(output_path, dataset_name,'\train_den\'); 108 | elseif is_gcc 109 | dataset_name = ['gcc_patches_' num2str(number_of_patches)]; 110 | output_path = ['D:\Dataset\GCC\all\formatted_trainval\']; 111 | img_path = ['D:\Dataset\GCC\all\train\jpgs\']; 112 | gt_path = ['D:\Dataset\GCC\all\train\mats\']; 113 | train_path_img = strcat(output_path, dataset_name,'\train\'); 114 | train_path_den = strcat(output_path, dataset_name,'\train_den\'); 115 | end 116 | 117 | mkdir(output_path); 118 | mkdir(train_path_img); 119 | mkdir(train_path_den); 120 | if is_save_roi 121 | mkdir(train_path_roi); 122 | end 123 | 124 | dir_output = dir(fullfile(img_path,'*.jpg')); 125 | img_name_list = {dir_output.name}; 126 | if is_use_roi 127 | roi_container = containers.Map; 128 | roi_fourth_size_container = containers.Map; 129 | roi_eighth_size_container = containers.Map; 130 | end 131 | 132 | num_images = numel(img_name_list); 133 | for index = 1:num_images 134 | [~, img_name, ~] = fileparts(img_name_list{index}); 135 | if is_worldexpo 136 | scene_number = img_name(1:6); 137 | end 138 | if is_airport 139 | scene_number = img_name(1:2); 140 | end 141 | 142 | if (mod(index, 10)==0) 143 | fprintf(1,'Processing %3d/%d files\n', index, num_images); 144 | end 145 | 146 | if is_shanghaitech 147 | load(strcat(gt_path, 'GT_', img_name, '.mat')) ; 148 | input_img_name = strcat(img_path, img_name, '.jpg'); 149 | elseif is_ucf_cc_50 150 | load(strcat(gt_path, img_name, '_ann.mat')) ; 151 | input_img_name = strcat(img_path, img_name, '.jpg'); 152 | elseif is_worldexpo 153 | load(strcat(gt_path, scene_number, '\', img_name, '.mat')); 154 | input_img_name = strcat(img_path, img_name_list{index}); 155 | if is_use_roi 156 | roi_path = strcat(gt_path, scene_number, '\roi.mat'); 157 | end 158 | elseif is_airport 159 | load(strcat(gt_path, img_name, '.mat')); 160 | input_img_name = strcat(img_path, img_name_list{index}); 161 | if is_use_roi 162 | roi_path = strcat(roi_root_path, 'roi-', scene_number, '.mat'); 163 | end 164 | elseif is_ucf_qnrf 165 | load(strcat(gt_path, img_name, '_ann.mat')) ; 166 | input_img_name = strcat(img_path, img_name, '.jpg'); 167 | elseif is_gcc 168 | load(strcat(gt_path, img_name, '.mat')) ; 169 | input_img_name = strcat(img_path, img_name, '.jpg'); 170 | end 171 | 172 | im = imread(input_img_name); 173 | 174 | [height, width, channel] = size(im); 175 | if is_gray_image 176 | if (channel == 3) 177 | im = rgb2gray(im); 178 | elseif (channel == 1) 179 | im = im; 180 | end 181 | else 182 | if (channel == 3) 183 | im = im; 184 | elseif (channel == 1) 185 | im_original = im; 186 | im = uint8(zeros(height, width, 3)); 187 | im(:, :, 1) = im_original; 188 | im(:, :, 2) = im_original; 189 | im(:, :, 3) = im_original; 190 | end 191 | end 192 | 193 | if is_shanghaitech 194 | annPoints = image_info{1}.location; 195 | elseif is_ucf_cc_50 196 | % nothing need to do here 197 | elseif is_worldexpo 198 | annPoints = point_position; 199 | elseif is_airport 200 | annPoints = image_info.location; 201 | elseif is_ucf_qnrf 202 | % nothing need to do here 203 | elseif is_gcc 204 | annPoints = image_info.location; 205 | end 206 | 207 | if is_resize 208 | [height, width, ~] = size(im); 209 | if height > width 210 | resized_height = max_height_width; 211 | resized_width = round(width / height * resized_height); 212 | else 213 | resized_width = max_height_width; 214 | resized_height = round(height / width * resized_width); 215 | end 216 | im = imresize(im, [resized_height, resized_width], 'bilinear'); 217 | annPoints(:, 1) = annPoints(:, 1) / width * resized_width; 218 | annPoints(:, 2) = annPoints(:, 2) / height * resized_height; 219 | if is_use_roi 220 | error('roi is not supported'); 221 | end 222 | end 223 | 224 | % im_density_1 = get_density_map_gaussian(im, annPoints, 9, 4); 225 | % im_density_2 = get_density_map_gaussian(im, annPoints, 15, 4); 226 | % im_density_3 = get_density_map_gaussian(im, annPoints, 21, 4); 227 | % im_density = (im_density_1 + im_density_2 + im_density_3) / 3.0; 228 | 229 | % im_density_1 = get_density_map_gaussian(im, annPoints, 9, 4); 230 | % im_density_2 = get_density_map_gaussian(im, annPoints, 15, 4); 231 | % im_density = (im_density_1 + im_density_2) / 2.0; 232 | 233 | im_density = get_density_map_gaussian(im, annPoints, 15, 4); 234 | 235 | % view_density_map(im_density); 236 | 237 | % fid=fopen('D:\count.txt','a'); 238 | % fprintf(fid, '%s\t%f\n', img_name, sum(im_density(:))); 239 | % fclose(fid); 240 | 241 | if is_worldexpo 242 | if is_use_roi 243 | if roi_container.isKey(scene_number) 244 | roi_original = roi_container(scene_number); 245 | roi_fourth_size = roi_fourth_size_container(scene_number); 246 | roi_eighth_size = roi_eighth_size_container(scene_number); 247 | else 248 | roi_original = get_mask_map_worldexpo(roi_path); 249 | roi_container(scene_number) = roi_original; 250 | roi_fourth_size = get_mask_map_worldexpo(roi_path, 'fourth'); 251 | roi_fourth_size_container(scene_number) = roi_fourth_size; 252 | roi_eighth_size = get_mask_map_worldexpo(roi_path, 'eighth'); 253 | roi_eighth_size_container(scene_number) = roi_eighth_size; 254 | end 255 | im_density = im_density .* roi_original.mask; 256 | end 257 | end 258 | 259 | if is_airport 260 | if is_use_roi 261 | if roi_container.isKey(scene_number) 262 | roi_original = roi_container(scene_number); 263 | roi_fourth_size = roi_fourth_size_container(scene_number); 264 | roi_eighth_size = roi_eighth_size_container(scene_number); 265 | else 266 | roi_original = get_mask_map_airport(roi_path); 267 | roi_container(scene_number) = roi_original; 268 | roi_fourth_size = get_mask_map_airport(roi_path, 'fourth'); 269 | roi_fourth_size_container(scene_number) = roi_fourth_size; 270 | roi_eighth_size = get_mask_map_airport(roi_path, 'eighth'); 271 | roi_eighth_size_container(scene_number) = roi_eighth_size; 272 | end 273 | im_density = im_density .* roi_original.matrix; 274 | end 275 | end 276 | 277 | if is_circumscribed_rectangle 278 | [top, bottom, left, right] = get_circumscribed_rectangle_roi(roi_original.mask(:,:)); 279 | rectangle_height = 8 * floor((bottom - top) / 8); 280 | rectangle_width = 8 * floor((right - left) / 8); 281 | bottom = top + rectangle_height - 1; 282 | right = left + rectangle_width - 1; 283 | top_fourth_size = max(round(top / 4), 1); 284 | left_fourth_size = max(round(left / 4), 1); 285 | bottom_fourth_size = top_fourth_size + rectangle_height / 4 - 1; 286 | right_fourth_size = left_fourth_size + rectangle_width / 4 - 1; 287 | top_eighth_size = max(round(top / 8), 1); 288 | left_eighth_size = max(round(left / 8), 1); 289 | bottom_eighth_size = top_eighth_size + rectangle_height / 8 - 1; 290 | right_eighth_size = left_eighth_size + rectangle_width / 8 - 1; 291 | 292 | if is_gray_image 293 | im = im(top:bottom, left:right); 294 | else 295 | im = im(top:bottom, left:right, :); 296 | end 297 | im_density = im_density(top:bottom, left:right); 298 | roi_original.mask = roi_original.mask(top:bottom, left:right); 299 | roi_fourth_size.mask = roi_fourth_size.mask(top_fourth_size:bottom_fourth_size, left_fourth_size:right_fourth_size); 300 | roi_eighth_size.mask = roi_eighth_size.mask(top_eighth_size:bottom_eighth_size, left_eighth_size:right_eighth_size); 301 | roi_original.matrix = roi_original.matrix(top:bottom, left:right); 302 | roi_fourth_size.matrix = roi_fourth_size.matrix(top_fourth_size:bottom_fourth_size, left_fourth_size:right_fourth_size); 303 | roi_eighth_size.matrix = roi_eighth_size.matrix(top_eighth_size:bottom_eighth_size, left_eighth_size:right_eighth_size); 304 | end 305 | 306 | [height, width, ~] = size(im); 307 | half_width_of_patches = floor(width_of_patches / 2); 308 | half_height_of_patches = floor(height_of_patches / 2); 309 | half_width_of_patches_fourth_size = half_width_of_patches / 4; 310 | half_height_of_patches_fourth_size = half_height_of_patches / 4; 311 | half_width_of_patches_eighth_size = half_width_of_patches / 8; 312 | half_height_of_patches_eighth_size = half_height_of_patches / 8; 313 | 314 | start_width = half_width_of_patches + 1; 315 | end_width = width - half_width_of_patches; 316 | start_height = half_height_of_patches + 1; 317 | end_height = height - half_height_of_patches; 318 | 319 | if number_of_patches == 0 320 | if is_fewer_samples 321 | N = ceil(height * width / height_of_patches / width_of_patches); 322 | else 323 | N = 4 * ceil(height * width / height_of_patches / width_of_patches); 324 | end 325 | else 326 | N = number_of_patches; 327 | end 328 | 329 | i = 0; 330 | patience = 0; 331 | while(i <= N) 332 | x_position = rand; 333 | y_position = rand; 334 | 335 | x = floor((end_width - start_width) * x_position + start_width); 336 | y = floor((end_height - start_height) * y_position + start_height); 337 | x1 = x - half_width_of_patches; 338 | y1 = y - half_height_of_patches; 339 | x2 = x + half_width_of_patches-1; 340 | y2 = y + half_height_of_patches-1; 341 | 342 | x_fourth_size = max(floor(x / 4), 1) + 1; 343 | y_fourth_size = max(floor(y / 4), 1) + 1; 344 | x1_fourth_size = x_fourth_size - half_width_of_patches_fourth_size; 345 | y1_fourth_size = y_fourth_size - half_height_of_patches_fourth_size; 346 | x2_fourth_size = x_fourth_size + half_width_of_patches_fourth_size - 1; 347 | y2_fourth_size = y_fourth_size + half_height_of_patches_fourth_size - 1; 348 | 349 | x_eighth_size = max(floor(x / 8), 1) + 1; 350 | y_eighth_size = max(floor(y / 8), 1) + 1; 351 | x1_eighth_size = x_eighth_size - half_width_of_patches_eighth_size; 352 | y1_eighth_size = y_eighth_size - half_height_of_patches_eighth_size; 353 | x2_eighth_size = x_eighth_size + half_width_of_patches_eighth_size - 1; 354 | y2_eighth_size = y_eighth_size + half_height_of_patches_eighth_size - 1; 355 | 356 | if is_gray_image 357 | im_sampled = im(y1:y2, x1:x2); 358 | else 359 | im_sampled = im(y1:y2, x1:x2, :); 360 | end 361 | im_density_sampled = im_density(y1:y2, x1:x2); 362 | if is_use_roi 363 | roi_sampled.mask = roi_original.mask(y1:y2,x1:x2); 364 | roi_fourth_size_sampled.mask = roi_fourth_size.mask(y1_fourth_size:y2_fourth_size, x1_fourth_size:x2_fourth_size); 365 | roi_eighth_size_sampled.mask = roi_eighth_size.mask(y1_eighth_size:y2_eighth_size, x1_eighth_size:x2_eighth_size); 366 | roi_sampled.matrix = roi_original.matrix(y1:y2,x1:x2); 367 | roi_fourth_size_sampled.matrix = roi_fourth_size.matrix(y1_fourth_size:y2_fourth_size, x1_fourth_size:x2_fourth_size); 368 | roi_eighth_size_sampled.matrix = roi_eighth_size.matrix(y1_eighth_size:y2_eighth_size, x1_eighth_size:x2_eighth_size); 369 | end 370 | 371 | if is_all_in_roi && sum(roi_sampled.matrix(:)) ~= width_of_patches * height_of_patches 372 | patience = patience + 1; 373 | if patience < N 374 | continue 375 | end 376 | elseif is_more_than_pedestrian && sum(im_density_sampled(:)) < pedstrian_number 377 | patience = patience + 1; 378 | if patience < N 379 | continue 380 | end 381 | elseif is_forbid_zero_to_one_pedestrian && sum(im_density_sampled(:)) > 0.0 && sum(im_density_sampled(:)) < 1.0 382 | patience = patience + 1; 383 | if patience < N 384 | continue 385 | end 386 | end 387 | patience = 0; 388 | i = i + 1; 389 | 390 | save_name = strcat(img_name, '_', num2str(i)); 391 | 392 | imwrite(im_sampled, [train_path_img save_name '.jpg']); 393 | csvwrite([train_path_den save_name '.csv'], im_density_sampled); 394 | if is_save_roi 395 | roi = roi_sampled; 396 | save([train_path_roi save_name '_roi.mat'], 'roi'); 397 | roi = roi_fourth_size_sampled; 398 | save([train_path_roi save_name '_roi_fourth_size.mat'], 'roi'); 399 | roi = roi_eighth_size_sampled; 400 | save([train_path_roi save_name '_roi_eighth_size.mat'], 'roi'); 401 | end 402 | 403 | if is_overturn 404 | im_sampled_overturn = fliplr(im_sampled); 405 | im_density_sampled_overturn = fliplr(im_density_sampled); 406 | if is_use_roi 407 | roi_sampled_overturn.mask = fliplr(roi_sampled.mask); 408 | roi_fourth_size_sampled_overturn.mask = fliplr(roi_fourth_size_sampled.mask); 409 | roi_eighth_size_sampled_overturn.mask = fliplr(roi_eighth_size_sampled.mask); 410 | roi_sampled_overturn.matrix = fliplr(roi_sampled.matrix); 411 | roi_fourth_size_sampled_overturn.matrix = fliplr(roi_fourth_size_sampled.matrix); 412 | roi_eighth_size_sampled_overturn.matrix = fliplr(roi_eighth_size_sampled.matrix); 413 | end 414 | imwrite(im_sampled_overturn, [train_path_img save_name '_overturn.jpg']); 415 | csvwrite([train_path_den save_name '_overturn.csv'], im_density_sampled_overturn); 416 | if is_save_roi 417 | roi = roi_sampled_overturn; 418 | save([train_path_roi save_name '_overturn_roi.mat'], 'roi'); 419 | roi = roi_fourth_size_sampled_overturn; 420 | save([train_path_roi save_name '_overturn_roi_fourth_size.mat'], 'roi'); 421 | roi = roi_eighth_size_sampled_overturn; 422 | save([train_path_roi save_name '_overturn_roi_eighth_size.mat'], 'roi'); 423 | end 424 | end 425 | end 426 | end 427 | 428 | -------------------------------------------------------------------------------- /data_preparation/get_circumscribed_rectangle_roi.m: -------------------------------------------------------------------------------- 1 | function [top, bottom, left, right] = get_circumscribed_rectangle_roi(mask) 2 | %get the smallest circumscribed rectangle in ROI 3 | %input mask is a m*n matrix 4 | 5 | whole_size = size(mask); 6 | outside_top = 0; 7 | outside_bottom = whole_size(1); 8 | outside_left = 0; 9 | outside_right = whole_size(2); 10 | 11 | mark = false; 12 | for i = 1:whole_size(1) 13 | for j = 1:whole_size(2) 14 | if mask(i, j) == 1 15 | mark = true; 16 | end 17 | end 18 | if mark 19 | outside_top = i; 20 | mark = false; 21 | break; 22 | end 23 | end 24 | 25 | mark = false; 26 | for i = whole_size(1):-1:1 27 | for j = 1:whole_size(2) 28 | if mask(i, j) == 1 29 | mark = true; 30 | end 31 | end 32 | if mark 33 | outside_bottom = i; 34 | mark = false; 35 | break; 36 | end 37 | end 38 | 39 | mark = false; 40 | for j = 1:whole_size(2) 41 | for i = 1:whole_size(1) 42 | if mask(i, j) == 1 43 | mark = true; 44 | end 45 | end 46 | if mark 47 | outside_left = j; 48 | mark = false; 49 | break; 50 | end 51 | end 52 | 53 | mark = false; 54 | for j = whole_size(2):-1:1 55 | for i = 1:whole_size(1) 56 | if mask(i, j) == 1 57 | mark = true; 58 | end 59 | end 60 | if mark 61 | outside_right = j; 62 | mark = false; 63 | break; 64 | end 65 | end 66 | 67 | top = outside_top; 68 | bottom = outside_bottom; 69 | left = outside_left; 70 | right = outside_right; 71 | end -------------------------------------------------------------------------------- /data_preparation/get_density_map_gaussian.m: -------------------------------------------------------------------------------- 1 | function im_density = get_density_map_gaussian(im, points, gaussian_size, sigma) 2 | if nargin == 2 3 | gaussian_size = 15; 4 | sigma = 4; 5 | elseif nargin ~= 4 6 | error('No size or sigma provided.') 7 | end 8 | 9 | [h, w, ~] = size(im); 10 | im_density = zeros(h, w); 11 | 12 | if(isempty(points)) 13 | return; 14 | end 15 | %{ 16 | if(length(points(:,1))==1) 17 | x1 = max(1,min(w,round(points(1,1)))); 18 | y1 = max(1,min(h,round(points(1,2)))); 19 | im_density(y1,x1) = 255; 20 | return; 21 | end 22 | %} 23 | for j = 1:length(points(:,1)) 24 | H = fspecial('Gaussian', gaussian_size, sigma); 25 | 26 | x = floor(points(j,1)); 27 | y = floor(points(j,2)); 28 | if(x > w || y > h || x < 1 || y < 1) 29 | continue; 30 | end 31 | x = min(w,max(1,abs(int32(x)))); 32 | y = min(h,max(1,abs(int32(y)))); 33 | 34 | x1 = x - int32(floor(gaussian_size / 2)); y1 = y - int32(floor(gaussian_size / 2)); 35 | x2 = x + int32(floor(gaussian_size / 2)); y2 = y + int32(floor(gaussian_size / 2)); 36 | dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0; 37 | change_H = false; 38 | if(x1 < 1) 39 | dfx1 = abs(x1)+1; 40 | x1 = 1; 41 | change_H = true; 42 | end 43 | if(y1 < 1) 44 | dfy1 = abs(y1)+1; 45 | y1 = 1; 46 | change_H = true; 47 | end 48 | if(x2 > w) 49 | dfx2 = x2 - w; 50 | x2 = w; 51 | change_H = true; 52 | end 53 | if(y2 > h) 54 | dfy2 = y2 - h; 55 | y2 = h; 56 | change_H = true; 57 | end 58 | x1h = 1+dfx1; y1h = 1+dfy1; x2h = gaussian_size - dfx2; y2h = gaussian_size - dfy2; 59 | if (change_H == true) 60 | H = fspecial('Gaussian',[double(y2h-y1h+1), double(x2h-x1h+1)],sigma); 61 | end 62 | im_density(y1:y2, x1:x2) = im_density(y1:y2, x1:x2) + H; 63 | 64 | end 65 | 66 | end -------------------------------------------------------------------------------- /data_preparation/get_density_map_gaussian_different_kernel.m: -------------------------------------------------------------------------------- 1 | function im_density = get_density_map_gaussian_different_kernel(im, points, density_mask) 2 | density_mask = double(density_mask); 3 | 4 | original_gaussian_size = 15; 5 | original_sigma = 4; 6 | 7 | [h, w, ~] = size(im); 8 | im_density = zeros(h, w); 9 | 10 | if(isempty(points)) 11 | return; 12 | end 13 | %{ 14 | if(length(points(:,1))==1) 15 | x1 = max(1,min(w,round(points(1,1)))); 16 | y1 = max(1,min(h,round(points(1,2)))); 17 | im_density(y1,x1) = 255; 18 | return; 19 | end 20 | %} 21 | for j = 1:length(points(:,1)) 22 | x = min(w,max(1,abs(int32(floor(points(j,1)))))); 23 | y = min(h,max(1,abs(int32(floor(points(j,2)))))); 24 | density_class = density_mask(y, x); 25 | 26 | if density_class == 0 27 | gaussian_size = 45; 28 | sigma = 12; 29 | elseif density_class == 1 30 | gaussian_size = 31; 31 | sigma = 8; 32 | elseif density_class == 2 33 | gaussian_size = 15; 34 | sigma = 4; 35 | end 36 | 37 | H = fspecial('Gaussian', gaussian_size, sigma); 38 | if(x > w || y > h) 39 | continue; 40 | end 41 | x1 = x - int32(floor(gaussian_size / 2)); y1 = y - int32(floor(gaussian_size / 2)); 42 | x2 = x + int32(floor(gaussian_size / 2)); y2 = y + int32(floor(gaussian_size / 2)); 43 | dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0; 44 | change_H = false; 45 | if(x1 < 1) 46 | dfx1 = abs(x1)+1; 47 | x1 = 1; 48 | change_H = true; 49 | end 50 | if(y1 < 1) 51 | dfy1 = abs(y1)+1; 52 | y1 = 1; 53 | change_H = true; 54 | end 55 | if(x2 > w) 56 | dfx2 = x2 - w; 57 | x2 = w; 58 | change_H = true; 59 | end 60 | if(y2 > h) 61 | dfy2 = y2 - h; 62 | y2 = h; 63 | change_H = true; 64 | end 65 | x1h = 1+dfx1; y1h = 1+dfy1; x2h = gaussian_size - dfx2; y2h = gaussian_size - dfy2; 66 | if (change_H == true) 67 | H = fspecial('Gaussian',[double(y2h-y1h+1), double(x2h-x1h+1)],sigma); 68 | end 69 | im_density(y1:y2, x1:x2) = im_density(y1:y2, x1:x2) + H; 70 | 71 | end 72 | 73 | end -------------------------------------------------------------------------------- /data_preparation/get_density_map_gaussian_perspective.m: -------------------------------------------------------------------------------- 1 | function im_density = get_density_map_gaussian_perspective(im,points,perspective_map) 2 | 3 | im_density = zeros(size(im)); 4 | [h,w] = size(im_density); 5 | 6 | if(isempty(points)) 7 | return; 8 | end 9 | 10 | if(length(points(:,1))==1) 11 | x1 = max(1,min(w,round(points(1,1)))); 12 | y1 = max(1,min(h,round(points(1,2)))); 13 | im_density(y1,x1) = 255; 14 | return; 15 | end 16 | 17 | for j = 1:length(points) 18 | x = min(w,max(1,abs(int32(floor(points(j,1)))))); 19 | y = min(h,max(1,abs(int32(floor(points(j,2)))))); 20 | if(x > w || y > h) 21 | continue; 22 | end 23 | 24 | f_sz = 15; 25 | sigma = perspective_map(y, x); 26 | H = fspecial('Gaussian',[f_sz, f_sz],sigma); 27 | 28 | x1 = x - int32(floor(f_sz/2)); y1 = y - int32(floor(f_sz/2)); 29 | x2 = x + int32(floor(f_sz/2)); y2 = y + int32(floor(f_sz/2)); 30 | dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0; 31 | change_H = false; 32 | if(x1 < 1) 33 | dfx1 = abs(x1)+1; 34 | x1 = 1; 35 | change_H = true; 36 | end 37 | if(y1 < 1) 38 | dfy1 = abs(y1)+1; 39 | y1 = 1; 40 | change_H = true; 41 | end 42 | if(x2 > w) 43 | dfx2 = x2 - w; 44 | x2 = w; 45 | change_H = true; 46 | end 47 | if(y2 > h) 48 | dfy2 = y2 - h; 49 | y2 = h; 50 | change_H = true; 51 | end 52 | x1h = 1+dfx1; y1h = 1+dfy1; x2h = f_sz - dfx2; y2h = f_sz - dfy2; 53 | if (change_H == true) 54 | H = fspecial('Gaussian',[double(y2h-y1h+1), double(x2h-x1h+1)],sigma); 55 | end 56 | im_density(y1:y2,x1:x2) = im_density(y1:y2,x1:x2) + H; 57 | 58 | end 59 | 60 | end -------------------------------------------------------------------------------- /data_preparation/get_inscribed_rectangle_roi.m: -------------------------------------------------------------------------------- 1 | function [top, bottom, left, right] = get_inscribed_rectangle_roi(mask) 2 | %get the largest inscribed rectangle in ROI 3 | %input mask is a m*n matrix 4 | 5 | whole_size = size(mask); 6 | outside_top = 0; 7 | outside_bottom = whole_size(1); 8 | outside_left = 0; 9 | outside_right = whole_size(2); 10 | 11 | mark = false; 12 | for i = 1:whole_size(1) 13 | for j = 1:whole_size(2) 14 | if mask(i, j) == 1 15 | mark = true; 16 | end 17 | end 18 | if mark 19 | outside_top = i; 20 | mark = false; 21 | break; 22 | end 23 | end 24 | 25 | mark = false; 26 | for i = whole_size(1):-1:1 27 | for j = 1:whole_size(2) 28 | if mask(i, j) == 1 29 | mark = true; 30 | end 31 | end 32 | if mark 33 | outside_bottom = i; 34 | mark = false; 35 | break; 36 | end 37 | end 38 | 39 | mark = false; 40 | for j = 1:whole_size(2) 41 | for i = 1:whole_size(1) 42 | if mask(i, j) == 1 43 | mark = true; 44 | end 45 | end 46 | if mark 47 | outside_left = j; 48 | mark = false; 49 | break; 50 | end 51 | end 52 | 53 | mark = false; 54 | for j = whole_size(2):-1:1 55 | for i = 1:whole_size(1) 56 | if mask(i, j) == 1 57 | mark = true; 58 | end 59 | end 60 | if mark 61 | outside_right = j; 62 | mark = false; 63 | break; 64 | end 65 | end 66 | 67 | inside_top = round((outside_top + outside_bottom) / 2) - 1; 68 | inside_bottom = round((outside_top + outside_bottom) / 2) + 1; 69 | inside_left = round((outside_left + outside_right) / 2) - 1; 70 | inside_right = round((outside_left + outside_right) / 2) + 1; 71 | raw_inside_left = inside_left; 72 | raw_inside_right = inside_right; 73 | 74 | proportion = abs(outside_right - outside_left) / abs(outside_bottom - outside_top); 75 | 76 | mark_top = true; 77 | mark_bottom = true; 78 | mark_left = true; 79 | mark_right = true; 80 | 81 | while mark_top || mark_bottom || mark_left || mark_right 82 | if mark_top 83 | for i = inside_left:inside_right 84 | if mask(inside_top, i) == 0 85 | mark_top = false; 86 | end 87 | end 88 | if mark_top 89 | inside_top = inside_top - 1; 90 | end 91 | end 92 | 93 | if mark_bottom 94 | for i = inside_left:inside_right 95 | if mask(inside_bottom, i) == 0 96 | mark_bottom = false; 97 | end 98 | end 99 | if mark_bottom 100 | inside_bottom = inside_bottom + 1; 101 | end 102 | end 103 | 104 | if mark_left 105 | for i = inside_top:inside_bottom 106 | if mask(i, inside_left) == 0 107 | mark_left = false; 108 | end 109 | end 110 | if mark_left 111 | raw_inside_left = raw_inside_left - proportion; 112 | inside_left = round(raw_inside_left); 113 | end 114 | end 115 | 116 | if mark_right 117 | for i = inside_top:inside_bottom 118 | if mask(i, inside_right) == 0 119 | mark_right = false; 120 | end 121 | end 122 | if mark_right 123 | raw_inside_right = raw_inside_right + proportion; 124 | inside_right = round(raw_inside_right); 125 | end 126 | end 127 | end 128 | 129 | top = inside_top; 130 | bottom = inside_bottom; 131 | left = inside_left; 132 | right = inside_right; 133 | 134 | end -------------------------------------------------------------------------------- /data_preparation/get_mask_map.m: -------------------------------------------------------------------------------- 1 | function roi = get_mask_map(roi_data, mode) 2 | %if is_quarter is true, will get quarter size ROI (same size of density map) 3 | is_fourth = false; 4 | is_eighth = false; 5 | 6 | if nargin < 2 7 | is_fourth = false; 8 | is_eighth = false; 9 | else 10 | if strcmp(mode, 'fourth') 11 | is_fourth = true; 12 | elseif strcmp(mode, 'eighth') 13 | is_eighth = true; 14 | end 15 | end 16 | 17 | X = roi_data.width; 18 | Y = roi_data.height; 19 | 20 | if is_fourth 21 | X = ceil(X / 4); 22 | Y = ceil(Y / 4); 23 | roi_data.x = roi_data.x / 4; 24 | roi_data.y = roi_data.y / 4; 25 | end 26 | 27 | if is_eighth 28 | X = ceil(X / 8); 29 | Y = ceil(Y / 8); 30 | roi_data.x = roi_data.x / 8; 31 | roi_data.y = roi_data.y / 8; 32 | end 33 | 34 | roi.mask = false(Y, X); 35 | roi.matrix= zeros(Y, X); 36 | for y_ = 1:Y 37 | for x_ = 1:X 38 | % IN = inpolygon(x_, y_, roi_data.x, roi_data.y); 39 | IN = inpolygon(y_, x_, roi_data.y, roi_data.x); 40 | roi.mask(y_, x_) = (IN==1); 41 | roi.matrix(y_, x_) = IN; 42 | end 43 | end 44 | %view_density_map(ma); 45 | end -------------------------------------------------------------------------------- /data_preparation/get_mask_map_airport.m: -------------------------------------------------------------------------------- 1 | function roi = get_mask_map_airport(roi_data, mode) 2 | %if mode is 'fourth', will get fourth size ROI (same size of density map) 3 | is_fourth = false; 4 | is_eighth = false; 5 | 6 | if nargin < 2 7 | is_fourth = false; 8 | is_eighth = false; 9 | else 10 | if strcmp(mode, 'fourth') 11 | is_fourth = true; 12 | elseif strcmp(mode, 'eighth') 13 | is_eighth = true; 14 | end 15 | end 16 | 17 | X = 960; 18 | Y = 540; 19 | 20 | if is_fourth 21 | X = X / 4; 22 | Y = Y / 4; 23 | roi_data.x = roi_data.x / 4; 24 | roi_data.y = roi_data.y / 4; 25 | end 26 | 27 | if is_eighth 28 | X = X / 8; 29 | Y = Y / 8; 30 | roi_data.x = roi_data.x / 8; 31 | roi_data.y = roi_data.y / 8; 32 | end 33 | 34 | roi.mask = false(Y, X); 35 | roi.matrix = zeros(Y, X); 36 | for y_ = 1:Y 37 | for x_ = 1:X 38 | IN = inpolygon(x_, y_, roi_data.x, roi_data.y); 39 | roi.mask(y_,x_) = (IN==1); 40 | roi.matrix(y_, x_) = IN; 41 | end 42 | end 43 | %view_density_map(roi.matrix); 44 | end -------------------------------------------------------------------------------- /data_preparation/get_mask_map_ucsd.m: -------------------------------------------------------------------------------- 1 | function roi = get_mask_map_ucsd(roi_data, mode) 2 | %if mode is 'fourth', will get fourth size ROI (same size of density map) 3 | is_fourth = false; 4 | is_eighth = false; 5 | 6 | if nargin < 2 7 | is_fourth = false; 8 | is_eighth = false; 9 | else 10 | if strcmp(mode, 'fourth') 11 | is_fourth = true; 12 | elseif strcmp(mode, 'eighth') 13 | is_eighth = true; 14 | end 15 | end 16 | 17 | X = 238; 18 | Y = 158; 19 | 20 | if is_fourth 21 | X = ceil(X / 4); 22 | Y = ceil(Y / 4); 23 | roi_data.x = roi_data.x / 4; 24 | roi_data.y = roi_data.y / 4; 25 | end 26 | 27 | if is_eighth 28 | X = ceil(X / 8); 29 | Y = ceil(Y / 8); 30 | roi_data.x = roi_data.x / 8; 31 | roi_data.y = roi_data.y / 8; 32 | end 33 | 34 | roi.mask = false(Y, X); 35 | roi.matrix = zeros(Y, X); 36 | for y_ = 1:Y 37 | for x_ = 1:X 38 | IN = inpolygon(x_, y_, roi_data.x, roi_data.y); 39 | roi.mask(y_,x_) = (IN==1); 40 | roi.matrix(y_, x_) = IN; 41 | end 42 | end 43 | %view_density_map(roi.matrix); 44 | end -------------------------------------------------------------------------------- /data_preparation/get_mask_map_worldexpo.m: -------------------------------------------------------------------------------- 1 | function roi = get_mask_map_worldexpo(roi_data, mode) 2 | %if is_quarter is true, will get quarter size ROI (same size of density map) 3 | is_fourth = false; 4 | is_eighth = false; 5 | 6 | if nargin < 2 7 | is_fourth = false; 8 | is_eighth = false; 9 | else 10 | if strcmp(mode, 'fourth') 11 | is_fourth = true; 12 | elseif strcmp(mode, 'eighth') 13 | is_eighth = true; 14 | end 15 | end 16 | 17 | %104207 18 | %200608 19 | %200702 20 | %202201 21 | %500717 22 | %roi_path = 'D:\Dataset\WorldExpo10\test_label\500717\roi.mat'; 23 | 24 | X = 720; 25 | Y = 576; 26 | 27 | if is_fourth 28 | X = X / 4; 29 | Y = Y / 4; 30 | roi_data.x = roi_data.x / 4; 31 | roi_data.y = roi_data.y / 4; 32 | end 33 | 34 | if is_eighth 35 | X = X / 8; 36 | Y = Y / 8; 37 | roi_data.x = roi_data.x / 8; 38 | roi_data.y = roi_data.y / 8; 39 | end 40 | 41 | roi.mask = false(Y, X); 42 | roi.matrix= zeros(Y, X); 43 | for y_ = 1:Y 44 | for x_ = 1:X 45 | IN = inpolygon(x_, y_, roi_data.x, roi_data.y); 46 | roi.mask(y_,x_) = (IN==1); 47 | roi.matrix(y_, x_) = IN; 48 | end 49 | end 50 | %view_density_map(ma); 51 | end -------------------------------------------------------------------------------- /data_preparation/match_perspective_map.m: -------------------------------------------------------------------------------- 1 | clc; 2 | clear all; 3 | 4 | % target_perspective_path = 'D:\Dataset\WorldExpo10\original\test_perspective\104207.mat'; 5 | % target_perspective_path = 'D:\Dataset\WorldExpo10\original\test_perspective\200608.mat'; 6 | % target_perspective_path = 'D:\Dataset\WorldExpo10\original\test_perspective\200702.mat'; 7 | % target_perspective_path = 'D:\Dataset\WorldExpo10\original\test_perspective\202201.mat'; 8 | target_perspective_path = 'D:\Dataset\WorldExpo10\original\test_perspective\500717.mat'; 9 | 10 | train_perspective_path = 'D:\Dataset\WorldExpo10\original\train_perspective'; 11 | train_perspective_dir = dir(fullfile(train_perspective_path,'*.mat')); 12 | train_perspective_name_list = {train_perspective_dir.name}; 13 | 14 | load(target_perspective_path); 15 | target_perspective_map = pMap; 16 | 17 | absolute_error_list = zeros(103, 1); 18 | 19 | for i = 1:103 20 | load([train_perspective_path '\' train_perspective_name_list{i}]); 21 | train_perspective_map = pMap; 22 | absolute_error_map = abs(target_perspective_map - train_perspective_map); 23 | absolute_error = sum(absolute_error_map(:)); 24 | absolute_error_list(i) = absolute_error; 25 | end 26 | 27 | [~, sort_idx] = sort(absolute_error_list, 'ascend'); 28 | for i = 1:4 29 | k = sort_idx(i); 30 | disp(train_perspective_name_list{k}); 31 | disp(absolute_error_list(k)); 32 | end -------------------------------------------------------------------------------- /data_preparation/view_density_map.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laridzhang/ASNet/059a963fb286e5d03633951f98d703d84e925ef1/data_preparation/view_density_map.m -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laridzhang/ASNet/059a963fb286e5d03633951f98d703d84e925ef1/src/__init__.py -------------------------------------------------------------------------------- /src/computing_bar.py: -------------------------------------------------------------------------------- 1 | import progressbar 2 | import sys 3 | 4 | 5 | class Bar(): 6 | def __init__(self, title='computing', max_value=100): 7 | self.count = 0 8 | widgets = [title, ': ', progressbar.Percentage(), ' ', progressbar.Bar(marker='#', left='[', right=']'), ' ', progressbar.ETA(), ] 9 | self.bar = progressbar.ProgressBar(widgets=widgets, max_value=max_value).start() 10 | 11 | def update(self): 12 | self.count += 1 13 | self.bar.update(self.count) 14 | sys.stdout.flush() 15 | 16 | def done(self): 17 | self.bar.finish() 18 | sys.stdout.flush() -------------------------------------------------------------------------------- /src/crowd_count.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | import torch.nn as nn 4 | # import numpy as np 5 | import torch.nn.functional as functional 6 | from functools import reduce 7 | 8 | # from src import network 9 | from src.models import Model 10 | from src.utils import build_block 11 | 12 | 13 | # only calculate with pixels which have value 14 | # pool size=stride=4 or 2 15 | # unless specified, the default Gaussian kernel size is 15 and sigma is 4 16 | # the first threshold is the average density of pooling ground truth 17 | # the next threshold is the average pooling density of the area remaining after excluding areas with a density below the previous threshold 18 | # different pooling density maps are inconsistent in size 19 | dataset_density_level = dict() 20 | dataset_density_level['shtA1_train_4_2'] = (3.206821266542162, 1.7318443746056567) 21 | dataset_density_level['shtA1_train_8_4'] = (8.464520906327428, 4.9102145881524950) 22 | 23 | 24 | class CrowdCount(nn.Module): 25 | def __init__(self): 26 | super(CrowdCount, self).__init__() 27 | self.features = Model() 28 | self.my_loss = None 29 | self.this_dataset_density_level = dataset_density_level['shtA1_train_8_4'] 30 | 31 | @property 32 | def loss(self): 33 | return self.my_loss 34 | 35 | def forward(self, im_data, roi, ground_truth=None): 36 | estimate_map, foreground_mask, visual_dict = self.features(im_data.cuda(), roi.cuda()) 37 | 38 | if self.training: 39 | self.my_loss, loss_dict = self.build_loss(ground_truth.cuda(), estimate_map, foreground_mask) 40 | else: 41 | loss_dict = None 42 | 43 | return estimate_map, loss_dict, visual_dict 44 | 45 | def build_loss(self, ground_truth_map, estimate_map, foreground_mask): 46 | if ground_truth_map.shape != estimate_map.shape: 47 | raise Exception('shapes of ground_truth_map and estimate_map are mismatch') 48 | if ground_truth_map.shape != foreground_mask.shape: 49 | raise Exception('shapes of ground_truth_map and foreground_mask are mismatch') 50 | 51 | ground_truth_map = ground_truth_map * foreground_mask 52 | estimate_map = estimate_map * foreground_mask 53 | 54 | pool8_loss_map = self.pooling_loss_map(ground_truth_map, estimate_map, 8) 55 | pool4_loss_map = self.pooling_loss_map(ground_truth_map, estimate_map, 4) 56 | 57 | foreground_active_for_pool8 = functional.interpolate(foreground_mask, scale_factor=1 / 8.0, mode='nearest') 58 | foreground_active_for_pool4 = functional.interpolate(foreground_mask, scale_factor=1 / 4.0, mode='nearest') 59 | 60 | pool8_deactive = build_block(ground_truth_map, 8) 61 | pool8_deactive[pool8_deactive < self.this_dataset_density_level[0]] = 0.0 62 | pool8_deactive[pool8_deactive > 0] = 1.0 63 | pool8_active = 1 - pool8_deactive 64 | 65 | pool8_deactive_for_pool4 = functional.interpolate(pool8_deactive, scale_factor=2.0, mode='nearest') 66 | pool4_active = torch.ones_like(pool4_loss_map) 67 | if pool8_deactive_for_pool4.shape != pool4_active.shape: 68 | raise Exception('active map mismatch') 69 | pool4_active = pool4_active * pool8_deactive_for_pool4 70 | 71 | pool8_active = pool8_active * foreground_active_for_pool8 72 | pool4_active = pool4_active * foreground_active_for_pool4 73 | 74 | pool8_active_sum = torch.sum(pool8_active) 75 | pool4_active_sum = torch.sum(pool4_active) 76 | 77 | pool8_loss = torch.sum(pool8_loss_map * pool8_active) / pool8_active_sum if pool8_active_sum > 0 else torch.sum(pool8_loss_map * pool8_active) 78 | pool4_loss = torch.sum(pool4_loss_map * pool4_active) / pool4_active_sum if pool4_active_sum > 0 else torch.sum(pool4_loss_map * pool4_active) 79 | total_loss = pool8_loss * 4 + pool4_loss 80 | 81 | loss_dict = dict() 82 | loss_dict['pool8'] = pool8_loss 83 | loss_dict['pool4'] = pool4_loss 84 | loss_dict['total'] = total_loss 85 | return total_loss, loss_dict 86 | 87 | @staticmethod 88 | def pooling_loss_map(ground_truth, estimate, block_size=4): 89 | square_error = (ground_truth - estimate) ** 2 90 | element_amount = reduce(lambda x, y: x * y, square_error.shape) 91 | block_square_error = build_block(square_error / element_amount, block_size) 92 | block_ground_truth = build_block(ground_truth, block_size) 93 | block_loss = block_square_error / (block_ground_truth + 1) 94 | return block_loss 95 | -------------------------------------------------------------------------------- /src/data_multithread.py: -------------------------------------------------------------------------------- 1 | # load data in multithreading 2 | 3 | import numpy as np 4 | import os 5 | import cv2 6 | import random 7 | import pandas 8 | import scipy.io as scio 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils.data import Dataset, DataLoader 12 | import psutil 13 | import warnings 14 | from os.path import join 15 | 16 | from src.utils import ndarray_to_tensor 17 | from src.data_path import DataPath 18 | 19 | 20 | DOWNSAMPLE = 8 # The output of the network needs to be downsampled to one of DOWNSAMPLE 21 | NUMBER_OF_LABELS = 2 22 | 23 | 24 | class Data(Dataset): 25 | def __init__(self, dataset_name, image_path, density_map_path, roi_path=None, is_label=False, is_mask=False): 26 | # image_path: path of all image file 27 | # density_map_path: path of all density map file 28 | # roi_path: path of all region of interest file 29 | self.dataset_name = dataset_name 30 | self.image_path = image_path 31 | self.density_map_path = density_map_path 32 | self.roi_path = roi_path 33 | self.is_label = is_label 34 | self.is_mask = is_mask 35 | 36 | self.image_channel = 3 37 | 38 | # get all image file name 39 | self.image_filename_list = [filename for filename in os.listdir(self.image_path) if os.path.isfile(join(self.image_path, filename))] 40 | self.image_filename_list.sort() 41 | 42 | self.number_of_samples = len(self.image_filename_list) 43 | 44 | def __len__(self): 45 | return self.number_of_samples 46 | 47 | def __getitem__(self, index): 48 | this_blob = self.read_blob(self.image_filename_list[index]) 49 | return this_blob 50 | 51 | def read_blob(self, filename): 52 | image_name, _ = os.path.splitext(filename) 53 | blob = dict() 54 | blob['image_name'] = image_name 55 | 56 | if self.image_channel == 1: 57 | image = cv2.imread(join(self.image_path, filename), 0) 58 | elif self.image_channel == 3: 59 | image = cv2.imread(join(self.image_path, filename), 1) 60 | else: 61 | raise Exception('invalid number of image channels') 62 | 63 | density_map = pandas.read_csv(join(self.density_map_path, image_name + '.csv'), sep=',', header=None).values 64 | 65 | if image.shape[0] != density_map.shape[0] or image.shape[1] != density_map.shape[1]: 66 | raise Exception('density map size mismatch.') 67 | 68 | density_map = self.downsample(density_map, DOWNSAMPLE) 69 | 70 | if self.roi_path is not None: 71 | if DOWNSAMPLE == 1: 72 | roi = self.load_roi(join(self.roi_path, image_name + '_roi.mat')) 73 | elif DOWNSAMPLE == 4: 74 | roi = self.load_roi(join(self.roi_path, image_name + '_roi_fourth_size.mat')) 75 | elif DOWNSAMPLE == 8: 76 | roi = self.load_roi(join(self.roi_path, image_name + '_roi_eighth_size.mat')) 77 | else: 78 | raise Exception('no suitable RoI file available') 79 | else: 80 | roi = None 81 | 82 | image = self.reshape_data(image) 83 | density_map = self.reshape_data(density_map) 84 | if roi is not None: 85 | roi = self.reshape_data(roi) 86 | if roi.shape[2] != density_map.shape[2] or roi.shape[3] != density_map.shape[3]: 87 | raise Exception('RoI size mismatch') 88 | 89 | blob['image'] = ndarray_to_tensor(image, is_cuda=False) 90 | blob['density'] = ndarray_to_tensor(density_map, is_cuda=False) 91 | if roi is not None: 92 | blob['roi'] = ndarray_to_tensor(roi, is_cuda=False) 93 | else: 94 | blob['roi'] = ndarray_to_tensor(np.ones_like(density_map), is_cuda=False) 95 | 96 | if self.is_label: 97 | blob['label'] = ndarray_to_tensor(self.get_label(blob['density']), is_cuda=False) 98 | 99 | if self.is_mask: 100 | blob['mask'] = ndarray_to_tensor(self.get_mask(blob['density']), is_cuda=False) 101 | return blob 102 | 103 | def compute_label(self, count): 104 | if count == 0: 105 | return 0 106 | # label = int(min(max(np.floor(np.log2(density * 3200 / DOWNSAMPLE ** 2)), 0), self.number_of_labels - 1)) 107 | # label = int(min(max(np.floor(np.log2(count / 10)), 0), self.number_of_labels - 1)) 108 | # return label 109 | else: 110 | return 1 111 | 112 | def get_label(self, density_map): 113 | # density_map torch.Tensor shape=(1, 1, h, w) 114 | if density_map.shape[0] != 1: 115 | raise Exception('invalid density map shape') 116 | # average_density = torch.mean(density_map) 117 | count = torch.sum(density_map) 118 | label = np.zeros(NUMBER_OF_LABELS, dtype=np.int) 119 | label[self.compute_label(count)] = 1 120 | return label 121 | 122 | def get_mask(self, ground_truth_map): 123 | # ground_truth numpy.ndarray shape=(1, 1, h, w) 124 | raise Exception('mask is not supported yet') 125 | 126 | def get_label_weights(self): 127 | return self.label_weights 128 | 129 | def get_number_of_samples(self): 130 | return self.number_of_samples 131 | 132 | @staticmethod 133 | def downsample(density_map, downsample_value=1): 134 | # height and width of output density map are about 1/[downsample_value] times that of original density map 135 | import torch 136 | import torch.nn.functional as functional 137 | 138 | if density_map.shape[0] % downsample_value != 0 or density_map.shape[1] % downsample_value != 0: 139 | raise Exception('density map size is not suitable for downsample value') 140 | 141 | density_map = density_map.reshape((1, 1, density_map.shape[0], density_map.shape[1])) 142 | if downsample_value > 1: 143 | density_map_tensor = torch.tensor(density_map, dtype=torch.float32) 144 | density_map_tensor = functional.avg_pool2d(density_map_tensor, downsample_value, stride=downsample_value) 145 | density_map = density_map_tensor.data.cpu().numpy() 146 | density_map = density_map * downsample_value * downsample_value 147 | density_map = density_map.reshape((density_map.shape[2], density_map.shape[3])) 148 | 149 | return density_map 150 | 151 | @staticmethod 152 | def load_roi(path): 153 | roi_mat = scio.loadmat(path) 154 | roi = roi_mat['roi'] 155 | raw_mask = roi['mask'] 156 | mask = raw_mask[0, 0] 157 | mask = mask.astype(np.float32, copy=False) 158 | return mask 159 | 160 | @staticmethod 161 | def reshape_data(data): 162 | # data numpy.ndarray shape=(x, y) or (x, y, 3) 163 | # return numpy.ndarray shape=(1, x, y) or (3, x, y) 164 | data = data.astype(np.float32, copy=False) 165 | height = data.shape[0] 166 | width = data.shape[1] 167 | if len(data.shape) == 3 and data.shape[2] == 3: 168 | # image_r = data[:, :, 0] 169 | # image_g = data[:, :, 1] 170 | # image_b = data[:, :, 2] 171 | # image = np.zeros((3, height, width), dtype=np.float32) 172 | # image[0] = image_r 173 | # image[1] = image_g 174 | # image[2] = image_b 175 | data = np.moveaxis(data, 2, 0) 176 | reshaped_data = data.reshape((3, height, width)) 177 | elif len(data.shape) == 2: 178 | reshaped_data = data.reshape((1, height, width)) 179 | else: 180 | raise Exception('Invalid data shape.') 181 | 182 | return reshaped_data 183 | 184 | 185 | def multithread_dataloader(data_config): 186 | # data_config: dict, a dictionay contains several datasets info, 187 | # key is dataset name, 188 | # value is a dict which contains is_preload and is_label and is_mask 189 | data_path = DataPath() 190 | 191 | data_dict = dict() 192 | 193 | for name in data_config: 194 | this_dataset_flag = data_config[name] 195 | if 'label' in this_dataset_flag: 196 | is_label = this_dataset_flag['label'] 197 | else: 198 | is_label = False 199 | if 'mask' in this_dataset_flag: 200 | is_mask = this_dataset_flag['mask'] 201 | else: 202 | is_mask = False 203 | if 'shuffle' in this_dataset_flag: 204 | is_shuffle = this_dataset_flag['shuffle'] 205 | else: 206 | is_shuffle = False 207 | if 'seed' in this_dataset_flag: 208 | random_seed = this_dataset_flag['seed'] 209 | else: 210 | random_seed = None 211 | if 'batch_size' in this_dataset_flag: 212 | batch_size = this_dataset_flag['batch_size'] 213 | else: 214 | batch_size = 1 215 | 216 | if random_seed is not None: 217 | def worker_init_fn(x): 218 | seed = random_seed + x 219 | np.random.seed(seed) 220 | random.seed(seed) 221 | torch.manual_seed(seed) 222 | return 223 | else: 224 | worker_init_fn = None 225 | 226 | path = data_path.get_path(name) 227 | this_data = Data(name, path['image'], path['gt'], roi_path=path['roi'], is_label=is_label, is_mask=is_mask) 228 | this_dataloader = DataLoader(this_data, batch_size=batch_size, shuffle=is_shuffle, num_workers=8, drop_last=False, worker_init_fn=worker_init_fn) 229 | 230 | if is_label: 231 | label_histogram = np.zeros(NUMBER_OF_LABELS) 232 | index = 0 233 | for blob in this_dataloader: 234 | label_histogram[torch.argmax(blob['label'])] += 1 235 | index += 1 236 | if index % 100 == 0: 237 | print('Built %6d of %d labels.' % (index, this_data.get_number_of_samples())) 238 | 239 | print('Completed building %d labels. Label histogram is %s' % (index, ' '.join([str(i) for i in label_histogram]))) 240 | label_weights = 1 - label_histogram / sum(label_histogram) 241 | label_weights = label_weights / sum(label_weights) 242 | else: 243 | label_weights = None 244 | 245 | this_dataset_dict = dict() 246 | this_dataset_dict['data'] = this_dataloader 247 | if is_label: 248 | this_dataset_dict['label_weights'] = ndarray_to_tensor(label_weights, is_cuda=False) 249 | 250 | data_dict[name] = this_dataset_dict 251 | 252 | return data_dict 253 | -------------------------------------------------------------------------------- /src/data_multithread_preload.py: -------------------------------------------------------------------------------- 1 | # load data in multithreading 2 | 3 | import numpy as np 4 | import os 5 | import cv2 6 | import random 7 | import pandas 8 | import scipy.io as scio 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as functional 12 | from torch.utils.data import Dataset, DataLoader 13 | import torchvision 14 | import psutil 15 | # import warnings 16 | from os.path import join 17 | import pickle 18 | import string 19 | # import shutil 20 | import datetime 21 | 22 | from src.utils import ndarray_to_tensor, print_red, make_path 23 | from src.data_path import DataPath 24 | 25 | 26 | DOWNSAMPLE = 8 # The output of the network needs to be downsampled to one of DOWNSAMPLE 27 | NUMBER_OF_LABELS = 3 28 | 29 | # only calculate with pixels which have value 30 | # pool size=8, stride=1, average of average density of pooled area 31 | # unless specified, the default Gaussian kernel size is 15 and sigma is 4 32 | dataset_average_density = dict() 33 | dataset_average_density['shtA1_train'] = 0.093094 34 | # dataset_average_density['shtB1_train'] = 0.033040 35 | dataset_average_density['shtB1_train'] = 0.033038 # gaussian kernel 15 sigma 15 36 | dataset_average_density['ucfQnrf1Resize1024_train'] = 0.133973 37 | dataset_average_density['ucf1_train1'] = 0.181556 38 | dataset_average_density['ucf1_train2'] = 0.186568 39 | dataset_average_density['ucf1_train3'] = 0.171988 40 | dataset_average_density['ucf1_train4'] = 0.170218 41 | dataset_average_density['ucf1_train5'] = 0.155635 42 | dataset_average_density['we1_train'] = 0.028217 43 | 44 | 45 | class PreloadData: 46 | def __init__(self, image_path, density_map_path, roi_path=None, is_preload=False, is_label=False, is_mask=False, is_transfrom=False, is_transform_in_gray=False): 47 | # image_path: path of all image file 48 | # density_map_path: path of all density map file 49 | # roi_path: path of all region of interest file 50 | self.image_path = image_path 51 | self.density_map_path = density_map_path 52 | self.roi_path = roi_path 53 | self.is_preload = is_preload 54 | self.is_label = is_label 55 | self.is_mask = is_mask 56 | self.is_transform = is_transfrom 57 | self.is_transform_in_gray = is_transform_in_gray 58 | 59 | self.image2tensor = torchvision.transforms.ToTensor() 60 | 61 | if self.is_transform: 62 | self.image2pil = torchvision.transforms.ToPILImage() 63 | self.color_jitter = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5) 64 | if self.is_transform_in_gray: 65 | self.image2gray = torchvision.transforms.Grayscale(num_output_channels=1) 66 | self.image2grayRGB = torchvision.transforms.Grayscale(num_output_channels=3) 67 | 68 | self.min_available_memory = 8 * 1024 * 1024 * 1024 # GB 69 | 70 | # make path for pickle 71 | time_now = datetime.datetime.now() 72 | self.pickle_path = os.path.join(r'/home/antec/PycharmProjects/pickle', '%4d%02d%02d%02d%02d%02d%06d_%s' % 73 | (time_now.year, time_now.month, time_now.day, time_now.hour, time_now.minute, time_now.second, time_now.microsecond, ''.join(random.sample(string.ascii_letters, 4)))) 74 | make_path(self.pickle_path) 75 | 76 | # get all image file name 77 | self.image_filename_list = [filename for filename in os.listdir(self.image_path) if os.path.isfile(join(self.image_path, filename))] 78 | self.image_filename_list.sort() 79 | 80 | self.number_of_samples = len(self.image_filename_list) 81 | 82 | self.preload_data_dict = dict() # store all preload data in this dict 83 | 84 | index = 0 85 | for filename in self.image_filename_list: 86 | if self.is_preload: 87 | if psutil.virtual_memory().available > self.min_available_memory: 88 | index += 1 89 | self.preload_data_dict[filename] = self.read_blob(filename) 90 | if index % 100 == 0: 91 | print('Loaded %6d of %d files.' % (index, self.number_of_samples)) 92 | else: 93 | self.preload_data_dict[filename] = None 94 | 95 | else: 96 | self.preload_data_dict[filename] = None 97 | print('Completed loading %d files. %d files are preloaded.' % (self.number_of_samples, index)) 98 | 99 | return 100 | 101 | def get_number_of_samples(self): 102 | return self.number_of_samples 103 | 104 | def get_blob_by_index(self, index): 105 | filename = self.image_filename_list[index] 106 | this_blob = self.preload_data_dict[filename] 107 | 108 | if this_blob is None: # no data is preloaded for this blob 109 | pickle_file_path = os.path.join(self.pickle_path, filename + '.pickle') 110 | if os.path.isfile(pickle_file_path): 111 | with open(pickle_file_path, 'rb') as file: 112 | this_blob = pickle.load(file) 113 | else: 114 | this_blob = self.read_blob(filename) 115 | with open(pickle_file_path, 'wb') as file: 116 | pickle.dump(this_blob, file) 117 | 118 | # transform image 119 | if self.is_transform: 120 | image = this_blob['image'] 121 | image = self.image2pil(image) 122 | # image.show() 123 | if self.is_transform_in_gray: 124 | image = self.image2gray(image) 125 | image = self.color_jitter(image) 126 | image = self.image2grayRGB(image) 127 | else: 128 | image = self.color_jitter(image) 129 | # image.show() 130 | image = self.image2tensor(image) 131 | 132 | this_blob['image'] = image 133 | 134 | return this_blob 135 | 136 | def read_blob(self, filename): 137 | image_name, _ = os.path.splitext(filename) 138 | blob = dict() 139 | blob['image_name'] = image_name 140 | 141 | # read image 142 | image = cv2.imread(join(self.image_path, filename), 1) 143 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 144 | 145 | density_map = pandas.read_csv(join(self.density_map_path, image_name + '.csv'), sep=',', header=None).values 146 | 147 | if image.shape[0] != density_map.shape[0] or image.shape[1] != density_map.shape[1]: 148 | raise Exception('density map size mismatch.') 149 | 150 | density_map = self.downsample(density_map, DOWNSAMPLE) 151 | 152 | if self.roi_path is not None: 153 | # if DOWNSAMPLE == 1: 154 | # roi = self.load_roi(join(self.roi_path, image_name + '_roi.mat')) 155 | # elif DOWNSAMPLE == 4: 156 | # roi = self.load_roi(join(self.roi_path, image_name + '_roi_fourth_size.mat')) 157 | # elif DOWNSAMPLE == 8: 158 | # roi = self.load_roi(join(self.roi_path, image_name + '_roi_eighth_size.mat')) 159 | # else: 160 | # raise Exception('no suitable RoI file available') 161 | roi = self.load_roi(join(self.roi_path, image_name + '_roi.mat')) 162 | else: 163 | roi = None 164 | 165 | # image = self.reshape_data(image) 166 | image = self.image2tensor(image) 167 | density_map = self.reshape_data(density_map) 168 | if roi is not None: 169 | roi = self.reshape_data(roi) 170 | if roi.shape[1] != image.shape[1] or roi.shape[2] != image.shape[2]: 171 | raise Exception('RoI size mismatch') 172 | else: 173 | roi = np.ones((1, image.shape[1], image.shape[2])) 174 | 175 | if isinstance(image, torch.Tensor): 176 | blob['image'] = image 177 | else: 178 | blob['image'] = ndarray_to_tensor(image, is_cuda=False) 179 | blob['density'] = ndarray_to_tensor(density_map, is_cuda=False) 180 | blob['roi'] = ndarray_to_tensor(roi, is_cuda=False) 181 | 182 | if self.is_label: 183 | blob['label'] = ndarray_to_tensor(self.get_label(blob['density']), is_cuda=False) 184 | 185 | if self.is_mask: 186 | blob['mask'] = ndarray_to_tensor(self.get_mask(blob['density'], blob['roi']), is_cuda=False) 187 | 188 | return blob 189 | 190 | def compute_label(self, count): 191 | if count == 0: 192 | return 0 193 | # label = int(min(max(np.floor(np.log2(density * 3200 / DOWNSAMPLE ** 2)), 0), NUMBER_OF_LABELS - 1)) 194 | label = int(min(max(np.floor(np.log2(count / 10)), 1), NUMBER_OF_LABELS - 1)) 195 | return label 196 | 197 | def get_label(self, density_map): 198 | # density_map torch.Tensor shape=(1, 1, h, w) 199 | if density_map.shape[0] != 1: 200 | raise Exception('invalid density map shape') 201 | # average_density = torch.mean(density_map) 202 | count = torch.sum(density_map) 203 | label = np.zeros(NUMBER_OF_LABELS, dtype=np.int) 204 | label[self.compute_label(count)] = 1 205 | return label 206 | 207 | def get_mask(self, ground_truth_map, roi, pool_size=8, bins=(0, 1.00 * dataset_average_density['ucf1_train1'])): 208 | # ground_truth numpy.ndarray shape=(1, 1, h, w) 209 | if len(bins) != NUMBER_OF_LABELS - 1 or len(bins) <= 1: 210 | raise Exception('invalid bins (%s)' % ', '.join([str(i) for i in bins])) 211 | 212 | if pool_size % 2 == 0: 213 | pad_size = (pool_size / 2, pool_size / 2 - 1, pool_size / 2, pool_size / 2 - 1) 214 | pad_size = *(int(i) for i in pad_size), 215 | else: 216 | pad_size = int((pool_size - 1) / 2) 217 | 218 | m = nn.Sequential(nn.ZeroPad2d(pad_size), 219 | nn.AvgPool2d(pool_size, stride=1, padding=0, count_include_pad=False)) 220 | pooled_map = m(ground_truth_map) 221 | 222 | if pooled_map.shape != ground_truth_map.shape: 223 | raise Exception('pooled map and ground truth map mismatch') 224 | 225 | resized_roi = functional.interpolate(roi.unsqueeze(0), scale_factor=1 / DOWNSAMPLE, mode='nearest')[0] 226 | 227 | if resized_roi.shape != pooled_map.shape: 228 | raise Exception('resized roi and pooled map mismatch') 229 | 230 | pooled_map = pooled_map * resized_roi 231 | 232 | mask_list = list() 233 | old_mask = None 234 | for bin in bins: 235 | this_mask = (pooled_map <= bin).to(torch.int64) 236 | if old_mask is not None: 237 | this_mask = this_mask * (1 - old_mask) 238 | old_mask = old_mask + this_mask 239 | else: 240 | old_mask = this_mask 241 | mask_list.append(this_mask) 242 | mask_list.append(1 - old_mask) 243 | final_mask = torch.cat(mask_list) 244 | return final_mask.data.numpy() 245 | 246 | def get_label_weights(self): 247 | return self.label_weights 248 | 249 | @staticmethod 250 | def downsample(density_map, downsample_value=1): 251 | # height and width of output density map are about 1/[downsample_value] times that of original density map 252 | import torch 253 | import torch.nn.functional as functional 254 | 255 | if density_map.shape[0] % downsample_value != 0 or density_map.shape[1] % downsample_value != 0: 256 | raise Exception('density map size is not suitable for downsample value') 257 | 258 | density_map = density_map.reshape((1, 1, density_map.shape[0], density_map.shape[1])) 259 | if downsample_value > 1: 260 | density_map_tensor = torch.tensor(density_map, dtype=torch.float32) 261 | density_map_tensor = functional.avg_pool2d(density_map_tensor, downsample_value, stride=downsample_value) 262 | density_map = density_map_tensor.data.cpu().numpy() 263 | density_map = density_map * downsample_value * downsample_value 264 | density_map = density_map.reshape((density_map.shape[2], density_map.shape[3])) 265 | 266 | return density_map 267 | 268 | @staticmethod 269 | def load_roi(path): 270 | roi_mat = scio.loadmat(path) 271 | roi = roi_mat['roi'] 272 | raw_mask = roi['mask'] 273 | mask = raw_mask[0, 0] 274 | mask = mask.astype(np.float32, copy=False) 275 | return mask 276 | 277 | @staticmethod 278 | def reshape_data(data): 279 | # data numpy.ndarray shape=(x, y) or (x, y, 3) 280 | # return numpy.ndarray shape=(1, x, y) or (3, x, y) 281 | data = data.astype(np.float32, copy=False) 282 | height = data.shape[0] 283 | width = data.shape[1] 284 | if len(data.shape) == 3 and data.shape[2] == 3: 285 | # image_r = data[:, :, 0] 286 | # image_g = data[:, :, 1] 287 | # image_b = data[:, :, 2] 288 | # image = np.zeros((3, height, width), dtype=np.float32) 289 | # image[0] = image_r 290 | # image[1] = image_g 291 | # image[2] = image_b 292 | data = np.moveaxis(data, 2, 0) 293 | reshaped_data = data.reshape((3, height, width)) 294 | elif len(data.shape) == 2: 295 | reshaped_data = data.reshape((1, height, width)) 296 | else: 297 | raise Exception('Invalid data shape.') 298 | 299 | return reshaped_data 300 | 301 | 302 | class Data(Dataset): 303 | def __init__(self, preload_data): 304 | # image_path: path of all image file 305 | # density_map_path: path of all density map file 306 | # roi_path: path of all region of interest file 307 | self.preload_data = preload_data 308 | 309 | def __len__(self): 310 | return self.preload_data.get_number_of_samples() 311 | 312 | def __getitem__(self, index): 313 | return self.preload_data.get_blob_by_index(index) 314 | 315 | 316 | def multithread_dataloader(data_config): 317 | # data_config: dict, a dictionay contains several datasets info, 318 | # key is dataset name, 319 | # value is a dict which contains is_preload and is_label and is_mask 320 | data_path = DataPath() 321 | 322 | data_dict = dict() 323 | 324 | for name in data_config: 325 | this_dataset_flag = data_config[name] 326 | is_preload = this_dataset_flag['preload'] 327 | if 'label' in this_dataset_flag: 328 | is_label = this_dataset_flag['label'] 329 | else: 330 | is_label = False 331 | if 'mask' in this_dataset_flag: 332 | is_mask = this_dataset_flag['mask'] 333 | else: 334 | is_mask = False 335 | if 'shuffle' in this_dataset_flag: 336 | is_shuffle = this_dataset_flag['shuffle'] 337 | else: 338 | is_shuffle = False 339 | if 'seed' in this_dataset_flag: 340 | random_seed = this_dataset_flag['seed'] 341 | else: 342 | random_seed = None 343 | if 'batch_size' in this_dataset_flag: 344 | batch_size = this_dataset_flag['batch_size'] 345 | else: 346 | batch_size = 1 347 | if 'transform' in this_dataset_flag: 348 | is_transform = this_dataset_flag['transform'] 349 | if 'transform_in_gray' in this_dataset_flag: 350 | is_transform_in_gray = this_dataset_flag['transform_in_gray'] 351 | else: 352 | is_transform_in_gray = False 353 | else: 354 | is_transform = False 355 | is_transform_in_gray = False 356 | 357 | 358 | if random_seed is not None: 359 | def worker_init_fn(x): 360 | seed = random_seed + x 361 | np.random.seed(seed) 362 | random.seed(seed) 363 | torch.manual_seed(seed) 364 | return 365 | else: 366 | worker_init_fn = None 367 | 368 | path = data_path.get_path(name) 369 | preload_data = PreloadData(path['image'], path['gt'], roi_path=path['roi'], is_preload=is_preload, is_label=is_label, is_mask=is_mask, is_transfrom=is_transform, is_transform_in_gray=is_transform_in_gray) 370 | this_data = Data(preload_data) 371 | this_dataloader = DataLoader(this_data, batch_size=batch_size, shuffle=is_shuffle, num_workers=8, drop_last=False, worker_init_fn=worker_init_fn) 372 | 373 | if is_label: 374 | number_of_samples = preload_data.get_number_of_samples() 375 | label_histogram = np.zeros(NUMBER_OF_LABELS) 376 | index = 0 377 | for blob in this_dataloader: 378 | labels = torch.argmax(blob['label'], dim=1, keepdim=True) 379 | for this_label in labels: 380 | label_histogram[this_label] += 1 381 | index += 1 382 | if index % 100 == 0: 383 | print('Built %6d of %d labels.' % (index, number_of_samples)) 384 | 385 | print('Completed building %d labels. Label histogram is %s' % (index, ' '.join([str(i) for i in label_histogram]))) 386 | label_weights = 1 - label_histogram / sum(label_histogram) 387 | label_weights = label_weights / sum(label_weights) 388 | else: 389 | label_weights = None 390 | 391 | this_dataset_dict = dict() 392 | this_dataset_dict['data'] = this_dataloader 393 | if is_label: 394 | this_dataset_dict['label_weights'] = ndarray_to_tensor(label_weights, is_cuda=False) 395 | else: 396 | this_dataset_dict['label_weights'] = None 397 | 398 | data_dict[name] = this_dataset_dict 399 | 400 | return data_dict 401 | -------------------------------------------------------------------------------- /src/data_path.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class DataPath: 5 | def __init__(self): 6 | self.base_path_list = list() 7 | self.base_path_list.append(r'/home/antec/PycharmProjects/') 8 | self.base_path_list.append(r'/media/antec/storage/PycharmProjects') 9 | 10 | self.data_path = dict() 11 | 12 | path = dict() 13 | path['image'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_overturn_rgb_64_64/train' 14 | path['gt'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_overturn_rgb_64_64/train_den' 15 | path['roi'] = None 16 | self.data_path['shtA0RandomOverturn_64_64_train'] = path 17 | 18 | path = dict() 19 | path['image'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_flip_rgb_128_128/train' 20 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_flip_rgb_128_128/train_den' 21 | path['roi'] = None 22 | self.data_path['shtA0RandFlip_128_128_train'] = path 23 | 24 | path = dict() 25 | path['image'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_flip_rgb_128_128/train' 26 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_flip_rgb_128_128/train_den' 27 | path['roi'] = None 28 | self.data_path['shtA0RandFlip_128_128_trainAsVali'] = path 29 | 30 | path = dict() 31 | path['image'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_overturn_rgb_128_256/train' 32 | path['gt'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_overturn_rgb_128_256/train_den' 33 | path['roi'] = None 34 | self.data_path['shtA0RandomOverturn_128_256_train'] = path 35 | 36 | path = dict() 37 | path['image'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_overturn_rgb_128_256_more_than_one_pedestrain/train' 38 | path['gt'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_overturn_rgb_128_256_more_than_one_pedestrain/train_den' 39 | path['roi'] = None 40 | self.data_path['shtA0RandomOverturn_128_256_more1_train'] = path 41 | 42 | path = dict() 43 | path['image'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_flip_rgb_128_256_more_than_ten_pedestrian/train' 44 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_0_random_flip_rgb_128_256_more_than_ten_pedestrian/train_den' 45 | path['roi'] = None 46 | self.data_path['shtA0RandFlip_128_256_more10_train'] = path 47 | 48 | path = dict() 49 | path['image'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_resize_1_rgb/test' 50 | path['gt'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_resize_1_rgb/test_den' 51 | path['roi'] = None 52 | self.data_path['shtA1Resize1_test'] = path 53 | 54 | path = dict() 55 | path['image'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_resize_1_rgb/train' 56 | path['gt'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_resize_1_rgb/train_den' 57 | path['roi'] = None 58 | self.data_path['shtA1Resize1_train'] = path 59 | 60 | path = dict() 61 | path['image'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_resize_1_rgb_times32/test' 62 | path['gt'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_resize_1_rgb_times32/test_den' 63 | path['roi'] = None 64 | self.data_path['shtA1Resize1Times32_test'] = path 65 | 66 | path = dict() 67 | path['image'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_rgb/train' 68 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_rgb/train_den' 69 | path['roi'] = None 70 | self.data_path['shtA1_train'] = path 71 | 72 | path = dict() 73 | path['image'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_rgb/test' 74 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_1_rgb/test_den' 75 | path['roi'] = None 76 | self.data_path['shtA1_test'] = path 77 | 78 | path = dict() 79 | path['image'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_9_random_overturn_rgb/train' 80 | path['gt'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_9_random_overturn_rgb/train_den' 81 | path['roi'] = None 82 | self.data_path['shtA9RandomOverturn_train'] = path 83 | 84 | path = dict() 85 | path['image'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_9_random_overturn_rgb/train_without_validation' 86 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_9_random_overturn_rgb/train_without_validation_den' 87 | path['roi'] = None 88 | self.data_path['shtA9RandFlip_trainNoVali'] = path 89 | 90 | path = dict() 91 | path['image'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_9_random_overturn_rgb/validation' 92 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_9_random_overturn_rgb/validation_den' 93 | path['roi'] = None 94 | self.data_path['shtA9RandFlip_vali'] = path 95 | 96 | path = dict() 97 | path['image'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_9_random_overturn_rgb_times32/train' 98 | path['gt'] = 'data/shanghaitech/formatted_trainval_15_4/shanghaitech_part_A_patches_9_random_overturn_rgb_times32/train_den' 99 | path['roi'] = None 100 | self.data_path['shtA9RandomOverturnTimes32_train'] = path 101 | 102 | path = dict() 103 | path['image'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_0_random_flip_128_128_rgb/train' 104 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_0_random_flip_128_128_rgb/train_den' 105 | path['roi'] = None 106 | self.data_path['shtB0RandFlip_128_128_train'] = path 107 | 108 | path = dict() 109 | path['image'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_0_random_flip_more1pedestrian_128_128_rgb/train' 110 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_0_random_flip_more1pedestrian_128_128_rgb/train_den' 111 | path['roi'] = None 112 | self.data_path['shtB0RandFlipMore1_128_128_train'] = path 113 | 114 | path = dict() 115 | path['image'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_1_rgb/train' 116 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_1_rgb/train_den' 117 | path['roi'] = None 118 | self.data_path['shtB1_train'] = path 119 | 120 | path = dict() 121 | path['image'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_1_rgb/test' 122 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_1_rgb/test_den' 123 | path['roi'] = None 124 | self.data_path['shtB1_test'] = path 125 | 126 | path = dict() 127 | path['image'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_9_random_flip_rgb/train' 128 | path['gt'] = r'data/shanghaitech/formatted_trainval_15_15/shanghaitech_part_B_patches_9_random_flip_rgb/train_den' 129 | path['roi'] = None 130 | self.data_path['shtB9RandFlip_train'] = path 131 | 132 | path = dict() 133 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/1/train' 134 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/1/train_den' 135 | path['roi'] = None 136 | self.data_path['ucf0RandFlip_128_128_train1'] = path 137 | 138 | path = dict() 139 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/2/train' 140 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/2/train_den' 141 | path['roi'] = None 142 | self.data_path['ucf0RandFlip_128_128_train2'] = path 143 | 144 | path = dict() 145 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/3/train' 146 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/3/train_den' 147 | path['roi'] = None 148 | self.data_path['ucf0RandFlip_128_128_train3'] = path 149 | 150 | path = dict() 151 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/4/train' 152 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/4/train_den' 153 | path['roi'] = None 154 | self.data_path['ucf0RandFlip_128_128_train4'] = path 155 | 156 | path = dict() 157 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/5/train' 158 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_0_random_rgb_flip_128_128/5/train_den' 159 | path['roi'] = None 160 | self.data_path['ucf0RandFlip_128_128_train5'] = path 161 | 162 | path = dict() 163 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/1/train' 164 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/1/train_den' 165 | path['roi'] = None 166 | self.data_path['ucf1_train1'] = path 167 | 168 | path = dict() 169 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/2/train' 170 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/2/train_den' 171 | path['roi'] = None 172 | self.data_path['ucf1_train2'] = path 173 | 174 | path = dict() 175 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/3/train' 176 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/3/train_den' 177 | path['roi'] = None 178 | self.data_path['ucf1_train3'] = path 179 | 180 | path = dict() 181 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/4/train' 182 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/4/train_den' 183 | path['roi'] = None 184 | self.data_path['ucf1_train4'] = path 185 | 186 | path = dict() 187 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/5/train' 188 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/5/train_den' 189 | path['roi'] = None 190 | self.data_path['ucf1_train5'] = path 191 | 192 | path = dict() 193 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/1/val' 194 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/1/val_den' 195 | path['roi'] = None 196 | self.data_path['ucf1_test1'] = path 197 | 198 | path = dict() 199 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/2/val' 200 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/2/val_den' 201 | path['roi'] = None 202 | self.data_path['ucf1_test2'] = path 203 | 204 | path = dict() 205 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/3/val' 206 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/3/val_den' 207 | path['roi'] = None 208 | self.data_path['ucf1_test3'] = path 209 | 210 | path = dict() 211 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/4/val' 212 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/4/val_den' 213 | path['roi'] = None 214 | self.data_path['ucf1_test4'] = path 215 | 216 | path = dict() 217 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/5/val' 218 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_1_rgb/5/val_den' 219 | path['roi'] = None 220 | self.data_path['ucf1_test5'] = path 221 | 222 | path = dict() 223 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/1/train' 224 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/1/train_den' 225 | path['roi'] = None 226 | self.data_path['ucf9RandFlip_train1'] = path 227 | 228 | path = dict() 229 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/2/train' 230 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/2/train_den' 231 | path['roi'] = None 232 | self.data_path['ucf9RandFlip_train2'] = path 233 | 234 | path = dict() 235 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/3/train' 236 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/3/train_den' 237 | path['roi'] = None 238 | self.data_path['ucf9RandFlip_train3'] = path 239 | 240 | path = dict() 241 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/4/train' 242 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/4/train_den' 243 | path['roi'] = None 244 | self.data_path['ucf9RandFlip_train4'] = path 245 | 246 | path = dict() 247 | path['image'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/5/train' 248 | path['gt'] = r'data/ucf_cc_50/formatted_trainval_15_4/ucf_cc_50_patches_9_random_rgb_overturn/5/train_den' 249 | path['roi'] = None 250 | self.data_path['ucf9RandFlip_train5'] = path 251 | 252 | path = dict() 253 | path['image'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb_overturn/train' 254 | path['gt'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb_overturn/train_den' 255 | path['roi'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb_overturn/train_roi' 256 | self.data_path['we1Flip_train'] = path 257 | 258 | path = dict() 259 | path['image'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/train' 260 | path['gt'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/train_den' 261 | path['roi'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/train_roi' 262 | self.data_path['we1_train'] = path 263 | 264 | path = dict() 265 | path['image'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test/1' 266 | path['gt'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_den/1' 267 | path['roi'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_roi/all' 268 | self.data_path['we1_test1'] = path 269 | 270 | path = dict() 271 | path['image'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test/2' 272 | path['gt'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_den/2' 273 | path['roi'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_roi/all' 274 | self.data_path['we1_test2'] = path 275 | 276 | path = dict() 277 | path['image'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test/3' 278 | path['gt'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_den/3' 279 | path['roi'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_roi/all' 280 | self.data_path['we1_test3'] = path 281 | 282 | path = dict() 283 | path['image'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test/4' 284 | path['gt'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_den/4' 285 | path['roi'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_roi/all' 286 | self.data_path['we1_test4'] = path 287 | 288 | path = dict() 289 | path['image'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test/5' 290 | path['gt'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_den/5' 291 | path['roi'] = r'data/WorldExpo10/formatted_trainval_15_4/worldexpo_patches_1_rgb/test_roi/all' 292 | self.data_path['we1_test5'] = path 293 | 294 | path = dict() 295 | path['image'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_overturn_resize_1_rgb/train_all_val_overturn' 296 | path['gt'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_overturn_resize_1_rgb/train_all_val_overturn_den' 297 | path['roi'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_overturn_resize_1_rgb/train_all_val_overturn_roi' 298 | self.data_path['tran1FlipResize1_trainAllValiFlip'] = path 299 | 300 | path = dict() 301 | path['image'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_resize_1_rgb/val' 302 | path['gt'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_resize_1_rgb/val_den' 303 | path['roi'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_resize_1_rgb/val_roi' 304 | self.data_path['tran1Resize1_Vali'] = path 305 | 306 | path = dict() 307 | path['image'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_resize_1_rgb/test' 308 | path['gt'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_resize_1_rgb/test_den' 309 | path['roi'] = r'data/trancos/formatted_trainval_15_10/trancos_patches_1_resize_1_rgb/test_roi' 310 | self.data_path['tran1Resize1_test'] = path 311 | 312 | path = dict() 313 | path['image'] = r'data/mall/formatted_trainval_15_4/mall_patches_1_resize_05_rgb/train' 314 | path['gt'] = r'data/mall/formatted_trainval_15_4/mall_patches_1_resize_05_rgb/train_den' 315 | path['roi'] = r'data/mall/formatted_trainval_15_4/mall_patches_1_resize_05_rgb/train_roi' 316 | self.data_path['mall1Resize05_train'] = path 317 | 318 | path = dict() 319 | path['image'] = r'data/mall/formatted_trainval_15_4/mall_patches_1_resize_05_rgb/val' 320 | path['gt'] = r'data/mall/formatted_trainval_15_4/mall_patches_1_resize_05_rgb/val_den' 321 | path['roi'] = r'data/mall/formatted_trainval_15_4/mall_patches_1_resize_05_rgb/val_roi' 322 | self.data_path['mall1Resize05_val'] = path 323 | 324 | path = dict() 325 | path['image'] = r'data/airport/formatted_trainval_15_4/airport_patches_1_rgb/train' 326 | path['gt'] = r'data/airport/formatted_trainval_15_4/airport_patches_1_rgb/train_den' 327 | path['roi'] = r'data/airport/formatted_trainval_15_4/airport_patches_1_rgb/train_roi' 328 | self.data_path['air1_train'] = path 329 | 330 | path = dict() 331 | path['image'] = r'data/airport/formatted_trainval_15_4/airport_patches_1_rgb/test' 332 | path['gt'] = r'data/airport/formatted_trainval_15_4/airport_patches_1_rgb/test_den' 333 | path['roi'] = r'data/airport/formatted_trainval_15_4/airport_patches_1_rgb/test_roi' 334 | self.data_path['air1_test'] = path 335 | 336 | path = dict() 337 | path['image'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_0_random_flip_rgb_128_128_resize1024/train' 338 | path['gt'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_0_random_flip_rgb_128_128_resize1024/train_den' 339 | path['roi'] = None 340 | self.data_path['ucfQnrf0RandFlip_128_128_resize1024_train'] = path 341 | 342 | path = dict() 343 | path['image'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_0_random_flip_rgb_128_256_more1_resize1024/train' 344 | path['gt'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_0_random_flip_rgb_128_256_more1_resize1024/train_den' 345 | path['roi'] = None 346 | self.data_path['ucfQnrf0RandFlip_128_256_more1Resize1024_train'] = path 347 | 348 | path = dict() 349 | path['image'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_1_rgb_resize1024/train' 350 | path['gt'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_1_rgb_resize1024/train_den' 351 | path['roi'] = None 352 | self.data_path['ucfQnrf1Resize1024_train'] = path 353 | 354 | path = dict() 355 | path['image'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_1_rgb_resize1024/test' 356 | path['gt'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_1_rgb_resize1024/test_den' 357 | path['roi'] = None 358 | self.data_path['ucfQnrf1Resize1024_test'] = path 359 | 360 | path = dict() 361 | path['image'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_9_random_flip_rgb_resize1024/train' 362 | path['gt'] = r'data/ucf_qnrf/kernel_15_4/ucf_qnrf_patches_9_random_flip_rgb_resize1024/train_den' 363 | path['roi'] = None 364 | self.data_path['ucfQnrf9RandFlipResize1024_train'] = path 365 | 366 | def get_path(self, name): 367 | data_path_dict = self.data_path[name] 368 | abs_path_dict = dict() 369 | 370 | is_dir = False 371 | 372 | for base_path in self.base_path_list: 373 | for key in data_path_dict: 374 | if data_path_dict[key] is not None: 375 | this_abs_path = os.path.join(base_path, data_path_dict[key]) 376 | if os.path.isdir(this_abs_path): 377 | is_dir = True 378 | abs_path_dict[key] = this_abs_path 379 | else: 380 | break 381 | else: 382 | abs_path_dict[key] = None 383 | 384 | if is_dir: 385 | break 386 | 387 | for key in data_path_dict: 388 | if not key in abs_path_dict: 389 | raise Exception('invalid key in absolute data path dict') 390 | 391 | return abs_path_dict 392 | -------------------------------------------------------------------------------- /src/evaluate_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import os 6 | 7 | from src.crowd_count import CrowdCount 8 | from src import network 9 | from src.utils import ndarray_to_tensor 10 | from src.psnr import build_psnr 11 | from src.ssim import SSIM 12 | 13 | 14 | def evaluate_model(model_path, data): 15 | net = CrowdCount() 16 | network.load_net(model_path, net) 17 | net.cuda() 18 | net.eval() 19 | 20 | build_ssim = SSIM(window_size=11) 21 | 22 | game = GridAverageMeanAbsoluteError() 23 | 24 | mae = 0.0 25 | mse = 0.0 26 | psnr = 0.0 27 | ssim = 0.0 28 | game_0 = 0.0 29 | game_1 = 0.0 30 | game_2 = 0.0 31 | game_3 = 0.0 32 | index = 0 33 | 34 | for blob in data: 35 | image_data = blob['image'] 36 | ground_truth_data = blob['density'] 37 | roi = blob['roi'] 38 | # filename = blob['filename'] 39 | 40 | if image_data.shape[0] != 1: 41 | raise Exception('invalid image batch size (%d) for evaluation' % image_data.shape[0]) 42 | 43 | with torch.no_grad(): 44 | estimate_map, _, _ = net(image_data, roi) 45 | 46 | ground_truth_data = ground_truth_data.data.cpu().numpy() 47 | density_map = estimate_map.data.cpu().numpy() 48 | 49 | ground_truth_count = np.sum(ground_truth_data) 50 | estimate_count = np.sum(density_map) 51 | 52 | mae += abs(ground_truth_count - estimate_count) 53 | mse += (ground_truth_count - estimate_count) ** 2 54 | psnr += build_psnr(ground_truth_data, density_map) 55 | ssim += build_ssim(ndarray_to_tensor(ground_truth_data), ndarray_to_tensor(density_map)).item() 56 | game_0 += game.calculate_error(ground_truth_data, density_map, 0) 57 | game_1 += game.calculate_error(ground_truth_data, density_map, 1) 58 | game_2 += game.calculate_error(ground_truth_data, density_map, 2) 59 | game_3 += game.calculate_error(ground_truth_data, density_map, 3) 60 | index += 1 61 | 62 | result_dict = dict() 63 | result_dict['name'] = os.path.basename(model_path) 64 | result_dict['number'] = int(index) 65 | result_dict['mae'] = float(mae / index) 66 | result_dict['mse'] = float(np.sqrt(mse / index)) 67 | result_dict['psnr'] = float(psnr / index) 68 | result_dict['ssim'] = float(ssim / index) 69 | result_dict['game_0'] = float(game_0 / index) 70 | result_dict['game_1'] = float(game_1 / index) 71 | result_dict['game_2'] = float(game_2 / index) 72 | result_dict['game_3'] = float(game_3 / index) 73 | 74 | return result_dict 75 | 76 | 77 | class GridAverageMeanAbsoluteError: 78 | @staticmethod 79 | def calculate_error(ground_truth, estimate, L=0): 80 | # grid average mean absolute error 81 | # ground_truth Tensor: shape=(1, 1, h, w) 82 | # estimate Tensor: same shape of ground_truth 83 | ground_truth = ndarray_to_tensor(ground_truth, is_cuda=True) 84 | estimate = ndarray_to_tensor(estimate, is_cuda=True) 85 | height = ground_truth.shape[2] 86 | width = ground_truth.shape[3] 87 | times = math.sqrt(math.pow(4, L)) 88 | padding_height = int(math.ceil(height / times) * times - height) 89 | padding_width = int(math.ceil(width / times) * times - width) 90 | if padding_height != 0 or padding_width != 0: 91 | m = nn.ZeroPad2d((0, padding_width, 0, padding_height)) 92 | ground_truth = m(ground_truth) 93 | estimate = m(estimate) 94 | height = ground_truth.shape[2] 95 | width = ground_truth.shape[3] 96 | m = nn.AdaptiveAvgPool2d(int(times)) 97 | ground_truth = m(ground_truth) * (height / times) * (width / times) 98 | estimate = m(estimate) * (height / times) * (width / times) 99 | game = torch.sum(torch.abs(ground_truth - estimate)) 100 | return game.item() 101 | 102 | # @staticmethod 103 | # def calculate_error_not_include_pad(ground_truth, estimate, L=0): 104 | # # grid average mean absolute error 105 | # # ground_truth Tensor: shape=(1, 1, h, w) 106 | # # estimate Tensor: same shape of ground_truth 107 | # ground_truth = ndarray_to_tensor(ground_truth, is_cuda=True) 108 | # estimate = ndarray_to_tensor(estimate, is_cuda=True) 109 | # height = ground_truth.shape[2] 110 | # width = ground_truth.shape[3] 111 | # times = math.sqrt(math.pow(4, L)) 112 | # grid_height = int(math.ceil(height / times)) 113 | # grid_width = int(math.ceil(width / times)) 114 | # padding_height = int(math.ceil((grid_height * times - height) / 2)) 115 | # padding_width = int(math.ceil((grid_width * times - width) / 2)) 116 | # m = nn.AvgPool2d((grid_height, grid_width), stride=(grid_height, grid_width), padding=(padding_height, padding_width), count_include_pad=False) 117 | # ground_truth = m(ground_truth) * (height / times) * (width / times) 118 | # estimate = m(estimate) * (height / times) * (width / times) 119 | # game = torch.sum(torch.abs(ground_truth - estimate)) 120 | # return game.item() 121 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | import cv2 5 | import numpy as np 6 | 7 | from src.network import Conv2d, ConvTranspose2d 8 | from src.utils import ndarray_to_tensor 9 | from src.data_multithread_preload import DOWNSAMPLE 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self): 14 | super(Model, self).__init__() 15 | 16 | self.prior = nn.Sequential(Conv2d(3, 64, 3, same_padding=True), 17 | Conv2d(64, 64, 3, same_padding=True), 18 | nn.MaxPool2d(2), 19 | Conv2d(64, 128, 3, same_padding=True), 20 | Conv2d(128, 128, 3, same_padding=True), 21 | nn.MaxPool2d(2), 22 | Conv2d(128, 256, 3, same_padding=True), 23 | Conv2d(256, 256, 3, same_padding=True), 24 | Conv2d(256, 256, 3, same_padding=True), 25 | nn.MaxPool2d(2), 26 | Conv2d(256, 512, 3, same_padding=True), 27 | Conv2d(512, 512, 3, same_padding=True), 28 | Conv2d(512, 512, 3, same_padding=True), 29 | nn.MaxPool2d(2), 30 | Conv2d(512, 512, 3, same_padding=True), 31 | Conv2d(512, 512, 3, same_padding=True), 32 | Conv2d(512, 512, 3, same_padding=True), 33 | Conv2d(512, 256, 1, same_padding=True), 34 | ConvTranspose2d(256, 128, 2, stride=2, padding=0), 35 | Conv2d(128, 128, 3, same_padding=True), 36 | Conv2d(128, 3, 1, same_padding=True)) 37 | 38 | self.vgg16 = nn.Sequential(Conv2d(3, 64, 3, same_padding=True), 39 | Conv2d(64, 64, 3, same_padding=True), 40 | nn.MaxPool2d(2), 41 | Conv2d(64, 128, 3, same_padding=True), 42 | Conv2d(128, 128, 3, same_padding=True), 43 | nn.MaxPool2d(2), 44 | Conv2d(128, 256, 3, same_padding=True), 45 | Conv2d(256, 256, 3, same_padding=True), 46 | Conv2d(256, 256, 3, same_padding=True), 47 | nn.MaxPool2d(2), 48 | Conv2d(256, 512, 3, same_padding=True), 49 | Conv2d(512, 512, 3, same_padding=True), 50 | Conv2d(512, 512, 3, same_padding=True), 51 | nn.MaxPool2d(2), 52 | Conv2d(512, 512, 3, same_padding=True), 53 | Conv2d(512, 512, 3, same_padding=True), 54 | Conv2d(512, 512, 3, same_padding=True), 55 | Conv2d(512, 256, 1, same_padding=True), 56 | ConvTranspose2d(256, 128, 2, stride=2, padding=0)) 57 | 58 | self.map = nn.Sequential(Conv2d(128, 128, 3, same_padding=True), 59 | Conv2d(128, 2, 1, same_padding=True)) 60 | 61 | self.scale = nn.Sequential(Conv2d(128, 128, 3, same_padding=True), 62 | Conv2d(128, 2, 1, same_padding=True, relu=False), 63 | nn.AdaptiveAvgPool2d(1), 64 | nn.Hardtanh(-1.0, 1.0)) 65 | 66 | def forward(self, im_data, roi=None): 67 | with torch.no_grad(): 68 | x_prior = self.prior(im_data) 69 | flag = torch.argmax(x_prior, dim=1, keepdim=True) 70 | 71 | background_mask = (flag == 0).to(torch.float32) 72 | foreground_mask = 1 - background_mask 73 | resized_foreground_mask = functional.interpolate(1 - background_mask, scale_factor=8.0, mode='nearest') 74 | 75 | # masks of foreground classes 76 | masks = None 77 | for i in range(1, x_prior.shape[1]): 78 | if masks is None: 79 | masks = (flag == i).to(torch.float32) 80 | else: 81 | masks = torch.cat((masks, (flag == i).to(torch.float32)), dim=1) 82 | 83 | dilate_size = 4 84 | if dilate_size > 1: 85 | _, number_of_classes, _, _ = masks.shape 86 | # pad mask for same size output 87 | if dilate_size % 2 == 0: 88 | pad_size = (dilate_size / 2, dilate_size / 2 - 1, dilate_size / 2, dilate_size / 2 - 1) 89 | pad_size = *(int(i) for i in pad_size), 90 | else: 91 | pad_size = int((dilate_size - 1) / 2) 92 | pad_size = (pad_size, pad_size, pad_size, pad_size) 93 | padded_mask = functional.pad(masks, pad_size, mode='constant', value=0) 94 | # dilate mask using convolution function 95 | padded_mask_list = torch.chunk(padded_mask, number_of_classes, dim=1) 96 | dilated_masks = None 97 | filters = torch.ones(1, 1, dilate_size, dilate_size).cuda() 98 | for i in range(number_of_classes): 99 | if dilated_masks is None: 100 | dilated_masks = torch.clamp(functional.conv2d(padded_mask_list[i], filters), 0, 1) * foreground_mask 101 | else: 102 | dilated_masks = torch.cat((dilated_masks, torch.clamp(functional.conv2d(padded_mask_list[i], filters), 0, 1) * foreground_mask), dim=1) 103 | else: 104 | dilated_masks = masks 105 | 106 | dilated_masks = torch.round(dilated_masks).to(torch.float32) 107 | 108 | x1 = self.vgg16(im_data * resized_foreground_mask) 109 | maps = self.map(x1) 110 | scales = self.scale(x1) + 1 111 | 112 | if dilated_masks.shape != maps.shape: 113 | raise Exception('mask and map mismatch') 114 | if dilated_masks.shape[1] != scales.shape[1]: 115 | raise Exception('mask and scale mismatch') 116 | 117 | flag = torch.sum(dilated_masks, 1, keepdim=True) + background_mask 118 | if torch.min(flag) < 1: # there should not be any zeros in flag 119 | raise Exception('invalid dilated masks') 120 | 121 | scaled_maps = maps * dilated_masks * scales 122 | 123 | scaled_map = torch.sum(scaled_maps, 1, keepdim=True) / flag 124 | density_map = torch.sum(scaled_map, 1, keepdim=True) 125 | 126 | resized_roi = functional.interpolate(roi, scale_factor=1 / DOWNSAMPLE, mode='nearest') 127 | density_map = density_map * resized_roi 128 | 129 | visual_dict = dict() 130 | visual_dict['density'] = density_map 131 | visual_dict['raw_maps'] = maps 132 | visual_dict['scaled_maps'] = scaled_maps 133 | visual_dict['masks'] = dilated_masks 134 | 135 | return density_map, foreground_mask, visual_dict -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class Conv2d(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, same_padding=False, dilation=1, groups=1, relu=True, bn=False): 8 | super(Conv2d, self).__init__() 9 | padding = int((kernel_size - 1) * dilation / 2) if same_padding else 0 10 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding, dilation=dilation, groups=groups) 11 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None 12 | self.relu = nn.ReLU(inplace=True) if relu else None 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | if self.bn is not None: 17 | x = self.bn(x) 18 | if self.relu is not None: 19 | x = self.relu(x) 20 | return x 21 | 22 | 23 | class ConvTranspose2d(nn.Module): 24 | def __init__(self, in_channels, out_channels, kernel_size=2, stride=2, relu=True, padding=0, bn=False, groups=1): 25 | super(ConvTranspose2d, self).__init__() 26 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=padding, groups=groups) 27 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None 28 | self.relu = nn.ReLU(inplace=True) if relu else None 29 | 30 | def forward(self, x): 31 | x = self.conv(x) 32 | if self.bn is not None: 33 | x = self.bn(x) 34 | if self.relu is not None: 35 | x = self.relu(x) 36 | return x 37 | 38 | 39 | class FC(nn.Module): 40 | def __init__(self, in_features, out_features, relu=True): 41 | super(FC, self).__init__() 42 | self.fc = nn.Linear(in_features, out_features) 43 | self.relu = nn.ReLU(inplace=True) if relu else None 44 | 45 | def forward(self, x): 46 | x = self.fc(x) 47 | if self.relu is not None: 48 | x = self.relu(x) 49 | return x 50 | 51 | 52 | class Fire(nn.Module): 53 | def __init__(self, in_channel, squeeze_channel, expand1x1_channel, expand3x3_channel, dilation=1, bn=False): 54 | super(Fire, self).__init__() 55 | self.squeeze = nn.Conv2d(in_channel, squeeze_channel, kernel_size=1) 56 | self.squeeze_bn = nn.BatchNorm2d(squeeze_channel, eps=0.001, momentum=0, affine=True) if bn else None 57 | self.squeeze_relu = nn.ReLU(inplace=True) 58 | 59 | self.expand1x1 = nn.Conv2d(squeeze_channel, expand1x1_channel, kernel_size=1) 60 | self.expand1x1_bn = nn.BatchNorm2d(expand1x1_channel, eps=0.001, momentum=0, affine=True) if bn else None 61 | self.expand1x1_relu = nn.ReLU(inplace=True) 62 | 63 | self.expand3x3 = nn.Conv2d(squeeze_channel, expand3x3_channel, kernel_size=3, padding=dilation, dilation=dilation) 64 | self.expand3x3_bn = nn.BatchNorm2d(expand3x3_channel, eps=0.001, momentum=0, affine=True) if bn else None 65 | self.expand3x3_relu = nn.ReLU(inplace=True) 66 | 67 | def forward(self, x): 68 | x = self.squeeze(x) 69 | if self.squeeze_bn is not None: 70 | x = self.squeeze_bn(x) 71 | x = self.squeeze_relu(x) 72 | 73 | x1 = self.expand1x1(x) 74 | if self.expand1x1_bn is not None: 75 | x1 = self.expand1x1_bn(x1) 76 | x1 = self.expand1x1_relu(x1) 77 | 78 | x3 = self.expand3x3(x) 79 | if self.expand3x3_bn is not None: 80 | x3 = self.expand3x3_bn(x3) 81 | x3 = self.expand3x3_relu(x3) 82 | 83 | return torch.cat((x1, x3), 1) 84 | 85 | 86 | def save_net(fname, net): 87 | import h5py 88 | h5f = h5py.File(fname, mode='w') 89 | for k, v in net.state_dict().items(): 90 | h5f.create_dataset(k, data=v.cpu().numpy()) 91 | 92 | 93 | def load_net(fname, net): 94 | import h5py 95 | # print('load from file: %s' % fname) 96 | h5f = h5py.File(fname, mode='r') 97 | for k, v in net.state_dict().items(): 98 | param = torch.from_numpy(np.asarray(h5f[k])) 99 | v.copy_(param) 100 | 101 | 102 | def load_net_safe(fname, net): 103 | import h5py 104 | print('load from file: %s' % fname) 105 | h5f = h5py.File(fname, mode='r') 106 | for k, v in net.state_dict().items(): 107 | try: 108 | param = torch.from_numpy(np.asarray(h5f[k])) 109 | except KeyError: 110 | print('do not find %s in h5 file' % k) 111 | else: 112 | print('loading %s from h5 file' % k) 113 | v.copy_(param) 114 | 115 | 116 | def set_trainable(model, requires_grad): 117 | for param in model.parameters(): 118 | param.requires_grad = requires_grad 119 | 120 | 121 | def weights_normal_init(model, dev=0.01): 122 | if isinstance(model, list): 123 | for m in model: 124 | weights_normal_init(m, dev) 125 | else: 126 | for m in model.modules(): 127 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.BatchNorm2d): 128 | #print torch.sum(m.weight) 129 | m.weight.data.normal_(0.0, dev) 130 | if m.bias is not None: 131 | m.bias.data.fill_(0.0) 132 | elif isinstance(m, nn.Linear): 133 | m.weight.data.normal_(0.0, dev) 134 | 135 | -------------------------------------------------------------------------------- /src/psnr.py: -------------------------------------------------------------------------------- 1 | def build_psnr(ground_truth, estimate): 2 | # PSNR defination: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 3 | import numpy 4 | import math 5 | 6 | if numpy.max(ground_truth) == 0: 7 | ground_truth = ground_truth * 255.0 8 | else: 9 | ground_truth = ground_truth / numpy.max(ground_truth) * 255.0 10 | 11 | if numpy.max(estimate) == 0: 12 | estimate = estimate * 255.0 13 | else: 14 | estimate = estimate / numpy.max(estimate) * 255.0 15 | 16 | mse = numpy.mean((ground_truth - estimate) ** 2) 17 | if mse == 0: 18 | return 100 19 | PIXEL_MAX = 255.0 20 | return 10 * math.log10(PIXEL_MAX ** 2 / mse) -------------------------------------------------------------------------------- /src/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 10 | return gauss / gauss.sum() 11 | 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 21 | # SSIM defination: https://en.wikipedia.org/wiki/Structural_similarity 22 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 23 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 24 | 25 | mu1_sq = mu1.pow(2) 26 | mu2_sq = mu2.pow(2) 27 | mu1_mu2 = mu1 * mu2 28 | 29 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 30 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 31 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 32 | 33 | C1 = 0.01 ** 2 34 | C2 = 0.03 ** 2 35 | 36 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 37 | 38 | if size_average: 39 | return ssim_map.mean() 40 | else: 41 | return ssim_map.mean(1).mean(1).mean(1) 42 | 43 | 44 | class SSIM(torch.nn.Module): 45 | def __init__(self, window_size=11, size_average=True): 46 | super(SSIM, self).__init__() 47 | self.window_size = window_size 48 | self.size_average = size_average 49 | self.channel = 1 50 | self.window = create_window(window_size, self.channel) 51 | 52 | def forward(self, img1, img2): 53 | (_, channel, _, _) = img1.size() 54 | 55 | if channel == self.channel and self.window.data.type() == img1.data.type(): 56 | window = self.window 57 | else: 58 | window = create_window(self.window_size, channel) 59 | 60 | if img1.is_cuda: 61 | window = window.cuda(img1.get_device()) 62 | window = window.type_as(img1) 63 | 64 | self.window = window 65 | self.channel = channel 66 | 67 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 68 | 69 | 70 | def ssim(img1, img2, window_size=11, size_average=True): 71 | (_, channel, _, _) = img1.size() 72 | window = create_window(window_size, channel) 73 | 74 | if img1.is_cuda: 75 | window = window.cuda(img1.get_device()) 76 | window = window.type_as(img1) 77 | 78 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from matplotlib import pyplot 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional 9 | import openpyxl as excel 10 | 11 | 12 | def ndarray_to_tensor(x, is_cuda=True, requires_grad=False, dtype=torch.float32): 13 | t = torch.tensor(x, dtype=dtype, requires_grad=requires_grad) 14 | if is_cuda: 15 | t = t.cuda() 16 | return t 17 | 18 | 19 | def log(path, log, mode='a', line=None, is_print=True): 20 | # line int i: add line i-th line of existing text. None: add line at the end of existing text 21 | if line is not None: 22 | with open(path, 'r') as file: 23 | exist_text_list = file.readlines() 24 | 25 | if not isinstance(line, int): 26 | raise Exception('invalid line') 27 | 28 | # add new lines 29 | line_now = line 30 | for l in log: 31 | exist_text_list.insert(line_now, l + '\n') 32 | line_now += 1 33 | if is_print: 34 | print(l, flush=True) 35 | 36 | # write to file 37 | with open(path, 'w') as file: 38 | file.writelines(exist_text_list) 39 | else: 40 | with open(path, mode) as file: 41 | for l in log: 42 | file.write(l + '\n') 43 | if is_print: 44 | print(l, flush=True) 45 | log[:] = [] 46 | 47 | 48 | def is_only_one_bool_is_true(*flag): 49 | count = 0 50 | for f in flag: 51 | if not isinstance(f, bool): 52 | raise Exception('not supported type') 53 | elif f: 54 | count += 1 55 | if count == 1: 56 | return True 57 | else: 58 | return False 59 | 60 | 61 | def show_matrix(matrix_list): 62 | # matrix a list of 2 dimension numpy.ndarray 63 | for matrix in matrix_list: 64 | pyplot.figure() 65 | pyplot.imshow(matrix) 66 | pyplot.show() 67 | return 68 | 69 | 70 | def compare_result(result_dict, best_result_dict, key_value, reverse=False): 71 | # Returns a dict of dictionaries containing the best error specified by the key_value 72 | # result_dict: dict 73 | # best_result_dict: dict 74 | # key_value: string 75 | # reverse: bool, False: return a result with smaller key value, True: return a result with larger key value 76 | final_result_dict = dict() 77 | for data_name in result_dict: 78 | result = result_dict[data_name] 79 | best_result = best_result_dict[data_name] 80 | 81 | if reverse: 82 | if result[key_value] > best_result[key_value]: 83 | final_result_dict[data_name] = result 84 | else: 85 | final_result_dict[data_name] = best_result 86 | else: 87 | if result[key_value] < best_result[key_value]: 88 | final_result_dict[data_name] = result 89 | else: 90 | final_result_dict[data_name] = best_result 91 | return final_result_dict 92 | 93 | 94 | def compare_mae(correct_cent_list, mse_list, model_name, best_correct_cent_list, best_mse_list, best_model_name_list): 95 | # choose best mean absolute error 96 | # correct_cent_list list 97 | # mse_list list 98 | # model_name string 99 | # best_correct_cent_list list 100 | # best_mse_list list 101 | # best_model_name_list list 102 | for i in range(len(correct_cent_list)): 103 | if correct_cent_list[i] < best_correct_cent_list[i]: 104 | best_correct_cent_list[i] = correct_cent_list[i] 105 | best_mse_list[i] = mse_list[i] 106 | best_model_name_list[i] = model_name 107 | return best_model_name_list, best_correct_cent_list, best_mse_list 108 | 109 | 110 | def compare_game(game_0_list, game_1_list, game_2_list, game_3_list, model_name, best_game_0_list, best_game_1_list, best_game_2_list, best_game_3_list, best_model_name_list): 111 | # choose best grid average mean absolute error 112 | # correct_cent_list list 113 | # mse_list list 114 | # model_name string 115 | # best_correct_cent_list list 116 | # best_mse_list list 117 | # best_model_name_list list 118 | for i in range(len(game_0_list)): 119 | if game_0_list[i] < best_game_0_list[i]: 120 | best_game_0_list[i] = game_0_list[i] 121 | best_game_1_list[i] = game_1_list[i] 122 | best_game_2_list[i] = game_2_list[i] 123 | best_game_3_list[i] = game_3_list[i] 124 | best_model_name_list[i] = model_name 125 | return best_model_name_list, best_game_0_list, best_game_1_list, best_game_2_list, best_game_3_list 126 | 127 | 128 | def compare_correct_cent(correct_cent_list, model_name, best_correct_cent_list, best_model_name_list): 129 | # choose higher correct cent 130 | # correct_cent_list list 131 | # model_name string 132 | # best_correct_cent_list list 133 | # best_model_name_list list 134 | for i in range(len(correct_cent_list)): 135 | if correct_cent_list[i] > best_correct_cent_list[i]: 136 | best_correct_cent_list[i] = correct_cent_list[i] 137 | best_model_name_list[i] = model_name 138 | return best_model_name_list, best_correct_cent_list 139 | 140 | 141 | def gray_to_bgr(gray_image, mode='jet'): 142 | # gray_image 2-D ndarray 143 | # output is a 3 channel uint8 rgb ndarray 144 | getColorMap = pyplot.get_cmap(mode) 145 | 146 | # gray_image = gray_image / np.max(gray_image) 147 | 148 | rgba_image = getColorMap(gray_image) 149 | rgb_image = np.delete(rgba_image, 3, 2) 150 | rgb_image = 255 * rgb_image 151 | 152 | return cv2.cvtColor(rgb_image.astype(np.uint8), cv2.COLOR_RGB2BGR) 153 | 154 | 155 | def make_path(path): 156 | if not isinstance(path, str): 157 | raise Exception('Path need to be a string.') 158 | if not os.path.exists(path): 159 | os.makedirs(path) 160 | 161 | 162 | def get_foreground_mask(ground_truth_map): 163 | # ground_truth_map numpy.ndarray shape=(1, 1, h, w) 164 | mask = np.zeros_like(ground_truth_map) 165 | mask[ground_truth_map > 0] = 1.0 166 | return mask 167 | 168 | 169 | def dilate_mask(mask, kernel_size, iterations=1, dtype=torch.float32, is_cuda=True): 170 | # masks type: torch.Tensor, shape is (1, 1, h, w) 171 | # kernel_size type: int or tuple of int 172 | if isinstance(kernel_size, int): 173 | kernel = np.ones((kernel_size, kernel_size)) 174 | else: 175 | raise Exception('invalid kernel_size type') 176 | 177 | mask = mask.data.cpu().numpy() 178 | 179 | # reshape (1, 1, h, w) to (h, w) 180 | mask = mask.reshape(mask.shape[2], mask.shape[3]) 181 | 182 | # dilate operation 183 | mask = cv2.dilate(mask, kernel, iterations=iterations) 184 | 185 | # reshape (h, w) to (1, 1, h, w) 186 | mask = mask.reshape(1, 1, mask.shape[0], mask.shape[1]) 187 | 188 | mask = torch.from_numpy(mask).to(dtype) 189 | if is_cuda: 190 | mask = mask.cuda() 191 | 192 | return mask 193 | 194 | 195 | def erode_mask(mask, kernel_size, iterations=1, dtype=torch.float32, is_cuda=True): 196 | # masks type: torch.Tensor, shape is (1, 1, h, w) 197 | # kernel_size type: int 198 | if isinstance(kernel_size, int): 199 | kernel = np.ones((kernel_size, kernel_size)) 200 | else: 201 | raise Exception('invalid kernel_size type') 202 | 203 | mask = mask.data.cpu().numpy() 204 | 205 | # reshape (1, 1, h, w) to (h, w) 206 | mask = mask.reshape(mask.shape[2], mask.shape[3]) 207 | 208 | # dilate operation 209 | mask = cv2.erode(mask, kernel, iterations=iterations) 210 | 211 | # reshape (h, w) to (1, 1, h, w) 212 | mask = mask.reshape(1, 1, mask.shape[0], mask.shape[1]) 213 | 214 | mask = torch.from_numpy(mask).to(dtype) 215 | if is_cuda: 216 | mask = mask.cuda() 217 | 218 | return mask 219 | 220 | 221 | def gaussian_kernel(shape=(15, 15), sigma=4): 222 | """ 223 | 2D gaussian mask - should give the same result as MATLAB's 224 | fspecial('gaussian',[shape],[sigma]) 225 | """ 226 | m, n = [(ss - 1.) / 2. for ss in shape] 227 | y,x = np.ogrid[-m:m+1, -n:n+1] 228 | h = np.exp(-(x ** 2 + y ** 2) / (2. * sigma ** 2)) 229 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 230 | sumh = h.sum() 231 | if sumh != 0: 232 | h /= sumh 233 | return h 234 | 235 | 236 | class ExcelLog: 237 | def __init__(self, path): 238 | # path: str, path of the excel file 239 | # datasets: list, name of every sheet 240 | # keys: list, name of every column 241 | if not isinstance(path, str): 242 | raise Exception('path should be a string') 243 | 244 | self.path = path 245 | self.alphabet = '_ABCDEFGHIJKLMNOPQRSTUVWXYZ' 246 | 247 | excel_book = excel.Workbook() 248 | 249 | # for i in range(len(self.datasets)): 250 | # excel_sheet = excel_book.create_sheet(self.datasets[i], i) 251 | # 252 | # for j in range(len(self.keys)): 253 | # excel_sheet[self.get_cell_name(j, 1)] = self.keys[j] 254 | 255 | excel_book.save(self.path) 256 | return 257 | 258 | def get_cell_name(self, column, row): 259 | if not (isinstance(column, int) or isinstance(row, int)): 260 | raise Exception('column and row should be integer') 261 | 262 | return self.alphabet[column] + str(row) 263 | 264 | def add_log(self, log_dict): 265 | # log_dict: dict, a dict contains several dicts, every dict contains information of one dataset 266 | if not isinstance(log_dict, dict): 267 | raise Exception('log_dict should be a dictionary') 268 | 269 | excel_book = excel.load_workbook(self.path) 270 | 271 | for dataset_name in log_dict: 272 | log = log_dict[dataset_name] 273 | try: 274 | excel_sheet = excel_book.get_sheet_by_name(dataset_name) 275 | except KeyError: 276 | excel_sheet = excel_book.create_sheet(dataset_name) 277 | column = 1 278 | for name in log: 279 | excel_sheet[self.get_cell_name(column, 1)] = name 280 | column += 1 281 | 282 | row = excel_sheet.max_row + 1 283 | 284 | column = 1 285 | for name in log: 286 | excel_sheet[self.get_cell_name(column, row)] = log[name] 287 | column += 1 288 | 289 | excel_book.save(self.path) 290 | return 291 | 292 | 293 | def calculate_game(ground_truth, estimate, L=0): 294 | # grid average mean absolute error 295 | # ground_truth Tensor: shape=(1, 1, h, w) 296 | # estimate Tensor: same shape of ground_truth 297 | ground_truth = ndarray_to_tensor(ground_truth, is_cuda=True) 298 | estimate = ndarray_to_tensor(estimate, is_cuda=True) 299 | height = ground_truth.shape[2] 300 | width = ground_truth.shape[3] 301 | times = math.sqrt(math.pow(4, L)) 302 | padding_height = int(math.ceil(height / times) * times - height) 303 | padding_width = int(math.ceil(width / times) * times - width) 304 | if padding_height != 0 or padding_width != 0: 305 | m = nn.ZeroPad2d((0, padding_width, 0, padding_height)) 306 | ground_truth = m(ground_truth) 307 | estimate = m(estimate) 308 | height = ground_truth.shape[2] 309 | width = ground_truth.shape[3] 310 | m = nn.AdaptiveAvgPool2d(int(times)) 311 | ground_truth = m(ground_truth) * (height / times) * (width / times) 312 | estimate = m(estimate) * (height / times) * (width / times) 313 | game = torch.sum(torch.abs(ground_truth - estimate)) 314 | return game.item() 315 | 316 | 317 | # def calculate_game_not_include_pad(ground_truth, estimate, L=0): 318 | # # grid average mean absolute error 319 | # # ground_truth Tensor: shape=(1, 1, h, w) 320 | # # estimate Tensor: same shape of ground_truth 321 | # ground_truth = ndarray_to_tensor(ground_truth, is_cuda=True) 322 | # estimate = ndarray_to_tensor(estimate, is_cuda=True) 323 | # height = ground_truth.shape[2] 324 | # width = ground_truth.shape[3] 325 | # times = math.sqrt(math.pow(4, L)) 326 | # grid_height = int(math.ceil(height / times)) 327 | # grid_width = int(math.ceil(width / times)) 328 | # padding_height = int(math.ceil((grid_height * times - height) / 2)) 329 | # padding_width = int(math.ceil((grid_width * times - width) / 2)) 330 | # m = nn.AvgPool2d((grid_height, grid_width), stride=(grid_height, grid_width), padding=(padding_height, padding_width), count_include_pad=False) 331 | # ground_truth = m(ground_truth) * (height / times) * (width / times) 332 | # estimate = m(estimate) * (height / times) * (width / times) 333 | # game = torch.sum(torch.abs(ground_truth - estimate)) 334 | # return game.item() 335 | 336 | 337 | def print_red(a_string): 338 | print('\033[91m' + a_string + '\033[0m') 339 | 340 | 341 | def build_block(x, size): 342 | # x shape=(1, c, h, w) 343 | height = x.shape[2] 344 | width = x.shape[3] 345 | padding_height = math.ceil((math.ceil(height / size) * size - height) / 2) 346 | padding_width = math.ceil((math.ceil(width / size) * size - width) / 2) 347 | return functional.avg_pool2d(x, size, stride=size, padding=(padding_height, padding_width), count_include_pad=True) * size * size -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | from src.utils import calculate_game 5 | from src.crowd_count import CrowdCount 6 | from src.data_multithread_preload import multithread_dataloader 7 | from src import network 8 | 9 | 10 | test_flag = dict() 11 | test_flag['preload'] = False 12 | test_flag['label'] = False 13 | test_flag['mask'] = False 14 | 15 | test_model_path = r'./final_model/shtechA.h5' 16 | # original_dataset_name = 'shtechA' 17 | test_data_config = dict() 18 | test_data_config['shtA1_test'] = test_flag.copy() 19 | 20 | # load data 21 | all_data = multithread_dataloader(test_data_config) 22 | 23 | net = CrowdCount() 24 | 25 | network.load_net(test_model_path, net) 26 | 27 | net.cuda() 28 | net.eval() 29 | 30 | total_forward_time = 0.0 31 | 32 | # calculate error on the test dataset 33 | for data_name in test_data_config: 34 | data = all_data[data_name]['data'] 35 | 36 | mae = 0.0 37 | mse = 0.0 38 | game_0 = 0.0 39 | game_1 = 0.0 40 | game_2 = 0.0 41 | game_3 = 0.0 42 | index = 0 43 | for blob in data: 44 | image_data = blob['image'] 45 | ground_truth_data = blob['density'] 46 | roi = blob['roi'] 47 | image_name = blob['image_name'][0] 48 | 49 | start_time = time.perf_counter() 50 | estimate_map, _, visual_dict = net(image_data, roi=roi) 51 | total_forward_time += time.perf_counter() - start_time 52 | 53 | ground_truth_map = ground_truth_data.data.cpu().numpy() 54 | estimate_map = estimate_map.data.cpu().numpy() 55 | 56 | ground_truth_count = np.sum(ground_truth_map) 57 | estimate_count = np.sum(estimate_map) 58 | 59 | mae += np.abs(ground_truth_count - estimate_count) 60 | mse += (ground_truth_count - estimate_count) ** 2 61 | game_0 += calculate_game(ground_truth_map, estimate_map, 0) 62 | game_1 += calculate_game(ground_truth_map, estimate_map, 1) 63 | game_2 += calculate_game(ground_truth_map, estimate_map, 2) 64 | game_3 += calculate_game(ground_truth_map, estimate_map, 3) 65 | index += 1 66 | 67 | mae = mae / index 68 | mse = np.sqrt(mse / index) 69 | game_0 = game_0 / index 70 | game_1 = game_1 / index 71 | game_2 = game_2 / index 72 | game_3 = game_3 / index 73 | print('mae: %.2f mse: %.2f game: %.2f %.2f %.2f %.2f' % (mae, mse, game_0, game_1, game_2, game_3)) 74 | 75 | print('total forward time is %f seconds of %d samples.' % (total_forward_time, index)) 76 | 77 | --------------------------------------------------------------------------------