├── requirements.txt ├── assets ├── bird.jpg ├── car.jpg ├── cat.jpg ├── dog.jpg ├── horse.jpg ├── sheep.jpg ├── street.jpg ├── aeroplane.jpg └── tvmonitor.jpg ├── data ├── voc.txt └── coco.txt ├── scripts ├── predict.sh ├── eval_yolo.sh ├── eval_coco.sh ├── eval_voc.sh ├── eval_ilsvrc.sh ├── train_yolo.sh ├── train_coco.sh ├── train_voc.sh └── train_ilsvrc.sh ├── net ├── __init__.py ├── centernet.py └── backbone │ └── resnet.py ├── core ├── dataset │ ├── __init__.py │ ├── yolo.py │ ├── ilsvrc.py │ ├── pascal.py │ ├── coco.py │ ├── utils.py │ └── dataset.py ├── loss.py ├── detect.py ├── helper.py └── map.py ├── tools ├── evaluate.py ├── predict.py ├── args.py └── train.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | torchvision==0.9.1 3 | opencv-python 4 | numpy -------------------------------------------------------------------------------- /assets/bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/bird.jpg -------------------------------------------------------------------------------- /assets/car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/car.jpg -------------------------------------------------------------------------------- /assets/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/cat.jpg -------------------------------------------------------------------------------- /assets/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/dog.jpg -------------------------------------------------------------------------------- /assets/horse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/horse.jpg -------------------------------------------------------------------------------- /assets/sheep.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/sheep.jpg -------------------------------------------------------------------------------- /assets/street.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/street.jpg -------------------------------------------------------------------------------- /assets/aeroplane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/aeroplane.jpg -------------------------------------------------------------------------------- /assets/tvmonitor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Runist/torch_CenterNet/HEAD/assets/tvmonitor.jpg -------------------------------------------------------------------------------- /data/voc.txt: -------------------------------------------------------------------------------- 1 | aeroplane 2 | bicycle 3 | bird 4 | boat 5 | bottle 6 | bus 7 | car 8 | cat 9 | chair 10 | cow 11 | diningtable 12 | dog 13 | horse 14 | motorbike 15 | person 16 | pottedplant 17 | sheep 18 | sofa 19 | train 20 | tvmonitor -------------------------------------------------------------------------------- /scripts/predict.sh: -------------------------------------------------------------------------------- 1 | python tools/predict.py \ 2 | --gpu=0 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --num_classes=20 \ 6 | --classes_info_file="./data/voc.txt" \ 7 | --outputs_dir="./outputs/yolo_voc" \ 8 | --test_weight="./logs/yolo_voc/weights/epoch=149_loss=0.5854_val_loss=3.3467.pt" -------------------------------------------------------------------------------- /net/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : __init__.py 3 | # @Author: Runist 4 | # @Time : 2022/3/29 9:16 5 | # @Software: PyCharm 6 | # @Brief: 7 | from .backbone import resnet 8 | from .centernet import CenterNet, CenterNetPoolingNMS 9 | 10 | __all__ = ['resnet', 'CenterNet', 'CenterNetPoolingNMS'] 11 | -------------------------------------------------------------------------------- /scripts/eval_yolo.sh: -------------------------------------------------------------------------------- 1 | python tools/evaluate.py \ 2 | --gpu=3 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --dataset_val_path="./data/voc_0712_val.txt" \ 6 | --dataset_format="yolo" \ 7 | --num_classes=20 \ 8 | --classes_info_file="./data/voc.txt" \ 9 | --outputs_dir="./outputs/yolo_voc/" \ 10 | --test_weight="./logs/yolo_voc/weights/epoch=149_loss=0.5854_val_loss=3.3467.pt" -------------------------------------------------------------------------------- /scripts/eval_coco.sh: -------------------------------------------------------------------------------- 1 | python tools/evaluate.py \ 2 | --gpu=0 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --dataset_val_path="./data/coco/annotations/instances_val2017.json" \ 6 | --image_val_dir="./data/coco/val2017" \ 7 | --dataset_format="coco" \ 8 | --num_classes=80 \ 9 | --classes_info_file="./data/coco.txt" \ 10 | --outputs_dir="./outputs/coco/" \ 11 | --test_weight="./logs/coco/weights/epoch=149_loss=1.3958_val_loss=2.8635.pt" -------------------------------------------------------------------------------- /scripts/eval_voc.sh: -------------------------------------------------------------------------------- 1 | python tools/evaluate.py \ 2 | --gpu=0 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --dataset_val_path="./data/VOC2012/ImageSets/Main/val.txt" \ 6 | --image_val_dir="./data/VOC2012/JPEGImages/" \ 7 | --annotation_val_dir="./data/VOC2012/Annotations/" \ 8 | --dataset_format="voc" \ 9 | --num_classes=20 \ 10 | --classes_info_file="./data/voc.txt" \ 11 | --outputs_dir="./outputs/voc/" \ 12 | --test_weight="./logs/voc/weights/epoch=149_loss=1.0000_val_loss=4.0783.pt" -------------------------------------------------------------------------------- /scripts/eval_ilsvrc.sh: -------------------------------------------------------------------------------- 1 | python tools/evaluate.py \ 2 | --gpu=2 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --dataset_val_path="./data/ImageNet/ILSVRC/ImageSets/DET/val.txt" \ 6 | --image_val_dir="./data/ImageNet/ILSVRC/Data/DET/val/" \ 7 | --annotation_val_dir="./data/ImageNet/ILSVRC/Annotations/DET/val/" \ 8 | --dataset_format="ilsvrc" \ 9 | --num_classes=200 \ 10 | --classes_info_file="./data/ilsvrc.txt" \ 11 | --outputs_dir="./outputs/ilsvrc/" \ 12 | --test_weight="./logs/yolo_ilsvrc/weights/epoch=198_loss=0.8236_val_loss=5.0576.pt" -------------------------------------------------------------------------------- /scripts/train_yolo.sh: -------------------------------------------------------------------------------- 1 | python tools/train.py \ 2 | --gpu=0,1,2,3 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --dataset_train_path="./data/voc_0712_train.txt" \ 6 | --dataset_val_path="./data/voc_0712_val.txt" \ 7 | --dataset_format="yolo" \ 8 | --num_classes=20 \ 9 | --classes_info_file="./data/voc.txt" \ 10 | --num_workers=8 \ 11 | --pretrain_weight_path="resnet50.pth" \ 12 | --learn_rate_init=5e-4 \ 13 | --learn_rate_end=5e-6 \ 14 | --warmup_epochs=5 \ 15 | --freeze_epochs=50 \ 16 | --unfreeze_epochs=100 \ 17 | --freeze_batch_size=160 \ 18 | --unfreeze_batch_size=80 \ 19 | --logs_dir="./logs/yolo_voc_resnet50" -------------------------------------------------------------------------------- /core/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : __init__.py 3 | # @Author: Runist 4 | # @Time : 2022/4/13 11:46 5 | # @Software: PyCharm 6 | # @Brief: 7 | from .dataset import CenterNetDataset 8 | from .coco import CocoDataset 9 | from .pascal import PascalDataset 10 | from .yolo import YoloDataset 11 | from .ilsvrc import ILSVRCDataset 12 | from .utils import image_resize, preprocess_input, recover_input, gaussian_radius, draw_gaussian 13 | 14 | __all__ = ['CenterNetDataset', 'CocoDataset', 'PascalDataset', 'YoloDataset', 'ILSVRCDataset', 15 | 'image_resize', 'preprocess_input', 'recover_input', 'gaussian_radius', 'draw_gaussian'] 16 | -------------------------------------------------------------------------------- /tools/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : evaluate.py 3 | # @Author: Runist 4 | # @Time : 2022/3/29 10:57 5 | # @Software: PyCharm 6 | # @Brief: Evaluate map 7 | 8 | from args import args, dev, class_names 9 | from core.map import get_map 10 | from core.helper import remove_dir_and_create_dir 11 | from net.centernet import CenterNet 12 | 13 | import os 14 | import xml.etree.ElementTree as ET 15 | from PIL import Image 16 | from tqdm import tqdm 17 | import torch 18 | import numpy as np 19 | 20 | 21 | if __name__ == '__main__': 22 | model = torch.load(args.test_weight) 23 | 24 | output_files_path = os.path.join(args.outputs_dir, "map") 25 | remove_dir_and_create_dir(output_files_path) 26 | 27 | get_map(args, output_files_path, class_names, model, dev) 28 | -------------------------------------------------------------------------------- /scripts/train_coco.sh: -------------------------------------------------------------------------------- 1 | python tools/train.py \ 2 | --gpu=0,1,2,3 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --dataset_train_path="./data/coco/annotations/instances_train2017.json" \ 6 | --dataset_val_path="./data/coco/annotations/instances_val2017.json" \ 7 | --image_train_dir="./data/coco/train2017" \ 8 | --image_val_dir="./data/coco/val2017" \ 9 | --dataset_format="coco" \ 10 | --num_classes=80 \ 11 | --classes_info_file="./data/coco.txt" \ 12 | --pretrain_weight_path="./logs/yolo_ilsvrc/weights/epoch=144_loss=0.9981_val_loss=4.5338.pt" \ 13 | --learn_rate_init=3e-4 \ 14 | --learn_rate_end=1e-6 \ 15 | --num_workers=8 \ 16 | --freeze_epochs=50 \ 17 | --unfreeze_epochs=100 \ 18 | --freeze_batch_size=160 \ 19 | --unfreeze_batch_size=80 \ 20 | --logs_dir="./logs/coco" -------------------------------------------------------------------------------- /scripts/train_voc.sh: -------------------------------------------------------------------------------- 1 | python tools/train.py \ 2 | --gpu=1,2 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --dataset_train_path="./data/VOCdevkit/VOC2007/ImageSets/Main/train.txt" \ 6 | --dataset_val_path="./data/VOCdevkit/VOC2007/ImageSets/Main/val.txt" \ 7 | --image_train_dir="./data/VOCdevkit/VOC2007/JPEGImages/" \ 8 | --image_val_dir="./data/VOCdevkit/VOC2007/JPEGImages/" \ 9 | --annotation_train_dir="./data/VOCdevkit/VOC2007/Annotations/" \ 10 | --annotation_val_dir="./data/VOCdevkit/VOC2007/Annotations/" \ 11 | --dataset_format="voc" \ 12 | --num_classes=20 \ 13 | --classes_info_file="./data/voc.txt" \ 14 | --num_workers=12 \ 15 | --pretrain_weight_path="./resnet50.pth" \ 16 | --learn_rate_init=1e-4 \ 17 | --learn_rate_end=1e-6 \ 18 | --freeze_epochs=50 \ 19 | --unfreeze_epochs=100 \ 20 | --freeze_batch_size=90 \ 21 | --unfreeze_batch_size=34 \ 22 | --logs_dir="./logs/voc" -------------------------------------------------------------------------------- /scripts/train_ilsvrc.sh: -------------------------------------------------------------------------------- 1 | python tools/train.py \ 2 | --gpu=0 \ 3 | --input_height=416 \ 4 | --input_width=416 \ 5 | --dataset_train_path="./data/ImageNet/ILSVRC/ImageSets/DET/train.txt" \ 6 | --dataset_val_path="./data/ImageNet/ILSVRC/ImageSets/DET/val.txt" \ 7 | --image_train_dir="./data/ImageNet/ILSVRC/Data/DET/train/" \ 8 | --image_val_dir="./data/ImageNet/ILSVRC/Data/DET/val/" \ 9 | --annotation_train_dir="./data/ImageNet/ILSVRC/Annotations/DET/train/" \ 10 | --annotation_val_dir="./data/ImageNet/ILSVRC/Annotations/DET/val/" \ 11 | --dataset_format="ilsvrc" \ 12 | --num_classes=200 \ 13 | --classes_info_file="./data/ilsvrc.txt" \ 14 | --num_workers=12 \ 15 | --pretrain_weight_path="./resnet50.pth" \ 16 | --learn_rate_init=1e-4 \ 17 | --learn_rate_end=1e-6 \ 18 | --freeze_epochs=50 \ 19 | --unfreeze_epochs=100 \ 20 | --freeze_batch_size=32 \ 21 | --unfreeze_batch_size=26 \ 22 | --logs_dir="./logs/ilsvrc" -------------------------------------------------------------------------------- /data/coco.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /tools/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : predict.py 3 | # @Author: Runist 4 | # @Time : 2022/3/29 11:01 5 | # @Software: PyCharm 6 | # @Brief: Predict script 7 | 8 | from args import args, dev, class_names 9 | from core.detect import predict 10 | from core.helper import draw_bbox, remove_dir_and_create_dir 11 | 12 | from PIL import Image 13 | import cv2 as cv 14 | import torch 15 | import numpy as np 16 | import os 17 | from tqdm import tqdm 18 | import random 19 | 20 | 21 | if __name__ == '__main__': 22 | test_folder = "./assets" 23 | outputs_dir = "{}/images".format(args.outputs_dir) 24 | remove_dir_and_create_dir(outputs_dir) 25 | 26 | model = torch.load(args.test_weight) 27 | 28 | model.eval() 29 | with torch.no_grad(): 30 | for file in tqdm(os.listdir(test_folder)): 31 | image_path = os.path.join(test_folder, file) 32 | image = Image.open(image_path) 33 | image = np.array(image) 34 | 35 | outputs = predict(image, model, dev, args) 36 | 37 | if len(outputs) == 0: 38 | image = Image.fromarray(image) 39 | image.save("{}/{}".format(outputs_dir, file)) 40 | continue 41 | 42 | outputs = outputs.data.cpu().numpy() 43 | labels = outputs[:, 5] 44 | scores = outputs[:, 4] 45 | bboxes = outputs[:, :4] 46 | 47 | image = draw_bbox(image, bboxes, labels, class_names, scores=scores, show_name=True) 48 | 49 | image = Image.fromarray(image) 50 | image.save("{}/images/{}".format(args.outputs_dir, file)) 51 | -------------------------------------------------------------------------------- /core/dataset/yolo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : yolo.py 3 | # @Author: Runist 4 | # @Time : 2022/4/13 9:42 5 | # @Software: PyCharm 6 | # @Brief: 7 | from core.dataset import CenterNetDataset 8 | 9 | import os 10 | import numpy as np 11 | from PIL import Image 12 | 13 | 14 | class YoloDataset(CenterNetDataset): 15 | def load_annotations(self, annotations_path): 16 | """ 17 | Load image and label info from yolo format annotation file. 18 | The format is "image_path|x1,y1,x2,y2,l1|x1,y1,x2,y2,l2" 19 | 20 | Args: 21 | annotations_path: Annotation file path 22 | 23 | Returns: image and label id list 24 | 25 | """ 26 | with open(annotations_path, 'r', encoding='utf-8') as f: 27 | txt = f.readlines() 28 | annotations = [line.strip().strip("|") for line in txt] 29 | 30 | return annotations 31 | 32 | def parse_annotation(self, index): 33 | """ 34 | Parse self.annotation element and read image and bounding boxes. 35 | 36 | Args: 37 | index: index for self.annotation 38 | 39 | Returns: image, bboxes 40 | 41 | """ 42 | annotation = self.annotations[index] 43 | 44 | line = annotation.split("|") 45 | image_path = line[0] 46 | 47 | if not os.path.exists(image_path): 48 | raise KeyError("%s does not exist ... " % image_path) 49 | 50 | image = Image.open(image_path) 51 | image = np.array(image) 52 | 53 | if len(image.shape) == 2: 54 | image = np.expand_dims(image, axis=-1) 55 | image = image.repeat(3, axis=-1) 56 | if image.shape[-1] == 4: 57 | image = image[:, :, :-1] 58 | 59 | bboxes = np.array([list(map(lambda x: int(float(x)), box.split(','))) for box in line[1:]]) 60 | 61 | return image, bboxes 62 | -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : loss.py 3 | # @Author: Runist 4 | # @Time : 2022/3/28 21:10 5 | # @Software: PyCharm 6 | # @Brief: Loss function 7 | 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | def focal_loss(pred, target): 14 | """ 15 | classifier loss of focal loss 16 | Args: 17 | pred: heatmap of prediction 18 | target: heatmap of ground truth 19 | 20 | Returns: cls loss 21 | 22 | """ 23 | # Find every image positive points and negative points, 24 | # one bounding box corresponds to one positive point, 25 | # except positive points, other feature points are negative sample. 26 | pos_inds = target.eq(1).float() 27 | neg_inds = target.lt(1).float() 28 | 29 | # The negative samples near the positive sample feature point have smaller weights 30 | neg_weights = torch.pow(1 - target, 4) 31 | loss = 0 32 | pred = torch.clamp(pred, 1e-6, 1 - 1e-6) 33 | 34 | # Calculate Focal Loss. 35 | # The hard to classify sample weight is large, easy to classify sample weight is small. 36 | pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds 37 | neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_inds * neg_weights 38 | 39 | # Loss normalization is carried out 40 | num_pos = pos_inds.float().sum() 41 | pos_loss = pos_loss.sum() 42 | neg_loss = neg_loss.sum() 43 | 44 | if num_pos == 0: 45 | loss = loss - neg_loss 46 | else: 47 | loss = loss - (pos_loss + neg_loss) / num_pos 48 | 49 | return loss 50 | 51 | 52 | def l1_loss(pred, target, mask): 53 | """ 54 | Calculate l1 loss 55 | Args: 56 | pred: offset detection result 57 | target: offset ground truth 58 | mask: offset mask, only center point is 1, other place is 0 59 | 60 | Returns: l1 loss 61 | 62 | """ 63 | expand_mask = torch.unsqueeze(mask, -1).repeat(1, 1, 1, 2) 64 | 65 | # Don't calculate loss in the position without ground truth. 66 | loss = F.l1_loss(pred * expand_mask, target * expand_mask, reduction='sum') 67 | 68 | loss = loss / (mask.sum() + 1e-7) 69 | 70 | return loss 71 | -------------------------------------------------------------------------------- /core/dataset/ilsvrc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : ilsvrc.py 3 | # @Author: Runist 4 | # @Time : 2022/4/28 9:49 5 | # @Software: PyCharm 6 | # @Brief: ILSVRC dataset 7 | from core.dataset import CenterNetDataset 8 | 9 | import os 10 | import numpy as np 11 | from PIL import Image 12 | import xml.etree.ElementTree as ET 13 | 14 | 15 | class ILSVRCDataset(CenterNetDataset): 16 | def __init__(self, image_dir, annotation_dir, class_names, annotation_path, input_shape, num_classes, is_train): 17 | self.image_dir = image_dir 18 | self.annotation_dir = annotation_dir 19 | self.class_names = class_names 20 | 21 | super().__init__(annotation_path, input_shape, num_classes, is_train) 22 | 23 | def load_annotations(self, annotations_path): 24 | """ 25 | Load image and label info from voc*.txt file. 26 | 27 | Args: 28 | annotations_path: Annotation file path 29 | 30 | Returns: image and label id list 31 | 32 | """ 33 | annotations = [] 34 | with open(annotations_path, 'r', encoding='utf-8') as f: 35 | txt = f.readlines() 36 | 37 | for line in txt: 38 | line = line.strip() 39 | if "extra" in line: 40 | continue 41 | annotations.append(line.split()[0]) 42 | 43 | return annotations 44 | 45 | def parse_annotation(self, index): 46 | """ 47 | Parse self.annotation element and read image and bounding boxes. 48 | 49 | Args: 50 | index: index for self.annotation 51 | 52 | Returns: image, bboxes 53 | 54 | """ 55 | path = self.annotations[index] 56 | 57 | image_path = os.path.join(self.image_dir, path + ".JPEG") 58 | xml = ET.parse(os.path.join(self.annotation_dir, path + ".xml")).getroot() 59 | 60 | bboxes = [] 61 | for obj in xml.iter("object"): 62 | 63 | name = obj.find("name").text.strip() 64 | bbox = obj.find("bndbox") 65 | 66 | xmin = int(float(bbox.find("xmin").text)) 67 | ymin = int(float(bbox.find("ymin").text)) 68 | xmax = int(float(bbox.find("xmax").text)) 69 | ymax = int(float(bbox.find("ymax").text)) 70 | bboxes.append([xmin, ymin, xmax, ymax, self.class_names.index(name)]) 71 | 72 | bboxes = np.array(bboxes) 73 | 74 | image = Image.open(image_path) 75 | image = np.array(image) 76 | if len(image.shape) == 2: 77 | image = np.expand_dims(image, axis=-1) 78 | image = image.repeat(3, axis=-1) 79 | if image.shape[-1] == 4: 80 | image = image[:, :, :-1] 81 | 82 | return image, bboxes 83 | -------------------------------------------------------------------------------- /core/dataset/pascal.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : pascal.py 3 | # @Author: Runist 4 | # @Time : 2022/3/28 17:32 5 | # @Software: PyCharm 6 | # @Brief: pascal VOC Dataset 7 | from core.dataset import CenterNetDataset 8 | 9 | import os 10 | import numpy as np 11 | from PIL import Image 12 | import xml.etree.ElementTree as ET 13 | 14 | 15 | class PascalDataset(CenterNetDataset): 16 | def __init__(self, image_dir, annotation_dir, class_names, annotation_path, input_shape, num_classes, is_train): 17 | self.image_dir = image_dir 18 | self.annotation_dir = annotation_dir 19 | self.class_names = class_names 20 | 21 | super().__init__(annotation_path, input_shape, num_classes, is_train) 22 | 23 | def load_annotations(self, annotations_path): 24 | """ 25 | Load image and label info from voc*.txt file. 26 | 27 | Args: 28 | annotations_path: Annotation file path 29 | 30 | Returns: image and label id list 31 | 32 | """ 33 | with open(annotations_path, 'r', encoding='utf-8') as f: 34 | txt = f.readlines() 35 | annotations = [line.strip().split()[0] for line in txt] 36 | return annotations 37 | 38 | def parse_annotation(self, index): 39 | """ 40 | Parse self.annotation element and read image and bounding boxes. 41 | 42 | Args: 43 | index: index for self.annotation 44 | 45 | Returns: image, bboxes 46 | 47 | """ 48 | img_id = self.annotations[index] 49 | 50 | image_path = os.path.join(self.image_dir, img_id + ".jpg") 51 | xml = ET.parse(os.path.join(self.annotation_dir, img_id + ".xml")).getroot() 52 | 53 | bboxes = [] 54 | for obj in xml.iter("object"): 55 | difficult = obj.find("difficult") 56 | if difficult is not None: 57 | difficult = int(difficult.text) == 1 58 | else: 59 | difficult = False 60 | 61 | if difficult: 62 | continue 63 | 64 | name = obj.find("name").text.strip() 65 | bbox = obj.find("bndbox") 66 | 67 | xmin = int(float(bbox.find("xmin").text)) 68 | ymin = int(float(bbox.find("ymin").text)) 69 | xmax = int(float(bbox.find("xmax").text)) 70 | ymax = int(float(bbox.find("ymax").text)) 71 | bboxes.append([xmin, ymin, xmax, ymax, self.class_names.index(name)]) 72 | 73 | bboxes = np.array(bboxes) 74 | 75 | image = Image.open(image_path) 76 | image = np.array(image) 77 | if len(image.shape) == 2: 78 | image = np.expand_dims(image, axis=-1) 79 | image = image.repeat(3, axis=-1) 80 | if image.shape[-1] == 4: 81 | image = image[:, :, :-1] 82 | 83 | return image, bboxes 84 | 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CenterNet 2 | 3 | ## Introduction 4 | 5 | ![f1.png](https://s2.loli.net/2022/04/27/Bv5V7DkbJOxnhGc.png) 6 | 7 | Network for CenterNet. The pytorch implementation for "[Objects as Points](https://arxiv.org/abs/1904.07850) ". 8 | 9 | ## Quick start 10 | 11 | 1. Clone this repository 12 | 13 | ```shell 14 | git clone https://github.com/Runist/torch_CenterNet 15 | ``` 16 | 2. Install torch_CenterNet from source. 17 | 18 | ```shell 19 | cd torch_CenterNet 20 | pip install -r requirements.txt 21 | ``` 22 | 3. Download the Pascal dataset or COCO dataset. Create new folder name call "data" and symbolic link for your dataset. 23 | ```shell 24 | mkdir data 25 | cd data 26 | ln -s xxx/VOCdevkit VOCdevkit 27 | cd .. 28 | ``` 29 | 4. Prepare the classes information file and place it in "data" directory, the txt file format is: 30 | ```shell 31 | aeroplane 32 | bicycle 33 | ... 34 | tvmonitor 35 | ``` 36 | 5. Configure the parameters in [tools/args.py](https://github.com/Runist/torch_CenterNet/blob/master/tools/args.py). 37 | 6. Start train your model. 38 | 39 | ```shell 40 | python tools/train.py 41 | ``` 42 | or use Linux shell to start. 43 | ```shell 44 | sh scripts/train_yolo.sh 45 | ``` 46 | 7. Open tensorboard to watch loss, learning rate etc. You can also see training process and training process and validation prediction. 47 | 48 | ```shell 49 | tensorboard --logdir ./weights/yolo_voc/log/summary 50 | ``` 51 | 52 | 8. After train, you can run *evaluate.py* to watch model performance. 53 | 54 | ```shell 55 | python tools/evaluate.py 56 | ``` 57 | As well as use Linux shell to start. 58 | ```shell 59 | sh scripts/eval_yolo.sh 60 | ``` 61 | 9. Get prediction of model. 62 | 63 | ```shell 64 | python tools/predict.py 65 | ``` 66 | 67 | Or use script to run 68 | 69 | ```shell 70 | sh scripts/predict.sh 71 | ``` 72 | 73 | ![dog.jpg](https://s2.loli.net/2022/06/21/RM9fQGgKwumy8is.jpg) 74 | 75 | ## Train your dataset 76 | 77 | We provide three dataset format for this repository "yolo", "coco", "voc",You need create new annotation file for "yolo", the format of "yolo" is: 78 | 79 | ```shell 80 | image_path|1,95,240,336,19 81 | image_path|305,131,318,151,14|304,142,354,160,3 82 | ``` 83 | 84 | "coco", "voc" is follow the format of their dataset. And prepare the classes information file and place it in "data" directory. 85 | 86 | ## Performance 87 | 88 | | Train Dataset | Val Dataset | weight | mAP 0.5 | mAP 0.5:0.95 | 89 | | ------------- | ----------- | ------------------------------------------------------------ | ------- | ------------- | 90 | | VOC07+12 | VOC-Test07 | [resnet50-CenterNet.pt](https://github.com/Runist/torch_CenterNet/releases/download/v1/resnet50-CenterNet.pt) | 0.622 | 0.436 | 91 | 92 | ## Reference 93 | 94 | Appreciate the work from the following repositories: 95 | 96 | - [bubbliiiing](https://github.com/bubbliiiing)/[centernet-pytorch](https://github.com/bubbliiiing/centernet-pytorch) 97 | 98 | - [katsura-jp](https://github.com/katsura-jp)/[pytorch-cosine-annealing-with-warmup](https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup) 99 | 100 | - [YunYang1994](https://github.com/YunYang1994)/[tensorflow-yolov3](https://github.com/YunYang1994/tensorflow-yolov3) 101 | 102 | ## License 103 | 104 | Code and datasets are released for non-commercial and research purposes **only**. For commercial purposes, please contact the authors. 105 | -------------------------------------------------------------------------------- /core/dataset/coco.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : coco.py 3 | # @Author: Runist 4 | # @Time : 2022/4/12 11:06 5 | # @Software: PyCharm 6 | # @Brief: 7 | from core.dataset import CenterNetDataset 8 | 9 | import os 10 | import numpy as np 11 | from PIL import Image 12 | 13 | from pycocotools.coco import COCO 14 | 15 | 16 | class CocoDataset(CenterNetDataset): 17 | def __init__(self, data_dir, annotation_path, input_shape, num_classes, is_train): 18 | self.data_dir = data_dir 19 | super().__init__(annotation_path, input_shape, num_classes, is_train) 20 | 21 | def load_annotations(self, annotations_path): 22 | """ 23 | Load image and label info from coco*.json. 24 | 25 | Args: 26 | annotations_path: Annotation file path 27 | 28 | Returns: image id list 29 | """ 30 | self.coco = COCO(annotations_path) 31 | self.remove_useless_info() 32 | 33 | img_ids = self.coco.getImgIds() 34 | 35 | self.cat_ids = sorted(self.coco.getCatIds()) 36 | 37 | return img_ids 38 | 39 | def parse_annotation(self, index): 40 | """ 41 | Parse self.annotation element and read image and bounding boxes. 42 | 43 | Args: 44 | index: index for self.annotation 45 | 46 | Returns: image, bboxes 47 | 48 | """ 49 | img_id = self.annotations[index] 50 | img_ann = self.coco.loadImgs(img_id)[0] 51 | 52 | width = img_ann["width"] 53 | height = img_ann["height"] 54 | filename = img_ann["file_name"] 55 | image_path = os.path.join(self.data_dir, filename) 56 | 57 | ann_id = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False) 58 | annotation = self.coco.loadAnns(ann_id) 59 | 60 | bboxes = [] 61 | for obj in annotation: 62 | x1 = np.max((0, obj["bbox"][0])) 63 | y1 = np.max((0, obj["bbox"][1])) 64 | x2 = np.min((width, x1 + np.max((0, obj["bbox"][2])))) 65 | y2 = np.min((height, y1 + np.max((0, obj["bbox"][3])))) 66 | 67 | if obj["area"] > 0 and x2 >= x1 and y2 >= y1: 68 | obj["clean_bbox"] = [x1, y1, x2, y2] 69 | bboxes.append([x1, y1, x2, y2, self.cat_ids.index(obj["category_id"])]) 70 | 71 | bboxes = np.array(bboxes, np.int32) 72 | 73 | image = Image.open(image_path) 74 | image = np.array(image) 75 | if len(image.shape) == 2: 76 | image = np.expand_dims(image, axis=-1) 77 | image = image.repeat(3, axis=-1) 78 | if image.shape[-1] == 4: 79 | image = image[:, :, :-1] 80 | 81 | return image, bboxes 82 | 83 | def remove_useless_info(self): 84 | """ 85 | Remove useless info in coco dataset. COCO object is modified inplace. 86 | This function is mainly used for saving memory (save about 30% mem). 87 | 88 | Returns: None 89 | 90 | """ 91 | if isinstance(self.coco, COCO): 92 | dataset = self.coco.dataset 93 | dataset.pop("info", None) 94 | dataset.pop("licenses", None) 95 | for img in dataset["images"]: 96 | img.pop("license", None) 97 | img.pop("coco_url", None) 98 | img.pop("date_captured", None) 99 | img.pop("flickr_url", None) 100 | if "annotations" in self.coco.dataset: 101 | for anno in self.coco.dataset["annotations"]: 102 | anno.pop("segmentation", None) 103 | -------------------------------------------------------------------------------- /tools/args.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : args.py 3 | # @Author: Runist 4 | # @Time : 2022/3/29 9:46 5 | # @Software: PyCharm 6 | # @Brief: Code argument parser 7 | 8 | import argparse 9 | import warnings 10 | import os 11 | import torch 12 | import sys 13 | sys.path.append(os.getcwd()) 14 | 15 | from core.helper import seed_torch, get_class_names 16 | 17 | 18 | warnings.filterwarnings("ignore") 19 | 20 | parser = argparse.ArgumentParser() 21 | 22 | parser.add_argument('--gpu', type=str, default='1', help='Select gpu device.') 23 | 24 | parser.add_argument('--input_height', type=int, default=416, help='The height of model input.') 25 | parser.add_argument('--input_width', type=int, default=416, help='The width of model input.') 26 | parser.add_argument('--num_classes', type=int, default=20, help='The number of class.') 27 | 28 | parser.add_argument('--warmup_epochs', type=int, default=5, help='The number of freeze training epochs.') 29 | parser.add_argument('--freeze_epochs', type=int, default=50, help='The number of freeze training epochs.') 30 | parser.add_argument('--unfreeze_epochs', type=int, default=100, help='The number of unfreeze training epochs.') 31 | 32 | parser.add_argument('--freeze_batch_size', type=int, default=16, help='The number of examples per batch.') 33 | parser.add_argument('--unfreeze_batch_size', type=int, default=16, help='The number of examples per batch.') 34 | 35 | parser.add_argument('--learn_rate_init', type=float, default=2e-4, 36 | help='Initial value of cosine annealing learning rate.') 37 | parser.add_argument('--learn_rate_end', type=float, default=1e-6, 38 | help='End value of cosine annealing learning rate.') 39 | parser.add_argument('--num_workers', type=int, default=12, help='The number of torch dataloader thread.') 40 | 41 | parser.add_argument('--backbone', type=str, 42 | default="resnet50", 43 | choices=["resnet50", "resnet101"], 44 | help='The path of the pretrain weight.') 45 | parser.add_argument('--pretrain_weight_path', type=str, 46 | default=None, 47 | help='The path of the pretrain weight.') 48 | 49 | parser.add_argument('--dataset_train_path', type=str, 50 | default="", 51 | help='The file path of the train data.') 52 | parser.add_argument('--dataset_val_path', type=str, 53 | default="", 54 | help='The file path of the val data.') 55 | parser.add_argument('--image_train_dir', type=str, 56 | default="", 57 | help='The images directory of the train data.') 58 | parser.add_argument('--image_val_dir', type=str, 59 | default="", 60 | help='The images directory of the val data.') 61 | parser.add_argument('--annotation_train_dir', type=str, 62 | default="", 63 | help='The labels directory of the train data.') 64 | parser.add_argument('--annotation_val_dir', type=str, 65 | default="", 66 | help='The labels directory of the val data.') 67 | parser.add_argument('--dataset_format', type=str, 68 | default="voc", choices=["coco", "voc", "yolo", "ilsvrc"], 69 | help='The format of dataset, it will influence dataloader method.') 70 | 71 | parser.add_argument('--logs_dir', type=str, default="./logs/temp", 72 | help='The directory of saving weights and training log.') 73 | parser.add_argument('--outputs_dir', type=str, default='./outputs', 74 | help='The directory of the predict image.') 75 | 76 | parser.add_argument('--confidence', type=float, default=0.3, help='The number of class.') 77 | parser.add_argument('--classes_info_file', type=str, default="./data/voc.txt", 78 | help='The text that stores classification information.') 79 | parser.add_argument('--test_weight', type=str, help='The name of the model weight.') 80 | 81 | args = parser.parse_args() 82 | 83 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 84 | seed_torch(777) 85 | 86 | dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 87 | class_names = get_class_names(args.classes_info_file) 88 | -------------------------------------------------------------------------------- /core/dataset/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : utils.py 3 | # @Author: Runist 4 | # @Time : 2022/3/28 17:36 5 | # @Software: PyCharm 6 | # @Brief: 7 | 8 | 9 | import numpy as np 10 | import cv2 as cv 11 | import math 12 | 13 | 14 | def draw_gaussian(heatmap, center, radius, k=1): 15 | """ 16 | Get a heatmap of one class 17 | Args: 18 | heatmap: The heatmap of one class(storage in single channel) 19 | center: The location of object center 20 | radius: 2D Gaussian circle radius 21 | k: The magnification of the Gaussian 22 | 23 | Returns: heatmap 24 | 25 | """ 26 | diameter = 2 * radius + 1 27 | gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6) 28 | 29 | x, y = int(center[0]), int(center[1]) 30 | 31 | height, width = heatmap.shape[0:2] 32 | 33 | left, right = min(x, radius), min(width - x, radius + 1) 34 | top, bottom = min(y, radius), min(height - y, radius + 1) 35 | 36 | masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] 37 | masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right] 38 | if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug 39 | np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) 40 | 41 | return heatmap 42 | 43 | 44 | def gaussian2D(shape, sigma=1): 45 | """ 46 | 2D Gaussian function 47 | Args: 48 | shape: (diameter, diameter) 49 | sigma: variance 50 | 51 | Returns: h 52 | 53 | """ 54 | m, n = [(ss - 1.) / 2. for ss in shape] 55 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 56 | 57 | h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) 58 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 59 | 60 | return h 61 | 62 | 63 | def gaussian_radius(det_size, min_overlap=0.7): 64 | """ 65 | Get gaussian circle radius. 66 | Args: 67 | det_size: (height, width) 68 | min_overlap: overlap minimum 69 | 70 | Returns: radius 71 | 72 | """ 73 | height, width = det_size 74 | 75 | a1 = 1 76 | b1 = (height + width) 77 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 78 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 79 | r1 = (b1 + sq1) / 2 80 | 81 | a2 = 4 82 | b2 = 2 * (height + width) 83 | c2 = (1 - min_overlap) * width * height 84 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 85 | r2 = (b2 + sq2) / 2 86 | 87 | a3 = 4 * min_overlap 88 | b3 = -2 * min_overlap * (height + width) 89 | c3 = (min_overlap - 1) * width * height 90 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 91 | r3 = (b3 + sq3) / 2 92 | 93 | return min(r1, r2, r3) 94 | 95 | 96 | def preprocess_input(image): 97 | mean = [0.40789655, 0.44719303, 0.47026116] 98 | std = [0.2886383, 0.27408165, 0.27809834] 99 | 100 | # mean = [0.485, 0.456, 0.406] 101 | # std = [0.229, 0.224, 0.225] 102 | 103 | image = (image / 255. - mean) / std 104 | image = np.transpose(image, (2, 0, 1)) 105 | return image 106 | 107 | 108 | def recover_input(image): 109 | mean = [0.40789655, 0.44719303, 0.47026116] 110 | std = [0.2886383, 0.27408165, 0.27809834] 111 | 112 | # mean = [0.485, 0.456, 0.406] 113 | # std = [0.229, 0.224, 0.225] 114 | 115 | image = np.transpose(image, (1, 2, 0)) 116 | image = (image * std + mean) * 255 117 | 118 | return image 119 | 120 | 121 | def image_resize(image, target_size, gt_boxes=None): 122 | ih, iw = target_size 123 | 124 | h, w = image.shape[:2] 125 | 126 | scale = min(iw/w, ih/h) 127 | nw, nh = int(scale * w), int(scale * h) 128 | image_resized = cv.resize(image, (nw, nh)) 129 | 130 | dw, dh = (iw - nw) // 2, (ih - nh) // 2 131 | image_paded = np.full(shape=[ih, iw, 3], fill_value=128.0) 132 | image_paded[dh:nh+dh, dw:nw+dw, :] = image_resized 133 | 134 | if gt_boxes is None: 135 | return image_paded 136 | elif gt_boxes.size == 0: 137 | # Use no label image to train 138 | return image_paded, gt_boxes 139 | else: 140 | gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] * scale + dw 141 | gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] * scale + dh 142 | return image_paded, gt_boxes 143 | -------------------------------------------------------------------------------- /core/detect.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : detect.py 3 | # @Author: Runist 4 | # @Time : 2022/3/29 11:37 5 | # @Software: PyCharm 6 | # @Brief: Detection function 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | from torchvision.ops import nms 12 | 13 | from net import CenterNetPoolingNMS 14 | from core.dataset import preprocess_input, image_resize 15 | 16 | 17 | def postprocess_output(hms, whs, offsets, confidence, dev): 18 | """ 19 | The post process of model output. 20 | Args: 21 | hms: heatmap 22 | whs: the height and width of bounding box 23 | offsets: center point offset 24 | confidence: the threshold of heatmap 25 | dev: torch device 26 | 27 | Returns: The list of bounding box(x, y, w, h, score, label). 28 | 29 | """ 30 | batch, output_h, output_w, c = hms.shape 31 | 32 | detections = [] 33 | for b in range(batch): 34 | # (h, w, c) -> (-1, c) 35 | heat_map = hms[b].view([-1, c]) 36 | # (h, w, 2) -> (-1, 2) 37 | wh = whs[b].view([-1, 2]) 38 | # (h, w, 2) -> (-1, 2) 39 | offset = offsets[b].view([-1, 2]) 40 | 41 | yv, xv = torch.meshgrid(torch.arange(0, output_h), torch.arange(0, output_w)) 42 | xv, yv = xv.flatten().float(), yv.flatten().float() 43 | 44 | xv = xv.to(dev) # x axis coordinate of feature point 45 | yv = yv.to(dev) # y axis coordinate of feature point 46 | 47 | # torch.max[0] max value 48 | # torch.max[1] index of max value 49 | score, label = torch.max(heat_map, dim=-1) 50 | mask = score > confidence 51 | 52 | # Choose height, width and offset by confidence mask 53 | wh_mask = wh[mask] 54 | offset_mask = offset[mask] 55 | 56 | if len(wh_mask) == 0: 57 | detections.append([]) 58 | continue 59 | 60 | # Adjust center of predict box 61 | xv_mask = torch.unsqueeze(xv[mask] + offset_mask[..., 0], -1) 62 | yv_mask = torch.unsqueeze(yv[mask] + offset_mask[..., 1], -1) 63 | 64 | # Get the (xmin, ymin, xmax, ymax) 65 | half_w, half_h = wh_mask[..., 0:1] / 2, wh_mask[..., 1:2] / 2 66 | bboxes = torch.cat([xv_mask - half_w, yv_mask - half_h, xv_mask + half_w, yv_mask + half_h], dim=1) 67 | 68 | # Bounding box coordinate normalize 69 | bboxes[:, [0, 2]] /= output_w 70 | bboxes[:, [1, 3]] /= output_h 71 | 72 | # Concatenate the prediction 73 | detect = torch.cat( 74 | [bboxes, torch.unsqueeze(score[mask], -1), torch.unsqueeze(label[mask], -1).float()], dim=-1) 75 | detections.append(detect) 76 | 77 | return detections 78 | 79 | 80 | def decode_bbox(prediction, input_shape, dev, image_shape=None, remove_pad=False, need_nms=False, nms_thres=0.4): 81 | """ 82 | Decode postprecess_output output 83 | Args: 84 | prediction: postprecess_output output 85 | input_shape: model input shape 86 | dev: torch device 87 | image_shape: image shape 88 | remove_pad: model input is padding image, you should set remove_pad=True if you want to remove this pad 89 | need_nms: whether use NMS to remove redundant detect box 90 | nms_thres: nms threshold 91 | 92 | Returns: The list of bounding box(x1, y1, x2, y2 score, label). 93 | 94 | """ 95 | output = [[] for _ in prediction] 96 | 97 | for b, detection in enumerate(prediction): 98 | if len(detection) == 0: 99 | continue 100 | 101 | if need_nms: 102 | keep = nms(detection[:, :4], detection[:, 4], nms_thres) 103 | detection = detection[keep] 104 | 105 | output[b].append(detection) 106 | 107 | output[b] = torch.cat(output[b]) 108 | if output[b] is not None: 109 | bboxes = output[b][:, 0:4] 110 | 111 | input_shape = torch.tensor(input_shape, device=dev) 112 | bboxes *= torch.cat([input_shape, input_shape], dim=-1) 113 | 114 | if remove_pad: 115 | assert image_shape is not None, \ 116 | "If remove_pad is True, image_shape must be set the shape of original image." 117 | ih, iw = input_shape 118 | h, w = image_shape 119 | scale = min(iw/w, ih/h) 120 | nw, nh = int(scale * w), int(scale * h) 121 | dw, dh = (iw - nw) // 2, (ih - nh) // 2 122 | 123 | bboxes[:, [0, 2]] = (bboxes[:, [0, 2]] - dw) / scale 124 | bboxes[:, [1, 3]] = (bboxes[:, [1, 3]] - dh) / scale 125 | 126 | output[b][:, :4] = bboxes 127 | 128 | return output 129 | 130 | 131 | def predict(image, model, dev, args): 132 | """ 133 | Predict one image 134 | Args: 135 | image: input image 136 | model: CenterNet model 137 | dev: torch device 138 | args: ArgumentParser 139 | 140 | Returns: bounding box of one image(x1, y1, x2, y2 score, label). 141 | 142 | """ 143 | input_data = image_resize(image, (args.input_height, args.input_height)) 144 | input_data = preprocess_input(input_data) 145 | input_data = np.expand_dims(input_data, 0) 146 | 147 | input_data = torch.from_numpy(input_data.copy()).float() 148 | input_data = input_data.to(dev) 149 | 150 | hms, whs, offsets = model(input_data) 151 | hms = CenterNetPoolingNMS(kernel=3)(hms) 152 | 153 | hms = hms.permute(0, 2, 3, 1) 154 | whs = whs.permute(0, 2, 3, 1) 155 | offsets = offsets.permute(0, 2, 3, 1) 156 | 157 | outputs = postprocess_output(hms, whs, offsets, args.confidence, dev) 158 | outputs = decode_bbox(outputs, 159 | (args.input_height, args.input_height), 160 | dev, image_shape=image.shape[:2], remove_pad=True, 161 | need_nms=True, nms_thres=0.45) 162 | 163 | return outputs[0] 164 | -------------------------------------------------------------------------------- /net/centernet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : centernet.py 3 | # @Author: Runist 4 | # @Time : 2022/3/28 16:46 5 | # @Software: PyCharm 6 | # @Brief: CenterNet implement 7 | from torch import nn 8 | import math 9 | from net.backbone.resnet import resnet50, resnet101 10 | from core.loss import focal_loss, l1_loss 11 | 12 | 13 | class CenterNet(nn.Module): 14 | def __init__(self, backbone, num_classes=20): 15 | """ 16 | 17 | Args: 18 | backbone: string 19 | num_classes: int 20 | """ 21 | super(CenterNet, self).__init__() 22 | 23 | # h, w, 3 -> h/32, w/32, 2048 24 | if backbone == "resnet50": 25 | self.backbone = resnet50(num_classes, include_top=False) 26 | elif backbone == "resnet101": 27 | self.backbone = resnet101(num_classes, include_top=False) 28 | else: 29 | raise Exception("There is no {}.".format(backbone)) 30 | 31 | # h/32, w/32, 2048 -> h/4, w/4, 64 32 | self.decoder = CenterNetDecoder(2048) 33 | 34 | # feature height and width: h/4, w/4 35 | # hm channel: num_classes 36 | # wh channel: 2 37 | # offset channel: 2 38 | self.head = CenterNetHead(channel=64, num_classes=num_classes) 39 | 40 | for m in self.modules(): 41 | if isinstance(m, nn.Conv2d): 42 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 43 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 44 | m.weight.data.normal_(0, math.sqrt(2. / n)) 45 | elif isinstance(m, nn.BatchNorm2d): 46 | nn.init.constant_(m.weight, 1) 47 | nn.init.constant_(m.bias, 0) 48 | 49 | self.head.cls_head[-2].weight.data.fill_(0) 50 | self.head.cls_head[-2].bias.data.fill_(-2.19) 51 | 52 | def freeze_backbone(self): 53 | """ 54 | Freeze centernet backbone parameters. 55 | Returns: None 56 | 57 | """ 58 | for param in self.backbone.parameters(): 59 | param.requires_grad = False 60 | for param in self.decoder.parameters(): 61 | param.requires_grad = False 62 | 63 | def unfreeze_backbone(self): 64 | """ 65 | Unfreeze centernet backbone parameters. 66 | Returns: None 67 | 68 | """ 69 | for param in self.backbone.parameters(): 70 | param.requires_grad = True 71 | for param in self.decoder.parameters(): 72 | param.requires_grad = True 73 | 74 | def forward(self, x, **kwargs): 75 | x = self.backbone(x) 76 | x = self.decoder(x) 77 | hms_pred, whs_pred, offsets_pred = self.head(x) 78 | 79 | if kwargs.get('mode') == "train": 80 | # Change model outputs tensor order 81 | hms_pred = hms_pred.permute(0, 2, 3, 1) 82 | whs_pred = whs_pred.permute(0, 2, 3, 1) 83 | offsets_pred = offsets_pred.permute(0, 2, 3, 1) 84 | 85 | hms_true, whs_true, offsets_true, offset_masks_true = kwargs.get('ground_truth_data') 86 | 87 | c_loss = focal_loss(hms_pred, hms_true) 88 | wh_loss = 0.1 * l1_loss(whs_pred, whs_true, offset_masks_true) 89 | off_loss = l1_loss(offsets_pred, offsets_true, offset_masks_true) 90 | 91 | loss = c_loss + wh_loss + off_loss 92 | 93 | # Using 3x3 kernel max pooling to filter the maximum response of heatmap 94 | hms_true = hms_true.permute(0, 3, 1, 2) 95 | hms_pred = hms_pred.permute(0, 3, 1, 2) 96 | hms_true = CenterNetPoolingNMS(kernel=3)(hms_true) 97 | hms_pred = CenterNetPoolingNMS(kernel=3)(hms_pred) 98 | hms_true = hms_true.permute(0, 2, 3, 1) 99 | hms_pred = hms_pred.permute(0, 2, 3, 1) 100 | 101 | return hms_pred, whs_pred, offsets_pred, loss, c_loss, wh_loss, off_loss, hms_true 102 | else: 103 | hms_pred = CenterNetPoolingNMS(kernel=3)(hms_pred) 104 | return hms_pred, whs_pred, offsets_pred 105 | 106 | 107 | class CenterNetDecoder(nn.Module): 108 | def __init__(self, in_channels, bn_momentum=0.1): 109 | super(CenterNetDecoder, self).__init__() 110 | self.bn_momentum = bn_momentum 111 | self.in_channels = in_channels 112 | self.deconv_with_bias = False 113 | 114 | # h/32, w/32, 2048 -> h/16, w/16, 256 -> h/8, w/8, 128 -> h/4, w/4, 64 115 | self.deconv_layers = self._make_deconv_layer( 116 | num_layers=3, 117 | num_filters=[256, 128, 64], 118 | num_kernels=[4, 4, 4], 119 | ) 120 | 121 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels): 122 | layers = [] 123 | for i in range(num_layers): 124 | kernel = num_kernels[i] 125 | num_filter = num_filters[i] 126 | 127 | layers.append( 128 | nn.ConvTranspose2d( 129 | in_channels=self.in_channels, 130 | out_channels=num_filter, 131 | kernel_size=kernel, 132 | stride=2, 133 | padding=1, 134 | output_padding=0, 135 | bias=self.deconv_with_bias)) 136 | layers.append(nn.BatchNorm2d(num_filter, momentum=self.bn_momentum)) 137 | layers.append(nn.ReLU(inplace=True)) 138 | self.in_channels = num_filter 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | return self.deconv_layers(x) 143 | 144 | 145 | class CenterNetHead(nn.Module): 146 | def __init__(self, num_classes=80, channel=64, bn_momentum=0.1): 147 | super(CenterNetHead, self).__init__() 148 | 149 | # heatmap 150 | self.cls_head = nn.Sequential( 151 | nn.Conv2d(64, channel, kernel_size=3, padding=1, bias=False), 152 | nn.BatchNorm2d(64, momentum=bn_momentum), 153 | nn.ReLU(inplace=True), 154 | nn.Conv2d(channel, num_classes, kernel_size=1, stride=1, padding=0), 155 | nn.Sigmoid() 156 | ) 157 | # bounding boxes height and width 158 | self.wh_head = nn.Sequential( 159 | nn.Conv2d(64, channel, kernel_size=3, padding=1, bias=False), 160 | nn.BatchNorm2d(64, momentum=bn_momentum), 161 | nn.ReLU(inplace=True), 162 | nn.Conv2d(channel, 2, kernel_size=1, stride=1, padding=0)) 163 | 164 | # center point offset 165 | self.offset_head = nn.Sequential( 166 | nn.Conv2d(64, channel, kernel_size=3, padding=1, bias=False), 167 | nn.BatchNorm2d(64, momentum=bn_momentum), 168 | nn.ReLU(inplace=True), 169 | nn.Conv2d(channel, 2, kernel_size=1, stride=1, padding=0)) 170 | 171 | def forward(self, x): 172 | hm = self.cls_head(x) 173 | wh = self.wh_head(x) 174 | offset = self.offset_head(x) 175 | 176 | return hm, wh, offset 177 | 178 | 179 | class CenterNetPoolingNMS(nn.Module): 180 | def __init__(self, kernel=3): 181 | """ 182 | To replace traditional nms method. Input is heatmap, the num of channel is num_classes, 183 | So one object center has strongest response, where use torch.max(heatmap, dim=-1), it only 184 | filter single pixel max value, the neighbour pixel still have strong response, so we should 185 | use max pooling stride=1 to filter this fake center point. 186 | Args: 187 | kernel: max pooling kernel size 188 | """ 189 | super(CenterNetPoolingNMS, self).__init__() 190 | self.pad = (kernel - 1) // 2 191 | self.max_pool = nn.MaxPool2d(kernel_size=kernel, stride=1, padding=(kernel - 1) // 2) 192 | 193 | def forward(self, x): 194 | xmax = self.max_pool(x) 195 | keep = (xmax == x).float() 196 | 197 | return x * keep 198 | -------------------------------------------------------------------------------- /net/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : resnet.py 3 | # @Author: Runist 4 | # @Time : 2022/3/28 17:03 5 | # @Software: PyCharm 6 | # @Brief: Resnet implement 7 | 8 | 9 | # -*- coding: utf-8 -*- 10 | # @File : model.py 11 | # @Author: Runist 12 | # @Time : 2021/10/28 10:20 13 | # @Software: PyCharm 14 | # @Brief: 15 | import torch.nn as nn 16 | import torch 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, 25 | kernel_size=3, stride=stride, padding=1, bias=False) 26 | self.bn1 = nn.BatchNorm2d(out_channel) 27 | self.relu = nn.ReLU() 28 | self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, 29 | kernel_size=3, stride=1, padding=1, bias=False) 30 | self.bn2 = nn.BatchNorm2d(out_channel) 31 | self.downsample = downsample 32 | 33 | def forward(self, x): 34 | identity = x 35 | if self.downsample is not None: 36 | identity = self.downsample(x) 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | out += identity 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, 55 | groups=1, width_per_group=64): 56 | super(Bottleneck, self).__init__() 57 | 58 | width = int(out_channel * (width_per_group / 64.)) * groups 59 | 60 | self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width, 61 | kernel_size=1, stride=1, bias=False) # squeeze channels 62 | self.bn1 = nn.BatchNorm2d(width) 63 | # ----------------------------------------- 64 | self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups, 65 | kernel_size=3, stride=stride, bias=False, padding=1) 66 | self.bn2 = nn.BatchNorm2d(width) 67 | # ----------------------------------------- 68 | self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion, 69 | kernel_size=1, stride=1, bias=False) # unsqueeze channels 70 | self.bn3 = nn.BatchNorm2d(out_channel*self.expansion) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | 74 | def forward(self, x): 75 | identity = x 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | out += identity 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, 99 | block, 100 | blocks_num, 101 | num_classes=1000, 102 | include_top=True, 103 | groups=1, 104 | width_per_group=64): 105 | super(ResNet, self).__init__() 106 | self.include_top = include_top 107 | self.in_channel = 64 108 | 109 | self.groups = groups 110 | self.width_per_group = width_per_group 111 | 112 | self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, 113 | padding=3, bias=False) 114 | self.bn1 = nn.BatchNorm2d(self.in_channel) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 117 | self.layer1 = self._make_layer(block, 64, blocks_num[0]) 118 | self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) 119 | self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) 120 | self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) 121 | if self.include_top: 122 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1) 123 | self.fc = nn.Linear(512 * block.expansion, num_classes) 124 | 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 128 | 129 | def _make_layer(self, block, channel, block_num, stride=1): 130 | downsample = None 131 | if stride != 1 or self.in_channel != channel * block.expansion: 132 | downsample = nn.Sequential( 133 | nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), 134 | nn.BatchNorm2d(channel * block.expansion)) 135 | 136 | layers = [] 137 | layers.append(block(self.in_channel, 138 | channel, 139 | downsample=downsample, 140 | stride=stride, 141 | groups=self.groups, 142 | width_per_group=self.width_per_group)) 143 | self.in_channel = channel * block.expansion 144 | 145 | for _ in range(1, block_num): 146 | layers.append(block(self.in_channel, 147 | channel, 148 | groups=self.groups, 149 | width_per_group=self.width_per_group)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def forward(self, x): 154 | x = self.conv1(x) 155 | x = self.bn1(x) 156 | x = self.relu(x) 157 | x = self.maxpool(x) 158 | 159 | x = self.layer1(x) 160 | x = self.layer2(x) 161 | x = self.layer3(x) 162 | x = self.layer4(x) 163 | 164 | if self.include_top: 165 | x = self.avgpool(x) 166 | x = torch.flatten(x, 1) 167 | x = self.fc(x) 168 | 169 | return x 170 | 171 | 172 | def resnet34(num_classes=1000, include_top=True): 173 | # https://download.pytorch.org/models/resnet34-333f7ec4.pth 174 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top) 175 | 176 | 177 | def resnet50(num_classes=1000, include_top=True): 178 | # https://download.pytorch.org/models/resnet50-19c8e357.pth 179 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top) 180 | 181 | 182 | def resnet101(num_classes=1000, include_top=True): 183 | # https://download.pytorch.org/models/resnet101-5d3b4d8f.pth 184 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top) 185 | 186 | 187 | def resnext50_32x4d(num_classes=1000, include_top=True): 188 | # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth 189 | groups = 32 190 | width_per_group = 4 191 | return ResNet(Bottleneck, [3, 4, 6, 3], 192 | num_classes=num_classes, 193 | include_top=include_top, 194 | groups=groups, 195 | width_per_group=width_per_group) 196 | 197 | 198 | def resnext101_32x8d(num_classes=1000, include_top=True): 199 | # https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth 200 | groups = 32 201 | width_per_group = 8 202 | return ResNet(Bottleneck, [3, 4, 23, 3], 203 | num_classes=num_classes, 204 | include_top=include_top, 205 | groups=groups, 206 | width_per_group=width_per_group) 207 | -------------------------------------------------------------------------------- /core/helper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : helper.py 3 | # @Author: Runist 4 | # @Time : 2022/3/30 12:00 5 | # @Software: PyCharm 6 | # @Brief: Some function 7 | 8 | from net import CenterNet 9 | from core.dataset import CocoDataset, PascalDataset, YoloDataset, ILSVRCDataset 10 | 11 | import torch 12 | from torch import nn 13 | import numpy as np 14 | import os 15 | import shutil 16 | import random 17 | import cv2 as cv 18 | 19 | 20 | def seed_torch(seed): 21 | """ 22 | Set all random seed 23 | Args: 24 | seed: random seed 25 | 26 | Returns: None 27 | 28 | """ 29 | 30 | random.seed(seed) 31 | os.environ['PYTHONHASHSEED'] = str(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed(seed) 35 | torch.cuda.manual_seed_all(seed) 36 | torch.backends.cudnn.benchmark = False 37 | torch.backends.cudnn.deterministic = True 38 | 39 | 40 | def remove_dir_and_create_dir(dir_name, is_remove=True): 41 | """ 42 | Make new folder, if this folder exist, we will remove it and create a new folder. 43 | Args: 44 | dir_name: path of folder 45 | is_remove: if true, it will remove old folder and create new folder 46 | 47 | Returns: None 48 | 49 | """ 50 | if not os.path.exists(dir_name): 51 | os.makedirs(dir_name) 52 | print(dir_name, "create.") 53 | else: 54 | if is_remove: 55 | shutil.rmtree(dir_name) 56 | os.makedirs(dir_name) 57 | print(dir_name, "create.") 58 | else: 59 | print(dir_name, "is exist.") 60 | 61 | 62 | def get_color_map(): 63 | """ 64 | Create color map. 65 | 66 | Returns: numpy array. 67 | 68 | """ 69 | color_map = np.zeros((256, 3), dtype=np.uint8) 70 | ind = np.arange(256, dtype=np.uint8) 71 | 72 | for shift in reversed(range(8)): 73 | for channel in range(3): 74 | color_map[:, channel] |= ((ind >> channel) & 1) << shift 75 | ind >>= 3 76 | 77 | return color_map 78 | 79 | 80 | def get_model(args, dev): 81 | """ 82 | Get CenterNet model. 83 | Args: 84 | args: ArgumentParser 85 | dev: torch dev 86 | 87 | Returns: CenterNet model 88 | 89 | """ 90 | device_ids = [i for i in range(len(args.gpu.split(',')))] 91 | 92 | model = CenterNet(args.backbone, num_classes=args.num_classes) 93 | 94 | if args.pretrain_weight_path is not None: 95 | print("Loading {}.".format(args.pretrain_weight_path)) 96 | model_state_dict = model.state_dict() 97 | pretrain_state_dict = torch.load(args.pretrain_weight_path) 98 | 99 | for k, v in pretrain_state_dict.items(): 100 | centernet_k = "backbone." + k 101 | if centernet_k in model_state_dict.keys(): 102 | model_state_dict[centernet_k] = v 103 | 104 | model.load_state_dict(model_state_dict) 105 | 106 | model = nn.DataParallel(model, device_ids=device_ids) 107 | model.to(dev) 108 | 109 | return model 110 | 111 | 112 | def get_class_names(classes_info_file): 113 | """ 114 | Get dataset class name 115 | Args: 116 | classes_info_file: class name file path 117 | 118 | Returns: 119 | 120 | """ 121 | class_names = [] 122 | with open(classes_info_file, 'r') as data: 123 | for name in data: 124 | name = name.strip('\n') 125 | if ":" in name: 126 | name = name.split(": ")[0] 127 | class_names.append(name) 128 | 129 | return class_names 130 | 131 | 132 | def get_dataset(args, class_names): 133 | """ 134 | Get CenterNet dataset. 135 | Args: 136 | args: ArgumentParser 137 | class_names: the name of class 138 | 139 | Returns: dataset 140 | 141 | """ 142 | if args.dataset_format == "voc": 143 | train_dataset = PascalDataset(args.image_train_dir, 144 | args.annotation_train_dir, 145 | class_names, 146 | args.dataset_train_path, (args.input_height, args.input_width), 147 | num_classes=args.num_classes, is_train=True) 148 | val_dataset = PascalDataset(args.image_val_dir, 149 | args.annotation_val_dir, 150 | class_names, 151 | args.dataset_val_path, (args.input_height, args.input_width), 152 | num_classes=args.num_classes, is_train=False) 153 | elif args.dataset_format == "coco": 154 | train_dataset = CocoDataset(args.image_train_dir, 155 | args.dataset_train_path, (args.input_height, args.input_width), 156 | num_classes=args.num_classes, is_train=True) 157 | val_dataset = CocoDataset(args.image_val_dir, 158 | args.dataset_val_path, (args.input_height, args.input_width), 159 | num_classes=args.num_classes, is_train=False) 160 | elif args.dataset_format == "yolo": 161 | train_dataset = YoloDataset(args.dataset_train_path, (args.input_height, args.input_width), 162 | num_classes=args.num_classes, is_train=True) 163 | val_dataset = YoloDataset(args.dataset_val_path, (args.input_height, args.input_width), 164 | num_classes=args.num_classes, is_train=False) 165 | elif args.dataset_format == "ilsvrc": 166 | train_dataset = ILSVRCDataset(args.image_train_dir, 167 | args.annotation_train_dir, 168 | class_names, 169 | args.dataset_train_path, (args.input_height, args.input_width), 170 | num_classes=args.num_classes, is_train=True) 171 | val_dataset = ILSVRCDataset(args.image_val_dir, 172 | args.annotation_val_dir, 173 | class_names, 174 | args.dataset_val_path, (args.input_height, args.input_width), 175 | num_classes=args.num_classes, is_train=False) 176 | else: 177 | raise Exception("There is no {} format for data parsing, you should choose one from 'yolo', 'coco', 'voc', 'ilsvrc'". 178 | format(args.dataset_format)) 179 | 180 | return train_dataset, val_dataset 181 | 182 | 183 | def draw_bbox(image, bboxes, labels, class_names, scores=None, show_name=False): 184 | """ 185 | Draw bounding box in image. 186 | Args: 187 | image: image 188 | bboxes: coordinate of bounding box 189 | labels: the index of labels 190 | class_names: the names of class 191 | scores: bounding box confidence 192 | show_name: show class name if set true, otherwise show index of class 193 | 194 | Returns: draw result 195 | 196 | """ 197 | color_map = get_color_map() 198 | image_height, image_width = image.shape[:2] 199 | draw_image = image.copy() 200 | 201 | for i, c in list(enumerate(labels)): 202 | bbox = bboxes[i] 203 | c = int(c) 204 | color = [int(j) for j in color_map[c]] 205 | if show_name: 206 | predicted_class = class_names[c] 207 | else: 208 | predicted_class = c 209 | 210 | if scores is None: 211 | text = '{}'.format(predicted_class) 212 | else: 213 | score = scores[i] 214 | text = '{} {:.2f}'.format(predicted_class, score) 215 | 216 | x1, y1, x2, y2 = bbox 217 | 218 | x1 = max(0, np.floor(x1).astype(np.int32)) 219 | y1 = max(0, np.floor(y1).astype(np.int32)) 220 | x2 = min(image_width, np.floor(x2).astype(np.int32)) 221 | y2 = min(image_height, np.floor(y2).astype(np.int32)) 222 | 223 | thickness = int((image_height + image_width) / (np.sqrt(image_height**2 + image_width**2))) 224 | fontScale = 0.35 225 | 226 | t_size = cv.getTextSize(text, 0, fontScale, thickness=thickness * 2)[0] 227 | cv.rectangle(draw_image, (x1, y1), (x2, y2), color=color, thickness=thickness) 228 | cv.rectangle(draw_image, (x1, y1), (x1 + t_size[0], y1 - t_size[1]), color, -1) # filled 229 | cv.putText(draw_image, text, (x1, y1), cv.FONT_HERSHEY_SIMPLEX, 230 | fontScale, (255, 255, 255), thickness//2, lineType=cv.LINE_AA) 231 | 232 | return draw_image 233 | -------------------------------------------------------------------------------- /core/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : dataset.py 3 | # @Author: Runist 4 | # @Time : 2022/4/12 11:06 5 | # @Software: PyCharm 6 | # @Brief: 7 | from core.dataset.utils import image_resize, preprocess_input, gaussian_radius, draw_gaussian 8 | 9 | import os 10 | import cv2 as cv 11 | import random 12 | import math 13 | import numpy as np 14 | from PIL import Image, ImageEnhance 15 | from torch.utils.data.dataset import Dataset 16 | 17 | 18 | class CenterNetDataset(Dataset): 19 | def __init__(self, annotation_path, input_shape, num_classes, is_train): 20 | super(CenterNetDataset, self).__init__() 21 | self.stride = 4 22 | 23 | self.input_shape = input_shape 24 | self.output_shape = (input_shape[0] // self.stride, input_shape[1] // self.stride) 25 | self.num_classes = num_classes 26 | self.is_train = is_train 27 | 28 | self.annotations = self.load_annotations(annotation_path) 29 | 30 | def load_annotations(self, annotations_path): 31 | """ 32 | Load image and label info from annotation file. 33 | Must implement in subclass. You can do this by referring to 'yolo.py'. 34 | 35 | Args: 36 | annotations_path: Annotation file path 37 | 38 | Returns: list or dict type. Depends on how do you read it in 'parse_annotation' 39 | 40 | """ 41 | raise NotImplementedError('load_annotations method not implemented!') 42 | 43 | def __len__(self): 44 | return len(self.annotations) 45 | 46 | def __getitem__(self, index): 47 | batch_hm = np.zeros((self.output_shape[0], self.output_shape[1], self.num_classes), dtype=np.float32) 48 | batch_wh = np.zeros((self.output_shape[0], self.output_shape[1], 2), dtype=np.float32) 49 | batch_offset = np.zeros((self.output_shape[0], self.output_shape[1], 2), dtype=np.float32) 50 | batch_offset_mask = np.zeros((self.output_shape[0], self.output_shape[1]), dtype=np.float32) 51 | 52 | # Read image and bounding boxes 53 | image, bboxes = self.parse_annotation(index) 54 | 55 | if self.is_train: 56 | image, bboxes = self.data_augmentation(image, bboxes) 57 | 58 | # Image preprocess 59 | image, bboxes = image_resize(image, self.input_shape, bboxes) 60 | image = preprocess_input(image) 61 | 62 | # Clip bounding boxes 63 | clip_bboxes = [] 64 | labels = [] 65 | for bbox in bboxes: 66 | x1, y1, x2, y2, label = bbox 67 | 68 | if x2 <= x1 or y2 <= y1: 69 | # Don't use such boxes as this may cause nan loss. 70 | continue 71 | 72 | x1 = int(np.clip(x1, 0, self.input_shape[1])) 73 | y1 = int(np.clip(y1, 0, self.input_shape[0])) 74 | x2 = int(np.clip(x2, 0, self.input_shape[1])) 75 | y2 = int(np.clip(y2, 0, self.input_shape[0])) 76 | # Clipping coordinates between 0 to image dimensions as negative values 77 | # or values greater than image dimensions may cause nan loss. 78 | clip_bboxes.append([x1, y1, x2, y2]) 79 | labels.append(label) 80 | 81 | bboxes = np.array(clip_bboxes) 82 | labels = np.array(labels) 83 | 84 | if len(bboxes) != 0: 85 | labels = np.array(labels, dtype=np.float32) 86 | bboxes = np.array(bboxes[:, :4], dtype=np.float32) 87 | bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]] / self.stride, a_min=0, a_max=self.output_shape[1]) 88 | bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]] / self.stride, a_min=0, a_max=self.output_shape[0]) 89 | 90 | for i in range(len(labels)): 91 | x1, y1, x2, y2 = bboxes[i] 92 | cls_id = int(labels[i]) 93 | 94 | h, w = y2 - y1, x2 - x1 95 | if h > 0 and w > 0: 96 | radius = gaussian_radius((math.ceil(h), math.ceil(w))) 97 | radius = max(0, int(radius)) 98 | 99 | # Calculates the feature points of the real box 100 | ct = np.array([(x1 + x2) / 2, (y1 + y2) / 2], dtype=np.float32) 101 | ct_int = ct.astype(np.int32) 102 | 103 | # Get gaussian heat map 104 | batch_hm[:, :, cls_id] = draw_gaussian(batch_hm[:, :, cls_id], ct_int, radius) 105 | 106 | # Assign ground truth height and width 107 | batch_wh[ct_int[1], ct_int[0]] = 1. * w, 1. * h 108 | 109 | # Assign center point offset 110 | batch_offset[ct_int[1], ct_int[0]] = ct - ct_int 111 | 112 | # Set the corresponding mask to 1 113 | batch_offset_mask[ct_int[1], ct_int[0]] = 1 114 | 115 | return image, batch_hm, batch_wh, batch_offset, batch_offset_mask 116 | 117 | def parse_annotation(self, index): 118 | """ 119 | Parse self.annotation element and read image and bounding boxes. 120 | 121 | Args: 122 | index: index for self.annotation 123 | 124 | Returns: image, bboxes 125 | 126 | """ 127 | 128 | raise NotImplementedError('parse_annotation method not implemented!') 129 | 130 | def data_augmentation(self, image, bboxes): 131 | if random.random() < 0.5: 132 | image, bboxes = self.random_horizontal_flip(image, bboxes) 133 | # if random.random() < 0.5: 134 | # image, bboxes = self.random_vertical_flip(image, bboxes) 135 | if random.random() < 0.5: 136 | image, bboxes = self.random_crop(image, bboxes) 137 | if random.random() < 0.5: 138 | image, bboxes = self.random_translate(image, bboxes) 139 | 140 | if random.random() < 0.5: 141 | image = Image.fromarray(image) 142 | enh_bri = ImageEnhance.Brightness(image) 143 | # brightness = [1, 0.5, 1.4] 144 | image = enh_bri.enhance(random.uniform(0.6, 1.4)) 145 | image = np.array(image) 146 | 147 | if random.random() < 0.5: 148 | image = Image.fromarray(image) 149 | enh_col = ImageEnhance.Color(image) 150 | # color = [0.7, 1.3, 1] 151 | image = enh_col.enhance(random.uniform(0.7, 1.3)) 152 | image = np.array(image) 153 | 154 | if random.random() < 0.5: 155 | image = Image.fromarray(image) 156 | enh_con = ImageEnhance.Contrast(image) 157 | # contrast = [0.7, 1, 1.3] 158 | image = enh_con.enhance(random.uniform(0.7, 1.3)) 159 | image = np.array(image) 160 | 161 | if random.random() < 0.5: 162 | image = Image.fromarray(image) 163 | enh_sha = ImageEnhance.Sharpness(image) 164 | # sharpness = [-0.5, 0, 1.0] 165 | image = enh_sha.enhance(random.uniform(0, 2.0)) 166 | image = np.array(image) 167 | 168 | return image, bboxes 169 | 170 | def random_horizontal_flip(self, image, bboxes): 171 | _, w, _ = image.shape 172 | image = image[:, ::-1, :] 173 | image = np.array(image) 174 | 175 | if bboxes.size != 0: 176 | bboxes[:, [0, 2]] = w - bboxes[:, [2, 0]] 177 | 178 | return image, bboxes 179 | 180 | def random_vertical_flip(self, image, bboxes): 181 | h, _, _ = image.shape 182 | image = image[::-1, :, :] 183 | image = np.array(image) 184 | 185 | if bboxes.size != 0: 186 | bboxes[:, [1, 3]] = h - bboxes[:, [3, 1]] 187 | 188 | return image, bboxes 189 | 190 | def random_crop(self, image, bboxes): 191 | if bboxes.size == 0: 192 | return image, bboxes 193 | 194 | h, w, _ = image.shape 195 | 196 | max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1) 197 | 198 | max_l_trans = max_bbox[0] 199 | max_u_trans = max_bbox[1] 200 | max_r_trans = w - max_bbox[2] 201 | max_d_trans = h - max_bbox[3] 202 | 203 | crop_xmin = max(0, int(max_bbox[0] - random.uniform(0, max_l_trans))) 204 | crop_ymin = max(0, int(max_bbox[1] - random.uniform(0, max_u_trans))) 205 | crop_xmax = min(w, int(max_bbox[2] + random.uniform(0, max_r_trans))) 206 | crop_ymax = min(h, int(max_bbox[3] + random.uniform(0, max_d_trans))) 207 | 208 | image = image[crop_ymin: crop_ymax, crop_xmin: crop_xmax] 209 | 210 | bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - crop_xmin 211 | bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - crop_ymin 212 | 213 | return image, bboxes 214 | 215 | def random_translate(self, image, bboxes): 216 | if bboxes.size == 0: 217 | return image, bboxes 218 | 219 | h, w, _ = image.shape 220 | 221 | max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0), np.max(bboxes[:, 2:4], axis=0)], axis=-1) 222 | max_l_trans = max_bbox[0] 223 | max_u_trans = max_bbox[1] 224 | max_r_trans = w - max_bbox[2] 225 | max_d_trans = h - max_bbox[3] 226 | 227 | tx = random.uniform(-(max_l_trans - 1), (max_r_trans - 1)) 228 | ty = random.uniform(-(max_u_trans - 1), (max_d_trans - 1)) 229 | 230 | M = np.array([[1, 0, tx], [0, 1, ty]]) 231 | image = cv.warpAffine(image, M, (w, h), borderValue=(128, 128, 128)) 232 | 233 | bboxes[:, [0, 2]] = bboxes[:, [0, 2]] + tx 234 | bboxes[:, [1, 3]] = bboxes[:, [1, 3]] + ty 235 | 236 | return image, bboxes 237 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : train.py 3 | # @Author: Runist 4 | # @Time : 2022/3/28 17:51 5 | # @Software: PyCharm 6 | # @Brief: train script 7 | from args import args, dev, class_names 8 | 9 | from core.loss import focal_loss, l1_loss 10 | from core.helper import remove_dir_and_create_dir, get_model, get_dataset, draw_bbox 11 | from core.detect import postprocess_output, decode_bbox 12 | from core.dataset import recover_input 13 | from net import CenterNetPoolingNMS 14 | 15 | from torch.utils.data import DataLoader 16 | from tensorboardX import SummaryWriter 17 | import torch.optim as optim 18 | from tqdm import tqdm 19 | import numpy as np 20 | import torch 21 | import os 22 | import math 23 | 24 | 25 | class CosineAnnealingWarmupRestarts(optim.lr_scheduler._LRScheduler): 26 | """ 27 | optimizer (Optimizer): Wrapped optimizer. 28 | first_cycle_steps (int): First cycle step size. 29 | cycle_mult(float): Cycle steps magnification. Default: -1. 30 | max_lr(float): First cycle's max learning rate. Default: 0.1. 31 | min_lr(float): Min learning rate. Default: 0.001. 32 | warmup_steps(int): Linear warmup step size. Default: 0. 33 | gamma(float): Decrease rate of max learning rate by cycle. Default: 1. 34 | last_epoch (int): The index of last epoch. Default: -1. 35 | """ 36 | 37 | def __init__(self, 38 | optimizer: torch.optim.Optimizer, 39 | first_cycle_steps: int, 40 | cycle_mult: float = 1., 41 | max_lr: float = 0.1, 42 | min_lr: float = 0.001, 43 | warmup_steps: int = 0, 44 | gamma: float = 1., 45 | last_epoch: int = -1 46 | ): 47 | assert warmup_steps < first_cycle_steps 48 | 49 | self.first_cycle_steps = first_cycle_steps # first cycle step size 50 | self.cycle_mult = cycle_mult # cycle steps magnification 51 | self.base_max_lr = max_lr # first max learning rate 52 | self.max_lr = max_lr # max learning rate in the current cycle 53 | self.min_lr = min_lr # min learning rate 54 | self.warmup_steps = warmup_steps # warmup step size 55 | self.gamma = gamma # decrease rate of max learning rate by cycle 56 | 57 | self.cur_cycle_steps = first_cycle_steps # first cycle step size 58 | self.cycle = 0 # cycle count 59 | self.step_in_cycle = last_epoch # step size of the current cycle 60 | 61 | super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) 62 | 63 | # set learning rate min_lr 64 | self.init_lr() 65 | 66 | def init_lr(self): 67 | self.base_lrs = [] 68 | for param_group in self.optimizer.param_groups: 69 | param_group['lr'] = self.min_lr 70 | self.base_lrs.append(self.min_lr) 71 | 72 | def get_lr(self): 73 | if self.step_in_cycle == -1: 74 | return self.base_lrs 75 | elif self.step_in_cycle < self.warmup_steps: 76 | return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] 77 | else: 78 | return [base_lr + (self.max_lr - base_lr) \ 79 | * (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ 80 | / (self.cur_cycle_steps - self.warmup_steps))) / 2 81 | for base_lr in self.base_lrs] 82 | 83 | def step(self, epoch=None): 84 | if epoch is None: 85 | epoch = self.last_epoch + 1 86 | self.step_in_cycle = self.step_in_cycle + 1 87 | if self.step_in_cycle >= self.cur_cycle_steps: 88 | self.cycle += 1 89 | self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps 90 | self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps 91 | else: 92 | if epoch >= self.first_cycle_steps: 93 | if self.cycle_mult == 1.: 94 | self.step_in_cycle = epoch % self.first_cycle_steps 95 | self.cycle = epoch // self.first_cycle_steps 96 | else: 97 | n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) 98 | self.cycle = n 99 | self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) 100 | self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) 101 | else: 102 | self.cur_cycle_steps = self.first_cycle_steps 103 | self.step_in_cycle = epoch 104 | 105 | self.max_lr = self.base_max_lr * (self.gamma**self.cycle) 106 | self.last_epoch = math.floor(epoch) 107 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 108 | param_group['lr'] = lr 109 | 110 | 111 | def get_summary_image(images, 112 | hms_true, whs_true, offsets_true, 113 | hms_pred, whs_pred, offsets_pred, dev): 114 | 115 | summary_images = [] 116 | 117 | outputs_true = postprocess_output(hms_true, whs_true, offsets_true, args.confidence, dev) 118 | outputs_pred = postprocess_output(hms_pred, whs_pred, offsets_pred, args.confidence, dev) 119 | outputs_true = decode_bbox(outputs_true, 120 | (args.input_height, args.input_height), 121 | dev, need_nms=True, nms_thres=0.4) 122 | outputs_pred = decode_bbox(outputs_pred, 123 | (args.input_height, args.input_height), 124 | dev, need_nms=True, nms_thres=0.4) 125 | 126 | images = images.cpu().numpy() 127 | for i in range(len(images)): 128 | image = images[i] 129 | image = recover_input(image.copy()) 130 | 131 | output_true = outputs_true[i] 132 | output_pred = outputs_pred[i] 133 | 134 | if len(output_true) != 0: 135 | output_true = output_true.data.cpu().numpy() 136 | labels_true = output_true[:, 5] 137 | bboxes_true = output_true[:, :4] 138 | else: 139 | labels_true = [] 140 | bboxes_true = [] 141 | 142 | if len(output_pred) != 0: 143 | output_pred = output_pred.data.cpu().numpy() 144 | labels_pred = output_pred[:, 5] 145 | bboxes_pred = output_pred[:, :4] 146 | else: 147 | labels_pred = [] 148 | bboxes_pred = [] 149 | 150 | image_true = draw_bbox(image, bboxes_true, labels_true, class_names) 151 | image_pred = draw_bbox(image, bboxes_pred, labels_pred, class_names) 152 | 153 | summary_images.append(np.hstack((image_true, image_pred)).astype(np.uint8)) 154 | 155 | return summary_images 156 | 157 | 158 | def train_one_epochs(model, train_loader, epoch, optimizer, scheduler, dev, writer): 159 | global step 160 | 161 | model.train() 162 | tbar = tqdm(train_loader) 163 | 164 | total_loss = [] 165 | image_write_step = len(train_loader) 166 | 167 | for images, hms_true, whs_true, offsets_true, offset_masks_true in tbar: 168 | tbar.set_description("epoch {}".format(epoch)) 169 | 170 | # Set variables for training 171 | images = images.float().to(dev) 172 | hms_true = hms_true.float().to(dev) 173 | whs_true = whs_true.float().to(dev) 174 | offsets_true = offsets_true.float().to(dev) 175 | offset_masks_true = offset_masks_true.float().to(dev) 176 | 177 | # Zero the gradient 178 | optimizer.zero_grad() 179 | 180 | # Get model predictions, calculate loss 181 | training_output = model(images, mode='train', ground_truth_data=(hms_true, 182 | whs_true, 183 | offsets_true, 184 | offset_masks_true)) 185 | hms_pred, whs_pred, offsets_pred, loss, c_loss, wh_loss, off_loss, hms_true = training_output 186 | 187 | loss = loss.mean() 188 | c_loss = c_loss.mean() 189 | wh_loss = wh_loss.mean() 190 | off_loss = off_loss.mean() 191 | 192 | total_loss.append(loss.item()) 193 | 194 | if step % image_write_step == 0: 195 | summary_images = get_summary_image(images, 196 | hms_true, whs_true, offsets_true, 197 | hms_pred, whs_pred, offsets_pred, dev) 198 | for i, summary_image in enumerate(summary_images): 199 | writer.add_image('train_images_{}'.format(i), summary_image, global_step=step, dataformats="HWC") 200 | 201 | writer.add_scalar("loss", loss.item(), step) 202 | writer.add_scalar("c_loss", c_loss.item(), step) 203 | writer.add_scalar("wh_loss", wh_loss.item(), step) 204 | writer.add_scalar("offset_loss", off_loss.item(), step) 205 | writer.add_scalar("lr", optimizer.param_groups[0]['lr'], step) 206 | 207 | loss.backward() 208 | optimizer.step() 209 | scheduler.step() 210 | 211 | step += 1 212 | tbar.set_postfix(total_loss="{:.4f}".format(loss.item()), 213 | c_loss="{:.4f}".format(c_loss.item()), 214 | wh_loss="{:.4f}".format(wh_loss.item()), 215 | offset_loss="{:.4f}".format(off_loss.item())) 216 | 217 | # clear batch variables from memory 218 | del images, hms_true, whs_true, offsets_true, offset_masks_true 219 | 220 | return np.mean(total_loss) 221 | 222 | 223 | def eval_one_epochs(model, val_loader, epoch, dev, writer): 224 | 225 | model.eval() 226 | 227 | total_loss = [] 228 | total_c_loss = [] 229 | total_wh_loss = [] 230 | total_offset_loss = [] 231 | write_image = True 232 | 233 | with torch.no_grad(): 234 | for images, hms_true, whs_true, offsets_true, offset_masks_true in val_loader: 235 | 236 | # Set variables for training 237 | images = images.float().to(dev) 238 | hms_true = hms_true.float().to(dev) 239 | whs_true = whs_true.float().to(dev) 240 | offsets_true = offsets_true.float().to(dev) 241 | offset_masks_true = offset_masks_true.float().to(dev) 242 | 243 | # Get model predictions, calculate loss 244 | training_output = model(images, mode='train', ground_truth_data=(hms_true, 245 | whs_true, 246 | offsets_true, 247 | offset_masks_true)) 248 | hms_pred, whs_pred, offsets_pred, loss, c_loss, wh_loss, off_loss, hms_true = training_output 249 | 250 | loss = loss.mean() 251 | c_loss = c_loss.mean() 252 | wh_loss = wh_loss.mean() 253 | off_loss = off_loss.mean() 254 | 255 | total_loss.append(loss.item()) 256 | total_c_loss.append(c_loss.item()) 257 | total_wh_loss.append(wh_loss.item()) 258 | total_offset_loss.append(off_loss.item()) 259 | 260 | if write_image: 261 | write_image = False 262 | summary_images = get_summary_image(images, 263 | hms_true, whs_true, offsets_true, 264 | hms_pred, whs_pred, offsets_pred, dev) 265 | for i, summary_image in enumerate(summary_images): 266 | writer.add_image('val_images_{}'.format(i), summary_image, global_step=epoch, dataformats="HWC") 267 | 268 | # clear batch variables from memory 269 | del images, hms_true, whs_true, offsets_true, offset_masks_true 270 | 271 | writer.add_scalar("val_loss", np.mean(total_loss), epoch) 272 | writer.add_scalar("val_c_loss", np.mean(total_c_loss), epoch) 273 | writer.add_scalar("val_wh_loss", np.mean(total_wh_loss), epoch) 274 | writer.add_scalar("val_offset_loss", np.mean(total_offset_loss), epoch) 275 | 276 | return np.mean(total_loss) 277 | 278 | 279 | if __name__ == '__main__': 280 | remove_dir_and_create_dir(os.path.join(args.logs_dir, "weights"), is_remove=True) 281 | remove_dir_and_create_dir(os.path.join(args.logs_dir, "summary"), is_remove=True) 282 | 283 | model = get_model(args, dev) 284 | train_dataset, val_dataset = get_dataset(args, class_names) 285 | 286 | writer = SummaryWriter(os.path.join(args.logs_dir, "summary")) 287 | step = 0 288 | 289 | freeze_step = len(train_dataset) // args.freeze_batch_size 290 | unfreeze_step = len(train_dataset) // args.unfreeze_batch_size 291 | params = [p for p in model.parameters() if p.requires_grad] 292 | optimizer = optim.AdamW(params, args.learn_rate_init) 293 | scheduler = CosineAnnealingWarmupRestarts(optimizer, 294 | first_cycle_steps=args.freeze_epochs * freeze_step + args.unfreeze_epochs * unfreeze_step, 295 | max_lr=args.learn_rate_init, 296 | min_lr=args.learn_rate_end, 297 | warmup_steps=args.warmup_epochs * freeze_step) 298 | 299 | # freeze 300 | if args.freeze_epochs > 0: 301 | print("Freeze backbone and decoder, train {} epochs.".format(args.freeze_epochs)) 302 | model.module.freeze_backbone() 303 | 304 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.freeze_batch_size, 305 | num_workers=args.num_workers, pin_memory=True) 306 | val_loader = DataLoader(val_dataset, shuffle=False, batch_size=args.freeze_batch_size, 307 | num_workers=args.num_workers, pin_memory=True) 308 | 309 | for epoch in range(args.freeze_epochs): 310 | train_loss = train_one_epochs(model, train_loader, epoch, optimizer, scheduler, dev, writer) 311 | val_loss = eval_one_epochs(model, val_loader, epoch, dev, writer) 312 | print("=> loss: {:.4f} val_loss: {:.4f}".format(train_loss, val_loss)) 313 | torch.save(model, 314 | '{}/weights/epoch={}_loss={:.4f}_val_loss={:.4f}.pt'. 315 | format(args.logs_dir, epoch, train_loss, val_loss)) 316 | 317 | # unfreeze 318 | if args.unfreeze_epochs > 0: 319 | print("Unfreeze backbone and decoder, train {} epochs.".format(args.unfreeze_epochs)) 320 | model.module.unfreeze_backbone() 321 | 322 | train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.unfreeze_batch_size, 323 | num_workers=args.num_workers, pin_memory=True) 324 | val_loader = DataLoader(val_dataset, shuffle=False, batch_size=args.unfreeze_batch_size, 325 | num_workers=args.num_workers, pin_memory=True) 326 | 327 | for epoch in range(args.unfreeze_epochs): 328 | epoch = args.freeze_epochs + epoch 329 | train_loss = train_one_epochs(model, train_loader, epoch, optimizer, scheduler, dev, writer) 330 | val_loss = eval_one_epochs(model, val_loader, epoch, dev, writer) 331 | print("=> loss: {:.4f} val_loss: {:.4f}".format(train_loss, val_loss)) 332 | torch.save(model, 333 | '{}/weights/epoch={}_loss={:.4f}_val_loss={:.4f}.pt'. 334 | format(args.logs_dir, epoch, train_loss, val_loss)) -------------------------------------------------------------------------------- /core/map.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : map.py 3 | # @Author: Runist 4 | # @Time : 2022/4/10 10:11 5 | # @Software: PyCharm 6 | # @Brief: Map function 7 | 8 | import json 9 | import os 10 | import torch 11 | import numpy as np 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | from pycocotools.coco import COCO 16 | from pycocotools.cocoeval import COCOeval 17 | 18 | from core.detect import predict 19 | import xml.etree.ElementTree as ET 20 | 21 | 22 | def get_ground_truth(args, class_names): 23 | ground_truth = {} 24 | categories_info = [] 25 | 26 | for i, cls in enumerate(class_names): 27 | category = {} 28 | category['supercategory'] = cls 29 | category['name'] = cls 30 | category['id'] = i + 1 31 | categories_info.append(category) 32 | 33 | print("Creating ground-truth files ...") 34 | if args.dataset_format == "yolo": 35 | images_info, annotations_info = get_yolo_ground_truth(args.dataset_val_path) 36 | elif args.dataset_format == "coco": 37 | images_info, annotations_info = get_coco_ground_truth(args.dataset_val_path) 38 | elif args.dataset_format == "voc": 39 | images_info, annotations_info = get_voc_ground_truth(args, class_names) 40 | elif args.dataset_format == "ilsvrc": 41 | images_info, annotations_info = get_ilsvrc_ground_truth(args, class_names) 42 | else: 43 | raise Exception("There is no {} format for data parsing, you should choose one from 'yolo', 'coco', 'voc', 'ilsvrc'". 44 | format(args.dataset_format)) 45 | 46 | ground_truth['images'] = images_info 47 | ground_truth['categories'] = categories_info 48 | ground_truth['annotations'] = annotations_info 49 | 50 | return ground_truth 51 | 52 | 53 | def get_map(args, output_files_path, class_names, model, dev): 54 | ground_truth = get_ground_truth(args, class_names) 55 | 56 | print("Creating detection-result files ...") 57 | if args.dataset_format == "yolo": 58 | detection_result = get_yolo_detection_result(args, model, dev) 59 | elif args.dataset_format == "coco": 60 | detection_result = get_coco_detection_result(args, model, dev) 61 | elif args.dataset_format == "voc": 62 | detection_result = get_voc_detection_result(args, model, dev) 63 | elif args.dataset_format == "ilsvrc": 64 | detection_result = get_ilsvrc_detection_result(args, model, dev) 65 | else: 66 | raise Exception("There is no {} format for data parsing, you should choose one from 'yolo', 'coco', 'voc', 'ilsvrc'". 67 | format(args.dataset_format)) 68 | 69 | print("Calculating map ...") 70 | gt_json_path = os.path.join(output_files_path, 'instances_gt.json') 71 | dr_json_path = os.path.join(output_files_path, 'instances_dr.json') 72 | 73 | with open(gt_json_path, "w") as f: 74 | json.dump(ground_truth, f, indent=4) 75 | 76 | with open(dr_json_path, "w") as f: 77 | json.dump(detection_result, f, indent=4) 78 | 79 | coco_gt = COCO(gt_json_path) 80 | coco_dt = coco_gt.loadRes(dr_json_path) 81 | coco_eval = COCOeval(coco_gt, coco_dt, 'bbox') 82 | coco_eval.evaluate() 83 | coco_eval.accumulate() 84 | coco_eval.summarize() 85 | 86 | # get all classes ap 87 | precisions = coco_eval.eval['precision'] 88 | 89 | # precision: (iou, recall, cls, area range, max dets) 90 | assert len(coco_gt.getCatIds()) == precisions.shape[2] 91 | 92 | for idx, catId in enumerate(coco_gt.getCatIds()): 93 | # area range index 0: all area ranges 94 | # max dets index -1: typically 100 per image 95 | name = coco_gt.loadCats(catId)[0] 96 | precision = precisions[:, :, idx, 0, -1] 97 | precision = precision[precision > -1] 98 | 99 | if precision.size: 100 | ap = np.mean(precision) 101 | else: 102 | ap = float('nan') 103 | 104 | print("{}: {:.1f}%".format(name["name"], float(ap) * 100)) 105 | 106 | 107 | def get_coco_ground_truth(annotations_path): 108 | coco = COCO(annotations_path) 109 | image_ids = coco.getImgIds() 110 | cat_ids = sorted(coco.getCatIds()) 111 | 112 | i = 0 113 | images_info = [] 114 | annotations_info = [] 115 | 116 | for image_id in tqdm(image_ids): 117 | image_info = {} 118 | 119 | image_info['file_name'] = str(image_id) + '.jpg' 120 | image_info['width'] = 1 121 | image_info['height'] = 1 122 | image_info['id'] = str(image_id) 123 | images_info.append(image_info) 124 | 125 | ann_id = coco.getAnnIds(imgIds=[int(image_id)], iscrowd=False) 126 | annotation = coco.loadAnns(ann_id) 127 | 128 | for obj in annotation: 129 | class_id = cat_ids.index(obj["category_id"]) 130 | xmin = np.max((0, obj["bbox"][0])) 131 | ymin = np.max((0, obj["bbox"][1])) 132 | width = np.max((0, obj["bbox"][2])) 133 | height = np.max((0, obj["bbox"][3])) 134 | 135 | annotation = {} 136 | if obj["area"] > 0 and width > 0 and height > 0: 137 | iscrowd = 0 138 | 139 | annotation['area'] = width * height 140 | annotation['category_id'] = int(class_id + 1) 141 | annotation['image_id'] = str(image_id) 142 | annotation['iscrowd'] = iscrowd 143 | annotation['bbox'] = [int(xmin), int(ymin), int(width), int(height)] 144 | annotation['id'] = i 145 | annotations_info.append(annotation) 146 | i += 1 147 | 148 | return images_info, annotations_info 149 | 150 | 151 | def get_coco_detection_result(args, model, dev): 152 | coco = COCO(args.dataset_val_path) 153 | image_ids = coco.getImgIds() 154 | detection_result = [] 155 | 156 | for image_id in tqdm(image_ids): 157 | img_ann = coco.loadImgs(image_id)[0] 158 | 159 | filename = img_ann["file_name"] 160 | image_path = os.path.join(args.image_val_dir, filename) 161 | image = Image.open(image_path) 162 | image = np.array(image) 163 | 164 | if len(image.shape) == 2: 165 | image = np.expand_dims(image, axis=-1) 166 | image = image.repeat(3, axis=-1) 167 | 168 | model.eval() 169 | with torch.no_grad(): 170 | outputs = predict(image, model, dev, args) 171 | 172 | if len(outputs) == 0: 173 | continue 174 | 175 | outputs = outputs.data.cpu().numpy() 176 | labels = outputs[:, 5] 177 | scores = outputs[:, 4] 178 | bboxes = outputs[:, :4] 179 | 180 | for bbox, class_id, score in zip(bboxes, labels, scores): 181 | xmin, ymin, xmax, ymax = bbox 182 | 183 | result = {} 184 | result["image_id"] = str(image_id) 185 | result["category_id"] = class_id + 1 186 | result["bbox"] = [int(xmin), int(ymin), int(xmax - xmin), int(ymax - ymin)] 187 | result["score"] = float(score) 188 | detection_result.append(result) 189 | 190 | return detection_result 191 | 192 | 193 | def get_yolo_ground_truth(annotations_path): 194 | with open(annotations_path, 'r', encoding='utf-8') as f: 195 | txt = f.readlines() 196 | annotations = [line.strip('|') for line in txt if len(line.strip("|").split("|")[1:]) != 0] 197 | 198 | i = 0 199 | images_info = [] 200 | annotations_info = [] 201 | 202 | for annotation in tqdm(annotations): 203 | annotation = annotation.strip() 204 | line = annotation.split("|") 205 | 206 | image_path = line[0] 207 | image_id = os.path.split(image_path)[-1].split(".")[0] 208 | 209 | image_info = {} 210 | 211 | image_info['file_name'] = image_id + '.jpg' 212 | image_info['width'] = 1 213 | image_info['height'] = 1 214 | image_info['id'] = str(image_id) 215 | images_info.append(image_info) 216 | 217 | bboxes = np.array([list(map(lambda x: int(float(x)), box.split(','))) for box in line[1:]]) 218 | for bbox in bboxes: 219 | xmin, ymin, xmax, ymax, class_id = bbox 220 | xmin = int(xmin) 221 | ymin = int(ymin) 222 | xmax = int(xmax) 223 | ymax = int(ymax) 224 | class_id = int(class_id) 225 | 226 | w = xmax - xmin 227 | h = ymax - ymin 228 | 229 | annotation = {} 230 | annotation['area'] = w * h 231 | annotation['category_id'] = class_id + 1 232 | annotation['image_id'] = str(image_id) 233 | annotation['iscrowd'] = 0 234 | annotation['bbox'] = [xmin, ymin, w, h] 235 | annotation['id'] = i 236 | annotations_info.append(annotation) 237 | i += 1 238 | 239 | return images_info, annotations_info 240 | 241 | 242 | def get_yolo_detection_result(args, model, dev): 243 | with open(args.dataset_val_path, 'r', encoding='utf-8') as f: 244 | txt = f.readlines() 245 | annotations = [line.strip('|') for line in txt if len(line.strip("|").split("|")[1:]) != 0] 246 | 247 | detection_result = [] 248 | 249 | for annotation in tqdm(annotations): 250 | annotation = annotation.strip() 251 | line = annotation.split("|") 252 | 253 | image_path = line[0] 254 | image_id = os.path.split(image_path)[-1].split(".")[0] 255 | 256 | image = Image.open(image_path) 257 | image = np.array(image) 258 | 259 | model.eval() 260 | with torch.no_grad(): 261 | outputs = predict(image, model, dev, args) 262 | 263 | if len(outputs) == 0: 264 | continue 265 | 266 | outputs = outputs.data.cpu().numpy() 267 | labels = outputs[:, 5] 268 | scores = outputs[:, 4] 269 | bboxes = outputs[:, :4] 270 | 271 | for bbox, class_id, score in zip(bboxes, labels, scores): 272 | xmin, ymin, xmax, ymax = bbox 273 | 274 | result = {} 275 | result["image_id"] = str(image_id) 276 | result["category_id"] = class_id + 1 277 | result["bbox"] = [int(xmin), int(ymin), int(xmax - xmin), int(ymax - ymin)] 278 | result["score"] = float(score) 279 | detection_result.append(result) 280 | 281 | return detection_result 282 | 283 | 284 | def get_voc_ground_truth(args, class_names): 285 | with open(args.dataset_val_path, 'r', encoding='utf-8') as f: 286 | txt = f.readlines() 287 | image_ids = [line.strip().split()[0] for line in txt] 288 | 289 | i = 0 290 | images_info = [] 291 | annotations_info = [] 292 | 293 | for image_id in tqdm(image_ids): 294 | xml = ET.parse(os.path.join(args.annotation_val_dir, image_id + ".xml")).getroot() 295 | 296 | image_info = {} 297 | 298 | image_info['file_name'] = image_id + '.jpg' 299 | image_info['width'] = 1 300 | image_info['height'] = 1 301 | image_info['id'] = str(image_id) 302 | images_info.append(image_info) 303 | 304 | for obj in xml.iter("object"): 305 | difficult = obj.find("difficult") 306 | if difficult is not None: 307 | difficult = int(difficult.text) == 1 308 | else: 309 | difficult = False 310 | 311 | if difficult: 312 | continue 313 | 314 | name = obj.find("name").text.strip() 315 | bbox = obj.find("bndbox") 316 | 317 | x1 = int(float(bbox.find("xmin").text)) 318 | y1 = int(float(bbox.find("ymin").text)) 319 | x2 = int(float(bbox.find("xmax").text)) 320 | y2 = int(float(bbox.find("ymax").text)) 321 | w = x2 - x1 322 | h = y2 - y1 323 | 324 | cls_id = class_names.index(name) + 1 325 | 326 | annotation = {} 327 | annotation['area'] = w * h - 10.0 328 | annotation['category_id'] = cls_id 329 | annotation['image_id'] = str(image_id) 330 | annotation['iscrowd'] = 0 331 | annotation['bbox'] = [x1, y1, w, h] 332 | annotation['id'] = i 333 | annotations_info.append(annotation) 334 | i += 1 335 | 336 | return images_info, annotations_info 337 | 338 | 339 | def get_voc_detection_result(args, model, dev): 340 | with open(args.dataset_val_path, 'r', encoding='utf-8') as f: 341 | txt = f.readlines() 342 | image_ids = [line.strip().split()[0] for line in txt] 343 | 344 | detection_result = [] 345 | 346 | for image_id in tqdm(image_ids): 347 | image_path = os.path.join(args.image_val_dir, image_id + ".jpg") 348 | image = Image.open(image_path) 349 | image = np.array(image) 350 | model.eval() 351 | with torch.no_grad(): 352 | outputs = predict(image, model, dev, args) 353 | 354 | if len(outputs) == 0: 355 | continue 356 | 357 | outputs = outputs.data.cpu().numpy() 358 | labels = outputs[:, 5] 359 | scores = outputs[:, 4] 360 | bboxes = outputs[:, :4] 361 | 362 | for bbox, class_id, score in zip(bboxes, labels, scores): 363 | xmin, ymin, xmax, ymax = bbox 364 | 365 | result = {} 366 | result["image_id"] = str(image_id) 367 | result["category_id"] = class_id + 1 368 | result["bbox"] = [int(xmin), int(ymin), int(xmax - xmin), int(ymax - ymin)] 369 | result["score"] = float(score) 370 | detection_result.append(result) 371 | 372 | return detection_result 373 | 374 | 375 | def get_ilsvrc_ground_truth(args, class_names): 376 | image_ids = [] 377 | with open(args.dataset_val_path, 'r', encoding='utf-8') as f: 378 | txt = f.readlines() 379 | 380 | for line in txt: 381 | line = line.strip() 382 | if "extra" in line: 383 | continue 384 | image_ids.append(line.split()[0]) 385 | 386 | i = 0 387 | images_info = [] 388 | annotations_info = [] 389 | 390 | for image_id in tqdm(image_ids): 391 | xml = ET.parse(os.path.join(args.annotation_val_dir, image_id + ".xml")).getroot() 392 | 393 | image_info = {} 394 | 395 | image_info['file_name'] = image_id + '.JPEG' 396 | image_info['width'] = 1 397 | image_info['height'] = 1 398 | image_info['id'] = str(image_id) 399 | images_info.append(image_info) 400 | 401 | for obj in xml.iter("object"): 402 | name = obj.find("name").text.strip() 403 | bbox = obj.find("bndbox") 404 | 405 | x1 = int(float(bbox.find("xmin").text)) 406 | y1 = int(float(bbox.find("ymin").text)) 407 | x2 = int(float(bbox.find("xmax").text)) 408 | y2 = int(float(bbox.find("ymax").text)) 409 | w = x2 - x1 410 | h = y2 - y1 411 | 412 | cls_id = class_names.index(name) + 1 413 | 414 | annotation = {} 415 | annotation['area'] = w * h - 10.0 416 | annotation['category_id'] = cls_id 417 | annotation['image_id'] = str(image_id) 418 | annotation['iscrowd'] = 0 419 | annotation['bbox'] = [x1, y1, w, h] 420 | annotation['id'] = i 421 | annotations_info.append(annotation) 422 | i += 1 423 | 424 | return images_info, annotations_info 425 | 426 | 427 | def get_ilsvrc_detection_result(args, model, dev): 428 | image_ids = [] 429 | with open(args.dataset_val_path, 'r', encoding='utf-8') as f: 430 | txt = f.readlines() 431 | 432 | for line in txt: 433 | line = line.strip() 434 | if "extra" in line: 435 | continue 436 | image_ids.append(line.split()[0]) 437 | 438 | detection_result = [] 439 | 440 | for image_id in tqdm(image_ids): 441 | image_path = os.path.join(args.image_val_dir, image_id + ".JPEG") 442 | 443 | image = Image.open(image_path) 444 | image = np.array(image) 445 | if len(image.shape) == 2: 446 | image = np.expand_dims(image, axis=-1) 447 | image = image.repeat(3, axis=-1) 448 | if image.shape[-1] == 4: 449 | image = image[:, :, :-1] 450 | 451 | model.eval() 452 | with torch.no_grad(): 453 | outputs = predict(image, model, dev, args) 454 | 455 | if len(outputs) == 0: 456 | continue 457 | 458 | outputs = outputs.data.cpu().numpy() 459 | labels = outputs[:, 5] 460 | scores = outputs[:, 4] 461 | bboxes = outputs[:, :4] 462 | 463 | for bbox, class_id, score in zip(bboxes, labels, scores): 464 | xmin, ymin, xmax, ymax = bbox 465 | 466 | result = {} 467 | result["image_id"] = str(image_id) 468 | result["category_id"] = class_id + 1 469 | result["bbox"] = [int(xmin), int(ymin), int(xmax - xmin), int(ymax - ymin)] 470 | result["score"] = float(score) 471 | detection_result.append(result) 472 | 473 | return detection_result 474 | --------------------------------------------------------------------------------