├── assets ├── example.png ├── netron.png └── tensorboard.png ├── .gitignore ├── configs ├── simple_model.py ├── resnet34_pretrained.py ├── resnet50_pretrained.py └── resnet34_custom.py ├── average_meter.py ├── helpers.py ├── LICENSE ├── summary_writer_opt.py ├── fetch_data.py ├── confusion_matrix.py ├── anchor_coverage.py ├── extended_collate.py ├── kitti_randomaccess.py ├── custom_models.py ├── requirements.txt ├── decode_detection.py ├── debug_tools.py ├── README.md ├── image_anno_transforms.py ├── average_precision.py ├── requirements_full.txt ├── data_writer.py ├── name_list_dataset.py ├── imagenet_models.py ├── box_utils.py ├── detection_models.py └── detect2d.py /assets/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrii-khizbullin/pytorch-detection/HEAD/assets/example.png -------------------------------------------------------------------------------- /assets/netron.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrii-khizbullin/pytorch-detection/HEAD/assets/netron.png -------------------------------------------------------------------------------- /assets/tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dmitrii-khizbullin/pytorch-detection/HEAD/assets/tensorboard.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.tar.gz 2 | *.pyc 3 | *.zip 4 | *.pth.tar 5 | logs/ 6 | train_val_split/ 7 | .idea/ 8 | kitti/ 9 | kitti 10 | runs/ 11 | -------------------------------------------------------------------------------- /configs/simple_model.py: -------------------------------------------------------------------------------- 1 | suffix = '_1x1' 2 | 3 | run_name = __name__.split('.')[-1] + suffix 4 | 5 | backbone_specs = { 6 | 'backbone_module': 'custom_models', 7 | 'backbone_function': 'simple_backbone', 8 | 'kwargs': {}, 9 | 'head_channel_multiplier': 64, 10 | } 11 | 12 | train_val_split_dir = 'train_val_split' 13 | 14 | epochs_before_val = 4 15 | -------------------------------------------------------------------------------- /configs/resnet34_pretrained.py: -------------------------------------------------------------------------------- 1 | suffix = '' 2 | 3 | run_name = __name__.split('.')[-1] + suffix 4 | 5 | backbone_specs = { 6 | 'backbone_module': 'imagenet_models', 7 | 'backbone_function': 'resnet34_backbone', 8 | 'kwargs': { 9 | 'pretrained': True, 10 | }, 11 | 'head_channel_multiplier': 128, 12 | } 13 | 14 | multibox_specs = { 15 | 'use_ohem': True 16 | } 17 | 18 | train_val_split_dir = 'train_val_split' 19 | 20 | epochs_before_val = 4 21 | -------------------------------------------------------------------------------- /configs/resnet50_pretrained.py: -------------------------------------------------------------------------------- 1 | suffix = '' 2 | 3 | run_name = __name__.split('.')[-1] + suffix 4 | 5 | backbone_specs = { 6 | 'backbone_module': 'imagenet_models', 7 | 'backbone_function': 'resnet50_backbone', 8 | 'kwargs': { 9 | 'pretrained': True, 10 | }, 11 | 'head_channel_multiplier': 128, 12 | } 13 | 14 | multibox_specs = { 15 | 'use_ohem': True 16 | } 17 | 18 | train_val_split_dir = 'train_val_split' 19 | 20 | epochs_before_val = 4 21 | -------------------------------------------------------------------------------- /configs/resnet34_custom.py: -------------------------------------------------------------------------------- 1 | suffix = '' 2 | 3 | run_name = __name__.split('.')[-1] + suffix 4 | 5 | backbone_specs = { 6 | 'backbone_module': 'imagenet_models', 7 | 'backbone_function': 'resnet34_backbone', 8 | 'kwargs': { 9 | 'pretrained': False, 10 | 'channel_config': (1, 2, 2, 2), 11 | 'channel_multiplier': 64, 12 | }, 13 | 'head_channel_multiplier': 64, 14 | } 15 | 16 | train_val_split_dir = 'train_val_split' 17 | 18 | epochs_before_val = 4 19 | -------------------------------------------------------------------------------- /average_meter.py: -------------------------------------------------------------------------------- 1 | # Borrowed from pytorch/examples/imagenet/main.py 2 | 3 | class AverageMeter: 4 | """Computes and stores the average and current value""" 5 | def __init__(self): 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = 0 10 | self.avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.avg = self.sum / self.count 19 | -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | 5 | IMG_EXTENSIONS = ['.png', '.PNG'] 6 | 7 | 8 | def is_image_file(filename): 9 | """Borrowed helper.""" 10 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 11 | 12 | 13 | def pil_loader(path): 14 | """Borrowed helper""" 15 | with open(path, 'rb') as f: 16 | with Image.open(f) as img: 17 | return img.convert('RGB') 18 | 19 | 20 | def clean_dir(dir): 21 | if not os.path.exists(dir): 22 | os.makedirs(dir) 23 | else: 24 | files = glob.glob(os.path.join(dir, '*')) 25 | for f in files: 26 | os.remove(f) 27 | 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Dmitrii Khizbullin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /summary_writer_opt.py: -------------------------------------------------------------------------------- 1 | # Author: Dmitrii Khizbullin 2 | 3 | import os 4 | import time 5 | from tensorboardX import SummaryWriter 6 | 7 | 8 | class SummaryWriterOpt: 9 | def __init__(self, enabled=True, log_root_dir="logs", suffix=None): 10 | self.enabled = enabled 11 | self.writer = None 12 | time_str = time.strftime("%Y.%m.%d_%H-%M-%S", time.gmtime()) 13 | if suffix is not None: 14 | log_dir = time_str + "_" + suffix 15 | else: 16 | log_dir = time_str 17 | self.log_dir = os.path.join(log_root_dir, log_dir) 18 | 19 | def _create_writer(self): 20 | if self.writer is None and self.enabled: 21 | self.writer = SummaryWriter(log_dir=self.log_dir) 22 | 23 | def add_scalar(self, *args): 24 | self._create_writer() 25 | if self.writer is not None: 26 | self.writer.add_scalar(*args) 27 | 28 | def add_image(self, *args, **kwargs): 29 | self._create_writer() 30 | if self.writer is not None: 31 | self.writer.add_image(*args, **kwargs) 32 | 33 | def add_histogram(self, *args, **kwargs): 34 | self._create_writer() 35 | if self.writer is not None: 36 | self.writer.add_histogram(*args, **kwargs) 37 | 38 | -------------------------------------------------------------------------------- /fetch_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import urllib.request 4 | import multiprocessing 5 | 6 | 7 | files = [ 8 | "data_object_calib.zip", 9 | "data_object_label_2.zip", 10 | "data_object_image_2.zip", 11 | # "data_object_velodyne.zip", 12 | ] 13 | 14 | # files = [ 15 | # "data_tracking_velodyne.zip", 16 | # "data_tracking_image_2.zip", 17 | # "data_tracking_oxts.zip", 18 | # "data_tracking_calib.zip", 19 | # "data_tracking_label_2.zip", 20 | # ] 21 | 22 | # files = [ 23 | # "data_odometry_velodyne.zip", 24 | # ] 25 | 26 | location = "https://s3.eu-central-1.amazonaws.com/avg-kitti/" 27 | 28 | dst_dir = "kitti/" 29 | os.makedirs(dst_dir, exist_ok=True) 30 | 31 | 32 | def download(file): 33 | url = location + file 34 | dst_file = os.path.join(dst_dir, file) 35 | 36 | if not os.path.exists(dst_file): 37 | print("Downloading", url) 38 | urllib.request.urlretrieve(url, dst_file) 39 | 40 | print("Unzipping", url) 41 | 42 | with zipfile.ZipFile(dst_file, 'r') as zip_ref: 43 | zip_ref.extractall(dst_dir) 44 | 45 | print("Done", url) 46 | 47 | 48 | def main(): 49 | pool = multiprocessing.Pool(len(files)) 50 | pool.map(download, files) 51 | 52 | print("Done!") 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /confusion_matrix.py: -------------------------------------------------------------------------------- 1 | # Create a plot of the confusion matrix 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def save_confusion_matrix(confusion_matrix, labelmap, path): 8 | plt.interactive(False) 9 | 10 | norm_conf = [] 11 | for i in confusion_matrix: 12 | a = 0 13 | tmp_arr = [] 14 | a = sum(i, 0) 15 | for j in i: 16 | tmp_arr.append(float(j)/float(a)) 17 | norm_conf.append(tmp_arr) 18 | 19 | fig = plt.figure() 20 | plt.clf() 21 | ax = fig.add_subplot(111) 22 | ax.set_aspect(1) 23 | res = ax.imshow( 24 | np.array(norm_conf), cmap=plt.cm.jet, 25 | interpolation='nearest') 26 | 27 | width, height = confusion_matrix.shape 28 | 29 | for x in range(width): 30 | for y in range(height): 31 | ax.annotate( 32 | str(confusion_matrix[x, y]), xy=(y, x), 33 | horizontalalignment='center', 34 | verticalalignment='center') 35 | 36 | cb = fig.colorbar(res) 37 | plt.xticks(range(width), labelmap[:width], rotation='vertical') 38 | plt.yticks(range(height), labelmap[:height]) 39 | plt.xlabel('Detected class') 40 | plt.ylabel('Annotation class') 41 | plt.subplots_adjust(bottom=0.32, left=0.15) 42 | plt.savefig(path, format='png', dpi=300) 43 | 44 | pass -------------------------------------------------------------------------------- /anchor_coverage.py: -------------------------------------------------------------------------------- 1 | # Author: Dmitrii Khizbullin 2 | # Builds a table of numbers of anchor boxes which cover each ground truth box 3 | 4 | import numpy as np 5 | 6 | 7 | class AnchorCoverage: 8 | def __init__(self): 9 | self.annotations = [] 10 | self.stats = [] 11 | 12 | def add_batch(self, annotations, stats): 13 | self.annotations.extend(annotations) 14 | self.stats.extend(stats) 15 | 16 | def print(self): 17 | import math 18 | from tabulate import tabulate 19 | 20 | num_height_bins = 12 21 | max_quantity = 20 22 | 23 | hist = np.zeros((num_height_bins, max_quantity), dtype=np.int) 24 | for frame_anno, frame_stat in zip(self.annotations, self.stats): 25 | for anno, stat in zip(frame_anno, frame_stat): 26 | height = anno['bbox'][3] - anno['bbox'][1] 27 | height_bin = min(math.floor(-math.log2(height)*2), num_height_bins-1) 28 | quantity_bin = min(stat, max_quantity-1) 29 | hist[height_bin, quantity_bin] += 1 30 | 31 | table = tabulate(hist, headers=list(range(max_quantity)), tablefmt="fancy_grid") 32 | print(table) 33 | fn = hist[:, 0].sum() 34 | pos = hist.sum() 35 | fnr = fn / pos 36 | print('Anchor FNR {:6f} = {}/{}'.format(fnr, fn, pos)) 37 | pass 38 | -------------------------------------------------------------------------------- /extended_collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | import re 4 | from torch._six import string_classes, int_classes 5 | 6 | 7 | def extended_collate(batch, depth=0, collate_first_n=2): 8 | """ 9 | Puts each data field into a tensor with outer dimension batch size. 10 | Dmitrii Khzibullin: iteratively collate only first 2 items: image and target. 11 | """ 12 | 13 | depth += 1 14 | 15 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 16 | elem_type = type(batch[0]) 17 | if torch.is_tensor(batch[0]): 18 | out = None 19 | return torch.stack(batch, 0, out=out) 20 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 21 | and elem_type.__name__ != 'string_': 22 | elem = batch[0] 23 | if elem_type.__name__ == 'ndarray': 24 | # array of string classes and object 25 | if re.search('[SaUO]', elem.dtype.str) is not None: 26 | raise TypeError(error_msg.format(elem.dtype)) 27 | 28 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 29 | elif isinstance(batch[0], string_classes): 30 | return batch 31 | elif isinstance(batch[0], collections.Sequence): 32 | transposed = [v for v in zip(*batch)] 33 | if depth == 1: # collate image and target only 34 | num_first = collate_first_n 35 | else: 36 | num_first = len(transposed) 37 | transposed_process = transposed[:num_first] 38 | transposed_noprocess = transposed[num_first:] 39 | collated = [extended_collate(samples, depth=depth) for samples in transposed_process] 40 | merged = [*collated, *transposed_noprocess] 41 | return merged 42 | else: 43 | return batch 44 | 45 | -------------------------------------------------------------------------------- /kitti_randomaccess.py: -------------------------------------------------------------------------------- 1 | """Load Kitti samples""" 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | def get_image(filename): 8 | """Function to read image files into arrays.""" 9 | return np.asarray(Image.open(filename), np.uint8) 10 | 11 | 12 | def get_image_pil(filename): 13 | """Function to read image files into arrays.""" 14 | return Image.open(filename) 15 | 16 | 17 | def get_velo_scan(filename): 18 | """Function to parse velodyne binary files into arrays.""" 19 | scan = np.fromfile(filename, dtype=np.float32) 20 | return scan.reshape((-1, 4)) 21 | 22 | 23 | def get_calib(filename): 24 | """Function to parse calibration text files into a dictionary.""" 25 | data = {} 26 | with open(filename, 'r') as f: 27 | lines = f.readlines() 28 | for i in range(7): 29 | key, value = lines[i].split(':', 1) 30 | data[key] = np.array([float(x) for x in value.split()]) 31 | return data 32 | 33 | 34 | def get_label(filename): 35 | """Function to parse label text files into a dictionary.""" 36 | data = [] 37 | with open(filename, 'r') as f: 38 | for line in f.readlines(): 39 | values = line.split() 40 | assert len(values) == 15 41 | obj = { 42 | 'type': str(values[0]), 43 | 'truncated': float(values[1]), 44 | 'occluded': int(values[2]), 45 | 'alpha': float(values[3]), 46 | 'bbox': np.array(values[4:8], dtype=float), 47 | 'dimensions': np.array(values[8:11], dtype=float), 48 | 'location': np.array(values[11:14], dtype=float), 49 | 'rotation_y': float(values[14]), 50 | } 51 | data.append(obj) 52 | return data 53 | 54 | 55 | -------------------------------------------------------------------------------- /custom_models.py: -------------------------------------------------------------------------------- 1 | # 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class ConvBlock(nn.Module): 7 | def __init__(self, inplanes, outplanes, kernel_size, stride=1): 8 | super().__init__() 9 | padding = (kernel_size - 1) // 2 10 | self.conv = nn.Conv2d(inplanes, outplanes, kernel_size, stride, padding=padding, bias=False) 11 | self.bn = nn.BatchNorm2d(outplanes) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | x = self.bn(x) 17 | x = self.relu(x) 18 | return x 19 | 20 | 21 | class SimpleBackbone(nn.Module): 22 | 23 | def __init__(self): 24 | super().__init__() 25 | 26 | base_multiplier = 32 27 | 28 | self.config = ( 29 | (1.0, 7, 2, False), 30 | 31 | (2.0, 3, 2, False), 32 | (1.0, 1, 1, False), 33 | (2.0, 3, 1, True), # <- branch 34 | 35 | (2.0, 3, 2, False), 36 | (1.0, 1, 1, False), 37 | (2.0, 3, 1, True), # <- branch 38 | 39 | (2.0, 3, 2, False), 40 | (1.0, 1, 1, False), 41 | (2.0, 3, 1, True), # <- branch 42 | 43 | (2.0, 3, 2, False), 44 | (1.0, 1, 1, False), 45 | (2.0, 3, 1, True), # <- branch 46 | ) 47 | 48 | in_planes = 3 49 | self.layers = nn.ModuleList() 50 | for ch_mul, kernel_size, stride, is_branch in self.config: 51 | out_planes = int(base_multiplier * ch_mul) 52 | conv = ConvBlock(in_planes, out_planes, kernel_size=kernel_size, stride=stride) 53 | self.layers.append(conv) 54 | in_planes = out_planes 55 | 56 | def forward(self, x): 57 | 58 | branches = [] 59 | for conv, is_branch in zip(self.layers, [c[3] for c in self.config]): 60 | x = conv(x) 61 | if is_branch: 62 | branches.append(x) 63 | 64 | return branches 65 | 66 | 67 | def simple_backbone(**kwargs): 68 | """Constructs a simple backbone.""" 69 | model = SimpleBackbone(**kwargs) 70 | return model 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: osx-64 4 | blas=1.1=openblas 5 | bzip2=1.0.6=1 6 | ca-certificates=2018.8.24=ha4d7672_0 7 | cairo=1.14.12=he56eebe_3 8 | certifi=2018.8.24=py36_1001 9 | cffi=1.11.5=py36h6174b99_1 10 | ffmpeg=4.0.2=ha6a6e2b_0 11 | fontconfig=2.13.0=h8c010e7_5 12 | freetype=2.8.1=hfa320df_1 13 | gettext=0.19.8.1=h1f1d5ed_1 14 | giflib=5.1.4=h470a237_1 15 | glib=2.55.0=h464dc38_2 16 | gmp=6.1.2=hfc679d8_0 17 | gnutls=3.5.19=h2a4e5f8_1 18 | graphite2=1.3.12=h7d4d677_1 19 | harfbuzz=1.8.5=h2bb21d5_0 20 | hdf5=1.10.2=hc401514_2 21 | icu=58.2=hfc679d8_0 22 | intel-openmp=2019.0=118 23 | jasper=1.900.1=hff1ad4c_5 24 | jpeg=9c=h470a237_1 25 | libcxx=4.0.1=h579ed51_0 26 | libcxxabi=4.0.1=hebd6815_0 27 | libedit=3.1.20170329=hb402a30_2 28 | libffi=3.2.1=h475c297_4 29 | libgfortran=3.0.1=h93005f0_2 30 | libiconv=1.15=h470a237_3 31 | libopenblas=0.3.3=hdc02c5d_3 32 | libpng=1.6.34=he12f830_0 33 | libprotobuf=3.6.0=hd28b015_0 34 | libtiff=4.0.9=hcb84e12_2 35 | libwebp=0.5.2=7 36 | libxml2=2.9.8=h422b904_5 37 | mkl=2019.0=118 38 | mkl_fft=1.0.6=py36_0 39 | mkl_random=1.0.1=py36_0 40 | ncurses=6.1=h0a44026_0 41 | nettle=3.3=0 42 | ninja=1.8.2=py36h04f5b5a_1 43 | numpy=1.15.2=py36_blas_openblashd3ea46f_1 44 | numpy-base=1.15.2=py36ha711998_1 45 | olefile=0.46=py36_0 46 | openblas=0.2.20=8 47 | opencv=3.4.3=py36_blas_openblash553dce0_200 48 | openh264=1.7.0=0 49 | openssl=1.0.2p=h470a237_0 50 | pcre=8.41=hfc679d8_3 51 | pillow=5.2.0=py36h2dc6135_1 52 | pip=10.0.1=py36_0 53 | pixman=0.34.0=h470a237_3 54 | protobuf=3.6.0=py36hfc679d8_0 55 | pycparser=2.19=py36_0 56 | python=3.6.6=hc167b69_0 57 | pytorch=0.4.1=py36_cuda0.0_cudnn0.0_1 58 | pytorch-nightly-cpu=1.0.0.dev20181005=py3.6_0 59 | readline=7.0=h1de35cc_5 60 | setuptools=40.4.3=py36_0 61 | six=1.11.0=py36_1 62 | sqlite=3.25.2=ha441bb4_0 63 | tabulate=0.8.2=py36_0 64 | tensorboardx=1.4=py_0 65 | termcolor=1.1.0=py36_1 66 | tk=8.6.8=ha441bb4_0 67 | torchvision=0.2.1=py36_1 68 | wheel=0.32.0=py36_0 69 | x264=1!152.20180717=h470a237_1 70 | xz=5.2.4=h1de35cc_4 71 | zlib=1.2.11=hf3cbc9b_2 72 | -------------------------------------------------------------------------------- /decode_detection.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/amdegroot/ssd.pytorch/blob/master/layers/functions/detection.py 2 | # with minor modifications. 3 | 4 | import torch 5 | from box_utils import decode, nms 6 | 7 | 8 | class Detect: 9 | """At test time, Detect is the final layer of SSD. Decode location preds, 10 | apply non-maximum suppression to location predictions based on conf 11 | scores and threshold to a top_k number of output predictions for both 12 | confidence score and locations. 13 | """ 14 | def __init__(self, num_classes, bkg_label, top_k, nms_thresh, variance): 15 | self.num_classes = num_classes 16 | self.background_label = bkg_label 17 | self.top_k = top_k 18 | # Parameters used in nms. 19 | self.nms_thresh = nms_thresh 20 | if nms_thresh <= 0: 21 | raise ValueError('nms_threshold must be non negative.') 22 | self.variance = variance 23 | 24 | def forward(self, loc_data, conf_data, prior_data, conf_thresh): 25 | """ 26 | Args: 27 | loc_data: (tensor) Loc preds from loc layers 28 | Shape: [batch,num_priors*4] 29 | conf_data: (tensor) Shape: Conf preds from conf layers 30 | Shape: [batch*num_priors,num_classes] 31 | prior_data: (tensor) Prior boxes and variances from priorbox layers 32 | Shape: [1,num_priors,4] 33 | """ 34 | batch_size = loc_data.size(0) 35 | num_priors = prior_data.size(0) 36 | output = torch.zeros(batch_size, self.num_classes, self.top_k, 5) 37 | if loc_data.is_cuda: 38 | output = output.cuda() 39 | conf_preds = conf_data.transpose(2, 1) # group by classes 40 | 41 | # Decode predictions into bboxes. 42 | for i in range(batch_size): 43 | decoded_boxes = decode(loc_data[i], prior_data, self.variance) 44 | # For each class, perform nms 45 | conf_scores = conf_preds[i].clone() 46 | 47 | for cl in range(self.num_classes): 48 | c_mask = conf_scores[cl].gt(conf_thresh) 49 | scores = conf_scores[cl][c_mask] 50 | if scores.dim() == 0: 51 | continue 52 | l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes) 53 | thresholded_boxes = decoded_boxes[l_mask] 54 | if len(thresholded_boxes) > 0: 55 | boxes = thresholded_boxes.view(-1, 4) 56 | # idx of highest scoring and non-overlapping boxes per class 57 | ids, count = nms(boxes, scores, self.nms_thresh, self.top_k) 58 | output[i, cl, :count] = \ 59 | torch.cat((scores[ids[:count]].unsqueeze(1), 60 | boxes[ids[:count]]), 1) 61 | flt = output.contiguous().view(batch_size, -1, 5) 62 | _, idx = flt[:, :, 0].sort(1, descending=True) 63 | _, rank = idx.sort(1) 64 | flt[(rank < self.top_k).unsqueeze(-1).expand_as(flt)].fill_(0) 65 | return output 66 | -------------------------------------------------------------------------------- /debug_tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import itertools 4 | import numpy as np 5 | 6 | 7 | def dump_images( 8 | names, pil_images, annotations, detections, stats, 9 | labelmap, dir): 10 | """ 11 | Dumps images with bbox overlays to disk. 12 | 13 | :param names: batch of sample names 14 | :param pil_images: batch of original PIL images 15 | :param annotations: batch of annotations 16 | :param detections: batch of detections from NN 17 | :param stats: batch of debug info from a network. Keeps number of anchors that match particular GT box. 18 | :param labelmap: names of classes 19 | :param dir: destination directory to save images 20 | :return: None 21 | """ 22 | 23 | det_color = (0, 255, 0) 24 | anno_color = (255, 0, 0) 25 | 26 | if annotations is None: annotations = [] 27 | if detections is None: detections = [] 28 | if stats is None: stats = [] 29 | 30 | try: 31 | for ib, (name, pil_img, anno, detection, stat) in \ 32 | enumerate(itertools.zip_longest(names, pil_images, annotations, detections, stats)): 33 | 34 | img = np.asarray(pil_img).copy() 35 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 36 | scale = [img.shape[1], img.shape[0], img.shape[1], img.shape[0]] 37 | if detection is not None: 38 | for icls, cls_det in enumerate(detection): 39 | for det in cls_det: 40 | conf = det[0] 41 | if conf > 0.0: 42 | bbox = det[1:] 43 | bbox_pix = bbox * scale 44 | type = labelmap[icls] 45 | cv2.rectangle( 46 | img, 47 | (int(bbox_pix[0]), int(bbox_pix[1])), 48 | (int(bbox_pix[2]), int(bbox_pix[3])), 49 | det_color, 1) 50 | cv2.putText( 51 | img, 52 | '{} {:.2f}'.format(type, conf), 53 | (int(bbox_pix[0]), int(bbox_pix[1])+10), 54 | cv2.FONT_HERSHEY_SIMPLEX, 55 | 0.4, 56 | det_color) 57 | 58 | if anno is not None and stat is not None: 59 | for obj, num_matches in zip(anno, stat): 60 | bbox = obj['bbox'] 61 | bbox_pix = bbox * scale 62 | cv2.rectangle( 63 | img, 64 | (int(bbox_pix[0]), int(bbox_pix[1])), 65 | (int(bbox_pix[2]), int(bbox_pix[3])), 66 | anno_color, 1) 67 | cv2.putText( 68 | img, 69 | obj['type'] + " M{}".format(num_matches), # M - number of matching anchors 70 | (int(bbox_pix[0]), int(bbox_pix[1])+10), 71 | cv2.FONT_HERSHEY_SIMPLEX, 72 | 0.4, 73 | anno_color) 74 | 75 | filename = name + '.png' 76 | cv2.imwrite(os.path.join(dir, filename), img) 77 | pass 78 | except Exception as e: 79 | pass 80 | 81 | pass 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SDD-inspired detection framework for Kitti 2 | 3 | [![BCH compliance](https://bettercodehub.com/edge/badge/Obs01ete/pytorch-detection?branch=master)](https://bettercodehub.com/) 4 | 5 | ## 2D object bounding box detection is solved using an implementation of SSD detector. 6 | 7 | ### Dataset 8 | The whole Kitti dataset of 6000 images is split into two parts: 95% for train and 5% for validation. Augmentation of random vertical and horizontal shift was applied to tackle the problem of cars being visually at the same positions within the frame 9 | 10 | ### Neural network 11 | Work resolution of 256x512 is used as input to the neural network. Resnet-34 is used as a backbone, VGG-like additional layers are stacked on top of it. Three branches are taken from the backbone, another three - from additional layers. Branch configuration is as follows [channels, height, width]: 12 | - [128, 32, 64] 13 | - [256, 16, 32] 14 | - [512, 8, 16] 15 | - [512, 4, 8] 16 | - [256, 2, 4] 17 | - [256, 1, 2] 18 | 19 | ![](assets/netron.png) 20 | 21 | Imagenet-pretrained model is taken from Pytorch’es samples, weights were not frozen. 22 | Anchors are generated in a way that in one cell there are 6 anchors: 3 aspect ratios and 2 scales. 23 | Model is trained for 1000 epochs at batch size 32 and base learning rate 0.01 with warm-up of 10 epochs. Online hard example mining (OHEM) 1:3 is used to address foreground-background imbalance. Training took 18 hours on 2 1080ti in multi-gpu mode. 24 | The primary advantage of the provided framework and the trainer class for PyTorch is its modular structure. For example, all shapes of regression/classification branches are derived automatically from input resolution and backbone’s branch shapes. Also, calculation of a target for a loss from annotation is encapsulated into the main class SingleShotDetector (file detection_models.py). Number of classes is set via the labelmap and all the parameters are derived from this single source. I have tried to eliminate any possible code or parameter duplication, which is there is many open source implementations of SSD. 25 | 26 | ### Training results 27 | Accuracy metric is implemented according to Kitti specification: minimal IoU overlap 0.7 for cars, 0.5 for other classes. The model’s target metric of Car average precision (AP) was chosen. Achieved AP for car is 0.65 which is quite lower than SotA 91% on Kitti. 28 | 29 | 30 | |Class | AP | 31 | |------|-------| 32 | | Car | 0.654 | 33 | | Van | 0.677 | 34 | | Truck | 0.862 | 35 | | Pedestrian | 0.427 | 36 | | Person_sitting | 0.219 | 37 | | Cyclist | 0.509 | 38 | | Tram | 0.697 | 39 | | Misc | 0.583 | 40 | | mAP | 0.579 | 41 | 42 | Training loss curve shows that the training almost converges. Longer training should yield about 5% higher AP. 43 | 44 | ![](assets/tensorboard.png) 45 | 46 | ### Testing 47 | A sample image with detection results is as follows. Blue are ground truth boxes, green are detected boxes. 48 | 49 | ![](assets/example.png) 50 | 51 | ### How to run 52 | 53 | ``` 54 | python detect2d.py 55 | ``` 56 | To generate detected bounding boxes run 57 | ``` 58 | python detect2d.py --validate 59 | ``` 60 | Resulting images with overlays are saved to `runs/resnet34_pretrained/detection_val_dump` folder. 61 | 62 | ## TODO list 63 | - [x] More fast-learning backbone to train from scratch 64 | - [ ] Add FPN 65 | - [x] Experiment managing 66 | - [ ] Local configuration (.template) 67 | - [ ] Profile 68 | - [ ] Rewrite to albumentations or cv2 69 | - [x] Gradients and weights to tensorboard 70 | -------------------------------------------------------------------------------- /image_anno_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | from PIL import Image, ImageOps 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def normalize_image(): 7 | """Zero-mean input image for faster training.""" 8 | return transforms.Normalize( 9 | mean=[0.5]*3, 10 | std=[0.25]*3) 11 | 12 | 13 | class RandomCropWithAnno: 14 | """ 15 | Crop the given PIL.Image at a random location along with annotations. 16 | Do not drop or crop annotations. 17 | 18 | Args: 19 | pad_val: percent of image width or height to pad 20 | """ 21 | 22 | def __init__(self, pad_val): 23 | self.pad_val = pad_val 24 | 25 | def __call__(self, img, anno): 26 | """ 27 | Args: 28 | img (PIL.Image): Image to be cropped. 29 | anno: corresponding annotation. 30 | 31 | Returns: 32 | PIL.Image: Cropped image + annotation. 33 | """ 34 | 35 | ow, oh = img.size 36 | 37 | border = tuple([int(ow*self.pad_val), int(oh*self.pad_val)] * 2) 38 | img_padded = ImageOps.expand(img, border=border, fill=0) 39 | 40 | w, h = img_padded.size 41 | 42 | x1 = random.randint(0, w - ow) 43 | y1 = random.randint(0, h - oh) 44 | img_out = img_padded.crop((x1, y1, x1 + ow, y1 + oh)) 45 | 46 | anno_offs_x = border[0] - x1 47 | anno_offs_y = border[1] - y1 48 | 49 | for obj in anno: 50 | bbox = obj['bbox'] 51 | bbox += [anno_offs_x, anno_offs_y] * 2 52 | 53 | return img_out, anno 54 | 55 | 56 | class RandomHorizontalFlipWithAnno: 57 | """Horizontally flip the given PIL.Image + annotation randomly with a probability of 0.5.""" 58 | 59 | def __call__(self, img, anno): 60 | """ 61 | Args: 62 | img (PIL.Image): Image to be flipped. 63 | anno: corresponding annotation. 64 | 65 | Returns: 66 | PIL.Image: Randomly flipped image + annotation. 67 | """ 68 | if random.random() < 0.5: 69 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 70 | width = img.width 71 | for obj in anno: 72 | bbox = obj['bbox'] 73 | new_left = width - bbox[2] 74 | new_right = width - bbox[0] 75 | bbox[0] = new_left 76 | bbox[2] = new_right 77 | return img, anno 78 | 79 | 80 | class ComposeVariadic: 81 | """ 82 | Composes several transforms together. Processes lists of image + annotation + whatever. 83 | 84 | Args: 85 | transforms (list of ``Transform`` objects): list of transforms to compose. 86 | """ 87 | 88 | def __init__(self, transforms): 89 | self.transforms = transforms 90 | 91 | def __call__(self, *args): 92 | for t in self.transforms: 93 | args = t(*args) 94 | return args 95 | 96 | 97 | class MapImageAndAnnoToInputWindow: 98 | """Transform to map any image + anno to network input format.""" 99 | 100 | def __init__(self, input_resolution): 101 | self.input_resolution = input_resolution 102 | self.transform = transforms.Compose([ 103 | transforms.Resize(input_resolution, interpolation=Image.BILINEAR), 104 | transforms.ToTensor(), 105 | normalize_image(), 106 | ]) 107 | 108 | def __call__(self, img, anno): 109 | img_out = self.transform(img) 110 | if anno is not None: 111 | anno_out = [] 112 | for obj in anno: 113 | obj_out = { 114 | 'type': obj['type'], 115 | 'bbox': obj['bbox'] / ([img.width, img.height] * 2) # left top right bottom 116 | } 117 | anno_out.append(obj_out) 118 | else: 119 | anno_out = None 120 | return img_out, anno_out 121 | 122 | -------------------------------------------------------------------------------- /average_precision.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def iou_point_np(box, boxes): 5 | """ 6 | Find intersection over union 7 | :param box: (tensor) One box [xmin, ymin, xmax, ymax], shape: [4]. 8 | :param boxes: (tensor) Shape:[N, 4]. 9 | :return: intersection over union. Shape: [N] 10 | """ 11 | 12 | A = np.maximum(box[:2], boxes[:, :2]) 13 | B = np.minimum(box[2:], boxes[:, 2:]) 14 | interArea = np.maximum(B[:, 0] - A[:, 0], 0) * np.maximum(B[:, 1] - A[:, 1], 0) 15 | boxArea = (box[2] - box[0]) * (box[3] - box[1]) 16 | 17 | boxesArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 18 | union = boxArea + boxesArea - interArea 19 | iou = interArea / union 20 | return iou 21 | 22 | 23 | class AveragePrecision: 24 | """Average precision calculation using sort-and-iterate algorithm (VOC12)""" 25 | 26 | def __init__(self, labelmap, iou_threshold_perclass): 27 | """ 28 | Ctor. 29 | 30 | :param labelmap: list of strings - class names 31 | :param iou_threshold_perclass: intersection over union thresholds for each class 32 | """ 33 | 34 | self.labelmap = labelmap 35 | self.num_classes = len(labelmap) 36 | self.iou_threshold_perclass = iou_threshold_perclass 37 | self.annotation_list = [] 38 | self.detection_list = [] 39 | 40 | def add_batch(self, annotations, detections): 41 | """ 42 | Accumulate detection results and annotations from one batch. 43 | 44 | :param annotations: list [N] of list [C] of numpy arrays [Q, 4], where N - batch size, 45 | C - number of object classes (i.e. no including background), Q - quantity of annotated objects. 46 | Dimension of size 4 is decoded as a bbox in fractional left-top-right-bottom (LTRB) format. 47 | 48 | :param detections: list [N] of list [C] of numpy arrays [Q, 5], where N - batch size, 49 | C - number of object classes (i.e. no including background), Q - quantity of detected objects. 50 | Dimension of size 5 is decoded as [0] - confidence, [1:5] - bbox in fractional 51 | left-top-right-bottom (LTRB) format. 52 | """ 53 | 54 | self.annotation_list.extend(annotations) 55 | self.detection_list.extend(detections) 56 | 57 | def calculate_mAP(self): 58 | """Perform calculation of mAP and per-class APs""" 59 | 60 | AP_list = np.zeros((self.num_classes,), dtype=np.float64) 61 | 62 | for cls_idx in range(self.num_classes): 63 | 64 | true_positive_list = [] 65 | positive_list = [] 66 | conf_list = [] 67 | for det, anno in zip(self.detection_list, self.annotation_list): 68 | 69 | annotation = anno[cls_idx] 70 | prediction = det[cls_idx] 71 | iou_threshold = self.iou_threshold_perclass[cls_idx] 72 | 73 | if len(prediction) == 0: 74 | continue 75 | 76 | matched_gt = np.zeros((len(annotation),), dtype=np.int32) 77 | true_positives = np.zeros((len(prediction),), dtype=np.int32) 78 | 79 | predicted_confs = prediction[:, 0] 80 | predicted_boxes = prediction[:, 1:] 81 | 82 | for idx, true_bbox in enumerate(annotation): 83 | 84 | iou = iou_point_np(true_bbox, predicted_boxes) 85 | 86 | # find matching 87 | iou_max = np.max(iou) 88 | if iou_max > iou_threshold: 89 | matched_gt[idx] = 1 90 | true_positives[np.argmax(iou)] = 1 91 | 92 | true_positive_list.append(true_positives) 93 | positive_list.append(len(annotation)) 94 | conf_list.append(predicted_confs) 95 | 96 | # end loop over images 97 | 98 | true_positive = np.concatenate(true_positive_list, axis=0) 99 | positive = np.array(positive_list, dtype=np.int).sum() 100 | conf = np.concatenate(conf_list, axis=0) 101 | 102 | idx_sort = np.argsort(-conf) 103 | fn = 1 - true_positive[idx_sort] 104 | true_positive = np.cumsum(true_positive[idx_sort]) 105 | false_negative = np.cumsum(fn) 106 | 107 | precision = true_positive / (true_positive + false_negative + 1e-4) 108 | recall = true_positive / (positive + 1e-4) 109 | AP_val = np.sum((recall[1:] - recall[:-1]) * precision[1:]) 110 | AP_list[cls_idx] = AP_val 111 | 112 | pass 113 | 114 | # end for cls_idx 115 | 116 | mAP = float(AP_list.mean()) 117 | 118 | return mAP, AP_list 119 | -------------------------------------------------------------------------------- /requirements_full.txt: -------------------------------------------------------------------------------- 1 | # packages in environment at /Users/dmitry/anaconda3/envs/pytorch_10: 2 | # 3 | # Name Version Build Channel 4 | blas 1.1 openblas conda-forge 5 | bzip2 1.0.6 1 conda-forge 6 | ca-certificates 2018.8.24 ha4d7672_0 conda-forge 7 | cairo 1.14.12 he56eebe_3 conda-forge 8 | certifi 2018.8.24 py36_1001 conda-forge 9 | cffi 1.11.5 py36h6174b99_1 10 | ffmpeg 4.0.2 ha6a6e2b_0 conda-forge 11 | fontconfig 2.13.0 h8c010e7_5 conda-forge 12 | freetype 2.8.1 hfa320df_1 conda-forge 13 | gettext 0.19.8.1 h1f1d5ed_1 conda-forge 14 | giflib 5.1.4 h470a237_1 conda-forge 15 | glib 2.55.0 h464dc38_2 conda-forge 16 | gmp 6.1.2 hfc679d8_0 conda-forge 17 | gnutls 3.5.19 h2a4e5f8_1 conda-forge 18 | graphite2 1.3.12 h7d4d677_1 conda-forge 19 | harfbuzz 1.8.5 h2bb21d5_0 conda-forge 20 | hdf5 1.10.2 hc401514_2 conda-forge 21 | icu 58.2 hfc679d8_0 conda-forge 22 | intel-openmp 2019.0 118 23 | jasper 1.900.1 hff1ad4c_5 conda-forge 24 | jpeg 9c h470a237_1 conda-forge 25 | libcxx 4.0.1 h579ed51_0 26 | libcxxabi 4.0.1 hebd6815_0 27 | libedit 3.1.20170329 hb402a30_2 28 | libffi 3.2.1 h475c297_4 29 | libgfortran 3.0.1 h93005f0_2 30 | libiconv 1.15 h470a237_3 conda-forge 31 | libopenblas 0.3.3 hdc02c5d_3 32 | libpng 1.6.34 he12f830_0 33 | libprotobuf 3.6.0 hd28b015_0 conda-forge 34 | libtiff 4.0.9 hcb84e12_2 35 | libwebp 0.5.2 7 conda-forge 36 | libxml2 2.9.8 h422b904_5 conda-forge 37 | mkl 2019.0 118 38 | mkl_fft 1.0.6 py36_0 conda-forge 39 | mkl_random 1.0.1 py36_0 conda-forge 40 | ncurses 6.1 h0a44026_0 41 | nettle 3.3 0 conda-forge 42 | ninja 1.8.2 py36h04f5b5a_1 43 | numpy 1.15.2 py36_blas_openblashd3ea46f_1 [blas_openblas] conda-forge 44 | numpy-base 1.15.2 py36ha711998_1 45 | olefile 0.46 py36_0 46 | openblas 0.2.20 8 conda-forge 47 | opencv 3.4.3 py36_blas_openblash553dce0_200 [blas_openblas] conda-forge 48 | openh264 1.7.0 0 conda-forge 49 | openssl 1.0.2p h470a237_0 conda-forge 50 | pcre 8.41 hfc679d8_3 conda-forge 51 | pillow 5.2.0 py36h2dc6135_1 conda-forge 52 | pip 10.0.1 py36_0 53 | pixman 0.34.0 h470a237_3 conda-forge 54 | protobuf 3.6.0 py36hfc679d8_0 conda-forge 55 | pycparser 2.19 py36_0 56 | python 3.6.6 hc167b69_0 57 | pytorch 0.4.1 py36_cuda0.0_cudnn0.0_1 pytorch 58 | pytorch-nightly-cpu 1.0.0.dev20181005 py3.6_0 pytorch 59 | readline 7.0 h1de35cc_5 60 | setuptools 40.4.3 py36_0 61 | six 1.11.0 py36_1 62 | sqlite 3.25.2 ha441bb4_0 63 | tabulate 0.8.2 py36_0 64 | tensorboardx 1.4 py_0 conda-forge 65 | termcolor 1.1.0 py36_1 66 | tk 8.6.8 ha441bb4_0 67 | torchvision 0.2.1 py36_1 pytorch 68 | wheel 0.32.0 py36_0 69 | x264 1!152.20180717 h470a237_1 conda-forge 70 | xz 5.2.4 h1de35cc_4 71 | zlib 1.2.11 hf3cbc9b_2 72 | -------------------------------------------------------------------------------- /data_writer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Provides 'Writer', which which writes predictions to file.""" 4 | 5 | from __future__ import absolute_import, print_function 6 | import os 7 | import numpy as np 8 | 9 | 10 | class Writer: 11 | """Write results into same format as label .txt files.""" 12 | 13 | def __init__(self, data_path): 14 | """ 15 | Set the folder into which prediction file will be written to. 16 | """ 17 | if not os.path.exists(data_path): 18 | print("Data path %s created." % data_path) 19 | os.makedirs(data_path) 20 | self.data_path = data_path 21 | self._defaults = {'type': 'DontCare', 22 | 'truncated': np.nan, 23 | 'occluded': 3, 24 | 'alpha': np.nan, 25 | 'bbox': np.array([np.nan, np.nan, np.nan, np.nan]), 26 | 'dimensions': np.array([np.nan, np.nan, np.nan]), 27 | 'location': np.array([np.nan, np.nan, np.nan]), 28 | 'rotation_y': np.nan} 29 | 30 | def write(self, filename, labels): 31 | """ 32 | Function to write labels to file provided by filename (i.e. '000000.txt') 33 | labels, just like in Parser, is a list of dictionaries with the keys below. 34 | N.B. You need not provide all the keys! For example, if your task is to do 2D 35 | bounding box detection, you can simply add your predicted 'type' and 'bbox' to dict, 36 | and ignore the rest. These will be padded with 'nan'. 37 | 38 | #Values Key Description 39 | ---------------------------------------------------------------------------- 40 | 1 type Describes the type of object: 'Car', 'Van', 'Truck', 41 | 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 42 | 'Misc' or 'DontCare' 43 | 44 | 1 truncated Float from 0 (non-truncated) to 1 (truncated), where 45 | truncated refers to the object leaving image boundaries 46 | 47 | 1 occluded Integer (0,1,2,3) indicating occlusion state: 48 | 0 = fully visible, 1 = partly occluded 49 | 2 = largely occluded, 3 = unknown 50 | 51 | 1 alpha Observation angle of object, ranging [-pi..pi] 52 | 53 | 4 bbox 2D bounding box of object in the image (0-based index): 54 | contains left, top, right, bottom pixel coordinates 55 | 56 | 3 dimensions 3D object dimensions: height, width, length (in meters) 57 | 58 | 3 location 3D object location x,y,z in camera coordinates (in meters) 59 | 60 | 1 rotation_y Rotation ry around Y-axis in camera coordinates [-pi..pi] 61 | """ 62 | assert isinstance(filename, str) 63 | assert isinstance(labels, list) 64 | for l in labels: 65 | assert isinstance(l, dict) 66 | keys = ['type', 'truncated', 'occluded', 'alpha', 67 | 'bbox', 'dimensions', 'location', 'rotation_y'] 68 | 69 | with open(os.path.join(self.data_path, filename), 'w+') as f: 70 | for obj in labels: 71 | out = [] 72 | for key in keys: 73 | if key in obj: 74 | self._checkvalidity(key, obj[key]) 75 | out.append(self._tostring(obj[key])) 76 | else: 77 | out.append(self._tostring(self._getdefault(key))) 78 | line = ' '.join(out) + '\n' 79 | f.write(line) 80 | 81 | def _tostring(self, value): 82 | if isinstance(value, str): 83 | return value 84 | else: 85 | try: 86 | return ' '.join([str(x) for x in value]) 87 | except TypeError: 88 | return str(value) 89 | 90 | def _checkvalidity(self, key, value): 91 | if key == 'type': 92 | assert value in {'Car', 'Van', 'Truck', 'Pedestrian', 93 | 'Person_sitting', 'Cyclist', 'Tram', 'Misc', 'DontCare'} 94 | elif key == 'truncated': 95 | assert isinstance(value, float) 96 | assert value <= 1.0 97 | assert value >= 0.0 98 | elif key == 'occluded': 99 | assert isinstance(value, int) 100 | assert value in {0, 1, 2, 3} 101 | elif key == 'alpha': 102 | assert isinstance(value, float) 103 | assert value <= np.pi 104 | assert value >= -np.pi 105 | elif key == 'bbox': 106 | assert isinstance(value, np.ndarray) 107 | assert np.all(value >= 0) 108 | elif key == 'dimensions': 109 | assert isinstance(value, np.ndarray) 110 | assert np.all(value >= 0) 111 | elif key == 'location': 112 | assert isinstance(value, np.ndarray) 113 | assert np.all(value >= 0) 114 | elif key == 'rotation_y': 115 | assert isinstance(value, float) 116 | assert value <= np.pi 117 | assert value >= -np.pi 118 | else: 119 | raise IndexError 120 | 121 | def _getdefault(self, key): 122 | return self._defaults[key] 123 | -------------------------------------------------------------------------------- /name_list_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pickle 4 | 5 | import torch.utils.data as data 6 | 7 | import kitti_randomaccess 8 | from helpers import * 9 | 10 | 11 | class NameListDataset(data.Dataset): 12 | """Class to load the custom dataset with PyTorch's DataLoader""" 13 | 14 | def __init__( 15 | self, 16 | dataset_list, 17 | image_transform, 18 | image_and_anno_transform, 19 | map_to_network_input, 20 | build_target 21 | ): 22 | """ 23 | Ctor. 24 | 25 | :param dataset_list: list of strings 26 | :param image_transform: transformer for images only (anno not altered) 27 | :param image_and_anno_transform: transformer that alters anno as well 28 | :param map_to_network_input: transformer to input tensor 29 | :param build_target: functor to build target from anno 30 | """ 31 | 32 | self.dataset_list = dataset_list 33 | self.image_transform = image_transform 34 | self.image_and_anno_transform = image_and_anno_transform 35 | self.map_to_network_input = map_to_network_input 36 | self.build_target = build_target 37 | 38 | self._is_pil_image = True 39 | self.data_path = self.get_data_path() 40 | self.image_path = self.get_image_path() 41 | self.velo_path = os.path.join(self.data_path, 'velodyne') 42 | self.calib_path = os.path.join(self.data_path, 'calib') 43 | self.label_path = os.path.join(self.data_path, 'label_2') 44 | 45 | pass 46 | 47 | @staticmethod 48 | def get_data_path(): 49 | return 'kitti/training/' 50 | 51 | @staticmethod 52 | def get_image_path(): 53 | return os.path.join(NameListDataset.get_data_path(), 'image_2') 54 | 55 | @staticmethod 56 | def list_all_images(): 57 | """Scan over all samples in the dataset""" 58 | 59 | print('Start generation of a file list') 60 | 61 | names = [] 62 | for root, _, fnames in sorted(os.walk(NameListDataset.get_image_path())): 63 | for fname in sorted(fnames): 64 | if is_image_file(fname): 65 | # path = os.path.join(root, fname) 66 | nameonly = os.path.splitext(fname)[0] 67 | names.append(nameonly) 68 | 69 | print('End generation of a file list') 70 | 71 | return names 72 | 73 | @staticmethod 74 | def train_val_split(image_list, train_val_split_dir, fraction_for_val=0.05): 75 | """Prepare file lists for training and validation.""" 76 | 77 | train_num = int(len(image_list) * (1.0 - fraction_for_val)) 78 | train_list = image_list[:train_num] 79 | val_list = image_list[train_num:] 80 | 81 | def save_object(name, obj): 82 | path = os.path.join(train_val_split_dir, name + '.pkl') 83 | with open(path, 'wb') as output: 84 | pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL) 85 | 86 | save_object('train_list', train_list) 87 | save_object('val_list', val_list) 88 | 89 | pass 90 | 91 | @staticmethod 92 | def getLabelmap(): 93 | return ['Car', 'Van', 'Truck', 'Pedestrian', 'Person_sitting', 'Cyclist', 'Tram', 'Misc'] 94 | 95 | @staticmethod 96 | def leave_required_fields(anno): 97 | required_fields = ['type', 'bbox'] 98 | anno_out = [] 99 | for obj in anno: 100 | if obj['type'] != 'DontCare': 101 | obj_out = {} 102 | for f in obj.items(): 103 | if f[0] in required_fields: 104 | obj_out[f[0]] = f[1] 105 | anno_out.append(obj_out) 106 | return anno_out 107 | 108 | def __getitem__(self, index): 109 | """ 110 | Args: 111 | index (int): Index of a sample 112 | 113 | Returns: 114 | input_tensor: tensor to feed into neural network 115 | built_target: target tuple of tensors for loss calculation 116 | name: string name of the sample 117 | image: PIL image (to render overlays) 118 | anno: annotation prior to encoding 119 | stats: debug information (number of anchor overlaps for every GT box) 120 | """ 121 | 122 | # t0 = time.time() 123 | 124 | name = self.dataset_list[index] 125 | image, velo, calib, anno = self._getitem(name) 126 | 127 | anno = self.leave_required_fields(anno) 128 | 129 | if self.image_transform is not None: 130 | image = self.image_transform(image) 131 | 132 | if self.image_and_anno_transform is not None: 133 | image, anno = self.image_and_anno_transform(image, anno) 134 | 135 | # print("image_and_anno_transform=", time.time()-t0) 136 | 137 | # t1 = time.time() 138 | input_tensor, anno = self.map_to_network_input(image, anno) 139 | # print("map_to_network_input=", time.time()-t1) 140 | 141 | # t2 = time.time() 142 | built_target, stats = self.build_target(anno) 143 | # print("build_target=", time.time()-t2) 144 | 145 | return input_tensor, built_target, name, image, anno, stats 146 | 147 | def _getitem(self, name, load_image=True, load_velodyne=False, load_calib=True, load_label=True): 148 | image = None 149 | if load_image: 150 | path = os.path.join(self.image_path, name+'.png') 151 | if self._is_pil_image: 152 | image = kitti_randomaccess.get_image_pil(path) 153 | else: 154 | image = kitti_randomaccess.get_image(path) 155 | 156 | velo = None 157 | if load_velodyne: 158 | path = os.path.join(self.velo_path, name+'.bin') 159 | velo = kitti_randomaccess.get_velo_scan(path) 160 | 161 | calib = None 162 | if load_calib: 163 | path = os.path.join(self.calib_path, name+'.txt') 164 | calib = kitti_randomaccess.get_calib(path) 165 | 166 | label = None 167 | if load_label: 168 | path = os.path.join(self.label_path, name+'.txt') 169 | label = kitti_randomaccess.get_label(path) 170 | 171 | return image, velo, calib, label 172 | 173 | def __len__(self): 174 | """ 175 | Args: 176 | none 177 | 178 | Returns: 179 | int: number of images in the dataset 180 | """ 181 | 182 | return len(self.dataset_list) 183 | 184 | 185 | -------------------------------------------------------------------------------- /imagenet_models.py: -------------------------------------------------------------------------------- 1 | # Borrowed and modified torchvision/models/resnet.py 2 | # See comments starting from "Dmitrii Khizbullin:" 3 | 4 | import torch.nn as nn 5 | import math 6 | import torch.utils.model_zoo as model_zoo 7 | from torch.nn.parameter import Parameter 8 | 9 | 10 | __all__ = ['resnet18_backbone', 'resnet34_backbone', 'resnet50_backbone'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | "3x3 convolution with padding" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNetBackbone(nn.Module): 98 | 99 | def __init__(self, block, layers, channel_config=(1, 2, 4, 8), channel_multiplier=64): 100 | assert len(channel_config) == 4 101 | self.inplanes = channel_multiplier 102 | super().__init__() 103 | self.conv1 = nn.Conv2d(3, channel_multiplier, kernel_size=7, stride=2, padding=3, 104 | bias=False) 105 | self.bn1 = nn.BatchNorm2d(channel_multiplier) 106 | self.relu = nn.ReLU(inplace=True) 107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 108 | self.layer1 = self._make_layer(block, channel_multiplier*channel_config[0], layers[0]) 109 | self.layer2 = self._make_layer(block, channel_multiplier*channel_config[1], layers[1], stride=2) 110 | self.layer3 = self._make_layer(block, channel_multiplier*channel_config[2], layers[2], stride=2) 111 | self.layer4 = self._make_layer(block, channel_multiplier*channel_config[3], layers[3], stride=2) 112 | # Dmitrii Khizbullin: remove classification head 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | nn.Conv2d(self.inplanes, planes * block.expansion, 119 | kernel_size=1, stride=stride, bias=False), 120 | nn.BatchNorm2d(planes * block.expansion), 121 | ) 122 | 123 | layers = [] 124 | layers.append(block(self.inplanes, planes, stride, downsample)) 125 | self.inplanes = planes * block.expansion 126 | for i in range(1, blocks): 127 | layers.append(block(self.inplanes, planes)) 128 | 129 | return nn.Sequential(*layers) 130 | 131 | def forward(self, x): 132 | x = self.conv1(x) 133 | x = self.bn1(x) 134 | x = self.relu(x) 135 | x = self.maxpool(x) 136 | 137 | branches = [] 138 | x = self.layer1(x) 139 | branches.append(x) 140 | x = self.layer2(x) 141 | branches.append(x) 142 | x = self.layer3(x) 143 | branches.append(x) 144 | x = self.layer4(x) 145 | branches.append(x) 146 | 147 | return branches 148 | 149 | 150 | def load_state_dict(self, state_dict): 151 | """Copies parameters and buffers from :attr:`state_dict` into 152 | this module and its descendants. The keys of :attr:`state_dict` must 153 | exactly match the keys returned by this module's :func:`state_dict()` 154 | function. 155 | 156 | Arguments: 157 | state_dict (dict): A dict containing parameters and 158 | persistent buffers. 159 | """ 160 | own_state = self.state_dict() 161 | for name, param in state_dict.items(): 162 | if name not in own_state: 163 | raise KeyError('unexpected key "{}" in state_dict' 164 | .format(name)) 165 | if isinstance(param, Parameter): 166 | # backwards compatibility for serialized parameters 167 | param = param.data 168 | try: 169 | own_state[name].copy_(param) 170 | except: 171 | # Dmitrii Khizbullin: skip weights if they cannot be loaded 172 | 173 | #print('While copying the parameter named {}, whose dimensions in the model are' 174 | # ' {} and whose dimensions in the checkpoint are {}, ...'.format( 175 | # name, own_state[name].size(), param.size())) 176 | #raise CustomException 177 | pass 178 | 179 | missing = set(own_state.keys()) - set(state_dict.keys()) 180 | if len(missing) > 0: 181 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 182 | 183 | 184 | def resnet18_backbone(pretrained=False, **kwargs): 185 | """Constructs a ResNet-18 backbone. 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | model = ResNetBackbone(BasicBlock, [2, 2, 2, 2], **kwargs) 191 | if pretrained: 192 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 193 | return model 194 | 195 | 196 | def resnet34_backbone(pretrained=False, **kwargs): 197 | """Constructs a ResNet-34 backbone. 198 | 199 | Args: 200 | pretrained (bool): If True, returns a model pre-trained on ImageNet 201 | """ 202 | model = ResNetBackbone(BasicBlock, [3, 4, 6, 3], **kwargs) 203 | if pretrained: 204 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False) 205 | return model 206 | 207 | 208 | def resnet50_backbone(pretrained=False, **kwargs): 209 | """Constructs a ResNet-50 backbone. 210 | 211 | Args: 212 | pretrained (bool): If True, returns a model pre-trained on ImageNet 213 | """ 214 | model = ResNetBackbone(Bottleneck, [3, 4, 6, 3], **kwargs) 215 | if pretrained: 216 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) 217 | return model 218 | 219 | 220 | -------------------------------------------------------------------------------- /box_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # From https://github.com/amdegroot/ssd.pytorch/blob/master/layers/box_utils.py 4 | # with minor modifications. 5 | 6 | import torch 7 | 8 | 9 | def point_form(boxes): 10 | """ Convert prior_boxes to (xmin, ymin, xmax, ymax) 11 | representation for comparison to point form ground truth data. 12 | Args: 13 | boxes: (tensor) center-size default boxes from priorbox layers. 14 | Return: 15 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 16 | """ 17 | return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin 18 | boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax 19 | 20 | 21 | def center_size(boxes): 22 | """ Convert prior_boxes to (cx, cy, w, h) 23 | representation for comparison to center-size form ground truth data. 24 | Args: 25 | boxes: (tensor) point_form boxes 26 | Return: 27 | boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. 28 | """ 29 | return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy 30 | boxes[:, 2:] - boxes[:, :2], 1) # w, h 31 | 32 | 33 | def intersect(box_a, box_b): 34 | """ We resize both tensors to [A,B,2] without new malloc: 35 | [A,2] -> [A,1,2] -> [A,B,2] 36 | [B,2] -> [1,B,2] -> [A,B,2] 37 | Then we compute the area of intersect between box_a and box_b. 38 | Args: 39 | box_a: (tensor) bounding boxes, Shape: [A,4]. 40 | box_b: (tensor) bounding boxes, Shape: [B,4]. 41 | Return: 42 | (tensor) intersection area, Shape: [A,B]. 43 | """ 44 | A = box_a.size(0) 45 | B = box_b.size(0) 46 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 47 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 48 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 49 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 50 | inter = torch.clamp((max_xy - min_xy), min=0) 51 | return inter[:, :, 0] * inter[:, :, 1] 52 | 53 | 54 | def jaccard(box_a, box_b): 55 | """Compute the jaccard overlap of two sets of boxes. The jaccard overlap 56 | is simply the intersection over union of two boxes. Here we operate on 57 | ground truth boxes and default boxes. 58 | E.g.: 59 | A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) 60 | Args: 61 | box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] 62 | box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] 63 | Return: 64 | jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] 65 | """ 66 | inter = intersect(box_a, box_b) 67 | area_a = ((box_a[:, 2]-box_a[:, 0]) * 68 | (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] 69 | area_b = ((box_b[:, 2]-box_b[:, 0]) * 70 | (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] 71 | union = area_a + area_b - inter 72 | return inter / union # [A,B] 73 | 74 | 75 | def match(threshold, truths, priors, variances, labels): 76 | """Match each prior box with the ground truth box of the highest jaccard 77 | overlap, encode the bounding boxes, then return the matched indices 78 | corresponding to both confidence and location preds. 79 | Args: 80 | threshold: (float) The overlap threshold used when mathing boxes. 81 | truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. 82 | priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. 83 | variances: (tensor) Variances corresponding to each prior coord, 84 | Shape: [num_priors, 4]. 85 | labels: (tensor) All the class labels for the image, Shape: [num_obj]. 86 | Return: 87 | loc: (tensor) Tensor of endcoded location targets. 88 | cls: (tensor) Tensor of matched indices for conf preds. 89 | num_matches: (tensor) Tensor of quantities of matched anchors (for debug) 90 | """ 91 | # jaccard index 92 | overlaps = jaccard( 93 | truths, 94 | point_form(priors) 95 | ) 96 | # (Bipartite Matching) 97 | # [1,num_objects] best prior for each ground truth 98 | best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) 99 | # [1,num_priors] best ground truth for each prior 100 | best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) 101 | best_truth_idx.squeeze_(0) 102 | best_truth_overlap.squeeze_(0) 103 | best_prior_idx.squeeze_(1) 104 | best_prior_overlap.squeeze_(1) 105 | best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior 106 | # TODO refactor: index best_prior_idx with long tensor 107 | # ensure every gt matches with its prior of max overlap 108 | for j in range(best_prior_idx.size(0)): 109 | best_truth_idx[best_prior_idx[j]] = j 110 | matches = truths[best_truth_idx] # Shape: [num_priors,4] 111 | cls = labels[best_truth_idx] + 1 # Shape: [num_priors] 112 | cls[best_truth_overlap < threshold] = 0 # label as background 113 | loc = encode(matches, priors, variances) 114 | 115 | # Dmitrii Khizbullin: return numbers of matches as well 116 | matches = overlaps > threshold 117 | num_matches = matches.long().sum(dim=1) 118 | 119 | return loc, cls, num_matches 120 | 121 | 122 | def encode(matched, priors, variances): 123 | """Encode the variances from the priorbox layers into the ground truth boxes 124 | we have matched (based on jaccard overlap) with the prior boxes. 125 | Args: 126 | matched: (tensor) Coords of ground truth for each prior in point-form 127 | Shape: [num_priors, 4]. 128 | priors: (tensor) Prior boxes in center-offset form 129 | Shape: [num_priors,4]. 130 | variances: (list[float]) Variances of priorboxes 131 | Return: 132 | encoded boxes (tensor), Shape: [num_priors, 4] 133 | """ 134 | 135 | # dist b/t match center and prior's center 136 | g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] 137 | # encode variance 138 | g_cxcy /= (variances[0] * priors[:, 2:]) 139 | # match wh / prior wh 140 | g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] 141 | g_wh = torch.log(g_wh) / variances[1] 142 | # return target for smooth_l1_loss 143 | return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] 144 | 145 | 146 | # Adapted from https://github.com/Hakuyume/chainer-ssd 147 | def decode(loc, priors, variances): 148 | """Decode locations from predictions using priors to undo 149 | the encoding we did for offset regression at train time. 150 | Args: 151 | loc (tensor): location predictions for loc layers, 152 | Shape: [num_priors,4] 153 | priors (tensor): Prior boxes in center-offset form. 154 | Shape: [num_priors,4]. 155 | variances: (list[float]) Variances of priorboxes 156 | Return: 157 | decoded bounding box predictions 158 | """ 159 | 160 | boxes = torch.cat(( 161 | priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], 162 | priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) 163 | boxes[:, :2] -= boxes[:, 2:] / 2 164 | boxes[:, 2:] += boxes[:, :2] 165 | return boxes 166 | 167 | 168 | def log_sum_exp(x): 169 | """Utility function for computing log_sum_exp while determining 170 | This will be used to determine unaveraged confidence loss across 171 | all examples in a batch. 172 | Args: 173 | x (Variable(tensor)): conf_preds from conf layers 174 | """ 175 | x_max = x.data.max() 176 | return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max 177 | 178 | 179 | # Original author: Francisco Massa: 180 | # https://github.com/fmassa/object-detection.torch 181 | # Ported to PyTorch by Max deGroot (02/01/2017) 182 | def nms(boxes, scores, overlap=0.5, top_k=200): 183 | """Apply non-maximum suppression at test time to avoid detecting too many 184 | overlapping bounding boxes for a given object. 185 | Args: 186 | boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. 187 | scores: (tensor) The class predscores for the img, Shape:[num_priors]. 188 | overlap: (float) The overlap thresh for suppressing unnecessary boxes. 189 | top_k: (int) The Maximum number of box preds to consider. 190 | Return: 191 | The indices of the kept boxes with respect to num_priors. 192 | """ 193 | 194 | keep = scores.new(scores.size(0)).zero_().long() 195 | if boxes.numel() == 0: 196 | return keep, 0 197 | x1 = boxes[:, 0] 198 | y1 = boxes[:, 1] 199 | x2 = boxes[:, 2] 200 | y2 = boxes[:, 3] 201 | area = torch.mul(x2 - x1, y2 - y1) 202 | v, idx = scores.sort(0) # sort in ascending order 203 | # I = I[v >= 0.01] 204 | idx = idx[-top_k:] # indices of the top-k largest vals 205 | xx1 = boxes.new() 206 | yy1 = boxes.new() 207 | xx2 = boxes.new() 208 | yy2 = boxes.new() 209 | w = boxes.new() 210 | h = boxes.new() 211 | 212 | # keep = torch.Tensor() 213 | count = 0 214 | while idx.numel() > 0: 215 | i = idx[-1] # index of current largest val 216 | # keep.append(i) 217 | keep[count] = i 218 | count += 1 219 | if idx.size(0) == 1: 220 | break 221 | idx = idx[:-1] # remove kept element from view 222 | # load bboxes of next highest vals 223 | torch.index_select(x1, 0, idx, out=xx1) 224 | torch.index_select(y1, 0, idx, out=yy1) 225 | torch.index_select(x2, 0, idx, out=xx2) 226 | torch.index_select(y2, 0, idx, out=yy2) 227 | # store element-wise max with next highest score 228 | xx1 = torch.clamp(xx1, min=x1[i]) 229 | yy1 = torch.clamp(yy1, min=y1[i]) 230 | xx2 = torch.clamp(xx2, max=x2[i]) 231 | yy2 = torch.clamp(yy2, max=y2[i]) 232 | w.resize_as_(xx2) 233 | h.resize_as_(yy2) 234 | w = xx2 - xx1 235 | h = yy2 - yy1 236 | # check sizes of xx1 and xx2.. after each iteration 237 | w = torch.clamp(w, min=0.0) 238 | h = torch.clamp(h, min=0.0) 239 | inter = w*h 240 | # IoU = i / (area(a) + area(b) - i) 241 | rem_areas = torch.index_select(area, 0, idx) # load remaining areas) 242 | union = (rem_areas - inter) + area[i] 243 | IoU = inter/union # store result in iou 244 | # keep only elements with an IoU <= overlap 245 | idx = idx[IoU.le(overlap)] 246 | return keep, count 247 | -------------------------------------------------------------------------------- /detection_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import importlib 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | # import imagenet_models as models 12 | import box_utils 13 | import decode_detection 14 | 15 | 16 | class ConvLayer(nn.Module): 17 | """Basic convolution block of ResNet.""" 18 | def __init__(self, in_planes, out_planes, kernel_size, padding, stride): 19 | super().__init__() 20 | self.conv = nn.Conv2d( 21 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 22 | padding=padding, bias=False) 23 | self.bn = nn.BatchNorm2d(out_planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | def forward(self, x): 27 | out = self.conv(x) 28 | out = self.bn(out) 29 | out = self.relu(out) 30 | return out 31 | 32 | 33 | class ExtraBlock(nn.Module): 34 | """Basic block of extra layers.""" 35 | 36 | def __init__(self, in_channels, out_channels): 37 | super().__init__() 38 | bottleneck = out_channels // 2 39 | self.conv1 = ConvLayer(in_channels, bottleneck, 1, 0, 1) 40 | self.conv2 = ConvLayer(bottleneck, out_channels, 3, 1, 2) 41 | 42 | def forward(self, x): 43 | out = self.conv1(x) 44 | out = self.conv2(out) 45 | return out 46 | 47 | 48 | class ExtraLayers(nn.Module): 49 | """Extra layers in VGG style to finish feature pyramid.""" 50 | 51 | def __init__(self, cfg, in_channels): 52 | super().__init__() 53 | 54 | layers = [] 55 | for k, out_channels in enumerate(cfg): 56 | extra_block = ExtraBlock(in_channels, out_channels) 57 | layers.append(extra_block) 58 | in_channels = out_channels 59 | 60 | self.blocks = nn.Sequential(*layers) 61 | 62 | def forward(self, x): 63 | branches = [] 64 | for block in self.blocks: 65 | x = block(x) 66 | branches.append(x) 67 | return branches 68 | 69 | 70 | class MultiboxLayers(nn.Module): 71 | """Detection head of SSD. It is customized for regular anchors across pyramid's scales.""" 72 | 73 | def __init__(self, final_branch_channels, labelmap, use_ohem=False): 74 | super().__init__() 75 | 76 | self.final_branch_channels = final_branch_channels 77 | self.labelmap = labelmap 78 | self.num_classes = len(self.labelmap) 79 | self.use_ohem = use_ohem 80 | 81 | sizes = [math.pow(2.0, 0.25), math.pow(2.0, 0.75)] 82 | self.iou_anchor_and_gt = 0.3 83 | 84 | aspect_ratios = [1.0, 1/2.0, 2.0] # AR=width/height 85 | 86 | anchor_list = [] 87 | for size in sizes: 88 | for aspect_ratio in aspect_ratios: 89 | anchor_list.append(np.array([size * aspect_ratio, size / aspect_ratio], dtype=np.float32)) 90 | self.anchors = np.stack(anchor_list, axis=0) 91 | 92 | self.variances = (0.1, 0.2) # (cxcy loc scale, wh loc scale) 93 | 94 | self.num_anchors_per_cell = 6 95 | self.bbox_regresion_size = 4 96 | self.detection_size = self.bbox_regresion_size + self.num_classes + 1 97 | self.cell_num_channels = self.num_anchors_per_cell*self.detection_size 98 | 99 | for i in range(len(self.final_branch_channels)): 100 | name = "multibox_branch_{}".format(i) 101 | conv = nn.Conv2d(final_branch_channels[i], self.cell_num_channels, kernel_size=3, padding=1) 102 | self.add_module(name, conv) 103 | 104 | self.detect_functor = decode_detection.Detect(self.num_classes, 0, 200, 0.45, self.variances) 105 | 106 | self.branch_resolutions = None 107 | 108 | pass 109 | 110 | def forward(self, tensor_list, is_probe=False): 111 | """ 112 | Forward. 113 | :param tensor_list: list of tensors (feature maps) to which to apply loc & cls regression 114 | :param is_probe: if the pass is a test one to generate anchors. Must be run once. 115 | :return: joint tensor of all encoded detections 116 | """ 117 | 118 | encoded_branches = [] 119 | for i, in_tensor in zip(range(len(self.final_branch_channels)), tensor_list): 120 | name = "multibox_branch_{}".format(i) 121 | conv = self.__getattr__(name) 122 | encoded_tensor = conv(in_tensor) 123 | encoded_branches.append(encoded_tensor) 124 | 125 | if is_probe: 126 | self.branch_resolutions = [v.size()[2:] for v in encoded_branches] 127 | self._generate_anchors() 128 | 129 | single_tensor = self._reshape_and_concat(encoded_branches) 130 | 131 | return single_tensor 132 | 133 | def _generate_anchors(self): 134 | """Generate anchors according to number of branches and featuremap resolutions.""" 135 | 136 | anchor_list = [] 137 | for resolution in self.branch_resolutions: 138 | for anchor in self.anchors: 139 | height = resolution[0] 140 | width = resolution[1] 141 | cell_width = 1.0 / width 142 | cell_height = 1.0 / height 143 | for row in range(height): 144 | for col in range(width): 145 | anchor_cx = (col + 0.5) * cell_width 146 | anchor_cy = (row + 0.5) * cell_height 147 | anchor_width = anchor[0] * cell_width 148 | anchor_height = anchor[1] * cell_height 149 | 150 | anchor_cxcywh = np.array([anchor_cx, anchor_cy, anchor_width, anchor_height], dtype=np.float32) 151 | anchor_list.append(anchor_cxcywh) 152 | anchors_cxcywh = np.stack(anchor_list, axis=0) 153 | anchors_cxcywh = torch.from_numpy(anchors_cxcywh) 154 | self.anchors_cxcywh = anchors_cxcywh 155 | self.register_buffer("anchors_cxcywh_cuda", anchors_cxcywh.clone()) 156 | pass 157 | 158 | def _reshape_and_concat(self, encoded_branches): 159 | """Transform separate branch outputs to a joint tensor.""" 160 | 161 | reshaped_branches = [] 162 | for branch in encoded_branches: 163 | s = branch.size() 164 | b = s[0] 165 | a = self.num_anchors_per_cell 166 | d = self.detection_size 167 | assert s[1] == a * d 168 | h = s[2] 169 | w = s[3] 170 | branch = branch.view(b, a, d, h, w) 171 | branch = branch.permute(0, 3, 4, 1, 2).contiguous() 172 | branch = branch.view(b, h*w*a, d) 173 | 174 | reshaped_branches.append(branch) 175 | 176 | encoded_tensor = torch.cat(reshaped_branches, dim=1) 177 | return encoded_tensor 178 | 179 | def build_target(self, anno): 180 | """ 181 | Building a target for loss calculation is incapsulated into the detection model class. 182 | Method to be called outside - in data loader threads. Must have no side effects on self object. 183 | 184 | :param anno: list of boxes with class ids 185 | :return: 186 | (loc, cls): encoded target: location regression and classification class 187 | loc: float tensor of shape (A, 4), A - total number of anchors 188 | cls: int tensor of shape (A,) of class labels, where 0 - background, 1 - class 0, etc 189 | matches: statistics of coverage of GT boxes by anchors 190 | """ 191 | 192 | anno = self._anno_class_names_to_ids(anno) 193 | 194 | if len(anno) > 0: 195 | gt_boxes = np.stack([obj['bbox'] for obj in anno], axis=0) 196 | gt_classes = np.stack([obj['class_id'] for obj in anno], axis=0).astype(np.int32) 197 | else: 198 | gt_boxes = np.zeros((0, 4), dtype=np.float32) 199 | gt_classes = np.zeros((0,), dtype=np.int32) 200 | 201 | gt_boxes = torch.from_numpy(gt_boxes) 202 | gt_classes = torch.from_numpy(gt_classes).long() 203 | 204 | loc, cls, matches = box_utils.match(self.iou_anchor_and_gt, gt_boxes, 205 | self.anchors_cxcywh, self.variances, gt_classes) 206 | 207 | return (loc, cls), matches 208 | 209 | def calculate_loss(self, encoded_prediction, encoded_target): 210 | """ 211 | Calculate total classification & localization loss of SSD. 212 | 213 | :param encoded_prediction: tensor [N, A, D], N-batch size, A-total number of anchors, D-detection size 214 | :param encoded_target: pair of (loc, cls). 215 | loc: shape [N, A, R], R - bbox regression size = 4 216 | cls: shape [N, A, C], C - number of classes including background 217 | :return: 218 | loss: loss variable to optimize 219 | losses: dict of scalars to post to graphs 220 | """ 221 | 222 | pred_xywh = encoded_prediction[:, :, 0:4].contiguous() 223 | pred_class = encoded_prediction[:, :, 4:].contiguous() 224 | assert pred_class.shape[2] == 1 + self.num_classes 225 | 226 | target_xywh = encoded_target[0] 227 | target_class_indexes = encoded_target[1] 228 | if torch.cuda.is_available(): 229 | target_xywh = target_xywh.cuda() 230 | target_class_indexes = target_class_indexes.cuda() 231 | 232 | # determine positives 233 | bbox_matches_byte = target_class_indexes > 0 234 | bbox_matches = bbox_matches_byte.long() 235 | batch_size = bbox_matches.size(0) 236 | num_matches = bbox_matches.sum().item() 237 | 238 | # bbox loss only for positives 239 | bbox_mask = bbox_matches_byte.unsqueeze(2).expand_as(pred_xywh) 240 | bbox_denom = max(num_matches, 1) 241 | loc_loss = F.smooth_l1_loss(pred_xywh[bbox_mask], target_xywh[bbox_mask], reduction='sum') / bbox_denom 242 | 243 | pred_class_flat = pred_class.view(-1, pred_class.shape[-1]) 244 | 245 | target_class_indexes_flat = target_class_indexes.view(-1) 246 | 247 | # calculate cls losses for positives and negative without reduction 248 | cls_loss_vec = F.cross_entropy(pred_class_flat, target_class_indexes_flat, reduction='none') 249 | cls_loss_vec = cls_loss_vec.view(batch_size, -1) 250 | 251 | if self.use_ohem: 252 | # Online hard sample mining (OHEM) 253 | neg_to_pos_ratio = 3 # the same as in the original SSD 254 | virtual_min_positive_matches = 100 # value for NN to learn on images without annotations 255 | 256 | # determine negatives with biggest loss 257 | cls_loss_neg = cls_loss_vec * (bbox_matches_byte.float() - 1.0) 258 | _, idx = cls_loss_neg.sort(1) 259 | _, rank_idxes = idx.sort(1) 260 | num_pos = bbox_matches.sum(1) 261 | num_neg = neg_to_pos_ratio * num_pos 262 | neg_idx = rank_idxes < num_neg[:, None] 263 | 264 | # combine losses from positives and negatives 265 | num_bbox_matches = bbox_matches.sum(dim=1) 266 | contributors_to_loss_mask = bbox_matches_byte | neg_idx 267 | contributors_to_loss_mask = contributors_to_loss_mask.float() 268 | contributors_to_loss = cls_loss_vec * contributors_to_loss_mask.float() 269 | cls_loss_batch_total = contributors_to_loss.sum(dim=1) 270 | cls_loss_total = cls_loss_batch_total.sum() 271 | num_bbox_matches_total = num_bbox_matches.sum() 272 | cls_denom = max(num_bbox_matches_total.float().item(), virtual_min_positive_matches) 273 | cls_loss = cls_loss_total / cls_denom 274 | pass 275 | else: 276 | # Average loss over all anchors (worse convergence than with OHEM) 277 | cls_loss = cls_loss_vec.sum() / cls_loss_vec.shape[1] 278 | 279 | loc_loss_mult = 1.0 #0.2 280 | cls_loss_mult = 1.0 if self.use_ohem else 8.0 281 | loc_loss_weighted = loc_loss_mult * loc_loss 282 | cls_loss_weighted = cls_loss_mult * cls_loss 283 | loss = loc_loss_weighted + cls_loss_weighted 284 | 285 | loss_details = { 286 | "loc_loss": loc_loss_weighted, 287 | "cls_loss": cls_loss_weighted, 288 | "loss": loss 289 | } 290 | loss_details = {name: float(var.item()) for (name, var) in loss_details.items()} 291 | 292 | return loss, loss_details 293 | 294 | def calculate_detections(self, encoded_tensor, threshold): 295 | """ 296 | 297 | :param encoded_tensor: tensor [N, A, D], N-batch size, A-total number of anchors, D-detection size 298 | :param threshold: minimum confidence threshold for generated detections 299 | :return: list [N] of list [C] of numpy arrays [Q, 5], where N - batch size, 300 | C - number of object classes (i.e. no including background), Q - quantity of detected objects. 301 | Dimention of size 5 is decoded as [0] - confidence, [1:5] - bbox in fractional 302 | left-top-right-bottom (LTRB) format. 303 | """ 304 | 305 | #encoded_tensor = encoded_tensor.cpu() 306 | 307 | loc_var = encoded_tensor[:, :, :4] 308 | conf_var = encoded_tensor[:, :, 4:] 309 | 310 | loc_data = loc_var.data 311 | conf_data = F.softmax(conf_var, dim=2).data 312 | 313 | conf_data = conf_data[:, :, 1:].contiguous() # throw away BG row after softmax 314 | 315 | anchors_cxcywh = self.anchors_cxcywh_cuda 316 | 317 | detections = self.detect_functor.forward(loc_data, conf_data, anchors_cxcywh, threshold) 318 | 319 | detections = detections.cpu().numpy() 320 | 321 | det_varsize = [] 322 | for s in detections: 323 | c_varsize = [] 324 | for c in s: 325 | c = c[c[:, 0] > 0.0] 326 | c_varsize.append(c) 327 | det_varsize.append(c_varsize) 328 | 329 | return det_varsize 330 | 331 | def _anno_class_names_to_ids(self, anno): 332 | anno_out = [] 333 | for obj in anno: 334 | obj_out = { 335 | 'class_id': self.labelmap.index(obj['type']), 336 | 'bbox': obj['bbox'].astype(np.float32) 337 | } 338 | anno_out.append(obj_out) 339 | return anno_out 340 | 341 | def export_model_to_caffe(self, input_resolution): 342 | """ 343 | Export to Caffe. 344 | """ 345 | 346 | sys.path.insert(0, os.path.join("~/git/pytorch2caffe/")) 347 | sys.path.insert(0, "~/git/caffe_ssd_py3/build/install/python/") 348 | from pytorch2caffe import pytorch2caffe 349 | 350 | input_var = torch.rand(1, 3, int(input_resolution[0]), int(input_resolution[1])) 351 | encoded_var = self(input_var) 352 | pytorch2caffe( 353 | input_var, encoded_var, 354 | 'model.prototxt', 355 | 'model.caffemodel') 356 | pass 357 | 358 | 359 | class SingleShotDetector(nn.Module): 360 | def __init__(self, backbone_specs, multibox_specs, input_resolution, labelmap): 361 | """ 362 | Ctor. 363 | 364 | :param input_resolution: input resolution (H, W) 365 | :param labelmap: list [C] of class name strings, where C - number of object classes (not including background) 366 | """ 367 | 368 | super().__init__() 369 | 370 | for c in input_resolution: 371 | assert c % 256 == 0 372 | 373 | self.labelmap = labelmap 374 | 375 | backbone_module = importlib.import_module(backbone_specs['backbone_module']) 376 | 377 | # Use Resnet-XX as a backbone 378 | backbone_create_func = getattr(backbone_module, backbone_specs['backbone_function']) 379 | self.backbone = backbone_create_func(**backbone_specs['kwargs']) 380 | channel_multiplier = backbone_specs['head_channel_multiplier'] 381 | 382 | self.backbone.eval() 383 | 384 | # probe backbone 385 | input_batch_shape = (1, 3, *input_resolution) 386 | input_tensor = torch.autograd.Variable(torch.rand(input_batch_shape)) 387 | backbone_out = self.backbone(input_tensor) 388 | backbone_last = backbone_out[-1] 389 | backbone_last_channels = backbone_last.shape[1] 390 | 391 | # create additional layers 392 | # extras_config = [512, 256, 256] 393 | extras_config = [v*channel_multiplier for v in (2, 2, 2)] 394 | self.extra_layers = ExtraLayers(extras_config, backbone_last_channels) 395 | self.extra_layers.eval() 396 | 397 | # probe extra layers 398 | extra_layers_out = self.extra_layers(backbone_last) 399 | 400 | # take only these last branches from backbone, all other branches come from additional layers 401 | self.num_last_backbone_branches = 3 402 | 403 | print("----- SSD branch configuration -----") 404 | for i, t in enumerate(backbone_out): 405 | print(t.shape, " <- branch" if len(backbone_out)-i <= self.num_last_backbone_branches else "") 406 | for t in extra_layers_out: 407 | print(t.shape, " <- branch") 408 | print("------------------------------------") 409 | 410 | # collect all branches in a tuple 411 | final_branches = (*backbone_out[-self.num_last_backbone_branches:], *extra_layers_out) 412 | final_branch_channels = [b.shape[1] for b in final_branches] 413 | 414 | # add multi-branch detection head on top of all branches 415 | self.multibox_layers = MultiboxLayers(final_branch_channels, self.labelmap, multibox_specs['use_ohem']) 416 | self.multibox_layers.eval() 417 | # probe multibox, save branch resolutions, generate anchors 418 | self.multibox_layers(final_branches, is_probe=True) 419 | 420 | if False: 421 | # probe the whole net 422 | encoded_tensor = self.forward(input_tensor) 423 | detections = self.get_detections(encoded_tensor, threshold=0.15) 424 | 425 | # export model 426 | if False: 427 | self.export_model_to_caffe(input_resolution) 428 | 429 | pass 430 | 431 | def forward(self, input_tensor_batch): 432 | """ 433 | Forward. 434 | 435 | :param input_tensor_batch: input image of shape [N, H, W, 3], where N - batch size, H - height, W - width 436 | :return: target - single tensor of shape [b=32, cat(flat_anchors=A*H*W, for all branches), D=4+1+num_classes] 437 | """ 438 | 439 | backbone_branches = self.backbone(input_tensor_batch) 440 | 441 | # automatically derive resolution for extra layers 442 | backbone_last_branch = backbone_branches[-1] 443 | extra_branches = self.extra_layers(backbone_last_branch) 444 | 445 | # collect all branch feature maps in a tuple 446 | final_branches = (*backbone_branches[-self.num_last_backbone_branches:], *extra_branches) 447 | 448 | encoded_tensor = self.multibox_layers(final_branches) 449 | 450 | return encoded_tensor 451 | 452 | def get_loss(self, encoded_tensor, target): 453 | """Get loss for optimization.""" 454 | return self.multibox_layers.calculate_loss(encoded_tensor, target) 455 | 456 | def get_detections(self, encoded_tensor, threshold): 457 | """Get bbox detections in finally decoded format.""" 458 | return self.multibox_layers.calculate_detections(encoded_tensor, threshold) 459 | 460 | def build_target(self, anno): 461 | # Forward to multibox component 462 | return self.multibox_layers.build_target(anno) -------------------------------------------------------------------------------- /detect2d.py: -------------------------------------------------------------------------------- 1 | # Author: Dmitrii Khizbullin 2 | # A script to train a neural network for 2d detection task. 3 | # Some code is borrowed from pytorch examples and torchvision. 4 | 5 | import os 6 | import sys 7 | import time 8 | import pickle 9 | import argparse 10 | import importlib 11 | import numpy as np 12 | from termcolor import colored 13 | 14 | import torch 15 | import torch.utils.data as data 16 | import torchvision 17 | import torchvision.transforms as transforms 18 | import torch.backends.cudnn as cudnn 19 | 20 | from name_list_dataset import NameListDataset 21 | from summary_writer_opt import SummaryWriterOpt 22 | from helpers import * 23 | from average_meter import AverageMeter 24 | import detection_models 25 | from extended_collate import extended_collate 26 | import image_anno_transforms 27 | import average_precision 28 | from debug_tools import dump_images 29 | 30 | 31 | def default_input_traits(): 32 | """ 33 | Default resolutions for training and evaluation. 34 | """ 35 | return { 36 | "resolution": (256, 512) 37 | #"resolution": (512, 1024) 38 | #"resolution": (384, 1152) 39 | #"resolution": (256, 768) 40 | #"resolution": (384, 768) 41 | } 42 | 43 | 44 | def train_image_transform(): 45 | """PIL image transformation for training.""" 46 | return image_anno_transforms.ComposeVariadic([ 47 | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), 48 | ]) 49 | 50 | 51 | def train_image_and_annotation_transform(): 52 | """Image+annotation synchronous transformation for training.""" 53 | return image_anno_transforms.ComposeVariadic([ 54 | image_anno_transforms.RandomHorizontalFlipWithAnno(), 55 | image_anno_transforms.RandomCropWithAnno(0.3), 56 | ]) 57 | 58 | 59 | class BuildTargetFunctor: 60 | """Functor to delegate model's target construction to preprocessing threads.""" 61 | 62 | def __init__(self, model): 63 | self.model = model 64 | 65 | def __call__(self, *args): 66 | return self.model.build_target(*args) 67 | 68 | 69 | def clip_gradient(model, clip_val, mode): 70 | """Clip the gradient.""" 71 | 72 | assert mode in ('by_max', 'by_norm') 73 | 74 | if mode is 'by_max': 75 | for p in model.parameters(): 76 | if p.grad is not None: 77 | mv = torch.max(torch.abs(p.grad.data)) 78 | if mv > clip_val: 79 | print(colored("Grad max {:.3f}".format(mv), "red")) 80 | p.grad.data.clamp_(-clip_val, clip_val) 81 | elif mode is 'by_norm': 82 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip_val) 83 | 84 | pass 85 | 86 | 87 | class Config: 88 | def __init__(self, attr_dict): 89 | for n, v in attr_dict.items(): 90 | self.__dict__[n] = v 91 | 92 | 93 | def import_config_by_name(config_name): 94 | config_module = importlib.import_module('configs.' + config_name) 95 | cfg_dict = {key: value for key, value in config_module.__dict__.items() if 96 | not (key.startswith('__') or key.startswith('_'))} 97 | cfg = Config(cfg_dict) 98 | return cfg 99 | 100 | 101 | class Trainer: 102 | """Class that performs train-validation loop to train a detection neural network.""" 103 | 104 | def __init__(self, config_name): 105 | """ 106 | Args: 107 | config_name: name of a configuration module to import 108 | """ 109 | 110 | print('Config name: {}'.format(config_name)) 111 | 112 | self.cfg = import_config_by_name(config_name) 113 | print(self.cfg) 114 | 115 | print('Start preparing dataset') 116 | self.prepare_dataset() 117 | print('Finished preparing dataset') 118 | 119 | print("torch.__version__=", torch.__version__) 120 | torchvision.set_image_backend('accimage') 121 | print("torchvision.get_image_backend()=", torchvision.get_image_backend()) 122 | 123 | self.epochs_to_train = 500 124 | self.base_learning_rate = 0.02 125 | self.lr_scales = ( 126 | (0, 0.1), # perform soft warm-up to reduce chance of divergence 127 | (2, 0.2), 128 | (4, 0.3), 129 | (6, 0.5), 130 | (8, 0.7), 131 | (10, 1.0), # main learning rate multiplier 132 | (int(0.90 * self.epochs_to_train), 0.1), 133 | (int(0.95 * self.epochs_to_train), 0.01), 134 | ) 135 | 136 | self.train_batch_size = 32 137 | self.val_batch_size = 32 138 | 139 | num_workers_train = 12 140 | num_workers_val = 12 141 | 142 | input_traits = default_input_traits() 143 | 144 | labelmap = NameListDataset.getLabelmap() 145 | 146 | model = detection_models.SingleShotDetector( 147 | self.cfg.backbone_specs, 148 | self.cfg.multibox_specs, 149 | input_traits['resolution'], 150 | labelmap) 151 | 152 | if True: 153 | model_dp = torch.nn.DataParallel(model) 154 | cudnn.benchmark = True 155 | else: 156 | model_dp = model 157 | 158 | if torch.cuda.is_available(): 159 | model_dp.cuda() 160 | 161 | self.model = model 162 | self.model_dp = model_dp 163 | 164 | build_target = BuildTargetFunctor(model) 165 | map_to_network_input = image_anno_transforms.MapImageAndAnnoToInputWindow(input_traits['resolution']) 166 | 167 | def load_list(name): 168 | path = os.path.join(self.cfg.train_val_split_dir, name + '.pkl') 169 | with open(path, 'rb') as input: 170 | return pickle.load(input) 171 | 172 | self.train_dataset = NameListDataset( 173 | dataset_list=load_list('train_list'), 174 | image_transform=train_image_transform(), 175 | image_and_anno_transform=train_image_and_annotation_transform(), 176 | map_to_network_input=map_to_network_input, 177 | build_target=build_target 178 | ) 179 | 180 | self.balanced_val_dataset = NameListDataset( 181 | dataset_list=load_list('val_list'), 182 | image_transform=None, 183 | image_and_anno_transform=None, 184 | map_to_network_input=map_to_network_input, 185 | build_target=build_target 186 | ) 187 | 188 | # Data loading and augmentation pipeline for training 189 | self.train_loader = torch.utils.data.DataLoader( 190 | self.train_dataset, batch_size=self.train_batch_size, shuffle=True, 191 | num_workers=num_workers_train, collate_fn=extended_collate, pin_memory=True) 192 | 193 | # Data loading and augmentation pipeline for validation 194 | self.val_loader = torch.utils.data.DataLoader( 195 | self.balanced_val_dataset, batch_size=self.val_batch_size, shuffle=False, 196 | num_workers=num_workers_val, collate_fn=extended_collate, pin_memory=True) 197 | 198 | self.optimizer = None 199 | self.learning_rate = None 200 | 201 | self.train_iter = 0 202 | self.epoch = 0 203 | self.best_performance_metric = None 204 | 205 | self.print_freq = 10 206 | 207 | self.writer = None 208 | 209 | self.run_dir = os.path.join('runs', self.cfg.run_name) 210 | os.makedirs(self.run_dir, exist_ok=True) 211 | self.snapshot_path = os.path.join(self.run_dir, self.cfg.run_name + '.pth.tar') 212 | 213 | pass 214 | 215 | def prepare_dataset(self): 216 | """Prepare dataset for training the detector. Done only once.""" 217 | 218 | if os.path.exists(self.cfg.train_val_split_dir): 219 | return 220 | 221 | image_list = NameListDataset.list_all_images() 222 | assert len(image_list) > 0 223 | 224 | os.makedirs(self.cfg.train_val_split_dir) 225 | 226 | NameListDataset.train_val_split(image_list, self.cfg.train_val_split_dir) 227 | 228 | @staticmethod 229 | def wrap_sample_with_variable(input, target, **kwargs): 230 | """Wrap tensor with Variable and push to cuda.""" 231 | if torch.cuda.is_available(): 232 | input = input.cuda(non_blocking=True) 233 | target = [t.cuda(non_blocking=True) for t in target] 234 | return input, target 235 | 236 | def train_epoch(self): 237 | """ 238 | Train the model for one epoch. 239 | """ 240 | 241 | print("-------------- Train epoch ------------------") 242 | 243 | batch_time = AverageMeter() 244 | data_time = AverageMeter() 245 | forward_time = AverageMeter() 246 | loss_time = AverageMeter() 247 | backward_time = AverageMeter() 248 | loss_total_am = AverageMeter() 249 | loss_loc_am = AverageMeter() 250 | loss_cls_am = AverageMeter() 251 | 252 | # switch to training mode 253 | self.model_dp.train() 254 | 255 | is_lr_change = self.epoch in [epoch for epoch, _ in self.lr_scales] 256 | if self.optimizer is None or is_lr_change: 257 | scale = None 258 | if self.optimizer is None: 259 | scale = 1.0 260 | if is_lr_change: 261 | scale = [sc for epoch, sc in self.lr_scales if epoch == self.epoch][0] 262 | self.learning_rate = self.base_learning_rate * scale 263 | if self.optimizer is None: 264 | self.optimizer = torch.optim.SGD( 265 | self.model_dp.parameters(), self.learning_rate, 266 | momentum=0.9, 267 | weight_decay=0.0001) 268 | else: 269 | for param_group in self.optimizer.param_groups: 270 | param_group['lr'] = self.learning_rate 271 | 272 | do_dump_train_images = False 273 | detection_train_dump_dir = None 274 | if do_dump_train_images: 275 | detection_train_dump_dir = os.path.join(self.run_dir, 'detection_train_dump') 276 | clean_dir(detection_train_dump_dir) 277 | 278 | end = time.time() 279 | for batch_idx, sample in enumerate(self.train_loader): 280 | # measure data loading time 281 | data_time.update(time.time() - end) 282 | 283 | input, target, names, pil_images, annotations, stats = sample 284 | 285 | if do_dump_train_images: # and random.random() < 0.01: 286 | dump_images( 287 | names, pil_images, annotations, None, stats, 288 | self.model.labelmap, 289 | detection_train_dump_dir) 290 | 291 | input_var, target_var = self.wrap_sample_with_variable(input, target) 292 | 293 | # compute output 294 | forward_ts = time.time() 295 | encoded_tensor = self.model_dp(input_var) 296 | forward_time.update(time.time() - forward_ts) 297 | loss_ts = time.time() 298 | loss, loss_details = self.model.get_loss(encoded_tensor, target_var) 299 | loss_time.update(time.time() - loss_ts) 300 | 301 | # record loss 302 | loss_total_am.update(loss_details["loss"], input.size(0)) 303 | loss_loc_am.update(loss_details["loc_loss"], input.size(0)) 304 | loss_cls_am.update(loss_details["cls_loss"], input.size(0)) 305 | 306 | # compute gradient and do SGD step 307 | backward_ts = time.time() 308 | self.optimizer.zero_grad() 309 | loss.backward() 310 | clip_gradient(self.model, 2.0, 'by_max') 311 | self.optimizer.step() 312 | backward_time.update(time.time() - backward_ts) 313 | 314 | # measure elapsed time 315 | batch_time.update(time.time() - end) 316 | end = time.time() 317 | 318 | if batch_idx % self.print_freq == 0: 319 | print('Epoch: [{0}][{1}/{2}]\t' 320 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 321 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 322 | 'Forward {forward_time.val:.3f} ({forward_time.avg:.3f})\t' 323 | 'LossTime {loss_time.val:.3f} ({loss_time.avg:.3f})\t' 324 | 'Backward {backward_time.val:.3f} ({backward_time.avg:.3f})\t' 325 | 'Loss {loss_total_am.val:.4f} ({loss_total_am.avg:.4f})\t' 326 | 'Loss_loc {loss_loc_am.val:.4f} ({loss_loc_am.avg:.4f})\t' 327 | 'Loss_cls {loss_cls_am.val:.4f} ({loss_cls_am.avg:.4f})\t' 328 | .format( 329 | self.epoch, batch_idx, len(self.train_loader), 330 | batch_time=batch_time, data_time=data_time, 331 | forward_time=forward_time, loss_time=loss_time, backward_time=backward_time, 332 | loss_total_am=loss_total_am, loss_loc_am=loss_loc_am, loss_cls_am=loss_cls_am 333 | )) 334 | 335 | if self.train_iter % self.print_freq == 0: 336 | self.writer.add_scalar('train/loss', loss_total_am.avg, self.train_iter) 337 | self.writer.add_scalar('train/loss_loc', loss_loc_am.avg, self.train_iter) 338 | self.writer.add_scalar('train/loss_cls', loss_cls_am.avg, self.train_iter) 339 | self.writer.add_scalar('train/lr', self.learning_rate, self.train_iter) 340 | 341 | num_prints = self.train_iter // self.print_freq 342 | # print('num_prints=', num_prints) 343 | num_prints_rare = num_prints // 100 344 | # print('num_prints_rare=', num_prints_rare) 345 | if num_prints_rare == 0 and num_prints % 10 == 0 or num_prints % 100 == 0: 346 | print('save historgams') 347 | if self.train_iter > 0: 348 | import itertools 349 | named_parameters = itertools.chain( 350 | self.model.multibox_layers.named_parameters(), 351 | self.model.extra_layers.named_parameters(), 352 | ) 353 | for name, param in named_parameters: 354 | self.writer.add_histogram(name, param.detach().cpu().numpy(), self.train_iter, bins='fd') 355 | self.writer.add_histogram(name+'_grad', param.grad.detach().cpu().numpy(), self.train_iter, bins='fd') 356 | 357 | first_conv = list(self.model.backbone._modules.items())[0][1]._parameters['weight'] 358 | image_grid = torchvision.utils.make_grid(first_conv.detach().cpu(), normalize=True, scale_each=True) 359 | image_grid_grad = torchvision.utils.make_grid(first_conv.grad.detach().cpu(), normalize=True, scale_each=True) 360 | self.writer.add_image('layers0_conv', image_grid, self.train_iter) 361 | self.writer.add_image('layers0_conv_grad', image_grid_grad, self.train_iter) 362 | 363 | self.train_iter += 1 364 | pass 365 | 366 | self.epoch += 1 367 | 368 | def to_class_grouped_anno(self, batch_anno): 369 | """ 370 | Since annotations have all classes mixed together, need to group them by class to 371 | pass it to average precision calculation function. 372 | """ 373 | 374 | all_annotations = [] 375 | for anno in batch_anno: 376 | classes = [[] for i in range(len(self.model.labelmap))] 377 | for obj in anno: 378 | object_id = self.model.labelmap.index(obj["type"]) 379 | classes[object_id].append(obj["bbox"]) 380 | classes = [ 381 | np.stack(objs, axis=0) if len(objs) > 0 else np.empty((0, 4), dtype=np.float64) \ 382 | for objs in classes] 383 | all_annotations.append(classes) 384 | return all_annotations 385 | 386 | def validate(self, do_dump_images=False, save_checkpoint=False): 387 | """ 388 | Run validation on the current network state. 389 | """ 390 | 391 | print("-------------- Validation ------------------") 392 | 393 | batch_time = AverageMeter() 394 | data_time = AverageMeter() 395 | loss_total_am = AverageMeter() 396 | loss_loc_am = AverageMeter() 397 | loss_cls_am = AverageMeter() 398 | 399 | # switch to evaluate mode 400 | self.model.eval() 401 | 402 | detection_val_dump_dir = os.path.join(self.run_dir, 'detection_val_dump') 403 | if do_dump_images: 404 | clean_dir(detection_val_dump_dir) 405 | 406 | iou_threshold_perclass = [0.7 if i == 0 else 0.5 for i in range(len(self.model.labelmap))] # Kitti 407 | 408 | ap_estimator = average_precision.AveragePrecision(self.model.labelmap, iou_threshold_perclass) 409 | 410 | end = time.time() 411 | for batch_idx, sample in enumerate(self.val_loader): 412 | # Measure data loading time 413 | data_time.update(time.time() - end) 414 | 415 | input, target, names, pil_images, annotations, stats = sample 416 | 417 | with torch.no_grad(): 418 | input_var, target_var = self.wrap_sample_with_variable(input, target, volatile=True) 419 | 420 | # Compute output tensor of the network 421 | encoded_tensor = self.model_dp(input_var) 422 | # Compute loss for logging only 423 | _, loss_details = self.model.get_loss(encoded_tensor, target_var) 424 | 425 | # Save annotation and detection results for further AP calculation 426 | class_grouped_anno = self.to_class_grouped_anno(annotations) 427 | detections_all = self.model.get_detections(encoded_tensor, 0.0) 428 | ap_estimator.add_batch(class_grouped_anno, detections_all) 429 | 430 | # Record loss 431 | loss_total_am.update(loss_details["loss"], input.size(0)) 432 | loss_loc_am.update(loss_details["loc_loss"], input.size(0)) 433 | loss_cls_am.update(loss_details["cls_loss"], input.size(0)) 434 | 435 | # Dump validation images with overlays for developer to subjectively estimate accuracy 436 | if do_dump_images: 437 | overlay_conf_threshold = 0.3 438 | detections_thr = self.model.get_detections(encoded_tensor, overlay_conf_threshold) 439 | dump_images( 440 | names, pil_images, annotations, detections_thr, stats, 441 | self.model.labelmap, 442 | detection_val_dump_dir) 443 | 444 | # Measure elapsed time 445 | batch_time.update(time.time() - end) 446 | end = time.time() 447 | 448 | if batch_idx % self.print_freq == 0: 449 | print('Validation: [{0}/{1}]\t' 450 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 451 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 452 | 'Loss {loss_total_am.val:.4f} ({loss_total_am.avg:.4f})\t' 453 | 'Loss_loc {loss_loc_am.val:.4f} ({loss_loc_am.avg:.4f})\t' 454 | 'Loss_cls {loss_cls_am.val:.4f} ({loss_cls_am.avg:.4f})\t' 455 | .format( 456 | batch_idx, len(self.val_loader), 457 | batch_time=batch_time, data_time=data_time, 458 | loss_total_am = loss_total_am, loss_loc_am = loss_loc_am, loss_cls_am = loss_cls_am 459 | )) 460 | 461 | # After coming over the while validation set, calculate individual average precision values and total mAP 462 | mAP, AP_list = ap_estimator.calculate_mAP() 463 | 464 | for ap, label in zip(AP_list, self.model.labelmap): 465 | print('{} {:.3f}'.format(label.ljust(20), ap)) 466 | print(' mAP - {mAP:.3f}'.format(mAP=mAP)) 467 | performance_metric = AP_list[self.model.labelmap.index('Car')] 468 | 469 | # Log to tensorboard 470 | if self.writer is not None: 471 | self.writer.add_scalar('val/mAP', mAP, self.train_iter) 472 | self.writer.add_scalar('val/performance_metric', performance_metric, self.train_iter) 473 | self.writer.add_scalar('val/loss', loss_total_am.avg, self.train_iter) 474 | self.writer.add_scalar('val/loss_loc', loss_loc_am.avg, self.train_iter) 475 | self.writer.add_scalar('val/loss_cls', loss_cls_am.avg, self.train_iter) 476 | 477 | if save_checkpoint: 478 | # Remember best accuracy and save checkpoint 479 | is_best = performance_metric > self.best_performance_metric 480 | 481 | if is_best: 482 | self.best_performance_metric = performance_metric 483 | torch.save( 484 | {'state_dict': self.model.state_dict()}, 485 | self.snapshot_path) 486 | 487 | pass 488 | 489 | def load_checkpoint(self, checkpoint_path): 490 | """Load spesified snapshot to the network.""" 491 | if checkpoint_path is None: 492 | checkpoint_path = self.snapshot_path 493 | if os.path.exists(checkpoint_path): 494 | checkpoint = torch.load(checkpoint_path) 495 | self.model.load_state_dict(checkpoint['state_dict']) 496 | else: 497 | print("Checkpoint not found:", checkpoint_path) 498 | assert False, "No sense to test random weights" 499 | pass 500 | 501 | def print_anchor_coverage(self): 502 | from anchor_coverage import AnchorCoverage 503 | 504 | anchor_coverage = AnchorCoverage() 505 | 506 | for batch_idx, sample in enumerate(self.val_loader): 507 | input, target, names, pil_images, annotations, stats = sample 508 | anchor_coverage.add_batch(annotations, stats) 509 | 510 | anchor_coverage.print() 511 | 512 | def export(self): 513 | resolution_hw = default_input_traits()["resolution"] 514 | example_input = torch.rand((1, 3, *resolution_hw)) 515 | example_input_cuda = example_input.cuda() 516 | traced_model = torch.jit.trace(self.model, (example_input_cuda,)) 517 | print(traced_model) 518 | 519 | self.model.cpu() 520 | path = os.path.join(self.run_dir, self.cfg.run_name+'.onnx') 521 | torch.onnx.export(self.model, (example_input,), path, verbose=False) 522 | if torch.cuda.is_available(): 523 | self.model.cuda() 524 | # assert False 525 | pass 526 | 527 | def run(self): 528 | """ 529 | Launch training procedure. Performs training interleaved 530 | by validation according to the training schedule. 531 | """ 532 | 533 | self.print_anchor_coverage() 534 | 535 | self.best_performance_metric = 0.0 536 | 537 | do_dump_images = False 538 | 539 | self.writer = SummaryWriterOpt(enabled=True, suffix=self.cfg.run_name) 540 | 541 | # self.validate(do_dump_images=do_dump_images, save_checkpoint=False) 542 | 543 | num_epochs = 0 544 | do_process = True 545 | while do_process: 546 | for i in range(self.cfg.epochs_before_val): 547 | if num_epochs >= self.epochs_to_train: 548 | do_process = False 549 | break 550 | self.train_epoch() 551 | num_epochs += 1 552 | self.validate(do_dump_images=do_dump_images, save_checkpoint=True) 553 | pass 554 | 555 | 556 | def main(): 557 | """Entry point.""" 558 | 559 | default_config = 'resnet34_pretrained' 560 | # default_config = 'resnet34_custom' 561 | # default_config = 'simple_model' 562 | # default_config = 'resnet50_pretrained' 563 | 564 | parser = argparse.ArgumentParser(description="Training script for 2D detection") 565 | parser.add_argument("--validate", action='store_true') 566 | parser.add_argument("--config", default=default_config) 567 | parser.add_argument("--checkpoint_path", default=None) 568 | args = parser.parse_args() 569 | 570 | trainer = Trainer(args.config) 571 | 572 | if args.validate: 573 | 574 | print('Start validation') 575 | trainer.print_anchor_coverage() 576 | trainer.load_checkpoint(args.checkpoint_path) 577 | trainer.export() 578 | trainer.validate(do_dump_images=True, save_checkpoint=False) 579 | print('Finished validation. Done!') 580 | 581 | else: 582 | 583 | print('Start training') 584 | # trainer.load_checkpoint("runs/simple_model_1x1/simple_model_1x1_epochs452.pth.tar") 585 | trainer.run() 586 | print('Finished training. Done!') 587 | 588 | pass 589 | 590 | 591 | if __name__ == "__main__": 592 | try: 593 | main() 594 | except KeyboardInterrupt: 595 | print('Interrupted') 596 | try: 597 | sys.exit(0) 598 | except SystemExit: 599 | os._exit(0) 600 | --------------------------------------------------------------------------------