├── README.md ├── data └── README.md ├── detector.py ├── load_data.py ├── requirements.txt ├── src ├── __init__.py ├── _utils.py ├── bbox_tools.py ├── image_list.py ├── modules.py ├── roi_heads.py ├── roi_layers │ ├── __init__.py │ ├── _utils.py │ ├── poolers.py │ ├── ps_roi_align.py │ └── ps_roi_pool.py ├── rpn.py └── transform.py ├── thundernet ├── ShufflenetV2.py ├── module.py └── snet.py ├── train.py ├── train.sh └── utils └── losses.py /README.md: -------------------------------------------------------------------------------- 1 | # Thundernet-pytorch 2 | thundernet object detection 3 | 4 | 5 | 6 | ## dataset 7 | 8 | we used `COCO2017` dataset to training, dataset structure: 9 | 10 | ``` 11 | COCO 12 | ├── annotations 13 | │ ├── instances_train2017.json 14 | │ └── instances_val2017.json 15 | │── images 16 | ├── train2017 17 | └── val2017 18 | ``` 19 | 20 | 21 | 22 | ## training 23 | 24 | running script: 25 | 26 | ```bash 27 | bash train.sh 28 | ``` 29 | 30 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liguiyuan/Thundernet-pytorch/0d62cd55430d5ce55560c1efc43d552b2d0b6671/data/README.md -------------------------------------------------------------------------------- /detector.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | from torch.jit.annotations import Tuple, List, Dict, Optional 11 | 12 | from thundernet.snet import SNet49 13 | from thundernet.module import CEM, SAM, RCNNSubNetHead, ThunderNetPredictor 14 | 15 | from src.roi_layers.ps_roi_align import PSRoIAlign 16 | from src.bbox_tools import generate_anchors 17 | from src.rpn import AnchorGenerator 18 | from src.rpn import RegionProposalNetwork 19 | from src.rpn import RPNHead 20 | from src.roi_layers.poolers import MultiScaleRoIAlign 21 | from src.roi_heads import RoIHeads 22 | from src.transform import GeneralizedRCNNTransform 23 | 24 | from collections import OrderedDict 25 | import warnings 26 | 27 | 28 | class DetectNet(nn.Module): 29 | """ 30 | if your backbone returns a Tensor, featmap_names is expected to be ['0']. 31 | More generally, the backbone should return an OrderedDict[Tensor], 32 | and in featmap_names you can choose which feature maps to use. 33 | """ 34 | def __init__(self, backbone, num_classes=None, 35 | # RPN parameters 36 | rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=100, 37 | rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, 38 | 39 | rpn_mns_thresh=0.7, 40 | rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, 41 | rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, 42 | 43 | # Box parameters 44 | box_ps_roi_align=None, box_head=None, box_predictor=None, 45 | box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, 46 | box_fg_iou_thresh=0.5,box_bg_iou_thresh=0.5, 47 | box_batch_size_per_image=512, box_positive_fraction=0.25, 48 | bbox_reg_weights=None): 49 | super(DetectNet, self).__init__() 50 | 51 | if not hasattr(backbone, "out_channels"): 52 | raise ValueError( 53 | "backbone should contain an attribute out_channels " 54 | "specifying the number of output channels (assumed to be the " 55 | "same for all the levels)") 56 | 57 | assert isinstance(box_ps_roi_align, (MultiScaleRoIAlign, type(None))) 58 | 59 | if num_classes is not None: 60 | if box_predictor is not None: 61 | raise ValueError("num_classes should be None when box_predictor is specified") 62 | else: 63 | if box_predictor is None: 64 | raise ValueError("num_classes should not be None when box_predictor " 65 | "is not specified") 66 | 67 | out_channels = backbone.out_channels # 245 68 | 69 | self.backbone = backbone 70 | 71 | self.cem = CEM() # CEM module 72 | self.sam = SAM() # SAM module 73 | 74 | # rpn 75 | anchor_sizes = ((32, 64, 128, 256, 512),) # anchor sizes 76 | aspect_ratios = ((0.5, 0.75, 1.0, 1.33, 2.0),) # aspect ratios, paper pyperparameters 77 | rpn_anchor_generator = AnchorGenerator(sizes=anchor_sizes, aspect_ratios=aspect_ratios) 78 | 79 | rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) 80 | 81 | rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) 82 | rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) 83 | 84 | self.rpn = RegionProposalNetwork( 85 | rpn_anchor_generator, rpn_head, 86 | rpn_fg_iou_thresh, rpn_bg_iou_thresh, 87 | rpn_batch_size_per_image, rpn_positive_fraction, 88 | rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_mns_thresh) 89 | 90 | 91 | # ps roi align 92 | if box_ps_roi_align is None: 93 | box_ps_roi_align = MultiScaleRoIAlign( 94 | featmap_names=['0', '1', '2', '3'], 95 | output_size=7, 96 | sampling_ratio=2) 97 | 98 | # R-CNN subnet 99 | if box_head is None: 100 | resolution = box_ps_roi_align.output_size[0] # size: (7, 7) 101 | representation_size = 1024 102 | box_out_channels = 5 103 | box_head = RCNNSubNetHead( 104 | box_out_channels * resolution ** 2, # 5 * 7 * 7 105 | representation_size) 106 | 107 | if box_predictor is None: 108 | representation_size = 1024 109 | box_predictor = ThunderNetPredictor(representation_size, num_classes) 110 | 111 | self.roi_heads = RoIHeads( 112 | box_ps_roi_align, box_head, box_predictor, 113 | box_fg_iou_thresh, box_bg_iou_thresh, 114 | box_batch_size_per_image, box_positive_fraction, 115 | bbox_reg_weights, 116 | box_score_thresh, box_nms_thresh, box_detections_per_img) 117 | 118 | self.transform = GeneralizedRCNNTransform() 119 | 120 | 121 | def forward(self, images, targets=None): 122 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) 123 | """ 124 | Arguments: 125 | images (list[Tensor]): images to be processed 126 | targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) 127 | Returns: 128 | result (list[BoxList] or dict[Tensor]): the output from the model. 129 | During training, it returns a dict[Tensor] which contains the losses. 130 | During testing, it returns list[BoxList] contains additional fields 131 | like `scores`, `labels` and `mask` (for Mask R-CNN models). 132 | """ 133 | 134 | original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) 135 | for img in images: 136 | val = img.shape[-2:] 137 | assert len(val) == 2 138 | original_image_sizes.append((val[0], val[1])) # (h, w) 139 | 140 | # backbone 141 | _, c4_feature, c5_feature = self.backbone(images) 142 | images, targets = self.transform(images, targets) # transform to list 143 | 144 | # cem 145 | cem_feature = self.cem(c4_feature, c5_feature) # [20, 20, 245] 146 | cem_feature_output = cem_feature 147 | 148 | if isinstance(cem_feature, torch.Tensor): 149 | cem_feature = OrderedDict([('0', cem_feature)]) 150 | 151 | # rpn 152 | proposals, proposal_losses, rpn_output = self.rpn(images, cem_feature, targets) 153 | 154 | # sam 155 | sam_feature = self.sam(rpn_output, cem_feature_output) 156 | 157 | if isinstance(sam_feature, torch.Tensor): 158 | sam_feature = OrderedDict([('0', sam_feature)]) 159 | 160 | detections, detector_losses = self.roi_heads(sam_feature, proposals, images.image_sizes, targets) 161 | #detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) # testing predict 162 | 163 | return detector_losses, proposal_losses 164 | 165 | def ThunderNet(): 166 | snet = SNet49() 167 | snet.out_channels = 245 168 | thundernet = DetectNet(snet, num_classes=80) 169 | 170 | return thundernet 171 | 172 | 173 | #if __name__ == '__main__': 174 | # thundernet = ThunderNet() 175 | # print('thundernet: ', thundernet) 176 | 177 | -------------------------------------------------------------------------------- /load_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import numpy as np 4 | import os 5 | import cv2 6 | 7 | import torch 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | from pycocotools.coco import COCO 11 | 12 | 13 | class CocoDataset(Dataset): 14 | def __init__(self, root_dir, set_name='train2017', transform=None): 15 | super(CocoDataset, self).__init__() 16 | 17 | self.root_dir = root_dir 18 | self.set_name = set_name 19 | self.transform = transform 20 | 21 | self.coco = COCO(os.path.join(self.root_dir, 'annotations', 'instances_' + self.set_name + '.json')) 22 | self.image_ids = self.coco.getImgIds() 23 | 24 | self.load_classes() 25 | 26 | def load_classes(self): 27 | # load class names (name -> label) 28 | categories = self.coco.loadCats(self.coco.getCatIds()) 29 | categories.sort(key=lambda x: x['id']) 30 | 31 | self.classes = {} 32 | self.coco_labels = {} 33 | self.coco_labels_inverse = {} 34 | for c in categories: 35 | self.coco_labels[len(self.classes)] = c['id'] 36 | self.coco_labels_inverse[c['id']] = len(self.classes) 37 | self.classes[c['name']] = len(self.classes) 38 | 39 | # also load the reverse (label -> name) 40 | self.labels = {} 41 | for key, value in self.classes.items(): 42 | self.labels[value] = key 43 | 44 | def __len__(self): 45 | return len(self.image_ids) 46 | 47 | def __getitem__(self, idx): 48 | img = self.load_image(idx) 49 | annot = self.load_annotations(idx) 50 | sample = {'img': img, 'annot': annot} 51 | if self.transform: 52 | sample = self.transform(sample) 53 | 54 | return sample 55 | 56 | def load_image(self, image_index): 57 | image_info = self.coco.loadImgs(self.image_ids[image_index])[0] 58 | path = os.path.join(self.root_dir, 'images', self.set_name, image_info['file_name']) 59 | #img = Image.open(path).convert('RGB') 60 | # if len(img.size) == 2: 61 | # img = skimage.color.gray2rgb(img) 62 | img = cv2.imread(path) 63 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 64 | 65 | return img.astype(np.float32) / 255. 66 | 67 | def load_annotations(self, image_index): 68 | # get ground truth annotations 69 | annotations_ids = self.coco.getAnnIds(imgIds=self.image_ids[image_index], iscrowd=False) 70 | annotations = np.zeros((0, 5)) 71 | 72 | # some images appear to miss annotations 73 | if len(annotations_ids) == 0: 74 | return annotations 75 | 76 | # parse annotations 77 | coco_annotations = self.coco.loadAnns(annotations_ids) 78 | for idx, a in enumerate(coco_annotations): 79 | # some annotations have basically no width / height, skip them 80 | if a['bbox'][2] < 1 or a['bbox'][3] < 1: 81 | continue 82 | 83 | annotation = np.zeros((1, 5)) 84 | annotation[0, :4] = a['bbox'] 85 | annotation[0, 4] = self.coco_label_to_label(a['category_id']) 86 | annotations = np.append(annotations, annotation, axis=0) 87 | 88 | # transform from [x, y, w, h] to [x1, y1, x2, y2] 89 | annotations[:, 2] = annotations[:, 0] + annotations[:, 2] 90 | annotations[:, 3] = annotations[:, 1] + annotations[:, 3] 91 | 92 | return annotations 93 | 94 | 95 | def coco_label_to_label(self, coco_label): 96 | return self.coco_labels_inverse[coco_label] 97 | 98 | def label_to_coco_label(self, label): 99 | return self.coco_labels[label] 100 | 101 | def num_classes(self): 102 | return 80 103 | 104 | 105 | def collater(data): 106 | imgs = [s['img'] for s in data] 107 | annots = [s['annot'] for s in data] 108 | scales = [s['scale'] for s in data] 109 | 110 | imgs = torch.from_numpy(np.stack(imgs, axis=0)) 111 | max_num_annots = max(annot.shape[0] for annot in annots) 112 | 113 | if max_num_annots > 0: 114 | annot_padded = torch.ones((len(annots), max_num_annots, 5)) * -1 115 | 116 | if max_num_annots > 0: 117 | for idx, annot in enumerate(annots): 118 | if annot.shape[0] > 0: 119 | annot_padded[idx, :annot.shape[0], :] = annot 120 | else: 121 | annot_padded = torch.ones((len(annots), 1, 5)) * -1 122 | 123 | imgs = imgs.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 124 | 125 | return {'img': imgs, 'annot': annot_padded, 'scale': scales} 126 | 127 | 128 | class Resizer(object): 129 | """Convert ndarrays in sample to Tensors.""" 130 | def __call__(self, sample, common_size=320): 131 | image, annots = sample['img'], sample['annot'] 132 | height, width, _ = image.shape 133 | if height > width: 134 | scale = common_size / height 135 | resized_height = common_size 136 | resized_width = int(width * scale) 137 | else: 138 | scale = common_size / width 139 | resized_height = int(height * scale) 140 | resized_width = common_size 141 | 142 | image = cv2.resize(image, (resized_width, resized_height)) # image resize 143 | 144 | new_image = np.zeros((common_size, common_size, 3)) 145 | new_image[0:resized_height, 0:resized_width] = image 146 | 147 | annots[:, :4] *= scale # resize boxes, [x1, y1, x2, y2] 148 | 149 | return {'img': torch.from_numpy(new_image), 'annot': torch.from_numpy(annots), 'scale': scale} 150 | 151 | class Augmenter(object): 152 | """Convert ndarrays in sample to Tensors.""" 153 | 154 | def __call__(self, sample, flip_x=0.5): 155 | if np.random.rand() < flip_x: 156 | image, annots = sample['img'], sample['annot'] 157 | image = image[:, ::-1, :] # flip 158 | 159 | rows, cols, channels = image.shape 160 | 161 | x1 = annots[:, 0].copy() 162 | x2 = annots[:, 2].copy() 163 | 164 | x_tmp = x1.copy() 165 | 166 | annots[:, 0] = cols - x2 167 | annots[:, 2] = cols - x_tmp 168 | 169 | sample = {'img': image, 'annot': annots} 170 | 171 | return sample 172 | 173 | 174 | class Normalizer(object): 175 | 176 | def __init__(self): 177 | self.mean = np.array([[[0.485, 0.456, 0.406]]]) 178 | self.std = np.array([[[0.229, 0.224, 0.225]]]) 179 | 180 | def __call__(self, sample): 181 | image, annots = sample['img'], sample['annot'] 182 | 183 | return {'img': ((image.astype(np.float32) - self.mean) / self.std), 'annot': annots} 184 | 185 | 186 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 1.4.0 2 | torchvision 0.5.0 3 | tqdm 4 | pycocotools -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liguiyuan/Thundernet-pytorch/0d62cd55430d5ce55560c1efc43d552b2d0b6671/src/__init__.py -------------------------------------------------------------------------------- /src/_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.jit.annotations import List, Tuple 5 | from torch import Tensor 6 | import torchvision 7 | 8 | 9 | # TODO: https://github.com/pytorch/pytorch/issues/26727 10 | def zeros_like(tensor, dtype): 11 | # type: (Tensor, int) -> Tensor 12 | return torch.zeros_like(tensor, dtype=dtype, layout=tensor.layout, 13 | device=tensor.device, pin_memory=tensor.is_pinned()) 14 | 15 | 16 | @torch.jit.script 17 | class BalancedPositiveNegativeSampler(object): 18 | """ 19 | This class samples batches, ensuring that they contain a fixed proportion of positives 20 | """ 21 | 22 | def __init__(self, batch_size_per_image, positive_fraction): 23 | # type: (int, float) 24 | """ 25 | Arguments: 26 | batch_size_per_image (int): number of elements to be selected per image 27 | positive_fraction (float): percentace of positive elements per batch 28 | """ 29 | self.batch_size_per_image = batch_size_per_image 30 | self.positive_fraction = positive_fraction 31 | 32 | def __call__(self, matched_idxs): 33 | # type: (List[Tensor]) 34 | """ 35 | Arguments: 36 | matched idxs: list of tensors containing -1, 0 or positive values. 37 | Each tensor corresponds to a specific image. 38 | -1 values are ignored, 0 are considered as negatives and > 0 as 39 | positives. 40 | Returns: 41 | pos_idx (list[tensor]) 42 | neg_idx (list[tensor]) 43 | Returns two lists of binary masks for each image. 44 | The first list contains the positive elements that were selected, 45 | and the second list the negative example. 46 | """ 47 | pos_idx = [] 48 | neg_idx = [] 49 | for matched_idxs_per_image in matched_idxs: 50 | positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1) 51 | negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1) 52 | 53 | num_pos = int(self.batch_size_per_image * self.positive_fraction) 54 | # protect against not enough positive examples 55 | num_pos = min(positive.numel(), num_pos) 56 | num_neg = self.batch_size_per_image - num_pos 57 | # protect against not enough negative examples 58 | num_neg = min(negative.numel(), num_neg) 59 | 60 | # randomly select positive and negative examples 61 | perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] 62 | perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] 63 | 64 | pos_idx_per_image = positive[perm1] 65 | neg_idx_per_image = negative[perm2] 66 | 67 | # create binary mask from indices 68 | pos_idx_per_image_mask = zeros_like( 69 | matched_idxs_per_image, dtype=torch.uint8 70 | ) 71 | neg_idx_per_image_mask = zeros_like( 72 | matched_idxs_per_image, dtype=torch.uint8 73 | ) 74 | 75 | pos_idx_per_image_mask[pos_idx_per_image] = torch.tensor(1, dtype=torch.uint8) 76 | neg_idx_per_image_mask[neg_idx_per_image] = torch.tensor(1, dtype=torch.uint8) 77 | 78 | pos_idx.append(pos_idx_per_image_mask) 79 | neg_idx.append(neg_idx_per_image_mask) 80 | 81 | return pos_idx, neg_idx 82 | 83 | 84 | @torch.jit.script 85 | def encode_boxes(reference_boxes, proposals, weights): 86 | # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 87 | """ 88 | Encode a set of proposals with respect to some 89 | reference boxes 90 | Arguments: 91 | reference_boxes (Tensor): reference boxes 92 | proposals (Tensor): boxes to be encoded 93 | """ 94 | 95 | # perform some unpacking to make it JIT-fusion friendly 96 | wx = weights[0] 97 | wy = weights[1] 98 | ww = weights[2] 99 | wh = weights[3] 100 | 101 | proposals_x1 = proposals[:, 0].unsqueeze(1) 102 | proposals_y1 = proposals[:, 1].unsqueeze(1) 103 | proposals_x2 = proposals[:, 2].unsqueeze(1) 104 | proposals_y2 = proposals[:, 3].unsqueeze(1) 105 | 106 | reference_boxes_x1 = reference_boxes[:, 0].unsqueeze(1) 107 | reference_boxes_y1 = reference_boxes[:, 1].unsqueeze(1) 108 | reference_boxes_x2 = reference_boxes[:, 2].unsqueeze(1) 109 | reference_boxes_y2 = reference_boxes[:, 3].unsqueeze(1) 110 | 111 | # implementation starts here 112 | ex_widths = proposals_x2 - proposals_x1 113 | ex_heights = proposals_y2 - proposals_y1 114 | ex_ctr_x = proposals_x1 + 0.5 * ex_widths 115 | ex_ctr_y = proposals_y1 + 0.5 * ex_heights 116 | 117 | gt_widths = reference_boxes_x2 - reference_boxes_x1 118 | gt_heights = reference_boxes_y2 - reference_boxes_y1 119 | gt_ctr_x = reference_boxes_x1 + 0.5 * gt_widths 120 | gt_ctr_y = reference_boxes_y1 + 0.5 * gt_heights 121 | 122 | targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths 123 | targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights 124 | targets_dw = ww * torch.log(gt_widths / ex_widths) 125 | targets_dh = wh * torch.log(gt_heights / ex_heights) 126 | 127 | targets = torch.cat((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) 128 | return targets 129 | 130 | 131 | @torch.jit.script 132 | class BoxCoder(object): 133 | """ 134 | This class encodes and decodes a set of bounding boxes into 135 | the representation used for training the regressors. 136 | """ 137 | 138 | def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): 139 | # type: (Tuple[float, float, float, float], float) 140 | """ 141 | Arguments: 142 | weights (4-element tuple) 143 | bbox_xform_clip (float) 144 | """ 145 | self.weights = weights 146 | self.bbox_xform_clip = bbox_xform_clip 147 | 148 | def encode(self, reference_boxes, proposals): 149 | # type: (List[Tensor], List[Tensor]) 150 | boxes_per_image = [len(b) for b in reference_boxes] 151 | reference_boxes = torch.cat(reference_boxes, dim=0) 152 | proposals = torch.cat(proposals, dim=0) 153 | targets = self.encode_single(reference_boxes, proposals) 154 | return targets.split(boxes_per_image, 0) 155 | 156 | def encode_single(self, reference_boxes, proposals): 157 | """ 158 | Encode a set of proposals with respect to some 159 | reference boxes 160 | Arguments: 161 | reference_boxes (Tensor): reference boxes 162 | proposals (Tensor): boxes to be encoded 163 | """ 164 | dtype = reference_boxes.dtype 165 | device = reference_boxes.device 166 | weights = torch.as_tensor(self.weights, dtype=dtype, device=device) 167 | targets = encode_boxes(reference_boxes, proposals, weights) 168 | 169 | return targets 170 | 171 | def decode(self, rel_codes, boxes): 172 | # type: (Tensor, List[Tensor]) 173 | assert isinstance(boxes, (list, tuple)) 174 | assert isinstance(rel_codes, torch.Tensor) 175 | boxes_per_image = [b.size(0) for b in boxes] 176 | concat_boxes = torch.cat(boxes, dim=0) 177 | box_sum = 0 178 | for val in boxes_per_image: 179 | box_sum += val 180 | pred_boxes = self.decode_single( 181 | rel_codes.reshape(box_sum, -1), concat_boxes 182 | ) 183 | return pred_boxes.reshape(box_sum, -1, 4) 184 | 185 | def decode_single(self, rel_codes, boxes): 186 | """ 187 | From a set of original boxes and encoded relative box offsets, 188 | get the decoded boxes. 189 | Arguments: 190 | rel_codes (Tensor): encoded boxes 191 | boxes (Tensor): reference boxes. 192 | """ 193 | 194 | boxes = boxes.to(rel_codes.dtype) 195 | 196 | widths = boxes[:, 2] - boxes[:, 0] 197 | heights = boxes[:, 3] - boxes[:, 1] 198 | ctr_x = boxes[:, 0] + 0.5 * widths 199 | ctr_y = boxes[:, 1] + 0.5 * heights 200 | 201 | wx, wy, ww, wh = self.weights 202 | dx = rel_codes[:, 0::4] / wx 203 | dy = rel_codes[:, 1::4] / wy 204 | dw = rel_codes[:, 2::4] / ww 205 | dh = rel_codes[:, 3::4] / wh 206 | 207 | # Prevent sending too large values into torch.exp() 208 | dw = torch.clamp(dw, max=self.bbox_xform_clip) 209 | dh = torch.clamp(dh, max=self.bbox_xform_clip) 210 | 211 | pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] 212 | pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] 213 | pred_w = torch.exp(dw) * widths[:, None] 214 | pred_h = torch.exp(dh) * heights[:, None] 215 | 216 | pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w 217 | pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h 218 | pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w 219 | pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h 220 | pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1) 221 | return pred_boxes 222 | 223 | 224 | @torch.jit.script 225 | class Matcher(object): 226 | """ 227 | This class assigns to each predicted "element" (e.g., a box) a ground-truth 228 | element. Each predicted element will have exactly zero or one matches; each 229 | ground-truth element may be assigned to zero or more predicted elements. 230 | Matching is based on the MxN match_quality_matrix, that characterizes how well 231 | each (ground-truth, predicted)-pair match. For example, if the elements are 232 | boxes, the matrix may contain box IoU overlap values. 233 | The matcher returns a tensor of size N containing the index of the ground-truth 234 | element m that matches to prediction n. If there is no match, a negative value 235 | is returned. 236 | """ 237 | 238 | BELOW_LOW_THRESHOLD = -1 239 | BETWEEN_THRESHOLDS = -2 240 | 241 | __annotations__ = { 242 | 'BELOW_LOW_THRESHOLD': int, 243 | 'BETWEEN_THRESHOLDS': int, 244 | } 245 | 246 | def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): 247 | # type: (float, float, bool) 248 | """ 249 | Args: 250 | high_threshold (float): quality values greater than or equal to 251 | this value are candidate matches. 252 | low_threshold (float): a lower quality threshold used to stratify 253 | matches into three levels: 254 | 1) matches >= high_threshold 255 | 2) BETWEEN_THRESHOLDS matches in [low_threshold, high_threshold) 256 | 3) BELOW_LOW_THRESHOLD matches in [0, low_threshold) 257 | allow_low_quality_matches (bool): if True, produce additional matches 258 | for predictions that have only low-quality match candidates. See 259 | set_low_quality_matches_ for more details. 260 | """ 261 | self.BELOW_LOW_THRESHOLD = -1 262 | self.BETWEEN_THRESHOLDS = -2 263 | assert low_threshold <= high_threshold 264 | self.high_threshold = high_threshold 265 | self.low_threshold = low_threshold 266 | self.allow_low_quality_matches = allow_low_quality_matches 267 | 268 | def __call__(self, match_quality_matrix): 269 | """ 270 | Args: 271 | match_quality_matrix (Tensor[float]): an MxN tensor, containing the 272 | pairwise quality between M ground-truth elements and N predicted elements. 273 | Returns: 274 | matches (Tensor[int64]): an N tensor where N[i] is a matched gt in 275 | [0, M - 1] or a negative value indicating that prediction i could not 276 | be matched. 277 | """ 278 | if match_quality_matrix.numel() == 0: 279 | # empty targets or proposals not supported during training 280 | if match_quality_matrix.shape[0] == 0: 281 | raise ValueError( 282 | "No ground-truth boxes available for one of the images " 283 | "during training") 284 | else: 285 | raise ValueError( 286 | "No proposal boxes available for one of the images " 287 | "during training") 288 | 289 | # match_quality_matrix is M (gt) x N (predicted) 290 | # Max over gt elements (dim 0) to find best gt candidate for each prediction 291 | matched_vals, matches = match_quality_matrix.max(dim=0) 292 | if self.allow_low_quality_matches: 293 | all_matches = matches.clone() 294 | else: 295 | all_matches = None 296 | 297 | # Assign candidate matches with low quality to negative (unassigned) values 298 | below_low_threshold = matched_vals < self.low_threshold 299 | between_thresholds = (matched_vals >= self.low_threshold) & ( 300 | matched_vals < self.high_threshold 301 | ) 302 | matches[below_low_threshold] = torch.tensor(self.BELOW_LOW_THRESHOLD) 303 | matches[between_thresholds] = torch.tensor(self.BETWEEN_THRESHOLDS) 304 | 305 | if self.allow_low_quality_matches: 306 | assert all_matches is not None 307 | self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) 308 | 309 | return matches 310 | 311 | def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): 312 | """ 313 | Produce additional matches for predictions that have only low-quality matches. 314 | Specifically, for each ground-truth find the set of predictions that have 315 | maximum overlap with it (including ties); for each prediction in that set, if 316 | it is unmatched, then match it to the ground-truth with which it has the highest 317 | quality value. 318 | """ 319 | # For each gt, find the prediction with which it has highest quality 320 | highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) 321 | # Find highest quality match available, even if it is low, including ties 322 | gt_pred_pairs_of_highest_quality = torch.nonzero( 323 | match_quality_matrix == highest_quality_foreach_gt[:, None] 324 | ) 325 | # Example gt_pred_pairs_of_highest_quality: 326 | # tensor([[ 0, 39796], 327 | # [ 1, 32055], 328 | # [ 1, 32070], 329 | # [ 2, 39190], 330 | # [ 2, 40255], 331 | # [ 3, 40390], 332 | # [ 3, 41455], 333 | # [ 4, 45470], 334 | # [ 5, 45325], 335 | # [ 5, 46390]]) 336 | # Each row is a (gt index, prediction index) 337 | # Note how gt items 1, 2, 3, and 5 each have two ties 338 | 339 | pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1] 340 | matches[pred_inds_to_update] = all_matches[pred_inds_to_update] -------------------------------------------------------------------------------- /src/bbox_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def generate_anchors(base_size=16, ratios=[0.5, 1, 2], 4 | scales=2**np.arange(3, 6)): 5 | """ 6 | Generate anchor (reference) windows by enumerating aspect ratios X 7 | scales wrt a reference (0, 0, 15, 15) window. 8 | """ 9 | base_anchor = np.array([1, 1, base_size, base_size]) - 1 10 | ratio_anchors = _ratio_enum(base_anchor, ratios) 11 | anchors = np.vstack([_scale_enum(ratio_anchors[i, :], scales) 12 | for i in range(ratio_anchors.shape[0])]) 13 | return anchors 14 | 15 | def _whctrs(anchor): 16 | """ 17 | Return width, height, x center, and y center for an anchor (window). 18 | """ 19 | 20 | w = anchor[2] - anchor[0] + 1 21 | h = anchor[3] - anchor[1] + 1 22 | x_ctr = anchor[0] + 0.5 * (w - 1) 23 | y_ctr = anchor[1] + 0.5 * (h - 1) 24 | return w, h, x_ctr, y_ctr 25 | 26 | def _mkanchors(ws, hs, x_ctr, y_ctr): 27 | """ 28 | Given a vector of widths (ws) and heights (hs) around a center 29 | (x_ctr, y_ctr), output a set of anchors (windows). 30 | """ 31 | 32 | ws = ws[:, np.newaxis] 33 | hs = hs[:, np.newaxis] 34 | anchors = np.hstack((x_ctr - 0.5 * (ws - 1), 35 | y_ctr - 0.5 * (hs - 1), 36 | x_ctr + 0.5 * (ws - 1), 37 | y_ctr + 0.5 * (hs - 1))) 38 | return anchors 39 | 40 | def _ratio_enum(anchor, ratios): 41 | """ 42 | Enumerate a set of anchors for each aspect ratio wrt an anchor. 43 | """ 44 | 45 | w, h, x_ctr, y_ctr = _whctrs(anchor) 46 | size = w * h 47 | size_ratios = size / ratios 48 | ws = np.round(np.sqrt(size_ratios)) 49 | hs = np.round(ws * ratios) 50 | anchors = _mkanchors(ws, hs, x_ctr, y_ctr) 51 | return anchors 52 | 53 | def _scale_enum(anchor, scales): 54 | """ 55 | Enumerate a set of anchors for each scale wrt an anchor. 56 | """ 57 | 58 | w, h, x_ctr, y_ctr = _whctrs(anchor) 59 | ws = w * scales 60 | hs = h * scales 61 | anchors = _mkanchors(ws, hs, x_ctr, y_ctr) 62 | return anchors 63 | 64 | #if __name__ == '__main__': 65 | # anchor = generate_anchors() 66 | -------------------------------------------------------------------------------- /src/image_list.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch.jit.annotations import List, Tuple 4 | from torch import Tensor 5 | 6 | 7 | @torch.jit.script 8 | class ImageList(object): 9 | """ 10 | Structure that holds a list of images (of possibly 11 | varying sizes) as a single tensor. 12 | This works by padding the images to the same size, 13 | and storing in a field the original sizes of each image 14 | """ 15 | 16 | def __init__(self, tensors, image_sizes): 17 | # type: (Tensor, List[Tuple[int, int]]) 18 | """ 19 | Arguments: 20 | tensors (tensor) 21 | image_sizes (list[tuple[int, int]]) 22 | """ 23 | self.tensors = tensors 24 | self.image_sizes = image_sizes 25 | 26 | def to(self, device): 27 | # type: (Device) # noqa 28 | cast_tensor = self.tensors.to(device) 29 | return ImageList(cast_tensor, self.image_sizes) -------------------------------------------------------------------------------- /src/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # TODO 7 | class CEM(nn.Module): 8 | def __init__(self, arg): 9 | super(CEM, self).__init__() 10 | self.conv4 = nn.Conv2d(245, 245, kernel_size=1, stride=1, padding=1) 11 | 12 | self.stage4 = stage(4) 13 | self.conv5 = nn.Conv2d(245, 245, kernel_size=1, stride=1, padding=1) 14 | self.conv5_upsample = nn.Upsample((10, 10), (2, 2), 'bilinear') 15 | 16 | self.avg_pool = nn.AvgPool2d(10) 17 | self.conv_glb = nn.Conv2d(245, 245, kernel_size=1, stride=1, padding=1) 18 | # self.broadcast = 19 | 20 | def forward(slef, inputs): 21 | c4 = inputs[0] # stage3 output feature map 22 | c4_lat = self.conv_lat(c4) 23 | 24 | c5 = inputs[1] # stage4 output feature map 25 | c5_lat = self.conv(c5) 26 | c5_lat = self.conv5_upsample(c5_lat) 27 | 28 | c_glb = self.avg_pool(c5) 29 | c_glb_lat = self.conv_glb(c_glb) 30 | 31 | out = c4_lat + c5_lat + cglb_lat 32 | 33 | return out 34 | 35 | 36 | 37 | class SAM(nn.Module): 38 | def __init__(self): 39 | super(SAM, self).__init__() 40 | self.conv1 = nn.Conv2d(245, 245, 1, 1, 0, bias=False) 41 | self.bn = nn.BatchNorm2d(245) 42 | 43 | def forward(self, input): 44 | cem = input[0] 45 | rpn = input[1] 46 | 47 | sam = slef.conv1(rpn) 48 | sam = self.bn(sam) 49 | sam = F.sigmoid(sam) 50 | out = cem * sam 51 | 52 | return out 53 | 54 | 55 | -------------------------------------------------------------------------------- /src/roi_heads.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | import torch.nn.functional as F 5 | from torch import nn, Tensor 6 | 7 | from torchvision.ops import boxes as box_ops 8 | from torchvision.ops import misc as misc_nn_ops 9 | 10 | from torchvision.ops import roi_align 11 | 12 | #from . import _utils as det_utils 13 | from torchvision.models.detection import _utils as det_utils 14 | 15 | from torch.jit.annotations import Optional, List, Dict, Tuple 16 | 17 | 18 | def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): 19 | # type: (Tensor, Tensor, List[Tensor], List[Tensor]) 20 | """ 21 | Computes the loss for Faster R-CNN. 22 | Arguments: 23 | class_logits (Tensor) 24 | box_regression (Tensor) 25 | labels (list[BoxList]) 26 | regression_targets (Tensor) 27 | Returns: 28 | classification_loss (Tensor) 29 | box_loss (Tensor) 30 | """ 31 | 32 | labels = torch.cat(labels, dim=0) 33 | regression_targets = torch.cat(regression_targets, dim=0) 34 | 35 | classification_loss = F.cross_entropy(class_logits, labels) 36 | 37 | # get indices that correspond to the regression targets for 38 | # the corresponding ground truth labels, to be used with 39 | # advanced indexing 40 | sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1) 41 | labels_pos = labels[sampled_pos_inds_subset] 42 | N, num_classes = class_logits.shape 43 | box_regression = box_regression.reshape(N, -1, 4) 44 | 45 | box_loss = F.smooth_l1_loss( 46 | box_regression[sampled_pos_inds_subset, labels_pos], 47 | regression_targets[sampled_pos_inds_subset], 48 | reduction="sum", 49 | ) 50 | box_loss = box_loss / labels.numel() 51 | 52 | return classification_loss, box_loss 53 | 54 | 55 | def maskrcnn_inference(x, labels): 56 | # type: (Tensor, List[Tensor]) 57 | """ 58 | From the results of the CNN, post process the masks 59 | by taking the mask corresponding to the class with max 60 | probability (which are of fixed size and directly output 61 | by the CNN) and return the masks in the mask field of the BoxList. 62 | Arguments: 63 | x (Tensor): the mask logits 64 | labels (list[BoxList]): bounding boxes that are used as 65 | reference, one for ech image 66 | Returns: 67 | results (list[BoxList]): one BoxList for each image, containing 68 | the extra field mask 69 | """ 70 | mask_prob = x.sigmoid() 71 | 72 | # select masks coresponding to the predicted classes 73 | num_masks = x.shape[0] 74 | boxes_per_image = [len(l) for l in labels] 75 | labels = torch.cat(labels) 76 | index = torch.arange(num_masks, device=labels.device) 77 | mask_prob = mask_prob[index, labels][:, None] 78 | 79 | if len(boxes_per_image) == 1: 80 | # TODO : remove when dynamic split supported in ONNX 81 | # and remove assignment to mask_prob_list, just assign to mask_prob 82 | mask_prob_list = [mask_prob] 83 | else: 84 | mask_prob_list = mask_prob.split(boxes_per_image, dim=0) 85 | 86 | return mask_prob_list 87 | 88 | 89 | def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): 90 | # type: (Tensor, Tensor, Tensor, int) 91 | """ 92 | Given segmentation masks and the bounding boxes corresponding 93 | to the location of the masks in the image, this function 94 | crops and resizes the masks in the position defined by the 95 | boxes. This prepares the masks for them to be fed to the 96 | loss computation as the targets. 97 | """ 98 | matched_idxs = matched_idxs.to(boxes) 99 | rois = torch.cat([matched_idxs[:, None], boxes], dim=1) 100 | gt_masks = gt_masks[:, None].to(rois) 101 | return roi_align(gt_masks, rois, (M, M), 1.)[:, 0] 102 | 103 | 104 | def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): 105 | # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) 106 | """ 107 | Arguments: 108 | proposals (list[BoxList]) 109 | mask_logits (Tensor) 110 | targets (list[BoxList]) 111 | Return: 112 | mask_loss (Tensor): scalar tensor containing the loss 113 | """ 114 | 115 | discretization_size = mask_logits.shape[-1] 116 | labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)] 117 | mask_targets = [ 118 | project_masks_on_boxes(m, p, i, discretization_size) 119 | for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) 120 | ] 121 | 122 | labels = torch.cat(labels, dim=0) 123 | mask_targets = torch.cat(mask_targets, dim=0) 124 | 125 | # torch.mean (in binary_cross_entropy_with_logits) doesn't 126 | # accept empty tensors, so handle it separately 127 | if mask_targets.numel() == 0: 128 | return mask_logits.sum() * 0 129 | 130 | mask_loss = F.binary_cross_entropy_with_logits( 131 | mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets 132 | ) 133 | return mask_loss 134 | 135 | 136 | def keypoints_to_heatmap(keypoints, rois, heatmap_size): 137 | # type: (Tensor, Tensor, int) 138 | offset_x = rois[:, 0] 139 | offset_y = rois[:, 1] 140 | scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) 141 | scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) 142 | 143 | offset_x = offset_x[:, None] 144 | offset_y = offset_y[:, None] 145 | scale_x = scale_x[:, None] 146 | scale_y = scale_y[:, None] 147 | 148 | x = keypoints[..., 0] 149 | y = keypoints[..., 1] 150 | 151 | x_boundary_inds = x == rois[:, 2][:, None] 152 | y_boundary_inds = y == rois[:, 3][:, None] 153 | 154 | x = (x - offset_x) * scale_x 155 | x = x.floor().long() 156 | y = (y - offset_y) * scale_y 157 | y = y.floor().long() 158 | 159 | x[x_boundary_inds] = torch.tensor(heatmap_size - 1) 160 | y[y_boundary_inds] = torch.tensor(heatmap_size - 1) 161 | 162 | valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) 163 | vis = keypoints[..., 2] > 0 164 | valid = (valid_loc & vis).long() 165 | 166 | lin_ind = y * heatmap_size + x 167 | heatmaps = lin_ind * valid 168 | 169 | return heatmaps, valid 170 | 171 | 172 | def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, 173 | widths_i, heights_i, offset_x_i, offset_y_i): 174 | num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64) 175 | 176 | width_correction = widths_i / roi_map_width 177 | height_correction = heights_i / roi_map_height 178 | 179 | roi_map = torch.nn.functional.interpolate( 180 | maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0] 181 | 182 | w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64) 183 | pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) 184 | 185 | x_int = (pos % w) 186 | y_int = ((pos - x_int) / w) 187 | 188 | x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * \ 189 | width_correction.to(dtype=torch.float32) 190 | y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * \ 191 | height_correction.to(dtype=torch.float32) 192 | 193 | xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32) 194 | xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32) 195 | xy_preds_i_2 = torch.ones((xy_preds_i_1.shape), dtype=torch.float32) 196 | xy_preds_i = torch.stack([xy_preds_i_0.to(dtype=torch.float32), 197 | xy_preds_i_1.to(dtype=torch.float32), 198 | xy_preds_i_2.to(dtype=torch.float32)], 0) 199 | 200 | # TODO: simplify when indexing without rank will be supported by ONNX 201 | end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \ 202 | .index_select(2, x_int.to(dtype=torch.int64))[:num_keypoints, 0, 0] 203 | return xy_preds_i, end_scores_i 204 | 205 | 206 | @torch.jit.script 207 | def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil, 208 | widths, heights, offset_x, offset_y, num_keypoints): 209 | xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device) 210 | end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device) 211 | 212 | for i in range(int(rois.size(0))): 213 | xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(maps, maps[i], 214 | widths_ceil[i], heights_ceil[i], 215 | widths[i], heights[i], 216 | offset_x[i], offset_y[i]) 217 | xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), 218 | xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0) 219 | end_scores = torch.cat((end_scores.to(dtype=torch.float32), 220 | end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0) 221 | return xy_preds, end_scores 222 | 223 | 224 | def heatmaps_to_keypoints(maps, rois): 225 | """Extract predicted keypoint locations from heatmaps. Output has shape 226 | (#rois, 4, #keypoints) with the 4 rows corresponding to (x, y, logit, prob) 227 | for each keypoint. 228 | """ 229 | # This function converts a discrete image coordinate in a HEATMAP_SIZE x 230 | # HEATMAP_SIZE image to a continuous keypoint coordinate. We maintain 231 | # consistency with keypoints_to_heatmap_labels by using the conversion from 232 | # Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a 233 | # continuous coordinate. 234 | offset_x = rois[:, 0] 235 | offset_y = rois[:, 1] 236 | 237 | widths = rois[:, 2] - rois[:, 0] 238 | heights = rois[:, 3] - rois[:, 1] 239 | widths = widths.clamp(min=1) 240 | heights = heights.clamp(min=1) 241 | widths_ceil = widths.ceil() 242 | heights_ceil = heights.ceil() 243 | 244 | num_keypoints = maps.shape[1] 245 | 246 | if torchvision._is_tracing(): 247 | xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(maps, rois, 248 | widths_ceil, heights_ceil, widths, heights, 249 | offset_x, offset_y, 250 | torch.scalar_tensor(num_keypoints, dtype=torch.int64)) 251 | return xy_preds.permute(0, 2, 1), end_scores 252 | 253 | xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device) 254 | end_scores = torch.zeros((len(rois), num_keypoints), dtype=torch.float32, device=maps.device) 255 | for i in range(len(rois)): 256 | roi_map_width = int(widths_ceil[i].item()) 257 | roi_map_height = int(heights_ceil[i].item()) 258 | width_correction = widths[i] / roi_map_width 259 | height_correction = heights[i] / roi_map_height 260 | roi_map = torch.nn.functional.interpolate( 261 | maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0] 262 | # roi_map_probs = scores_to_probs(roi_map.copy()) 263 | w = roi_map.shape[2] 264 | pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) 265 | 266 | x_int = pos % w 267 | y_int = (pos - x_int) // w 268 | # assert (roi_map_probs[k, y_int, x_int] == 269 | # roi_map_probs[k, :, :].max()) 270 | x = (x_int.float() + 0.5) * width_correction 271 | y = (y_int.float() + 0.5) * height_correction 272 | xy_preds[i, 0, :] = x + offset_x[i] 273 | xy_preds[i, 1, :] = y + offset_y[i] 274 | xy_preds[i, 2, :] = 1 275 | end_scores[i, :] = roi_map[torch.arange(num_keypoints), y_int, x_int] 276 | 277 | return xy_preds.permute(0, 2, 1), end_scores 278 | 279 | 280 | def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs): 281 | # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) 282 | N, K, H, W = keypoint_logits.shape 283 | assert H == W 284 | discretization_size = H 285 | heatmaps = [] 286 | valid = [] 287 | for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs): 288 | kp = gt_kp_in_image[midx] 289 | heatmaps_per_image, valid_per_image = keypoints_to_heatmap( 290 | kp, proposals_per_image, discretization_size 291 | ) 292 | heatmaps.append(heatmaps_per_image.view(-1)) 293 | valid.append(valid_per_image.view(-1)) 294 | 295 | keypoint_targets = torch.cat(heatmaps, dim=0) 296 | valid = torch.cat(valid, dim=0).to(dtype=torch.uint8) 297 | valid = torch.nonzero(valid).squeeze(1) 298 | 299 | # torch.mean (in binary_cross_entropy_with_logits) does'nt 300 | # accept empty tensors, so handle it sepaartely 301 | if keypoint_targets.numel() == 0 or len(valid) == 0: 302 | return keypoint_logits.sum() * 0 303 | 304 | keypoint_logits = keypoint_logits.view(N * K, H * W) 305 | 306 | keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid]) 307 | return keypoint_loss 308 | 309 | 310 | def keypointrcnn_inference(x, boxes): 311 | # type: (Tensor, List[Tensor]) 312 | kp_probs = [] 313 | kp_scores = [] 314 | 315 | boxes_per_image = [box.size(0) for box in boxes] 316 | 317 | if len(boxes_per_image) == 1: 318 | # TODO : remove when dynamic split supported in ONNX 319 | kp_prob, scores = heatmaps_to_keypoints(x, boxes[0]) 320 | return [kp_prob], [scores] 321 | 322 | x2 = x.split(boxes_per_image, dim=0) 323 | 324 | for xx, bb in zip(x2, boxes): 325 | kp_prob, scores = heatmaps_to_keypoints(xx, bb) 326 | kp_probs.append(kp_prob) 327 | kp_scores.append(scores) 328 | 329 | return kp_probs, kp_scores 330 | 331 | 332 | def _onnx_expand_boxes(boxes, scale): 333 | # type: (Tensor, float) 334 | w_half = (boxes[:, 2] - boxes[:, 0]) * .5 335 | h_half = (boxes[:, 3] - boxes[:, 1]) * .5 336 | x_c = (boxes[:, 2] + boxes[:, 0]) * .5 337 | y_c = (boxes[:, 3] + boxes[:, 1]) * .5 338 | 339 | w_half = w_half.to(dtype=torch.float32) * scale 340 | h_half = h_half.to(dtype=torch.float32) * scale 341 | 342 | boxes_exp0 = x_c - w_half 343 | boxes_exp1 = y_c - h_half 344 | boxes_exp2 = x_c + w_half 345 | boxes_exp3 = y_c + h_half 346 | boxes_exp = torch.stack((boxes_exp0, boxes_exp1, boxes_exp2, boxes_exp3), 1) 347 | return boxes_exp 348 | 349 | 350 | # the next two functions should be merged inside Masker 351 | # but are kept here for the moment while we need them 352 | # temporarily for paste_mask_in_image 353 | def expand_boxes(boxes, scale): 354 | # type: (Tensor, float) 355 | if torchvision._is_tracing(): 356 | return _onnx_expand_boxes(boxes, scale) 357 | w_half = (boxes[:, 2] - boxes[:, 0]) * .5 358 | h_half = (boxes[:, 3] - boxes[:, 1]) * .5 359 | x_c = (boxes[:, 2] + boxes[:, 0]) * .5 360 | y_c = (boxes[:, 3] + boxes[:, 1]) * .5 361 | 362 | w_half *= scale 363 | h_half *= scale 364 | 365 | boxes_exp = torch.zeros_like(boxes) 366 | boxes_exp[:, 0] = x_c - w_half 367 | boxes_exp[:, 2] = x_c + w_half 368 | boxes_exp[:, 1] = y_c - h_half 369 | boxes_exp[:, 3] = y_c + h_half 370 | return boxes_exp 371 | 372 | 373 | @torch.jit.unused 374 | def expand_masks_tracing_scale(M, padding): 375 | # type: (int, int) -> float 376 | return torch.tensor(M + 2 * padding).to(torch.float32) / torch.tensor(M).to(torch.float32) 377 | 378 | 379 | def expand_masks(mask, padding): 380 | # type: (Tensor, int) 381 | M = mask.shape[-1] 382 | if torch._C._get_tracing_state(): # could not import is_tracing(), not sure why 383 | scale = expand_masks_tracing_scale(M, padding) 384 | else: 385 | scale = float(M + 2 * padding) / M 386 | padded_mask = torch.nn.functional.pad(mask, (padding,) * 4) 387 | return padded_mask, scale 388 | 389 | 390 | def paste_mask_in_image(mask, box, im_h, im_w): 391 | # type: (Tensor, Tensor, int, int) 392 | TO_REMOVE = 1 393 | w = int(box[2] - box[0] + TO_REMOVE) 394 | h = int(box[3] - box[1] + TO_REMOVE) 395 | w = max(w, 1) 396 | h = max(h, 1) 397 | 398 | # Set shape to [batchxCxHxW] 399 | mask = mask.expand((1, 1, -1, -1)) 400 | 401 | # Resize mask 402 | mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) 403 | mask = mask[0][0] 404 | 405 | im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) 406 | x_0 = max(box[0], 0) 407 | x_1 = min(box[2] + 1, im_w) 408 | y_0 = max(box[1], 0) 409 | y_1 = min(box[3] + 1, im_h) 410 | 411 | im_mask[y_0:y_1, x_0:x_1] = mask[ 412 | (y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0]) 413 | ] 414 | return im_mask 415 | 416 | 417 | def _onnx_paste_mask_in_image(mask, box, im_h, im_w): 418 | one = torch.ones(1, dtype=torch.int64) 419 | zero = torch.zeros(1, dtype=torch.int64) 420 | 421 | w = (box[2] - box[0] + one) 422 | h = (box[3] - box[1] + one) 423 | w = torch.max(torch.cat((w, one))) 424 | h = torch.max(torch.cat((h, one))) 425 | 426 | # Set shape to [batchxCxHxW] 427 | mask = mask.expand((1, 1, mask.size(0), mask.size(1))) 428 | 429 | # Resize mask 430 | mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) 431 | mask = mask[0][0] 432 | 433 | x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero))) 434 | x_1 = torch.min(torch.cat((box[2].unsqueeze(0) + one, im_w.unsqueeze(0)))) 435 | y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero))) 436 | y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0)))) 437 | 438 | unpaded_im_mask = mask[(y_0 - box[1]):(y_1 - box[1]), 439 | (x_0 - box[0]):(x_1 - box[0])] 440 | 441 | # TODO : replace below with a dynamic padding when support is added in ONNX 442 | 443 | # pad y 444 | zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1)) 445 | zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1)) 446 | concat_0 = torch.cat((zeros_y0, 447 | unpaded_im_mask.to(dtype=torch.float32), 448 | zeros_y1), 0)[0:im_h, :] 449 | # pad x 450 | zeros_x0 = torch.zeros(concat_0.size(0), x_0) 451 | zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1) 452 | im_mask = torch.cat((zeros_x0, 453 | concat_0, 454 | zeros_x1), 1)[:, :im_w] 455 | return im_mask 456 | 457 | 458 | @torch.jit.script 459 | def _onnx_paste_masks_in_image_loop(masks, boxes, im_h, im_w): 460 | res_append = torch.zeros(0, im_h, im_w) 461 | for i in range(masks.size(0)): 462 | mask_res = _onnx_paste_mask_in_image(masks[i][0], boxes[i], im_h, im_w) 463 | mask_res = mask_res.unsqueeze(0) 464 | res_append = torch.cat((res_append, mask_res)) 465 | return res_append 466 | 467 | 468 | def paste_masks_in_image(masks, boxes, img_shape, padding=1): 469 | # type: (Tensor, Tensor, Tuple[int, int], int) 470 | masks, scale = expand_masks(masks, padding=padding) 471 | boxes = expand_boxes(boxes, scale).to(dtype=torch.int64) 472 | im_h, im_w = img_shape 473 | 474 | if torchvision._is_tracing(): 475 | return _onnx_paste_masks_in_image_loop(masks, boxes, 476 | torch.scalar_tensor(im_h, dtype=torch.int64), 477 | torch.scalar_tensor(im_w, dtype=torch.int64))[:, None] 478 | res = [ 479 | paste_mask_in_image(m[0], b, im_h, im_w) 480 | for m, b in zip(masks, boxes) 481 | ] 482 | if len(res) > 0: 483 | ret = torch.stack(res, dim=0)[:, None] 484 | else: 485 | ret = masks.new_empty((0, 1, im_h, im_w)) 486 | return ret 487 | 488 | 489 | class RoIHeads(torch.nn.Module): 490 | __annotations__ = { 491 | 'box_coder': det_utils.BoxCoder, 492 | 'proposal_matcher': det_utils.Matcher, 493 | 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, 494 | } 495 | 496 | def __init__(self, 497 | box_roi_pool, 498 | box_head, 499 | box_predictor, 500 | # Faster R-CNN training 501 | fg_iou_thresh, bg_iou_thresh, 502 | batch_size_per_image, positive_fraction, 503 | bbox_reg_weights, 504 | # Faster R-CNN inference 505 | score_thresh, 506 | nms_thresh, 507 | detections_per_img, 508 | # Mask 509 | mask_roi_pool=None, 510 | mask_head=None, 511 | mask_predictor=None, 512 | keypoint_roi_pool=None, 513 | keypoint_head=None, 514 | keypoint_predictor=None, 515 | ): 516 | super(RoIHeads, self).__init__() 517 | 518 | self.box_similarity = box_ops.box_iou 519 | # assign ground-truth boxes for each proposal 520 | self.proposal_matcher = det_utils.Matcher( 521 | fg_iou_thresh, 522 | bg_iou_thresh, 523 | allow_low_quality_matches=False) 524 | 525 | self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( 526 | batch_size_per_image, 527 | positive_fraction) 528 | 529 | if bbox_reg_weights is None: 530 | bbox_reg_weights = (10., 10., 5., 5.) 531 | self.box_coder = det_utils.BoxCoder(bbox_reg_weights) 532 | 533 | self.box_roi_pool = box_roi_pool 534 | self.box_head = box_head 535 | self.box_predictor = box_predictor 536 | 537 | self.score_thresh = score_thresh 538 | self.nms_thresh = nms_thresh 539 | self.detections_per_img = detections_per_img 540 | 541 | self.mask_roi_pool = mask_roi_pool 542 | self.mask_head = mask_head 543 | self.mask_predictor = mask_predictor 544 | 545 | self.keypoint_roi_pool = keypoint_roi_pool 546 | self.keypoint_head = keypoint_head 547 | self.keypoint_predictor = keypoint_predictor 548 | 549 | def has_mask(self): 550 | if self.mask_roi_pool is None: 551 | return False 552 | if self.mask_head is None: 553 | return False 554 | if self.mask_predictor is None: 555 | return False 556 | return True 557 | 558 | def has_keypoint(self): 559 | if self.keypoint_roi_pool is None: 560 | return False 561 | if self.keypoint_head is None: 562 | return False 563 | if self.keypoint_predictor is None: 564 | return False 565 | return True 566 | 567 | def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): 568 | # type: (List[Tensor], List[Tensor], List[Tensor]) 569 | matched_idxs = [] 570 | labels = [] 571 | for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels): 572 | 573 | if gt_boxes_in_image.numel() == 0: 574 | # Background image 575 | device = proposals_in_image.device 576 | clamped_matched_idxs_in_image = torch.zeros( 577 | (proposals_in_image.shape[0],), dtype=torch.int64, device=device 578 | ) 579 | labels_in_image = torch.zeros( 580 | (proposals_in_image.shape[0],), dtype=torch.int64, device=device 581 | ) 582 | else: 583 | # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands 584 | match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image) 585 | matched_idxs_in_image = self.proposal_matcher(match_quality_matrix) 586 | 587 | clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0) 588 | 589 | labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image] 590 | labels_in_image = labels_in_image.to(dtype=torch.int64) 591 | 592 | # Label background (below the low threshold) 593 | bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD 594 | labels_in_image[bg_inds] = torch.tensor(0) 595 | 596 | # Label ignore proposals (between low and high thresholds) 597 | ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS 598 | labels_in_image[ignore_inds] = torch.tensor(-1) # -1 is ignored by sampler 599 | 600 | matched_idxs.append(clamped_matched_idxs_in_image) 601 | labels.append(labels_in_image) 602 | return matched_idxs, labels 603 | 604 | def subsample(self, labels): 605 | # type: (List[Tensor]) 606 | sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) 607 | sampled_inds = [] 608 | for img_idx, (pos_inds_img, neg_inds_img) in enumerate( 609 | zip(sampled_pos_inds, sampled_neg_inds) 610 | ): 611 | img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1) 612 | sampled_inds.append(img_sampled_inds) 613 | return sampled_inds 614 | 615 | def add_gt_proposals(self, proposals, gt_boxes): 616 | # type: (List[Tensor], List[Tensor]) 617 | proposals = [ 618 | torch.cat((proposal, gt_box)) 619 | for proposal, gt_box in zip(proposals, gt_boxes) 620 | ] 621 | 622 | return proposals 623 | 624 | def DELTEME_all(self, the_list): 625 | # type: (List[bool]) 626 | for i in the_list: 627 | if not i: 628 | return False 629 | return True 630 | 631 | def check_targets(self, targets): 632 | # type: (Optional[List[Dict[str, Tensor]]]) 633 | assert targets is not None 634 | assert self.DELTEME_all(["boxes" in t for t in targets]) 635 | assert self.DELTEME_all(["labels" in t for t in targets]) 636 | if self.has_mask(): 637 | assert self.DELTEME_all(["masks" in t for t in targets]) 638 | 639 | def select_training_samples(self, proposals, targets): 640 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) 641 | self.check_targets(targets) 642 | assert targets is not None 643 | dtype = proposals[0].dtype 644 | device = proposals[0].device 645 | 646 | gt_boxes = [t["boxes"].to(dtype) for t in targets] 647 | gt_labels = [t["labels"] for t in targets] 648 | 649 | # append ground-truth bboxes to propos 650 | proposals = self.add_gt_proposals(proposals, gt_boxes) 651 | 652 | # get matching gt indices for each proposal 653 | matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) 654 | # sample a fixed proportion of positive-negative proposals 655 | sampled_inds = self.subsample(labels) 656 | matched_gt_boxes = [] 657 | num_images = len(proposals) 658 | for img_id in range(num_images): 659 | img_sampled_inds = sampled_inds[img_id] 660 | proposals[img_id] = proposals[img_id][img_sampled_inds] 661 | labels[img_id] = labels[img_id][img_sampled_inds].cuda() 662 | matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds] 663 | 664 | gt_boxes_in_image = gt_boxes[img_id] 665 | if gt_boxes_in_image.numel() == 0: 666 | gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device) 667 | matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]]) 668 | 669 | regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) 670 | return proposals, matched_idxs, labels, regression_targets 671 | 672 | def postprocess_detections(self, class_logits, box_regression, proposals, image_shapes): 673 | # type: (Tensor, Tensor, List[Tensor], List[Tuple[int, int]]) 674 | device = class_logits.device 675 | num_classes = class_logits.shape[-1] 676 | 677 | boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals] 678 | pred_boxes = self.box_coder.decode(box_regression, proposals) 679 | 680 | pred_scores = F.softmax(class_logits, -1) 681 | 682 | pred_boxes_list = pred_boxes.split(boxes_per_image, 0) 683 | pred_scores_list = pred_scores.split(boxes_per_image, 0) 684 | 685 | all_boxes = [] 686 | all_scores = [] 687 | all_labels = [] 688 | for boxes, scores, image_shape in zip(pred_boxes_list, pred_scores_list, image_shapes): 689 | boxes = box_ops.clip_boxes_to_image(boxes, image_shape) 690 | 691 | # create labels for each prediction 692 | labels = torch.arange(num_classes, device=device) 693 | labels = labels.view(1, -1).expand_as(scores) 694 | 695 | # remove predictions with the background label 696 | boxes = boxes[:, 1:] 697 | scores = scores[:, 1:] 698 | labels = labels[:, 1:] 699 | 700 | # batch everything, by making every class prediction be a separate instance 701 | boxes = boxes.reshape(-1, 4) 702 | scores = scores.reshape(-1) 703 | labels = labels.reshape(-1) 704 | 705 | # remove low scoring boxes 706 | inds = torch.nonzero(scores > self.score_thresh).squeeze(1) 707 | boxes, scores, labels = boxes[inds], scores[inds], labels[inds] 708 | 709 | # remove empty boxes 710 | keep = box_ops.remove_small_boxes(boxes, min_size=1e-2) 711 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep] 712 | 713 | # non-maximum suppression, independently done per class 714 | keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) 715 | # keep only topk scoring predictions 716 | keep = keep[:self.detections_per_img] 717 | boxes, scores, labels = boxes[keep], scores[keep], labels[keep] 718 | 719 | all_boxes.append(boxes) 720 | all_scores.append(scores) 721 | all_labels.append(labels) 722 | 723 | return all_boxes, all_scores, all_labels 724 | 725 | def forward(self, features, proposals, image_shapes, targets=None): 726 | # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]], Optional[List[Dict[str, Tensor]]]) 727 | """ 728 | Arguments: 729 | features (List[Tensor]) 730 | proposals (List[Tensor[N, 4]]) 731 | image_shapes (List[Tuple[H, W]]) 732 | targets (List[Dict]) 733 | """ 734 | 735 | #print('targets type:', type(targets)) 736 | if targets is not None: 737 | for t in targets: 738 | if t["labels"].dtype != torch.int64: 739 | t["labels"] = t["labels"].type(torch.LongTensor) 740 | 741 | # TODO: https://github.com/pytorch/pytorch/issues/26731 742 | floating_point_types = (torch.float, torch.double, torch.half) 743 | assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type' 744 | assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' 745 | 746 | if self.has_keypoint(): 747 | assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' 748 | 749 | #if self.training: 750 | proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) 751 | 752 | box_features = self.box_roi_pool(features, proposals, image_shapes) 753 | box_features = self.box_head(box_features) 754 | class_logits, box_regression = self.box_predictor(box_features) 755 | 756 | result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) 757 | losses = {} 758 | #if self.training: 759 | assert labels is not None and regression_targets is not None 760 | loss_classifier, loss_box_reg = fastrcnn_loss( 761 | class_logits, box_regression, labels, regression_targets) 762 | losses = { 763 | "loss_classifier": loss_classifier, 764 | "loss_box_reg": loss_box_reg 765 | } 766 | 767 | if self.has_mask(): 768 | mask_proposals = [p["boxes"] for p in result] 769 | #if self.training: 770 | assert matched_idxs is not None 771 | # during training, only focus on positive boxes 772 | num_images = len(proposals) 773 | mask_proposals = [] 774 | pos_matched_idxs = [] 775 | for img_id in range(num_images): 776 | pos = torch.nonzero(labels[img_id] > 0).squeeze(1) 777 | mask_proposals.append(proposals[img_id][pos]) 778 | pos_matched_idxs.append(matched_idxs[img_id][pos]) 779 | 780 | if self.mask_roi_pool is not None: 781 | mask_features = self.mask_roi_pool(features, mask_proposals, image_shapes) 782 | mask_features = self.mask_head(mask_features) 783 | mask_logits = self.mask_predictor(mask_features) 784 | else: 785 | mask_logits = torch.tensor(0) 786 | raise Exception("Expected mask_roi_pool to be not None") 787 | 788 | loss_mask = {} 789 | #if self.training: 790 | assert targets is not None 791 | assert pos_matched_idxs is not None 792 | assert mask_logits is not None 793 | 794 | gt_masks = [t["masks"] for t in targets] 795 | gt_labels = [t["labels"] for t in targets] 796 | rcnn_loss_mask = maskrcnn_loss( 797 | mask_logits, mask_proposals, 798 | gt_masks, gt_labels, pos_matched_idxs) 799 | loss_mask = { 800 | "loss_mask": rcnn_loss_mask 801 | } 802 | 803 | losses.update(loss_mask) 804 | 805 | # keep none checks in if conditional so torchscript will conditionally 806 | # compile each branch 807 | if self.keypoint_roi_pool is not None and self.keypoint_head is not None \ 808 | and self.keypoint_predictor is not None: 809 | keypoint_proposals = [p["boxes"] for p in result] 810 | 811 | #if self.training: 812 | # during training, only focus on positive boxes 813 | num_images = len(proposals) 814 | keypoint_proposals = [] 815 | pos_matched_idxs = [] 816 | assert matched_idxs is not None 817 | for img_id in range(num_images): 818 | pos = torch.nonzero(labels[img_id] > 0).squeeze(1) 819 | keypoint_proposals.append(proposals[img_id][pos]) 820 | pos_matched_idxs.append(matched_idxs[img_id][pos]) 821 | 822 | keypoint_features = self.keypoint_roi_pool(features, keypoint_proposals, image_shapes) 823 | keypoint_features = self.keypoint_head(keypoint_features) 824 | keypoint_logits = self.keypoint_predictor(keypoint_features) 825 | 826 | loss_keypoint = {} 827 | #if self.training: 828 | assert targets is not None 829 | assert pos_matched_idxs is not None 830 | 831 | gt_keypoints = [t["keypoints"] for t in targets] 832 | rcnn_loss_keypoint = keypointrcnn_loss( 833 | keypoint_logits, keypoint_proposals, 834 | gt_keypoints, pos_matched_idxs) 835 | loss_keypoint = { 836 | "loss_keypoint": rcnn_loss_keypoint 837 | } 838 | 839 | losses.update(loss_keypoint) 840 | 841 | return result, losses -------------------------------------------------------------------------------- /src/roi_layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liguiyuan/Thundernet-pytorch/0d62cd55430d5ce55560c1efc43d552b2d0b6671/src/roi_layers/__init__.py -------------------------------------------------------------------------------- /src/roi_layers/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.jit.annotations import List 4 | 5 | 6 | def _cat(tensors, dim=0): 7 | # type: (List[Tensor], int) -> Tensor 8 | """ 9 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 10 | """ 11 | # TODO add back the assert 12 | # assert isinstance(tensors, (list, tuple)) 13 | if len(tensors) == 1: 14 | return tensors[0] 15 | return torch.cat(tensors, dim) 16 | 17 | 18 | def convert_boxes_to_roi_format(boxes): 19 | # type: (List[Tensor]) -> Tensor 20 | concat_boxes = _cat([b for b in boxes], dim=0) 21 | temp = [] 22 | for i, b in enumerate(boxes): 23 | temp.append(torch.full_like(b[:, :1], i)) 24 | ids = _cat(temp, dim=0) 25 | rois = torch.cat([ids, concat_boxes], dim=1) 26 | return rois 27 | 28 | 29 | def check_roi_boxes_shape(boxes): 30 | if isinstance(boxes, list): 31 | for _tensor in boxes: 32 | assert _tensor.size(1) == 4, \ 33 | 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' 34 | elif isinstance(boxes, torch.Tensor): 35 | assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' 36 | else: 37 | assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]' 38 | return -------------------------------------------------------------------------------- /src/roi_layers/poolers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, Tensor 5 | 6 | from .ps_roi_align import ps_roi_align, PSRoIAlign 7 | from torchvision.ops.boxes import box_area 8 | 9 | from torch.jit.annotations import Optional, List, Dict, Tuple 10 | import torchvision 11 | 12 | # copying result_idx_in_level to a specific index in result[] 13 | # is not supported by ONNX tracing yet. 14 | # _onnx_merge_levels() is an implementation supported by ONNX 15 | # that merges the levels to the right indices 16 | @torch.jit.unused 17 | def _onnx_merge_levels(levels, unmerged_results): 18 | # type: (Tensor, List[Tensor]) -> Tensor 19 | first_result = unmerged_results[0] 20 | dtype, device = first_result.dtype, first_result.device 21 | res = torch.zeros((levels.size(0), first_result.size(1), 22 | first_result.size(2), first_result.size(3)), 23 | dtype=dtype, device=device) 24 | for l in range(len(unmerged_results)): 25 | index = (levels == l).nonzero().view(-1, 1, 1, 1) 26 | index = index.expand(index.size(0), 27 | unmerged_results[l].size(1), 28 | unmerged_results[l].size(2), 29 | unmerged_results[l].size(3)) 30 | res = res.scatter(0, index, unmerged_results[l]) 31 | return res 32 | 33 | 34 | # TODO: (eellison) T54974082 https://github.com/pytorch/pytorch/issues/26744/pytorch/issues/26744 35 | def initLevelMapper(k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): 36 | # type: (int, int, int, int, float) 37 | return LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps) 38 | 39 | 40 | @torch.jit.script 41 | class LevelMapper(object): 42 | """Determine which FPN level each RoI in a set of RoIs should map to based 43 | on the heuristic in the FPN paper. 44 | Arguments: 45 | k_min (int) 46 | k_max (int) 47 | canonical_scale (int) 48 | canonical_level (int) 49 | eps (float) 50 | """ 51 | 52 | def __init__(self, k_min, k_max, canonical_scale=224, canonical_level=4, eps=1e-6): 53 | # type: (int, int, int, int, float) 54 | self.k_min = k_min 55 | self.k_max = k_max 56 | self.s0 = canonical_scale 57 | self.lvl0 = canonical_level 58 | self.eps = eps 59 | 60 | def __call__(self, boxlists): 61 | # type: (List[Tensor]) 62 | """ 63 | Arguments: 64 | boxlists (list[BoxList]) 65 | """ 66 | # Compute level ids 67 | s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists])) 68 | 69 | # Eqn.(1) in FPN paper 70 | target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype)) 71 | target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max) 72 | return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64) 73 | 74 | 75 | class MultiScaleRoIAlign(nn.Module): 76 | """ 77 | Multi-scale RoIAlign pooling, which is useful for detection with or without FPN. 78 | It infers the scale of the pooling via the heuristics present in the FPN paper. 79 | Arguments: 80 | featmap_names (List[str]): the names of the feature maps that will be used 81 | for the pooling. 82 | output_size (List[Tuple[int, int]] or List[int]): output size for the pooled region 83 | sampling_ratio (int): sampling ratio for ROIAlign 84 | Examples:: 85 | >>> m = torchvision.ops.MultiScaleRoIAlign(['feat1', 'feat3'], 3, 2) 86 | >>> i = OrderedDict() 87 | >>> i['feat1'] = torch.rand(1, 5, 64, 64) 88 | >>> i['feat2'] = torch.rand(1, 5, 32, 32) # this feature won't be used in the pooling 89 | >>> i['feat3'] = torch.rand(1, 5, 16, 16) 90 | >>> # create some random bounding boxes 91 | >>> boxes = torch.rand(6, 4) * 256; boxes[:, 2:] += boxes[:, :2] 92 | >>> # original image size, before computing the feature maps 93 | >>> image_sizes = [(512, 512)] 94 | >>> output = m(i, [boxes], image_sizes) 95 | >>> print(output.shape) 96 | >>> torch.Size([6, 5, 3, 3]) 97 | """ 98 | 99 | __annotations__ = { 100 | 'scales': Optional[List[float]], 101 | 'map_levels': Optional[LevelMapper] 102 | } 103 | 104 | def __init__(self, featmap_names, output_size, sampling_ratio): 105 | super(MultiScaleRoIAlign, self).__init__() 106 | if isinstance(output_size, int): 107 | output_size = (output_size, output_size) 108 | self.featmap_names = featmap_names 109 | self.sampling_ratio = sampling_ratio 110 | self.output_size = tuple(output_size) 111 | self.scales = None 112 | self.map_levels = None 113 | 114 | def convert_to_roi_format(self, boxes): 115 | # type: (List[Tensor]) 116 | concat_boxes = torch.cat(boxes, dim=0) 117 | device, dtype = concat_boxes.device, concat_boxes.dtype 118 | ids = torch.cat( 119 | [ 120 | torch.full_like(b[:, :1], i, dtype=dtype, layout=torch.strided, device=device) 121 | for i, b in enumerate(boxes) 122 | ], 123 | dim=0, 124 | ) 125 | rois = torch.cat([ids, concat_boxes], dim=1) 126 | return rois 127 | 128 | def infer_scale(self, feature, original_size): 129 | # type: (Tensor, List[int]) 130 | # assumption: the scale is of the form 2 ** (-k), with k integer 131 | size = feature.shape[-2:] 132 | possible_scales = torch.jit.annotate(List[float], []) 133 | for s1, s2 in zip(size, original_size): 134 | approx_scale = float(s1) / float(s2) 135 | scale = 2 ** float(torch.tensor(approx_scale).log2().round()) 136 | possible_scales.append(scale) 137 | assert possible_scales[0] == possible_scales[1] 138 | return possible_scales[0] 139 | 140 | def setup_scales(self, features, image_shapes): 141 | # type: (List[Tensor], List[Tuple[int, int]]) 142 | assert len(image_shapes) != 0 143 | max_x = 0 144 | max_y = 0 145 | for shape in image_shapes: 146 | max_x = max(shape[0], max_x) 147 | max_y = max(shape[1], max_y) 148 | original_input_shape = (max_x, max_y) 149 | 150 | scales = [self.infer_scale(feat, original_input_shape) for feat in features] 151 | # get the levels in the feature map by leveraging the fact that the network always 152 | # downsamples by a factor of 2 at each level. 153 | lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item() 154 | lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item() 155 | self.scales = scales 156 | self.map_levels = initLevelMapper(int(lvl_min), int(lvl_max)) 157 | 158 | def forward(self, x, boxes, image_shapes): 159 | # type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) 160 | """ 161 | Arguments: 162 | x (OrderedDict[Tensor]): feature maps for each level. They are assumed to have 163 | all the same number of channels, but they can have different sizes. 164 | boxes (List[Tensor[N, 4]]): boxes to be used to perform the pooling operation, in 165 | (x1, y1, x2, y2) format and in the image reference size, not the feature map 166 | reference. 167 | image_shapes (List[Tuple[height, width]]): the sizes of each image before they 168 | have been fed to a CNN to obtain feature maps. This allows us to infer the 169 | scale factor for each one of the levels to be pooled. 170 | Returns: 171 | result (Tensor) 172 | """ 173 | x_filtered = [] 174 | for k, v in x.items(): 175 | if k in self.featmap_names: 176 | x_filtered.append(v) 177 | num_levels = len(x_filtered) 178 | rois = self.convert_to_roi_format(boxes) 179 | if self.scales is None: 180 | self.setup_scales(x_filtered, image_shapes) 181 | 182 | scales = self.scales 183 | assert scales is not None 184 | 185 | if num_levels == 1: 186 | return ps_roi_align( 187 | x_filtered[0], rois, 188 | output_size=self.output_size, 189 | spatial_scale=scales[0], 190 | sampling_ratio=self.sampling_ratio 191 | ) 192 | 193 | mapper = self.map_levels 194 | assert mapper is not None 195 | 196 | levels = mapper(boxes) 197 | 198 | num_rois = len(rois) 199 | num_channels = x_filtered[0].shape[1] 200 | 201 | dtype, device = x_filtered[0].dtype, x_filtered[0].device 202 | result = torch.zeros( 203 | (num_rois, num_channels,) + self.output_size, 204 | dtype=dtype, 205 | device=device, 206 | ) 207 | 208 | tracing_results = [] 209 | for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)): 210 | idx_in_level = torch.nonzero(levels == level).squeeze(1) 211 | rois_per_level = rois[idx_in_level] 212 | 213 | result_idx_in_level = ps_roi_align( 214 | per_level_feature, rois_per_level, 215 | output_size=self.output_size, 216 | spatial_scale=scale, sampling_ratio=self.sampling_ratio) 217 | 218 | if torchvision._is_tracing(): 219 | tracing_results.append(result_idx_in_level.to(dtype)) 220 | else: 221 | result[idx_in_level] = result_idx_in_level 222 | 223 | if torchvision._is_tracing(): 224 | result = _onnx_merge_levels(levels, tracing_results) 225 | 226 | return result -------------------------------------------------------------------------------- /src/roi_layers/ps_roi_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | from torch.nn.modules.utils import _pair 5 | from torch.jit.annotations import List 6 | 7 | from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape 8 | 9 | 10 | def ps_roi_align(input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1): 11 | # type: (Tensor, Tensor, int, float, int) -> Tensor 12 | """ 13 | Performs Position-Sensitive Region of Interest (RoI) Align operator 14 | mentioned in Light-Head R-CNN. 15 | Arguments: 16 | input (Tensor[N, C, H, W]): input tensor 17 | boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) 18 | format where the regions will be taken from. If a single Tensor is passed, 19 | then the first column should contain the batch index. If a list of Tensors 20 | is passed, then each Tensor will correspond to the boxes for an element i 21 | in a batch 22 | output_size (int or Tuple[int, int]): the size of the output after the cropping 23 | is performed, as (height, width) 24 | spatial_scale (float): a scaling factor that maps the input coordinates to 25 | the box coordinates. Default: 1.0 26 | sampling_ratio (int): number of sampling points in the interpolation grid 27 | used to compute the output value of each pooled output bin. If > 0 28 | then exactly sampling_ratio x sampling_ratio grid points are used. 29 | If <= 0, then an adaptive number of grid points are used (computed as 30 | ceil(roi_width / pooled_w), and likewise for height). Default: -1 31 | Returns: 32 | output (Tensor[K, C, output_size[0], output_size[1]]) 33 | """ 34 | check_roi_boxes_shape(boxes) 35 | rois = boxes 36 | output_size = _pair(output_size) 37 | if not isinstance(rois, torch.Tensor): 38 | rois = convert_boxes_to_roi_format(rois) 39 | output, _ = torch.ops.torchvision.ps_roi_align(input, rois, spatial_scale, 40 | output_size[0], 41 | output_size[1], 42 | sampling_ratio) 43 | return output 44 | 45 | 46 | class PSRoIAlign(nn.Module): 47 | """ 48 | See ps_roi_align 49 | """ 50 | def __init__(self, output_size, spatial_scale, sampling_ratio): 51 | super(PSRoIAlign, self).__init__() 52 | self.output_size = output_size 53 | self.spatial_scale = spatial_scale 54 | self.sampling_ratio = sampling_ratio 55 | 56 | def forward(self, input, rois): 57 | return ps_roi_align(input, rois, self.output_size, self.spatial_scale, 58 | self.sampling_ratio) 59 | 60 | def __repr__(self): 61 | tmpstr = self.__class__.__name__ + '(' 62 | tmpstr += 'output_size=' + str(self.output_size) 63 | tmpstr += ', spatial_scale=' + str(self.spatial_scale) 64 | tmpstr += ', sampling_ratio=' + str(self.sampling_ratio) 65 | tmpstr += ')' 66 | return tmpstr 67 | -------------------------------------------------------------------------------- /src/roi_layers/ps_roi_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | from torch.nn.modules.utils import _pair 5 | from torch.jit.annotations import List 6 | 7 | from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape 8 | 9 | 10 | def ps_roi_pool(input, boxes, output_size, spatial_scale=1.0): 11 | # type: (Tensor, Tensor, int, float) -> Tensor 12 | """ 13 | Performs Position-Sensitive Region of Interest (RoI) Pool operator 14 | described in R-FCN 15 | Arguments: 16 | input (Tensor[N, C, H, W]): input tensor 17 | boxes (Tensor[K, 5] or List[Tensor[L, 4]]): the box coordinates in (x1, y1, x2, y2) 18 | format where the regions will be taken from. If a single Tensor is passed, 19 | then the first column should contain the batch index. If a list of Tensors 20 | is passed, then each Tensor will correspond to the boxes for an element i 21 | in a batch 22 | output_size (int or Tuple[int, int]): the size of the output after the cropping 23 | is performed, as (height, width) 24 | spatial_scale (float): a scaling factor that maps the input coordinates to 25 | the box coordinates. Default: 1.0 26 | Returns: 27 | output (Tensor[K, C, output_size[0], output_size[1]]) 28 | """ 29 | check_roi_boxes_shape(boxes) 30 | rois = boxes 31 | output_size = _pair(output_size) 32 | if not isinstance(rois, torch.Tensor): 33 | rois = convert_boxes_to_roi_format(rois) 34 | output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, 35 | output_size[0], 36 | output_size[1]) 37 | return output 38 | 39 | 40 | class PSRoIPool(nn.Module): 41 | """ 42 | See ps_roi_pool 43 | """ 44 | def __init__(self, output_size, spatial_scale): 45 | super(PSRoIPool, self).__init__() 46 | self.output_size = output_size 47 | self.spatial_scale = spatial_scale 48 | 49 | def forward(self, input, rois): 50 | return ps_roi_pool(input, rois, self.output_size, self.spatial_scale) 51 | 52 | def __repr__(self): 53 | tmpstr = self.__class__.__name__ + '(' 54 | tmpstr += 'output_size=' + str(self.output_size) 55 | tmpstr += ', spatial_scale=' + str(self.spatial_scale) 56 | tmpstr += ')' 57 | return tmpstr 58 | -------------------------------------------------------------------------------- /src/rpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | from torch.nn import functional as F 4 | from torch import nn, Tensor 5 | 6 | import torchvision 7 | from torchvision.ops import boxes as box_ops 8 | 9 | from . import _utils as det_utils 10 | #from torchvision.models.detection import _utils as det_utils 11 | #from .image_list import ImageList 12 | from torchvision.models.detection.image_list import ImageList 13 | 14 | from torch.jit.annotations import List, Optional, Dict, Tuple 15 | import pandas as pd 16 | 17 | @torch.jit.unused 18 | def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): 19 | # type: (Tensor, int) -> Tuple[int, int] 20 | from torch.onnx import operators 21 | num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) 22 | # TODO : remove cast to IntTensor/num_anchors.dtype when 23 | # ONNX Runtime version is updated with ReduceMin int64 support 24 | pre_nms_top_n = torch.min(torch.cat( 25 | (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), 26 | num_anchors), 0).to(torch.int32)).to(num_anchors.dtype) 27 | 28 | return num_anchors, pre_nms_top_n 29 | 30 | 31 | class AnchorGenerator(nn.Module): 32 | __annotations__ = { 33 | "cell_anchors": Optional[List[torch.Tensor]], 34 | "_cache": Dict[str, List[torch.Tensor]] 35 | } 36 | 37 | """ 38 | Module that generates anchors for a set of feature maps and 39 | image sizes. 40 | The module support computing anchors at multiple sizes and aspect ratios 41 | per feature map. 42 | sizes and aspect_ratios should have the same number of elements, and it should 43 | correspond to the number of feature maps. 44 | sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, 45 | and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors 46 | per spatial location for feature map i. 47 | Arguments: 48 | sizes (Tuple[Tuple[int]]): 49 | aspect_ratios (Tuple[Tuple[float]]): 50 | """ 51 | 52 | def __init__( 53 | self, 54 | sizes=(128, 256, 512), 55 | aspect_ratios=(0.5, 1.0, 2.0), 56 | ): 57 | super(AnchorGenerator, self).__init__() 58 | 59 | if not isinstance(sizes[0], (list, tuple)): 60 | # TODO change this 61 | sizes = tuple((s,) for s in sizes) 62 | if not isinstance(aspect_ratios[0], (list, tuple)): 63 | aspect_ratios = (aspect_ratios,) * len(sizes) 64 | 65 | assert len(sizes) == len(aspect_ratios) 66 | 67 | self.sizes = sizes 68 | self.aspect_ratios = aspect_ratios 69 | self.cell_anchors = None 70 | self._cache = {} 71 | 72 | # TODO: https://github.com/pytorch/pytorch/issues/26792 73 | # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. 74 | # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) 75 | def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): 76 | # type: (List[int], List[float], int, Device) # noqa: F821 77 | scales = torch.as_tensor(scales, dtype=dtype, device=device) 78 | aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) 79 | h_ratios = torch.sqrt(aspect_ratios) 80 | w_ratios = 1 / h_ratios 81 | 82 | ws = (w_ratios[:, None] * scales[None, :]).view(-1) 83 | hs = (h_ratios[:, None] * scales[None, :]).view(-1) 84 | 85 | base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 86 | return base_anchors.round() 87 | 88 | def set_cell_anchors(self, dtype, device): 89 | # type: (int, Device) -> None # noqa: F821 90 | if self.cell_anchors is not None: 91 | cell_anchors = self.cell_anchors 92 | assert cell_anchors is not None 93 | # suppose that all anchors have the same device 94 | # which is a valid assumption in the current state of the codebase 95 | if cell_anchors[0].device == device: 96 | return 97 | 98 | cell_anchors = [ 99 | self.generate_anchors( 100 | sizes, 101 | aspect_ratios, 102 | dtype, 103 | device 104 | ) 105 | for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) 106 | ] 107 | self.cell_anchors = cell_anchors 108 | 109 | def num_anchors_per_location(self): 110 | return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] 111 | 112 | # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), 113 | # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. 114 | def grid_anchors(self, grid_sizes, strides): 115 | # type: (List[List[int]], List[List[Tensor]]) 116 | anchors = [] 117 | cell_anchors = self.cell_anchors 118 | assert cell_anchors is not None 119 | 120 | for size, stride, base_anchors in zip( 121 | grid_sizes, strides, cell_anchors 122 | ): 123 | grid_height, grid_width = size 124 | stride_height, stride_width = stride 125 | device = base_anchors.device 126 | 127 | # For output anchor, compute [x_center, y_center, x_center, y_center] 128 | shifts_x = torch.arange( 129 | 0, grid_width, dtype=torch.float32, device=device 130 | ) * stride_width 131 | shifts_y = torch.arange( 132 | 0, grid_height, dtype=torch.float32, device=device 133 | ) * stride_height 134 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 135 | shift_x = shift_x.reshape(-1) 136 | shift_y = shift_y.reshape(-1) 137 | shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) 138 | 139 | # For every (base anchor, output anchor) pair, 140 | # offset each zero-centered base anchor by the center of the output anchor. 141 | anchors.append( 142 | (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) 143 | ) 144 | 145 | return anchors 146 | 147 | def cached_grid_anchors(self, grid_sizes, strides): 148 | # type: (List[List[int]], List[List[Tensor]]) 149 | key = str(grid_sizes) + str(strides) 150 | if key in self._cache: 151 | return self._cache[key] 152 | anchors = self.grid_anchors(grid_sizes, strides) 153 | self._cache[key] = anchors 154 | return anchors 155 | 156 | def forward(self, image_list, feature_maps): 157 | # type: (ImageList, List[Tensor]) 158 | 159 | grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) # grid_sizes: [torch.Size([20, 20])] 160 | #image_size = image_list.tensors.shape[-2:] 161 | image_size = image_list.tensors[0].shape[-2:] # image_size: torch.Size([320, 320]) 162 | 163 | dtype, device = feature_maps[0].dtype, feature_maps[0].device 164 | strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64, device=device), 165 | torch.tensor(image_size[1] / g[1], dtype=torch.int64, device=device)] for g in grid_sizes] 166 | self.set_cell_anchors(dtype, device) 167 | anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) 168 | anchors = torch.jit.annotate(List[List[torch.Tensor]], []) 169 | for i, (image_height, image_width) in enumerate(image_list.image_sizes): 170 | anchors_in_image = [] 171 | for anchors_per_feature_map in anchors_over_all_feature_maps: 172 | anchors_in_image.append(anchors_per_feature_map) 173 | anchors.append(anchors_in_image) 174 | 175 | # per image generate 10000 regions 176 | anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] # type: list[Tensor] 177 | # Clear the cache in case that memory leaks. 178 | self._cache.clear() 179 | 180 | return anchors 181 | 182 | 183 | class RPNHead(nn.Module): 184 | """ 185 | Adds a simple RPN Head with classification and regression heads 186 | Arguments: 187 | in_channels (int): number of channels of the input feature 188 | num_anchors (int): number of anchors to be predicted 189 | """ 190 | 191 | def __init__(self, in_channels, num_anchors, rpn_channel=256): 192 | super(RPNHead, self).__init__() 193 | self.dw5x5 = nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=1, padding=2, groups=in_channels) 194 | self.conv = nn.Conv2d(in_channels, rpn_channel, kernel_size=1, stride=1, padding=0) 195 | self.cls_logits = nn.Conv2d(rpn_channel, num_anchors, kernel_size=1, stride=1) # output channel: num_anchors*2 ? why? 196 | self.bbox_pred = nn.Conv2d(rpn_channel, num_anchors * 4, kernel_size=1, stride=1) 197 | 198 | for l in self.children(): 199 | torch.nn.init.normal_(l.weight, std=0.01) 200 | torch.nn.init.constant_(l.bias, 0) 201 | 202 | def forward(self, x): 203 | # type: (List[Tensor]) 204 | logits = [] 205 | bbox_reg = [] 206 | #rpn_sam = [] 207 | 208 | for feature in x: 209 | f1 = F.relu(self.dw5x5(feature)) 210 | f1 = F.relu(self.conv(f1)) # share feature: f1 211 | logits.append(self.cls_logits(f1)) 212 | bbox_reg.append(self.bbox_pred(f1)) 213 | #rpn_sam.append(f1) 214 | rpn_sam = f1 215 | return logits, bbox_reg, rpn_sam 216 | 217 | 218 | def permute_and_flatten(layer, N, A, C, H, W): 219 | # type: (Tensor, int, int, int, int, int) 220 | layer = layer.view(N, -1, C, H, W) 221 | layer = layer.permute(0, 3, 4, 1, 2) 222 | layer = layer.reshape(N, -1, C) 223 | return layer 224 | 225 | 226 | def concat_box_prediction_layers(box_cls, box_regression): 227 | # type: (List[Tensor], List[Tensor]) 228 | box_cls_flattened = [] 229 | box_regression_flattened = [] 230 | # for each feature level, permute the outputs to make them be in the 231 | # same format as the labels. Note that the labels are computed for 232 | # all feature levels concatenated, so we keep the same representation 233 | # for the objectness and the box_regression 234 | for box_cls_per_level, box_regression_per_level in zip( 235 | box_cls, box_regression 236 | ): 237 | N, AxC, H, W = box_cls_per_level.shape # [N, 25x1, 20, 20] 238 | Ax4 = box_regression_per_level.shape[1] # 25x4 = 100 239 | A = Ax4 // 4 # A = 25 240 | C = AxC // A # C = 1 241 | box_cls_per_level = permute_and_flatten( # shape: [N, 10000, 1] 242 | box_cls_per_level, N, A, C, H, W 243 | ) 244 | box_cls_flattened.append(box_cls_per_level) 245 | 246 | box_regression_per_level = permute_and_flatten( # shape: [N, 10000, 4] 247 | box_regression_per_level, N, A, 4, H, W 248 | ) 249 | box_regression_flattened.append(box_regression_per_level) 250 | # concatenate on the first dimension (representing the feature levels), to 251 | # take into account the way the labels were generated (with all feature maps 252 | # being concatenated as well) 253 | box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2) # shape: [N x 10000, 1] 254 | box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4) # shape: [N x 10000, 4] 255 | return box_cls, box_regression 256 | 257 | 258 | class RegionProposalNetwork(torch.nn.Module): 259 | """ 260 | Implements Region Proposal Network (RPN). 261 | Arguments: 262 | anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature 263 | maps. 264 | head (nn.Module): module that computes the objectness and regression deltas 265 | fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be 266 | considered as positive during training of the RPN. 267 | bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be 268 | considered as negative during training of the RPN. 269 | batch_size_per_image (int): number of anchors that are sampled during training of the RPN 270 | for computing the loss 271 | positive_fraction (float): proportion of positive anchors in a mini-batch during training 272 | of the RPN 273 | pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS. It should 274 | contain two fields: training and testing, to allow for different values depending 275 | on training or evaluation 276 | post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS. It should 277 | contain two fields: training and testing, to allow for different values depending 278 | on training or evaluation 279 | nms_thresh (float): NMS threshold used for postprocessing the RPN proposals 280 | """ 281 | __annotations__ = { 282 | 'box_coder': det_utils.BoxCoder, 283 | 'proposal_matcher': det_utils.Matcher, 284 | 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, 285 | 'pre_nms_top_n': Dict[str, int], 286 | 'post_nms_top_n': Dict[str, int], 287 | } 288 | 289 | def __init__(self, 290 | anchor_generator, 291 | head, 292 | # 293 | fg_iou_thresh, bg_iou_thresh, 294 | batch_size_per_image, positive_fraction, 295 | # 296 | pre_nms_top_n, post_nms_top_n, nms_thresh): 297 | super(RegionProposalNetwork, self).__init__() 298 | self.anchor_generator = anchor_generator 299 | self.head = head 300 | self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) 301 | 302 | # used during training 303 | self.box_similarity = box_ops.box_iou 304 | 305 | self.proposal_matcher = det_utils.Matcher( 306 | fg_iou_thresh, 307 | bg_iou_thresh, 308 | allow_low_quality_matches=True, 309 | ) 310 | 311 | self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( 312 | batch_size_per_image, positive_fraction 313 | ) 314 | # used during testing 315 | self._pre_nms_top_n = pre_nms_top_n 316 | self._post_nms_top_n = post_nms_top_n 317 | self.nms_thresh = nms_thresh 318 | self.min_size = 1e-3 319 | 320 | def pre_nms_top_n(self): 321 | if self.training: 322 | return self._pre_nms_top_n['training'] 323 | return self._pre_nms_top_n['testing'] 324 | 325 | def post_nms_top_n(self): 326 | if self.training: 327 | return self._post_nms_top_n['training'] 328 | return self._post_nms_top_n['testing'] 329 | 330 | def assign_targets_to_anchors(self, anchors, targets): 331 | # type: (List[Tensor], List[Dict[str, Tensor]]) 332 | labels = [] 333 | matched_gt_boxes = [] 334 | for anchors_per_image, targets_per_image in zip(anchors, targets): 335 | 336 | gt_boxes = targets_per_image["boxes"] 337 | gt_labels = targets_per_image["labels"] 338 | gt_boxes = gt_boxes[gt_labels[:] != -1] # select gt boxes, remove -1 value boxes. 339 | 340 | # in this step, we only select hight quality boxes, and the label of each boxes 341 | # is only objection or background, don't need to classification which it belong to 342 | if gt_boxes.numel() == 0: 343 | # Background image (negative example) 344 | device = anchors_per_image.device 345 | matched_gt_boxes_per_image = torch.ones(anchors_per_image.shape, dtype=torch.float32, device=device) * (-1) 346 | labels_per_image = torch.ones((anchors_per_image.shape[0],), dtype=torch.float32, device=device) * (-1) 347 | else: 348 | match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image) 349 | matched_idxs = self.proposal_matcher(match_quality_matrix) 350 | 351 | # debug: each class count 352 | """ 353 | temp1 = matched_idxs.cpu().detach().numpy() 354 | count1 = pd.value_counts(temp1) 355 | print('each count1:') 356 | print(count1) 357 | """ 358 | 359 | # get the targets corresponding GT for each proposal 360 | # NB: need to clamp the indices because we can have a single 361 | # GT in the image, and matched_idxs can be -2, which goes 362 | # out of bounds 363 | matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)] 364 | 365 | labels_per_image = matched_idxs >= 0 366 | labels_per_image = labels_per_image.to(dtype=torch.float32) 367 | 368 | # Background (negative examples) 369 | bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD # -1 370 | labels_per_image[bg_indices] = torch.tensor(0.0) 371 | 372 | # discard indices that are between thresholds 373 | inds_to_discard = matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS # -2 374 | labels_per_image[inds_to_discard] = torch.tensor(-1.0) 375 | 376 | labels.append(labels_per_image) 377 | matched_gt_boxes.append(matched_gt_boxes_per_image) 378 | return labels, matched_gt_boxes 379 | 380 | def _get_top_n_idx(self, objectness, num_anchors_per_level): 381 | # type: (Tensor, List[int]) 382 | r = [] 383 | offset = 0 384 | for ob in objectness.split(num_anchors_per_level, 1): 385 | is_tracing = True 386 | #if torchvision._is_tracing(): 387 | if is_tracing: 388 | num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n()) 389 | else: 390 | num_anchors = ob.shape[1] 391 | pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors) 392 | _, top_n_idx = ob.topk(pre_nms_top_n, dim=1) 393 | r.append(top_n_idx + offset) 394 | offset += num_anchors 395 | return torch.cat(r, dim=1) 396 | 397 | def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level): 398 | # type: (Tensor, Tensor, List[Tuple[int, int]], List[int]) 399 | num_images = proposals.shape[0] 400 | device = proposals.device 401 | # do not backprop throught objectness 402 | objectness = objectness.detach() 403 | objectness = objectness.reshape(num_images, -1) # shape: [64, 10000] 404 | 405 | levels = [ 406 | torch.full((n,), idx, dtype=torch.int64, device=device) 407 | for idx, n in enumerate(num_anchors_per_level) 408 | ] 409 | levels = torch.cat(levels, 0) 410 | levels = levels.reshape(1, -1).expand_as(objectness) # shape: [64, 10000] 411 | 412 | # select top_n boxes independently per level before applying nms 413 | top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level) # shape: [64, 2000] 414 | 415 | image_range = torch.arange(num_images, device=device) 416 | batch_idx = image_range[:, None] 417 | 418 | objectness = objectness[batch_idx, top_n_idx] # shape: [64, 2000] 419 | levels = levels[batch_idx, top_n_idx] 420 | proposals = proposals[batch_idx, top_n_idx] # shape: [64, 2000, 4] 421 | 422 | final_boxes = [] 423 | final_scores = [] 424 | for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes): 425 | boxes = box_ops.clip_boxes_to_image(boxes, img_shape) 426 | keep = box_ops.remove_small_boxes(boxes, self.min_size) 427 | boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep] 428 | # non-maximum suppression, independently done per level 429 | keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh) 430 | # keep only topk scoring predictions 431 | keep = keep[:self.post_nms_top_n()] 432 | boxes, scores = boxes[keep], scores[keep] 433 | final_boxes.append(boxes) 434 | final_scores.append(scores) 435 | return final_boxes, final_scores 436 | 437 | def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): 438 | # type: (Tensor, Tensor, List[Tensor], List[Tensor]) 439 | """ 440 | Arguments: 441 | objectness (Tensor) 442 | pred_bbox_deltas (Tensor) 443 | labels (List[Tensor]) 444 | regression_targets (List[Tensor]) 445 | Returns: 446 | objectness_loss (Tensor) 447 | box_loss (Tensor) 448 | """ 449 | 450 | sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) 451 | sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) 452 | sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) 453 | 454 | sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) 455 | 456 | objectness = objectness.flatten() 457 | 458 | labels = torch.cat(labels, dim=0) 459 | regression_targets = torch.cat(regression_targets, dim=0) 460 | 461 | box_loss = F.l1_loss( 462 | pred_bbox_deltas[sampled_pos_inds], 463 | regression_targets[sampled_pos_inds], 464 | reduction="sum", 465 | ) / (sampled_inds.numel()) 466 | 467 | objectness_loss = F.binary_cross_entropy_with_logits( 468 | objectness[sampled_inds], labels[sampled_inds] 469 | ) 470 | 471 | return objectness_loss, box_loss 472 | 473 | def forward(self, images, features, targets=None): 474 | # type: (ImageList, Dict[str, Tensor], Optional[List[Dict[str, Tensor]]]) 475 | """ 476 | Arguments: 477 | images (ImageList): images for which we want to compute the predictions 478 | features (List[Tensor]): features computed from the images that are 479 | used for computing the predictions. Each tensor in the list 480 | correspond to different feature levels 481 | targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional). 482 | If provided, each element in the dict should contain a field `boxes`, 483 | with the locations of the ground-truth boxes. 484 | Returns: 485 | boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per 486 | image. 487 | losses (Dict[Tensor]): the losses for the model during training. During 488 | testing, it is an empty dict. 489 | """ 490 | # RPN uses all feature maps that are available 491 | features = list(features.values()) 492 | objectness, pred_bbox_deltas, rpn_sam_input = self.head(features) 493 | anchors = self.anchor_generator(images, features) 494 | 495 | num_images = len(anchors) # batch size 496 | num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] # [25, 20, 20] 497 | num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] # 25x20x20=10000 498 | objectness, pred_bbox_deltas = \ 499 | concat_box_prediction_layers(objectness, pred_bbox_deltas) 500 | # apply pred_bbox_deltas to anchors to obtain the decoded proposals 501 | # note that we detach the deltas because Faster R-CNN do not backprop through 502 | # the proposals 503 | proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors) # shape: [640000, 1, 4] 504 | proposals = proposals.view(num_images, -1, 4) # shape: [64, 10000, 4] 505 | boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) 506 | 507 | losses = {} 508 | 509 | #if self.training: 510 | assert targets is not None 511 | labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) 512 | regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) 513 | loss_objectness, loss_rpn_box_reg = self.compute_loss( 514 | objectness, pred_bbox_deltas, labels, regression_targets) 515 | losses = { 516 | "loss_rpn_objectness": loss_objectness, 517 | "loss_rpn_box_reg": loss_rpn_box_reg, 518 | } 519 | return boxes, losses, rpn_sam_input -------------------------------------------------------------------------------- /src/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import torch 4 | from torch import nn, Tensor 5 | import torchvision 6 | from torch.jit.annotations import List, Tuple, Dict, Optional 7 | 8 | from torchvision.ops import misc as misc_nn_ops 9 | from .image_list import ImageList 10 | #from .roi_heads import paste_masks_in_image 11 | 12 | 13 | class GeneralizedRCNNTransform(nn.Module): 14 | """ 15 | Performs input / target transformation before feeding the data to a GeneralizedRCNN 16 | model. 17 | The transformations it perform are: 18 | - input normalization (mean subtraction and std division) 19 | - input / target resizing to match min_size / max_size 20 | It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets 21 | """ 22 | 23 | def __init__(self): 24 | super(GeneralizedRCNNTransform, self).__init__() 25 | 26 | def forward(self, images, targets2=None): 27 | # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) 28 | 29 | targets = [] 30 | t2 = {} 31 | for t in targets2: 32 | t2["boxes"] = t[:, 0:4] 33 | t2["labels"] = t[:, 4] 34 | targets.append(t2.copy()) 35 | 36 | images = [img for img in images] 37 | 38 | image_sizes = [img.shape[-2:] for img in images] 39 | 40 | image_sizes_list = torch.jit.annotate(List[Tuple[int, int]], []) 41 | for image_size in image_sizes: 42 | assert len(image_size) == 2 43 | image_sizes_list.append((image_size[0], image_size[1])) 44 | 45 | image_list = ImageList(images, image_sizes_list) 46 | return image_list, targets 47 | 48 | 49 | def postprocess(self, result, image_shapes, original_image_sizes): 50 | # type: (List[Dict[str, Tensor]], List[Tuple[int, int]], List[Tuple[int, int]]) 51 | if self.training: 52 | return result 53 | for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): 54 | boxes = pred["boxes"] 55 | boxes = resize_boxes(boxes, im_s, o_im_s) 56 | result[i]["boxes"] = boxes 57 | if "masks" in pred: 58 | masks = pred["masks"] 59 | masks = paste_masks_in_image(masks, boxes, o_im_s) 60 | result[i]["masks"] = masks 61 | if "keypoints" in pred: 62 | keypoints = pred["keypoints"] 63 | keypoints = resize_keypoints(keypoints, im_s, o_im_s) 64 | result[i]["keypoints"] = keypoints 65 | return result 66 | 67 | 68 | def resize_keypoints(keypoints, original_size, new_size): 69 | # type: (Tensor, List[int], List[int]) 70 | ratios = [ 71 | torch.tensor(s, dtype=torch.float32, device=keypoints.device) / 72 | torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) 73 | for s, s_orig in zip(new_size, original_size) 74 | ] 75 | ratio_h, ratio_w = ratios 76 | resized_data = keypoints.clone() 77 | if torch._C._get_tracing_state(): 78 | resized_data_0 = resized_data[:, :, 0] * ratio_w 79 | resized_data_1 = resized_data[:, :, 1] * ratio_h 80 | resized_data = torch.stack((resized_data_0, resized_data_1, resized_data[:, :, 2]), dim=2) 81 | else: 82 | resized_data[..., 0] *= ratio_w 83 | resized_data[..., 1] *= ratio_h 84 | return resized_data 85 | 86 | 87 | def resize_boxes(boxes, original_size, new_size): 88 | # type: (Tensor, List[int], List[int]) 89 | ratios = [ 90 | torch.tensor(s, dtype=torch.float32, device=boxes.device) / 91 | torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) 92 | for s, s_orig in zip(new_size, original_size) 93 | ] 94 | ratio_height, ratio_width = ratios 95 | xmin, ymin, xmax, ymax = boxes.unbind(1) 96 | 97 | xmin = xmin * ratio_width 98 | xmax = xmax * ratio_width 99 | ymin = ymin * ratio_height 100 | ymax = ymax * ratio_height 101 | return torch.stack((xmin, ymin, xmax, ymax), dim=1) 102 | -------------------------------------------------------------------------------- /thundernet/ShufflenetV2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | from torch.nn import init 7 | import math 8 | 9 | def conv_bn(inp, oup, stride): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 12 | nn.BatchNorm2d(oup), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | def conv_1x1_bn(inp, oup): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 19 | nn.BatchNorm2d(oup), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def channel_shuffle(x, groups): 24 | batchsize, num_channels, height, width = x.data.size() 25 | 26 | channels_per_group = num_channels // groups 27 | 28 | # reshape 29 | x = x.view(batchsize, groups, channels_per_group, height, width) 30 | x = torch.transpose(x, 1, 2).contiguous() 31 | 32 | # flatten 33 | x = x.view(batchsize, -1, height, width) 34 | 35 | return x 36 | 37 | class InvertedResidual(nn.Module): 38 | def __init__(self, inp, oup, stride, benchmodel): 39 | super(InvertedResidual, self).__init__() 40 | self.benchmodel = benchmodel 41 | self.stride = stride 42 | assert stride in [1, 2] 43 | 44 | oup_inc = oup // 2 45 | 46 | if self.benchmodel == 1: # basic unit 47 | self.branch2 = nn.Sequential( 48 | # pw 49 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(oup_inc), 51 | nn.ReLU(inplace=True), 52 | # dw 53 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 54 | nn.BatchNorm2d(oup_inc), 55 | # pw-linear 56 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 57 | nn.BatchNorm2d(oup_inc), 58 | nn.ReLU(inplace=True), 59 | ) 60 | else: # down sample (2x) 61 | self.branch1 = nn.Sequential( 62 | # dw 63 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 64 | nn.BatchNorm2d(inp), 65 | # pw-linear 66 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 67 | nn.BatchNorm2d(oup_inc), 68 | nn.ReLU(inplace=True), 69 | ) 70 | 71 | self.branch2 = nn.Sequential( 72 | # pw 73 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 74 | nn.BatchNorm2d(oup_inc), 75 | nn.ReLU(inplace=True), 76 | # dw 77 | nn.Conv2d(oup_inc, oup_inc, 3, stride, 1, groups=oup_inc, bias=False), 78 | nn.BatchNorm2d(oup_inc), 79 | # pw-linear 80 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 81 | nn.BatchNorm2d(oup_inc), 82 | nn.ReLU(inplace=True), 83 | ) 84 | 85 | @staticmethod 86 | def _concat(x, out): 87 | # concatenate along channel axis 88 | return torch.cat((x, out), 1) 89 | 90 | def forward(self, x): 91 | if 1 == self.benchmodel: 92 | x1 = x[:, :(x.shape[1]//2), :, :] 93 | x2 = x[:, (x.shape[1]//2):, :, :] 94 | out = self._concat(x1, self.branch2(x2)) 95 | elif 2 == self.benchmodel: 96 | out = self._concat(self.branch1(x), self.branch2(x)) 97 | 98 | return channel_shuffle(out, 2) 99 | 100 | class ShuffleNetV2(nn.Module): 101 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 102 | super(ShuffleNetV2, self).__init__() 103 | assert input_size % 32 == 0 104 | 105 | self.stage_repeats = [4, 8, 4] 106 | # index 0 is invalid and should never be called. 107 | # only used for indexing convenience. 108 | if width_mult == 0.5: 109 | self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] 110 | elif width_mult == 1.0: 111 | self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] 112 | elif width_mult == 1.5: 113 | self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] 114 | elif width_mult == 2.0: 115 | self.stage_out_channels = [-1, 24, 224, 488, 976, 2048] 116 | else: 117 | raise ValueError( 118 | """{} groups is not supported for 119 | 1x1 Grouped Convolutions""".format(num_groups)) 120 | 121 | # building first layer 122 | input_channel = self.stage_out_channels[1] 123 | self.conv1 = conv_bn(3, input_channel, 2) 124 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 125 | 126 | self.features = [] 127 | # building inverted residual blocks 128 | for idxstage in range(len(self.stage_repeats)): 129 | numrepeat = self.stage_repeats[idxstage] 130 | output_channel = self.stage_out_channels[idxstage+2] 131 | for i in range(numrepeat): 132 | if i == 0: 133 | # (inp, oup, stride, benchmodel) 134 | self.features.append(InvertedResidual(input_channel, output_channel, 2, 2)) 135 | else: 136 | self.features.append(InvertedResidual(input_channel, output_channel, 1, 1)) 137 | input_channel = output_channel 138 | 139 | # make it nn.Sequential 140 | self.features = nn.Sequential(*self.features) 141 | 142 | # building last several layers 143 | self.conv5 = conv_1x1_bn(input_channel, self.stage_out_channels[-1]) 144 | self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32))) 145 | 146 | # building classifier 147 | self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class)) 148 | 149 | def forward(self, x): 150 | x = self.conv1(x) 151 | x = self.maxpool(x) 152 | x = self.features(x) # stage2, stage3, stage4 153 | x = self.conv5(x) 154 | x = self.globalpool(x) 155 | x = x.view(-1, self.stage_out_channels[-1]) 156 | x = self.classifier(x) 157 | return x 158 | 159 | def Snet(width_mult=1.): 160 | model = ShuffleNetV2(width_mult=width_mult) 161 | return model 162 | 163 | 164 | if __name__ == '__main__': 165 | snet = ShuffleNetV2() 166 | print(snet) 167 | 168 | -------------------------------------------------------------------------------- /thundernet/module.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class CEM(nn.Module): 8 | def __init__(self): 9 | super(CEM, self).__init__() 10 | self.conv1 = nn.Conv2d(120, 245, kernel_size=1, stride=1, padding=0) 11 | self.conv2 = nn.Conv2d(512, 245, kernel_size=1, stride=1, padding=0) 12 | self.avg_pool = nn.AvgPool2d(10) 13 | self.conv3 = nn.Conv2d(512, 245, kernel_size=1, stride=1, padding=0) 14 | 15 | def forward(self, c4_feature, c5_feature): 16 | c4 = c4_feature 17 | c4_lat = self.conv1(c4) # output: [245, 20, 20] 18 | 19 | c5 = c5_feature 20 | c5_lat = self.conv2(c5) # output: [245, 10, 10] 21 | 22 | # upsample x2 23 | c5_lat = F.interpolate(input=c5_lat, size=[20, 20], mode="nearest") # output: [245, 20, 20] 24 | c_glb = self.avg_pool(c5) # output: [512, 1, 1] 25 | c_glb_lat = self.conv3(c_glb) # output: [245, 1, 1] 26 | 27 | out = c4_lat + c5_lat + c_glb_lat # output: [245, 20, 20] 28 | return out 29 | 30 | class SAM(nn.Module): 31 | def __init__(self): 32 | super(SAM, self).__init__() 33 | self.conv = nn.Conv2d(256, 245, 1, 1, 0, bias=False) # input channel = 245 ? 34 | self.bn = nn.BatchNorm2d(245) 35 | self.sigmoid = nn.Sigmoid() 36 | 37 | def forward(self, rpn_feature, cem_feature): 38 | cem = cem_feature # feature map of CEM: [245, 20, 20] 39 | rpn = rpn_feature # feature map of RPN: [256, 20, 20] 40 | 41 | sam = self.conv(rpn) 42 | sam = self.bn(sam) 43 | sam = self.sigmoid(sam) 44 | out = cem * sam # output: [245, 20, 20] 45 | return out 46 | 47 | 48 | class RCNNSubNetHead(nn.Module): 49 | """ 50 | Standard heads for FPN-based models 51 | Arguments: 52 | in_channels (int): number of input channels 53 | representation_size (int): size of the intermediate representation 54 | """ 55 | def __init__(self, in_channels, representation_size): 56 | super(RCNNSubNetHead, self).__init__() 57 | self.fc6 = nn.Linear(in_channels, representation_size) # in_channles: 7*7*5=245 representation_size:1024 58 | 59 | def forward(self, x): 60 | x = x.flatten(start_dim=1) 61 | x = F.relu(self.fc6(x)) 62 | return x 63 | 64 | class ThunderNetPredictor(nn.Module): 65 | """ 66 | Standard classification + bounding box regression layers 67 | for Fast R-CNN. 68 | Arguments: 69 | in_channels (int): number of input channels 70 | num_classes (int): number of output classes (including background) 71 | """ 72 | def __init__(self, in_channels, num_classes): 73 | super(ThunderNetPredictor, self).__init__() 74 | self.cls_score = nn.Linear(in_channels, num_classes) 75 | self.bbox_pred = nn.Linear(in_channels, num_classes * 4) 76 | 77 | def forward(self, x): # x: [1024, 1, 1] 78 | if x.dim() == 4: 79 | assert list(x.shape[2:]) == [1, 1] 80 | x = x.flatten(start_dim=1) 81 | scores = self.cls_score(x) 82 | bbox_deltas = self.bbox_pred(x) 83 | 84 | return scores, bbox_deltas -------------------------------------------------------------------------------- /thundernet/snet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from collections import OrderedDict 6 | from torch.nn import init 7 | import math 8 | 9 | def conv_bn(inp, oup, stride): 10 | return nn.Sequential( 11 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 12 | nn.BatchNorm2d(oup), 13 | nn.ReLU(inplace=True) 14 | ) 15 | 16 | def conv_1x1_bn(inp, oup): 17 | return nn.Sequential( 18 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 19 | nn.BatchNorm2d(oup), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def channel_shuffle(x, groups): 24 | batchsize, num_channels, height, width = x.data.size() 25 | channels_per_group = num_channels // groups 26 | 27 | # reshape 28 | x = x.view(batchsize, groups, channels_per_group, height, width) 29 | x = torch.transpose(x, 1, 2).contiguous() 30 | 31 | # flatten 32 | x = x.view(batchsize, -1, height, width) 33 | return x 34 | 35 | class InvertedResidual(nn.Module): 36 | def __init__(self, inp, oup, stride, benchmodel): 37 | super(InvertedResidual, self).__init__() 38 | self.benchmodel = benchmodel 39 | self.stride = stride 40 | assert stride in [1, 2] 41 | oup_inc = oup // 2 42 | 43 | if self.benchmodel == 1: # basic unit 44 | self.branch2 = nn.Sequential( 45 | # pw 46 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(oup_inc), 48 | nn.ReLU(inplace=True), 49 | # dw 50 | nn.Conv2d(oup_inc, oup_inc, 5, stride, padding=2, groups=oup_inc, bias=False), # padding=2 51 | nn.BatchNorm2d(oup_inc), 52 | # pw-linear 53 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 54 | nn.BatchNorm2d(oup_inc), 55 | nn.ReLU(inplace=True), 56 | ) 57 | else: # down sample (2x) 58 | self.branch1 = nn.Sequential( 59 | # dw 60 | nn.Conv2d(inp, inp, 5, stride, padding=2, groups=inp, bias=False), 61 | nn.BatchNorm2d(inp), 62 | # pw-linear 63 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 64 | nn.BatchNorm2d(oup_inc), 65 | nn.ReLU(inplace=True), 66 | ) 67 | 68 | self.branch2 = nn.Sequential( 69 | # pw 70 | nn.Conv2d(inp, oup_inc, 1, 1, 0, bias=False), 71 | nn.BatchNorm2d(oup_inc), 72 | nn.ReLU(inplace=True), 73 | # dw 74 | nn.Conv2d(oup_inc, oup_inc, 5, stride, padding=2, groups=oup_inc, bias=False), 75 | nn.BatchNorm2d(oup_inc), 76 | # pw-linear 77 | nn.Conv2d(oup_inc, oup_inc, 1, 1, 0, bias=False), 78 | nn.BatchNorm2d(oup_inc), 79 | nn.ReLU(inplace=True), 80 | ) 81 | 82 | @staticmethod 83 | def _concat(x, out): 84 | # concatenate along channel axis 85 | return torch.cat((x, out), 1) 86 | 87 | def forward(self, x): 88 | if 1 == self.benchmodel: 89 | x1 = x[:, :(x.shape[1]//2), :, :] 90 | x2 = x[:, (x.shape[1]//2):, :, :] 91 | out = self._concat(x1, self.branch2(x2)) 92 | elif 2 == self.benchmodel: 93 | out = self._concat(self.branch1(x), self.branch2(x)) 94 | 95 | return channel_shuffle(out, 2) 96 | 97 | class SNet49(nn.Module): 98 | def __init__(self, n_class=1024, input_size=224): 99 | super(SNet49, self).__init__() 100 | assert input_size % 32 == 0 101 | self.stage_repeats = [4, 8, 4] 102 | 103 | # index 0 is invalid and should never be called. 104 | # only used for indexing convenience. 105 | self.stage_out_channels = [-1, 24, 60, 120, 240, 512] 106 | 107 | # building first layer 108 | input_channel = self.stage_out_channels[1] 109 | self.conv1 = conv_bn(3, input_channel, 2) 110 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 111 | 112 | self.features1 = [] 113 | self.features2 = [] 114 | self.features3 = [] 115 | 116 | # stage2 117 | numrepeat = self.stage_repeats[0] 118 | output_channel = self.stage_out_channels[2] 119 | for i in range(numrepeat): 120 | if i == 0: 121 | # (inp, oup, stride, benchmodel) 122 | self.features1.append(InvertedResidual(input_channel, output_channel, 2, 2)) 123 | else: 124 | self.features1.append(InvertedResidual(input_channel, output_channel, 1, 1)) 125 | input_channel = output_channel 126 | 127 | # stage3 128 | numrepeat = self.stage_repeats[1] 129 | output_channel = self.stage_out_channels[3] 130 | for i in range(numrepeat): 131 | if i == 0: 132 | # (inp, oup, stride, benchmodel) 133 | self.features2.append(InvertedResidual(input_channel, output_channel, 2, 2)) 134 | else: 135 | self.features2.append(InvertedResidual(input_channel, output_channel, 1, 1)) 136 | input_channel = output_channel 137 | 138 | # stage4 139 | numrepeat = self.stage_repeats[2] 140 | output_channel = self.stage_out_channels[4] 141 | for i in range(numrepeat): 142 | if i == 0: 143 | # (inp, oup, stride, benchmodel) 144 | self.features3.append(InvertedResidual(input_channel, output_channel, 2, 2)) 145 | else: 146 | self.features3.append(InvertedResidual(input_channel, output_channel, 1, 1)) 147 | input_channel = output_channel 148 | 149 | # make it nn.Sequential 150 | self.features1 = nn.Sequential(*self.features1) 151 | self.features2 = nn.Sequential(*self.features2) 152 | self.features3 = nn.Sequential(*self.features3) 153 | 154 | # building last several layers 155 | self.conv5 = conv_1x1_bn(input_channel, self.stage_out_channels[-1]) 156 | self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32))) 157 | 158 | # building classifier 159 | self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class)) 160 | 161 | def forward(self, x): 162 | x = self.conv1(x) 163 | x = self.maxpool(x) 164 | x = self.features1(x) # stage2 165 | x = self.features2(x) # stage3 166 | out_c4 = x 167 | 168 | x = self.features3(x) # stage4 169 | x = self.conv5(x) 170 | out_c5 = x 171 | 172 | x = self.globalpool(x) 173 | x = x.view(-1, self.stage_out_channels[-1]) 174 | x = self.classifier(x) 175 | 176 | return x, out_c4, out_c5 177 | 178 | 179 | class SNet146(nn.Module): 180 | def __init__(self, n_class=1000, input_size=224): 181 | super(SNet146, self).__init__() 182 | assert input_size % 32 == 0 183 | 184 | self.stage_repeats = [4, 8, 4] 185 | 186 | # index 0 is invalid and should never be called. 187 | # only used for indexing convenience. 188 | self.stage_out_channels = [-1, 24, 132, 264, 528, 1024] 189 | 190 | # building first layer 191 | input_channel = self.stage_out_channels[1] 192 | self.conv1 = conv_bn(3, input_channel, 2) 193 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 194 | 195 | self.features = [] 196 | # building inverted residual blocks 197 | for idxstage in range(len(self.stage_repeats)): 198 | numrepeat = self.stage_repeats[idxstage] 199 | output_channel = self.stage_out_channels[idxstage+2] 200 | for i in range(numrepeat): 201 | if i == 0: 202 | # (inp, oup, stride, benchmodel) 203 | self.features.append(InvertedResidual(input_channel, output_channel, 2, 2)) 204 | else: 205 | self.features.append(InvertedResidual(input_channel, output_channel, 1, 1)) 206 | input_channel = output_channel 207 | 208 | # make it nn.Sequential 209 | self.features = nn.Sequential(*self.features) 210 | 211 | # building last several layers 212 | self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32))) 213 | 214 | # building classifier 215 | self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class)) 216 | 217 | def forward(self, x): 218 | x = self.conv1(x) 219 | x = self.maxpool(x) 220 | x = self.features(x) # stage2, stage3, stage4 221 | x = self.globalpool(x) 222 | x = x.view(-1, self.stage_out_channels[-1]) 223 | x = self.classifier(x) 224 | return x 225 | 226 | class SNet535(nn.Module): 227 | def __init__(self, n_class=1000, input_size=224): 228 | super(SNet535, self).__init__() 229 | assert input_size % 32 == 0 230 | 231 | self.stage_repeats = [4, 8, 4] 232 | 233 | # index 0 is invalid and should never be called. 234 | # only used for indexing convenience. 235 | self.stage_out_channels = [-1, 48, 248, 496, 992, 1024] 236 | 237 | # building first layer 238 | input_channel = self.stage_out_channels[1] 239 | self.conv1 = conv_bn(3, input_channel, 2) 240 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 241 | 242 | self.features = [] 243 | # building inverted residual blocks 244 | for idxstage in range(len(self.stage_repeats)): 245 | numrepeat = self.stage_repeats[idxstage] 246 | output_channel = self.stage_out_channels[idxstage+2] 247 | for i in range(numrepeat): 248 | if i == 0: 249 | # (inp, oup, stride, benchmodel) 250 | self.features.append(InvertedResidual(input_channel, output_channel, 2, 2)) 251 | else: 252 | self.features.append(InvertedResidual(input_channel, output_channel, 1, 1)) 253 | input_channel = output_channel 254 | 255 | # make it nn.Sequential 256 | self.features = nn.Sequential(*self.features) 257 | 258 | # building last several layers 259 | self.globalpool = nn.Sequential(nn.AvgPool2d(int(input_size/32))) 260 | 261 | # building classifier 262 | self.classifier = nn.Sequential(nn.Linear(self.stage_out_channels[-1], n_class)) 263 | 264 | def forward(self, x): 265 | x = self.conv1(x) 266 | x = self.maxpool(x) 267 | x = self.features(x) # stage2, stage3, stage4 268 | x = self.globalpool(x) 269 | x = x.view(-1, self.stage_out_channels[-1]) 270 | x = self.classifier(x) 271 | return x 272 | 273 | 274 | """ 275 | if __name__ == '__main__': 276 | img = torch.randn(1, 3, 224, 224) 277 | snet = ShuffleNetV2() 278 | feature, out1, out2 = snet(img) 279 | print(snet) 280 | """ 281 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import os 7 | import numpy as np 8 | import argparse 9 | import torch 10 | 11 | import torch.optim as optim 12 | from torchvision import transforms 13 | from torch.utils.data import DataLoader 14 | from tensorboardX import SummaryWriter 15 | 16 | from detector import ThunderNet 17 | from load_data import CocoDataset, Resizer, Normalizer, Augmenter, collater 18 | from tqdm.autonotebook import tqdm 19 | 20 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | use_cuda = torch.cuda.is_available() 22 | print('device available: {}'.format(device)) 23 | 24 | def parse_args(): 25 | """ 26 | Parse input arguments 27 | """ 28 | parser = argparse.ArgumentParser(description='Simple training parameter for training a SNet.') 29 | 30 | parser.add_argument('--data_path', type=str, default='data/COCO', help='the path folder of dataset') 31 | parser.add_argument('--batch_size', help='Batch size', type=int, default=64) 32 | parser.add_argument('--epochs', help='Number of epochs', type=int, default=100) 33 | parser.add_argument('--start_epoch', help='start epoch', type=int, default=1) 34 | parser.add_argument('--gpus', help='Use CUDA on the listed devides', nargs='+', type=int, default=[]) 35 | parser.add_argument('--seed', help='Random seed', type=int, default=1234) 36 | parser.add_argument('--saved_path', help='save path', type=str, default='./checkpoint') 37 | parser.add_argument('--lr', help='learning rate', type=float, default=1e-4) 38 | 39 | if len(sys.argv) == 1: 40 | parser.print_help() 41 | sys.exit(1) 42 | 43 | args = parser.parse_args() 44 | return args 45 | 46 | def main(args=None): 47 | transform_train = transforms.Compose([ 48 | Normalizer(), 49 | Augmenter(), 50 | Resizer() 51 | ]) 52 | 53 | transform_test = transforms.Compose([ 54 | Normalizer(), 55 | Resizer() 56 | ]) 57 | 58 | num_gpus = 1 59 | train_params = { 60 | "batch_size": 64, 61 | "shuffle": True, 62 | "drop_last": True, 63 | "collate_fn": collater, 64 | #"num_workers": 1, # bug ??? 65 | } 66 | 67 | test_params = { 68 | "batch_size": args.batch_size, 69 | "shuffle": False, 70 | "drop_last": False, 71 | "collate_fn": collater, 72 | #"num_workers": 1, 73 | } 74 | 75 | train_set = CocoDataset(root_dir=args.data_path, set_name='train2017', transform=transform_train) 76 | val_set = CocoDataset(root_dir=args.data_path, set_name='val2017', transform=transform_test) 77 | 78 | train_loader = DataLoader(dataset=train_set, **train_params) 79 | test_loader = DataLoader(dataset=val_set, **test_params) 80 | 81 | num_iter = len(train_loader) 82 | 83 | model = ThunderNet() 84 | 85 | save_path = args.saved_path 86 | if not os.path.isdir(args.saved_path): 87 | os.makedirs(args.saved_path) 88 | 89 | if use_cuda: 90 | torch.cuda.set_device(args.gpus[0]) 91 | torch.cuda.manual_seed(args.seed) 92 | model = model.cuda() 93 | #model = torch.nn.DataParallel(model) 94 | 95 | # optimizer 96 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.0005) 97 | 98 | # update lr 99 | milestones = [500, 800, 1200, 1500] 100 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1) 101 | 102 | writer = SummaryWriter(log_dir='./checkpoint/summary') 103 | 104 | for epoch in range(args.start_epoch, 50): 105 | train_loss = train(train_loader, model, optimizer, args, num_iter, epoch, scheduler) 106 | test(test_loader, model) 107 | 108 | writer.add_scalar('train loss', train_loss) 109 | scheduler.step() 110 | 111 | save_name = '{}/thundernet_{}.pth.tar'.format(save_path, epoch) 112 | save_checkpoint({ 113 | 'epoch': epoch, 114 | 'state_dict': model.state_dict(), 115 | 'optimizer': optimizer.state_dict(), 116 | }, filename=save_name) 117 | 118 | writer.export_scalars_to_json('./checkpoint/summary/' + 'pretrain' + 'all_scalars.json') 119 | writer.close() 120 | 121 | 122 | def train(train_loader, model, optimizer, args, num_iter, epoch, scheduler): 123 | model.train() 124 | epoch_loss = [] 125 | 126 | losses = {} 127 | progress_bar = tqdm(train_loader) 128 | for i, data in enumerate(progress_bar): 129 | 130 | input_data = data['img'].cuda().float() 131 | input_labels = data['annot'].cuda() 132 | 133 | detector_losses, proposal_losses = model(input_data, input_labels) 134 | 135 | losses.update(detector_losses) 136 | losses.update(proposal_losses) 137 | #print(detector_losses) 138 | #print(proposal_losses) 139 | 140 | total_loss = sum(loss for loss in losses.values()) 141 | 142 | optimizer.zero_grad() 143 | total_loss.backward() 144 | optimizer.step() 145 | 146 | epoch_loss.append(total_loss.item()) 147 | if (i+1)%50 == 0: 148 | learning_rate = scheduler.get_last_lr()[0] # get learning rate 149 | detector_loss = sum(loss for loss in detector_losses.values()) 150 | proposal_loss = sum(loss for loss in proposal_losses.values()) 151 | 152 | print('Epoch: {}/{} | Iter: {}/{} | total loss: {:.3f} | det loss: {:.3f} | proposal loss: {:.3f}'.format( 153 | epoch, args.epochs, (i+1), num_iter, total_loss.item(), 154 | detector_loss.item(), proposal_loss.item())) 155 | 156 | train_loss = np.mean(epoch_loss) 157 | return train_loss 158 | 159 | def test(test_loader, model): 160 | model.eval() 161 | all_loss = [] 162 | losses = {} 163 | progress_bar = tqdm(test_loader) 164 | for i, data in enumerate(progress_bar): 165 | with torch.no_grad(): 166 | input_data = data['img'].cuda().float() 167 | input_labels = data['annot'].cuda() 168 | 169 | detector_losses, proposal_losses = model(input_data, input_labels) 170 | losses.update(detector_losses) 171 | losses.update(proposal_losses) 172 | 173 | #print(detector_losses) 174 | #print(proposal_losses) 175 | total_loss = sum(loss for loss in losses.values()) 176 | #print('total loss: ', total_loss) 177 | 178 | all_loss.append(total_loss.item()) 179 | #cls_loss = cls_loss.mean() 180 | #reg_loss = reg_loss.mean() 181 | 182 | mean_loss = np.mean(all_loss) 183 | print('test loss: {:1.5f}'.format(mean_loss)) 184 | 185 | def save_checkpoint(state, filename): 186 | print('save model: {}\n'.format(filename)) 187 | torch.save(state, filename) 188 | 189 | if __name__ == '__main__': 190 | args = parse_args() 191 | main(args) -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #python train.py --data_path 'data/COCO' --batch_size 128 --epochs 100 --gpus 0 2 | python train.py --data_path 'data/COCO' \ 3 | --batch_size 128 \ 4 | --epochs 100 \ 5 | --gpus 0 \ 6 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liguiyuan/Thundernet-pytorch/0d62cd55430d5ce55560c1efc43d552b2d0b6671/utils/losses.py --------------------------------------------------------------------------------