├── README.md ├── datasets ├── __init__.py └── crowd_dmap.py ├── imgs └── arch.jpg ├── losses ├── __init__.py └── losses.py ├── models ├── chsnet.py ├── convolution_module.py └── transformer_module.py ├── preprocess ├── preprocess_dataset_sha_dmap.py ├── preprocess_dataset_shb_dmap.py └── preprocess_dataset_ucf_dmap.py ├── requirements.txt ├── run.sh ├── train.py └── utils ├── __init__.py ├── chsnet_trainer.py ├── helper.py ├── logger.py └── trainer.py /README.md: -------------------------------------------------------------------------------- 1 | # Cross-head Supervision for Crowd Counting with Noisy Annotations 2 | 3 | This is the official implementation for ICASSP 2023 paper: cross-head supervision for crowd counting with noisy annotations. 4 | ![](./imgs/arch.jpg) 5 | 6 | ## to run 7 | 1. `pip install -r requirements.txt` 8 | 2. download dataset to ./DATASET 9 | 3. run preprocess_dataset_xxx_dmap.py in ./preprocess 10 | 4. sh run.sh 11 | 12 | ## use wandb to record experiments result 13 | Refer to [wandb](https://docs.wandb.ai/quickstart) for config. -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/crowd_dmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from glob import glob 4 | 5 | import torch.utils.data as data 6 | import torchvision.transforms.functional as F 7 | from torchvision import transforms 8 | import numpy as np 9 | from PIL import Image 10 | 11 | 12 | def random_crop(im_h, im_w, crop_h, crop_w): 13 | res_h = im_h - crop_h 14 | res_w = im_w - crop_w 15 | i = random.randint(0, res_h) 16 | j = random.randint(0, res_w) 17 | return i, j, crop_h, crop_w 18 | 19 | 20 | class Crowd(data.Dataset): 21 | def __init__(self, root_path, crop_size, 22 | downsample_ratio, is_gray=False, 23 | method='train'): 24 | 25 | self.root_path = root_path 26 | self.im_list = sorted(glob(os.path.join(self.root_path, '*.jpg'))) 27 | 28 | if method not in ['train', 'val']: 29 | raise Exception("not implement") 30 | self.method = method 31 | 32 | self.c_size = crop_size 33 | self.d_ratio = downsample_ratio 34 | assert self.c_size % self.d_ratio == 0 35 | 36 | if is_gray: 37 | self.trans_img = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 40 | ]) 41 | else: 42 | self.trans_img = transforms.Compose([ 43 | transforms.ToTensor(), 44 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 45 | ]) 46 | self.trans_dmap = transforms.ToTensor() 47 | 48 | def __len__(self): 49 | return len(self.im_list) 50 | 51 | def __getitem__(self, item): 52 | img_path = self.im_list[item] 53 | gd_path = img_path.replace('.jpg', '_dmap.npy') 54 | name = os.path.basename(img_path).split('.')[0] 55 | 56 | try: 57 | img = Image.open(img_path).convert('RGB') 58 | dmap = np.load(gd_path) 59 | dmap = dmap.astype(np.float32, copy=False) # np.float64 -> np.float32 to save memory 60 | except: 61 | raise Exception('Image open error {}'.format(name)) 62 | 63 | if self.method == 'train': 64 | return self.train_transform(img, dmap) 65 | elif self.method == 'val': 66 | return self.trans_img(img), np.sum(dmap), name 67 | 68 | def train_transform(self, img, dmap): 69 | dmap = Image.fromarray(dmap) 70 | wd, ht = img.size 71 | # random gray scale augmentation 72 | if random.random() > 0.88: 73 | img = img.convert('L').convert('RGB') 74 | 75 | # rescale augmentation 76 | re_size = random.random() * 0.5 + 0.75 77 | wdd = int(wd*re_size) 78 | htt = int(ht*re_size) 79 | if min(wdd, htt) >= self.c_size: 80 | raw_size = (wd, ht) 81 | wd = wdd 82 | ht = htt 83 | img = img.resize((wd, ht)) 84 | dmap = dmap.resize((wd, ht)) 85 | ratio = (raw_size[0]*raw_size[1])/(wd*ht) 86 | dmap = Image.fromarray(np.array(dmap) * ratio) 87 | 88 | # random crop augmentation 89 | i, j, h, w = random_crop(ht, wd, self.c_size, self.c_size) 90 | img = F.crop(img, i, j, h, w) 91 | dmap = F.crop(dmap, i, j, h, w) 92 | 93 | # random horizontal flip 94 | if random.random() > 0.5: 95 | img = F.hflip(img) 96 | dmap = F.hflip(dmap) 97 | 98 | return self.trans_img(img), self.trans_dmap(dmap) 99 | -------------------------------------------------------------------------------- /imgs/arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RaccoonDML/CHSNet/ecc84d5a23f23d94e4d4d70006c5fcd3c72a4570/imgs/arch.jpg -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | 5 | 6 | class CHSLoss(nn.Module): 7 | def __init__(self, size=8, max_noisy_ratio=0.1, max_weight_ratio=1): 8 | super().__init__() 9 | self.avgpooling = nn.AvgPool2d(kernel_size=size) 10 | self.tot = size * size 11 | self.max_noisy_ratio = max_noisy_ratio 12 | self.max_weight_ratio = max_weight_ratio 13 | 14 | def forward(self, dmap_conv, dmap_tran, gt_density, process): 15 | weight = self.max_weight_ratio * process 16 | noisy_ratio = self.max_noisy_ratio * process 17 | 18 | gt_density = self.avgpooling(gt_density) * self.tot 19 | assert dmap_conv.size() == dmap_tran.size(), f'{dmap_conv.size()},{dmap_tran.size}' 20 | b, c, h, w = dmap_conv.size() 21 | dmap_conv = rearrange(dmap_conv, 'b c h w -> b (c h w)') 22 | dmap_tran = rearrange(dmap_tran, 'b c h w -> b (c h w)') 23 | dmap_gt = rearrange(gt_density, 'b c h w -> b (c h w)') 24 | 25 | error_conv_gt = torch.abs(dmap_gt - dmap_conv) 26 | error_tran_gt = torch.abs(dmap_gt - dmap_tran) 27 | 28 | # weight: dmap_gt ---> dmap_conv/tran 29 | combine_conv_gt = weight * dmap_conv + (1 - weight) * dmap_gt 30 | combine_tran_gt = weight * dmap_tran + (1 - weight) * dmap_gt 31 | 32 | num = int(h * w * noisy_ratio) 33 | if num < 1: 34 | loss = torch.sum((dmap_conv - dmap_gt) ** 2) + torch.sum((dmap_tran - dmap_gt) ** 2) 35 | return loss 36 | 37 | # conv-branch use tran+gt to supervise 38 | v, _ = torch.topk(error_conv_gt, num, dim=-1, largest=True) 39 | v_min = v.min(dim=-1).values 40 | v_min = v_min.unsqueeze(-1) 41 | supervision_from_tran = torch.where(torch.ge(error_conv_gt, v_min), combine_tran_gt, dmap_gt) 42 | mse_conv = (dmap_conv - supervision_from_tran) ** 2 43 | 44 | # tran-branch use conv+gt to supervise 45 | v, _ = torch.topk(error_tran_gt, num, dim=-1, largest=True) 46 | v_min = v.min(dim=-1).values 47 | v_min = v_min.unsqueeze(-1) 48 | supervision_from_conv = torch.where(torch.ge(error_tran_gt, v_min), combine_conv_gt, dmap_gt) 49 | mse_tran = (dmap_tran - supervision_from_conv) ** 2 50 | 51 | loss = torch.sum(mse_conv) + torch.sum(mse_tran) 52 | return loss 53 | -------------------------------------------------------------------------------- /models/chsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import collections 5 | from models.transformer_module import Transformer 6 | from models.convolution_module import ConvBlock, OutputNet 7 | 8 | 9 | class CHSNet(nn.Module): 10 | def __init__(self, dcsize=8, batch_norm=True, load_weights=False): 11 | super().__init__() 12 | self.scale_factor = 16//dcsize 13 | self.encoder = nn.Sequential( 14 | ConvBlock(cin=3, cout=64), 15 | ConvBlock(cin=64, cout=64), 16 | nn.AvgPool2d(kernel_size=2, stride=2), 17 | ConvBlock(cin=64, cout=128), 18 | ConvBlock(cin=128, cout=128), 19 | nn.AvgPool2d(kernel_size=2, stride=2), 20 | ConvBlock(cin=128, cout=256), 21 | ConvBlock(cin=256, cout=256), 22 | ConvBlock(cin=256, cout=256), 23 | nn.AvgPool2d(kernel_size=2, stride=2), 24 | ConvBlock(cin=256, cout=512), 25 | ConvBlock(cin=512, cout=512), 26 | ConvBlock(cin=512, cout=512), 27 | nn.AvgPool2d(kernel_size=2, stride=2), 28 | ConvBlock(cin=512, cout=512), 29 | ConvBlock(cin=512, cout=512), 30 | ConvBlock(cin=512, cout=512), 31 | ) 32 | 33 | self.tran_decoder = Transformer(layers=4) 34 | self.tran_decoder_p2 = OutputNet(dim=512) 35 | 36 | self.conv_decoder = nn.Sequential( 37 | ConvBlock(512, 512, 3, d_rate=2), 38 | ConvBlock(512, 512, 3, d_rate=2), 39 | ConvBlock(512, 512, 3, d_rate=2), 40 | ConvBlock(512, 512, 3, d_rate=2), 41 | ) 42 | self.conv_decoder_p2 = OutputNet(dim=512) 43 | 44 | self._initialize_weights() 45 | if not load_weights: 46 | if batch_norm: 47 | mod = torchvision.models.vgg16_bn(pretrained=True) 48 | else: 49 | mod = torchvision.models.vgg16(pretrained=True) 50 | self._initialize_weights() 51 | fsd = collections.OrderedDict() 52 | for i in range(len(self.encoder.state_dict().items())): 53 | temp_key = list(self.encoder.state_dict().items())[i][0] 54 | fsd[temp_key] = list(mod.state_dict().items())[i][1] 55 | self.encoder.load_state_dict(fsd) 56 | 57 | def _initialize_weights(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | nn.init.kaiming_normal_(m.weight) 61 | if m.bias is not None: 62 | nn.init.constant_(m.bias, 0) 63 | elif isinstance(m, nn.BatchNorm2d): 64 | nn.init.constant_(m.weight, 1) 65 | nn.init.constant_(m.bias, 0) 66 | 67 | def forward(self, x): 68 | raw_x = self.encoder(x) 69 | bs, c, h, w = raw_x.shape 70 | 71 | # path-conv 72 | x = self.conv_decoder(raw_x) 73 | x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=True) 74 | y1 = self.conv_decoder_p2(x) 75 | 76 | # path-transformer 77 | x = raw_x.flatten(2).permute(2, 0, 1) # -> bs c hw -> hw b c 78 | x = self.tran_decoder(x, (h, w)) 79 | x = x.permute(1, 2, 0).view(bs, c, h, w) 80 | x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode='bicubic', align_corners=True) 81 | y2 = self.tran_decoder_p2(x) 82 | 83 | return y1, y2 84 | -------------------------------------------------------------------------------- /models/convolution_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class ConvBlock(nn.Module): 5 | """ 6 | Normal Conv Block with BN & ReLU 7 | """ 8 | 9 | def __init__(self, cin, cout, k_size=3, d_rate=1, batch_norm=True, res_link=False): 10 | super().__init__() 11 | self.res_link = res_link 12 | if batch_norm: 13 | self.body = nn.Sequential( 14 | nn.Conv2d(cin, cout, k_size, padding=d_rate, dilation=d_rate), 15 | nn.BatchNorm2d(cout), 16 | nn.ReLU(inplace=True), 17 | ) 18 | else: 19 | self.body = nn.Sequential( 20 | nn.Conv2d(cin, cout, k_size, padding=d_rate, dilation=d_rate), 21 | nn.ReLU(inplace=True), 22 | ) 23 | 24 | def forward(self, x): 25 | if self.res_link: 26 | return x + self.body(x) 27 | else: 28 | return self.body(x) 29 | 30 | 31 | class OutputNet(nn.Module): 32 | def __init__(self, dim=512): 33 | super().__init__() 34 | self.conv1 = ConvBlock(dim, 256, 3) 35 | self.conv2 = ConvBlock(256, 128, 3) 36 | self.conv3 = ConvBlock(128, 64, 3) 37 | self.conv4 = nn.Sequential( 38 | nn.Conv2d(64, 1, kernel_size=1), 39 | nn.ReLU(True), 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.conv1(x) 44 | x = self.conv2(x) 45 | x = self.conv3(x) 46 | x = self.conv4(x) 47 | return x 48 | -------------------------------------------------------------------------------- /models/transformer_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn, Tensor 4 | from typing import Optional 5 | import copy 6 | 7 | 8 | class GlobalMultiheadAttention(nn.Module): 9 | def __init__(self, embed_dim, num_heads, dropout=0.): 10 | super().__init__() 11 | self.embed_dim = embed_dim # 512 12 | self.num_heads = num_heads # 2 13 | self.dropout = dropout 14 | self.head_dim = embed_dim // num_heads # 256 15 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 16 | self.scaling = self.head_dim ** -0.5 17 | 18 | self.in_proj_q = nn.Linear(in_features=embed_dim, out_features=embed_dim) 19 | self.in_proj_k = nn.Linear(in_features=embed_dim, out_features=embed_dim) 20 | self.in_proj_v = nn.Linear(in_features=embed_dim, out_features=embed_dim) 21 | self.out_proj = nn.Linear(embed_dim, embed_dim) 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | nn.init.xavier_uniform_(self.in_proj_q.weight) 26 | nn.init.xavier_uniform_(self.in_proj_k.weight) 27 | nn.init.xavier_uniform_(self.in_proj_v.weight) 28 | nn.init.xavier_uniform_(self.out_proj.weight) 29 | if self.in_proj_q.bias is not None: 30 | nn.init.constant_(self.in_proj_q.bias, 0.) 31 | nn.init.constant_(self.in_proj_k.bias, 0.) 32 | nn.init.constant_(self.in_proj_v.bias, 0.) 33 | nn.init.constant_(self.out_proj.bias, 0.) 34 | 35 | def forward(self, query, key, shape, value): 36 | # [hw b c] [hw b c] [hw b c] 37 | tgt_len, bsz, embed_dim = query.size() 38 | assert embed_dim == self.embed_dim 39 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 40 | assert key.size() == value.size() 41 | 42 | q = self.in_proj_q(query) # [hw b c] * [c c]-> [hw b c] 43 | k = self.in_proj_k(key) 44 | v = self.in_proj_v(value) 45 | q = q * self.scaling 46 | 47 | q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) # [hw b c] -> [hw b*4 c//4] -> [b*4 hw c//4] 48 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 49 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 50 | 51 | attn_weights = torch.bmm(q, k.transpose(1, 2)) # [4*b hw c//4] * [4*b c//4 hw] -> [4b hw hw] 52 | attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) 53 | attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) 54 | 55 | attn = torch.bmm(attn_weights, v) # [4b hw softmax(hw)] * [4b hw c//4] -> [4b hw c//4] 56 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) # [hw 4b c//4] -> [hw b c] 57 | attn = self.out_proj(attn) 58 | return attn 59 | 60 | 61 | class TransformerEncoderLayer(nn.Module): 62 | 63 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False): 64 | super().__init__() 65 | self.self_attn = GlobalMultiheadAttention(d_model, nhead, dropout=dropout) 66 | # Implementation of Feedforward model 67 | self.linear1 = nn.Linear(d_model, dim_feedforward) 68 | self.dropout = nn.Dropout(dropout) 69 | self.linear2 = nn.Linear(dim_feedforward, d_model) 70 | 71 | self.norm1 = nn.LayerNorm(d_model) 72 | self.norm2 = nn.LayerNorm(d_model) 73 | self.dropout1 = nn.Dropout(dropout) 74 | self.dropout2 = nn.Dropout(dropout) 75 | self.activation = _get_activation_fn(activation) 76 | 77 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 78 | return tensor if pos is None else tensor + pos 79 | 80 | def forward(self, src, shape, pos: Optional[Tensor] = None): 81 | q = k = self.with_pos_embed(src, pos) # src: hw b c 82 | src2 = self.self_attn(q, k, shape, src) 83 | src = src + self.dropout1(src2) 84 | src = self.norm1(src) 85 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 86 | src = src + self.dropout2(src2) 87 | src = self.norm2(src) 88 | return src 89 | 90 | 91 | class Transformer(nn.Module): 92 | def __init__(self, layers=4, dim=512, norm=None): 93 | super().__init__() 94 | d_model = dim 95 | nhead = 2 96 | dim_feedforward = 2048 97 | dropout = 0.1 98 | activation = "relu" 99 | normalize_before = False 100 | 101 | self.layers = nn.ModuleList([copy.deepcopy(TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before)) 102 | for i in range(layers)]) 103 | self.norm = norm 104 | 105 | def forward(self, src, shape, pos: Optional[Tensor] = None): 106 | output = src 107 | for layer in self.layers: 108 | output = layer(output, shape, pos) 109 | if self.norm is not None: 110 | output = self.norm(output) 111 | return output 112 | 113 | 114 | def _get_activation_fn(activation): 115 | """Return an activation function given a string""" 116 | if activation == "relu": 117 | return F.relu 118 | if activation == "gelu": 119 | return F.gelu 120 | if activation == "glu": 121 | return F.glu 122 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 123 | 124 | -------------------------------------------------------------------------------- /preprocess/preprocess_dataset_sha_dmap.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | from glob import glob 6 | import cv2 7 | import argparse 8 | 9 | import scipy 10 | import scipy.spatial 11 | import scipy.ndimage 12 | 13 | def cal_new_size(im_h, im_w, min_size, max_size): 14 | if im_h < im_w: 15 | if im_h < min_size: 16 | ratio = 1.0 * min_size / im_h 17 | im_h = min_size 18 | im_w = round(im_w*ratio) 19 | elif im_h > max_size: 20 | ratio = 1.0 * max_size / im_h 21 | im_h = max_size 22 | im_w = round(im_w*ratio) 23 | else: 24 | ratio = 1.0 25 | else: 26 | if im_w < min_size: 27 | ratio = 1.0 * min_size / im_w 28 | im_w = min_size 29 | im_h = round(im_h*ratio) 30 | elif im_w > max_size: 31 | ratio = 1.0 * max_size / im_w 32 | im_w = max_size 33 | im_h = round(im_h*ratio) 34 | else: 35 | ratio = 1.0 36 | return im_h, im_w, ratio 37 | 38 | 39 | def find_dis(point): 40 | square = np.sum(point*points, axis=1) 41 | dis = np.sqrt(np.maximum(square[:, None] - 2*np.matmul(point, point.T) + square[None, :], 0.0)) 42 | # 快速排序的划分函数,找出第0,1,2,3近的四个点,第0个是自己 43 | dis = np.mean(np.partition(dis, 3, axis=1)[:, 1:4], axis=1, keepdims=True) 44 | return dis 45 | 46 | 47 | # this is borrowed from https://github.com/davideverona/deep-crowd-counting_crowdnet 48 | def generate_adaptive_dmap_from_point(image, points): 49 | im_w, im_h = image.size 50 | dmap = np.zeros((im_w, im_h)) 51 | if len(points)==0: 52 | return dmap.T 53 | else: 54 | for point in points: 55 | pt2d = np.zeros((im_w, im_h), dtype=np.float32) 56 | pt2d[int(point[0]), int(point[1])] += 1 57 | sigma = min(point[2], 15) 58 | tmp_dmap = scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 59 | # rectify border error 60 | ct = np.sum(tmp_dmap) 61 | if abs(ct - 1) > 0.001: 62 | tmp_dmap *= (1 / ct) 63 | dmap += tmp_dmap 64 | return dmap.T 65 | 66 | 67 | def generate_dmap_from_point(image, points): 68 | im_w, im_h = image.size 69 | dmap = np.zeros((im_w, im_h)) 70 | for point in points: 71 | dmap[min(int(point[0]), im_w-1), min(int(point[1]), im_h-1)] += 1 72 | density_map = scipy.ndimage.filters.gaussian_filter(dmap, 15, mode='reflect') 73 | assert(abs(len(points)-np.sum(dmap)) < 1e-2) 74 | return density_map.T 75 | 76 | 77 | def generate_data(im_path): 78 | im = Image.open(im_path) 79 | im_w, im_h = im.size 80 | mat_path = im_path.replace('images','ground-truth').replace('IMG','GT_IMG').replace('.jpg', '.mat') 81 | points = loadmat(mat_path) 82 | points = points["image_info"][0,0][0,0][0].astype(np.float32) 83 | idx_mask = (points[:, 0] >= 0) * (points[:, 0] <= im_w) * (points[:, 1] >= 0) * (points[:, 1] <= im_h) 84 | # 过滤掉错误的标注 85 | points = points[idx_mask] 86 | im_h, im_w, rr = cal_new_size(im_h, im_w, min_size, max_size) 87 | im = np.array(im) 88 | if rr != 1.0: 89 | im = cv2.resize(im, (im_w, im_h), cv2.INTER_CUBIC) 90 | points = points * rr 91 | return Image.fromarray(im), points 92 | 93 | 94 | def parse_args(): 95 | parser = argparse.ArgumentParser(description='Test') 96 | parser.add_argument('--origin-dir', default='./DATASET/ShanghaiTech/part_A', 97 | help='original data directory') 98 | parser.add_argument('--data-dir', default='./DATASET/SHA-train-test-dmapfix15', 99 | help='processed data directory') 100 | args = parser.parse_args() 101 | return args 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | save_dir = args.data_dir 106 | min_size = 512 107 | max_size = 2048 108 | 109 | for phase in ['train', 'test']: 110 | sub_dir = os.path.join(args.origin_dir, phase+'_data') 111 | sub_save_dir = os.path.join(save_dir, phase) 112 | if not os.path.exists(sub_save_dir): 113 | os.makedirs(sub_save_dir) 114 | 115 | im_list = glob(os.path.join(sub_dir, 'images', '*jpg')) 116 | for im_path in im_list: 117 | name = os.path.basename(im_path) 118 | print(phase + '-' + name) 119 | # 图像缩放,点标注提取和过滤 120 | im, points = generate_data(im_path) 121 | 122 | # 保存图像 123 | im_save_path = os.path.join(sub_save_dir, name) 124 | im.save(im_save_path) 125 | 126 | # 生成并保存密度图 127 | dmap = generate_dmap_from_point(im, points) 128 | # or dmap = generate_adaptive_dmap_from_point(im, points) 129 | dmap_save_path = im_save_path.replace('.jpg', '_dmap.npy') 130 | np.save(dmap_save_path, dmap) 131 | 132 | # 保存点标注 133 | # if phase == 'train': # for MAN BL-loss point annotation 134 | dis = find_dis(points) 135 | points = np.concatenate((points, dis), axis=1) # N,2 -> N,3 136 | gd_save_path = im_save_path.replace('jpg', 'npy') 137 | np.save(gd_save_path, points) 138 | -------------------------------------------------------------------------------- /preprocess/preprocess_dataset_shb_dmap.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | from glob import glob 6 | import cv2 7 | import argparse 8 | 9 | import scipy 10 | import scipy.spatial 11 | import scipy.ndimage 12 | 13 | def cal_new_size(im_h, im_w, min_size, max_size): 14 | if im_h < im_w: 15 | if im_h < min_size: 16 | ratio = 1.0 * min_size / im_h 17 | im_h = min_size 18 | im_w = round(im_w*ratio) 19 | elif im_h > max_size: 20 | ratio = 1.0 * max_size / im_h 21 | im_h = max_size 22 | im_w = round(im_w*ratio) 23 | else: 24 | ratio = 1.0 25 | else: 26 | if im_w < min_size: 27 | ratio = 1.0 * min_size / im_w 28 | im_w = min_size 29 | im_h = round(im_h*ratio) 30 | elif im_w > max_size: 31 | ratio = 1.0 * max_size / im_w 32 | im_w = max_size 33 | im_h = round(im_h*ratio) 34 | else: 35 | ratio = 1.0 36 | return im_h, im_w, ratio 37 | 38 | 39 | def find_dis(point): 40 | square = np.sum(point*points, axis=1) 41 | dis = np.sqrt(np.maximum(square[:, None] - 2*np.matmul(point, point.T) + square[None, :], 0.0)) 42 | # 快速排序的划分函数,找出第0,1,2,3近的四个点,第0个是自己 43 | dis = np.mean(np.partition(dis, 3, axis=1)[:, 1:4], axis=1, keepdims=True) 44 | return dis 45 | 46 | 47 | # this is borrowed from https://github.com/davideverona/deep-crowd-counting_crowdnet 48 | def generate_adaptive_dmap_from_point(image, points): 49 | im_w, im_h = image.size 50 | dmap = np.zeros((im_w, im_h)) 51 | if len(points)==0: 52 | return dmap.T 53 | else: 54 | for point in points: 55 | pt2d = np.zeros((im_w, im_h), dtype=np.float32) 56 | pt2d[int(point[0]), int(point[1])] += 1 57 | sigma = min(point[2], 15) 58 | tmp_dmap = scipy.ndimage.filters.gaussian_filter(pt2d, sigma, mode='constant') 59 | # rectify border error 60 | ct = np.sum(tmp_dmap) 61 | if abs(ct - 1) > 0.001: 62 | tmp_dmap *= (1 / ct) 63 | dmap += tmp_dmap 64 | return dmap.T 65 | 66 | 67 | def generate_dmap_from_point(image, points): 68 | im_w, im_h = image.size 69 | dmap = np.zeros((im_w, im_h)) 70 | for point in points: 71 | dmap[min(int(point[0]), im_w-1), min(int(point[1]), im_h-1)] += 1 72 | density_map = scipy.ndimage.filters.gaussian_filter(dmap, 15, mode='reflect') 73 | assert(abs(len(points)-np.sum(dmap)) < 1e-2) 74 | return density_map.T 75 | 76 | 77 | def generate_data(im_path): 78 | im = Image.open(im_path) 79 | im_w, im_h = im.size 80 | mat_path = im_path.replace('images','ground-truth').replace('IMG','GT_IMG').replace('.jpg', '.mat') 81 | points = loadmat(mat_path) 82 | points = points["image_info"][0,0][0,0][0].astype(np.float32) 83 | idx_mask = (points[:, 0] >= 0) * (points[:, 0] <= im_w) * (points[:, 1] >= 0) * (points[:, 1] <= im_h) 84 | # 过滤掉错误的标注 85 | points = points[idx_mask] 86 | im_h, im_w, rr = cal_new_size(im_h, im_w, min_size, max_size) 87 | im = np.array(im) 88 | if rr != 1.0: 89 | im = cv2.resize(im, (im_w, im_h), cv2.INTER_CUBIC) 90 | points = points * rr 91 | return Image.fromarray(im), points 92 | 93 | 94 | def parse_args(): 95 | parser = argparse.ArgumentParser(description='Test') 96 | parser.add_argument('--origin-dir', default='./DATASET/ShanghaiTech/part_B', 97 | help='original data directory') 98 | parser.add_argument('--data-dir', default='./DATASET/SHB-train-test-dmapfix15', 99 | help='processed data directory') 100 | args = parser.parse_args() 101 | return args 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | save_dir = args.data_dir 106 | min_size = 512 107 | max_size = 2048 108 | 109 | for phase in ['train', 'test']: 110 | sub_dir = os.path.join(args.origin_dir, phase+'_data') 111 | sub_save_dir = os.path.join(save_dir, phase) 112 | if not os.path.exists(sub_save_dir): 113 | os.makedirs(sub_save_dir) 114 | 115 | im_list = glob(os.path.join(sub_dir, 'images', '*jpg')) 116 | for im_path in im_list: 117 | name = os.path.basename(im_path) 118 | print(phase + '-' + name) 119 | # 图像缩放,点标注提取和过滤 120 | im, points = generate_data(im_path) 121 | 122 | # 保存图像 123 | im_save_path = os.path.join(sub_save_dir, name) 124 | im.save(im_save_path) 125 | 126 | # 生成并保存密度图 127 | dmap = generate_dmap_from_point(im, points) 128 | # or dmap = generate_adaptive_dmap_from_point(im, points) 129 | dmap_save_path = im_save_path.replace('.jpg', '_dmap.npy') 130 | np.save(dmap_save_path, dmap) 131 | 132 | # 保存点标注 133 | # if phase == 'train': # for MAN BL-loss point annotation 134 | dis = find_dis(points) 135 | points = np.concatenate((points, dis), axis=1) # N,2 -> N,3 136 | gd_save_path = im_save_path.replace('jpg', 'npy') 137 | np.save(gd_save_path, points) 138 | -------------------------------------------------------------------------------- /preprocess/preprocess_dataset_ucf_dmap.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | from glob import glob 6 | import cv2 7 | import argparse 8 | 9 | import scipy 10 | import scipy.spatial 11 | import scipy.ndimage 12 | 13 | def cal_new_size(im_h, im_w, min_size, max_size): 14 | if im_h < im_w: 15 | if im_h < min_size: 16 | ratio = 1.0 * min_size / im_h 17 | im_h = min_size 18 | im_w = round(im_w*ratio) 19 | elif im_h > max_size: 20 | ratio = 1.0 * max_size / im_h 21 | im_h = max_size 22 | im_w = round(im_w*ratio) 23 | else: 24 | ratio = 1.0 25 | else: 26 | if im_w < min_size: 27 | ratio = 1.0 * min_size / im_w 28 | im_w = min_size 29 | im_h = round(im_h*ratio) 30 | elif im_w > max_size: 31 | ratio = 1.0 * max_size / im_w 32 | im_w = max_size 33 | im_h = round(im_h*ratio) 34 | else: 35 | ratio = 1.0 36 | return im_h, im_w, ratio 37 | 38 | 39 | def find_dis(point): 40 | square = np.sum(point*points, axis=1) 41 | dis = np.sqrt(np.maximum(square[:, None] - 2*np.matmul(point, point.T) + square[None, :], 0.0)) 42 | dis = np.mean(np.partition(dis, 3, axis=1)[:, 1:4], axis=1, keepdims=True) 43 | return dis 44 | 45 | 46 | def generate_dmap_from_point(image, points): 47 | im_w, im_h = image.size 48 | dmap = np.zeros((im_w, im_h)) 49 | for point in points: 50 | dmap[min(int(point[0]), im_w-1), min(int(point[1]), im_h-1)] += 1 51 | density_map = scipy.ndimage.filters.gaussian_filter(dmap, 15, mode='reflect') 52 | assert(abs(len(points)-np.sum(dmap)) < 1e-2) 53 | return density_map.T 54 | 55 | 56 | def generate_data(im_path): 57 | im = Image.open(im_path) 58 | im_w, im_h = im.size 59 | mat_path = im_path.replace('.jpg', '_ann.mat') 60 | points = loadmat(mat_path)['annPoints'].astype(np.float32) 61 | idx_mask = (points[:, 0] >= 0) * (points[:, 0] <= im_w) * (points[:, 1] >= 0) * (points[:, 1] <= im_h) 62 | points = points[idx_mask] 63 | im_h, im_w, rr = cal_new_size(im_h, im_w, min_size, max_size) 64 | im = np.array(im) 65 | if rr != 1.0: 66 | im = cv2.resize(im, (im_w, im_h), cv2.INTER_CUBIC) 67 | points = points * rr 68 | return Image.fromarray(im), points 69 | 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser(description='Test') 73 | parser.add_argument('--origin-dir', default='./DATASET/UCF-QNRF_ECCV18', 74 | help='original data directory') 75 | parser.add_argument('--data-dir', default='./DATASET/QNRF-trainfull-test-dmapfix15', 76 | help='processed data directory') 77 | args = parser.parse_args() 78 | return args 79 | 80 | if __name__ == '__main__': 81 | args = parse_args() 82 | save_dir = args.data_dir 83 | min_size = 512 84 | max_size = 2048 85 | 86 | for phase in ['Train', 'Test']: 87 | sub_dir = os.path.join(args.origin_dir, phase) 88 | if phase == 'Train': 89 | sub_save_dir = os.path.join(save_dir, 'train') 90 | if not os.path.exists(sub_save_dir): 91 | os.makedirs(sub_save_dir) 92 | im_list = glob(os.path.join(sub_dir, '*jpg')) 93 | for im_path in im_list: 94 | name = os.path.basename(im_path) 95 | print(name) 96 | im, points = generate_data(im_path) 97 | dis = find_dis(points) 98 | points = np.concatenate((points, dis), axis=1) # N,2 -> N,3 99 | 100 | im_save_path = os.path.join(sub_save_dir, name) 101 | im.save(im_save_path) 102 | gd_save_path = im_save_path.replace('jpg', 'npy') 103 | np.save(gd_save_path, points) 104 | 105 | dmap = generate_dmap_from_point(im, points) 106 | dmap_save_path = im_save_path.replace('.jpg', '_dmap.npy') 107 | np.save(dmap_save_path, dmap) 108 | else: 109 | sub_save_dir = os.path.join(save_dir, 'test') 110 | if not os.path.exists(sub_save_dir): 111 | os.makedirs(sub_save_dir) 112 | im_list = glob(os.path.join(sub_dir, '*jpg')) 113 | for im_path in im_list: 114 | name = os.path.basename(im_path) 115 | print(name) 116 | im, points = generate_data(im_path) 117 | 118 | im_save_path = os.path.join(sub_save_dir, name) 119 | im.save(im_save_path) 120 | gd_save_path = im_save_path.replace('jpg', 'npy') 121 | np.save(gd_save_path, points) 122 | 123 | dmap = generate_dmap_from_point(im, points).T 124 | dmap_save_path = im_save_path.replace('.jpg', '_dmap.npy') 125 | np.save(dmap_save_path, dmap) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch==1.7.0 2 | torchvision==0.8.0 3 | tqdm==4.64.0 4 | einops==0.4.1 5 | wandb==0.13.3 6 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # train on QNRF 2 | python train.py --tag CHSNet-qnrf --no-wandb --device 0 --max-noisy-ratio 0.05 --max-weight-ratio 0.5 --scheduler cosine --dcsize 4 --batch-size 8 --lr 4e-5 --data-dir ../DATASET/QNRF-trainfull-test-dmapfix15 --val-start 200 --val-epoch 5 3 | # train on SHA 4 | python train.py --tag CHSNet-sha --no-wandb --device 1 --max-noisy-ratio 0.10 --max-weight-ratio 1.0 --scheduler cosine --dcsize 2 --batch-size 8 --lr 4e-5 --data-dir ../DATASET/SHA-train-test-dmapfix15 --val-start 200 --val-epoch 5 5 | # train on SHB 6 | python train.py --tag CHSNet-shb --no-wandb --device 2 --max-noisy-ratio 0.05 --max-weight-ratio 1.0 --scheduler cosine --dcsize 4 --batch-size 8 --lr 4e-5 --data-dir ../DATASET/SHB-train-test-dmapfix15 --val-start 200 --val-epoch 5 7 | 8 | 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | import random 6 | import wandb 7 | 8 | from utils.chsnet_trainer import CHSNetTrainer 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Train ') 12 | parser.add_argument('--tag', default='chsnet', help='tag of training') 13 | parser.add_argument('--device', default='0', help='assign device') 14 | parser.add_argument('--no-wandb', action='store_true', default=False, help='whether to use wandb') 15 | 16 | parser.add_argument('--data-dir', default=r'../DATASET/QNRF-trainfull-test-dmapfix15', help='training data directory') 17 | parser.add_argument('--log-param', type=float, default=100.0, help='dmap scale factor') 18 | parser.add_argument('--is-gray', type=bool, default=False, help='whether the input image is gray') 19 | parser.add_argument('--crop-size', type=int, default=512, help='the crop size of the train image') 20 | parser.add_argument('--downsample-ratio', type=int, default=16, help='downsample ratio') 21 | parser.add_argument('--dcsize', type=int, default=4, help='divide count size for density map') 22 | parser.add_argument('--max-noisy-ratio', type=float, default=0.1, help='for chsloss') 23 | parser.add_argument('--max-weight-ratio', type=float, default=1, help='for chsloss') 24 | 25 | parser.add_argument('--lr', type=float, default=4*1e-5, help='the initial learning rate') 26 | parser.add_argument('--batch-size', type=int, default=1, help='train batch size') 27 | parser.add_argument('--num-workers', type=int, default=4, help='the num of training process') 28 | parser.add_argument('--weight-decay', type=float, default=1e-5, help='the weight decay') 29 | parser.add_argument('--max-epoch', type=int, default=1000, help='max training epoch') 30 | parser.add_argument('--val-epoch', type=int, default=5, help='the num of steps to log training information') 31 | parser.add_argument('--val-start', type=int, default=200, help='the epoch start to val') 32 | 33 | parser.add_argument('--scheduler', type=str, default='step', help='or cosine') 34 | parser.add_argument('--step', type=int, default=400) 35 | parser.add_argument('--gamma', type=float, default=0.5) 36 | parser.add_argument('--t-max', type=int, default=200, help='for consine scheduler') 37 | parser.add_argument('--eta-min', type=float, default=4*1e-6, help='for consine scheduler') 38 | 39 | parser.add_argument('--save-dir', default='./checkpoint', help='directory to save models.') 40 | parser.add_argument('--save-all', type=bool, default=False, help='whether to save all best model') 41 | parser.add_argument('--max-model-num', type=int, default=1, help='max models num to save ') 42 | parser.add_argument('--resume', default='', help='the path of resume training model') 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def setup_seed(seed): 48 | torch.manual_seed(seed) 49 | torch.cuda.manual_seed_all(seed) 50 | np.random.seed(seed) 51 | random.seed(seed) 52 | if seed == 0: # reproducible but slow 53 | torch.backends.cudnn.benchmark = False # false by default, slow 54 | torch.backends.cudnn.deterministic = True # Whether to use deterministic convolution algorithm? false by default. 55 | else: # fast 56 | torch.backends.cudnn.benchmark = True 57 | 58 | 59 | if __name__ == '__main__': 60 | setup_seed(43) 61 | args = parse_args() 62 | os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip() # set vis gpu 63 | if args.no_wandb: 64 | wandb.init(mode="disabled") 65 | else: 66 | wandb.init(project="CHSNet", name=args.tag, config=vars(args)) 67 | trainer = CHSNetTrainer(args) 68 | trainer.setup() 69 | trainer.train() 70 | wandb.finish() 71 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/chsnet_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import logging 5 | from math import ceil 6 | 7 | import torch 8 | from torch import optim 9 | from torch.utils.data import DataLoader 10 | 11 | from datasets.crowd_dmap import Crowd 12 | from models.chsnet import CHSNet 13 | from losses.losses import CHSLoss 14 | from utils.trainer import Trainer 15 | from utils.helper import Save_Handle, AverageMeter 16 | 17 | import numpy as np 18 | from tqdm import tqdm 19 | import wandb 20 | 21 | # sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 22 | 23 | 24 | def train_collate(batch): 25 | transposed_batch = list(zip(*batch)) 26 | images = torch.stack(transposed_batch[0], 0) 27 | dmaps = torch.stack(transposed_batch[1], 0) 28 | return images, dmaps 29 | 30 | 31 | class CHSNetTrainer(Trainer): 32 | def setup(self): 33 | """initial the datasets, model, loss and optimizer""" 34 | args = self.args 35 | if torch.cuda.is_available(): 36 | self.device = torch.device("cuda") 37 | else: 38 | raise Exception("gpu is not available") 39 | 40 | train_datasets = Crowd(os.path.join(args.data_dir, 'train'), 41 | args.crop_size, 42 | args.downsample_ratio, 43 | args.is_gray, method='train') 44 | train_dataloaders = DataLoader(train_datasets, 45 | collate_fn=train_collate, 46 | batch_size=args.batch_size, 47 | shuffle=True, 48 | num_workers=args.num_workers, 49 | pin_memory=True) 50 | val_datasets = Crowd(os.path.join(args.data_dir, 'test'), 512, 8, is_gray=False, method='val') 51 | val_dataloaders = torch.utils.data.DataLoader(val_datasets, 1, shuffle=False, 52 | num_workers=args.num_workers, pin_memory=True) 53 | self.dataloaders = {'train': train_dataloaders, 'val': val_dataloaders} 54 | 55 | self.model = CHSNet(dcsize=args.dcsize) 56 | self.model.to(self.device) 57 | 58 | self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 59 | 60 | self.criterion = CHSLoss(size=args.dcsize, max_noisy_ratio=args.max_noisy_ratio, max_weight_ratio=args.max_weight_ratio) 61 | 62 | self.save_list = Save_Handle(max_num=args.max_model_num) 63 | self.best_mae = np.inf 64 | self.best_mse = np.inf 65 | self.best_mae_at = 0 66 | self.best_count = 0 67 | 68 | self.start_epoch = 0 69 | if args.resume: 70 | suf = args.resume.rsplit('.', 1)[-1] 71 | if suf == 'tar': 72 | checkpoint = torch.load(args.resume, self.device) 73 | self.model.load_state_dict(checkpoint['model_state_dict']) 74 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 75 | self.start_epoch = checkpoint['epoch'] + 1 76 | self.best_mae = checkpoint['best_mae'] 77 | self.best_mse = checkpoint['best_mse'] 78 | self.best_mae_at = checkpoint['best_mae_at'] 79 | elif suf == 'pth': 80 | self.model.load_state_dict(torch.load(args.resume, self.device)) 81 | 82 | if args.scheduler == 'step': 83 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=args.step, gamma=args.gamma, last_epoch=self.start_epoch-1) 84 | elif args.scheduler == 'cosine': 85 | self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=args.t_max, eta_min=args.eta_min, last_epoch=self.start_epoch-1) 86 | 87 | def train(self): 88 | args = self.args 89 | self.epoch = None 90 | # self.val_epoch() 91 | for epoch in range(self.start_epoch, args.max_epoch): 92 | logging.info('-' * 5 + 'Epoch {}/{}'.format(epoch, args.max_epoch - 1) + '-' * 5) 93 | self.epoch = epoch 94 | self.train_epoch() 95 | self.scheduler.step() 96 | if epoch >= args.val_start and (epoch % args.val_epoch == 0 or epoch == args.max_epoch - 1): 97 | self.val_epoch() 98 | 99 | def train_epoch(self): 100 | epoch_loss = AverageMeter() 101 | epoch_mae = AverageMeter() 102 | epoch_mse = AverageMeter() 103 | epoch_start = time.time() 104 | self.model.train() 105 | 106 | # Iterate over data. 107 | for inputs, targets in tqdm(self.dataloaders['train']): 108 | inputs = inputs.to(self.device) 109 | targets = targets.to(self.device) * self.args.log_param 110 | 111 | with torch.set_grad_enabled(True): 112 | dmap_conv, dmap_tran = self.model(inputs) 113 | loss = self.criterion(dmap_conv, dmap_tran, targets, self.epoch/self.args.max_epoch) 114 | dmap = (dmap_conv + dmap_tran) / 2.0 115 | 116 | self.optimizer.zero_grad() 117 | loss.backward() 118 | self.optimizer.step() 119 | 120 | N = inputs.size(0) 121 | pre_count = torch.sum(dmap.view(N, -1), dim=1).detach().cpu().numpy() 122 | gd_count = torch.sum(targets.view(N, -1), dim=1).detach().cpu().numpy() 123 | res = pre_count - gd_count 124 | epoch_loss.update(loss.item(), N) 125 | epoch_mse.update(np.mean(res * res), N) 126 | epoch_mae.update(np.mean(abs(res)), N) 127 | 128 | logging.info('Epoch {} Train, Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec' 129 | .format(self.epoch, epoch_loss.get_avg(), np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(), 130 | time.time() - epoch_start)) 131 | wandb.log({'Train/loss': epoch_loss.get_avg(), 132 | 'Train/lr': self.scheduler.get_last_lr()[0], 133 | 'Train/epoch_mae': epoch_mae.get_avg()}, step=self.epoch) 134 | 135 | model_state_dic = self.model.state_dict() 136 | save_path = os.path.join(self.save_dir, '{}_ckpt.tar'.format(self.epoch)) 137 | torch.save({ 138 | 'epoch': self.epoch, 139 | 'optimizer_state_dict': self.optimizer.state_dict(), 140 | 'model_state_dict': model_state_dic, 141 | 'best_mae': self.best_mae, 142 | 'best_mse': self.best_mse, 143 | 'best_mae_at': self.best_mae_at, 144 | }, save_path) 145 | self.save_list.append(save_path) # control the number of saved models 146 | 147 | def val_epoch(self): 148 | epoch_start = time.time() 149 | self.model.eval() 150 | epoch_res = [] 151 | 152 | for inputs, count, name in tqdm(self.dataloaders['val']): 153 | inputs = inputs.to(self.device) 154 | # inputs are images with different sizes 155 | b, c, h, w = inputs.shape 156 | h, w = int(h), int(w) 157 | assert b == 1, 'the batch size should equal to 1 in validation mode' 158 | 159 | max_size = 2000 160 | if h > max_size or w > max_size: 161 | h_stride = int(ceil(1.0 * h / max_size)) 162 | w_stride = int(ceil(1.0 * w / max_size)) 163 | h_step = h // h_stride 164 | w_step = w // w_stride 165 | input_list = [] 166 | for i in range(h_stride): 167 | for j in range(w_stride): 168 | h_start = i * h_step 169 | if i != h_stride - 1: 170 | h_end = (i + 1) * h_step 171 | else: 172 | h_end = h 173 | w_start = j * w_step 174 | if j != w_stride - 1: 175 | w_end = (j + 1) * w_step 176 | else: 177 | w_end = w 178 | input_list.append(inputs[:, :, h_start:h_end, w_start:w_end]) 179 | with torch.set_grad_enabled(False): 180 | pre_count = 0.0 181 | for input in input_list: 182 | output1, output2 = self.model(input) 183 | pre_count += (torch.sum(output1) + torch.sum(output2)) / 2 184 | else: 185 | with torch.set_grad_enabled(False): 186 | output1, output2 = self.model(inputs) 187 | pre_count = (torch.sum(output1) + torch.sum(output2)) / 2 188 | 189 | epoch_res.append(count[0].item() - pre_count.item() / self.args.log_param) 190 | 191 | epoch_res = np.array(epoch_res) 192 | mse = np.sqrt(np.mean(np.square(epoch_res))) 193 | mae = np.mean(np.abs(epoch_res)) 194 | 195 | logging.info('Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec' 196 | .format(self.epoch, mse, mae, time.time() - epoch_start)) 197 | 198 | model_state_dic = self.model.state_dict() 199 | if mae < self.best_mae: 200 | self.best_mse = mse 201 | self.best_mae = mae 202 | self.best_mae_at = self.epoch 203 | logging.info("SAVE best mse {:.2f} mae {:.2f} model @epoch {}".format(self.best_mse, self.best_mae, self.epoch)) 204 | if self.args.save_all: 205 | torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model_{}.pth'.format(self.best_count))) 206 | self.best_count += 1 207 | else: 208 | torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model.pth')) 209 | 210 | logging.info("best mae {:.2f} mse {:.2f} @epoch {}".format(self.best_mae, self.best_mse, self.best_mae_at)) 211 | 212 | if self.epoch is not None: 213 | wandb.log({'Val/bestMAE': self.best_mae, 214 | 'Val/MAE': mae, 215 | 'Val/MSE': mse, 216 | }, step=self.epoch) 217 | 218 | 219 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Save_Handle(object): 5 | """handle the number of """ 6 | def __init__(self, max_num): 7 | self.save_list = [] 8 | self.max_num = max_num 9 | 10 | def append(self, save_path): 11 | if len(self.save_list) < self.max_num: 12 | self.save_list.append(save_path) 13 | else: 14 | remove_path = self.save_list[0] 15 | del self.save_list[0] 16 | self.save_list.append(save_path) 17 | if os.path.exists(remove_path): 18 | os.remove(remove_path) 19 | 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value""" 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = 1.0 * self.sum / self.count 37 | 38 | def get_avg(self): 39 | return self.avg 40 | 41 | def get_count(self): 42 | return self.count 43 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def setlogger(path): 4 | logger = logging.getLogger() 5 | logger.setLevel(logging.INFO) 6 | logFormatter = logging.Formatter("%(asctime)s %(message)s", 7 | "%m-%d %H:%M:%S") 8 | 9 | fileHandler = logging.FileHandler(path) 10 | fileHandler.setFormatter(logFormatter) 11 | logger.addHandler(fileHandler) 12 | 13 | consoleHandler = logging.StreamHandler() 14 | consoleHandler.setFormatter(logFormatter) 15 | logger.addHandler(consoleHandler) 16 | 17 | -------------------------------------------------------------------------------- /utils/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import datetime 4 | from utils.logger import setlogger 5 | 6 | 7 | class Trainer(object): 8 | def __init__(self, args): 9 | # sub_dir = datetime.strftime(datetime.now(), '%m%d-%H%M%S') # prepare saving path 10 | sub_dir = datetime.strftime(datetime.now(), '%m%d') # prepare saving path 11 | self.save_dir = os.path.join(args.save_dir, sub_dir+'_'+args.tag) 12 | if not os.path.exists(self.save_dir): 13 | os.makedirs(self.save_dir) 14 | setlogger(os.path.join(self.save_dir, 'train.log')) # set logger 15 | for k, v in args.__dict__.items(): # save args 16 | logging.info("{}: {}".format(k, v)) 17 | self.args = args 18 | 19 | def setup(self): 20 | """initial the datasets, model, loss and optimizer""" 21 | pass 22 | 23 | def train(self): 24 | """training one epoch""" 25 | pass 26 | --------------------------------------------------------------------------------