├── .gitattributes ├── README.md ├── dataset.py ├── image.py ├── img ├── 1553080510253.png ├── 1553080672186.png ├── 1553081148694.png ├── 1553081269333.png ├── 1553083826276.png ├── 1553084203041.png ├── 1553084581750.png ├── 1553085088872.png ├── 1553085203128.png ├── 1553085315569.png ├── 1553085525410.png ├── 1553085714302.png ├── 1553085877665.png ├── 1553086025902.png ├── 1553086270822.png ├── 1553086450949.png ├── 1553088999563.png ├── 1553089130159.png ├── 1553089178172.png ├── 1553089513587.png ├── 1553123470473.png ├── 1553156680043.png ├── 1553157191632.png ├── 1553157269984.png ├── 1553157324880.png ├── 1553157626748.png ├── 1553157727196.png ├── 1553157772668.png ├── 1553158312997.png ├── 1553158352180.png ├── 1553158644963.png ├── 1554628742996.png ├── 1554628757778.png ├── 1554628804461.png ├── 1554628873425.png ├── 1554628903758.png ├── 1554628933425.png ├── 1554628946463.png └── 1554628958176.png ├── make_dataset.ipynb ├── make_dataset.py ├── make_model.ipynb ├── model.py ├── test_single-image.py ├── train.py ├── utils.py ├── val.ipynb └── val.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # chapter5-learning_CSRNet 4 | 5 | What **CSRNet** is doing: counting the number of people in the picture 6 | 7 | ------ 8 | 9 | I encountered some problems when building the environment according to the author's github(https://github.com/leeyeehoo/CSRNet-pytorch). I debugged it and Baidu took a long time to solve it. The purpose of writing this chapter is to help everyone learn better and take less detours. 10 | 11 | In this article I will lead everyone to debug the code and visualize it. 12 | 13 | ![1553123470473](img/1553123470473.png) 14 | 15 | ------ 16 | 17 | ## step1. install 18 | 19 | For the specific installation process, you can refer to the author's github. Here I simply show the command line of my operation. 20 | 21 | ``` 22 | conda create -n CSRNet python=3.6 23 | source activate CSRNet 24 | unzip CSRNet-pytorch-master.zip 25 | pip install -i https://pypi.tuna.tsinghua.edu.cn/simple torch torchvision 26 | pip install decorator cloudpickle>=0.2.1 dask[array]>=1.0.0 matplotlib>=2.0.0 networkx>=1.8 scipy>=0.17.0 bleach python-dateutil>=2.1 decorator 27 | unzip ShanghaiTech_Crowd_Counting_Dataset.zip 28 | jupyter nbconvert --to script make_dataset.ipynb #Convert .ipynb file to .py file 29 | ``` 30 | 31 | ## step2. make_dataset.py 32 | 33 | I just run the command to convert the **make_dataset.ipynb** file to a **make_dataset.py** file.Now you need to modify the contents of the **make_dataset.py** file. 34 | 35 | Find the location where **root** is, add **def main()** in the above line 36 | 37 | ![1553080510253](img/1553080510253.png) 38 | 39 | Add these two lines at the end of the **make_dataset.py**, adjust the format of the code 40 | 41 | ![1553080672186](img/1553080672186.png) 42 | 43 | There is an error in the author's source code, you need to change the code 44 | 45 | Replace **pts = np.array(zip(np.nonzero(gt)[1], np.nonzero(gt)[0]))** with **pts = np.array(list(zip(np.nonzero(gt)[1], np.nonzero(gt)[0])))** 46 | 47 | ![1553081148694](img/1553081148694.png) 48 | 49 | Then run the **make_dataset.py** file 50 | 51 | ![1553081269333](img/1553081269333.png) 52 | 53 | ------ 54 | 55 | **The above is just a general summary, then we will run and visualize the line-by-line code.** 56 | 57 | I will use this image as an example. 58 | 59 | ![1553083826276](img/1553083826276.png) 60 | 61 | ``` 62 | # coding: utf-8 63 | import h5py 64 | import scipy.io as io 65 | import PIL.Image as Image 66 | import numpy as np 67 | import os 68 | import glob 69 | from matplotlib import pyplot as plt 70 | from scipy.ndimage.filters import gaussian_filter 71 | import scipy 72 | import json 73 | from matplotlib import cm as CM 74 | from image import * 75 | from model import CSRNet 76 | import torch 77 | img_path='D:\\paper\\CSRNet\\CSRNet\\dataset\\Shanghai\\part_A_final\\train_data\\images\\IMG_21.jpg' 78 | mat='D:\\paper\\CSRNet\\CSRNet\\dataset\\Shanghai\\part_A_final\\train_data\\ground_truth\\GT_IMG_21.mat' 79 | mat = io.loadmat(mat) 80 | img= plt.imread(img_path) 81 | k = np.zeros((img.shape[0],img.shape[1])) 82 | ``` 83 | 84 | The following is the information of **k** 85 | 86 | ![1553084203041](img/1553084203041.png) 87 | 88 | ``` 89 | gt = mat["image_info"][0,0][0,0][0] 90 | ``` 91 | 92 | ![1553084581750](img/1553084581750.png) 93 | 94 | ``` 95 | for i in range(0,len(gt)): 96 | if int(gt[i][1]) 1: 143 | sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1 144 | else: 145 | sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point 146 | density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 147 | print('done.') 148 | ``` 149 | 150 | ![1553086270822](img/1553086270822.png) 151 | 152 | ``` 153 | k = density 154 | with h5py.File(img_path.replace('.jpg','.h5').replace('images','ground_truth'), 'w') as hf: 155 | hf['density'] = k 156 | ``` 157 | 158 | ![1553086450949](img/1553086450949.png) 159 | 160 | So far, we have generated true values for the image. At this point I will sort the above code as follows 161 | 162 | ``` 163 | # coding: utf-8 164 | import h5py 165 | import scipy.io as io 166 | import PIL.Image as Image 167 | import numpy as np 168 | import os 169 | import glob 170 | from matplotlib import pyplot as plt 171 | from scipy.ndimage.filters import gaussian_filter 172 | import scipy 173 | import json 174 | from matplotlib import cm as CM 175 | from image import * 176 | from model import CSRNet 177 | import torch 178 | def gaussian_filter_density(gt): 179 | print(gt.shape) 180 | density = np.zeros(gt.shape, dtype=np.float32) 181 | gt_count = np.count_nonzero(gt) 182 | if gt_count == 0: 183 | return density 184 | 185 | # pts = np.array(zip(np.nonzero(gt)[1], np.nonzero(gt)[0])) 186 | pts = np.array(list(zip(np.nonzero(gt)[1], np.nonzero(gt)[0]))) 187 | leafsize = 2048 188 | # build kdtree 189 | tree = scipy.spatial.KDTree(pts.copy(), leafsize=leafsize) 190 | # query kdtree 191 | distances, locations = tree.query(pts, k=4) 192 | 193 | print('generate density...') 194 | for i, pt in enumerate(pts): 195 | pt2d = np.zeros(gt.shape, dtype=np.float32) 196 | pt2d[pt[1],pt[0]] = 1. 197 | if gt_count > 1: 198 | sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1 199 | else: 200 | sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point 201 | density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 202 | print('done.') 203 | return density 204 | img_path='D:\\paper\\CSRNet\\CSRNet\\dataset\\Shanghai\\part_A_final\\train_data\\images\\IMG_21.jpg' 205 | mat='D:\\paper\\CSRNet\\CSRNet\\dataset\\Shanghai\\part_A_final\\train_data\\ground_truth\\GT_IMG_21.mat' 206 | img_paths = [] 207 | img_paths.append(img_path) 208 | for img_path in img_paths: 209 | print(img_path) 210 | mat = io.loadmat(mat) 211 | img= plt.imread(img_path) 212 | k = np.zeros((img.shape[0],img.shape[1])) 213 | gt = mat["image_info"][0,0][0,0][0] 214 | for i in range(0,len(gt)): 215 | if int(gt[i][1])0.8: 34 | target = np.fliplr(target) 35 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 36 | 37 | 38 | 39 | 40 | target = cv2.resize(target,(target.shape[1]//8,target.shape[0]//8),interpolation = cv2.INTER_CUBIC)*64 41 | 42 | 43 | return img,target -------------------------------------------------------------------------------- /img/1553080510253.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553080510253.png -------------------------------------------------------------------------------- /img/1553080672186.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553080672186.png -------------------------------------------------------------------------------- /img/1553081148694.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553081148694.png -------------------------------------------------------------------------------- /img/1553081269333.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553081269333.png -------------------------------------------------------------------------------- /img/1553083826276.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553083826276.png -------------------------------------------------------------------------------- /img/1553084203041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553084203041.png -------------------------------------------------------------------------------- /img/1553084581750.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553084581750.png -------------------------------------------------------------------------------- /img/1553085088872.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553085088872.png -------------------------------------------------------------------------------- /img/1553085203128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553085203128.png -------------------------------------------------------------------------------- /img/1553085315569.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553085315569.png -------------------------------------------------------------------------------- /img/1553085525410.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553085525410.png -------------------------------------------------------------------------------- /img/1553085714302.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553085714302.png -------------------------------------------------------------------------------- /img/1553085877665.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553085877665.png -------------------------------------------------------------------------------- /img/1553086025902.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553086025902.png -------------------------------------------------------------------------------- /img/1553086270822.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553086270822.png -------------------------------------------------------------------------------- /img/1553086450949.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553086450949.png -------------------------------------------------------------------------------- /img/1553088999563.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553088999563.png -------------------------------------------------------------------------------- /img/1553089130159.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553089130159.png -------------------------------------------------------------------------------- /img/1553089178172.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553089178172.png -------------------------------------------------------------------------------- /img/1553089513587.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553089513587.png -------------------------------------------------------------------------------- /img/1553123470473.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553123470473.png -------------------------------------------------------------------------------- /img/1553156680043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553156680043.png -------------------------------------------------------------------------------- /img/1553157191632.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553157191632.png -------------------------------------------------------------------------------- /img/1553157269984.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553157269984.png -------------------------------------------------------------------------------- /img/1553157324880.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553157324880.png -------------------------------------------------------------------------------- /img/1553157626748.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553157626748.png -------------------------------------------------------------------------------- /img/1553157727196.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553157727196.png -------------------------------------------------------------------------------- /img/1553157772668.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553157772668.png -------------------------------------------------------------------------------- /img/1553158312997.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553158312997.png -------------------------------------------------------------------------------- /img/1553158352180.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553158352180.png -------------------------------------------------------------------------------- /img/1553158644963.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1553158644963.png -------------------------------------------------------------------------------- /img/1554628742996.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1554628742996.png -------------------------------------------------------------------------------- /img/1554628757778.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1554628757778.png -------------------------------------------------------------------------------- /img/1554628804461.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1554628804461.png -------------------------------------------------------------------------------- /img/1554628873425.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1554628873425.png -------------------------------------------------------------------------------- /img/1554628903758.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1554628903758.png -------------------------------------------------------------------------------- /img/1554628933425.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1554628933425.png -------------------------------------------------------------------------------- /img/1554628946463.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1554628946463.png -------------------------------------------------------------------------------- /img/1554628958176.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ruru-Xu/chapter5-learning_CSRNet/64472756bc173f1eadf38bb5d89aac30232c09dc/img/1554628958176.png -------------------------------------------------------------------------------- /make_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/leeyh/anaconda2/lib/python2.7/site-packages/numpy/core/machar.py:127: RuntimeWarning: overflow encountered in add\n", 13 | " a = a + a\n", 14 | "/home/leeyh/anaconda2/lib/python2.7/site-packages/numpy/core/machar.py:129: RuntimeWarning: invalid value encountered in subtract\n", 15 | " temp1 = temp - a\n", 16 | "/home/leeyh/anaconda2/lib/python2.7/site-packages/numpy/core/machar.py:138: RuntimeWarning: invalid value encountered in subtract\n", 17 | " itemp = int_conv(temp-a)\n", 18 | "/home/leeyh/anaconda2/lib/python2.7/site-packages/numpy/core/machar.py:162: RuntimeWarning: overflow encountered in add\n", 19 | " a = a + a\n", 20 | "/home/leeyh/anaconda2/lib/python2.7/site-packages/numpy/core/machar.py:164: RuntimeWarning: invalid value encountered in subtract\n", 21 | " temp1 = temp - a\n", 22 | "/home/leeyh/anaconda2/lib/python2.7/site-packages/numpy/core/machar.py:171: RuntimeWarning: invalid value encountered in subtract\n", 23 | " if any(temp-a != zero):\n" 24 | ] 25 | } 26 | ], 27 | "source": [ 28 | "import h5py\n", 29 | "import scipy.io as io\n", 30 | "import PIL.Image as Image\n", 31 | "import numpy as np\n", 32 | "import os\n", 33 | "import glob\n", 34 | "from matplotlib import pyplot as plt\n", 35 | "from scipy.ndimage.filters import gaussian_filter \n", 36 | "import scipy\n", 37 | "import json\n", 38 | "from matplotlib import cm as CM\n", 39 | "from image import *\n", 40 | "from model import CSRNet\n", 41 | "import torch\n", 42 | "%matplotlib inline" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "collapsed": true 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "#this is borrowed from https://github.com/davideverona/deep-crowd-counting_crowdnet\n", 54 | "def gaussian_filter_density(gt):\n", 55 | " print gt.shape\n", 56 | " density = np.zeros(gt.shape, dtype=np.float32)\n", 57 | " gt_count = np.count_nonzero(gt)\n", 58 | " if gt_count == 0:\n", 59 | " return density\n", 60 | "\n", 61 | " pts = np.array(zip(np.nonzero(gt)[1], np.nonzero(gt)[0]))\n", 62 | " leafsize = 2048\n", 63 | " # build kdtree\n", 64 | " tree = scipy.spatial.KDTree(pts.copy(), leafsize=leafsize)\n", 65 | " # query kdtree\n", 66 | " distances, locations = tree.query(pts, k=4)\n", 67 | "\n", 68 | " print 'generate density...'\n", 69 | " for i, pt in enumerate(pts):\n", 70 | " pt2d = np.zeros(gt.shape, dtype=np.float32)\n", 71 | " pt2d[pt[1],pt[0]] = 1.\n", 72 | " if gt_count > 1:\n", 73 | " sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1\n", 74 | " else:\n", 75 | " sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point\n", 76 | " density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant')\n", 77 | " print 'done.'\n", 78 | " return density" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 2, 84 | "metadata": { 85 | "collapsed": true 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "#set the root to the Shanghai dataset you download\n", 90 | "root = '/home/leeyh/Downloads/Shanghai/'" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": { 97 | "collapsed": true 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "#now generate the ShanghaiA's ground truth\n", 102 | "part_A_train = os.path.join(root,'part_A_final/train_data','images')\n", 103 | "part_A_test = os.path.join(root,'part_A_final/test_data','images')\n", 104 | "part_B_train = os.path.join(root,'part_B_final/train_data','images')\n", 105 | "part_B_test = os.path.join(root,'part_B_final/test_data','images')\n", 106 | "path_sets = [part_A_train,part_A_test]" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": 4, 112 | "metadata": { 113 | "collapsed": true 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "img_paths = []\n", 118 | "for path in path_sets:\n", 119 | " for img_path in glob.glob(os.path.join(path, '*.jpg')):\n", 120 | " img_paths.append(img_path)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": { 127 | "collapsed": true, 128 | "scrolled": true 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "for img_path in img_paths:\n", 133 | " print img_path\n", 134 | " mat = io.loadmat(img_path.replace('.jpg','.mat').replace('images','ground_truth').replace('IMG_','GT_IMG_'))\n", 135 | " img= plt.imread(img_path)\n", 136 | " k = np.zeros((img.shape[0],img.shape[1]))\n", 137 | " gt = mat[\"image_info\"][0,0][0,0][0]\n", 138 | " for i in range(0,len(gt)):\n", 139 | " if int(gt[i][1]) 1: 48 | sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1 49 | else: 50 | sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point 51 | density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 52 | print('done.') 53 | return density 54 | 55 | 56 | # In[2]: 57 | 58 | def main(): 59 | #set the root to the Shanghai dataset you download 60 | # root = '/home/leeyh/Downloads/Shanghai/' 61 | root = '/home/imc/XR/temp/CSRNet/dataset/Shanghai/' 62 | 63 | # In[3]: 64 | 65 | 66 | #now generate the ShanghaiA's ground truth 67 | part_A_train = os.path.join(root,'part_A_final/train_data','images') 68 | part_A_test = os.path.join(root,'part_A_final/test_data','images') 69 | part_B_train = os.path.join(root,'part_B_final/train_data','images') 70 | part_B_test = os.path.join(root,'part_B_final/test_data','images') 71 | path_sets = [part_A_train,part_A_test] 72 | 73 | 74 | # In[4]: 75 | 76 | 77 | img_paths = [] 78 | for path in path_sets: 79 | for img_path in glob.glob(os.path.join(path, '*.jpg')): 80 | img_paths.append(img_path) 81 | 82 | 83 | # In[ ]: 84 | 85 | 86 | for img_path in img_paths: 87 | print(img_path) 88 | mat = io.loadmat(img_path.replace('.jpg','.mat').replace('images','ground_truth').replace('IMG_','GT_IMG_')) 89 | img= plt.imread(img_path) 90 | k = np.zeros((img.shape[0],img.shape[1])) 91 | gt = mat["image_info"][0,0][0,0][0] 92 | for i in range(0,len(gt)): 93 | if int(gt[i][1]) loading checkpoint '{}'".format(args.pre)) 78 | checkpoint = torch.load(args.pre) 79 | args.start_epoch = checkpoint['epoch'] 80 | best_prec1 = checkpoint['best_prec1'] 81 | model.load_state_dict(checkpoint['state_dict']) 82 | optimizer.load_state_dict(checkpoint['optimizer']) 83 | print("=> loaded checkpoint '{}' (epoch {})" 84 | .format(args.pre, checkpoint['epoch'])) 85 | else: 86 | print("=> no checkpoint found at '{}'".format(args.pre)) 87 | 88 | for epoch in range(args.start_epoch, args.epochs): 89 | 90 | adjust_learning_rate(optimizer, epoch) 91 | 92 | train(train_list, model, criterion, optimizer, epoch) 93 | prec1 = validate(val_list, model, criterion) 94 | 95 | is_best = prec1 < best_prec1 96 | best_prec1 = min(prec1, best_prec1) 97 | print(' * best MAE {mae:.3f} ' 98 | .format(mae=best_prec1)) 99 | save_checkpoint({ 100 | 'epoch': epoch + 1, 101 | 'arch': args.pre, 102 | 'state_dict': model.state_dict(), 103 | 'best_prec1': best_prec1, 104 | 'optimizer' : optimizer.state_dict(), 105 | }, is_best,args.task) 106 | 107 | def train(train_list, model, criterion, optimizer, epoch): 108 | 109 | losses = AverageMeter() 110 | batch_time = AverageMeter() 111 | data_time = AverageMeter() 112 | 113 | 114 | train_loader = torch.utils.data.DataLoader( 115 | dataset.listDataset(train_list, 116 | shuffle=True, 117 | transform=transforms.Compose([ 118 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], 119 | std=[0.229, 0.224, 0.225]), 120 | ]), 121 | train=True, 122 | seen=model.seen, 123 | batch_size=args.batch_size, 124 | num_workers=args.workers), 125 | batch_size=args.batch_size) 126 | print('epoch %d, processed %d samples, lr %.10f' % (epoch, epoch * len(train_loader.dataset), args.lr)) 127 | 128 | model.train() 129 | end = time.time() 130 | 131 | for i,(img, target)in enumerate(train_loader): 132 | data_time.update(time.time() - end) 133 | 134 | img = img.cuda() 135 | img = Variable(img) 136 | output = model(img) 137 | 138 | 139 | 140 | 141 | target = target.type(torch.FloatTensor).unsqueeze(0).cuda() 142 | target = Variable(target) 143 | 144 | 145 | loss = criterion(output, target) 146 | 147 | losses.update(loss.item(), img.size(0)) 148 | optimizer.zero_grad() 149 | loss.backward() 150 | optimizer.step() 151 | 152 | batch_time.update(time.time() - end) 153 | end = time.time() 154 | 155 | if i % args.print_freq == 0: 156 | print('Epoch: [{0}][{1}/{2}]\t' 157 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 158 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 159 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 160 | .format( 161 | epoch, i, len(train_loader), batch_time=batch_time, 162 | data_time=data_time, loss=losses)) 163 | 164 | def validate(val_list, model, criterion): 165 | print ('begin test') 166 | test_loader = torch.utils.data.DataLoader( 167 | dataset.listDataset(val_list, 168 | shuffle=False, 169 | transform=transforms.Compose([ 170 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], 171 | std=[0.229, 0.224, 0.225]), 172 | ]), train=False), 173 | batch_size=args.batch_size) 174 | 175 | model.eval() 176 | 177 | mae = 0 178 | 179 | for i,(img, target) in enumerate(test_loader): 180 | img = img.cuda() 181 | img = Variable(img) 182 | output = model(img) 183 | 184 | mae += abs(output.data.sum()-target.sum().type(torch.FloatTensor).cuda()) 185 | 186 | mae = mae/len(test_loader) 187 | print(' * MAE {mae:.3f} ' 188 | .format(mae=mae)) 189 | 190 | return mae 191 | 192 | def adjust_learning_rate(optimizer, epoch): 193 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 194 | 195 | 196 | args.lr = args.original_lr 197 | 198 | for i in range(len(args.steps)): 199 | 200 | scale = args.scales[i] if i < len(args.scales) else 1 201 | 202 | 203 | if epoch >= args.steps[i]: 204 | args.lr = args.lr * scale 205 | if epoch == args.steps[i]: 206 | break 207 | else: 208 | break 209 | for param_group in optimizer.param_groups: 210 | param_group['lr'] = args.lr 211 | 212 | class AverageMeter(object): 213 | """Computes and stores the average and current value""" 214 | def __init__(self): 215 | self.reset() 216 | 217 | def reset(self): 218 | self.val = 0 219 | self.avg = 0 220 | self.sum = 0 221 | self.count = 0 222 | 223 | def update(self, val, n=1): 224 | self.val = val 225 | self.sum += val * n 226 | self.count += n 227 | self.avg = self.sum / self.count 228 | 229 | if __name__ == '__main__': 230 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | import shutil 4 | 5 | def save_net(fname, net): 6 | with h5py.File(fname, 'w') as h5f: 7 | for k, v in net.state_dict().items(): 8 | h5f.create_dataset(k, data=v.cpu().numpy()) 9 | def load_net(fname, net): 10 | with h5py.File(fname, 'r') as h5f: 11 | for k, v in net.state_dict().items(): 12 | param = torch.from_numpy(np.asarray(h5f[k])) 13 | v.copy_(param) 14 | 15 | def save_checkpoint(state, is_best,task_id, filename='checkpoint.pth.tar'): 16 | torch.save(state, task_id+filename) 17 | if is_best: 18 | shutil.copyfile(task_id+filename, task_id+'model_best.pth.tar') -------------------------------------------------------------------------------- /val.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 42, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import h5py\n", 12 | "import scipy.io as io\n", 13 | "import PIL.Image as Image\n", 14 | "import numpy as np\n", 15 | "import os\n", 16 | "import glob\n", 17 | "from matplotlib import pyplot as plt\n", 18 | "from scipy.ndimage.filters import gaussian_filter \n", 19 | "import scipy\n", 20 | "import json\n", 21 | "import torchvision.transforms.functional as F\n", 22 | "from matplotlib import cm as CM\n", 23 | "from image import *\n", 24 | "from model import CSRNet\n", 25 | "import torch\n", 26 | "%matplotlib inline" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 10, 32 | "metadata": { 33 | "collapsed": true 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "from torchvision import datasets, transforms\n", 38 | "transform=transforms.Compose([\n", 39 | " transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],\n", 40 | " std=[0.229, 0.224, 0.225]),\n", 41 | " ])" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": { 48 | "collapsed": true 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "root = '/home/leeyh/Downloads/Shanghai/'" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "metadata": { 59 | "collapsed": true 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "#now generate the ShanghaiA's ground truth\n", 64 | "part_A_train = os.path.join(root,'part_A_final/train_data','images')\n", 65 | "part_A_test = os.path.join(root,'part_A_final/test_data','images')\n", 66 | "part_B_train = os.path.join(root,'part_B_final/train_data','images')\n", 67 | "part_B_test = os.path.join(root,'part_B_final/test_data','images')\n", 68 | "path_sets = [part_A_test]" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 5, 74 | "metadata": { 75 | "collapsed": true 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "img_paths = []\n", 80 | "for path in path_sets:\n", 81 | " for img_path in glob.glob(os.path.join(path, '*.jpg')):\n", 82 | " img_paths.append(img_path)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 6, 88 | "metadata": { 89 | "collapsed": true 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "model = CSRNet()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 7, 99 | "metadata": { 100 | "collapsed": true 101 | }, 102 | "outputs": [], 103 | "source": [ 104 | "model = model.cuda()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 38, 110 | "metadata": { 111 | "collapsed": true 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "checkpoint = torch.load('model_best.pth.tar')" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 39, 121 | "metadata": { 122 | "collapsed": true 123 | }, 124 | "outputs": [], 125 | "source": [ 126 | "model.load_state_dict(checkpoint['state_dict'])" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 45, 132 | "metadata": { 133 | "scrolled": true 134 | }, 135 | "outputs": [ 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "0 15.50390625\n", 141 | "1 60.9075317383\n", 142 | "2 220.511169434\n", 143 | "3 239.312469482\n", 144 | "4 252.252349854\n", 145 | "5 272.965286255\n", 146 | "6 457.101577759\n", 147 | "7 651.92250061\n", 148 | "8 681.363113403\n", 149 | "9 785.472061157\n", 150 | "10 838.996322632\n", 151 | "11 1012.11277771\n", 152 | "12 1073.18791199\n", 153 | "13 1074.72886658\n", 154 | "14 1139.53701782\n", 155 | "15 1201.55630493\n", 156 | "16 1316.97366333\n", 157 | "17 1447.83328247\n", 158 | "18 1578.5967865\n", 159 | "19 1622.11299896\n", 160 | "20 1650.6510849\n", 161 | "21 1756.23152924\n", 162 | "22 1810.32888031\n", 163 | "23 1816.16628265\n", 164 | "24 1864.3579483\n", 165 | "25 1886.67713928\n", 166 | "26 1942.00212097\n", 167 | "27 1995.17939758\n", 168 | "28 2031.85020447\n", 169 | "29 2236.5879364\n", 170 | "30 2265.58026123\n", 171 | "31 2272.81892395\n", 172 | "32 2333.38215637\n", 173 | "33 2489.37367249\n", 174 | "34 2560.18891907\n", 175 | "35 2580.02906799\n", 176 | "36 2588.45735168\n", 177 | "37 2661.36177063\n", 178 | "38 2788.47312927\n", 179 | "39 2900.07542419\n", 180 | "40 2900.83190918\n", 181 | "41 2976.19485474\n", 182 | "42 2995.64665985\n", 183 | "43 3058.02416229\n", 184 | "44 3103.91133881\n", 185 | "45 3241.52921295\n", 186 | "46 3906.06101227\n", 187 | "47 3918.4709549\n", 188 | "48 3948.23314667\n", 189 | "49 3953.28383636\n", 190 | "50 3989.9439621\n", 191 | "51 4093.59087372\n", 192 | "52 4180.21788788\n", 193 | "53 4289.41963959\n", 194 | "54 4291.79859161\n", 195 | "55 4302.33109283\n", 196 | "56 4339.80475616\n", 197 | "57 4543.3482132\n", 198 | "58 4626.09952545\n", 199 | "59 4684.85929108\n", 200 | "60 4713.21773529\n", 201 | "61 4801.95433807\n", 202 | "62 5371.40599823\n", 203 | "63 5430.06401062\n", 204 | "64 5463.51008606\n", 205 | "65 5517.62242126\n", 206 | "66 5531.55604553\n", 207 | "67 5531.86631775\n", 208 | "68 5547.58416748\n", 209 | "69 5586.52706909\n", 210 | "70 5588.99980164\n", 211 | "71 5711.29048157\n", 212 | "72 5732.70922852\n", 213 | "73 5764.59197998\n", 214 | "74 5799.78912354\n", 215 | "75 6082.0055542\n", 216 | "76 6110.95211792\n", 217 | "77 6141.81124878\n", 218 | "78 6156.64276886\n", 219 | "79 6169.39850616\n", 220 | "80 6193.0460434\n", 221 | "81 6232.38686371\n", 222 | "82 6257.28891754\n", 223 | "83 6367.4381485\n", 224 | "84 6411.16867828\n", 225 | "85 6790.66916656\n", 226 | "86 6935.87509918\n", 227 | "87 7059.09088898\n", 228 | "88 7105.02100372\n", 229 | "89 7138.09949493\n", 230 | "90 7177.44655609\n", 231 | "91 7284.53981781\n", 232 | "92 7363.71459198\n", 233 | "93 7486.82701874\n", 234 | "94 7518.29531097\n", 235 | "95 7519.01638031\n", 236 | "96 7553.38800049\n", 237 | "97 7622.98153687\n", 238 | "98 7635.38021851\n", 239 | "99 7692.21868896\n", 240 | "100 7695.25094604\n", 241 | "101 7711.97100067\n", 242 | "102 7772.21665192\n", 243 | "103 7800.12631989\n", 244 | "104 7804.24352264\n", 245 | "105 7861.25572968\n", 246 | "106 7887.95476532\n", 247 | "107 7913.09929657\n", 248 | "108 7966.69908905\n", 249 | "109 7977.07558441\n", 250 | "110 7978.10308075\n", 251 | "111 8035.75792694\n", 252 | "112 8099.88143158\n", 253 | "113 8713.50301361\n", 254 | "114 8755.26369476\n", 255 | "115 8764.1200943\n", 256 | "116 8807.97383881\n", 257 | "117 8831.98430634\n", 258 | "118 8899.79549408\n", 259 | "119 9223.38906097\n", 260 | "120 9325.20153046\n", 261 | "121 9335.49025726\n", 262 | "122 9391.73867035\n", 263 | "123 9429.91559601\n", 264 | "124 9475.48241425\n", 265 | "125 9504.1687851\n", 266 | "126 9571.56279755\n", 267 | "127 9621.32411957\n", 268 | "128 9771.76332855\n", 269 | "129 9825.34178925\n", 270 | "130 9829.55644226\n", 271 | "131 9852.76200867\n", 272 | "132 9933.64692688\n", 273 | "133 9962.82582092\n", 274 | "134 10008.0011597\n", 275 | "135 10096.1471558\n", 276 | "136 10133.0110474\n", 277 | "137 10251.9801483\n", 278 | "138 10256.5829926\n", 279 | "139 10268.2446671\n", 280 | "140 10314.4458389\n", 281 | "141 10404.9172134\n", 282 | "142 10450.5001602\n", 283 | "143 10454.2003555\n", 284 | "144 10460.2950211\n", 285 | "145 10468.630188\n", 286 | "146 10513.0866699\n", 287 | "147 10540.4085999\n", 288 | "148 10969.8531189\n", 289 | "149 11258.8591614\n", 290 | "150 11272.8908997\n", 291 | "151 11405.8752747\n", 292 | "152 11432.7355957\n", 293 | "153 11438.5995789\n", 294 | "154 11497.5692749\n", 295 | "155 11842.559082\n", 296 | "156 11955.1886826\n", 297 | "157 12029.658989\n", 298 | "158 12082.984642\n", 299 | "159 12297.4541245\n", 300 | "160 12355.7072983\n", 301 | "161 12436.2029648\n", 302 | "162 12455.2758102\n", 303 | "163 12750.65905\n", 304 | "164 12854.170433\n", 305 | "165 12870.6474228\n", 306 | "166 12913.027977\n", 307 | "167 12924.1338921\n", 308 | "168 12944.4942436\n", 309 | "169 12986.8365135\n", 310 | "170 13030.5286522\n", 311 | "171 13076.5168724\n", 312 | "172 13136.3384972\n", 313 | "173 13241.9948387\n", 314 | "174 13256.391201\n", 315 | "175 13441.6517601\n", 316 | "176 13550.5347557\n", 317 | "177 13579.6995201\n", 318 | "178 13612.6000328\n", 319 | "179 13623.0231133\n", 320 | "180 13667.2316093\n", 321 | "181 13725.9745598\n", 322 | "75.4174426362\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "mae = 0\n", 328 | "for i in xrange(len(img_paths)):\n", 329 | " img = 255.0 * F.to_tensor(Image.open(img_paths[i]).convert('RGB'))\n", 330 | "\n", 331 | " img[0,:,:]=img[0,:,:]-92.8207477031\n", 332 | " img[1,:,:]=img[1,:,:]-95.2757037428\n", 333 | " img[2,:,:]=img[2,:,:]-104.877445883\n", 334 | " img = img.cuda()\n", 335 | " #img = transform(Image.open(img_paths[i]).convert('RGB')).cuda()\n", 336 | " gt_file = h5py.File(img_paths[i].replace('.jpg','.h5').replace('images','ground_truth'),'r')\n", 337 | " groundtruth = np.asarray(gt_file['density'])\n", 338 | " output = model(img.unsqueeze(0))\n", 339 | " mae += abs(output.detach().cpu().sum().numpy()-np.sum(groundtruth))\n", 340 | " print i,mae\n", 341 | "print mae/len(img_paths)" 342 | ] 343 | } 344 | ], 345 | "metadata": { 346 | "kernelspec": { 347 | "display_name": "Python 2", 348 | "language": "python", 349 | "name": "python2" 350 | }, 351 | "language_info": { 352 | "codemirror_mode": { 353 | "name": "ipython", 354 | "version": 2 355 | }, 356 | "file_extension": ".py", 357 | "mimetype": "text/x-python", 358 | "name": "python", 359 | "nbconvert_exporter": "python", 360 | "pygments_lexer": "ipython2", 361 | "version": "2.7.13" 362 | } 363 | }, 364 | "nbformat": 4, 365 | "nbformat_minor": 2 366 | } 367 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[42]: 5 | 6 | 7 | import h5py 8 | import scipy.io as io 9 | import PIL.Image as Image 10 | import numpy as np 11 | import os 12 | import glob 13 | from matplotlib import pyplot as plt 14 | from scipy.ndimage.filters import gaussian_filter 15 | import scipy 16 | import json 17 | import torchvision.transforms.functional as F 18 | from matplotlib import cm as CM 19 | from image import * 20 | from model import CSRNet 21 | import torch 22 | # get_ipython().run_line_magic('matplotlib', 'inline') 23 | 24 | 25 | # In[10]: 26 | 27 | 28 | from torchvision import datasets, transforms 29 | transform=transforms.Compose([ 30 | transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]), 32 | ]) 33 | 34 | 35 | # In[3]: 36 | 37 | 38 | root = '/home/imc/XR/temp/CSRNet/dataset/Shanghai/' 39 | 40 | 41 | # In[4]: 42 | 43 | 44 | #now generate the ShanghaiA's ground truth 45 | part_A_train = os.path.join(root,'part_A_final/train_data','images') 46 | part_A_test = os.path.join(root,'part_A_final/test_data','images') 47 | part_B_train = os.path.join(root,'part_B_final/train_data','images') 48 | part_B_test = os.path.join(root,'part_B_final/test_data','images') 49 | path_sets = [part_A_test] 50 | 51 | 52 | # In[5]: 53 | 54 | 55 | img_paths = [] 56 | for path in path_sets: 57 | for img_path in glob.glob(os.path.join(path, '*.jpg')): 58 | img_paths.append(img_path) 59 | 60 | 61 | # In[6]: 62 | 63 | 64 | model = CSRNet() 65 | 66 | 67 | # In[7]: 68 | 69 | 70 | model = model.cuda() 71 | 72 | 73 | # In[38]: 74 | 75 | 76 | # checkpoint = torch.load('model_best.pth.tar') 77 | checkpoint = torch.load('dataset/Shanghai/PartAmodel_best.pth.tar') 78 | 79 | 80 | # In[39]: 81 | 82 | 83 | model.load_state_dict(checkpoint['state_dict']) 84 | 85 | 86 | # In[45]: 87 | 88 | 89 | mae = 0 90 | for i in range(len(img_paths)): 91 | img = 255.0 * F.to_tensor(Image.open(img_paths[i]).convert('RGB')) 92 | img[0,:,:]=img[0,:,:]-92.8207477031 93 | img[1,:,:]=img[1,:,:]-95.2757037428 94 | img[2,:,:]=img[2,:,:]-104.877445883 95 | img = img.cuda() 96 | img = transform(Image.open(img_paths[i]).convert('RGB')).cuda() 97 | gt_file = h5py.File(img_paths[i].replace('.jpg','.h5').replace('images','ground_truth'),'r') 98 | groundtruth = np.asarray(gt_file['density']) 99 | output = model(img.unsqueeze(0)) 100 | mae += abs(output.detach().cpu().sum().numpy()-np.sum(groundtruth)) 101 | print(i,mae) 102 | print(mae/len(img_paths)) 103 | 104 | --------------------------------------------------------------------------------