├── .gitignore ├── README.md ├── STMask.py ├── backbone.py ├── datasets ├── __init__.py ├── bbox_overlaps.py ├── concat_dataset.py ├── config.py ├── custom.py ├── extra_aug.py ├── loader │ ├── __init__.py │ ├── build_loader.py │ └── sampler.py ├── repeat_dataset.py ├── transforms.py ├── utils.py ├── version.py └── ytvos.py ├── environment.yml ├── eval.py ├── images └── overall1.png ├── layers ├── __init__.py ├── box_utils.py ├── display_gt_annotations.py ├── eval_utils.py ├── functions │ ├── TF_utils.py │ ├── __init__.py │ ├── detection.py │ ├── detection_TF.py │ ├── track.py │ └── track_TF.py ├── interpolate.py ├── mask_utils.py ├── modules │ ├── FPN.py │ ├── FastMaskIoUNet.py │ ├── Featurealign.py │ ├── __init__.py │ ├── make_net.py │ ├── multibox_loss.py │ ├── prediction_head.py │ ├── prediction_head_FC.py │ └── track_to_segment_head.py ├── output_utils.py ├── track_utils.py ├── train_output_utils.py └── visualization.py ├── scripts ├── augment_bbox.py ├── bbox_recall.py ├── cluster_bbox_sizes.py ├── compute_masks.py ├── convert_darknet.py ├── eval.sh ├── make_grid.py ├── optimize_bboxes.py ├── parse_eval.py ├── plot_loss.py ├── resume.sh ├── save_bboxes.py ├── train.sh └── unpack_statedict.py ├── train.py ├── utils ├── __init__.py ├── augmentations.py ├── cython_nms.pyx ├── functions.py ├── logger.py ├── nvinfo.py └── timer.py └── web ├── css ├── index.css ├── list.css ├── toggle.css └── viewer.css ├── dets ├── yolact_base.json ├── yolact_darknet53.json ├── yolact_im700.json └── yolact_resnet50.json ├── index.html ├── iou.html ├── scripts ├── index.js ├── iou.js ├── jquery.js ├── js.cookie.js ├── utils.js └── viewer.js ├── server.py └── viewer.html /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # dirs 31 | weights/ 32 | cocoapi/ 33 | external/ 34 | results/ 35 | .idea/ 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # Pyre type checker 136 | .pyre/ 137 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .utils import to_tensor, random_scale, show_ann, get_dataset 3 | 4 | from .custom import CustomDataset 5 | from .ytvos import YTVOSDataset 6 | from .loader import GroupSampler, DistributedGroupSampler, build_dataloader 7 | from .utils import to_tensor, random_scale, show_ann, get_dataset, prepare_data 8 | from .concat_dataset import ConcatDataset 9 | from .repeat_dataset import RepeatDataset 10 | from .extra_aug import ExtraAugmentation 11 | 12 | __all__ = [ 13 | 'cfg', 'MEANS', 'STD', 'set_cfg', 'set_dataset', 'detection_collate', 14 | 'CustomDataset', 'YTVOSDataset', 15 | 'GroupSampler', 'DistributedGroupSampler', 'build_dataloader', 16 | 'to_tensor', 'random_scale', 'show_ann', 'get_dataset', 'prepare_data', 17 | 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation' 18 | ] 19 | 20 | import torch 21 | def detection_collate(batch): 22 | """Custom collate fn for dealing with batches of images that have a different 23 | number of associated object annotations (bounding boxes). 24 | 25 | Arguments: 26 | batch: (tuple) A tuple of tensor images and (lists of annotations, masks) 27 | 28 | Return: 29 | A tuple containing: 30 | 1) (tensor) batch of images stacked on their 0 dim 31 | 2) (list, list, list) annotations for a given image are stacked 32 | on 0 dim. The output gt is a tuple of annotations and masks. 33 | """ 34 | batch_out = {} 35 | # batch_out['img'] = torch.cat([batch[i]['img'].data for i in range(batch_size)]) 36 | # if 'ref_imgs' in batch[0].keys(): 37 | # batch_out['ref_imgs'] = torch.cat([batch[i]['ref_imgs'].data for i in range(batch_size)]) 38 | 39 | for k in batch[0].keys(): 40 | batch_out[k] = [] 41 | 42 | for i in range(len(batch)): 43 | for k in batch_out.keys(): 44 | if isinstance(batch[i][k], list): 45 | batch_out[k].append(batch[i][k]) 46 | else: 47 | batch_out[k].append(batch[i][k].data) 48 | 49 | return batch_out 50 | 51 | -------------------------------------------------------------------------------- /datasets/bbox_overlaps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def bbox_overlaps(bboxes1, bboxes2, mode='iou'): 5 | """Calculate the ious between each bbox of bboxes1 and bboxes2. 6 | 7 | Args: 8 | bboxes1(ndarray): shape (n, 4) 9 | bboxes2(ndarray): shape (k, 4) 10 | mode(str): iou (intersection over union) or iof (intersection 11 | over foreground) 12 | 13 | Returns: 14 | ious(ndarray): shape (n, k) 15 | """ 16 | 17 | assert mode in ['iou', 'iof'] 18 | 19 | bboxes1 = bboxes1.astype(np.float32) 20 | bboxes2 = bboxes2.astype(np.float32) 21 | rows = bboxes1.shape[0] 22 | cols = bboxes2.shape[0] 23 | ious = np.zeros((rows, cols), dtype=np.float32) 24 | if rows * cols == 0: 25 | return ious 26 | exchange = False 27 | if bboxes1.shape[0] > bboxes2.shape[0]: 28 | bboxes1, bboxes2 = bboxes2, bboxes1 29 | ious = np.zeros((cols, rows), dtype=np.float32) 30 | exchange = True 31 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( 32 | bboxes1[:, 3] - bboxes1[:, 1] + 1) 33 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( 34 | bboxes2[:, 3] - bboxes2[:, 1] + 1) 35 | for i in range(bboxes1.shape[0]): 36 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) 37 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) 38 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) 39 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) 40 | overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( 41 | y_end - y_start + 1, 0) 42 | if mode == 'iou': 43 | union = area1[i] + area2 - overlap 44 | else: 45 | union = area1[i] if not exchange else area2 46 | ious[i, :] = overlap / union 47 | if exchange: 48 | ious = ious.T 49 | return ious 50 | -------------------------------------------------------------------------------- /datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 3 | 4 | 5 | class ConcatDataset(_ConcatDataset): 6 | """A wrapper of concatenated dataset. 7 | 8 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 9 | concat the group flag for image aspect ratio. 10 | 11 | Args: 12 | datasets (list[:obj:`Dataset`]): A list of datasets. 13 | """ 14 | 15 | def __init__(self, datasets): 16 | super(ConcatDataset, self).__init__(datasets) 17 | self.CLASSES = datasets[0].CLASSES 18 | if hasattr(datasets[0], 'flag'): 19 | flags = [] 20 | for i in range(0, len(datasets)): 21 | flags.append(datasets[i].flag) 22 | self.flag = np.concatenate(flags) 23 | -------------------------------------------------------------------------------- /datasets/extra_aug.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | from numpy import random 4 | 5 | from datasets.bbox_overlaps import bbox_overlaps 6 | 7 | 8 | class PhotoMetricDistortion(object): 9 | 10 | def __init__(self, 11 | brightness_delta=32, 12 | contrast_range=(0.5, 1.5), 13 | saturation_range=(0.5, 1.5), 14 | hue_delta=18): 15 | self.brightness_delta = brightness_delta 16 | self.contrast_lower, self.contrast_upper = contrast_range 17 | self.saturation_lower, self.saturation_upper = saturation_range 18 | self.hue_delta = hue_delta 19 | 20 | def __call__(self, img, boxes, labels, masks, ids): 21 | # random brightness 22 | if random.randint(2): 23 | delta = random.uniform(-self.brightness_delta, 24 | self.brightness_delta) 25 | img += delta 26 | 27 | # mode == 0 --> do random contrast first 28 | # mode == 1 --> do random contrast last 29 | mode = random.randint(2) 30 | if mode == 1: 31 | if random.randint(2): 32 | alpha = random.uniform(self.contrast_lower, 33 | self.contrast_upper) 34 | img *= alpha 35 | 36 | # convert color from BGR to HSV 37 | img = mmcv.bgr2hsv(img) 38 | 39 | # random saturation 40 | if random.randint(2): 41 | img[..., 1] *= random.uniform(self.saturation_lower, 42 | self.saturation_upper) 43 | 44 | # random hue 45 | if random.randint(2): 46 | img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) 47 | img[..., 0][img[..., 0] > 360] -= 360 48 | img[..., 0][img[..., 0] < 0] += 360 49 | 50 | # convert color from HSV to BGR 51 | img = mmcv.hsv2bgr(img) 52 | 53 | # random contrast 54 | if mode == 0: 55 | if random.randint(2): 56 | alpha = random.uniform(self.contrast_lower, 57 | self.contrast_upper) 58 | img *= alpha 59 | 60 | # randomly swap channels 61 | if random.randint(2): 62 | img = img[..., random.permutation(3)] 63 | 64 | return img, boxes, labels, masks, ids 65 | 66 | 67 | class Expand(object): 68 | 69 | def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): 70 | if to_rgb: 71 | self.mean = mean[::-1] 72 | else: 73 | self.mean = mean 74 | self.min_ratio, self.max_ratio = ratio_range 75 | 76 | def __call__(self, img, boxes, labels, masks, ids): 77 | if random.randint(2): 78 | return img, boxes, labels, masks, ids 79 | 80 | h, w, c = img.shape 81 | ratio = random.uniform(self.min_ratio, self.max_ratio) 82 | expand_img = np.full((int(h * ratio), int(w * ratio), c), 83 | self.mean).astype(img.dtype) 84 | left = int(random.uniform(0, w * ratio - w)) 85 | top = int(random.uniform(0, h * ratio - h)) 86 | expand_img[top:top + h, left:left + w] = img 87 | img = mmcv.imresize(expand_img, (w, h), interpolation='nearest') 88 | boxes = np.rint((boxes + np.tile((left, top), 2)) / ratio).astype(boxes.dtype) 89 | expand_masks = [] 90 | for i in range(len(masks)): 91 | expand_mask_cur = np.full((int(h * ratio), int(w * ratio)), 0).astype(masks[i].dtype) 92 | expand_mask_cur[top:top + h, left:left + w] = masks[i] 93 | expand_masks.append(mmcv.imresize(expand_mask_cur, (w, h), interpolation='nearest')) 94 | masks = expand_masks 95 | return img, boxes, labels, masks, ids 96 | 97 | 98 | class RandomCrop(object): 99 | 100 | def __init__(self, 101 | min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), 102 | min_crop_size=0.3): 103 | # 1: return ori img 104 | self.sample_mode = (1, *min_ious, 0) 105 | self.min_crop_size = min_crop_size 106 | 107 | def __call__(self, img, boxes, labels, masks, ids): 108 | h, w, c = img.shape 109 | while True: 110 | mode = random.choice(self.sample_mode) 111 | if mode == 1: 112 | return img, boxes, labels, masks, ids 113 | 114 | min_iou = mode 115 | for i in range(50): 116 | new_w = random.uniform(self.min_crop_size * w, w) 117 | new_h = random.uniform(self.min_crop_size * h, h) 118 | 119 | # h / w in [0.5, 2] 120 | if new_h / new_w < 0.5 or new_h / new_w > 2: 121 | continue 122 | 123 | left = random.uniform(w - new_w) 124 | top = random.uniform(h - new_h) 125 | 126 | patch = np.array((int(left), int(top), int(left + new_w), 127 | int(top + new_h))) 128 | overlaps = bbox_overlaps( 129 | patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1) 130 | if overlaps.min() < min_iou: 131 | continue 132 | 133 | # center of boxes should inside the crop img 134 | center = (boxes[:, :2] + boxes[:, 2:]) / 2 135 | mask = (center[:, 0] > patch[0]) * ( 136 | center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * ( 137 | center[:, 1] < patch[3]) 138 | if not mask.any(): 139 | continue 140 | boxes = boxes[mask] 141 | labels = labels[mask] 142 | masks = np.array(masks)[mask] 143 | ids = np.array(ids)[mask].tolist() 144 | 145 | # adjust boxes 146 | img_crop = np.zeros(img.shape) 147 | img_crop[patch[1]:patch[3], patch[0]:patch[2]] = img[patch[1]:patch[3], patch[0]:patch[2]] 148 | boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) 149 | boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) 150 | boxes -= np.tile(patch[:2], 2) 151 | 152 | # adjust masks 153 | masks_crop = np.zeros(masks.shape) 154 | masks_crop[:, patch[1]:patch[3], patch[0]:patch[2]] = masks[:, patch[1]:patch[3], patch[0]:patch[2]] 155 | masks_crop_list = [] 156 | for m in range(len(masks)): 157 | masks_crop_list.append(masks_crop[m]) 158 | return img_crop, boxes, labels, masks_crop_list, ids 159 | 160 | 161 | class ExtraAugmentation(object): 162 | 163 | def __init__(self, 164 | photo_metric_distortion=None, 165 | expand=None, 166 | random_crop=None): 167 | self.transforms = [] 168 | if photo_metric_distortion is not None: 169 | self.transforms.append( 170 | PhotoMetricDistortion(**photo_metric_distortion)) 171 | if expand is not None: 172 | self.transforms.append(Expand(**expand)) 173 | if random_crop is not None: 174 | self.transforms.append(RandomCrop(**random_crop)) 175 | 176 | def __call__(self, img, boxes, labels, masks, ids): 177 | img = img.astype(np.float32) 178 | for transform in self.transforms: 179 | img, boxes, labels, masks, ids = transform(img, boxes, labels, masks, ids) 180 | return img, boxes, labels, masks, ids 181 | -------------------------------------------------------------------------------- /datasets/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_loader import build_dataloader 2 | from .sampler import GroupSampler, DistributedGroupSampler 3 | 4 | __all__ = [ 5 | 'GroupSampler', 'DistributedGroupSampler', 'build_dataloader' 6 | ] 7 | -------------------------------------------------------------------------------- /datasets/loader/build_loader.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from mmcv.runner import get_dist_info 4 | from mmcv.parallel import collate 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | from .sampler import GroupSampler, DistributedGroupSampler 8 | 9 | # https://github.com/pytorch/pytorch/issues/973 10 | import resource 11 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 12 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 13 | 14 | 15 | def build_dataloader(dataset, 16 | batch_size, 17 | num_workers, 18 | dist=True, 19 | shuffle_sampler=True, 20 | **kwargs): 21 | if dist: 22 | rank, world_size = get_dist_info() 23 | sampler = DistributedGroupSampler(dataset, batch_size, world_size, 24 | rank) 25 | 26 | else: 27 | if not kwargs.get('shuffle', True): 28 | sampler = None 29 | else: 30 | sampler = GroupSampler(dataset, batch_size, shuffle=shuffle_sampler) 31 | 32 | data_loader = DataLoader( 33 | dataset, 34 | batch_size=batch_size, 35 | sampler=sampler, 36 | num_workers=num_workers, 37 | collate_fn=collate, 38 | pin_memory=False, 39 | **kwargs) 40 | 41 | return data_loader 42 | -------------------------------------------------------------------------------- /datasets/loader/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import torch 5 | import numpy as np 6 | 7 | from torch.distributed import get_world_size, get_rank 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class GroupSampler(Sampler): 12 | 13 | def __init__(self, dataset, samples_per_gpu=1, shuffle=True): 14 | # assert hasattr(dataset, 'flag') 15 | self.dataset = dataset 16 | self.samples_per_gpu = samples_per_gpu 17 | self.shuffle = shuffle 18 | if hasattr(dataset, 'flag'): 19 | self.flag = dataset.flag.astype(np.int64) 20 | else: 21 | self.flag = np.ones(len(dataset)).astype(np.int64) 22 | self.group_sizes = np.bincount(self.flag) 23 | self.num_samples = 0 24 | for i, size in enumerate(self.group_sizes): 25 | self.num_samples += int(np.ceil( 26 | size / self.samples_per_gpu)) * self.samples_per_gpu 27 | 28 | def __iter__(self): 29 | indices = [] 30 | for i, size in enumerate(self.group_sizes): 31 | if size == 0: 32 | continue 33 | indice = np.where(self.flag == i)[0] 34 | assert len(indice) == size 35 | num_extra = int(np.ceil(size / self.samples_per_gpu) 36 | ) * self.samples_per_gpu - len(indice) 37 | if self.shuffle: 38 | np.random.shuffle(indice) 39 | indice = np.concatenate([indice, indice[:num_extra]]) 40 | indices.append(indice) 41 | indices = np.concatenate(indices) 42 | if self.shuffle: 43 | indices = [ 44 | indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu] 45 | for i in np.random.permutation( 46 | range(len(indices) // self.samples_per_gpu)) 47 | ] 48 | indices = np.concatenate(indices) 49 | indices = torch.from_numpy(indices).long() 50 | assert len(indices) == self.num_samples 51 | return iter(indices) 52 | 53 | def __len__(self): 54 | return self.num_samples 55 | 56 | 57 | class DistributedGroupSampler(Sampler): 58 | """Sampler that restricts data loading to a subset of the dataset. 59 | It is especially useful in conjunction with 60 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 61 | process can pass a DistributedSampler instance as a DataLoader sampler, 62 | and load a subset of the original dataset that is exclusive to it. 63 | .. note:: 64 | Dataset is assumed to be of constant size. 65 | Arguments: 66 | dataset: Dataset used for sampling. 67 | num_replicas (optional): Number of processes participating in 68 | distributed training. 69 | rank (optional): Rank of the current process within num_replicas. 70 | """ 71 | 72 | def __init__(self, 73 | dataset, 74 | samples_per_gpu=1, 75 | num_replicas=None, 76 | rank=None): 77 | if num_replicas is None: 78 | num_replicas = get_world_size() 79 | if rank is None: 80 | rank = get_rank() 81 | self.dataset = dataset 82 | self.samples_per_gpu = samples_per_gpu 83 | self.num_replicas = num_replicas 84 | self.rank = rank 85 | self.epoch = 0 86 | 87 | assert hasattr(self.dataset, 'flag') 88 | self.flag = self.dataset.flag 89 | self.group_sizes = np.bincount(self.flag) 90 | 91 | self.num_samples = 0 92 | for i, j in enumerate(self.group_sizes): 93 | self.num_samples += int( 94 | math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / 95 | self.num_replicas)) * self.samples_per_gpu 96 | self.total_size = self.num_samples * self.num_replicas 97 | 98 | def __iter__(self): 99 | # deterministically shuffle based on epoch 100 | g = torch.Generator() 101 | g.manual_seed(self.epoch) 102 | 103 | indices = [] 104 | for i, size in enumerate(self.group_sizes): 105 | if size > 0: 106 | indice = np.where(self.flag == i)[0] 107 | assert len(indice) == size 108 | a = torch.randperm(int(size), generator=g).tolist() 109 | indice = indice[torch.randperm(int(size), generator=g).tolist()].tolist() 110 | extra = int( 111 | math.ceil( 112 | size * 1.0 / self.samples_per_gpu / self.num_replicas) 113 | ) * self.samples_per_gpu * self.num_replicas - len(indice) 114 | indice += indice[:extra] 115 | indices += indice 116 | 117 | assert len(indices) == self.total_size 118 | 119 | indices = [ 120 | indices[j] for i in list( 121 | torch.randperm( 122 | len(indices) // self.samples_per_gpu, generator=g)) 123 | for j in range(i * self.samples_per_gpu, (i + 1) * 124 | self.samples_per_gpu) 125 | ] 126 | 127 | # subsample 128 | offset = self.num_samples * self.rank 129 | indices = indices[offset:offset + self.num_samples] 130 | assert len(indices) == self.num_samples 131 | 132 | return iter(indices) 133 | 134 | def __len__(self): 135 | return self.num_samples 136 | 137 | def set_epoch(self, epoch): 138 | self.epoch = epoch 139 | -------------------------------------------------------------------------------- /datasets/repeat_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RepeatDataset(object): 5 | 6 | def __init__(self, dataset, times): 7 | self.dataset = dataset 8 | self.times = times 9 | self.CLASSES = dataset.CLASSES 10 | if hasattr(self.dataset, 'flag'): 11 | self.flag = np.tile(self.dataset.flag, times) 12 | 13 | self._ori_len = len(self.dataset) 14 | 15 | def __getitem__(self, idx): 16 | return self.dataset[idx % self._ori_len] 17 | 18 | def __len__(self): 19 | return self.times * self._ori_len 20 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | import torch 4 | 5 | __all__ = ['ImageTransform', 'BboxTransform', 'MaskTransform', 'Numpy2Tensor'] 6 | 7 | 8 | class ImageTransform(object): 9 | """Preprocess an image. 10 | 11 | 1. rescale the image to expected size 12 | 2. normalize the image 13 | 3. flip the image (if needed) 14 | 4. pad the image (if needed) 15 | 5. transpose to (c, h, w) 16 | """ 17 | 18 | def __init__(self, 19 | mean=(0, 0, 0), 20 | std=(1, 1, 1), 21 | to_rgb=True, 22 | size_divisor=None): 23 | self.mean = np.array(mean, dtype=np.float32) 24 | self.std = np.array(std, dtype=np.float32) 25 | self.to_rgb = to_rgb 26 | self.size_divisor = size_divisor 27 | 28 | def __call__(self, img, scale, flip=False, keep_ratio=True): 29 | if keep_ratio: 30 | img, scale_factor = mmcv.imrescale(img, scale, return_scale=True) 31 | else: 32 | img, w_scale, h_scale = mmcv.imresize( 33 | img, scale, return_scale=True) 34 | scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], 35 | dtype=np.float32) 36 | img_shape = img.shape 37 | img = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb) 38 | if flip: 39 | img = mmcv.imflip(img) 40 | if self.size_divisor is not None: 41 | img = mmcv.impad_to_multiple(img, self.size_divisor) 42 | pad_shape = img.shape 43 | else: 44 | pad_shape = img_shape 45 | img = img.transpose(2, 0, 1) 46 | return img, img_shape, pad_shape, scale_factor 47 | 48 | 49 | def bbox_flip(bboxes, img_shape): 50 | """Flip bboxes horizontally. 51 | 52 | Args: 53 | bboxes(ndarray): shape (..., 4*k) 54 | img_shape(tuple): (height, width) 55 | """ 56 | assert bboxes.shape[-1] % 4 == 0 57 | w = img_shape[1] 58 | flipped = bboxes.copy() 59 | flipped[..., 0::4] = w - bboxes[..., 2::4] - 1 60 | flipped[..., 2::4] = w - bboxes[..., 0::4] - 1 61 | return flipped 62 | 63 | 64 | class BboxTransform(object): 65 | """Preprocess gt bboxes. 66 | 67 | 1. rescale bboxes according to image size 68 | 2. flip bboxes (if needed) 69 | 3. pad the first dimension to `max_num_gts` 70 | """ 71 | 72 | def __init__(self, max_num_gts=None): 73 | self.max_num_gts = max_num_gts 74 | 75 | def __call__(self, bboxes, img_shape, pad_shape, scale_factor, flip=False): 76 | gt_bboxes = bboxes * scale_factor 77 | if flip: 78 | gt_bboxes = bbox_flip(gt_bboxes, img_shape) 79 | # normalization [0, 1] [x1,y1,x2,y2] 80 | gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1])/pad_shape[1] 81 | gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0])/pad_shape[0] 82 | 83 | if self.max_num_gts is None: 84 | return gt_bboxes 85 | else: 86 | num_gts = gt_bboxes.shape[0] 87 | padded_bboxes = np.zeros((self.max_num_gts, 4), dtype=np.float32) 88 | padded_bboxes[:num_gts, :] = gt_bboxes 89 | return padded_bboxes 90 | 91 | 92 | class MaskTransform(object): 93 | """Preprocess masks. 94 | 95 | 1. resize masks to expected size and stack to a single array 96 | 2. flip the masks (if needed) 97 | 3. pad the masks (if needed) 98 | """ 99 | 100 | def __call__(self, masks, pad_shape, scale, flip=False, keep_ratio=True): 101 | if keep_ratio: 102 | masks = [ 103 | mmcv.imrescale(mask, scale, interpolation='nearest') 104 | for mask in masks 105 | ] 106 | else: 107 | masks = [ 108 | mmcv.imresize(mask, scale, interpolation='nearest') 109 | for mask in masks 110 | ] 111 | 112 | if flip: 113 | masks = [mask[:, ::-1] for mask in masks] 114 | padded_masks = [ 115 | mmcv.impad(mask, shape=pad_shape[:2], pad_val=0) for mask in masks 116 | ] 117 | padded_masks = np.stack(padded_masks, axis=0) 118 | return padded_masks 119 | 120 | 121 | class Numpy2Tensor(object): 122 | 123 | def __init__(self): 124 | pass 125 | 126 | def __call__(self, *args): 127 | if len(args) == 1: 128 | return torch.from_numpy(args[0]) 129 | else: 130 | return tuple([torch.from_numpy(np.array(array)) for array in args]) 131 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import Sequence 3 | import os 4 | 5 | import mmcv 6 | from mmcv.runner import obj_from_dict 7 | import torch 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from .concat_dataset import ConcatDataset 12 | from .repeat_dataset import RepeatDataset 13 | import datasets 14 | from torch.autograd import Variable 15 | import torch.nn.functional as F 16 | import random 17 | 18 | 19 | def to_tensor(data): 20 | """Convert objects of various python types to :obj:`torch.Tensor`. 21 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, 22 | :class:`Sequence`, :class:`int` and :class:`float`. 23 | """ 24 | if isinstance(data, torch.Tensor): 25 | return data 26 | elif isinstance(data, np.ndarray): 27 | return torch.from_numpy(data) 28 | elif isinstance(data, Sequence) and not mmcv.is_str(data): 29 | return torch.tensor(data) 30 | elif isinstance(data, int): 31 | return torch.LongTensor([data]) 32 | elif isinstance(data, float): 33 | return torch.FloatTensor([data]) 34 | else: 35 | raise TypeError('type {} cannot be converted to tensor.'.format( 36 | type(data))) 37 | 38 | 39 | def random_scale(img_scales, mode='range'): 40 | """Randomly select a scale from a list of scales or scale ranges. 41 | Args: 42 | img_scales (list[tuple]): Image scale or scale range. 43 | mode (str): "range" or "value". 44 | Returns: 45 | tuple: Sampled image scale. 46 | """ 47 | num_scales = len(img_scales) 48 | if num_scales == 1: # fixed scale is specified 49 | img_scale = img_scales[0] 50 | elif num_scales == 2: # randomly sample a scale 51 | if mode == 'range': 52 | img_scale_long = [max(s) for s in img_scales] 53 | img_scale_short = [min(s) for s in img_scales] 54 | long_edge = np.random.randint( 55 | min(img_scale_long), 56 | max(img_scale_long) + 1) 57 | short_edge = np.random.randint( 58 | min(img_scale_short), 59 | max(img_scale_short) + 1) 60 | img_scale = (long_edge, short_edge) 61 | elif mode == 'range_keep_ratio': 62 | img_scale_long = [max(s) for s in img_scales] 63 | img_scale_short = [min(s) for s in img_scales] 64 | scale = np.random.rand(1) * (max(img_scale_long) / min(img_scale_long)-1) + 1 65 | img_scale = (int(min(img_scale_long) * scale), int(min(img_scale_short) * scale)) 66 | elif mode == 'value': 67 | img_scale = img_scales[np.random.randint(num_scales)] 68 | else: 69 | if mode != 'value': 70 | raise ValueError( 71 | 'Only "value" mode supports more than 2 image scales') 72 | img_scale = img_scales[np.random.randint(num_scales)] 73 | return img_scale 74 | 75 | 76 | def show_ann(coco, img, ann_info): 77 | plt.imshow(mmcv.bgr2rgb(img)) 78 | plt.axis('off') 79 | coco.showAnns(ann_info) 80 | plt.show() 81 | 82 | 83 | def get_dataset(data_cfg): 84 | data_cfg = vars(data_cfg) 85 | if data_cfg['type'] == 'RepeatDataset': 86 | return RepeatDataset( 87 | get_dataset(data_cfg['dataset']), data_cfg['times']) 88 | 89 | if isinstance(data_cfg['ann_file'], (list, tuple)): 90 | ann_files = data_cfg['ann_file'] 91 | num_dset = len(ann_files) 92 | else: 93 | ann_files = [data_cfg['ann_file']] 94 | num_dset = 1 95 | 96 | if 'proposal_file' in data_cfg.keys(): 97 | if isinstance(data_cfg['proposal_file'], (list, tuple)): 98 | proposal_files = data_cfg['proposal_file'] 99 | else: 100 | proposal_files = [data_cfg['proposal_file']] 101 | else: 102 | proposal_files = [None] * num_dset 103 | assert len(proposal_files) == num_dset 104 | 105 | if isinstance(data_cfg['img_prefix'], (list, tuple)): 106 | img_prefixes = data_cfg['img_prefix'] 107 | else: 108 | img_prefixes = [data_cfg['img_prefix']] * num_dset 109 | assert len(img_prefixes) == num_dset 110 | 111 | dsets = [] 112 | for i in range(num_dset): 113 | data_info = copy.deepcopy(data_cfg) 114 | data_info['ann_file'] = ann_files[i] 115 | data_info['proposal_file'] = proposal_files[i] 116 | data_info['img_prefix'] = img_prefixes[i] 117 | dset = obj_from_dict(data_info, datasets) 118 | dsets.append(dset) 119 | if len(dsets) > 1: 120 | dset = ConcatDataset(dsets) 121 | else: 122 | dset = dsets[0] 123 | return dset 124 | 125 | 126 | def prepare_data(data_batch, devices: list = None, allocation: list = None, batch_size=None, is_cuda=False, 127 | train_mode=True): 128 | if train_mode: 129 | with torch.no_grad(): 130 | if batch_size is None: 131 | batch_size = 1 132 | if devices is None: 133 | devices = ['cuda:0'] if is_cuda else ['cpu'] 134 | if allocation is None: 135 | allocation = [batch_size // len(devices)] * (len(devices) - 1) 136 | allocation.append(batch_size - sum(allocation)) # The rest might need more/less 137 | 138 | images_list = data_batch['img'] 139 | bboxes_list = data_batch['bboxes'] 140 | labels_list = data_batch['labels'] 141 | masks_list = data_batch['masks'] 142 | ids_list = data_batch['ids'] 143 | images_meta_list = data_batch['img_meta'] 144 | n_clip = images_list[0].size(0) 145 | 146 | split_images, split_bboxes, split_labels, split_masks, split_ids, split_images_meta = \ 147 | [[None for alloc in allocation] for _ in range(6)] 148 | for idx, device, alloc in zip(range(len(devices)), devices, allocation): 149 | split_images[idx] = gradinator(torch.stack(images_list[alloc * idx:alloc * (idx + 1)], dim=0).to(device)) 150 | for cur_idx in range(alloc): 151 | bboxes_list[alloc * idx + cur_idx] = [gradinator( 152 | bboxes_list[alloc * idx + cur_idx][i].to(device)) for i in range(n_clip)] 153 | labels_list[alloc * idx + cur_idx] = [gradinator( 154 | labels_list[alloc * idx + cur_idx][i].to(device)) for i in range(n_clip)] 155 | masks_list[alloc * idx + cur_idx] = [gradinator(masks_list[alloc * idx + cur_idx][i].to(device)) 156 | for i in range(n_clip)] 157 | ids_list[alloc * idx + cur_idx] = [gradinator(ids_list[alloc * idx + cur_idx][i].to(device)) 158 | for i in range(n_clip)] 159 | 160 | split_bboxes[idx] = bboxes_list[alloc * idx:alloc * (idx + 1)] 161 | split_labels[idx] = labels_list[alloc * idx:alloc * (idx + 1)] 162 | split_masks[idx] = masks_list[alloc * idx:alloc * (idx + 1)] 163 | split_ids[idx] = ids_list[alloc * idx:alloc * (idx + 1)] 164 | split_images_meta[idx] = images_meta_list[alloc * idx:alloc * (idx + 1)] 165 | 166 | return split_images, split_bboxes, split_labels, split_masks, split_ids, split_images_meta 167 | else: 168 | # [0] is downsample image [1, 3, 384, 640], [1] is original image [1, 3, 736, 1280] 169 | images = torch.stack([img[0].data for img in data_batch['img']], dim=0) 170 | images_meta = [img_meta[0].data for img_meta in data_batch['img_meta']] 171 | if 'ref_imgs' in data_batch.keys(): 172 | ref_images = torch.stack([ref_img[0].data for ref_img in data_batch['ref_imgs']], dim=0) 173 | ref_images_meta = [ref_img_meta[0].data for ref_img_meta in data_batch['ref_img_metas']] 174 | else: 175 | ref_images = None 176 | ref_images_meta = None 177 | 178 | if is_cuda: 179 | images = gradinator(images.cuda()) 180 | images_meta = images_meta 181 | if ref_images is not None: 182 | ref_images = gradinator(ref_images.cuda()) 183 | ref_images_meta = ref_images_meta 184 | else: 185 | images = gradinator(images) 186 | images_meta = images_meta 187 | if ref_images is not None: 188 | ref_images = gradinator(ref_images) 189 | ref_images_meta = ref_images_meta 190 | 191 | return images, images_meta, ref_images, ref_images_meta 192 | 193 | 194 | def gradinator(x): 195 | x.requires_grad = False 196 | return x 197 | 198 | 199 | def enforce_size(img, targets, masks, num_crowds, new_w, new_h): 200 | """ Ensures that the image is the given size without distorting aspect ratio. """ 201 | with torch.no_grad(): 202 | _, h, w = img.size() 203 | 204 | if h == new_h and w == new_w: 205 | return img, targets, masks, num_crowds 206 | 207 | # Resize the image so that it fits within new_w, new_h 208 | w_prime = new_w 209 | h_prime = h * new_w / w 210 | 211 | if h_prime > new_h: 212 | w_prime *= new_h / h_prime 213 | h_prime = new_h 214 | 215 | w_prime = int(w_prime) 216 | h_prime = int(h_prime) 217 | 218 | # Do all the resizing 219 | img = F.interpolate(img.unsqueeze(0), (h_prime, w_prime), mode='bilinear', align_corners=False) 220 | img.squeeze_(0) 221 | 222 | # Act like each object is a color channel 223 | masks = F.interpolate(masks.unsqueeze(0), (h_prime, w_prime), mode='bilinear', align_corners=False) 224 | masks.squeeze_(0) 225 | 226 | # Scale bounding boxes (this will put them in the top left corner in the case of padding) 227 | targets[:, [0, 2]] *= (w_prime / new_w) 228 | targets[:, [1, 3]] *= (h_prime / new_h) 229 | 230 | # Finally, pad everything to be the new_w, new_h 231 | pad_dims = (0, new_w - w_prime, 0, new_h - h_prime) 232 | img = F.pad(img, pad_dims, mode='constant', value=0) 233 | masks = F.pad(masks, pad_dims, mode='constant', value=0) 234 | 235 | return img, targets, masks, num_crowds 236 | 237 | -------------------------------------------------------------------------------- /datasets/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Tue Oct 15 09:56:22 2019 3 | 4 | __version__ = '0.5.6+53bec28' 5 | short_version = '0.5.6' 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # Installs dependencies for YOLACT managed by Anaconda. 2 | # Advantage is you get working CUDA+cuDNN+pytorch+torchvison versions. 3 | # 4 | # TODO: you must additionally install nVidia drivers, eg. on Ubuntu linux 5 | # `apt install nvidia-driver-440` (change the 440 for whatever version you need/have). 6 | # 7 | name: STMask-env 8 | #prefix: /your/custom/path/envs/STMask-env 9 | channels: 10 | - conda-forge 11 | - pytorch 12 | - defaults 13 | dependencies: 14 | - python==3.7 15 | - pip 16 | - cython 17 | - pytorch::torchvision ==0.5.0 18 | - pytorch::pytorch ==1.4.0 19 | - cudatoolkit 20 | - cudnn 21 | - pytorch::cuda100 22 | - matplotlib 23 | - git 24 | - pip: 25 | - opencv-python 26 | - pillow <7.0 # bug PILLOW_VERSION in torchvision, must be < 7.0 until torchvision is upgraded 27 | - pycocotools 28 | - PyQt5 # needed on KDE/Qt envs for matplotlib 29 | 30 | 31 | -------------------------------------------------------------------------------- /images/overall1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinghanLi/STMask/b8ca9efbac6d57e30676a679514a0627b85a494e/images/overall1.png -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | from .modules import * 3 | -------------------------------------------------------------------------------- /layers/display_gt_annotations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import mmcv 5 | import os.path as osp 6 | import pycocotools.mask as mask_util 7 | from cocoapi.PythonAPI.pycocotools.ytvos import YTVOS 8 | import matplotlib.pyplot as plt 9 | from datasets import cfg, MEANS, STD 10 | import cv2 11 | 12 | 13 | def display_gt_ann(anno_file, img_prefix, save_path, mask_alpha=0.45): 14 | ytvosGt = YTVOS(anno_file) 15 | anns = ytvosGt.anns 16 | videos_info = ytvosGt.dataset['videos'] 17 | video_id = anns[3394]['video_id'] 18 | cat_id, bboxes, segm = [], [], [] 19 | n_vid = 0 20 | for idx, ann_id in enumerate(anns): 21 | video_id_cur = anns[ann_id]['video_id'] 22 | cat_id_cur = anns[ann_id]['category_id'] 23 | bboxes_cur = anns[ann_id]['bboxes'] 24 | segm_cur = anns[ann_id]['segmentations'] 25 | if video_id_cur == video_id: 26 | cat_id.append(cat_id_cur) 27 | bboxes.append(bboxes_cur) 28 | segm.append(segm_cur) 29 | else: 30 | vid_info = videos_info[n_vid] 31 | h, w = vid_info['height'], vid_info['width'] 32 | display_masks(n_vid, h, w, bboxes, segm, cat_id, vid_info, img_prefix, save_path, mask_alpha) 33 | n_vid += 1 34 | video_id = video_id_cur 35 | cat_id = [cat_id_cur] 36 | bboxes = [bboxes_cur] 37 | segm = [segm_cur] 38 | 39 | 40 | def display_masks(n_vid, h, w, bboxes, segm, cat_id, vid_info, img_prefix, save_path, mask_alpha=0.45): 41 | for frame_id in range(len(bboxes[0])): 42 | print(n_vid, frame_id) 43 | img_numpy = mmcv.imread(osp.join(img_prefix, vid_info['file_names'][frame_id])) 44 | img_numpy = img_numpy[:, :, (2, 1, 0)] / 255. 45 | img_numpy = np.clip(img_numpy, 0, 1) 46 | img_gpu = torch.Tensor(img_numpy).cuda() 47 | img_numpy = img_gpu.cpu().numpy() 48 | 49 | # plot masks 50 | masks, colors = [], [] 51 | for j in range(len(bboxes)): 52 | if segm[j][frame_id] is not None: 53 | # polygons to rle, rle to binary mask 54 | mask_rle = mask_util.frPyObjects(segm[j][frame_id], h, w) 55 | masks.append(mask_util.decode(mask_rle)) 56 | colors.append(np.array(get_color(j)).reshape([1, 1, 3])) 57 | 58 | if len(masks) == 0: 59 | img_numpy = np.clip(img_numpy * 255, 0, 255).astype(np.int32) 60 | else: 61 | masks = np.stack(masks, axis=0)[:, :, :, None] 62 | colors = np.stack(colors, axis=0) 63 | masks_color = np.repeat(masks, 3, axis=3) * colors * mask_alpha 64 | inv_alph_masks = masks * (-mask_alpha) + 1 65 | 66 | masks_color_summand = masks_color[0] 67 | if len(colors) > 1: 68 | inv_alph_cumul = inv_alph_masks[:(len(colors) - 1)].cumprod(0) 69 | masks_color_cumul = masks_color[1:] * inv_alph_cumul 70 | masks_color_summand += masks_color_cumul.sum(0) 71 | img_numpy = img_numpy * inv_alph_masks.prod(axis=0) + masks_color_summand 72 | img_numpy = np.clip(img_numpy*255, 0, 255).astype(np.int32) 73 | # img_numpy = cv2.cvtColor(img_numpy, cv2.COLOR_RGB2BGR) 74 | # img_numpy = cv2.cvtColor(np.float32(img_numpy), cv2.COLOR_RGB2GRAY) 75 | 76 | # plot bboxes and text 77 | for j in range(len(bboxes)): 78 | if bboxes[j][frame_id] is not None: 79 | color = get_color(j) 80 | x1, y1, w, h = bboxes[j][frame_id] 81 | # x1, x2 = cx - w / 2, cx + w / 2 82 | x2, y2 = x1 + w, y1 + h 83 | y1, x1, y2, x2 = int(y1), int(x1), int(y2), int(x2) 84 | cv2.rectangle(img_numpy, (x1, y1), (x2, y2), color, 1) 85 | 86 | _class = cfg.classes[cat_id[j] - 1] 87 | text_str = '%s' % _class 88 | font_face = cv2.FONT_HERSHEY_DUPLEX 89 | font_scale = 1 90 | font_thickness = 1 91 | text_w, text_h = cv2.getTextSize(text_str, font_face, font_scale, font_thickness)[0] 92 | text_pt = (max(x1, 50), max(y1 - 3, 50)) 93 | text_color = [255, 255, 255] 94 | cv2.rectangle(img_numpy, (max(int(x1), 5), max(int(y1), 5)), (int(x1 + text_w), int(y1 - text_h - 4)), color, -1) 95 | cv2.putText(img_numpy, text_str, text_pt, font_face, font_scale, text_color, font_thickness, 96 | cv2.LINE_AA) 97 | plt.imshow(img_numpy) 98 | plt.axis('off') 99 | plt.title(str([n_vid, frame_id])) 100 | plt.savefig(''.join([save_path, str([n_vid, frame_id]), '.png'])) 101 | plt.clf() 102 | 103 | 104 | # Quick and dirty lambda for selecting the color for a particular index 105 | # Also keeps track of a per-gpu color cache for maximum speed 106 | def get_color(j, norm=True): 107 | global color_cache 108 | color_idx = (j * 5) % len(cfg.COLORS) 109 | 110 | color = cfg.COLORS[color_idx] 111 | if norm: 112 | color = [color[0] / 255., color[1] / 255., color[2] / 255.] 113 | 114 | return color 115 | 116 | 117 | if __name__ == '__main__': 118 | anno_file = ''.join(['/home/lmh/Downloads/VIS/code/', cfg.valid_sub_dataset.ann_file[3:]]) 119 | img_prefix = cfg.valid_sub_dataset.img_prefix 120 | save_path = '/home/lmh/Downloads/VIS/code/yolact_JDT_VIS/results/gt_anno/' 121 | display_gt_ann(anno_file, img_prefix, save_path) -------------------------------------------------------------------------------- /layers/eval_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import numpy as np 4 | import mmcv 5 | import os 6 | import pycocotools.mask as mask_util 7 | from cocoapi.PythonAPI.pycocotools.ytvos import YTVOS 8 | from cocoapi.PythonAPI.pycocotools.ytvoseval import YTVOSeval 9 | import matplotlib as plt 10 | from datasets import cfg 11 | import cv2 12 | from utils.functions import SavePath 13 | 14 | 15 | def bbox2result_with_id(preds, img_meta, classes): 16 | """Convert detection results to a list of numpy arrays. 17 | 18 | Args: 19 | bboxes (Tensor): shape (n, 5) 20 | labels (Tensor): shape (n, ) 21 | classes (int): class category, including background class 22 | 23 | Returns: 24 | list(ndarray): bbox results of each class 25 | """ 26 | video_id, frame_id = img_meta['video_id'], img_meta['frame_id'] 27 | results = {'video_id': video_id, 'frame_id': frame_id} 28 | if preds['box'].shape[0] == 0: 29 | return results 30 | else: 31 | bboxes = preds['box'].cpu().numpy() 32 | if preds['class'] is not None: 33 | labels = preds['class'].cpu().numpy() 34 | # labels_all = preds['class_all'].cpu().numpy() 35 | else: 36 | labels = None 37 | scores = preds['score'].cpu().numpy() 38 | segms = preds['segm'] 39 | obj_ids = preds['box_ids'].cpu().numpy() 40 | if labels is not None: 41 | for bbox, label, score, segm, obj_id in zip(bboxes, labels, scores, segms, obj_ids): 42 | if obj_id >= 0: 43 | results[obj_id] = {'bbox': bbox, 'label': label, 'score': score, 'segm': segm, 44 | 'category': classes[label-1]} 45 | else: 46 | for bbox, score, segm, obj_id in zip(bboxes, scores, segms, obj_ids): 47 | if obj_id >= 0: 48 | results[obj_id] = {'bbox': bbox, 'score': score, 'segm': segm} 49 | 50 | return results 51 | 52 | 53 | def results2json_videoseg(results, out_file): 54 | json_results = [] 55 | vid_objs = {} 56 | size = len(results) 57 | 58 | for idx in range(size): 59 | # assume results is ordered 60 | 61 | vid_id, frame_id = results[idx]['video_id'], results[idx]['frame_id'] 62 | if idx == size - 1: 63 | is_last = True 64 | else: 65 | vid_id_next, frame_id_next = results[idx + 1]['video_id'], results[idx + 1]['frame_id'] 66 | is_last = vid_id_next != vid_id 67 | 68 | det = results[idx] 69 | for obj_id in det: 70 | if obj_id not in {'video_id', 'frame_id'}: 71 | bbox = det[obj_id]['bbox'] 72 | score = det[obj_id]['score'] 73 | segm = det[obj_id]['segm'] 74 | label = det[obj_id]['label'] 75 | # label_all = det[obj_id]['label_all'] 76 | if obj_id not in vid_objs: 77 | vid_objs[obj_id] = {'scores': [], 'cats': [], 'segms': {}} 78 | vid_objs[obj_id]['scores'].append(score) 79 | vid_objs[obj_id]['cats'].append(label) 80 | segm['counts'] = segm['counts'].decode() 81 | vid_objs[obj_id]['segms'][frame_id] = segm 82 | if is_last: 83 | # store results of the current video 84 | for obj_id, obj in vid_objs.items(): 85 | data = dict() 86 | 87 | data['video_id'] = vid_id 88 | data['score'] = np.array(obj['scores']).mean().item() 89 | # majority voting for sequence category 90 | # data['category_id'] = np.stack(obj['cats'], axis=0).sum(0).argmax().item()+1 91 | data['category_id'] = np.bincount(np.array(obj['cats'])).argmax().item() 92 | vid_seg = [] 93 | for fid in range(frame_id + 1): 94 | if fid in obj['segms']: 95 | vid_seg.append(obj['segms'][fid]) 96 | else: 97 | vid_seg.append(None) 98 | data['segmentations'] = vid_seg 99 | json_results.append(data) 100 | 101 | vid_objs = {} 102 | if not os.path.exists(out_file[:-13]): 103 | os.makedirs(out_file[:-13]) 104 | 105 | mmcv.dump(json_results, out_file) 106 | print('Done') 107 | 108 | 109 | def calc_metrics(anno_file, dt_file, output_file=None): 110 | ytvosGt = YTVOS(anno_file) 111 | ytvosDt = ytvosGt.loadRes(dt_file) 112 | 113 | E = YTVOSeval(ytvosGt, ytvosDt, iouType='segm', output_file=output_file) 114 | E.evaluate() 115 | E.accumulate() 116 | E.summarize() 117 | print('finish validation') 118 | 119 | return E.stats 120 | 121 | 122 | def ytvos_eval(result_file, result_types, ytvos, max_dets=(100, 300, 1000), save_path_valid_metrics=None): 123 | if mmcv.is_str(ytvos): 124 | ytvos = YTVOS(ytvos) 125 | assert isinstance(ytvos, YTVOS) 126 | 127 | if len(ytvos.anns) == 0: 128 | print("Annotations does not exist") 129 | return 130 | assert result_file.endswith('.json') 131 | ytvos_dets = ytvos.loadRes(result_file) 132 | 133 | vid_ids = ytvos.getVidIds() 134 | for res_type in result_types: 135 | iou_type = res_type 136 | ytvosEval = YTVOSeval(ytvos, ytvos_dets, iou_type, output_file=save_path_valid_metrics) 137 | ytvosEval.params.vidIds = vid_ids 138 | if res_type == 'proposal': 139 | ytvosEval.params.useCats = 0 140 | ytvosEval.params.maxDets = list(max_dets) 141 | ytvosEval.evaluate() 142 | ytvosEval.accumulate() 143 | ytvosEval.summarize() 144 | -------------------------------------------------------------------------------- /layers/functions/TF_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from layers.box_utils import jaccard, center_size, point_form, decode, crop, mask_iou 4 | from layers.mask_utils import generate_mask 5 | from layers.modules import correlate, bbox_feat_extractor 6 | from layers.visualization import display_box_shift, display_correlation_map_patch 7 | 8 | from datasets import cfg 9 | from utils import timer 10 | 11 | 12 | def CandidateShift(net, ref_candidate, next_candidate, img=None, img_meta=None, display=False): 13 | """ 14 | The function try to shift the candidates of reference frame to that of target frame. 15 | The most important step is to shift the bounding box of reference frame to that of target frame 16 | :param net: network 17 | :param next_candidate: features of the last layer to predict bounding box on target frame 18 | :param ref_candidate: the candidate dictionary that includes 'box', 'conf', 'mask_coeff', 'track' items. 19 | :return: candidates on the target frame 20 | """ 21 | 22 | ref_candidate_shift = {} 23 | for k, v in next_candidate.items(): 24 | if k in {'proto', 'fpn_feat', 'T2S_feat'}: 25 | ref_candidate_shift[k] = v.clone() 26 | 27 | # we only use the features in the P3 layer to perform correlation operation 28 | T2S_feat_ref, T2S_feat_next = ref_candidate['T2S_feat'], next_candidate['T2S_feat'] 29 | fpn_feat_ref, fpn_feat_next = ref_candidate['fpn_feat'], next_candidate['fpn_feat'] 30 | x_corr = correlate(fpn_feat_ref, fpn_feat_next, patch_size=cfg.correlation_patch_size) 31 | concatenated_features = F.relu(torch.cat([x_corr, T2S_feat_ref, T2S_feat_next], dim=1)) 32 | 33 | box_ref = ref_candidate['box'].clone() 34 | feat_h, feat_w = fpn_feat_ref.size()[2:] 35 | bbox_feat_input = bbox_feat_extractor(concatenated_features, box_ref, feat_h, feat_w, 7) 36 | loc_ref_shift, mask_coeff_shift = net.TemporalNet(bbox_feat_input) 37 | box_ref_shift = decode(loc_ref_shift, center_size(box_ref)) 38 | mask_coeff_ref_shift = ref_candidate['mask_coeff'].clone() + mask_coeff_shift 39 | masks_ref_shift = generate_mask(next_candidate['proto'], mask_coeff_ref_shift, box_ref_shift) 40 | 41 | # display = 1 42 | if display: 43 | # display_correlation_map_patch(bbox_feat_input[:, :121], img_meta) 44 | display_box_shift(box_ref, box_ref_shift, mask_shift=masks_ref_shift, img_meta=img_meta, img_gpu=img) 45 | 46 | ref_candidate_shift['box'] = box_ref_shift.clone() 47 | ref_candidate_shift['score'] = ref_candidate['score'].clone() * 0.95 48 | ref_candidate_shift['mask_coeff'] = mask_coeff_ref_shift.clone() 49 | ref_candidate_shift['mask'] = masks_ref_shift.clone() 50 | 51 | return ref_candidate_shift 52 | 53 | 54 | def generate_candidate(predictions): 55 | batch_Size = predictions['loc'].size(0) 56 | candidate = [] 57 | prior_data = predictions['priors'].squeeze(0) 58 | for i in range(batch_Size): 59 | loc_data = predictions['loc'][i] 60 | conf_data = predictions['conf'][i] 61 | 62 | candidate_cur = {'T2S_feat': predictions['T2S_feat'][i].unsqueeze(0), 63 | 'fpn_feat': predictions['fpn_feat'][i].unsqueeze(0)} 64 | 65 | with timer.env('Detect'): 66 | decoded_boxes = decode(loc_data, prior_data) 67 | 68 | conf_data = conf_data.t().contiguous() 69 | conf_scores, _ = torch.max(conf_data[1:, :], dim=0) 70 | 71 | keep = (conf_scores > cfg.eval_conf_thresh) 72 | candidate_cur['proto'] = predictions['proto'][i] 73 | candidate_cur['conf'] = conf_data[:, keep].t() 74 | candidate_cur['box'] = decoded_boxes[keep, :] 75 | candidate_cur['mask_coeff'] = predictions['mask_coeff'][i][keep, :] 76 | candidate_cur['track'] = predictions['track'][i][keep, :] if cfg.train_track else None 77 | if cfg.train_centerness: 78 | candidate_cur['centerness'] = predictions['centerness'][i][keep].view(-1) 79 | 80 | candidate.append(candidate_cur) 81 | 82 | return candidate 83 | 84 | 85 | def merge_candidates(candidate, ref_candidate_clip_shift): 86 | merged_candidate = {} 87 | for k, v in candidate.items(): 88 | merged_candidate[k] = v.clone() 89 | 90 | for ref_candidate in ref_candidate_clip_shift: 91 | if ref_candidate['box'].nelement() > 0: 92 | for k, v in merged_candidate.items(): 93 | if k not in {'proto', 'T2S_feat', 'fpn_feat'}: 94 | merged_candidate[k] = torch.cat([v.clone(), ref_candidate[k].clone()], dim=0) 95 | 96 | return merged_candidate 97 | 98 | 99 | def compute_comp_scores(match_ll, bbox_scores, bbox_ious, mask_ious, label_delta, add_bbox_dummy=False, bbox_dummy_iou=0, 100 | match_coeff=None): 101 | # compute comprehensive matching score based on matchig likelihood, 102 | # bbox confidence, and ious 103 | if add_bbox_dummy: 104 | bbox_iou_dummy = torch.ones(bbox_ious.size(0), 1, 105 | device=torch.cuda.current_device()) * bbox_dummy_iou 106 | bbox_ious = torch.cat((bbox_iou_dummy, bbox_ious), dim=1) 107 | mask_ious = torch.cat((bbox_iou_dummy, mask_ious), dim=1) 108 | label_dummy = torch.ones(bbox_ious.size(0), 1, 109 | device=torch.cuda.current_device()) 110 | label_delta = torch.cat((label_dummy, label_delta), dim=1) 111 | 112 | if match_coeff is None: 113 | return match_ll 114 | else: 115 | # match coeff needs to be length of 4 116 | assert (len(match_coeff) == 4) 117 | return match_ll + match_coeff[0] * bbox_scores \ 118 | + match_coeff[1] * mask_ious \ 119 | + match_coeff[2] * bbox_ious \ 120 | + match_coeff[3] * label_delta 121 | 122 | 123 | -------------------------------------------------------------------------------- /layers/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import Detect 2 | from .detection_TF import Detect_TF 3 | from .track import Track 4 | from .track_TF import Track_TF 5 | from .TF_utils import CandidateShift, generate_candidate, merge_candidates 6 | 7 | 8 | __all__ = ['Detect', 'Detect_TF', 'Track', 'Track_TF', 9 | 'merge_candidates', 'CandidateShift', 'generate_candidate'] 10 | -------------------------------------------------------------------------------- /layers/functions/detection_TF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..box_utils import jaccard, mask_iou, crop 3 | from ..mask_utils import generate_mask 4 | from utils import timer 5 | from datasets import cfg 6 | 7 | 8 | class Detect_TF(object): 9 | """At test time, Detect is the final layer of SSD. Decode location preds, 10 | apply non-maximum suppression to location predictions based on conf 11 | scores and threshold to a top_k number of output predictions for both 12 | confidence score and locations, as the predicted masks. 13 | """ 14 | 15 | # TODO: Refactor this whole class away. It needs to go. 16 | 17 | def __init__(self, num_classes, bkg_label, top_k, conf_thresh, nms_thresh): 18 | self.num_classes = num_classes 19 | self.background_label = bkg_label 20 | self.top_k = top_k 21 | # Parameters used in nms. 22 | self.nms_thresh = nms_thresh 23 | if nms_thresh <= 0: 24 | raise ValueError('nms_threshold must be non negative.') 25 | self.conf_thresh = conf_thresh 26 | 27 | self.use_cross_class_nms = True 28 | self.use_fast_nms = True 29 | 30 | def __call__(self, net, candidates, is_output_candidate=False): 31 | """ 32 | Args: 33 | net: (tensor) Loc preds from loc layers 34 | Shape: [batch, num_priors, 4] 35 | candidate: (tensor) Shape: Conf preds from conf layers 36 | Shape: [batch, num_priors, num_classes] 37 | Returns: 38 | output of shape (batch_size, top_k, 1 + 1 + 4 + mask_dim) 39 | These outputs are in the order: class idx, confidence, bbox coords, and mask. 40 | 41 | Note that the outputs are sorted only if cross_class_nms is False 42 | """ 43 | 44 | with timer.env('Detect'): 45 | results = [] 46 | 47 | for candidate in candidates: 48 | result = self.detect(candidate, is_output_candidate) 49 | if is_output_candidate: 50 | results.append(result) 51 | else: 52 | results.append({'detection': result, 'net': net}) 53 | 54 | return results 55 | 56 | def detect(self, candidate, is_output_candidate=False): 57 | """ Perform nms for only the max scoring class that isn't background (class 0) """ 58 | 59 | scores = candidate['conf'].t()[1:] # [n_class, n_dets] 60 | boxes = candidate['box'] 61 | centerness_scores = candidate['centerness'] 62 | mask_coeff = candidate['mask_coeff'] 63 | track = candidate['track'] 64 | proto_data = candidate['proto'] 65 | 66 | if boxes.size(0) == 0: 67 | out_aft_nms = {'box': boxes, 'mask_coeff': mask_coeff, 'class': torch.Tensor(), 'score': torch.Tensor()} 68 | 69 | else: 70 | 71 | if self.use_cross_class_nms: 72 | out_aft_nms = self.cc_fast_nms(boxes, mask_coeff, proto_data, track, scores, 73 | centerness_scores, self.nms_thresh, self.top_k) 74 | else: 75 | out_aft_nms = self.fast_nms(boxes, mask_coeff, proto_data, track, scores, centerness_scores, 76 | self.nms_thresh, self.top_k) 77 | 78 | if is_output_candidate: 79 | for k, v in candidate.items(): 80 | if k in {'fpn_feat', 'proto', 'T2S_feat', 'sem_seg'}: 81 | out_aft_nms[k] = v 82 | 83 | return out_aft_nms 84 | 85 | def cc_fast_nms(self, boxes, masks_coeff, proto_data, track, scores, centerness_scores, 86 | iou_threshold: float = 0.5, top_k: int = 200): 87 | 88 | scores, classes = scores.max(dim=0) 89 | 90 | if centerness_scores is not None: 91 | scores = scores * centerness_scores 92 | 93 | _, idx = scores.sort(0, descending=True) 94 | idx = idx[:top_k] 95 | 96 | if len(idx) == 0: 97 | out_after_NMS = {'box': torch.Tensor(), 'mask_coeff': torch.Tensor(), 'class': torch.Tensor(), 98 | 'score': torch.Tensor()} 99 | 100 | else: 101 | # Compute the pairwise IoU between the boxes 102 | boxes_idx = boxes[idx] 103 | iou = jaccard(boxes_idx, boxes_idx) 104 | if cfg.nms_as_miou: 105 | det_masks_soft = generate_mask(proto_data, masks_coeff, boxes) 106 | det_masks = det_masks_soft.gt(0.5).float() 107 | miou = mask_iou(det_masks[idx], det_masks[idx]) 108 | iou = 0.5 * iou + 0.5 * miou 109 | 110 | # Zero out the lower triangle of the cosine similarity matrix and diagonal 111 | iou = torch.triu(iou, diagonal=1) 112 | 113 | # Now that everything in the diagonal and below is zeroed out, if we take the max 114 | # of the IoU matrix along the columns, each column will represent the maximum IoU 115 | # between this element and every element with a higher score than this element. 116 | iou_max, _ = torch.max(iou, dim=0) 117 | 118 | # Now just filter out the ones greater than the threshold, i.e., only keep boxes that 119 | # don't have a higher scoring box that would supress it in normal NMS. 120 | idx_out = idx[iou_max <= iou_threshold] 121 | 122 | boxes = boxes[idx_out] 123 | masks_coeff = masks_coeff[idx_out] 124 | if track is not None: 125 | track = track[idx_out] 126 | if classes is not None: 127 | classes = classes[idx_out] + 1 128 | scores = scores[idx_out] 129 | if centerness_scores is not None: 130 | centerness_scores = centerness_scores[idx_out] 131 | 132 | out_after_NMS = {'box': boxes, 'mask_coeff': masks_coeff, 'track': track, 'class': classes, 133 | 'score': scores, 'centerness': centerness_scores} 134 | return out_after_NMS 135 | 136 | def fast_nms(self, boxes, masks_coeff, proto_data, track, conf, centerness_scores, 137 | iou_threshold: float = 0.5, top_k: int = 200, 138 | second_threshold: bool = True): 139 | 140 | if centerness_scores is not None: 141 | centerness_scores = centerness_scores.view(-1, 1) 142 | conf = conf * centerness_scores.t() 143 | 144 | scores, idx = conf.sort(1, descending=True) # [num_classes, num_dets] 145 | idx = idx[:, :top_k].contiguous() 146 | scores = scores[:, :top_k] 147 | 148 | if len(idx) == 0: 149 | out_after_NMS = {'box': torch.Tensor(), 'mask_coeff': torch.Tensor(), 'class': torch.Tensor(), 150 | 'score': torch.Tensor()} 151 | 152 | else: 153 | num_classes, num_dets = idx.size() 154 | boxes = boxes[idx.view(-1), :].view(num_classes, num_dets, 4) 155 | masks_coeff = masks_coeff[idx.view(-1), :].view(num_classes, num_dets, -1) 156 | if cfg.train_track: 157 | track = track[idx.view(-1), :].view(num_classes, num_dets, -1) 158 | if centerness_scores is not None: 159 | centerness_scores = centerness_scores[idx.view(-1), :].view(num_classes, num_dets, -1) 160 | 161 | iou = jaccard(boxes, boxes) # [num_classes, num_dets, num_dets] 162 | iou.triu_(diagonal=1) 163 | iou_max, _ = iou.max(dim=1) # [num_classes, num_dets] 164 | 165 | # Now just filter out the ones higher than the threshold 166 | keep = (iou_max <= iou_threshold) # [num_classes, num_dets] 167 | 168 | # We should also only keep detections over the confidence threshold, but at the cost of 169 | # maxing out your detection count for every image, you can just not do that. Because we 170 | # have such a minimal amount of computation per detection (matrix mulitplication only), 171 | # this increase doesn't affect us much (+0.2 mAP for 34 -> 33 fps), so we leave it out. 172 | # However, when you implement this in your method, you should do this second threshold. 173 | if second_threshold: 174 | keep *= (scores > self.conf_thresh) 175 | 176 | # Assign each kept detection to its corresponding class 177 | classes = torch.arange(num_classes, device=boxes.device)[:, None].expand_as(keep) 178 | classes = classes[keep] 179 | 180 | boxes = boxes[keep] 181 | masks_coeff = masks_coeff[keep] 182 | if cfg.train_track: 183 | track = track[keep] 184 | if centerness_scores is not None: 185 | centerness_scores = centerness_scores[keep] 186 | scores = scores[keep] 187 | 188 | # Only keep the top cfg.max_num_detections highest scores across all classes 189 | scores, idx = scores.sort(0, descending=True) 190 | idx = idx[:cfg.max_num_detections] 191 | scores = scores[:cfg.max_num_detections] 192 | 193 | classes = classes[idx] + 1 194 | boxes = boxes[idx] 195 | masks_coeff = masks_coeff[idx] 196 | if cfg.train_track: 197 | track = track[idx] 198 | if centerness_scores is not None: 199 | centerness_scores = centerness_scores[idx] 200 | 201 | out_after_NMS = {'box': boxes, 'mask_coeff': masks_coeff, 'track': track, 'class': classes, 202 | 'score': scores, 'centerness': centerness_scores} 203 | 204 | return out_after_NMS 205 | 206 | -------------------------------------------------------------------------------- /layers/functions/track.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from ..box_utils import decode, jaccard, index2d, mask_iou, crop, center_size, DIoU 4 | from ..mask_utils import generate_mask 5 | from layers.mask_utils import generate_rel_coord 6 | from utils import timer 7 | from .TF_utils import compute_comp_scores 8 | 9 | from datasets import cfg 10 | 11 | import numpy as np 12 | 13 | import pyximport 14 | pyximport.install(setup_args={"include_dirs":np.get_include()}, reload_support=True) 15 | 16 | 17 | class Track(object): 18 | """At test time, Detect is the final layer of SSD. Decode location preds, 19 | apply non-maximum suppression to location predictions based on conf 20 | scores and threshold to a top_k number of output predictions for both 21 | confidence score and locations, as the predicted masks. 22 | """ 23 | # TODO: Refactor this whole class away. It needs to go. 24 | 25 | def __init__(self): 26 | self.prev_det_bbox = None 27 | self.prev_track_embed = None 28 | self.prev_det_labels = None 29 | self.prev_det_masks = None 30 | self.prev_det_masks_coeff = None 31 | self.prev_protos = None 32 | self.det_scores = None 33 | 34 | def __call__(self, pred_outs_after_NMS, img_meta): 35 | """ 36 | Args: 37 | loc_data: (tensor) Loc preds from loc layers 38 | Shape: [batch, num_priors, 4] 39 | 40 | Returns: 41 | output of shape (batch_size, top_k, 1 + 1 + 4 + mask_dim) 42 | These outputs are in the order: class idx, confidence, bbox coords, and mask. 43 | 44 | Note that the outputs are sorted only if cross_class_nms is False 45 | """ 46 | 47 | with timer.env('Track'): 48 | batch_size = len(pred_outs_after_NMS) 49 | 50 | for batch_idx in range(batch_size): 51 | detection = pred_outs_after_NMS[batch_idx]['detection'] 52 | pred_outs_after_NMS[batch_idx]['detection'] = self.track(detection, img_meta[batch_idx]) 53 | 54 | return pred_outs_after_NMS 55 | 56 | def track(self, detection, img_meta): 57 | 58 | # only support batch_size = 1 for video test 59 | is_first = img_meta['is_first'] 60 | if is_first: 61 | self.prev_det_bbox = None 62 | self.prev_track_embed = None 63 | self.prev_det_labels = None 64 | self.prev_det_masks = None 65 | self.prev_det_masks_coeff = None 66 | self.prev_protos = None 67 | self.det_scores = None 68 | # self.prev_track = {} 69 | 70 | if detection['class'].nelement() == 0: 71 | detection['box_ids'] = torch.tensor([], dtype=torch.int64) 72 | return detection 73 | 74 | # get bbox and class after NMS 75 | det_bbox = detection['box'] 76 | det_labels = detection['class'] # class 77 | det_score = detection['score'] 78 | det_masks_coff = detection['mask_coeff'] 79 | if cfg.train_track: 80 | det_track_embed = detection['track'] 81 | else: 82 | det_track_embed = F.normalize(det_masks_coff, dim=1) 83 | proto_data = detection['proto'] 84 | 85 | n_dets = det_bbox.size(0) 86 | mask_h, mask_w = proto_data.size()[:2] 87 | 88 | # get masks 89 | det_masks_soft = generate_mask(proto_data, det_masks_coff, det_bbox) 90 | det_masks = det_masks_soft.gt(0.5).float() 91 | detection['mask'] = det_masks 92 | 93 | # compared bboxes in current frame with bboxes in previous frame to achieve tracking 94 | if is_first or (not is_first and self.prev_det_bbox is None): 95 | det_obj_ids = torch.arange(det_bbox.size(0)) 96 | # save bbox and features for later matching 97 | self.prev_det_bbox = det_bbox 98 | self.prev_track_embed = det_track_embed 99 | self.prev_det_labels = det_labels.view(-1) 100 | self.prev_det_masks = det_masks 101 | self.prev_det_masks_coeff = det_masks_coff 102 | self.prev_protos = proto_data.unsqueeze(0).repeat(n_dets, 1, 1, 1) 103 | self.prev_scores = det_score 104 | # self.prev_track = {i: det_track_embed[i].view(1, -1) for i in range(det_bbox.size(0))} 105 | 106 | else: 107 | 108 | assert self.prev_track_embed is not None 109 | n_prev = self.prev_det_bbox.size(0) 110 | # only support one image at a time 111 | cos_sim = det_track_embed @ self.prev_track_embed.t() # [n_dets, n_prev], val in [-1, 1] 112 | cos_sim = torch.cat([torch.zeros(n_dets, 1), cos_sim], dim=1) 113 | cos_sim = (cos_sim + 1) / 2 # [0, 1] 114 | 115 | bbox_ious = jaccard(det_bbox, self.prev_det_bbox) 116 | mask_ious = mask_iou(det_masks, self.prev_det_masks) 117 | 118 | # compute comprehensive score 119 | label_delta = (self.prev_det_labels == det_labels.view(-1, 1)).float() 120 | comp_scores = compute_comp_scores(cos_sim, 121 | det_score.view(-1, 1), 122 | bbox_ious, 123 | mask_ious, 124 | label_delta, 125 | add_bbox_dummy=True, 126 | bbox_dummy_iou=0.3, 127 | match_coeff=cfg.match_coeff) 128 | 129 | match_likelihood, match_ids = torch.max(comp_scores, dim=1) 130 | # translate match_ids to det_obj_ids, assign new id to new objects 131 | # update tracking features/bboxes of exisiting object, 132 | # add tracking features/bboxes of new object 133 | match_ids = match_ids 134 | det_obj_ids = torch.ones(n_dets, dtype=torch.int32) * (-1) 135 | best_match_scores = torch.ones(n_prev) * (-1) 136 | best_match_idx = torch.ones(n_prev) * (-1) 137 | for idx, match_id in enumerate(match_ids): 138 | if match_id == 0: 139 | det_obj_ids[idx] = self.prev_det_masks.size(0) 140 | self.prev_track_embed = torch.cat([self.prev_track_embed, det_track_embed[idx][None]], dim=0) 141 | self.prev_det_bbox = torch.cat((self.prev_det_bbox, det_bbox[idx][None]), dim=0) 142 | if det_labels is not None: 143 | self.prev_det_labels = torch.cat((self.prev_det_labels, det_labels[idx][None]), dim=0) 144 | self.prev_det_masks = torch.cat((self.prev_det_masks, det_masks[idx][None]), dim=0) 145 | self.prev_det_masks_coeff = torch.cat((self.prev_det_masks_coeff, det_masks_coff[idx][None]), dim=0) 146 | self.prev_protos = torch.cat((self.prev_protos, proto_data[None]), dim=0) 147 | self.prev_scores = torch.cat((self.prev_scores, det_score[idx][None]), dim=0) 148 | # self.prev_track[self.prev_det_masks.size(0)-1] = det_track_embed[idx].view(1, -1) 149 | 150 | else: 151 | # multiple candidate might match with previous object, here we choose the one with 152 | # largest comprehensive score 153 | obj_id = match_id - 1 154 | match_score = det_score[idx] # match_likelihood[idx] 155 | if match_score > best_match_scores[obj_id]: 156 | if best_match_idx[obj_id] != -1: 157 | det_obj_ids[int(best_match_idx[obj_id])] = -1 158 | det_obj_ids[idx] = obj_id 159 | best_match_scores[obj_id] = match_score 160 | best_match_idx[obj_id] = idx 161 | # udpate feature 162 | if (mask_ious[idx] > 0.3).sum() < 2: 163 | if det_labels is not None: 164 | self.prev_det_labels[obj_id] = det_labels[idx] 165 | self.prev_track_embed[obj_id] = det_track_embed[idx] 166 | self.prev_det_bbox[obj_id] = det_bbox[idx] 167 | self.prev_det_masks[obj_id] = det_masks[idx] 168 | self.prev_det_masks_coeff[obj_id] = det_masks_coff[idx] 169 | self.prev_protos[obj_id] = proto_data 170 | self.prev_scores[obj_id] = det_score[idx] 171 | # self.prev_track[int(obj_id)] = torch.cat([self.prev_track[int(obj_id)], det_track_embed[idx][None]], dim=0) 172 | 173 | detection['box_ids'] = det_obj_ids 174 | if cfg.remove_false_inst: 175 | keep = det_obj_ids >= 0 176 | for k, v in detection.items(): 177 | if k not in {'proto', 'bbox_idx', 'priors', 'loc_t'}: 178 | detection[k] = detection[k][keep] 179 | 180 | return detection -------------------------------------------------------------------------------- /layers/functions/track_TF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from ..box_utils import jaccard, mask_iou 4 | from ..mask_utils import generate_mask 5 | from .TF_utils import CandidateShift, compute_comp_scores 6 | from utils import timer 7 | 8 | from datasets import cfg 9 | 10 | import numpy as np 11 | 12 | import pyximport 13 | pyximport.install(setup_args={"include_dirs":np.get_include()}, reload_support=True) 14 | 15 | 16 | class Track_TF(object): 17 | """At test time, Detect is the final layer of SSD. Decode location preds, 18 | apply non-maximum suppression to location predictions based on conf 19 | scores and threshold to a top_k number of output predictions for both 20 | confidence score and locations, as the predicted masks. 21 | """ 22 | # TODO: Refactor this whole class away. It needs to go. 23 | 24 | def __init__(self): 25 | self.prev_candidate = None 26 | 27 | def __call__(self, net, candidates, imgs_meta, imgs=None): 28 | """ 29 | Args: 30 | loc_data: (tensor) Loc preds from loc layers 31 | Shape: [batch, num_priors, 4] 32 | 33 | Returns: 34 | output of shape (batch_size, top_k, 1 + 1 + 4 + mask_dim) 35 | These outputs are in the order: class idx, confidence, bbox coords, and mask. 36 | 37 | Note that the outputs are sorted only if cross_class_nms is False 38 | """ 39 | 40 | with timer.env('Track'): 41 | results = [] 42 | 43 | # only support batch_size = 1 for video test 44 | for batch_idx, candidate in enumerate(candidates): 45 | result = self.track(net, candidate, imgs_meta[batch_idx], img=imgs[batch_idx]) 46 | results.append({'detection': result, 'net': net}) 47 | 48 | return results 49 | 50 | def track(self, net, candidate, img_meta, img=None): 51 | # only support batch_size = 1 for video test 52 | is_first = img_meta['is_first'] 53 | if is_first: 54 | self.prev_candidate = None 55 | 56 | if candidate['box'].nelement() == 0 and self.prev_candidate is None: 57 | return {'box': torch.Tensor(), 'mask_coeff': torch.Tensor(), 'class': torch.Tensor(), 58 | 'score': torch.Tensor(), 'box_ids': torch.Tensor()} 59 | 60 | else: 61 | if candidate['box'].nelement() == 0 and self.prev_candidate is not None: 62 | prev_candidate_shift = CandidateShift(net, self.prev_candidate, candidate, 63 | img=img, img_meta=img_meta) 64 | for k, v in prev_candidate_shift.items(): 65 | self.prev_candidate[k] = v.clone() 66 | self.prev_candidate['tracked_mask'] = self.prev_candidate['tracked_mask'] + 1 67 | else: 68 | 69 | # get bbox and class after NMS 70 | det_bbox = candidate['box'] 71 | det_score = candidate['score'] 72 | det_labels = candidate['class'] 73 | det_masks_coeff = candidate['mask_coeff'] 74 | if cfg.train_track: 75 | det_track_embed = candidate['track'] 76 | else: 77 | det_track_embed = F.normalize(det_masks_coeff, dim=1) 78 | 79 | n_dets = det_bbox.size(0) 80 | # get masks 81 | det_masks_soft = generate_mask(candidate['proto'], det_masks_coeff, det_bbox) 82 | candidate['mask'] = det_masks_soft 83 | det_masks = det_masks_soft.gt(0.5).float() 84 | 85 | # compared bboxes in current frame with bboxes in previous frame to achieve tracking 86 | if is_first or (not is_first and self.prev_candidate is None): 87 | # save bbox and features for later matching 88 | self.prev_candidate = dict() 89 | for k, v in candidate.items(): 90 | self.prev_candidate[k] = v 91 | self.prev_candidate['tracked_mask'] = torch.zeros(n_dets) 92 | 93 | else: 94 | 95 | assert self.prev_candidate is not None 96 | prev_candidate_shift = CandidateShift(net, self.prev_candidate, candidate, 97 | img=img, img_meta=img_meta) 98 | for k, v in prev_candidate_shift.items(): 99 | self.prev_candidate[k] = v.clone() 100 | self.prev_candidate['tracked_mask'] = self.prev_candidate['tracked_mask'] + 1 101 | 102 | n_prev = self.prev_candidate['box'].size(0) 103 | # only support one image at a time 104 | cos_sim = det_track_embed @ self.prev_candidate['track'].t() 105 | cos_sim = torch.cat([torch.zeros(n_dets, 1), cos_sim], dim=1) 106 | cos_sim = (cos_sim + 1) / 2 # [0, 1] 107 | 108 | bbox_ious = jaccard(det_bbox, self.prev_candidate['box']) 109 | prev_masks_shift = self.prev_candidate['mask'].gt(0.5).float() 110 | 111 | mask_ious = mask_iou(det_masks, prev_masks_shift) # [n_dets, n_prev] 112 | # print(img_meta['video_id'], img_meta['frame_id'], cos_sim[:, 1:], mask_ious) 113 | 114 | # compute comprehensive score 115 | prev_det_labels = self.prev_candidate['class'] 116 | label_delta = (prev_det_labels == det_labels.view(-1, 1)).float() 117 | comp_scores = compute_comp_scores(cos_sim, 118 | det_score.view(-1, 1), 119 | bbox_ious, 120 | mask_ious, 121 | label_delta, 122 | add_bbox_dummy=True, 123 | bbox_dummy_iou=0.3, 124 | match_coeff=cfg.match_coeff) 125 | match_likelihood, match_ids = torch.max(comp_scores, dim=1) 126 | # translate match_ids to det_obj_ids, assign new id to new objects 127 | # update tracking features/bboxes of exisiting object, 128 | # add tracking features/bboxes of new object 129 | det_obj_ids = torch.ones(n_dets, dtype=torch.int32) * (-1) 130 | best_match_scores = torch.ones(n_prev) * (-1) 131 | best_match_idx = torch.ones(n_prev) * (-1) 132 | for idx, match_id in enumerate(match_ids): 133 | if match_id == 0: 134 | det_obj_ids[idx] = self.prev_candidate['box'].size(0) 135 | for k, v in self.prev_candidate.items(): 136 | if k not in {'proto', 'T2S_feat', 'fpn_feat', 'tracked_mask'}: 137 | self.prev_candidate[k] = torch.cat([v, candidate[k][idx][None]], dim=0) 138 | self.prev_candidate['tracked_mask'] = torch.cat([self.prev_candidate['tracked_mask'], 139 | torch.zeros(1)], dim=0) 140 | 141 | else: 142 | # multiple candidate might match with previous object, here we choose the one with 143 | # largest comprehensive score 144 | obj_id = match_id - 1 145 | match_score = det_score[idx] # match_likelihood[idx] 146 | if match_score > best_match_scores[obj_id]: 147 | if best_match_idx[obj_id] != -1: 148 | det_obj_ids[int(best_match_idx[obj_id])] = -1 149 | det_obj_ids[idx] = obj_id 150 | best_match_scores[obj_id] = match_score 151 | best_match_idx[obj_id] = idx 152 | # udpate feature 153 | for k, v in self.prev_candidate.items(): 154 | if k not in {'proto', 'T2S_feat', 'fpn_feat', 'tracked_mask'}: 155 | self.prev_candidate[k][obj_id] = candidate[k][idx] 156 | self.prev_candidate['tracked_mask'][obj_id] = 0 157 | 158 | det_obj_ids = torch.arange(self.prev_candidate['box'].size(0)) 159 | # whether add some tracked masks 160 | cond1 = self.prev_candidate['tracked_mask'] <= 10 161 | # whether tracked masks are greater than a small threshold, which removes some false positives 162 | cond2 = self.prev_candidate['mask'].gt(0.5).sum([1, 2]) > 1 163 | # a declining weights (0.8) to remove some false positives that cased by consecutively track to segment 164 | cond3 = self.prev_candidate['score'].clone().detach() > cfg.eval_conf_thresh 165 | keep = cond1 & cond2 & cond3 166 | 167 | if keep.sum() == 0: 168 | detection = {'box': torch.Tensor(), 'mask_coeff': torch.Tensor(), 'class': torch.Tensor(), 169 | 'score': torch.Tensor(), 'box_ids': torch.Tensor()} 170 | else: 171 | 172 | detection = {'box': self.prev_candidate['box'][keep], 173 | 'mask_coeff': self.prev_candidate['mask_coeff'][keep], 174 | 'track': self.prev_candidate['track'][keep], 175 | 'class': self.prev_candidate['class'][keep], 176 | 'score': self.prev_candidate['score'][keep], 177 | 'centerness': self.prev_candidate['centerness'][keep], 178 | 'proto': candidate['proto'], 'mask': self.prev_candidate['mask'][keep], 179 | 'box_ids': det_obj_ids[keep]} 180 | 181 | return detection 182 | -------------------------------------------------------------------------------- /layers/interpolate.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class InterpolateModule(nn.Module): 6 | """ 7 | This is a module version of F.interpolate (rip nn.Upsampling). 8 | Any arguments you give it just get passed along for the ride. 9 | """ 10 | 11 | def __init__(self, *args, **kwdargs): 12 | super().__init__() 13 | 14 | self.args = args 15 | self.kwdargs = kwdargs 16 | 17 | def forward(self, x): 18 | return F.interpolate(x, *self.args, **self.kwdargs) 19 | -------------------------------------------------------------------------------- /layers/mask_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn.functional as F 4 | import matplotlib.pyplot as plt 5 | from .box_utils import crop, crop_sipmask 6 | from datasets import cfg 7 | 8 | 9 | def generate_rel_coord(det_bbox, mask_h, mask_w, sigma_scale=2): 10 | ''' 11 | :param det_box: the centers of pos bboxes ==> [cx, cy] 12 | :param mask_h: height of pred_mask 13 | :param mask_w: weight of pred_mask 14 | :return: rel_coord ==> [num_pos, mask_h, mask_w, 2] 15 | ''' 16 | 17 | # generate relative coordinates 18 | num_pos = det_bbox.size(0) 19 | det_bbox_ori = det_bbox.new(num_pos, 4) 20 | det_bbox_ori[:, 0::2] = det_bbox[:, 0::2] * mask_w 21 | det_bbox_ori[:, 1::2] = det_bbox[:, 1::2] * mask_h 22 | x_range = torch.arange(mask_w) 23 | y_range = torch.arange(mask_h) 24 | y_grid, x_grid = torch.meshgrid(y_range, x_range) 25 | det_bbox_c = (det_bbox_ori[:, :2] + det_bbox_ori[:, 2:]) / 2 26 | cx, cy = torch.round(det_bbox_c[:, 0]), torch.round(det_bbox_c[:, 1]) 27 | y_rel_coord = (y_grid.float().unsqueeze(0).repeat(num_pos, 1, 1) - cy.view(-1, 1, 1)) ** 2 28 | x_rel_coord = (x_grid.float().unsqueeze(0).repeat(num_pos, 1, 1) - cx.view(-1, 1, 1)) ** 2 29 | 30 | # build 2D Normal distribution 31 | det_bbox_wh = det_bbox_ori[:, 2:] - det_bbox_ori[:, :2] 32 | rel_coord = [] 33 | for i in range(num_pos): 34 | if det_bbox_wh[i][0] * det_bbox_wh[i][1] / mask_h / mask_w < 0.1: 35 | sigma_scale = 0.5 * sigma_scale 36 | sigma_x, sigma_y = det_bbox_wh[i] / sigma_scale 37 | val = torch.exp(-0.5 * (x_rel_coord[i] / (sigma_x ** 2) + y_rel_coord[i] / (sigma_y ** 2))) 38 | rel_coord.append(val.unsqueeze(0)) 39 | 40 | return torch.cat(rel_coord, dim=0) 41 | 42 | 43 | def mask_head(protos, proto_coeff, num_mask_head, mask_dim=8, use_rela_coord=False, img_meta=None): 44 | """ 45 | :param protos: [1, n, h, w] 46 | :param proto_coeff: reshape as weigths and bias 47 | :return: [1, 1, h, w] 48 | """ 49 | 50 | # reshape proto_coef as weights and bias of filters 51 | if use_rela_coord: 52 | ch = mask_dim + 1 53 | else: 54 | ch = mask_dim 55 | ch2 = mask_dim * ch 56 | 57 | if num_mask_head == 1: 58 | weights1 = proto_coeff[:8].reshape(1, 8, 1, 1) 59 | bias1 = proto_coeff[-1].reshape(1) 60 | # FCN network for mask prediction 61 | pred_masks = F.conv2d(protos, weights1, bias1, stride=1, padding=0, dilation=1, groups=1) 62 | 63 | elif num_mask_head == 2: 64 | weights1 = proto_coeff[:ch2].reshape(mask_dim, ch, 1, 1) 65 | bias1 = proto_coeff[ch2:ch2 + mask_dim] 66 | weights2 = proto_coeff[ch2 + mask_dim:ch2 + 2 * mask_dim].reshape(1, mask_dim, 1, 1) 67 | bias2 = proto_coeff[-1].reshape(1) 68 | # FCN network for mask prediction 69 | protos1 = F.relu(F.conv2d(protos, weights1, bias1, stride=1, padding=0, dilation=1, groups=1)) 70 | pred_masks = F.conv2d(protos1, weights2, bias2, stride=1, padding=0, dilation=1, groups=1) 71 | 72 | # plot_protos(protos, pred_masks, img_meta, num=1) 73 | # plot_protos(protos1, pred_masks, img_meta, num=2) 74 | 75 | elif num_mask_head == 3: 76 | weights1 = proto_coeff[:ch2].reshape(mask_dim, ch, 1, 1) 77 | bias1 = proto_coeff[ch2:ch2 + mask_dim] 78 | weights2 = proto_coeff[ch2 + mask_dim:ch2 + mask_dim + mask_dim**2].reshape(mask_dim, mask_dim, 1, 1) 79 | bias2 = proto_coeff[ch2 + mask_dim + mask_dim**2:ch2 + mask_dim*2 + mask_dim**2] 80 | weights3 = proto_coeff[ch2 + mask_dim*2 + mask_dim**2:ch2 + mask_dim*3 + mask_dim**2].reshape(1, mask_dim, 1, 1) 81 | bias3 = proto_coeff[-1].reshape(1) 82 | # FCN network for mask prediction 83 | protos1 = F.relu(F.conv2d(protos, weights1, bias1, stride=1, padding=0, dilation=1, groups=1)) 84 | protos2 = F.relu(F.conv2d(protos1, weights2, bias2, stride=1, padding=0, dilation=1, groups=1)) 85 | pred_masks = F.conv2d(protos2, weights3, bias3, stride=1, padding=0, dilation=1, groups=1) 86 | 87 | return pred_masks 88 | 89 | 90 | def plot_protos(protos, pred_masks, img_meta, num): 91 | if protos.size(1) == 9: 92 | protos = torch.cat([protos, protos[:, -1, :, :].unsqueeze(1)], dim=1) 93 | elif protos.size(1) == 8: 94 | protos = torch.cat([protos, pred_masks, pred_masks], dim=1) 95 | proto_data = protos.squeeze(0) 96 | num_per_row = int(proto_data.size(0) / 2) 97 | proto_data_list = [] 98 | for r in range(2): 99 | proto_data_list.append( 100 | torch.cat([proto_data[i, :, :] * 5 for i in range(num_per_row * r, num_per_row * (r + 1))], dim=-1)) 101 | 102 | img = torch.cat(proto_data_list, dim=0) 103 | img = img / img.max() 104 | plt.imshow(img.cpu().detach().numpy()) 105 | plt.title([img_meta['video_id'], img_meta['frame_id'], 'protos']) 106 | plt.savefig(''.join(['results/results_0306/out_protos/', 107 | str((img_meta['video_id'], img_meta['frame_id'])), 108 | str(num), '.png'])) 109 | 110 | 111 | def generate_mask(proto_data, mask_coeff, bbox=None, use_sipmask=False): 112 | mask_coeff = cfg.mask_proto_coeff_activation(mask_coeff) 113 | # get masks 114 | if use_sipmask: 115 | pred_masks00 = cfg.mask_proto_mask_activation(proto_data @ mask_coeff[:, :32].t()) 116 | pred_masks01 = cfg.mask_proto_mask_activation(proto_data @ mask_coeff[:, 32:64].t()) 117 | pred_masks10 = cfg.mask_proto_mask_activation(proto_data @ mask_coeff[:, 64:96].t()) 118 | pred_masks11 = cfg.mask_proto_mask_activation(proto_data @ mask_coeff[:, 96:128].t()) 119 | pred_masks = crop_sipmask(pred_masks00, pred_masks01, pred_masks10, pred_masks11, bbox) 120 | else: 121 | pred_masks = proto_data @ mask_coeff.t() 122 | pred_masks = cfg.mask_proto_mask_activation(pred_masks) 123 | if bbox is not None: 124 | _, pred_masks = crop(pred_masks.squeeze(0), bbox) # [mask_h, mask_w, n] 125 | 126 | det_masks = pred_masks.permute(2, 0, 1).contiguous() # [n_masks, h, w] 127 | 128 | return det_masks 129 | -------------------------------------------------------------------------------- /layers/modules/FPN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import List 5 | 6 | from datasets.config import cfg 7 | try: 8 | from dcn_v2 import DCN, DCNv2 9 | except ImportError: 10 | def DCN(*args, **kwdargs): 11 | raise Exception('DCN could not be imported. If you want to use YOLACT++ models, compile DCN. Check the README for instructions.') 12 | 13 | # As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules 14 | use_jit = torch.cuda.device_count() <= 1 15 | if not use_jit: 16 | print('Multiple GPUs detected! Turning off JIT.') 17 | 18 | ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module 19 | script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn 20 | 21 | 22 | class FPN(ScriptModuleWrapper): 23 | """ 24 | Implements a general version of the FPN introduced in 25 | https://arxiv.org/pdf/1612.03144.pdf 26 | 27 | Parameters (in cfg.fpn): 28 | - num_features (int): The number of output features in the fpn layers. 29 | - interpolation_mode (str): The mode to pass to F.interpolate. 30 | - num_downsample (int): The number of downsampled layers to add onto the selected layers. 31 | These extra layers are downsampled from the last selected layer. 32 | 33 | Args: 34 | - in_channels (list): For each conv layer you supply in the forward pass, 35 | how many features will it have? 36 | """ 37 | __constants__ = ['interpolation_mode', 'num_downsample', 'use_conv_downsample', 38 | 'lat_layers', 'pred_layers', 'downsample_layers'] 39 | 40 | def __init__(self, in_channels): 41 | super().__init__() 42 | 43 | self.lat_layers = nn.ModuleList([ 44 | nn.Conv2d(x, cfg.fpn.num_features, kernel_size=1) 45 | for x in reversed(in_channels) 46 | ]) 47 | 48 | # This is here for backwards compatability 49 | padding = 1 if cfg.fpn.pad else 0 50 | self.pred_layers = nn.ModuleList([ 51 | nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=padding) 52 | for _ in in_channels 53 | ]) 54 | 55 | if cfg.fpn.use_conv_downsample: 56 | self.downsample_layers = nn.ModuleList([ 57 | nn.Conv2d(cfg.fpn.num_features, cfg.fpn.num_features, kernel_size=3, padding=1, stride=2) 58 | for _ in range(cfg.fpn.num_downsample) 59 | ]) 60 | 61 | self.interpolation_mode = cfg.fpn.interpolation_mode 62 | self.num_downsample = cfg.fpn.num_downsample 63 | self.use_conv_downsample = cfg.fpn.use_conv_downsample 64 | self.relu_downsample_layers = cfg.fpn.relu_downsample_layers # yolact++ 65 | self.relu_pred_layers = cfg.fpn.relu_pred_layers # yolact++ 66 | 67 | @script_method_wrapper 68 | def forward(self, convouts: List[torch.Tensor]): 69 | """ 70 | Args: 71 | - convouts (list): A list of convouts for the corresponding layers in in_channels. 72 | Returns: 73 | - A list of FPN convouts in the same order as x with extra downsample layers if requested. 74 | """ 75 | 76 | out = [] 77 | x = torch.zeros(1, device=convouts[0].device) 78 | for i in range(len(convouts)): 79 | out.append(x) 80 | 81 | # For backward compatability, the conv layers are stored in reverse but the input and output is 82 | # given in the correct order. Thus, use j=-i-1 for the input and output and i for the conv layers. 83 | j = len(convouts) 84 | for lat_layer in self.lat_layers: 85 | j -= 1 86 | 87 | if j < len(convouts) - 1: 88 | _, _, h, w = convouts[j].size() 89 | x = F.interpolate(x, size=(h, w), mode=self.interpolation_mode, align_corners=False) 90 | 91 | x = x + lat_layer(convouts[j]) 92 | out[j] = x 93 | 94 | # This janky second loop is here because TorchScript. 95 | j = len(convouts) 96 | for pred_layer in self.pred_layers: 97 | j -= 1 98 | out[j] = F.relu(pred_layer(out[j])) 99 | 100 | # In the original paper, this takes care of P6 101 | if self.use_conv_downsample: 102 | for downsample_layer in self.downsample_layers: 103 | out.append(downsample_layer(out[-1])) 104 | else: 105 | for idx in range(self.num_downsample): 106 | # Note: this is an untested alternative to out.append(out[-1][:, :, ::2, ::2]). Thanks TorchScript. 107 | out.append(nn.functional.max_pool2d(out[-1], 1, stride=2)) 108 | 109 | return out -------------------------------------------------------------------------------- /layers/modules/FastMaskIoUNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from datasets.config import cfg 5 | from .make_net import make_net 6 | 7 | try: 8 | from dcn_v2 import DCN, DCNv2 9 | except ImportError: 10 | def DCN(*args, **kwdargs): 11 | raise Exception('DCN could not be imported. If you want to use YOLACT++ models, compile DCN. Check the README for instructions.') 12 | 13 | # As of March 10, 2019, Pytorch DataParallel still doesn't support JIT Script Modules 14 | use_jit = torch.cuda.device_count() <= 1 15 | if not use_jit: 16 | print('Multiple GPUs detected! Turning off JIT.') 17 | 18 | ScriptModuleWrapper = torch.jit.ScriptModule if use_jit else nn.Module 19 | script_method_wrapper = torch.jit.script_method if use_jit else lambda fn, _rcn=None: fn 20 | 21 | 22 | class FastMaskIoUNet(ScriptModuleWrapper): 23 | 24 | def __init__(self): 25 | super().__init__() 26 | input_channels = 1 27 | last_layer = [(cfg.num_classes-1, 1, {})] 28 | self.maskiou_net, _ = make_net(input_channels, cfg.maskiou_net + last_layer, include_last_relu=True) 29 | 30 | def forward(self, x): 31 | x = self.maskiou_net(x) 32 | maskiou_p = F.max_pool2d(x, kernel_size=x.size()[2:]).squeeze(-1).squeeze(-1) 33 | 34 | return maskiou_p -------------------------------------------------------------------------------- /layers/modules/Featurealign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.ops import DeformConv2d 4 | 5 | 6 | class FeatureAlign(nn.Module): 7 | def __init__(self, 8 | in_channels, 9 | out_channels, 10 | kernel_size=(3, 3), 11 | deformable_groups=4, 12 | use_pred_offset=True): 13 | super(FeatureAlign, self).__init__() 14 | if isinstance(kernel_size, int): 15 | kernel_size = (kernel_size, kernel_size) 16 | self.kernel_size = kernel_size 17 | self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) 18 | self.use_pred_offset = use_pred_offset 19 | 20 | if self.use_pred_offset: 21 | offset_channels = kernel_size[0] * kernel_size[1] * 2 22 | self.conv_offset = nn.Conv2d(4, 23 | deformable_groups * offset_channels, 24 | 1, 25 | bias=False) 26 | 27 | self.conv_adaption = DeformConv2d(in_channels, 28 | in_channels, 29 | kernel_size=self.kernel_size, 30 | padding=self.padding, 31 | deform_groups=deformable_groups) 32 | self.conv = nn.Conv2d(in_channels, out_channels, 33 | kernel_size=self.kernel_size, padding=self.padding) 34 | 35 | self.relu = nn.ReLU(inplace=True) 36 | # self.norm = nn.BatchNorm2d(in_channels) 37 | 38 | def init_weights(self, bias_value=0): 39 | torch.nn.init.normal_(self.conv_offset.weight, std=0.0) 40 | torch.nn.init.normal_(self.conv_adaption.weight, std=0.01) 41 | 42 | def forward(self, x, shape): 43 | if self.use_pred_offset: 44 | offset = self.conv_offset(shape.detach()) 45 | else: 46 | ks_h, ks_w = self.kernel_size 47 | batch_size = x.size(0) 48 | 49 | variances = [0.1, 0.2] 50 | # dx = 2*\delta x , dy = 2*\delta y 51 | dxy = shape[:, :2].view(batch_size, 2, -1) * variances[0] # [bs, 2, hw] 52 | dx = (dxy[:, 0] * ks_w).unsqueeze(1).repeat(1, ks_h * ks_w, 1) 53 | dy = (dxy[:, 1] * ks_h).unsqueeze(1).repeat(1, ks_h * ks_w, 1) 54 | 55 | # dw = exp(\delta w) - 1 56 | dwh = (shape[:, 2:].view(batch_size, 2, -1) * variances[1]).exp() - 1 57 | 58 | # build offset for h 59 | dh_R = torch.arange(-ks_h // 2 + 1, ks_h // 2 + 1).float() 60 | dh_R = dh_R.view(-1, 1).repeat(1, ks_w) 61 | dh = dwh[:, 1].unsqueeze(1) * dh_R.view(1, -1, 1) 62 | # build offset for w 63 | dw_R = torch.arange(-ks_w // 2 + 1, ks_w // 2 + 1).float() 64 | dw_R = dw_R.repeat(ks_h) 65 | dw = dwh[:, 0].unsqueeze(1) * dw_R.view(1, -1, 1) 66 | 67 | # [dy1, dx1, dy2, dx2, ..., dyn, dxn] 68 | offset = torch.stack([dy + dh, dx + dw], dim=1).permute(0, 2, 1, 3).contiguous() 69 | offset = offset.view(batch_size, -1, x.size(2), x.size(3)) 70 | 71 | # x = self.conv_adaption(x, offset) 72 | x = self.relu(self.conv_adaption(x, offset)) 73 | x = self.conv(x) 74 | return x 75 | 76 | 77 | -------------------------------------------------------------------------------- /layers/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multibox_loss import MultiBoxLoss 2 | from .FPN import FPN 3 | from .FastMaskIoUNet import FastMaskIoUNet 4 | from .track_to_segment_head import TemporalNet, correlate, bbox_feat_extractor 5 | from .prediction_head_FC import PredictionModule_FC 6 | from .prediction_head import PredictionModule 7 | from .make_net import make_net 8 | from .Featurealign import FeatureAlign 9 | 10 | __all__ = ['MultiBoxLoss', 'FPN', 'FastMaskIoUNet', 11 | 'TemporalNet', 'correlate', 'bbox_feat_extractor', 12 | 'PredictionModule', 'PredictionModule_FC', 'make_net', 'FeatureAlign'] 13 | -------------------------------------------------------------------------------- /layers/modules/make_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from layers.interpolate import InterpolateModule 3 | 4 | 5 | def make_net(in_channels, conf, include_last_relu=True): 6 | """ 7 | A helper function to take a config setting and turn it into a network. 8 | Used by protonet and extrahead. Returns (network, out_channels) 9 | """ 10 | 11 | def make_layer(layer_cfg): 12 | nonlocal in_channels 13 | 14 | # Possible patterns: 15 | # ( 256, 3, {}) -> conv 16 | # ( 256,-2, {}) -> deconv 17 | # (None,-2, {}) -> bilinear interpolate 18 | # ('cat',[],{}) -> concat the subnetworks in the list 19 | # 20 | # You know it would have probably been simpler just to adopt a 'c' 'd' 'u' naming scheme. 21 | # Whatever, it's too late now. 22 | if isinstance(layer_cfg[0], str): 23 | layer_name = layer_cfg[0] 24 | 25 | if layer_name == 'cat': 26 | nets = [make_net(in_channels, x) for x in layer_cfg[1]] 27 | layer = Concat([net[0] for net in nets], layer_cfg[2]) 28 | num_channels = sum([net[1] for net in nets]) 29 | else: 30 | num_channels = layer_cfg[0] 31 | kernel_size = layer_cfg[1] 32 | 33 | if kernel_size > 0: 34 | layer = nn.Conv2d(in_channels, num_channels, kernel_size, **layer_cfg[2]) 35 | else: 36 | if num_channels is None: 37 | layer = InterpolateModule(scale_factor=-kernel_size, mode='bilinear', align_corners=False, 38 | **layer_cfg[2]) 39 | else: 40 | layer = nn.ConvTranspose2d(in_channels, num_channels, -kernel_size, **layer_cfg[2]) 41 | 42 | in_channels = num_channels if num_channels is not None else in_channels 43 | 44 | # Don't return a ReLU layer if we're doing an upsample. This probably doesn't affect anything 45 | # output-wise, but there's no need to go through a ReLU here. 46 | # Commented out for backwards compatibility with previous models 47 | # if num_channels is None: 48 | # return [layer] 49 | # else: 50 | return [layer, nn.ReLU(inplace=True)] 51 | 52 | # Use sum to concat together all the component layer lists 53 | net = sum([make_layer(x) for x in conf], []) 54 | if not include_last_relu: 55 | net = net[:-1] 56 | 57 | return nn.Sequential(*(net)), in_channels 58 | 59 | 60 | -------------------------------------------------------------------------------- /layers/modules/prediction_head.py: -------------------------------------------------------------------------------- 1 | 2 | import torch, torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from datasets.config import cfg, mask_type 6 | 7 | from .make_net import make_net 8 | from .Featurealign import FeatureAlign 9 | from utils import timer 10 | from itertools import product 11 | from math import sqrt 12 | from mmcv.ops import DeformConv2d 13 | 14 | 15 | class PredictionModule(nn.Module): 16 | """ 17 | The (c) prediction module adapted from DSSD: 18 | https://arxiv.org/pdf/1701.06659.pdf 19 | 20 | Note that this is slightly different to the module in the paper 21 | because the Bottleneck block actually has a 3x3 convolution in 22 | the middle instead of a 1x1 convolution. Though, I really can't 23 | be arsed to implement it myself, and, who knows, this might be 24 | better. 25 | Args: 26 | - in_channels: The input feature size. 27 | - out_channels: The output feature size (must be a multiple of 4). 28 | - aspect_ratios: A list of lists of priorbox aspect ratios (one list per scale). 29 | - scales: A list of priorbox scales relative to this layer's convsize. 30 | For instance: If this layer has convouts of size 30x30 for 31 | an image of size 600x600, the 'default' (scale 32 | of 1) for this layer would produce bounding 33 | boxes with an area of 20x20px. If the scale is 34 | .5 on the other hand, this layer would consider 35 | bounding boxes with area 10x10px, etc. 36 | - parent: If parent is a PredictionModule, this module will use all the layers 37 | from parent instead of from this module. 38 | """ 39 | 40 | def __init__(self, in_channels, out_channels=1024, 41 | pred_aspect_ratios=None, pred_scales=None, parent=None, deform_groups=1): 42 | super().__init__() 43 | 44 | self.out_channels = out_channels 45 | self.num_classes = cfg.num_classes 46 | self.mask_dim = cfg.mask_dim 47 | self.num_priors = len(pred_aspect_ratios[0]) * len(pred_scales) 48 | self.embed_dim = cfg.embed_dim 49 | self.pred_aspect_ratios = pred_aspect_ratios 50 | self.pred_scales = pred_scales 51 | self.deform_groups = deform_groups 52 | self.parent = [parent] # Don't include this in the state dict 53 | self.num_heads = cfg.num_heads 54 | if cfg.use_sipmask: 55 | self.mask_dim = self.mask_dim * cfg.sipmask_head 56 | 57 | if cfg.mask_proto_prototypes_as_features: 58 | in_channels += self.mask_dim 59 | 60 | if parent is None: 61 | if cfg.extra_head_net is None: 62 | self.out_channels = in_channels 63 | else: 64 | self.upfeature, self.out_channels = make_net(in_channels, cfg.extra_head_net) 65 | 66 | self.bbox_layer = nn.Conv2d(out_channels, self.num_priors * 4, **cfg.head_layer_params) 67 | 68 | kernel_size = cfg.head_layer_params['kernel_size'] 69 | if cfg.train_class: 70 | if cfg.use_cascade_pred and cfg.use_dcn_class: 71 | self.conf_layer = FeatureAlign(self.out_channels, 72 | self.num_priors * self.num_classes, 73 | kernel_size=kernel_size, 74 | deformable_groups=self.deform_groups, 75 | use_pred_offset=cfg.use_pred_offset) 76 | else: 77 | self.conf_layer = nn.Conv2d(self.out_channels, self.num_priors * self.num_classes, 78 | **cfg.head_layer_params) 79 | 80 | if cfg.train_track: 81 | if cfg.use_cascade_pred and cfg.use_dcn_track: 82 | self.track_layer = FeatureAlign(self.out_channels, 83 | self.num_priors * self.embed_dim, 84 | kernel_size=kernel_size, 85 | deformable_groups=self.deform_groups, 86 | use_pred_offset=cfg.use_pred_offset) 87 | else: 88 | self.track_layer = nn.Conv2d(out_channels, self.num_priors * self.embed_dim, **cfg.head_layer_params) 89 | 90 | if cfg.use_cascade_pred and cfg.use_dcn_mask: 91 | self.mask_layer = FeatureAlign(self.out_channels, 92 | self.num_priors * self.mask_dim, 93 | kernel_size=kernel_size, 94 | deformable_groups=self.deform_groups, 95 | use_pred_offset=cfg.use_pred_offset) 96 | else: 97 | self.mask_layer = nn.Conv2d(out_channels, self.num_priors * self.mask_dim, **cfg.head_layer_params) 98 | 99 | if cfg.train_centerness: 100 | self.centerness_layer = nn.Conv2d(out_channels, self.num_priors, **cfg.head_layer_params) 101 | 102 | if cfg.use_instance_coeff: 103 | self.inst_layer = nn.Conv2d(out_channels, self.num_priors * cfg.num_instance_coeffs, 104 | **cfg.head_layer_params) 105 | 106 | # What is this ugly lambda doing in the middle of all this clean prediction module code? 107 | def make_extra(num_layers): 108 | if num_layers == 0: 109 | return lambda x: x 110 | else: 111 | # Looks more complicated than it is. This just creates an array of num_layers alternating conv-relu 112 | return nn.Sequential(*sum([[ 113 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 114 | nn.ReLU(inplace=True) 115 | ] for _ in range(num_layers)], [])) 116 | 117 | self.bbox_extra, self.conf_extra, self.mask_extra, self.track_extra = [make_extra(x) for x in cfg.extra_layers] 118 | 119 | if cfg.mask_type == mask_type.lincomb and cfg.mask_proto_coeff_gate: 120 | self.gate_layer = nn.Conv2d(out_channels, self.num_priors * self.mask_dim, kernel_size=3, padding=1) 121 | 122 | def forward(self, x): 123 | """ 124 | Args: 125 | - x: The convOut from a layer in the backbone network 126 | Size: [batch_size, in_channels, conv_h, conv_w]) 127 | Returns a tuple (bbox_coords, class_confs, mask_output, prior_boxes) with sizes 128 | - bbox_coords: [batch_size, conv_h*conv_w*num_priors, 4] 129 | - class_confs: [batch_size, conv_h*conv_w*num_priors, num_classes] 130 | - mask_output: [batch_size, conv_h*conv_w*num_priors, mask_dim] 131 | - prior_boxes: [conv_h*conv_w*num_priors, 4] 132 | """ 133 | # In case we want to use another module's layers 134 | src = self if self.parent[0] is None else self.parent[0] 135 | 136 | bs, _, conv_h, conv_w = x.size() 137 | 138 | if cfg.extra_head_net is not None: 139 | x = src.upfeature(x) 140 | 141 | bbox_x = src.bbox_extra(x) 142 | conf_x = src.conf_extra(x) 143 | mask_x = src.mask_extra(x) 144 | track_x = src.track_extra(x) 145 | 146 | bbox = src.bbox_layer(bbox_x) 147 | if cfg.use_cascade_pred: 148 | offset = src.conv_offset(bbox.detach()) 149 | # o1, o2, offset_mask = torch.chunk(offset_all, 3, dim=1) 150 | # offset = torch.cat((o1, o2), dim=1) 151 | # offset_mask = offset.new_ones(bs, int(offset.size(1)/2), conv_h, conv_w) 152 | bbox = bbox.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4) 153 | 154 | if cfg.train_class: 155 | if cfg.use_cascade_pred and cfg.use_dcn_class: 156 | conf = src.conf_layer(conf_x, offset) 157 | else: 158 | conf = src.conf_layer(conf_x) 159 | conf = conf.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.num_classes) 160 | 161 | if cfg.train_track: 162 | if cfg.use_cascade_pred and cfg.use_dcn_track: 163 | track = src.track_layer(track_x, offset) 164 | else: 165 | track = src.track_layer(track_x) 166 | track = track.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.embed_dim) 167 | 168 | if cfg.use_cascade_pred and cfg.use_dcn_mask: 169 | mask = src.mask_layer(mask_x, offset) 170 | else: 171 | mask = src.mask_layer(mask_x) 172 | mask = mask.permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.mask_dim) 173 | 174 | if cfg.train_centerness: 175 | centerness = src.centerness_layer(bbox_x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 1) 176 | centerness = torch.tanh(centerness) 177 | 178 | if cfg.use_mask_scoring: 179 | score = src.score_layer(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 1) 180 | 181 | if cfg.use_instance_coeff: 182 | inst = src.inst_layer(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, cfg.num_instance_coeffs) 183 | 184 | # See box_utils.decode for an explanation of this 185 | if cfg.use_yolo_regressors: 186 | bbox[:, :, :2] = torch.sigmoid(bbox[:, :, :2]) - 0.5 187 | bbox[:, :, 0] /= conv_w 188 | bbox[:, :, 1] /= conv_h 189 | 190 | if cfg.mask_proto_split_prototypes_by_head and cfg.mask_type == mask_type.lincomb: 191 | mask = F.pad(mask, (self.index * self.mask_dim, (self.num_heads - self.index - 1) * self.mask_dim), 192 | mode='constant', value=0) 193 | 194 | priors = self.make_priors(conv_h, conv_w, x.device) 195 | preds = {'loc': bbox, 'conf': conf, 'mask_coeff': mask, 'priors': priors} 196 | 197 | if cfg.train_centerness: 198 | preds['centerness'] = centerness 199 | 200 | if cfg.train_track: 201 | preds['track'] = F.normalize(track, dim=-1) 202 | 203 | if cfg.use_mask_scoring: 204 | preds['score'] = score 205 | 206 | if cfg.use_instance_coeff: 207 | preds['inst'] = inst 208 | 209 | if cfg.temporal_fusion_module: 210 | preds['T2S_feat'] = x 211 | 212 | return preds 213 | 214 | def make_priors(self, conv_h, conv_w, device): 215 | """ Note that priors are [x,y,width,height] where (x,y) is the center of the box. """ 216 | with timer.env('makepriors'): 217 | prior_data = [] 218 | # Iteration order is important (it has to sync up with the convout) 219 | for j, i in product(range(conv_h), range(conv_w)): 220 | # +0.5 because priors are in center-size notation 221 | x = (i + 0.5) / conv_w 222 | y = (j + 0.5) / conv_h 223 | 224 | for ars in self.pred_aspect_ratios: 225 | for scale in self.pred_scales: 226 | for ar in ars: 227 | # [1, 1/2, 2] 228 | ar = sqrt(ar) 229 | r = scale / self.pred_scales[0] * 3 230 | w = r * ar / conv_w 231 | h = r / ar / conv_h 232 | 233 | prior_data += [x, y, w, h] 234 | 235 | priors = torch.Tensor(prior_data, device=device).view(1, -1, 4).detach() 236 | priors.requires_grad = False 237 | 238 | return priors 239 | 240 | -------------------------------------------------------------------------------- /layers/modules/track_to_segment_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from datasets.config import cfg 4 | from spatial_correlation_sampler import spatial_correlation_sample 5 | import torch.nn.functional as F 6 | from mmcv.ops import roi_align 7 | from layers.box_utils import sanitize_coordinates_hw 8 | 9 | 10 | class TemporalNet(nn.Module): 11 | def __init__(self, corr_channels, mask_proto_n=32): 12 | 13 | super().__init__() 14 | self.conv1 = nn.Conv2d(corr_channels, 512, kernel_size=3, padding=1) 15 | self.conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 16 | self.conv3 = nn.Conv2d(512, 1024, kernel_size=3, padding=1) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.pool = nn.AvgPool2d((7, 7), stride=1) 19 | self.fc = nn.Linear(1024, 4) 20 | if cfg.use_sipmask: 21 | self.fc_coeff = nn.Linear(1024, mask_proto_n * cfg.sipmask_head) 22 | else: 23 | self.fc_coeff = nn.Linear(1024, mask_proto_n) 24 | 25 | def forward(self, x): 26 | x = self.conv1(x) 27 | x = self.relu(x) 28 | x = self.conv2(x) 29 | x = self.relu(x) 30 | x = self.conv3(x) 31 | x = self.relu(x) 32 | x = self.pool(x) 33 | x = x.view(x.size(0), -1) 34 | x_reg = self.fc(x) 35 | x_coeff = self.fc_coeff(x) 36 | 37 | return x_reg, x_coeff 38 | 39 | 40 | def correlate(x1, x2, patch_size=11, dilation_patch=1): 41 | """ 42 | :param x1: features 1 43 | :param x2: features 2 44 | :param patch_size: the size of whole patch is used to calculate the correlation 45 | :return: 46 | """ 47 | 48 | # Output sizes oH and oW are no longer dependant of patch size, but only of kernel size and padding 49 | # patch_size is now the whole patch, and not only the radii. 50 | # stride1 is now stride and stride2 is dilation_patch, which behave like dilated convolutions 51 | # equivalent max_displacement is then dilation_patch * (patch_size - 1) / 2. 52 | # to get the right parameters for FlowNetC, you would have 53 | out_corr = spatial_correlation_sample(x1, 54 | x2, 55 | kernel_size=1, 56 | patch_size=patch_size, 57 | stride=1, 58 | padding=0, 59 | dilation_patch=dilation_patch) 60 | b, ph, pw, h, w = out_corr.size() 61 | out_corr = out_corr.view(b, ph*pw, h, w) / x1.size(1) 62 | return F.leaky_relu_(out_corr, 0.1) 63 | 64 | 65 | def bbox_feat_extractor(feature_maps, boxes_w_norm, h, w, pool_size): 66 | """ 67 | feature_maps: size:1*C*h*w 68 | boxes: Mx5 float box with (x1, y1, x2, y2) **without normalization** 69 | """ 70 | # Currently only supports batch_size 1 71 | boxes = sanitize_coordinates_hw(boxes_w_norm, h, w) 72 | # boxes = boxes_w_norm 73 | 74 | # Crop and Resize 75 | # Result: [num_boxes, pool_height, pool_width, channels] 76 | box_ind = torch.zeros(boxes.size(0)) # index of bbox in batch 77 | if boxes.is_cuda: 78 | box_ind = box_ind.cuda() 79 | 80 | # CropAndResizeFunction needs batch dimension 81 | if len(feature_maps.size()) == 3: 82 | feature_maps = feature_maps.unsqueeze(0) 83 | 84 | # make crops: 85 | rois = torch.cat([box_ind.unsqueeze(1), boxes], dim=1) 86 | pooled_features = roi_align(feature_maps, rois, pool_size) 87 | 88 | return pooled_features 89 | 90 | -------------------------------------------------------------------------------- /layers/output_utils.py: -------------------------------------------------------------------------------- 1 | """ Contains functions used to sanitize and prepare the output of Yolact. """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import cv2 7 | import pycocotools.mask as mask_util 8 | from matplotlib.patches import Polygon 9 | 10 | from datasets import cfg, mask_type, MEANS, STD, activation_func 11 | from utils.augmentations import Resize 12 | from utils import timer 13 | from .box_utils import crop, sanitize_coordinates, center_size 14 | 15 | 16 | def postprocess_ytbvis(det_output, img_meta, interpolation_mode='bilinear', 17 | display_mask=False, visualize_lincomb=False, crop_masks=True, score_threshold=0, 18 | img_ids=None, mask_det_file=None): 19 | """ 20 | Postprocesses the output of Yolact on testing mode into a format that makes sense, 21 | accounting for all the possible configuration settings. 22 | 23 | Args: 24 | - det_output: The lost of dicts that Detect outputs. 25 | - w: The real with of the image. 26 | - h: The real height of the image. 27 | - batch_idx: If you have multiple images for this batch, the image's index in the batch. 28 | - interpolation_mode: Can be 'nearest' | 'area' | 'bilinear' (see torch.nn.functional.interpolate) 29 | 30 | Returns 4 torch Tensors (in the following order): 31 | - classes [num_det]: The class idx for each detection. 32 | - scores [num_det]: The confidence score for each detection. 33 | - boxes [num_det, 4]: The bounding box for each detection in absolute point form. 34 | - masks [num_det, h, w]: Full image masks for each detection. 35 | """ 36 | 37 | net = det_output['net'] 38 | detection = det_output['detection'] 39 | dets = {} 40 | for k, v in detection.items(): 41 | dets[k] = v.clone() 42 | 43 | ori_h, ori_w = img_meta['ori_shape'][:2] 44 | img_h, img_w = img_meta['img_shape'][:2] 45 | pad_h, pad_w = img_meta['pad_shape'][:2] 46 | s_w, s_h = (img_w / pad_w, img_h / pad_h) 47 | 48 | if dets['box'].nelement() == 0: 49 | dets['segm'] = torch.Tensor() 50 | return dets 51 | 52 | # double check 53 | if score_threshold > 0: 54 | keep = dets['score'] > score_threshold 55 | 56 | for k in dets: 57 | if k not in {'proto', 'bbox_idx', 'priors', 'embed_vectors', 'box_shift'} and dets[k] is not None: 58 | dets[k] = dets[k][keep] 59 | 60 | # Undo the padding introduced with preserve_aspect_ratio 61 | if cfg.preserve_aspect_ratio and dets['score'].nelement() != 0: 62 | # Get rid of any detections whose centers are outside the image 63 | boxes = dets['box'] 64 | boxes = center_size(boxes) 65 | not_outside = ((boxes[:, 0] > s_w) + (boxes[:, 1] > s_h)) < 1 # not (a or b) 66 | for k in dets: 67 | if k not in {'proto', 'bbox_idx', 'priors', 'embed_vectors', 'box_shift'} and dets[k] is not None: 68 | dets[k] = dets[k][not_outside] 69 | 70 | if dets['score'].size(0) == 0: 71 | dets['segm'] = torch.Tensor() 72 | return dets 73 | 74 | # Actually extract everything from dets now 75 | boxes = dets['box'] 76 | masks_coeff = dets['mask_coeff'] 77 | masks = dets['mask'] 78 | proto_data = dets['proto'] 79 | # normlized_coeff = F.normalize(masks_coeff, dim=1) 80 | # sim = torch.mm(normlized_coeff, normlized_coeff.t()) 81 | 82 | if visualize_lincomb: 83 | display_lincomb(proto_data, masks_coeff, img_ids, mask_det_file) 84 | 85 | # Undo padding for masks 86 | masks = masks[:, :int(s_h*masks.size(1)), :int(s_w*masks.size(2))] 87 | # Scale masks up to the full image 88 | if cfg.preserve_aspect_ratio: 89 | masks = F.interpolate(masks.unsqueeze(0), (ori_h, ori_w), mode=interpolation_mode, 90 | align_corners=False).squeeze(0) 91 | else: 92 | masks = F.interpolate(masks.unsqueeze(0), (img_h, img_w), mode=interpolation_mode, 93 | align_corners=False).squeeze(0) 94 | # Binarize the masks 95 | masks.gt_(0.5) 96 | 97 | if display_mask: 98 | dets['segm'] = masks 99 | else: 100 | # segm annotation: png2rle 101 | masks_output_json = [] 102 | for i in range(masks.size(0)): 103 | cur_mask = mask_util.encode(np.array(masks[i].cpu(), order='F', dtype='uint8')) 104 | # masks[i, :, :] = torch.from_numpy(mask_util.decode(cur_mask)).cuda() 105 | masks_output_json.append(cur_mask) 106 | dets['segm'] = masks_output_json 107 | 108 | # Undo padding for bboxes 109 | boxes[:, 0::2] = boxes[:, 0::2] / s_w 110 | boxes[:, 1::2] = boxes[:, 1::2] / s_h 111 | # priors = dets['priors'] # [cx, cy, w, h] 112 | # priors[:, :2] = priors[:, :2] - priors[:, 2:]/2 113 | # priors[:, 2:] = priors[:, :2] + priors[:, 2:] 114 | # priors[:, 0::2] = priors[:, 0::2] / s_w 115 | # priors[:, 1::2] = priors[:, 1::2] / s_h 116 | 117 | if cfg.preserve_aspect_ratio: 118 | out_w = ori_w 119 | out_h = ori_h 120 | else: 121 | out_w = img_w 122 | out_h = img_h 123 | 124 | boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], out_w, cast=False) 125 | boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], out_h, cast=False) 126 | # priors[:, 0], priors[:, 2] = sanitize_coordinates(priors[:, 0], priors[:, 2], out_w, cast=False) 127 | # priors[:, 1], priors[:, 3] = sanitize_coordinates(priors[:, 1], priors[:, 3], out_h, cast=False) 128 | 129 | boxes = boxes.long() 130 | dets['box'] = boxes 131 | # dets['priors'] = priors.long() 132 | 133 | return dets 134 | 135 | 136 | def undo_image_transformation(img, img_meta, pad_h, pad_w, interpolation_mode='bilinear'): 137 | """ 138 | Takes a transformed image tensor and returns a numpy ndarray that is untransformed. 139 | Arguments w and h are the original height and width of the image. 140 | """ 141 | ori_h, ori_w = img_meta['ori_shape'][0:2] 142 | img_h, img_w = img_meta['img_shape'][0:2] 143 | s_w, s_h = (img_w / pad_w, img_h / pad_h) 144 | 145 | # Undo padding 146 | img = img[:, :int(s_h * img.size(1)), :int(s_w * img.size(2))] 147 | if cfg.preserve_aspect_ratio: 148 | img = F.interpolate(img.unsqueeze(0), (ori_h, ori_w), mode=interpolation_mode, 149 | align_corners=False).squeeze(0) 150 | else: 151 | img = F.interpolate(img.unsqueeze(0), (img_h, img_w), mode=interpolation_mode, 152 | align_corners=False).squeeze(0) 153 | 154 | img_numpy = img.permute(1, 2, 0).cpu().numpy() 155 | img_numpy = img_numpy[:, :, (2, 1, 0)] # To BRG 156 | 157 | if cfg.backbone.transform.normalize: 158 | img_numpy = (img_numpy * np.array(STD) + np.array(MEANS)) / 255.0 159 | elif cfg.backbone.transform.subtract_means: 160 | img_numpy = (img_numpy / 255.0 + np.array(MEANS) / 255.0).astype(np.float32) 161 | 162 | img_numpy = img_numpy[:, :, (2, 1, 0)] # To RGB 163 | img_numpy = np.clip(img_numpy, 0, 1) 164 | 165 | return img_numpy 166 | 167 | 168 | def display_lincomb(proto_data, masks, img_ids=None, mask_det_file=None): 169 | proto_data = proto_data.squeeze() 170 | out_masks = torch.matmul(proto_data, masks.t()) 171 | out_masks = cfg.mask_proto_mask_activation(out_masks) 172 | 173 | for kdx in range(1): 174 | jdx = kdx + 0 175 | import matplotlib.pyplot as plt 176 | coeffs = masks[jdx, :].cpu().numpy() 177 | idx = np.argsort(-np.abs(coeffs)) 178 | # plt.bar(list(range(idx.shape[0])), coeffs[idx]) 179 | # plt.show() 180 | 181 | coeffs_sort = coeffs[idx] 182 | arr_h, arr_w = (8, 4) 183 | proto_h, proto_w, _ = proto_data.size() 184 | arr_img = np.zeros([proto_h * arr_h, proto_w * arr_w]) 185 | arr_run = np.zeros([proto_h * arr_h, proto_w * arr_w]) 186 | test = torch.sum(proto_data, -1).cpu().numpy() 187 | 188 | for y in range(arr_h): 189 | for x in range(arr_w): 190 | i = arr_w * y + x 191 | 192 | if i == 0: 193 | running_total = proto_data[:, :, idx[i]].cpu().numpy() * coeffs_sort[i] 194 | else: 195 | running_total += proto_data[:, :, idx[i]].cpu().numpy() * coeffs_sort[i] 196 | 197 | running_total_nonlin = running_total 198 | if cfg.mask_proto_mask_activation == activation_func.sigmoid: 199 | running_total_nonlin = (1 / (1 + np.exp(-running_total_nonlin))) 200 | 201 | arr_img[y * proto_h:(y + 1) * proto_h, x * proto_w:(x + 1) * proto_w] = (proto_data[:, :, 202 | idx[i]] / torch.max( 203 | proto_data[:, :, idx[i]])).cpu().numpy() * coeffs_sort[i] 204 | arr_run[y * proto_h:(y + 1) * proto_h, x * proto_w:(x + 1) * proto_w] = ( 205 | running_total_nonlin > 0.5).astype(np.float) 206 | plt.imshow(arr_img) 207 | plt.axis('off') 208 | if img_ids is not None: 209 | plt.title(str(img_ids)) 210 | plt.savefig(''.join([mask_det_file[:-12], 'out_proto/', str(img_ids), 'protos.png'])) 211 | # plt.show() 212 | # plt.imshow(arr_run) 213 | # plt.show() 214 | # plt.imshow(test) 215 | # plt.show() 216 | 217 | for jdx in range(out_masks.size(2)): 218 | plt.imshow(out_masks[:, :, jdx].cpu().numpy()) 219 | if img_ids is not None: 220 | plt.title(str(img_ids)) 221 | plt.savefig(''.join([mask_det_file[:-12], 'out_proto/', str(img_ids), str(jdx), 'mask.png'])) 222 | # plt.show() 223 | 224 | 225 | def display_fpn_outs(outs, img_ids=None, mask_det_file=None): 226 | 227 | for batch_idx in range(outs[0].size(0)): 228 | for idx in range(len(outs)): 229 | cur_out = outs[idx][batch_idx] 230 | import matplotlib.pyplot as plt 231 | arr_h, arr_w = (4, 4) 232 | _, h, w = cur_out.size() 233 | arr_img = np.zeros([h * arr_h, w * arr_w]) 234 | 235 | for y in range(arr_h): 236 | for x in range(arr_w): 237 | i = arr_w * y + x 238 | arr_img[y * h:(y + 1) * h, x * w:(x + 1) * w] = cur_out[i, :, :].cpu().numpy() 239 | 240 | plt.imshow(arr_img) 241 | if img_ids is not None: 242 | plt.title(str(img_ids)) 243 | plt.savefig(''.join([mask_det_file, str(img_ids), 'outs', str(batch_idx), str(idx), '.png'])) 244 | plt.show() 245 | -------------------------------------------------------------------------------- /layers/track_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | 4 | 5 | def split_bbox(bbox, idx_bbox_ori, nums_bbox_per_layer): 6 | num_layers = len(nums_bbox_per_layer) 7 | num_bboxes = len(idx_bbox_ori) 8 | for i in range(1, num_layers): 9 | nums_bbox_per_layer[i] = nums_bbox_per_layer[i-1] + nums_bbox_per_layer[i] 10 | 11 | split_bboxes = [[] for _ in range(num_layers)] 12 | split_bboxes_idx = [[] for _ in range(num_layers)] 13 | for i in range(num_bboxes): 14 | for j in range(num_layers): 15 | if idx_bbox_ori[i] < nums_bbox_per_layer[j]: 16 | split_bboxes[j].append(bbox[i].unsqueeze(0)) 17 | split_bboxes_idx[j].append(i) 18 | break 19 | 20 | for j in range(num_layers): 21 | if len(split_bboxes[j]) > 0: 22 | split_bboxes[j] = torch.cat(split_bboxes[j]) 23 | if j > 0: 24 | split_bboxes_idx[0] += split_bboxes_idx[j] 25 | 26 | return split_bboxes, split_bboxes_idx[0] 27 | 28 | 29 | def mask_iou(mask1, mask2): 30 | intersection = torch.sum(mask1 * mask2, dim=(0, 1)) 31 | area1 = torch.sum(mask1, dim=(0, 1)) 32 | area2 = torch.sum(mask2, dim=(0, 1)) 33 | union = (area1 + area2) - intersection 34 | ret = intersection / union 35 | return ret 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /layers/train_output_utils.py: -------------------------------------------------------------------------------- 1 | """ Contains functions used to sanitize and prepare the output of Yolact. """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from datasets import cfg, mask_type, MEANS, STD, activation_func 7 | from utils.augmentations import Resize 8 | from utils import timer 9 | from .box_utils import crop, sanitize_coordinates, center_size, decode 10 | import eval as eval_script 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | def display_train_output(images, predictions, conf_t, pids_t, gt_bboxes, gt_labels, gt_masks, ref_images, ref_bboxes, 15 | gt_pids, img_meta, epoch, iteration, path=None): 16 | setup_eval() 17 | loc_data = predictions['loc'] 18 | conf_data = predictions['conf'] 19 | mask_data = predictions['mask'] 20 | priors = predictions['priors'] 21 | priors = priors[0, :, :] 22 | match_score = predictions['track'] 23 | ref_boxes_n = predictions['ref_boxes_n'] 24 | if cfg.mask_type == mask_type.lincomb: 25 | proto_data = predictions['proto'] 26 | 27 | batch_size, _, h, w = images.size() 28 | 29 | if cfg.use_sigmoid_focal_loss: 30 | # Note: even though conf[0] exists, this mode doesn't train it so don't use it 31 | conf_data = torch.sigmoid(conf_data) 32 | elif cfg.use_objectness_score: 33 | # See focal_loss_sigmoid in multibox_loss.py for details 34 | objectness = torch.sigmoid(conf_data[:, :, 0]) 35 | conf_data[:, :, 1:] = objectness[:, :, None] * F.softmax(conf_data[:, :, 1:], -1) 36 | conf_data[:, :, 0] = 1 - objectness 37 | else: 38 | conf_data = F.softmax(conf_data, -1) 39 | 40 | # visualization 41 | pos = conf_t > 0 42 | for batch_idx in range(batch_size): 43 | # detection results 44 | dets_out = {} 45 | idx_pos = pos[batch_idx, :] == 1 46 | 47 | dets_out['score'], class_pred = conf_data[batch_idx, idx_pos, 1:].max(dim=1) 48 | dets_out['class_pred'] = class_pred + 1 # classes begins from 1 49 | dets_out['class'] = conf_t[batch_idx, idx_pos] 50 | dets_out['pids'] = pids_t[batch_idx, idx_pos] 51 | dets_out['box'] = decode(loc_data[batch_idx, idx_pos, :], priors[idx_pos]) 52 | dets_out['mask'] = mask_data[batch_idx, idx_pos, :] 53 | if cfg.mask_type == mask_type.lincomb: 54 | dets_out['proto'] = proto_data[batch_idx] 55 | 56 | img_numpy = eval_script.prep_display(dets_out, images[batch_idx], h, w, img_meta[batch_idx], display_mode='train') 57 | 58 | # gt results 59 | dets_out = {} 60 | dets_out['class'] = gt_labels[batch_idx] 61 | dets_out['box'] = gt_bboxes[batch_idx] 62 | dets_out['segm'] = gt_masks[batch_idx].type(torch.cuda.FloatTensor) 63 | dets_out['pids'] = gt_pids[batch_idx].type(torch.cuda.LongTensor) 64 | 65 | img_numpy_gt = eval_script.prep_display(dets_out, images[batch_idx], h, w, img_meta[batch_idx]) 66 | 67 | # gt results of the last frame 68 | dets_out = {} 69 | gt_class_last = [] 70 | for i in range(1, len(ref_bboxes[batch_idx])+1): 71 | if i in gt_pids[batch_idx].tolist(): 72 | gt_class_last.append(gt_labels[batch_idx][gt_pids[batch_idx].tolist().index(i)]) 73 | else: 74 | gt_class_last.append(-1) 75 | dets_out['class'] = torch.tensor(gt_class_last) 76 | 77 | dets_out['box'] = ref_bboxes[batch_idx] 78 | dets_out['pids'] = torch.arange(1, len(ref_bboxes[batch_idx])+1) 79 | 80 | img_numpy_gt_last = eval_script.prep_display(dets_out, ref_images[batch_idx], h, w, img_meta[batch_idx]) 81 | 82 | # show results and save figs 83 | plt.imshow(img_numpy) 84 | plt.title('train') 85 | plt.savefig(''.join([path, 'out/', str(epoch), '_', str(iteration), '_', str(batch_idx), '_train', '.png'])) 86 | plt.show() 87 | 88 | plt.imshow(img_numpy_gt) 89 | plt.title('gt') 90 | plt.savefig(''.join([path, 'out/', str(epoch), '_', str(iteration), '_', str(batch_idx), '_gt', '.png'])) 91 | plt.show() 92 | 93 | plt.imshow(img_numpy_gt_last) 94 | plt.title('gt_last') 95 | plt.savefig(''.join([path, 'out/', str(epoch), '_', str(iteration), '_', str(batch_idx), '_gt_last', '.png'])) 96 | plt.show() 97 | 98 | 99 | def setup_eval(): 100 | eval_script.parse_args(['--no_bar', 101 | '--output_json', 102 | ]) 103 | 104 | -------------------------------------------------------------------------------- /layers/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | from datasets import cfg, mask_type, MEANS, STD 5 | import random 6 | from math import sqrt 7 | import matplotlib.pyplot as plt 8 | import mmcv 9 | import torch.nn.functional as F 10 | import os 11 | 12 | 13 | # Quick and dirty lambda for selecting the color for a particular index 14 | # Also keeps track of a per-gpu color cache for maximum speed 15 | def get_color(j, color_type, on_gpu=None, undo_transform=True): 16 | global color_cache 17 | color_idx = color_type[j] * 5 % len(cfg.COLORS) 18 | 19 | if on_gpu is not None and color_idx in color_cache[on_gpu]: 20 | return color_cache[on_gpu][color_idx] 21 | else: 22 | color = cfg.COLORS[color_idx] 23 | if not undo_transform: 24 | # The image might come in as RGB or BRG, depending 25 | color = (color[2], color[1], color[0]) 26 | if on_gpu is not None: 27 | color = torch.Tensor(color).to(on_gpu).float() / 255. 28 | color_cache[on_gpu][color_idx] = color 29 | return color 30 | 31 | 32 | def display_box_shift(box, box_shift, mask_shift=None, img_meta=None, img_gpu=None, conf=None, mask_alpha=0.45): 33 | if img_meta is None: 34 | video_id, frame_id = 0, 0 35 | else: 36 | video_id, frame_id = img_meta['video_id'], img_meta['frame_id'] 37 | 38 | save_dir = 'weights/YTVIS2019/weights_r50_new/box_shift/' 39 | save_dir = os.path.join(save_dir, str(video_id)) 40 | if not os.path.exists(save_dir): 41 | os.makedirs(save_dir) 42 | path = ''.join([save_dir, '/', str(frame_id), '.png']) 43 | 44 | # Make empty black image 45 | if img_gpu is None: 46 | h, w = 384, 640 47 | image = np.ones((h, w, 3), np.uint8) * 255 48 | else: 49 | 50 | h, w = img_gpu.size()[1:] 51 | img_numpy = img_gpu.squeeze(0).permute(1, 2, 0).cpu().numpy() 52 | img_numpy = img_numpy[:, :, (2, 1, 0)] # To BRG 53 | img_numpy = (img_numpy * np.array(STD) + np.array(MEANS)) / 255.0 54 | # img_numpy = img_numpy[:, :, (2, 1, 0)] # To RGB 55 | img_numpy = np.clip(img_numpy, 0, 1) * 255 56 | img_gpu = torch.Tensor(img_numpy).cuda() 57 | 58 | # plot pred bbox 59 | color_type = range(box.size(0)) 60 | if mask_shift is not None: 61 | # Undo padding for masks 62 | # gt_masks_cur = gt_masks_cur[:, :img_h, :img_w].float() 63 | mask_shift = mask_shift.unsqueeze(-1).repeat(1, 1, 1, 3) 64 | # This is 1 everywhere except for 1-mask_alpha where the mask is 65 | inv_alph_masks = mask_shift.sum(0) * (-mask_alpha) + 1 66 | mask_shift_color = [] 67 | for i in range(box.size(0)): 68 | color = get_color(i, color_type, on_gpu=img_gpu.device.index).view(1, 1, 3) 69 | mask_shift_color.append(mask_shift[i] * color * mask_alpha) 70 | mask_shift_color = torch.stack(mask_shift_color, dim=0).sum(0) 71 | img_gpu = (img_gpu * inv_alph_masks + mask_shift_color) 72 | 73 | image = img_gpu.byte().cpu().numpy() 74 | 75 | if conf is not None: 76 | scores, classes = conf[:, 1:].max(dim=1) 77 | 78 | for i in range(box.size(0)): 79 | color = get_color(i, color_type) 80 | cv2.rectangle(image, (box[i, 0]*w, box[i, 1]*h), (box[i, 2]*w, box[i, 3]*h), color, 2) 81 | 82 | cv2.rectangle(image, (box_shift[i, 0] * w, box_shift[i, 1] * h), 83 | (box_shift[i, 2] * w, box_shift[i, 3] * h), color, 4) 84 | 85 | if conf is not None: 86 | text_str = '%s: %.2f' % (classes[i].item()+1, scores[i]) 87 | 88 | font_face = cv2.FONT_HERSHEY_DUPLEX 89 | font_scale = 0.5 90 | font_thickness = 1 91 | text_pt = (box_shift[i, 0]*w, box_shift[i, 1]*h - 3) 92 | text_color = [255, 255, 255] 93 | cv2.putText(image, text_str, text_pt, font_face, font_scale, text_color, font_thickness, cv2.LINE_AA) 94 | cv2.imwrite(path, image) 95 | 96 | 97 | def display_feature_align_dcn(detection, offset, loc_data, img_gpu=None, img_meta=None, use_yolo_regressors=False): 98 | h, w = 384, 640 99 | # Make empty black image 100 | if img_gpu is None: 101 | image = np.ones((h, w, 3), np.uint8) * 255 102 | else: 103 | img_numpy = img_gpu.squeeze(0).permute(1, 2, 0).cpu().numpy() 104 | img_numpy = img_numpy[:, :, (2, 1, 0)] # To BRG 105 | img_numpy = (img_numpy * np.array(STD) + np.array(MEANS)) / 255.0 106 | # img_numpy = img_numpy[:, :, (2, 1, 0)] # To RGB 107 | img_numpy = np.clip(img_numpy, 0, 1) 108 | img_gpu = torch.Tensor(img_numpy).cuda() 109 | image = (img_gpu * 255).byte().cpu().numpy() 110 | 111 | n_dets = detection['box'].size(0) 112 | n = 0 113 | p = detection['priors'][n] 114 | decoded_loc = detection['box'][n] 115 | id = detection['bbox_idx'][n] 116 | loc = loc_data[0, id, :] 117 | pixel_id = id // 3 118 | prior_id = id % 3 119 | if prior_id == 0: 120 | o = offset[0, :18, pixel_id] 121 | ks_h, ks_w = 3, 3 122 | grid_w = torch.tensor([-1, 0, 1] * ks_h) 123 | grid_h = torch.tensor([[-1], [0], [1]]).repeat(1, ks_w).view(-1) 124 | elif prior_id == 1: 125 | o = offset[0, 18:48, pixel_id] 126 | ks_h, ks_w = 3, 5 127 | grid_w = torch.tensor([-2, -1, 0, 1, 2] * ks_h) 128 | grid_h = torch.tensor([[-1], [0], [1]]).repeat(1, ks_w).view(-1) 129 | else: 130 | o = offset[0, 48:, pixel_id] 131 | ks_h, ks_w = 5, 3 132 | grid_w = torch.tensor([-1, 0, 1] * ks_h) 133 | grid_h = torch.tensor([[-2], [-1], [0], [1], [2]]).repeat(1, ks_w).view(-1) 134 | 135 | # thransfer the rectange to 9 points 136 | cx1, cy1, w1, h1 = p[0], p[1], p[2], p[3] 137 | dw1 = grid_w * w1 / (ks_w-1) + cx1 138 | dh1 = grid_h * h1 / (ks_h-1) + cy1 139 | 140 | dwh = p[2:] * ((loc.detach()[2:] * 0.2).exp() - 1) 141 | # regressed bounding boxes 142 | new_dh1 = dh1 + loc[1] * p[3] * 0.1 + dwh[1] / ks_h * grid_h 143 | new_dw1 = dw1 + loc[0] * p[2] * 0.1 + dwh[0] / ks_w * grid_w 144 | # points after the offsets of dcn 145 | new_dh2 = dh1 + o[::2].view(-1) * 0.5 * p[3] 146 | new_dw2 = dw1 + o[1::2].view(-1) * 0.5 * p[2] 147 | 148 | # Create a named colour 149 | blue = [255, 0, 0] # bgr 150 | purple = [128, 0, 128] 151 | red = [0, 0, 255] 152 | 153 | # plot pred bbox 154 | cv2.rectangle(image, (decoded_loc[0] * w, decoded_loc[1] * h), (decoded_loc[2] * w, decoded_loc[3] * h), 155 | blue, 2, lineType=8) 156 | 157 | # plot priors 158 | pxy1 = p[:2] - p[2:] / 2 159 | pxy2 = p[:2] + p[2:] / 2 160 | cv2.rectangle(image, (pxy1[0] * w, pxy1[1] * h), (pxy2[0] * w, pxy2[1] * h), 161 | purple, 2, lineType=8) 162 | for i in range(len(dw1)): 163 | cv2.circle(image, (new_dw2[i] * w, new_dh2[i] * h), radius=0, color=blue, thickness=10) 164 | cv2.circle(image, (new_dw1[i]*w, new_dh1[i]*h), radius=0, color=blue, thickness=6) 165 | cv2.circle(image, (dw1[i] * w, dh1[i] * h), radius=0, color=purple, thickness=6) 166 | 167 | if img_meta is not None: 168 | path = ''.join(['results/results_1024_2/FCB/', str(img_meta[0]['video_id']), '_', 169 | str(img_meta[0]['frame_id']), '.png']) 170 | else: 171 | path = 'results/results_1024_2/FCB/0.png' 172 | cv2.imwrite(path, image) 173 | 174 | 175 | def display_correlation_map_patch(x_corr, img_meta=None): 176 | if img_meta is not None: 177 | video_id, frame_id = img_meta['video_id'], img_meta['frame_id'] 178 | else: 179 | video_id, frame_id = 0, 0 180 | 181 | save_dir = 'weights/YTVIS2019/weights_r50_new/box_shift/' 182 | save_dir = os.path.join(save_dir, str(video_id)) 183 | if not os.path.exists(save_dir): 184 | os.makedirs(save_dir) 185 | # x_corr = x_corr[0, :, :18, :]**2 186 | bs, ch, h, w = x_corr.size() 187 | # x_corr = F.normalize(x_corr, dim=1) 188 | r = int(sqrt(ch)) 189 | for i in range(bs): 190 | x_corr_cur = x_corr[i] 191 | x_show = x_corr_cur.view(r, r, h, w) 192 | 193 | x_show = x_show.permute(0, 2, 1, 3).contiguous().view(h*r, r*w) 194 | x_numpy = x_show.detach().cpu().numpy() 195 | 196 | path_max = ''.join([save_dir, '/', str(frame_id), '_', str(i), '_max_corr_patch.png']) 197 | max_corr = x_corr_cur.max(dim=0)[0].detach().cpu().numpy() 198 | plt.imshow(max_corr) 199 | plt.savefig(path_max) 200 | 201 | path = ''.join([save_dir, '/', str(frame_id), '_', str(i), '_corr_patch.png']) 202 | plt.axis('off') 203 | plt.pcolormesh(x_numpy) 204 | plt.savefig(path) 205 | plt.clf() 206 | 207 | 208 | def display_correlation_map(x_corr, img_meta=None, idx=0): 209 | x_corr = x_corr[:, :36] 210 | bs, ch, h, w = x_corr.size() 211 | r = int(sqrt(ch)) 212 | x_show = x_corr.view(r, r, h, w).permute(0, 2, 1, 3).contiguous() 213 | x_show = x_show.view(h*r, r*w) 214 | x_numpy = (x_show).cpu().numpy() 215 | 216 | if img_meta is not None: 217 | path = ''.join(['results/results_1024_2/fea_ref/', str(img_meta[0]['video_id']), '_', 218 | str(img_meta[0]['frame_id']), '_', str(idx), '.png']) 219 | else: 220 | path = 'results/results_1024_2/fea_ref/0.png' 221 | 222 | plt.axis('off') 223 | plt.pcolormesh(x_numpy) 224 | plt.savefig(path) 225 | plt.clf() 226 | 227 | 228 | def display_embedding_map(matching_map_all, idx, img_meta=None): 229 | if img_meta is not None: 230 | path = ''.join(['results/results_1227_1/embedding_map/', str(img_meta['video_id']), '_', 231 | str(img_meta['frame_id']), '_', str(idx), '.png']) 232 | path2 = ''.join(['results/results_1227_1/embedding_map/', str(img_meta['video_id']), '_', 233 | str(img_meta['frame_id']), '_', str(idx), '_m.png']) 234 | 235 | else: 236 | path = 'results/results_1227_1/embedding_map/0.png' 237 | path2 = 'results/results_1227_1/embedding_map/0_m.png' 238 | 239 | matching_map_all = matching_map_all.squeeze(0) 240 | r, r, h, w = matching_map_all.size() 241 | # matching_map_mean = matching_map_all.view(r**2, h, w).mean(0) # / (r**2) 242 | matching_map, _ = matching_map_all.view(r ** 2, h, w).max(0) # / (r**2) 243 | x_show = matching_map_all.permute(0, 2, 1, 3).contiguous() 244 | x_show = x_show.view(h * r, r * w) 245 | x_numpy = (x_show[h*2:h*10, w*2:w*10]).cpu().numpy() 246 | 247 | plt.axis('off') 248 | plt.pcolormesh(mmcv.imflip(x_numpy, direction='vertical')) 249 | plt.savefig(path) 250 | plt.clf() 251 | 252 | matching_map_numpy = matching_map.squeeze(0).cpu().numpy() 253 | plt.axis('off') 254 | plt.imshow(matching_map_numpy) 255 | plt.savefig(path2) 256 | plt.clf() 257 | 258 | 259 | def display_shifted_masks(shifted_masks, img_meta=None): 260 | n, h, w = shifted_masks.size() 261 | 262 | for i in range(n): 263 | if img_meta is not None: 264 | path = ''.join(['results/results_1227_1/embedding_map/', str(img_meta['video_id']), '_', 265 | str(img_meta['frame_id']), '_', str(i), '_shifted_masks.png']) 266 | 267 | else: 268 | path = 'results/results_1227_1/fea_ref/0_shifted_mask.png' 269 | shifted_masks = shifted_masks.gt(0.3).float() 270 | shifted_masks_numpy = shifted_masks[i].cpu().numpy() 271 | plt.axis('off') 272 | plt.pcolormesh(mmcv.imflip(shifted_masks_numpy*10, direction='vertical')) 273 | plt.savefig(path) 274 | plt.clf() 275 | 276 | 277 | 278 | -------------------------------------------------------------------------------- /scripts/augment_bbox.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path as osp 3 | import json, pickle 4 | import sys 5 | from math import sqrt 6 | from itertools import product 7 | import torch 8 | from numpy import random 9 | 10 | import numpy as np 11 | 12 | 13 | max_image_size = 550 14 | augment_idx = 0 15 | dump_file = 'weights/bboxes_aug.pkl' 16 | box_file = 'weights/bboxes.pkl' 17 | 18 | def augment_boxes(bboxes): 19 | bboxes_rel = [] 20 | for box in bboxes: 21 | bboxes_rel.append(prep_box(box)) 22 | bboxes_rel = np.concatenate(bboxes_rel, axis=0) 23 | 24 | with open(dump_file, 'wb') as f: 25 | pickle.dump(bboxes_rel, f) 26 | 27 | def prep_box(box_list): 28 | global augment_idx 29 | boxes = np.array([box_list[2:]], dtype=np.float32) 30 | 31 | # Image width and height 32 | width, height = box_list[:2] 33 | 34 | # To point form 35 | boxes[:, 2:] += boxes[:, :2] 36 | 37 | 38 | # Expand 39 | ratio = random.uniform(1, 4) 40 | left = random.uniform(0, width*ratio - width) 41 | top = random.uniform(0, height*ratio - height) 42 | 43 | height *= ratio 44 | width *= ratio 45 | 46 | boxes[:, :2] += (int(left), int(top)) 47 | boxes[:, 2:] += (int(left), int(top)) 48 | 49 | 50 | # RandomSampleCrop 51 | height, width, boxes = random_sample_crop(height, width, boxes) 52 | 53 | 54 | # RandomMirror 55 | if random.randint(0, 2): 56 | boxes[:, 0::2] = width - boxes[:, 2::-2] 57 | 58 | 59 | # Resize 60 | boxes[:, [0, 2]] *= (max_image_size / width) 61 | boxes[:, [1, 3]] *= (max_image_size / height) 62 | width = height = max_image_size 63 | 64 | 65 | # ToPercentCoords 66 | boxes[:, [0, 2]] /= width 67 | boxes[:, [1, 3]] /= height 68 | 69 | if augment_idx % 50000 == 0: 70 | print('Current idx: %d' % augment_idx) 71 | 72 | augment_idx += 1 73 | 74 | return boxes 75 | 76 | 77 | 78 | 79 | sample_options = ( 80 | # using entire original input image 81 | None, 82 | # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 83 | (0.1, None), 84 | (0.3, None), 85 | (0.7, None), 86 | (0.9, None), 87 | # randomly sample a patch 88 | (None, None), 89 | ) 90 | 91 | def intersect(box_a, box_b): 92 | max_xy = np.minimum(box_a[:, 2:], box_b[2:]) 93 | min_xy = np.maximum(box_a[:, :2], box_b[:2]) 94 | inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) 95 | return inter[:, 0] * inter[:, 1] 96 | 97 | 98 | def jaccard_numpy(box_a, box_b): 99 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 100 | is simply the intersection over union of two boxes. 101 | E.g.: 102 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 103 | Args: 104 | box_a: Multiple bounding boxes, Shape: [num_boxes,4] 105 | box_b: Single bounding box, Shape: [4] 106 | Return: 107 | jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] 108 | """ 109 | inter = intersect(box_a, box_b) 110 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 111 | (box_a[:, 3]-box_a[:, 1])) # [A,B] 112 | area_b = ((box_b[2]-box_b[0]) * 113 | (box_b[3]-box_b[1])) # [A,B] 114 | union = area_a + area_b - inter 115 | return inter / union # [A,B] 116 | 117 | 118 | def random_sample_crop(height, width, boxes=None): 119 | global sample_options 120 | 121 | while True: 122 | # randomly choose a mode 123 | mode = random.choice(sample_options) 124 | if mode is None: 125 | return height, width, boxes 126 | 127 | min_iou, max_iou = mode 128 | if min_iou is None: 129 | min_iou = float('-inf') 130 | if max_iou is None: 131 | max_iou = float('inf') 132 | 133 | for _ in range(50): 134 | w = random.uniform(0.3 * width, width) 135 | h = random.uniform(0.3 * height, height) 136 | 137 | if h / w < 0.5 or h / w > 2: 138 | continue 139 | 140 | left = random.uniform(0, width - w) 141 | top = random.uniform(0, height - h) 142 | 143 | rect = np.array([int(left), int(top), int(left+w), int(top+h)]) 144 | overlap = jaccard_numpy(boxes, rect) 145 | if overlap.min() < min_iou and max_iou < overlap.max(): 146 | continue 147 | 148 | centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 149 | 150 | m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) 151 | m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) 152 | mask = m1 * m2 153 | 154 | if not mask.any(): 155 | continue 156 | 157 | current_boxes = boxes[mask, :].copy() 158 | current_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2]) 159 | current_boxes[:, :2] -= rect[:2] 160 | current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:]) 161 | current_boxes[:, 2:] -= rect[:2] 162 | 163 | return h, w, current_boxes 164 | 165 | 166 | if __name__ == '__main__': 167 | 168 | with open(box_file, 'rb') as f: 169 | bboxes = pickle.load(f) 170 | 171 | augment_boxes(bboxes) 172 | -------------------------------------------------------------------------------- /scripts/bbox_recall.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script compiles all the bounding boxes in the training data and 3 | clusters them for each convout resolution on which they're used. 4 | 5 | Run this script from the Yolact root directory. 6 | """ 7 | 8 | import os.path as osp 9 | import json, pickle 10 | import sys 11 | from math import sqrt 12 | from itertools import product 13 | import torch 14 | import random 15 | 16 | import numpy as np 17 | 18 | dump_file = 'weights/bboxes.pkl' 19 | aug_file = 'weights/bboxes_aug.pkl' 20 | 21 | use_augmented_boxes = True 22 | 23 | 24 | def intersect(box_a, box_b): 25 | """ We resize both tensors to [A,B,2] without new malloc: 26 | [A,2] -> [A,1,2] -> [A,B,2] 27 | [B,2] -> [1,B,2] -> [A,B,2] 28 | Then we compute the area of intersect between box_a and box_b. 29 | Args: 30 | box_a: (tensor) bounding boxes, Shape: [A,4]. 31 | box_b: (tensor) bounding boxes, Shape: [B,4]. 32 | Return: 33 | (tensor) intersection area, Shape: [A,B]. 34 | """ 35 | A = box_a.size(0) 36 | B = box_b.size(0) 37 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 38 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 39 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 40 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 41 | inter = torch.clamp((max_xy - min_xy), min=0) 42 | return inter[:, :, 0] * inter[:, :, 1] 43 | 44 | 45 | def jaccard(box_a, box_b, iscrowd=False): 46 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 47 | is simply the intersection over union of two boxes. Here we operate on 48 | ground truth boxes and default boxes. If iscrowd=True, put the crowd in box_b. 49 | E.g.: 50 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 51 | Args: 52 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 53 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 54 | Return: 55 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 56 | """ 57 | inter = intersect(box_a, box_b) 58 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 59 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 60 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 61 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 62 | union = area_a + area_b - inter 63 | 64 | if iscrowd: 65 | return inter / area_a 66 | else: 67 | return inter / union # [A,B] 68 | 69 | # Also convert to point form 70 | def to_relative(bboxes): 71 | return np.concatenate((bboxes[:, 2:4] / bboxes[:, :2], (bboxes[:, 2:4] + bboxes[:, 4:]) / bboxes[:, :2]), axis=1) 72 | 73 | 74 | def make_priors(conv_size, scales, aspect_ratios): 75 | prior_data = [] 76 | conv_h = conv_size[0] 77 | conv_w = conv_size[1] 78 | 79 | # Iteration order is important (it has to sync up with the convout) 80 | for j, i in product(range(conv_h), range(conv_w)): 81 | x = (i + 0.5) / conv_w 82 | y = (j + 0.5) / conv_h 83 | 84 | for scale, ars in zip(scales, aspect_ratios): 85 | for ar in ars: 86 | w = scale * ar / conv_w 87 | h = scale / ar / conv_h 88 | 89 | # Point form 90 | prior_data += [x - w/2, y - h/2, x + w/2, y + h/2] 91 | 92 | return np.array(prior_data).reshape(-1, 4) 93 | 94 | # fixed_ssd_config 95 | # scales = [[3.5, 4.95], [3.6, 4.90], [3.3, 4.02], [2.7, 3.10], [2.1, 2.37], [2.1, 2.37], [1.8, 1.92]] 96 | # aspect_ratios = [ [[1, sqrt(2), 1/sqrt(2), sqrt(3), 1/sqrt(3)][:n], [1]] for n in [3, 5, 5, 5, 3, 3, 3] ] 97 | # conv_sizes = [(35, 35), (18, 18), (9, 9), (5, 5), (3, 3), (2, 2)] 98 | 99 | scales = [[1.68, 2.91], 100 | [2.95, 2.22, 0.84], 101 | [2.23, 2.17, 3.12], 102 | [0.76, 1.94, 2.72], 103 | [2.10, 2.65], 104 | [1.80, 1.92]] 105 | aspect_ratios = [[[0.72, 0.96], [0.68, 1.17]], 106 | [[1.28, 0.66], [0.63, 1.23], [0.89, 1.40]], 107 | [[2.05, 1.24], [0.57, 0.83], [0.61, 1.15]], 108 | [[1.00, 2.21], [0.47, 1.60], [1.44, 0.79]], 109 | [[1.00, 1.41, 0.71, 1.73, 0.58], [1.08]], 110 | [[1.00, 1.41, 0.71, 1.73, 0.58], [1.00]]] 111 | conv_sizes = [(35, 35), (18, 18), (9, 9), (5, 5), (3, 3), (2, 2)] 112 | 113 | # yrm33_config 114 | # scales = [ [5.3] ] * 5 115 | # aspect_ratios = [ [[1, 1/sqrt(2), sqrt(2)]] ]*5 116 | # conv_sizes = [(136, 136), (67, 67), (33, 33), (16, 16), (8, 8)] 117 | 118 | 119 | SMALL = 0 120 | MEDIUM = 1 121 | LARGE = 2 122 | 123 | if __name__ == '__main__': 124 | 125 | with open(dump_file, 'rb') as f: 126 | bboxes = pickle.load(f) 127 | 128 | sizes = [] 129 | smalls = [] 130 | for i in range(len(bboxes)): 131 | area = bboxes[i][4] * bboxes[i][5] 132 | if area < 32 ** 2: 133 | sizes.append(SMALL) 134 | smalls.append(area) 135 | elif area < 96 ** 2: 136 | sizes.append(MEDIUM) 137 | else: 138 | sizes.append(LARGE) 139 | 140 | # Each box is in the form [im_w, im_h, pos_x, pos_y, size_x, size_y] 141 | 142 | if use_augmented_boxes: 143 | with open(aug_file, 'rb') as f: 144 | bboxes_rel = pickle.load(f) 145 | else: 146 | bboxes_rel = to_relative(np.array(bboxes)) 147 | 148 | 149 | with torch.no_grad(): 150 | sizes = torch.Tensor(sizes) 151 | 152 | anchors = [make_priors(cs, s, ar) for cs, s, ar in zip(conv_sizes, scales, aspect_ratios)] 153 | anchors = np.concatenate(anchors, axis=0) 154 | anchors = torch.Tensor(anchors).cuda() 155 | 156 | bboxes_rel = torch.Tensor(bboxes_rel).cuda() 157 | perGTAnchorMax = torch.zeros(bboxes_rel.shape[0]).cuda() 158 | 159 | chunk_size = 1000 160 | for i in range((bboxes_rel.size(0) // chunk_size) + 1): 161 | start = i * chunk_size 162 | end = min((i + 1) * chunk_size, bboxes_rel.size(0)) 163 | 164 | ious = jaccard(bboxes_rel[start:end, :], anchors) 165 | maxes, maxidx = torch.max(ious, dim=1) 166 | 167 | perGTAnchorMax[start:end] = maxes 168 | 169 | 170 | hits = (perGTAnchorMax > 0.5).float() 171 | 172 | print('Total recall: %.2f' % (torch.sum(hits) / hits.size(0) * 100)) 173 | print() 174 | 175 | for i, metric in zip(range(3), ('small', 'medium', 'large')): 176 | _hits = hits[sizes == i] 177 | _size = (1 if _hits.size(0) == 0 else _hits.size(0)) 178 | print(metric + ' recall: %.2f' % ((torch.sum(_hits) / _size) * 100)) 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /scripts/cluster_bbox_sizes.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script compiles all the bounding boxes in the training data and 3 | clusters them for each convout resolution on which they're used. 4 | 5 | Run this script from the Yolact root directory. 6 | """ 7 | 8 | import os.path as osp 9 | import json, pickle 10 | import sys 11 | 12 | import numpy as np 13 | import sklearn.cluster as cluster 14 | 15 | dump_file = 'weights/bboxes.pkl' 16 | max_size = 550 17 | 18 | num_scale_clusters = 5 19 | num_aspect_ratio_clusters = 3 20 | 21 | def to_relative(bboxes): 22 | return bboxes[:, 2:4] / bboxes[:, :2] 23 | 24 | def process(bboxes): 25 | return to_relative(bboxes) * max_size 26 | 27 | if __name__ == '__main__': 28 | 29 | with open(dump_file, 'rb') as f: 30 | bboxes = pickle.load(f) 31 | 32 | bboxes = np.array(bboxes) 33 | bboxes = process(bboxes) 34 | bboxes = bboxes[(bboxes[:, 0] > 1) * (bboxes[:, 1] > 1)] 35 | 36 | scale = np.sqrt(bboxes[:, 0] * bboxes[:, 1]).reshape(-1, 1) 37 | 38 | clusterer = cluster.KMeans(num_scale_clusters, random_state=99, n_jobs=4) 39 | assignments = clusterer.fit_predict(scale) 40 | counts = np.bincount(assignments) 41 | 42 | cluster_centers = clusterer.cluster_centers_ 43 | 44 | center_indices = list(range(num_scale_clusters)) 45 | center_indices.sort(key=lambda x: cluster_centers[x, 0]) 46 | 47 | for idx in center_indices: 48 | center = cluster_centers[idx, 0] 49 | boxes_for_center = bboxes[assignments == idx] 50 | aspect_ratios = (boxes_for_center[:,0] / boxes_for_center[:,1]).reshape(-1, 1) 51 | 52 | c = cluster.KMeans(num_aspect_ratio_clusters, random_state=idx, n_jobs=4) 53 | ca = c.fit_predict(aspect_ratios) 54 | cc = np.bincount(ca) 55 | 56 | c = list(c.cluster_centers_.reshape(-1)) 57 | cidx = list(range(num_aspect_ratio_clusters)) 58 | cidx.sort(key=lambda x: -cc[x]) 59 | 60 | # import code 61 | # code.interact(local=locals()) 62 | 63 | print('%.3f (%d) aspect ratios:' % (center, counts[idx])) 64 | for idx in cidx: 65 | print('\t%.2f (%d)' % (c[idx], cc[idx])) 66 | print() 67 | # exit() 68 | 69 | 70 | -------------------------------------------------------------------------------- /scripts/compute_masks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import cv2 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | COLORS = ((255, 0, 0, 128), (0, 255, 0, 128), (0, 0, 255, 128), 8 | (0, 255, 255, 128), (255, 0, 255, 128), (255, 255, 0, 128)) 9 | 10 | def mask_iou(mask1, mask2): 11 | """ 12 | Inputs inputs are matricies of size _ x N. Output is size _1 x _2. 13 | Note: if iscrowd is True, then mask2 should be the crowd. 14 | """ 15 | intersection = torch.matmul(mask1, mask2.t()) 16 | area1 = torch.sum(mask1, dim=1).view(1, -1) 17 | area2 = torch.sum(mask2, dim=1).view(1, -1) 18 | union = (area1.t() + area2) - intersection 19 | 20 | return intersection / union 21 | 22 | def paint_mask(img_numpy, mask, color): 23 | h, w, _ = img_numpy.shape 24 | img_numpy = img_numpy.copy() 25 | 26 | mask = np.tile(mask.reshape(h, w, 1), (1, 1, 3)) 27 | color_np = np.array(color[:3]).reshape(1, 1, 3) 28 | color_np = np.tile(color_np, (h, w, 1)) 29 | mask_color = mask * color_np 30 | 31 | mask_alpha = 0.3 32 | 33 | # Blend image and mask 34 | image_crop = img_numpy * mask 35 | img_numpy *= (1-mask) 36 | img_numpy += image_crop * (1-mask_alpha) + mask_color * mask_alpha 37 | 38 | return img_numpy 39 | 40 | # Inverse sigmoid 41 | def logit(x): 42 | return np.log(x / (1-x + 0.0001) + 0.0001) 43 | 44 | def sigmoid(x): 45 | return 1 / (1 + np.exp(-x)) 46 | 47 | img_fmt = '../data/coco/images/%012d.jpg' 48 | with open('info.txt', 'r') as f: 49 | img_id = int(f.read()) 50 | 51 | img = plt.imread(img_fmt % img_id).astype(np.float32) 52 | h, w, _ = img.shape 53 | 54 | gt_masks = np.load('gt.npy').astype(np.float32).transpose(1, 2, 0) 55 | proto_masks = np.load('proto.npy').astype(np.float32) 56 | 57 | proto_masks = torch.Tensor(proto_masks).permute(2, 0, 1).contiguous().unsqueeze(0) 58 | proto_masks = F.interpolate(proto_masks, (h, w), mode='bilinear', align_corners=False).squeeze(0) 59 | proto_masks = proto_masks.permute(1, 2, 0).numpy() 60 | 61 | # # A x = b 62 | ls_A = proto_masks.reshape(-1, proto_masks.shape[-1]) 63 | ls_b = gt_masks.reshape(-1, gt_masks.shape[-1]) 64 | 65 | # x is size [256, num_gt] 66 | x = np.linalg.lstsq(ls_A, ls_b, rcond=None)[0] 67 | 68 | approximated_masks = (np.matmul(proto_masks, x) > 0.5).astype(np.float32) 69 | 70 | num_gt = approximated_masks.shape[2] 71 | ious = mask_iou(torch.Tensor(approximated_masks.reshape(-1, num_gt).T), 72 | torch.Tensor(gt_masks.reshape(-1, num_gt).T)) 73 | 74 | ious = [int(ious[i, i].item() * 100) for i in range(num_gt)] 75 | ious.sort(key=lambda x: -x) 76 | 77 | print(ious) 78 | 79 | gt_img = img.copy() 80 | 81 | for i in range(num_gt): 82 | gt_img = paint_mask(gt_img, gt_masks[:, :, i], COLORS[i % len(COLORS)]) 83 | 84 | plt.imshow(gt_img / 255) 85 | plt.title('GT') 86 | plt.show() 87 | 88 | for i in range(num_gt): 89 | img = paint_mask(img, approximated_masks[:, :, i], COLORS[i % len(COLORS)]) 90 | 91 | plt.imshow(img / 255) 92 | plt.title('Approximated') 93 | plt.show() 94 | -------------------------------------------------------------------------------- /scripts/convert_darknet.py: -------------------------------------------------------------------------------- 1 | from backbone import DarkNetBackbone 2 | import h5py 3 | import torch 4 | 5 | f = h5py.File('darknet53.h5', 'r') 6 | m = f['model_weights'] 7 | 8 | yolo_keys = list(m.keys()) 9 | yolo_keys = [x for x in yolo_keys if len(m[x].keys()) > 0] 10 | yolo_keys.sort() 11 | 12 | sd = DarkNetBackbone().state_dict() 13 | 14 | sd_keys = list(sd.keys()) 15 | sd_keys.sort() 16 | 17 | # Note this won't work if there are 10 elements in some list but whatever that doesn't happen 18 | layer_keys = list(set(['.'.join(x.split('.')[:-2]) for x in sd_keys])) 19 | layer_keys.sort() 20 | 21 | # print([x for x in sd_keys if x.startswith(layer_keys[0])]) 22 | 23 | mapping = { 24 | '.0.weight' : ('conv2d_%d', 'kernel:0'), 25 | '.1.bias' : ('batch_normalization_%d', 'beta:0'), 26 | '.1.weight' : ('batch_normalization_%d', 'gamma:0'), 27 | '.1.running_var' : ('batch_normalization_%d', 'moving_variance:0'), 28 | '.1.running_mean': ('batch_normalization_%d', 'moving_mean:0'), 29 | '.1.num_batches_tracked': None, 30 | } 31 | 32 | for i, layer_key in zip(range(1, len(layer_keys) + 1), layer_keys): 33 | # This is pretty inefficient but I don't care 34 | for weight_key in [x for x in sd_keys if x.startswith(layer_key)]: 35 | diff = weight_key[len(layer_key):] 36 | 37 | if mapping[diff] is not None: 38 | yolo_key = mapping[diff][0] % i 39 | sub_key = mapping[diff][1] 40 | 41 | yolo_weight = torch.Tensor(m[yolo_key][yolo_key][sub_key].value) 42 | if (len(yolo_weight.size()) == 4): 43 | yolo_weight = yolo_weight.permute(3, 2, 0, 1).contiguous() 44 | 45 | sd[weight_key] = yolo_weight 46 | 47 | torch.save(sd, 'weights/darknet53.pth') 48 | 49 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p GPU-small 3 | #SBATCH -t 2:00:00 4 | #SBATCH --gres=gpu:p100:1 5 | #SBATCH --no-requeue 6 | 7 | # Usage: ./eval.sh weights extra_args 8 | 9 | module load python/3.6.4_gcc5_np1.14.5 10 | module load cuda/9.0 11 | 12 | cd $SCRATCH/yolact 13 | 14 | python3 eval.py --trained_model=$1 --no_bar $2 > logs/eval/$(basename -- $1).log 2>&1 15 | -------------------------------------------------------------------------------- /scripts/make_grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math, random 3 | 4 | import matplotlib.pyplot as plt 5 | from matplotlib.widgets import Slider, Button 6 | 7 | 8 | fig, ax = plt.subplots() 9 | plt.subplots_adjust(bottom=0.24) 10 | im_handle = None 11 | 12 | save_path = 'grid.npy' 13 | 14 | center_x, center_y = (0.5, 0.5) 15 | grid_w, grid_h = (35, 35) 16 | spacing = 0 17 | scale = 4 18 | angle = 0 19 | grid = None 20 | 21 | all_grids = [] 22 | unique = False 23 | 24 | # A hack 25 | disable_render = False 26 | 27 | def render(): 28 | if disable_render: 29 | return 30 | 31 | x = np.tile(np.array(list(range(grid_w)), dtype=np.float).reshape(1, grid_w), [grid_h, 1]) - grid_w * center_x 32 | y = np.tile(np.array(list(range(grid_h)), dtype=np.float).reshape(grid_h, 1), [1, grid_w]) - grid_h * center_y 33 | 34 | x /= scale 35 | y /= scale 36 | 37 | a1 = angle + math.pi / 3 38 | a2 = -angle + math.pi / 3 39 | a3 = angle 40 | 41 | z1 = x * math.sin(a1) + y * math.cos(a1) 42 | z2 = x * math.sin(a2) - y * math.cos(a2) 43 | z3 = x * math.sin(a3) + y * math.cos(a3) 44 | 45 | s1 = np.square(np.sin(z1)) 46 | s2 = np.square(np.sin(z2)) 47 | s3 = np.square(np.sin(z3)) 48 | 49 | line_1 = np.exp(s1 * spacing) * s1 50 | line_2 = np.exp(s2 * spacing) * s2 51 | line_3 = np.exp(s3 * spacing) * s3 52 | 53 | global grid 54 | grid = np.clip(1 - (line_1 + line_2 + line_3) / 3, 0, 1) 55 | 56 | global im_handle 57 | if im_handle is None: 58 | im_handle = plt.imshow(grid) 59 | else: 60 | im_handle.set_data(grid) 61 | fig.canvas.draw_idle() 62 | 63 | def update_scale(val): 64 | global scale 65 | scale = val 66 | 67 | render() 68 | 69 | def update_angle(val): 70 | global angle 71 | angle = val 72 | 73 | render() 74 | 75 | def update_centerx(val): 76 | global center_x 77 | center_x = val 78 | 79 | render() 80 | 81 | def update_centery(val): 82 | global center_y 83 | center_y = val 84 | 85 | render() 86 | 87 | def update_spacing(val): 88 | global spacing 89 | spacing = val 90 | 91 | render() 92 | 93 | def randomize(val): 94 | global center_x, center_y, spacing, scale, angle, disable_render 95 | 96 | center_x, center_y = (random.uniform(0, 1), random.uniform(0, 1)) 97 | spacing = random.uniform(-0.2, 2) 98 | scale = 4 * math.exp(random.uniform(-1, 1)) 99 | angle = random.uniform(-math.pi, math.pi) 100 | 101 | disable_render = True 102 | 103 | scale_slider.set_val(scale) 104 | angle_slider.set_val(angle) 105 | centx_slider.set_val(center_x) 106 | centy_slider.set_val(center_y) 107 | spaci_slider.set_val(spacing) 108 | 109 | disable_render = False 110 | 111 | render() 112 | 113 | def add(val): 114 | all_grids.append(grid) 115 | 116 | global unique 117 | if not unique: 118 | unique = test_uniqueness(np.stack(all_grids)) 119 | 120 | export_len_text.set_text('Num Grids: ' + str(len(all_grids))) 121 | fig.canvas.draw_idle() 122 | 123 | def add_randomize(val): 124 | add(val) 125 | randomize(val) 126 | 127 | def export(val): 128 | np.save(save_path, np.stack(all_grids)) 129 | print('Saved %d grids to "%s"' % (len(all_grids), save_path)) 130 | 131 | global unique 132 | unique = False 133 | all_grids.clear() 134 | 135 | export_len_text.set_text('Num Grids: ' + str(len(all_grids))) 136 | fig.canvas.draw_idle() 137 | 138 | def test_uniqueness(grids): 139 | # Grids shape [ngrids, h, w] 140 | grids = grids.reshape((-1, grid_h, grid_w)) 141 | 142 | for y in range(grid_h): 143 | for x in range(grid_h): 144 | pixel_features = grids[:, y, x] 145 | 146 | # l1 distance for this pixel with every other 147 | l1_dist = np.sum(np.abs(grids - np.tile(pixel_features, grid_h*grid_w).reshape((-1, grid_h, grid_w))), axis=0) 148 | 149 | # Equal if l1 distance is really small. Note that this will include this pixel 150 | num_equal = np.sum((l1_dist < 0.0001).astype(np.int32)) 151 | 152 | if num_equal > 1: 153 | print('Pixel at (%d, %d) has %d other pixel%s with the same representation.' % (x, y, num_equal-1, '' if num_equal==2 else 's')) 154 | return False 155 | 156 | print('Each pixel has a distinct representation.') 157 | return True 158 | 159 | 160 | 161 | render() 162 | 163 | axis = plt.axes([0.22, 0.19, 0.59, 0.03], facecolor='lightgoldenrodyellow') 164 | scale_slider = Slider(axis, 'Scale', 0.1, 20, valinit=scale, valstep=0.1) 165 | scale_slider.on_changed(update_scale) 166 | 167 | axis = plt.axes([0.22, 0.15, 0.59, 0.03], facecolor='lightgoldenrodyellow') 168 | angle_slider = Slider(axis, 'Angle', -math.pi, math.pi, valinit=angle, valstep=0.1) 169 | angle_slider.on_changed(update_angle) 170 | 171 | axis = plt.axes([0.22, 0.11, 0.59, 0.03], facecolor='lightgoldenrodyellow') 172 | centx_slider = Slider(axis, 'Center X', 0, 1, valinit=center_x, valstep=0.05) 173 | centx_slider.on_changed(update_centerx) 174 | 175 | axis = plt.axes([0.22, 0.07, 0.59, 0.03], facecolor='lightgoldenrodyellow') 176 | centy_slider = Slider(axis, 'Center Y', 0, 1, valinit=center_y, valstep=0.05) 177 | centy_slider.on_changed(update_centery) 178 | 179 | axis = plt.axes([0.22, 0.03, 0.59, 0.03], facecolor='lightgoldenrodyellow') 180 | spaci_slider = Slider(axis, 'Spacing', -1, 2, valinit=spacing, valstep=0.05) 181 | spaci_slider.on_changed(update_spacing) 182 | 183 | axis = plt.axes([0.8, 0.54, 0.15, 0.05], facecolor='lightgoldenrodyellow') 184 | rando_button = Button(axis, 'Randomize') 185 | rando_button.on_clicked(randomize) 186 | 187 | axis = plt.axes([0.8, 0.48, 0.15, 0.05], facecolor='lightgoldenrodyellow') 188 | addgr_button = Button(axis, 'Add') 189 | addgr_button.on_clicked(add) 190 | 191 | # Likely not a good way to do this but whatever 192 | export_len_text = plt.text(0, 3, 'Num Grids: 0') 193 | 194 | axis = plt.axes([0.8, 0.42, 0.15, 0.05], facecolor='lightgoldenrodyellow') 195 | addra_button = Button(axis, 'Add / Rand') 196 | addra_button.on_clicked(add_randomize) 197 | 198 | axis = plt.axes([0.8, 0.36, 0.15, 0.05], facecolor='lightgoldenrodyellow') 199 | saveg_button = Button(axis, 'Save') 200 | saveg_button.on_clicked(export) 201 | 202 | 203 | 204 | plt.show() 205 | -------------------------------------------------------------------------------- /scripts/optimize_bboxes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instead of clustering bbox widths and heights, this script 3 | directly optimizes average IoU across the training set given 4 | the specified number of anchor boxes. 5 | 6 | Run this script from the Yolact root directory. 7 | """ 8 | 9 | import pickle 10 | import random 11 | from itertools import product 12 | from math import sqrt 13 | 14 | import numpy as np 15 | import torch 16 | from scipy.optimize import minimize 17 | 18 | dump_file = 'weights/bboxes.pkl' 19 | aug_file = 'weights/bboxes_aug.pkl' 20 | 21 | use_augmented_boxes = True 22 | 23 | 24 | def intersect(box_a, box_b): 25 | """ We resize both tensors to [A,B,2] without new malloc: 26 | [A,2] -> [A,1,2] -> [A,B,2] 27 | [B,2] -> [1,B,2] -> [A,B,2] 28 | Then we compute the area of intersect between box_a and box_b. 29 | Args: 30 | box_a: (tensor) bounding boxes, Shape: [A,4]. 31 | box_b: (tensor) bounding boxes, Shape: [B,4]. 32 | Return: 33 | (tensor) intersection area, Shape: [A,B]. 34 | """ 35 | A = box_a.size(0) 36 | B = box_b.size(0) 37 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 38 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 39 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 40 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 41 | inter = torch.clamp((max_xy - min_xy), min=0) 42 | return inter[:, :, 0] * inter[:, :, 1] 43 | 44 | 45 | def jaccard(box_a, box_b, iscrowd=False): 46 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 47 | is simply the intersection over union of two boxes. Here we operate on 48 | ground truth boxes and default boxes. If iscrowd=True, put the crowd in box_b. 49 | E.g.: 50 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 51 | Args: 52 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 53 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 54 | Return: 55 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 56 | """ 57 | inter = intersect(box_a, box_b) 58 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 59 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 60 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 61 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 62 | union = area_a + area_b - inter 63 | 64 | if iscrowd: 65 | return inter / area_a 66 | else: 67 | return inter / union # [A,B] 68 | 69 | # Also convert to point form 70 | def to_relative(bboxes): 71 | return np.concatenate((bboxes[:, 2:4] / bboxes[:, :2], (bboxes[:, 2:4] + bboxes[:, 4:]) / bboxes[:, :2]), axis=1) 72 | 73 | 74 | def make_priors(conv_size, scales, aspect_ratios): 75 | prior_data = [] 76 | conv_h = conv_size[0] 77 | conv_w = conv_size[1] 78 | 79 | # Iteration order is important (it has to sync up with the convout) 80 | for j, i in product(range(conv_h), range(conv_w)): 81 | x = (i + 0.5) / conv_w 82 | y = (j + 0.5) / conv_h 83 | 84 | for scale, ars in zip(scales, aspect_ratios): 85 | for ar in ars: 86 | w = scale * ar / conv_w 87 | h = scale / ar / conv_h 88 | 89 | # Point form 90 | prior_data += [x - w/2, y - h/2, x + w/2, y + h/2] 91 | return torch.Tensor(prior_data).view(-1, 4).cuda() 92 | 93 | 94 | 95 | scales = [[1.68, 2.91], [2.95, 2.22, 0.84], [2.17, 2.22, 3.22], [0.76, 2.06, 2.81], [5.33, 2.79], [13.69]] 96 | aspect_ratios = [[[0.72, 0.96], [0.68, 1.17]], [[1.30, 0.66], [0.63, 1.23], [0.87, 1.41]], [[1.96, 1.23], [0.58, 0.84], [0.61, 1.15]], [[19.79, 2.21], [0.47, 1.76], [1.38, 0.79]], [[4.79, 17.96], [1.04]], [[14.82]]] 97 | conv_sizes = [(35, 35), (18, 18), (9, 9), (5, 5), (3, 3), (2, 2)] 98 | 99 | optimize_scales = False 100 | 101 | batch_idx = 0 102 | 103 | 104 | def compute_hits(bboxes, anchors, iou_threshold=0.5): 105 | ious = jaccard(bboxes, anchors) 106 | perGTAnchorMax, _ = torch.max(ious, dim=1) 107 | 108 | return (perGTAnchorMax > iou_threshold) 109 | 110 | def compute_recall(hits, base_hits): 111 | hits = (hits | base_hits).float() 112 | return torch.sum(hits) / hits.size(0) 113 | 114 | 115 | def step(x, x_func, bboxes, base_hits, optim_idx): 116 | # This should set the scale and aspect ratio 117 | x_func(x, scales[optim_idx], aspect_ratios[optim_idx]) 118 | 119 | anchors = make_priors(conv_sizes[optim_idx], scales[optim_idx], aspect_ratios[optim_idx]) 120 | 121 | return -float(compute_recall(compute_hits(bboxes, anchors), base_hits).cpu()) 122 | 123 | 124 | def optimize(full_bboxes, optim_idx, batch_size=5000): 125 | global batch_idx, scales, aspect_ratios, conv_sizes 126 | 127 | start = batch_idx * batch_size 128 | end = min((batch_idx + 1) * batch_size, full_bboxes.size(0)) 129 | 130 | if batch_idx > (full_bboxes.size(0) // batch_size): 131 | batch_idx = 0 132 | 133 | bboxes = full_bboxes[start:end, :] 134 | 135 | anchor_base = [ 136 | make_priors(conv_sizes[idx], scales[idx], aspect_ratios[idx]) 137 | for idx in range(len(conv_sizes)) if idx != optim_idx] 138 | base_hits = compute_hits(bboxes, torch.cat(anchor_base, dim=0)) 139 | 140 | 141 | def set_x(x, scales, aspect_ratios): 142 | if optimize_scales: 143 | for i in range(len(scales)): 144 | scales[i] = max(x[i], 0) 145 | else: 146 | k = 0 147 | for i in range(len(aspect_ratios)): 148 | for j in range(len(aspect_ratios[i])): 149 | aspect_ratios[i][j] = x[k] 150 | k += 1 151 | 152 | 153 | res = minimize(step, x0=scales[optim_idx] if optimize_scales else sum(aspect_ratios[optim_idx], []), method='Powell', 154 | args = (set_x, bboxes, base_hits, optim_idx),) 155 | 156 | 157 | def pretty_str(x:list): 158 | if isinstance(x, list): 159 | return '[' + ', '.join([pretty_str(y) for y in x]) + ']' 160 | elif isinstance(x, np.ndarray): 161 | return pretty_str(list(x)) 162 | else: 163 | return '%.2f' % x 164 | 165 | if __name__ == '__main__': 166 | 167 | if use_augmented_boxes: 168 | with open(aug_file, 'rb') as f: 169 | bboxes = pickle.load(f) 170 | else: 171 | # Load widths and heights from a dump file. Obtain this with 172 | # python3 scripts/save_bboxes.py 173 | with open(dump_file, 'rb') as f: 174 | bboxes = pickle.load(f) 175 | 176 | bboxes = np.array(bboxes) 177 | bboxes = to_relative(bboxes) 178 | 179 | with torch.no_grad(): 180 | bboxes = torch.Tensor(bboxes).cuda() 181 | 182 | def print_out(): 183 | if optimize_scales: 184 | print('Scales: ' + pretty_str(scales)) 185 | else: 186 | print('Aspect Ratios: ' + pretty_str(aspect_ratios)) 187 | 188 | for p in range(10): 189 | print('(Sub Iteration) ', end='') 190 | for i in range(len(conv_sizes)): 191 | print('%d ' % i, end='', flush=True) 192 | optimize(bboxes, i) 193 | print('Done', end='\r') 194 | 195 | print('(Iteration %d) ' % p, end='') 196 | print_out() 197 | print() 198 | 199 | optimize_scales = not optimize_scales 200 | 201 | print('scales = ' + pretty_str(scales)) 202 | print('aspect_ratios = ' + pretty_str(aspect_ratios)) 203 | 204 | 205 | -------------------------------------------------------------------------------- /scripts/parse_eval.py: -------------------------------------------------------------------------------- 1 | import re, sys, os 2 | import matplotlib.pyplot as plt 3 | from matplotlib._color_data import XKCD_COLORS 4 | 5 | with open(sys.argv[1], 'r') as f: 6 | txt = f.read() 7 | 8 | txt, overall = txt.split('overall performance') 9 | 10 | class_names = [] 11 | mAP_overall = [] 12 | mAP_small = [] 13 | mAP_medium = [] 14 | mAP_large = [] 15 | 16 | for class_result in txt.split('evaluate category: ')[1:]: 17 | lines = class_result.split('\n') 18 | class_names.append(lines[0]) 19 | 20 | def grabMAP(string): 21 | return float(string.split('] = ')[1]) * 100 22 | 23 | mAP_overall.append(grabMAP(lines[ 7])) 24 | mAP_small .append(grabMAP(lines[10])) 25 | mAP_medium .append(grabMAP(lines[11])) 26 | mAP_large .append(grabMAP(lines[12])) 27 | 28 | mAP_map = { 29 | 'small': mAP_small, 30 | 'medium': mAP_medium, 31 | 'large': mAP_large, 32 | } 33 | 34 | if len(sys.argv) > 2: 35 | bars = plt.bar(class_names, mAP_map[sys.argv[2]]) 36 | plt.title(sys.argv[2] + ' mAP per class') 37 | else: 38 | bars = plt.bar(class_names, mAP_overall) 39 | plt.title('overall mAP per class') 40 | 41 | colors = list(XKCD_COLORS.values()) 42 | 43 | for idx, bar in enumerate(bars): 44 | # Mmm pseudorandom colors 45 | char_sum = sum([ord(char) for char in class_names[idx]]) 46 | bar.set_color(colors[char_sum % len(colors)]) 47 | 48 | plt.xticks(rotation='vertical') 49 | plt.show() 50 | -------------------------------------------------------------------------------- /scripts/plot_loss.py: -------------------------------------------------------------------------------- 1 | import re, sys, os 2 | import matplotlib.pyplot as plt 3 | 4 | from utils.functions import MovingAverage 5 | 6 | with open(sys.argv[1], 'r') as f: 7 | inp = f.read() 8 | 9 | patterns = { 10 | 'train': re.compile(r'\[\s*(?P\d+)\]\s*(?P\d+) \|\| B: (?P\S+) \| C: (?P\S+) \| M: (?P\S+) \|( S: (?P\S+) \|)? T: (?P\S+)'), 11 | 'val': re.compile(r'\s*(?P[a-z]+) \|\s*(?P\S+)') 12 | } 13 | data = {key: [] for key in patterns} 14 | 15 | for line in inp.split('\n'): 16 | for key, pattern in patterns.items(): 17 | f = pattern.search(line) 18 | 19 | if f is not None: 20 | datum = f.groupdict() 21 | for k, v in datum.items(): 22 | if v is not None: 23 | try: 24 | v = float(v) 25 | except ValueError: 26 | pass 27 | datum[k] = v 28 | 29 | if key == 'val': 30 | datum = (datum, data['train'][-1]) 31 | data[key].append(datum) 32 | break 33 | 34 | 35 | def smoother(y, interval=100): 36 | avg = MovingAverage(interval) 37 | 38 | for i in range(len(y)): 39 | avg.append(y[i]) 40 | y[i] = avg.get_avg() 41 | 42 | return y 43 | 44 | def plot_train(data): 45 | plt.title(os.path.basename(sys.argv[1]) + ' Training Loss') 46 | plt.xlabel('Iteration') 47 | plt.ylabel('Loss') 48 | 49 | loss_names = ['BBox Loss', 'Conf Loss', 'Mask Loss'] 50 | 51 | x = [x['iteration'] for x in data] 52 | plt.plot(x, smoother([y['b'] for y in data])) 53 | plt.plot(x, smoother([y['c'] for y in data])) 54 | plt.plot(x, smoother([y['m'] for y in data])) 55 | 56 | if data[0]['s'] is not None: 57 | plt.plot(x, smoother([y['s'] for y in data])) 58 | loss_names.append('Segmentation Loss') 59 | 60 | plt.legend(loss_names) 61 | plt.show() 62 | 63 | def plot_val(data): 64 | plt.title(os.path.basename(sys.argv[1]) + ' Validation mAP') 65 | plt.xlabel('Epoch') 66 | plt.ylabel('mAP') 67 | 68 | x = [x[1]['epoch'] for x in data if x[0]['type'] == 'box'] 69 | plt.plot(x, [x[0]['all'] for x in data if x[0]['type'] == 'box']) 70 | plt.plot(x, [x[0]['all'] for x in data if x[0]['type'] == 'mask']) 71 | 72 | plt.legend(['BBox mAP', 'Mask mAP']) 73 | plt.show() 74 | 75 | if len(sys.argv) > 2 and sys.argv[2] == 'val': 76 | plot_val(data['val']) 77 | else: 78 | plot_train(data['train']) 79 | -------------------------------------------------------------------------------- /scripts/resume.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p GPU-shared 3 | #SBATCH -t 48:00:00 4 | #SBATCH --gres=gpu:p100:1 5 | #SBATCH --no-requeue 6 | 7 | # Usage: ./resume.sh config batch_size resume_file 8 | 9 | module load python/3.6.4_gcc5_np1.14.5 10 | module load cuda/9.0 11 | 12 | cd $SCRATCH/yolact 13 | 14 | python3 train.py --config $1 --batch_size $2 --resume=$3 --save_interval 5000 --start_iter=-1 >>logs/$1_log 2>&1 15 | -------------------------------------------------------------------------------- /scripts/save_bboxes.py: -------------------------------------------------------------------------------- 1 | """ This script transforms and saves bbox coordinates into a pickle object for easy loading. """ 2 | 3 | 4 | import os.path as osp 5 | import json, pickle 6 | import sys 7 | 8 | import numpy as np 9 | 10 | COCO_ROOT = osp.join('.', 'data/coco/') 11 | 12 | annotation_file = 'instances_train2017.json' 13 | annotation_path = osp.join(COCO_ROOT, 'annotations/', annotation_file) 14 | 15 | dump_file = 'weights/bboxes.pkl' 16 | 17 | with open(annotation_path, 'r') as f: 18 | annotations_json = json.load(f) 19 | 20 | annotations = annotations_json['annotations'] 21 | images = annotations_json['images'] 22 | images = {image['id']: image for image in images} 23 | bboxes = [] 24 | 25 | for ann in annotations: 26 | image = images[ann['image_id']] 27 | w,h = (image['width'], image['height']) 28 | 29 | if 'bbox' in ann: 30 | bboxes.append([w, h] + ann['bbox']) 31 | 32 | with open(dump_file, 'wb') as f: 33 | pickle.dump(bboxes, f) 34 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p GPU-shared 3 | #SBATCH -t 48:00:00 4 | #SBATCH --gres=gpu:p100:1 5 | #SBATCH --no-requeue 6 | 7 | # Usage: ./train.sh config batch_size 8 | 9 | module load python/3.6.4_gcc5_np1.14.5 10 | module load cuda/9.0 11 | 12 | cd $SCRATCH/yolact 13 | 14 | python3 train.py --config $1 --batch_size $2 --save_interval 5000 &>logs/$1_log 15 | -------------------------------------------------------------------------------- /scripts/unpack_statedict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys, os 3 | 4 | # Usage python scripts/unpack_statedict.py path_to_pth out_folder/ 5 | # Make sure to include that slash after your out folder, since I can't 6 | # be arsed to do path concatenation so I'd rather type out this comment 7 | 8 | print('Loading state dict...') 9 | state = torch.load(sys.argv[1]) 10 | 11 | if not os.path.exists(sys.argv[2]): 12 | os.mkdir(sys.argv[2]) 13 | 14 | print('Saving stuff...') 15 | for key, val in state.items(): 16 | torch.save(val, sys.argv[2] + key) 17 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentations import SSDAugmentation -------------------------------------------------------------------------------- /utils/cython_nms.pyx: -------------------------------------------------------------------------------- 1 | ## Note: Figure out the license details later. 2 | # 3 | # Based on: 4 | # -------------------------------------------------------- 5 | # Fast R-CNN 6 | # Copyright (c) 2015 Microsoft 7 | # Licensed under The MIT License [see LICENSE for details] 8 | # Written by Ross Girshick 9 | # -------------------------------------------------------- 10 | 11 | cimport cython 12 | import numpy as np 13 | cimport numpy as np 14 | 15 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b) nogil: 16 | return a if a >= b else b 17 | 18 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b) nogil: 19 | return a if a <= b else b 20 | 21 | @cython.boundscheck(False) 22 | @cython.cdivision(True) 23 | @cython.wraparound(False) 24 | def nms(np.ndarray[np.float32_t, ndim=2] dets, np.float32_t thresh): 25 | cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0] 26 | cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1] 27 | cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2] 28 | cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3] 29 | cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4] 30 | 31 | cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1) 32 | cdef np.ndarray[np.int64_t, ndim=1] order = scores.argsort()[::-1] 33 | 34 | cdef int ndets = dets.shape[0] 35 | cdef np.ndarray[np.int_t, ndim=1] suppressed = \ 36 | np.zeros((ndets), dtype=np.int) 37 | 38 | # nominal indices 39 | cdef int _i, _j 40 | # sorted indices 41 | cdef int i, j 42 | # temp variables for box i's (the box currently under consideration) 43 | cdef np.float32_t ix1, iy1, ix2, iy2, iarea 44 | # variables for computing overlap with box j (lower scoring box) 45 | cdef np.float32_t xx1, yy1, xx2, yy2 46 | cdef np.float32_t w, h 47 | cdef np.float32_t inter, ovr 48 | 49 | with nogil: 50 | for _i in range(ndets): 51 | i = order[_i] 52 | if suppressed[i] == 1: 53 | continue 54 | ix1 = x1[i] 55 | iy1 = y1[i] 56 | ix2 = x2[i] 57 | iy2 = y2[i] 58 | iarea = areas[i] 59 | for _j in range(_i + 1, ndets): 60 | j = order[_j] 61 | if suppressed[j] == 1: 62 | continue 63 | xx1 = max(ix1, x1[j]) 64 | yy1 = max(iy1, y1[j]) 65 | xx2 = min(ix2, x2[j]) 66 | yy2 = min(iy2, y2[j]) 67 | w = max(0.0, xx2 - xx1 + 1) 68 | h = max(0.0, yy2 - yy1 + 1) 69 | inter = w * h 70 | ovr = inter / (iarea + areas[j] - inter) 71 | if ovr >= thresh: 72 | suppressed[j] = 1 73 | 74 | return np.where(suppressed == 0)[0] 75 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch.nn as nn 4 | from collections import deque 5 | from pathlib import Path 6 | from layers.interpolate import InterpolateModule 7 | 8 | 9 | class MovingAverage(): 10 | """ Keeps an average window of the specified number of items. """ 11 | 12 | def __init__(self, max_window_size=1000): 13 | self.max_window_size = max_window_size 14 | self.reset() 15 | 16 | def add(self, elem): 17 | """ Adds an element to the window, removing the earliest element if necessary. """ 18 | if not math.isfinite(elem): 19 | print('Warning: Moving average ignored a value of %f' % elem) 20 | return 21 | 22 | self.window.append(elem) 23 | self.sum += elem 24 | 25 | if len(self.window) > self.max_window_size: 26 | self.sum -= self.window.popleft() 27 | 28 | def append(self, elem): 29 | """ Same as add just more pythonic. """ 30 | self.add(elem) 31 | 32 | def reset(self): 33 | """ Resets the MovingAverage to its initial state. """ 34 | self.window = deque() 35 | self.sum = 0 36 | 37 | def get_avg(self): 38 | """ Returns the average of the elements in the window. """ 39 | return self.sum / max(len(self.window), 1) 40 | 41 | def __str__(self): 42 | return str(self.get_avg()) 43 | 44 | def __repr__(self): 45 | return repr(self.get_avg()) 46 | 47 | 48 | class ProgressBar(): 49 | """ A simple progress bar that just outputs a string. """ 50 | 51 | def __init__(self, length, max_val): 52 | self.max_val = max_val 53 | self.length = length 54 | self.cur_val = 0 55 | 56 | self.cur_num_bars = -1 57 | self._update_str() 58 | 59 | def set_val(self, new_val): 60 | self.cur_val = new_val 61 | 62 | if self.cur_val > self.max_val: 63 | self.cur_val = self.max_val 64 | if self.cur_val < 0: 65 | self.cur_val = 0 66 | 67 | self._update_str() 68 | 69 | def is_finished(self): 70 | return self.cur_val == self.max_val 71 | 72 | def _update_str(self): 73 | num_bars = int(self.length * (self.cur_val / self.max_val)) 74 | 75 | if num_bars != self.cur_num_bars: 76 | self.cur_num_bars = num_bars 77 | # self.string = '█' * num_bars + '░' * (self.length - num_bars) 78 | self.string = '1' * num_bars + '0' * (self.length - num_bars) 79 | 80 | def __repr__(self): 81 | return self.string 82 | 83 | def __str__(self): 84 | return self.string 85 | 86 | 87 | def init_console(): 88 | """ 89 | Initialize the console to be able to use ANSI escape characters on Windows. 90 | """ 91 | if os.name == 'nt': 92 | from colorama import init 93 | init() 94 | 95 | 96 | class SavePath: 97 | """ 98 | Why is this a class? 99 | Why do I have a class for creating and parsing save paths? 100 | What am I doing with my life? 101 | """ 102 | 103 | def __init__(self, model_name:str, epoch:int, iteration:int): 104 | self.model_name = model_name 105 | self.epoch = epoch 106 | self.iteration = iteration 107 | 108 | def get_path(self, root:str=''): 109 | file_name = self.model_name + '_' + str(self.epoch) + '_' + str(self.iteration) + '.pth' 110 | return os.path.join(root, file_name) 111 | 112 | @staticmethod 113 | def from_str(path:str): 114 | file_name = os.path.basename(path) 115 | 116 | if file_name.endswith('.pth'): 117 | file_name = file_name[:-4] 118 | 119 | params = file_name.split('_') 120 | 121 | if file_name.endswith('interrupt'): 122 | params = params[:-1] 123 | 124 | model_name = '_'.join(params[:-2]) 125 | epoch = params[-2] 126 | iteration = params[-1] 127 | 128 | return SavePath(model_name, int(epoch), int(iteration)) 129 | 130 | @staticmethod 131 | def remove_interrupt(save_folder): 132 | for p in Path(save_folder).glob('*_interrupt.pth'): 133 | p.unlink() 134 | 135 | @staticmethod 136 | def get_interrupt(save_folder): 137 | for p in Path(save_folder).glob('*_interrupt.pth'): 138 | return str(p) 139 | return None 140 | 141 | @staticmethod 142 | def get_latest(save_folder, config): 143 | """ Note: config should be config.name. """ 144 | max_iter = -1 145 | max_name = None 146 | 147 | for p in Path(save_folder).glob(config + '_*'): 148 | path_name = str(p) 149 | 150 | try: 151 | save = SavePath.from_str(path_name) 152 | except: 153 | continue 154 | 155 | if save.model_name == config and save.iteration > max_iter: 156 | max_iter = save.iteration 157 | max_name = path_name 158 | 159 | return max_name 160 | -------------------------------------------------------------------------------- /utils/nvinfo.py: -------------------------------------------------------------------------------- 1 | # My version of nvgpu because nvgpu didn't have all the information I was looking for. 2 | import re 3 | import subprocess 4 | import shutil 5 | import os 6 | 7 | 8 | def gpu_info() -> list: 9 | """ 10 | Returns a dictionary of stats mined from nvidia-smi for each gpu in a list. 11 | Adapted from nvgpu: https://pypi.org/project/nvgpu/, but mine has more info. 12 | """ 13 | gpus = [line for line in _run_cmd(['nvidia-smi', '-L']) if line] 14 | gpu_infos = [re.match('GPU ([0-9]+): ([^(]+) \(UUID: ([^)]+)\)', gpu).groups() for gpu in gpus] 15 | gpu_infos = [dict(zip(['idx', 'name', 'uuid'], info)) for info in gpu_infos] 16 | gpu_count = len(gpus) 17 | 18 | lines = _run_cmd(['nvidia-smi']) 19 | selected_lines = lines[7:7 + 3 * gpu_count] 20 | for i in range(gpu_count): 21 | mem_used, mem_total = [int(m.strip().replace('MiB', '')) for m in 22 | selected_lines[3 * i + 1].split('|')[2].strip().split('/')] 23 | 24 | pw_tmp_info, mem_info, util_info = [x.strip() for x in selected_lines[3 * i + 1].split('|')[1:-1]] 25 | 26 | pw_tmp_info = [x[:-1] for x in pw_tmp_info.split(' ') if len(x) > 0] 27 | fan_speed, temperature, pwr_used, pwr_cap = [int(pw_tmp_info[i]) for i in (0, 1, 3, 5)] 28 | gpu_infos[i]['fan_spd'] = fan_speed 29 | gpu_infos[i]['temp'] = temperature 30 | gpu_infos[i]['pwr_used'] = pwr_used 31 | gpu_infos[i]['pwr_cap'] = pwr_cap 32 | 33 | mem_used, mem_total = [int(x) for x in mem_info.replace('MiB', '').split(' / ')] 34 | gpu_infos[i]['mem_used'] = mem_used 35 | gpu_infos[i]['mem_total'] = mem_total 36 | 37 | utilization = int(util_info.split(' ')[0][:-1]) 38 | gpu_infos[i]['util'] = utilization 39 | 40 | gpu_infos[i]['idx'] = int(gpu_infos[i]['idx']) 41 | 42 | return gpu_infos 43 | 44 | 45 | def nvsmi_available() -> bool: 46 | """ Returns whether or not nvidia-smi is present in this system's PATH. """ 47 | return shutil.which('nvidia-smi') is not None 48 | 49 | 50 | def visible_gpus() -> list: 51 | """ Returns a list of the indexes of all the gpus visible to pytorch. """ 52 | 53 | if 'CUDA_VISIBLE_DEVICES' not in os.environ: 54 | return list(range(len(gpu_info()))) 55 | else: 56 | return [int(x.strip()) for x in os.environ['CUDA_VISIBLE_DEVICES'].split(',')] 57 | 58 | 59 | def _run_cmd(cmd: list) -> list: 60 | """ Runs a command and returns a list of output lines. """ 61 | output = subprocess.check_output(cmd) 62 | output = output.decode('UTF-8') 63 | return output.split('\n') -------------------------------------------------------------------------------- /utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import defaultdict 3 | 4 | _total_times = defaultdict(lambda: 0) 5 | _start_times = defaultdict(lambda: -1) 6 | _disabled_names = set() 7 | _timer_stack = [] 8 | _running_timer = None 9 | _disable_all = False 10 | 11 | def disable_all(): 12 | global _disable_all 13 | _disable_all = True 14 | 15 | def enable_all(): 16 | global _disable_all 17 | _disable_all = False 18 | 19 | def disable(fn_name): 20 | """ Disables the given function name fom being considered for the average or outputted in print_stats. """ 21 | _disabled_names.add(fn_name) 22 | 23 | def enable(fn_name): 24 | """ Enables function names disabled by disable. """ 25 | _disabled_names.remove(fn_name) 26 | 27 | def reset(): 28 | """ Resets the current timer. Call this at the start of an iteration. """ 29 | global _running_timer 30 | _total_times.clear() 31 | _start_times.clear() 32 | _timer_stack.clear() 33 | _running_timer = None 34 | 35 | def start(fn_name, use_stack=True): 36 | """ 37 | Start timing the specific function. 38 | Note: If use_stack is True, only one timer can be active at a time. 39 | Once you stop this timer, the previous one will start again. 40 | """ 41 | global _running_timer, _disable_all 42 | 43 | if _disable_all: 44 | return 45 | 46 | if use_stack: 47 | if _running_timer is not None: 48 | stop(_running_timer, use_stack=False) 49 | _timer_stack.append(_running_timer) 50 | start(fn_name, use_stack=False) 51 | _running_timer = fn_name 52 | else: 53 | _start_times[fn_name] = time.perf_counter() 54 | 55 | def stop(fn_name=None, use_stack=True): 56 | """ 57 | If use_stack is True, this will stop the currently running timer and restore 58 | the previous timer on the stack if that exists. Note if use_stack is True, 59 | fn_name will be ignored. 60 | 61 | If use_stack is False, this will just stop timing the timer fn_name. 62 | """ 63 | global _running_timer, _disable_all 64 | 65 | if _disable_all: 66 | return 67 | 68 | if use_stack: 69 | if _running_timer is not None: 70 | stop(_running_timer, use_stack=False) 71 | if len(_timer_stack) > 0: 72 | _running_timer = _timer_stack.pop() 73 | start(_running_timer, use_stack=False) 74 | else: 75 | _running_timer = None 76 | else: 77 | print('Warning: timer stopped with no timer running!') 78 | else: 79 | if _start_times[fn_name] > -1: 80 | _total_times[fn_name] += time.perf_counter() - _start_times[fn_name] 81 | else: 82 | print('Warning: timer for %s stopped before starting!' % fn_name) 83 | 84 | 85 | def print_stats(): 86 | """ Prints the current timing information into a table. """ 87 | print() 88 | 89 | all_fn_names = [k for k in _total_times.keys() if k not in _disabled_names] 90 | 91 | max_name_width = max([len(k) for k in all_fn_names] + [4]) 92 | if max_name_width % 2 == 1: max_name_width += 1 93 | format_str = ' {:>%d} | {:>10.4f} ' % max_name_width 94 | 95 | header = (' {:^%d} | {:^10} ' % max_name_width).format('Name', 'Time (ms)') 96 | print(header) 97 | 98 | sep_idx = header.find('|') 99 | sep_text = ('-' * sep_idx) + '+' + '-' * (len(header)-sep_idx-1) 100 | print(sep_text) 101 | 102 | for name in all_fn_names: 103 | print(format_str.format(name, _total_times[name]*1000)) 104 | 105 | print(sep_text) 106 | print(format_str.format('Total', total_time()*1000)) 107 | print() 108 | 109 | def total_time(): 110 | """ Returns the total amount accumulated across all functions in seconds. """ 111 | return sum([elapsed_time for name, elapsed_time in _total_times.items() if name not in _disabled_names]) 112 | 113 | 114 | class env(): 115 | """ 116 | A class that lets you go: 117 | with timer.env(fn_name): 118 | # (...) 119 | That automatically manages a timer start and stop for you. 120 | """ 121 | 122 | def __init__(self, fn_name, use_stack=True): 123 | self.fn_name = fn_name 124 | self.use_stack = use_stack 125 | 126 | def __enter__(self): 127 | start(self.fn_name, use_stack=self.use_stack) 128 | 129 | def __exit__(self, e, ev, t): 130 | stop(self.fn_name, use_stack=self.use_stack) 131 | 132 | -------------------------------------------------------------------------------- /web/css/index.css: -------------------------------------------------------------------------------- 1 | 2 | /* 3 | Pallete: 4 | 5 | FFFFFF 6 | D2CBCB 7 | 7D8491 8 | 003459 9 | 274C77 10 | 161925 11 | */ 12 | 13 | * { box-sizing: border-box; } 14 | 15 | .big { 16 | font-size:72px; 17 | margin-bottom: 20px; 18 | } 19 | 20 | .list_wrapper { 21 | width: 500px; 22 | padding-top: 2px; 23 | padding-bottom: 20px; 24 | } 25 | 26 | 27 | body { 28 | margin:0; 29 | padding:0; 30 | vertical-align: top; 31 | 32 | background-color: #274C77; 33 | color: #ffffff; 34 | font-family: 'Open Sans', sans-serif; 35 | font-size: 24px; 36 | width: 100%; 37 | height: 99vh; 38 | 39 | display: grid; 40 | grid-template-areas: 41 | 'header' 42 | 'main' 43 | 'footer'; 44 | 45 | grid-template-rows: 100px auto 25px; 46 | 47 | text-align: center; 48 | } 49 | 50 | .box { 51 | background-color: #23395B; 52 | border-radius: 10px; 53 | } 54 | 55 | .header { grid-area: header; } 56 | .main { grid-area: main; } 57 | .footer { grid-area: footer; } 58 | 59 | span { 60 | margin:0; 61 | padding:0; 62 | vertical-align: top; 63 | } 64 | -------------------------------------------------------------------------------- /web/css/list.css: -------------------------------------------------------------------------------- 1 | ul { 2 | list-style-type: none; 3 | margin: 0; 4 | padding: 0; 5 | } 6 | 7 | li { 8 | /* font: 200 24px/1.5 Helvetica, Verdana, sans-serif; */ 9 | font-size: 22px; 10 | } 11 | 12 | li a { 13 | text-decoration: none; 14 | color: #fff; 15 | display: block; 16 | width: 100%; 17 | 18 | -webkit-transition: font-size 0.2s ease, background-color 0.2s ease; 19 | -moz-transition: font-size 0.2s ease, background-color 0.2s ease; 20 | -o-transition: font-size 0.2s ease, background-color 0.2s ease; 21 | -ms-transition: font-size 0.2s ease, background-color 0.2s ease; 22 | transition: font-size 0.2s ease, background-color 0.2s ease; 23 | } 24 | 25 | li a:hover { 26 | font-size: 30px; 27 | background: rgb(95, 138, 219); 28 | } 29 | -------------------------------------------------------------------------------- /web/css/toggle.css: -------------------------------------------------------------------------------- 1 | .switch { 2 | position: relative; 3 | top: 5; 4 | } 5 | 6 | .switch input {display:none;} 7 | 8 | .slider { 9 | position: relative; 10 | display: inline-block; 11 | width: 60px; 12 | height: 26px; 13 | cursor: pointer; 14 | top: 0; 15 | left: 0; 16 | right: 0; 17 | bottom: 0; 18 | background-color: #ccc; 19 | -webkit-transition: .4s; 20 | transition: .4s; 21 | } 22 | 23 | .slider:before { 24 | position: absolute; 25 | content: ""; 26 | height: 20px; 27 | width: 20px; 28 | left: 3px; 29 | bottom: 3px; 30 | background-color: white; 31 | -webkit-transition: .4s; 32 | transition: .4s; 33 | } 34 | 35 | input:checked + .slider { 36 | background-color: #2196F3; 37 | } 38 | 39 | input:focus + .slider { 40 | box-shadow: 0 0 1px #2196F3; 41 | } 42 | 43 | input:checked + .slider:before { 44 | -webkit-transform: translateX(34px); 45 | -ms-transform: translateX(34px); 46 | transform: translateX(34px); 47 | } 48 | 49 | /* Rounded sliders */ 50 | .slider.round { 51 | border-radius: 34px; 52 | } 53 | 54 | .slider.round:before { 55 | border-radius: 50%; 56 | } 57 | -------------------------------------------------------------------------------- /web/css/viewer.css: -------------------------------------------------------------------------------- 1 | 2 | .info { grid-area: info; } 3 | .image { grid-area: image; } 4 | .controls { grid-area: controls; } 5 | 6 | 7 | #viewer { 8 | display: grid; 9 | grid-template-areas: 'info image controls'; 10 | grid-template-columns: 1fr 2fr 1fr; 11 | grid-gap: 0; 12 | } 13 | 14 | #viewer > div.box { 15 | padding: 10px; 16 | margin: 0 10px 10px 10px; 17 | } 18 | 19 | .image_box { 20 | display: grid; 21 | grid-template-rows: max-content auto; 22 | grid-gap: 10px; 23 | } 24 | 25 | #image_idx, #config_name, .info_value { 26 | color: rgb(152, 160, 175); 27 | } 28 | 29 | .info_section { 30 | text-align: center; 31 | border-bottom: 1px solid #fff; 32 | } 33 | 34 | a { 35 | text-decoration: none; 36 | color: #fff; 37 | } 38 | 39 | a:hover { 40 | color: rgb(152, 160, 175); 41 | } 42 | 43 | .setting { 44 | display: grid; 45 | grid-template-areas: 'label value input'; 46 | grid-template-columns: max-content 30px 1fr; 47 | grid-gap: 20px; 48 | padding: 0 10px 0 10px; 49 | text-align: left; 50 | } 51 | .setting_label { grid-area: label; } 52 | .setting_input { 53 | grid-area: input; 54 | } 55 | .setting_value { 56 | grid-area: value; 57 | color: rgb(152, 160, 175); 58 | } 59 | 60 | .box_title { 61 | width: 100%; 62 | border-bottom: 1px solid #fff; 63 | } 64 | 65 | -------------------------------------------------------------------------------- /web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Configurations 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | Detections Viewer 20 | 21 |
22 |
23 |

