├── Figures ├── KP_net.jpg └── confusion_matrix.png ├── README.md ├── data ├── VeRi │ └── mean.pth.tar ├── VehicleKeyPointData │ ├── keypoint_test.txt │ └── keypoint_train.txt ├── __init__.py └── datatools │ ├── __init__.py │ ├── transforms.py │ └── veri_dataset.py ├── main.py ├── models ├── KP_Orientation_Net.py └── __init__.py ├── requirements.txt └── tools ├── __init__.py ├── confusion_meter.py ├── paths.py ├── test.py ├── train.py └── utilities.py /Figures/KP_net.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pirazh/Vehicle_Key_Point_Orientation_Estimation/f2ece3d82c7a4a778113481b6efb39601621c886/Figures/KP_net.jpg -------------------------------------------------------------------------------- /Figures/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pirazh/Vehicle_Key_Point_Orientation_Estimation/f2ece3d82c7a4a778113481b6efb39601621c886/Figures/confusion_matrix.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vehicle Key-Point & Orientation Estimation 2 | 3 | The repository contains the code for vehicle key-point and Orientation estimation Network proposed in the [A Dual Path Model With Adaptive Attention For Vehicle Re-Identification](http://openaccess.thecvf.com/content_ICCV_2019/papers/Khorramshahi_A_Dual-Path_Model_With_Adaptive_Attention_for_Vehicle_Re-Identification_ICCV_2019_paper.pdf) which has been accepted as an **oral presentation** in ICCV 2019. The code for re-identification network does not exist in the repository. 4 | 5 | The code for vehicle key-point and orientation estimation has been released to facilitate future research in vehicle alignment, 3d vehicle modeling and vehicle speed estimation. 6 | 7 | ## Vehicle Key-Point & Orientation Estimation Pipeline 8 | 9 | The figure below demonstrates the pipeline for prediction of 20 vehicle landmarks and classify vehicle's orientation into one of 8 classes all defined in [here](https://github.com/Zhongdao/VehicleReIDKeyPointData). 10 | 11 | ![Pipeline](./Figures/KP_net.jpg) 12 | 13 | Key-point estimation is done in two stages; in stage 1 the model tries to come up with coarse estimation of key-points location and in stage 2 those coarse estimates are refined through an hourglass like structure and in a parallel branch the orientation of the vehicle is predicted as well. 14 | 15 | 16 | ## Getting Started 17 | Clone this repository with the following command: 18 | 19 | ``` 20 | git clone https://github.com/Pirazh/Vehicle_Key_Point_Orientation_Estimation 21 | ``` 22 | 23 | ## Requirements 24 | 25 | The code is written in Python 2.7 with [Pytorch](https://pytorch.org) version "0.4.1". To install the dependencies run the following command: 26 | 27 | ``` 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | Then you have to download and put the pre-trained model and [Veri-776](https://vehiclereid.github.io/VeRi/) dataset in the following directories: 32 | 33 | - Put the `VeRi` folder containing the dataset into `./data/`. You can alternatively change the paths in the `./tools/paths.py` according to your preference. 34 | 35 | - Download the pre-trained stage1 & stage2 Key-point models from [here](https://drive.google.com/file/d/1A4A8Xu6RbVHUK6Pq5QSmKQt7jS_p5OKx/view?usp=sharing) and [here](https://drive.google.com/file/d/1jZR1YuDOLiZ3lh0B_CFJQp2aWh-qUU5C/view?usp=sharing) and put them in `./checkpoints/stage1/` and `./checkpoints/stage2/` directories. 36 | 37 | ## Testing 38 | To test an already trained model, you have to specify the test phase, stage1(Coarse key-points estimation)/stage2(Entire model for fine key-points generation and orientation estimation) use cases and the path to the trained model. This can be achieved by running the following command: 39 | 40 | ### Stage1 41 | 42 | ``` 43 | python main.py --phase test --use_case stage1 --resumed_ckpt PATH_TO_STAGE1_PRE_TRAINED_MODEL 44 | ``` 45 | 46 | ### Stage2 47 | 48 | ``` 49 | python main.py --phase test --use_case stage2 --resumed_ckpt PATH_TO_STAGE2_PRE_TRAINED_MODEL 50 | ``` 51 | 52 | The number of workers, train/test batch size can be set through arguments `--num_workers`, `--train_batch_size`, `--test_batch_size`. The code also has multi GPU training/testing support which is enabled by passing `--mGPU` argument in the `main.py` script. If you wish to visualize the predicted key-points, you can do so by passing the `--visualize` argument. 53 | 54 | ## Training 55 | 56 | ### Stage1 57 | 58 | To train stage1 of the model run the following command: 59 | 60 | ``` 61 | python main.py --phase train --use_case stage1 --mGPU --lr 0.0001 --epochs 15 62 | ``` 63 | 64 | After training, results can be found in `./checkpoints/stage1/TIME_STAMP_STAMP_WHEN_TRAINING_STARTED`. 65 | 66 | ### Stage2 67 | 68 | To train the entire model run the followning: 69 | 70 | ``` 71 | python main.py --phase train --use_case stage2 --mGPU --lr 0.0001 --epochs 15 --stage1_ckpt PATH_TO_THE_STAGE1_TRAINED_MODEL 72 | ``` 73 | 74 | Training results can be found in `./checkpoints/stage2/TIME_STAMP_WHEN_TRAINING_STARTED`. 75 | 76 | ## Results 77 | 78 | | | Stage1 | Stage2 | 79 | | ---------------------------------- | --- | ---------------------- | 80 | |Key-Point localization MSE (pixels) | 1.95 | 1.56 | 81 | | Orientation Classification Accuracy | - | 84.44% | 82 | 83 | Note that the localization MSE is calculated in 56 * 56 heatmaps. The following figure is the confusion matrix for the vehicle orientation estimation. In most of the cases the network classifies the orientation correctly; however in some cases since there is no clear boundry between defined orientation classes *e.g.* left front and left, the network struggles the in determining the correct class. 84 | 85 | Orientation_Classification_accuracy 86 | 87 | ## Cite 88 | 89 | If you find this repository useful in your research please cite our paper: 90 | 91 | @InProceedings{Khorramshahi_2019_ICCV, 92 | author = {Khorramshahi, Pirazh and Kumar, Amit and Peri, Neehar and Rambhatla, Sai Saketh and Chen, Jun-Cheng and Chellappa, Rama}, 93 | title = {A Dual-Path Model With Adaptive Attention for Vehicle Re-Identification}, 94 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 95 | month = {October}, 96 | year = {2019} 97 | } 98 | 99 | ## Questions 100 | 101 | If you have any questions regarding the model and the repository send me an email at (pkhorram@terpmail.umd.edu). 102 | -------------------------------------------------------------------------------- /data/VeRi/mean.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pirazh/Vehicle_Key_Point_Orientation_Estimation/f2ece3d82c7a4a778113481b6efb39601621c886/data/VeRi/mean.pth.tar -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pirazh/Vehicle_Key_Point_Orientation_Estimation/f2ece3d82c7a4a778113481b6efb39601621c886/data/__init__.py -------------------------------------------------------------------------------- /data/datatools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pirazh/Vehicle_Key_Point_Orientation_Estimation/f2ece3d82c7a4a778113481b6efb39601621c886/data/datatools/__init__.py -------------------------------------------------------------------------------- /data/datatools/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from skimage import transform 3 | import numpy as np 4 | from tools import paths 5 | 6 | 7 | class ToTensor(object): 8 | """ 9 | The object to transform numpy arrays to torch tensors 10 | """ 11 | def __call__(self, image): 12 | return torch.from_numpy(image).float() 13 | 14 | 15 | class Rescale(object): 16 | """ 17 | The object to rescale loaded images to desired output size 18 | """ 19 | def __init__(self, input_size1=(224, 224), input_size2=(56, 56)): 20 | self.input_size1 = input_size1 21 | self.input_size2 = input_size2 22 | 23 | def __call__(self, image): 24 | 25 | image_in2 = transform.resize(image, self.input_size2) 26 | image_in1 = transform.resize(image, self.input_size1) 27 | 28 | return image_in1, image_in2 29 | 30 | 31 | class Normalize(object): 32 | """ 33 | The object to normalize the images based on the Veri-776 training set mean and standard deviation 34 | """ 35 | def __init__(self, dataset_mean_std=paths.paths.VERI_MEAN_STD_FILE): 36 | mean_std = torch.load(dataset_mean_std) 37 | self.mean = mean_std['mean'].numpy() 38 | self.std = mean_std['std'].numpy() 39 | 40 | def __call__(self, image): 41 | for j in range(3): 42 | image[:, :, j] = (image[:, :, j] - self.mean[j]) / self.std[j] 43 | return image 44 | 45 | 46 | class Rotate(object): 47 | """ 48 | The object to rotate the input image with a desired angle 49 | """ 50 | def __init__(self): 51 | pass 52 | 53 | def __call__(self, image, theta=0): 54 | image = np.round(transform.rotate(image, theta, preserve_range=True)) 55 | return image 56 | 57 | 58 | class LRFlip(object): 59 | """ 60 | The object to horizontally mirror the input image 61 | """ 62 | def __init__(self): 63 | pass 64 | 65 | def __call__(self, image): 66 | return np.fliplr(image).copy() 67 | -------------------------------------------------------------------------------- /data/datatools/veri_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | import os 6 | import scipy.ndimage as ndimage 7 | from skimage import io 8 | from tools.paths import paths 9 | from . import transforms 10 | import pdb 11 | 12 | 13 | pose_flip_lr_dict = {'0': 0, '1': 1, '2': 5, '3': 6, '4': 7, '5': 2, '6': 3, '7': 4} 14 | 15 | 16 | class VeriDataset(Dataset): 17 | """ 18 | Veri Dataset class 19 | """ 20 | def __init__(self, phase='train', flip_probability=0, rotate_probability=0): 21 | self.phase = phase 22 | self.flip_probability = flip_probability 23 | self.rotate_probability = rotate_probability 24 | self.struct = ndimage.generate_binary_structure(2, 1) 25 | 26 | if self.phase == 'train': 27 | txt_file = paths.VERI_KP_ANNOTATIONS_TRAINING_FILE 28 | elif self.phase == 'test': 29 | txt_file = paths.VERI_KP_ANNOTATIONS_TESTING_FILE 30 | else: 31 | raise NameError('Phase should be either "train" or "test"') 32 | assert (os.path.exists(txt_file)) 33 | self.anno = [line.rstrip('\n') for line in open(txt_file)] 34 | 35 | # remove missing data files 36 | no_data = [] 37 | for line in self.anno: 38 | if not os.path.isfile(os.path.join(paths.VERI_DATA_PATH, line.split(' ')[0])): 39 | no_data.append(line) 40 | for line in no_data: 41 | self.anno.remove(line) 42 | 43 | self.Normalize = transforms.Normalize() 44 | self.LRflip = transforms.LRFlip() 45 | self.Rescale = transforms.Rescale() 46 | self.ToTensor = transforms.ToTensor() 47 | self.Rotate = transforms.Rotate() 48 | self.key_point_distribution, self.pose_distribution = self._class_weights() 49 | 50 | def _class_weights(self): 51 | """ 52 | Calculate the frequency of each class to help balance the training ot the model. 53 | :return: Inverse frequency of each pixel type class (20 key-points + background) and 54 | each vehicle orientation class 55 | """ 56 | txt_file = paths.VERI_KP_ANNOTATIONS_TRAINING_FILE 57 | anno = [line.rstrip('\n') for line in open(txt_file)] 58 | 59 | no_data = [] 60 | for line in anno: 61 | if not os.path.isfile(os.path.join(paths.VERI_DATA_PATH, line.split(' ')[0])): 62 | no_data.append(line) 63 | for line in no_data: 64 | anno.remove(line) 65 | 66 | weights = np.zeros(21) 67 | pose_distribution = np.zeros(8) 68 | for line in anno: 69 | cnt = 0 70 | pose_distribution[int(line[-1])] += 1 71 | for i in range(0, 20): 72 | coordinate = line.split(' ')[2 * i + 1: 2 * i + 3] 73 | if int(coordinate[0]) > -1: 74 | cnt += 1 75 | weights[i] += 1 76 | weights[20] += 56 * 56 - cnt 77 | 78 | return torch.from_numpy((1 / weights) / (1 / weights).sum()),\ 79 | torch.from_numpy((1 / pose_distribution) / (1 / pose_distribution).sum()) 80 | 81 | def __len__(self): 82 | return len(self.anno) 83 | 84 | def __getitem__(self, item): 85 | # load the image 86 | im_path = os.path.join(paths.VERI_DATA_PATH, self.anno[item].split(' ')[0]) 87 | image = io.imread(im_path) 88 | image = image.astype(np.float) 89 | 90 | image = self.Normalize(image) 91 | image_in1, image_in2 = self.Rescale(image) 92 | 93 | H, W = image.shape[0], image.shape[1] 94 | 95 | pose = int(self.anno[item][-1]) 96 | 97 | # load annotations 98 | keypoints = np.zeros([20, 2]) 99 | 100 | for i in range(0, 20): 101 | keypoints[i] = [int(b) for b in self.anno[item].split(' ')[2 * i + 1: 2 * i + 3]] 102 | 103 | # Resize annotations to fit the modifed image 104 | keypoints[:, 0] = keypoints[:, 0] * (56 / W) 105 | keypoints[:, 1] = keypoints[:, 1] * (56 / H) 106 | 107 | if self.phase == 'train': 108 | 109 | # LR Flipping 110 | if np.random.rand() < self.flip_probability: 111 | image_in1, image_in2 = self.LRflip(image_in1), self.LRflip(image_in2) 112 | pose = pose_flip_lr_dict[str(pose)] 113 | keypoints[:, 0] = 56 - keypoints[:, 0] 114 | 115 | # Rotation 116 | if np.random.rand() < self.rotate_probability: 117 | angle = int(np.random.rand() * 10 - 5 / 2) 118 | image_in1, image_in2 = self.Rotate(image_in1, theta=angle), self.Rotate(image_in2, theta=angle) 119 | angle_radian = np.pi * angle / 180 120 | R = np.array([[np.cos(angle_radian), np.sin(angle_radian)], 121 | [-np.sin(angle_radian), np.cos(angle_radian)]]) 122 | keypoints = np.matmul(R, keypoints.T - 28).T + 28 123 | 124 | gt_heatmaps = np.zeros([21, 56, 56]) 125 | pixel_class_label = np.ones([56, 56]) * 20 126 | 127 | for i, pt in enumerate(keypoints): 128 | if (55 >= pt[0] > 0) and (55 >= pt[1] > 0): 129 | gt_heatmaps[i][int(pt[1])][int(pt[0])] = 1 130 | pixel_class_label[int(pt[1])][int(pt[0])] = i 131 | 132 | # defining map corresponding to the background 133 | gt_heatmaps[20] = np.ones([56, 56]) 134 | 135 | for i in range(0, 20): 136 | gt_heatmaps[20] = gt_heatmaps[20] - gt_heatmaps[i] 137 | """ 138 | for i in range(20): 139 | gt_heatmaps[i] = \ 140 | ndimage.binary_dilation(gt_heatmaps[i], structure=self.struct, iterations=1).astype(gt_heatmaps.dtype) 141 | """ 142 | #pdb.set_trace() 143 | return self.ToTensor(image_in1.transpose(2, 0, 1)),\ 144 | self.ToTensor(image_in2.transpose(2, 0, 1)), \ 145 | torch.from_numpy(gt_heatmaps).float(), \ 146 | torch.from_numpy(pixel_class_label).long(), \ 147 | torch.tensor([pose]) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from models import KP_Orientation_Net 3 | import torch 4 | from tools import train, test 5 | import argparse 6 | import time 7 | 8 | 9 | def main(args): 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 14 | 15 | if args.use_case == 'stage1': 16 | net = KP_Orientation_Net.CoarseRegressor() 17 | elif args.use_case == 'stage2': 18 | net = KP_Orientation_Net.KeyPointModel() 19 | if args.phase == 'train': 20 | # Load the stage 1 checkpoint and freeze its weights 21 | net.coarse_estimator.load_state_dict(torch.load(args.stage1_ckpt)['net_state_dict']) 22 | for param in net.coarse_estimator.parameters(): 23 | param.requires_grad = False 24 | print('stage1 weights have been initialized with pre-trained weights and are frozen!') 25 | else: 26 | raise NameError('use case should be either "stage1" or "stage2"') 27 | 28 | net = net.to(device) 29 | 30 | print('Total number of Parameters = %s' % sum(p.numel() for p in net.parameters())) 31 | print('Total number of trainable Parameters = %s' % sum(p.numel() for p in net.parameters() if p.requires_grad)) 32 | 33 | if args.resume or args.phase == 'test': 34 | checkpoint = torch.load(args.resumed_ckpt) 35 | net.load_state_dict(checkpoint['net_state_dict']) 36 | print('Resumed Checkpoint :{} is Loaded!'.format(args.resumed_ckpt)) 37 | 38 | if torch.cuda.device_count() > 1 and args.mGPU: 39 | net = torch.nn.DataParallel(net) 40 | 41 | if args.phase == 'train': 42 | train.train(args, net) 43 | elif args.phase == 'test': 44 | net.eval() 45 | output = test.test(args, net) 46 | print(output['message']) 47 | else: 48 | raise NameError('phase should be either "train" or "test"') 49 | 50 | 51 | if __name__ == '__main__': 52 | 53 | parser = argparse.ArgumentParser('Key-Point and Orientation Estimation Network') 54 | parser.add_argument('--phase', default='train', type=str, choices=['train', 'test'], 55 | help='train/test mode selection', required=True) 56 | parser.add_argument('--use_case', default='Stage1', type=str, choices=['stage1', 'stage2'], 57 | help='Coarse/Fine heatmap model training', required=True) 58 | parser.add_argument('--rotate_probability', default=0, type=float) 59 | parser.add_argument('--flip_probability', default=0, type=float) 60 | parser.add_argument('--visualize', default=False, action='store_true', 61 | help='randomly visulaize estimated key-points with respect to GT') 62 | parser.add_argument('--train_batch_size', default=128, help='Size of Training Batch', type=int) 63 | parser.add_argument('--test_batch_size', default=128, help='Size of Testing Batch', type=int) 64 | parser.add_argument('--lr', default=0.0001, help='Learning Rate', type=float) 65 | parser.add_argument('--num_workers', default=10, type=int) 66 | parser.add_argument('--weight_decay', default=0, help='Optimizer Weight Decay', type=float) 67 | parser.add_argument('--start_epoch', default=0, type=int) 68 | parser.add_argument('--epochs', default=10, type=int) 69 | parser.add_argument('--Lambda', default=0.1, type=float, help='The balance between fine/coarse losses') 70 | parser.add_argument('--test_every_n_epoch', default=1, help='Testing Network after n epochs', type=int) 71 | parser.add_argument('--resume', default=False, action='store_true', help='resume to specific checkpoint') 72 | parser.add_argument('--mGPU', default=False, action='store_true', help='Multi GPU support') 73 | parser.add_argument('--stage1_ckpt', default='', help='Path to the stage1 trained model', type=str, 74 | required=True if (parser.parse_known_args()[0].use_case == 'stage2' and 75 | parser.parse_known_args()[0].phase == 'train') else False) 76 | parser.add_argument('--resumed_ckpt', default='', help='Path to resume the checkpoint', type=str, 77 | required=True if parser.parse_known_args()[0].phase == 'test' or 78 | parser.parse_known_args()[0].resume else False) 79 | if parser.parse_known_args()[0].phase == 'train': 80 | parser.add_argument('--ckpt', 81 | default='./checkpoints/' + parser.parse_known_args()[0].use_case + 82 | '/' + time.strftime("%Y-%m-%d-%H"), 83 | help='Path to save the checkpoints', 84 | type=str) 85 | args = parser.parse_args() 86 | main(args) 87 | -------------------------------------------------------------------------------- /models/KP_Orientation_Net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | import torch 4 | 5 | 6 | class CoarseRegressor(nn.Module): 7 | """ 8 | Coarse Key-Point Detecor Network 9 | """ 10 | def __init__(self, n1=20): 11 | super(CoarseRegressor, self).__init__() 12 | self.N1 = n1 13 | # VGG Convolutional Layers 14 | self.A1 = nn.Sequential(*list(list(models.vgg16_bn(pretrained=True).children())[0].children())[:7]) 15 | self.A2 = nn.Sequential(*list(list(models.vgg16_bn(pretrained=True).children())[0].children())[7:14]) 16 | self.A3 = nn.Sequential(*list(list(models.vgg16_bn(pretrained=True).children())[0].children())[14:24]) 17 | self.A4 = nn.Sequential(*list(list(models.vgg16_bn(pretrained=True).children())[0].children())[24:34]) 18 | self.A5 = nn.Sequential(*list(list(models.vgg16_bn(pretrained=True).children())[0].children())[34:]) 19 | # Coarse Regressors 20 | self.A6 = nn.Sequential(nn.Conv2d(512, 512, 1, padding=0), nn.BatchNorm2d(512), nn.ReLU()) 21 | self.A6to7 = nn.Sequential(nn.Conv2d(512, self.N1 + 1, 1, padding=0), nn.BatchNorm2d(self.N1 + 1), nn.ReLU()) 22 | self.A3to7 = nn.Sequential(nn.Conv2d(256, self.N1 + 1, 1, padding=0), nn.BatchNorm2d(self.N1 + 1), nn.ReLU()) 23 | self.A4to7 = nn.Sequential(nn.Conv2d(512, self.N1 + 1, 1, padding=0), nn.BatchNorm2d(self.N1 + 1), nn.ReLU()) 24 | self.Up = nn.Upsample(scale_factor=2, mode='bilinear') 25 | 26 | def forward(self, x): 27 | """ 28 | Coarse Key-Point regression forward pass 29 | :param x: The input tensor of shape B * 3 * 224 * 224 30 | :return: The predicted Heatmaps of 20 Key-Points plus the map for background. Shape = B * 21 * 56 * 56 31 | """ 32 | x = self.A1(x) # B * 64 * 112 * 112 33 | x = self.A2(x) # B * 128 * 56 * 56 34 | x = self.A3(x) # B * 256 * 28 * 28 35 | res2 = self.A4(x) # B * 512 * 14 * 14 36 | return self.Up(self.Up(self.Up(self.A6to7(self.A6(self.A5(res2)))) + self.A4to7(res2)) + self.A3to7(x)) 37 | 38 | 39 | class FineRegressor(nn.Module): 40 | """ 41 | Key-Point Refinement Network 42 | """ 43 | def __init__(self, n2=20): 44 | super(FineRegressor, self).__init__() 45 | self.N2 = n2 46 | self.Normalize = nn.Softmax(dim=2) 47 | self.MaxPool = nn.MaxPool2d(2, 2) 48 | self.L1 = nn.Sequential(nn.Conv2d(24, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) 49 | self.HR1 = nn.Sequential(nn.Conv2d(64, 64, 7), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, 5), 50 | nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, 1), nn.BatchNorm2d(256), 51 | nn.ReLU(), nn.ConvTranspose2d(256, 128, 5), nn.BatchNorm2d(128), nn.ReLU(), 52 | nn.ConvTranspose2d(128, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) 53 | self.res1 = nn.Sequential(nn.Conv2d(64, 64, 1), nn.BatchNorm2d(64), nn.ReLU()) 54 | self.L2 = nn.Sequential(nn.ConvTranspose2d(64, self.N2 + 1 , 7), nn.BatchNorm2d(self.N2 + 1), nn.ReLU(), 55 | nn.Conv2d(self.N2 + 1, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) 56 | self.L3 = nn.Sequential(nn.Conv2d(64, 64, 7), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, 5), 57 | nn.BatchNorm2d(128), nn.ReLU(), nn.Conv2d(128, 256, 1), nn.BatchNorm2d(256), nn.ReLU()) 58 | self.res2 = nn.Sequential(nn.Conv2d(64, 64, 1), nn.BatchNorm2d(64), nn.ReLU()) 59 | self.L4 = nn.Sequential(nn.ConvTranspose2d(256, 128, 5), nn.BatchNorm2d(128), nn.ReLU(), 60 | nn.ConvTranspose2d(128, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) 61 | self.L5 = nn.Sequential(nn.ConvTranspose2d(64, self.N2 + 1, 7), nn.BatchNorm2d(self.N2 + 1), nn.ReLU()) 62 | self.pose_branch1 = nn.Sequential(nn.Conv2d(256, 128, 7), nn.BatchNorm2d(128), nn.ReLU(), 63 | nn.Conv2d(128, 64, 7), nn.BatchNorm2d(64), nn.ReLU()) 64 | self.pose_branch2 = nn.Sequential(nn.Conv2d(64, 32, 7), nn.BatchNorm2d(32), nn.ReLU()) 65 | self.FC = nn.Sequential(nn.Linear(2048, 256, bias=True), nn.Dropout(0.5), nn.Linear(256, 8, bias=True)) 66 | 67 | def forward(self, x): 68 | """ 69 | Key-Point Refinement forward pass 70 | :param x: The input tensor of shape B * 24 * 56 * 56 71 | :return : kp: The refined Heatmaps of 20 Key-Points of Shape = B * 20 * 56 * 56 72 | pose: The predicted orientation of vehicle 73 | """ 74 | x = self.L1(x) # B * 64 * 50 * 50 75 | x = self.L2(self.res1(x) + self.HR1(x)) # B * 20 * 50 * 50 76 | joint = self.L3(x) # B * 256 * 40 * 40 77 | # Key Point Estimation 78 | kp = self.L5(self.res2(x) + self.L4(joint)) # B * 20 * 56 * 56 79 | B, C, H, W = kp.shape 80 | kp = self.Normalize(kp.view(B, C, W * H)) 81 | kp = kp.view(B, C, H, W) 82 | # Orientation Estimation 83 | pose = self.pose_branch1(joint) # B * 64 * 28 * 28 84 | pose = self.MaxPool(pose) # B * 64 * 14 * 14 85 | pose = self.pose_branch2(pose) # B * 32 * 8 * 8 86 | pose = pose.view(-1, 2048) # B * 2048 87 | pose = self.FC(pose) # B * 8 88 | return kp, pose 89 | 90 | 91 | class KeyPointModel(nn.Module): 92 | """ 93 | End-to-End Key-Point Regression models 94 | """ 95 | def __init__(self): 96 | super(KeyPointModel, self).__init__() 97 | self.coarse_estimator = CoarseRegressor() 98 | self.refinement = FineRegressor() 99 | 100 | def forward(self, x1, x2): 101 | """ 102 | Key-Point Estimation forward pass 103 | :param x1: The input tensor of shape B * 3 * 224 * 224 104 | :param x2: The input tensor of shape B * 3 * 56 * 56 105 | :return: coarse_kp: The coarse heatmaps of size B * 21 * 56 * 56 106 | fine_kp: The refined heatmaps of size B * 20 * 56 * 56 107 | orientation: The predicted orientation of the vehicle of size B * 8 108 | """ 109 | coarse_kp = self.coarse_estimator(x1) 110 | x2 = torch.cat((x2, coarse_kp), dim=1) 111 | fine_kp, orientation = self.refinement(x2) 112 | 113 | return coarse_kp, fine_kp, orientation 114 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pirazh/Vehicle_Key_Point_Orientation_Estimation/f2ece3d82c7a4a778113481b6efb39601621c886/models/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.19.4 2 | numpy==1.12.1 3 | scipy==1.1.0 4 | matplotlib==2.2.3 5 | torch==0.4.1 6 | scikit_image==0.14.0 7 | skimage==0.0 8 | torchvision==0.4.0 9 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pirazh/Vehicle_Key_Point_Orientation_Estimation/f2ece3d82c7a4a778113481b6efb39601621c886/tools/__init__.py -------------------------------------------------------------------------------- /tools/confusion_meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import itertools 4 | from subprocess import Popen, PIPE 5 | cmd = ' uname -n' 6 | proc = Popen(cmd, stdin=PIPE, stdout=PIPE, shell=True) 7 | stdout, stderr = proc.communicate() 8 | if not stdout.split('.')[0] == 'ramawks80': 9 | plt.switch_backend('agg') 10 | 11 | 12 | class ConfusionMeter: 13 | 14 | def __init__(self, labels=[], normalize=False, save_path=''): 15 | self.save_path = save_path 16 | self.labels = labels 17 | self.num_classes = len(labels) 18 | self.confusion_matrix = np.zeros([self.num_classes, self.num_classes]) 19 | self.normalize = normalize 20 | 21 | def update(self, predictions, labels): 22 | assert predictions.shape == labels.shape 23 | for i in range(predictions.size(0)): 24 | self.confusion_matrix[labels[i].item()][predictions[i].item()] += 1 25 | 26 | def get_result(self): 27 | if self.normalize: 28 | return self.confusion_matrix / self.confusion_matrix.sum(1).clip(min=1e-10)[:, None] 29 | else: 30 | return self.confusion_matrix 31 | 32 | def save_confusion_matrix(self): 33 | final_confusion = self.get_result() 34 | plt.figure() 35 | plt.imshow(final_confusion, interpolation='nearest', cmap=plt.cm.YlOrRd) 36 | plt.colorbar() 37 | tick_marks = np.arange(len(self.labels)) 38 | plt.xticks(tick_marks, self.labels, rotation=90, fontsize=8) 39 | plt.yticks(tick_marks, self.labels, fontsize=8) 40 | fmt = '.2f' if self.normalize else 'd' 41 | thresh = final_confusion.mean() 42 | for i, j in itertools.product(range(self.num_classes), range(self.num_classes)): 43 | plt.text(j, i, format(final_confusion[i, j], fmt), 44 | horizontalalignment="center", 45 | color="white" if final_confusion[i, j] > thresh else "black", fontsize=8) 46 | plt.ylabel('True label', fontsize=10) 47 | plt.xlabel('Predicted label', fontsize=10) 48 | plt.tight_layout() 49 | plt.savefig(self.save_path, dpi=600) 50 | plt.close() 51 | -------------------------------------------------------------------------------- /tools/paths.py: -------------------------------------------------------------------------------- 1 | class paths(): 2 | """ 3 | The path class that encapsulates all the required paths to run the scripts 4 | """ 5 | def __init__(self): 6 | pass 7 | 8 | VERI_DATA_PATH = './data' 9 | VERI_KP_ANNOTATIONS_TRAINING_FILE = './data/VehicleKeyPointData/keypoint_train.txt' 10 | VERI_KP_ANNOTATIONS_TESTING_FILE = './data/VehicleKeyPointData/keypoint_test.txt' 11 | VERI_MEAN_STD_FILE = './data/VeRi/mean.pth.tar' -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from data.datatools import veri_dataset 4 | from tools.utilities import sample_visualizer, accuracy as Acc 5 | from tools.confusion_meter import ConfusionMeter 6 | import os, sys 7 | from tqdm import tqdm 8 | import pdb 9 | 10 | Orientation_labels = ['front', 'rear', 'left', 'left front', 'left rear', 'right', 'right front', 'right rear'] 11 | 12 | 13 | def test(args, net, epoch=None): 14 | """ 15 | This is the function to test the trained key-point and orientation estimation model 16 | :param args: the object that encapsulates all the required settings 17 | :param net: the network to be tested 18 | :param epoch: the epoch number that this test is being done for 19 | :return: a dictionary that contains the result of the test 20 | """ 21 | test_set = veri_dataset.VeriDataset(phase='test') 22 | 23 | test_loader = DataLoader(test_set, shuffle=False, batch_size=args.test_batch_size, num_workers=args.num_workers) 24 | 25 | if args.use_case == 'stage1': 26 | coarse_error = 0.0 27 | with torch.no_grad(): 28 | with tqdm(total=len(test_loader), ncols=0, file=sys.stdout, desc='Stage 1 Evaluation...') as pbar: 29 | for i, in_batch in enumerate(test_loader): 30 | image_in1, _, gt_heatmaps, _, _ = in_batch 31 | 32 | if torch.cuda.is_available(): 33 | image_in1, gt_heatmaps = image_in1.cuda(), gt_heatmaps.cuda() 34 | 35 | coarse_kp = net(image_in1) 36 | 37 | if args.visualize: 38 | sample_visualizer(coarse_kp[:, :20, :, :], gt_heatmaps[:, :20, :, :], image_in1) 39 | 40 | coarse_error += Acc(coarse_kp, gt_heatmaps) 41 | pbar.set_postfix(coarse_kp_dist_from_gt=coarse_error / (i + 1)) 42 | pbar.update() 43 | 44 | coarse_error = coarse_error / len(test_loader) 45 | message = 'Stage 1 KP estimation error is {0:.3f} pixels in 56 by 56 grid.'.format(coarse_error) 46 | return {'message': message, 'coarse_error': coarse_error} 47 | 48 | elif args.use_case == 'stage2': 49 | coarse_error, fine_error, orientation_accuracy, total, correct, = 0.0, 0.0, 0.0, 0, 0 50 | if args.phase == 'train': 51 | save_path = os.path.join(args.ckpt, 'epoch_{}'.format(epoch + 1) + '.png') 52 | orientation_cmf = ConfusionMeter(normalize=True, save_path=save_path, labels=Orientation_labels) 53 | with torch.no_grad(): 54 | with tqdm(total=len(test_loader), ncols=0, file=sys.stdout, desc='Stage 2 Evaluation...') as pbar: 55 | for i, in_batch in enumerate(test_loader): 56 | image_in1, image_in2, gt_heatmaps, _, gt_orientation_label = in_batch 57 | if torch.cuda.is_available(): 58 | image_in1, image_in2, gt_heatmaps, gt_orientation_label = image_in1.cuda(),\ 59 | image_in2.cuda(),\ 60 | gt_heatmaps.cuda(),\ 61 | gt_orientation_label.cuda() 62 | coarse_kp, fine_kp, orientation = net(image_in1, image_in2) 63 | if args.visualize: 64 | sample_visualizer(fine_kp[:, :20, :, :], gt_heatmaps[:, :20, :, :], image_in1) 65 | _, predicted_orientation = torch.max(orientation.data, 1) 66 | if args.phase == 'train': 67 | orientation_cmf.update(predicted_orientation, gt_orientation_label.squeeze()) 68 | total += gt_orientation_label.size(0) 69 | correct += (gt_orientation_label.squeeze() == predicted_orientation).sum().item() 70 | coarse_error += Acc(coarse_kp, gt_heatmaps) 71 | fine_error += Acc(fine_kp, gt_heatmaps) 72 | pbar.set_postfix(fine_kp_dist_from_gt=fine_error / (i + 1), 73 | orientation_classification_accuracy=float(correct) / total * 100) 74 | pbar.update() 75 | if not args.phase == 'test': 76 | orientation_cmf.save_confusion_matrix() 77 | 78 | coarse_error, fine_error, orientation_accuracy = \ 79 | coarse_error / len(test_loader), \ 80 | fine_error / len(test_loader), \ 81 | float(correct) / total * 100 82 | 83 | message = 'Stages 1 & 2 KP estimation errors are {0:.3f} & {1:.3f} pixels in 56 by 56 grid & Orientation' \ 84 | ' classification accuracy is {2:.2f}%.'.format(coarse_error, fine_error, orientation_accuracy) 85 | 86 | return {'message': message, 87 | 'coarse_error': coarse_error, 88 | 'fine_error': fine_error, 89 | 'orientation_accuracy': orientation_accuracy} -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | from tools.utilities import save_checkpoint, Chronometer 5 | from data.datatools import veri_dataset 6 | from torch.utils.data import DataLoader 7 | import os, sys 8 | from tqdm import tqdm 9 | from tools import test 10 | 11 | 12 | def train(args, net): 13 | # Defininig Training Set and Data Loader 14 | train_set = veri_dataset.VeriDataset(phase=args.phase, 15 | rotate_probability=args.rotate_probability, 16 | flip_probability=args.flip_probability) 17 | train_loader = DataLoader(train_set, 18 | batch_size=args.train_batch_size, 19 | shuffle=True, 20 | num_workers=args.num_workers) 21 | 22 | # Train Stage 1 for Coarse Heatmap Generation Using Pixel Based Classification 23 | if args.use_case == 'stage1': 24 | params = net.module.parameters() if args.mGPU and (torch.cuda.device_count() > 1) else net.parameters() 25 | Heatmap_criterion = nn.CrossEntropyLoss(train_set.key_point_distribution.float().cuda()) 26 | 27 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 28 | timer = Chronometer() 29 | start_epoch = args.start_epoch 30 | 31 | if not os.path.isdir(args.ckpt): 32 | os.mkdir(args.ckpt) 33 | 34 | if args.resume: 35 | checkpoint = torch.load(args.resumed_ckpt) 36 | start_epoch = checkpoint['epoch'] 37 | 38 | best_error = 1000 39 | timer.set() 40 | # Initiate Logger 41 | with open(args.ckpt + '/logger.txt', 'w+') as f: 42 | f.write('Training Session on ' + time.strftime("%Y%m%d-%H") + '\n') 43 | # Write Used Arguments 44 | f.write('Used Arguments:\n') 45 | print('Used Arguments:') 46 | for key in args.__dict__.keys(): 47 | f.write(key + ':{}\n'.format(args.__dict__[key])) 48 | print(key + ':{}'.format(args.__dict__[key])) 49 | 50 | # Training Loop 51 | for epoch in range(start_epoch, args.epochs): 52 | f.write('Epoch: {}'.format(epoch + 1)) 53 | epoch_train_loss, is_best = 0.0, False 54 | with tqdm(total=len(train_loader), ncols=0, file=sys.stdout, 55 | desc='Epoch: {}'.format(epoch + 1)) as pbar: 56 | 57 | for i, in_batch in enumerate(train_loader): 58 | optimizer.zero_grad() 59 | image_in1, _, _, gt_pixel_label, _ = in_batch 60 | 61 | if torch.cuda.is_available(): 62 | image_in1, gt_pixel_label = image_in1.cuda(), gt_pixel_label.cuda() 63 | 64 | coarse_kp = net(image_in1) 65 | loss = Heatmap_criterion(coarse_kp, gt_pixel_label) 66 | 67 | epoch_train_loss += loss.item() 68 | loss.backward() 69 | optimizer.step() 70 | pbar.set_postfix(coarse_Kp=loss.item()) 71 | pbar.update() 72 | 73 | epoch_train_loss = epoch_train_loss / len(train_loader) 74 | # save the checkpoint 75 | save_checkpoint( 76 | {'epoch': epoch + 1, 'net_state_dict': net.module.state_dict() if args.mGPU else net.state_dict()}, 77 | is_best, filename=os.path.join(args.ckpt, 'checkpoint.pth.tar'), 78 | best_filename=os.path.join(args.ckpt, 'best_checkpoint.pth.tar')) 79 | 80 | f.write(', Average Training Loss: {} '.format(epoch_train_loss)) 81 | print('Average Epoch Loss = {}'.format(epoch_train_loss)) 82 | 83 | # Check Error of the Trained Model on test set 84 | if epoch % args.test_every_n_epoch == args.test_every_n_epoch - 1: 85 | print('Testing the network...') 86 | net.eval() 87 | output = test.test(args, net, epoch) 88 | print(output['message']) 89 | if output['coarse_error'] < best_error: 90 | best_error = output['coarse_error'] 91 | is_best = True 92 | # save the checkpoint as best checkpoint so far 93 | save_checkpoint( 94 | {'epoch': epoch + 1, 95 | 'net_state_dict': net.module.state_dict() if args.mGPU else net.state_dict()}, 96 | is_best, filename=os.path.join(args.ckpt, 'checkpoint.pth.tar'), 97 | best_filename=os.path.join(args.ckpt, 'best_checkpoint.pth.tar')) 98 | f.write('\n') 99 | f.write(output['message']) 100 | f.write('\n') 101 | net.train() 102 | 103 | timer.stop() 104 | f.write('Finished Trainig Session after {0} Epochs & {1} hours & {2} minutes, ' 105 | 'Best coarse error Achieved: {3:.2f} pixel in 56 by 56 grid \n' 106 | .format(args.epochs - start_epoch, int(timer.elapsed / 3600), 107 | int((timer.elapsed % 3600) / 60), best_error)) 108 | f.close() 109 | 110 | print('Finished Trainig Session after {0} Epochs & {1} hours & {2} minutes, ' 111 | 'Best coarse error Achieved: {3:.2f} pixel in 56 by 56 grid \n' 112 | .format(args.epochs - start_epoch, int(timer.elapsed / 3600), 113 | int((timer.elapsed % 3600) / 60), best_error)) 114 | 115 | # Train Stage 2 for Foarse Heatmap Regression and Orientation Estimation 116 | elif args.use_case == 'stage2': 117 | params = net.module.refinement.parameters() if args.mGPU and (torch.cuda.device_count() > 1) else net.refinement.parameters() 118 | Heatmap_criterion = nn.MSELoss() 119 | Orientation_criterion = nn.CrossEntropyLoss(train_set.pose_distribution.float().cuda()) 120 | 121 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 122 | timer = Chronometer() 123 | start_epoch = args.start_epoch 124 | 125 | if not os.path.isdir(args.ckpt): 126 | os.mkdir(args.ckpt) 127 | 128 | if args.resume: 129 | checkpoint = torch.load(args.resumed_ckpt) 130 | start_epoch = checkpoint['epoch'] 131 | 132 | best_error, best_accuracy = 1000, 0 133 | timer.set() 134 | # Initiate Logger 135 | with open(args.ckpt + '/logger.txt', 'w+') as f: 136 | f.write('Training Session on ' + time.strftime("%Y%m%d-%H") + '\n') 137 | # Write Used Arguments 138 | f.write('Used Arguments:\n') 139 | print('Used Arguments:') 140 | for key in args.__dict__.keys(): 141 | f.write(key + ':{}\n'.format(args.__dict__[key])) 142 | print(key + ':{}'.format(args.__dict__[key])) 143 | 144 | # Training Loop 145 | for epoch in range(start_epoch, args.epochs): 146 | f.write('Epoch: {}'.format(epoch + 1)) 147 | 148 | epoch_train_heatmap_loss, epoch_train_orientation_loss, is_best_orinetation, is_best_kp = \ 149 | 0.0, 0.0, False, False 150 | 151 | with tqdm(total=len(train_loader), ncols=0, file=sys.stdout, 152 | desc='Epoch: {}'.format(epoch + 1)) as pbar: 153 | 154 | for i, in_batch in enumerate(train_loader): 155 | optimizer.zero_grad() 156 | 157 | image_in1, image_in2, gt_heatmaps, _, gt_orientation_label = in_batch 158 | 159 | if torch.cuda.is_available(): 160 | image_in1, image_in2, gt_heatmaps, gt_orientation_label = image_in1.cuda(),\ 161 | image_in2.cuda(),\ 162 | gt_heatmaps.cuda(),\ 163 | gt_orientation_label.cuda() 164 | 165 | coarse_kp, fine_kp, orientation = net(image_in1, image_in2) 166 | heatmap_loss = Heatmap_criterion(fine_kp, gt_heatmaps) 167 | orientation_loss = Orientation_criterion(orientation, gt_orientation_label.squeeze()) 168 | loss = heatmap_loss + args.Lambda * orientation_loss 169 | 170 | epoch_train_heatmap_loss += heatmap_loss.item() 171 | epoch_train_orientation_loss += orientation_loss.item() 172 | loss.backward() 173 | optimizer.step() 174 | pbar.set_postfix(fine_kp_loss=heatmap_loss.item(), orientation_loss=orientation_loss.item()) 175 | pbar.update() 176 | 177 | epoch_train_heatmap_loss, epoch_train_orientation_loss = \ 178 | epoch_train_heatmap_loss / len(train_loader), epoch_train_orientation_loss / len(train_loader) 179 | 180 | # save the checkpoint 181 | save_checkpoint( 182 | {'epoch': epoch + 1, 'net_state_dict': net.module.state_dict() if args.mGPU else net.state_dict()}, 183 | is_best=False, filename=os.path.join(args.ckpt, 'checkpoint.pth.tar'), 184 | best_filename=os.path.join(args.ckpt, 'best_checkpoint.pth.tar')) 185 | 186 | f.write('Average Heatmap & Orientation Loss : {} and {}'. 187 | format(epoch_train_heatmap_loss, epoch_train_orientation_loss)) 188 | 189 | print('Average Heatmap & Orientation Loss : {} and {}'. 190 | format(epoch_train_heatmap_loss, epoch_train_orientation_loss)) 191 | 192 | # Check Error of the Trained Model on test set 193 | if epoch % args.test_every_n_epoch == args.test_every_n_epoch - 1: 194 | print('Testing the network...') 195 | net.eval() 196 | output = test.test(args, net, epoch) 197 | print(output['message']) 198 | if output['fine_error'] < best_error: 199 | best_error = output['fine_error'] 200 | is_best_kp = True 201 | if output['orientation_accuracy'] > best_accuracy: 202 | best_accuracy = output['orientation_accuracy'] 203 | is_best_orinetation = True 204 | 205 | # save the checkpoint as best checkpoint so far 206 | save_checkpoint( 207 | {'epoch': epoch + 1, 208 | 'net_state_dict': net.module.state_dict() if args.mGPU else net.state_dict()}, 209 | is_best_kp, filename=os.path.join(args.ckpt, 'checkpoint.pth.tar'), 210 | best_filename=os.path.join(args.ckpt, 'best_fine_kp_checkpoint.pth.tar')) 211 | save_checkpoint( 212 | {'epoch': epoch + 1, 213 | 'net_state_dict': net.module.state_dict() if args.mGPU else net.state_dict()}, 214 | is_best_orinetation, filename=os.path.join(args.ckpt, 'checkpoint.pth.tar'), 215 | best_filename=os.path.join(args.ckpt, 'best_orientation_checkpoint.pth.tar')) 216 | 217 | f.write('\n') 218 | f.write(output['message']) 219 | f.write('\n') 220 | net.train() 221 | 222 | timer.stop() 223 | f.write('Finished training session after {0} epochs, {1} hours & {2} minutes, best fine error: ' 224 | '{3:.2f} pixels in 56 by 56 grid, best orientation accuracy: {4:.2f}%.' 225 | .format(args.epochs - start_epoch, int(timer.elapsed / 3600), int((timer.elapsed % 3600) / 60), 226 | best_error, best_accuracy)) 227 | f.close() 228 | 229 | print('Finished training session after {0} epochs, {1} hours & {2} minutes, best fine error: ' 230 | '{3:.2f} pixels in 56 by 56 grid, best orientation accuracy: {4:.2f}%.' 231 | .format(args.epochs - start_epoch, int(timer.elapsed / 3600), int((timer.elapsed % 3600) / 60), 232 | best_error, best_accuracy)) 233 | 234 | -------------------------------------------------------------------------------- /tools/utilities.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import time 6 | 7 | KP_labels = ['left-front wheel', 'left-back wheel', 'right-front wheel', 'right-back wheel', 'right fog lamp', 8 | 'left fog lamp', 'right headlight', 'left headlight', 'front auto logo', 'front license plate', 9 | 'left rear-view mirror', 'right rear-view mirror', 'right-front corner of vehicle top', 10 | 'left-front corner of vehicle top', 'left-back corner of vehicle top', 'right-back corner of vehicle top', 11 | 'left rear lamp', 'right rear lamp', 'rear auto logo', 'rear license plate'] 12 | 13 | 14 | class Chronometer: 15 | """ 16 | Chronometer class to time the code 17 | """ 18 | def __init__(self): 19 | self.elapsed = 0 20 | self.start = 0 21 | self.end = 0 22 | 23 | def set(self): 24 | self.start = time.time() 25 | 26 | def stop(self): 27 | self.end = time.time() 28 | self.elapsed = (self.end - self.start) 29 | 30 | def reset(self): 31 | self.start, self.end, self.elapsed = 0, 0, 0 32 | 33 | 34 | def save_checkpoint(state, is_best, filename='./checkpoint/checkpoint.pth.tar', 35 | best_filename='./checkpoint/model_best.pth.tar'): 36 | """ 37 | Save trained model 38 | :param state: the state dictionary to be saved 39 | :param is_best: boolian to show if this is the best checkpoint so far 40 | :param filename: label of the current checkpoint 41 | :param best_filename: label of the so far best checkpoint 42 | """ 43 | torch.save(state, filename) 44 | if is_best: 45 | shutil.copyfile(filename, best_filename) 46 | 47 | 48 | def sample_visualizer(outputs, maps, inputs): 49 | """ 50 | Randomly visualize the estimated key-points and their respective ground-truth maps 51 | :param outputs: the estimated key-points 52 | :param maps: the ground-truth maps 53 | :param inputs: the tensor containing the normalized image data 54 | :return: visualize the heatmaps 55 | """ 56 | rand = np.random.randint(0, outputs.shape[0]) 57 | map_out, map1, sample_in = outputs[rand], maps[rand], inputs[rand] 58 | 59 | plt.figure(2) 60 | plt.imshow(sample_in.transpose(0, 2).transpose(0, 1).cpu().numpy()) 61 | plt.pause(.05) 62 | map_out = map_out.cpu().numpy() 63 | map1 = map1.cpu().numpy() 64 | 65 | plt.figure(1) 66 | for i in range(0, 20): 67 | plt.subplot(4, 10, i + 1) 68 | plt.imshow(map_out[i] / map_out.sum()) 69 | plt.xlabel(KP_labels[i], fontsize=5) 70 | plt.xticks() 71 | plt.subplot(4, 10, 21 + i) 72 | plt.imshow(map1[i]) 73 | plt.xlabel(KP_labels[i], fontsize=5) 74 | 75 | plt.pause(.05) 76 | plt.draw() 77 | 78 | 79 | def get_preds(heatmaps): 80 | """ 81 | Get the coordinates of 82 | :param heatmaps: heatmaps 83 | :return: coordinate of hottest points 84 | """ 85 | assert heatmaps.dim() == 4, 'Score maps should be 4-dim Batch, Channel, Heigth, Width' 86 | 87 | maxval, idx = torch.max(heatmaps.view(heatmaps.size(0), heatmaps.size(1), -1), 2) 88 | maxval = maxval.view(heatmaps.size(0), heatmaps.size(1), 1) 89 | idx = idx.view(heatmaps.size(0), heatmaps.size(1), 1) + 1 90 | 91 | preds = idx.repeat(1, 1, 2).float() 92 | 93 | preds[:, :, 0] = (preds[:, :, 0] - 1) % heatmaps.size(3) + 1 94 | preds[:, :, 1] = torch.floor((preds[:, :, 1] - 1) / heatmaps.size(3)) + 1 95 | 96 | pred_mask = maxval.gt(0).repeat(1, 1, 2).float() 97 | preds *= pred_mask 98 | return preds 99 | 100 | 101 | def calc_dists(preds, target): 102 | """ 103 | Calculate the average distance from predictions to their ground-truth 104 | :param preds: predicted coordinates of key-points from estimations 105 | :param target: predicted coordionates of key-point from ground-truth maps 106 | :return: the average distance 107 | """ 108 | preds = preds.float() 109 | target = target.float() 110 | cnt = 0 111 | dists = 0 112 | for n in range(preds.size(0)): 113 | for c in range(preds.size(1)): 114 | if target[n, c, 0] > 1 and target[n, c, 1] > 1: 115 | dists += torch.dist(preds[n, c, :], target[n, c, :]) 116 | cnt += 1 117 | dists = dists / cnt 118 | return dists 119 | 120 | 121 | def accuracy(output, target): 122 | """ 123 | Calculate the accuracy of predicted key-points with respect to visible key-points in ground-truth 124 | :param output: the estimated key-points of shape B * 21 (or 20) * 56 * 56 125 | :param target: the gt-key-points of shape B * 21 * 56 * 56 126 | :return: the average distance of the hottest point from its ground-truth 127 | """ 128 | preds = get_preds(output[:, :20, :, :]) 129 | gts = get_preds(target[:, :20, :, :]) 130 | dists = calc_dists(preds, gts) 131 | return dists.item() 132 | 133 | 134 | --------------------------------------------------------------------------------