├── README.md ├── dataset.py ├── make_dataset.py ├── model.py ├── train.py ├── utils.py └── val.py /README.md: -------------------------------------------------------------------------------- 1 | # Pyramid_Scale_Network 2 | This is the PyTorch version repo for "Exploit the potential of Multi-column architecture for Crowd Counting", which delivered a state-of-the-art, straightforward and end-to-end architecture for crowd counting tasks. We also recommend another work on crowd counting([Deep Density-aware Count Regressor](https://github.com/GeorgeChenZJ/deepcount)), which is accepted by ECAI2020. 3 | 4 | # Datasets 5 | ShanghaiTech Dataset 6 | 7 | # Prerequisites 8 | We strongly recommend Anaconda as the environment. 9 | 10 | Python: 3.6 11 | 12 | PyTorch: 1.5.0 13 | 14 | # Train & Test 15 | 1、python make_dataset.py # generate the ground truth. the ShanghaiTech dataset should be placed in the "datasets" directory. 16 | 2、python train.py # train model 17 | 3、python val.py # test model 18 | 19 | # Results 20 | partA: MAE 55.5 MSE 90.1 21 | 22 | partB: MAE 6.8 MSE 10.7 23 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | crowd counting dataset 4 | """ 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import os 8 | import glob 9 | import numpy as np 10 | import h5py 11 | import cv2 12 | import random 13 | from torchvision.transforms import functional 14 | 15 | 16 | ################################################################################ 17 | # crowd counting dataset 18 | ################################################################################ 19 | class CrowdCountingDataset(Dataset): 20 | def __init__(self, dir_path, transforms, scale=8, mode='train'): 21 | """ 22 | :param 23 | dir_path(str) -- the path of the image directory 24 | transforms -- 25 | scale(int) -- density map scale factor 26 | mode(str) -- 27 | """ 28 | self.transforms = transforms 29 | self.scale = scale 30 | self.mode = mode 31 | 32 | # acquire image path 33 | self.img_paths = [] 34 | for img_path in glob.glob(os.path.join(dir_path, '*.jpg')): 35 | self.img_paths.append(img_path) 36 | 37 | def __getitem__(self, index): 38 | ##--load image--## 39 | img_path = self.img_paths[index] 40 | # read image 41 | img = Image.open(img_path).convert('RGB') 42 | # image size 43 | img_width, img_height = img.size 44 | 45 | ##--load density map--## 46 | density_path = img_path.replace('.jpg', '.h5').replace('images', 'density') 47 | # read density map 48 | with h5py.File(density_path, 'r') as hf: 49 | density = np.asarray(hf['density']) 50 | 51 | if self.mode != 'train': 52 | # image 53 | img = self.transforms(img) 54 | # density map 55 | gt = np.sum(density) 56 | density = cv2.resize(density, 57 | (density.shape[1] // self.scale, density.shape[0] // self.scale), 58 | interpolation=cv2.INTER_CUBIC) * (self.scale ** 2) 59 | density = density[np.newaxis, :, :] 60 | 61 | return img, gt, density 62 | 63 | # random resize 64 | short = min(img_width, img_height) 65 | if short < 512: 66 | scale = 512 / short 67 | img_width = round(img_width * scale) 68 | img_height = round(img_height * scale) 69 | img = img.resize((img_width, img_height), Image.BILINEAR) 70 | density = cv2.resize(density, (img_width, img_height), interpolation=cv2.INTER_LINEAR) / scale / scale 71 | scale = random.uniform(0.8, 1.2) 72 | img_width = round(img_width * scale) 73 | img_height = round(img_height * scale) 74 | img = img.resize((img_width, img_height), Image.BILINEAR) 75 | density = cv2.resize(density, (img_width, img_height), interpolation=cv2.INTER_LINEAR) / scale / scale 76 | 77 | # random crop 78 | h, w = 400, 400 79 | dh = random.randint(0, img_height - h) 80 | dw = random.randint(0, img_width - w) 81 | img = img.crop((dw, dh, dw + w, dh + h)) 82 | density = density[dh:dh + h, dw:dw + w] 83 | 84 | # random flip 85 | if random.random() < 0.5: 86 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 87 | density = density[:, ::-1] 88 | 89 | # random gamma 90 | if random.random() < 0.3: 91 | gamma = random.uniform(0.5, 1.5) 92 | img = functional.adjust_gamma(img, gamma) 93 | 94 | # random to gray 95 | if random.random() < 0.1: 96 | img = functional.to_grayscale(img, num_output_channels=3) 97 | 98 | img = self.transforms(img) 99 | density = cv2.resize(density, (density.shape[1] // self.scale, density.shape[0] // self.scale), 100 | interpolation=cv2.INTER_LINEAR) * self.scale * self.scale 101 | density = density[np.newaxis, :, :] 102 | 103 | return img, density 104 | 105 | def __len__(self): 106 | return len(self.img_paths) 107 | -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | data processing. 4 | """ 5 | from __future__ import print_function, division 6 | import h5py 7 | import scipy.io as io 8 | import numpy as np 9 | import glob 10 | from matplotlib import pyplot as plt 11 | from scipy.ndimage.filters import gaussian_filter 12 | import scipy 13 | import scipy.spatial 14 | import os 15 | # %matplotlib inline 16 | 17 | 18 | ################################################################################ 19 | # generate density maps of the ShanghaiTech dataset. 20 | # the density maps are stored in the density directory, 21 | # which is at the same level as the images directory. 22 | ################################################################################ 23 | def generate_density(): 24 | # root directory 25 | root = 'datasets/ShanghaiTech_Dataset' 26 | # PartA train and test set directory 27 | part_A_train = os.path.join(root, 'part_A_final/train_data', 'images') 28 | part_A_test = os.path.join(root, 'part_A_final/test_data', 'images') 29 | # # PartB train and test set directory 30 | # part_B_train = os.path.join(root, 'part_B_final/train_data', 'images') 31 | # part_B_test = os.path.join(root, 'part_B_final/test_data', 'images') 32 | 33 | path_sets = [part_A_train, part_A_test] 34 | 35 | # density map storage directory 36 | for path in path_sets: 37 | dir_path, _ = os.path.split(path) 38 | path = os.path.join(dir_path, 'density') 39 | if not os.path.exists(path): 40 | os.makedirs(path) 41 | 42 | # acquire image path 43 | img_paths = [] 44 | for path in path_sets: 45 | for img_path in glob.glob(os.path.join(path, '*.jpg')): 46 | img_paths.append(img_path) 47 | 48 | # generate density map 49 | for img_path in img_paths: 50 | print(img_path) 51 | mat = io.loadmat(img_path.replace('.jpg', '.mat').replace('images', 'ground_truth').replace('IMG_', 'GT_IMG_')) 52 | img = plt.imread(img_path) 53 | k = np.zeros((img.shape[0], img.shape[1])) 54 | gt = mat["image_info"][0, 0][0, 0][0] 55 | for i in range(0, len(gt)): 56 | if int(gt[i][1]) < img.shape[0] and int(gt[i][0]) < img.shape[1]: 57 | k[int(gt[i][1]), int(gt[i][0])] = 1 58 | k = gaussian_filter_density(k) 59 | with h5py.File(img_path.replace('.jpg', '.h5').replace('images', 'density'), 'w') as hf: 60 | hf['density'] = k 61 | 62 | print("Over!!!") 63 | 64 | 65 | ################################################################################ 66 | # internal function. 67 | ################################################################################ 68 | def gaussian_filter_density(gt): 69 | print(gt.shape) 70 | density = np.zeros(gt.shape, dtype = np.float32) 71 | gt_count = np.count_nonzero(gt) 72 | if gt_count == 0: 73 | return density 74 | 75 | pts = np.array(list(zip(np.nonzero(gt)[1], np.nonzero(gt)[0]))) 76 | leafsize = 2048 77 | # build kdtree 78 | tree = scipy.spatial.KDTree(pts.copy(), leafsize = leafsize) 79 | # query kdtree 80 | distances, locations = tree.query(pts, k=4) 81 | 82 | print('generate density...') 83 | for i, pt in enumerate(pts): 84 | pt2d = np.zeros(gt.shape, dtype = np.float32) 85 | pt2d[pt[1],pt[0]] = 1. 86 | if gt_count > 1: 87 | sigma = (distances[i][1]+distances[i][2]+distances[i][3])*0.1 88 | else: 89 | sigma = np.average(np.array(gt.shape))/2./2. #case: 1 point 90 | density += scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 91 | print('done.') 92 | return density 93 | 94 | 95 | ################################################################################ 96 | # main function 97 | ################################################################################ 98 | if __name__ == '__main__': 99 | # generate density map 100 | generate_density() 101 | 102 | pass -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | PSNet model 4 | """ 5 | from __future__ import print_function, division 6 | import torch.nn as nn 7 | import torch 8 | from torch.utils import model_zoo 9 | import numpy as np 10 | import torch.nn.functional as F 11 | 12 | 13 | ################################################################################ 14 | # PSNet 15 | ################################################################################ 16 | class PSNet(nn.Module): 17 | def __init__(self): 18 | super(PSNet, self).__init__() 19 | self.vgg = VGG() 20 | self.dmp = BackEnd() 21 | 22 | self._load_vgg() 23 | 24 | def forward(self, input): 25 | input = self.vgg(input) 26 | dmp_out = self.dmp(*input) 27 | 28 | return dmp_out 29 | 30 | def _load_vgg(self): 31 | state_dict = model_zoo.load_url('https://download.pytorch.org/models/vgg16_bn-6c64b313.pth') 32 | old_name = [0, 1, 3, 4, 7, 8, 10, 11, 14, 15, 17, 18, 20, 21, 24, 25, 27, 28, 30, 31, 34, 35, 37, 38, 40, 41] 33 | new_name = ['1_1', '1_2', '2_1', '2_2', '3_1', '3_2', '3_3', '4_1', '4_2', '4_3', '5_1', '5_2', '5_3'] 34 | new_dict = {} 35 | for i in range(10): 36 | new_dict['conv' + new_name[i] + '.conv.weight'] = \ 37 | state_dict['features.' + str(old_name[2 * i]) + '.weight'] 38 | new_dict['conv' + new_name[i] + '.conv.bias'] = \ 39 | state_dict['features.' + str(old_name[2 * i]) + '.bias'] 40 | new_dict['conv' + new_name[i] + '.bn.weight'] = \ 41 | state_dict['features.' + str(old_name[2 * i + 1]) + '.weight'] 42 | new_dict['conv' + new_name[i] + '.bn.bias'] = \ 43 | state_dict['features.' + str(old_name[2 * i + 1]) + '.bias'] 44 | new_dict['conv' + new_name[i] + '.bn.running_mean'] = \ 45 | state_dict['features.' + str(old_name[2 * i + 1]) + '.running_mean'] 46 | new_dict['conv' + new_name[i] + '.bn.running_var'] = \ 47 | state_dict['features.' + str(old_name[2 * i + 1]) + '.running_var'] 48 | 49 | self.vgg.load_state_dict(new_dict) 50 | 51 | 52 | class VGG(nn.Module): 53 | def __init__(self): 54 | super(VGG, self).__init__() 55 | self.pool = nn.MaxPool2d(2, 2) 56 | self.conv1_1 = BaseConv(3, 64, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 57 | self.conv1_2 = BaseConv(64, 64, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 58 | self.conv2_1 = BaseConv(64, 128, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 59 | self.conv2_2 = BaseConv(128, 128, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 60 | self.conv3_1 = BaseConv(128, 256, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 61 | self.conv3_2 = BaseConv(256, 256, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 62 | self.conv3_3 = BaseConv(256, 256, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 63 | self.conv4_1 = BaseConv(256, 512, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 64 | self.conv4_2 = BaseConv(512, 512, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 65 | self.conv4_3 = BaseConv(512, 512, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 66 | 67 | def forward(self, input): 68 | input = self.conv1_1(input) 69 | conv1_2 = self.conv1_2(input) 70 | 71 | input = self.pool(conv1_2) 72 | input = self.conv2_1(input) 73 | conv2_2 = self.conv2_2(input) 74 | 75 | input = self.pool(conv2_2) 76 | input = self.conv3_1(input) 77 | input = self.conv3_2(input) 78 | conv3_3 = self.conv3_3(input) 79 | 80 | input = self.pool(conv3_3) 81 | input = self.conv4_1(input) 82 | input = self.conv4_2(input) 83 | conv4_3 = self.conv4_3(input) 84 | 85 | return conv1_2, conv2_2, conv3_3, conv4_3 86 | 87 | 88 | class BackEnd(nn.Module): 89 | def __init__(self): 90 | super(BackEnd, self).__init__() 91 | 92 | self.dense1 = DenseModule(512) 93 | self.dense2 = DenseModule(512) 94 | self.dense3 = DenseModule(512) 95 | 96 | self.conv1 = BaseConv(512, 256, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 97 | self.conv2 = BaseConv(256, 128, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 98 | self.conv3 = BaseConv(128, 64, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 99 | self.conv4 = BaseConv(64, 1, 1, 1, activation=None, use_bn=False) 100 | 101 | def forward(self, *input): 102 | conv1_2, conv2_2, conv3_3, conv4_3 = input 103 | 104 | input, attention_map_1 = self.dense1(conv4_3) 105 | input, attention_map_2 = self.dense2(input) 106 | input, attention_map_3 = self.dense3(input) 107 | 108 | input = self.conv1(input) 109 | input = self.conv2(input) 110 | input = self.conv3(input) 111 | input = self.conv4(input) 112 | 113 | return input, attention_map_1, attention_map_2, attention_map_3 114 | 115 | 116 | class ChannelAttention(nn.Module): 117 | def __init__(self, in_planes, ratio=16): 118 | super(ChannelAttention, self).__init__() 119 | 120 | self.conv1 = BaseConv(in_planes, round(in_planes // ratio), 1, 1, activation=nn.ReLU(), use_bn=False) 121 | self.conv2 = BaseConv(round(in_planes // ratio), in_planes, 1, 1, activation=nn.Sigmoid(), use_bn=False) 122 | 123 | def forward(self, input): 124 | out = self.conv1(input) 125 | out = self.conv2(out) 126 | 127 | return out 128 | 129 | 130 | class DenseModule(nn.Module): 131 | def __init__(self, in_channels): 132 | super(DenseModule, self).__init__() 133 | 134 | self.conv3x3 = nn.Sequential( 135 | BaseConv(in_channels, in_channels // 4, 1, 1, activation=nn.ReLU(), use_bn=True), 136 | BaseConv(in_channels // 4, in_channels // 4, 3, 1, 1, activation=nn.ReLU(), use_bn=True)) 137 | self.conv5x5 = nn.Sequential( 138 | BaseConv(in_channels, in_channels // 4, 1, 1, activation=nn.ReLU(), use_bn=True), 139 | BaseConv(in_channels // 4, in_channels // 4, 3, 1, 2, 2, activation=nn.ReLU(), use_bn=True)) 140 | self.conv7x7 = nn.Sequential( 141 | BaseConv(in_channels, in_channels // 4, 1, 1, activation=nn.ReLU(), use_bn=True), 142 | BaseConv(in_channels // 4, in_channels // 4, 3, 1, 3, 3, activation=nn.ReLU(), use_bn=True)) 143 | self.conv9x9 = nn.Sequential( 144 | BaseConv(in_channels, in_channels // 4, 1, 1, activation=nn.ReLU(), use_bn=True), 145 | BaseConv(in_channels // 4, in_channels // 4, 3, 1, 4, 4, activation=nn.ReLU(), use_bn=True)) 146 | 147 | self.conv1 = BaseConv(in_channels // 2, in_channels // 4, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 148 | self.conv2 = BaseConv(in_channels // 2, in_channels // 4, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 149 | self.conv3 = BaseConv(in_channels // 2, in_channels // 4, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 150 | 151 | self.att = ChannelAttention(in_channels) 152 | 153 | self.conv = BaseConv(in_channels, in_channels, 3, 1, 1, activation=nn.ReLU(), use_bn=True) 154 | 155 | def forward(self, input): 156 | conv3x3 = self.conv3x3(input) 157 | conv5x5 = self.conv5x5(input) 158 | conv7x7 = self.conv7x7(input) 159 | conv9x9 = self.conv9x9(input) 160 | 161 | conv5x5 = self.conv1(torch.cat((conv3x3, conv5x5), dim=1)) 162 | conv7x7 = self.conv2(torch.cat((conv5x5, conv7x7), dim=1)) 163 | conv9x9 = self.conv3(torch.cat((conv7x7, conv9x9), dim=1)) 164 | 165 | att = self.att(input) 166 | 167 | out = self.conv(torch.cat((conv3x3, conv5x5, conv7x7, conv9x9), dim=1)) 168 | 169 | attention_map = torch.cat((torch.mean(conv3x3, dim=1, keepdim=True), 170 | torch.mean(conv5x5, dim=1, keepdim=True), 171 | torch.mean(conv7x7, dim=1, keepdim=True), 172 | torch.mean(conv9x9, dim=1, keepdim=True)), dim=1) 173 | 174 | return out * att, attention_map 175 | 176 | 177 | class BaseConv(nn.Module): 178 | def __init__(self, in_channels, out_channels, kernel, stride=1, padding=0, dilation=1, activation=None, 179 | use_bn=False): 180 | super(BaseConv, self).__init__() 181 | self.use_bn = use_bn 182 | self.activation = activation 183 | self.conv = nn.Conv2d(in_channels, out_channels, kernel, stride, padding, dilation) 184 | self.conv.weight.data.normal_(0, 0.01) 185 | self.conv.bias.data.zero_() 186 | self.bn = nn.BatchNorm2d(out_channels) 187 | self.bn.weight.data.fill_(1) 188 | self.bn.bias.data.zero_() 189 | 190 | def forward(self, input): 191 | input = self.conv(input) 192 | if self.use_bn: 193 | input = self.bn(input) 194 | if self.activation: 195 | input = self.activation(input) 196 | 197 | return input -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from torch.utils.data import DataLoader 3 | import torch.optim as optim 4 | from torchvision import transforms 5 | import torch 6 | import torch.nn as nn 7 | import os 8 | import numpy as np 9 | import time 10 | from utils import adjust_learning_rate, setup_seed 11 | from model import PSNet 12 | from dataset import CrowdCountingDataset 13 | # %matplotlib inline 14 | 15 | ################################################################################ 16 | # configuration 17 | ################################################################################ 18 | # set random seed for reproducibility 19 | manualSeed = 1 20 | # manualSeed = random.randint(1, 10000) # use if you want new results 21 | # print("Random Seed: ", manualSeed) 22 | setup_seed(manualSeed) 23 | # choose to run on cpu or cuda 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | # create a directory to store the model 26 | if not os.path.isdir('Checkpoint'): 27 | os.mkdir('Checkpoint') 28 | os.mkdir('Checkpoint/models') 29 | 30 | 31 | ################################################################################ 32 | # train PSNet model to generate density map 33 | ################################################################################ 34 | def train(): 35 | """ 36 | train model 37 | """ 38 | # set hyperparameter 39 | TRAIN_IMG_DIR = 'datasets/ShanghaiTech_Dataset/part_A_final/train_data/images' # the directory path for storing training set images 40 | TEST_IMG_DIR = 'datasets/ShanghaiTech_Dataset/part_A_final/test_data/images' # the directory path for storing test set images 41 | LR = 1e-4 # learning rate 42 | EPOCH = 460 # training epoch 43 | BATCH_SIZE = 12 # batch size 44 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 45 | resume = False # whether to breakpoint training 46 | workers = 4 # number of workers for dataloader 47 | hyper_param_D = 1 # the weight parameter of the loss function 48 | 49 | # best MAE, MSE 50 | # BEST_MAE = float("inf") 51 | BEST_MAE = 300 52 | # BEST_MSE = float("inf") 53 | BEST_MSE = 300 54 | 55 | # load data 56 | MEAN = [0.485, 0.456, 0.406] # mean 57 | STD = [0.229, 0.224, 0.225] # std 58 | normalize = transforms.Normalize( 59 | mean=MEAN, 60 | std=STD 61 | ) 62 | train_transform = transforms.Compose([ 63 | transforms.ToTensor(), 64 | normalize] 65 | ) 66 | val_transform = transforms.Compose([ 67 | transforms.ToTensor(), 68 | normalize] 69 | ) 70 | # define trainloader 71 | train_dataset = CrowdCountingDataset(TRAIN_IMG_DIR, transforms = train_transform, scale = 8, mode = 'train') 72 | train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, num_workers=workers) 73 | # define valloader 74 | val_dataset = CrowdCountingDataset(TEST_IMG_DIR, transforms = val_transform, scale = 8, mode = 'test') 75 | val_loader = DataLoader(val_dataset, batch_size = 1, num_workers=workers) 76 | 77 | # define model 78 | model = PSNet().float() 79 | model = model.to(device) 80 | 81 | # define optimizer 82 | optimizer = optim.Adam(model.parameters(), lr=LR) 83 | 84 | # breakpoint training, load model weights 85 | if resume: 86 | print('==> Resuming from checkpoint..') 87 | assert os.path.isdir('Checkpoint'), 'Error: no Checkpoint directory found!' 88 | state = torch.load('Checkpoint/models/ckpt.pth') 89 | model.load_state_dict(state['net']) 90 | optimizer.load_state_dict(state['optim']) 91 | start_epoch = state['epoch'] 92 | BEST_MAE = state['mae'] 93 | BEST_MSE = state['mse'] 94 | 95 | # loss function 96 | mseloss = nn.MSELoss(reduction='sum').to(device) 97 | cosloss = nn.CosineSimilarity(dim=1, eps=1e-6).to(device) 98 | 99 | # train model 100 | for epoch in range(start_epoch, EPOCH): 101 | print("####################################################################################") 102 | # learning rate scheduling strategy 103 | adjust_learning_rate(optimizer, epoch) 104 | print('Learning rate is {}'.format(optimizer.param_groups[0]['lr'])) 105 | ############################ 106 | # train 107 | ############################ 108 | start_time = time.time() 109 | # train mode 110 | model.train() 111 | # loss 112 | sum_loss = 0.0 113 | sum_att_loss = 0.0 114 | sum_den_loss = 0.0 115 | # number of iterations 116 | cnt = 0 117 | for data in train_loader: 118 | cnt += 1 119 | 120 | # load data 121 | image, gt_density = data 122 | image, gt_density = image.float(), gt_density.float() 123 | image, gt_density = image.to(device), gt_density.to(device) 124 | 125 | # gradient zero 126 | optimizer.zero_grad() 127 | 128 | # forward and backward propagation 129 | pr_density, attention_map_1, attention_map_2, attention_map_3 = model(image) 130 | attention_loss = 0. 131 | for attention_map in (attention_map_1, attention_map_2, attention_map_3): 132 | attention_map_sum = attention_map[:, 0:1] + attention_map[:, 1:2] + attention_map[:, 2:3] +\ 133 | attention_map[:, 3:4] 134 | attention_loss_temp = 0. 135 | for i in range(4): 136 | attention_loss_temp += torch.sum(cosloss(attention_map[:, i:(i+1)].contiguous().view(image.size(0), -1), 137 | ((attention_map_sum-attention_map[:, i:(i+1)])/3).contiguous().view(image.size(0), -1))) / image.size(0) 138 | attention_loss += (attention_loss_temp / 4) 139 | attention_loss /= 3 140 | density_loss = mseloss(pr_density, gt_density) / image.size(0) 141 | loss = density_loss + hyper_param_D*attention_loss 142 | loss.backward() 143 | 144 | # gradient update 145 | optimizer.step() 146 | sum_loss += loss.item() 147 | sum_att_loss += attention_loss.item() 148 | sum_den_loss += density_loss.item() 149 | 150 | # print log 151 | if cnt % 5 == 0 or cnt == len(train_loader): 152 | print('[%d/%d]--[%d/%d]\tLoss: %.4f\tAtt_Loss: %.4f\tDen_Loss: %.4f' 153 | % (epoch + 1, EPOCH, cnt, len(train_loader), sum_loss / cnt, 154 | sum_att_loss / cnt, sum_den_loss / cnt)) 155 | t_loss = sum_loss / cnt 156 | # save model 157 | state = { 158 | 'net': model.state_dict(), 159 | 'optim': optimizer.state_dict(), 160 | 'epoch': epoch, 161 | 'mae': BEST_MAE, 162 | 'mse': BEST_MSE 163 | } 164 | torch.save(state, 'Checkpoint/models/ckpt.pth') 165 | 166 | ############################ 167 | # test 168 | ############################ 169 | # test mode 170 | model.eval() 171 | # loss 172 | mae = 0.0 173 | mse = 0.0 174 | # number of iterations 175 | cnt = 0 176 | with torch.no_grad(): 177 | for data in val_loader: 178 | cnt += 1 179 | 180 | # load data 181 | image, gt, gt_density = data 182 | image, gt_density = image.float(), gt_density.float() 183 | image, gt_density = image.to(device), gt_density.to(device) 184 | 185 | # forward and backward propagation 186 | pr_density, attention_map_1, attention_map_2, attention_map_3 = model(image) 187 | 188 | # record real results and predicted results 189 | pr_density = pr_density.cpu().detach().numpy() 190 | gt_density = gt_density.cpu().detach().numpy() 191 | pr = np.sum(pr_density) 192 | gt = np.sum(gt_density) 193 | mae += np.abs(gt - pr) 194 | mse += np.abs(gt - pr) ** 2 195 | 196 | # calculate loss 197 | mae_loss = mae / cnt 198 | mse_loss = np.sqrt(mse / cnt) 199 | # update best mse, mae 200 | BEST_MSE = min(BEST_MSE, mse_loss) 201 | if BEST_MAE > mae_loss: 202 | BEST_MAE = mae_loss 203 | # save best model 204 | state = { 205 | 'net': model.state_dict(), 206 | 'optim': optimizer.state_dict(), 207 | 'epoch': epoch, 208 | 'mae': BEST_MAE, 209 | 'mse': BEST_MSE 210 | } 211 | torch.save(state, 'Checkpoint/models/ckpt_best.pth') 212 | 213 | # print log 214 | print('[%d/%d]\ttime: %.2f[minute]\tLoss_T: %.4f\tMAE: %.4f\tMSE: %.4f\tBEST_MAE: %.4f\tBEST_MSE: %.4f' 215 | % (epoch + 1, EPOCH, (time.time() - start_time) / 60, t_loss, mae_loss, mse_loss, BEST_MAE, BEST_MSE)) 216 | 217 | 218 | ################################################################################ 219 | # main function 220 | ################################################################################ 221 | if __name__ == '__main__': 222 | # train model 223 | train() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: -*- 2 | """ 3 | utils 4 | """ 5 | from __future__ import print_function, division 6 | import tensorflow as tf 7 | import numpy as np 8 | import scipy.misc 9 | from torch.utils.data import Dataset 10 | import os 11 | import glob 12 | from PIL import Image 13 | import h5py 14 | import torch 15 | import random 16 | import matplotlib.pyplot as plt 17 | import math 18 | from matplotlib import cm as CM 19 | from torch.utils.data import DataLoader 20 | from torchvision import transforms 21 | import torch.nn as nn 22 | 23 | # %matplotlib inline 24 | 25 | try: 26 | from StringIO import StringIO # Python 2.7 27 | except ImportError: 28 | from io import BytesIO # Python 3.x 29 | 30 | 31 | ################################################################################ 32 | # change the learning rate according to epoch. 33 | ################################################################################ 34 | def adjust_learning_rate(optimizer, epoch): 35 | if (epoch + 1) % 100 == 0: 36 | for param_group in optimizer.param_groups: 37 | param_group['lr'] = param_group['lr'] * 0.5 38 | 39 | 40 | ################################################################################ 41 | # set the random seed. 42 | ################################################################################ 43 | def setup_seed(seed): 44 | random.seed(seed) 45 | np.random.seed(seed) 46 | torch.manual_seed(seed) 47 | torch.cuda.manual_seed(seed) 48 | torch.cuda.manual_seed_all(seed) -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | import torch 5 | import torch.nn as nn 6 | import os 7 | import numpy as np 8 | from dataset import CrowdCountingDataset 9 | from model import PSNet 10 | # %matplotlib inline 11 | 12 | 13 | ################################################################################ 14 | # test model 15 | ################################################################################ 16 | def val(): 17 | """ 18 | val model 19 | """ 20 | # set hyperparameter 21 | TEST_IMG_DIR = 'datasets/ShanghaiTech_Dataset/part_A_final/test_data/images' # the directory path for storing test set images 22 | workers = 2 # Number of workers for dataloader 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # choose to run on cpu or cuda 24 | 25 | # bset MAE, MSE 26 | # BEST_MAE = float("inf") 27 | BEST_MAE = 300 28 | # BEST_MSE = float("inf") 29 | BEST_MSE = 300 30 | 31 | # load data 32 | MEAN = [0.485, 0.456, 0.406] # mean 33 | STD = [0.229, 0.224, 0.225] # std 34 | normalize = transforms.Normalize( 35 | mean=MEAN, 36 | std=STD 37 | ) 38 | val_transform = transforms.Compose([ 39 | transforms.ToTensor(), 40 | normalize] 41 | ) 42 | # define valloader 43 | val_dataset = CrowdCountingDataset(TEST_IMG_DIR, transforms = val_transform, scale = 8, mode = 'test') 44 | val_loader = DataLoader(val_dataset, batch_size = 1, num_workers=workers) 45 | 46 | # define model 47 | model = PSNet().float() 48 | model = model.to(device) 49 | 50 | # load model weights 51 | print('==> Resuming from checkpoint..') 52 | assert os.path.isdir('Checkpoint'), 'Error: no Checkpoint directory found!' 53 | state = torch.load('Checkpoint/models/ckpt_best.pth') 54 | model.load_state_dict(state['net']) 55 | BEST_MAE = state['mae'] 56 | BEST_MSE = state['mse'] 57 | epoch = state['epoch'] 58 | 59 | # loss function 60 | cosloss = nn.CosineSimilarity(dim=1, eps=1e-6).to(device) 61 | 62 | ############################ 63 | # test 64 | ############################ 65 | # test mode 66 | model.eval() 67 | # loss 68 | mae = 0.0 69 | mse = 0.0 70 | sum_att_loss = 0.0 71 | # number of iterations 72 | cnt = 0 73 | with torch.no_grad(): 74 | for data in val_loader: 75 | cnt += 1 76 | 77 | # load data 78 | image, gt, gt_density = data 79 | image, gt_density = image.float(), gt_density.float() 80 | image, gt_density = image.to(device), gt_density.to(device) 81 | 82 | # forward and backward propagation 83 | pr_density, attention_map_1, attention_map_2, attention_map_3 = model(image) 84 | 85 | attention_loss = 0. 86 | for attention_map in (attention_map_1, attention_map_2, attention_map_3): 87 | attention_map_sum = attention_map[:, 0:1] + attention_map[:, 1:2] + attention_map[:, 2:3] + \ 88 | attention_map[:, 3:4] 89 | attention_loss_temp = 0. 90 | for i in range(4): 91 | attention_loss_temp += torch.sum( 92 | cosloss(attention_map[:, i:(i + 1)].contiguous().view(image.size(0), -1), 93 | ((attention_map_sum - attention_map[:, i:(i + 1)]) / 3).contiguous().view(image.size(0), -1))) / image.size(0) 94 | attention_loss += (attention_loss_temp / 4) 95 | attention_loss /= 3 96 | sum_att_loss += attention_loss.item() 97 | 98 | # record real results and predicted results 99 | pr_density = pr_density.cpu().detach().numpy() 100 | pr = np.sum(pr_density) 101 | mae += np.abs(gt - pr) 102 | mse += np.abs(gt - pr) ** 2 103 | 104 | # calculate loss 105 | mae_loss = mae / cnt 106 | mse_loss = np.sqrt(mse / cnt) 107 | att_loss = sum_att_loss / cnt 108 | 109 | # print log 110 | print('EPOCH: %d\tMAE: %.4f\tMSE: %.4f\tBEST_MAE: %.4f\tBEST_MSE: %.4f\tATT_LOSS: %.4f' 111 | % (epoch, mae_loss, mse_loss, BEST_MAE, BEST_MSE, att_loss)) 112 | 113 | 114 | ################################################################################ 115 | # main function 116 | ################################################################################ 117 | if __name__ == '__main__': 118 | # test model 119 | val() 120 | 121 | pass 122 | --------------------------------------------------------------------------------