Select a configuration

24 |
    25 |
26 |
27 |
28 | 29 | By Daniel Bolya 30 | 31 | 32 | -------------------------------------------------------------------------------- /web/iou.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | IoU thingy 11 | 12 | 13 | 14 | 15 | 16 | 17 | 35 | 36 | 37 | 38 |
39 | 40 | This text is displayed if your browser does not support HTML5 Canvas. 41 | 42 |
43 |
44 |

IoU:

45 |


46 |

Bbox manupluation sourced from here

47 |
48 | 49 | 50 | -------------------------------------------------------------------------------- /web/scripts/index.js: -------------------------------------------------------------------------------- 1 | $(document).ready(function() { 2 | // Load in det_index and fill the config list with the appropriate elements 3 | $.ajax({ 4 | url: 'detindex', 5 | dataType: 'text', 6 | success: function (data) { 7 | data = data.trim().split('\n'); 8 | for (let i = 0; i < data.length; i++) { 9 | name = data[i]; 10 | 11 | $('#config_list').append( 12 | '
  • ' + name + '
  • ' 13 | ); 14 | } 15 | } 16 | }); 17 | }); 18 | -------------------------------------------------------------------------------- /web/scripts/js.cookie.js: -------------------------------------------------------------------------------- 1 | /*! 2 | * JavaScript Cookie v2.2.0 3 | * https://github.com/js-cookie/js-cookie 4 | * 5 | * Copyright 2006, 2015 Klaus Hartl & Fagner Brack 6 | * Released under the MIT license 7 | */ 8 | ;(function (factory) { 9 | var registeredInModuleLoader = false; 10 | if (typeof define === 'function' && define.amd) { 11 | define(factory); 12 | registeredInModuleLoader = true; 13 | } 14 | if (typeof exports === 'object') { 15 | module.exports = factory(); 16 | registeredInModuleLoader = true; 17 | } 18 | if (!registeredInModuleLoader) { 19 | var OldCookies = window.Cookies; 20 | var api = window.Cookies = factory(); 21 | api.noConflict = function () { 22 | window.Cookies = OldCookies; 23 | return api; 24 | }; 25 | } 26 | }(function () { 27 | function extend () { 28 | var i = 0; 29 | var result = {}; 30 | for (; i < arguments.length; i++) { 31 | var attributes = arguments[ i ]; 32 | for (var key in attributes) { 33 | result[key] = attributes[key]; 34 | } 35 | } 36 | return result; 37 | } 38 | 39 | function init (converter) { 40 | function api (key, value, attributes) { 41 | var result; 42 | if (typeof document === 'undefined') { 43 | return; 44 | } 45 | 46 | // Write 47 | 48 | if (arguments.length > 1) { 49 | attributes = extend({ 50 | path: '/' 51 | }, api.defaults, attributes); 52 | 53 | if (typeof attributes.expires === 'number') { 54 | var expires = new Date(); 55 | expires.setMilliseconds(expires.getMilliseconds() + attributes.expires * 864e+5); 56 | attributes.expires = expires; 57 | } 58 | 59 | // We're using "expires" because "max-age" is not supported by IE 60 | attributes.expires = attributes.expires ? attributes.expires.toUTCString() : ''; 61 | 62 | try { 63 | result = JSON.stringify(value); 64 | if (/^[\{\[]/.test(result)) { 65 | value = result; 66 | } 67 | } catch (e) {} 68 | 69 | if (!converter.write) { 70 | value = encodeURIComponent(String(value)) 71 | .replace(/%(23|24|26|2B|3A|3C|3E|3D|2F|3F|40|5B|5D|5E|60|7B|7D|7C)/g, decodeURIComponent); 72 | } else { 73 | value = converter.write(value, key); 74 | } 75 | 76 | key = encodeURIComponent(String(key)); 77 | key = key.replace(/%(23|24|26|2B|5E|60|7C)/g, decodeURIComponent); 78 | key = key.replace(/[\(\)]/g, escape); 79 | 80 | var stringifiedAttributes = ''; 81 | 82 | for (var attributeName in attributes) { 83 | if (!attributes[attributeName]) { 84 | continue; 85 | } 86 | stringifiedAttributes += '; ' + attributeName; 87 | if (attributes[attributeName] === true) { 88 | continue; 89 | } 90 | stringifiedAttributes += '=' + attributes[attributeName]; 91 | } 92 | return (document.cookie = key + '=' + value + stringifiedAttributes); 93 | } 94 | 95 | // Read 96 | 97 | if (!key) { 98 | result = {}; 99 | } 100 | 101 | // To prevent the for loop in the first place assign an empty array 102 | // in case there are no cookies at all. Also prevents odd result when 103 | // calling "get()" 104 | var cookies = document.cookie ? document.cookie.split('; ') : []; 105 | var rdecode = /(%[0-9A-Z]{2})+/g; 106 | var i = 0; 107 | 108 | for (; i < cookies.length; i++) { 109 | var parts = cookies[i].split('='); 110 | var cookie = parts.slice(1).join('='); 111 | 112 | if (!this.json && cookie.charAt(0) === '"') { 113 | cookie = cookie.slice(1, -1); 114 | } 115 | 116 | try { 117 | var name = parts[0].replace(rdecode, decodeURIComponent); 118 | cookie = converter.read ? 119 | converter.read(cookie, name) : converter(cookie, name) || 120 | cookie.replace(rdecode, decodeURIComponent); 121 | 122 | if (this.json) { 123 | try { 124 | cookie = JSON.parse(cookie); 125 | } catch (e) {} 126 | } 127 | 128 | if (key === name) { 129 | result = cookie; 130 | break; 131 | } 132 | 133 | if (!key) { 134 | result[name] = cookie; 135 | } 136 | } catch (e) {} 137 | } 138 | 139 | return result; 140 | } 141 | 142 | api.set = api; 143 | api.get = function (key) { 144 | return api.call(api, key); 145 | }; 146 | api.getJSON = function () { 147 | return api.apply({ 148 | json: true 149 | }, [].slice.call(arguments)); 150 | }; 151 | api.defaults = {}; 152 | 153 | api.remove = function (key, attributes) { 154 | api(key, '', extend(attributes, { 155 | expires: -1 156 | })); 157 | }; 158 | 159 | api.withConverter = init; 160 | 161 | return api; 162 | } 163 | 164 | return init(function () {}); 165 | })); 166 | -------------------------------------------------------------------------------- /web/scripts/utils.js: -------------------------------------------------------------------------------- 1 | function load_RLE(rle_obj, fillColor=[255, 255, 255], alpha=255) { 2 | var h = rle_obj.size[0], w = rle_obj.size[1]; 3 | var counts = uncompress_RLE(rle_obj.counts); 4 | 5 | var buffer_size = (w*h*4); 6 | var buffer = new Uint8ClampedArray(w*h*4); 7 | var bufferIdx = 0; 8 | 9 | for (var countsIdx = 0; countsIdx < counts.length; countsIdx++) { 10 | while (counts[countsIdx] > 0) { 11 | // Kind of transpose the image as we go 12 | if (bufferIdx >= buffer_size) 13 | bufferIdx = (bufferIdx % buffer_size) + 4; 14 | 15 | buffer[bufferIdx+0] = fillColor[0]; 16 | buffer[bufferIdx+1] = fillColor[1]; 17 | buffer[bufferIdx+2] = fillColor[2]; 18 | buffer[bufferIdx+3] = alpha * (countsIdx % 2); 19 | 20 | bufferIdx += 4*w; 21 | counts[countsIdx]--; 22 | } 23 | } 24 | 25 | // Load into an off-screen canvas and return an image with that data 26 | var canvas = document.createElement('canvas'); 27 | var ctx = canvas.getContext('2d'); 28 | 29 | canvas.width = w; 30 | canvas.height = h; 31 | 32 | var idata = ctx.createImageData(w, h); 33 | idata.data.set(buffer); 34 | 35 | ctx.putImageData(idata, 0, 0); 36 | 37 | var img = new Image(); 38 | img.src = canvas.toDataURL(); 39 | 40 | return img; 41 | } 42 | 43 | function uncompress_RLE(rle_str) { 44 | // Don't ask me how this works--I'm just transcribing from the pycocotools c api. 45 | var p = 0, m = 0; 46 | var counts = Array(rle_str.lenght); 47 | 48 | while (p < rle_str.length) { 49 | var x=0, k=0, more=1; 50 | 51 | while (more) { 52 | var c = rle_str.charCodeAt(p) - 48; 53 | x |= (c & 0x1f) << 5*k; 54 | more = c & 0x20; 55 | p++; k++; 56 | if (!more && (c & 0x10)) 57 | x |= (-1 << 5*k); 58 | } 59 | 60 | if (m > 2) 61 | x += counts[m-2]; 62 | counts[m++] = (x >>> 0); 63 | } 64 | 65 | return counts; 66 | } 67 | 68 | function hexToRgb(hex) { 69 | var result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex); 70 | return result ? [parseInt(result[1], 16), parseInt(result[2], 16), parseInt(result[3], 16)] : null; 71 | } 72 | -------------------------------------------------------------------------------- /web/scripts/viewer.js: -------------------------------------------------------------------------------- 1 | // Global variables so I remember them 2 | config_name = null; 3 | img_idx = null; 4 | 5 | img = null; 6 | dets = null; 7 | masks = null; 8 | 9 | // Must be in hex 10 | colors = ['#FF0000', '#FF7F00', '#00FF00', '#0000FF', '#4B0082', '#9400D3']; 11 | 12 | settings = { 13 | 'top_k': 5, 14 | 'font_height': 20, 15 | 'mask_alpha': 100, 16 | 17 | 'show_class': true, 18 | 'show_score': true, 19 | 'show_bbox': true, 20 | 'show_mask': true, 21 | 22 | 'show_one': false, 23 | } 24 | 25 | function save_settings() { 26 | Cookies.set('settings', settings); 27 | } 28 | 29 | function load_settings() { 30 | var new_settings = Cookies.getJSON('settings'); 31 | 32 | for (var key in new_settings) 33 | settings[key] = new_settings[key]; 34 | } 35 | 36 | $.urlParam = function(name){ 37 | var results = new RegExp('[\?&]' + name + '=([^&#]*)').exec(window.location.href); 38 | if (results==null){ 39 | return null; 40 | } 41 | else{ 42 | return decodeURI(results[1]) || 0; 43 | } 44 | } 45 | 46 | $(document).ready(function() { 47 | config_name = $.urlParam('config'); 48 | $('#config_name').html(config_name); 49 | 50 | img_idx = $.urlParam('idx'); 51 | if (img_idx === null) img_idx = 0; 52 | img_idx = parseInt(img_idx); 53 | 54 | load_settings(); 55 | 56 | $.getJSON('dets/' + config_name + '.json', function(data) { 57 | img_idx = (img_idx+data.images.length) % data.images.length; 58 | var info = data.info; 59 | var data = data.images[img_idx]; 60 | 61 | // These are globals on purpose 62 | dets = data.dets; 63 | img = new Image(); 64 | masks = Array(dets.length); 65 | 66 | img.onload = function() { render(); } 67 | img.src = 'image' + data.image_id; 68 | 69 | $('#image_name').html(data.image_id); 70 | $('#image_idx').html(img_idx); 71 | 72 | fill_info(info); 73 | fill_controls(); 74 | }); 75 | }); 76 | 77 | function is_object(val) { return val === Object(val); } 78 | 79 | function fill_info(info) { 80 | var html = ''; 81 | 82 | var add_item = function(item, val) { 83 | html += '' + item + '' 84 | html += ' ' 85 | html += '' + val + '' 86 | html += '
    ' 87 | } 88 | 89 | for (var item in info) { 90 | var val = info[item]; 91 | 92 | if (is_object(val)) { 93 | html += '' + item + '
    '; 94 | 95 | for (var item2 in val) 96 | add_item(item2, val[item2]); 97 | 98 | html += '
    ' 99 | } else add_item(item, val); 100 | } 101 | 102 | $('#info_box').html(html); 103 | } 104 | 105 | function fill_controls() { 106 | var html = ''; 107 | 108 | var append_html = function() { 109 | $('#control_box').append(html); 110 | html = ''; 111 | } 112 | 113 | var make_slider = function (name, setting, min, max) { 114 | settings[setting] = Math.min(max, settings[setting]); 115 | var value = settings[setting]; 116 | 117 | html += '
    '; 118 | html += '' + name + ''; 119 | html += ''; 120 | html += '' + value + ''; 121 | html += '
    '; 122 | append_html(); 123 | 124 | $('input#'+setting).change(function(e) { 125 | settings[setting] = $('input#'+setting).prop('value'); 126 | $('span#'+setting).html(settings[setting]); 127 | save_settings(); 128 | render(); 129 | }); 130 | } 131 | 132 | var make_toggle = function(name, setting) { 133 | html += '
    '; 134 | html += '' + name + ''; 135 | html += '
    '; 139 | append_html(); 140 | 141 | $('input#' + setting).change(function (e) { 142 | settings[setting] = $('input#' + setting).prop('checked'); 143 | save_settings(); 144 | render(); 145 | }); 146 | } 147 | 148 | 149 | make_slider('Top K', 'top_k', 1, dets.length); 150 | make_toggle('Show One', 'show_one'); 151 | html += '
    '; 152 | make_toggle('Show BBox', 'show_bbox'); 153 | make_toggle('Show Class', 'show_class'); 154 | make_toggle('Show Score', 'show_score'); 155 | html += '
    '; 156 | make_slider('Mask Alpha', 'mask_alpha', 0, 255); 157 | make_toggle('Show Mask', 'show_mask'); 158 | 159 | html += '

    '; 160 | html += 'Prev'; 161 | html += '   '; 162 | html += 'Next'; 163 | html += '

    '; 164 | html += 'Back'; 165 | 166 | append_html(); 167 | } 168 | 169 | function render() { 170 | var canvas = document.querySelector('#image_canvas'); 171 | var ctx = canvas.getContext('2d'); 172 | 173 | canvas.style.width='100%'; 174 | canvas.style.height='94%'; 175 | canvas.width = canvas.offsetWidth; 176 | canvas.height = canvas.offsetHeight; 177 | 178 | var scale = Math.min(canvas.width / img.width, canvas.height / img.height); 179 | 180 | var im_x = canvas.width/2-img.width*scale/2; 181 | var im_y = canvas.height/2-img.height*scale/2; 182 | ctx.translate(im_x, im_y); 183 | ctx.drawImage(img, 0, 0, img.width * scale, img.height * scale); 184 | 185 | var startIdx = Math.min(dets.length, settings.top_k)-1; 186 | var endIdx = (settings.show_one ? startIdx : 0); 187 | 188 | // Draw masks behind everything 189 | for (var i = startIdx; i >= endIdx; i--) { 190 | if (settings.show_mask) { 191 | var mask = masks[i]; 192 | if (typeof mask == 'undefined') { 193 | masks[i] = load_RLE(dets[i].mask, hexToRgb(colors[i % colors.length])); 194 | masks[i].onload = function() { render(); } 195 | } else { 196 | ctx.globalAlpha = settings.mask_alpha / 255; 197 | ctx.drawImage(mask, 0, 0, mask.width * scale, mask.height * scale); 198 | ctx.globalAlpha = 1; 199 | } 200 | } 201 | } 202 | 203 | for (var i = startIdx; i >= endIdx; i--) { 204 | ctx.strokeStyle = colors[i % colors.length]; 205 | ctx.fillStyle = ctx.strokeStyle; 206 | ctx.lineWidth = 4; 207 | ctx.font = settings.font_height + 'px sans-serif'; 208 | 209 | var x = dets[i].bbox[0] * scale; 210 | var y = dets[i].bbox[1] * scale; 211 | var w = dets[i].bbox[2] * scale; 212 | var h = dets[i].bbox[3] * scale; 213 | 214 | if (settings.show_bbox) { 215 | ctx.strokeRect(x, y, w, h); 216 | ctx.stroke(); 217 | } 218 | 219 | var text_array = [] 220 | if (settings.show_class) 221 | text_array.push(dets[i].category); 222 | if (settings.show_score) 223 | text_array.push(Math.round(dets[i].score * 1000) / 1000); 224 | 225 | if (text_array.length > 0) { 226 | var text = text_array.join(' '); 227 | 228 | text_w = ctx.measureText(text).width; 229 | ctx.fillRect(x-ctx.lineWidth/2, y-settings.font_height-8, text_w+ctx.lineWidth, settings.font_height+8); 230 | 231 | ctx.fillStyle = 'white'; 232 | ctx.fillText(text, x, y-8); 233 | } 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /web/server.py: -------------------------------------------------------------------------------- 1 | from http.server import SimpleHTTPRequestHandler, HTTPServer, HTTPStatus 2 | from pathlib import Path 3 | import os 4 | 5 | PORT = 6337 6 | IMAGE_PATH = '../data/coco/images/' 7 | IMAGE_FMT = '%012d.jpg' 8 | 9 | class Handler(SimpleHTTPRequestHandler): 10 | 11 | def do_GET(self): 12 | if self.path == '/detindex': 13 | self.send_str('\n'.join([p.name[:-5] for p in Path('dets/').glob('*.json')])) 14 | elif self.path.startswith('/image'): 15 | # Unsafe practices ahead! 16 | path = self.translate_path(self.path).split('image') 17 | self.send_file(os.path.join(path[0], IMAGE_PATH, IMAGE_FMT % int(path[1]))) 18 | else: 19 | super().do_GET() 20 | 21 | def send_str(self, string): 22 | self.send_response(HTTPStatus.OK) 23 | self.send_header('Content-type', 'text/plain') 24 | self.send_header('Content-Length', str(len(string))) 25 | self.send_header('Last-Modified', self.date_time_string()) 26 | self.end_headers() 27 | 28 | self.wfile.write(string.encode()) 29 | 30 | def send_file(self, path): 31 | try: 32 | f = open(path, 'rb') 33 | except OSError: 34 | self.send_error(HTTPStatus.NOT_FOUND, "File not found") 35 | return 36 | 37 | try: 38 | self.send_response(HTTPStatus.OK) 39 | self.send_header("Content-type", self.guess_type(path)) 40 | fs = os.fstat(f.fileno()) 41 | self.send_header("Content-Length", str(fs[6])) 42 | self.send_header("Last-Modified", self.date_time_string(fs.st_mtime)) 43 | self.end_headers() 44 | 45 | self.copyfile(f, self.wfile) 46 | finally: 47 | f.close() 48 | 49 | def send_response(self, code, message=None): 50 | super().send_response(code, message) 51 | 52 | 53 | with HTTPServer(('', PORT), Handler) as httpd: 54 | print('Serving at port', PORT) 55 | try: 56 | httpd.serve_forever() 57 | except KeyboardInterrupt: 58 | pass 59 | -------------------------------------------------------------------------------- /web/viewer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Detections Viewer 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | Detections Viewer 23 | 24 |
    25 |
    26 |
    Info  

    27 |
    28 |
    29 | 30 |
    31 |
     
    32 | 33 |
    34 | 35 |
    36 |
    Controls

    37 |
    38 |
    39 |
    40 | 41 | By Daniel Bolya 42 | 43 | 44 | --------------------------------------------------------------------------------