├── .gitignore ├── .gitmodules ├── README.md ├── config.py ├── data.py ├── evaluate.py ├── export.py ├── hourglass.py ├── imgs ├── 000019.jpg └── 000019 │ └── image.png ├── loss.py ├── main.py ├── optim.py ├── requirements.txt ├── train.py ├── transform.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | *.sh 3 | *.pth 4 | __pycache__ 5 | DATA/ 6 | WEIGHTS/ 7 | .ipynb_checkpoints/ 8 | arial.ttf 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mAP"] 2 | path = mAP 3 | url = https://github.com/tyui592/mAP.git 4 | [submodule "PytorchToCpp"] 5 | path = PytorchToCpp 6 | url = https://github.com/tyui592/PytorchToCpp.git 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Real Time Helmet Detection 2 | == 3 | This is pytorch implementation of helmet detector based on [CenterNet](https://arxiv.org/abs/1904.07850). 4 | 5 | I used the [SafetyHelmetWearing-Dataset(SHWD)](https://github.com/njvisionpower/Safety-Helmet-Wearing-Dataset) to detect helmet or person. 6 | 7 | I will continue to update the entries to be filled in [TODO](https://github.com/tyui592/Real_Time_Helmet_Detection/wiki/TODO) for research in [`nightly`](https://github.com/tyui592/Real_Time_Helmet_Detection/tree/nightly) 8 | 9 | [![Open RTHD in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-zz9z9_irTNtvsHBxVefsyT-lV0isBbJ?usp=sharing) 10 | 11 | ## Requirements 12 | - [imgaug](https://github.com/aleju/imgaug) (v0.4.0) 13 | - [torch](https://pytorch.org/) (v1.6.0) 14 | - [torchvision](https://pytorch.org/) (v0.7.0) 15 | - [torchsummary](https://github.com/sksq96/pytorch-summary) 16 | - [requirements.txt](./requirements.txt) 17 | 18 | ## Features 19 | - [Automatic Mixed Precision(AMP)](https://pytorch.org/docs/stable/amp.html) 20 | - [Distributed Data Parallel(DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) 21 | - [TorchScript](https://pytorch.org/docs/stable/jit.html#mixing-tracing-and-scripting) 22 | 23 | ## Usage 24 | [SafetyHelmetWearing-Dataset(SHWD)](https://github.com/njvisionpower/Safety-Helmet-Wearing-Dataset) is need to train detector [[Download](https://drive.google.com/file/d/1qWm7rrwvjAWs1slymbrLaCf7Q-wnGLEX/view)]. 25 | 26 | The trained model [weights](https://github.com/tyui592/Real_Time_Helmet_Detection/releases/download/v0.0/check_point.pth) and [demo app](https://github.com/tyui592/Real_Time_Helmet_Detection/releases/download/v0.0/main) can be used in [release(v0.0)](https://github.com/tyui592/Real_Time_Helmet_Detection/releases/tag/v0.0). 27 | 28 | ### Example Scripts 29 | 30 | #### Train 31 | ```bash 32 | $ python main.py --train-flag --gpu-no 0 --data ./DATA/VOC2028/ --save-path ./WEIGHTS/ --amp 33 | ``` 34 | 35 | #### Test 36 | ```bash 37 | $ python main.py --gpu-no 0 --model-load ./WEIGHTS/check_point.pth --data ./DATA/VOC2028 --imsize 512 --save-path ./WEIGHTS/results --batch-size 8 38 | ``` 39 | 40 | #### Measure mAP 41 | ```bash 42 | $ cd mAP 43 | mAP$ python main.py -na -np --dr ../WEIGHTS/results/txt/ 44 | ``` 45 | 46 | #### Demo 47 | ```bash 48 | $ python evaluate.py --gpu-no 0 --model-load ./WEIGHTS/check_point.pth --data ./imgs/000019.jpg --imsize 512 --save-path ./imgs/000019 --topk 100 --conf-th 0.2 --nms-th 0.2 --fontsize 0 49 | ``` 50 | 51 | ## Results 52 | | Input | Output | 53 | | --- | --- | 54 | | ![input](./imgs/000019.jpg) | ![output](./imgs/000019/image.png) | 55 | 56 | **Performance** 57 | 58 | | Helmet (AP) | Person (AP) | mAP | 59 | | --- | --- | --- | 60 | | 88.16 % | 88.71 % | 88.43 % | 61 | 62 | The model trained by above Example Script. 63 | The performance may be improved by increasing model size (ex, `--num-stack`, `--increase-ch`, ...) or searching hyperparameters (ex, `--hm-weight`, `--lr`, ...) more carefully. 64 | 65 | Loading A Pytorch Model in C++ 66 | -- 67 | 68 | 1. Create TorchScript Code of Detector 69 | ```bash 70 | $ python export.py --model-load ./WEIGHTS/check_point.pth --nms-th 0.5 --topk 100 71 | ``` 72 | 73 | 2. [Build App](https://github.com/tyui592/PytorchToCpp) 74 | 75 | 76 | 3. Run App (Speed: 100FPS @(512x512, 1080 Ti)) 77 | ```bash 78 | PytorchToCpp/build$ ./main -m ../../jit_traced_model_gpu.pth -i ../../imgs/000019.jpg 79 | ``` 80 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import torch 6 | import random 7 | import numpy 8 | 9 | from utils import save_pickle, load_pickle 10 | 11 | def build_parser(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument('--gpu-no', type=int, nargs='+', 15 | help='Number of GPU ID to use, 0~N: GPU, -1: CPU', default=[0]) 16 | parser.add_argument('--random-seed', type=int, 17 | help='random seed for reproducible experiments', default=777) 18 | 19 | # train 20 | parser.add_argument('--train-flag', action='store_true', 21 | help='set this flag for training', default=False) 22 | parser.add_argument('--data', type=str, 23 | help='data path for training or evaluation', default=None) 24 | parser.add_argument('--batch-size', type=int, 25 | help='batch size', default=16) 26 | parser.add_argument('--sub-divisions', type=int, 27 | help='optimize every N iterations for gradient accumulation', default=1) 28 | parser.add_argument('--start-epoch', type=int, 29 | help='start epoch', default=0) 30 | parser.add_argument('--end-epoch', type=int, 31 | help='end epoch', default=100) 32 | parser.add_argument('--num-workers', type=int, 33 | help='number of workers for data loading', default=8) 34 | 35 | # amp 36 | parser.add_argument('--amp', action='store_true', 37 | help='automatic mixed precision flag', default=False) 38 | 39 | # ddp(distributed data parallel) 40 | parser.add_argument('--world-size', type=int, 41 | help='for distributed data parallel', default=1) 42 | parser.add_argument('--rank', type=int, 43 | help='for distributed data parallel', default=0) 44 | parser.add_argument('--dist-backend', type=str, 45 | help='for distributed data parallel', default='nccl') 46 | parser.add_argument('--dist-url', type=str, 47 | help='for distributed data parallel', default='tcp://localhost:29500') 48 | 49 | # evaluation and demo 50 | parser.add_argument('--imsize', type=int, 51 | help='when evaluation or demo run, the image resized by imsize x imsize', default=None) 52 | parser.add_argument('--topk', type=int, 53 | help='extract topk peak predictions', default=100) 54 | parser.add_argument('--conf-th', type=float, 55 | help='confidence threshold', default=0.0) 56 | parser.add_argument('--nms-th', type=float, 57 | help='nms threshold', default=0.5) 58 | parser.add_argument('--pool-size', type=int, 59 | help='pool(it is used to find peak value in heatmap) size', default=3) 60 | parser.add_argument('--model-load', type=str, 61 | help='check_point path', default=None) 62 | parser.add_argument('--nms', type=str, 63 | help='select nms algorithm (nms | soft-nms)', default='nms') 64 | parser.add_argument('--fontsize', type=int, 65 | help='fontsize for demo, 0: dont write score and class in the image', default=10) 66 | 67 | # augmentation 68 | # reference: https://imgaug.readthedocs.io/en/latest/source/examples_bounding_boxes.html#a-simple-example 69 | parser.add_argument('--crop-percent', type=float, nargs='+', 70 | help='range(min, max), how many crop the image', default=[0.0, 0.1]) 71 | parser.add_argument('--color-multiply', type=float, nargs='+', 72 | help='range(min, max), how many adjust the brightness', default=[1.2, 1.5]) 73 | parser.add_argument('--translate-percent', type=float, 74 | help='ratio, how many translate the image', default=0.1) 75 | parser.add_argument('--affine-scale', type=float, nargs='+', 76 | help='range(min, ratio), how many scaling the image', default=[0.5, 1.5]) 77 | parser.add_argument('--multiscale_flag', action='store_true', 78 | help='training with multi-resolution images the resolution is randomly selected per every iteration', default=False) 79 | parser.add_argument('--multiscale', type=int, nargs='+', 80 | help='[min, max, step] if multiscale_flag set False network train with the max size', default=[320, 512, 64]) 81 | 82 | # loss 83 | parser.add_argument('--hm-weight', type=float, 84 | help='heat map loss weight', default=1.0) 85 | parser.add_argument('--offset-weight', type=float, 86 | help='offset loss weight', default=1.0) 87 | parser.add_argument('--size-weight', type=float, 88 | help='size(wh) loss weight', default=0.1) 89 | parser.add_argument('--focal-alpha', type=float, 90 | help='alpha for focal loss(heatmap)', default=2.0) 91 | parser.add_argument('--focal-beta', type=float, 92 | help='beta for focal loss(heatmap)', default=4.0) 93 | 94 | # network 95 | parser.add_argument('--scale_factor', type=int, 96 | help='downsampling scale from image to heatmap', default=4) 97 | parser.add_argument('--num-cls', type=int, 98 | help='number of classes', default=2) 99 | parser.add_argument('--pretrained', type=str, 100 | help='select pretrained backbone (scratch | imagenet)', default='imagenet') 101 | parser.add_argument('--normalized-coord', action='store_true', 102 | help='predict normalized(relative) offset ans size of bounding box', default=False) 103 | ## backbone - hourglass 104 | parser.add_argument('--num-stack', type=int, 105 | help='number of stack in hourglass network', default=1) 106 | parser.add_argument('--hourglass-inch', type=int, 107 | help='number of channels for hougrglass networks', default=128) 108 | parser.add_argument('--increase-ch', type=int, 109 | help='in the hougralss network, more deep layer has more channels by this factor', default=0) 110 | parser.add_argument('--activation', type=str, 111 | help='activation funciton', default='ReLU') 112 | parser.add_argument('--pool', type=str, 113 | help='pooling function', default='Max') 114 | ## neck 115 | parser.add_argument('--neck-activation', type=str, 116 | help='activation funciton', default='ReLU') 117 | parser.add_argument('--neck-pool', type=str, 118 | help='pooling function (None | SPP)', default='None') 119 | 120 | # optimization 121 | parser.add_argument('--lr', type=float, 122 | help='learning rate, select lr more carefully when use amp', default=5e-4) 123 | parser.add_argument('--optim', type=str, 124 | help='optimization algorithm', default='Adam') 125 | parser.add_argument('--lr-milestone', type=int, nargs='+', 126 | help='epoch for adjust lr', default=[50, 90]) 127 | parser.add_argument('--lr-gamma', type=float, 128 | help='scale factor', default=0.1) 129 | 130 | # log 131 | parser.add_argument('--print-interval', type=int, 132 | help='print logs every N iterations', default=100) 133 | parser.add_argument('--save-path', type=str, 134 | help='path to save results', default='./WEIGHTS/') 135 | 136 | return parser.parse_args() 137 | 138 | 139 | def get_arguments(): 140 | args = build_parser() 141 | # set random seed for reproducible experiments 142 | # reference: https://github.com/pytorch/pytorch/issues/7068 143 | random.seed(args.random_seed) 144 | numpy.random.seed(args.random_seed) 145 | torch.manual_seed(args.random_seed) 146 | torch.cuda.manual_seed(args.random_seed) 147 | torch.cuda.manual_seed_all(args.random_seed) 148 | 149 | # these flags can affect performance, selec carefully 150 | # torch.backends.cudnn.deterministic = True 151 | # torch.backends.cudnn.benchmark = False 152 | 153 | os.makedirs(args.save_path, exist_ok=True) 154 | if args.train_flag: 155 | os.makedirs(os.path.join(args.save_path, 'training_log'), exist_ok=True) 156 | else: 157 | loaded_args = load_pickle(os.path.join(os.path.dirname(args.model_load), 'argument.pickle')) 158 | args = update_arguments_for_eval(args, loaded_args) 159 | 160 | # cuda setting 161 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' 162 | os.environ['CUDA_VISIBLE_DEVICES'] = ', '.join(map(str, args.gpu_no)) 163 | 164 | with open(os.path.join(args.save_path, 'argument.txt'), 'w') as f: 165 | for key, value in sorted(vars(args).items()): 166 | f.write('%s: %s'%(key, value) + '\n') 167 | 168 | save_pickle(os.path.join(args.save_path, 'argument.pickle'), args) 169 | return args 170 | 171 | def update_arguments_for_eval(old, new): 172 | targets = ['scale_factor', 'num_cls', 'pretrained', 'normalized_coord', 173 | 'num_stack', 'hourglass_inch', 'increase_ch', 'activation', 'pool', 174 | 'neck_activation', 'neck_pool'] 175 | 176 | for target in targets: 177 | old.__dict__[target] = new.__dict__[target] 178 | 179 | return old 180 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import collections 4 | import numpy as np 5 | from PIL import Image 6 | import xml.etree.ElementTree as ET 7 | 8 | import imgaug.augmenters as iaa 9 | from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage 10 | 11 | import torch 12 | import torchvision.transforms.functional as TF 13 | 14 | from transform import box2hm 15 | from utils import get_normalizer 16 | 17 | CLASS2INDEX = {'hat': 0, 'person': 1, 'dog': 0} 18 | INDEX2CLASS = {0: 'hat', 1: 'person'} 19 | CLASS2COLOR = {0: (255, 0, 0), 1: (0, 255, 0)} 20 | 21 | # reference: https://pytorch.org/docs/stable/_modules/torchvision/datasets/voc.html#VOCDetection 22 | class VOC: 23 | def __init__(self, root, transform, image_set, pretrained, normalized_coord, num_cls): 24 | self.transform = transform 25 | self.image_set = image_set 26 | self.normalize = get_normalizer(pretrained=pretrained) 27 | self.normalized_coord = normalized_coord 28 | self.num_cls = num_cls 29 | 30 | image_dir = os.path.join(root, 'JPEGImages') 31 | annotation_dir = os.path.join(root, 'Annotations') 32 | splits_dir = os.path.join(root, 'ImageSets/Main') 33 | 34 | split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') 35 | with open(os.path.join(split_f), 'r') as f: 36 | file_names = [x.strip() for x in f.readlines()] 37 | 38 | self.images = [os.path.join(image_dir, x + '.jpg') for x in file_names] 39 | self.annotations = [os.path.join(annotation_dir, x + '.xml') for x in file_names] 40 | 41 | assert len(self.images) == len(self.annotations) 42 | print('%s: %d images are loaded from %s'%(time.ctime(), len(self.images), root)) 43 | 44 | def __len__(self): 45 | return len(self.images) 46 | 47 | def __getitem__(self, index): 48 | img_pil = Image.open(self.images[index]).convert('RGB') 49 | voc_dict = self._parse_voc_xml(ET.parse(self.annotations[index]).getroot()) 50 | box_lst, id_lst = self._parse_voc_dict(voc_dict) 51 | 52 | img_np, bbs_iaa = self._type_cast(img_pil, box_lst, id_lst) 53 | return img_np, bbs_iaa, voc_dict 54 | 55 | def _parse_voc_dict(self, voc_dict): 56 | box_lst, id_lst = [], [] 57 | for obj in voc_dict['annotation']['object']: 58 | id_lst.append(CLASS2INDEX[obj['name'].lower()]) 59 | 60 | bndbox = obj['bndbox'] 61 | box = bndbox['xmin'], bndbox['ymin'], bndbox['xmax'], bndbox['ymax'] 62 | box_lst.append(list(map(int, box))) 63 | return box_lst, id_lst 64 | 65 | def _parse_voc_xml(self, node): 66 | voc_dict = {} 67 | children = list(node) 68 | if children: 69 | def_dic = collections.defaultdict(list) 70 | for dc in map(self._parse_voc_xml, children): 71 | for ind, v in dc.items(): 72 | def_dic[ind].append(v) 73 | if node.tag == 'annotation': 74 | def_dic['object'] = [def_dic['object']] 75 | voc_dict = {node.tag: {ind: v[0] if len(v) == 1 else v for ind, v in def_dic.items()}} 76 | if node.text: 77 | text = node.text.strip() 78 | if not children: 79 | voc_dict[node.tag] = text 80 | return voc_dict 81 | 82 | def _type_cast(self, img_pil, box_lst, id_lst): 83 | # change type of data to use `imgaug` 84 | img_np = np.asarray(img_pil) 85 | 86 | bbs = [] 87 | for (x1, y1, x2, y2), label in zip(box_lst, id_lst): 88 | bbs.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2, label=label)) 89 | bbs_iaa = BoundingBoxesOnImage(bbs, shape=img_np.shape) 90 | 91 | return img_np, bbs_iaa 92 | 93 | def collate_fn(self, batch): 94 | img_np_lst, bbs_iaa_lst, voc_dict_lst = list(zip(*batch)) 95 | 96 | # batch-wise image augmentation 97 | img_np_lst, bbs_iaa_lst = self.transform(img_np_lst, bbs_iaa_lst) 98 | 99 | # type casting for heatmap 100 | batch_bbs_lst, batch_id_lst = [], [] 101 | for bbs_iaa in bbs_iaa_lst: 102 | temp_bbs_lst, temp_id_lst = [], [] 103 | for bbs in bbs_iaa.bounding_boxes: 104 | temp_bbs_lst.append(bbs.coords.flatten()) 105 | temp_id_lst.append(bbs.label) 106 | batch_bbs_lst.append(temp_bbs_lst) 107 | batch_id_lst.append(temp_id_lst) 108 | 109 | # make heatmap 110 | heatmap_lst, offset_lst, wh_lst, mask_lst = [],[],[],[] 111 | for bbs_lst, id_lst in zip(batch_bbs_lst, batch_id_lst): 112 | heatmap, offset, wh, mask = box2hm(bbs_lst, id_lst, bbs_iaa_lst[0].shape[:2], num_cls=self.num_cls, normalized=self.normalized_coord) 113 | heatmap_lst.append(heatmap) 114 | offset_lst.append(offset) 115 | wh_lst.append(wh) 116 | mask_lst.append(mask) 117 | 118 | # to tensor 119 | batch_img_ten = torch.stack([self.normalize(TF.to_tensor(img_np)) for img_np in img_np_lst]) 120 | batch_heatmap_ten = torch.stack([torch.Tensor(heatmap) for heatmap in heatmap_lst]) 121 | batch_offset_ten = torch.stack([torch.Tensor(offset) for offset in offset_lst]) 122 | batch_wh_ten = torch.stack([torch.Tensor(wh) for wh in wh_lst]) 123 | batch_mask_ten = torch.stack([torch.Tensor(mask) for mask in mask_lst]) 124 | 125 | return batch_img_ten, batch_heatmap_ten, batch_offset_ten, batch_wh_ten, batch_mask_ten, voc_dict_lst 126 | 127 | class TrainAugmentor: 128 | def __init__(self, crop_percent=(0.0, 0.1), color_multiply=(1.2, 1.5), translate_percent=0.1, 129 | affine_scale=(0.5, 1.5), multiscale_flag=False, multiscale=[320, 608, 32]): 130 | 131 | self.multiscale_flag = multiscale_flag 132 | self.multiscale_min = multiscale[0] 133 | self.multiscale_max = multiscale[1] 134 | self.multiscale_step = multiscale[2] 135 | 136 | self.seq = iaa.Sequential([ 137 | iaa.Multiply(color_multiply), 138 | iaa.Affine( 139 | translate_percent=translate_percent, 140 | scale=affine_scale 141 | ), 142 | iaa.Crop(percent=crop_percent), 143 | iaa.Fliplr(0.5), 144 | ]) 145 | 146 | return None 147 | 148 | def __call__(self, img_np_lst, bbs_iaa_lst): 149 | # transform 150 | img_np_lst, bbs_iaa_lst = self.seq(images=img_np_lst, bounding_boxes=bbs_iaa_lst) 151 | bbs_iaa_lst = [bbs_iaa.remove_out_of_image().clip_out_of_image() for bbs_iaa in bbs_iaa_lst] 152 | 153 | if self.multiscale_flag: 154 | target_size = np.random.choice(range(self.multiscale_min, self.multiscale_max, self.multiscale_step)) 155 | else: 156 | target_size = self.multiscale_max 157 | resize = iaa.Resize(target_size) 158 | 159 | img_np_lst, bbs_iaa_lst = resize(images=img_np_lst, bounding_boxes=bbs_iaa_lst) 160 | 161 | return img_np_lst, bbs_iaa_lst 162 | 163 | class TestAugmentor: 164 | def __init__(self, imsize): 165 | self.seq = iaa.Resize(imsize) 166 | 167 | def __call__(self, img_np_lst, bbs_iaa_lst): 168 | img_np_lst, bbs_iaa_lst = self.seq(images=img_np_lst, bounding_boxes=bbs_iaa_lst) 169 | 170 | return img_np_lst, bbs_iaa_lst 171 | 172 | def load_dataset(args): 173 | if args.train_flag: 174 | transform = TrainAugmentor(crop_percent = tuple(args.crop_percent), 175 | color_multiply = tuple(args.color_multiply), 176 | translate_percent = args.translate_percent, 177 | affine_scale = tuple(args.affine_scale), 178 | multiscale_flag = args.multiscale_flag, 179 | multiscale = args.multiscale) 180 | else: 181 | transform = TestAugmentor(imsize=args.imsize) 182 | 183 | dataset = VOC(root = args.data, 184 | transform = transform, 185 | image_set = 'trainval' if args.train_flag else 'test', 186 | pretrained = args.pretrained, 187 | normalized_coord = args.normalized_coord, 188 | num_cls = args.num_cls) 189 | return dataset 190 | 191 | 192 | if __name__ == '__main__': 193 | from utils import ten2pil 194 | 195 | dataset = VOC(root='./DATA/VOC2028/', 196 | transform=TrainAugmentor(), 197 | image_set='trainval', 198 | pretrained='imagenet', 199 | normalized_coord=False, 200 | num_cls=2) 201 | 202 | dataloader = torch.utils.data.DataLoader(dataset = dataset, 203 | batch_size = 10, 204 | shuffle = False, 205 | num_workers = 1, 206 | collate_fn = dataset.collate_fn) 207 | 208 | image, heatmap, offset, wh, mask, info = next(iter(dataloader)) 209 | 210 | image_pil = ten2pil(image, 'imagenet') 211 | image_pil.save('image.png') 212 | 213 | heatmap_pil_1 = ten2pil(heatmap[:, 0:1, :, :], None) 214 | heatmap_pil_2 = ten2pil(heatmap[:, 1:2, :, :], None) 215 | heatmap_pil_1.save('heatmap_0.png') 216 | heatmap_pil_2.save('heatmap_1.png') 217 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | from tqdm import tqdm 5 | from collections import defaultdict 6 | 7 | import torch 8 | import torchvision 9 | 10 | from data import load_dataset 11 | from train import load_network 12 | from utils import AverageMeter, save_pickle 13 | from transform import hm2box 14 | 15 | def single_device_evaluate(args): 16 | device = torch.device('cpu' if -1 in args.gpu_no else 'cuda') 17 | print('%s: Use %s for evaluation'%(time.ctime(), 'CPU' if -1 in args.gpu_no else 'GPU: %d'%args.gpu_no[0])) 18 | 19 | # load network 20 | network, _, _, _ = load_network(args, device) 21 | 22 | # perdict 23 | predictor = Prediction(network = network, 24 | topk = args.topk, 25 | scale_factor = args.scale_factor, 26 | conf_th = args.conf_th, 27 | nms = args.nms, 28 | nms_th = args.nms_th, 29 | normalized_coord = args.normalized_coord).to(device) 30 | 31 | # load dataset 32 | dataset = load_dataset(args) 33 | dataloader = torch.utils.data.DataLoader(dataset = dataset, 34 | batch_size = args.batch_size, 35 | shuffle = False, 36 | num_workers = args.num_workers, 37 | collate_fn = dataset.collate_fn) 38 | 39 | # evaluate 40 | predictions = evaluate_step(dataloader, predictor, device, args) 41 | 42 | # save the result 43 | save_pickle(os.path.join(args.save_path, 'prediction_results.pickle'), predictions) 44 | 45 | # make text file to measure mAP 46 | save_path = os.path.join(args.save_path, 'txt') 47 | os.makedirs(save_path, exist_ok=True) 48 | for filename, prediction in predictions.items(): 49 | cls_ids, scores, boxes = prediction[:, 0], prediction[:, 1], prediction[:, 2:] 50 | 51 | filename = os.path.splitext(filename)[0] + '.txt' 52 | with open(os.path.join(save_path, filename), 'w') as f: 53 | for i in range(cls_ids.shape[0]): 54 | f.write('%d %f %d %d %d %d\n'%(cls_ids[i], scores[i], boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3])) 55 | 56 | return None 57 | 58 | def evaluate_step(dataloader, predictor, device, args): 59 | time_logger = defaultdict(AverageMeter) 60 | 61 | prediction_results = {} 62 | 63 | predictor.eval() 64 | 65 | tictoc = time.time() 66 | for image, gt_heatmap, gt_offset, gt_size, gt_mask, gt_dict in tqdm(dataloader): 67 | time_logger['data'].update(time.time() - tictoc) 68 | 69 | tictoc = time.time() 70 | box_lst, cls_lst, score_lst = predictor(image.to(device)) 71 | time_logger['forward'].update(time.time() - tictoc) 72 | 73 | for b in range(image.shape[0]): 74 | boxes, clss, scores, gt_info = box_lst[b], cls_lst[b], score_lst[b], gt_dict[b] 75 | 76 | # resize to origin image size 77 | origin_size = int(gt_info['annotation']['size']['width']), int(gt_info['annotation']['size']['height']) 78 | resized_size = args.imsize, args.imsize 79 | _boxes = resize_box_to_original_scale(boxes.detach().cpu().numpy(), origin_size, resized_size) 80 | 81 | # make single array 82 | clss_np = clss.detach().cpu().numpy()[:, np.newaxis] 83 | scores_np = scores.detach().cpu().numpy()[:, np.newaxis] 84 | boxes_np = np.asarray(_boxes) 85 | 86 | # save prediction results 87 | if boxes_np.shape[0] != 0: 88 | pred = np.hstack([clss_np, scores_np, boxes_np]) 89 | else: 90 | pred = np.zeros((0, 6)) 91 | prediction_results[gt_info['annotation']['filename']] = pred 92 | 93 | tictoc = time.time() 94 | _log = '%s: Evaluation'%(time.ctime()) 95 | _log += ', Time(ms) [data: %6.2f'%(time_logger['data'].avg * 1000) 96 | _log += ', forward: %6.2f]'%(time_logger['forward'].avg * 1000) 97 | print(_log) 98 | 99 | return prediction_results 100 | 101 | def resize_box_to_original_scale(boxes, original_size, transformed_size): 102 | origin_width, origin_height = original_size 103 | trans_width, trans_height = transformed_size 104 | 105 | rw = origin_width / trans_width 106 | rh = origin_height / trans_height 107 | 108 | resized_boxes = [] 109 | for xmin, ymin, xmax, ymax in boxes: 110 | resized_boxes.append([xmin * rw, ymin * rh, xmax * rw, ymax * rh]) 111 | 112 | return resized_boxes 113 | 114 | class Prediction(torch.nn.Module): 115 | def __init__(self, network, topk, scale_factor, conf_th, nms, nms_th, normalized_coord=False): 116 | super(Prediction, self).__init__() 117 | 118 | self.network = network 119 | self.topk = topk 120 | self.scale_factor = scale_factor 121 | self.conf_th = conf_th 122 | self.nms = nms 123 | self.nms_th = nms_th 124 | self.normalized_coord = normalized_coord 125 | 126 | def forward(self, x): 127 | ''' x: input tensor (b, c, h, w) ''' 128 | box_lst, cls_lst, score_lst = [], [], [] 129 | 130 | batch_output = self.network(x) # b, n, num_cls+4, h, w 131 | num_cls = batch_output.size(2) - 4 132 | # batch-wise 133 | for outputs in batch_output.split(1, dim=0): 134 | # stack(scale)-wise 135 | stack_boxes, stack_clss, stack_scores = [], [], [] 136 | for output in outputs.split(1, dim=1): 137 | output.squeeze_(1) 138 | heatmap, offset, wh = output.split([num_cls,2,2], dim=1) 139 | heatmap = torch.sigmoid(heatmap) 140 | if self.normalized_coord: 141 | offset = torch.sigmoid(offset) 142 | wh = torch.sigmoid(wh) 143 | 144 | boxes, clss, scores = hm2box(heatmap = heatmap.squeeze_(0), 145 | offset = offset.squeeze_(0), 146 | wh = wh.squeeze_(0), 147 | scale_factor = self.scale_factor, 148 | topk = self.topk, 149 | conf_th = self.conf_th, 150 | normalized = self.normalized_coord) 151 | stack_boxes.append(boxes) 152 | stack_clss.append(clss) 153 | stack_scores.append(scores) 154 | 155 | # non maximum suppression 156 | boxes, clss, scores = self.nonmaximum_supression(torch.cat(stack_boxes, dim=0), 157 | torch.cat(stack_clss, dim=0), 158 | torch.cat(stack_scores, dim=0)) 159 | 160 | # append boxes per batch 161 | box_lst.append(boxes) 162 | cls_lst.append(clss) 163 | score_lst.append(scores) 164 | 165 | return box_lst, cls_lst, score_lst 166 | 167 | def nonmaximum_supression(self, boxes, clss, scores): 168 | ''' 169 | boxes: tensor (N, 4) 170 | clss: tensor (N) 171 | scores: tensor (N) 172 | ''' 173 | if self.nms == 'nms': 174 | unique_indices = torchvision.ops.nms(boxes, scores, self.nms_th) 175 | 176 | elif self.nms == 'soft-nms': 177 | unique_indices = soft_nms_pytorch(boxes, scores, thresh=self.conf_th) 178 | 179 | else: 180 | raise NotImplementedError('Not expected nms algorithm: %s'%self.nms) 181 | 182 | return boxes[unique_indices], clss[unique_indices], scores[unique_indices] 183 | 184 | def soft_nms_pytorch(dets, box_scores, sigma=0.5, thresh=0.001, cuda=0): 185 | """ Author: Richard Fang(github.com/DocF) 186 | Build a pytorch implement of Soft NMS algorithm. 187 | # Augments 188 | dets: boxes coordinate tensor (format:[x1, y1, x2, y2]) 189 | box_scores: box score tensors 190 | sigma: variance of Gaussian function 191 | thresh: score thresh 192 | cuda: CUDA flag 193 | # Return 194 | the index of the selected boxes 195 | """ 196 | 197 | # Indexes concatenate boxes with the last column 198 | N = dets.shape[0] 199 | if cuda: 200 | indexes = torch.arange(0, N, dtype=torch.float).cuda().view(N, 1) 201 | else: 202 | indexes = torch.arange(0, N, dtype=torch.float).view(N, 1) 203 | dets = torch.cat((dets, indexes), dim=1) 204 | 205 | # The order of boxes coordinate is [y1,x1,y2,x2] 206 | x1 = dets[:, 0] 207 | y1 = dets[:, 1] 208 | x2 = dets[:, 2] 209 | y2 = dets[:, 3] 210 | scores = box_scores 211 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 212 | 213 | for i in range(N): 214 | # intermediate parameters for later parameters exchange 215 | tscore = scores[i].clone() 216 | pos = i + 1 217 | 218 | if i != N - 1: 219 | maxscore, maxpos = torch.max(scores[pos:], dim=0) 220 | if tscore < maxscore: 221 | dets[i], dets[maxpos.item() + i + 1] = dets[maxpos.item() + i + 1].clone(), dets[i].clone() 222 | scores[i], scores[maxpos.item() + i + 1] = scores[maxpos.item() + i + 1].clone(), scores[i].clone() 223 | areas[i], areas[maxpos + i + 1] = areas[maxpos + i + 1].clone(), areas[i].clone() 224 | 225 | # IoU calculate 226 | xx1 = np.maximum(dets[i, 0].to("cpu").detach().numpy(), dets[pos:, 0].to("cpu").detach().numpy()) 227 | yy1 = np.maximum(dets[i, 1].to("cpu").detach().numpy(), dets[pos:, 1].to("cpu").detach().numpy()) 228 | xx2 = np.minimum(dets[i, 2].to("cpu").detach().numpy(), dets[pos:, 2].to("cpu").detach().numpy()) 229 | yy2 = np.minimum(dets[i, 3].to("cpu").detach().numpy(), dets[pos:, 3].to("cpu").detach().numpy()) 230 | 231 | w = np.maximum(0.0, xx2 - xx1 + 1) 232 | h = np.maximum(0.0, yy2 - yy1 + 1) 233 | inter = torch.tensor(w * h).cuda() if cuda else torch.tensor(w * h) 234 | ovr = torch.div(inter, (areas[i] + areas[pos:] - inter)) 235 | 236 | # Gaussian decay 237 | weight = torch.exp(-(ovr * ovr) / sigma) 238 | scores[pos:] = weight * scores[pos:] 239 | 240 | # select the boxes and keep the corresponding indexes 241 | keep = dets[:, 4][scores > thresh].long() 242 | 243 | return keep 244 | 245 | if __name__ == '__main__': 246 | from config import get_arguments 247 | from data import INDEX2CLASS, CLASS2COLOR 248 | from utils import imload, draw_box, write_text 249 | 250 | args = get_arguments() 251 | 252 | device = torch.device('cpu' if -1 in args.gpu_no else 'cuda') 253 | 254 | network, _, _, _ = load_network(args, device) 255 | predictor = Prediction(network = network, 256 | topk = args.topk, 257 | scale_factor = args.scale_factor, 258 | conf_th = args.conf_th, 259 | nms = args.nms, 260 | nms_th = args.nms_th, 261 | normalized_coord = args.normalized_coord).to(device) 262 | predictor.eval() 263 | 264 | # single image prediction 265 | img_ten, img_pil, origin_size = imload(args.data, args.pretrained, args.imsize) 266 | box_ten, cls_ten, score_ten = predictor(img_ten.to(device)) 267 | box_lst, cls_lst, score_lst = box_ten[0].tolist(), cls_ten[0].tolist(), score_ten[0].tolist() 268 | 269 | # clamp outside image 270 | box_lst = [list(map(lambda x: max(0, min(x, args.imsize)), box)) for box in box_lst] 271 | 272 | # draw box, class and score per prediction 273 | for i, (box, cls, score) in enumerate(zip(box_lst, cls_lst, score_lst)): 274 | img_pil = draw_box(img_pil, box, color=CLASS2COLOR[cls]) 275 | if args.fontsize > 0: 276 | text = '%s: %1.2f'%(INDEX2CLASS[cls], score) 277 | coord = [box[0], box[1]-args.fontsize] 278 | img_pil = write_text(img_pil, text, coord, fontsize=args.fontsize) 279 | 280 | # resize origin scale of image 281 | xmin, ymin, xmax, ymax = box 282 | xmin = xmin*origin_size[0]/args.imsize 283 | ymin = ymin*origin_size[1]/args.imsize 284 | xmax = xmax*origin_size[0]/args.imsize 285 | ymax = ymax*origin_size[0]/args.imsize 286 | 287 | print('%s: Index: %3d, Class: %7s, Score: %1.2f, Box: %4d, %4d, %4d, %4d'%(time.ctime(), i, INDEX2CLASS[cls], score, xmin, ymin, xmax, ymax)) 288 | 289 | # resize to origin size and save the result image 290 | img_pil.resize(origin_size).save(os.path.join(args.save_path, 'image.png')) 291 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | from transform import hm2box 6 | from utils import get_normalizer 7 | 8 | class Export(torch.nn.Module): 9 | def __init__(self, network, topk, scale_factor, conf_th, nms_th, normalized_coord=False): 10 | super(Export, self).__init__() 11 | 12 | self.network = network 13 | self.topk = topk 14 | self.scale_factor = scale_factor 15 | self.conf_th = conf_th 16 | self.nms_th = nms_th 17 | self.normalized_coord = normalized_coord 18 | 19 | def forward(self, x): 20 | # x: input tensor (1, c, h, w) 21 | # output: 1, n, 6, h, w 22 | batch_output = self.network(x) 23 | 24 | # batch-wise for loop 25 | for outputs in batch_output.split(1, dim=0): 26 | 27 | # stack-wise for loop 28 | stack_boxes, stack_clss, stack_scores = [], [], [] 29 | for output in outputs.split(1, dim=1): 30 | output.squeeze_(1) 31 | heatmap, offset, wh = output.split([2,2,2], dim=1) 32 | heatmap = torch.sigmoid(heatmap) 33 | if self.normalized_coord: 34 | offset = torch.sigmoid(offset) 35 | wh = torch.sigmoid(wh) 36 | 37 | boxes, clss, scores = hm2box(heatmap = heatmap.squeeze_(0), 38 | offset = offset.squeeze_(0), 39 | wh = wh.squeeze_(0), 40 | scale_factor = self.scale_factor, 41 | topk = self.topk, 42 | conf_th = self.conf_th, 43 | normalized = self.normalized_coord) 44 | stack_boxes.append(boxes) 45 | stack_clss.append(clss) 46 | stack_scores.append(scores) 47 | 48 | boxes = torch.cat(stack_boxes, dim=0) 49 | clss = torch.cat(stack_clss, dim=0) 50 | scores = torch.cat(stack_scores, dim=0) 51 | 52 | # non maximum suppression 53 | boxes, clss, scores = self.nms(boxes, clss, scores, self.nms_th) 54 | 55 | return boxes, clss, scores 56 | 57 | def nms(self, boxes, clss, scores, threshold): 58 | ''' 59 | boxes: tensor (N, 4) 60 | clss: tensor (N) 61 | scores: tensor (N) 62 | threshold: float 63 | ''' 64 | unique_indices = nms_pytorch(boxes, scores, threshold) 65 | 66 | return boxes[unique_indices], clss[unique_indices], scores[unique_indices] 67 | 68 | @torch.jit.script 69 | def nms_pytorch(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) -> torch.Tensor: 70 | indices = torch.argsort(scores, descending=True) 71 | _boxes = boxes[indices] 72 | _scores = scores[indices] 73 | 74 | for i in range(_boxes.shape[0]-1): 75 | if _scores[i] == 0: 76 | continue 77 | xmin, ymin, xmax, ymax = torch.split(_boxes[i], 1, 0) 78 | 79 | _xmin, _ymin, _xmax, _ymax = torch.split(_boxes[i+1:], 1, 1) 80 | 81 | # intersection area 82 | x1 = torch.max(xmin, _xmin) 83 | y1 = torch.max(ymin, _ymin) 84 | x2 = torch.min(xmax, _xmax) 85 | y2 = torch.min(ymax, _ymax) 86 | w = torch.clamp((x2 - x1 + 1), min=0) 87 | h = torch.clamp((y2 - y1 + 1), min=0) 88 | 89 | area = (xmax - xmin + 1) * (ymax - ymin + 1) 90 | _area = (_xmax - _xmin + 1) * (_ymax - _ymin + 1) 91 | overlap = w * h 92 | 93 | iou = overlap / (area + _area - overlap) 94 | 95 | _scores[i+1:] = _scores[i+1:] * (iou.squeeze() < threshold).float() 96 | 97 | return indices[_scores>0].long() 98 | 99 | if __name__ == '__main__': 100 | from train import load_network 101 | from config import build_parser 102 | 103 | do_test = False # for debug 104 | device = torch.device('cpu') 105 | 106 | args = build_parser() 107 | 108 | # load network 109 | network, _, _, _ = load_network(args, device) 110 | # perdict 111 | #pre = Preprocess() 112 | predictor = Export(network = network, 113 | topk = args.topk, 114 | scale_factor = args.scale_factor, 115 | conf_th = args.conf_th, 116 | nms_th = args.nms_th, 117 | normalized_coord = args.normalized_coord).to(device) 118 | predictor.eval() 119 | 120 | ##################### model save at cpu ##################### 121 | x = torch.randn(1, 3, 512, 512) 122 | traced_model_cpu = torch.jit.trace(predictor.cpu(), x.cpu()) 123 | torch.jit.save(traced_model_cpu, "jit_traced_model_cpu.pth") 124 | print("Model saved at cpu") 125 | 126 | ##################### model save at gpu ##################### 127 | x = torch.randn(1, 3, 512, 512) 128 | traced_model_cpu = torch.jit.trace(predictor.cuda(), x.cuda()) 129 | torch.jit.save(traced_model_cpu, "jit_traced_model_gpu.pth") 130 | print("Model saved at gpu") 131 | 132 | if do_test: 133 | import cv2 134 | normalizer = get_normalizer(pretrained=args.pretrained) 135 | x = cv2.cvtColor(cv2.imread('../0.jpg'), cv2.COLOR_BGR2RGB) 136 | x = cv2.resize(x, dsize=(512, 512), interpolation=cv2.INTER_AREA) 137 | x = torch.tensor(x) # x: HxWxC, 0.0 ~ 255.0 138 | x = x.permute(2, 0, 1)/255.0 139 | x = normalizer(x).unsqueeze(0) 140 | 141 | box_lst, cls_lst, score_lst = predictor(x.to(device)) 142 | for i in range(box_lst.shape[0]): 143 | print(', '.join(map(str, box_lst[i].tolist())), ',', cls_lst[i].item(), ',', score_lst[i].item()) 144 | 145 | ############ check the output of python and traced models ################ 146 | x = torch.ones(1, 3, 512, 512) 147 | box_lst, cls_lst, score_lst = predictor(x.to(device)) 148 | 149 | traced_model = torch.jit.trace(predictor, torch.randn(1, 3, 512, 512)) 150 | x = torch.ones(1, 3, 512, 512) 151 | box_lst2, cls_lst2, score_lst2 = traced_model(x) 152 | print('output python == output jit: ', torch.all(torch.eq(box_lst, box_lst2))) 153 | 154 | -------------------------------------------------------------------------------- /hourglass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Mish(nn.Module): 7 | def __init__(self): 8 | super(Mish, self).__init__() 9 | 10 | def forward(self, x): 11 | return x * (torch.tanh(F.softplus(x))) 12 | 13 | 14 | class Activation(nn.Module): 15 | def __init__(self, activation: str): 16 | super(Activation, self).__init__() 17 | 18 | if activation == 'ReLU': 19 | self.activation = nn.ReLU() 20 | 21 | elif activation == 'LReLU': 22 | self.activation = nn.LeakyReLU() 23 | 24 | elif activation == 'PReLU': 25 | self.activation = nn.PReLU() 26 | 27 | elif activation == 'Linear': 28 | self.activation = nn.Identity() 29 | 30 | elif activation == 'Mish': 31 | self.activation = Mish() 32 | 33 | elif activation == 'Sigmoid': 34 | self.activation = nn.Sigmoid() 35 | 36 | elif activation == 'CELU': 37 | self.activation = nn.CELU() 38 | 39 | else: 40 | raise NotImplementedError("Not expected activation: %s"%activation) 41 | 42 | def forward(self, x): 43 | return self.activation(x) 44 | 45 | 46 | class SPP(nn.Module): 47 | # Convolutional SPP network 48 | # Reference: https://github.com/WongKinYiu/PyTorch_YOLOv4 49 | def __init__(self, ch=128, kernel_sizes=[5, 9, 13], stride=1): 50 | super(SPP, self).__init__() 51 | _ch = ch //2 52 | # convolution layers to deal with increased channels 53 | self.conv1 = nn.Conv2d(ch, _ch, 1, 1, bias=False) 54 | self.conv2 = nn.Conv2d(_ch*(len(kernel_sizes)+1), ch, 1, 1, bias=False) 55 | 56 | self.pooling_layers = nn.ModuleList() 57 | for kernel_size in kernel_sizes: 58 | self.pooling_layers.append(nn.MaxPool2d(kernel_size, stride, (kernel_size-1)//2)) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | y = [x] 63 | for pooling_layer in self.pooling_layers: 64 | y.append(pooling_layer(x)) 65 | return self.conv2(torch.cat(y, dim=1)) 66 | 67 | 68 | class Pool(nn.Module): 69 | def __init__(self, channel: int, pool: str): 70 | super(Pool, self).__init__() 71 | if pool == 'Max': 72 | self.pool = nn.MaxPool2d(2, 2) 73 | 74 | elif pool == 'Avg': 75 | self.pool = nn.AvgPool2d(2, 2) 76 | 77 | elif pool == 'Conv': 78 | self.pool = nn.Conv2d(channel, channel, kernel_size=2, stride=2) 79 | 80 | elif pool == 'SPP': 81 | # NOTE: SPP does not reduce the resolution. It's output has 4 times the number of input channels. 82 | self.pool = SPP(channel) 83 | 84 | elif pool == 'None': 85 | self.pool = nn.Identity() 86 | 87 | else: 88 | raise NotImplementedError("Not expected pool: %s"%pool) 89 | 90 | def forward(self, x): 91 | return self.pool(x) 92 | 93 | 94 | class Convolution(nn.Module): 95 | def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, stride: int = 1, bias: bool = True, bn: bool = False, activation: str = 'ReLU'): 96 | super(Convolution, self).__init__() 97 | 98 | self.activation = Activation(activation) 99 | 100 | self.convolution = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=(kernel_size-1)//2, bias=bias) 101 | 102 | if bn: 103 | self.bn = nn.BatchNorm2d(out_ch, affine=True, track_running_stats=True) 104 | else: 105 | self.bn = nn.Identity() 106 | 107 | def forward(self, x): 108 | return self.activation(self.bn(self.convolution(x))) 109 | 110 | 111 | class Residual(nn.Module): 112 | def __init__(self, in_ch: int, out_ch: int, kernel_size:int = 3, stride: int = 1, activation: str = 'ReLU'): 113 | super(Residual, self).__init__() 114 | 115 | self.activation = Activation(activation) 116 | 117 | self.conv1 = Convolution(in_ch, out_ch, kernel_size, stride, bias=False, bn=True, activation=activation) 118 | self.conv2 = Convolution(out_ch, out_ch, kernel_size, stride, bias=False, bn=True, activation='Linear') 119 | 120 | if in_ch != out_ch: 121 | self.skip = Convolution(in_ch, out_ch, kernel_size=1, stride=stride, bias=False, bn=True, activation='Linear') 122 | else: 123 | self.skip = nn.Identity() 124 | 125 | def forward(self, x): 126 | y = self.conv2(self.conv1(x)) 127 | return self.activation(y + self.skip(x)) 128 | 129 | 130 | class Hourglass(nn.Module): 131 | def __init__(self, num_layer: int, in_ch: int, increase_ch: int = 0, activation: str = 'ReLU', pool: str = 'Max'): 132 | super(Hourglass, self).__init__() 133 | mid_ch = in_ch + increase_ch 134 | 135 | self.up1 = Residual(in_ch, in_ch, activation=activation) 136 | self.pool1 = Pool(in_ch, pool=pool) 137 | _in_ch = in_ch * 4 if pool == 'SPP' else in_ch 138 | 139 | self.low1 = Residual(_in_ch, mid_ch, activation=activation) 140 | # initialize the hourglass layers recursively 141 | if num_layer > 1: 142 | self.low2 = Hourglass(num_layer-1, mid_ch, increase_ch, activation=activation, pool=pool) 143 | else: 144 | self.low2 = Residual(mid_ch, mid_ch, activation=activation) 145 | 146 | self.low3 = Residual(mid_ch, in_ch, activation=activation) 147 | self.up2 = nn.Upsample(scale_factor=2, mode='nearest') 148 | 149 | def forward(self, x): 150 | up1 = self.up1(x) 151 | pool1 = self.pool1(x) 152 | low1 = self.low1(pool1) 153 | low2 = self.low2(low1) 154 | low3 = self.low3(low2) 155 | up2 = self.up2(low3) 156 | return up1 + up2 157 | 158 | 159 | class PreLayer(nn.Module): 160 | def __init__(self, in_ch: int = 3, mid_ch: int = 128, out_ch: int = 5, activation: str = 'ReLU', pool: str = 'Max'): 161 | super(PreLayer, self).__init__() 162 | layers = [] 163 | layers.append(Convolution(in_ch=in_ch, out_ch=64, kernel_size=7, stride=2, bias=True, bn=True, activation=activation)) 164 | layers.append(Residual(in_ch=64, out_ch=mid_ch)) 165 | layers.append(Pool(channel=mid_ch, pool=pool)) 166 | _mid_ch = mid_ch * 4 if pool == 'SPP' else mid_ch 167 | layers.append(Residual(in_ch=_mid_ch, out_ch=mid_ch)) 168 | layers.append(Residual(in_ch=mid_ch, out_ch=out_ch)) 169 | 170 | self.layers = nn.Sequential(*layers) 171 | 172 | def forward(self, x): 173 | return self.layers(x) 174 | 175 | 176 | class Neck(nn.Module): 177 | def __init__(self, ch: int = 128, activation: str = 'ReLU', pool: str = 'None'): 178 | super(Neck, self).__init__() 179 | layers = [] 180 | layers.append(Pool(ch, pool)) 181 | layers.append(Convolution(in_ch=ch, out_ch=ch, kernel_size=1, bn=True, activation=activation)) 182 | layers.append(Residual(ch, ch)) 183 | self.layers = nn.Sequential(*layers) 184 | 185 | def forward(self, x): 186 | return self.layers(x) 187 | 188 | 189 | class Head(nn.Module): 190 | def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 1, stride: int = 1, bias: bool = True, bn: bool = False, activation: str = 'Linear'): 191 | super(Head, self).__init__() 192 | self.layer = Convolution(in_ch=in_ch, out_ch=out_ch, kernel_size=kernel_size, stride=stride, bias=bias, bn=bn, activation=activation) 193 | 194 | def forward(self, x): 195 | return self.layer(x) 196 | 197 | 198 | class StackedHourglass(nn.Module): 199 | def __init__(self, num_stack: int, in_ch: int, out_ch: int, increase_ch: int = 0, activation: str = 'ReLU', pool: str = 'Max', neck_activation: str = 'ReLU', neck_pool: str = 'None'): 200 | super(StackedHourglass, self).__init__() 201 | 202 | # downsample the resolution of input (1 --> 1/4(scale_factor)) 203 | self.pre_layer = PreLayer(in_ch=3, mid_ch=128, out_ch=in_ch, activation=activation, pool=pool) 204 | 205 | # hourglass modules (backbone) 206 | self.hourglass_lst = nn.ModuleList([Hourglass(num_layer=4, in_ch=in_ch, increase_ch=increase_ch, activation=activation, pool=pool) for _ in range(num_stack)]) 207 | 208 | # feature layer (neck) 209 | self.neck_lst = nn.ModuleList([Neck(in_ch, neck_activation, neck_pool) for _ in range(num_stack)]) 210 | 211 | # prediction layer (head) 212 | self.head_lst = nn.ModuleList([Head(in_ch=in_ch, out_ch=out_ch, kernel_size=1, stride=1, bias=True, bn=False, activation='Linear') for _ in range(num_stack)]) 213 | 214 | # merge intermediate hourglass features 215 | self.merge_feature = nn.ModuleList([Convolution(in_ch=in_ch, out_ch=in_ch, kernel_size=1, stride=1, bias=True, bn=False, activation='Linear') for _ in range(num_stack-1)]) 216 | 217 | # merger intermediate hourglass feature and prediction 218 | self.merge_prediction = nn.ModuleList([Convolution(in_ch=out_ch, out_ch=in_ch, kernel_size=1, stride=1, bias=True, bn=False, activation='Linear') for _ in range(num_stack-1)]) 219 | 220 | self.num_stack = num_stack 221 | 222 | 223 | def forward(self, x): 224 | x = self.pre_layer(x) 225 | 226 | intermediate_predictions = [] 227 | for i in range(len(self.hourglass_lst)): 228 | hg = self.hourglass_lst[i](x) 229 | feature = self.neck_lst[i](hg) 230 | prediction = self.head_lst[i](feature) 231 | 232 | intermediate_predictions.append(prediction) 233 | 234 | if i < len(self.hourglass_lst) - 1: 235 | x = x + self.merge_feature[i](feature) + self.merge_prediction[i](prediction) 236 | 237 | return torch.stack(intermediate_predictions, dim=1) 238 | 239 | 240 | if __name__ == '__main__': 241 | # Stacked Hourglass module test 242 | stacked_hourglass = StackedHourglass(num_stack=2, in_ch=128, out_ch=5, increase_ch=0, activation='ReLU', pool='Max', neck_activation='ReLU', neck_pool='None') 243 | print(stacked_hourglass) 244 | stacked_hourglass.eval() 245 | x = torch.randn(2, 3, 512, 512) 246 | out = stacked_hourglass(x) 247 | num_param = sum([params.numel() for params in stacked_hourglass.parameters()]) 248 | print('Stacked Hourglass (%d params) input: (%s), output: (%s)'%(num_param, x.shape, out.shape)) 249 | 250 | # test jit 251 | scripted_sh = torch.jit.trace(stacked_hourglass, x) 252 | x2 = torch.ones(1, 3, 512, 512) 253 | out1 = stacked_hourglass(x2) 254 | x2 = torch.ones(1, 3, 512, 512) 255 | out2 = scripted_sh(x2) 256 | print('Jit test: ', torch.all(torch.eq(out1, out2))) 257 | -------------------------------------------------------------------------------- /imgs/000019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Real_Time_Helmet_Detection/a806bf098990f1d454fc0d5027a95be0935ed3c7/imgs/000019.jpg -------------------------------------------------------------------------------- /imgs/000019/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyui592/Real_Time_Helmet_Detection/a806bf098990f1d454fc0d5027a95be0935ed3c7/imgs/000019/image.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils import load_pickle, save_pickle 5 | 6 | class LossCalculator(nn.Module): 7 | def __init__(self, hm_weight, offset_weight, size_weight, focal_alpha, focal_beta): 8 | super(LossCalculator, self).__init__() 9 | self.log = {'hm': [], 'offset': [], 'size': [], 'total': []} 10 | 11 | self.l1_criterion = NormedL1Loss() 12 | self.focal_criterion = FocalLoss(alpha=focal_alpha, beta=focal_beta) 13 | 14 | self.hm_weight = hm_weight 15 | self.offset_weight = offset_weight 16 | self.size_weight = size_weight 17 | 18 | def forward(self, phm, poff, psize, ghm, goff, gsize, mask): 19 | hm_loss = self.focal_criterion(phm, ghm, mask) 20 | offset_loss = self.l1_criterion(poff, goff, mask) 21 | size_loss = self.l1_criterion(psize, gsize, mask) 22 | 23 | total_loss = hm_loss * self.hm_weight + \ 24 | offset_loss * self.offset_weight + \ 25 | size_loss * self.size_weight 26 | 27 | self.log['hm'].append(hm_loss.item()) 28 | self.log['offset'].append(offset_loss.item()) 29 | self.log['size'].append(size_loss.item()) 30 | self.log['total'].append(total_loss.item()) 31 | 32 | return total_loss 33 | 34 | def get_log(self, length=100): 35 | log = [] 36 | for key in ['hm', 'offset', 'size', 'total']: 37 | if len(self.log[key]) < length: 38 | length = len(self.log[key]) 39 | log.append('%s: %5.2f'%(key, sum(self.log[key][-length:]) / length)) 40 | return ', '.join(log) 41 | 42 | class NormedL1Loss(nn.Module): 43 | def __init__(self): 44 | super(NormedL1Loss, self).__init__() 45 | 46 | def forward(self, pred, gt, mask): 47 | loss = torch.abs(pred * mask - gt * mask) 48 | loss = torch.sum(loss, dim=[1,2,3]).mean() 49 | num_pos = torch.sum(mask).clamp(1, 1e30) 50 | return loss / num_pos 51 | 52 | class FocalLoss(nn.Module): 53 | def __init__(self, alpha, beta): 54 | super(FocalLoss, self).__init__() 55 | self.alpha = alpha 56 | self.beta = beta 57 | 58 | def forward(self, pred, gt, mask, eps=1e-7): 59 | neg_inds = torch.ones_like(mask) - mask 60 | 61 | neg_weights = torch.pow(1 - gt, self.beta) 62 | 63 | pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, self.alpha) * mask 64 | neg_loss = torch.log(1 - pred + eps) * torch.pow(pred, self.alpha) * neg_weights * neg_inds 65 | 66 | pos_loss = pos_loss.sum(dim=[1,2,3]).mean() 67 | neg_loss = neg_loss.sum(dim=[1,2,3]).mean() 68 | num_pos = mask.sum().clamp(1, 1e30) 69 | return -(pos_loss + neg_loss) / num_pos 70 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from datetime import timedelta 4 | 5 | from config import get_arguments 6 | from train import distributed_device_train 7 | from evaluate import single_device_evaluate 8 | 9 | if __name__ == '__main__': 10 | args = get_arguments() 11 | 12 | tictoc = time.time() 13 | if args.train_flag: 14 | distributed_device_train(args) 15 | else: 16 | single_device_evaluate(args) 17 | print('%s: Process is Done During %s'%(time.ctime(), str(timedelta(seconds=(time.time() - tictoc))))) 18 | -------------------------------------------------------------------------------- /optim.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | def get_optimizer(network, lr, lr_milestone, lr_gamma): 4 | optimizer = optim.Adam(network.parameters(), lr=lr) 5 | 6 | scheduler = None 7 | if lr_milestone is not None: 8 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer = optimizer, 9 | milestones = lr_milestone, 10 | gamma = lr_gamma) 11 | 12 | return optimizer, scheduler 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _ipyw_jlab_nb_ext_conf=0.1.0=py37_0 5 | _libgcc_mutex=0.1=main 6 | alabaster=0.7.12=py37_0 7 | anaconda=2019.07=py37_0 8 | anaconda-client=1.7.2=py37_0 9 | anaconda-navigator=1.9.7=py37_0 10 | anaconda-project=0.8.3=py_0 11 | asn1crypto=0.24.0=py37_0 12 | astroid=2.2.5=py37_0 13 | astropy=3.2.1=py37h7b6447c_0 14 | atomicwrites=1.3.0=py37_1 15 | attrs=19.1.0=py37_1 16 | babel=2.7.0=py_0 17 | backcall=0.1.0=py37_0 18 | backports=1.0=py_2 19 | backports.functools_lru_cache=1.5=py_2 20 | backports.os=0.1.1=py37_0 21 | backports.shutil_get_terminal_size=1.0.0=py37_2 22 | backports.tempfile=1.0=py_1 23 | backports.weakref=1.0.post1=py_1 24 | beautifulsoup4=4.7.1=py37_1 25 | bitarray=0.9.3=py37h7b6447c_0 26 | bkcharts=0.2=py37_0 27 | blas=1.0=mkl 28 | bleach=3.1.0=py37_0 29 | blosc=1.16.3=hd408876_0 30 | bokeh=1.2.0=py37_0 31 | boto=2.49.0=py37_0 32 | bottleneck=1.2.1=py37h035aef0_1 33 | bzip2=1.0.8=h7b6447c_0 34 | ca-certificates=2019.5.15=0 35 | cairo=1.14.12=h8948797_3 36 | certifi=2019.6.16=py37_0 37 | cffi=1.12.3=py37h2e261b9_0 38 | chardet=3.0.4=py37_1 39 | click=7.0=py37_0 40 | cloudpickle=1.2.1=py_0 41 | clyent=1.2.2=py37_1 42 | colorama=0.4.1=py37_0 43 | conda=4.7.10=py37_0 44 | conda-build=3.18.8=py37_0 45 | conda-env=2.6.0=1 46 | conda-package-handling=1.3.11=py37_0 47 | conda-verify=3.4.2=py_1 48 | contextlib2=0.5.5=py37_0 49 | cryptography=2.7=py37h1ba5d50_0 50 | curl=7.65.2=hbc83047_0 51 | cycler=0.10.0=py37_0 52 | cython=0.29.12=py37he6710b0_0 53 | cytoolz=0.10.0=py37h7b6447c_0 54 | dask=2.1.0=py_0 55 | dask-core=2.1.0=py_0 56 | dbus=1.13.6=h746ee38_0 57 | decorator=4.4.0=py37_1 58 | defusedxml=0.6.0=py_0 59 | distributed=2.1.0=py_0 60 | docutils=0.14=py37_0 61 | entrypoints=0.3=py37_0 62 | et_xmlfile=1.0.1=py37_0 63 | expat=2.2.6=he6710b0_0 64 | fastcache=1.1.0=py37h7b6447c_0 65 | filelock=3.0.12=py_0 66 | flask=1.1.1=py_0 67 | fontconfig=2.13.0=h9420a91_0 68 | freetype=2.9.1=h8a8886c_1 69 | fribidi=1.0.5=h7b6447c_0 70 | future=0.17.1=py37_0 71 | get_terminal_size=1.0.0=haa9412d_0 72 | gevent=1.4.0=py37h7b6447c_0 73 | glib=2.56.2=hd408876_0 74 | glob2=0.7=py_0 75 | gmp=6.1.2=h6c8ec71_1 76 | gmpy2=2.0.8=py37h10f8cd9_2 77 | graphite2=1.3.13=h23475e2_0 78 | greenlet=0.4.15=py37h7b6447c_0 79 | gst-plugins-base=1.14.0=hbbd80ab_1 80 | gstreamer=1.14.0=hb453b48_1 81 | h5py=2.9.0=py37h7918eee_0 82 | harfbuzz=1.8.8=hffaf4a1_0 83 | hdf5=1.10.4=hb1b8bf9_0 84 | heapdict=1.0.0=py37_2 85 | html5lib=1.0.1=py37_0 86 | icu=58.2=h9c2bf20_1 87 | idna=2.8=py37_0 88 | imageio=2.5.0=py37_0 89 | imagesize=1.1.0=py37_0 90 | importlib_metadata=0.17=py37_1 91 | intel-openmp=2019.4=243 92 | ipykernel=5.1.1=py37h39e3cac_0 93 | ipython=7.6.1=py37h39e3cac_0 94 | ipython_genutils=0.2.0=py37_0 95 | ipywidgets=7.5.0=py_0 96 | isort=4.3.21=py37_0 97 | itsdangerous=1.1.0=py37_0 98 | jbig=2.1=hdba287a_0 99 | jdcal=1.4.1=py_0 100 | jedi=0.13.3=py37_0 101 | jeepney=0.4=py37_0 102 | jinja2=2.10.1=py37_0 103 | joblib=0.13.2=py37_0 104 | jpeg=9b=h024ee3a_2 105 | json5=0.8.4=py_0 106 | jsonschema=3.0.1=py37_0 107 | jupyter=1.0.0=py37_7 108 | jupyter_client=5.3.1=py_0 109 | jupyter_console=6.0.0=py37_0 110 | jupyter_core=4.5.0=py_0 111 | jupyterlab=1.0.2=py37hf63ae98_0 112 | jupyterlab_server=1.0.0=py_0 113 | keyring=18.0.0=py37_0 114 | kiwisolver=1.1.0=py37he6710b0_0 115 | krb5=1.16.1=h173b8e3_7 116 | lazy-object-proxy=1.4.1=py37h7b6447c_0 117 | libarchive=3.3.3=h5d8350f_5 118 | libcurl=7.65.2=h20c2e04_0 119 | libedit=3.1.20181209=hc058e9b_0 120 | libffi=3.2.1=hd88cf55_4 121 | libgcc-ng=9.1.0=hdf63c60_0 122 | libgfortran-ng=7.3.0=hdf63c60_0 123 | liblief=0.9.0=h7725739_2 124 | libpng=1.6.37=hbc83047_0 125 | libsodium=1.0.16=h1bed415_0 126 | libssh2=1.8.2=h1ba5d50_0 127 | libstdcxx-ng=9.1.0=hdf63c60_0 128 | libtiff=4.0.10=h2733197_2 129 | libtool=2.4.6=h7b6447c_5 130 | libuuid=1.0.3=h1bed415_2 131 | libxcb=1.13=h1bed415_1 132 | libxml2=2.9.9=hea5a465_1 133 | libxslt=1.1.33=h7d1a2b0_0 134 | llvmlite=0.29.0=py37hd408876_0 135 | locket=0.2.0=py37_1 136 | lxml=4.3.4=py37hefd8a0e_0 137 | lz4-c=1.8.1.2=h14c3975_0 138 | lzo=2.10=h49e0be7_2 139 | markupsafe=1.1.1=py37h7b6447c_0 140 | matplotlib=3.1.0=py37h5429711_0 141 | mccabe=0.6.1=py37_1 142 | mistune=0.8.4=py37h7b6447c_0 143 | mkl=2019.4=243 144 | mkl-service=2.0.2=py37h7b6447c_0 145 | mkl_fft=1.0.12=py37ha843d7b_0 146 | mkl_random=1.0.2=py37hd81dba3_0 147 | mock=3.0.5=py37_0 148 | more-itertools=7.0.0=py37_0 149 | mpc=1.1.0=h10f8cd9_1 150 | mpfr=4.0.1=hdf1c602_3 151 | mpmath=1.1.0=py37_0 152 | msgpack-python=0.6.1=py37hfd86e86_1 153 | multipledispatch=0.6.0=py37_0 154 | navigator-updater=0.2.1=py37_0 155 | nbconvert=5.5.0=py_0 156 | nbformat=4.4.0=py37_0 157 | ncurses=6.1=he6710b0_1 158 | networkx=2.3=py_0 159 | nltk=3.4.4=py37_0 160 | nose=1.3.7=py37_2 161 | notebook=6.0.0=py37_0 162 | numba=0.44.1=py37h962f231_0 163 | numexpr=2.6.9=py37h9e4a6bb_0 164 | numpy=1.16.4=py37h7e9f1db_0 165 | numpy-base=1.16.4=py37hde5b4d6_0 166 | numpydoc=0.9.1=py_0 167 | olefile=0.46=py37_0 168 | opencv-python=4.2.0.34=pypi_0 169 | openpyxl=2.6.2=py_0 170 | openssl=1.1.1c=h7b6447c_1 171 | packaging=19.0=py37_0 172 | pandas=0.24.2=py37he6710b0_0 173 | pandoc=2.2.3.2=0 174 | pandocfilters=1.4.2=py37_1 175 | pango=1.42.4=h049681c_0 176 | parso=0.5.0=py_0 177 | partd=1.0.0=py_0 178 | patchelf=0.9=he6710b0_3 179 | path.py=12.0.1=py_0 180 | pathlib2=2.3.4=py37_0 181 | patsy=0.5.1=py37_0 182 | pcre=8.43=he6710b0_0 183 | pep8=1.7.1=py37_0 184 | pexpect=4.7.0=py37_0 185 | pickleshare=0.7.5=py37_0 186 | pillow=6.1.0=py37h34e0f95_0 187 | pip=19.1.1=py37_0 188 | pixman=0.38.0=h7b6447c_0 189 | pkginfo=1.5.0.1=py37_0 190 | pluggy=0.12.0=py_0 191 | ply=3.11=py37_0 192 | prometheus_client=0.7.1=py_0 193 | prompt_toolkit=2.0.9=py37_0 194 | psutil=5.6.3=py37h7b6447c_0 195 | ptyprocess=0.6.0=py37_0 196 | py=1.8.0=py37_0 197 | py-lief=0.9.0=py37h7725739_2 198 | pycodestyle=2.5.0=py37_0 199 | pycosat=0.6.3=py37h14c3975_0 200 | pycparser=2.19=py37_0 201 | pycrypto=2.6.1=py37h14c3975_9 202 | pycurl=7.43.0.3=py37h1ba5d50_0 203 | pyflakes=2.1.1=py37_0 204 | pygments=2.4.2=py_0 205 | pylint=2.3.1=py37_0 206 | pyodbc=4.0.26=py37he6710b0_0 207 | pyopenssl=19.0.0=py37_0 208 | pyparsing=2.4.0=py_0 209 | pyqt=5.9.2=py37h05f1152_2 210 | pyrsistent=0.14.11=py37h7b6447c_0 211 | pysocks=1.7.0=py37_0 212 | pytables=3.5.2=py37h71ec239_1 213 | pytest=5.0.1=py37_0 214 | pytest-arraydiff=0.3=py37h39e3cac_0 215 | pytest-astropy=0.5.0=py37_0 216 | pytest-doctestplus=0.3.0=py37_0 217 | pytest-openfiles=0.3.2=py37_0 218 | pytest-remotedata=0.3.1=py37_0 219 | python=3.7.3=h0371630_0 220 | python-dateutil=2.8.0=py37_0 221 | python-libarchive-c=2.8=py37_11 222 | pytz=2019.1=py_0 223 | pywavelets=1.0.3=py37hdd07704_1 224 | pyyaml=5.1.1=py37h7b6447c_0 225 | pyzmq=18.0.0=py37he6710b0_0 226 | qt=5.9.7=h5867ecd_1 227 | qtawesome=0.5.7=py37_1 228 | qtconsole=4.5.1=py_0 229 | qtpy=1.8.0=py_0 230 | readline=7.0=h7b6447c_5 231 | requests=2.22.0=py37_0 232 | rope=0.14.0=py_0 233 | ruamel_yaml=0.15.46=py37h14c3975_0 234 | scikit-image=0.15.0=py37he6710b0_0 235 | scikit-learn=0.21.2=py37hd81dba3_0 236 | scipy=1.3.0=py37h7c811a0_0 237 | seaborn=0.9.0=py37_0 238 | secretstorage=3.1.1=py37_0 239 | send2trash=1.5.0=py37_0 240 | setuptools=41.0.1=py37_0 241 | simplegeneric=0.8.1=py37_2 242 | singledispatch=3.4.0.3=py37_0 243 | sip=4.19.8=py37hf484d3e_0 244 | six=1.12.0=py37_0 245 | snappy=1.1.7=hbae5bb6_3 246 | snowballstemmer=1.9.0=py_0 247 | sortedcollections=1.1.2=py37_0 248 | sortedcontainers=2.1.0=py37_0 249 | soupsieve=1.8=py37_0 250 | sphinx=2.1.2=py_0 251 | sphinxcontrib=1.0=py37_1 252 | sphinxcontrib-applehelp=1.0.1=py_0 253 | sphinxcontrib-devhelp=1.0.1=py_0 254 | sphinxcontrib-htmlhelp=1.0.2=py_0 255 | sphinxcontrib-jsmath=1.0.1=py_0 256 | sphinxcontrib-qthelp=1.0.2=py_0 257 | sphinxcontrib-serializinghtml=1.1.3=py_0 258 | sphinxcontrib-websupport=1.1.2=py_0 259 | spyder=3.3.6=py37_0 260 | spyder-kernels=0.5.1=py37_0 261 | sqlalchemy=1.3.5=py37h7b6447c_0 262 | sqlite=3.29.0=h7b6447c_0 263 | statsmodels=0.10.0=py37hdd07704_0 264 | sympy=1.4=py37_0 265 | tblib=1.4.0=py_0 266 | terminado=0.8.2=py37_0 267 | testpath=0.4.2=py37_0 268 | tk=8.6.8=hbc83047_0 269 | toolz=0.10.0=py_0 270 | tornado=6.0.3=py37h7b6447c_0 271 | tqdm=4.32.1=py_0 272 | traitlets=4.3.2=py37_0 273 | unicodecsv=0.14.1=py37_0 274 | unixodbc=2.3.7=h14c3975_0 275 | urllib3=1.24.2=py37_0 276 | wcwidth=0.1.7=py37_0 277 | webencodings=0.5.1=py37_1 278 | werkzeug=0.15.4=py_0 279 | wheel=0.33.4=py37_0 280 | widgetsnbextension=3.5.0=py37_0 281 | wrapt=1.11.2=py37h7b6447c_0 282 | wurlitzer=1.0.2=py37_0 283 | xlrd=1.2.0=py37_0 284 | xlsxwriter=1.1.8=py_0 285 | xlwt=1.3.0=py37_0 286 | xz=5.2.4=h14c3975_4 287 | yaml=0.1.7=had09818_2 288 | zeromq=4.3.1=he6710b0_3 289 | zict=1.0.0=py_0 290 | zipp=0.5.1=py_0 291 | zlib=1.2.11=h7b6447c_3 292 | zstd=1.3.7=h0b5b093_0 293 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import defaultdict 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | import torch.utils.data.distributed 12 | from torch.cuda.amp import autocast, GradScaler 13 | 14 | from hourglass import StackedHourglass 15 | from loss import LossCalculator 16 | from optim import get_optimizer 17 | from data import load_dataset 18 | from utils import AverageMeter, blend_heatmap 19 | 20 | from torchsummary import summary 21 | 22 | 23 | def distributed_device_train(args): 24 | ngpus_per_node = torch.cuda.device_count() 25 | 26 | args.world_size = ngpus_per_node * args.world_size 27 | 28 | mp.spawn(distributed_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 29 | 30 | return None 31 | 32 | def distributed_worker(device, ngpus_per_node, args): 33 | torch.cuda.set_device(device) 34 | cudnn.benchmark = True 35 | print('%s: Use GPU: %d for training'%(time.ctime(), args.gpu_no[device])) 36 | 37 | rank = args.rank * ngpus_per_node + device 38 | batch_size = int(args.batch_size / ngpus_per_node) 39 | num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node) 40 | 41 | # init process for distributed training 42 | dist.init_process_group(backend = args.dist_backend, 43 | init_method = args.dist_url, 44 | world_size = args.world_size, 45 | rank = rank) 46 | 47 | # load network 48 | network, optimizer, scheduler, loss_calculator = load_network(args, device) 49 | if device == 0: 50 | summary(network, input_size=(3, 512, 512)) 51 | 52 | # load dataset 53 | dataset = load_dataset(args) 54 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 55 | dataloader = torch.utils.data.DataLoader(dataset = dataset, 56 | batch_size = batch_size, 57 | num_workers = num_workers, 58 | pin_memory = True, 59 | sampler = sampler, 60 | collate_fn = dataset.collate_fn) 61 | 62 | # gradient scaler for automatic mixed precision 63 | scaler = GradScaler() if args.amp else None 64 | 65 | # training 66 | for epoch in range(args.start_epoch, args.end_epoch): 67 | sampler.set_epoch(epoch) 68 | 69 | # train one epoch 70 | train_step(dataloader, network, loss_calculator, optimizer, scheduler, scaler, epoch, device, args) 71 | 72 | # adjust learning rate 73 | scheduler.step() 74 | 75 | # save network 76 | if rank % ngpus_per_node == 0: 77 | torch.save({'epoch': epoch+1, 78 | 'state_dict': network.module.state_dict() if hasattr(network, 'module') else network.state_dict(), 79 | 'optimizer': optimizer.state_dict(), 80 | 'scheduler': scheduler.state_dict(), 81 | 'scaler': scaler.state_dict() if scaler is not None else None, 82 | 'loss_log': loss_calculator.log}, os.path.join(args.save_path, 'check_point_%d.pth'%(epoch+1))) 83 | 84 | return None 85 | 86 | def train_step(dataloader, network, loss_calculator, optimizer, scheduler, scaler, epoch, device, args): 87 | time_logger = defaultdict(AverageMeter) 88 | 89 | network.train() 90 | 91 | tictoc = time.time() 92 | for iteration, (image, gt_heatmap, gt_offset, gt_size, gt_mask, gt_dict) in enumerate(dataloader, 1): 93 | time_logger['data'].update(time.time() - tictoc) 94 | 95 | # forward 96 | autocast_flag = True if scaler is not None else False 97 | with autocast(enabled=autocast_flag): 98 | tictoc = time.time() 99 | outputs = network(image.to(device)) 100 | time_logger['forward'].update(time.time() - tictoc) 101 | 102 | ## calculate losses per scales 103 | tictoc = time.time() 104 | total_loss = 0 105 | for output in outputs.split(1, dim=1): 106 | output.squeeze_(1) 107 | pred_heatmap, pred_offset, pred_size = output.split([args.num_cls, 2, 2], dim=1) 108 | pred_heatmap = torch.sigmoid(pred_heatmap) 109 | if args.normalized_coord: 110 | pred_offset = torch.sigmoid(pred_offset) 111 | pred_size = torch.sigmoid(pred_size) 112 | 113 | _total_loss = loss_calculator(pred_heatmap, 114 | pred_offset, 115 | pred_size, 116 | gt_heatmap.to(device), 117 | gt_offset.to(device), 118 | gt_size.to(device), 119 | gt_mask.to(device)) 120 | total_loss += _total_loss 121 | time_logger['loss'].update(time.time() - tictoc) 122 | 123 | # gradient accumulation 124 | optimizatoin_flag = (iteration % args.sub_divisions == 0) or (iteration == len(dataloader)) 125 | 126 | # backward 127 | tictoc = time.time() 128 | if scaler is not None: 129 | scaler.scale(total_loss).backward() 130 | if optimizatoin_flag: 131 | scaler.step(optimizer) 132 | scaler.update() 133 | else: 134 | total_loss.backward() 135 | if optimizatoin_flag: 136 | optimizer.step() 137 | 138 | if optimizatoin_flag: 139 | optimizer.zero_grad() 140 | time_logger['backward'].update(time.time() - tictoc) 141 | 142 | # loging 143 | if (iteration % args.print_interval == 0) and (device == 0): 144 | loss_log = loss_calculator.get_log() 145 | _log = '%s: Epoch [%2d/%2d]'%(time.ctime(), epoch, args.end_epoch) 146 | _log += ', Iteration [%4d/%4d]'%(iteration, len(dataloader)) 147 | _log += ', Loss [%s]'%(loss_log) 148 | _log += ', Time(ms) [data: %6.2f'%(time_logger['data'].avg * 1000) 149 | _log += ', forward: %6.2f'%(time_logger['forward'].avg * 1000) 150 | _log += ', backward: %6.2f'%(time_logger['backward'].avg * 1000) 151 | _log += ', loss: %6.2f]'%(time_logger['loss'].avg * 1000) 152 | print(_log) 153 | 154 | # save blended image 155 | blended_pred = blend_heatmap(image[0], pred_heatmap[0], args.pretrained) 156 | blended_gt = blend_heatmap(image[0], gt_heatmap[0], args.pretrained) 157 | blended_pred.save(os.path.join(args.save_path, 'training_log', 'training_pred.png')) 158 | blended_gt.save(os.path.join(args.save_path, 'training_log', 'training_gt.png')) 159 | 160 | tictoc = time.time() 161 | 162 | return None 163 | 164 | def load_network(args, device): 165 | network = StackedHourglass(num_stack = args.num_stack, 166 | in_ch = args.hourglass_inch, 167 | out_ch = args.num_cls+4, 168 | increase_ch = args.increase_ch, 169 | activation = args.activation, 170 | pool = args.pool, 171 | neck_activation = args.neck_activation, 172 | neck_pool = args.neck_pool).to(device) 173 | 174 | if len(args.gpu_no) > 1 and args.train_flag: 175 | network = torch.nn.parallel.DistributedDataParallel(network, device_ids=[device]) 176 | 177 | optimizer, scheduler, loss_calculator = None, None, None 178 | if args.train_flag: 179 | optimizer, scheduler = get_optimizer(network = network, 180 | lr = args.lr, 181 | lr_milestone = args.lr_milestone, 182 | lr_gamma = args.lr_gamma) 183 | 184 | loss_calculator = LossCalculator(hm_weight = args.hm_weight, 185 | offset_weight = args.offset_weight, 186 | size_weight = args.size_weight, 187 | focal_alpha = args.focal_alpha, 188 | focal_beta = args.focal_beta).to(device) 189 | 190 | if args.model_load: 191 | check_point = torch.load(args.model_load, map_location=device) 192 | network.load_state_dict(check_point['state_dict']) 193 | print('%s: Weights are loaded from %s'%(time.ctime(), args.model_load)) 194 | 195 | if args.train_flag: 196 | optimizer.load_state_dict(check_point['optimizer']) 197 | loss_calculator.log = check_point['loss_log'] 198 | if scheduler is not None: 199 | scheduler.load_state_dict(check_point['scheduler']) 200 | 201 | return network, optimizer, scheduler, loss_calculator 202 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def box2hm(boxes, labels, imsize, scale_factor=4, num_cls=2, normalized=False): 5 | width, height = imsize[0]//scale_factor, imsize[1]//scale_factor 6 | heat_map = np.zeros((num_cls, height, width), dtype=np.float32) 7 | offset_map = np.zeros((2, height, width), dtype=np.float32) 8 | size_map = np.zeros((2, height, width), dtype=np.float32) 9 | mask = np.zeros((1, height, width), dtype=np.float32) 10 | 11 | if boxes is None: 12 | return heat_map, offset_map, size_map, mask 13 | 14 | for box, label in zip(boxes, labels): 15 | if box is None: 16 | continue 17 | # sclae change (image to feature map) 18 | xmin, ymin, xmax, ymax = [val/scale_factor for val in box] 19 | 20 | # center point 21 | xcen, ycen = (xmax+xmin)/2, (ymax+ymin)/2 22 | 23 | # index of heat map 24 | xind, yind = int(xcen), int(ycen) 25 | 26 | # set mask 27 | mask[:, yind, xind] = 1.0 28 | 29 | # offset, size 30 | xoff, yoff = xcen - xind, ycen - yind 31 | xsize, ysize = xmax - xmin, ymax - ymin 32 | 33 | if normalized: 34 | xoff, yoff = xoff/scale_factor, yoff/scale_factor 35 | xsize, ysize = xsize/width, ysize/height 36 | 37 | # assign offset, size and confidence 38 | offset_map[:, yind, xind] = np.array([xoff, yoff]) 39 | size_map[:, yind, xind] = np.array([xsize, ysize]) 40 | 41 | # heatmap 42 | radius = ((xcen-xmin)**2 + (ycen-ymin)**2)**0.5 43 | draw_gaussian(heat_map[label], (xind, yind), radius) 44 | 45 | return heat_map, offset_map, size_map, mask 46 | 47 | 48 | def gaussian2D(shape, sigma=1): 49 | m, n = list(map(int, shape)) 50 | y, x = np.ogrid[-m:m+1,-n:n+1] 51 | 52 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 53 | return h 54 | 55 | 56 | def draw_gaussian(heatmap, center, radius): 57 | gaussian = gaussian2D((radius, radius), sigma=radius/3) 58 | radius = int(radius) 59 | 60 | x, y = center 61 | 62 | height, width = heatmap.shape[0:2] 63 | 64 | left, right = min(x, radius), min(width - x, radius + 1) 65 | top, bottom = min(y, radius), min(height - y, radius + 1) 66 | 67 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 68 | masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] 69 | 70 | np.maximum(masked_heatmap, masked_gaussian, out=masked_heatmap) 71 | 72 | 73 | def hm2box(heatmap, offset, wh, scale_factor=4, topk=10, conf_th=0.3, normalized=False): 74 | height, width = heatmap.shape[-2:] 75 | 76 | max_pool = torch.nn.MaxPool2d(3, stride=1, padding=3//2) 77 | 78 | isPeak = max_pool(heatmap) == heatmap 79 | peakmap = heatmap * isPeak 80 | 81 | scores, indices = peakmap.flatten().topk(topk) 82 | 83 | clss = torch.floor_divide(indices, (height*width)) 84 | inds = torch.fmod(indices, (height*width)) 85 | yinds = torch.floor_divide(inds, width) 86 | xinds = torch.fmod(inds, width) 87 | 88 | xoffs = offset[0, yinds, xinds] 89 | xsizs = wh[0, yinds, xinds] 90 | 91 | yoffs = offset[1, yinds, xinds] 92 | ysizs = wh[1, yinds, xinds] 93 | 94 | if normalized: 95 | xoffs = xoffs * scale_factor 96 | yoffs = yoffs * scale_factor 97 | xsizs = xsizs * width 98 | ysizs = ysizs * height 99 | 100 | xmin = (xinds + xoffs - xsizs/2) * scale_factor 101 | ymin = (yinds + yoffs - ysizs/2) * scale_factor 102 | xmax = (xinds + xoffs + xsizs/2) * scale_factor 103 | ymax = (yinds + yoffs + ysizs/2) * scale_factor 104 | 105 | boxes = torch.stack([xmin, ymin, xmax, ymax], dim=1) # Tensor: topk x 4 106 | 107 | # confidence thresholding 108 | over_threshold = scores >= conf_th 109 | 110 | return boxes[over_threshold], clss[over_threshold], scores[over_threshold] 111 | 112 | if __name__ == '__main__': 113 | boxes = [[10, 20, 100, 200]] 114 | labels = [1] 115 | imsize = 512, 512 116 | normalized = True 117 | print('original box: ', boxes) 118 | 119 | # numpy 120 | heatmap_np, offset_np, wh_np, mask = box2hm(boxes, labels, imsize, normalized=normalized) 121 | print('Shape of heatmap: ', heatmap_np.shape) 122 | print('Value of heatmap: ', heatmap_np[:, 27, 13]) 123 | print('Value of offset: ', offset_np[:, 27, 13]) 124 | print('Value of wh: ', wh_np[:, 27, 13]) 125 | 126 | # numpy to tensor 127 | heatmap_ten = torch.from_numpy(heatmap_np) 128 | offset_ten = torch.from_numpy(offset_np) 129 | wh_ten = torch.from_numpy(wh_np) 130 | _boxes, _labels, _scores = hm2box(heatmap_ten, offset_ten, wh_ten, normalized=normalized) 131 | print('re-calculated box: ', _boxes.tolist()) 132 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from PIL import Image, ImageDraw, ImageFont 4 | 5 | import torch 6 | import torchvision 7 | import torchvision.transforms.functional as TF 8 | 9 | def save_pickle(path, data): 10 | with open(path, 'wb') as f: 11 | pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) 12 | return None 13 | 14 | def load_pickle(path): 15 | with open(path, 'rb') as f: 16 | data = pickle.load(f) 17 | return data 18 | 19 | class AverageMeter: 20 | def __init__(self): 21 | self.reset() 22 | 23 | def reset(self): 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n = 1): 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | def ten2pil(tensor, pretrained): 34 | if pretrained is None: 35 | denormalize = lambda x: x 36 | else: 37 | denormalize = get_normalizer(denormalize = True, pretrained = pretrained) 38 | if tensor.is_cuda: 39 | tensor = tensor.cpu() 40 | tensor = torchvision.utils.make_grid(tensor, pad_value=0.5) 41 | image = TF.to_pil_image(denormalize(tensor).clamp_(0.0, 1.0)) 42 | return image 43 | 44 | def draw_box(pil, box, width=2, color=(0, 0, 255)): 45 | draw = ImageDraw.Draw(pil) 46 | draw.rectangle(list(map(int, box)), width=width, outline=color, fill=None) 47 | return pil 48 | 49 | def write_text(pil, text, coordinate, fontsize=15, fontcolor='red'): 50 | draw = ImageDraw.Draw(pil) 51 | font = ImageFont.truetype('arial.ttf', size=fontsize) 52 | draw.text(coordinate, text, fill=fontcolor, font=font) 53 | return pil 54 | 55 | def get_normalizer(pretrained, denormalize = False): 56 | if pretrained.lower() == "imagenet": 57 | MEAN = [0.485, 0.456, 0.406] 58 | STD = [0.229, 0.224, 0.225] 59 | elif pretrained.lower() == "scratch": 60 | MEAN = [0.5, 0.5, 0.5] 61 | STD = [0.5, 0.5, 0.5] 62 | else: 63 | raise NotImplementedError("Not expected dataset pretrained parameter: %s"%pretrained) 64 | 65 | if denormalize: 66 | MEAN = [-mean/std for mean, std in zip(MEAN, STD)] 67 | STD = [1/std for std in STD] 68 | return torchvision.transforms.Normalize(mean=MEAN, std=STD) 69 | 70 | def blend_heatmap(image, heatmap, pretrained): 71 | image_pil = ten2pil(image.detach().cpu(), pretrained=pretrained) 72 | 73 | for c in range(heatmap.shape[0]): 74 | heatmap_rgb = [np.zeros(heatmap.shape[1:], dtype=np.uint8)]*2 75 | 76 | _heatmap = heatmap[c] 77 | _heatmap_np = _heatmap.detach().cpu().numpy() * 255 78 | _heatmap_np = _heatmap_np.astype(np.uint8) 79 | 80 | # gray to rgb 81 | heatmap_rgb.insert(c, _heatmap_np) 82 | 83 | heatmap_pil = Image.fromarray(np.stack(heatmap_rgb, axis=-1)).resize(image_pil.size).convert('RGB') 84 | image_pil = Image.blend(image_pil, heatmap_pil, 0.3) 85 | return image_pil 86 | 87 | def imload(path, pretrained, size=None): 88 | img_pil = Image.open(path).convert('RGB') 89 | origin_size = img_pil.size 90 | if size: 91 | img_pil = img_pil.resize((size, size)) 92 | normalizer = get_normalizer(pretrained=pretrained) 93 | img_ten = normalizer(TF.to_tensor(img_pil)).unsqueeze(0) 94 | return img_ten, img_pil, origin_size 95 | --------------------------------------------------------------------------------