├── .gitignore ├── README.md ├── assets ├── 1702.02359.pdf ├── model.png └── mscnn.png ├── data ├── download_malldataset.sh └── download_shanghaitech.sh ├── models └── tips.txt ├── requirements.txt ├── results ├── rst.png └── rst_shanghai.png └── scripts ├── data.py ├── model.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | venv/ 3 | __pycache__/ 4 | *.rar 5 | *.7z 6 | *.jpg 7 | mall_dataset/ 8 | ShanghaiTech/ 9 | *.rar 10 | *.zip 11 | *.h5 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-scale Convolution Neural Networks for Crowd Counting 2 | 3 | >更新于2020年3月 4 | >由于最近很多人咨询这个项目,而我个人之前这个项目因为一些原因终止了,最近抽空再次完善了一下,目前的代码结构主要是在ShanghaiTech数据集下训练,得到的效果还是不错的,具体见下文。 5 | 6 | 7 | ## 项目简介 8 | 复现论文[Multi-scale Convolution Neural Networks for Crowd Counting](https://arxiv.org/abs/1702.02359)。考虑到目前还没有具体写的比较完善的基于Keras的复现,这对于迅速成型的顶层系统的构建不太方便。本项目使用的是TensorFlow2中Keras的API,不建议使用单独的Keras,本项目设计在多个数据集上进行训练测试,模型泛化能力强,实际主要ShanghaiTech数据集上训练获得模型。 9 | 10 | 11 | ## 数据集下载 12 | - ShanghaiTech Dataset 13 | - [下载地址](https://drive.google.com/open?id=1CW6PiAnLSWuUBX-2tVqQO5-1TDdilJB1) 14 | - Mall Dataset 15 | - [下载地址](https://drive.google.com/open?id=170bssJjE_UbGeGSc_s2WHGBtbDAZRd7t) 16 | - The_UCF_CC_50 Dataset 17 | - [下载地址](https://drive.google.com/open?id=1MwfTXFQUTx_sqw-g-D7TDOox1S88XYVN) 18 | - 地址说明 19 | - 不提供数据集官方地址,数据集均放置在我的谷歌云盘,开启共享,无法访问的可以[邮箱](mailto:luanshiyinyang@gmail.com)联系我,我将提供百度网盘地址。 20 | 21 | 22 | ## 论文说明 23 | 针对深度神经网络近几年的发展以及现有的网络模型难以优化以及计算耗时,主要提出了multi-scale blob模块(类Inception结构)进行相关特征的提取。 24 | 25 | 作者主要提出了MSCNN的网络结构,该结构比起MCNN具有更好的处理能力及效果且参数量大幅度缩减,并且纵向对比了LBP+RR、MCNN+CCR等模型。 26 | 27 | 具体[论文文件](/assets/1702.02359.pdf)可以直接访问。 28 | ## 环境配置 29 | - 基于Python3.6 30 | - 需要第三方包已在[requirements](/requirements.txt)列出 31 | - 切换到requirements文件所在目录,执行命令`pip install -r requirements.txt`即可配置环境 32 | - 脚本运行说明 33 | - 训练 34 | - 命令行执行 35 | - `python train.py -b 16` 36 | - 更详细的选项可以执行`python train.py -h`查看帮助 37 | - 测试 38 | - 命令行执行 39 | - `python test.py -s yes` 40 | - 更详细的选项可以执行`python test.py -h`查看帮助 41 | 42 | 43 | ## 模型构建 44 | 针对之前出现的损失大幅降低但密度图预测全0的情况,主要是由于回归器层激活函数不当,为了获得更好的收敛结果,最后一层的激活函数调整如下。 45 | 46 | $$output = Relu(Sigmoid(x))$$ 47 | 48 | 使用Keras的Function API构建模型的代码如下,更具体的可以查看文末的Github地址。 49 | ```python 50 | def MSCNN(input_shape=(224, 224, 3)): 51 | """ 52 | 模型构建 53 | 本论文模型简单 54 | :param input_shape 输入图片尺寸 55 | :return: 56 | """ 57 | input_tensor = Input(shape=input_shape) 58 | # block1 59 | x = Conv2D(filters=64, kernel_size=(9, 9), strides=1, padding='same', activation='relu')(input_tensor) 60 | # block2 61 | x = MSB(4*16)(x) 62 | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) 63 | # block3 64 | x = MSB(4*32)(x) 65 | x = MSB(4*32)(x) 66 | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) 67 | 68 | x = MSB_mini(3*64)(x) 69 | x = MSB_mini(3*64)(x) 70 | 71 | x = Conv2D(1000, (1, 1), activation='relu', kernel_regularizer=l2(5e-4))(x) 72 | 73 | x = Conv2D(1, (1, 1))(x) 74 | x = Activation('sigmoid')(x) 75 | x = Activation('relu')(x) 76 | 77 | model = Model(inputs=input_tensor, outputs=x) 78 | return model 79 | ``` 80 | 81 | **注意,输出层不能使用传统的Relu,会输出陷入“死区”,导致预测均为0值且loss确实在不断降低。** 82 | ### 结构概念图 83 | ![图片来自论文](./assets/mscnn.png) 84 | ### 结构配置图 85 | ![图片来自论文](./assets/model.png) 86 | 87 | 88 | ## 模型训练 89 | ### 训练数据集 90 | 主要在ShanghaiTech上训练,其余数据集类似封装data loader即可完成训练或者测试。 91 | ### 训练效果展示(模型简单训练5轮) 92 | 对ShanghaiTech验证集随机5张图片进行人群密度估计,结果如下,可以看到,收敛的效果还是不错的,想要获得更好的效果则需要更为细致的训练调整。 93 | 94 | ![](./results/rst_shanghai.png) 95 | 96 | 97 | ## 补充说明 98 | 训练完成的预训练模型可以在[百度网盘](https://pan.baidu.com/s/1syEjwWVLjihEC5y2tQDdAA)下载(提取码jppg)或者从[Google Drive](https://drive.google.com/drive/folders/1XsvQHXfXKprRrOIECsL740Z-1TDE9x-P?usp=sharing)获取,下载后放置在models文件夹即可。完整代码已经上传到我的Github,欢迎Star或者Fork。如有错误,欢迎指正。 99 | -------------------------------------------------------------------------------- /assets/1702.02359.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luanshiyinyang/MSCNN/0600936a1239da34081c961d08c04e570f88bb97/assets/1702.02359.pdf -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luanshiyinyang/MSCNN/0600936a1239da34081c961d08c04e570f88bb97/assets/model.png -------------------------------------------------------------------------------- /assets/mscnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luanshiyinyang/MSCNN/0600936a1239da34081c961d08c04e570f88bb97/assets/mscnn.png -------------------------------------------------------------------------------- /data/download_malldataset.sh: -------------------------------------------------------------------------------- 1 | `wget http://personal.ie.cuhk.edu.hk/~ccloy/files/datasets/mall_dataset.zip` 2 | `unzip -q mall_dataset.zip -d ./` -------------------------------------------------------------------------------- /data/download_shanghaitech.sh: -------------------------------------------------------------------------------- 1 | mkdir ShanghaiTech 2 | unzip -q ShanghaiTech_Crowd_Counting_Dataset.zip 3 | mv part_* ShanghaiTech/ 4 | 5 | -------------------------------------------------------------------------------- /models/tips.txt: -------------------------------------------------------------------------------- 1 | models here -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.0.0 2 | opencv-python 3 | numpy 4 | scipy 5 | glob3 6 | matplotlib -------------------------------------------------------------------------------- /results/rst.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luanshiyinyang/MSCNN/0600936a1239da34081c961d08c04e570f88bb97/results/rst.png -------------------------------------------------------------------------------- /results/rst_shanghai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luanshiyinyang/MSCNN/0600936a1239da34081c961d08c04e570f88bb97/results/rst_shanghai.png -------------------------------------------------------------------------------- /scripts/data.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | import glob 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | class MallDataset(object): 8 | 9 | def __init__(self): 10 | self.filenames = sorted(glob.glob('../data/mall_dataset/frames/*.jpg'), key=lambda x: int(x[-8:-4])) 11 | 12 | def get_train_num(self): 13 | return int(len(self.filenames) * 0.8) 14 | 15 | def get_valid_num(self): 16 | return len(self.filenames) - int(len(self.filenames) * 0.8) 17 | 18 | def get_annotation(self): 19 | """ 20 | 读取2000个图片的注解,得到 每个图片的人数 和 每章图片的所有人坐标 21 | Annotation按照图片命名顺序 22 | :return: 23 | """ 24 | mat_annotation = loadmat('../data/mall_dataset/mall_gt.mat') 25 | count_data, position_data = mat_annotation['count'], mat_annotation['frame'][0] 26 | return count_data, position_data 27 | 28 | def get_pixels(self, img, img_index, positions, size): 29 | """ 30 | 生成密度图,准备输入神经网络 31 | :param img 32 | :param img_index 33 | :param positions 34 | :param size 神经网络输入层图片大小 35 | """ 36 | h, w = img.shape[:-1] 37 | proportion_h, proportion_w = size / h, size / w # 输入层需求与当前图片大小对比 38 | pixels = np.zeros((size, size)) 39 | 40 | for p in positions[img_index][0][0][0]: 41 | # 取出每个人的坐标 42 | now_x, now_y = int(p[0] * proportion_w), int(p[1] * proportion_h) # 按照输入层要求调整坐标位置 43 | if now_x >= size or now_y >= size: 44 | # 越界则向下取整 45 | print("Sorry skip the point, its index of all is {}".format(img_index)) 46 | else: 47 | pixels[now_y, now_x] += 1 48 | pixels = cv2.GaussianBlur(pixels, (15, 15), 0) 49 | return pixels 50 | 51 | def get_img_data(self, index, size): 52 | """ 53 | 读取源文件图片 54 | :param index 图片下标 55 | :param size 神经网络输入层尺寸 56 | :return: 57 | """ 58 | _, positions = self.get_annotation() 59 | img = cv2.imread(self.filenames[index]) 60 | density_map = np.expand_dims(self.get_pixels(img, index, positions, size // 4), axis=-1) 61 | img = cv2.resize(img, (size, size)) / 255. 62 | 63 | return img, density_map 64 | 65 | def gen_train(self, batch_size, size): 66 | """ 67 | 生成数据生成器 68 | :param batch_size: 69 | :param size: 70 | :return: 71 | """ 72 | _, position = self.get_annotation() 73 | index_all = list(range(int(len(self.filenames) * 0.8))) # 取出所有训练数据下标,默认数据的前80%为训练集 74 | 75 | i, n = 0, len(index_all) 76 | if batch_size > n: 77 | raise Exception('Batch size {} is larger than the number of dataset {}!'.format(batch_size, n)) 78 | 79 | while True: 80 | if i + batch_size >= n: 81 | np.random.shuffle(index_all) 82 | i = 0 83 | continue 84 | batch_x, batch_y = [], [] 85 | for j in range(i, i + batch_size): 86 | x, y = self.get_img_data(index_all[j], size) 87 | batch_x.append(x) 88 | batch_y.append(y) 89 | i += batch_size 90 | yield np.array(batch_x), np.array(batch_y) 91 | 92 | def gen_valid(self, batch_size, size): 93 | """ 94 | 生成数据生成器 95 | :param batch_size: 96 | :param size: 97 | :return: 98 | """ 99 | _, position = self.get_annotation() 100 | index_all = list(range(int(len(self.filenames) * 0.8), len(self.filenames))) 101 | 102 | i, n = 0, len(index_all) 103 | if batch_size > n: 104 | raise Exception('Batch size {} is larger than the number of dataset {}!'.format(batch_size, n)) 105 | 106 | while True: 107 | if i + batch_size >= n: 108 | np.random.shuffle(index_all) 109 | i = 0 110 | continue 111 | batch_x, batch_y = [], [] 112 | for j in range(i, i + batch_size): 113 | x, y = self.get_img_data(index_all[j], size) 114 | batch_x.append(x) 115 | batch_y.append(y) 116 | i += batch_size 117 | 118 | yield np.array(batch_x), np.array(batch_y) 119 | 120 | def gen_all(self, pic_size): 121 | """ 122 | 数据生成器 123 | :param pic_size: 124 | :return: 125 | """ 126 | x_data = [] 127 | y_data = [] 128 | for i in range(len(self.filenames)): 129 | image, map_ = self.get_img_data(i, pic_size) 130 | x_data.append(image) 131 | y_data.append(map_) 132 | x_data, y_data = np.array(x_data), np.array(y_data) 133 | return x_data, y_data 134 | 135 | 136 | class ShanghaitechDataset(object): 137 | 138 | def __init__(self, part='A'): 139 | if part == 'A': 140 | self.folder = '../data/ShanghaiTech/part_A_final/' 141 | else: 142 | self.folder = '../data/ShanghaiTech/part_B_final/' 143 | 144 | def get_annotation(self, folder, index): 145 | """ 146 | 读取图片注解 147 | :param folder 路径必须是part_A/train_data/这一步 148 | :param index: 图片索引,1开始 149 | :return: 150 | """ 151 | mat_data = loadmat(folder + 'ground_truth/GT_IMG_{}.mat'.format(index)) 152 | positions, count = mat_data['image_info'][0][0][0][0][0], mat_data['image_info'][0][0][0][0][1][0][0] 153 | return positions, count 154 | 155 | def get_pixels(self, folder, img, img_index, size): 156 | """ 157 | 生成密度图,准备输入神经网络 158 | :param folder 当前所在数据目录,该数据集目录较为复杂 159 | :param img 原始图像 160 | :param img_index 图片在当前目录下的图片序号,1开始 161 | :param size 目标图大小,按照模型为img的1/4 162 | """ 163 | positions, _ = self.get_annotation(folder, img_index) 164 | h, w = img.shape[0], img.shape[1] 165 | proportion_h, proportion_w = size / h, size / w # 输入层需求与当前图片大小对比 166 | pixels = np.zeros((size, size)) 167 | 168 | for p in positions: 169 | # 取出每个人的坐标 170 | now_x, now_y = int(p[0] * proportion_w), int(p[1] * proportion_h) # 按照输入层要求调整坐标位置 171 | if now_x >= size or now_y >= size: 172 | # 越界则向下取整 173 | pass 174 | # print("Sorry skip the point, image index of all is {}".format(img_index)) 175 | else: 176 | pixels[now_y, now_x] += 1 177 | 178 | pixels = cv2.GaussianBlur(pixels, (15, 15), 0) 179 | return pixels 180 | 181 | def gen_train(self, batch_size, size): 182 | """ 183 | 获取训练数据 184 | :return: 185 | """ 186 | folder = self.folder + 'train_data/' 187 | index_all = [i+1 for i in range(len(glob.glob(folder + 'images/*.jpg')))] 188 | 189 | i, n = 0, len(index_all) 190 | if batch_size > n: 191 | raise Exception('Batch size {} is larger than the number of dataset {}!'.format(batch_size, n)) 192 | 193 | while True: 194 | if i + batch_size >= n: 195 | np.random.shuffle(index_all) 196 | i = 0 197 | continue 198 | batch_x, batch_y = [], [] 199 | for j in range(i, i + batch_size): 200 | img = cv2.imread(folder + 'images/IMG_{}.jpg'.format(index_all[j])) 201 | density = np.expand_dims(self.get_pixels(folder, img, index_all[j], size // 4), axis=-1) 202 | img = cv2.resize(img, (size, size)) / 255. 203 | batch_x.append(img) 204 | batch_y.append(density) 205 | i += batch_size 206 | yield np.array(batch_x), np.array(batch_y) 207 | 208 | def gen_valid(self, batch_size, size): 209 | """ 210 | 获取验证数据 211 | :return: 212 | """ 213 | folder = self.folder + 'test_data/' 214 | index_all = [i + 1 for i in range(len(glob.glob(folder + 'images/*.jpg')))] 215 | 216 | i, n = 0, len(index_all) 217 | if batch_size > n: 218 | raise Exception('Batch size {} is larger than the number of dataset {}!'.format(batch_size, n)) 219 | 220 | while True: 221 | if i + batch_size >= n: 222 | np.random.shuffle(index_all) 223 | i = 0 224 | continue 225 | batch_x, batch_y = [], [] 226 | for j in range(i, i + batch_size): 227 | img = cv2.imread(folder + 'images/IMG_{}.jpg'.format(index_all[j])) 228 | density = np.expand_dims(self.get_pixels(folder, img, index_all[j], size // 4), axis=-1) 229 | img = cv2.resize(img, (size, size)) / 255. 230 | batch_x.append(img) 231 | batch_y.append(density) 232 | i += batch_size 233 | yield np.array(batch_x), np.array(batch_y) 234 | 235 | def get_train_num(self): 236 | return len(glob.glob(self.folder + 'train_data/' + 'images/*')) 237 | 238 | def get_valid_num(self): 239 | return len(glob.glob(self.folder + 'test_data/' + 'images/*')) 240 | 241 | 242 | if __name__ == '__main__': 243 | MallDataset().gen_valid(16, 224) -------------------------------------------------------------------------------- /scripts/model.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Activation, Dense, BatchNormalization 2 | from tensorflow.keras.models import Model 3 | from tensorflow.keras.regularizers import l2 4 | 5 | 6 | def MSB(filter_num): 7 | def f(x): 8 | params = { 9 | 'strides': 1, 10 | 'activation': 'relu', 11 | 'padding': 'same', 12 | 'kernel_regularizer': l2(5e-4) 13 | } 14 | x1 = Conv2D(filters=filter_num, kernel_size=(9, 9), **params)(x) 15 | x2 = Conv2D(filters=filter_num, kernel_size=(7, 7), **params)(x) 16 | x3 = Conv2D(filters=filter_num, kernel_size=(5, 5), **params)(x) 17 | x4 = Conv2D(filters=filter_num, kernel_size=(3, 3), **params)(x) 18 | x = concatenate([x1, x2, x3, x4]) 19 | x = BatchNormalization()(x) 20 | return x 21 | return f 22 | 23 | 24 | def MSB_mini(filter_num): 25 | def f(x): 26 | params = { 27 | 'strides': 1, 28 | 'activation': 'relu', 29 | 'padding': 'same', 30 | 'kernel_regularizer': l2(5e-4) 31 | } 32 | x2 = Conv2D(filters=filter_num, kernel_size=(7, 7), **params)(x) 33 | x3 = Conv2D(filters=filter_num, kernel_size=(5, 5), **params)(x) 34 | x4 = Conv2D(filters=filter_num, kernel_size=(3, 3), **params)(x) 35 | x = concatenate([x2, x3, x4]) 36 | x = BatchNormalization()(x) 37 | x = Activation('relu')(x) 38 | return x 39 | return f 40 | 41 | 42 | def MSCNN(input_shape=(224, 224, 3)): 43 | """ 44 | 模型构建 45 | 本论文模型简单 46 | :param input_shape 输入图片尺寸 47 | :return: 48 | """ 49 | input_tensor = Input(shape=input_shape) 50 | # block1 51 | x = Conv2D(filters=64, kernel_size=(9, 9), strides=1, padding='same', activation='relu')(input_tensor) 52 | # block2 53 | x = MSB(4*16)(x) 54 | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) 55 | # block3 56 | x = MSB(4*32)(x) 57 | x = MSB(4*32)(x) 58 | x = MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(x) 59 | 60 | x = MSB_mini(3*64)(x) 61 | x = MSB_mini(3*64)(x) 62 | 63 | x = Conv2D(1000, (1, 1), activation='relu', kernel_regularizer=l2(5e-4))(x) 64 | 65 | x = Conv2D(1, (1, 1))(x) 66 | x = Activation('sigmoid')(x) 67 | x = Activation('relu')(x) 68 | 69 | model = Model(inputs=input_tensor, outputs=x) 70 | return model 71 | 72 | 73 | if __name__ == '__main__': 74 | mscnn = MSCNN((224, 224, 3)) 75 | print(mscnn.summary()) 76 | 77 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import matplotlib.pyplot as plt 3 | import os 4 | import random 5 | import glob 6 | import cv2 7 | import numpy as np 8 | from model import MSCNN 9 | import scipy.io as sio 10 | 11 | 12 | def parse_params(): 13 | """ 14 | 解析命令行参数 15 | :return: 16 | """ 17 | ap = ArgumentParser() 18 | ap.add_argument('-s', '--show', default='yes', help='if show test result map') 19 | args_ = ap.parse_args() 20 | args_ = vars(args_) 21 | return args_ 22 | 23 | 24 | def get_samples_malldataset(num): 25 | """ 26 | 获取测试图片 27 | :return: 28 | """ 29 | def get_annotation(): 30 | """ 31 | 读取2000个图片的注解,得到 每个图片的人数 和 每章图片的所有人坐标 32 | Annotation按照图片命名顺序 33 | :return: 34 | """ 35 | mat_annotation = sio.loadmat('../data/mall_dataset/mall_gt.mat') 36 | count_data, position_data = mat_annotation['count'], mat_annotation['frame'][0] 37 | return count_data, position_data 38 | counts_true, _ = get_annotation() 39 | datasize = len(glob.glob('../data/mall_dataset/frames/*.jpg')) 40 | # 从验证集选取测试图片 41 | samples_index = random.sample([i for i in range(int(0.8*datasize), datasize)], num) 42 | samples = [glob.glob('../data/mall_dataset/frames/*.jpg')[i] for i in samples_index] 43 | images = [] 44 | counts = [] 45 | for i in range(num): 46 | filename = samples[i] 47 | print(filename) 48 | print(samples_index[i]) 49 | img = cv2.resize(cv2.imread(filename), (224, 224)) / 255. 50 | img = np.expand_dims(img, axis=0) 51 | images.append(img) 52 | counts.append(counts_true[samples_index[i]]) 53 | return images, counts 54 | 55 | 56 | def get_samples_shanghaitech(num): 57 | """ 58 | 获取测试图片 59 | :return: 60 | """ 61 | def get_annotation(index): 62 | """ 63 | 读取2000个图片的注解,得到 每个图片的人数 和 每章图片的所有人坐标 64 | Annotation按照图片命名顺序 65 | :return: 66 | """ 67 | mat_data = sio.loadmat('../data/ShanghaiTech/part_A_final/test_data/ground_truth/GT_IMG_{}.mat'.format(index)) 68 | position_data, count_data = mat_data['image_info'][0][0][0][0][0], mat_data['image_info'][0][0][0][0][1][0][0] 69 | return count_data, position_data 70 | 71 | datasize = len(glob.glob('../data/ShanghaiTech/part_A_final/test_data/images/*.jpg')) 72 | # 从验证集选取测试图片 73 | samples_index = random.sample([i for i in range(datasize)], num) 74 | samples = ['../data/ShanghaiTech/part_A_final/test_data/images/IMG_{}.jpg'.format(i) for i in samples_index] 75 | images = [] 76 | counts = [] 77 | for i in range(num): 78 | filename = samples[i] 79 | img = cv2.resize(cv2.imread(filename), (224, 224)) / 255. 80 | img = np.expand_dims(img, axis=0) 81 | images.append(img) 82 | counts.append(get_annotation(samples_index[i])[0]) 83 | return images, counts 84 | 85 | 86 | def plot_sample(raw_images, maps, counts, true_counts): 87 | """ 88 | 演示测试的5个图片 89 | :return: 90 | """ 91 | plt.figure(figsize=(15, 9)) 92 | for i in range(len(maps)): 93 | plt.subplot(2, 5, i + 1) 94 | plt.imshow(np.squeeze(raw_images[i], axis=0)) 95 | plt.title('people true num {}'.format(int(true_counts[i]))) 96 | plt.subplot(2, 5, i + 1 + 5) 97 | plt.imshow(maps[i][0]) 98 | plt.title('people pred num {}'.format(counts[i])) 99 | plt.savefig('../results/rst.png') 100 | 101 | 102 | def save_result(raw_images, maps, counts, args_, true_counts): 103 | """ 104 | 保存map图 105 | :return: 106 | """ 107 | if not os.path.exists('../results'): 108 | os.mkdir('../results') 109 | # for i in range(len(maps)): 110 | # cv2.imwrite('../results/sample_{}.jpg'.format(i), maps[i]) 111 | if args_['show'] == 'yes': 112 | plot_sample(raw_images, maps, counts, true_counts) 113 | 114 | 115 | def test(args_): 116 | """ 117 | 测试模型效果 118 | :param args_: 119 | :return: 120 | """ 121 | model = MSCNN((224, 224, 3)) 122 | if os.path.exists('../models/best_model_weights.h5'): 123 | model.load_weights('../models/best_model_weights.h5') 124 | samples, true_counts = get_samples_shanghaitech(5) 125 | maps = [] 126 | counts = [] 127 | for sample in samples: 128 | dmap = np.squeeze(model.predict(sample), axis=-1) 129 | dmap = cv2.GaussianBlur(dmap, (15, 15), 0) 130 | counts.append(int(np.sum(dmap))) 131 | maps.append(dmap) 132 | save_result(samples, maps, counts, args_, true_counts) 133 | else: 134 | print("Sorry, cannot find model file in root_path/models/, please download my model or train your model") 135 | 136 | 137 | if __name__ == '__main__': 138 | args = parse_params() 139 | test(args) -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | import warnings 4 | warnings.filterwarnings('ignore') 5 | 6 | from argparse import ArgumentParser 7 | from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint 8 | from tensorflow.keras.optimizers import Adam, SGD 9 | from model import MSCNN 10 | from data import MallDataset, ShanghaitechDataset 11 | import tensorflow as tf 12 | 13 | 14 | if tf.test.is_gpu_available(): 15 | print("use gpu 0") 16 | else: 17 | print("no gpu") 18 | 19 | 20 | def parse_command_params(): 21 | """ 22 | 解析命令行参数 23 | :return: 24 | """ 25 | parser = ArgumentParser() 26 | parser.add_argument('-e', '--epochs', default=50, help='how many epochs to fit') 27 | parser.add_argument('-v', '--show', default='yes', help='if show training log') 28 | parser.add_argument('-b', '--batch', default=16, help='batch size of train') 29 | parser.add_argument('-d', '--dataset', default='shanghaitechdataset', help='which dataset to train') 30 | parser.add_argument('-p', '--pretrained', default='no', help='load your pretrained model in folder root/models') 31 | args_ = parser.parse_args() 32 | args_ = vars(args_) 33 | return args_ 34 | 35 | 36 | def get_callbacks(): 37 | """ 38 | 设置部分回调 39 | :return: 40 | """ 41 | early_stopping = EarlyStopping(monitor='val_loss', patience=20) 42 | reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.0005, patience=5, min_lr=1e-7, verbose=True) 43 | if not os.path.exists('../models'): 44 | os.mkdir('../models') 45 | model_checkpoint = ModelCheckpoint('../models/best_model_weights.h5', monitor='val_loss', 46 | verbose=True, save_best_only=True, save_weights_only=True) 47 | callbacks = [early_stopping, reduce_lr, model_checkpoint] 48 | return callbacks 49 | 50 | 51 | def train(args_): 52 | """ 53 | 进行训练 54 | :return: 55 | """ 56 | model = MSCNN((224, 224, 3)) 57 | model.compile(optimizer=SGD(lr=3e-4, momentum=0.9), loss='mse') 58 | # load pretrained model 59 | if args_['pretrained'] == 'yes': 60 | model.load_weights('../models/best_model_weights.h5') 61 | print("load model from ../models/") 62 | 63 | callbacks = get_callbacks() 64 | 65 | # 流式读取,一个batch读入内存 66 | batch_size = int(args_['batch']) 67 | if args_['dataset'] == 'malldataset': 68 | model.fit_generator(MallDataset().gen_train(batch_size, 224), 69 | steps_per_epoch=MallDataset().get_train_num() // batch_size, 70 | validation_data=MallDataset().gen_valid(batch_size, 224), 71 | validation_steps=MallDataset().get_valid_num() // batch_size, 72 | epochs=int(args_['epochs']), 73 | callbacks=callbacks) 74 | elif args_['dataset'] == 'shanghaitechdataset': 75 | model.fit_generator(ShanghaitechDataset().gen_train(batch_size, 224), 76 | steps_per_epoch=ShanghaitechDataset().get_train_num() // batch_size, 77 | validation_data=ShanghaitechDataset().gen_valid(batch_size, 224), 78 | validation_steps=ShanghaitechDataset().get_valid_num() // batch_size, 79 | epochs=int(args_['epochs']), 80 | callbacks=callbacks) 81 | else: 82 | print('not support this dataset') 83 | 84 | 85 | if __name__ == '__main__': 86 | args = parse_command_params() 87 | train(args) 88 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | def visualize_dmap(img, dmap): 2 | """ 3 | 4 | :param img: 5 | :param dmap: 6 | :return: 7 | """ 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | plt.figure(figsize=(12, 8)) 11 | plt.subplot(2, 1, 1) 12 | plt.imshow(img) 13 | plt.subplot(2, 1, 2) 14 | plt.imshow(np.squeeze(dmap, axis=-1), cmap='gray') 15 | plt.show() 16 | 17 | 18 | if __name__ == '__main__': 19 | from data import MallDataset 20 | x, y = MallDataset().get_img_data(0, size=224) 21 | visualize_dmap(x, y) --------------------------------------------------------------------------------