├── .gitignore ├── LICENSE ├── README.md ├── cannet.py ├── data_preparation └── k_nearest_gaussian_kernel.py ├── my_dataset.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # my_ignore 107 | checkpoints/ 108 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 CommissarMa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Context-Aware_Crowd_Counting-pytorch 2 | This is an simple and clean unoffical implemention of CVPR 2019 paper ["Context-Aware Crowd Counting"](https://arxiv.org/pdf/1811.10452.pdf). 3 | # Installation 4 |  1. Install pytorch 1.0.0 later and python 3.6 later 5 |  2. Install visdom 6 | ```pip 7 | pip install visdom 8 | ``` 9 |  3. Install tqdm 10 | ```pip 11 | pip install tqdm 12 | ``` 13 |  4. Clone this repository 14 | ```git 15 | git clone https://github.com/CommissarMa/Context-Aware_Crowd_Counting-pytorch.git 16 | ``` 17 | We'll call the directory that you cloned Context-Aware_Crowd_Counting-pytorch as ROOT. 18 | # Data Setup 19 |  1. Download ShanghaiTech Dataset from 20 | Dropbox: [link](https://www.dropbox.com/s/fipgjqxl7uj8hd5/ShanghaiTech.zip?dl=0) or Baidu Disk: [link](http://pan.baidu.com/s/1nuAYslz) 21 |  2. Put ShanghaiTech Dataset in ROOT and use "data_preparation/k_nearest_gaussian_kernel.py" to generate ground truth density-map. (Mind that you need modify the root_path in the main function of "data_preparation/k_nearest_gaussian_kernel.py") 22 | # Training 23 |  1. Modify the root path in "train.py" according to your dataset position. 24 |  2. In command line: 25 | ``` 26 | python -m visdom.server 27 | ``` 28 |  3. Run train.py 29 | # Testing 30 |  1. Modify the root path in "test.py" according to your dataset position. 31 |  2. Run test.py for calculate MAE of test images or just show an estimated density-map. 32 | # Other notes 33 | we got the comparable MAE at the 353 epoch [BaiduDisk download with Extraction code: yfwb](https://pan.baidu.com/s/1Y-nnVQoZgmgNjpHhE4y--Q) or [Dropbox Link](https://www.dropbox.com/s/do3yf8hs841exha/cvpr2019_CAN_SHHA_353.pth?dl=0) which is reported in paper. Thanks for the author's(Weizhe Liu) response by email. His mainpage is [link](https://sites.google.com/view/weizheliu/home). -------------------------------------------------------------------------------- /cannet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torchvision import models 4 | import collections 5 | 6 | 7 | class CANNet(nn.Module): 8 | def __init__(self, load_weights=False): 9 | super(CANNet,self).__init__() 10 | self.frontend_feat=[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512] 11 | self.backend_feat=[512, 512, 512,256,128,64] 12 | self.frontend = make_layers(self.frontend_feat) 13 | self.backend = make_layers(self.backend_feat,in_channels = 1024,dilation = True) 14 | self.output_layer = nn.Conv2d(64, 1, kernel_size=1) 15 | self.conv1_1=nn.Conv2d(512,512,kernel_size=1,bias=False) 16 | self.conv1_2=nn.Conv2d(512,512,kernel_size=1,bias=False) 17 | self.conv2_1=nn.Conv2d(512,512,kernel_size=1,bias=False) 18 | self.conv2_2=nn.Conv2d(512,512,kernel_size=1,bias=False) 19 | self.conv3_1=nn.Conv2d(512,512,kernel_size=1,bias=False) 20 | self.conv3_2=nn.Conv2d(512,512,kernel_size=1,bias=False) 21 | self.conv6_1=nn.Conv2d(512,512,kernel_size=1,bias=False) 22 | self.conv6_2=nn.Conv2d(512,512,kernel_size=1,bias=False) 23 | if not load_weights: 24 | mod = models.vgg16(pretrained = True) 25 | self._initialize_weights() 26 | # print("VGG",list(mod.state_dict().items())[0][1])#要的VGG值 27 | fsd=collections.OrderedDict() 28 | for i in range(len(self.frontend.state_dict().items())):#10个卷积*(weight,bias)=20个参数 29 | temp_key=list(self.frontend.state_dict().items())[i][0] 30 | fsd[temp_key]=list(mod.state_dict().items())[i][1] 31 | self.frontend.load_state_dict(fsd) 32 | # print("Mine",list(self.frontend.state_dict().items())[0][1])#将VGG值赋予自己网络后输出验证 33 | # self.frontend.state_dict().items()[i][1].data[:] = mod.state_dict().items()[i][1].data[:]#python2.7版本 34 | def forward(self,x): 35 | fv = self.frontend(x) 36 | #S=1 37 | ave1=nn.functional.adaptive_avg_pool2d(fv,(1,1)) 38 | ave1=self.conv1_1(ave1) 39 | # ave1=nn.functional.relu(ave1) 40 | s1=nn.functional.upsample(ave1,size=(fv.shape[2],fv.shape[3]),mode='bilinear') 41 | c1=s1-fv 42 | w1=self.conv1_2(c1) 43 | w1=nn.functional.sigmoid(w1) 44 | #S=2 45 | ave2=nn.functional.adaptive_avg_pool2d(fv,(2,2)) 46 | ave2=self.conv2_1(ave2) 47 | # ave2=nn.functional.relu(ave2) 48 | s2=nn.functional.upsample(ave2,size=(fv.shape[2],fv.shape[3]),mode='bilinear') 49 | c2=s2-fv 50 | w2=self.conv2_2(c2) 51 | w2=nn.functional.sigmoid(w2) 52 | #S=3 53 | ave3=nn.functional.adaptive_avg_pool2d(fv,(3,3)) 54 | ave3=self.conv3_1(ave3) 55 | # ave3=nn.functional.relu(ave3) 56 | s3=nn.functional.upsample(ave3,size=(fv.shape[2],fv.shape[3]),mode='bilinear') 57 | c3=s3-fv 58 | w3=self.conv3_2(c3) 59 | w3=nn.functional.sigmoid(w3) 60 | #S=6 61 | # print('fv',fv.mean()) 62 | ave6=nn.functional.adaptive_avg_pool2d(fv,(6,6)) 63 | # print('ave6',ave6.mean()) 64 | ave6=self.conv6_1(ave6) 65 | # print(ave6.mean()) 66 | # ave6=nn.functional.relu(ave6) 67 | s6=nn.functional.upsample(ave6,size=(fv.shape[2],fv.shape[3]),mode='bilinear') 68 | # print('s6',s6.mean(),'s1',s1.mean(),'s2',s2.mean(),'s3',s3.mean()) 69 | c6=s6-fv 70 | # print('c6',c6.mean()) 71 | w6=self.conv6_2(c6) 72 | w6=nn.functional.sigmoid(w6) 73 | # print('w6',w6.mean()) 74 | 75 | fi=(w1*s1+w2*s2+w3*s3+w6*s6)/(w1+w2+w3+w6+0.000000000001) 76 | # print('fi',fi.mean()) 77 | # fi=fv 78 | x=torch.cat((fv,fi),1) 79 | 80 | x = self.backend(x) 81 | x = self.output_layer(x) 82 | return x 83 | def _initialize_weights(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | nn.init.normal_(m.weight, std=0.01) 87 | if m.bias is not None: 88 | nn.init.constant_(m.bias, 0) 89 | elif isinstance(m, nn.BatchNorm2d): 90 | nn.init.constant_(m.weight, 1) 91 | nn.init.constant_(m.bias, 0) 92 | 93 | 94 | def make_layers(cfg, in_channels = 3,batch_norm=False,dilation = False): 95 | if dilation: 96 | d_rate = 2 97 | else: 98 | d_rate = 1 99 | layers = [] 100 | for v in cfg: 101 | if v == 'M': 102 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 103 | else: 104 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=d_rate,dilation = d_rate) 105 | if batch_norm: 106 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 107 | else: 108 | layers += [conv2d, nn.ReLU(inplace=True)] 109 | in_channels = v 110 | return nn.Sequential(*layers) 111 | 112 | 113 | # testing 114 | if __name__=="__main__": 115 | csrnet=CANNet().to('cuda') 116 | input_img=torch.ones((1,3,256,256)).to('cuda') 117 | out=csrnet(input_img) 118 | print(out.mean()) -------------------------------------------------------------------------------- /data_preparation/k_nearest_gaussian_kernel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import scipy.io as io 4 | from scipy.ndimage.filters import gaussian_filter 5 | import os 6 | import glob 7 | from matplotlib import pyplot as plt 8 | import h5py 9 | import PIL.Image as Image 10 | from matplotlib import cm as CM 11 | 12 | 13 | #partly borrowed from https://github.com/davideverona/deep-crowd-counting_crowdnet 14 | def gaussian_filter_density(img,points): 15 | ''' 16 | This code use k-nearst, will take one minute or more to generate a density-map with one thousand people. 17 | 18 | points: a two-dimension list of pedestrians' annotation with the order [[col,row],[col,row],...]. 19 | img_shape: the shape of the image, same as the shape of required density-map. (row,col). Note that can not have channel. 20 | 21 | return: 22 | density: the density-map we want. Same shape as input image but only has one channel. 23 | 24 | example: 25 | points: three pedestrians with annotation:[[163,53],[175,64],[189,74]]. 26 | img_shape: (768,1024) 768 is row and 1024 is column. 27 | ''' 28 | img_shape=[img.shape[0],img.shape[1]] 29 | print("Shape of current image: ",img_shape,". Totally need generate ",len(points),"gaussian kernels.") 30 | density = np.zeros(img_shape, dtype=np.float32) 31 | gt_count = len(points) 32 | if gt_count == 0: 33 | return density 34 | 35 | leafsize = 2048 36 | # build kdtree 37 | tree = scipy.spatial.KDTree(points.copy(), leafsize=leafsize) 38 | # query kdtree 39 | distances, locations = tree.query(points, k=4) 40 | 41 | print ('generate density...') 42 | for i, pt in enumerate(points): 43 | pt2d = np.zeros(img_shape, dtype=np.float32) 44 | if int(pt[1]) 1: 49 | sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1 50 | else: 51 | sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point 52 | density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 53 | print ('done.') 54 | return density 55 | 56 | 57 | # test code 58 | if __name__=="__main__": 59 | # show an example to use function generate_density_map_with_fixed_kernel. 60 | root = 'D:\\workspaceMaZhenwei\\GithubProject\\Crowd_counting_from_scratch\\data' 61 | 62 | # now generate the ShanghaiA's ground truth 63 | part_A_train = os.path.join(root,'part_A_final/train_data','images') 64 | part_A_test = os.path.join(root,'part_A_final/test_data','images') 65 | # part_B_train = os.path.join(root,'part_B_final/train_data','images') 66 | # part_B_test = os.path.join(root,'part_B_final/test_data','images') 67 | path_sets = [part_A_train,part_A_test] 68 | 69 | img_paths = [] 70 | for path in path_sets: 71 | for img_path in glob.glob(os.path.join(path, '*.jpg')): 72 | img_paths.append(img_path) 73 | 74 | for img_path in img_paths: 75 | print(img_path) 76 | mat = io.loadmat(img_path.replace('.jpg','.mat').replace('images','ground_truth').replace('IMG_','GT_IMG_')) 77 | img= plt.imread(img_path)#768行*1024列 78 | k = np.zeros((img.shape[0],img.shape[1])) 79 | points = mat["image_info"][0,0][0,0][0] #1546person*2(col,row) 80 | k = gaussian_filter_density(img,points) 81 | # plt.imshow(k,cmap=CM.jet) 82 | # save density_map to disk 83 | np.save(img_path.replace('.jpg','.npy').replace('images','ground_truth'), k) 84 | 85 | ''' 86 | #now see a sample from ShanghaiA 87 | plt.imshow(Image.open(img_paths[0])) 88 | 89 | gt_file = np.load(img_paths[0].replace('.jpg','.npy').replace('images','ground_truth')) 90 | plt.imshow(gt_file,cmap=CM.jet) 91 | 92 | print(np.sum(gt_file))# don't mind this slight variation 93 | ''' -------------------------------------------------------------------------------- /my_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from torchvision import transforms 8 | import random 9 | 10 | 11 | class CrowdDataset(Dataset): 12 | ''' 13 | crowdDataset 14 | ''' 15 | def __init__(self,img_root,gt_dmap_root,gt_downsample=1,phase='train'): 16 | ''' 17 | img_root: the root path of img. 18 | gt_dmap_root: the root path of ground-truth density-map. 19 | gt_downsample: default is 0, denote that the output of deep-model is the same size as input image. 20 | phase: train or test 21 | ''' 22 | self.img_root=img_root 23 | self.gt_dmap_root=gt_dmap_root 24 | self.gt_downsample=gt_downsample 25 | self.phase=phase 26 | 27 | self.img_names=[filename for filename in os.listdir(img_root) \ 28 | if os.path.isfile(os.path.join(img_root,filename))] 29 | self.n_samples=len(self.img_names) 30 | 31 | def __len__(self): 32 | return self.n_samples 33 | 34 | def __getitem__(self,index): 35 | assert index <= len(self), 'index range error' 36 | img_name=self.img_names[index] 37 | img=plt.imread(os.path.join(self.img_root,img_name))/255# convert from [0,255] to [0,1] 38 | 39 | if len(img.shape)==2: # expand grayscale image to three channel. 40 | img=img[:,:,np.newaxis] 41 | img=np.concatenate((img,img,img),2) 42 | 43 | gt_dmap=np.load(os.path.join(self.gt_dmap_root,img_name.replace('.jpg','.npy'))) 44 | 45 | if random.randint(0,1)==1 and self.phase=='train': 46 | img=img[:,::-1]#水平翻转 47 | gt_dmap=gt_dmap[:,::-1]#水平翻转 48 | 49 | if self.gt_downsample>1: # to downsample image and density-map to match deep-model. 50 | ds_rows=int(img.shape[0]//self.gt_downsample) 51 | ds_cols=int(img.shape[1]//self.gt_downsample) 52 | img = cv2.resize(img,(ds_cols*self.gt_downsample,ds_rows*self.gt_downsample)) 53 | img=img.transpose((2,0,1)) # convert to order (channel,rows,cols) 54 | gt_dmap=cv2.resize(gt_dmap,(ds_cols,ds_rows)) 55 | gt_dmap=gt_dmap[np.newaxis,:,:]*self.gt_downsample*self.gt_downsample 56 | 57 | img_tensor=torch.tensor(img,dtype=torch.float) 58 | img_tensor=transforms.functional.normalize(img_tensor,mean=[0.485, 0.456, 0.406], 59 | std=[0.229, 0.224, 0.225]) 60 | gt_dmap_tensor=torch.tensor(gt_dmap,dtype=torch.float) 61 | 62 | return img_tensor,gt_dmap_tensor 63 | 64 | 65 | # test code 66 | if __name__=="__main__": 67 | img_root="./data/Shanghai_part_A/train_data/images" 68 | gt_dmap_root="./data/Shanghai_part_A/train_data/ground_truth" 69 | dataset=CrowdDataset(img_root,gt_dmap_root,gt_downsample=8) 70 | for i,(img,gt_dmap) in enumerate(dataset): 71 | # plt.imshow(img) 72 | # plt.figure() 73 | # plt.imshow(gt_dmap) 74 | # plt.figure() 75 | # if i>5: 76 | # break 77 | print(img.shape,gt_dmap.shape) 78 | print(img.min(),img.max(),gt_dmap.min(),gt_dmap.max()) 79 | break -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import matplotlib.cm as CM 4 | from tqdm import tqdm 5 | 6 | from cannet import CANNet 7 | from my_dataset import CrowdDataset 8 | 9 | 10 | def cal_mae(img_root,gt_dmap_root,model_param_path): 11 | ''' 12 | Calculate the MAE of the test data. 13 | img_root: the root of test image data. 14 | gt_dmap_root: the root of test ground truth density-map data. 15 | model_param_path: the path of specific mcnn parameters. 16 | ''' 17 | device=torch.device("cuda") 18 | model=CANNet() 19 | model.load_state_dict(torch.load(model_param_path)) 20 | model.to(device) 21 | dataset=CrowdDataset(img_root,gt_dmap_root,8,phase='test') 22 | dataloader=torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=False) 23 | model.eval() 24 | mae=0 25 | with torch.no_grad(): 26 | for i,(img,gt_dmap) in enumerate(tqdm(dataloader)): 27 | img=img.to(device) 28 | gt_dmap=gt_dmap.to(device) 29 | # forward propagation 30 | et_dmap=model(img) 31 | mae+=abs(et_dmap.data.sum()-gt_dmap.data.sum()).item() 32 | del img,gt_dmap,et_dmap 33 | 34 | print("model_param_path:"+model_param_path+" mae:"+str(mae/len(dataloader))) 35 | 36 | def estimate_density_map(img_root,gt_dmap_root,model_param_path,index): 37 | ''' 38 | Show one estimated density-map. 39 | img_root: the root of test image data. 40 | gt_dmap_root: the root of test ground truth density-map data. 41 | model_param_path: the path of specific mcnn parameters. 42 | index: the order of the test image in test dataset. 43 | ''' 44 | device=torch.device("cuda") 45 | model=CANNet().to(device) 46 | model.load_state_dict(torch.load(model_param_path)) 47 | dataset=CrowdDataset(img_root,gt_dmap_root,8,phase='test') 48 | dataloader=torch.utils.data.DataLoader(dataset,batch_size=1,shuffle=False) 49 | model.eval() 50 | for i,(img,gt_dmap) in enumerate(dataloader): 51 | if i==index: 52 | img=img.to(device) 53 | gt_dmap=gt_dmap.to(device) 54 | # forward propagation 55 | et_dmap=model(img).detach() 56 | et_dmap=et_dmap.squeeze(0).squeeze(0).cpu().numpy() 57 | print(et_dmap.shape) 58 | plt.imshow(et_dmap,cmap=CM.jet) 59 | break 60 | 61 | 62 | if __name__=="__main__": 63 | torch.backends.cudnn.enabled=False 64 | img_root='./data/Shanghai_part_A/test_data/images' 65 | gt_dmap_root='./data/Shanghai_part_A/test_data/ground_truth' 66 | model_param_path='./checkpoints/epoch_354.pth' 67 | cal_mae(img_root,gt_dmap_root,model_param_path) 68 | # estimate_density_map(img_root,gt_dmap_root,model_param_path,3) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import visdom 7 | import random 8 | from tqdm import tqdm as tqdm 9 | 10 | from cannet import CANNet 11 | from my_dataset import CrowdDataset 12 | 13 | if __name__=="__main__": 14 | # configuration 15 | train_image_root='./data/Shanghai_part_A/train_data/images' 16 | train_dmap_root='./data/Shanghai_part_A/train_data/ground_truth' 17 | test_image_root='./data/Shanghai_part_A/test_data/images' 18 | test_dmap_root='./data/Shanghai_part_A/test_data/ground_truth' 19 | gpu_or_cpu='cuda' # use cuda or cpu 20 | lr = 1e-7 21 | batch_size = 1 22 | momentum = 0.95 23 | epochs = 20000 24 | steps = [-1,1,100,150] 25 | scales = [1,1,1,1] 26 | workers = 4 27 | seed = time.time() 28 | print_freq = 30 29 | 30 | vis=visdom.Visdom() 31 | device=torch.device(gpu_or_cpu) 32 | torch.cuda.manual_seed(seed) 33 | model=CANNet().to(device) 34 | criterion=nn.MSELoss(size_average=False).to(device) 35 | optimizer=torch.optim.SGD(model.parameters(),lr, 36 | momentum=momentum, 37 | weight_decay=0) 38 | # optimizer=torch.optim.Adam(model.parameters(),lr) 39 | train_dataset=CrowdDataset(train_image_root,train_dmap_root,gt_downsample=8,phase='train') 40 | train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=1,shuffle=True) 41 | test_dataset=CrowdDataset(test_image_root,test_dmap_root,gt_downsample=8,phase='test') 42 | test_loader=torch.utils.data.DataLoader(test_dataset,batch_size=1,shuffle=False) 43 | 44 | if not os.path.exists('./checkpoints'): 45 | os.mkdir('./checkpoints') 46 | min_mae=10000 47 | min_epoch=0 48 | train_loss_list=[] 49 | epoch_list=[] 50 | test_error_list=[] 51 | for epoch in range(0,epochs): 52 | # training phase 53 | model.train() 54 | epoch_loss=0 55 | for i,(img,gt_dmap) in enumerate(tqdm(train_loader)): 56 | img=img.to(device) 57 | gt_dmap=gt_dmap.to(device) 58 | # forward propagation 59 | et_dmap=model(img) 60 | # calculate loss 61 | loss=criterion(et_dmap,gt_dmap) 62 | epoch_loss+=loss.item() 63 | optimizer.zero_grad() 64 | loss.backward() 65 | optimizer.step() 66 | # print("epoch:",epoch,"loss:",epoch_loss/len(dataloader)) 67 | epoch_list.append(epoch) 68 | train_loss_list.append(epoch_loss/len(train_loader)) 69 | torch.save(model.state_dict(),'./checkpoints/epoch_'+str(epoch)+".pth") 70 | 71 | # testing phase 72 | model.eval() 73 | mae=0 74 | for i,(img,gt_dmap) in enumerate(tqdm(test_loader)): 75 | img=img.to(device) 76 | gt_dmap=gt_dmap.to(device) 77 | # forward propagation 78 | et_dmap=model(img) 79 | mae+=abs(et_dmap.data.sum()-gt_dmap.data.sum()).item() 80 | del img,gt_dmap,et_dmap 81 | if mae/len(test_loader)