├── utils ├── __init__.py ├── visualize.py ├── ap_eval.py ├── box_ops.py └── roi_align.py ├── dataset ├── __init__.py ├── transforms.py ├── data.py └── augmentation.py ├── pics ├── map.png ├── total_loss.png ├── roi_cls_loss.png ├── roi_loc_loss.png ├── rpn_cls_loss.png └── rpn_loc_loss.png ├── requirements.txt ├── .gitignore ├── model ├── roi_head.py ├── resnet.py ├── faster_rcnn.py └── rpn.py ├── README.md ├── train.py └── evaluate.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pics/map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/TrafficSignDetection/HEAD/pics/map.png -------------------------------------------------------------------------------- /pics/total_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/TrafficSignDetection/HEAD/pics/total_loss.png -------------------------------------------------------------------------------- /pics/roi_cls_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/TrafficSignDetection/HEAD/pics/roi_cls_loss.png -------------------------------------------------------------------------------- /pics/roi_loc_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/TrafficSignDetection/HEAD/pics/roi_loc_loss.png -------------------------------------------------------------------------------- /pics/rpn_cls_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/TrafficSignDetection/HEAD/pics/rpn_cls_loss.png -------------------------------------------------------------------------------- /pics/rpn_loc_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jittor/TrafficSignDetection/HEAD/pics/rpn_loc_loss.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | Pillow 4 | matplotlib 5 | tqdm 6 | tensorboardX 7 | tensorboard 8 | jittor -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode 3 | __pycache__ 4 | __pycache__/__pycache__ 5 | runs 6 | tmp 7 | *.pkl 8 | *.json 9 | runs* 10 | test_imgs -------------------------------------------------------------------------------- /model/roi_head.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from jittor import nn,init 3 | import jittor as jt 4 | import numpy as np 5 | from utils.roi_align import ROIAlign 6 | 7 | class RoIHead(nn.Module): 8 | """ 9 | This class is used as a head for Faster R-CNN. 10 | This outputs class-wise localizations and classification based on feature 11 | maps in the given RoIs. 12 | 13 | Args: 14 | n_class (int): The number of classes possibly including the background. 15 | roi_size (int): Height and width of the feature maps after RoI-pooling. 16 | spatial_scale (float): Scale of the roi is resized. 17 | """ 18 | 19 | def __init__(self, in_channels,n_class, roi_size, spatial_scale,sampling_ratio): 20 | # n_class includes the background 21 | super(RoIHead, self).__init__() 22 | 23 | self.classifier = nn.Sequential( 24 | nn.Linear(in_channels * roi_size * roi_size, 4096), 25 | nn.ReLU(), 26 | nn.Linear(4096, 4096), 27 | nn.ReLU() 28 | ) 29 | self.cls_loc = nn.Linear(4096, n_class * 4) 30 | self.score = nn.Linear(4096, n_class) 31 | 32 | self.n_class = n_class 33 | self.roi_size = roi_size 34 | self.spatial_scale = spatial_scale 35 | self.roi = ROIAlign((self.roi_size, self.roi_size),self.spatial_scale,sampling_ratio=sampling_ratio) 36 | 37 | init.gauss_(self.cls_loc.weight,0,0.001) 38 | init.constant_(self.cls_loc.bias,0) 39 | init.gauss_(self.score.weight,0,0.01) 40 | init.constant_(self.score.bias,0) 41 | 42 | 43 | def execute(self, x, rois, roi_indices): 44 | """Forward the chain. 45 | 46 | We assume that there are :math:`N` batches. 47 | 48 | Args: 49 | x (Variable): 4D image variable. 50 | rois (Tensor): A bounding box array containing coordinates of 51 | proposal boxes. This is a concatenation of bounding box 52 | arrays from multiple images in the batch. 53 | Its shape is :math:`(R', 4)`. Given :math:`R_i` proposed 54 | RoIs from the :math:`i` th image, 55 | :math:`R' = \\sum _{i=1} ^ N R_i`. 56 | roi_indices (Tensor): An array containing indices of images to 57 | which bounding boxes correspond to. Its shape is :math:`(R',)`. 58 | 59 | """ 60 | indices_and_rois = jt.contrib.concat([roi_indices.unsqueeze(1), rois], dim=1) 61 | pool = self.roi(x, indices_and_rois) 62 | pool = pool.view(pool.shape[0], np.prod(pool.shape[1:]).item()) 63 | fc7 = self.classifier(pool) 64 | roi_cls_locs = self.cls_loc(fc7) 65 | roi_scores = self.score(fc7) 66 | return roi_cls_locs, roi_scores -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import jittor as jt 3 | import numpy as np 4 | import random 5 | import jittor.transform as T 6 | from PIL import Image 7 | 8 | class Compose(object): 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | 12 | def __call__(self, image, target=None): 13 | for t in self.transforms: 14 | image,target = t(image, target) 15 | return image, target 16 | 17 | class Resize(object): 18 | def __init__(self, min_size, max_size): 19 | if not isinstance(min_size, (list, tuple)): 20 | min_size = (min_size,) 21 | self.min_size = min_size 22 | self.max_size = max_size 23 | 24 | def get_size(self, image_size): 25 | w, h = image_size 26 | size = random.choice(self.min_size) 27 | max_size = self.max_size 28 | if max_size is not None: 29 | min_original_size = float(min((w, h))) 30 | max_original_size = float(max((w, h))) 31 | if max_original_size / min_original_size * size > max_size: 32 | size = int(round(max_size * min_original_size / max_original_size)) 33 | 34 | if (w <= h and w == size) or (h <= w and h == size): 35 | return (h, w) 36 | 37 | if w < h: 38 | ow = size 39 | oh = int(size * h / w) 40 | else: 41 | oh = size 42 | ow = int(size * w / h) 43 | 44 | return (ow,oh) 45 | 46 | def __call__(self, image, target=None): 47 | size = self.get_size(image.size) 48 | image = image.resize(size,Image.BILINEAR) 49 | if target is not None: 50 | target.resize(image.size) 51 | return image, target 52 | 53 | 54 | class RandomHorizontalFlip(object): 55 | def __init__(self, prob=0.5): 56 | self.prob = prob 57 | 58 | def __call__(self, image, target): 59 | if random.random() < self.prob: 60 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 61 | if target is not None: 62 | target.hflip() 63 | return image, target 64 | 65 | class ToTensor(object): 66 | def __call__(self, image, target): 67 | if isinstance(image, Image.Image): 68 | image = np.array(image).transpose((2,0,1))/255.0 69 | return image, target 70 | 71 | class Normalize(object): 72 | def __init__(self, mean, std, to_bgr255=True): 73 | self.mean = np.array(mean).reshape(3,1,1) 74 | self.std = np.array(std).reshape(3,1,1) 75 | self.to_bgr255 = to_bgr255 76 | 77 | def __call__(self, image, target=None): 78 | if self.to_bgr255: 79 | image = image[[2, 1, 0]] * 255 80 | image = (image-self.mean)/self.std 81 | return image, target 82 | 83 | 84 | def build_transforms(min_size=2048, 85 | max_size=2048, 86 | flip_horizontal_prob=0.5, 87 | mean=[102.9801, 115.9465, 122.7717], 88 | std = [1.,1.,1.], 89 | to_bgr255=True): 90 | 91 | 92 | transform = Compose([ 93 | Resize(min_size, max_size), 94 | RandomHorizontalFlip(flip_horizontal_prob), 95 | ToTensor(), 96 | Normalize(mean=mean, std=std, to_bgr255=to_bgr255), 97 | ]) 98 | return transform 99 | 100 | def train_transforms(): 101 | return build_transforms() 102 | 103 | def val_transforms(): 104 | return build_transforms(flip_horizontal_prob=0.0) -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import os 5 | 6 | from utils.box_ops import calculate_ious 7 | 8 | def draw_box(img,box,text,color): 9 | box = [int(x) for x in box] 10 | img = cv2.rectangle(img=img, pt1=tuple(box[0:2]), pt2=tuple(box[2:]), color=color, thickness=1) 11 | img = cv2.putText(img=img, text=text, org=(box[0],box[1]-5), fontFace=0, fontScale=0.5, color=color, thickness=1) 12 | return img 13 | 14 | 15 | def draw_boxes(img,boxes,labels,classnames,scores=None, color=(0,0,0)): 16 | if scores is None: 17 | scores = ['']*len(labels) 18 | for box,score,label in zip(boxes,scores,labels): 19 | box = [int(i) for i in box] 20 | text = classnames[label-1]+(f': {score:.2f}' if not isinstance(score,str) else score) 21 | img = draw_box(img,box,text,color) 22 | return img 23 | 24 | def visualize_result(img_file, 25 | pred_boxes, 26 | pred_scores, 27 | pred_labels, 28 | gt_boxes, 29 | gt_labels, 30 | classnames, 31 | iou_thresh=0.5, 32 | miss_color=(255,0,0), 33 | wrong_color=(0,255,0), 34 | surplus_color=(0,0,255), 35 | right_color=(0,255,255)): 36 | 37 | img = cv2.imread(img_file) 38 | 39 | detected = [False for _ in range(len(gt_boxes))] 40 | miss_boxes = [] 41 | wrong_boxes = [] 42 | surplus_boxes = [] 43 | right_boxes = [] 44 | 45 | # sort the box by scores 46 | ind = np.argsort(-pred_scores) 47 | pred_boxes = pred_boxes[ind,:] 48 | pred_scores = pred_scores[ind] 49 | pred_labels = pred_labels[ind] 50 | 51 | # add background 52 | classnames = ['background']+classnames 53 | 54 | for box,score,label in zip(pred_boxes,pred_scores,pred_labels): 55 | ioumax = 0. 56 | if len(gt_boxes)>0: 57 | ioumax,jmax = calculate_ious(gt_boxes,box) 58 | if ioumax>iou_thresh: 59 | if not detected[jmax]: 60 | detected[jmax]=True 61 | if label == gt_labels[jmax]: 62 | right_boxes.append((box,f'{classnames[label]}:{int(score*100)}%')) 63 | else: 64 | wrong_boxes.append((box,f'{classnames[label]}->{classnames[gt_labels[jmax]]}')) 65 | else: 66 | surplus_boxes.append((box,f'{classnames[label]}:{int(score*100)}%')) 67 | else: 68 | surplus_boxes.append((box,f'{classnames[label]}:{int(score*100)}%')) 69 | 70 | for box,label,d in zip(gt_boxes,gt_labels,detected): 71 | if not d: 72 | miss_boxes.append((box,f'{classnames[label]}')) 73 | 74 | colors = [miss_color]*len(miss_boxes) + [wrong_color]*len(wrong_boxes) + [right_color]*len(right_boxes) + [surplus_color]*len(surplus_boxes) 75 | 76 | boxes = miss_boxes + wrong_boxes + right_boxes + surplus_boxes 77 | 78 | for (box,text),color in zip(boxes,colors): 79 | img = draw_box(img,box,text,color) 80 | 81 | # draw colors 82 | colors = [right_color,wrong_color,miss_color,surplus_color] 83 | texts = ['Detect Right','Detect Wrong Class','Missed Ground Truth','Surplus Detection'] 84 | for i,(color,text) in enumerate(zip(colors,texts)): 85 | img = cv2.rectangle(img=img, pt1=(0,i*30), pt2=(60,(i+1)*30), color=color, thickness=-1) 86 | img = cv2.putText(img=img, text=text, org=(70,(i+1)*30-5), fontFace=0, fontScale=0.8, color=color, thickness=2) 87 | return img 88 | 89 | 90 | def find_dir(data_dir,img_id): 91 | t_f = f"{data_dir}/test/{img_id}.jpg" 92 | tt_f = f"{data_dir}/train/{img_id}.jpg" 93 | o_f = f"{data_dir}/other/{img_id}.jpg" 94 | if os.path.exists(tt_f): 95 | return tt_f 96 | elif os.path.exists(t_f): 97 | return t_f 98 | elif os.path.exists(o_f): 99 | return o_f 100 | assert False,f"{img_id}.jpg is not exists" 101 | 102 | def save_visualize_image(data_dir,img_id,pred_boxes,pred_scores,pred_labels,gt_boxes,gt_labels,classnames): 103 | img_file = find_dir(data_dir,img_id) 104 | 105 | img = visualize_result(img_file,pred_boxes,pred_scores,pred_labels,gt_boxes,gt_labels,classnames) 106 | 107 | os.makedirs('test_imgs',exist_ok=True) 108 | cv2.imwrite(f'test_imgs/{img_id}.jpg',img) 109 | 110 | -------------------------------------------------------------------------------- /dataset/data.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import jittor as jt 3 | from jittor import dataset 4 | import numpy as np 5 | import json 6 | import os 7 | import glob 8 | from PIL import Image 9 | from utils.box_ops import BBox 10 | from .transforms import train_transforms,val_transforms 11 | 12 | def read_annotations(filename,filter_empty=False,classnames=None): 13 | annotations = json.load(open(filename)) 14 | if classnames is None: 15 | classnames = annotations['types'] 16 | imgs = annotations['imgs'] 17 | test_imgs = [] 18 | train_imgs = [] 19 | other_imgs = [] 20 | all_imgs = [] 21 | for img in imgs.values(): 22 | if filter_empty and len([o for o in img['objects'] if o['category'] in classnames])==0: 23 | continue 24 | path = img['path'] 25 | if 'train' in path: 26 | train_imgs.append(img) 27 | elif 'test' in path: 28 | test_imgs.append(img) 29 | else: 30 | other_imgs.append(img) 31 | all_imgs.append(img) 32 | return train_imgs,test_imgs,all_imgs,classnames 33 | 34 | 35 | class TrainDataset(dataset.Dataset): 36 | def __init__(self,data_dir,annos,classnames,transforms=None,batch_size=1,shuffle=False,num_workers=0): 37 | super(TrainDataset,self).__init__(batch_size=batch_size,shuffle=shuffle,num_workers=num_workers) 38 | self.total_len = len(annos) 39 | self.annos = annos 40 | self.data_dir = data_dir 41 | self.classnames = classnames 42 | self.transforms = transforms 43 | 44 | def __getitem__(self,index): 45 | anno = self.annos[index] 46 | img_path = os.path.join(self.data_dir,anno['path']) 47 | objects = [o for o in anno['objects'] if o['category'] in self.classnames] 48 | xyxy = [[o['bbox']['xmin'],o['bbox']['ymin'],o['bbox']['xmax'],o['bbox']['ymax']] for o in objects] 49 | labels = [self.classnames.index(o['category'])+1 for o in objects] 50 | img = Image.open(img_path) 51 | ori_img_size = img.size 52 | boxes = BBox(xyxy,img_size=img.size) 53 | labels = np.array(labels,dtype=np.int32) 54 | if self.transforms is not None: 55 | img,boxes = self.transforms(img,boxes) 56 | return img.astype(np.float32),boxes.bbox,labels,ori_img_size,anno['id'] 57 | 58 | def collate_batch(self,batch): 59 | imgs = [] 60 | boxes = [] 61 | labels = [] 62 | img_sizes = [] 63 | img_ids = [] 64 | for img,box,label,img_size,ID in batch: 65 | imgs.append(img) 66 | boxes.append(box.astype(np.float32)) 67 | labels.append(label) 68 | img_sizes.append(img_size) 69 | img_ids.append(ID) 70 | imgs = np.stack(imgs,axis=0) 71 | return imgs,boxes,labels,img_sizes,img_ids 72 | 73 | class TestDataset(dataset.Dataset): 74 | def __init__(self,img_dir,transforms=None,batch_size=1,shuffle=False,num_workers=0): 75 | super(TestDataset,self).__init__(batch_size=batch_size,shuffle=shuffle,num_workers=num_workers) 76 | img_list = list(glob.glob(os.path.join(img_dir,"*.jpg"))) 77 | self.total_len = len(img_list) 78 | self.img_list = img_list 79 | self.transforms = transforms 80 | 81 | def __getitem__(self,index): 82 | img_path = self.img_list[index] 83 | img = Image.open(img_path) 84 | ori_img_size = img.size 85 | if self.transforms is not None: 86 | img,_ = self.transforms(img,None) 87 | return img.astype(np.float32),ori_img_size,img_path.split("/")[-1].split(".jpg")[0] 88 | 89 | def collate_batch(self,batch): 90 | imgs = [] 91 | img_sizes = [] 92 | img_ids = [] 93 | for img,img_size,ID in batch: 94 | imgs.append(img) 95 | img_sizes.append(img_size) 96 | img_ids.append(ID) 97 | imgs = np.stack(imgs,axis=0) 98 | return imgs,img_sizes,img_ids 99 | 100 | 101 | def build_dataset(data_dir,anno_file,classnames,filter_empty=True,batch_size=1,shuffle=False,num_workers=0,is_train=False,use_all=False): 102 | train_imgs,test_imgs,all_imgs,classes = read_annotations(anno_file,filter_empty=filter_empty,classnames=classnames) 103 | if classnames is None: 104 | classnames = classes 105 | 106 | if is_train: 107 | annos = train_imgs 108 | transforms = train_transforms() 109 | else: 110 | annos = test_imgs 111 | transforms = val_transforms() 112 | if use_all: 113 | annos = all_imgs 114 | 115 | dataset = TrainDataset(data_dir,annos,classnames,transforms,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers) 116 | return dataset 117 | 118 | def build_testdataset(img_dir,batch_size=1,shuffle=False,num_workers=1): 119 | transforms = val_transforms() 120 | return TestDataset(img_dir,transforms=transforms,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers) 121 | -------------------------------------------------------------------------------- /utils/ap_eval.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import numpy as np 3 | 4 | def calculate_ious(gt_boxes,box): 5 | 6 | in_w = np.minimum(gt_boxes[:,2],box[2]) - np.maximum(gt_boxes[:,0],box[0]) 7 | in_h = np.minimum(gt_boxes[:,3],box[3]) - np.maximum(gt_boxes[:,1],box[1]) 8 | 9 | in_w = np.maximum(in_w,0) 10 | in_h = np.maximum(in_h,0) 11 | 12 | inter = in_w*in_h 13 | 14 | area1 = (gt_boxes[:,2]-gt_boxes[:,0])*(gt_boxes[:,3]-gt_boxes[:,1]) 15 | area2 = (box[2]-box[0])*(box[3]-box[1]) 16 | union = area1+area2-inter 17 | ious = inter / union 18 | jmax = np.argmax(ious) 19 | maxiou = ious[jmax] 20 | return maxiou,jmax 21 | 22 | def calculate_voc_ap(prec,rec,use_07_metric): 23 | if use_07_metric: 24 | # 11 point metric 25 | # http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.pdf (page 313) 26 | 27 | ap = 0. 28 | for t in np.arange(0., 1.1, 0.1): 29 | if np.sum(rec >= t) == 0: 30 | p = 0 31 | else: 32 | p = np.max(prec[rec >= t]) 33 | ap = ap + p / 11. 34 | else: 35 | # correct AP calculation (from VOC 2010 challenge) 36 | # http://host.robots.ox.ac.uk/pascal/VOC/voc2012/devkit_doc.pdf (page 12) 37 | 38 | # first append sentinel values at the end 39 | mrec = np.concatenate(([0.], rec, [1.])) 40 | mpre = np.concatenate(([0.], prec, [0.])) 41 | 42 | # compute the precision envelope 43 | for i in range(mpre.size - 1, 0, -1): 44 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 45 | 46 | # to calculate area under PR curve, look for points 47 | # where X axis (recall) changes value 48 | i = np.where(mrec[1:] != mrec[:-1])[0] 49 | 50 | # and sum (\Delta recall) * prec 51 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 52 | 53 | return ap 54 | 55 | def calculate_prec_recall(result,label,iou_thresh): 56 | GTBB = {} 57 | BB = [] 58 | BB_ids = [] 59 | BB_scores = [] 60 | n_gt_bboxes = 0 61 | for img_id,pred_boxes,pred_labels,pred_scores,gt_boxes,gt_labels in result: 62 | gts = gt_boxes[gt_labels==label,:] 63 | n_gt_bboxes+=gts.shape[0] 64 | GTBB[img_id]={ 65 | 'bboxes':gts, 66 | 'detected':[False for i in range(gts.shape[0])] 67 | } 68 | pred = pred_boxes[pred_labels==label,:] 69 | scores = pred_scores[pred_labels==label] 70 | ids = [img_id for i in range(scores.shape[0])] 71 | BB.append(pred) 72 | BB_ids.extend(ids) 73 | BB_scores.append(scores) 74 | 75 | if n_gt_bboxes==0: 76 | return None,None 77 | 78 | if len(BB_ids)>0: 79 | BB = np.concatenate(BB,axis=0) 80 | BB_scores = np.concatenate(BB_scores,axis=0) 81 | indexes = np.argsort(-BB_scores) 82 | BB = BB[indexes,:] 83 | BB_ids = np.array(BB_ids)[indexes] 84 | else: 85 | return np.array([0.]),np.array([0.]) 86 | 87 | n_pred = len(BB_ids) 88 | tp = np.zeros(n_pred) 89 | fp = np.zeros(n_pred) 90 | 91 | for d in range(n_pred): 92 | bb = BB[d,:] 93 | gt_boxes = GTBB[BB_ids[d]]['bboxes'] 94 | 95 | ioumax = 0.0 96 | if gt_boxes.shape[0]>0: 97 | ioumax,jmax = calculate_ious(gt_boxes,bb) 98 | if ioumax>iou_thresh: 99 | if not GTBB[BB_ids[d]]['detected'][jmax]: 100 | tp[d]=1. 101 | GTBB[BB_ids[d]]['detected'][jmax] = True 102 | else: 103 | fp[d]=1. 104 | else: 105 | fp[d] = 1. 106 | 107 | fp = np.cumsum(fp) 108 | tp = np.cumsum(tp) 109 | 110 | rec = tp / float(n_gt_bboxes) 111 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 112 | 113 | return prec,rec 114 | 115 | 116 | def calculate_VOC_mAP(result,classnames,iou_thresh=0.5,use_07_metric=True): 117 | ''' 118 | INPUTS: 119 | result: the result list of detections, 120 | it's like [(img_id,pred_boxes,pred_labels,pred_scores,gt_boxes,gt_labels),...,(...)] 121 | In Labels, 0 is background 122 | 123 | classnames: the class we used to calculate mAP, and "background" is not in it. 124 | 125 | iou_thresh: A bounding box reported by an algorithm is considered 126 | correct if its area intersection over union with a ground 127 | truth bounding box is beyond 50%. 128 | 129 | use_07_metric: True means we use voc 07 challenge metric, False means use 10 metric 130 | 131 | OUTPUT: 132 | mAP: mean Average Precision 133 | ''' 134 | 135 | aps = [] 136 | all_aps = [] 137 | for i,classname in enumerate(classnames): 138 | # background is not in classnames and it's 0 139 | label = i+1 140 | prec,rec = calculate_prec_recall(result,label,iou_thresh) 141 | if prec is None: 142 | all_aps.append(None) 143 | continue 144 | ap = calculate_voc_ap(prec,rec,use_07_metric) 145 | all_aps.append(ap) 146 | aps.append(ap) 147 | 148 | if len(aps)>0: 149 | mAP = np.mean(aps) 150 | else: 151 | mAP = 0. 152 | return mAP,all_aps 153 | 154 | 155 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn 3 | 4 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 5 | conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) 6 | jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out") 7 | return conv 8 | 9 | def conv1x1(in_planes, out_planes, stride=1): 10 | conv=nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 11 | jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out") 12 | return conv 13 | 14 | class Bottleneck(nn.Module): 15 | expansion = 4 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): 18 | super(Bottleneck, self).__init__() 19 | if (norm_layer is None): 20 | norm_layer = nn.BatchNorm 21 | width = (int((planes * (base_width / 64.0))) * groups) 22 | self.conv1 = conv1x1(inplanes, width) 23 | self.bn1 = norm_layer(width) 24 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 25 | self.bn2 = norm_layer(width) 26 | self.conv3 = conv1x1(width, (planes * self.expansion)) 27 | self.bn3 = norm_layer((planes * self.expansion)) 28 | self.relu = nn.Relu() 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def execute(self, x): 33 | identity = x 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | out = self.conv3(out) 41 | out = self.bn3(out) 42 | if (self.downsample is not None): 43 | identity = self.downsample(x) 44 | out += identity 45 | out = self.relu(out) 46 | return out 47 | 48 | class ResNet(nn.Module): 49 | 50 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): 51 | super(ResNet, self).__init__() 52 | if (norm_layer is None): 53 | norm_layer = nn.BatchNorm 54 | self._norm_layer = norm_layer 55 | self.inplanes = 64 56 | self.dilation = 1 57 | self.feat_stride = 16 58 | self.out_channels = 1024 59 | 60 | if (replace_stride_with_dilation is None): 61 | replace_stride_with_dilation = [False, False, False] 62 | if (len(replace_stride_with_dilation) != 3): 63 | raise ValueError('replace_stride_with_dilation should be None or a 3-element tuple, got {}'.format(replace_stride_with_dilation)) 64 | self.groups = groups 65 | self.base_width = width_per_group 66 | stride = 2 67 | self.conv1 = nn.Conv(3, self.inplanes, kernel_size=7, stride=stride, padding=3, bias=False) 68 | jt.init.relu_invariant_gauss_(self.conv1.weight, mode="fan_out") 69 | self.bn1 = norm_layer(self.inplanes) 70 | self.relu = nn.Relu() 71 | 72 | self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum') 73 | self.layer1 = self._make_layer(block, 64, layers[0]) 74 | self.layer2 = self._make_layer(block, 128, layers[1], stride=stride, dilate=replace_stride_with_dilation[0]) 75 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride, dilate=replace_stride_with_dilation[1]) 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 78 | norm_layer = self._norm_layer 79 | downsample = None 80 | previous_dilation = self.dilation 81 | if dilate: 82 | self.dilation *= stride 83 | stride = 1 84 | if ((stride != 1) or (self.inplanes != (planes * block.expansion))): 85 | downsample = nn.Sequential(conv1x1(self.inplanes, (planes * block.expansion), stride), norm_layer((planes * block.expansion))) 86 | layers = [] 87 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) 88 | self.inplanes = (planes * block.expansion) 89 | for _ in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) 91 | return nn.Sequential(*layers) 92 | 93 | def _forward_impl(self, x): 94 | x = self.conv1(x) 95 | x = self.bn1(x) 96 | x = self.relu(x) 97 | x = self.maxpool(x) 98 | x = self.layer1(x) 99 | x = self.layer2(x) 100 | x = self.layer3(x) 101 | return x 102 | 103 | def execute(self, x): 104 | return self._forward_impl(x) 105 | 106 | def _resnet(block, layers, **kwargs): 107 | model = ResNet(block, layers, **kwargs) 108 | return model 109 | 110 | def Resnet50(pretrained=False, **kwargs): 111 | model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs) 112 | if pretrained: model.load("jittorhub://resnet50.pkl") 113 | return model 114 | 115 | def Resnet101(pretrained=False, **kwargs): 116 | model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs) 117 | if pretrained: model.load("jittorhub://resnet101.pkl") 118 | return model -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Traffic Sign Detection 2 | 本项目为2021年第一届“计图”人工智能算法挑战赛-交通标志检测的Baseline。深度学习框架为[Jittor](https://cg.cs.tsinghua.edu.cn/jittor/)。 3 | 4 | 比赛链接如下: 5 | 1. [2021年第一届“计图”人工智能算法挑战赛-交通标志检测](https://www.educoder.net/competitions/index/Jittor-1) 6 | 2. [2021年第一届“计图”人工智能算法挑战赛-狗细分类](https://www.educoder.net/competitions/index/Jittor-2) 7 | 8 | ### 项目文件 9 | ```shell 10 | TrafficSignDetection 11 | ├── dataset 12 | │ ├── augmentation.py # 数据增强方法 13 | │ ├── data.py # Dataset 14 | │ └── transforms.py # 图片数据处理 15 | ├── model 16 | │ ├── faster_rcnn.py 17 | │ ├── resnet.py 18 | │ ├── roi_head.py 19 | │ └── rpn.py 20 | ├── utils 21 | │ ├── ap_eval.py 22 | │ ├── box_ops.py 23 | │ ├── roi_align.py 24 | │ └── visualize.py 25 | ├── evaluate.py 26 | ├── train.py 27 | ├── README.md 28 | └── requirements.txt 29 | ``` 30 | 31 | ### 数据集 32 | 33 | ```shell 34 | # 下载训练数据集 35 | wget https://cg.cs.tsinghua.edu.cn/traffic-sign/tt100k_2021.zip 36 | unzip tt100k_2021.zip 37 | 38 | # 下载A榜数据集 39 | wget -O TT_TEST_A.zip https://cloud.tsinghua.edu.cn/f/14252d9ad07b4d7b86d4/?dl=1 40 | unzip TT_TEST_A.zip 41 | 42 | # 下载Baseline数据增强使用的部分marks 43 | cd tt100k_2021 44 | wget https://cg.cs.tsinghua.edu.cn/traffic-sign/re_marks.zip 45 | unzip re_marks.zip 46 | ``` 47 | ```shell 48 | tt100k_2021 49 | ├── annotations_all.json 50 | ├── marks.jpg 51 | ├── report.pdf 52 | ├── test_result.pkl 53 | ├── marks 54 | │ └── .png 55 | ├── re_marks 56 | │ └── .jpg 57 | ├── train 58 | │ └── .jpg 59 | ├── test 60 | │ └── .jpg 61 | └── other 62 | └── .jpg 63 | ``` 64 | ### 模型 65 | 66 | 检测模型为Faster RCNN, Backbone为ResNet。 67 | 68 | 目前的训练基于ResNet50。 69 | 70 | 训练参数如下: 71 | 1. 图片尺寸:2048\* 2048 72 | 2. Batch Size: 1 73 | 3. Epoch: 20 74 | 4. LR: 0.001 75 | 76 | ### 模型训练和测试 77 | #### 安装依赖库: 78 | ```shell 79 | python3 -m pip install -r requirements.txt 80 | ``` 81 | #### 模型训练: 82 | 训练之前先进行数据增强: 83 | ```shell 84 | python3 dataset/augmentation.py 85 | ``` 86 | 87 | 训练前需要设置数据集的位置,具体设置方式为**train.py**中Line18-22。 88 | Batch Size,lr,num_workers等等请参考**train.py**中的train函数。 89 | 训练过程的显存消耗如下: 90 | ```shell 91 | | GeForce RTX 3090 | 23629MiB / 24268MiB | 79% | 92 | ``` 93 | 因为图片尺寸较大,当显存较小时请缩小图片尺寸或者对图片进行裁剪,缩小图片尺寸代码在**dataset/transforms.py**里: 94 | 更改min_size和max_size即可 95 | ```python 96 | def build_transforms(min_size=2048, 97 | max_size=2048, 98 | flip_horizontal_prob=0.5, 99 | mean=[102.9801, 115.9465, 122.7717], 100 | std = [1.,1.,1.], 101 | to_bgr255=True): 102 | 103 | 104 | transform = Compose([ 105 | Resize(min_size, max_size), 106 | RandomHorizontalFlip(flip_horizontal_prob), 107 | ToTensor(), 108 | Normalize(mean=mean, std=std, to_bgr255=to_bgr255), 109 | ]) 110 | return transform 111 | ``` 112 | 设置好后训练脚本如下: 113 | ```shell 114 | python3 train.py --task=train 115 | ``` 116 | 本模型用3090来训练的,默认为单卡,如果使用多卡,请调整batch_size和lr之后使用,多卡训练如下所示: 117 | ```shell 118 | # 8卡训练 119 | mpirun -np 8 python3 train.py --task=train 120 | ``` 121 | Jittor多卡训练细节请参考:[https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-2-16-44-distributed/](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-2-16-44-distributed/) 122 | 123 | #### 模型验证: 124 | 由于训练过程使用全部数据,因此验证过程同样使用训练数据集。 125 | ```shell 126 | python3 train.py --task=test 127 | ``` 128 | 129 | #### 模型测试: 130 | 测试集没有Ground Truth。如测试A榜数据,则设置A榜图片的文件路径为数据路径。 131 | checkpoints链接:[https://cloud.tsinghua.edu.cn/d/0b03f9dedd674101bc94/](https://cloud.tsinghua.edu.cn/d/0b03f9dedd674101bc94/) 132 | ```shell 133 | python3 evaluate.py 134 | ``` 135 | 保存结果的格式为: 136 | ```json 137 | { 138 | "0.jpg": [ 139 | { 140 | "bbox": { 141 | "xmin": 1181.00341796875, 142 | "ymin": 935.8701171875, 143 | "xmax": 1200.79736328125, 144 | "ymax": 954.7010498046875 145 | }, 146 | "category": "w21", 147 | "score": 0.022070620208978653 148 | }, 149 | { 150 | "bbox": { 151 | "xmin": 1182.023193359375, 152 | "ymin": 936.9432983398438, 153 | "xmax": 1203.480712890625, 154 | "ymax": 957.1759643554688 155 | }, 156 | "category": "w57", 157 | "score": 0.28726232051849365 158 | } 159 | ], 160 | "1.jpg": [ 161 | { 162 | "bbox": { 163 | "xmin": 164.0937042236328, 164 | "ymin": 699.9683837890625, 165 | "xmax": 200.88169860839844, 166 | "ymax": 759.8314208984375 167 | }, 168 | "category": "pne", 169 | "score": 0.4119001626968384 170 | } 171 | ] 172 | } 173 | ``` 174 | 175 | 176 | 注意细节: 177 | 1. 如果显存较小,存在显存不够的问题,可以考虑把2048\*2048的图片给切割成512\*512的小图片来训练和测试。 178 | 切成小图要考虑box是否被切开的情况,如何处理等等。 179 | 2. 如果显存爆了,模型推理会非常慢,因为此时大量数据存储在内存,而非显存。 180 | 3. 如果出现bug或者有任何困惑,请加入jittor群(QQ:761222083)一起交流 181 | 182 | 训练过程可视化: 183 | 184 | ```shell 185 | tensorboard --logdir=runs --bind_all 186 | ``` 187 | 可视化结果如下所示: 188 | ![](pics/map.png) 189 | ![](pics/roi_cls_loss.png) 190 | ![](pics/roi_loc_loss.png) 191 | ![](pics/rpn_cls_loss.png) 192 | ![](pics/rpn_loc_loss.png) 193 | ![](pics/total_loss.png) 194 | ### 可能改进的地方 195 | 1. 在ResNet后面增加FPN 196 | 2. 把ResNet换成ResNeXt 197 | 3. 根据交通标志的大小重新设置RPN的anchors 198 | 4. 使用其他更加有效的检测模型 199 | 5. .... 200 | 201 | ### 参考 202 | [1] https://github.com/endernewton/tf-faster-rcnn 203 | 204 | [2] https://github.com/chenyuntc/simple-faster-rcnn-pytorch 205 | 206 | [3] https://github.com/aarcosg/traffic-sign-detection 207 | 208 | [4] https://github.com/Cartucho/mAP 209 | 210 | [5] https://cg.cs.tsinghua.edu.cn/traffic-sign/tutorial.html 211 | 212 | [6] https://cg.cs.tsinghua.edu.cn/traffic-sign/ -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import jittor as jt 3 | import numpy as np 4 | from tqdm import tqdm 5 | from jittor import optim 6 | import argparse 7 | import sys 8 | import glob 9 | import pickle 10 | import os 11 | from tensorboardX import SummaryWriter 12 | 13 | from utils.ap_eval import calculate_VOC_mAP 14 | from utils.visualize import save_visualize_image 15 | from dataset.data import build_dataset 16 | from model.faster_rcnn import FasterRCNN 17 | 18 | DATA_DIR = '/data/lxl/dataset/tt100k/tt100k_2021' 19 | CLASSNAMES = ['i1', 'i10', 'i11', 'i12', 'i13', 'i14', 'i15', 'i2', 'i3', 'i4', 'i5', 'il100', 'il110', 'il50', 'il60', 'il70', 'il80', 'il90', 'ip', 'p1', 'p10', 'p11', 'p12', 'p13', 'p14', 'p15', 'p16', 'p17', 'p18', 'p19', 'p2', 'p20', 'p21', 'p23', 'p24', 'p25', 'p26', 'p27', 'p28', 'p3', 'p4', 'p5', 'p6', 'p7', 'p8', 'p9', 'pa10', 'pa12', 'pa13', 'pa14', 'pa8', 'pb', 'pc', 'pg', 'ph2', 'ph2.1', 'ph2.2', 'ph2.4', 'ph2.5', 'ph2.8', 'ph2.9', 'ph3', 'ph3.2', 'ph3.5', 'ph3.8', 'ph4', 'ph4.2', 'ph4.3', 'ph4.5', 'ph4.8', 'ph5', 'ph5.3', 'ph5.5', 'pl10', 'pl100', 'pl110', 'pl120', 'pl15', 'pl20', 'pl25', 'pl30', 'pl35', 'pl40', 'pl5', 'pl50', 'pl60', 'pl65', 'pl70', 'pl80', 'pl90', 'pm10', 'pm13', 'pm15', 'pm1.5', 'pm2', 'pm20', 'pm25', 'pm30', 'pm35', 'pm40', 'pm46', 'pm5', 'pm50', 'pm55', 'pm8', 'pn', 'pne', 'pr10', 'pr100', 'pr20', 'pr30', 'pr40', 'pr45', 'pr50', 'pr60', 'pr70', 'pr80', 'ps', 'pw2.5', 'pw3', 'pw3.2', 'pw3.5', 'pw4', 'pw4.2', 'pw4.5', 'w1', 'w10', 'w12', 'w13', 'w16', 'w18', 'w20', 'w21', 'w22', 'w24', 'w28', 'w3', 'w30', 'w31', 'w32', 'w34', 'w35', 'w37', 'w38', 'w41', 'w42', 'w43', 'w44', 'w45', 'w46', 'w47', 'w48', 'w49', 'w5', 'w50', 'w55', 'w56', 'w57', 'w58', 'w59', 'w60', 'w62', 'w63', 'w66', 'w8', 'i6', 'i7', 'i8', 'i9', 'p29', 'w29', 'w33', 'w36', 'w39', 'w4', 'w40', 'w51', 'w52', 'w53', 'w54', 'w6', 'w61', 'w64', 'w65', 'w67', 'w7', 'w9', 'pd', 'pe', 'pnl', 'w11', 'w14', 'w15', 'w17', 'w19', 'w2', 'w23', 'w25', 'w26', 'w27', 'pm2.5', 'ph4.4', 'ph3.3', 'ph2.6', 'i4l', 'i2r', 'im', 'wc', 'pcr', 'pcl', 'pss', 'pbp', 'p1n', 'pbm', 'pt', 'pn-2', 'pclr', 'pcs', 'pcd', 'iz', 'pmb', 'pdd', 'pctl', 'ph1.8', 'pnlc', 'pmblr', 'phclr', 'phcs', 'pmr'] 20 | EPOCHS = 10 21 | save_checkpoint_path = f"{DATA_DIR}/checkpoints" 22 | 23 | def eval(val_dataset,faster_rcnn,epoch,is_display=False,is_save_result=True,score_thresh=0.01): 24 | faster_rcnn.eval() 25 | results = [] 26 | for batch_idx,(images,boxes,labels,image_sizes,img_ids) in tqdm(enumerate(val_dataset)): 27 | result = faster_rcnn.predict(images,score_thresh=score_thresh) 28 | for i in range(len(img_ids)): 29 | pred_boxes,pred_scores,pred_labels = result[i] 30 | gt_boxes = boxes[i] 31 | gt_labels = labels[i] 32 | img_id = img_ids[i] 33 | results.append((img_id.item(),pred_boxes.numpy(),pred_labels.numpy(),pred_scores.numpy(),gt_boxes.numpy(),gt_labels.numpy())) 34 | if is_display: 35 | save_visualize_image(DATA_DIR,img_id,pred_boxes.numpy(),pred_scores.numpy(),pred_labels.numpy(),gt_boxes.numpy(),gt_labels.numpy(),CLASSNAMES) 36 | if is_save_result: 37 | os.makedirs(save_checkpoint_path,exist_ok=True) 38 | pickle.dump(results,open(f"{save_checkpoint_path}/result_{epoch}.pkl","wb")) 39 | mAP,_ = calculate_VOC_mAP(results,CLASSNAMES,use_07_metric=False) 40 | return mAP 41 | 42 | 43 | def test(): 44 | val_dataset = build_dataset(data_dir=DATA_DIR, 45 | anno_file=f'{DATA_DIR}/annotations_aug.json', 46 | classnames=CLASSNAMES, 47 | batch_size=1, 48 | shuffle=False, 49 | num_workers=8, 50 | is_train=False) 51 | faster_rcnn = FasterRCNN(n_class=len(CLASSNAMES)+1) 52 | files = sorted(list(glob.glob(f'{save_checkpoint_path}/checkpoint*.pkl'))) 53 | f = files[-1] 54 | faster_rcnn.load(f) 55 | mAP = eval(val_dataset,faster_rcnn,0,is_display=True,is_save_result=False,score_thresh=0.5) 56 | print(mAP) 57 | 58 | def train(): 59 | train_dataset = build_dataset(data_dir=DATA_DIR, 60 | anno_file=f'{DATA_DIR}/annotations_aug.json', 61 | classnames=CLASSNAMES, 62 | batch_size=1, 63 | shuffle=True, 64 | num_workers=4, 65 | is_train=True, 66 | use_all=True) 67 | 68 | val_dataset = build_dataset(data_dir=DATA_DIR, 69 | anno_file=f'{DATA_DIR}/annotations_aug.json', 70 | classnames=CLASSNAMES, 71 | batch_size=1, 72 | shuffle=False, 73 | num_workers=4, 74 | is_train=False, 75 | use_all=True) 76 | 77 | faster_rcnn = FasterRCNN(n_class = len(CLASSNAMES)+1) 78 | 79 | optimizer = optim.SGD(faster_rcnn.parameters(),momentum=0.9,lr=0.001) 80 | 81 | writer = SummaryWriter() 82 | 83 | for epoch in range(EPOCHS): 84 | faster_rcnn.train() 85 | dataset_len = len(train_dataset) 86 | for batch_idx,(images,boxes,labels,image_sizes,img_ids) in tqdm(enumerate(train_dataset)): 87 | rpn_loc_loss,rpn_cls_loss,roi_loc_loss,roi_cls_loss,total_loss = faster_rcnn(images,boxes,labels) 88 | 89 | optimizer.step(total_loss) 90 | 91 | writer.add_scalar('rpn_cls_loss', rpn_cls_loss.item(), global_step=dataset_len*epoch+batch_idx) 92 | writer.add_scalar('rpn_loc_loss', rpn_loc_loss.item(), global_step=dataset_len*epoch+batch_idx) 93 | writer.add_scalar('roi_loc_loss', roi_loc_loss.item(), global_step=dataset_len*epoch+batch_idx) 94 | writer.add_scalar('roi_cls_loss', roi_cls_loss.item(), global_step=dataset_len*epoch+batch_idx) 95 | writer.add_scalar('total_loss', total_loss.item(), global_step=dataset_len*epoch+batch_idx) 96 | 97 | if batch_idx % 10 == 0: 98 | loss_str = '\nrpn_loc_loss: %s \nrpn_cls_loss: %s \nroi_loc_loss: %s \nroi_cls_loss: %s \ntotoal_loss: %s \n' 99 | print(loss_str % (rpn_loc_loss.item(),rpn_cls_loss.item(),roi_loc_loss.item(),roi_cls_loss.item(),total_loss.item())) 100 | 101 | mAP = eval(val_dataset,faster_rcnn,epoch) 102 | writer.add_scalar('map', mAP, global_step=epoch) 103 | os.makedirs(save_checkpoint_path,exist_ok=True) 104 | faster_rcnn.save(f"{save_checkpoint_path}/checkpoint_{epoch}.pkl") 105 | 106 | def main(): 107 | parser = argparse.ArgumentParser(description='Test a Faster R-CNN network') 108 | parser.add_argument('--task',help='Task(train,test)',default='test',type=str) 109 | parser.add_argument('--no_cuda', help='not use cuda', action='store_true') 110 | args = parser.parse_args() 111 | 112 | if not args.no_cuda: 113 | jt.flags.use_cuda=1 114 | 115 | if args.task == 'test': 116 | test() 117 | elif args.task == 'train': 118 | train() 119 | else: 120 | print(f"No this task: {args.task}") 121 | 122 | if __name__ == '__main__': 123 | main() 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /utils/box_ops.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import numpy as np 3 | 4 | class BBox(object): 5 | def __init__(self,bbox,img_size): 6 | # img_size: w,h box[N 4] x1,y1 x2 y2 7 | if not isinstance(bbox,np.ndarray): 8 | bbox = np.array(bbox,dtype=np.float32) 9 | if bbox.shape[0]==0: 10 | bbox = np.zeros((0,4)) 11 | self.img_size = img_size 12 | self.bbox = bbox 13 | 14 | def resize(self,size): 15 | rw,rh = size[0]/self.img_size[0],size[1]/self.img_size[1] 16 | self.img_size = size 17 | x1,y1,x2,y2 = np.split(self.bbox,4,axis=1) 18 | x1 = x1*rw 19 | x2 = x2*rw 20 | y1 = y1*rh 21 | y2 = y2*rh 22 | self.bbox = np.concatenate([x1,y1,x2,y2],axis=1) 23 | 24 | def hflip(self): 25 | x1,y1,x2,y2 = np.split(self.bbox,4,axis=1) 26 | w,h = self.img_size 27 | TO_REMOVE=1 28 | x1,x2 = w-x2-TO_REMOVE,w-x1-TO_REMOVE 29 | self.bbox = np.concatenate([x1,y1,x2,y2],axis=1) 30 | 31 | 32 | def generate_anchor_base(base_size=16, ratios=[0.5, 1, 2], 33 | anchor_scales=[8, 16, 32]): 34 | """Generate anchor base windows by enumerating aspect ratio and scales. 35 | 36 | Generate anchors that are scaled and modified to the given aspect ratios. 37 | Area of a scaled anchor is preserved when modifying to the given aspect 38 | ratio. 39 | 40 | :obj:`R = len(ratios) * len(anchor_scales)` anchors are generated by this 41 | function. 42 | The :obj:`i * len(anchor_scales) + j` th anchor corresponds to an anchor 43 | generated by :obj:`ratios[i]` and :obj:`anchor_scales[j]`. 44 | 45 | For example, if the scale is :math:`8` and the ratio is :math:`0.25`, 46 | the width and the height of the base window will be stretched by :math:`8`. 47 | For modifying the anchor to the given aspect ratio, 48 | the height is halved and the width is doubled. 49 | 50 | Args: 51 | base_size (number): The width and the height of the reference window. 52 | ratios (list of floats): This is ratios of width to height of 53 | the anchors. 54 | anchor_scales (list of numbers): This is areas of anchors. 55 | Those areas will be the product of the square of an element in 56 | :obj:`anchor_scales` and the original area of the reference 57 | window. 58 | 59 | Returns: 60 | ~numpy.ndarray: 61 | An array of shape :math:`(R, 4)`. 62 | Each element is a set of coordinates of a bounding box. 63 | The second axis corresponds to 64 | :math:`(y_{min}, x_{min}, y_{max}, x_{max})` of a bounding box. 65 | 66 | """ 67 | py = base_size / 2. 68 | px = base_size / 2. 69 | 70 | anchor_base = np.zeros((len(ratios) * len(anchor_scales), 4), 71 | dtype=np.float32) 72 | for i in range(len(ratios)): 73 | for j in range(len(anchor_scales)): 74 | h = base_size * anchor_scales[j] * np.sqrt(ratios[i]) 75 | w = base_size * anchor_scales[j] * np.sqrt(1. / ratios[i]) 76 | 77 | index = i * len(anchor_scales) + j 78 | anchor_base[index, 0] = py - h / 2. 79 | anchor_base[index, 1] = px - w / 2. 80 | anchor_base[index, 2] = py + h / 2. 81 | anchor_base[index, 3] = px + w / 2. 82 | return anchor_base 83 | 84 | def _enumerate_shifted_anchor(anchor_base, feat_stride, height, width): 85 | # Enumerate all shifted anchors: 86 | # 87 | # add A anchors (1, A, 4) to 88 | # cell K shifts (K, 1, 4) to get 89 | # shift anchors (K, A, 4) 90 | # reshape to (K*A, 4) shifted anchors 91 | # return (K*A, 4) 92 | 93 | shift_y = np.arange(0, height * feat_stride, feat_stride) 94 | shift_x = np.arange(0, width * feat_stride, feat_stride) 95 | shift_x, shift_y = np.meshgrid(shift_x, shift_y) 96 | shift = np.stack((shift_x.ravel(), shift_y.ravel(), 97 | shift_x.ravel(), shift_y.ravel()), axis=1) 98 | 99 | A = anchor_base.shape[0] 100 | K = shift.shape[0] 101 | anchor = anchor_base.reshape((1, A, 4)) + \ 102 | shift.reshape((1, K, 4)).transpose((1, 0, 2)) 103 | anchor = anchor.reshape((K * A, 4)).astype(np.float32) 104 | return anchor 105 | 106 | 107 | def loc2bbox(src_bbox,loc): 108 | if src_bbox.shape[0] == 0: 109 | return jt.zeros((0, 4), dtype=loc.dtype) 110 | 111 | src_width = src_bbox[:, 2:3] - src_bbox[:, 0:1] 112 | src_height = src_bbox[:, 3:4] - src_bbox[:, 1:2] 113 | src_center_x = src_bbox[:, 0:1] + 0.5 * src_width 114 | src_center_y = src_bbox[:, 1:2] + 0.5 * src_height 115 | 116 | dx = loc[:, 0:1] 117 | dy = loc[:, 1:2] 118 | dw = loc[:, 2:3] 119 | dh = loc[:, 3:4] 120 | 121 | center_x = dx*src_width+src_center_x 122 | center_y = dy*src_height+src_center_y 123 | 124 | w = jt.exp(dw.minimum(20.0)) * src_width 125 | h = jt.exp(dh.minimum(20.0)) * src_height 126 | 127 | x1,y1,x2,y2 = center_x-0.5*w, center_y-0.5*h, center_x+0.5*w, center_y+0.5*h 128 | 129 | dst_bbox = jt.contrib.concat([x1,y1,x2,y2],dim=1) 130 | 131 | return dst_bbox 132 | 133 | def bbox2loc(src_bbox,dst_bbox): 134 | width = src_bbox[:, 2:3] - src_bbox[:, 0:1] 135 | height = src_bbox[:, 3:4] - src_bbox[:, 1:2] 136 | center_x = src_bbox[:, 0:1] + 0.5 * width 137 | center_y = src_bbox[:, 1:2] + 0.5 * height 138 | 139 | base_width = dst_bbox[:, 2:3] - dst_bbox[:, 0:1] 140 | base_height = dst_bbox[:, 3:4] - dst_bbox[:, 1:2] 141 | base_center_x = dst_bbox[:, 0:1] + 0.5 * base_width 142 | base_center_y = dst_bbox[:, 1:2] + 0.5 * base_height 143 | 144 | eps = 1e-5 145 | height = jt.maximum(height, eps) 146 | width = jt.maximum(width, eps) 147 | 148 | dy = (base_center_y - center_y) / height 149 | dx = (base_center_x - center_x) / width 150 | 151 | dw = jt.log(base_width / width) 152 | dh = jt.log(base_height / height) 153 | 154 | loc = jt.contrib.concat([dx,dy,dw,dh],dim=1) 155 | return loc 156 | 157 | def bbox_iou(bbox_a, bbox_b): 158 | assert bbox_a.shape[1]==4 and bbox_b.shape[1]==4 159 | 160 | # top left 161 | tl = jt.maximum(bbox_a[:, :2].unsqueeze(1), bbox_b[:, :2]) 162 | # bottom right 163 | br = jt.minimum(bbox_a[:,2:].unsqueeze(1), bbox_b[:, 2:]) 164 | 165 | area_i = jt.prod(br - tl, dim=2) * (tl < br).all(dim=2) 166 | area_a = jt.prod(bbox_a[:, 2:] - bbox_a[:, :2], dim=1) 167 | area_b = jt.prod(bbox_b[:, 2:] - bbox_b[:, :2], dim=1) 168 | return area_i / (area_a.unsqueeze(1) + area_b - area_i) 169 | 170 | 171 | def calculate_ious(gt_boxes,box): 172 | 173 | in_w = np.minimum(gt_boxes[:,2],box[2]) - np.maximum(gt_boxes[:,0],box[0]) 174 | in_h = np.minimum(gt_boxes[:,3],box[3]) - np.maximum(gt_boxes[:,1],box[1]) 175 | 176 | in_w = np.maximum(in_w,0) 177 | in_h = np.maximum(in_h,0) 178 | 179 | inter = in_w*in_h 180 | 181 | area1 = (gt_boxes[:,2]-gt_boxes[:,0])*(gt_boxes[:,3]-gt_boxes[:,1]) 182 | area2 = (box[2]-box[0])*(box[3]-box[1]) 183 | union = area1+area2-inter 184 | ious = inter / union 185 | jmax = np.argmax(ious) 186 | maxiou = ious[jmax] 187 | return maxiou,jmax 188 | 189 | 190 | def _unmap(data, count, index, fill=0): 191 | # Unmap a subset of item (data) back to the original set of items (of 192 | # size count) 193 | ret_shape = list(data.shape) 194 | ret_shape[0]=count 195 | ret = jt.ones(ret_shape,dtype=data.dtype)*fill 196 | ret[index]=data 197 | return ret 198 | 199 | 200 | 201 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import glob 4 | import json 5 | import jittor as jt 6 | import numpy as np 7 | from tqdm import tqdm 8 | import matplotlib.pyplot as plt 9 | 10 | from dataset.data import build_testdataset 11 | from utils.ap_eval import calculate_VOC_mAP 12 | from model.faster_rcnn import FasterRCNN 13 | from utils.visualize import save_visualize_image 14 | 15 | 16 | CLASSNAMES = ['i1', 'i10', 'i11', 'i12', 'i13', 'i14', 'i15', 'i2', 'i3', 'i4', 'i5', 'il100', 'il110', 'il50', 'il60', 'il70', 'il80', 'il90', 'ip', 'p1', 'p10', 'p11', 'p12', 'p13', 'p14', 'p15', 'p16', 'p17', 'p18', 'p19', 'p2', 'p20', 'p21', 'p23', 'p24', 'p25', 'p26', 'p27', 'p28', 'p3', 'p4', 'p5', 'p6', 'p7', 'p8', 'p9', 'pa10', 'pa12', 'pa13', 'pa14', 'pa8', 'pb', 'pc', 'pg', 'ph2', 'ph2.1', 'ph2.2', 'ph2.4', 'ph2.5', 'ph2.8', 'ph2.9', 'ph3', 'ph3.2', 'ph3.5', 'ph3.8', 'ph4', 'ph4.2', 'ph4.3', 'ph4.5', 'ph4.8', 'ph5', 'ph5.3', 'ph5.5', 'pl10', 'pl100', 'pl110', 'pl120', 'pl15', 'pl20', 'pl25', 'pl30', 'pl35', 'pl40', 'pl5', 'pl50', 'pl60', 'pl65', 'pl70', 'pl80', 'pl90', 'pm10', 'pm13', 'pm15', 'pm1.5', 'pm2', 'pm20', 'pm25', 'pm30', 'pm35', 'pm40', 'pm46', 'pm5', 'pm50', 'pm55', 'pm8', 'pn', 'pne', 'pr10', 'pr100', 'pr20', 'pr30', 'pr40', 'pr45', 'pr50', 'pr60', 'pr70', 'pr80', 'ps', 'pw2.5', 'pw3', 'pw3.2', 'pw3.5', 'pw4', 'pw4.2', 'pw4.5', 'w1', 'w10', 'w12', 'w13', 'w16', 'w18', 'w20', 'w21', 'w22', 'w24', 'w28', 'w3', 'w30', 'w31', 'w32', 'w34', 'w35', 'w37', 'w38', 'w41', 'w42', 'w43', 'w44', 'w45', 'w46', 'w47', 'w48', 'w49', 'w5', 'w50', 'w55', 'w56', 'w57', 'w58', 'w59', 'w60', 'w62', 'w63', 'w66', 'w8', 'i6', 'i7', 'i8', 'i9', 'p29', 'w29', 'w33', 'w36', 'w39', 'w4', 'w40', 'w51', 'w52', 'w53', 'w54', 'w6', 'w61', 'w64', 'w65', 'w67', 'w7', 'w9', 'pd', 'pe', 'pnl', 'w11', 'w14', 'w15', 'w17', 'w19', 'w2', 'w23', 'w25', 'w26', 'w27', 'pm2.5', 'ph4.4', 'ph3.3', 'ph2.6', 'i4l', 'i2r', 'im', 'wc', 'pcr', 'pcl', 'pss', 'pbp', 'p1n', 'pbm', 'pt', 'pn-2', 'pclr', 'pcs', 'pcd', 'iz', 'pmb', 'pdd', 'pctl', 'ph1.8', 'pnlc', 'pmblr', 'phclr', 'phcs', 'pmr'] 17 | 18 | def run_images(img_dir,classnames,checkpoint_path,save_path): 19 | test_dataset = build_testdataset(img_dir=img_dir) 20 | 21 | faster_rcnn = FasterRCNN(n_class=len(classnames)+1) 22 | faster_rcnn.load(checkpoint_path) 23 | faster_rcnn.eval() 24 | 25 | results = {} 26 | for batch_idx,(images,image_sizes,img_ids) in tqdm(enumerate(test_dataset)): 27 | result = faster_rcnn.predict(images,score_thresh=0.01) 28 | for img_id,(pred_boxes,pred_scores,pred_labels) in zip(img_ids,result): 29 | objects = [] 30 | for box,label,score in zip(pred_boxes.numpy(),pred_labels.numpy(),pred_scores.numpy()): 31 | bbox = {"xmin":float(box[0]),"ymin":float(box[1]),"xmax":float(box[2]),"ymax":float(box[3])} 32 | category = classnames[label-1] 33 | score = float(score) 34 | objects.append({"bbox":bbox,"category":category,"score":score}) 35 | results[img_id+".jpg"]=objects 36 | 37 | os.makedirs(save_path,exist_ok=True) 38 | json.dump(results,open(os.path.join(save_path,"test_result.json"),"w")) 39 | 40 | def build_comparison(detection_f,gt_f,classnames=None): 41 | detections = json.load(open(detection_f)) 42 | gt_annos = json.load(open(gt_f)) 43 | gts = gt_annos["imgs"] 44 | if classnames is None: 45 | classnames = gts["types"] 46 | 47 | img_ids = set(gts.keys()) 48 | img_ids.update(detections.keys()) 49 | 50 | results = [] 51 | for img_id in img_ids: 52 | if not img_id in gts or len([o for o in gts[img_id]['objects'] if o["category"] in classnames])==0: 53 | gt_boxes = np.zeros((0,4)) 54 | gt_labels = np.zeros((0,)) 55 | else: 56 | gt = gts[img_id] 57 | objects = [o for o in gt['objects'] if o["category"] in classnames] 58 | gt_boxes = [[o['bbox']['xmin'],o['bbox']['ymin'],o['bbox']['xmax'],o['bbox']['ymax']] for o in objects] 59 | gt_labels = [classnames.index(o['category'])+1 for o in objects] 60 | gt_boxes = np.array(gt_boxes) 61 | gt_labels = np.array(gt_labels) 62 | 63 | w = gt_boxes[:,2]-gt_boxes[:,0] 64 | h = gt_boxes[:,3]-gt_boxes[:,1] 65 | use_range = (w>=16) & (h>=16) 66 | 67 | gt_boxes = gt_boxes[use_range,:] 68 | gt_labels = gt_labels[use_range] 69 | 70 | if img_id not in detections or len(detections[img_id])==0: 71 | pred_boxes = np.zeros((0,4)) 72 | pred_labels = np.zeros((0,)) 73 | pred_scores = np.zeros((0,)) 74 | else: 75 | objects = [o for o in detections[img_id] if o["category"] in classnames] 76 | pred_boxes = [[o['bbox']['xmin'],o['bbox']['ymin'],o['bbox']['xmax'],o['bbox']['ymax']] for o in objects] 77 | pred_labels = [classnames.index(o['category'])+1 for o in objects] 78 | pred_scores = [o["score"] for o in objects] 79 | 80 | pred_boxes = np.array(pred_boxes) 81 | pred_labels = np.array(pred_labels) 82 | pred_scores = np.array(pred_scores) 83 | 84 | results.append((img_id,pred_boxes,pred_labels,pred_scores,gt_boxes,gt_labels)) 85 | index = (pred_scores>0.5) 86 | if False and (index.sum()>0 or gt_labels.shape[0]>0): 87 | save_visualize_image(img_id,pred_boxes[index,:],pred_scores[index],pred_labels[index],gt_boxes,gt_labels,classnames) 88 | return results 89 | 90 | def filter_range(results,areaRange,is_filter_preds=True): 91 | filter_results = [] 92 | for img_id,pred_boxes,pred_labels,pred_scores,gt_boxes,gt_labels in results: 93 | if is_filter_preds: 94 | pred_w = pred_boxes[:,2]-pred_boxes[:,0] 95 | pred_h = pred_boxes[:,3]-pred_boxes[:,1] 96 | pred_area = pred_h*pred_w 97 | pred_used = (pred_area>=areaRange[0]) & (pred_area=areaRange[0]) & (gt_area 0).view(-1, 1).expand_as(in_weight)] = 1 24 | loc_loss = _smooth_l1_loss(pred_loc, gt_loc, in_weight.detach(), sigma) 25 | # Normalize by total number of negtive and positive rois. 26 | # ignore gt_label==-1 for rpn_loss 27 | loc_loss /= ((gt_label >= 0).sum().float()) 28 | return loc_loss 29 | 30 | 31 | class FasterRCNN(nn.Module): 32 | 33 | def __init__(self,n_class,backbone_name='resnet50'): 34 | super(FasterRCNN,self).__init__() 35 | if backbone_name=='resnet101': 36 | self.backbone = Resnet101(pretrained=True) 37 | elif backbone_name == 'resnet50': 38 | self.backbone = Resnet50(pretrained=True) 39 | else: 40 | assert False, f'{backbone_name} is not supported' 41 | 42 | self.n_class = n_class 43 | 44 | self.rpn = RegionProposalNetwork(in_channels=self.backbone.out_channels, 45 | mid_channels=512, 46 | ratios=[0.5, 1, 2], 47 | anchor_scales=[8, 16, 32], 48 | feat_stride=self.backbone.feat_stride, 49 | nms_thresh=0.7, 50 | n_train_pre_nms=12000, 51 | n_train_post_nms=2000, 52 | n_test_pre_nms=6000, 53 | n_test_post_nms=300, 54 | min_size=16,) 55 | 56 | self.anchor_target_creator = AnchorTargetCreator(n_sample=256, 57 | pos_iou_thresh=0.7, 58 | neg_iou_thresh=0.3, 59 | pos_ratio=0.5) 60 | 61 | self.proposal_target_creator = ProposalTargetCreator(n_sample=128, 62 | pos_ratio=0.25, 63 | pos_iou_thresh=0.5, 64 | neg_iou_thresh_hi=0.5, 65 | neg_iou_thresh_lo=0.0) 66 | 67 | self.head = RoIHead(in_channels=self.backbone.out_channels, 68 | n_class=n_class, 69 | roi_size=7, 70 | spatial_scale=1.0/self.backbone.feat_stride, 71 | sampling_ratio=0) 72 | 73 | self.rpn_sigma = 3. 74 | self.roi_sigma = 1. 75 | 76 | def execute(self,images,boxes=None,labels=None): 77 | # w,h 78 | img_size = (images.shape[-1],images.shape[-2]) 79 | features = self.backbone(images) 80 | if self.is_training(): 81 | return self._forward_train(features,img_size,boxes,labels) 82 | else: 83 | return self._forward_test(features,img_size) 84 | 85 | def _forward_train(self,features,img_size,boxes,labels): 86 | N = features.shape[0] 87 | rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn(features, img_size) 88 | 89 | sample_rois = [] 90 | gt_roi_locs = [] 91 | gt_roi_labels = [] 92 | sample_roi_indexs = [] 93 | gt_rpn_locs = [] 94 | gt_rpn_labels = [] 95 | for i in range(N): 96 | index = jt.where(roi_indices == i)[0] 97 | roi = rois[index,:] 98 | box = boxes[i] 99 | label = labels[i] 100 | sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator(roi,box,label) 101 | sample_roi_index = i*jt.ones((sample_roi.shape[0],)) 102 | 103 | sample_rois.append(sample_roi) 104 | gt_roi_labels.append(gt_roi_label) 105 | gt_roi_locs.append(gt_roi_loc) 106 | sample_roi_indexs.append(sample_roi_index) 107 | 108 | gt_rpn_loc, gt_rpn_label = self.anchor_target_creator(box,anchor,img_size) 109 | gt_rpn_locs.append(gt_rpn_loc) 110 | gt_rpn_labels.append(gt_rpn_label) 111 | 112 | sample_roi_indexs = jt.contrib.concat(sample_roi_indexs,dim=0) 113 | sample_rois = jt.contrib.concat(sample_rois,dim=0) 114 | roi_cls_loc, roi_score = self.head(features,sample_rois,sample_roi_indexs) 115 | 116 | # ------------------ RPN losses -------------------# 117 | rpn_locs = rpn_locs.reshape(-1,4) 118 | rpn_scores = rpn_scores.reshape(-1,2) 119 | gt_rpn_labels = jt.contrib.concat(gt_rpn_labels,dim=0) 120 | gt_rpn_locs = jt.contrib.concat(gt_rpn_locs,dim=0) 121 | rpn_loc_loss = _fast_rcnn_loc_loss(rpn_locs,gt_rpn_locs,gt_rpn_labels,self.rpn_sigma) 122 | rpn_cls_loss = nn.cross_entropy_loss(rpn_scores[gt_rpn_labels>=0,:],gt_rpn_labels[gt_rpn_labels>=0]) 123 | 124 | # ------------------ ROI losses (fast rcnn loss) -------------------# 125 | gt_roi_locs = jt.contrib.concat(gt_roi_locs,dim=0) 126 | gt_roi_labels = jt.contrib.concat(gt_roi_labels,dim=0) 127 | n_sample = roi_cls_loc.shape[0] 128 | roi_cls_loc = roi_cls_loc.view(n_sample, np.prod(roi_cls_loc.shape[1:]).item()//4, 4) 129 | roi_loc = roi_cls_loc[jt.arange(0, n_sample).int32(), gt_roi_labels] 130 | roi_loc_loss = _fast_rcnn_loc_loss(roi_loc,gt_roi_locs,gt_roi_labels,self.roi_sigma) 131 | roi_cls_loss = nn.cross_entropy_loss(roi_score, gt_roi_labels) 132 | 133 | losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss] 134 | losses = losses + [sum(losses)] 135 | return losses 136 | 137 | def _forward_test(self,features,img_size): 138 | rpn_locs, rpn_scores, rois, roi_indices, anchor = self.rpn(features, img_size) 139 | roi_cls_locs, roi_scores = self.head(features, rois, roi_indices) 140 | return rpn_locs, rpn_scores,roi_cls_locs, roi_scores, rois, roi_indices 141 | 142 | def predict(self, images,score_thresh=0.7,nms_thresh = 0.3): 143 | N = images.shape[0] 144 | img_size = (images.shape[-1],images.shape[-2]) 145 | rpn_locs, rpn_scores,roi_cls_locs, roi_scores, rois, roi_indices = self.execute(images) 146 | roi_cls_locs = roi_cls_locs.reshape(roi_cls_locs.shape[0],-1,4) 147 | probs = nn.softmax(roi_scores,dim=-1) 148 | rois = rois.unsqueeze(1).repeat(1,self.n_class,1) 149 | cls_bbox = loc2bbox(rois.reshape(-1,4),roi_cls_locs.reshape(-1,4)) 150 | cls_bbox[:,0::2] = jt.clamp(cls_bbox[:,0::2],min_v=0,max_v=img_size[0]) 151 | cls_bbox[:,1::2] = jt.clamp(cls_bbox[:,1::2],min_v=0,max_v=img_size[1]) 152 | 153 | cls_bbox = cls_bbox.reshape(roi_cls_locs.shape) 154 | 155 | results = [] 156 | for i in range(N): 157 | index = jt.where(roi_indices==i)[0] 158 | score = probs[index,:] 159 | bbox = cls_bbox[index,:,:] 160 | boxes = [] 161 | scores = [] 162 | labels = [] 163 | for j in range(1,self.n_class): 164 | bbox_j = bbox[:,j,:] 165 | score_j = score[:,j] 166 | mask = jt.where(score_j>score_thresh)[0] 167 | bbox_j = bbox_j[mask,:] 168 | score_j = score_j[mask] 169 | dets = jt.contrib.concat([bbox_j,score_j.unsqueeze(1)],dim=1) 170 | keep = jt.nms(dets,nms_thresh) 171 | bbox_j = bbox_j[keep] 172 | score_j = score_j[keep] 173 | label_j = jt.ones_like(score_j).int32()*j 174 | boxes.append(bbox_j) 175 | scores.append(score_j) 176 | labels.append(label_j) 177 | 178 | boxes = jt.contrib.concat(boxes,dim=0) 179 | scores = jt.contrib.concat(scores,dim=0) 180 | labels = jt.contrib.concat(labels,dim=0) 181 | results.append((boxes,scores,labels)) 182 | 183 | return results 184 | 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | import cv2 5 | import glob 6 | 7 | MAX_SIZE = 160 8 | MIN_SIZE = 30 9 | MAX_RATIO = 1.4 10 | MIN_RATIO = 0.6 11 | MAX_POSITION = 2048-240 12 | AUG_NUM=150 13 | 14 | def read_anno(file): 15 | if os.path.exists(file): 16 | return json.load(open(file)) 17 | return [] 18 | 19 | def overlaps(bbox_a,bbox_b): 20 | # top left 21 | tl = np.maximum(bbox_a[:, None,:2], bbox_b[:, :2]) 22 | # bottom right 23 | br = np.minimum(bbox_a[:,None,2:], bbox_b[:, 2:]) 24 | 25 | area_i = np.prod(br - tl, axis=2) * (tl < br).all(axis=2) 26 | area_a = np.prod(bbox_a[:, 2:] - bbox_a[:, :2], axis=1) 27 | area_b = np.prod(bbox_b[:, 2:] - bbox_b[:, :2], axis=1) 28 | return area_i / (area_a[:,None] + area_b - area_i) 29 | 30 | def filter_ious(boxes,iou_thresh=0.1): 31 | if boxes.shape[0]==0: 32 | return False 33 | ious = overlaps(boxes,boxes) 34 | lx = [i for i in range(ious.shape[0])] 35 | ious[lx,lx]=0 36 | fails = np.sum(ious>iou_thresh) 37 | if fails>0: 38 | return True 39 | return False 40 | 41 | def _random_boxes(num_boxes): 42 | x1,y1 = np.random.uniform(0,MAX_POSITION,size=(num_boxes,1)),np.random.uniform(0,MAX_POSITION,size=(num_boxes,1)) 43 | w = np.random.uniform(MIN_SIZE,MAX_SIZE,size=(num_boxes,1)) 44 | h = w*np.random.uniform(MIN_RATIO,MAX_RATIO,size=(num_boxes,1)) 45 | boxes = np.concatenate([x1,y1,x1+w,y1+h],axis=1) 46 | return boxes 47 | 48 | def random_boxes(num_boxes,num_images,num_classes): 49 | boxes = _random_boxes(num_boxes) 50 | img_indexes = np.random.randint(0,num_images,size=(num_boxes,)) 51 | counts = 0 52 | while counts<100: 53 | aug_classes = np.random.randint(0,num_classes,size=(num_boxes,)) 54 | _,counts = np.unique(aug_classes,return_counts=True) 55 | counts = np.min(counts) 56 | 57 | img_ids = np.unique(img_indexes) 58 | for i in img_ids: 59 | r = True 60 | while r: 61 | img_boxes = boxes[img_indexes==i,:] 62 | n = img_boxes.shape[0] 63 | r = filter_ious(img_boxes) 64 | if r: 65 | boxes[img_indexes==i,:] = _random_boxes(n) 66 | 67 | return boxes,img_indexes,aug_classes 68 | 69 | def add_noise(img,sigma=0.02): 70 | noise = np.random.randn(img.shape[0],img.shape[1],1)*np.max(img) 71 | img = img.astype(np.float32) 72 | img = img+noise*sigma 73 | img = np.maximum(img,0) 74 | img = np.minimum(img,255) 75 | img = img.astype(np.uint8) 76 | return img 77 | 78 | def paste_mark(img_file,boxes,classes,need_aug,mark_images,save_file,is_add_noise=True): 79 | img = cv2.imread(img_file) 80 | for b,c in zip(boxes,classes): 81 | mark = mark_images[need_aug[c]] 82 | w,h = int(b[2]-b[0]),int(b[3]-b[1]) 83 | x,y = int(b[0]),int(b[1]) 84 | mark = cv2.resize(mark,(w,h)) 85 | if mark.shape[2]==4: 86 | ratio = (mark[:,:,3:]>0)*1.0 87 | mark = mark[:,:,:3] 88 | else: 89 | ratio = np.ones((mark.shape[0],mark.shape[1],1)) 90 | if is_add_noise: 91 | mark = add_noise(mark) 92 | img[y:y+h,x:x+w]=mark*ratio+img[y:y+h,x:x+w]*(1.0-ratio) 93 | cv2.imwrite(save_file,img) 94 | 95 | def find_dir(data_dir,img_id): 96 | t_f = f"{data_dir}/test/{img_id}.jpg" 97 | tt_f = f"{data_dir}/train/{img_id}.jpg" 98 | o_f = f"{data_dir}/other/{img_id}.jpg" 99 | if os.path.exists(tt_f): 100 | return f"train/{img_id}.jpg" 101 | elif os.path.exists(t_f): 102 | return f"test/{img_id}.jpg" 103 | elif os.path.exists(o_f): 104 | return f"other/{img_id}.jpg" 105 | assert False,f"{img_id}.jpg is not exists" 106 | 107 | def select_empty_images(other_dir,annos): 108 | other_files = list(glob.glob(f"{other_dir}/*.jpg")) 109 | empty_files = [] 110 | for f in other_files: 111 | img_id = f.split("/")[-1].split(".jpg")[0] 112 | if img_id not in annos or len(annos[img_id]["objects"])==0: 113 | empty_files.append(f) 114 | return empty_files 115 | 116 | def build_annotations(): 117 | data_dir = "/data/lxl/dataset/tt100k/tt100k_2021" 118 | mark_dir = f"{data_dir}/crop_marks" 119 | remark_dir = f"{data_dir}/re_marks" 120 | anno_f = f"{data_dir}/annotations_all.json" 121 | 122 | crop_marks(data_dir) 123 | data = json.load(open(anno_f)) 124 | annos = data["imgs"] 125 | classnames = data["types"] 126 | empty_files = select_empty_images(f"{data_dir}/other",annos) 127 | 128 | marks = {} 129 | for category in classnames: 130 | m_f = f"{mark_dir}/{category}.png" 131 | rm_f = f"{remark_dir}/{category}.jpg" 132 | a = os.path.exists(m_f) 133 | marks[category]= m_f if a else rm_f 134 | if not os.path.exists(marks[category]): 135 | assert False,category 136 | 137 | category_nums = {c:0 for c in classnames} 138 | for img_id,v in annos.items(): 139 | objects = v["objects"] 140 | for o in objects: 141 | category = o["category"] 142 | category_nums[category]+=1 143 | 144 | need_aug = [c for c,v in category_nums.items() if v<100] 145 | mark_images = {c:cv2.imread(marks[c],cv2.IMREAD_UNCHANGED) for c in need_aug} 146 | 147 | num_boxes = len(need_aug)*AUG_NUM 148 | num_images = len(empty_files) 149 | num_classes = len(need_aug) 150 | boxes,img_indexes,aug_classes = random_boxes(num_boxes,num_images,num_classes) 151 | 152 | count = 0 153 | annotations = {"imgs":{},"types":classnames} 154 | 155 | for img_i in np.unique(img_indexes): 156 | img_boxes = boxes[img_indexes==img_i,:] 157 | img_file = empty_files[img_i] 158 | file_name = img_file.split("/")[-1] 159 | img_id = file_name.split(".")[0] 160 | 161 | anno = {"path":f"augmentations/{file_name}","id":int(img_id)} 162 | classes = aug_classes[img_indexes==img_i] 163 | 164 | os.makedirs(f"{data_dir}/augmentations",exist_ok=True) 165 | save_file = img_file.replace("other","augmentations") 166 | paste_mark(img_file,img_boxes,classes,need_aug,mark_images,save_file) 167 | 168 | objects = [] 169 | for box,c in zip(img_boxes,classes): 170 | bbox = {"xmin":float(box[0]),"ymin":float(box[1]),"xmax":float(box[2]),"ymax":float(box[3])} 171 | objects.append({"bbox":bbox,"category":need_aug[c]}) 172 | anno["objects"]=objects 173 | annotations["imgs"][img_id]=anno 174 | 175 | count+=1 176 | print(count,"/",len(empty_files)) 177 | 178 | for img_id,v in annotations["imgs"].items(): 179 | if img_id not in annos: 180 | annos[img_id]=v 181 | else: 182 | assert len(annos[img_id]["objects"])==0 183 | annos[img_id] = v 184 | 185 | json.dump({"imgs":annos,"types":classnames},open(f"{data_dir}/annotations_aug.json","w")) 186 | 187 | def crop_marks(data_dir): 188 | img_files = glob.glob(f"{data_dir}/marks/*.png") 189 | for img_file in img_files: 190 | img = cv2.imread(img_file,cv2.IMREAD_UNCHANGED) 191 | alpha = img[:,:,3] 192 | ww = np.sum(alpha,axis=0) 193 | l = 0 194 | while ww[l]==0: 195 | l+=1 196 | r = ww.shape[0]-1 197 | while ww[r-1]==0: 198 | r-=1 199 | 200 | hh = np.sum(alpha,axis=1) 201 | 202 | t = 0 203 | while hh[t]==0: 204 | t+=1 205 | b = hh.shape[0]-1 206 | while hh[b-1]==0: 207 | b-=1 208 | img = img[t:b,l:r] 209 | os.makedirs(f"{data_dir}/crop_marks",exist_ok=True) 210 | cv2.imwrite(img_file.replace("marks","crop_marks"),img) 211 | 212 | def draw_box(img,box,text,color): 213 | box = [int(x) for x in box] 214 | img = cv2.rectangle(img=img, pt1=tuple(box[0:2]), pt2=tuple(box[2:]), color=color, thickness=1) 215 | img = cv2.putText(img=img, text=text, org=(box[0],box[1]-5), fontFace=0, fontScale=0.5, color=color, thickness=1) 216 | return img 217 | 218 | 219 | def draw_boxes(img,boxes,classnames,color=(255,0,0)): 220 | for box,label in zip(boxes,classnames): 221 | box = [int(i) for i in box] 222 | img = draw_box(img,box,label,color) 223 | return img 224 | 225 | def test(): 226 | data_dir = "/data/lxl/dataset/tt100k/tt100k_2021" 227 | tmp_dir = "./tmp" 228 | os.makedirs(tmp_dir,exist_ok=True) 229 | anno_f = data_dir+"/annotations_aug.json" 230 | annos = json.load(open(anno_f)) 231 | imgs = list(annos["imgs"].values()) 232 | imgs = [img for img in imgs if "augmentations" in img["path"]] 233 | indexes = [i for i in range(len(imgs))] 234 | np.random.shuffle(indexes) 235 | for i in indexes[:10]: 236 | data = imgs[i] 237 | path = data_dir+"/"+data["path"] 238 | objects = data["objects"] 239 | img = cv2.imread(path) 240 | boxes = [[o['bbox']['xmin'],o['bbox']['ymin'],o['bbox']['xmax'],o['bbox']['ymax']] for o in objects] 241 | labels = [o['category'] for o in objects] 242 | img = draw_boxes(img,boxes,labels) 243 | img_file = f"{tmp_dir}/{data['id']}.jpg" 244 | cv2.imwrite(img_file,img) 245 | 246 | def main(): 247 | np.random.seed(0) 248 | build_annotations() 249 | test() 250 | 251 | if __name__ == '__main__': 252 | main() -------------------------------------------------------------------------------- /utils/roi_align.py: -------------------------------------------------------------------------------- 1 | from jittor import nn 2 | import jittor as jt 3 | 4 | from jittor.misc import _pair 5 | 6 | CUDA_HEADER = r''' 7 | #include 8 | #include 9 | #include 10 | #define CUDA_1D_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) 11 | using namespace std; 12 | 13 | __device__ float bilinear_interpolate(const float* bottom_data, 14 | const int height, const int width, 15 | float y, float x, 16 | const int index /* index for debug only*/) { 17 | 18 | // deal with cases that inverse elements are out of feature map boundary 19 | if (y < -1.0 || y > height || x < -1.0 || x > width) { 20 | //empty 21 | return 0; 22 | } 23 | 24 | if (y <= 0) y = 0; 25 | if (x <= 0) x = 0; 26 | 27 | int y_low = (int) y; 28 | int x_low = (int) x; 29 | int y_high; 30 | int x_high; 31 | 32 | if (y_low >= height - 1) { 33 | y_high = y_low = height - 1; 34 | y = (float) y_low; 35 | } else { 36 | y_high = y_low + 1; 37 | } 38 | 39 | if (x_low >= width - 1) { 40 | x_high = x_low = width - 1; 41 | x = (float) x_low; 42 | } else { 43 | x_high = x_low + 1; 44 | } 45 | 46 | float ly = y - y_low; 47 | float lx = x - x_low; 48 | float hy = 1. - ly, hx = 1. - lx; 49 | // do bilinear interpolation 50 | float v1 = bottom_data[y_low * width + x_low]; 51 | float v2 = bottom_data[y_low * width + x_high]; 52 | float v3 = bottom_data[y_high * width + x_low]; 53 | float v4 = bottom_data[y_high * width + x_high]; 54 | float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 55 | 56 | float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 57 | 58 | return val; 59 | } 60 | 61 | __device__ void bilinear_interpolate_gradient( 62 | const int height, const int width, 63 | float y, float x, 64 | float & w1, float & w2, float & w3, float & w4, 65 | int & x_low, int & x_high, int & y_low, int & y_high, 66 | const int index /* index for debug only*/) { 67 | 68 | // deal with cases that inverse elements are out of feature map boundary 69 | if (y < -1.0 || y > height || x < -1.0 || x > width) { 70 | //empty 71 | w1 = w2 = w3 = w4 = 0.; 72 | x_low = x_high = y_low = y_high = -1; 73 | return; 74 | } 75 | 76 | if (y <= 0) y = 0; 77 | if (x <= 0) x = 0; 78 | 79 | y_low = (int) y; 80 | x_low = (int) x; 81 | 82 | if (y_low >= height - 1) { 83 | y_high = y_low = height - 1; 84 | y = (float) y_low; 85 | } else { 86 | y_high = y_low + 1; 87 | } 88 | 89 | if (x_low >= width - 1) { 90 | x_high = x_low = width - 1; 91 | x = (float) x_low; 92 | } else { 93 | x_high = x_low + 1; 94 | } 95 | 96 | float ly = y - y_low; 97 | float lx = x - x_low; 98 | float hy = 1. - ly, hx = 1. - lx; 99 | 100 | // reference in forward 101 | // float v1 = bottom_data[y_low * width + x_low]; 102 | // float v2 = bottom_data[y_low * width + x_high]; 103 | // float v3 = bottom_data[y_high * width + x_low]; 104 | // float v4 = bottom_data[y_high * width + x_high]; 105 | // float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 106 | 107 | w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 108 | 109 | return; 110 | } 111 | 112 | ''' 113 | 114 | CUDA_SRC = r''' 115 | __global__ void RoIAlignForward(@ARGS_DEF,const int nthreads, const float* bottom_data, 116 | const int channels,const int height, const int width,const int pooled_height, const int pooled_width, 117 | const float* bottom_rois, float* top_data) { 118 | @PRECALC 119 | const float spatial_scale = @in2(0); 120 | const float sampling_ratio = @in2(1); 121 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 122 | // (n, c, ph, pw) is an element in the pooled output 123 | int pw = index % pooled_width; 124 | int ph = (index / pooled_width) % pooled_height; 125 | int c = (index / pooled_width / pooled_height) % channels; 126 | int n = index / pooled_width / pooled_height / channels; 127 | 128 | const float* offset_bottom_rois = bottom_rois + n * 5; 129 | int roi_batch_ind = (int)offset_bottom_rois[0]; 130 | 131 | // Do not using rounding; this implementation detail is critical 132 | auto roi_start_w = offset_bottom_rois[1] * spatial_scale; 133 | auto roi_start_h = offset_bottom_rois[2] * spatial_scale; 134 | auto roi_end_w = offset_bottom_rois[3] * spatial_scale; 135 | auto roi_end_h = offset_bottom_rois[4] * spatial_scale; 136 | 137 | // Force malformed ROIs to be 1x1 138 | auto roi_width = max(roi_end_w - roi_start_w, 1.); 139 | auto roi_height = max(roi_end_h - roi_start_h, 1.); 140 | auto bin_size_h = static_cast(roi_height) / static_cast(pooled_height); 141 | auto bin_size_w = static_cast(roi_width) / static_cast(pooled_width); 142 | 143 | const float* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; 144 | 145 | // We use roi_bin_grid to sample the grid and mimic integral 146 | int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 147 | int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); 148 | 149 | // We do average (integral) pooling inside a bin 150 | const float count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 151 | 152 | float output_val = 0.; 153 | for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 154 | { 155 | const float y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 156 | for (int ix = 0; ix < roi_bin_grid_w; ix ++) 157 | { 158 | const float x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); 159 | 160 | float val = bilinear_interpolate(offset_bottom_data, height, width, y, x, index); 161 | output_val += val; 162 | } 163 | } 164 | output_val /= count; 165 | 166 | top_data[index] = output_val; 167 | } 168 | } 169 | 170 | @alias(input,in0); 171 | @alias(rois,in1); 172 | @alias(output,out0); 173 | auto num_rois = rois_shape0; 174 | auto channels = input_shape1; 175 | auto height = input_shape2; 176 | auto width = input_shape3; 177 | auto pooled_height = output_shape2; 178 | auto pooled_width = output_shape3; 179 | 180 | auto output_size = num_rois * pooled_height * pooled_width * channels; 181 | const int total_count = in1_shape0 * out0_shape2 * out0_shape3 * in0_shape1; 182 | const int thread_per_block = 512L; 183 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 184 | RoIAlignForward<<>>(@ARGS,output_size,input_p,channels, 185 | height, width,pooled_height,pooled_width,rois_p,output_p); 186 | ''' 187 | 188 | CUDA_GRAD_SRC = [r''' 189 | __global__ void RoIAlignBackwardFeature(@ARGS_DEF,const int nthreads, const float* top_diff, 190 | const int num_rois, 191 | const int channels, const int height, const int width, 192 | const int pooled_height, const int pooled_width, 193 | float* bottom_diff, 194 | const float* bottom_rois) { 195 | @PRECALC 196 | 197 | @alias(input,in0) 198 | @alias(rois,in1) 199 | @alias(grad_input,out0) 200 | @alias(grad,dout) 201 | 202 | const float spatial_scale = @in2(0); 203 | const float sampling_ratio = @in2(1); 204 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 205 | // (n, c, ph, pw) is an element in the pooled output 206 | int pw = index % pooled_width; 207 | int ph = (index / pooled_width) % pooled_height; 208 | int c = (index / pooled_width / pooled_height) % channels; 209 | int n = index / pooled_width / pooled_height / channels; 210 | 211 | const float* offset_bottom_rois = bottom_rois + n * 5; 212 | int roi_batch_ind = offset_bottom_rois[0]; 213 | 214 | // Do not using rounding; this implementation detail is critical 215 | float roi_start_w = offset_bottom_rois[1] * spatial_scale; 216 | float roi_start_h = offset_bottom_rois[2] * spatial_scale; 217 | float roi_end_w = offset_bottom_rois[3] * spatial_scale; 218 | float roi_end_h = offset_bottom_rois[4] * spatial_scale; 219 | // float roi_start_w = round(offset_bottom_rois[1] * spatial_scale); 220 | // float roi_start_h = round(offset_bottom_rois[2] * spatial_scale); 221 | // float roi_end_w = round(offset_bottom_rois[3] * spatial_scale); 222 | // float roi_end_h = round(offset_bottom_rois[4] * spatial_scale); 223 | 224 | // Force malformed ROIs to be 1x1 225 | float roi_width = max(roi_end_w - roi_start_w, (float)1.); 226 | float roi_height = max(roi_end_h - roi_start_h, (float)1.); 227 | float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); 228 | float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); 229 | 230 | float* offset_bottom_diff = bottom_diff + (roi_batch_ind * channels + c) * height * width; 231 | 232 | int top_offset = (n * channels + c) * pooled_height * pooled_width; 233 | const float* offset_top_diff = top_diff + top_offset; 234 | const float top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; 235 | 236 | // We use roi_bin_grid to sample the grid and mimic integral 237 | int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 238 | int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); 239 | 240 | // We do average (integral) pooling inside a bin 241 | const float count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 242 | 243 | for (int iy = 0; iy < roi_bin_grid_h; iy ++) // e.g., iy = 0, 1 244 | { 245 | const float y = roi_start_h + ph * bin_size_h + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 246 | for (int ix = 0; ix < roi_bin_grid_w; ix ++) 247 | { 248 | const float x = roi_start_w + pw * bin_size_w + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); 249 | 250 | float w1, w2, w3, w4; 251 | int x_low, x_high, y_low, y_high; 252 | 253 | bilinear_interpolate_gradient(height, width, y, x, 254 | w1, w2, w3, w4, 255 | x_low, x_high, y_low, y_high, 256 | index); 257 | 258 | float g1 = top_diff_this_bin * w1 / count; 259 | float g2 = top_diff_this_bin * w2 / count; 260 | float g3 = top_diff_this_bin * w3 / count; 261 | float g4 = top_diff_this_bin * w4 / count; 262 | 263 | if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) 264 | { 265 | atomicAdd(offset_bottom_diff + y_low * width + x_low, static_cast(g1)); 266 | atomicAdd(offset_bottom_diff + y_low * width + x_high, static_cast(g2)); 267 | atomicAdd(offset_bottom_diff + y_high * width + x_low, static_cast(g3)); 268 | atomicAdd(offset_bottom_diff + y_high * width + x_high, static_cast(g4)); 269 | } // if 270 | } // ix 271 | } // iy 272 | } // CUDA_1D_KERNEL_LOOP 273 | } // RoIAlignBackward 274 | 275 | 276 | auto num_rois = rois_shape0; 277 | auto channels = input_shape1; 278 | auto height = input_shape2; 279 | auto width = input_shape3; 280 | auto pooled_height = grad_shape2; 281 | auto pooled_width = grad_shape3; 282 | 283 | auto output_size = num_rois * pooled_height * pooled_width * channels; 284 | cudaMemsetAsync(grad_input_p,0,grad_input->size); 285 | const int total_count = rois_shape0 * grad_shape2 * grad_shape3 * input_shape1; 286 | const int thread_per_block = 512; 287 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 288 | RoIAlignBackwardFeature<<>>(@ARGS,output_size,grad_p,num_rois, 289 | channels, 290 | height, 291 | width, 292 | pooled_height, 293 | pooled_width,grad_input_p,rois_p); 294 | ''','',''] 295 | 296 | def roi_align(input, rois, output_size, spatial_scale, sampling_ratio): 297 | output_size = _pair(output_size) 298 | options = jt.array([spatial_scale,sampling_ratio]) 299 | output_shapes = (rois.shape[0], input.shape[1], output_size[0], output_size[1]) 300 | inputs = [input,rois,options] 301 | output_types = input.dtype 302 | if rois.shape[0]==0: 303 | return jt.zeros(output_shapes,input.dtype) 304 | output = jt.code(output_shapes,output_types,inputs,cuda_header=CUDA_HEADER,cuda_src=CUDA_SRC,cuda_grad_src=CUDA_GRAD_SRC) 305 | return output 306 | 307 | 308 | class ROIAlign(nn.Module): 309 | def __init__(self, output_size, spatial_scale, sampling_ratio): 310 | super(ROIAlign, self).__init__() 311 | self.output_size = output_size 312 | self.spatial_scale = spatial_scale 313 | self.sampling_ratio = sampling_ratio 314 | 315 | def execute(self, input, rois): 316 | return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) 317 | 318 | -------------------------------------------------------------------------------- /model/rpn.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn,init 3 | import numpy as np 4 | from utils.box_ops import _enumerate_shifted_anchor,bbox2loc,loc2bbox,generate_anchor_base,_unmap,bbox_iou 5 | 6 | class ProposalTargetCreator(nn.Module): 7 | """Assign ground truth bounding boxes to given RoIs. 8 | 9 | Args: 10 | n_sample (int): The number of sampled regions. 11 | pos_ratio (float): Fraction of regions that is labeled as a 12 | foreground. 13 | pos_iou_thresh (float): IoU threshold for a RoI to be considered as a 14 | foreground. 15 | neg_iou_thresh_hi (float): RoI is considered to be the background 16 | if IoU is in 17 | [:obj:`neg_iou_thresh_hi`, :obj:`neg_iou_thresh_hi`). 18 | neg_iou_thresh_lo (float): See above. 19 | 20 | """ 21 | 22 | def __init__(self, 23 | n_sample=128, 24 | pos_ratio=0.25, 25 | pos_iou_thresh=0.5, 26 | neg_iou_thresh_hi=0.5, 27 | neg_iou_thresh_lo=0.0 28 | ): 29 | super(ProposalTargetCreator,self).__init__() 30 | self.n_sample = n_sample 31 | self.pos_ratio = pos_ratio 32 | self.pos_iou_thresh = pos_iou_thresh 33 | self.neg_iou_thresh_hi = neg_iou_thresh_hi 34 | self.neg_iou_thresh_lo = neg_iou_thresh_lo # NOTE:default 0.1 in py-faster-rcnn 35 | 36 | 37 | def execute(self, roi, bbox, label): 38 | """Assigns ground truth to sampled proposals. 39 | 40 | This function samples total of :obj:`self.n_sample` RoIs 41 | from the combination of :obj:`roi` and :obj:`bbox`. 42 | The RoIs are assigned with the ground truth class labels as well as 43 | bounding box offsets and scales to match the ground truth bounding 44 | boxes. As many as :obj:`pos_ratio * self.n_sample` RoIs are 45 | sampled as foregrounds. 46 | 47 | Offsets and scales of bounding boxes are calculated using 48 | :func:`model.utils.bbox_tools.bbox2loc`. 49 | Also, types of input arrays and output arrays are same. 50 | 51 | Here are notations. 52 | 53 | * :math:`S` is the total number of sampled RoIs, which equals \ 54 | :obj:`self.n_sample`. 55 | * :math:`L` is number of object classes possibly including the \ 56 | background. 57 | 58 | Args: 59 | roi (array): Region of Interests (RoIs) from which we sample. 60 | Its shape is :math:`(R, 4)` 61 | bbox (array): The coordinates of ground truth bounding boxes. 62 | Its shape is :math:`(R', 4)`. 63 | label (array): Ground truth bounding box labels. Its shape 64 | is :math:`(R',)`. Its range is :math:`[0, L - 1]`, where 65 | :math:`L` is the number of foreground classes. 66 | 67 | Returns: 68 | (array, array, array): 69 | 70 | * **sample_roi**: Regions of interests that are sampled. \ 71 | Its shape is :math:`(S, 4)`. 72 | * **gt_roi_loc**: Offsets and scales to match \ 73 | the sampled RoIs to the ground truth bounding boxes. \ 74 | Its shape is :math:`(S, 4)`. 75 | * **gt_roi_label**: Labels assigned to sampled RoIs. Its shape is \ 76 | :math:`(S,)`. Its range is :math:`[0, L]`. The label with \ 77 | value 0 is the background. 78 | 79 | """ 80 | pos_roi_per_image = np.round(self.n_sample * self.pos_ratio) 81 | iou = bbox_iou(roi, bbox) 82 | gt_assignment,max_iou = iou.argmax(dim=1) 83 | # Offset range of classes from [0, n_fg_class - 1] to [1, n_fg_class]. 84 | # The label with value 0 is the background. 85 | gt_roi_label = label[gt_assignment] 86 | 87 | # Select foreground RoIs as those with >= pos_iou_thresh IoU. 88 | pos_index = jt.where(max_iou >= self.pos_iou_thresh)[0] 89 | pos_roi_per_this_image = int(min(pos_roi_per_image, pos_index.shape[0])) 90 | if pos_index.shape[0] > 0: 91 | tmp_indexes = np.arange(0,pos_index.shape[0]) 92 | np.random.shuffle(tmp_indexes) 93 | tmp_indexes = tmp_indexes[:pos_roi_per_this_image] 94 | pos_index = pos_index[tmp_indexes] 95 | 96 | # Select background RoIs as those within 97 | # [neg_iou_thresh_lo, neg_iou_thresh_hi). 98 | neg_index = jt.where((max_iou < self.neg_iou_thresh_hi) & 99 | (max_iou >= self.neg_iou_thresh_lo))[0] 100 | neg_roi_per_this_image = self.n_sample - pos_roi_per_this_image 101 | neg_roi_per_this_image = int(min(neg_roi_per_this_image, 102 | neg_index.shape[0])) 103 | if neg_index.shape[0] > 0: 104 | tmp_indexes = np.arange(0,neg_index.shape[0]) 105 | np.random.shuffle(tmp_indexes) 106 | tmp_indexes = tmp_indexes[:neg_roi_per_this_image] 107 | neg_index = neg_index[tmp_indexes] 108 | 109 | 110 | # The indices that we're selecting (both positive and negative). 111 | keep_index = jt.contrib.concat((pos_index, neg_index),dim=0) 112 | gt_roi_label = gt_roi_label[keep_index] 113 | gt_roi_label[pos_roi_per_this_image:] = 0 # negative labels --> 0 114 | sample_roi = roi[keep_index] 115 | 116 | # Compute offsets and scales to match sampled RoIs to the GTs. 117 | gt_roi_loc = bbox2loc(sample_roi, bbox[gt_assignment[keep_index]]) 118 | 119 | return sample_roi, gt_roi_loc, gt_roi_label 120 | 121 | 122 | class AnchorTargetCreator(nn.Module): 123 | """Assign the ground truth bounding boxes to anchors. 124 | 125 | Args: 126 | n_sample (int): The number of regions to produce. 127 | pos_iou_thresh (float): Anchors with IoU above this 128 | threshold will be assigned as positive. 129 | neg_iou_thresh (float): Anchors with IoU below this 130 | threshold will be assigned as negative. 131 | pos_ratio (float): Ratio of positive regions in the 132 | sampled regions. 133 | 134 | """ 135 | 136 | def __init__(self, 137 | n_sample=256, 138 | pos_iou_thresh=0.7, 139 | neg_iou_thresh=0.3, 140 | pos_ratio=0.5): 141 | super(AnchorTargetCreator,self).__init__() 142 | self.n_sample = n_sample 143 | self.pos_iou_thresh = pos_iou_thresh 144 | self.neg_iou_thresh = neg_iou_thresh 145 | self.pos_ratio = pos_ratio 146 | 147 | def execute(self, bbox, anchor, img_size): 148 | """Assign ground truth supervision to sampled subset of anchors. 149 | 150 | Types of input arrays and output arrays are same. 151 | 152 | Here are notations. 153 | 154 | * :math:`S` is the number of anchors. 155 | * :math:`R` is the number of bounding boxes. 156 | 157 | Args: 158 | bbox (array): Coordinates of bounding boxes. Its shape is 159 | :math:`(R, 4)`. 160 | anchor (array): Coordinates of anchors. Its shape is 161 | :math:`(S, 4)`. 162 | img_size (tuple of ints): A tuple :obj:`H, W`, which 163 | is a tuple of height and width of an image. 164 | 165 | Returns: 166 | (array, array): 167 | 168 | #NOTE: it's scale not only offset 169 | * **loc**: Offsets and scales to match the anchors to \ 170 | the ground truth bounding boxes. Its shape is :math:`(S, 4)`. 171 | * **label**: Labels of anchors with values \ 172 | :obj:`(1=positive, 0=negative, -1=ignore)`. Its shape \ 173 | is :math:`(S,)`. 174 | 175 | """ 176 | 177 | img_W, img_H = img_size 178 | 179 | n_anchor = len(anchor) 180 | inside_index = jt.where( 181 | (anchor[:, 0] >= 0) & 182 | (anchor[:, 1] >= 0) & 183 | (anchor[:, 2] <= img_W) & 184 | (anchor[:, 3] <= img_H) 185 | )[0] 186 | anchor = anchor[inside_index] 187 | argmax_ious, label = self._create_label(anchor, bbox) 188 | 189 | # compute bounding box regression targets 190 | loc = bbox2loc(anchor, bbox[argmax_ious]) 191 | 192 | # map up to original set of anchors 193 | label = _unmap(label, n_anchor, inside_index, fill=-1) 194 | loc = _unmap(loc, n_anchor, inside_index, fill=0) 195 | 196 | return loc, label 197 | 198 | def _create_label(self, anchor, bbox): 199 | # label: 1 is positive, 0 is negative, -1 is dont care 200 | label = -jt.ones((anchor.shape[0],), dtype="int32") 201 | 202 | argmax_ious, max_ious, gt_argmax_ious = self._calc_ious(anchor, bbox) 203 | 204 | # assign negative labels first so that positive labels can clobber them 205 | label[max_ious < self.neg_iou_thresh] = 0 206 | 207 | # positive label: for each gt, anchor with highest iou 208 | label[gt_argmax_ious] = 1 209 | 210 | # positive label: above threshold IOU 211 | label[max_ious >= self.pos_iou_thresh] = 1 212 | 213 | # subsample positive labels if we have too many 214 | n_pos = int(self.pos_ratio * self.n_sample) 215 | pos_index = jt.where(label == 1)[0] 216 | if len(pos_index) > n_pos: 217 | tmp_index = np.arange(0,pos_index.shape[0]) 218 | np.random.shuffle(tmp_index) 219 | disable_index = tmp_index[:pos_index.shape[0] - n_pos] 220 | disable_index = pos_index[disable_index] 221 | label[disable_index] = -1 222 | 223 | # subsample negative labels if we have too many 224 | n_neg = self.n_sample - jt.sum(label == 1).item() 225 | neg_index = jt.where(label == 0)[0] 226 | if len(neg_index) > n_neg: 227 | tmp_index = np.arange(0,neg_index.shape[0]) 228 | np.random.shuffle(tmp_index) 229 | disable_index = tmp_index[:neg_index.shape[0] - n_neg] 230 | disable_index = neg_index[disable_index] 231 | label[disable_index] = -1 232 | return argmax_ious, label 233 | 234 | def _calc_ious(self, anchor, bbox): 235 | # ious between the anchors and the gt boxes 236 | ious = bbox_iou(anchor, bbox) 237 | argmax_ious,max_ious = ious.argmax(dim=1) 238 | 239 | gt_argmax_ious,gt_max_ious = ious.argmax(dim=0) 240 | gt_argmax_ious = jt.where(ious == gt_max_ious)[0] 241 | 242 | return argmax_ious, max_ious, gt_argmax_ious 243 | 244 | 245 | 246 | class ProposalCreator(nn.Module): 247 | """Proposal regions are generated by calling this object. 248 | 249 | Args: 250 | nms_thresh (float): Threshold value used when calling NMS. 251 | n_train_pre_nms (int): Number of top scored bounding boxes 252 | to keep before passing to NMS in train mode. 253 | n_train_post_nms (int): Number of top scored bounding boxes 254 | to keep after passing to NMS in train mode. 255 | n_test_pre_nms (int): Number of top scored bounding boxes 256 | to keep before passing to NMS in test mode. 257 | n_test_post_nms (int): Number of top scored bounding boxes 258 | to keep after passing to NMS in test mode. 259 | force_cpu_nms (bool): If this is :obj:`True`, 260 | always use NMS in CPU mode. If :obj:`False`, 261 | the NMS mode is selected based on the type of inputs. 262 | min_size (int): A paramter to determine the threshold on 263 | discarding bounding boxes based on their sizes. 264 | 265 | """ 266 | 267 | def __init__(self, 268 | nms_thresh=0.7, 269 | n_train_pre_nms=12000, 270 | n_train_post_nms=2000, 271 | n_test_pre_nms=6000, 272 | n_test_post_nms=300, 273 | min_size=16 274 | ): 275 | super(ProposalCreator,self).__init__() 276 | self.nms_thresh = nms_thresh 277 | self.n_train_pre_nms = n_train_pre_nms 278 | self.n_train_post_nms = n_train_post_nms 279 | self.n_test_pre_nms = n_test_pre_nms 280 | self.n_test_post_nms = n_test_post_nms 281 | self.min_size = min_size 282 | 283 | def execute(self, loc, score,anchor, img_size, scale=1.): 284 | """input should be ndarray 285 | Propose RoIs. 286 | 287 | Inputs :obj:`loc, score, anchor` refer to the same anchor when indexed 288 | by the same index. 289 | 290 | On notations, :math:`R` is the total number of anchors. This is equal 291 | to product of the height and the width of an image and the number of 292 | anchor bases per pixel. 293 | 294 | Type of the output is same as the inputs. 295 | 296 | Args: 297 | loc (array): Predicted offsets and scaling to anchors. 298 | Its shape is :math:`(R, 4)`. 299 | score (array): Predicted foreground probability for anchors. 300 | Its shape is :math:`(R,)`. 301 | anchor (array): Coordinates of anchors. Its shape is 302 | :math:`(R, 4)`. 303 | img_size (tuple of ints): A tuple :obj:`height, width`, 304 | which contains image size after scaling. 305 | scale (float): The scaling factor used to scale an image after 306 | reading it from a file. 307 | 308 | Returns: 309 | array: 310 | An array of coordinates of proposal boxes. 311 | Its shape is :math:`(S, 4)`. :math:`S` is less than 312 | :obj:`self.n_test_post_nms` in test time and less than 313 | :obj:`self.n_train_post_nms` in train time. :math:`S` depends on 314 | the size of the predicted bounding boxes and the number of 315 | bounding boxes discarded by NMS. 316 | 317 | """ 318 | # NOTE: when test, remember 319 | if self.is_training(): 320 | n_pre_nms = self.n_train_pre_nms 321 | n_post_nms = self.n_train_post_nms 322 | else: 323 | n_pre_nms = self.n_test_pre_nms 324 | n_post_nms = self.n_test_post_nms 325 | 326 | # Convert anchors into proposal via bbox transformations. 327 | roi = loc2bbox(anchor, loc) 328 | 329 | # Clip predicted boxes to image. 330 | roi[:,0] = jt.clamp(roi[:,0],min_v=0,max_v=img_size[0]) 331 | roi[:,2] = jt.clamp(roi[:,2],min_v=0,max_v=img_size[0]) 332 | 333 | roi[:,1] = jt.clamp(roi[:,1],min_v=0,max_v=img_size[1]) 334 | roi[:,3] = jt.clamp(roi[:,3],min_v=0,max_v=img_size[1]) 335 | 336 | # Remove predicted boxes with either height or width < threshold. 337 | min_size = self.min_size * scale 338 | hs = roi[:, 2] - roi[:, 0] 339 | ws = roi[:, 3] - roi[:, 1] 340 | keep = jt.where((hs >= min_size) & (ws >= min_size))[0] 341 | roi = roi[keep, :] 342 | score = score[keep] 343 | 344 | # Sort all (proposal, score) pairs by score from highest to lowest. 345 | # Take top pre_nms_topN (e.g. 6000). 346 | order,_ = jt.argsort(score, descending=True) 347 | if n_pre_nms > 0: 348 | order = order[:n_pre_nms] 349 | roi = roi[order, :] 350 | score = score[order] 351 | 352 | # Apply nms (e.g. threshold = 0.7). 353 | # Take after_nms_topN (e.g. 300). 354 | 355 | dets = jt.contrib.concat([roi,score.unsqueeze(1)],dim=1) 356 | keep = jt.nms(dets,self.nms_thresh) 357 | if n_post_nms > 0: 358 | keep = keep[:n_post_nms] 359 | roi = roi[keep] 360 | return roi 361 | 362 | 363 | 364 | class RegionProposalNetwork(nn.Module): 365 | 366 | def __init__(self, 367 | in_channels=512, 368 | mid_channels=512, 369 | ratios=[0.5, 1, 2], 370 | anchor_scales=[8, 16, 32], 371 | feat_stride=16, 372 | nms_thresh=0.7, 373 | n_train_pre_nms=12000, 374 | n_train_post_nms=2000, 375 | n_test_pre_nms=6000, 376 | n_test_post_nms=300, 377 | min_size=16, 378 | ): 379 | super(RegionProposalNetwork, self).__init__() 380 | self.anchor_base = generate_anchor_base( 381 | anchor_scales=anchor_scales, ratios=ratios) 382 | self.feat_stride = feat_stride 383 | self.proposal_layer = ProposalCreator(nms_thresh=nms_thresh, 384 | n_train_pre_nms=n_train_pre_nms, 385 | n_train_post_nms=n_train_post_nms, 386 | n_test_pre_nms=n_test_pre_nms, 387 | n_test_post_nms=n_test_post_nms, 388 | min_size=min_size) 389 | n_anchor = self.anchor_base.shape[0] 390 | self.conv1 = nn.Conv(in_channels, mid_channels, 3, 1, 1) 391 | self.score = nn.Conv(mid_channels, n_anchor * 2, 1, 1, 0) 392 | self.loc = nn.Conv(mid_channels, n_anchor * 4, 1, 1, 0) 393 | self._normal_init() 394 | 395 | 396 | def _normal_init(self): 397 | for var in [self.conv1,self.score,self.loc]: 398 | init.gauss_(var.weight,0,0.01) 399 | init.constant_(var.bias,0.0) 400 | 401 | def execute(self, x, img_size,scale=1.0): 402 | """Forward Region Proposal Network. 403 | 404 | Here are notations. 405 | 406 | * :math:`N` is batch size. 407 | * :math:`C` channel size of the input. 408 | * :math:`H` and :math:`W` are height and witdh of the input feature. 409 | * :math:`A` is number of anchors assigned to each pixel. 410 | 411 | Args: 412 | x : The Features extracted from images. 413 | Its shape is :math:`(N, C, H, W)`. 414 | img_size (tuple of ints): A tuple :obj:`height, width`, 415 | which contains image size after scaling. 416 | 417 | Returns: 418 | This is a tuple of five following values. 419 | 420 | * **rpn_locs**: Predicted bounding box offsets and scales for \ 421 | anchors. Its shape is :math:`(N, H W A, 4)`. 422 | * **rpn_scores**: Predicted foreground scores for \ 423 | anchors. Its shape is :math:`(N, H W A, 2)`. 424 | * **rois**: A bounding box array containing coordinates of \ 425 | proposal boxes. This is a concatenation of bounding box \ 426 | arrays from multiple images in the batch. \ 427 | Its shape is :math:`(R', 4)`. Given :math:`R_i` predicted \ 428 | bounding boxes from the :math:`i` th image, \ 429 | :math:`R' = \\sum _{i=1} ^ N R_i`. 430 | * **roi_indices**: An array containing indices of images to \ 431 | which RoIs correspond to. Its shape is :math:`(R',)`. 432 | * **anchor**: Coordinates of enumerated shifted anchors. \ 433 | Its shape is :math:`(H W A, 4)`. 434 | 435 | """ 436 | n, _, hh, ww = x.shape 437 | anchor = _enumerate_shifted_anchor(self.anchor_base,self.feat_stride, hh, ww) 438 | anchor = jt.array(anchor) 439 | 440 | n_anchor = anchor.shape[0] // (hh * ww) 441 | h = nn.relu(self.conv1(x)) 442 | 443 | rpn_locs = self.loc(h) 444 | 445 | rpn_locs = rpn_locs.permute(0, 2, 3, 1).view(n, -1, 4) 446 | rpn_scores = self.score(h) 447 | rpn_scores = rpn_scores.permute(0, 2, 3, 1) 448 | rpn_softmax_scores = nn.softmax(rpn_scores.view(n, hh, ww, n_anchor, 2), dim=4) 449 | rpn_fg_scores = rpn_softmax_scores[:, :, :, :, 1] 450 | rpn_fg_scores = rpn_fg_scores.view(n, -1) 451 | rpn_scores = rpn_scores.view(n, -1, 2) 452 | rois = [] 453 | roi_indices = [] 454 | for i in range(n): 455 | roi = self.proposal_layer( 456 | rpn_locs[i], 457 | rpn_fg_scores[i], 458 | anchor, 459 | img_size, 460 | scale) 461 | batch_index = i * jt.ones((len(roi),), dtype='int32') 462 | rois.append(roi) 463 | roi_indices.append(batch_index) 464 | 465 | rois = jt.contrib.concat(rois, dim=0) 466 | roi_indices = jt.contrib.concat(roi_indices, dim=0) 467 | return rpn_locs, rpn_scores, rois, roi_indices, anchor 468 | 469 | 470 | --------------------------------------------------------------------------------