├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── architecture ├── crowd_count.py ├── data_loader.py ├── evaluate_model.py ├── losses.py ├── models.py ├── network.py ├── timer.py └── utils.py ├── manage_data ├── __pycache__ │ ├── data_augmentation.cpython-35.pyc │ ├── dataset_loader.cpython-35.pyc │ ├── get_density_map.cpython-35.pyc │ ├── get_density_map.cpython-36.pyc │ ├── misc.cpython-35.pyc │ └── utils.cpython-35.pyc ├── create_synthetic_dataset.py ├── data_augmentation.py ├── dataset_loader.py ├── get_density_map.py ├── get_density_map.pyc ├── misc.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | log 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Rodolfo Quispe 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | train-ucf-fold1: 2 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 24 --units ucf-fold1 --save-dir log/ACSCP 3 | train-ucf-fold2: 4 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 24 --units ucf-fold2 --save-dir log/ACSCP 5 | train-ucf-fold3: 6 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 24 --units ucf-fold3 --save-dir log/ACSCP 7 | train-ucf-fold4: 8 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 24 --units ucf-fold4 --save-dir log/ACSCP 9 | train-ucf-fold5: 10 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 24 --units ucf-fold5 --save-dir log/ACSCP 11 | test-ucf-fold1: 12 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 16 --units ucf-fold1 --save-dir log/ACSCP --evaluate-only --resume log/ACSCP/ucf-cc-50_people_thr_20_gt_mode_same/ucf-fold1/best_model.h5 --save-plots 13 | test-ucf-fold2: 14 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 16 --units ucf-fold2 --save-dir log/ACSCP --evaluate-only --resume log/ACSCP/ucf-cc-50_people_thr_20_gt_mode_same/ucf-fold2/best_model.h5 --save-plots --overlap-test 15 | test-ucf-fold3: 16 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 16 --units ucf-fold3 --save-dir log/ACSCP --evaluate-only --resume log/ACSCP/ucf-cc-50_people_thr_20_gt_mode_same/ucf-fold3/best_model.h5 --save-plots --overlap-test 17 | test-ucf-fold4: 18 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 16 --units ucf-fold4 --save-dir log/ACSCP --evaluate-only --resume log/ACSCP/ucf-cc-50_people_thr_20_gt_mode_same/ucf-fold4/best_model.h5 --save-plots --overlap-test 19 | test-ucf-fold5: 20 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 20 --train-batch 16 --units ucf-fold5 --save-dir log/ACSCP --evaluate-only --resume log/ACSCP/ucf-cc-50_people_thr_20_gt_mode_same/ucf-fold5/best_model.h5 --save-plots --overlap-test 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Crowd Counting via Adversarial Cross-Scale Consistency Pursuit IN PYTORCH (Unofficial) 2 | 3 | Implementation of CVPR 2018 paper [Crowd Counting via Adversarial Cross-Scale](http://openaccess.thecvf.com/content_cvpr_2018/papers/Shen_Crowd_Counting_via_CVPR_2018_paper.pdf) 4 | 5 | ## 1. Environment 6 | 7 | We used the following enviroment: 8 | 9 | * Python 3 10 | * PyTorch 11 | * OpenCV 12 | * Numpy 13 | * MatPlotLib 14 | * Ubuntu 16.04 15 | 16 | You can also run the code using the docker image of [ufoym/deepo](https://hub.docker.com/r/ufoym/deepo). 17 | 18 | ## 2. Preparing data 19 | 20 | We make available UCF-CC-50 and Shanghai Tech datasets [here](http://www.liv.ic.unicamp.br/~quispe/publications/data/data-crowd-counting.zip), download and unzip it into the root of the repo. Directories should have the following hierarchy: 21 | 22 | ``` 23 | ROOT_OF_REPO 24 | data 25 | ucf_cc_50 26 | UCF_CC_50 27 | images 28 | labels 29 | ShanghaiTech 30 | part_A 31 | train_data 32 | images 33 | ground-truth 34 | test_data 35 | images 36 | ground-truth 37 | part_B 38 | train_data 39 | images 40 | ground-truth 41 | test_data 42 | images 43 | ground-truth 44 | ``` 45 | 46 | The code was developed such that data augmentation is computed before every other step and the results are stored in the hard drive. Thus, the first time you run the code it will take quite a long time. Augmented data is stored with the following hierarchy: 47 | 48 | ``` 49 | ROOT_OF_REPO 50 | data 51 | ucf_cc_50 52 | people_thr_0_gt_mode_same 53 | ShanghaiTech 54 | part_A 55 | people_thr_0_gt_mode_same 56 | part_B 57 | people_thr_0_gt_mode_same 58 | ``` 59 | 60 | ## 3. Training 61 | 62 | To train using UCF-CC-50 (with all folds) and save the results log in `log/ACSCP` you can run: 63 | 64 | ``` 65 | python3 train.py -d ucf-cc-50 --gt-mode same --people-thr 0 --train-batch 24 --save-dir log/ACSCP 66 | 67 | ``` 68 | 69 | In case you want to run a specific fold or part you can use flag `--units`, check the `Makefile` for more examples. 70 | 71 | The training log is stored in `log_train.txt` inside the corresponding log/fold/part directory. 72 | 73 | ## 4. Testing 74 | 75 | After training you can re-load the trained weights (using flag `--resume`) and use them for testing: 76 | 77 | ``` 78 | python3 train.py -d ucf-cc-50 --save-dir log/multi-stream --resume log/ACSCP/ucf-cc-50_people_thr_0_gt_mode_same --evaluate-only 79 | ``` 80 | 81 | The testing log is stored in `log_test.txt` inside the corresponding log/fold/part directory. You can also generate the plots of the predictions using flag `--save-plots`, results are stored in the directory `plot-results-test` inside the corresponding log/fold/part directory. 82 | 83 | ## 5. Final notes / TODO: 84 | 85 | * Results for UCF_CC_50 with this code are MAE 281,73 MSE 415,56 (--people-thr 20). Reported results by the authors are MAE 291.0 MSE 404.6. 86 | * Validation for other dataset may be done in the future. 87 | * Batch normalization is not used because of inestable learning. 88 | * Ground thruth is generated using a gaussian of fixed size. 89 | * Number of channels of autoencoder in the middle layer is changed to 3, instead of 4. 90 | * Network receives images of 1 channel, instead of 3. 91 | * You can use the flag `--overlap-test` to overlap the sliding windows used for testing (as implemented by the authors). 92 | -------------------------------------------------------------------------------- /architecture/crowd_count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from architecture import network 5 | from architecture.models import G_Large, G_Small, discriminator 6 | from architecture.losses import MSE, Perceptual, CSCP 7 | 8 | import numpy as np 9 | 10 | class CrowdCounter(nn.Module): 11 | def __init__(self): 12 | super(CrowdCounter, self).__init__() 13 | self.euclidean_loss = MSE() 14 | self.perceptual_loss = Perceptual() 15 | self.cscp_loss = CSCP() 16 | self.g_large = G_Large() 17 | self.g_small = G_Small() 18 | self.d_large = discriminator() 19 | self.d_small = discriminator() 20 | self.alpha_euclidean = 150.0 21 | self.alpha_perceptual = 150.0 22 | self.alpha_cscp = 0.0 23 | self.loss_gen_large = 0.0 24 | self.loss_gen_small = 0.0 25 | self.loss_dis_large = 0.0 26 | self.loss_dis_small = 0.0 27 | 28 | def adv_loss_generator(self, generator, discriminator, inputs): 29 | batch_size, _, _, _ = inputs.size() 30 | x = generator(inputs) 31 | fake_logits, _ = discriminator(inputs, x) 32 | ones = torch.ones(batch_size).cuda() 33 | loss = F.binary_cross_entropy_with_logits(fake_logits, ones) 34 | return x, loss 35 | 36 | def adv_loss_discriminator(self, generator, discriminator, inputs, targets): 37 | batch_size, _, _, _ = inputs.size() 38 | ones = torch.ones(batch_size) 39 | # swap some labels and smooth the labels 40 | idx = np.random.uniform(0, 1, batch_size) 41 | idx = np.argwhere(idx < 0.03).reshape(-1) 42 | ones += torch.tensor(np.random.uniform(-0.1, 0.1)) 43 | ones[idx] = 0 44 | zeros = torch.zeros(batch_size) 45 | ones = ones.cuda() 46 | zeros = zeros.cuda() 47 | 48 | x = generator(inputs) 49 | fake_logits, _ = discriminator(inputs, x) 50 | real_logits, _ = discriminator(inputs, targets) 51 | 52 | loss_fake = F.binary_cross_entropy_with_logits(fake_logits, zeros) 53 | loss_real = F.binary_cross_entropy_with_logits(real_logits, ones) 54 | loss = loss_fake + loss_real 55 | return x, loss 56 | 57 | def chunk_input(self, inputs, gt_data): 58 | chunks = torch.chunk(inputs, chunks = 2, dim = 2) 59 | inputs_1, inputs_2 = torch.chunk(chunks[0], chunks = 2, dim = 3) 60 | inputs_3, inputs_4 = torch.chunk(chunks[1], chunks = 2, dim = 3) 61 | 62 | chunks = torch.chunk(gt_data, chunks = 2, dim = 2) 63 | targets_1, targets_2 = torch.chunk(chunks[0], chunks = 2, dim = 3) 64 | targets_3, targets_4 = torch.chunk(chunks[1], chunks = 2, dim = 3) 65 | 66 | inputs_chunks = torch.cat((inputs_1, inputs_2, inputs_3, inputs_4), dim = 0) 67 | targets_chunks = torch.cat((targets_1, targets_2, targets_3, targets_4), dim = 0) 68 | 69 | return inputs_chunks, targets_chunks 70 | 71 | def forward(self, inputs, gt_data=None, epoch=0, mode="generator"): 72 | assert mode in list(["discriminator", "generator"]), ValueError("Invalid network mode '{}'".format(mode)) 73 | inputs = network.np_to_variable(inputs, is_cuda=True, is_training=self.training) 74 | if not self.training: 75 | g_l = self.g_large(inputs) 76 | else: 77 | gt_data = network.np_to_variable(gt_data, is_cuda=True, is_training=self.training) 78 | #chunk input data in 4 79 | inputs_chunks, gt_data_chunks = self.chunk_input(inputs, gt_data) 80 | 81 | if mode == "generator": 82 | # g_large 83 | x_l, self.loss_gen_large = self.adv_loss_generator(self.g_large, self.d_large, inputs) 84 | self.loss_gen_large += self.alpha_euclidean * self.euclidean_loss(x_l, gt_data) 85 | self.loss_gen_large += self.alpha_perceptual * self.perceptual_loss(x_l, gt_data) 86 | 87 | # g_small 88 | x_s, self.loss_gen_small = self.adv_loss_generator(self.g_small, self.d_small, inputs_chunks) 89 | self.loss_gen_small += self.alpha_euclidean * self.euclidean_loss(x_s, gt_data_chunks) 90 | self.loss_gen_small += self.alpha_perceptual * self.perceptual_loss(x_s, gt_data_chunks) 91 | 92 | if epoch >= 100: 93 | self.alpha_cscp = 10 94 | self.loss_gen_large += self.alpha_cscp * self.cscp_loss(x_l, x_s) 95 | self.loss_gen_small += self.alpha_cscp * self.cscp_loss(x_l, x_s) 96 | 97 | self.loss_gen = self.loss_gen_large + self.loss_gen_small 98 | else: 99 | #d_large 100 | x_l, self.loss_dis_large = self.adv_loss_discriminator(self.g_large, self.d_large, inputs, gt_data) 101 | 102 | #d_small 103 | x_s, self.loss_dis_small = self.adv_loss_discriminator(self.g_small, self.d_small, inputs_chunks, gt_data_chunks) 104 | g_l = x_l 105 | return g_l 106 | -------------------------------------------------------------------------------- /architecture/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import random 5 | import pandas as pd 6 | 7 | 8 | class ImageDataLoader(): 9 | def __init__(self, data_path, gt_path, shuffle=False, batch_size = 1, test_loader = False, img_width = 256, img_height = 256, test_overlap = False): 10 | self.data_path = data_path 11 | self.gt_path = gt_path 12 | self.batch_size = batch_size 13 | self.test_loader = test_loader 14 | self.img_width = img_width 15 | self.img_height = img_height 16 | self.test_overlap = test_overlap 17 | self.data_files = [filename for filename in os.listdir(data_path) \ 18 | if os.path.isfile(os.path.join(data_path,filename)) and os.path.splitext(filename)[1] == '.jpg'] 19 | self.data_files.sort() 20 | self.shuffle = shuffle 21 | if shuffle: 22 | random.seed(2468) 23 | self.num_samples = len(self.data_files) 24 | self.blob_list = {} 25 | self.id_list = np.arange(0,self.num_samples) 26 | 27 | def __iter__(self): 28 | if self.shuffle: 29 | random.shuffle(self.data_files) 30 | files = np.array(self.data_files) 31 | id_list = np.array(self.id_list) 32 | 33 | for ind in range(0, len(id_list), self.batch_size): 34 | idx = id_list[ind: ind + self.batch_size] 35 | fnames = files[idx] 36 | imgs = [] 37 | dens = [] 38 | dens_small = [] 39 | for fname in fnames: 40 | if not os.path.isfile(os.path.join(self.data_path,fname)): 41 | print("Error: file '{}' doen't exists".format(os.path.join(self.data_path,fname))) 42 | img = cv2.imread(os.path.join(self.data_path,fname),0) 43 | img = img.astype(np.float32, copy=False) 44 | img = img.reshape((1,img.shape[0],img.shape[1])) 45 | 46 | den = np.load(os.path.join(self.gt_path,os.path.splitext(fname)[0] + '.npy')) 47 | den = den.astype(np.float32, copy=False) 48 | den = den.reshape((1, den.shape[0], den.shape[1])) 49 | 50 | if self.test_loader: #loader is for testing, then we divide the image in chunks of size (img_height, img_width) 51 | _, h, w = img.shape 52 | orig_shape = (h, w) 53 | # compute padding 54 | if self.test_overlap: 55 | padding_h = self.img_height - max(h % self.img_height, (h - self.img_height//2) % self.img_height) 56 | padding_w = self.img_width - max(w % self.img_width, (w - self.img_width//2) % self.img_width) 57 | else: 58 | padding_h = self.img_height - (h % self.img_height) 59 | padding_w = self.img_width - (w % self.img_width) 60 | 61 | # add padding 62 | img = np.concatenate((img, np.zeros((img.shape[0], padding_h, img.shape[2]))), axis =1) 63 | den = np.concatenate((den, np.zeros((img.shape[0], padding_h, img.shape[2]))), axis =1) 64 | img = np.concatenate((img, np.zeros((img.shape[0], img.shape[1], padding_w))), axis =2) 65 | den = np.concatenate((den, np.zeros((img.shape[0], img.shape[1], padding_w))), axis =2) 66 | assert img.shape[1] % 2 == 0 and img.shape[2] % 2 == 0, "Inputs images must have even dimensions, found {}".format(img.shape) 67 | 68 | # create batch for test 69 | _, h, w = img.shape 70 | new_shape = (h, w) 71 | disp_height = self.img_height // 2 if self.test_overlap else self.img_height 72 | disp_width = self.img_width // 2 if self.test_overlap else self.img_width 73 | for i in range(0, h - self.img_height + 1, disp_height): 74 | for j in range(0, w - self.img_width + 1, disp_width): 75 | chunk_img = img[0, i:i + self.img_height, j:j + self.img_width] 76 | chunk_den = den[0, i:i + self.img_height, j:j + self.img_width] 77 | chunk_img = chunk_img.reshape((1, chunk_img.shape[0], chunk_img.shape[1])) 78 | chunk_den = chunk_den.reshape((1, chunk_den.shape[0], chunk_den.shape[1])) 79 | imgs.append(chunk_img) 80 | dens.append(chunk_den) 81 | else: 82 | imgs.append(img) 83 | dens.append(den) 84 | blob = {} 85 | blob['data']=np.array(imgs) 86 | blob['gt_density']=np.array(dens) 87 | blob['fname'] = np.array(fnames) 88 | blob['idx'] = np.array(idx) 89 | if self.test_loader: 90 | blob['orig_shape'] = np.array(orig_shape) 91 | blob['new_shape'] = np.array(new_shape) 92 | yield blob 93 | 94 | def get_num_samples(self): 95 | return self.num_samples 96 | 97 | def recontruct_test(self, img_batch, den_batch, orig_shape, new_shape): 98 | disp_height = self.img_height // 2 if self.test_overlap else self.img_height 99 | disp_width = self.img_width // 2 if self.test_overlap else self.img_width 100 | img = np.zeros(new_shape) 101 | cnt = np.zeros(new_shape) 102 | den = np.zeros(new_shape) 103 | ind = 0 104 | for i in range(0, new_shape[0] - self.img_height + 1, disp_height): 105 | for j in range(0, new_shape[1] - self.img_width + 1, disp_width): 106 | img[i:i + self.img_height, j:j + self.img_width] = img_batch[ind, 0] 107 | den[i:i + self.img_height, j:j + self.img_width] += den_batch[ind, 0] 108 | cnt[i:i + self.img_height, j:j + self.img_width] += 1 109 | ind += 1 110 | den /= cnt 111 | 112 | #crop to original shape 113 | img = img[:orig_shape[0], :orig_shape[1]].reshape((1, 1, orig_shape[0], orig_shape[1])) 114 | den = den[:orig_shape[0], :orig_shape[1]].reshape((1, 1, orig_shape[0], orig_shape[1])) 115 | return img, den -------------------------------------------------------------------------------- /architecture/evaluate_model.py: -------------------------------------------------------------------------------- 1 | from architecture.crowd_count import CrowdCounter 2 | import architecture.network as network 3 | import numpy as np 4 | import torch 5 | 6 | from manage_data.utils import Logger, mkdir_if_missing 7 | from architecture import utils 8 | 9 | def evaluate_model(trained_model, data_loader, epoch = 0, save_test_results = False, plot_save_dir = "/tmp/", den_factor = 1e3): 10 | net = CrowdCounter() 11 | network.load_net(trained_model, net) 12 | net.cuda() 13 | net.eval() 14 | mae = 0.0 15 | mse = 0.0 16 | 17 | for blob in data_loader: 18 | im_data = blob['data'] 19 | gt_data = blob['gt_density'] 20 | idx_data = blob['idx'] 21 | new_shape = blob['new_shape'] 22 | orig_shape = blob['orig_shape'] 23 | im_data_norm = im_data / 127.5 - 1. #normalize between -1 and 1 24 | gt_data = gt_data * den_factor 25 | 26 | density_map = net(im_data_norm, epoch = epoch) 27 | density_map = density_map.data.cpu().numpy() 28 | density_map /= den_factor 29 | gt_data /= den_factor 30 | im_data, gt_data = data_loader.recontruct_test(im_data, gt_data, orig_shape, new_shape) 31 | _, density_map = data_loader.recontruct_test(im_data_norm, density_map, orig_shape, new_shape) 32 | gt_count = np.sum(gt_data) 33 | et_count = np.sum(density_map) 34 | print("image {} gt {:.3f} es {:.3f}".format(idx_data[0], gt_count, et_count)) 35 | mae += abs(gt_count-et_count) 36 | mse += ((gt_count-et_count)*(gt_count-et_count)) 37 | 38 | if save_test_results: 39 | print("Plotting results") 40 | mkdir_if_missing(plot_save_dir) 41 | utils.save_results(im_data, gt_data, density_map, idx_data, plot_save_dir) 42 | 43 | mae = mae/data_loader.get_num_samples() 44 | mse = np.sqrt(mse/data_loader.get_num_samples()) 45 | return mae,mse -------------------------------------------------------------------------------- /architecture/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import torchvision 8 | 9 | __all__ = ['MSE', 'Perceptual', 'CSCP', 'Adversarial'] 10 | 11 | class MSE(nn.Module): 12 | """ 13 | Computes MSE loss between inputs 14 | """ 15 | def __init__(self): 16 | super(MSE, self).__init__() 17 | self.mse = nn.MSELoss() 18 | def forward(self, inputs, targets): 19 | loss = self.mse(inputs, targets) 20 | return loss 21 | 22 | class Perceptual(nn.Module): 23 | """ 24 | Computes Perceptual loss between inputs 25 | """ 26 | def __init__(self): 27 | super(Perceptual, self).__init__() 28 | vgg = torchvision.models.vgg.vgg16(pretrained=True) 29 | self.vgg_layers = nn.Sequential(*list(vgg.features.children())[:9]) #recover up to relu2_2 layer 30 | 31 | def vgg_feature(self, x): 32 | # x has only one channel and vgg expects 3 33 | x = torch.cat((x, x, x), dim = 1) 34 | x = self.vgg_layers(x) 35 | return x 36 | 37 | def forward(self, inputs, targets): 38 | loss = F.mse_loss(self.vgg_feature(inputs), self.vgg_feature(targets)) 39 | return loss 40 | 41 | class CSCP(nn.Module): 42 | """ 43 | Implements Cross-Scale Consistency Pursuit Loss 44 | """ 45 | def __init__(self): 46 | super(CSCP, self).__init__() 47 | 48 | def forward(self, density_maps, density_chunks): 49 | batch_size, _, _, _ = density_maps.size() 50 | inputs_1 = density_chunks[:batch_size, :, :, :] 51 | inputs_2 = density_chunks[batch_size:2*batch_size, :, :, :] 52 | inputs_3 = density_chunks[2*batch_size:3*batch_size, :, :, :] 53 | inputs_4 = density_chunks[3*batch_size:4*batch_size, :, :, :] 54 | density_joined = torch.cat((torch.cat((inputs_1, inputs_2), dim = 3), torch.cat((inputs_3, inputs_4), dim = 3)), dim = 2) 55 | loss = F.mse_loss(density_maps, density_joined) 56 | return loss 57 | 58 | -------------------------------------------------------------------------------- /architecture/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | class Conv2d(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bn=False, activation = 'leakyrelu', dropout = False): 9 | super(Conv2d, self).__init__() 10 | padding = int((kernel_size - 1) / 2) 11 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=padding) 12 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None 13 | self.dropout = nn.Dropout(p=0.5) if dropout else None 14 | if activation == 'leakyrelu': 15 | self.activation = nn.LeakyReLU(negative_slope = 0.2) 16 | elif activation == 'relu': 17 | self.activation = nn.ReLU() 18 | elif activation == 'tanh': 19 | self.activation = nn.Tanh() 20 | else: 21 | raise ValueError('Not a valid activation, received {}'.format(activation)) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | if self.bn is not None: 26 | x = self.bn(x) 27 | if self.dropout is not None: 28 | x = self.dropout(x) 29 | x = self.activation(x) 30 | return x 31 | 32 | class Deconv2d(nn.Module): 33 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, bn=False, activation = 'leakyrelu', dropout = False): 34 | super(Deconv2d, self).__init__() 35 | padding = int((kernel_size - 1) / 2) 36 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=padding) 37 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0, affine=True) if bn else None 38 | self.dropout = nn.Dropout(p=0.5) if dropout else None 39 | if activation == 'leakyrelu': 40 | self.activation = nn.LeakyReLU(negative_slope = 0.2) 41 | elif activation == 'relu': 42 | self.activation = nn.ReLU() 43 | elif activation == 'tanh': 44 | self.activation = nn.Tanh() 45 | else: 46 | raise ValueError('Not a valid activation, received {}'.format(activation)) 47 | 48 | def forward(self, x): 49 | x = self.conv(x) 50 | if self.bn is not None: 51 | x = self.bn(x) 52 | if self.dropout is not None: 53 | x = self.dropout(x) 54 | x = self.activation(x) 55 | return x 56 | 57 | class G_Large(nn.Module): 58 | def __init__(self): 59 | super(G_Large, self).__init__() 60 | self.encoder_1 = Conv2d(1, 64, 6, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 61 | self.encoder_2 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 62 | self.encoder_3 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 63 | self.encoder_4 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 64 | self.encoder_5 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 65 | self.encoder_6 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 66 | self.encoder_7 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 67 | self.encoder_8 = Conv2d(64, 64, 3, stride = 1, bn = False, activation = 'leakyrelu', dropout=False) 68 | 69 | self.decoder_1 = Deconv2d(64, 64, 3, stride = 1, bn = False, activation = 'relu', dropout = True) 70 | self.decoder_2 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = True) 71 | self.decoder_3 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = True) 72 | self.decoder_4 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = False) 73 | self.decoder_5 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = False) 74 | self.decoder_6 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = False) 75 | self.decoder_7 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = False) 76 | self.decoder_8 = Deconv2d(128, 1, 6, stride = 2, bn = False, activation = 'relu', dropout = False) 77 | 78 | def forward(self, x): 79 | e1 = self.encoder_1(x) 80 | e2 = self.encoder_2(e1) 81 | e3 = self.encoder_3(e2) 82 | e4 = self.encoder_4(e3) 83 | e5 = self.encoder_5(e4) 84 | e6 = self.encoder_6(e5) 85 | e7 = self.encoder_7(e6) 86 | e8 = self.encoder_8(e7) 87 | 88 | d = self.decoder_1(e8) 89 | d = torch.cat((d, e7), dim=1) 90 | d = self.decoder_2(d) 91 | d = torch.cat((d, e6), dim=1) 92 | d = self.decoder_3(d) 93 | d = torch.cat((d, e5), dim=1) 94 | d = self.decoder_4(d) 95 | d = torch.cat((d, e4), dim=1) 96 | d = self.decoder_5(d) 97 | d = torch.cat((d, e3), dim=1) 98 | d = self.decoder_6(d) 99 | d = torch.cat((d, e2), dim=1) 100 | d = self.decoder_7(d) 101 | d = torch.cat((d, e1), dim=1) 102 | d = self.decoder_8(d) 103 | 104 | return d 105 | 106 | 107 | class G_Small(nn.Module): 108 | def __init__(self): 109 | super(G_Small, self).__init__() 110 | self.encoder_1 = Conv2d(1, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 111 | self.encoder_2 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 112 | self.encoder_3 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 113 | self.encoder_4 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 114 | self.encoder_5 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 115 | self.encoder_6 = Conv2d(64, 64, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 116 | self.encoder_7 = Conv2d(64, 64, 3, stride = 1, bn = False, activation = 'leakyrelu', dropout=False) 117 | 118 | self.decoder_1 = Deconv2d(64, 64, 3, stride = 1, bn = False, activation = 'relu', dropout = True) 119 | self.decoder_2 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = True) 120 | self.decoder_3 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = True) 121 | self.decoder_4 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = False) 122 | self.decoder_5 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = False) 123 | self.decoder_6 = Deconv2d(128, 64, 4, stride = 2, bn = False, activation = 'relu', dropout = False) 124 | self.decoder_7 = Deconv2d(128, 1, 4, stride = 2, bn = False, activation = 'relu', dropout = False) 125 | 126 | def forward(self, x): 127 | e1 = self.encoder_1(x) 128 | e2 = self.encoder_2(e1) 129 | e3 = self.encoder_3(e2) 130 | e4 = self.encoder_4(e3) 131 | e5 = self.encoder_5(e4) 132 | e6 = self.encoder_6(e5) 133 | e7 = self.encoder_7(e6) 134 | 135 | d = self.decoder_1(e7) 136 | d = torch.cat((d, e6), dim=1) 137 | d = self.decoder_2(d) 138 | d = torch.cat((d, e5), dim=1) 139 | d = self.decoder_3(d) 140 | d = torch.cat((d, e4), dim=1) 141 | d = self.decoder_4(d) 142 | d = torch.cat((d, e3), dim=1) 143 | d = self.decoder_5(d) 144 | d = torch.cat((d, e2), dim=1) 145 | d = self.decoder_6(d) 146 | d = torch.cat((d, e1), dim=1) 147 | d = self.decoder_7(d) 148 | 149 | return d 150 | 151 | class discriminator(nn.Module): 152 | def __init__(self): 153 | super(discriminator, self).__init__() 154 | self.f_1 = Conv2d(2, 48, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 155 | self.f_2 = Conv2d(48,96, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 156 | self.f_3 = Conv2d(96, 192, 4, stride = 2, bn = False, activation = 'leakyrelu', dropout=False) 157 | self.f_4 = Conv2d(192, 384, 3, stride = 1, bn = False, activation = 'leakyrelu', dropout=False) 158 | self.f_5 = Conv2d(384, 1, 3, stride = 1, bn = False, activation = 'leakyrelu', dropout=False) 159 | 160 | def forward(self, images, den_maps): 161 | x = torch.cat((images, den_maps), dim = 1) 162 | x = self.f_1(x) 163 | x = self.f_2(x) 164 | x = self.f_3(x) 165 | x = self.f_4(x) 166 | x = self.f_5(x) 167 | 168 | logits = F.avg_pool2d(x, x.size()[2:]) 169 | logits = logits.view(-1) 170 | y = F.tanh(logits) 171 | return logits, y 172 | 173 | -------------------------------------------------------------------------------- /architecture/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def save_net(fname, net): 7 | import h5py 8 | h5f = h5py.File(fname, mode='w') 9 | for k, v in net.state_dict().items(): 10 | h5f.create_dataset(k, data=v.cpu().numpy()) 11 | 12 | 13 | def load_net(fname, net): 14 | import h5py 15 | h5f = h5py.File(fname, mode='r') 16 | for k, v in net.state_dict().items(): 17 | if k in h5f: #layer exists in saved model 18 | param = torch.from_numpy(np.asarray(h5f[k])) 19 | v.copy_(param) 20 | else: 21 | print("WARNING: saved model does not have layer {}".format(k)) 22 | 23 | 24 | def np_to_variable(x, is_cuda=True, is_training=False, dtype=torch.FloatTensor): 25 | if is_training: 26 | v = Variable(torch.from_numpy(x).type(dtype)) 27 | else: 28 | with torch.no_grad(): 29 | v = Variable(torch.from_numpy(x).type(dtype), requires_grad = False) 30 | if is_cuda: 31 | v = v.cuda() 32 | return v 33 | 34 | 35 | def set_trainable(model, requires_grad): 36 | for param in model.parameters(): 37 | param.requires_grad = requires_grad 38 | 39 | 40 | def weights_normal_init(model, dev=0.01): 41 | if isinstance(model, list): 42 | for m in model: 43 | weights_normal_init(m, dev) 44 | else: 45 | for m in model.modules(): 46 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 47 | #print torch.sum(m.weight) 48 | m.weight.data.normal_(0.0, dev) 49 | if m.bias is not None: 50 | m.bias.data.fill_(0.0) 51 | elif isinstance(m, nn.Linear): 52 | m.weight.data.normal_(0.0, dev) 53 | 54 | 55 | -------------------------------------------------------------------------------- /architecture/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer(object): 4 | def __init__(self): 5 | self.tot_time = 0. 6 | self.calls = 0 7 | self.start_time = 0. 8 | self.diff = 0. 9 | self.average_time = 0. 10 | 11 | def tic(self): 12 | # using time.time instead of time.clock because time time.clock 13 | # does not normalize for multithreading 14 | self.start_time = time.time() 15 | 16 | def toc(self, average=True): 17 | self.diff = time.time() - self.start_time 18 | self.tot_time += self.diff 19 | self.calls += 1 20 | self.average_time = self.tot_time / self.calls 21 | if average: 22 | return self.average_time 23 | else: 24 | return self.diff 25 | -------------------------------------------------------------------------------- /architecture/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import matplotlib as mpl 7 | if os.environ.get('DISPLAY','') == '': 8 | mpl.use('Agg') 9 | import matplotlib.pylab as plt 10 | import gc 11 | 12 | 13 | def save_results(img, gt_density_map, et_density_map, idx, output_dir): 14 | idx = idx[0] 15 | img = img[0, 0] 16 | gt_density_map = np.array(gt_density_map[0, 0]) 17 | et_density_map = et_density_map[0, 0] 18 | gt_count = np.sum(gt_density_map) 19 | et_count = np.sum(et_density_map) 20 | maxi = gt_density_map.max() 21 | if maxi != 0: 22 | gt_density_map = gt_density_map*(255. / maxi) 23 | et_density_map = et_density_map*(255. / maxi) 24 | #print("min, max GT - ET", gt_density_map.max(), gt_density_map.min(), et_density_map.max(), et_density_map.min()) 25 | 26 | if gt_density_map.shape[1] != img.shape[1]: 27 | gt_density_map = cv2.resize(gt_density_map, (img.shape[1], img.shape[0])) 28 | et_density_map = cv2.resize(et_density_map, (img.shape[1], img.shape[0])) 29 | 30 | fig = plt.figure(figsize = (30, 20)) 31 | a = fig.add_subplot(1, 3, 1) 32 | plt.imshow(img, cmap='gray') 33 | a.set_title('input') 34 | plt.axis('off') 35 | a = fig.add_subplot(1, 3, 2) 36 | plt.imshow(gt_density_map) 37 | a.set_title('ground thruth {:.2f}'.format(gt_count)) 38 | plt.axis('off') 39 | a = fig.add_subplot(1, 3, 3) 40 | plt.imshow(et_density_map) 41 | a.set_title('estimated {:.2f}'.format(et_count)) 42 | plt.axis('off') 43 | 44 | img_file_name = os.path.join(output_dir, str(idx) + ".jpg") 45 | fig.savefig(img_file_name, bbox_inches='tight') 46 | fig.clf() 47 | plt.close() 48 | del a 49 | gc.collect() 50 | 51 | -------------------------------------------------------------------------------- /manage_data/__pycache__/data_augmentation.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RQuispeC/pytorch-ACSCP/1247fcc4f54f247ee63859adfddfd7c46d753142/manage_data/__pycache__/data_augmentation.cpython-35.pyc -------------------------------------------------------------------------------- /manage_data/__pycache__/dataset_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RQuispeC/pytorch-ACSCP/1247fcc4f54f247ee63859adfddfd7c46d753142/manage_data/__pycache__/dataset_loader.cpython-35.pyc -------------------------------------------------------------------------------- /manage_data/__pycache__/get_density_map.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RQuispeC/pytorch-ACSCP/1247fcc4f54f247ee63859adfddfd7c46d753142/manage_data/__pycache__/get_density_map.cpython-35.pyc -------------------------------------------------------------------------------- /manage_data/__pycache__/get_density_map.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RQuispeC/pytorch-ACSCP/1247fcc4f54f247ee63859adfddfd7c46d753142/manage_data/__pycache__/get_density_map.cpython-36.pyc -------------------------------------------------------------------------------- /manage_data/__pycache__/misc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RQuispeC/pytorch-ACSCP/1247fcc4f54f247ee63859adfddfd7c46d753142/manage_data/__pycache__/misc.cpython-35.pyc -------------------------------------------------------------------------------- /manage_data/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RQuispeC/pytorch-ACSCP/1247fcc4f54f247ee63859adfddfd7c46d753142/manage_data/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /manage_data/create_synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | THIS SCRIPT HAS NOT BEEN TESTED AND MAY HAVE ERRORS 3 | original author: @darwin 4 | """ 5 | 6 | import numpy as np 7 | import cv2 8 | import json, codecs 9 | from random import * 10 | from random import randint 11 | import scipy.io as sio 12 | import os 13 | import os.path as osp 14 | import matplotlib.pyplot as plt 15 | import sys 16 | 17 | from manage_data.utils import join_json, resize 18 | 19 | people = [] 20 | background = [] 21 | 22 | labelPeople = [] 23 | labelBackground = [] 24 | 25 | dataset = [] 26 | 27 | def readImages(fileName = 'input/FAKE/elements/', scale =10): 28 | # read figures of people and background 29 | file_names = os.listdir(fileName) 30 | file_names.sort() 31 | for file_name in file_names: 32 | if file_name[len(file_name) - 3:] != 'png' and file_name[len(file_name) - 3:] != 'jpg': 33 | continue 34 | image_name = fileName + file_name 35 | img = cv2.imread(image_name,0) 36 | if file_name[0]=='p': 37 | img = cv2.resize(img, (img.shape[1]//scale, img.shape[0]//scale)) 38 | people.append(img) 39 | else : 40 | background.append(img) 41 | 42 | def readJSON(fileName= '/JSON/', scale =10): 43 | # read labels(.json) of people and background 44 | file_names = os.listdir(fileName) 45 | file_names.sort() 46 | for file_name in file_names: 47 | if file_name[len(file_name) - 4:] != 'json': 48 | continue 49 | label_name = fileName + file_name 50 | if file_name[0]=='p': 51 | with open(label_name) as data_file: 52 | data = json.load(data_file) 53 | labelPeople.append(resize(data, scale)) 54 | else : 55 | with open(label_name) as data_file: 56 | data = json.load(data_file) 57 | labelBackground.append(data) 58 | 59 | def isValid(Y, X, img): 60 | return X>=0 and X=0 and Y=0 and X=0 and Y= people_thr: 146 | img_id += 1 147 | out_img_path = osp.join(out_img_dir, str(img_id).zfill(7) + '.jpg') 148 | out_lab_path = osp.join(out_lab_dir, str(img_id).zfill(7) + '.json') 149 | out_den_path = osp.join(out_den_dir, str(img_id).zfill(7) + '.npy') 150 | cv2.imwrite(out_img_path, new_img) 151 | np.save(out_den_path, new_den) 152 | data = str(json_to_string(new_labels)) 153 | with open(out_lab_path, 'w') as outfile: 154 | outfile.write(data) 155 | return img_id 156 | 157 | def augment(img_paths, label_paths, den_paths, out_img_dir, out_lab_dir, out_den_dir, slide_window_params, noise_params, light_params, add_original = False): 158 | print("Augmenting data, results will be stored in '{}'".format(out_img_dir)) 159 | aug_img_id = 0 160 | for img_path, label_path, den_path in zip(img_paths, label_paths, den_paths): 161 | img = cv2.imread(img_path) 162 | den = np.load(den_path) 163 | label = json.load(open(label_path)) 164 | 165 | #sliding window for data data augmentation 166 | aug_img_id = sliding_window(out_img_dir, out_lab_dir, out_den_dir, aug_img_id, img, label, den 167 | , displace = slide_window_params['displace'] 168 | , size_x = slide_window_params['size_x'] 169 | , size_y = slide_window_params['size_y'] 170 | , people_thr = slide_window_params['people_thr']) 171 | 172 | if add_original: 173 | #add original to augmented set 174 | aug_img_id += 1 175 | out_img_path = osp.join(out_img_dir, str(aug_img_id).zfill(7) + '.jpg') 176 | out_lab_path = osp.join(out_lab_dir, str(aug_img_id).zfill(7) + '.json') 177 | out_den_path = osp.join(out_den_dir, str(aug_img_id).zfill(7) + '.npy') 178 | cv2.imwrite(out_img_path, img) 179 | npy.save(out_den_path, den) 180 | label = json_to_string(label) 181 | with open(out_lab_path, 'w') as outfile: 182 | outfile.write(label) 183 | 184 | #if slide_window_params['joint_patches']: 185 | # join_patches(out_img_dir, out_lab_dir, slide_window_params['numberPatch']) 186 | if noise_params['augment_noise']: 187 | aug_img_id = noise_augmentation(out_img_dir, out_lab_dir, out_den_dir, aug_img_id) 188 | if light_params['augment_light']: 189 | aug_img_id = bright_contrast_augmentation(out_img_dir, out_lab_dir, out_den_dir, light_params['bright'], light_params['contrast'], aug_img_id) 190 | print("{} images created after augmentation".format(aug_img_id)) 191 | 192 | 193 | 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /manage_data/dataset_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import glob 4 | import re 5 | import sys 6 | import urllib 7 | import tarfile 8 | import zipfile 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | import numpy as np 12 | import h5py 13 | from scipy.misc import imsave 14 | import scipy.io as sio 15 | 16 | from manage_data.get_density_map import create_density_map 17 | from manage_data.utils import mkdir_if_missing, copy_to_directory 18 | from manage_data.data_augmentation import augment 19 | 20 | """class that has a train and test unit""" 21 | class train_test_unit(object): 22 | train_dir_img = "" 23 | train_dir_den = "" 24 | test_dir_img = "" 25 | test_dir_den = "" 26 | metadata = dict() 27 | def __init__(self, _train_dir_img, _train_dir_den, _test_dir_img, _test_dir_den, kwargs): 28 | self.train_dir_img = _train_dir_img 29 | self.train_dir_den = _train_dir_den 30 | self.test_dir_img = _test_dir_img 31 | self.test_dir_den = _test_dir_den 32 | self._check_before_run() 33 | self.metadata = kwargs 34 | 35 | def _check_before_run(self): 36 | """Check if all files are available before going deeper""" 37 | if not osp.exists(self.train_dir_img): 38 | raise RuntimeError("'{}' is not available".format(self.train_dir_img)) 39 | if not osp.exists(self.train_dir_den): 40 | raise RuntimeError("'{}' is not available".format(self.train_dir_den)) 41 | if not osp.exists(self.test_dir_img): 42 | raise RuntimeError("'{}' is not available".format(self.test_dir_img)) 43 | if not osp.exists(self.test_dir_den): 44 | raise RuntimeError("'{}' is not available".format(self.test_dir_den)) 45 | 46 | def to_string(self): 47 | return "_".join([ str(key) + "_" + str(value) for key, value in sorted(self.metadata.items()) if key != 'name']) 48 | 49 | """Dataset classes""" 50 | 51 | """Crowd counting dataset""" 52 | 53 | class UCF_CC_50(object): 54 | root = 'data/ucf_cc_50/' 55 | ori_dir = osp.join(root, 'UCF_CC_50') 56 | ori_dir_lab = osp.join(ori_dir, 'labels') 57 | ori_dir_img = osp.join(ori_dir, 'images') 58 | 59 | ori_dir_den = osp.join(ori_dir, 'density_maps') #.npy files of density maps matrices 60 | augmented_dir = "" 61 | train_test_set = [] 62 | signature_args = ['people_thr', 'gt_mode'] 63 | metadata = dict() 64 | train_test_size = 5 65 | 66 | def __init__(self, force_create_den_maps = False, force_augmentation = False, **kwargs): 67 | self._check_before_run() 68 | self.metadata = kwargs 69 | self._create_original_density_maps(force_create_den_maps) 70 | self._create_train_test(force_augmentation, kwargs) 71 | 72 | def _create_original_density_maps(self, force_create_den_maps): 73 | if not osp.exists(self.ori_dir_den): 74 | os.makedirs(self.ori_dir_den) 75 | elif not force_create_den_maps: 76 | return 77 | create_density_map(self.ori_dir_img, self.ori_dir_lab, self.ori_dir_den, mode = self.metadata['gt_mode']) 78 | 79 | def _check_before_run(self): 80 | """Check if all files are available before going deeper""" 81 | if not osp.exists(self.root): 82 | raise RuntimeError("'{}' is not available".format(self.root)) 83 | if not osp.exists(self.ori_dir): 84 | raise RuntimeError("'{}' is not available".format(self.ori_dir)) 85 | if not osp.exists(self.ori_dir_img): 86 | raise RuntimeError("'{}' is not available".format(self.ori_dir_img)) 87 | if not osp.exists(self.ori_dir_lab): 88 | raise RuntimeError("'{}' is not available".format(self.ori_dir_lab)) 89 | 90 | def signature(self): 91 | return "_".join(["{}_{}".format(sign_elem, self.metadata[sign_elem]) for sign_elem in self.signature_args]) 92 | 93 | def _create_train_test(self, force_augmentation, kwargs): 94 | slide_window_params = {'displace' : kwargs['displace'], 'size_x' : kwargs['size_x'], 'size_y' : kwargs['size_y'], 'people_thr' : kwargs['people_thr']} 95 | noise_params = {'augment_noise' : kwargs['augment_noise']} 96 | light_params = {'augment_light' : kwargs['augment_light'], 'bright' : kwargs['bright'], 'contrast' : kwargs['contrast']} 97 | 98 | file_names = os.listdir(self.ori_dir_img) 99 | file_names.sort() 100 | img_names = [] 101 | img_ids = [] 102 | for file_name in file_names: 103 | file_extention = file_name.split('.')[-1] 104 | file_id = file_name[:len(file_name) - len(file_extention)] 105 | if file_extention != 'png' and file_extention != 'jpg': 106 | continue 107 | img_names.append(file_name) 108 | img_ids.append(file_id) 109 | if len(img_names) != 50: 110 | raise RuntimeError("UCF_CC_50 dataset expects 50 images, {} found".format(len(img_names))) 111 | 112 | self.augmented_dir = osp.join(self.root, self.signature()) 113 | augment_data = False 114 | if osp.exists(self.augmented_dir): 115 | print("'{}' already exists".format(self.augmented_dir)) 116 | if force_augmentation: 117 | augment_data = True 118 | print("augmenting data anyway") 119 | else: 120 | augment_data = False 121 | print("will not augmenting data") 122 | else: 123 | augment_data = True 124 | os.makedirs(self.augmented_dir) 125 | 126 | #using 5 fold cross validation protocol 127 | for fold in range(5): 128 | fold_dir = osp.join(self.augmented_dir, 'fold{}'.format(fold + 1)) 129 | aug_train_dir_img = osp.join(fold_dir, 'train_img') 130 | aug_train_dir_den = osp.join(fold_dir, 'train_den') 131 | aug_train_dir_lab = osp.join(fold_dir, 'train_lab') 132 | fold_test_dir_img = osp.join(fold_dir, 'test_img') 133 | fold_test_dir_den = osp.join(fold_dir, 'test_den') 134 | fold_test_dir_lab = osp.join(fold_dir, 'test_lab') 135 | 136 | mkdir_if_missing(aug_train_dir_img) 137 | mkdir_if_missing(aug_train_dir_den) 138 | mkdir_if_missing(aug_train_dir_lab) 139 | mkdir_if_missing(fold_test_dir_img) 140 | mkdir_if_missing(fold_test_dir_den) 141 | mkdir_if_missing(fold_test_dir_lab) 142 | 143 | kwargs['name'] = 'ucf-fold{}'.format(fold + 1) 144 | train_test = train_test_unit(aug_train_dir_img, aug_train_dir_den, fold_test_dir_img, fold_test_dir_den, kwargs.copy()) 145 | self.train_test_set.append(train_test) 146 | 147 | if augment_data: 148 | test_img = img_names[fold * 10: (fold + 1) * 10] 149 | test_ids = img_ids[fold * 10: (fold + 1) * 10] 150 | test_den_paths = [osp.join(self.ori_dir_den, img_id + 'npy') for img_id in test_ids] 151 | test_lab_paths = [osp.join(self.ori_dir_lab, img_id + 'json') for img_id in test_ids] 152 | test_img_paths = [osp.join(self.ori_dir_img, img) for img in test_img] 153 | 154 | train_img = sorted(list(set(img_names) - set(test_img))) 155 | train_ids = sorted(list(set(img_ids) - set(test_ids))) 156 | train_den_paths = [osp.join(self.ori_dir_den, img_id + 'npy') for img_id in train_ids] 157 | train_lab_paths = [osp.join(self.ori_dir_lab, img_id + 'json') for img_id in train_ids] 158 | train_img_paths = [osp.join(self.ori_dir_img, img) for img in train_img] 159 | 160 | #augment train data 161 | print("Augmenting {}".format(kwargs['name'])) 162 | augment(train_img_paths, train_lab_paths, train_den_paths, aug_train_dir_img, aug_train_dir_lab, aug_train_dir_den, slide_window_params, noise_params, light_params) 163 | copy_to_directory(test_den_paths, fold_test_dir_den) 164 | copy_to_directory(test_lab_paths, fold_test_dir_lab) 165 | copy_to_directory(test_img_paths, fold_test_dir_img) 166 | 167 | class ShanghaiTech(object): 168 | root = 'data/ShanghaiTech/' 169 | ori_dir_partA = osp.join(root, 'part_A') 170 | ori_dir_partA_train = osp.join(ori_dir_partA, 'train_data') 171 | ori_dir_partA_train_mat = osp.join(ori_dir_partA_train, 'ground-truth') 172 | ori_dir_partA_train_img = osp.join(ori_dir_partA_train, 'images') 173 | ori_dir_partA_test = osp.join(ori_dir_partA, 'test_data') 174 | ori_dir_partA_test_mat = osp.join(ori_dir_partA_test, 'ground-truth') 175 | ori_dir_partA_test_img = osp.join(ori_dir_partA_test, 'images') 176 | 177 | ori_dir_partB = osp.join(root, 'part_B') 178 | ori_dir_partB_train = osp.join(ori_dir_partB, 'train_data') 179 | ori_dir_partB_train_mat = osp.join(ori_dir_partB_train, 'ground-truth') 180 | ori_dir_partB_train_img = osp.join(ori_dir_partB_train, 'images') 181 | ori_dir_partB_test = osp.join(ori_dir_partB, 'test_data') 182 | ori_dir_partB_test_mat = osp.join(ori_dir_partB_test, 'ground-truth') 183 | ori_dir_partB_test_img = osp.join(ori_dir_partB_test, 'images') 184 | 185 | #to be computed 186 | ori_dir_partA_train_lab = osp.join(ori_dir_partA_train, 'labels') 187 | ori_dir_partA_train_den = osp.join(ori_dir_partA_train, 'density_maps') 188 | ori_dir_partA_test_lab = osp.join(ori_dir_partA_test, 'labels') 189 | ori_dir_partA_test_den = osp.join(ori_dir_partA_test, 'density_maps') 190 | 191 | ori_dir_partB_train_lab = osp.join(ori_dir_partB_train, 'labels') 192 | ori_dir_partB_train_den = osp.join(ori_dir_partB_train, 'density_maps') 193 | ori_dir_partB_test_lab = osp.join(ori_dir_partB_test, 'labels') 194 | ori_dir_partB_test_den = osp.join(ori_dir_partB_test, 'density_maps') 195 | 196 | augmented_dir_partA = "" 197 | augmented_dir_partB = "" 198 | train_test_set = [] 199 | signature_args = ['people_thr', 'gt_mode'] 200 | metadata = dict() 201 | train_test_size = 2 202 | 203 | def __init__(self, force_create_den_maps = False, force_augmentation = False, **kwargs): 204 | self._check_before_run() 205 | self.metadata = kwargs 206 | self._create_labels() 207 | self._create_original_density_maps(force_create_den_maps) 208 | self._create_train_test(force_augmentation, kwargs) 209 | 210 | def _create_original_density_maps(self, force_create_den_maps): 211 | all_density_dirs_exist = osp.exists(self.ori_dir_partA_train_den) 212 | all_density_dirs_exist = all_density_dirs_exist and osp.exists(self.ori_dir_partA_test_den) 213 | all_density_dirs_exist = all_density_dirs_exist and osp.exists(self.ori_dir_partB_train_den) 214 | all_density_dirs_exist = all_density_dirs_exist and osp.exists(self.ori_dir_partB_test_den) 215 | if not all_density_dirs_exist: 216 | mkdir_if_missing(self.ori_dir_partA_train_den) 217 | mkdir_if_missing(self.ori_dir_partA_test_den) 218 | mkdir_if_missing(self.ori_dir_partB_train_den) 219 | mkdir_if_missing(self.ori_dir_partB_test_den) 220 | elif not force_create_den_maps: 221 | return 222 | create_density_map(self.ori_dir_partA_train_img, self.ori_dir_partA_train_lab, self.ori_dir_partA_train_den, mode = self.metadata['gt_mode']) 223 | create_density_map(self.ori_dir_partA_test_img, self.ori_dir_partA_test_lab, self.ori_dir_partA_test_den, mode = self.metadata['gt_mode']) 224 | create_density_map(self.ori_dir_partB_train_img, self.ori_dir_partB_train_lab, self.ori_dir_partB_train_den, mode = self.metadata['gt_mode']) 225 | create_density_map(self.ori_dir_partB_test_img, self.ori_dir_partB_test_lab, self.ori_dir_partB_test_den, mode = self.metadata['gt_mode']) 226 | 227 | def _check_before_run(self): 228 | """Check if all files are available before going deeper""" 229 | if not osp.exists(self.root): 230 | raise RuntimeError("'{}' is not available".format(self.root)) 231 | if not osp.exists(self.ori_dir_partA): 232 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partA)) 233 | if not osp.exists(self.ori_dir_partB): 234 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partB)) 235 | if not osp.exists(self.ori_dir_partA_train): 236 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partA_train)) 237 | if not osp.exists(self.ori_dir_partB_train): 238 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partB_train)) 239 | if not osp.exists(self.ori_dir_partA_test): 240 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partA_test)) 241 | if not osp.exists(self.ori_dir_partB_test): 242 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partB_test)) 243 | if not osp.exists(self.ori_dir_partA_train_img): 244 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partA_train_img)) 245 | if not osp.exists(self.ori_dir_partA_train_mat): 246 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partA_train_mat)) 247 | if not osp.exists(self.ori_dir_partB_train_img): 248 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partB_train_img)) 249 | if not osp.exists(self.ori_dir_partB_train_mat): 250 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partB_train_mat)) 251 | if not osp.exists(self.ori_dir_partA_test_img): 252 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partA_test_img)) 253 | if not osp.exists(self.ori_dir_partA_test_mat): 254 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partA_test_mat)) 255 | if not osp.exists(self.ori_dir_partB_test_img): 256 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partB_test_img)) 257 | if not osp.exists(self.ori_dir_partB_test_mat): 258 | raise RuntimeError("'{}' is not available".format(self.ori_dir_partB_test_mat)) 259 | 260 | def _json_to_string(self, array): 261 | """ 262 | converts json to string specifically for shanghai tech dataset 263 | """ 264 | if len(array)==0: 265 | return '[]' 266 | line = '[' 267 | for i in range(len(array)): 268 | line += '{\"x\":'+str(array[i][0])+',\"y\":'+str(array[i][1])+'},' 269 | return line[0:len(line)-1]+']' 270 | 271 | def _convert_mat_to_json(self, in_dir, out_dir): 272 | """ 273 | converts every .mat file in in_dir to a .json equivalent in out_dir 274 | """ 275 | print("converting mat to json from {}".format(in_dir)) 276 | file_names = os.listdir(in_dir) 277 | for mat_file in file_names: 278 | mat_file_path = osp.join(in_dir, mat_file) 279 | file_extention = mat_file.split('.')[-1] 280 | file_id = mat_file[3:len(mat_file) - len(file_extention)] 281 | json_file_path = osp.join(out_dir, file_id + 'json') 282 | labels = sio.loadmat(mat_file_path) 283 | labels = labels['image_info'][0][0][0][0][0] 284 | labels = str(self._json_to_string(labels)) 285 | with open(json_file_path, 'w') as outfile: 286 | outfile.write(labels) 287 | 288 | def _create_labels(self): 289 | mkdir_if_missing(self.ori_dir_partA_train_lab) 290 | mkdir_if_missing(self.ori_dir_partA_test_lab) 291 | mkdir_if_missing(self.ori_dir_partB_train_lab) 292 | mkdir_if_missing(self.ori_dir_partB_test_lab) 293 | #check if number os files is equal 294 | if len(os.listdir(self.ori_dir_partA_train_mat)) != len(os.listdir(self.ori_dir_partA_train_lab)): 295 | self._convert_mat_to_json(self.ori_dir_partA_train_mat, self.ori_dir_partA_train_lab) 296 | if len(os.listdir(self.ori_dir_partA_test_mat)) != len(os.listdir(self.ori_dir_partA_test_lab)): 297 | self._convert_mat_to_json(self.ori_dir_partA_test_mat, self.ori_dir_partA_test_lab) 298 | if len(os.listdir(self.ori_dir_partB_train_mat)) != len(os.listdir(self.ori_dir_partB_train_lab)): 299 | self._convert_mat_to_json(self.ori_dir_partB_train_mat, self.ori_dir_partB_train_lab) 300 | if len(os.listdir(self.ori_dir_partB_test_mat)) != len(os.listdir(self.ori_dir_partB_test_lab)): 301 | self._convert_mat_to_json(self.ori_dir_partB_test_mat, self.ori_dir_partB_test_lab) 302 | 303 | def signature(self): 304 | return "_".join(["{}_{}".format(sign_elem, self.metadata[sign_elem]) for sign_elem in self.signature_args]) 305 | 306 | def _create_train_test(self, force_augmentation, kwargs): 307 | slide_window_params = {'displace' : kwargs['displace'], 'size_x' : kwargs['size_x'], 'size_y' : kwargs['size_y'], 'people_thr' : kwargs['people_thr']} 308 | noise_params = {'augment_noise' : kwargs['augment_noise']} 309 | light_params = {'augment_light' : kwargs['augment_light'], 'bright' : kwargs['bright'], 'contrast' : kwargs['contrast']} 310 | 311 | #shanghaiTech part A 312 | self.augmented_dir_partA = osp.join(self.ori_dir_partA, self.signature()) 313 | augment_data_A = False 314 | if osp.exists(self.augmented_dir_partA): 315 | print("'{}' already exists".format(self.augmented_dir_partA)) 316 | if force_augmentation: 317 | augment_data_A = True 318 | print("augmenting data anyway") 319 | else: 320 | augment_data_A = False 321 | print("will not augmenting data") 322 | else: 323 | augment_data_A = True 324 | os.makedirs(self.augmented_dir_partA) 325 | 326 | aug_dir_partA_img = osp.join(self.augmented_dir_partA, "train_img") 327 | aug_dir_partA_den = osp.join(self.augmented_dir_partA, "train_den") 328 | aug_dir_partA_lab = osp.join(self.augmented_dir_partA, "train_lab") 329 | mkdir_if_missing(aug_dir_partA_img) 330 | mkdir_if_missing(aug_dir_partA_den) 331 | mkdir_if_missing(aug_dir_partA_lab) 332 | 333 | kwargs['name'] = 'shanghai-partA' 334 | part_A_train_test = train_test_unit(aug_dir_partA_img, aug_dir_partA_den, self.ori_dir_partA_test_img, self.ori_dir_partA_test_den, kwargs.copy()) 335 | self.train_test_set.append(part_A_train_test) 336 | 337 | if augment_data_A: 338 | ori_img_paths = [osp.join(self.ori_dir_partA_train_img, file_name) for file_name in sorted(os.listdir(self.ori_dir_partA_train_img))] 339 | ori_lab_paths = [osp.join(self.ori_dir_partA_train_lab, file_name) for file_name in sorted(os.listdir(self.ori_dir_partA_train_lab))] 340 | ori_den_paths = [osp.join(self.ori_dir_partA_train_den, file_name) for file_name in sorted(os.listdir(self.ori_dir_partA_train_den))] 341 | augment(ori_img_paths, ori_lab_paths, ori_den_paths, aug_dir_partA_img, aug_dir_partA_lab, aug_dir_partA_den, slide_window_params, noise_params, light_params) 342 | 343 | #shanghaiTech part B 344 | self.augmented_dir_partB = osp.join(self.ori_dir_partB, self.signature()) 345 | augment_data_B = False 346 | if osp.exists(self.augmented_dir_partB): 347 | print("'{}' already exists".format(self.augmented_dir_partB)) 348 | if force_augmentation: 349 | augment_data_B = True 350 | print("augmenting data anyway") 351 | else: 352 | augment_data_B = False 353 | print("will not augmenting data") 354 | else: 355 | augment_data_B = True 356 | os.makedirs(self.augmented_dir_partB) 357 | 358 | aug_dir_partB_img = osp.join(self.augmented_dir_partB, "train_img") 359 | aug_dir_partB_den = osp.join(self.augmented_dir_partB, "train_den") 360 | aug_dir_partB_lab = osp.join(self.augmented_dir_partB, "train_lab") 361 | mkdir_if_missing(aug_dir_partB_img) 362 | mkdir_if_missing(aug_dir_partB_den) 363 | mkdir_if_missing(aug_dir_partB_lab) 364 | 365 | kwargs['name'] = 'shanghai-partB' 366 | part_B_train_test = train_test_unit(aug_dir_partB_img, aug_dir_partB_den, self.ori_dir_partB_test_img, self.ori_dir_partB_test_den, kwargs.copy()) 367 | self.train_test_set.append(part_B_train_test) 368 | 369 | if augment_data_B: 370 | ori_img_paths = [osp.join(self.ori_dir_partB_train_img, file_name) for file_name in sorted(os.listdir(self.ori_dir_partB_train_img))] 371 | ori_lab_paths = [osp.join(self.ori_dir_partB_train_lab, file_name) for file_name in sorted(os.listdir(self.ori_dir_partB_train_lab))] 372 | ori_den_paths = [osp.join(self.ori_dir_partB_train_den, file_name) for file_name in sorted(os.listdir(self.ori_dir_partB_train_den))] 373 | augment(ori_img_paths, ori_lab_paths, ori_den_paths, aug_dir_partB_img, aug_dir_partB_lab, aug_dir_partB_den, slide_window_params, noise_params, light_params) 374 | 375 | """Create dataset""" 376 | 377 | __factory = { 378 | 'ucf-cc-50': UCF_CC_50, 379 | 'shanghai-tech': ShanghaiTech 380 | } 381 | 382 | def get_names(): 383 | return __factory.keys() 384 | 385 | def init_dataset(name, force_create_den_maps = False, force_augmentation = False, **kwargs): 386 | if name not in __factory.keys(): 387 | raise KeyError("Unknown dataset: {}".format(name)) 388 | return __factory[name](force_create_den_maps, force_augmentation, **kwargs) 389 | 390 | if __name__ == '__main__': 391 | pass 392 | -------------------------------------------------------------------------------- /manage_data/get_density_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy.signal import gaussian 4 | import json 5 | from matplotlib.image import imsave 6 | 7 | import os.path as osp 8 | import json 9 | import os 10 | 11 | from manage_data.utils import cnt_overlaps 12 | 13 | F_SZ = 15 14 | SIGMA = 15 15 | 16 | VALID_GT_MODES = ['same'] 17 | 18 | def gauss_ker(sigma, shape): 19 | gaussian_kernel = np.outer(gaussian(shape[0], std = sigma), gaussian(shape[1], std = sigma)) 20 | gaussian_kernel /= gaussian_kernel.sum() 21 | return gaussian_kernel 22 | 23 | def get_density_map_gaussian(img_shape, points, mode = 'same'): 24 | """ 25 | Creates a density map with img_shape and gaussians over points 26 | 27 | Inputs: 28 | - img_shape: tuple of the heigth and width of the ouput density map 29 | - points: positions for the head of people 30 | - mode: ["same", "k-nearest"] if "same" is used all the gaussian kernels has the same kernel size, else and k-nearest kernel is used. 31 | 32 | Ouputs: 33 | - density_map of shape img_shape 34 | """ 35 | img_density = np.zeros(img_shape) 36 | h, w = img_shape 37 | for ind, point in enumerate(points): 38 | kernel_size_y, kernel_size_x = F_SZ, F_SZ 39 | SIGMA = 4 40 | H = gauss_ker(SIGMA, [kernel_size_y, kernel_size_x]) 41 | x = min(w,max(1,(int)(abs(point[1])))) 42 | y = min(h,max(1,(int)(abs(point[0])))) 43 | if x > w or y > h: 44 | continue 45 | x1 = x - (int)(np.floor(kernel_size_x/2)) 46 | y1 = y - (int)(np.floor(kernel_size_y/2)) 47 | x2 = x + (int)(np.floor(kernel_size_x/2)) 48 | y2 = y + (int)(np.floor(kernel_size_y/2)) 49 | dfx1 = 0 50 | dfy1 = 0 51 | dfx2 = 0 52 | dfy2 = 0 53 | change_H = False 54 | if x1 < 1: 55 | dfx1 = abs(x1) + 1 56 | x1 = 1 57 | change_H = True 58 | if y1 < 1: 59 | dfy1 = abs(y1) + 1; 60 | y1 = 1; 61 | change_H = True 62 | if x2 > w: 63 | dfx2 = x2 - w 64 | x2 = w 65 | change_H = True 66 | if y2 > h: 67 | dfy2 = y2 - h 68 | y2 = h 69 | change_H = True 70 | x1h = 1+dfx1 71 | y1h = 1+dfy1 72 | x2h = kernel_size_x - dfx2 73 | y2h = kernel_size_x - dfy2 74 | x1 = (int)(x1) 75 | x2 = (int)(x2) 76 | y1 = (int)(y1) 77 | y2 = (int)(y2) 78 | if change_H or y2 - y1 != kernel_size_y or x2 - x1 != kernel_size_x: 79 | H = gauss_ker(SIGMA, [y2 - y1, x2 - x1]) 80 | img_density[y1:y2, x1:x2] += H 81 | return img_density 82 | 83 | def create_density_map(imgs_path, labels_path, density_maps_path, mode = 'same'): 84 | """ 85 | Generates density maps files (.npy) inside directory density_maps_path 86 | 87 | input: 88 | 89 | imgs_path: directory with original images (.jpg or .png) 90 | labels_path: directory with data labels (.json) 91 | density_maps_path: directory where generated density maps (.npy) files are stored 92 | mode: method used for generation of ground thuth images 93 | """ 94 | if not mode in VALID_GT_MODES: 95 | raise RuntimeError("'{}' is invalid mode for grounth thruth generation. Valid modes are: {}".format(self.ori_dir_img, ', '.join(VALID_GT_MODES))) 96 | file_names = os.listdir(imgs_path) 97 | file_names.sort() 98 | print("Creating density maps for '{}', {} images will be processed".format(imgs_path, len(file_names))) 99 | 100 | for file_name in file_names: 101 | file_extention = file_name.split('.')[-1] 102 | file_id = file_name[:len(file_name) - len(file_extention)] 103 | if file_extention != 'png' and file_extention != 'jpg': 104 | continue 105 | file_path = osp.join(imgs_path, file_name) 106 | label_path = osp.join(labels_path, file_id + 'json') 107 | density_map_path = osp.join(density_maps_path, file_id + 'npy') 108 | with open(label_path) as data_file: 109 | labels = json.load(data_file) 110 | points = [] 111 | for p in labels: 112 | points.append([p['y'], p['x']]) 113 | img = cv2.imread(file_path) 114 | img_den = get_density_map_gaussian(img.shape[:2], points, mode = mode) 115 | np.save(density_map_path, img_den) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /manage_data/get_density_map.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RQuispeC/pytorch-ACSCP/1247fcc4f54f247ee63859adfddfd7c46d753142/manage_data/get_density_map.pyc -------------------------------------------------------------------------------- /manage_data/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | set of functions used for debbuging during development 3 | """ 4 | 5 | import numpy as np 6 | import cv2 7 | import json 8 | 9 | import os 10 | import matplotlib as mpl 11 | if os.environ.get('DISPLAY','') == '': 12 | mpl.use('Agg') 13 | import matplotlib.pylab as plt 14 | import gc 15 | 16 | def create_points(shape): 17 | points = [] 18 | for i in range(shape[0]): 19 | for j in range(shape[1]): 20 | points.append([i, j]) 21 | return points 22 | 23 | def max_value_gt(): 24 | shape = (15, 15) 25 | points = create_points(shape) 26 | img = f(shape, points) 27 | print(np.max(img)) 28 | 29 | def plot_img_gt(): 30 | base_dir = "/home/quispe/Documents/crowd-counting/code/" 31 | 32 | #data_dir = "data/ucf_cc_50/people_thr_20_gt_mode_face/fold1/" 33 | #img = cv2.imread(base_dir + data_dir + "train_img/0000234.jpg", 0) 34 | #lab = json.load(open(base_dir + data_dir + "train_lab/0000234.json")) 35 | #gt = np.load(base_dir + data_dir + "train_den/0000234.npy") 36 | 37 | #data_dir = "data/ucf_cc_50/UCF_CC_50/" 38 | #img = cv2.imread(base_dir + data_dir + "images/02.jpg", 0) 39 | #lab = json.load(open(base_dir + data_dir + "labels/02.json")) 40 | #gt = np.load(base_dir + data_dir + "density_maps/02.npy") 41 | 42 | data_dir = "data/ShanghaiTech/part_A/train_data/" 43 | img = cv2.imread(base_dir + data_dir + "images/IMG_74.jpg", 0) 44 | lab = json.load(open(base_dir + data_dir + "labels/IMG_74.json")) 45 | gt = np.load(base_dir + data_dir + "density_maps/IMG_74.npy") 46 | 47 | 48 | gt_cnt = np.sum(gt) 49 | gt = gt*255.0 50 | 51 | fig = plt.figure(figsize = (30, 20)) 52 | a = fig.add_subplot(1, 2, 1) 53 | plt.imshow(img, cmap='gray') 54 | a.set_title('input') 55 | plt.axis('off') 56 | 57 | a = fig.add_subplot(1, 2, 2) 58 | plt.imshow(gt) 59 | a.set_title('sum {:.2f} -- ground thruth {:.0f}'.format(gt_cnt, len(lab))) 60 | plt.axis('off') 61 | 62 | fig.savefig("tmp.jpg", bbox_inches='tight') 63 | fig.clf() 64 | plt.close() 65 | del a 66 | gc.collect() 67 | 68 | def plot_maps(origin_dir, output_dir): 69 | files = os.listdir(origin_dir) 70 | for file in files: 71 | if file.split('.')[-1] != 'npy': 72 | continue 73 | file_name = os.path.join(origin_dir, file) 74 | print(file_name) 75 | image = np.load(file_name) 76 | image = image / np.max(image) * 255 77 | file_out = os.path.join(output_dir, file.split('.')[0] + '.jpg') 78 | print(file_out) 79 | 80 | fig = plt.figure() 81 | a = fig.add_subplot(1, 1, 1) 82 | plt.imshow(image) 83 | plt.axis('off') 84 | 85 | fig.savefig(file_out, bbox_inches='tight') 86 | fig.clf() 87 | plt.close() 88 | del a 89 | gc.collect() 90 | 91 | def plot_loss(log_file, out_file, out_options = '1'): 92 | f = open(log_file, "r") 93 | loss = [] 94 | mae = [] 95 | mse = [] 96 | for line in f: 97 | if line.startswith("Epoch:"): 98 | for item in line.split(","): 99 | item = item.strip() 100 | num = float(item.split()[-1]) 101 | if item.startswith("MAE"): 102 | mae.append(num) 103 | elif item.startswith("MSE"): 104 | mse.append(num) 105 | elif item.startswith("loss:"): 106 | loss.append(num) 107 | assert len(loss) == len(mse) and len(mae) == len(mse), "Error in vector sizes mae: {}, mse: {}, loss: {}".format(len(mae), len(mse), len(loss)) 108 | epoch = np.arange(len(loss)) 109 | if out_options == '1' or out_options == '2': 110 | plt.plot(epoch, loss, label = 'loss') 111 | if out_options == '1' or out_options == '3': 112 | plt.plot(epoch, mae, label = 'mae') 113 | plt.plot(epoch, mse, label = 'mse') 114 | plt.xlabel("epoch") 115 | plt.legend(loc='upper left') 116 | plt.savefig(out_file) 117 | plt.clf() 118 | plt.close() 119 | gc.collect() 120 | 121 | if __name__ == "__main__": 122 | #origin_dir = '/workspace/quispe/ucf_cc_50/UCF_CC_50/density_maps/' 123 | #output_dir = '/home/quispe/public_html/files/ucf_face_det_v3/' 124 | #plot_maps(origin_dir, output_dir) 125 | 126 | #plot_img_gt() 127 | 128 | log_file = 'log/ACSCP/ucf-cc-50_people_thr_20_gt_mode_same/ucf-fold1/log_train.txt' 129 | out_file = 'log/ACSCP/ucf-cc-50_people_thr_20_gt_mode_same/ucf-fold1/log_train_loss_mae_mse.png' 130 | plot_loss(log_file, out_file, out_options = '1') 131 | -------------------------------------------------------------------------------- /manage_data/utils.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | from shutil import copyfile 4 | import sys 5 | import errno 6 | 7 | def mkdir_if_missing(directory): 8 | if not osp.exists(directory): 9 | try: 10 | os.makedirs(directory) 11 | except OSError as e: 12 | if e.errno != errno.EEXIST: 13 | raise 14 | 15 | def join_json(final_data, data, index, size): 16 | x = 0 17 | y = 0 18 | if index == 1 or index == 3: 19 | x = size 20 | if index == 2 or index == 3: 21 | y = size 22 | for i in range(len(data)): 23 | final_data.append({"x":data[i]['x'] + x,"y":data[i]['y'] + y}) 24 | return final_data 25 | 26 | def resize(data, scale): 27 | for i in range(len(data)): 28 | data[i]['x'] = data[i]['x'] / scale; data[i]['y'] = data[i]['y'] / scale; 29 | return data 30 | 31 | def copy_to_directory(files_list, output_dir): 32 | for file in files_list: 33 | file_name = file.split('/')[-1] 34 | out_file_path = osp.join(output_dir, file_name) 35 | copyfile(file, out_file_path) 36 | 37 | class Logger(object): 38 | """ 39 | Write console output to external text file. 40 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 41 | """ 42 | def __init__(self, fpath=None): 43 | self.console = sys.stdout 44 | self.file = None 45 | if fpath is not None: 46 | mkdir_if_missing(os.path.dirname(fpath)) 47 | self.file = open(fpath, 'w') 48 | 49 | def __del__(self): 50 | self.close() 51 | 52 | def __enter__(self): 53 | pass 54 | 55 | def __exit__(self, *args): 56 | self.close() 57 | 58 | def write(self, msg): 59 | self.console.write(msg) 60 | if self.file is not None: 61 | self.file.write(msg) 62 | 63 | def flush(self): 64 | self.console.flush() 65 | if self.file is not None: 66 | self.file.flush() 67 | os.fsync(self.file.fileno()) 68 | 69 | def close(self): 70 | self.console.close() 71 | if self.file is not None: 72 | self.file.close() 73 | 74 | def intersec(first, second): 75 | overlap = False 76 | overlap = overlap or (first[0] <= second[0] and second[0] <= first[2] and first[1] <= second[1] and second[1] <= first[3]) 77 | overlap = overlap or (first[0] <= second[2] and second[2] <= first[2] and first[1] <= second[1] and second[1] <= first[3]) 78 | overlap = overlap or (first[0] <= second[0] and second[0] <= first[2] and first[1] <= second[3] and second[3] <= first[3]) 79 | overlap = overlap or (first[0] <= second[2] and second[2] <= first[2] and first[1] <= second[3] and second[3] <= first[3]) 80 | return overlap 81 | 82 | def cnt_overlaps(boxes): 83 | boxes_overlap = [] 84 | id_overlap = [] 85 | for ind_first, first in enumerate(boxes): 86 | cnt = 0 87 | overlap = [] 88 | for ind_second, second in enumerate(boxes): 89 | if ind_first != ind_second and intersec(first, second): 90 | cnt += 1 91 | overlap.append(ind_second) 92 | boxes_overlap.append(cnt) 93 | id_overlap.append(overlap) 94 | return boxes_overlap, id_overlap -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import torch 4 | import numpy as np 5 | import sys 6 | from torch.nn.utils import clip_grad_norm_ 7 | 8 | from architecture.crowd_count import CrowdCounter 9 | from architecture import network 10 | from architecture.data_loader import ImageDataLoader 11 | from architecture.timer import Timer 12 | from architecture import utils 13 | from architecture.evaluate_model import evaluate_model 14 | 15 | import argparse 16 | 17 | from manage_data import dataset_loader 18 | from manage_data.utils import Logger, mkdir_if_missing 19 | 20 | import time 21 | EPSILON = 1e-10 22 | MAXIMUM_CNT = 2.7675038540969217 23 | NORMALIZE_ADD = np.log(EPSILON) / 2.0 24 | 25 | parser = argparse.ArgumentParser(description='Train crowd counting network using data augmentation') 26 | # Datasets 27 | parser.add_argument('-d', '--dataset', type=str, default='ucf', 28 | choices=dataset_loader.get_names()) 29 | #Data augmentation hyperpameters 30 | parser.add_argument('--force-den-maps', action='store_true', help="force generation of dentisity maps for original dataset, by default it is generated only once") 31 | parser.add_argument('--force-augment', action='store_true', help="force generation of augmented data, by default it is generated only once") 32 | parser.add_argument('--displace', default=70, type=int,help="displacement for sliding window in data augmentation, default 70") 33 | parser.add_argument('--size-x', default=256, type=int, help="width of sliding window in data augmentation, default 200") 34 | parser.add_argument('--size-y', default=256, type=int, help="height of sliding window in data augmentation, default 300") 35 | parser.add_argument('--people-thr', default=0, type=int, help="threshold of people sliding window in data augmentation, default 200") 36 | parser.add_argument('--not-augment-noise', action='store_true', help="not use noise for data augmetnation, default True") 37 | parser.add_argument('--not-augment-light', action='store_true', help="not use bright & contrast for data augmetnation, default True") 38 | parser.add_argument('--bright', default=10, type=int, help="bright value for bright & contrast augmentation, defaul 10") 39 | parser.add_argument('--contrast', default=10, type=int, help="contrast value for bright & contrast augmentation, defaul 10") 40 | parser.add_argument('--gt-mode', type=str, default='same', help="mode for generation of ground thruth.") 41 | 42 | # Optimization options 43 | parser.add_argument('--max-epoch', default=500, type=int, 44 | help="maximum epochs to run") 45 | parser.add_argument('--start-epoch', default=0, type=int, 46 | help="manual epoch number (useful on restarts)") 47 | parser.add_argument('--lr', '--learning-rate', default=0.00005, type=float, 48 | help="initial learning rate") 49 | parser.add_argument('--beta1', default=0.5, type=float, 50 | help="training b1 for adam optimizer") 51 | parser.add_argument('--beta2', default=0.999, type=float, 52 | help="training b2 for adam optimizer") 53 | parser.add_argument('--train-batch', default=32, type=int, 54 | help="train batch size (default 32)") 55 | # Miscs 56 | parser.add_argument('--den-factor', type=float, default=1e3, help="factor to multiply for density maps to avoid too small values") 57 | parser.add_argument('--overlap-test', action='store_true', help="overlap the sliding windows for test") 58 | parser.add_argument('--seed', type=int, default=64678, help="manual seed") 59 | parser.add_argument('--resume', type=str, default='', metavar='PATH', help="root directory where part/fold of previous train are saved") 60 | parser.add_argument('--save-dir', type=str, default='log', help="path where results for each part/fold are saved") 61 | parser.add_argument('--units', type=str, default='', help="folds/parts units to be trained, be default all folds/parts are trained") 62 | parser.add_argument('--augment-only', action='store_true', help="run only data augmentation, default False") 63 | parser.add_argument('--evaluate-only', action='store_true', help="run only data validation, --resume arg is needed, default False") 64 | parser.add_argument('--save-plots', action='store_true', help="save plots of density map estimation (done only in test step), default False") 65 | 66 | args = parser.parse_args() 67 | 68 | def train(train_test_unit, out_dir_root): 69 | output_dir = osp.join(out_dir_root, train_test_unit.metadata['name']) 70 | mkdir_if_missing(output_dir) 71 | sys.stdout = Logger(osp.join(output_dir, 'log_train.txt')) 72 | print("==========\nArgs:{}\n==========".format(args)) 73 | 74 | dataset_name = train_test_unit.metadata['name'] 75 | train_path = train_test_unit.train_dir_img 76 | train_gt_path = train_test_unit.train_dir_den 77 | val_path =train_test_unit.test_dir_img 78 | val_gt_path = train_test_unit.test_dir_den 79 | 80 | #training configuration 81 | start_step = args.start_epoch 82 | end_step = args.max_epoch 83 | lr = args.lr 84 | 85 | #log frequency 86 | disp_interval = args.train_batch*20 87 | 88 | # ------------ 89 | rand_seed = args.seed 90 | if rand_seed is not None: 91 | np.random.seed(rand_seed) 92 | torch.manual_seed(rand_seed) 93 | torch.cuda.manual_seed(rand_seed) 94 | 95 | # load net 96 | net = CrowdCounter() 97 | if not args.resume : 98 | network.weights_normal_init(net, dev=0.01) 99 | else: 100 | #network.weights_normal_init(net, dev=0.01) #init all layers in case of partial net load 101 | if args.resume[-3:] == '.h5': 102 | pretrained_model = args.resume 103 | else: 104 | resume_dir = osp.join(args.resume, pu.metadata['name']) 105 | pretrained_model = osp.join(resume_dir, 'best_model.h5') 106 | network.load_net(pretrained_model, net) 107 | print('Will apply fine tunning over', pretrained_model) 108 | net.cuda() 109 | net.train() 110 | 111 | optimizer_d_large = torch.optim.Adam(filter(lambda p: p.requires_grad, net.d_large.parameters()), lr=lr, betas = (args.beta1, args.beta2)) 112 | optimizer_d_small = torch.optim.Adam(filter(lambda p: p.requires_grad, net.d_small.parameters()), lr=lr, betas = (args.beta1, args.beta2)) 113 | optimizer_g_large = torch.optim.Adam(filter(lambda p: p.requires_grad, net.g_large.parameters()), lr=lr, betas = (args.beta1, args.beta2)) 114 | optimizer_g_small = torch.optim.Adam(filter(lambda p: p.requires_grad, net.g_small.parameters()), lr=lr, betas = (args.beta1, args.beta2)) 115 | 116 | # training 117 | train_loss = 0 118 | step_cnt = 0 119 | re_cnt = False 120 | t = Timer() 121 | t.tic() 122 | 123 | #preprocess flags 124 | overlap_test = True if args.overlap_test else False 125 | 126 | data_loader = ImageDataLoader(train_path, train_gt_path, shuffle=True, batch_size = args.train_batch, test_loader = False) 127 | data_loader_val = ImageDataLoader(val_path, val_gt_path, shuffle=False, batch_size = 1, test_loader = True, img_width = args.size_x, img_height = args.size_y, test_overlap = overlap_test) 128 | best_mae = sys.maxsize 129 | 130 | for epoch in range(start_step, end_step+1): 131 | step = 0 132 | train_loss_gen_small = 0 133 | train_loss_gen_large = 0 134 | train_loss_dis_small = 0 135 | train_loss_dis_large = 0 136 | 137 | for blob in data_loader: 138 | step = step + args.train_batch 139 | im_data = blob['data'] 140 | gt_data = blob['gt_density'] 141 | idx_data = blob['idx'] 142 | im_data_norm = im_data / 127.5 - 1. #normalize between -1 and 1 143 | gt_data = gt_data * args.den_factor 144 | 145 | optimizer_d_large.zero_grad() 146 | optimizer_d_small.zero_grad() 147 | density_map = net(im_data_norm, gt_data, epoch = epoch, mode = "discriminator") 148 | loss_d_small = net.loss_dis_small 149 | loss_d_large = net.loss_dis_large 150 | loss_d_small.backward() 151 | loss_d_large.backward() 152 | optimizer_d_small.step() 153 | optimizer_d_large.step() 154 | 155 | optimizer_g_large.zero_grad() 156 | optimizer_g_small.zero_grad() 157 | density_map = net(im_data_norm, gt_data, epoch = epoch, mode = "generator") 158 | loss_g_small = net.loss_gen_small 159 | loss_g_large = net.loss_gen_large 160 | loss_g = net.loss_gen 161 | loss_g.backward() # loss_g_large + loss_g_small 162 | optimizer_g_small.step() 163 | optimizer_g_large.step() 164 | 165 | density_map /= args.den_factor 166 | gt_data /= args.den_factor 167 | 168 | train_loss_gen_small += loss_g_small.data.item() 169 | train_loss_gen_large += loss_g_large.data.item() 170 | train_loss_dis_small += loss_d_small.data.item() 171 | train_loss_dis_large += loss_d_large.data.item() 172 | 173 | step_cnt += 1 174 | if step % disp_interval == 0: 175 | duration = t.toc(average=False) 176 | fps = step_cnt / duration 177 | density_map = density_map.data.cpu().numpy() 178 | train_batch_size = gt_data.shape[0] 179 | gt_count = np.sum(gt_data.reshape(train_batch_size, -1), axis = 1) 180 | et_count = np.sum(density_map.reshape(train_batch_size, -1), axis = 1) 181 | 182 | if args.save_plots: 183 | plot_save_dir = osp.join(output_dir, 'plot-results-train/') 184 | mkdir_if_missing(plot_save_dir) 185 | utils.save_results(im_data, gt_data, density_map, idx_data, plot_save_dir, loss = args.loss) 186 | 187 | print("epoch: {0}, step {1}/{5}, Time: {2:.4f}s, gt_cnt: {3:.4f}, et_cnt: {4:.4f}, mean_diff: {6:.4f}".format(epoch, step, 1./fps, gt_count[0],et_count[0], data_loader.num_samples, np.mean(np.abs(gt_count - et_count)))) 188 | re_cnt = True 189 | 190 | if re_cnt: 191 | t.tic() 192 | re_cnt = False 193 | 194 | save_name = os.path.join(output_dir, '{}_{}_{}.h5'.format(train_test_unit.to_string(), dataset_name,epoch)) 195 | network.save_net(save_name, net) 196 | 197 | #calculate error on the validation dataset 198 | mae,mse = evaluate_model(save_name, data_loader_val, epoch = epoch, den_factor = args.den_factor) 199 | if mae < best_mae: 200 | best_mae = mae 201 | best_mse = mse 202 | best_model = '{}_{}_{}.h5'.format(train_test_unit.to_string(),dataset_name,epoch) 203 | network.save_net(os.path.join(output_dir, "best_model.h5"), net) 204 | 205 | print("Epoch: {0}, MAE: {1:.4f}, MSE: {2:.4f}, loss gen small: {3:.4f}, loss gen large: {4:.4f}, loss dis small: {5:.4f}, loss dis large: {6:.4f}, loss: {7:.4f}".format(epoch, mae, mse, train_loss_gen_small, train_loss_gen_large, train_loss_dis_small, train_loss_dis_large, train_loss_gen_small + train_loss_gen_large + train_loss_dis_small + train_loss_dis_large)) 206 | print("Best MAE: {0:.4f}, Best MSE: {1:.4f}, Best model: {2}".format(best_mae, best_mse, best_model)) 207 | 208 | def test(train_test_unit, out_dir_root): 209 | output_dir = osp.join(out_dir_root, train_test_unit.metadata['name']) 210 | mkdir_if_missing(output_dir) 211 | sys.stdout = Logger(osp.join(output_dir, 'log_test.txt')) 212 | print("==========\nArgs:{}\n==========".format(args)) 213 | 214 | dataset_name = train_test_unit.metadata['name'] 215 | val_path =train_test_unit.test_dir_img 216 | val_gt_path = train_test_unit.test_dir_den 217 | 218 | if not args.resume : 219 | pretrained_model = osp.join(output_dir, 'best_model.h5') 220 | else: 221 | if args.resume[-3:] == '.h5': 222 | pretrained_model = args.resume 223 | else: 224 | resume_dir = osp.join(args.resume, train_test_unit.metadata['name']) 225 | pretrained_model = osp.join(resume_dir, 'best_model.h5') 226 | print("Using {} for testing.".format(pretrained_model)) 227 | 228 | overlap_test = True if args.overlap_test else False 229 | 230 | data_loader = ImageDataLoader(val_path, val_gt_path, shuffle=False, batch_size = 1, test_loader = True, img_width = args.size_x, img_height = args.size_y, test_overlap = overlap_test) 231 | mae,mse = evaluate_model(pretrained_model, data_loader, save_test_results=args.save_plots, plot_save_dir=osp.join(output_dir, 'plot-results-test/'), den_factor = args.den_factor) 232 | 233 | print("MAE: {0:.4f}, MSE: {1:.4f}".format(mae, mse)) 234 | 235 | def main(): 236 | #augment data 237 | 238 | force_create_den_maps = True if args.force_den_maps else False 239 | force_augmentation = True if args.force_augment else False 240 | augment_noise = False if args.not_augment_noise else True 241 | augment_light = False if args.not_augment_light else True 242 | augment_only = True if args.augment_only else False 243 | 244 | 245 | dataset = dataset_loader.init_dataset(name=args.dataset 246 | , force_create_den_maps = force_create_den_maps 247 | , force_augmentation = force_augmentation 248 | #sliding windows params 249 | , gt_mode = args.gt_mode 250 | , displace = args.displace 251 | , size_x= args.size_x 252 | , size_y= args.size_y 253 | , people_thr = args.people_thr 254 | #noise_params 255 | , augment_noise = augment_noise 256 | #light_params 257 | , augment_light = augment_light 258 | , bright = args.bright 259 | , contrast = args.contrast) 260 | 261 | if augment_only: 262 | set_units = [unit.metadata['name'] for unit in dataset.train_test_set] 263 | print("Dataset train-test units are: {}".format(", ".join(set_units))) 264 | print("Augment only - network will not be trained") 265 | return 266 | 267 | metadata = "_".join([args.dataset, dataset.signature()]) 268 | out_dir_root = osp.join(args.save_dir, metadata) 269 | 270 | if args.units != '': 271 | units_to_train = [name.strip() for name in args.units.split(',')] 272 | set_units = [unit.metadata['name'] for unit in dataset.train_test_set] 273 | print("Dataset train-test units are: {}".format(", ".join(set_units))) 274 | set_units = set(set_units) 275 | for unit in units_to_train: 276 | if not unit in set_units: 277 | raise RuntimeError("Invalid '{}' train-test unit".format(unit)) 278 | else: 279 | units_to_train = [unit.metadata['name'] for unit in dataset.train_test_set] 280 | units_to_train = set(units_to_train) 281 | for train_test in dataset.train_test_set: 282 | if train_test.metadata['name'] in units_to_train: 283 | if args.evaluate_only: 284 | print("Testing {}".format(train_test.metadata['name'])) 285 | test(train_test, out_dir_root) 286 | else: 287 | print("Training {}".format(train_test.metadata['name'])) 288 | train(train_test, out_dir_root) 289 | print("Testing {}".format(train_test.metadata['name'])) 290 | test(train_test, out_dir_root) 291 | 292 | if __name__ == '__main__': 293 | main() 294 | 295 | --------------------------------------------------------------------------------